"vscode:/vscode.git/clone" did not exist on "c8aa934673a5ddd8d870aef673f91e6910df7c62"
Commit 6c7ff9c7 authored by LDOUBLEV's avatar LDOUBLEV
Browse files

fix conflict

parents ac91a9e1 9b8f587e
...@@ -19,27 +19,29 @@ __dir__ = os.path.dirname(__file__) ...@@ -19,27 +19,29 @@ __dir__ = os.path.dirname(__file__)
sys.path.append(os.path.join(__dir__, '')) sys.path.append(os.path.join(__dir__, ''))
import cv2 import cv2
import logging
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
import tarfile
import requests
from tqdm import tqdm
from tools.infer import predict_system from tools.infer import predict_system
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
logger = get_logger() logger = get_logger()
from ppocr.utils.utility import check_and_read_gif, get_image_file_list from ppocr.utils.utility import check_and_read_gif, get_image_file_list
from tools.infer.utility import draw_ocr from ppocr.utils.network import maybe_download, download_with_progressbar, is_link, confirm_model_dir_url
from tools.infer.utility import draw_ocr, str2bool
from ppstructure.utility import init_args, draw_structure_result
from ppstructure.predict_system import OCRSystem, save_structure_res
__all__ = ['PaddleOCR'] __all__ = ['PaddleOCR', 'PPStructure', 'draw_ocr', 'draw_structure_result', 'save_structure_res','download_with_progressbar']
model_urls = { model_urls = {
'det': { 'det': {
'ch': 'ch':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
'en': 'en':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar' 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar',
'structure': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar'
}, },
'rec': { 'rec': {
'ch': { 'ch': {
...@@ -111,175 +113,47 @@ model_urls = { ...@@ -111,175 +113,47 @@ model_urls = {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/devanagari_dict.txt' 'dict_path': './ppocr/utils/dict/devanagari_dict.txt'
},
'structure': {
'url': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar',
'dict_path': 'ppocr/utils/dict/table_dict.txt'
} }
}, },
'cls': 'cls': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar' 'table': {
'url': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar',
'dict_path': 'ppocr/utils/dict/table_structure_dict.txt'
}
} }
SUPPORT_DET_MODEL = ['DB'] SUPPORT_DET_MODEL = ['DB']
VERSION = '2.1' VERSION = '2.2'
SUPPORT_REC_MODEL = ['CRNN'] SUPPORT_REC_MODEL = ['CRNN']
BASE_DIR = os.path.expanduser("~/.paddleocr/") BASE_DIR = os.path.expanduser("~/.paddleocr/")
def download_with_progressbar(url, save_path): def parse_args(mMain=True):
response = requests.get(url, stream=True)
total_size_in_bytes = int(response.headers.get('content-length', 0))
block_size = 1024 # 1 Kibibyte
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
with open(save_path, 'wb') as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes:
logger.error("Something went wrong while downloading models")
sys.exit(0)
def maybe_download(model_storage_directory, url):
# using custom model
tar_file_name_list = [
'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel'
]
if not os.path.exists(
os.path.join(model_storage_directory, 'inference.pdiparams')
) or not os.path.exists(
os.path.join(model_storage_directory, 'inference.pdmodel')):
tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
print('download {} to {}'.format(url, tmp_path))
os.makedirs(model_storage_directory, exist_ok=True)
download_with_progressbar(url, tmp_path)
with tarfile.open(tmp_path, 'r') as tarObj:
for member in tarObj.getmembers():
filename = None
for tar_file_name in tar_file_name_list:
if tar_file_name in member.name:
filename = tar_file_name
if filename is None:
continue
file = tarObj.extractfile(member)
with open(
os.path.join(model_storage_directory, filename),
'wb') as f:
f.write(file.read())
os.remove(tmp_path)
def parse_args(mMain=True, add_help=True):
import argparse import argparse
parser = init_args()
def str2bool(v): parser.add_help = mMain
return v.lower() in ("true", "t", "1")
if mMain:
parser = argparse.ArgumentParser(add_help=add_help)
# params for prediction engine
parser.add_argument("--use_gpu", type=str2bool, default=True)
parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--gpu_mem", type=int, default=8000)
# params for text detector
parser.add_argument("--image_dir", type=str)
parser.add_argument("--det_algorithm", type=str, default='DB')
parser.add_argument("--det_model_dir", type=str, default=None)
parser.add_argument("--det_limit_side_len", type=float, default=960)
parser.add_argument("--det_limit_type", type=str, default='max')
# DB parmas
parser.add_argument("--det_db_thresh", type=float, default=0.3)
parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6)
parser.add_argument("--use_dilation", type=bool, default=False)
parser.add_argument("--det_db_score_mode", type=str, default="fast")
# EAST parmas
parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
# params for text recognizer
parser.add_argument("--rec_algorithm", type=str, default='CRNN')
parser.add_argument("--rec_model_dir", type=str, default=None)
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
parser.add_argument("--rec_char_type", type=str, default='ch')
parser.add_argument("--rec_batch_num", type=int, default=6)
parser.add_argument("--max_text_length", type=int, default=25)
parser.add_argument("--rec_char_dict_path", type=str, default=None)
parser.add_argument("--use_space_char", type=bool, default=True)
parser.add_argument("--drop_score", type=float, default=0.5)
# params for text classifier
parser.add_argument("--cls_model_dir", type=str, default=None)
parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
parser.add_argument("--label_list", type=list, default=['0', '180'])
parser.add_argument("--cls_batch_num", type=int, default=6)
parser.add_argument("--cls_thresh", type=float, default=0.9)
parser.add_argument("--enable_mkldnn", type=bool, default=False)
parser.add_argument("--use_zero_copy_run", type=bool, default=False)
parser.add_argument("--use_pdserving", type=str2bool, default=False)
parser.add_argument("--lang", type=str, default='ch') parser.add_argument("--lang", type=str, default='ch')
parser.add_argument("--det", type=str2bool, default=True) parser.add_argument("--det", type=str2bool, default=True)
parser.add_argument("--rec", type=str2bool, default=True) parser.add_argument("--rec", type=str2bool, default=True)
parser.add_argument("--use_angle_cls", type=str2bool, default=False) parser.add_argument("--type", type=str, default='ocr')
for action in parser._actions:
if action.dest in ['rec_char_dict_path', 'table_char_dict_path']:
action.default = None
if mMain:
return parser.parse_args() return parser.parse_args()
else: else:
return argparse.Namespace( inference_args_dict = {}
use_gpu=True, for action in parser._actions:
ir_optim=True, inference_args_dict[action.dest] = action.default
use_tensorrt=False, return argparse.Namespace(**inference_args_dict)
gpu_mem=8000,
image_dir='',
det_algorithm='DB',
det_model_dir=None,
det_limit_side_len=960,
det_limit_type='max',
det_db_thresh=0.3,
det_db_box_thresh=0.5,
det_db_unclip_ratio=1.6,
use_dilation=False,
det_db_score_mode="fast",
det_east_score_thresh=0.8,
det_east_cover_thresh=0.1,
det_east_nms_thresh=0.2,
rec_algorithm='CRNN',
rec_model_dir=None,
rec_image_shape="3, 32, 320",
rec_char_type='ch',
rec_batch_num=6,
max_text_length=25,
rec_char_dict_path=None,
use_space_char=True,
drop_score=0.5,
cls_model_dir=None,
cls_image_shape="3, 48, 192",
label_list=['0', '180'],
cls_batch_num=6,
cls_thresh=0.9,
enable_mkldnn=False,
use_zero_copy_run=False,
use_pdserving=False,
lang='ch',
det=True,
rec=True,
use_angle_cls=False)
class PaddleOCR(predict_system.TextSystem): def parse_lang(lang):
def __init__(self, **kwargs):
"""
paddleocr package
args:
**kwargs: other params show in paddleocr --help
"""
postprocess_params = parse_args(mMain=False, add_help=False)
postprocess_params.__dict__.update(**kwargs)
self.use_angle_cls = postprocess_params.use_angle_cls
lang = postprocess_params.lang
latin_lang = [ latin_lang = [
'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga', 'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga',
'hr', 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms', 'hr', 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms',
...@@ -308,45 +182,57 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -308,45 +182,57 @@ class PaddleOCR(predict_system.TextSystem):
model_urls['rec'].keys(), lang) model_urls['rec'].keys(), lang)
if lang == "ch": if lang == "ch":
det_lang = "ch" det_lang = "ch"
elif lang == 'structure':
det_lang = 'structure'
else: else:
det_lang = "en" det_lang = "en"
use_inner_dict = False return lang, det_lang
if postprocess_params.rec_char_dict_path is None:
use_inner_dict = True
postprocess_params.rec_char_dict_path = model_urls['rec'][lang][ class PaddleOCR(predict_system.TextSystem):
'dict_path'] def __init__(self, **kwargs):
"""
paddleocr package
args:
**kwargs: other params show in paddleocr --help
"""
params = parse_args(mMain=False)
params.__dict__.update(**kwargs)
if not params.show_log:
logger.setLevel(logging.INFO)
self.use_angle_cls = params.use_angle_cls
lang, det_lang = parse_lang(params.lang)
# init model dir # init model dir
if postprocess_params.det_model_dir is None: params.det_model_dir, det_url = confirm_model_dir_url(params.det_model_dir,
postprocess_params.det_model_dir = os.path.join(BASE_DIR, VERSION, os.path.join(BASE_DIR, VERSION, 'ocr', 'det', det_lang),
'det', det_lang)
if postprocess_params.rec_model_dir is None:
postprocess_params.rec_model_dir = os.path.join(BASE_DIR, VERSION,
'rec', lang)
if postprocess_params.cls_model_dir is None:
postprocess_params.cls_model_dir = os.path.join(BASE_DIR, 'cls')
print(postprocess_params)
# download model
maybe_download(postprocess_params.det_model_dir,
model_urls['det'][det_lang]) model_urls['det'][det_lang])
maybe_download(postprocess_params.rec_model_dir, params.rec_model_dir, rec_url = confirm_model_dir_url(params.rec_model_dir,
os.path.join(BASE_DIR, VERSION, 'ocr', 'rec', lang),
model_urls['rec'][lang]['url']) model_urls['rec'][lang]['url'])
maybe_download(postprocess_params.cls_model_dir, model_urls['cls']) params.cls_model_dir, cls_url = confirm_model_dir_url(params.cls_model_dir,
os.path.join(BASE_DIR, VERSION, 'ocr', 'cls'),
model_urls['cls'])
# download model
maybe_download(params.det_model_dir, det_url)
maybe_download(params.rec_model_dir, rec_url)
maybe_download(params.cls_model_dir, cls_url)
if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL: if params.det_algorithm not in SUPPORT_DET_MODEL:
logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL)) logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL))
sys.exit(0) sys.exit(0)
if postprocess_params.rec_algorithm not in SUPPORT_REC_MODEL: if params.rec_algorithm not in SUPPORT_REC_MODEL:
logger.error('rec_algorithm must in {}'.format(SUPPORT_REC_MODEL)) logger.error('rec_algorithm must in {}'.format(SUPPORT_REC_MODEL))
sys.exit(0) sys.exit(0)
if use_inner_dict:
postprocess_params.rec_char_dict_path = str(
Path(__file__).parent / postprocess_params.rec_char_dict_path)
if params.rec_char_dict_path is None:
params.rec_char_dict_path = str(Path(__file__).parent / model_urls['rec'][lang]['dict_path'])
print(params)
# init det_model and rec_model # init det_model and rec_model
super().__init__(postprocess_params) super().__init__(params)
def ocr(self, img, det=True, rec=True, cls=False): def ocr(self, img, det=True, rec=True, cls=True):
""" """
ocr with paddleocr ocr with paddleocr
args: args:
...@@ -358,9 +244,7 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -358,9 +244,7 @@ class PaddleOCR(predict_system.TextSystem):
if isinstance(img, list) and det == True: if isinstance(img, list) and det == True:
logger.error('When input a list of images, det must be false') logger.error('When input a list of images, det must be false')
exit(0) exit(0)
if cls == False: if cls == True and self.use_angle_cls == False:
self.use_angle_cls = False
elif cls == True and self.use_angle_cls == False:
logger.warning( logger.warning(
'Since the angle classifier is not initialized, the angle classifier will not be uesd during the forward process' 'Since the angle classifier is not initialized, the angle classifier will not be uesd during the forward process'
) )
...@@ -382,7 +266,7 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -382,7 +266,7 @@ class PaddleOCR(predict_system.TextSystem):
if isinstance(img, np.ndarray) and len(img.shape) == 2: if isinstance(img, np.ndarray) and len(img.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
if det and rec: if det and rec:
dt_boxes, rec_res = self.__call__(img) dt_boxes, rec_res = self.__call__(img, cls)
return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)] return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
elif det and not rec: elif det and not rec:
dt_boxes, elapse = self.text_detector(img) dt_boxes, elapse = self.text_detector(img)
...@@ -392,7 +276,7 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -392,7 +276,7 @@ class PaddleOCR(predict_system.TextSystem):
else: else:
if not isinstance(img, list): if not isinstance(img, list):
img = [img] img = [img]
if self.use_angle_cls: if self.use_angle_cls and cls:
img, cls_res, elapse = self.text_classifier(img) img, cls_res, elapse = self.text_classifier(img)
if not rec: if not rec:
return cls_res return cls_res
...@@ -400,11 +284,64 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -400,11 +284,64 @@ class PaddleOCR(predict_system.TextSystem):
return rec_res return rec_res
class PPStructure(OCRSystem):
def __init__(self, **kwargs):
params = parse_args(mMain=False)
params.__dict__.update(**kwargs)
if not params.show_log:
logger.setLevel(logging.INFO)
lang, det_lang = parse_lang(params.lang)
# init model dir
params.det_model_dir, det_url = confirm_model_dir_url(params.det_model_dir,
os.path.join(BASE_DIR, VERSION, 'ocr', 'det', det_lang),
model_urls['det'][det_lang])
params.rec_model_dir, rec_url = confirm_model_dir_url(params.rec_model_dir,
os.path.join(BASE_DIR, VERSION, 'ocr', 'rec', lang),
model_urls['rec'][lang]['url'])
params.table_model_dir, table_url = confirm_model_dir_url(params.table_model_dir,
os.path.join(BASE_DIR, VERSION, 'ocr', 'table'),
model_urls['table']['url'])
# download model
maybe_download(params.det_model_dir, det_url)
maybe_download(params.rec_model_dir, rec_url)
maybe_download(params.table_model_dir, table_url)
if params.rec_char_dict_path is None:
params.rec_char_dict_path = str(Path(__file__).parent / model_urls['rec'][lang]['dict_path'])
if params.table_char_dict_path is None:
params.table_char_dict_path = str(Path(__file__).parent / model_urls['table']['dict_path'])
print(params)
super().__init__(params)
def __call__(self, img):
if isinstance(img, str):
# download net image
if img.startswith('http'):
download_with_progressbar(img, 'tmp.jpg')
img = 'tmp.jpg'
image_file = img
img, flag = check_and_read_gif(image_file)
if not flag:
with open(image_file, 'rb') as f:
np_arr = np.frombuffer(f.read(), dtype=np.uint8)
img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
if img is None:
logger.error("error in loading image:{}".format(image_file))
return None
if isinstance(img, np.ndarray) and len(img.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
res = super().__call__(img)
return res
def main(): def main():
# for cmd # for cmd
args = parse_args(mMain=True) args = parse_args(mMain=True)
image_dir = args.image_dir image_dir = args.image_dir
if image_dir.startswith('http'): if is_link(image_dir):
download_with_progressbar(image_dir, 'tmp.jpg') download_with_progressbar(image_dir, 'tmp.jpg')
image_file_list = ['tmp.jpg'] image_file_list = ['tmp.jpg']
else: else:
...@@ -412,14 +349,29 @@ def main(): ...@@ -412,14 +349,29 @@ def main():
if len(image_file_list) == 0: if len(image_file_list) == 0:
logger.error('no images find in {}'.format(args.image_dir)) logger.error('no images find in {}'.format(args.image_dir))
return return
if args.type == 'ocr':
engine = PaddleOCR(**(args.__dict__))
elif args.type == 'structure':
engine = PPStructure(**(args.__dict__))
else:
raise NotImplementedError
ocr_engine = PaddleOCR(**(args.__dict__))
for img_path in image_file_list: for img_path in image_file_list:
img_name = os.path.basename(img_path).split('.')[0]
logger.info('{}{}{}'.format('*' * 10, img_path, '*' * 10)) logger.info('{}{}{}'.format('*' * 10, img_path, '*' * 10))
result = ocr_engine.ocr(img_path, if args.type == 'ocr':
result = engine.ocr(img_path,
det=args.det, det=args.det,
rec=args.rec, rec=args.rec,
cls=args.use_angle_cls) cls=args.use_angle_cls)
if result is not None: if result is not None:
for line in result: for line in result:
logger.info(line) logger.info(line)
elif args.type == 'structure':
result = engine(img_path)
save_structure_res(result, args.output, img_name)
for item in result:
item.pop('img')
logger.info(item)
...@@ -35,6 +35,7 @@ from ppocr.data.imaug import transform, create_operators ...@@ -35,6 +35,7 @@ from ppocr.data.imaug import transform, create_operators
from ppocr.data.simple_dataset import SimpleDataSet from ppocr.data.simple_dataset import SimpleDataSet
from ppocr.data.lmdb_dataset import LMDBDataSet from ppocr.data.lmdb_dataset import LMDBDataSet
from ppocr.data.pgnet_dataset import PGDataSet from ppocr.data.pgnet_dataset import PGDataSet
from ppocr.data.pubtab_dataset import PubTabDataSet
__all__ = ['build_dataloader', 'transform', 'create_operators'] __all__ = ['build_dataloader', 'transform', 'create_operators']
...@@ -55,7 +56,7 @@ signal.signal(signal.SIGTERM, term_mp) ...@@ -55,7 +56,7 @@ signal.signal(signal.SIGTERM, term_mp)
def build_dataloader(config, mode, device, logger, seed=None): def build_dataloader(config, mode, device, logger, seed=None):
config = copy.deepcopy(config) config = copy.deepcopy(config)
support_dict = ['SimpleDataSet', 'LMDBDataSet', 'PGDataSet'] support_dict = ['SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet']
module_name = config[mode]['dataset']['name'] module_name = config[mode]['dataset']['name']
assert module_name in support_dict, Exception( assert module_name in support_dict, Exception(
'DataSet only support {}'.format(support_dict)) 'DataSet only support {}'.format(support_dict))
......
...@@ -23,12 +23,14 @@ from .random_crop_data import EastRandomCropData, PSERandomCrop ...@@ -23,12 +23,14 @@ from .random_crop_data import EastRandomCropData, PSERandomCrop
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg
from .randaugment import RandAugment from .randaugment import RandAugment
from .copy_paste import CopyPaste
from .operators import * from .operators import *
from .label_ops import * from .label_ops import *
from .east_process import * from .east_process import *
from .sast_process import * from .sast_process import *
from .pg_process import * from .pg_process import *
from .gen_table_mask import *
def transform(data, ops=None): def transform(data, ops=None):
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import cv2
import random
import numpy as np
from PIL import Image
from shapely.geometry import Polygon
from ppocr.data.imaug.iaa_augment import IaaAugment
from ppocr.data.imaug.random_crop_data import is_poly_outside_rect
from tools.infer.utility import get_rotate_crop_image
class CopyPaste(object):
def __init__(self, objects_paste_ratio=0.2, limit_paste=True, **kwargs):
self.ext_data_num = 1
self.objects_paste_ratio = objects_paste_ratio
self.limit_paste = limit_paste
augmenter_args = [{'type': 'Resize', 'args': {'size': [0.5, 3]}}]
self.aug = IaaAugment(augmenter_args)
def __call__(self, data):
src_img = data['image']
src_polys = data['polys'].tolist()
src_ignores = data['ignore_tags'].tolist()
ext_data = data['ext_data'][0]
ext_image = ext_data['image']
ext_polys = ext_data['polys']
ext_ignores = ext_data['ignore_tags']
indexs = [i for i in range(len(ext_ignores)) if not ext_ignores[i]]
select_num = max(
1, min(int(self.objects_paste_ratio * len(ext_polys)), 30))
random.shuffle(indexs)
select_idxs = indexs[:select_num]
select_polys = ext_polys[select_idxs]
select_ignores = ext_ignores[select_idxs]
src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
ext_image = cv2.cvtColor(ext_image, cv2.COLOR_BGR2RGB)
src_img = Image.fromarray(src_img).convert('RGBA')
for poly, tag in zip(select_polys, select_ignores):
box_img = get_rotate_crop_image(ext_image, poly)
src_img, box = self.paste_img(src_img, box_img, src_polys)
if box is not None:
src_polys.append(box)
src_ignores.append(tag)
src_img = cv2.cvtColor(np.array(src_img), cv2.COLOR_RGB2BGR)
h, w = src_img.shape[:2]
src_polys = np.array(src_polys)
src_polys[:, :, 0] = np.clip(src_polys[:, :, 0], 0, w)
src_polys[:, :, 1] = np.clip(src_polys[:, :, 1], 0, h)
data['image'] = src_img
data['polys'] = src_polys
data['ignore_tags'] = np.array(src_ignores)
return data
def paste_img(self, src_img, box_img, src_polys):
box_img_pil = Image.fromarray(box_img).convert('RGBA')
src_w, src_h = src_img.size
box_w, box_h = box_img_pil.size
angle = np.random.randint(0, 360)
box = np.array([[[0, 0], [box_w, 0], [box_w, box_h], [0, box_h]]])
box = rotate_bbox(box_img, box, angle)[0]
box_img_pil = box_img_pil.rotate(angle, expand=1)
box_w, box_h = box_img_pil.width, box_img_pil.height
if src_w - box_w < 0 or src_h - box_h < 0:
return src_img, None
paste_x, paste_y = self.select_coord(src_polys, box, src_w - box_w,
src_h - box_h)
if paste_x is None:
return src_img, None
box[:, 0] += paste_x
box[:, 1] += paste_y
r, g, b, A = box_img_pil.split()
src_img.paste(box_img_pil, (paste_x, paste_y), mask=A)
return src_img, box
def select_coord(self, src_polys, box, endx, endy):
if self.limit_paste:
xmin, ymin, xmax, ymax = box[:, 0].min(), box[:, 1].min(
), box[:, 0].max(), box[:, 1].max()
for _ in range(50):
paste_x = random.randint(0, endx)
paste_y = random.randint(0, endy)
xmin1 = xmin + paste_x
xmax1 = xmax + paste_x
ymin1 = ymin + paste_y
ymax1 = ymax + paste_y
num_poly_in_rect = 0
for poly in src_polys:
if not is_poly_outside_rect(poly, xmin1, ymin1,
xmax1 - xmin1, ymax1 - ymin1):
num_poly_in_rect += 1
break
if num_poly_in_rect == 0:
return paste_x, paste_y
return None, None
else:
paste_x = random.randint(0, endx)
paste_y = random.randint(0, endy)
return paste_x, paste_y
def get_union(pD, pG):
return Polygon(pD).union(Polygon(pG)).area
def get_intersection_over_union(pD, pG):
return get_intersection(pD, pG) / get_union(pD, pG)
def get_intersection(pD, pG):
return Polygon(pD).intersection(Polygon(pG)).area
def rotate_bbox(img, text_polys, angle, scale=1):
"""
from https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/augment.py
Args:
img: np.ndarray
text_polys: np.ndarray N*4*2
angle: int
scale: int
Returns:
"""
w = img.shape[1]
h = img.shape[0]
rangle = np.deg2rad(angle)
nw = (abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w))
nh = (abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w))
rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, scale)
rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0]))
rot_mat[0, 2] += rot_move[0]
rot_mat[1, 2] += rot_move[1]
# ---------------------- rotate box ----------------------
rot_text_polys = list()
for bbox in text_polys:
point1 = np.dot(rot_mat, np.array([bbox[0, 0], bbox[0, 1], 1]))
point2 = np.dot(rot_mat, np.array([bbox[1, 0], bbox[1, 1], 1]))
point3 = np.dot(rot_mat, np.array([bbox[2, 0], bbox[2, 1], 1]))
point4 = np.dot(rot_mat, np.array([bbox[3, 0], bbox[3, 1], 1]))
rot_text_polys.append([point1, point2, point3, point4])
return np.array(rot_text_polys, dtype=np.float32)
"""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import sys
import six
import cv2
import numpy as np
class GenTableMask(object):
""" gen table mask """
def __init__(self, shrink_h_max, shrink_w_max, mask_type=0, **kwargs):
self.shrink_h_max = 5
self.shrink_w_max = 5
self.mask_type = mask_type
def projection(self, erosion, h, w, spilt_threshold=0):
# 水平投影
projection_map = np.ones_like(erosion)
project_val_array = [0 for _ in range(0, h)]
for j in range(0, h):
for i in range(0, w):
if erosion[j, i] == 255:
project_val_array[j] += 1
# 根据数组,获取切割点
start_idx = 0 # 记录进入字符区的索引
end_idx = 0 # 记录进入空白区域的索引
in_text = False # 是否遍历到了字符区内
box_list = []
for i in range(len(project_val_array)):
if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了
in_text = True
start_idx = i
elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了
end_idx = i
in_text = False
if end_idx - start_idx <= 2:
continue
box_list.append((start_idx, end_idx + 1))
if in_text:
box_list.append((start_idx, h - 1))
# 绘制投影直方图
for j in range(0, h):
for i in range(0, project_val_array[j]):
projection_map[j, i] = 0
return box_list, projection_map
def projection_cx(self, box_img):
box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY)
h, w = box_gray_img.shape
# 灰度图片进行二值化处理
ret, thresh1 = cv2.threshold(box_gray_img, 200, 255, cv2.THRESH_BINARY_INV)
# 纵向腐蚀
if h < w:
kernel = np.ones((2, 1), np.uint8)
erode = cv2.erode(thresh1, kernel, iterations=1)
else:
erode = thresh1
# 水平膨胀
kernel = np.ones((1, 5), np.uint8)
erosion = cv2.dilate(erode, kernel, iterations=1)
# 水平投影
projection_map = np.ones_like(erosion)
project_val_array = [0 for _ in range(0, h)]
for j in range(0, h):
for i in range(0, w):
if erosion[j, i] == 255:
project_val_array[j] += 1
# 根据数组,获取切割点
start_idx = 0 # 记录进入字符区的索引
end_idx = 0 # 记录进入空白区域的索引
in_text = False # 是否遍历到了字符区内
box_list = []
spilt_threshold = 0
for i in range(len(project_val_array)):
if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了
in_text = True
start_idx = i
elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了
end_idx = i
in_text = False
if end_idx - start_idx <= 2:
continue
box_list.append((start_idx, end_idx + 1))
if in_text:
box_list.append((start_idx, h - 1))
# 绘制投影直方图
for j in range(0, h):
for i in range(0, project_val_array[j]):
projection_map[j, i] = 0
split_bbox_list = []
if len(box_list) > 1:
for i, (h_start, h_end) in enumerate(box_list):
if i == 0:
h_start = 0
if i == len(box_list):
h_end = h
word_img = erosion[h_start:h_end + 1, :]
word_h, word_w = word_img.shape
w_split_list, w_projection_map = self.projection(word_img.T, word_w, word_h)
w_start, w_end = w_split_list[0][0], w_split_list[-1][1]
if h_start > 0:
h_start -= 1
h_end += 1
word_img = box_img[h_start:h_end + 1:, w_start:w_end + 1, :]
split_bbox_list.append([w_start, h_start, w_end, h_end])
else:
split_bbox_list.append([0, 0, w, h])
return split_bbox_list
def shrink_bbox(self, bbox):
left, top, right, bottom = bbox
sh_h = min(max(int((bottom - top) * 0.1), 1), self.shrink_h_max)
sh_w = min(max(int((right - left) * 0.1), 1), self.shrink_w_max)
left_new = left + sh_w
right_new = right - sh_w
top_new = top + sh_h
bottom_new = bottom - sh_h
if left_new >= right_new:
left_new = left
right_new = right
if top_new >= bottom_new:
top_new = top
bottom_new = bottom
return [left_new, top_new, right_new, bottom_new]
def __call__(self, data):
img = data['image']
cells = data['cells']
height, width = img.shape[0:2]
if self.mask_type == 1:
mask_img = np.zeros((height, width), dtype=np.float32)
else:
mask_img = np.zeros((height, width, 3), dtype=np.float32)
cell_num = len(cells)
for cno in range(cell_num):
if "bbox" in cells[cno]:
bbox = cells[cno]['bbox']
left, top, right, bottom = bbox
box_img = img[top:bottom, left:right, :].copy()
split_bbox_list = self.projection_cx(box_img)
for sno in range(len(split_bbox_list)):
split_bbox_list[sno][0] += left
split_bbox_list[sno][1] += top
split_bbox_list[sno][2] += left
split_bbox_list[sno][3] += top
for sno in range(len(split_bbox_list)):
left, top, right, bottom = split_bbox_list[sno]
left, top, right, bottom = self.shrink_bbox([left, top, right, bottom])
if self.mask_type == 1:
mask_img[top:bottom, left:right] = 1.0
data['mask_img'] = mask_img
else:
mask_img[top:bottom, left:right, :] = (255, 255, 255)
data['image'] = mask_img
return data
class ResizeTableImage(object):
def __init__(self, max_len, **kwargs):
super(ResizeTableImage, self).__init__()
self.max_len = max_len
def get_img_bbox(self, cells):
bbox_list = []
if len(cells) == 0:
return bbox_list
cell_num = len(cells)
for cno in range(cell_num):
if "bbox" in cells[cno]:
bbox = cells[cno]['bbox']
bbox_list.append(bbox)
return bbox_list
def resize_img_table(self, img, bbox_list, max_len):
height, width = img.shape[0:2]
ratio = max_len / (max(height, width) * 1.0)
resize_h = int(height * ratio)
resize_w = int(width * ratio)
img_new = cv2.resize(img, (resize_w, resize_h))
bbox_list_new = []
for bno in range(len(bbox_list)):
left, top, right, bottom = bbox_list[bno].copy()
left = int(left * ratio)
top = int(top * ratio)
right = int(right * ratio)
bottom = int(bottom * ratio)
bbox_list_new.append([left, top, right, bottom])
return img_new, bbox_list_new
def __call__(self, data):
img = data['image']
if 'cells' not in data:
cells = []
else:
cells = data['cells']
bbox_list = self.get_img_bbox(cells)
img_new, bbox_list_new = self.resize_img_table(img, bbox_list, self.max_len)
data['image'] = img_new
cell_num = len(cells)
bno = 0
for cno in range(cell_num):
if "bbox" in data['cells'][cno]:
data['cells'][cno]['bbox'] = bbox_list_new[bno]
bno += 1
data['max_len'] = self.max_len
return data
class PaddingTableImage(object):
def __init__(self, **kwargs):
super(PaddingTableImage, self).__init__()
def __call__(self, data):
img = data['image']
max_len = data['max_len']
padding_img = np.zeros((max_len, max_len, 3), dtype=np.float32)
height, width = img.shape[0:2]
padding_img[0:height, 0:width, :] = img.copy()
data['image'] = padding_img
return data
\ No newline at end of file
...@@ -19,6 +19,7 @@ from __future__ import unicode_literals ...@@ -19,6 +19,7 @@ from __future__ import unicode_literals
import numpy as np import numpy as np
import string import string
import json
class ClsLabelEncode(object): class ClsLabelEncode(object):
...@@ -39,7 +40,6 @@ class DetLabelEncode(object): ...@@ -39,7 +40,6 @@ class DetLabelEncode(object):
pass pass
def __call__(self, data): def __call__(self, data):
import json
label = data['label'] label = data['label']
label = json.loads(label) label = json.loads(label)
nBox = len(label) nBox = len(label)
...@@ -53,6 +53,8 @@ class DetLabelEncode(object): ...@@ -53,6 +53,8 @@ class DetLabelEncode(object):
txt_tags.append(True) txt_tags.append(True)
else: else:
txt_tags.append(False) txt_tags.append(False)
if len(boxes) == 0:
return None
boxes = self.expand_points_num(boxes) boxes = self.expand_points_num(boxes)
boxes = np.array(boxes, dtype=np.float32) boxes = np.array(boxes, dtype=np.float32)
txt_tags = np.array(txt_tags, dtype=np.bool) txt_tags = np.array(txt_tags, dtype=np.bool)
...@@ -351,3 +353,171 @@ class SRNLabelEncode(BaseRecLabelEncode): ...@@ -351,3 +353,171 @@ class SRNLabelEncode(BaseRecLabelEncode):
assert False, "Unsupport type %s in get_beg_end_flag_idx" \ assert False, "Unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end % beg_or_end
return idx return idx
class TableLabelEncode(object):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length,
max_elem_length,
max_cell_num,
character_dict_path,
span_weight=1.0,
**kwargs):
self.max_text_length = max_text_length
self.max_elem_length = max_elem_length
self.max_cell_num = max_cell_num
list_character, list_elem = self.load_char_elem_dict(
character_dict_path)
list_character = self.add_special_char(list_character)
list_elem = self.add_special_char(list_elem)
self.dict_character = {}
for i, char in enumerate(list_character):
self.dict_character[char] = i
self.dict_elem = {}
for i, elem in enumerate(list_elem):
self.dict_elem[elem] = i
self.span_weight = span_weight
def load_char_elem_dict(self, character_dict_path):
list_character = []
list_elem = []
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
substr = lines[0].decode('utf-8').strip("\r\n").split("\t")
character_num = int(substr[0])
elem_num = int(substr[1])
for cno in range(1, 1 + character_num):
character = lines[cno].decode('utf-8').strip("\r\n")
list_character.append(character)
for eno in range(1 + character_num, 1 + character_num + elem_num):
elem = lines[eno].decode('utf-8').strip("\r\n")
list_elem.append(elem)
return list_character, list_elem
def add_special_char(self, list_character):
self.beg_str = "sos"
self.end_str = "eos"
list_character = [self.beg_str] + list_character + [self.end_str]
return list_character
def get_span_idx_list(self):
span_idx_list = []
for elem in self.dict_elem:
if 'span' in elem:
span_idx_list.append(self.dict_elem[elem])
return span_idx_list
def __call__(self, data):
cells = data['cells']
structure = data['structure']['tokens']
structure = self.encode(structure, 'elem')
if structure is None:
return None
elem_num = len(structure)
structure = [0] + structure + [len(self.dict_elem) - 1]
structure = structure + [0] * (self.max_elem_length + 2 - len(structure)
)
structure = np.array(structure)
data['structure'] = structure
elem_char_idx1 = self.dict_elem['<td>']
elem_char_idx2 = self.dict_elem['<td']
span_idx_list = self.get_span_idx_list()
td_idx_list = np.logical_or(structure == elem_char_idx1,
structure == elem_char_idx2)
td_idx_list = np.where(td_idx_list)[0]
structure_mask = np.ones(
(self.max_elem_length + 2, 1), dtype=np.float32)
bbox_list = np.zeros((self.max_elem_length + 2, 4), dtype=np.float32)
bbox_list_mask = np.zeros(
(self.max_elem_length + 2, 1), dtype=np.float32)
img_height, img_width, img_ch = data['image'].shape
if len(span_idx_list) > 0:
span_weight = len(td_idx_list) * 1.0 / len(span_idx_list)
span_weight = min(max(span_weight, 1.0), self.span_weight)
for cno in range(len(cells)):
if 'bbox' in cells[cno]:
bbox = cells[cno]['bbox'].copy()
bbox[0] = bbox[0] * 1.0 / img_width
bbox[1] = bbox[1] * 1.0 / img_height
bbox[2] = bbox[2] * 1.0 / img_width
bbox[3] = bbox[3] * 1.0 / img_height
td_idx = td_idx_list[cno]
bbox_list[td_idx] = bbox
bbox_list_mask[td_idx] = 1.0
cand_span_idx = td_idx + 1
if cand_span_idx < (self.max_elem_length + 2):
if structure[cand_span_idx] in span_idx_list:
structure_mask[cand_span_idx] = span_weight
data['bbox_list'] = bbox_list
data['bbox_list_mask'] = bbox_list_mask
data['structure_mask'] = structure_mask
char_beg_idx = self.get_beg_end_flag_idx('beg', 'char')
char_end_idx = self.get_beg_end_flag_idx('end', 'char')
elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
data['sp_tokens'] = np.array([
char_beg_idx, char_end_idx, elem_beg_idx, elem_end_idx,
elem_char_idx1, elem_char_idx2, self.max_text_length,
self.max_elem_length, self.max_cell_num, elem_num
])
return data
def encode(self, text, char_or_elem):
"""convert text-label into text-index.
"""
if char_or_elem == "char":
max_len = self.max_text_length
current_dict = self.dict_character
else:
max_len = self.max_elem_length
current_dict = self.dict_elem
if len(text) > max_len:
return None
if len(text) == 0:
if char_or_elem == "char":
return [self.dict_character['space']]
else:
return None
text_list = []
for char in text:
if char not in current_dict:
return None
text_list.append(current_dict[char])
if len(text_list) == 0:
if char_or_elem == "char":
return [self.dict_character['space']]
else:
return None
return text_list
def get_ignored_tokens(self, char_or_elem):
beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
return [beg_idx, end_idx]
def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
if char_or_elem == "char":
if beg_or_end == "beg":
idx = np.array(self.dict_character[self.beg_str])
elif beg_or_end == "end":
idx = np.array(self.dict_character[self.end_str])
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
% beg_or_end
elif char_or_elem == "elem":
if beg_or_end == "beg":
idx = np.array(self.dict_elem[self.beg_str])
elif beg_or_end == "end":
idx = np.array(self.dict_elem[self.end_str])
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
% beg_or_end
else:
assert False, "Unsupport type %s in char_or_elem" \
% char_or_elem
return idx
...@@ -163,7 +163,7 @@ class DetResizeForTest(object): ...@@ -163,7 +163,7 @@ class DetResizeForTest(object):
img, (ratio_h, ratio_w) img, (ratio_h, ratio_w)
""" """
limit_side_len = self.limit_side_len limit_side_len = self.limit_side_len
h, w, _ = img.shape h, w, c = img.shape
# limit the max side # limit the max side
if self.limit_type == 'max': if self.limit_type == 'max':
...@@ -174,7 +174,7 @@ class DetResizeForTest(object): ...@@ -174,7 +174,7 @@ class DetResizeForTest(object):
ratio = float(limit_side_len) / w ratio = float(limit_side_len) / w
else: else:
ratio = 1. ratio = 1.
else: elif self.limit_type == 'min':
if min(h, w) < limit_side_len: if min(h, w) < limit_side_len:
if h < w: if h < w:
ratio = float(limit_side_len) / h ratio = float(limit_side_len) / h
...@@ -182,6 +182,10 @@ class DetResizeForTest(object): ...@@ -182,6 +182,10 @@ class DetResizeForTest(object):
ratio = float(limit_side_len) / w ratio = float(limit_side_len) / w
else: else:
ratio = 1. ratio = 1.
elif self.limit_type == 'resize_long':
ratio = float(limit_side_len) / max(h,w)
else:
raise Exception('not support limit type, image ')
resize_h = int(h * ratio) resize_h = int(h * ratio)
resize_w = int(w * ratio) resize_w = int(w * ratio)
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import os
import random
from paddle.io import Dataset
import json
from .imaug import transform, create_operators
class PubTabDataSet(Dataset):
def __init__(self, config, mode, logger, seed=None):
super(PubTabDataSet, self).__init__()
self.logger = logger
global_config = config['Global']
dataset_config = config[mode]['dataset']
loader_config = config[mode]['loader']
label_file_path = dataset_config.pop('label_file_path')
self.data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle']
self.do_hard_select = False
if 'hard_select' in loader_config:
self.do_hard_select = loader_config['hard_select']
self.hard_prob = loader_config['hard_prob']
if self.do_hard_select:
self.img_select_prob = self.load_hard_select_prob()
self.table_select_type = None
if 'table_select_type' in loader_config:
self.table_select_type = loader_config['table_select_type']
self.table_select_prob = loader_config['table_select_prob']
self.seed = seed
logger.info("Initialize indexs of datasets:%s" % label_file_path)
with open(label_file_path, "rb") as f:
self.data_lines = f.readlines()
self.data_idx_order_list = list(range(len(self.data_lines)))
if mode.lower() == "train":
self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config)
def shuffle_data_random(self):
if self.do_shuffle:
random.seed(self.seed)
random.shuffle(self.data_lines)
return
def __getitem__(self, idx):
try:
data_line = self.data_lines[idx]
data_line = data_line.decode('utf-8').strip("\n")
info = json.loads(data_line)
file_name = info['filename']
select_flag = True
if self.do_hard_select:
prob = self.img_select_prob[file_name]
if prob < random.uniform(0, 1):
select_flag = False
if self.table_select_type:
structure = info['html']['structure']['tokens'].copy()
structure_str = ''.join(structure)
table_type = "simple"
if 'colspan' in structure_str or 'rowspan' in structure_str:
table_type = "complex"
if table_type == "complex":
if self.table_select_prob < random.uniform(0, 1):
select_flag = False
if select_flag:
cells = info['html']['cells'].copy()
structure = info['html']['structure'].copy()
img_path = os.path.join(self.data_dir, file_name)
data = {'img_path': img_path, 'cells': cells, 'structure':structure}
if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path))
with open(data['img_path'], 'rb') as f:
img = f.read()
data['image'] = img
outs = transform(data, self.ops)
else:
outs = None
except Exception as e:
self.logger.error(
"When parsing line {}, error happened with msg: {}".format(
data_line, e))
outs = None
if outs is None:
return self.__getitem__(np.random.randint(self.__len__()))
return outs
def __len__(self):
return len(self.data_idx_order_list)
...@@ -69,12 +69,42 @@ class SimpleDataSet(Dataset): ...@@ -69,12 +69,42 @@ class SimpleDataSet(Dataset):
random.shuffle(self.data_lines) random.shuffle(self.data_lines)
return return
def get_ext_data(self):
ext_data_num = 0
for op in self.ops:
if hasattr(op, 'ext_data_num'):
ext_data_num = getattr(op, 'ext_data_num')
break
load_data_ops = self.ops[:2]
ext_data = []
while len(ext_data) < ext_data_num:
file_idx = self.data_idx_order_list[np.random.randint(self.__len__(
))]
data_line = self.data_lines[file_idx]
data_line = data_line.decode('utf-8')
substr = data_line.strip("\n").split(self.delimiter)
file_name = substr[0]
label = substr[1]
img_path = os.path.join(self.data_dir, file_name)
data = {'img_path': img_path, 'label': label}
if not os.path.exists(img_path):
continue
with open(data['img_path'], 'rb') as f:
img = f.read()
data['image'] = img
data = transform(data, load_data_ops)
if data is None:
continue
ext_data.append(data)
return ext_data
def __getitem__(self, idx): def __getitem__(self, idx):
file_idx = self.data_idx_order_list[idx] file_idx = self.data_idx_order_list[idx]
data_line = self.data_lines[file_idx] data_line = self.data_lines[file_idx]
try: try:
data_line = data_line.decode('utf-8') data_line = data_line.decode('utf-8')
substr = data_line.strip("\n").strip("\r").split(self.delimiter) substr = data_line.strip("\n").split(self.delimiter)
file_name = substr[0] file_name = substr[0]
label = substr[1] label = substr[1]
img_path = os.path.join(self.data_dir, file_name) img_path = os.path.join(self.data_dir, file_name)
...@@ -84,6 +114,7 @@ class SimpleDataSet(Dataset): ...@@ -84,6 +114,7 @@ class SimpleDataSet(Dataset):
with open(data['img_path'], 'rb') as f: with open(data['img_path'], 'rb') as f:
img = f.read() img = f.read()
data['image'] = img data['image'] = img
data['ext_data'] = self.get_ext_data()
outs = transform(data, self.ops) outs = transform(data, self.ops)
except Exception as e: except Exception as e:
self.logger.error( self.logger.error(
......
...@@ -13,28 +13,39 @@ ...@@ -13,28 +13,39 @@
# limitations under the License. # limitations under the License.
import copy import copy
import paddle
import paddle.nn as nn
# det loss
from .det_db_loss import DBLoss
from .det_east_loss import EASTLoss
from .det_sast_loss import SASTLoss
def build_loss(config): # rec loss
# det loss from .rec_ctc_loss import CTCLoss
from .det_db_loss import DBLoss from .rec_att_loss import AttentionLoss
from .det_east_loss import EASTLoss from .rec_srn_loss import SRNLoss
from .det_sast_loss import SASTLoss
# cls loss
from .cls_loss import ClsLoss
# e2e loss
from .e2e_pg_loss import PGLoss
# rec loss # basic loss function
from .rec_ctc_loss import CTCLoss from .basic_loss import DistanceLoss
from .rec_att_loss import AttentionLoss
from .rec_srn_loss import SRNLoss
# cls loss # combined loss function
from .cls_loss import ClsLoss from .combined_loss import CombinedLoss
# e2e loss # table loss
from .e2e_pg_loss import PGLoss from .table_att_loss import TableAttentionLoss
def build_loss(config):
support_dict = [ support_dict = [
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss', 'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
'SRNLoss', 'PGLoss'] 'SRNLoss', 'PGLoss', 'CombinedLoss', 'TableAttentionLoss'
]
config = copy.deepcopy(config) config = copy.deepcopy(config)
module_name = config.pop('name') module_name = config.pop('name')
assert module_name in support_dict, Exception('loss only support {}'.format( assert module_name in support_dict, Exception('loss only support {}'.format(
......
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import L1Loss
from paddle.nn import MSELoss as L2Loss
from paddle.nn import SmoothL1Loss
class CELoss(nn.Layer):
def __init__(self, epsilon=None):
super().__init__()
if epsilon is not None and (epsilon <= 0 or epsilon >= 1):
epsilon = None
self.epsilon = epsilon
def _labelsmoothing(self, target, class_num):
if target.shape[-1] != class_num:
one_hot_target = F.one_hot(target, class_num)
else:
one_hot_target = target
soft_target = F.label_smooth(one_hot_target, epsilon=self.epsilon)
soft_target = paddle.reshape(soft_target, shape=[-1, class_num])
return soft_target
def forward(self, x, label):
loss_dict = {}
if self.epsilon is not None:
class_num = x.shape[-1]
label = self._labelsmoothing(label, class_num)
x = -F.log_softmax(x, axis=-1)
loss = paddle.sum(x * label, axis=-1)
else:
if label.shape[-1] == x.shape[-1]:
label = F.softmax(label, axis=-1)
soft_label = True
else:
soft_label = False
loss = F.cross_entropy(x, label=label, soft_label=soft_label)
return loss
class KLJSLoss(object):
def __init__(self, mode='kl'):
assert mode in ['kl', 'js', 'KL', 'JS'], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
self.mode = mode
def __call__(self, p1, p2, reduction="mean"):
loss = paddle.multiply(p2, paddle.log( (p2+1e-5)/(p1+1e-5) + 1e-5))
if self.mode.lower() == "js":
loss += paddle.multiply(p1, paddle.log((p1+1e-5)/(p2+1e-5) + 1e-5))
loss *= 0.5
if reduction == "mean":
loss = paddle.mean(loss, axis=[1,2])
elif reduction=="none" or reduction is None:
return loss
else:
loss = paddle.sum(loss, axis=[1,2])
return loss
class DMLLoss(nn.Layer):
"""
DMLLoss
"""
def __init__(self, act=None):
super().__init__()
if act is not None:
assert act in ["softmax", "sigmoid"]
if act == "softmax":
self.act = nn.Softmax(axis=-1)
elif act == "sigmoid":
self.act = nn.Sigmoid()
else:
self.act = None
self.jskl_loss = KLJSLoss(mode="js")
def forward(self, out1, out2):
if self.act is not None:
out1 = self.act(out1)
out2 = self.act(out2)
if len(out1.shape) < 2:
log_out1 = paddle.log(out1)
log_out2 = paddle.log(out2)
loss = (F.kl_div(
log_out1, out2, reduction='batchmean') + F.kl_div(
log_out2, out1, reduction='batchmean')) / 2.0
else:
loss = self.jskl_loss(out1, out2)
return loss
class DistanceLoss(nn.Layer):
"""
DistanceLoss:
mode: loss mode
"""
def __init__(self, mode="l2", **kargs):
super().__init__()
assert mode in ["l1", "l2", "smooth_l1"]
if mode == "l1":
self.loss_func = nn.L1Loss(**kargs)
elif mode == "l2":
self.loss_func = nn.MSELoss(**kargs)
elif mode == "smooth_l1":
self.loss_func = nn.SmoothL1Loss(**kargs)
def forward(self, x, y):
return self.loss_func(x, y)
...@@ -24,7 +24,7 @@ class ClsLoss(nn.Layer): ...@@ -24,7 +24,7 @@ class ClsLoss(nn.Layer):
super(ClsLoss, self).__init__() super(ClsLoss, self).__init__()
self.loss_func = nn.CrossEntropyLoss(reduction='mean') self.loss_func = nn.CrossEntropyLoss(reduction='mean')
def __call__(self, predicts, batch): def forward(self, predicts, batch):
label = batch[1] label = batch[1]
loss = self.loss_func(input=predicts, label=label) loss = self.loss_func(input=predicts, label=label)
return {'loss': loss} return {'loss': loss}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment