Commit 22a9f2ad authored by WenmuZhou's avatar WenmuZhou
Browse files

update faq

parents b544a561 0c5c9f69
doc/joinus.PNG

107 KB | W: | H:

doc/joinus.PNG

105 KB | W: | H:

doc/joinus.PNG
doc/joinus.PNG
doc/joinus.PNG
doc/joinus.PNG
  • 2-up
  • Swipe
  • Onion skin
...@@ -33,7 +33,7 @@ import paddle.distributed as dist ...@@ -33,7 +33,7 @@ import paddle.distributed as dist
from ppocr.data.imaug import transform, create_operators from ppocr.data.imaug import transform, create_operators
from ppocr.data.simple_dataset import SimpleDataSet from ppocr.data.simple_dataset import SimpleDataSet
from ppocr.data.lmdb_dataset import LMDBDateSet from ppocr.data.lmdb_dataset import LMDBDataSet
__all__ = ['build_dataloader', 'transform', 'create_operators'] __all__ = ['build_dataloader', 'transform', 'create_operators']
...@@ -51,20 +51,21 @@ signal.signal(signal.SIGINT, term_mp) ...@@ -51,20 +51,21 @@ signal.signal(signal.SIGINT, term_mp)
signal.signal(signal.SIGTERM, term_mp) signal.signal(signal.SIGTERM, term_mp)
def build_dataloader(config, mode, device, logger): def build_dataloader(config, mode, device, logger, seed=None):
config = copy.deepcopy(config) config = copy.deepcopy(config)
support_dict = ['SimpleDataSet', 'LMDBDateSet'] support_dict = ['SimpleDataSet', 'LMDBDataSet']
module_name = config[mode]['dataset']['name'] module_name = config[mode]['dataset']['name']
assert module_name in support_dict, Exception( assert module_name in support_dict, Exception(
'DataSet only support {}'.format(support_dict)) 'DataSet only support {}'.format(support_dict))
assert mode in ['Train', 'Eval', 'Test' assert mode in ['Train', 'Eval', 'Test'
], "Mode should be Train, Eval or Test." ], "Mode should be Train, Eval or Test."
dataset = eval(module_name)(config, mode, logger) dataset = eval(module_name)(config, mode, logger, seed)
loader_config = config[mode]['loader'] loader_config = config[mode]['loader']
batch_size = loader_config['batch_size_per_card'] batch_size = loader_config['batch_size_per_card']
drop_last = loader_config['drop_last'] drop_last = loader_config['drop_last']
shuffle = loader_config['shuffle']
num_workers = loader_config['num_workers'] num_workers = loader_config['num_workers']
if 'use_shared_memory' in loader_config.keys(): if 'use_shared_memory' in loader_config.keys():
use_shared_memory = loader_config['use_shared_memory'] use_shared_memory = loader_config['use_shared_memory']
...@@ -75,14 +76,14 @@ def build_dataloader(config, mode, device, logger): ...@@ -75,14 +76,14 @@ def build_dataloader(config, mode, device, logger):
batch_sampler = DistributedBatchSampler( batch_sampler = DistributedBatchSampler(
dataset=dataset, dataset=dataset,
batch_size=batch_size, batch_size=batch_size,
shuffle=False, shuffle=shuffle,
drop_last=drop_last) drop_last=drop_last)
else: else:
#Distribute data to single card #Distribute data to single card
batch_sampler = BatchSampler( batch_sampler = BatchSampler(
dataset=dataset, dataset=dataset,
batch_size=batch_size, batch_size=batch_size,
shuffle=False, shuffle=shuffle,
drop_last=drop_last) drop_last=drop_last)
data_loader = DataLoader( data_loader = DataLoader(
......
...@@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap ...@@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap
from .make_shrink_map import MakeShrinkMap from .make_shrink_map import MakeShrinkMap
from .random_crop_data import EastRandomCropData, PSERandomCrop from .random_crop_data import EastRandomCropData, PSERandomCrop
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg
from .randaugment import RandAugment from .randaugment import RandAugment
from .operators import * from .operators import *
from .label_ops import * from .label_ops import *
......
...@@ -18,6 +18,7 @@ from __future__ import print_function ...@@ -18,6 +18,7 @@ from __future__ import print_function
from __future__ import unicode_literals from __future__ import unicode_literals
import numpy as np import numpy as np
import string
class ClsLabelEncode(object): class ClsLabelEncode(object):
...@@ -92,18 +93,28 @@ class BaseRecLabelEncode(object): ...@@ -92,18 +93,28 @@ class BaseRecLabelEncode(object):
character_type='ch', character_type='ch',
use_space_char=False): use_space_char=False):
support_character_type = [ support_character_type = [
'ch', 'en', 'en_sensitive', 'french', 'german', 'japan', 'korean' 'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
'EN', 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs',
'oc', 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi',
'mr', 'ne'
] ]
assert character_type in support_character_type, "Only {} are supported now but get {}".format( assert character_type in support_character_type, "Only {} are supported now but get {}".format(
support_character_type, character_type) support_character_type, character_type)
self.max_text_len = max_text_length self.max_text_len = max_text_length
self.beg_str = "sos"
self.end_str = "eos"
if character_type == "en": if character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str) dict_character = list(self.character_str)
elif character_type in ["ch", "french", "german", "japan", "korean"]: elif character_type == "EN_symbol":
# same with ASTER setting (use 94 char).
self.character_str = string.printable[:-6]
dict_character = list(self.character_str)
elif character_type in support_character_type:
self.character_str = "" self.character_str = ""
assert character_dict_path is not None, "character_dict_path should not be None when character_type is ch" assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format(
character_type)
with open(character_dict_path, "rb") as fin: with open(character_dict_path, "rb") as fin:
lines = fin.readlines() lines = fin.readlines()
for line in lines: for line in lines:
...@@ -112,11 +123,6 @@ class BaseRecLabelEncode(object): ...@@ -112,11 +123,6 @@ class BaseRecLabelEncode(object):
if use_space_char: if use_space_char:
self.character_str += " " self.character_str += " "
dict_character = list(self.character_str) dict_character = list(self.character_str)
elif character_type == "en_sensitive":
# same with ASTER setting (use 94 char).
import string
self.character_str = string.printable[:-6]
dict_character = list(self.character_str)
self.character_type = character_type self.character_type = character_type
dict_character = self.add_special_char(dict_character) dict_character = self.add_special_char(dict_character)
self.dict = {} self.dict = {}
...@@ -213,3 +219,49 @@ class AttnLabelEncode(BaseRecLabelEncode): ...@@ -213,3 +219,49 @@ class AttnLabelEncode(BaseRecLabelEncode):
assert False, "Unsupport type %s in get_beg_end_flag_idx" \ assert False, "Unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end % beg_or_end
return idx return idx
class SRNLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length=25,
character_dict_path=None,
character_type='en',
use_space_char=False,
**kwargs):
super(SRNLabelEncode,
self).__init__(max_text_length, character_dict_path,
character_type, use_space_char)
def add_special_char(self, dict_character):
dict_character = dict_character + [self.beg_str, self.end_str]
return dict_character
def __call__(self, data):
text = data['label']
text = self.encode(text)
char_num = len(self.character_str)
if text is None:
return None
if len(text) > self.max_text_len:
return None
data['length'] = np.array(len(text))
text = text + [char_num] * (self.max_text_len - len(text))
data['label'] = np.array(text)
return data
def get_ignored_tokens(self):
beg_idx = self.get_beg_end_flag_idx("beg")
end_idx = self.get_beg_end_flag_idx("end")
return [beg_idx, end_idx]
def get_beg_end_flag_idx(self, beg_or_end):
if beg_or_end == "beg":
idx = np.array(self.dict[self.beg_str])
elif beg_or_end == "end":
idx = np.array(self.dict[self.end_str])
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end
return idx
...@@ -12,20 +12,6 @@ ...@@ -12,20 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math import math
import cv2 import cv2
import numpy as np import numpy as np
...@@ -77,6 +63,26 @@ class RecResizeImg(object): ...@@ -77,6 +63,26 @@ class RecResizeImg(object):
return data return data
class SRNRecResizeImg(object):
def __init__(self, image_shape, num_heads, max_text_length, **kwargs):
self.image_shape = image_shape
self.num_heads = num_heads
self.max_text_length = max_text_length
def __call__(self, data):
img = data['image']
norm_img = resize_norm_img_srn(img, self.image_shape)
data['image'] = norm_img
[encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
srn_other_inputs(self.image_shape, self.num_heads, self.max_text_length)
data['encoder_word_pos'] = encoder_word_pos
data['gsrm_word_pos'] = gsrm_word_pos
data['gsrm_slf_attn_bias1'] = gsrm_slf_attn_bias1
data['gsrm_slf_attn_bias2'] = gsrm_slf_attn_bias2
return data
def resize_norm_img(img, image_shape): def resize_norm_img(img, image_shape):
imgC, imgH, imgW = image_shape imgC, imgH, imgW = image_shape
h = img.shape[0] h = img.shape[0]
...@@ -103,7 +109,7 @@ def resize_norm_img(img, image_shape): ...@@ -103,7 +109,7 @@ def resize_norm_img(img, image_shape):
def resize_norm_img_chinese(img, image_shape): def resize_norm_img_chinese(img, image_shape):
imgC, imgH, imgW = image_shape imgC, imgH, imgW = image_shape
# todo: change to 0 and modified image shape # todo: change to 0 and modified image shape
max_wh_ratio = 0 max_wh_ratio = imgW * 1.0 / imgH
h, w = img.shape[0], img.shape[1] h, w = img.shape[0], img.shape[1]
ratio = w * 1.0 / h ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, ratio) max_wh_ratio = max(max_wh_ratio, ratio)
...@@ -126,6 +132,60 @@ def resize_norm_img_chinese(img, image_shape): ...@@ -126,6 +132,60 @@ def resize_norm_img_chinese(img, image_shape):
return padding_im return padding_im
def resize_norm_img_srn(img, image_shape):
imgC, imgH, imgW = image_shape
img_black = np.zeros((imgH, imgW))
im_hei = img.shape[0]
im_wid = img.shape[1]
if im_wid <= im_hei * 1:
img_new = cv2.resize(img, (imgH * 1, imgH))
elif im_wid <= im_hei * 2:
img_new = cv2.resize(img, (imgH * 2, imgH))
elif im_wid <= im_hei * 3:
img_new = cv2.resize(img, (imgH * 3, imgH))
else:
img_new = cv2.resize(img, (imgW, imgH))
img_np = np.asarray(img_new)
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
img_black[:, 0:img_np.shape[1]] = img_np
img_black = img_black[:, :, np.newaxis]
row, col, c = img_black.shape
c = 1
return np.reshape(img_black, (c, row, col)).astype(np.float32)
def srn_other_inputs(image_shape, num_heads, max_text_length):
imgC, imgH, imgW = image_shape
feature_dim = int((imgH / 8) * (imgW / 8))
encoder_word_pos = np.array(range(0, feature_dim)).reshape(
(feature_dim, 1)).astype('int64')
gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
(max_text_length, 1)).astype('int64')
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
[1, max_text_length, max_text_length])
gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1,
[num_heads, 1, 1]) * [-1e9]
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
[1, max_text_length, max_text_length])
gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2,
[num_heads, 1, 1]) * [-1e9]
return [
encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2
]
def flag(): def flag():
""" """
flag flag
......
...@@ -20,9 +20,9 @@ import cv2 ...@@ -20,9 +20,9 @@ import cv2
from .imaug import transform, create_operators from .imaug import transform, create_operators
class LMDBDateSet(Dataset): class LMDBDataSet(Dataset):
def __init__(self, config, mode, logger): def __init__(self, config, mode, logger, seed=None):
super(LMDBDateSet, self).__init__() super(LMDBDataSet, self).__init__()
global_config = config['Global'] global_config = config['Global']
dataset_config = config[mode]['dataset'] dataset_config = config[mode]['dataset']
......
...@@ -20,7 +20,7 @@ from .imaug import transform, create_operators ...@@ -20,7 +20,7 @@ from .imaug import transform, create_operators
class SimpleDataSet(Dataset): class SimpleDataSet(Dataset):
def __init__(self, config, mode, logger): def __init__(self, config, mode, logger, seed=None):
super(SimpleDataSet, self).__init__() super(SimpleDataSet, self).__init__()
self.logger = logger self.logger = logger
...@@ -41,6 +41,7 @@ class SimpleDataSet(Dataset): ...@@ -41,6 +41,7 @@ class SimpleDataSet(Dataset):
self.data_dir = dataset_config['data_dir'] self.data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle'] self.do_shuffle = loader_config['shuffle']
self.seed = seed
logger.info("Initialize indexs of datasets:%s" % label_file_list) logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list) self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
self.data_idx_order_list = list(range(len(self.data_lines))) self.data_idx_order_list = list(range(len(self.data_lines)))
...@@ -55,6 +56,7 @@ class SimpleDataSet(Dataset): ...@@ -55,6 +56,7 @@ class SimpleDataSet(Dataset):
for idx, file in enumerate(file_list): for idx, file in enumerate(file_list):
with open(file, "rb") as f: with open(file, "rb") as f:
lines = f.readlines() lines = f.readlines()
random.seed(self.seed)
lines = random.sample(lines, lines = random.sample(lines,
round(len(lines) * ratio_list[idx])) round(len(lines) * ratio_list[idx]))
data_lines.extend(lines) data_lines.extend(lines)
...@@ -62,6 +64,7 @@ class SimpleDataSet(Dataset): ...@@ -62,6 +64,7 @@ class SimpleDataSet(Dataset):
def shuffle_data_random(self): def shuffle_data_random(self):
if self.do_shuffle: if self.do_shuffle:
random.seed(self.seed)
random.shuffle(self.data_lines) random.shuffle(self.data_lines)
return return
......
...@@ -23,11 +23,14 @@ def build_loss(config): ...@@ -23,11 +23,14 @@ def build_loss(config):
# rec loss # rec loss
from .rec_ctc_loss import CTCLoss from .rec_ctc_loss import CTCLoss
from .rec_srn_loss import SRNLoss
# cls loss # cls loss
from .cls_loss import ClsLoss from .cls_loss import ClsLoss
support_dict = ['DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss'] support_dict = [
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'SRNLoss'
]
config = copy.deepcopy(config) config = copy.deepcopy(config)
module_name = config.pop('name') module_name = config.pop('name')
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
class SRNLoss(nn.Layer):
def __init__(self, **kwargs):
super(SRNLoss, self).__init__()
self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="sum")
def forward(self, predicts, batch):
predict = predicts['predict']
word_predict = predicts['word_out']
gsrm_predict = predicts['gsrm_out']
label = batch[1]
casted_label = paddle.cast(x=label, dtype='int64')
casted_label = paddle.reshape(x=casted_label, shape=[-1, 1])
cost_word = self.loss_func(word_predict, label=casted_label)
cost_gsrm = self.loss_func(gsrm_predict, label=casted_label)
cost_vsfd = self.loss_func(predict, label=casted_label)
cost_word = paddle.reshape(x=paddle.sum(cost_word), shape=[1])
cost_gsrm = paddle.reshape(x=paddle.sum(cost_gsrm), shape=[1])
cost_vsfd = paddle.reshape(x=paddle.sum(cost_vsfd), shape=[1])
sum_cost = cost_word * 3.0 + cost_vsfd + cost_gsrm * 0.15
return {'loss': sum_cost, 'word_loss': cost_word, 'img_loss': cost_vsfd}
...@@ -33,8 +33,6 @@ class RecMetric(object): ...@@ -33,8 +33,6 @@ class RecMetric(object):
if pred == target: if pred == target:
correct_num += 1 correct_num += 1
all_num += 1 all_num += 1
# if all_num < 10 and kwargs.get('show_str', False):
# print('{} -> {}'.format(pred, target))
self.correct_num += correct_num self.correct_num += correct_num
self.all_num += all_num self.all_num += all_num
self.norm_edit_dis += norm_edit_dis self.norm_edit_dis += norm_edit_dis
...@@ -50,7 +48,7 @@ class RecMetric(object): ...@@ -50,7 +48,7 @@ class RecMetric(object):
'norm_edit_dis': 0, 'norm_edit_dis': 0,
} }
""" """
acc = self.correct_num / self.all_num acc = 1.0 * self.correct_num / self.all_num
norm_edit_dis = 1 - self.norm_edit_dis / self.all_num norm_edit_dis = 1 - self.norm_edit_dis / self.all_num
self.reset() self.reset()
return {'acc': acc, 'norm_edit_dis': norm_edit_dis} return {'acc': acc, 'norm_edit_dis': norm_edit_dis}
......
...@@ -68,11 +68,14 @@ class BaseModel(nn.Layer): ...@@ -68,11 +68,14 @@ class BaseModel(nn.Layer):
config["Head"]['in_channels'] = in_channels config["Head"]['in_channels'] = in_channels
self.head = build_head(config["Head"]) self.head = build_head(config["Head"])
def forward(self, x): def forward(self, x, data=None):
if self.use_transform: if self.use_transform:
x = self.transform(x) x = self.transform(x)
x = self.backbone(x) x = self.backbone(x)
if self.use_neck: if self.use_neck:
x = self.neck(x) x = self.neck(x)
x = self.head(x) if data is None:
x = self.head(x)
else:
x = self.head(x, data)
return x return x
...@@ -24,7 +24,8 @@ def build_backbone(config, model_type): ...@@ -24,7 +24,8 @@ def build_backbone(config, model_type):
elif model_type == 'rec' or model_type == 'cls': elif model_type == 'rec' or model_type == 'cls':
from .rec_mobilenet_v3 import MobileNetV3 from .rec_mobilenet_v3 import MobileNetV3
from .rec_resnet_vd import ResNet from .rec_resnet_vd import ResNet
support_dict = ['MobileNetV3', 'ResNet', 'ResNet_FPN'] from .rec_resnet_fpn import ResNetFPN
support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN']
else: else:
raise NotImplementedError raise NotImplementedError
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment