"tools/vscode:/vscode.git/clone" did not exist on "7c7972f3f1997b4cda8b17a4e887e4d7451ada52"
Unverified Commit 006d84bf authored by 崔浩's avatar 崔浩 Committed by GitHub
Browse files

Merge branch 'PaddlePaddle:dygraph' into dygraph

parents 302ca30c 8beeb84c
doc/table/1.png

263 KB | W: | H:

doc/table/1.png

758 KB | W: | H:

doc/table/1.png
doc/table/1.png
doc/table/1.png
doc/table/1.png
  • 2-up
  • Swipe
  • Onion skin
doc/table/table.jpg

24.1 KB | W: | H:

doc/table/table.jpg

58 KB | W: | H:

doc/table/table.jpg
doc/table/table.jpg
doc/table/table.jpg
doc/table/table.jpg
  • 2-up
  • Swipe
  • Onion skin
...@@ -33,104 +33,141 @@ from tools.infer.utility import draw_ocr, str2bool ...@@ -33,104 +33,141 @@ from tools.infer.utility import draw_ocr, str2bool
from ppstructure.utility import init_args, draw_structure_result from ppstructure.utility import init_args, draw_structure_result
from ppstructure.predict_system import OCRSystem, save_structure_res from ppstructure.predict_system import OCRSystem, save_structure_res
__all__ = ['PaddleOCR', 'PPStructure', 'draw_ocr', 'draw_structure_result', 'save_structure_res','download_with_progressbar'] __all__ = [
'PaddleOCR', 'PPStructure', 'draw_ocr', 'draw_structure_result',
model_urls = { 'save_structure_res', 'download_with_progressbar'
'det': { ]
'ch':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar', SUPPORT_DET_MODEL = ['DB']
'en': VERSION = '2.2.1'
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar', SUPPORT_REC_MODEL = ['CRNN']
'structure': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar' BASE_DIR = os.path.expanduser("~/.paddleocr/")
DEFAULT_MODEL_VERSION = '2.0'
MODEL_URLS = {
'2.1': {
'det': {
'ch': {
'url':
'https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar',
},
},
'rec': {
'ch': {
'url':
'https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar',
'dict_path': './ppocr/utils/ppocr_keys_v1.txt'
}
}
}, },
'rec': { '2.0': {
'ch': { 'det': {
'url': 'ch': {
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar', 'url':
'dict_path': './ppocr/utils/ppocr_keys_v1.txt' 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
},
'en': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar',
},
'structure': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar'
}
}, },
'en': { 'rec': {
'url': 'ch': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/ppocr_keys_v1.txt'
},
'en': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/en_dict.txt' 'dict_path': './ppocr/utils/en_dict.txt'
}, },
'french': { 'french': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/french_dict.txt' 'dict_path': './ppocr/utils/dict/french_dict.txt'
}, },
'german': { 'german': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/german_dict.txt' 'dict_path': './ppocr/utils/dict/german_dict.txt'
}, },
'korean': { 'korean': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/korean_dict.txt' 'dict_path': './ppocr/utils/dict/korean_dict.txt'
}, },
'japan': { 'japan': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/japan_dict.txt' 'dict_path': './ppocr/utils/dict/japan_dict.txt'
}, },
'chinese_cht': { 'chinese_cht': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/chinese_cht_dict.txt' 'dict_path': './ppocr/utils/dict/chinese_cht_dict.txt'
}, },
'ta': { 'ta': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/ta_dict.txt' 'dict_path': './ppocr/utils/dict/ta_dict.txt'
}, },
'te': { 'te': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/te_dict.txt' 'dict_path': './ppocr/utils/dict/te_dict.txt'
}, },
'ka': { 'ka': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/ka_dict.txt' 'dict_path': './ppocr/utils/dict/ka_dict.txt'
}, },
'latin': { 'latin': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/latin_dict.txt' 'dict_path': './ppocr/utils/dict/latin_dict.txt'
}, },
'arabic': { 'arabic': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/arabic_dict.txt' 'dict_path': './ppocr/utils/dict/arabic_dict.txt'
}, },
'cyrillic': { 'cyrillic': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/cyrillic_dict.txt' 'dict_path': './ppocr/utils/dict/cyrillic_dict.txt'
}, },
'devanagari': { 'devanagari': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/devanagari_dict.txt' 'dict_path': './ppocr/utils/dict/devanagari_dict.txt'
},
'structure': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar',
'dict_path': 'ppocr/utils/dict/table_dict.txt'
}
},
'cls': {
'ch': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar',
}
}, },
'structure': { 'table': {
'url': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar', 'en': {
'dict_path': 'ppocr/utils/dict/table_dict.txt' 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar',
'dict_path': 'ppocr/utils/dict/table_structure_dict.txt'
}
} }
},
'cls': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar',
'table': {
'url': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar',
'dict_path': 'ppocr/utils/dict/table_structure_dict.txt'
} }
} }
SUPPORT_DET_MODEL = ['DB']
VERSION = '2.2'
SUPPORT_REC_MODEL = ['CRNN']
BASE_DIR = os.path.expanduser("~/.paddleocr/")
def parse_args(mMain=True): def parse_args(mMain=True):
import argparse import argparse
...@@ -140,6 +177,7 @@ def parse_args(mMain=True): ...@@ -140,6 +177,7 @@ def parse_args(mMain=True):
parser.add_argument("--det", type=str2bool, default=True) parser.add_argument("--det", type=str2bool, default=True)
parser.add_argument("--rec", type=str2bool, default=True) parser.add_argument("--rec", type=str2bool, default=True)
parser.add_argument("--type", type=str, default='ocr') parser.add_argument("--type", type=str, default='ocr')
parser.add_argument("--version", type=str, default='2.1')
for action in parser._actions: for action in parser._actions:
if action.dest in ['rec_char_dict_path', 'table_char_dict_path']: if action.dest in ['rec_char_dict_path', 'table_char_dict_path']:
...@@ -155,19 +193,19 @@ def parse_args(mMain=True): ...@@ -155,19 +193,19 @@ def parse_args(mMain=True):
def parse_lang(lang): def parse_lang(lang):
latin_lang = [ latin_lang = [
'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga', 'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga', 'hr',
'hr', 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms', 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms', 'mt', 'nl',
'mt', 'nl', 'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin', 'sk', 'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin', 'sk', 'sl', 'sq', 'sv',
'sl', 'sq', 'sv', 'sw', 'tl', 'tr', 'uz', 'vi' 'sw', 'tl', 'tr', 'uz', 'vi'
] ]
arabic_lang = ['ar', 'fa', 'ug', 'ur'] arabic_lang = ['ar', 'fa', 'ug', 'ur']
cyrillic_lang = [ cyrillic_lang = [
'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd', 'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd', 'ava',
'ava', 'dar', 'inh', 'che', 'lbe', 'lez', 'tab' 'dar', 'inh', 'che', 'lbe', 'lez', 'tab'
] ]
devanagari_lang = [ devanagari_lang = [
'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new', 'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new', 'gom',
'gom', 'sa', 'bgc' 'sa', 'bgc'
] ]
if lang in latin_lang: if lang in latin_lang:
lang = "latin" lang = "latin"
...@@ -177,9 +215,9 @@ def parse_lang(lang): ...@@ -177,9 +215,9 @@ def parse_lang(lang):
lang = "cyrillic" lang = "cyrillic"
elif lang in devanagari_lang: elif lang in devanagari_lang:
lang = "devanagari" lang = "devanagari"
assert lang in model_urls[ assert lang in MODEL_URLS[DEFAULT_MODEL_VERSION][
'rec'], 'param lang must in {}, but got {}'.format( 'rec'], 'param lang must in {}, but got {}'.format(
model_urls['rec'].keys(), lang) MODEL_URLS[DEFAULT_MODEL_VERSION]['rec'].keys(), lang)
if lang == "ch": if lang == "ch":
det_lang = "ch" det_lang = "ch"
elif lang == 'structure': elif lang == 'structure':
...@@ -189,6 +227,35 @@ def parse_lang(lang): ...@@ -189,6 +227,35 @@ def parse_lang(lang):
return lang, det_lang return lang, det_lang
def get_model_config(version, model_type, lang):
if version not in MODEL_URLS:
logger.warning('version {} not in {}, use version {} instead'.format(
version, MODEL_URLS.keys(), DEFAULT_MODEL_VERSION))
version = DEFAULT_MODEL_VERSION
if model_type not in MODEL_URLS[version]:
if model_type in MODEL_URLS[DEFAULT_MODEL_VERSION]:
logger.warning(
'version {} not support {} models, use version {} instead'.
format(version, model_type, DEFAULT_MODEL_VERSION))
version = DEFAULT_MODEL_VERSION
else:
logger.error('{} models is not support, we only support {}'.format(
model_type, MODEL_URLS[DEFAULT_MODEL_VERSION].keys()))
sys.exit(-1)
if lang not in MODEL_URLS[version][model_type]:
if lang in MODEL_URLS[DEFAULT_MODEL_VERSION][model_type]:
logger.warning('lang {} is not support in {}, use {} instead'.
format(lang, version, DEFAULT_MODEL_VERSION))
version = DEFAULT_MODEL_VERSION
else:
logger.error(
'lang {} is not support, we only support {} for {} models'.
format(lang, MODEL_URLS[DEFAULT_MODEL_VERSION][model_type].keys(
), model_type))
sys.exit(-1)
return MODEL_URLS[version][model_type][lang]
class PaddleOCR(predict_system.TextSystem): class PaddleOCR(predict_system.TextSystem):
def __init__(self, **kwargs): def __init__(self, **kwargs):
""" """
...@@ -204,15 +271,21 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -204,15 +271,21 @@ class PaddleOCR(predict_system.TextSystem):
lang, det_lang = parse_lang(params.lang) lang, det_lang = parse_lang(params.lang)
# init model dir # init model dir
params.det_model_dir, det_url = confirm_model_dir_url(params.det_model_dir, det_model_config = get_model_config(params.version, 'det', det_lang)
os.path.join(BASE_DIR, VERSION, 'ocr', 'det', det_lang), params.det_model_dir, det_url = confirm_model_dir_url(
model_urls['det'][det_lang]) params.det_model_dir,
params.rec_model_dir, rec_url = confirm_model_dir_url(params.rec_model_dir, os.path.join(BASE_DIR, VERSION, 'ocr', 'det', det_lang),
os.path.join(BASE_DIR, VERSION, 'ocr', 'rec', lang), det_model_config['url'])
model_urls['rec'][lang]['url']) rec_model_config = get_model_config(params.version, 'rec', lang)
params.cls_model_dir, cls_url = confirm_model_dir_url(params.cls_model_dir, params.rec_model_dir, rec_url = confirm_model_dir_url(
os.path.join(BASE_DIR, VERSION, 'ocr', 'cls'), params.rec_model_dir,
model_urls['cls']) os.path.join(BASE_DIR, VERSION, 'ocr', 'rec', lang),
rec_model_config['url'])
cls_model_config = get_model_config(params.version, 'cls', 'ch')
params.cls_model_dir, cls_url = confirm_model_dir_url(
params.cls_model_dir,
os.path.join(BASE_DIR, VERSION, 'ocr', 'cls'),
cls_model_config['url'])
# download model # download model
maybe_download(params.det_model_dir, det_url) maybe_download(params.det_model_dir, det_url)
maybe_download(params.rec_model_dir, rec_url) maybe_download(params.rec_model_dir, rec_url)
...@@ -226,7 +299,8 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -226,7 +299,8 @@ class PaddleOCR(predict_system.TextSystem):
sys.exit(0) sys.exit(0)
if params.rec_char_dict_path is None: if params.rec_char_dict_path is None:
params.rec_char_dict_path = str(Path(__file__).parent / model_urls['rec'][lang]['dict_path']) params.rec_char_dict_path = str(
Path(__file__).parent / rec_model_config['dict_path'])
print(params) print(params)
# init det_model and rec_model # init det_model and rec_model
...@@ -293,24 +367,32 @@ class PPStructure(OCRSystem): ...@@ -293,24 +367,32 @@ class PPStructure(OCRSystem):
lang, det_lang = parse_lang(params.lang) lang, det_lang = parse_lang(params.lang)
# init model dir # init model dir
params.det_model_dir, det_url = confirm_model_dir_url(params.det_model_dir, det_model_config = get_model_config(params.version, 'det', det_lang)
os.path.join(BASE_DIR, VERSION, 'ocr', 'det', det_lang), params.det_model_dir, det_url = confirm_model_dir_url(
model_urls['det'][det_lang]) params.det_model_dir,
params.rec_model_dir, rec_url = confirm_model_dir_url(params.rec_model_dir, os.path.join(BASE_DIR, VERSION, 'ocr', 'det', det_lang),
os.path.join(BASE_DIR, VERSION, 'ocr', 'rec', lang), det_model_config['url'])
model_urls['rec'][lang]['url']) rec_model_config = get_model_config(params.version, 'rec', lang)
params.table_model_dir, table_url = confirm_model_dir_url(params.table_model_dir, params.rec_model_dir, rec_url = confirm_model_dir_url(
os.path.join(BASE_DIR, VERSION, 'ocr', 'table'), params.rec_model_dir,
model_urls['table']['url']) os.path.join(BASE_DIR, VERSION, 'ocr', 'rec', lang),
rec_model_config['url'])
table_model_config = get_model_config(params.version, 'table', 'en')
params.table_model_dir, table_url = confirm_model_dir_url(
params.table_model_dir,
os.path.join(BASE_DIR, VERSION, 'ocr', 'table'),
table_model_config['url'])
# download model # download model
maybe_download(params.det_model_dir, det_url) maybe_download(params.det_model_dir, det_url)
maybe_download(params.rec_model_dir, rec_url) maybe_download(params.rec_model_dir, rec_url)
maybe_download(params.table_model_dir, table_url) maybe_download(params.table_model_dir, table_url)
if params.rec_char_dict_path is None: if params.rec_char_dict_path is None:
params.rec_char_dict_path = str(Path(__file__).parent / model_urls['rec'][lang]['dict_path']) params.rec_char_dict_path = str(
Path(__file__).parent / rec_model_config['dict_path'])
if params.table_char_dict_path is None: if params.table_char_dict_path is None:
params.table_char_dict_path = str(Path(__file__).parent / model_urls['table']['dict_path']) params.table_char_dict_path = str(
Path(__file__).parent / table_model_config['dict_path'])
print(params) print(params)
super().__init__(params) super().__init__(params)
...@@ -374,4 +456,3 @@ def main(): ...@@ -374,4 +456,3 @@ def main():
for item in result: for item in result:
item.pop('img') item.pop('img')
logger.info(item) logger.info(item)
...@@ -49,14 +49,12 @@ def term_mp(sig_num, frame): ...@@ -49,14 +49,12 @@ def term_mp(sig_num, frame):
os.killpg(pgid, signal.SIGKILL) os.killpg(pgid, signal.SIGKILL)
signal.signal(signal.SIGINT, term_mp)
signal.signal(signal.SIGTERM, term_mp)
def build_dataloader(config, mode, device, logger, seed=None): def build_dataloader(config, mode, device, logger, seed=None):
config = copy.deepcopy(config) config = copy.deepcopy(config)
support_dict = ['SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet'] support_dict = [
'SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet'
]
module_name = config[mode]['dataset']['name'] module_name = config[mode]['dataset']['name']
assert module_name in support_dict, Exception( assert module_name in support_dict, Exception(
'DataSet only support {}'.format(support_dict)) 'DataSet only support {}'.format(support_dict))
...@@ -96,4 +94,8 @@ def build_dataloader(config, mode, device, logger, seed=None): ...@@ -96,4 +94,8 @@ def build_dataloader(config, mode, device, logger, seed=None):
return_list=True, return_list=True,
use_shared_memory=use_shared_memory) use_shared_memory=use_shared_memory)
# support exit using ctrl+c
signal.signal(signal.SIGINT, term_mp)
signal.signal(signal.SIGTERM, term_mp)
return data_loader return data_loader
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.vision.transforms import ColorJitter as pp_ColorJitter
__all__ = ['ColorJitter']
class ColorJitter(object):
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0,**kwargs):
self.aug = pp_ColorJitter(brightness, contrast, saturation, hue)
def __call__(self, data):
image = data['image']
image = self.aug(image)
data['image'] = image
return data
...@@ -19,11 +19,13 @@ from __future__ import unicode_literals ...@@ -19,11 +19,13 @@ from __future__ import unicode_literals
from .iaa_augment import IaaAugment from .iaa_augment import IaaAugment
from .make_border_map import MakeBorderMap from .make_border_map import MakeBorderMap
from .make_shrink_map import MakeShrinkMap from .make_shrink_map import MakeShrinkMap
from .random_crop_data import EastRandomCropData, PSERandomCrop from .random_crop_data import EastRandomCropData, RandomCropImgMask
from .make_pse_gt import MakePseGt
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg
from .randaugment import RandAugment from .randaugment import RandAugment
from .copy_paste import CopyPaste from .copy_paste import CopyPaste
from .ColorJitter import ColorJitter
from .operators import * from .operators import *
from .label_ops import * from .label_ops import *
......
...@@ -21,6 +21,8 @@ import numpy as np ...@@ -21,6 +21,8 @@ import numpy as np
import string import string
import json import json
from ppocr.utils.logging import get_logger
class ClsLabelEncode(object): class ClsLabelEncode(object):
def __init__(self, label_list, **kwargs): def __init__(self, label_list, **kwargs):
...@@ -92,31 +94,23 @@ class BaseRecLabelEncode(object): ...@@ -92,31 +94,23 @@ class BaseRecLabelEncode(object):
def __init__(self, def __init__(self,
max_text_length, max_text_length,
character_dict_path=None, character_dict_path=None,
character_type='ch',
use_space_char=False): use_space_char=False):
support_character_type = [
'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
'EN', 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs',
'oc', 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi',
'mr', 'ne', 'latin', 'arabic', 'cyrillic', 'devanagari'
]
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
support_character_type, character_type)
self.max_text_len = max_text_length self.max_text_len = max_text_length
self.beg_str = "sos" self.beg_str = "sos"
self.end_str = "eos" self.end_str = "eos"
if character_type == "en": self.lower = False
if character_dict_path is None:
logger = get_logger()
logger.warning(
"The character_dict_path is None, model can only recognize number and lower letters"
)
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str) dict_character = list(self.character_str)
elif character_type == "EN_symbol": self.lower = True
# same with ASTER setting (use 94 char). else:
self.character_str = string.printable[:-6]
dict_character = list(self.character_str)
elif character_type in support_character_type:
self.character_str = "" self.character_str = ""
assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format(
character_type)
with open(character_dict_path, "rb") as fin: with open(character_dict_path, "rb") as fin:
lines = fin.readlines() lines = fin.readlines()
for line in lines: for line in lines:
...@@ -125,7 +119,6 @@ class BaseRecLabelEncode(object): ...@@ -125,7 +119,6 @@ class BaseRecLabelEncode(object):
if use_space_char: if use_space_char:
self.character_str += " " self.character_str += " "
dict_character = list(self.character_str) dict_character = list(self.character_str)
self.character_type = character_type
dict_character = self.add_special_char(dict_character) dict_character = self.add_special_char(dict_character)
self.dict = {} self.dict = {}
for i, char in enumerate(dict_character): for i, char in enumerate(dict_character):
...@@ -147,7 +140,7 @@ class BaseRecLabelEncode(object): ...@@ -147,7 +140,7 @@ class BaseRecLabelEncode(object):
""" """
if len(text) == 0 or len(text) > self.max_text_len: if len(text) == 0 or len(text) > self.max_text_len:
return None return None
if self.character_type == "en": if self.lower:
text = text.lower() text = text.lower()
text_list = [] text_list = []
for char in text: for char in text:
...@@ -161,18 +154,47 @@ class BaseRecLabelEncode(object): ...@@ -161,18 +154,47 @@ class BaseRecLabelEncode(object):
return text_list return text_list
class NRTRLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length,
character_dict_path=None,
use_space_char=False,
**kwargs):
super(NRTRLabelEncode, self).__init__(
max_text_length, character_dict_path, use_space_char)
def __call__(self, data):
text = data['label']
text = self.encode(text)
if text is None:
return None
if len(text) >= self.max_text_len - 1:
return None
data['length'] = np.array(len(text))
text.insert(0, 2)
text.append(3)
text = text + [0] * (self.max_text_len - len(text))
data['label'] = np.array(text)
return data
def add_special_char(self, dict_character):
dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character
return dict_character
class CTCLabelEncode(BaseRecLabelEncode): class CTCLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """ """ Convert between text-label and text-index """
def __init__(self, def __init__(self,
max_text_length, max_text_length,
character_dict_path=None, character_dict_path=None,
character_type='ch',
use_space_char=False, use_space_char=False,
**kwargs): **kwargs):
super(CTCLabelEncode, super(CTCLabelEncode, self).__init__(
self).__init__(max_text_length, character_dict_path, max_text_length, character_dict_path, use_space_char)
character_type, use_space_char)
def __call__(self, data): def __call__(self, data):
text = data['label'] text = data['label']
...@@ -182,6 +204,11 @@ class CTCLabelEncode(BaseRecLabelEncode): ...@@ -182,6 +204,11 @@ class CTCLabelEncode(BaseRecLabelEncode):
data['length'] = np.array(len(text)) data['length'] = np.array(len(text))
text = text + [0] * (self.max_text_len - len(text)) text = text + [0] * (self.max_text_len - len(text))
data['label'] = np.array(text) data['label'] = np.array(text)
label = [0] * len(self.character)
for x in text:
label[x] += 1
data['label_ace'] = np.array(label)
return data return data
def add_special_char(self, dict_character): def add_special_char(self, dict_character):
...@@ -193,12 +220,10 @@ class E2ELabelEncodeTest(BaseRecLabelEncode): ...@@ -193,12 +220,10 @@ class E2ELabelEncodeTest(BaseRecLabelEncode):
def __init__(self, def __init__(self,
max_text_length, max_text_length,
character_dict_path=None, character_dict_path=None,
character_type='EN',
use_space_char=False, use_space_char=False,
**kwargs): **kwargs):
super(E2ELabelEncodeTest, super(E2ELabelEncodeTest, self).__init__(
self).__init__(max_text_length, character_dict_path, max_text_length, character_dict_path, use_space_char)
character_type, use_space_char)
def __call__(self, data): def __call__(self, data):
import json import json
...@@ -267,12 +292,10 @@ class AttnLabelEncode(BaseRecLabelEncode): ...@@ -267,12 +292,10 @@ class AttnLabelEncode(BaseRecLabelEncode):
def __init__(self, def __init__(self,
max_text_length, max_text_length,
character_dict_path=None, character_dict_path=None,
character_type='ch',
use_space_char=False, use_space_char=False,
**kwargs): **kwargs):
super(AttnLabelEncode, super(AttnLabelEncode, self).__init__(
self).__init__(max_text_length, character_dict_path, max_text_length, character_dict_path, use_space_char)
character_type, use_space_char)
def add_special_char(self, dict_character): def add_special_char(self, dict_character):
self.beg_str = "sos" self.beg_str = "sos"
...@@ -309,18 +332,46 @@ class AttnLabelEncode(BaseRecLabelEncode): ...@@ -309,18 +332,46 @@ class AttnLabelEncode(BaseRecLabelEncode):
return idx return idx
class SEEDLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length,
character_dict_path=None,
use_space_char=False,
**kwargs):
super(SEEDLabelEncode, self).__init__(
max_text_length, character_dict_path, use_space_char)
def add_special_char(self, dict_character):
self.end_str = "eos"
dict_character = dict_character + [self.end_str]
return dict_character
def __call__(self, data):
text = data['label']
text = self.encode(text)
if text is None:
return None
if len(text) >= self.max_text_len:
return None
data['length'] = np.array(len(text)) + 1 # conclude eos
text = text + [len(self.character) - 1] * (self.max_text_len - len(text)
)
data['label'] = np.array(text)
return data
class SRNLabelEncode(BaseRecLabelEncode): class SRNLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """ """ Convert between text-label and text-index """
def __init__(self, def __init__(self,
max_text_length=25, max_text_length=25,
character_dict_path=None, character_dict_path=None,
character_type='en',
use_space_char=False, use_space_char=False,
**kwargs): **kwargs):
super(SRNLabelEncode, super(SRNLabelEncode, self).__init__(
self).__init__(max_text_length, character_dict_path, max_text_length, character_dict_path, use_space_char)
character_type, use_space_char)
def add_special_char(self, dict_character): def add_special_char(self, dict_character):
dict_character = dict_character + [self.beg_str, self.end_str] dict_character = dict_character + [self.beg_str, self.end_str]
...@@ -388,7 +439,6 @@ class TableLabelEncode(object): ...@@ -388,7 +439,6 @@ class TableLabelEncode(object):
substr = lines[0].decode('utf-8').strip("\r\n").split("\t") substr = lines[0].decode('utf-8').strip("\r\n").split("\t")
character_num = int(substr[0]) character_num = int(substr[0])
elem_num = int(substr[1]) elem_num = int(substr[1])
for cno in range(1, 1 + character_num): for cno in range(1, 1 + character_num):
character = lines[cno].decode('utf-8').strip("\r\n") character = lines[cno].decode('utf-8').strip("\r\n")
list_character.append(character) list_character.append(character)
...@@ -521,3 +571,47 @@ class TableLabelEncode(object): ...@@ -521,3 +571,47 @@ class TableLabelEncode(object):
assert False, "Unsupport type %s in char_or_elem" \ assert False, "Unsupport type %s in char_or_elem" \
% char_or_elem % char_or_elem
return idx return idx
class SARLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length,
character_dict_path=None,
use_space_char=False,
**kwargs):
super(SARLabelEncode, self).__init__(
max_text_length, character_dict_path, use_space_char)
def add_special_char(self, dict_character):
beg_end_str = "<BOS/EOS>"
unknown_str = "<UKN>"
padding_str = "<PAD>"
dict_character = dict_character + [unknown_str]
self.unknown_idx = len(dict_character) - 1
dict_character = dict_character + [beg_end_str]
self.start_idx = len(dict_character) - 1
self.end_idx = len(dict_character) - 1
dict_character = dict_character + [padding_str]
self.padding_idx = len(dict_character) - 1
return dict_character
def __call__(self, data):
text = data['label']
text = self.encode(text)
if text is None:
return None
if len(text) >= self.max_text_len - 1:
return None
data['length'] = np.array(len(text))
target = [self.start_idx] + text + [self.end_idx]
padded_text = [self.padding_idx for _ in range(self.max_text_len)]
padded_text[:len(target)] = target
data['label'] = np.array(padded_text)
return data
def get_ignored_tokens(self):
return [self.padding_idx]
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import cv2
import numpy as np
import pyclipper
from shapely.geometry import Polygon
__all__ = ['MakePseGt']
class MakePseGt(object):
r'''
Making binary mask from detection data with ICDAR format.
Typically following the process of class `MakeICDARData`.
'''
def __init__(self, kernel_num=7, size=640, min_shrink_ratio=0.4, **kwargs):
self.kernel_num = kernel_num
self.min_shrink_ratio = min_shrink_ratio
self.size = size
def __call__(self, data):
image = data['image']
text_polys = data['polys']
ignore_tags = data['ignore_tags']
h, w, _ = image.shape
short_edge = min(h, w)
if short_edge < self.size:
# keep short_size >= self.size
scale = self.size / short_edge
image = cv2.resize(image, dsize=None, fx=scale, fy=scale)
text_polys *= scale
gt_kernels = []
for i in range(1,self.kernel_num+1):
# s1->sn, from big to small
rate = 1.0 - (1.0 - self.min_shrink_ratio) / (self.kernel_num - 1) * i
text_kernel, ignore_tags = self.generate_kernel(image.shape[0:2], rate, text_polys, ignore_tags)
gt_kernels.append(text_kernel)
training_mask = np.ones(image.shape[0:2], dtype='uint8')
for i in range(text_polys.shape[0]):
if ignore_tags[i]:
cv2.fillPoly(training_mask, text_polys[i].astype(np.int32)[np.newaxis, :, :], 0)
gt_kernels = np.array(gt_kernels)
gt_kernels[gt_kernels > 0] = 1
data['image'] = image
data['polys'] = text_polys
data['gt_kernels'] = gt_kernels[0:]
data['gt_text'] = gt_kernels[0]
data['mask'] = training_mask.astype('float32')
return data
def generate_kernel(self, img_size, shrink_ratio, text_polys, ignore_tags=None):
h, w = img_size
text_kernel = np.zeros((h, w), dtype=np.float32)
for i, poly in enumerate(text_polys):
polygon = Polygon(poly)
distance = polygon.area * (1 - shrink_ratio * shrink_ratio) / (polygon.length + 1e-6)
subject = [tuple(l) for l in poly]
pco = pyclipper.PyclipperOffset()
pco.AddPath(subject, pyclipper.JT_ROUND,
pyclipper.ET_CLOSEDPOLYGON)
shrinked = np.array(pco.Execute(-distance))
if len(shrinked) == 0 or shrinked.size == 0:
if ignore_tags is not None:
ignore_tags[i] = True
continue
try:
shrinked = np.array(shrinked[0]).reshape(-1, 2)
except:
if ignore_tags is not None:
ignore_tags[i] = True
continue
cv2.fillPoly(text_kernel, [shrinked.astype(np.int32)], i + 1)
return text_kernel, ignore_tags
...@@ -23,6 +23,7 @@ import sys ...@@ -23,6 +23,7 @@ import sys
import six import six
import cv2 import cv2
import numpy as np import numpy as np
import fasttext
class DecodeImage(object): class DecodeImage(object):
...@@ -57,6 +58,39 @@ class DecodeImage(object): ...@@ -57,6 +58,39 @@ class DecodeImage(object):
return data return data
class NRTRDecodeImage(object):
""" decode image """
def __init__(self, img_mode='RGB', channel_first=False, **kwargs):
self.img_mode = img_mode
self.channel_first = channel_first
def __call__(self, data):
img = data['image']
if six.PY2:
assert type(img) is str and len(
img) > 0, "invalid input 'img' in DecodeImage"
else:
assert type(img) is bytes and len(
img) > 0, "invalid input 'img' in DecodeImage"
img = np.frombuffer(img, dtype='uint8')
img = cv2.imdecode(img, 1)
if img is None:
return None
if self.img_mode == 'GRAY':
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
elif self.img_mode == 'RGB':
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
img = img[:, :, ::-1]
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
if self.channel_first:
img = img.transpose((2, 0, 1))
data['image'] = img
return data
class NormalizeImage(object): class NormalizeImage(object):
""" normalize image such as substract mean, divide std """ normalize image such as substract mean, divide std
""" """
...@@ -81,7 +115,7 @@ class NormalizeImage(object): ...@@ -81,7 +115,7 @@ class NormalizeImage(object):
assert isinstance(img, assert isinstance(img,
np.ndarray), "invalid input 'img' in NormalizeImage" np.ndarray), "invalid input 'img' in NormalizeImage"
data['image'] = ( data['image'] = (
img.astype('float32') * self.scale - self.mean) / self.std img.astype('float32') * self.scale - self.mean) / self.std
return data return data
...@@ -101,6 +135,17 @@ class ToCHWImage(object): ...@@ -101,6 +135,17 @@ class ToCHWImage(object):
return data return data
class Fasttext(object):
def __init__(self, path="None", **kwargs):
self.fast_model = fasttext.load_model(path)
def __call__(self, data):
label = data['label']
fast_label = self.fast_model[label]
data['fast_label'] = fast_label
return data
class KeepKeys(object): class KeepKeys(object):
def __init__(self, keep_keys, **kwargs): def __init__(self, keep_keys, **kwargs):
self.keep_keys = keep_keys self.keep_keys = keep_keys
...@@ -112,6 +157,34 @@ class KeepKeys(object): ...@@ -112,6 +157,34 @@ class KeepKeys(object):
return data_list return data_list
class Resize(object):
def __init__(self, size=(640, 640), **kwargs):
self.size = size
def resize_image(self, img):
resize_h, resize_w = self.size
ori_h, ori_w = img.shape[:2] # (h, w, c)
ratio_h = float(resize_h) / ori_h
ratio_w = float(resize_w) / ori_w
img = cv2.resize(img, (int(resize_w), int(resize_h)))
return img, [ratio_h, ratio_w]
def __call__(self, data):
img = data['image']
text_polys = data['polys']
img_resize, [ratio_h, ratio_w] = self.resize_image(img)
new_boxes = []
for box in text_polys:
new_box = []
for cord in box:
new_box.append([cord[0] * ratio_w, cord[1] * ratio_h])
new_boxes.append(new_box)
data['image'] = img_resize
data['polys'] = np.array(new_boxes, dtype=np.float32)
return data
class DetResizeForTest(object): class DetResizeForTest(object):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(DetResizeForTest, self).__init__() super(DetResizeForTest, self).__init__()
...@@ -183,7 +256,7 @@ class DetResizeForTest(object): ...@@ -183,7 +256,7 @@ class DetResizeForTest(object):
else: else:
ratio = 1. ratio = 1.
elif self.limit_type == 'resize_long': elif self.limit_type == 'resize_long':
ratio = float(limit_side_len) / max(h,w) ratio = float(limit_side_len) / max(h, w)
else: else:
raise Exception('not support limit type, image ') raise Exception('not support limit type, image ')
resize_h = int(h * ratio) resize_h = int(h * ratio)
......
...@@ -164,47 +164,55 @@ class EastRandomCropData(object): ...@@ -164,47 +164,55 @@ class EastRandomCropData(object):
return data return data
class PSERandomCrop(object): class RandomCropImgMask(object):
def __init__(self, size, **kwargs): def __init__(self, size, main_key, crop_keys, p=3 / 8, **kwargs):
self.size = size self.size = size
self.main_key = main_key
self.crop_keys = crop_keys
self.p = p
def __call__(self, data): def __call__(self, data):
imgs = data['imgs'] image = data['image']
h, w = imgs[0].shape[0:2] h, w = image.shape[0:2]
th, tw = self.size th, tw = self.size
if w == tw and h == th: if w == tw and h == th:
return imgs return data
# label中存在文本实例,并且按照概率进行裁剪,使用threshold_label_map控制 mask = data[self.main_key]
if np.max(imgs[2]) > 0 and random.random() > 3 / 8: if np.max(mask) > 0 and random.random() > self.p:
# 文本实例的左上角点 # make sure to crop the text region
tl = np.min(np.where(imgs[2] > 0), axis=1) - self.size tl = np.min(np.where(mask > 0), axis=1) - (th, tw)
tl[tl < 0] = 0 tl[tl < 0] = 0
# 文本实例的右下角点 br = np.max(np.where(mask > 0), axis=1) - (th, tw)
br = np.max(np.where(imgs[2] > 0), axis=1) - self.size
br[br < 0] = 0 br[br < 0] = 0
# 保证选到右下角点时,有足够的距离进行crop
br[0] = min(br[0], h - th) br[0] = min(br[0], h - th)
br[1] = min(br[1], w - tw) br[1] = min(br[1], w - tw)
for _ in range(50000): i = random.randint(tl[0], br[0]) if tl[0] < br[0] else 0
i = random.randint(tl[0], br[0]) j = random.randint(tl[1], br[1]) if tl[1] < br[1] else 0
j = random.randint(tl[1], br[1])
# 保证shrink_label_map有文本
if imgs[1][i:i + th, j:j + tw].sum() <= 0:
continue
else:
break
else: else:
i = random.randint(0, h - th) i = random.randint(0, h - th) if h - th > 0 else 0
j = random.randint(0, w - tw) j = random.randint(0, w - tw) if w - tw > 0 else 0
# return i, j, th, tw # return i, j, th, tw
for idx in range(len(imgs)): for k in data:
if len(imgs[idx].shape) == 3: if k in self.crop_keys:
imgs[idx] = imgs[idx][i:i + th, j:j + tw, :] if len(data[k].shape) == 3:
else: if np.argmin(data[k].shape) == 0:
imgs[idx] = imgs[idx][i:i + th, j:j + tw] img = data[k][:, i:i + th, j:j + tw]
data['imgs'] = imgs if img.shape[1] != img.shape[2]:
a = 1
elif np.argmin(data[k].shape) == 2:
img = data[k][i:i + th, j:j + tw, :]
if img.shape[1] != img.shape[0]:
a = 1
else:
img = data[k]
else:
img = data[k][i:i + th, j:j + tw]
if img.shape[0] != img.shape[1]:
a = 1
data[k] = img
return data return data
...@@ -16,7 +16,7 @@ import math ...@@ -16,7 +16,7 @@ import math
import cv2 import cv2
import numpy as np import numpy as np
import random import random
from PIL import Image
from .text_image_aug import tia_perspective, tia_stretch, tia_distort from .text_image_aug import tia_perspective, tia_stretch, tia_distort
...@@ -43,22 +43,64 @@ class ClsResizeImg(object): ...@@ -43,22 +43,64 @@ class ClsResizeImg(object):
return data return data
class NRTRRecResizeImg(object):
def __init__(self, image_shape, resize_type, padding=False, **kwargs):
self.image_shape = image_shape
self.resize_type = resize_type
self.padding = padding
def __call__(self, data):
img = data['image']
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
image_shape = self.image_shape
if self.padding:
imgC, imgH, imgW = image_shape
# todo: change to 0 and modified image shape
h = img.shape[0]
w = img.shape[1]
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
norm_img = np.expand_dims(resized_image, -1)
norm_img = norm_img.transpose((2, 0, 1))
resized_image = norm_img.astype(np.float32) / 128. - 1.
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
data['image'] = padding_im
return data
if self.resize_type == 'PIL':
image_pil = Image.fromarray(np.uint8(img))
img = image_pil.resize(self.image_shape, Image.ANTIALIAS)
img = np.array(img)
if self.resize_type == 'OpenCV':
img = cv2.resize(img, self.image_shape)
norm_img = np.expand_dims(img, -1)
norm_img = norm_img.transpose((2, 0, 1))
data['image'] = norm_img.astype(np.float32) / 128. - 1.
return data
class RecResizeImg(object): class RecResizeImg(object):
def __init__(self, def __init__(self,
image_shape, image_shape,
infer_mode=False, infer_mode=False,
character_type='ch', character_dict_path='./ppocr/utils/ppocr_keys_v1.txt',
padding=True,
**kwargs): **kwargs):
self.image_shape = image_shape self.image_shape = image_shape
self.infer_mode = infer_mode self.infer_mode = infer_mode
self.character_type = character_type self.character_dict_path = character_dict_path
self.padding = padding
def __call__(self, data): def __call__(self, data):
img = data['image'] img = data['image']
if self.infer_mode and self.character_type == "ch": if self.infer_mode and self.character_dict_path is not None:
norm_img = resize_norm_img_chinese(img, self.image_shape) norm_img = resize_norm_img_chinese(img, self.image_shape)
else: else:
norm_img = resize_norm_img(img, self.image_shape) norm_img = resize_norm_img(img, self.image_shape, self.padding)
data['image'] = norm_img data['image'] = norm_img
return data return data
...@@ -83,16 +125,72 @@ class SRNRecResizeImg(object): ...@@ -83,16 +125,72 @@ class SRNRecResizeImg(object):
return data return data
def resize_norm_img(img, image_shape): class SARRecResizeImg(object):
imgC, imgH, imgW = image_shape def __init__(self, image_shape, width_downsample_ratio=0.25, **kwargs):
self.image_shape = image_shape
self.width_downsample_ratio = width_downsample_ratio
def __call__(self, data):
img = data['image']
norm_img, resize_shape, pad_shape, valid_ratio = resize_norm_img_sar(
img, self.image_shape, self.width_downsample_ratio)
data['image'] = norm_img
data['resized_shape'] = resize_shape
data['pad_shape'] = pad_shape
data['valid_ratio'] = valid_ratio
return data
def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
imgC, imgH, imgW_min, imgW_max = image_shape
h = img.shape[0] h = img.shape[0]
w = img.shape[1] w = img.shape[1]
valid_ratio = 1.0
# make sure new_width is an integral multiple of width_divisor.
width_divisor = int(1 / width_downsample_ratio)
# resize
ratio = w / float(h) ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW: resize_w = math.ceil(imgH * ratio)
if resize_w % width_divisor != 0:
resize_w = round(resize_w / width_divisor) * width_divisor
if imgW_min is not None:
resize_w = max(imgW_min, resize_w)
if imgW_max is not None:
valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
resize_w = min(imgW_max, resize_w)
resized_image = cv2.resize(img, (resize_w, imgH))
resized_image = resized_image.astype('float32')
# norm
if image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
else:
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
resize_shape = resized_image.shape
padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
padding_im[:, :, 0:resize_w] = resized_image
pad_shape = padding_im.shape
return padding_im, resize_shape, pad_shape, valid_ratio
def resize_norm_img(img, image_shape, padding=True):
imgC, imgH, imgW = image_shape
h = img.shape[0]
w = img.shape[1]
if not padding:
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_w = imgW resized_w = imgW
else: else:
resized_w = int(math.ceil(imgH * ratio)) ratio = w / float(h)
resized_image = cv2.resize(img, (resized_w, imgH)) if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32') resized_image = resized_image.astype('float32')
if image_shape[0] == 1: if image_shape[0] == 1:
resized_image = resized_image / 255 resized_image = resized_image / 255
......
...@@ -15,7 +15,6 @@ import numpy as np ...@@ -15,7 +15,6 @@ import numpy as np
import os import os
import random import random
from paddle.io import Dataset from paddle.io import Dataset
from .imaug import transform, create_operators from .imaug import transform, create_operators
......
...@@ -20,11 +20,15 @@ import paddle.nn as nn ...@@ -20,11 +20,15 @@ import paddle.nn as nn
from .det_db_loss import DBLoss from .det_db_loss import DBLoss
from .det_east_loss import EASTLoss from .det_east_loss import EASTLoss
from .det_sast_loss import SASTLoss from .det_sast_loss import SASTLoss
from .det_pse_loss import PSELoss
# rec loss # rec loss
from .rec_ctc_loss import CTCLoss from .rec_ctc_loss import CTCLoss
from .rec_att_loss import AttentionLoss from .rec_att_loss import AttentionLoss
from .rec_srn_loss import SRNLoss from .rec_srn_loss import SRNLoss
from .rec_nrtr_loss import NRTRLoss
from .rec_sar_loss import SARLoss
from .rec_aster_loss import AsterLoss
# cls loss # cls loss
from .cls_loss import ClsLoss from .cls_loss import ClsLoss
...@@ -41,10 +45,12 @@ from .combined_loss import CombinedLoss ...@@ -41,10 +45,12 @@ from .combined_loss import CombinedLoss
# table loss # table loss
from .table_att_loss import TableAttentionLoss from .table_att_loss import TableAttentionLoss
def build_loss(config): def build_loss(config):
support_dict = [ support_dict = [
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss', 'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss',
'SRNLoss', 'PGLoss', 'CombinedLoss', 'TableAttentionLoss' 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss',
'TableAttentionLoss', 'SARLoss', 'AsterLoss'
] ]
config = copy.deepcopy(config) config = copy.deepcopy(config)
module_name = config.pop('name') module_name = config.pop('name')
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
class ACELoss(nn.Layer):
def __init__(self, **kwargs):
super().__init__()
self.loss_func = nn.CrossEntropyLoss(
weight=None,
ignore_index=0,
reduction='none',
soft_label=True,
axis=-1)
def __call__(self, predicts, batch):
if isinstance(predicts, (list, tuple)):
predicts = predicts[-1]
B, N = predicts.shape[:2]
div = paddle.to_tensor([N]).astype('float32')
predicts = nn.functional.softmax(predicts, axis=-1)
aggregation_preds = paddle.sum(predicts, axis=1)
aggregation_preds = paddle.divide(aggregation_preds, div)
length = batch[2].astype("float32")
batch = batch[3].astype("float32")
batch[:, 0] = paddle.subtract(div, length)
batch = paddle.divide(batch, div)
loss = self.loss_func(aggregation_preds, batch)
return {"loss_ace": loss}
...@@ -56,31 +56,34 @@ class CELoss(nn.Layer): ...@@ -56,31 +56,34 @@ class CELoss(nn.Layer):
class KLJSLoss(object): class KLJSLoss(object):
def __init__(self, mode='kl'): def __init__(self, mode='kl'):
assert mode in ['kl', 'js', 'KL', 'JS'], "mode can only be one of ['kl', 'js', 'KL', 'JS']" assert mode in ['kl', 'js', 'KL', 'JS'
], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
self.mode = mode self.mode = mode
def __call__(self, p1, p2, reduction="mean"): def __call__(self, p1, p2, reduction="mean"):
loss = paddle.multiply(p2, paddle.log( (p2+1e-5)/(p1+1e-5) + 1e-5)) loss = paddle.multiply(p2, paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))
if self.mode.lower() == "js": if self.mode.lower() == "js":
loss += paddle.multiply(p1, paddle.log((p1+1e-5)/(p2+1e-5) + 1e-5)) loss += paddle.multiply(
p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5))
loss *= 0.5 loss *= 0.5
if reduction == "mean": if reduction == "mean":
loss = paddle.mean(loss, axis=[1,2]) loss = paddle.mean(loss, axis=[1, 2])
elif reduction=="none" or reduction is None: elif reduction == "none" or reduction is None:
return loss return loss
else: else:
loss = paddle.sum(loss, axis=[1,2]) loss = paddle.sum(loss, axis=[1, 2])
return loss
return loss
class DMLLoss(nn.Layer): class DMLLoss(nn.Layer):
""" """
DMLLoss DMLLoss
""" """
def __init__(self, act=None): def __init__(self, act=None, use_log=False):
super().__init__() super().__init__()
if act is not None: if act is not None:
assert act in ["softmax", "sigmoid"] assert act in ["softmax", "sigmoid"]
...@@ -90,20 +93,24 @@ class DMLLoss(nn.Layer): ...@@ -90,20 +93,24 @@ class DMLLoss(nn.Layer):
self.act = nn.Sigmoid() self.act = nn.Sigmoid()
else: else:
self.act = None self.act = None
self.use_log = use_log
self.jskl_loss = KLJSLoss(mode="js") self.jskl_loss = KLJSLoss(mode="js")
def forward(self, out1, out2): def forward(self, out1, out2):
if self.act is not None: if self.act is not None:
out1 = self.act(out1) out1 = self.act(out1)
out2 = self.act(out2) out2 = self.act(out2)
if len(out1.shape) < 2: if self.use_log:
# for recognition distillation, log is needed for feature map
log_out1 = paddle.log(out1) log_out1 = paddle.log(out1)
log_out2 = paddle.log(out2) log_out2 = paddle.log(out2)
loss = (F.kl_div( loss = (F.kl_div(
log_out1, out2, reduction='batchmean') + F.kl_div( log_out1, out2, reduction='batchmean') + F.kl_div(
log_out2, out1, reduction='batchmean')) / 2.0 log_out2, out1, reduction='batchmean')) / 2.0
else: else:
# for detection distillation log is not needed
loss = self.jskl_loss(out1, out2) loss = self.jskl_loss(out1, out2)
return loss return loss
......
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import pickle
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class CenterLoss(nn.Layer):
"""
Reference: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
"""
def __init__(self,
num_classes=6625,
feat_dim=96,
init_center=False,
center_file_path=None):
super().__init__()
self.num_classes = num_classes
self.feat_dim = feat_dim
self.centers = paddle.randn(
shape=[self.num_classes, self.feat_dim]).astype("float64")
if init_center:
assert os.path.exists(
center_file_path
), f"center path({center_file_path}) must exist when init_center is set as True."
with open(center_file_path, 'rb') as f:
char_dict = pickle.load(f)
for key in char_dict.keys():
self.centers[key] = paddle.to_tensor(char_dict[key])
def __call__(self, predicts, batch):
assert isinstance(predicts, (list, tuple))
features, predicts = predicts
feats_reshape = paddle.reshape(
features, [-1, features.shape[-1]]).astype("float64")
label = paddle.argmax(predicts, axis=2)
label = paddle.reshape(label, [label.shape[0] * label.shape[1]])
batch_size = feats_reshape.shape[0]
#calc l2 distance between feats and centers
square_feat = paddle.sum(paddle.square(feats_reshape),
axis=1,
keepdim=True)
square_feat = paddle.expand(square_feat, [batch_size, self.num_classes])
square_center = paddle.sum(paddle.square(self.centers),
axis=1,
keepdim=True)
square_center = paddle.expand(
square_center, [self.num_classes, batch_size]).astype("float64")
square_center = paddle.transpose(square_center, [1, 0])
distmat = paddle.add(square_feat, square_center)
feat_dot_center = paddle.matmul(feats_reshape,
paddle.transpose(self.centers, [1, 0]))
distmat = distmat - 2.0 * feat_dot_center
#generate the mask
classes = paddle.arange(self.num_classes).astype("int64")
label = paddle.expand(
paddle.unsqueeze(label, 1), (batch_size, self.num_classes))
mask = paddle.equal(
paddle.expand(classes, [batch_size, self.num_classes]),
label).astype("float64")
dist = paddle.multiply(distmat, mask)
loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size
return {'loss_center': loss}
...@@ -15,6 +15,10 @@ ...@@ -15,6 +15,10 @@
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
from .rec_ctc_loss import CTCLoss
from .center_loss import CenterLoss
from .ace_loss import ACELoss
from .distillation_loss import DistillationCTCLoss from .distillation_loss import DistillationCTCLoss
from .distillation_loss import DistillationDMLLoss from .distillation_loss import DistillationDMLLoss
from .distillation_loss import DistillationDistanceLoss, DistillationDBLoss, DistillationDilaDBLoss from .distillation_loss import DistillationDistanceLoss, DistillationDBLoss, DistillationDilaDBLoss
...@@ -49,11 +53,15 @@ class CombinedLoss(nn.Layer): ...@@ -49,11 +53,15 @@ class CombinedLoss(nn.Layer):
loss = loss_func(input, batch, **kargs) loss = loss_func(input, batch, **kargs)
if isinstance(loss, paddle.Tensor): if isinstance(loss, paddle.Tensor):
loss = {"loss_{}_{}".format(str(loss), idx): loss} loss = {"loss_{}_{}".format(str(loss), idx): loss}
weight = self.loss_weight[idx] weight = self.loss_weight[idx]
for key in loss.keys():
if key == "loss": loss = {key: loss[key] * weight for key in loss}
loss_all += loss[key] * weight
else: if "loss" in loss:
loss_dict["{}_{}".format(key, idx)] = loss[key] loss_all += loss["loss"]
else:
loss_all += paddle.add_n(list(loss.values()))
loss_dict.update(loss)
loss_dict["loss"] = loss_all loss_dict["loss"] = loss_all
return loss_dict return loss_dict
...@@ -75,12 +75,6 @@ class BalanceLoss(nn.Layer): ...@@ -75,12 +75,6 @@ class BalanceLoss(nn.Layer):
mask (variable): masked maps. mask (variable): masked maps.
return: (variable) balanced loss return: (variable) balanced loss
""" """
# if self.main_loss_type in ['DiceLoss']:
# # For the loss that returns to scalar value, perform ohem on the mask
# mask = ohem_batch(pred, gt, mask, self.negative_ratio)
# loss = self.loss(pred, gt, mask)
# return loss
positive = gt * mask positive = gt * mask
negative = (1 - gt) * mask negative = (1 - gt) * mask
...@@ -153,53 +147,4 @@ class BCELoss(nn.Layer): ...@@ -153,53 +147,4 @@ class BCELoss(nn.Layer):
def forward(self, input, label, mask=None, weight=None, name=None): def forward(self, input, label, mask=None, weight=None, name=None):
loss = F.binary_cross_entropy(input, label, reduction=self.reduction) loss = F.binary_cross_entropy(input, label, reduction=self.reduction)
return loss return loss
\ No newline at end of file
def ohem_single(score, gt_text, training_mask, ohem_ratio):
pos_num = (int)(np.sum(gt_text > 0.5)) - (
int)(np.sum((gt_text > 0.5) & (training_mask <= 0.5)))
if pos_num == 0:
# selected_mask = gt_text.copy() * 0 # may be not good
selected_mask = training_mask
selected_mask = selected_mask.reshape(
1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
return selected_mask
neg_num = (int)(np.sum(gt_text <= 0.5))
neg_num = (int)(min(pos_num * ohem_ratio, neg_num))
if neg_num == 0:
selected_mask = training_mask
selected_mask = selected_mask.reshape(
1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
return selected_mask
neg_score = score[gt_text <= 0.5]
# 将负样本得分从高到低排序
neg_score_sorted = np.sort(-neg_score)
threshold = -neg_score_sorted[neg_num - 1]
# 选出 得分高的 负样本 和正样本 的 mask
selected_mask = ((score >= threshold) |
(gt_text > 0.5)) & (training_mask > 0.5)
selected_mask = selected_mask.reshape(
1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
return selected_mask
def ohem_batch(scores, gt_texts, training_masks, ohem_ratio):
scores = scores.numpy()
gt_texts = gt_texts.numpy()
training_masks = training_masks.numpy()
selected_masks = []
for i in range(scores.shape[0]):
selected_masks.append(
ohem_single(scores[i, :, :], gt_texts[i, :, :], training_masks[
i, :, :], ohem_ratio))
selected_masks = np.concatenate(selected_masks, 0)
selected_masks = paddle.to_tensor(selected_masks)
return selected_masks
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import nn
from paddle.nn import functional as F
import numpy as np
from ppocr.utils.iou import iou
class PSELoss(nn.Layer):
def __init__(self,
alpha,
ohem_ratio=3,
kernel_sample_mask='pred',
reduction='sum',
eps=1e-6,
**kwargs):
"""Implement PSE Loss.
"""
super(PSELoss, self).__init__()
assert reduction in ['sum', 'mean', 'none']
self.alpha = alpha
self.ohem_ratio = ohem_ratio
self.kernel_sample_mask = kernel_sample_mask
self.reduction = reduction
self.eps = eps
def forward(self, outputs, labels):
predicts = outputs['maps']
predicts = F.interpolate(predicts, scale_factor=4)
texts = predicts[:, 0, :, :]
kernels = predicts[:, 1:, :, :]
gt_texts, gt_kernels, training_masks = labels[1:]
# text loss
selected_masks = self.ohem_batch(texts, gt_texts, training_masks)
loss_text = self.dice_loss(texts, gt_texts, selected_masks)
iou_text = iou((texts > 0).astype('int64'),
gt_texts,
training_masks,
reduce=False)
losses = dict(loss_text=loss_text, iou_text=iou_text)
# kernel loss
loss_kernels = []
if self.kernel_sample_mask == 'gt':
selected_masks = gt_texts * training_masks
elif self.kernel_sample_mask == 'pred':
selected_masks = (
F.sigmoid(texts) > 0.5).astype('float32') * training_masks
for i in range(kernels.shape[1]):
kernel_i = kernels[:, i, :, :]
gt_kernel_i = gt_kernels[:, i, :, :]
loss_kernel_i = self.dice_loss(kernel_i, gt_kernel_i,
selected_masks)
loss_kernels.append(loss_kernel_i)
loss_kernels = paddle.mean(paddle.stack(loss_kernels, axis=1), axis=1)
iou_kernel = iou((kernels[:, -1, :, :] > 0).astype('int64'),
gt_kernels[:, -1, :, :],
training_masks * gt_texts,
reduce=False)
losses.update(dict(loss_kernels=loss_kernels, iou_kernel=iou_kernel))
loss = self.alpha * loss_text + (1 - self.alpha) * loss_kernels
losses['loss'] = loss
if self.reduction == 'sum':
losses = {x: paddle.sum(v) for x, v in losses.items()}
elif self.reduction == 'mean':
losses = {x: paddle.mean(v) for x, v in losses.items()}
return losses
def dice_loss(self, input, target, mask):
input = F.sigmoid(input)
input = input.reshape([input.shape[0], -1])
target = target.reshape([target.shape[0], -1])
mask = mask.reshape([mask.shape[0], -1])
input = input * mask
target = target * mask
a = paddle.sum(input * target, 1)
b = paddle.sum(input * input, 1) + self.eps
c = paddle.sum(target * target, 1) + self.eps
d = (2 * a) / (b + c)
return 1 - d
def ohem_single(self, score, gt_text, training_mask, ohem_ratio=3):
pos_num = int(paddle.sum((gt_text > 0.5).astype('float32'))) - int(
paddle.sum(
paddle.logical_and((gt_text > 0.5), (training_mask <= 0.5))
.astype('float32')))
if pos_num == 0:
selected_mask = training_mask
selected_mask = selected_mask.reshape(
[1, selected_mask.shape[0], selected_mask.shape[1]]).astype(
'float32')
return selected_mask
neg_num = int(paddle.sum((gt_text <= 0.5).astype('float32')))
neg_num = int(min(pos_num * ohem_ratio, neg_num))
if neg_num == 0:
selected_mask = training_mask
selected_mask = selected_mask.view(
1, selected_mask.shape[0],
selected_mask.shape[1]).astype('float32')
return selected_mask
neg_score = paddle.masked_select(score, gt_text <= 0.5)
neg_score_sorted = paddle.sort(-neg_score)
threshold = -neg_score_sorted[neg_num - 1]
selected_mask = paddle.logical_and(
paddle.logical_or((score >= threshold), (gt_text > 0.5)),
(training_mask > 0.5))
selected_mask = selected_mask.reshape(
[1, selected_mask.shape[0], selected_mask.shape[1]]).astype(
'float32')
return selected_mask
def ohem_batch(self, scores, gt_texts, training_masks, ohem_ratio=3):
selected_masks = []
for i in range(scores.shape[0]):
selected_masks.append(
self.ohem_single(scores[i, :, :], gt_texts[i, :, :],
training_masks[i, :, :], ohem_ratio))
selected_masks = paddle.concat(selected_masks, 0).astype('float32')
return selected_masks
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment