Unverified Commit f6532a0e authored by andyjpaddle's avatar andyjpaddle Committed by GitHub
Browse files

add ppocrv3 rec (#6033)

* add ppocrv3 rec
parent 6902d160
...@@ -349,7 +349,10 @@ class ParallelSARDecoder(BaseDecoder): ...@@ -349,7 +349,10 @@ class ParallelSARDecoder(BaseDecoder):
class SARHead(nn.Layer): class SARHead(nn.Layer):
def __init__(self, def __init__(self,
in_channels,
out_channels, out_channels,
enc_dim=512,
max_text_length=30,
enc_bi_rnn=False, enc_bi_rnn=False,
enc_drop_rnn=0.1, enc_drop_rnn=0.1,
enc_gru=False, enc_gru=False,
...@@ -358,14 +361,17 @@ class SARHead(nn.Layer): ...@@ -358,14 +361,17 @@ class SARHead(nn.Layer):
dec_gru=False, dec_gru=False,
d_k=512, d_k=512,
pred_dropout=0.1, pred_dropout=0.1,
max_text_length=30,
pred_concat=True, pred_concat=True,
**kwargs): **kwargs):
super(SARHead, self).__init__() super(SARHead, self).__init__()
# encoder module # encoder module
self.encoder = SAREncoder( self.encoder = SAREncoder(
enc_bi_rnn=enc_bi_rnn, enc_drop_rnn=enc_drop_rnn, enc_gru=enc_gru) enc_bi_rnn=enc_bi_rnn,
enc_drop_rnn=enc_drop_rnn,
enc_gru=enc_gru,
d_model=in_channels,
d_enc=enc_dim)
# decoder module # decoder module
self.decoder = ParallelSARDecoder( self.decoder = ParallelSARDecoder(
...@@ -374,6 +380,8 @@ class SARHead(nn.Layer): ...@@ -374,6 +380,8 @@ class SARHead(nn.Layer):
dec_bi_rnn=dec_bi_rnn, dec_bi_rnn=dec_bi_rnn,
dec_drop_rnn=dec_drop_rnn, dec_drop_rnn=dec_drop_rnn,
dec_gru=dec_gru, dec_gru=dec_gru,
d_model=in_channels,
d_enc=enc_dim,
d_k=d_k, d_k=d_k,
pred_dropout=pred_dropout, pred_dropout=pred_dropout,
max_text_length=max_text_length, max_text_length=max_text_length,
...@@ -390,7 +398,7 @@ class SARHead(nn.Layer): ...@@ -390,7 +398,7 @@ class SARHead(nn.Layer):
label = paddle.to_tensor(label, dtype='int64') label = paddle.to_tensor(label, dtype='int64')
final_out = self.decoder( final_out = self.decoder(
feat, holistic_feat, label, img_metas=targets) feat, holistic_feat, label, img_metas=targets)
if not self.training: else:
final_out = self.decoder( final_out = self.decoder(
feat, feat,
holistic_feat, holistic_feat,
......
...@@ -16,9 +16,11 @@ from __future__ import absolute_import ...@@ -16,9 +16,11 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import paddle
from paddle import nn from paddle import nn
from ppocr.modeling.heads.rec_ctc_head import get_para_bias_attr from ppocr.modeling.heads.rec_ctc_head import get_para_bias_attr
from ppocr.modeling.backbones.rec_svtrnet import Block, ConvBNLayer, trunc_normal_, zeros_, ones_
class Im2Seq(nn.Layer): class Im2Seq(nn.Layer):
...@@ -64,29 +66,126 @@ class EncoderWithFC(nn.Layer): ...@@ -64,29 +66,126 @@ class EncoderWithFC(nn.Layer):
return x return x
class EncoderWithSVTR(nn.Layer):
def __init__(
self,
in_channels,
dims=64, # XS
depth=2,
hidden_dims=120,
use_guide=False,
num_heads=8,
qkv_bias=True,
mlp_ratio=2.0,
drop_rate=0.1,
attn_drop_rate=0.1,
drop_path=0.,
qk_scale=None):
super(EncoderWithSVTR, self).__init__()
self.depth = depth
self.use_guide = use_guide
self.conv1 = ConvBNLayer(
in_channels, in_channels // 8, padding=1, act=nn.Swish)
self.conv2 = ConvBNLayer(
in_channels // 8, hidden_dims, kernel_size=1, act=nn.Swish)
self.svtr_block = nn.LayerList([
Block(
dim=hidden_dims,
num_heads=num_heads,
mixer='Global',
HW=None,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
act_layer=nn.Swish,
attn_drop=attn_drop_rate,
drop_path=drop_path,
norm_layer='nn.LayerNorm',
epsilon=1e-05,
prenorm=False) for i in range(depth)
])
self.norm = nn.LayerNorm(hidden_dims, epsilon=1e-6)
self.conv3 = ConvBNLayer(
hidden_dims, in_channels, kernel_size=1, act=nn.Swish)
# last conv-nxn, the input is concat of input tensor and conv3 output tensor
self.conv4 = ConvBNLayer(
2 * in_channels, in_channels // 8, padding=1, act=nn.Swish)
self.conv1x1 = ConvBNLayer(
in_channels // 8, dims, kernel_size=1, act=nn.Swish)
self.out_channels = dims
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
zeros_(m.bias)
ones_(m.weight)
def forward(self, x):
# for use guide
if self.use_guide:
z = x.clone()
z.stop_gradient = True
else:
z = x
# for short cut
h = z
# reduce dim
z = self.conv1(z)
z = self.conv2(z)
# SVTR global block
B, C, H, W = z.shape
z = z.flatten(2).transpose([0, 2, 1])
for blk in self.svtr_block:
z = blk(z)
z = self.norm(z)
# last stage
z = z.reshape([0, H, W, C]).transpose([0, 3, 1, 2])
z = self.conv3(z)
z = paddle.concat((h, z), axis=1)
z = self.conv1x1(self.conv4(z))
return z
class SequenceEncoder(nn.Layer): class SequenceEncoder(nn.Layer):
def __init__(self, in_channels, encoder_type, hidden_size=48, **kwargs): def __init__(self, in_channels, encoder_type, hidden_size=48, **kwargs):
super(SequenceEncoder, self).__init__() super(SequenceEncoder, self).__init__()
self.encoder_reshape = Im2Seq(in_channels) self.encoder_reshape = Im2Seq(in_channels)
self.out_channels = self.encoder_reshape.out_channels self.out_channels = self.encoder_reshape.out_channels
self.encoder_type = encoder_type
if encoder_type == 'reshape': if encoder_type == 'reshape':
self.only_reshape = True self.only_reshape = True
else: else:
support_encoder_dict = { support_encoder_dict = {
'reshape': Im2Seq, 'reshape': Im2Seq,
'fc': EncoderWithFC, 'fc': EncoderWithFC,
'rnn': EncoderWithRNN 'rnn': EncoderWithRNN,
'svtr': EncoderWithSVTR
} }
assert encoder_type in support_encoder_dict, '{} must in {}'.format( assert encoder_type in support_encoder_dict, '{} must in {}'.format(
encoder_type, support_encoder_dict.keys()) encoder_type, support_encoder_dict.keys())
if encoder_type == "svtr":
self.encoder = support_encoder_dict[encoder_type]( self.encoder = support_encoder_dict[encoder_type](
self.encoder_reshape.out_channels, hidden_size) self.encoder_reshape.out_channels, **kwargs)
else:
self.encoder = support_encoder_dict[encoder_type](
self.encoder_reshape.out_channels, hidden_size)
self.out_channels = self.encoder.out_channels self.out_channels = self.encoder.out_channels
self.only_reshape = False self.only_reshape = False
def forward(self, x): def forward(self, x):
x = self.encoder_reshape(x) if self.encoder_type != 'svtr':
if not self.only_reshape: x = self.encoder_reshape(x)
if not self.only_reshape:
x = self.encoder(x)
return x
else:
x = self.encoder(x) x = self.encoder(x)
return x x = self.encoder_reshape(x)
return x
...@@ -41,7 +41,8 @@ def build_post_process(config, global_config=None): ...@@ -41,7 +41,8 @@ def build_post_process(config, global_config=None):
'PGPostProcess', 'DistillationCTCLabelDecode', 'TableLabelDecode', 'PGPostProcess', 'DistillationCTCLabelDecode', 'TableLabelDecode',
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode', 'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess', 'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode' 'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode',
'DistillationSARLabelDecode'
] ]
if config['name'] == 'PSEPostProcess': if config['name'] == 'PSEPostProcess':
......
...@@ -117,6 +117,7 @@ class DistillationCTCLabelDecode(CTCLabelDecode): ...@@ -117,6 +117,7 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
use_space_char=False, use_space_char=False,
model_name=["student"], model_name=["student"],
key=None, key=None,
multi_head=False,
**kwargs): **kwargs):
super(DistillationCTCLabelDecode, self).__init__(character_dict_path, super(DistillationCTCLabelDecode, self).__init__(character_dict_path,
use_space_char) use_space_char)
...@@ -125,6 +126,7 @@ class DistillationCTCLabelDecode(CTCLabelDecode): ...@@ -125,6 +126,7 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
self.model_name = model_name self.model_name = model_name
self.key = key self.key = key
self.multi_head = multi_head
def __call__(self, preds, label=None, *args, **kwargs): def __call__(self, preds, label=None, *args, **kwargs):
output = dict() output = dict()
...@@ -132,6 +134,8 @@ class DistillationCTCLabelDecode(CTCLabelDecode): ...@@ -132,6 +134,8 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
pred = preds[name] pred = preds[name]
if self.key is not None: if self.key is not None:
pred = pred[self.key] pred = pred[self.key]
if self.multi_head and isinstance(pred, dict):
pred = pred['ctc']
output[name] = super().__call__(pred, label=label, *args, **kwargs) output[name] = super().__call__(pred, label=label, *args, **kwargs)
return output return output
...@@ -656,6 +660,40 @@ class SARLabelDecode(BaseRecLabelDecode): ...@@ -656,6 +660,40 @@ class SARLabelDecode(BaseRecLabelDecode):
return [self.padding_idx] return [self.padding_idx]
class DistillationSARLabelDecode(SARLabelDecode):
"""
Convert
Convert between text-label and text-index
"""
def __init__(self,
character_dict_path=None,
use_space_char=False,
model_name=["student"],
key=None,
multi_head=False,
**kwargs):
super(DistillationSARLabelDecode, self).__init__(character_dict_path,
use_space_char)
if not isinstance(model_name, list):
model_name = [model_name]
self.model_name = model_name
self.key = key
self.multi_head = multi_head
def __call__(self, preds, label=None, *args, **kwargs):
output = dict()
for name in self.model_name:
pred = preds[name]
if self.key is not None:
pred = pred[self.key]
if self.multi_head and isinstance(pred, dict):
pred = pred['sar']
output[name] = super().__call__(pred, label=label, *args, **kwargs)
return output
class PRENLabelDecode(BaseRecLabelDecode): class PRENLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """ """ Convert between text-label and text-index """
......
...@@ -47,14 +47,38 @@ def main(): ...@@ -47,14 +47,38 @@ def main():
if config['Architecture']["algorithm"] in ["Distillation", if config['Architecture']["algorithm"] in ["Distillation",
]: # distillation model ]: # distillation model
for key in config['Architecture']["Models"]: for key in config['Architecture']["Models"]:
config['Architecture']["Models"][key]["Head"][ if config['Architecture']['Models'][key]['Head'][
'out_channels'] = char_num 'name'] == 'MultiHead': # for multi head
out_channels_list = {}
if config['PostProcess'][
'name'] == 'DistillationSARLabelDecode':
char_num = char_num - 2
out_channels_list['CTCLabelDecode'] = char_num
out_channels_list['SARLabelDecode'] = char_num + 2
config['Architecture']['Models'][key]['Head'][
'out_channels_list'] = out_channels_list
else:
config['Architecture']["Models"][key]["Head"][
'out_channels'] = char_num
elif config['Architecture']['Head'][
'name'] == 'MultiHead': # for multi head
out_channels_list = {}
if config['PostProcess']['name'] == 'SARLabelDecode':
char_num = char_num - 2
out_channels_list['CTCLabelDecode'] = char_num
out_channels_list['SARLabelDecode'] = char_num + 2
config['Architecture']['Head'][
'out_channels_list'] = out_channels_list
else: # base rec model else: # base rec model
config['Architecture']["Head"]['out_channels'] = char_num config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
extra_input = config['Architecture'][ extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"]
'algorithm'] in ["SRN", "NRTR", "SAR", "SEED"] if config['Architecture']['algorithm'] == 'Distillation':
extra_input = config['Architecture']['Models']['Teacher'][
'algorithm'] in extra_input_models
else:
extra_input = config['Architecture']['algorithm'] in extra_input_models
if "model_type" in config['Architecture'].keys(): if "model_type" in config['Architecture'].keys():
model_type = config['Architecture']['model_type'] model_type = config['Architecture']['model_type']
else: else:
......
...@@ -55,6 +55,13 @@ def export_single_model(model, arch_config, save_path, logger): ...@@ -55,6 +55,13 @@ def export_single_model(model, arch_config, save_path, logger):
shape=[None, 3, 48, 160], dtype="float32"), shape=[None, 3, 48, 160], dtype="float32"),
] ]
model = to_static(model, input_spec=other_shape) model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "SVTR":
if arch_config["Head"]["name"] == 'MultiHead':
other_shape = [
paddle.static.InputSpec(
shape=[None, 3, 48, -1], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "PREN": elif arch_config["algorithm"] == "PREN":
other_shape = [ other_shape = [
paddle.static.InputSpec( paddle.static.InputSpec(
...@@ -105,13 +112,36 @@ def main(): ...@@ -105,13 +112,36 @@ def main():
if config["Architecture"]["algorithm"] in ["Distillation", if config["Architecture"]["algorithm"] in ["Distillation",
]: # distillation model ]: # distillation model
for key in config["Architecture"]["Models"]: for key in config["Architecture"]["Models"]:
config["Architecture"]["Models"][key]["Head"][ if config["Architecture"]["Models"][key]["Head"][
"out_channels"] = char_num "name"] == 'MultiHead': # multi head
out_channels_list = {}
if config['PostProcess'][
'name'] == 'DistillationSARLabelDecode':
char_num = char_num - 2
out_channels_list['CTCLabelDecode'] = char_num
out_channels_list['SARLabelDecode'] = char_num + 2
loss_list = config['Loss']['loss_config_list']
config['Architecture']['Models'][key]['Head'][
'out_channels_list'] = out_channels_list
else:
config["Architecture"]["Models"][key]["Head"][
"out_channels"] = char_num
# just one final tensor needs to to exported for inference # just one final tensor needs to to exported for inference
config["Architecture"]["Models"][key][ config["Architecture"]["Models"][key][
"return_all_feats"] = False "return_all_feats"] = False
elif config['Architecture']['Head'][
'name'] == 'MultiHead': # multi head
out_channels_list = {}
char_num = len(getattr(post_process_class, 'character'))
if config['PostProcess']['name'] == 'SARLabelDecode':
char_num = char_num - 2
out_channels_list['CTCLabelDecode'] = char_num
out_channels_list['SARLabelDecode'] = char_num + 2
config['Architecture']['Head'][
'out_channels_list'] = out_channels_list
else: # base rec model else: # base rec model
config["Architecture"]["Head"]["out_channels"] = char_num config["Architecture"]["Head"]["out_channels"] = char_num
model = build_model(config["Architecture"]) model = build_model(config["Architecture"])
load_model(config, model) load_model(config, model)
model.eval() model.eval()
......
...@@ -107,7 +107,7 @@ class TextRecognizer(object): ...@@ -107,7 +107,7 @@ class TextRecognizer(object):
return norm_img.astype(np.float32) / 128. - 1. return norm_img.astype(np.float32) / 128. - 1.
assert imgC == img.shape[2] assert imgC == img.shape[2]
imgW = int((32 * max_wh_ratio)) imgW = int((imgH * max_wh_ratio))
if self.use_onnx: if self.use_onnx:
w = self.input_tensor.shape[3:][0] w = self.input_tensor.shape[3:][0]
if w is not None and w > 0: if w is not None and w > 0:
...@@ -255,7 +255,9 @@ class TextRecognizer(object): ...@@ -255,7 +255,9 @@ class TextRecognizer(object):
for beg_img_no in range(0, img_num, batch_num): for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num) end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = [] norm_img_batch = []
max_wh_ratio = 0 imgC, imgH, imgW = self.rec_image_shape
max_wh_ratio = imgW / imgH
# max_wh_ratio = 0
for ino in range(beg_img_no, end_img_no): for ino in range(beg_img_no, end_img_no):
h, w = img_list[indices[ino]].shape[0:2] h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h wh_ratio = w * 1.0 / h
......
...@@ -51,8 +51,28 @@ def main(): ...@@ -51,8 +51,28 @@ def main():
if config['Architecture']["algorithm"] in ["Distillation", if config['Architecture']["algorithm"] in ["Distillation",
]: # distillation model ]: # distillation model
for key in config['Architecture']["Models"]: for key in config['Architecture']["Models"]:
config['Architecture']["Models"][key]["Head"][ if config['Architecture']['Models'][key]['Head'][
'out_channels'] = char_num 'name'] == 'MultiHead': # for multi head
out_channels_list = {}
if config['PostProcess'][
'name'] == 'DistillationSARLabelDecode':
char_num = char_num - 2
out_channels_list['CTCLabelDecode'] = char_num
out_channels_list['SARLabelDecode'] = char_num + 2
config['Architecture']['Models'][key]['Head'][
'out_channels_list'] = out_channels_list
else:
config['Architecture']["Models"][key]["Head"][
'out_channels'] = char_num
elif config['Architecture']['Head'][
'name'] == 'MultiHead': # for multi head loss
out_channels_list = {}
if config['PostProcess']['name'] == 'SARLabelDecode':
char_num = char_num - 2
out_channels_list['CTCLabelDecode'] = char_num
out_channels_list['SARLabelDecode'] = char_num + 2
config['Architecture']['Head'][
'out_channels_list'] = out_channels_list
else: # base rec model else: # base rec model
config['Architecture']["Head"]['out_channels'] = char_num config['Architecture']["Head"]['out_channels'] = char_num
......
...@@ -201,12 +201,17 @@ def train(config, ...@@ -201,12 +201,17 @@ def train(config,
model.train() model.train()
use_srn = config['Architecture']['algorithm'] == "SRN" use_srn = config['Architecture']['algorithm'] == "SRN"
extra_input = config['Architecture'][ extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"]
'algorithm'] in ["SRN", "NRTR", "SAR", "SEED"] if config['Architecture']['algorithm'] == 'Distillation':
extra_input = config['Architecture']['Models']['Teacher'][
'algorithm'] in extra_input_models
else:
extra_input = config['Architecture']['algorithm'] in extra_input_models
try: try:
model_type = config['Architecture']['model_type'] model_type = config['Architecture']['model_type']
except: except:
model_type = None model_type = None
algorithm = config['Architecture']['algorithm'] algorithm = config['Architecture']['algorithm']
start_epoch = best_model_dict[ start_epoch = best_model_dict[
...@@ -269,7 +274,12 @@ def train(config, ...@@ -269,7 +274,12 @@ def train(config,
if model_type in ['table', 'kie']: if model_type in ['table', 'kie']:
eval_class(preds, batch) eval_class(preds, batch)
else: else:
post_result = post_process_class(preds, batch[1]) if config['Loss']['name'] in ['MultiLoss', 'MultiLoss_v2'
]: # for multi head loss
post_result = post_process_class(
preds['ctc'], batch[1]) # for CTC head out
else:
post_result = post_process_class(preds, batch[1])
eval_class(post_result, batch) eval_class(post_result, batch)
metric = eval_class.get_metric() metric = eval_class.get_metric()
train_stats.update(metric) train_stats.update(metric)
...@@ -541,7 +551,7 @@ def preprocess(is_train=False): ...@@ -541,7 +551,7 @@ def preprocess(is_train=False):
assert alg in [ assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE' 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR'
] ]
device = 'cpu' device = 'cpu'
......
...@@ -74,11 +74,49 @@ def main(config, device, logger, vdl_writer): ...@@ -74,11 +74,49 @@ def main(config, device, logger, vdl_writer):
if config['Architecture']["algorithm"] in ["Distillation", if config['Architecture']["algorithm"] in ["Distillation",
]: # distillation model ]: # distillation model
for key in config['Architecture']["Models"]: for key in config['Architecture']["Models"]:
config['Architecture']["Models"][key]["Head"][ if config['Architecture']['Models'][key]['Head'][
'out_channels'] = char_num 'name'] == 'MultiHead': # for multi head
if config['PostProcess'][
'name'] == 'DistillationSARLabelDecode':
char_num = char_num - 2
# update SARLoss params
assert list(config['Loss']['loss_config_list'][-1].keys())[
0] == 'DistillationSARLoss'
config['Loss']['loss_config_list'][-1][
'DistillationSARLoss']['ignore_index'] = char_num + 1
out_channels_list = {}
out_channels_list['CTCLabelDecode'] = char_num
out_channels_list['SARLabelDecode'] = char_num + 2
config['Architecture']['Models'][key]['Head'][
'out_channels_list'] = out_channels_list
else:
config['Architecture']["Models"][key]["Head"][
'out_channels'] = char_num
elif config['Architecture']['Head'][
'name'] == 'MultiHead': # for multi head
if config['PostProcess']['name'] == 'SARLabelDecode':
char_num = char_num - 2
# update SARLoss params
assert list(config['Loss']['loss_config_list'][1].keys())[
0] == 'SARLoss'
if config['Loss']['loss_config_list'][1]['SARLoss'] is None:
config['Loss']['loss_config_list'][1]['SARLoss'] = {
'ignore_index': char_num + 1
}
else:
config['Loss']['loss_config_list'][1]['SARLoss'][
'ignore_index'] = char_num + 1
out_channels_list = {}
out_channels_list['CTCLabelDecode'] = char_num
out_channels_list['SARLabelDecode'] = char_num + 2
config['Architecture']['Head'][
'out_channels_list'] = out_channels_list
else: # base rec model else: # base rec model
config['Architecture']["Head"]['out_channels'] = char_num config['Architecture']["Head"]['out_channels'] = char_num
if config['PostProcess']['name'] == 'SARLabelDecode': # for SAR model
config['Loss']['ignore_index'] = char_num - 1
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
if config['Global']['distributed']: if config['Global']['distributed']:
model = paddle.DataParallel(model) model = paddle.DataParallel(model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment