"docs/vscode:/vscode.git/clone" did not exist on "988369a01c4bb910a99cde46baa9e2b5b0b69aab"
Unverified Commit 631fd9fd authored by xiaoting's avatar xiaoting Committed by GitHub
Browse files

Merge branch 'dygraph' into dygraph_doc

parents 8520dd1e 90b968d5
This diff is collapsed.
...@@ -52,6 +52,7 @@ class DetLabelEncode(object): ...@@ -52,6 +52,7 @@ class DetLabelEncode(object):
txt_tags.append(True) txt_tags.append(True)
else: else:
txt_tags.append(False) txt_tags.append(False)
boxes = self.expand_points_num(boxes)
boxes = np.array(boxes, dtype=np.float32) boxes = np.array(boxes, dtype=np.float32)
txt_tags = np.array(txt_tags, dtype=np.bool) txt_tags = np.array(txt_tags, dtype=np.bool)
...@@ -70,6 +71,17 @@ class DetLabelEncode(object): ...@@ -70,6 +71,17 @@ class DetLabelEncode(object):
rect[3] = pts[np.argmax(diff)] rect[3] = pts[np.argmax(diff)]
return rect return rect
def expand_points_num(self, boxes):
max_points_num = 0
for box in boxes:
if len(box) > max_points_num:
max_points_num = len(box)
ex_boxes = []
for box in boxes:
ex_box = box + [box[-1]] * (max_points_num - len(box))
ex_boxes.append(ex_box)
return ex_boxes
class BaseRecLabelEncode(object): class BaseRecLabelEncode(object):
""" Convert between text-label and text-index """ """ Convert between text-label and text-index """
...@@ -83,7 +95,7 @@ class BaseRecLabelEncode(object): ...@@ -83,7 +95,7 @@ class BaseRecLabelEncode(object):
'ch', 'en', 'en_sensitive', 'french', 'german', 'japan', 'korean' 'ch', 'en', 'en_sensitive', 'french', 'german', 'japan', 'korean'
] ]
assert character_type in support_character_type, "Only {} are supported now but get {}".format( assert character_type in support_character_type, "Only {} are supported now but get {}".format(
support_character_type, self.character_str) support_character_type, character_type)
self.max_text_len = max_text_length self.max_text_len = max_text_length
if character_type == "en": if character_type == "en":
......
This diff is collapsed.
This diff is collapsed.
...@@ -27,14 +27,13 @@ class SimpleDataSet(Dataset): ...@@ -27,14 +27,13 @@ class SimpleDataSet(Dataset):
global_config = config['Global'] global_config = config['Global']
dataset_config = config[mode]['dataset'] dataset_config = config[mode]['dataset']
loader_config = config[mode]['loader'] loader_config = config[mode]['loader']
batch_size = loader_config['batch_size_per_card']
self.delimiter = dataset_config.get('delimiter', '\t') self.delimiter = dataset_config.get('delimiter', '\t')
label_file_list = dataset_config.pop('label_file_list') label_file_list = dataset_config.pop('label_file_list')
data_source_num = len(label_file_list) data_source_num = len(label_file_list)
ratio_list = dataset_config.get("ratio_list", [1.0]) ratio_list = dataset_config.get("ratio_list", [1.0])
if isinstance(ratio_list, (float, int)): if isinstance(ratio_list, (float, int)):
ratio_list = [float(ratio_list)] * len(data_source_num) ratio_list = [float(ratio_list)] * int(data_source_num)
assert len( assert len(
ratio_list ratio_list
...@@ -76,6 +75,8 @@ class SimpleDataSet(Dataset): ...@@ -76,6 +75,8 @@ class SimpleDataSet(Dataset):
label = substr[1] label = substr[1]
img_path = os.path.join(self.data_dir, file_name) img_path = os.path.join(self.data_dir, file_name)
data = {'img_path': img_path, 'label': label} data = {'img_path': img_path, 'label': label}
if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path))
with open(data['img_path'], 'rb') as f: with open(data['img_path'], 'rb') as f:
img = f.read() img = f.read()
data['image'] = img data['image'] = img
......
...@@ -18,6 +18,8 @@ import copy ...@@ -18,6 +18,8 @@ import copy
def build_loss(config): def build_loss(config):
# det loss # det loss
from .det_db_loss import DBLoss from .det_db_loss import DBLoss
from .det_east_loss import EASTLoss
from .det_sast_loss import SASTLoss
# rec loss # rec loss
from .rec_ctc_loss import CTCLoss from .rec_ctc_loss import CTCLoss
...@@ -25,7 +27,7 @@ def build_loss(config): ...@@ -25,7 +27,7 @@ def build_loss(config):
# cls loss # cls loss
from .cls_loss import ClsLoss from .cls_loss import ClsLoss
support_dict = ['DBLoss', 'CTCLoss', 'ClsLoss'] support_dict = ['DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss']
config = copy.deepcopy(config) config = copy.deepcopy(config)
module_name = config.pop('name') module_name = config.pop('name')
......
This diff is collapsed.
This diff is collapsed.
...@@ -16,7 +16,7 @@ from __future__ import division ...@@ -16,7 +16,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from paddle import nn from paddle import nn
from ppocr.modeling.transform import build_transform from ppocr.modeling.transforms import build_transform
from ppocr.modeling.backbones import build_backbone from ppocr.modeling.backbones import build_backbone
from ppocr.modeling.necks import build_neck from ppocr.modeling.necks import build_neck
from ppocr.modeling.heads import build_head from ppocr.modeling.heads import build_head
......
...@@ -19,6 +19,7 @@ def build_backbone(config, model_type): ...@@ -19,6 +19,7 @@ def build_backbone(config, model_type):
if model_type == 'det': if model_type == 'det':
from .det_mobilenet_v3 import MobileNetV3 from .det_mobilenet_v3 import MobileNetV3
from .det_resnet_vd import ResNet from .det_resnet_vd import ResNet
from .det_resnet_vd_sast import ResNet_SAST
support_dict = ['MobileNetV3', 'ResNet', 'ResNet_SAST'] support_dict = ['MobileNetV3', 'ResNet', 'ResNet_SAST']
elif model_type == 'rec' or model_type == 'cls': elif model_type == 'rec' or model_type == 'cls':
from .rec_mobilenet_v3 import MobileNetV3 from .rec_mobilenet_v3 import MobileNetV3
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment