Unverified Commit 3f1cb773 authored by Double_V's avatar Double_V Committed by GitHub
Browse files

Merge branch 'dygraph' into lite

parents af1ac7c2 f687e092
doc/joinus.PNG

107 KB | W: | H:

doc/joinus.PNG

109 KB | W: | H:

doc/joinus.PNG
doc/joinus.PNG
doc/joinus.PNG
doc/joinus.PNG
  • 2-up
  • Swipe
  • Onion skin
...@@ -146,7 +146,8 @@ def parse_args(mMain=True, add_help=True): ...@@ -146,7 +146,8 @@ def parse_args(mMain=True, add_help=True):
# DB parmas # DB parmas
parser.add_argument("--det_db_thresh", type=float, default=0.3) parser.add_argument("--det_db_thresh", type=float, default=0.3)
parser.add_argument("--det_db_box_thresh", type=float, default=0.5) parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
parser.add_argument("--det_db_unclip_ratio", type=float, default=2.0) parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6)
parser.add_argument("--use_dilation", type=bool, default=False)
# EAST parmas # EAST parmas
parser.add_argument("--det_east_score_thresh", type=float, default=0.8) parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
...@@ -193,7 +194,8 @@ def parse_args(mMain=True, add_help=True): ...@@ -193,7 +194,8 @@ def parse_args(mMain=True, add_help=True):
det_limit_type='max', det_limit_type='max',
det_db_thresh=0.3, det_db_thresh=0.3,
det_db_box_thresh=0.5, det_db_box_thresh=0.5,
det_db_unclip_ratio=2.0, det_db_unclip_ratio=1.6,
use_dilation=False,
det_east_score_thresh=0.8, det_east_score_thresh=0.8,
det_east_cover_thresh=0.1, det_east_cover_thresh=0.1,
det_east_nms_thresh=0.2, det_east_nms_thresh=0.2,
......
...@@ -33,7 +33,7 @@ import paddle.distributed as dist ...@@ -33,7 +33,7 @@ import paddle.distributed as dist
from ppocr.data.imaug import transform, create_operators from ppocr.data.imaug import transform, create_operators
from ppocr.data.simple_dataset import SimpleDataSet from ppocr.data.simple_dataset import SimpleDataSet
from ppocr.data.lmdb_dataset import LMDBDateSet from ppocr.data.lmdb_dataset import LMDBDataSet
__all__ = ['build_dataloader', 'transform', 'create_operators'] __all__ = ['build_dataloader', 'transform', 'create_operators']
...@@ -51,20 +51,21 @@ signal.signal(signal.SIGINT, term_mp) ...@@ -51,20 +51,21 @@ signal.signal(signal.SIGINT, term_mp)
signal.signal(signal.SIGTERM, term_mp) signal.signal(signal.SIGTERM, term_mp)
def build_dataloader(config, mode, device, logger): def build_dataloader(config, mode, device, logger, seed=None):
config = copy.deepcopy(config) config = copy.deepcopy(config)
support_dict = ['SimpleDataSet', 'LMDBDateSet'] support_dict = ['SimpleDataSet', 'LMDBDataSet']
module_name = config[mode]['dataset']['name'] module_name = config[mode]['dataset']['name']
assert module_name in support_dict, Exception( assert module_name in support_dict, Exception(
'DataSet only support {}'.format(support_dict)) 'DataSet only support {}'.format(support_dict))
assert mode in ['Train', 'Eval', 'Test' assert mode in ['Train', 'Eval', 'Test'
], "Mode should be Train, Eval or Test." ], "Mode should be Train, Eval or Test."
dataset = eval(module_name)(config, mode, logger) dataset = eval(module_name)(config, mode, logger, seed)
loader_config = config[mode]['loader'] loader_config = config[mode]['loader']
batch_size = loader_config['batch_size_per_card'] batch_size = loader_config['batch_size_per_card']
drop_last = loader_config['drop_last'] drop_last = loader_config['drop_last']
shuffle = loader_config['shuffle']
num_workers = loader_config['num_workers'] num_workers = loader_config['num_workers']
if 'use_shared_memory' in loader_config.keys(): if 'use_shared_memory' in loader_config.keys():
use_shared_memory = loader_config['use_shared_memory'] use_shared_memory = loader_config['use_shared_memory']
...@@ -75,14 +76,14 @@ def build_dataloader(config, mode, device, logger): ...@@ -75,14 +76,14 @@ def build_dataloader(config, mode, device, logger):
batch_sampler = DistributedBatchSampler( batch_sampler = DistributedBatchSampler(
dataset=dataset, dataset=dataset,
batch_size=batch_size, batch_size=batch_size,
shuffle=False, shuffle=shuffle,
drop_last=drop_last) drop_last=drop_last)
else: else:
#Distribute data to single card #Distribute data to single card
batch_sampler = BatchSampler( batch_sampler = BatchSampler(
dataset=dataset, dataset=dataset,
batch_size=batch_size, batch_size=batch_size,
shuffle=False, shuffle=shuffle,
drop_last=drop_last) drop_last=drop_last)
data_loader = DataLoader( data_loader = DataLoader(
......
...@@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap ...@@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap
from .make_shrink_map import MakeShrinkMap from .make_shrink_map import MakeShrinkMap
from .random_crop_data import EastRandomCropData, PSERandomCrop from .random_crop_data import EastRandomCropData, PSERandomCrop
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg
from .randaugment import RandAugment from .randaugment import RandAugment
from .operators import * from .operators import *
from .label_ops import * from .label_ops import *
......
...@@ -18,6 +18,7 @@ from __future__ import print_function ...@@ -18,6 +18,7 @@ from __future__ import print_function
from __future__ import unicode_literals from __future__ import unicode_literals
import numpy as np import numpy as np
import string
class ClsLabelEncode(object): class ClsLabelEncode(object):
...@@ -92,18 +93,28 @@ class BaseRecLabelEncode(object): ...@@ -92,18 +93,28 @@ class BaseRecLabelEncode(object):
character_type='ch', character_type='ch',
use_space_char=False): use_space_char=False):
support_character_type = [ support_character_type = [
'ch', 'en', 'en_sensitive', 'french', 'german', 'japan', 'korean' 'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
'EN', 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs',
'oc', 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi',
'mr', 'ne'
] ]
assert character_type in support_character_type, "Only {} are supported now but get {}".format( assert character_type in support_character_type, "Only {} are supported now but get {}".format(
support_character_type, character_type) support_character_type, character_type)
self.max_text_len = max_text_length self.max_text_len = max_text_length
self.beg_str = "sos"
self.end_str = "eos"
if character_type == "en": if character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str) dict_character = list(self.character_str)
elif character_type in ["ch", "french", "german", "japan", "korean"]: elif character_type == "EN_symbol":
# same with ASTER setting (use 94 char).
self.character_str = string.printable[:-6]
dict_character = list(self.character_str)
elif character_type in support_character_type:
self.character_str = "" self.character_str = ""
assert character_dict_path is not None, "character_dict_path should not be None when character_type is ch" assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format(
character_type)
with open(character_dict_path, "rb") as fin: with open(character_dict_path, "rb") as fin:
lines = fin.readlines() lines = fin.readlines()
for line in lines: for line in lines:
...@@ -112,11 +123,6 @@ class BaseRecLabelEncode(object): ...@@ -112,11 +123,6 @@ class BaseRecLabelEncode(object):
if use_space_char: if use_space_char:
self.character_str += " " self.character_str += " "
dict_character = list(self.character_str) dict_character = list(self.character_str)
elif character_type == "en_sensitive":
# same with ASTER setting (use 94 char).
import string
self.character_str = string.printable[:-6]
dict_character = list(self.character_str)
self.character_type = character_type self.character_type = character_type
dict_character = self.add_special_char(dict_character) dict_character = self.add_special_char(dict_character)
self.dict = {} self.dict = {}
...@@ -193,16 +199,76 @@ class AttnLabelEncode(BaseRecLabelEncode): ...@@ -193,16 +199,76 @@ class AttnLabelEncode(BaseRecLabelEncode):
super(AttnLabelEncode, super(AttnLabelEncode,
self).__init__(max_text_length, character_dict_path, self).__init__(max_text_length, character_dict_path,
character_type, use_space_char) character_type, use_space_char)
def add_special_char(self, dict_character):
self.beg_str = "sos" self.beg_str = "sos"
self.end_str = "eos" self.end_str = "eos"
dict_character = [self.beg_str] + dict_character + [self.end_str]
return dict_character
def __call__(self, data):
text = data['label']
text = self.encode(text)
if text is None:
return None
if len(text) >= self.max_text_len:
return None
data['length'] = np.array(len(text))
text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len
- len(text) - 2)
data['label'] = np.array(text)
return data
def get_ignored_tokens(self):
beg_idx = self.get_beg_end_flag_idx("beg")
end_idx = self.get_beg_end_flag_idx("end")
return [beg_idx, end_idx]
def get_beg_end_flag_idx(self, beg_or_end):
if beg_or_end == "beg":
idx = np.array(self.dict[self.beg_str])
elif beg_or_end == "end":
idx = np.array(self.dict[self.end_str])
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end
return idx
class SRNLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length=25,
character_dict_path=None,
character_type='en',
use_space_char=False,
**kwargs):
super(SRNLabelEncode,
self).__init__(max_text_length, character_dict_path,
character_type, use_space_char)
def add_special_char(self, dict_character): def add_special_char(self, dict_character):
dict_character = [self.beg_str, self.end_str] + dict_character dict_character = dict_character + [self.beg_str, self.end_str]
return dict_character return dict_character
def __call__(self, text): def __call__(self, data):
text = data['label']
text = self.encode(text) text = self.encode(text)
return text char_num = len(self.character)
if text is None:
return None
if len(text) > self.max_text_len:
return None
data['length'] = np.array(len(text))
text = text + [char_num - 1] * (self.max_text_len - len(text))
data['label'] = np.array(text)
return data
def get_ignored_tokens(self):
beg_idx = self.get_beg_end_flag_idx("beg")
end_idx = self.get_beg_end_flag_idx("end")
return [beg_idx, end_idx]
def get_beg_end_flag_idx(self, beg_or_end): def get_beg_end_flag_idx(self, beg_or_end):
if beg_or_end == "beg": if beg_or_end == "beg":
......
...@@ -32,7 +32,6 @@ class MakeShrinkMap(object): ...@@ -32,7 +32,6 @@ class MakeShrinkMap(object):
text_polys, ignore_tags = self.validate_polygons(text_polys, text_polys, ignore_tags = self.validate_polygons(text_polys,
ignore_tags, h, w) ignore_tags, h, w)
gt = np.zeros((h, w), dtype=np.float32) gt = np.zeros((h, w), dtype=np.float32)
# gt = np.zeros((1, h, w), dtype=np.float32)
mask = np.ones((h, w), dtype=np.float32) mask = np.ones((h, w), dtype=np.float32)
for i in range(len(text_polys)): for i in range(len(text_polys)):
polygon = text_polys[i] polygon = text_polys[i]
...@@ -44,21 +43,34 @@ class MakeShrinkMap(object): ...@@ -44,21 +43,34 @@ class MakeShrinkMap(object):
ignore_tags[i] = True ignore_tags[i] = True
else: else:
polygon_shape = Polygon(polygon) polygon_shape = Polygon(polygon)
distance = polygon_shape.area * ( subject = [tuple(l) for l in polygon]
1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length
subject = [tuple(l) for l in text_polys[i]]
padding = pyclipper.PyclipperOffset() padding = pyclipper.PyclipperOffset()
padding.AddPath(subject, pyclipper.JT_ROUND, padding.AddPath(subject, pyclipper.JT_ROUND,
pyclipper.ET_CLOSEDPOLYGON) pyclipper.ET_CLOSEDPOLYGON)
shrinked = []
# Increase the shrink ratio every time we get multiple polygon returned back
possible_ratios = np.arange(self.shrink_ratio, 1,
self.shrink_ratio)
np.append(possible_ratios, 1)
# print(possible_ratios)
for ratio in possible_ratios:
# print(f"Change shrink ratio to {ratio}")
distance = polygon_shape.area * (
1 - np.power(ratio, 2)) / polygon_shape.length
shrinked = padding.Execute(-distance) shrinked = padding.Execute(-distance)
if len(shrinked) == 1:
break
if shrinked == []: if shrinked == []:
cv2.fillPoly(mask, cv2.fillPoly(mask,
polygon.astype(np.int32)[np.newaxis, :, :], 0) polygon.astype(np.int32)[np.newaxis, :, :], 0)
ignore_tags[i] = True ignore_tags[i] = True
continue continue
shrinked = np.array(shrinked[0]).reshape(-1, 2)
cv2.fillPoly(gt, [shrinked.astype(np.int32)], 1) for each_shirnk in shrinked:
# cv2.fillPoly(gt[0], [shrinked.astype(np.int32)], 1) shirnk = np.array(each_shirnk).reshape(-1, 2)
cv2.fillPoly(gt, [shirnk.astype(np.int32)], 1)
data['shrink_map'] = gt data['shrink_map'] = gt
data['shrink_mask'] = mask data['shrink_mask'] = mask
...@@ -84,11 +96,12 @@ class MakeShrinkMap(object): ...@@ -84,11 +96,12 @@ class MakeShrinkMap(object):
return polygons, ignore_tags return polygons, ignore_tags
def polygon_area(self, polygon): def polygon_area(self, polygon):
# return cv2.contourArea(polygon.astype(np.float32)) """
edge = 0 compute polygon area
for i in range(polygon.shape[0]): """
next_index = (i + 1) % polygon.shape[0] area = 0
edge += (polygon[next_index, 0] - polygon[i, 0]) * ( q = polygon[-1]
polygon[next_index, 1] - polygon[i, 1]) for p in polygon:
area += p[0] * q[1] - p[1] * q[0]
return edge / 2. q = p
return area / 2.0
...@@ -185,8 +185,8 @@ class DetResizeForTest(object): ...@@ -185,8 +185,8 @@ class DetResizeForTest(object):
resize_h = int(h * ratio) resize_h = int(h * ratio)
resize_w = int(w * ratio) resize_w = int(w * ratio)
resize_h = int(round(resize_h / 32) * 32) resize_h = max(int(round(resize_h / 32) * 32), 32)
resize_w = int(round(resize_w / 32) * 32) resize_w = max(int(round(resize_w / 32) * 32), 32)
try: try:
if int(resize_w) <= 0 or int(resize_h) <= 0: if int(resize_w) <= 0 or int(resize_h) <= 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment