"deploy/cpp_infer/src_rec/preprocess_op.cpp" did not exist on "028ccc45a00204910f9fb74b06d3a3e6576d7f5c"
Unverified Commit adc62fcd authored by topduke's avatar topduke Committed by GitHub
Browse files

Merge branch 'dygraph' into dygraph

parents 8227ad1b a81b88a0
...@@ -19,154 +19,119 @@ __dir__ = os.path.dirname(__file__) ...@@ -19,154 +19,119 @@ __dir__ = os.path.dirname(__file__)
sys.path.append(os.path.join(__dir__, '')) sys.path.append(os.path.join(__dir__, ''))
import cv2 import cv2
import logging
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
import tarfile
import requests
from tqdm import tqdm
from tools.infer import predict_system from tools.infer import predict_system
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
logger = get_logger() logger = get_logger()
from ppocr.utils.utility import check_and_read_gif, get_image_file_list from ppocr.utils.utility import check_and_read_gif, get_image_file_list
from tools.infer.utility import draw_ocr, init_args, str2bool from ppocr.utils.network import maybe_download, download_with_progressbar, is_link, confirm_model_dir_url
from tools.infer.utility import draw_ocr, str2bool
from ppstructure.utility import init_args, draw_structure_result
from ppstructure.predict_system import OCRSystem, save_structure_res
__all__ = ['PaddleOCR'] __all__ = ['PaddleOCR', 'PPStructure', 'draw_ocr', 'draw_structure_result', 'save_structure_res','download_with_progressbar']
model_urls = { model_urls = {
'det': { 'det': {
'ch': 'ch':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
'en': 'en':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar' 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar',
'structure': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar'
}, },
'rec': { 'rec': {
'ch': { 'ch': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/ppocr_keys_v1.txt' 'dict_path': './ppocr/utils/ppocr_keys_v1.txt'
}, },
'en': { 'en': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/en_dict.txt' 'dict_path': './ppocr/utils/en_dict.txt'
}, },
'french': { 'french': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/french_dict.txt' 'dict_path': './ppocr/utils/dict/french_dict.txt'
}, },
'german': { 'german': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/german_dict.txt' 'dict_path': './ppocr/utils/dict/german_dict.txt'
}, },
'korean': { 'korean': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/korean_dict.txt' 'dict_path': './ppocr/utils/dict/korean_dict.txt'
}, },
'japan': { 'japan': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/japan_dict.txt' 'dict_path': './ppocr/utils/dict/japan_dict.txt'
}, },
'chinese_cht': { 'chinese_cht': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/chinese_cht_dict.txt' 'dict_path': './ppocr/utils/dict/chinese_cht_dict.txt'
}, },
'ta': { 'ta': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/ta_dict.txt' 'dict_path': './ppocr/utils/dict/ta_dict.txt'
}, },
'te': { 'te': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/te_dict.txt' 'dict_path': './ppocr/utils/dict/te_dict.txt'
}, },
'ka': { 'ka': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/ka_dict.txt' 'dict_path': './ppocr/utils/dict/ka_dict.txt'
}, },
'latin': { 'latin': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/latin_dict.txt' 'dict_path': './ppocr/utils/dict/latin_dict.txt'
}, },
'arabic': { 'arabic': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/arabic_dict.txt' 'dict_path': './ppocr/utils/dict/arabic_dict.txt'
}, },
'cyrillic': { 'cyrillic': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/cyrillic_dict.txt' 'dict_path': './ppocr/utils/dict/cyrillic_dict.txt'
}, },
'devanagari': { 'devanagari': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/devanagari_dict.txt' 'dict_path': './ppocr/utils/dict/devanagari_dict.txt'
},
'structure': {
'url': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar',
'dict_path': 'ppocr/utils/dict/table_dict.txt'
} }
}, },
'cls': 'cls': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar',
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar' 'table': {
'url': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar',
'dict_path': 'ppocr/utils/dict/table_structure_dict.txt'
}
} }
SUPPORT_DET_MODEL = ['DB'] SUPPORT_DET_MODEL = ['DB']
VERSION = '2.1' VERSION = '2.2'
SUPPORT_REC_MODEL = ['CRNN'] SUPPORT_REC_MODEL = ['CRNN']
BASE_DIR = os.path.expanduser("~/.paddleocr/") BASE_DIR = os.path.expanduser("~/.paddleocr/")
def download_with_progressbar(url, save_path):
response = requests.get(url, stream=True)
total_size_in_bytes = int(response.headers.get('content-length', 0))
block_size = 1024 # 1 Kibibyte
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
with open(save_path, 'wb') as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes:
logger.error("Something went wrong while downloading models")
sys.exit(0)
def maybe_download(model_storage_directory, url):
# using custom model
tar_file_name_list = [
'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel'
]
if not os.path.exists(
os.path.join(model_storage_directory, 'inference.pdiparams')
) or not os.path.exists(
os.path.join(model_storage_directory, 'inference.pdmodel')):
tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
print('download {} to {}'.format(url, tmp_path))
os.makedirs(model_storage_directory, exist_ok=True)
download_with_progressbar(url, tmp_path)
with tarfile.open(tmp_path, 'r') as tarObj:
for member in tarObj.getmembers():
filename = None
for tar_file_name in tar_file_name_list:
if tar_file_name in member.name:
filename = tar_file_name
if filename is None:
continue
file = tarObj.extractfile(member)
with open(
os.path.join(model_storage_directory, filename),
'wb') as f:
f.write(file.read())
os.remove(tmp_path)
def parse_args(mMain=True): def parse_args(mMain=True):
import argparse import argparse
parser = init_args() parser = init_args()
...@@ -174,9 +139,10 @@ def parse_args(mMain=True): ...@@ -174,9 +139,10 @@ def parse_args(mMain=True):
parser.add_argument("--lang", type=str, default='ch') parser.add_argument("--lang", type=str, default='ch')
parser.add_argument("--det", type=str2bool, default=True) parser.add_argument("--det", type=str2bool, default=True)
parser.add_argument("--rec", type=str2bool, default=True) parser.add_argument("--rec", type=str2bool, default=True)
parser.add_argument("--type", type=str, default='ocr')
for action in parser._actions: for action in parser._actions:
if action.dest == 'rec_char_dict_path': if action.dest in ['rec_char_dict_path', 'table_char_dict_path']:
action.default = None action.default = None
if mMain: if mMain:
return parser.parse_args() return parser.parse_args()
...@@ -187,6 +153,42 @@ def parse_args(mMain=True): ...@@ -187,6 +153,42 @@ def parse_args(mMain=True):
return argparse.Namespace(**inference_args_dict) return argparse.Namespace(**inference_args_dict)
def parse_lang(lang):
latin_lang = [
'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga',
'hr', 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms',
'mt', 'nl', 'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin', 'sk',
'sl', 'sq', 'sv', 'sw', 'tl', 'tr', 'uz', 'vi'
]
arabic_lang = ['ar', 'fa', 'ug', 'ur']
cyrillic_lang = [
'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd',
'ava', 'dar', 'inh', 'che', 'lbe', 'lez', 'tab'
]
devanagari_lang = [
'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new',
'gom', 'sa', 'bgc'
]
if lang in latin_lang:
lang = "latin"
elif lang in arabic_lang:
lang = "arabic"
elif lang in cyrillic_lang:
lang = "cyrillic"
elif lang in devanagari_lang:
lang = "devanagari"
assert lang in model_urls[
'rec'], 'param lang must in {}, but got {}'.format(
model_urls['rec'].keys(), lang)
if lang == "ch":
det_lang = "ch"
elif lang == 'structure':
det_lang = 'structure'
else:
det_lang = "en"
return lang, det_lang
class PaddleOCR(predict_system.TextSystem): class PaddleOCR(predict_system.TextSystem):
def __init__(self, **kwargs): def __init__(self, **kwargs):
""" """
...@@ -194,75 +196,41 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -194,75 +196,41 @@ class PaddleOCR(predict_system.TextSystem):
args: args:
**kwargs: other params show in paddleocr --help **kwargs: other params show in paddleocr --help
""" """
postprocess_params = parse_args(mMain=False) params = parse_args(mMain=False)
postprocess_params.__dict__.update(**kwargs) params.__dict__.update(**kwargs)
self.use_angle_cls = postprocess_params.use_angle_cls if not params.show_log:
lang = postprocess_params.lang logger.setLevel(logging.INFO)
latin_lang = [ self.use_angle_cls = params.use_angle_cls
'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga', lang, det_lang = parse_lang(params.lang)
'hr', 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms',
'mt', 'nl', 'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin', 'sk',
'sl', 'sq', 'sv', 'sw', 'tl', 'tr', 'uz', 'vi'
]
arabic_lang = ['ar', 'fa', 'ug', 'ur']
cyrillic_lang = [
'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd',
'ava', 'dar', 'inh', 'che', 'lbe', 'lez', 'tab'
]
devanagari_lang = [
'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new',
'gom', 'sa', 'bgc'
]
if lang in latin_lang:
lang = "latin"
elif lang in arabic_lang:
lang = "arabic"
elif lang in cyrillic_lang:
lang = "cyrillic"
elif lang in devanagari_lang:
lang = "devanagari"
assert lang in model_urls[
'rec'], 'param lang must in {}, but got {}'.format(
model_urls['rec'].keys(), lang)
if lang == "ch":
det_lang = "ch"
else:
det_lang = "en"
use_inner_dict = False
if postprocess_params.rec_char_dict_path is None:
use_inner_dict = True
postprocess_params.rec_char_dict_path = model_urls['rec'][lang][
'dict_path']
# init model dir # init model dir
if postprocess_params.det_model_dir is None: params.det_model_dir, det_url = confirm_model_dir_url(params.det_model_dir,
postprocess_params.det_model_dir = os.path.join(BASE_DIR, VERSION, os.path.join(BASE_DIR, VERSION, 'ocr', 'det', det_lang),
'det', det_lang) model_urls['det'][det_lang])
if postprocess_params.rec_model_dir is None: params.rec_model_dir, rec_url = confirm_model_dir_url(params.rec_model_dir,
postprocess_params.rec_model_dir = os.path.join(BASE_DIR, VERSION, os.path.join(BASE_DIR, VERSION, 'ocr', 'rec', lang),
'rec', lang) model_urls['rec'][lang]['url'])
if postprocess_params.cls_model_dir is None: params.cls_model_dir, cls_url = confirm_model_dir_url(params.cls_model_dir,
postprocess_params.cls_model_dir = os.path.join(BASE_DIR, 'cls') os.path.join(BASE_DIR, VERSION, 'ocr', 'cls'),
print(postprocess_params) model_urls['cls'])
# download model # download model
maybe_download(postprocess_params.det_model_dir, maybe_download(params.det_model_dir, det_url)
model_urls['det'][det_lang]) maybe_download(params.rec_model_dir, rec_url)
maybe_download(postprocess_params.rec_model_dir, maybe_download(params.cls_model_dir, cls_url)
model_urls['rec'][lang]['url'])
maybe_download(postprocess_params.cls_model_dir, model_urls['cls'])
if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL: if params.det_algorithm not in SUPPORT_DET_MODEL:
logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL)) logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL))
sys.exit(0) sys.exit(0)
if postprocess_params.rec_algorithm not in SUPPORT_REC_MODEL: if params.rec_algorithm not in SUPPORT_REC_MODEL:
logger.error('rec_algorithm must in {}'.format(SUPPORT_REC_MODEL)) logger.error('rec_algorithm must in {}'.format(SUPPORT_REC_MODEL))
sys.exit(0) sys.exit(0)
if use_inner_dict:
postprocess_params.rec_char_dict_path = str(
Path(__file__).parent / postprocess_params.rec_char_dict_path)
if params.rec_char_dict_path is None:
params.rec_char_dict_path = str(Path(__file__).parent / model_urls['rec'][lang]['dict_path'])
print(params)
# init det_model and rec_model # init det_model and rec_model
super().__init__(postprocess_params) super().__init__(params)
def ocr(self, img, det=True, rec=True, cls=True): def ocr(self, img, det=True, rec=True, cls=True):
""" """
...@@ -316,11 +284,64 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -316,11 +284,64 @@ class PaddleOCR(predict_system.TextSystem):
return rec_res return rec_res
class PPStructure(OCRSystem):
def __init__(self, **kwargs):
params = parse_args(mMain=False)
params.__dict__.update(**kwargs)
if not params.show_log:
logger.setLevel(logging.INFO)
lang, det_lang = parse_lang(params.lang)
# init model dir
params.det_model_dir, det_url = confirm_model_dir_url(params.det_model_dir,
os.path.join(BASE_DIR, VERSION, 'ocr', 'det', det_lang),
model_urls['det'][det_lang])
params.rec_model_dir, rec_url = confirm_model_dir_url(params.rec_model_dir,
os.path.join(BASE_DIR, VERSION, 'ocr', 'rec', lang),
model_urls['rec'][lang]['url'])
params.table_model_dir, table_url = confirm_model_dir_url(params.table_model_dir,
os.path.join(BASE_DIR, VERSION, 'ocr', 'table'),
model_urls['table']['url'])
# download model
maybe_download(params.det_model_dir, det_url)
maybe_download(params.rec_model_dir, rec_url)
maybe_download(params.table_model_dir, table_url)
if params.rec_char_dict_path is None:
params.rec_char_dict_path = str(Path(__file__).parent / model_urls['rec'][lang]['dict_path'])
if params.table_char_dict_path is None:
params.table_char_dict_path = str(Path(__file__).parent / model_urls['table']['dict_path'])
print(params)
super().__init__(params)
def __call__(self, img):
if isinstance(img, str):
# download net image
if img.startswith('http'):
download_with_progressbar(img, 'tmp.jpg')
img = 'tmp.jpg'
image_file = img
img, flag = check_and_read_gif(image_file)
if not flag:
with open(image_file, 'rb') as f:
np_arr = np.frombuffer(f.read(), dtype=np.uint8)
img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
if img is None:
logger.error("error in loading image:{}".format(image_file))
return None
if isinstance(img, np.ndarray) and len(img.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
res = super().__call__(img)
return res
def main(): def main():
# for cmd # for cmd
args = parse_args(mMain=True) args = parse_args(mMain=True)
image_dir = args.image_dir image_dir = args.image_dir
if image_dir.startswith('http'): if is_link(image_dir):
download_with_progressbar(image_dir, 'tmp.jpg') download_with_progressbar(image_dir, 'tmp.jpg')
image_file_list = ['tmp.jpg'] image_file_list = ['tmp.jpg']
else: else:
...@@ -328,14 +349,29 @@ def main(): ...@@ -328,14 +349,29 @@ def main():
if len(image_file_list) == 0: if len(image_file_list) == 0:
logger.error('no images find in {}'.format(args.image_dir)) logger.error('no images find in {}'.format(args.image_dir))
return return
if args.type == 'ocr':
engine = PaddleOCR(**(args.__dict__))
elif args.type == 'structure':
engine = PPStructure(**(args.__dict__))
else:
raise NotImplementedError
ocr_engine = PaddleOCR(**(args.__dict__))
for img_path in image_file_list: for img_path in image_file_list:
img_name = os.path.basename(img_path).split('.')[0]
logger.info('{}{}{}'.format('*' * 10, img_path, '*' * 10)) logger.info('{}{}{}'.format('*' * 10, img_path, '*' * 10))
result = ocr_engine.ocr(img_path, if args.type == 'ocr':
result = engine.ocr(img_path,
det=args.det, det=args.det,
rec=args.rec, rec=args.rec,
cls=args.use_angle_cls) cls=args.use_angle_cls)
if result is not None: if result is not None:
for line in result: for line in result:
logger.info(line) logger.info(line)
elif args.type == 'structure':
result = engine(img_path)
save_structure_res(result, args.output, img_name)
for item in result:
item.pop('img')
logger.info(item)
...@@ -35,6 +35,7 @@ from ppocr.data.imaug import transform, create_operators ...@@ -35,6 +35,7 @@ from ppocr.data.imaug import transform, create_operators
from ppocr.data.simple_dataset import SimpleDataSet from ppocr.data.simple_dataset import SimpleDataSet
from ppocr.data.lmdb_dataset import LMDBDataSet from ppocr.data.lmdb_dataset import LMDBDataSet
from ppocr.data.pgnet_dataset import PGDataSet from ppocr.data.pgnet_dataset import PGDataSet
from ppocr.data.pubtab_dataset import PubTabDataSet
__all__ = ['build_dataloader', 'transform', 'create_operators'] __all__ = ['build_dataloader', 'transform', 'create_operators']
...@@ -55,7 +56,7 @@ signal.signal(signal.SIGTERM, term_mp) ...@@ -55,7 +56,7 @@ signal.signal(signal.SIGTERM, term_mp)
def build_dataloader(config, mode, device, logger, seed=None): def build_dataloader(config, mode, device, logger, seed=None):
config = copy.deepcopy(config) config = copy.deepcopy(config)
support_dict = ['SimpleDataSet', 'LMDBDataSet', 'PGDataSet'] support_dict = ['SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet']
module_name = config[mode]['dataset']['name'] module_name = config[mode]['dataset']['name']
assert module_name in support_dict, Exception( assert module_name in support_dict, Exception(
'DataSet only support {}'.format(support_dict)) 'DataSet only support {}'.format(support_dict))
......
...@@ -23,12 +23,14 @@ from .random_crop_data import EastRandomCropData, PSERandomCrop ...@@ -23,12 +23,14 @@ from .random_crop_data import EastRandomCropData, PSERandomCrop
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, PILResize, CVResize from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, PILResize, CVResize
from .randaugment import RandAugment from .randaugment import RandAugment
from .copy_paste import CopyPaste
from .operators import * from .operators import *
from .label_ops import * from .label_ops import *
from .east_process import * from .east_process import *
from .sast_process import * from .sast_process import *
from .pg_process import * from .pg_process import *
from .gen_table_mask import *
def transform(data, ops=None): def transform(data, ops=None):
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import cv2
import random
import numpy as np
from PIL import Image
from shapely.geometry import Polygon
from ppocr.data.imaug.iaa_augment import IaaAugment
from ppocr.data.imaug.random_crop_data import is_poly_outside_rect
from tools.infer.utility import get_rotate_crop_image
class CopyPaste(object):
def __init__(self, objects_paste_ratio=0.2, limit_paste=True, **kwargs):
self.ext_data_num = 1
self.objects_paste_ratio = objects_paste_ratio
self.limit_paste = limit_paste
augmenter_args = [{'type': 'Resize', 'args': {'size': [0.5, 3]}}]
self.aug = IaaAugment(augmenter_args)
def __call__(self, data):
src_img = data['image']
src_polys = data['polys'].tolist()
src_ignores = data['ignore_tags'].tolist()
ext_data = data['ext_data'][0]
ext_image = ext_data['image']
ext_polys = ext_data['polys']
ext_ignores = ext_data['ignore_tags']
indexs = [i for i in range(len(ext_ignores)) if not ext_ignores[i]]
select_num = max(
1, min(int(self.objects_paste_ratio * len(ext_polys)), 30))
random.shuffle(indexs)
select_idxs = indexs[:select_num]
select_polys = ext_polys[select_idxs]
select_ignores = ext_ignores[select_idxs]
src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
ext_image = cv2.cvtColor(ext_image, cv2.COLOR_BGR2RGB)
src_img = Image.fromarray(src_img).convert('RGBA')
for poly, tag in zip(select_polys, select_ignores):
box_img = get_rotate_crop_image(ext_image, poly)
src_img, box = self.paste_img(src_img, box_img, src_polys)
if box is not None:
src_polys.append(box)
src_ignores.append(tag)
src_img = cv2.cvtColor(np.array(src_img), cv2.COLOR_RGB2BGR)
h, w = src_img.shape[:2]
src_polys = np.array(src_polys)
src_polys[:, :, 0] = np.clip(src_polys[:, :, 0], 0, w)
src_polys[:, :, 1] = np.clip(src_polys[:, :, 1], 0, h)
data['image'] = src_img
data['polys'] = src_polys
data['ignore_tags'] = np.array(src_ignores)
return data
def paste_img(self, src_img, box_img, src_polys):
box_img_pil = Image.fromarray(box_img).convert('RGBA')
src_w, src_h = src_img.size
box_w, box_h = box_img_pil.size
angle = np.random.randint(0, 360)
box = np.array([[[0, 0], [box_w, 0], [box_w, box_h], [0, box_h]]])
box = rotate_bbox(box_img, box, angle)[0]
box_img_pil = box_img_pil.rotate(angle, expand=1)
box_w, box_h = box_img_pil.width, box_img_pil.height
if src_w - box_w < 0 or src_h - box_h < 0:
return src_img, None
paste_x, paste_y = self.select_coord(src_polys, box, src_w - box_w,
src_h - box_h)
if paste_x is None:
return src_img, None
box[:, 0] += paste_x
box[:, 1] += paste_y
r, g, b, A = box_img_pil.split()
src_img.paste(box_img_pil, (paste_x, paste_y), mask=A)
return src_img, box
def select_coord(self, src_polys, box, endx, endy):
if self.limit_paste:
xmin, ymin, xmax, ymax = box[:, 0].min(), box[:, 1].min(
), box[:, 0].max(), box[:, 1].max()
for _ in range(50):
paste_x = random.randint(0, endx)
paste_y = random.randint(0, endy)
xmin1 = xmin + paste_x
xmax1 = xmax + paste_x
ymin1 = ymin + paste_y
ymax1 = ymax + paste_y
num_poly_in_rect = 0
for poly in src_polys:
if not is_poly_outside_rect(poly, xmin1, ymin1,
xmax1 - xmin1, ymax1 - ymin1):
num_poly_in_rect += 1
break
if num_poly_in_rect == 0:
return paste_x, paste_y
return None, None
else:
paste_x = random.randint(0, endx)
paste_y = random.randint(0, endy)
return paste_x, paste_y
def get_union(pD, pG):
return Polygon(pD).union(Polygon(pG)).area
def get_intersection_over_union(pD, pG):
return get_intersection(pD, pG) / get_union(pD, pG)
def get_intersection(pD, pG):
return Polygon(pD).intersection(Polygon(pG)).area
def rotate_bbox(img, text_polys, angle, scale=1):
"""
from https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/augment.py
Args:
img: np.ndarray
text_polys: np.ndarray N*4*2
angle: int
scale: int
Returns:
"""
w = img.shape[1]
h = img.shape[0]
rangle = np.deg2rad(angle)
nw = (abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w))
nh = (abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w))
rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, scale)
rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0]))
rot_mat[0, 2] += rot_move[0]
rot_mat[1, 2] += rot_move[1]
# ---------------------- rotate box ----------------------
rot_text_polys = list()
for bbox in text_polys:
point1 = np.dot(rot_mat, np.array([bbox[0, 0], bbox[0, 1], 1]))
point2 = np.dot(rot_mat, np.array([bbox[1, 0], bbox[1, 1], 1]))
point3 = np.dot(rot_mat, np.array([bbox[2, 0], bbox[2, 1], 1]))
point4 = np.dot(rot_mat, np.array([bbox[3, 0], bbox[3, 1], 1]))
rot_text_polys.append([point1, point2, point3, point4])
return np.array(rot_text_polys, dtype=np.float32)
"""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import sys
import six
import cv2
import numpy as np
class GenTableMask(object):
""" gen table mask """
def __init__(self, shrink_h_max, shrink_w_max, mask_type=0, **kwargs):
self.shrink_h_max = 5
self.shrink_w_max = 5
self.mask_type = mask_type
def projection(self, erosion, h, w, spilt_threshold=0):
# 水平投影
projection_map = np.ones_like(erosion)
project_val_array = [0 for _ in range(0, h)]
for j in range(0, h):
for i in range(0, w):
if erosion[j, i] == 255:
project_val_array[j] += 1
# 根据数组,获取切割点
start_idx = 0 # 记录进入字符区的索引
end_idx = 0 # 记录进入空白区域的索引
in_text = False # 是否遍历到了字符区内
box_list = []
for i in range(len(project_val_array)):
if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了
in_text = True
start_idx = i
elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了
end_idx = i
in_text = False
if end_idx - start_idx <= 2:
continue
box_list.append((start_idx, end_idx + 1))
if in_text:
box_list.append((start_idx, h - 1))
# 绘制投影直方图
for j in range(0, h):
for i in range(0, project_val_array[j]):
projection_map[j, i] = 0
return box_list, projection_map
def projection_cx(self, box_img):
box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY)
h, w = box_gray_img.shape
# 灰度图片进行二值化处理
ret, thresh1 = cv2.threshold(box_gray_img, 200, 255, cv2.THRESH_BINARY_INV)
# 纵向腐蚀
if h < w:
kernel = np.ones((2, 1), np.uint8)
erode = cv2.erode(thresh1, kernel, iterations=1)
else:
erode = thresh1
# 水平膨胀
kernel = np.ones((1, 5), np.uint8)
erosion = cv2.dilate(erode, kernel, iterations=1)
# 水平投影
projection_map = np.ones_like(erosion)
project_val_array = [0 for _ in range(0, h)]
for j in range(0, h):
for i in range(0, w):
if erosion[j, i] == 255:
project_val_array[j] += 1
# 根据数组,获取切割点
start_idx = 0 # 记录进入字符区的索引
end_idx = 0 # 记录进入空白区域的索引
in_text = False # 是否遍历到了字符区内
box_list = []
spilt_threshold = 0
for i in range(len(project_val_array)):
if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了
in_text = True
start_idx = i
elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了
end_idx = i
in_text = False
if end_idx - start_idx <= 2:
continue
box_list.append((start_idx, end_idx + 1))
if in_text:
box_list.append((start_idx, h - 1))
# 绘制投影直方图
for j in range(0, h):
for i in range(0, project_val_array[j]):
projection_map[j, i] = 0
split_bbox_list = []
if len(box_list) > 1:
for i, (h_start, h_end) in enumerate(box_list):
if i == 0:
h_start = 0
if i == len(box_list):
h_end = h
word_img = erosion[h_start:h_end + 1, :]
word_h, word_w = word_img.shape
w_split_list, w_projection_map = self.projection(word_img.T, word_w, word_h)
w_start, w_end = w_split_list[0][0], w_split_list[-1][1]
if h_start > 0:
h_start -= 1
h_end += 1
word_img = box_img[h_start:h_end + 1:, w_start:w_end + 1, :]
split_bbox_list.append([w_start, h_start, w_end, h_end])
else:
split_bbox_list.append([0, 0, w, h])
return split_bbox_list
def shrink_bbox(self, bbox):
left, top, right, bottom = bbox
sh_h = min(max(int((bottom - top) * 0.1), 1), self.shrink_h_max)
sh_w = min(max(int((right - left) * 0.1), 1), self.shrink_w_max)
left_new = left + sh_w
right_new = right - sh_w
top_new = top + sh_h
bottom_new = bottom - sh_h
if left_new >= right_new:
left_new = left
right_new = right
if top_new >= bottom_new:
top_new = top
bottom_new = bottom
return [left_new, top_new, right_new, bottom_new]
def __call__(self, data):
img = data['image']
cells = data['cells']
height, width = img.shape[0:2]
if self.mask_type == 1:
mask_img = np.zeros((height, width), dtype=np.float32)
else:
mask_img = np.zeros((height, width, 3), dtype=np.float32)
cell_num = len(cells)
for cno in range(cell_num):
if "bbox" in cells[cno]:
bbox = cells[cno]['bbox']
left, top, right, bottom = bbox
box_img = img[top:bottom, left:right, :].copy()
split_bbox_list = self.projection_cx(box_img)
for sno in range(len(split_bbox_list)):
split_bbox_list[sno][0] += left
split_bbox_list[sno][1] += top
split_bbox_list[sno][2] += left
split_bbox_list[sno][3] += top
for sno in range(len(split_bbox_list)):
left, top, right, bottom = split_bbox_list[sno]
left, top, right, bottom = self.shrink_bbox([left, top, right, bottom])
if self.mask_type == 1:
mask_img[top:bottom, left:right] = 1.0
data['mask_img'] = mask_img
else:
mask_img[top:bottom, left:right, :] = (255, 255, 255)
data['image'] = mask_img
return data
class ResizeTableImage(object):
def __init__(self, max_len, **kwargs):
super(ResizeTableImage, self).__init__()
self.max_len = max_len
def get_img_bbox(self, cells):
bbox_list = []
if len(cells) == 0:
return bbox_list
cell_num = len(cells)
for cno in range(cell_num):
if "bbox" in cells[cno]:
bbox = cells[cno]['bbox']
bbox_list.append(bbox)
return bbox_list
def resize_img_table(self, img, bbox_list, max_len):
height, width = img.shape[0:2]
ratio = max_len / (max(height, width) * 1.0)
resize_h = int(height * ratio)
resize_w = int(width * ratio)
img_new = cv2.resize(img, (resize_w, resize_h))
bbox_list_new = []
for bno in range(len(bbox_list)):
left, top, right, bottom = bbox_list[bno].copy()
left = int(left * ratio)
top = int(top * ratio)
right = int(right * ratio)
bottom = int(bottom * ratio)
bbox_list_new.append([left, top, right, bottom])
return img_new, bbox_list_new
def __call__(self, data):
img = data['image']
if 'cells' not in data:
cells = []
else:
cells = data['cells']
bbox_list = self.get_img_bbox(cells)
img_new, bbox_list_new = self.resize_img_table(img, bbox_list, self.max_len)
data['image'] = img_new
cell_num = len(cells)
bno = 0
for cno in range(cell_num):
if "bbox" in data['cells'][cno]:
data['cells'][cno]['bbox'] = bbox_list_new[bno]
bno += 1
data['max_len'] = self.max_len
return data
class PaddingTableImage(object):
def __init__(self, **kwargs):
super(PaddingTableImage, self).__init__()
def __call__(self, data):
img = data['image']
max_len = data['max_len']
padding_img = np.zeros((max_len, max_len, 3), dtype=np.float32)
height, width = img.shape[0:2]
padding_img[0:height, 0:width, :] = img.copy()
data['image'] = padding_img
return data
\ No newline at end of file
This diff is collapsed.
...@@ -113,7 +113,7 @@ class NormalizeImage(object): ...@@ -113,7 +113,7 @@ class NormalizeImage(object):
assert isinstance(img, assert isinstance(img,
np.ndarray), "invalid input 'img' in NormalizeImage" np.ndarray), "invalid input 'img' in NormalizeImage"
data['image'] = ( data['image'] = (
img.astype('float32') * self.scale - self.mean) / self.std img.astype('float32') * self.scale - self.mean) / self.std
return data return data
...@@ -195,7 +195,7 @@ class DetResizeForTest(object): ...@@ -195,7 +195,7 @@ class DetResizeForTest(object):
img, (ratio_h, ratio_w) img, (ratio_h, ratio_w)
""" """
limit_side_len = self.limit_side_len limit_side_len = self.limit_side_len
h, w, _ = img.shape h, w, c = img.shape
# limit the max side # limit the max side
if self.limit_type == 'max': if self.limit_type == 'max':
...@@ -206,7 +206,7 @@ class DetResizeForTest(object): ...@@ -206,7 +206,7 @@ class DetResizeForTest(object):
ratio = float(limit_side_len) / w ratio = float(limit_side_len) / w
else: else:
ratio = 1. ratio = 1.
else: elif self.limit_type == 'min':
if min(h, w) < limit_side_len: if min(h, w) < limit_side_len:
if h < w: if h < w:
ratio = float(limit_side_len) / h ratio = float(limit_side_len) / h
...@@ -214,6 +214,10 @@ class DetResizeForTest(object): ...@@ -214,6 +214,10 @@ class DetResizeForTest(object):
ratio = float(limit_side_len) / w ratio = float(limit_side_len) / w
else: else:
ratio = 1. ratio = 1.
elif self.limit_type == 'resize_long':
ratio = float(limit_side_len) / max(h,w)
else:
raise Exception('not support limit type, image ')
resize_h = int(h * ratio) resize_h = int(h * ratio)
resize_w = int(w * ratio) resize_w = int(w * ratio)
......
This diff is collapsed.
...@@ -69,12 +69,42 @@ class SimpleDataSet(Dataset): ...@@ -69,12 +69,42 @@ class SimpleDataSet(Dataset):
random.shuffle(self.data_lines) random.shuffle(self.data_lines)
return return
def get_ext_data(self):
ext_data_num = 0
for op in self.ops:
if hasattr(op, 'ext_data_num'):
ext_data_num = getattr(op, 'ext_data_num')
break
load_data_ops = self.ops[:2]
ext_data = []
while len(ext_data) < ext_data_num:
file_idx = self.data_idx_order_list[np.random.randint(self.__len__(
))]
data_line = self.data_lines[file_idx]
data_line = data_line.decode('utf-8')
substr = data_line.strip("\n").split(self.delimiter)
file_name = substr[0]
label = substr[1]
img_path = os.path.join(self.data_dir, file_name)
data = {'img_path': img_path, 'label': label}
if not os.path.exists(img_path):
continue
with open(data['img_path'], 'rb') as f:
img = f.read()
data['image'] = img
data = transform(data, load_data_ops)
if data is None:
continue
ext_data.append(data)
return ext_data
def __getitem__(self, idx): def __getitem__(self, idx):
file_idx = self.data_idx_order_list[idx] file_idx = self.data_idx_order_list[idx]
data_line = self.data_lines[file_idx] data_line = self.data_lines[file_idx]
try: try:
data_line = data_line.decode('utf-8') data_line = data_line.decode('utf-8')
substr = data_line.strip("\n").strip("\r").split(self.delimiter) substr = data_line.strip("\n").split(self.delimiter)
file_name = substr[0] file_name = substr[0]
label = substr[1] label = substr[1]
img_path = os.path.join(self.data_dir, file_name) img_path = os.path.join(self.data_dir, file_name)
...@@ -84,6 +114,7 @@ class SimpleDataSet(Dataset): ...@@ -84,6 +114,7 @@ class SimpleDataSet(Dataset):
with open(data['img_path'], 'rb') as f: with open(data['img_path'], 'rb') as f:
img = f.read() img = f.read()
data['image'] = img data['image'] = img
data['ext_data'] = self.get_ext_data()
outs = transform(data, self.ops) outs = transform(data, self.ops)
except Exception as e: except Exception as e:
self.logger.error( self.logger.error(
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment