Unverified Commit 6ebbbfe4 authored by xiaoting's avatar xiaoting Committed by GitHub
Browse files

Merge branch 'dygraph' into dygraph_for_srn

parents f2144375 fad40158
......@@ -65,6 +65,7 @@ def build_dataloader(config, mode, device, logger, seed=None):
loader_config = config[mode]['loader']
batch_size = loader_config['batch_size_per_card']
drop_last = loader_config['drop_last']
shuffle = loader_config['shuffle']
num_workers = loader_config['num_workers']
if 'use_shared_memory' in loader_config.keys():
use_shared_memory = loader_config['use_shared_memory']
......@@ -75,14 +76,14 @@ def build_dataloader(config, mode, device, logger, seed=None):
batch_sampler = DistributedBatchSampler(
dataset=dataset,
batch_size=batch_size,
shuffle=False,
shuffle=shuffle,
drop_last=drop_last)
else:
#Distribute data to single card
batch_sampler = BatchSampler(
dataset=dataset,
batch_size=batch_size,
shuffle=False,
shuffle=shuffle,
drop_last=drop_last)
data_loader = DataLoader(
......
......@@ -18,6 +18,7 @@ from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import string
class ClsLabelEncode(object):
......@@ -92,7 +93,10 @@ class BaseRecLabelEncode(object):
character_type='ch',
use_space_char=False):
support_character_type = [
'ch', 'en', 'en_sensitive', 'french', 'german', 'japan', 'korean'
'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
'EN', 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs',
'oc', 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi',
'mr', 'ne'
]
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
support_character_type, character_type)
......@@ -103,9 +107,14 @@ class BaseRecLabelEncode(object):
if character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
elif character_type in ["ch", "french", "german", "japan", "korean"]:
elif character_type == "EN_symbol":
# same with ASTER setting (use 94 char).
self.character_str = string.printable[:-6]
dict_character = list(self.character_str)
elif character_type in support_character_type:
self.character_str = ""
assert character_dict_path is not None, "character_dict_path should not be None when character_type is ch"
assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format(
character_type)
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
......@@ -114,11 +123,6 @@ class BaseRecLabelEncode(object):
if use_space_char:
self.character_str += " "
dict_character = list(self.character_str)
elif character_type == "en_sensitive":
# same with ASTER setting (use 94 char).
import string
self.character_str = string.printable[:-6]
dict_character = list(self.character_str)
self.character_type = character_type
dict_character = self.add_special_char(dict_character)
self.dict = {}
......
......@@ -58,15 +58,15 @@ class MobileNetV3(nn.Layer):
[5, 72, 40, True, 'relu', 2],
[5, 120, 40, True, 'relu', 1],
[5, 120, 40, True, 'relu', 1],
[3, 240, 80, False, 'hard_swish', 2],
[3, 200, 80, False, 'hard_swish', 1],
[3, 184, 80, False, 'hard_swish', 1],
[3, 184, 80, False, 'hard_swish', 1],
[3, 480, 112, True, 'hard_swish', 1],
[3, 672, 112, True, 'hard_swish', 1],
[5, 672, 160, True, 'hard_swish', 2],
[5, 960, 160, True, 'hard_swish', 1],
[5, 960, 160, True, 'hard_swish', 1],
[3, 240, 80, False, 'hardswish', 2],
[3, 200, 80, False, 'hardswish', 1],
[3, 184, 80, False, 'hardswish', 1],
[3, 184, 80, False, 'hardswish', 1],
[3, 480, 112, True, 'hardswish', 1],
[3, 672, 112, True, 'hardswish', 1],
[5, 672, 160, True, 'hardswish', 2],
[5, 960, 160, True, 'hardswish', 1],
[5, 960, 160, True, 'hardswish', 1],
]
cls_ch_squeeze = 960
elif model_name == "small":
......@@ -75,14 +75,14 @@ class MobileNetV3(nn.Layer):
[3, 16, 16, True, 'relu', 2],
[3, 72, 24, False, 'relu', 2],
[3, 88, 24, False, 'relu', 1],
[5, 96, 40, True, 'hard_swish', 2],
[5, 240, 40, True, 'hard_swish', 1],
[5, 240, 40, True, 'hard_swish', 1],
[5, 120, 48, True, 'hard_swish', 1],
[5, 144, 48, True, 'hard_swish', 1],
[5, 288, 96, True, 'hard_swish', 2],
[5, 576, 96, True, 'hard_swish', 1],
[5, 576, 96, True, 'hard_swish', 1],
[5, 96, 40, True, 'hardswish', 2],
[5, 240, 40, True, 'hardswish', 1],
[5, 240, 40, True, 'hardswish', 1],
[5, 120, 48, True, 'hardswish', 1],
[5, 144, 48, True, 'hardswish', 1],
[5, 288, 96, True, 'hardswish', 2],
[5, 576, 96, True, 'hardswish', 1],
[5, 576, 96, True, 'hardswish', 1],
]
cls_ch_squeeze = 576
else:
......@@ -102,7 +102,7 @@ class MobileNetV3(nn.Layer):
padding=1,
groups=1,
if_act=True,
act='hard_swish',
act='hardswish',
name='conv1')
self.stages = []
......@@ -112,7 +112,8 @@ class MobileNetV3(nn.Layer):
inplanes = make_divisible(inplanes * scale)
for (k, exp, c, se, nl, s) in cfg:
se = se and not self.disable_se
if s == 2 and i > 2:
start_idx = 2 if model_name == 'large' else 0
if s == 2 and i > start_idx:
self.out_channels.append(inplanes)
self.stages.append(nn.Sequential(*block_list))
block_list = []
......@@ -137,7 +138,7 @@ class MobileNetV3(nn.Layer):
padding=0,
groups=1,
if_act=True,
act='hard_swish',
act='hardswish',
name='conv_last'))
self.stages.append(nn.Sequential(*block_list))
self.out_channels.append(make_divisible(scale * cls_ch_squeeze))
......@@ -191,10 +192,11 @@ class ConvBNLayer(nn.Layer):
if self.if_act:
if self.act == "relu":
x = F.relu(x)
elif self.act == "hard_swish":
x = F.activation.hard_swish(x)
elif self.act == "hardswish":
x = F.hardswish(x)
else:
print("The activation function is selected incorrectly.")
print("The activation function({}) is selected incorrectly.".
format(self.act))
exit()
return x
......@@ -281,5 +283,5 @@ class SEModule(nn.Layer):
outputs = self.conv1(outputs)
outputs = F.relu(outputs)
outputs = self.conv2(outputs)
outputs = F.activation.hard_sigmoid(outputs)
outputs = F.hardsigmoid(outputs, slope=0.2, offset=0.5)
return inputs * outputs
......@@ -51,15 +51,15 @@ class MobileNetV3(nn.Layer):
[5, 72, 40, True, 'relu', (large_stride[2], 1)],
[5, 120, 40, True, 'relu', 1],
[5, 120, 40, True, 'relu', 1],
[3, 240, 80, False, 'hard_swish', 1],
[3, 200, 80, False, 'hard_swish', 1],
[3, 184, 80, False, 'hard_swish', 1],
[3, 184, 80, False, 'hard_swish', 1],
[3, 480, 112, True, 'hard_swish', 1],
[3, 672, 112, True, 'hard_swish', 1],
[5, 672, 160, True, 'hard_swish', (large_stride[3], 1)],
[5, 960, 160, True, 'hard_swish', 1],
[5, 960, 160, True, 'hard_swish', 1],
[3, 240, 80, False, 'hardswish', 1],
[3, 200, 80, False, 'hardswish', 1],
[3, 184, 80, False, 'hardswish', 1],
[3, 184, 80, False, 'hardswish', 1],
[3, 480, 112, True, 'hardswish', 1],
[3, 672, 112, True, 'hardswish', 1],
[5, 672, 160, True, 'hardswish', (large_stride[3], 1)],
[5, 960, 160, True, 'hardswish', 1],
[5, 960, 160, True, 'hardswish', 1],
]
cls_ch_squeeze = 960
elif model_name == "small":
......@@ -68,14 +68,14 @@ class MobileNetV3(nn.Layer):
[3, 16, 16, True, 'relu', (small_stride[0], 1)],
[3, 72, 24, False, 'relu', (small_stride[1], 1)],
[3, 88, 24, False, 'relu', 1],
[5, 96, 40, True, 'hard_swish', (small_stride[2], 1)],
[5, 240, 40, True, 'hard_swish', 1],
[5, 240, 40, True, 'hard_swish', 1],
[5, 120, 48, True, 'hard_swish', 1],
[5, 144, 48, True, 'hard_swish', 1],
[5, 288, 96, True, 'hard_swish', (small_stride[3], 1)],
[5, 576, 96, True, 'hard_swish', 1],
[5, 576, 96, True, 'hard_swish', 1],
[5, 96, 40, True, 'hardswish', (small_stride[2], 1)],
[5, 240, 40, True, 'hardswish', 1],
[5, 240, 40, True, 'hardswish', 1],
[5, 120, 48, True, 'hardswish', 1],
[5, 144, 48, True, 'hardswish', 1],
[5, 288, 96, True, 'hardswish', (small_stride[3], 1)],
[5, 576, 96, True, 'hardswish', 1],
[5, 576, 96, True, 'hardswish', 1],
]
cls_ch_squeeze = 576
else:
......@@ -96,7 +96,7 @@ class MobileNetV3(nn.Layer):
padding=1,
groups=1,
if_act=True,
act='hard_swish',
act='hardswish',
name='conv1')
i = 0
block_list = []
......@@ -124,7 +124,7 @@ class MobileNetV3(nn.Layer):
padding=0,
groups=1,
if_act=True,
act='hard_swish',
act='hardswish',
name='conv_last')
self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import string
import paddle
from paddle.nn import functional as F
......@@ -24,9 +25,10 @@ class BaseRecLabelDecode(object):
character_type='ch',
use_space_char=False):
support_character_type = [
'ch', 'en', 'en_sensitive', 'french', 'german', 'japan', 'korean', 'it',
'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'oc', 'rsc', 'bg',
'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', 'mr', 'ne'
'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'oc',
'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', 'mr',
'ne', 'EN'
]
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
support_character_type, character_type)
......@@ -37,9 +39,14 @@ class BaseRecLabelDecode(object):
if character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
elif character_type in ["ch", "french", "german", "japan", "korean"]:
elif character_type == "EN_symbol":
# same with ASTER setting (use 94 char).
self.character_str = string.printable[:-6]
dict_character = list(self.character_str)
elif character_type in support_character_type:
self.character_str = ""
assert character_dict_path is not None, "character_dict_path should not be None when character_type is ch"
assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format(
character_type)
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
......@@ -48,11 +55,7 @@ class BaseRecLabelDecode(object):
if use_space_char:
self.character_str += " "
dict_character = list(self.character_str)
elif character_type == "en_sensitive":
# same with ASTER setting (use 94 char).
import string
self.character_str = string.printable[:-6]
dict_character = list(self.character_str)
else:
raise NotImplementedError
self.character_type = character_type
......
......@@ -75,10 +75,17 @@ def main():
]
]
model = to_static(model, input_spec=other_shape)
else:
infer_shape = [3, 32, 100] if config['Architecture'][
'model_type'] != "det" else [3, 640, 640]
infer_shape = [3, -1, -1]
if config['Architecture']['model_type'] == "rec":
infer_shape = [3, 32, -1] # for rec model, H must be 32
if 'Transform' in config['Architecture'] and config['Architecture'][
'Transform'] is not None and config['Architecture'][
'Transform']['name'] == 'TPS':
logger.info(
'When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training'
)
infer_shape[-1] = 100
model = to_static(
model,
input_spec=[
......
......@@ -70,7 +70,7 @@ def parse_args():
default="./ppocr/utils/ppocr_keys_v1.txt")
parser.add_argument("--use_space_char", type=str2bool, default=True)
parser.add_argument(
"--vis_font_path", type=str, default="./doc/simfang.ttf")
"--vis_font_path", type=str, default="./doc/fonts/simfang.ttf")
parser.add_argument("--drop_score", type=float, default=0.5)
# params for text classifier
......
......@@ -218,7 +218,7 @@ def train(config,
stats['lr'] = lr
train_stats.update(stats)
if cal_metric_during_train: # onlt rec and cls need
if cal_metric_during_train: # only rec and cls need
batch = [item.numpy() for item in batch]
post_result = post_process_class(preds, batch[1])
eval_class(post_result, batch)
......@@ -253,19 +253,19 @@ def train(config,
Model_Average.apply()
cur_metirc = eval(model, valid_dataloader, post_process_class,
eval_class)
cur_metirc_str = 'cur metirc, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in cur_metirc.items()]))
logger.info(cur_metirc_str)
cur_metric_str = 'cur metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
logger.info(cur_metric_str)
# logger metric
if vdl_writer is not None:
for k, v in cur_metirc.items():
for k, v in cur_metric.items():
if isinstance(v, (float, int)):
vdl_writer.add_scalar('EVAL/{}'.format(k),
cur_metirc[k], global_step)
if cur_metirc[main_indicator] >= best_model_dict[
cur_metric[k], global_step)
if cur_metric[main_indicator] >= best_model_dict[
main_indicator]:
best_model_dict.update(cur_metirc)
best_model_dict.update(cur_metric)
best_model_dict['best_epoch'] = epoch
save_model(
model,
......@@ -276,7 +276,7 @@ def train(config,
prefix='best_accuracy',
best_model_dict=best_model_dict,
epoch=epoch)
best_str = 'best metirc, {}'.format(', '.join([
best_str = 'best metric, {}'.format(', '.join([
'{}: {}'.format(k, v) for k, v in best_model_dict.items()
]))
logger.info(best_str)
......@@ -308,7 +308,7 @@ def train(config,
prefix='iter_epoch_{}'.format(epoch),
best_model_dict=best_model_dict,
epoch=epoch)
best_str = 'best metirc, {}'.format(', '.join(
best_str = 'best metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
logger.info(best_str)
if dist.get_rank() == 0 and vdl_writer is not None:
......@@ -338,13 +338,13 @@ def eval(model, valid_dataloader, post_process_class, eval_class):
eval_class(post_result, batch)
pbar.update(1)
total_frame += len(images)
# Get final metirc,eg. acc or hmean
metirc = eval_class.get_metric()
# Get final metric,eg. acc or hmean
metric = eval_class.get_metric()
pbar.close()
model.train()
metirc['fps'] = total_frame / total_time
return metirc
metric['fps'] = total_frame / total_time
return metric
def preprocess(is_train=False):
......
# for paddle.__version__ >= 2.0rc1
# recommended paddle.__version__ == 2.0.0
python3 -m paddle.distributed.launch --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml
# for paddle.__version__ < 2.0rc1
# python3 -m paddle.distributed.launch --selected_gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment