Unverified Commit 631fd9fd authored by xiaoting's avatar xiaoting Committed by GitHub
Browse files

Merge branch 'dygraph' into dygraph_doc

parents 8520dd1e 90b968d5
This diff is collapsed.
......@@ -52,6 +52,7 @@ class DetLabelEncode(object):
txt_tags.append(True)
else:
txt_tags.append(False)
boxes = self.expand_points_num(boxes)
boxes = np.array(boxes, dtype=np.float32)
txt_tags = np.array(txt_tags, dtype=np.bool)
......@@ -70,6 +71,17 @@ class DetLabelEncode(object):
rect[3] = pts[np.argmax(diff)]
return rect
def expand_points_num(self, boxes):
max_points_num = 0
for box in boxes:
if len(box) > max_points_num:
max_points_num = len(box)
ex_boxes = []
for box in boxes:
ex_box = box + [box[-1]] * (max_points_num - len(box))
ex_boxes.append(ex_box)
return ex_boxes
class BaseRecLabelEncode(object):
""" Convert between text-label and text-index """
......@@ -83,7 +95,7 @@ class BaseRecLabelEncode(object):
'ch', 'en', 'en_sensitive', 'french', 'german', 'japan', 'korean'
]
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
support_character_type, self.character_str)
support_character_type, character_type)
self.max_text_len = max_text_length
if character_type == "en":
......
This diff is collapsed.
This diff is collapsed.
......@@ -27,14 +27,13 @@ class SimpleDataSet(Dataset):
global_config = config['Global']
dataset_config = config[mode]['dataset']
loader_config = config[mode]['loader']
batch_size = loader_config['batch_size_per_card']
self.delimiter = dataset_config.get('delimiter', '\t')
label_file_list = dataset_config.pop('label_file_list')
data_source_num = len(label_file_list)
ratio_list = dataset_config.get("ratio_list", [1.0])
if isinstance(ratio_list, (float, int)):
ratio_list = [float(ratio_list)] * len(data_source_num)
ratio_list = [float(ratio_list)] * int(data_source_num)
assert len(
ratio_list
......@@ -76,6 +75,8 @@ class SimpleDataSet(Dataset):
label = substr[1]
img_path = os.path.join(self.data_dir, file_name)
data = {'img_path': img_path, 'label': label}
if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path))
with open(data['img_path'], 'rb') as f:
img = f.read()
data['image'] = img
......
......@@ -18,6 +18,8 @@ import copy
def build_loss(config):
# det loss
from .det_db_loss import DBLoss
from .det_east_loss import EASTLoss
from .det_sast_loss import SASTLoss
# rec loss
from .rec_ctc_loss import CTCLoss
......@@ -25,7 +27,7 @@ def build_loss(config):
# cls loss
from .cls_loss import ClsLoss
support_dict = ['DBLoss', 'CTCLoss', 'ClsLoss']
support_dict = ['DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss']
config = copy.deepcopy(config)
module_name = config.pop('name')
......
This diff is collapsed.
This diff is collapsed.
......@@ -16,7 +16,7 @@ from __future__ import division
from __future__ import print_function
from paddle import nn
from ppocr.modeling.transform import build_transform
from ppocr.modeling.transforms import build_transform
from ppocr.modeling.backbones import build_backbone
from ppocr.modeling.necks import build_neck
from ppocr.modeling.heads import build_head
......
......@@ -19,6 +19,7 @@ def build_backbone(config, model_type):
if model_type == 'det':
from .det_mobilenet_v3 import MobileNetV3
from .det_resnet_vd import ResNet
from .det_resnet_vd_sast import ResNet_SAST
support_dict = ['MobileNetV3', 'ResNet', 'ResNet_SAST']
elif model_type == 'rec' or model_type == 'cls':
from .rec_mobilenet_v3 import MobileNetV3
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment