Commit 07b6d635 authored by Leif's avatar Leif
Browse files

Merge remote-tracking branch 'upstream/dygraph' into dy1

parents ee9c1bcf 3ce97f18
...@@ -44,21 +44,34 @@ class MakeShrinkMap(object): ...@@ -44,21 +44,34 @@ class MakeShrinkMap(object):
ignore_tags[i] = True ignore_tags[i] = True
else: else:
polygon_shape = Polygon(polygon) polygon_shape = Polygon(polygon)
distance = polygon_shape.area * ( subject = [tuple(l) for l in polygon]
1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length
subject = [tuple(l) for l in text_polys[i]]
padding = pyclipper.PyclipperOffset() padding = pyclipper.PyclipperOffset()
padding.AddPath(subject, pyclipper.JT_ROUND, padding.AddPath(subject, pyclipper.JT_ROUND,
pyclipper.ET_CLOSEDPOLYGON) pyclipper.ET_CLOSEDPOLYGON)
shrinked = padding.Execute(-distance) shrinked = []
# Increase the shrink ratio every time we get multiple polygon returned back
possible_ratios = np.arange(self.shrink_ratio, 1,
self.shrink_ratio)
np.append(possible_ratios, 1)
# print(possible_ratios)
for ratio in possible_ratios:
# print(f"Change shrink ratio to {ratio}")
distance = polygon_shape.area * (
1 - np.power(ratio, 2)) / polygon_shape.length
shrinked = padding.Execute(-distance)
if len(shrinked) == 1:
break
if shrinked == []: if shrinked == []:
cv2.fillPoly(mask, cv2.fillPoly(mask,
polygon.astype(np.int32)[np.newaxis, :, :], 0) polygon.astype(np.int32)[np.newaxis, :, :], 0)
ignore_tags[i] = True ignore_tags[i] = True
continue continue
shrinked = np.array(shrinked[0]).reshape(-1, 2)
cv2.fillPoly(gt, [shrinked.astype(np.int32)], 1) for each_shirnk in shrinked:
# cv2.fillPoly(gt[0], [shrinked.astype(np.int32)], 1) shirnk = np.array(each_shirnk).reshape(-1, 2)
cv2.fillPoly(gt, [shirnk.astype(np.int32)], 1)
data['shrink_map'] = gt data['shrink_map'] = gt
data['shrink_mask'] = mask data['shrink_mask'] = mask
...@@ -84,11 +97,12 @@ class MakeShrinkMap(object): ...@@ -84,11 +97,12 @@ class MakeShrinkMap(object):
return polygons, ignore_tags return polygons, ignore_tags
def polygon_area(self, polygon): def polygon_area(self, polygon):
# return cv2.contourArea(polygon.astype(np.float32)) """
edge = 0 compute polygon area
for i in range(polygon.shape[0]): """
next_index = (i + 1) % polygon.shape[0] area = 0
edge += (polygon[next_index, 0] - polygon[i, 0]) * ( q = polygon[-1]
polygon[next_index, 1] - polygon[i, 1]) for p in polygon:
area += p[0] * q[1] - p[1] * q[0]
return edge / 2. q = p
return area / 2.0
...@@ -185,8 +185,8 @@ class DetResizeForTest(object): ...@@ -185,8 +185,8 @@ class DetResizeForTest(object):
resize_h = int(h * ratio) resize_h = int(h * ratio)
resize_w = int(w * ratio) resize_w = int(w * ratio)
resize_h = int(round(resize_h / 32) * 32) resize_h = max(int(round(resize_h / 32) * 32), 32)
resize_w = int(round(resize_w / 32) * 32) resize_w = max(int(round(resize_w / 32) * 32), 32)
try: try:
if int(resize_w) <= 0 or int(resize_h) <= 0: if int(resize_w) <= 0 or int(resize_h) <= 0:
......
...@@ -23,13 +23,15 @@ def build_loss(config): ...@@ -23,13 +23,15 @@ def build_loss(config):
# rec loss # rec loss
from .rec_ctc_loss import CTCLoss from .rec_ctc_loss import CTCLoss
from .rec_att_loss import AttentionLoss
from .rec_srn_loss import SRNLoss from .rec_srn_loss import SRNLoss
# cls loss # cls loss
from .cls_loss import ClsLoss from .cls_loss import ClsLoss
support_dict = [ support_dict = [
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'SRNLoss' 'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
'SRNLoss'
] ]
config = copy.deepcopy(config) config = copy.deepcopy(config)
......
...@@ -200,6 +200,6 @@ def ohem_batch(scores, gt_texts, training_masks, ohem_ratio): ...@@ -200,6 +200,6 @@ def ohem_batch(scores, gt_texts, training_masks, ohem_ratio):
i, :, :], ohem_ratio)) i, :, :], ohem_ratio))
selected_masks = np.concatenate(selected_masks, 0) selected_masks = np.concatenate(selected_masks, 0)
selected_masks = paddle.to_variable(selected_masks) selected_masks = paddle.to_tensor(selected_masks)
return selected_masks return selected_masks
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
class AttentionLoss(nn.Layer):
def __init__(self, **kwargs):
super(AttentionLoss, self).__init__()
self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none')
def forward(self, predicts, batch):
targets = batch[1].astype("int64")
label_lengths = batch[2].astype('int64')
batch_size, num_steps, num_classes = predicts.shape[0], predicts.shape[
1], predicts.shape[2]
assert len(targets.shape) == len(list(predicts.shape)) - 1, \
"The target's shape and inputs's shape is [N, d] and [N, num_steps]"
inputs = paddle.reshape(predicts, [-1, predicts.shape[-1]])
targets = paddle.reshape(targets, [-1])
return {'loss': paddle.sum(self.loss_func(inputs, targets))}
...@@ -29,7 +29,7 @@ class RecMetric(object): ...@@ -29,7 +29,7 @@ class RecMetric(object):
pred = pred.replace(" ", "") pred = pred.replace(" ", "")
target = target.replace(" ", "") target = target.replace(" ", "")
norm_edit_dis += Levenshtein.distance(pred, target) / max( norm_edit_dis += Levenshtein.distance(pred, target) / max(
len(pred), len(target)) len(pred), len(target), 1)
if pred == target: if pred == target:
correct_num += 1 correct_num += 1
all_num += 1 all_num += 1
......
...@@ -23,12 +23,14 @@ def build_head(config): ...@@ -23,12 +23,14 @@ def build_head(config):
# rec head # rec head
from .rec_ctc_head import CTCHead from .rec_ctc_head import CTCHead
from .rec_att_head import AttentionHead
from .rec_srn_head import SRNHead from .rec_srn_head import SRNHead
# cls head # cls head
from .cls_head import ClsHead from .cls_head import ClsHead
support_dict = [ support_dict = [
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'SRNHead' 'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
'SRNHead'
] ]
module_name = config.pop('name') module_name = config.pop('name')
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import numpy as np
class AttentionHead(nn.Layer):
def __init__(self, in_channels, out_channels, hidden_size, **kwargs):
super(AttentionHead, self).__init__()
self.input_size = in_channels
self.hidden_size = hidden_size
self.num_classes = out_channels
self.attention_cell = AttentionGRUCell(
in_channels, hidden_size, out_channels, use_gru=False)
self.generator = nn.Linear(hidden_size, out_channels)
def _char_to_onehot(self, input_char, onehot_dim):
input_ont_hot = F.one_hot(input_char, onehot_dim)
return input_ont_hot
def forward(self, inputs, targets=None, batch_max_length=25):
batch_size = inputs.shape[0]
num_steps = batch_max_length
hidden = paddle.zeros((batch_size, self.hidden_size))
output_hiddens = []
if targets is not None:
for i in range(num_steps):
char_onehots = self._char_to_onehot(
targets[:, i], onehot_dim=self.num_classes)
(outputs, hidden), alpha = self.attention_cell(hidden, inputs,
char_onehots)
output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
output = paddle.concat(output_hiddens, axis=1)
probs = self.generator(output)
else:
targets = paddle.zeros(shape=[batch_size], dtype="int32")
probs = None
char_onehots = None
outputs = None
alpha = None
for i in range(num_steps):
char_onehots = self._char_to_onehot(
targets, onehot_dim=self.num_classes)
(outputs, hidden), alpha = self.attention_cell(hidden, inputs,
char_onehots)
probs_step = self.generator(outputs)
if probs is None:
probs = paddle.unsqueeze(probs_step, axis=1)
else:
probs = paddle.concat(
[probs, paddle.unsqueeze(
probs_step, axis=1)], axis=1)
next_input = probs_step.argmax(axis=1)
targets = next_input
return probs
class AttentionGRUCell(nn.Layer):
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
super(AttentionGRUCell, self).__init__()
self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
self.h2h = nn.Linear(hidden_size, hidden_size)
self.score = nn.Linear(hidden_size, 1, bias_attr=False)
self.rnn = nn.GRUCell(
input_size=input_size + num_embeddings, hidden_size=hidden_size)
self.hidden_size = hidden_size
def forward(self, prev_hidden, batch_H, char_onehots):
batch_H_proj = self.i2h(batch_H)
prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden), axis=1)
res = paddle.add(batch_H_proj, prev_hidden_proj)
res = paddle.tanh(res)
e = self.score(res)
alpha = F.softmax(e, axis=1)
alpha = paddle.transpose(alpha, [0, 2, 1])
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
concat_context = paddle.concat([context, char_onehots], 1)
cur_hidden = self.rnn(concat_context, prev_hidden)
return cur_hidden, alpha
class AttentionLSTM(nn.Layer):
def __init__(self, in_channels, out_channels, hidden_size, **kwargs):
super(AttentionLSTM, self).__init__()
self.input_size = in_channels
self.hidden_size = hidden_size
self.num_classes = out_channels
self.attention_cell = AttentionLSTMCell(
in_channels, hidden_size, out_channels, use_gru=False)
self.generator = nn.Linear(hidden_size, out_channels)
def _char_to_onehot(self, input_char, onehot_dim):
input_ont_hot = F.one_hot(input_char, onehot_dim)
return input_ont_hot
def forward(self, inputs, targets=None, batch_max_length=25):
batch_size = inputs.shape[0]
num_steps = batch_max_length
hidden = (paddle.zeros((batch_size, self.hidden_size)), paddle.zeros(
(batch_size, self.hidden_size)))
output_hiddens = []
if targets is not None:
for i in range(num_steps):
# one-hot vectors for a i-th char
char_onehots = self._char_to_onehot(
targets[:, i], onehot_dim=self.num_classes)
hidden, alpha = self.attention_cell(hidden, inputs,
char_onehots)
hidden = (hidden[1][0], hidden[1][1])
output_hiddens.append(paddle.unsqueeze(hidden[0], axis=1))
output = paddle.concat(output_hiddens, axis=1)
probs = self.generator(output)
else:
targets = paddle.zeros(shape=[batch_size], dtype="int32")
probs = None
for i in range(num_steps):
char_onehots = self._char_to_onehot(
targets, onehot_dim=self.num_classes)
hidden, alpha = self.attention_cell(hidden, inputs,
char_onehots)
probs_step = self.generator(hidden[0])
hidden = (hidden[1][0], hidden[1][1])
if probs is None:
probs = paddle.unsqueeze(probs_step, axis=1)
else:
probs = paddle.concat(
[probs, paddle.unsqueeze(
probs_step, axis=1)], axis=1)
next_input = probs_step.argmax(axis=1)
targets = next_input
return probs
class AttentionLSTMCell(nn.Layer):
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
super(AttentionLSTMCell, self).__init__()
self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
self.h2h = nn.Linear(hidden_size, hidden_size)
self.score = nn.Linear(hidden_size, 1, bias_attr=False)
if not use_gru:
self.rnn = nn.LSTMCell(
input_size=input_size + num_embeddings, hidden_size=hidden_size)
else:
self.rnn = nn.GRUCell(
input_size=input_size + num_embeddings, hidden_size=hidden_size)
self.hidden_size = hidden_size
def forward(self, prev_hidden, batch_H, char_onehots):
batch_H_proj = self.i2h(batch_H)
prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden[0]), axis=1)
res = paddle.add(batch_H_proj, prev_hidden_proj)
res = paddle.tanh(res)
e = self.score(res)
alpha = F.softmax(e, axis=1)
alpha = paddle.transpose(alpha, [0, 2, 1])
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
concat_context = paddle.concat([context, char_onehots], 1)
cur_hidden = self.rnn(concat_context, prev_hidden)
return cur_hidden, alpha
...@@ -246,7 +246,7 @@ class SRNHead(nn.Layer): ...@@ -246,7 +246,7 @@ class SRNHead(nn.Layer):
num_encoder_tus=self.num_encoder_TUs, num_encoder_tus=self.num_encoder_TUs,
num_decoder_tus=self.num_decoder_TUs, num_decoder_tus=self.num_decoder_TUs,
hidden_dims=self.hidden_dims) hidden_dims=self.hidden_dims)
self.vsfd = VSFD(in_channels=in_channels) self.vsfd = VSFD(in_channels=in_channels, char_num=self.char_num)
self.gsrm.wrap_encoder1.prepare_decoder.emb0 = self.gsrm.wrap_encoder0.prepare_decoder.emb0 self.gsrm.wrap_encoder1.prepare_decoder.emb0 = self.gsrm.wrap_encoder0.prepare_decoder.emb0
......
...@@ -135,16 +135,62 @@ class AttnLabelDecode(BaseRecLabelDecode): ...@@ -135,16 +135,62 @@ class AttnLabelDecode(BaseRecLabelDecode):
**kwargs): **kwargs):
super(AttnLabelDecode, self).__init__(character_dict_path, super(AttnLabelDecode, self).__init__(character_dict_path,
character_type, use_space_char) character_type, use_space_char)
self.beg_str = "sos"
self.end_str = "eos"
def add_special_char(self, dict_character): def add_special_char(self, dict_character):
dict_character = [self.beg_str, self.end_str] + dict_character self.beg_str = "sos"
self.end_str = "eos"
dict_character = dict_character
dict_character = [self.beg_str] + dict_character + [self.end_str]
return dict_character return dict_character
def __call__(self, text): def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
ignored_tokens = self.get_ignored_tokens()
[beg_idx, end_idx] = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
conf_list = []
for idx in range(len(text_index[batch_idx])):
if text_index[batch_idx][idx] in ignored_tokens:
continue
if int(text_index[batch_idx][idx]) == int(end_idx):
break
if is_remove_duplicate:
# only for predict
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
batch_idx][idx]:
continue
char_list.append(self.character[int(text_index[batch_idx][
idx])])
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
text = ''.join(char_list)
result_list.append((text, np.mean(conf_list)))
return result_list
def __call__(self, preds, label=None, *args, **kwargs):
"""
text = self.decode(text) text = self.decode(text)
return text if label is None:
return text
else:
label = self.decode(label, is_remove_duplicate=False)
return text, label
"""
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
if label is None:
return text
label = self.decode(label, is_remove_duplicate=False)
return text, label
def get_ignored_tokens(self): def get_ignored_tokens(self):
beg_idx = self.get_beg_end_flag_idx("beg") beg_idx = self.get_beg_end_flag_idx("beg")
......
...@@ -47,6 +47,7 @@ def main(): ...@@ -47,6 +47,7 @@ def main():
config['Architecture']["Head"]['out_channels'] = len( config['Architecture']["Head"]['out_channels'] = len(
getattr(post_process_class, 'character')) getattr(post_process_class, 'character'))
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
use_srn = config['Architecture']['algorithm'] == "SRN"
best_model_dict = init_model(config, model, logger) best_model_dict = init_model(config, model, logger)
if len(best_model_dict): if len(best_model_dict):
...@@ -59,7 +60,7 @@ def main(): ...@@ -59,7 +60,7 @@ def main():
# start eval # start eval
metirc = program.eval(model, valid_dataloader, post_process_class, metirc = program.eval(model, valid_dataloader, post_process_class,
eval_class) eval_class, use_srn)
logger.info('metric eval ***************') logger.info('metric eval ***************')
for k, v in metirc.items(): for k, v in metirc.items():
logger.info('{}:{}'.format(k, v)) logger.info('{}:{}'.format(k, v))
......
...@@ -64,7 +64,7 @@ class TextDetector(object): ...@@ -64,7 +64,7 @@ class TextDetector(object):
postprocess_params["box_thresh"] = args.det_db_box_thresh postprocess_params["box_thresh"] = args.det_db_box_thresh
postprocess_params["max_candidates"] = 1000 postprocess_params["max_candidates"] = 1000
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
postprocess_params["use_dilation"] = True postprocess_params["use_dilation"] = args.use_dilation
elif self.det_algorithm == "EAST": elif self.det_algorithm == "EAST":
postprocess_params['name'] = 'EASTPostProcess' postprocess_params['name'] = 'EASTPostProcess'
postprocess_params["score_thresh"] = args.det_east_score_thresh postprocess_params["score_thresh"] = args.det_east_score_thresh
......
...@@ -54,6 +54,13 @@ class TextRecognizer(object): ...@@ -54,6 +54,13 @@ class TextRecognizer(object):
"character_dict_path": args.rec_char_dict_path, "character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char "use_space_char": args.use_space_char
} }
elif self.rec_algorithm == "RARE":
postprocess_params = {
'name': 'AttnLabelDecode',
"character_type": args.rec_char_type,
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
self.postprocess_op = build_post_process(postprocess_params) self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors = \ self.predictor, self.input_tensor, self.output_tensors = \
utility.create_predictor(args, 'rec', logger) utility.create_predictor(args, 'rec', logger)
...@@ -241,9 +248,11 @@ class TextRecognizer(object): ...@@ -241,9 +248,11 @@ class TextRecognizer(object):
def main(args): def main(args):
image_file_list = get_image_file_list(args.image_dir) image_file_list = get_image_file_list(args.image_dir)
text_recognizer = TextRecognizer(args) text_recognizer = TextRecognizer(args)
total_run_time = 0.0
total_images_num = 0
valid_image_file_list = [] valid_image_file_list = []
img_list = [] img_list = []
for image_file in image_file_list: for idx, image_file in enumerate(image_file_list):
img, flag = check_and_read_gif(image_file) img, flag = check_and_read_gif(image_file)
if not flag: if not flag:
img = cv2.imread(image_file) img = cv2.imread(image_file)
...@@ -252,22 +261,29 @@ def main(args): ...@@ -252,22 +261,29 @@ def main(args):
continue continue
valid_image_file_list.append(image_file) valid_image_file_list.append(image_file)
img_list.append(img) img_list.append(img)
try: if len(img_list) >= args.rec_batch_num or idx == len(
rec_res, predict_time = text_recognizer(img_list) image_file_list) - 1:
except: try:
logger.info(traceback.format_exc()) rec_res, predict_time = text_recognizer(img_list)
logger.info( total_run_time += predict_time
"ERROR!!!! \n" except:
"Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n" logger.info(traceback.format_exc())
"If your model has tps module: " logger.info(
"TPS does not support variable shape.\n" "ERROR!!!! \n"
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ") "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
exit() "If your model has tps module: "
for ino in range(len(img_list)): "TPS does not support variable shape.\n"
logger.info("Predicts of {}:{}".format(valid_image_file_list[ino], "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' "
rec_res[ino])) )
exit()
for ino in range(len(img_list)):
logger.info("Predicts of {}:{}".format(valid_image_file_list[
ino], rec_res[ino]))
total_images_num += len(valid_image_file_list)
valid_image_file_list = []
img_list = []
logger.info("Total predict time for {} images, cost: {:.3f}".format( logger.info("Total predict time for {} images, cost: {:.3f}".format(
len(img_list), predict_time)) total_images_num, total_run_time))
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -184,4 +184,4 @@ def main(args): ...@@ -184,4 +184,4 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
main(utility.parse_args()) main(utility.parse_args())
\ No newline at end of file
...@@ -47,6 +47,7 @@ def parse_args(): ...@@ -47,6 +47,7 @@ def parse_args():
parser.add_argument("--det_db_box_thresh", type=float, default=0.5) parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6) parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6)
parser.add_argument("--max_batch_size", type=int, default=10) parser.add_argument("--max_batch_size", type=int, default=10)
parser.add_argument("--use_dilation", type=bool, default=False)
# EAST parmas # EAST parmas
parser.add_argument("--det_east_score_thresh", type=float, default=0.8) parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
...@@ -123,6 +124,8 @@ def create_predictor(args, mode, logger): ...@@ -123,6 +124,8 @@ def create_predictor(args, mode, logger):
# cache 10 different shapes for mkldnn to avoid memory leak # cache 10 different shapes for mkldnn to avoid memory leak
config.set_mkldnn_cache_capacity(10) config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn() config.enable_mkldnn()
# TODO LDOUBLEV: fix mkldnn bug when bach_size > 1
#config.set_mkldnn_op({'conv2d', 'depthwise_conv2d', 'pool2d', 'batch_norm'})
args.rec_batch_num = 1 args.rec_batch_num = 1
# config.enable_memory_optim() # config.enable_memory_optim()
......
...@@ -163,6 +163,11 @@ def train(config, ...@@ -163,6 +163,11 @@ def train(config,
if type(eval_batch_step) == list and len(eval_batch_step) >= 2: if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
start_eval_step = eval_batch_step[0] start_eval_step = eval_batch_step[0]
eval_batch_step = eval_batch_step[1] eval_batch_step = eval_batch_step[1]
if len(valid_dataloader) == 0:
logger.info(
'No Images in eval dataset, evaluation during training will be disabled'
)
start_eval_step = 1e111
logger.info( logger.info(
"During the training process, after the {}th iteration, an evaluation is run every {} iterations". "During the training process, after the {}th iteration, an evaluation is run every {} iterations".
format(start_eval_step, eval_batch_step)) format(start_eval_step, eval_batch_step))
...@@ -177,6 +182,8 @@ def train(config, ...@@ -177,6 +182,8 @@ def train(config,
model_average = False model_average = False
model.train() model.train()
use_srn = config['Architecture']['algorithm'] == "SRN"
if 'start_epoch' in best_model_dict: if 'start_epoch' in best_model_dict:
start_epoch = best_model_dict['start_epoch'] start_epoch = best_model_dict['start_epoch']
else: else:
...@@ -195,7 +202,7 @@ def train(config, ...@@ -195,7 +202,7 @@ def train(config,
break break
lr = optimizer.get_lr() lr = optimizer.get_lr()
images = batch[0] images = batch[0]
if config['Architecture']['algorithm'] == "SRN": if use_srn:
others = batch[-4:] others = batch[-4:]
preds = model(images, others) preds = model(images, others)
model_average = True model_average = True
...@@ -222,8 +229,8 @@ def train(config, ...@@ -222,8 +229,8 @@ def train(config,
batch = [item.numpy() for item in batch] batch = [item.numpy() for item in batch]
post_result = post_process_class(preds, batch[1]) post_result = post_process_class(preds, batch[1])
eval_class(post_result, batch) eval_class(post_result, batch)
metirc = eval_class.get_metric() metric = eval_class.get_metric()
train_stats.update(metirc) train_stats.update(metric)
if vdl_writer is not None and dist.get_rank() == 0: if vdl_writer is not None and dist.get_rank() == 0:
for k, v in train_stats.get().items(): for k, v in train_stats.get().items():
...@@ -251,8 +258,12 @@ def train(config, ...@@ -251,8 +258,12 @@ def train(config,
min_average_window=10000, min_average_window=10000,
max_average_window=15625) max_average_window=15625)
Model_Average.apply() Model_Average.apply()
cur_metirc = eval(model, valid_dataloader, post_process_class, cur_metric = eval(
eval_class) model,
valid_dataloader,
post_process_class,
eval_class,
use_srn=use_srn)
cur_metric_str = 'cur metric, {}'.format(', '.join( cur_metric_str = 'cur metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in cur_metric.items()])) ['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
logger.info(cur_metric_str) logger.info(cur_metric_str)
...@@ -316,7 +327,8 @@ def train(config, ...@@ -316,7 +327,8 @@ def train(config,
return return
def eval(model, valid_dataloader, post_process_class, eval_class): def eval(model, valid_dataloader, post_process_class, eval_class,
use_srn=False):
model.eval() model.eval()
with paddle.no_grad(): with paddle.no_grad():
total_frame = 0.0 total_frame = 0.0
...@@ -326,9 +338,13 @@ def eval(model, valid_dataloader, post_process_class, eval_class): ...@@ -326,9 +338,13 @@ def eval(model, valid_dataloader, post_process_class, eval_class):
if idx >= len(valid_dataloader): if idx >= len(valid_dataloader):
break break
images = batch[0] images = batch[0]
others = batch[-4:]
start = time.time() start = time.time()
preds = model(images, others)
if use_srn:
others = batch[-4:]
preds = model(images, others)
else:
preds = model(images)
batch = [item.numpy() for item in batch] batch = [item.numpy() for item in batch]
# Obtain usable results from post-processing methods # Obtain usable results from post-processing methods
...@@ -378,6 +394,7 @@ def preprocess(is_train=False): ...@@ -378,6 +394,7 @@ def preprocess(is_train=False):
logger = get_logger(name='root', log_file=log_file) logger = get_logger(name='root', log_file=log_file)
if config['Global']['use_visualdl']: if config['Global']['use_visualdl']:
from visualdl import LogWriter from visualdl import LogWriter
save_model_dir = config['Global']['save_model_dir']
vdl_writer_path = '{}/vdl/'.format(save_model_dir) vdl_writer_path = '{}/vdl/'.format(save_model_dir)
os.makedirs(vdl_writer_path, exist_ok=True) os.makedirs(vdl_writer_path, exist_ok=True)
vdl_writer = LogWriter(logdir=vdl_writer_path) vdl_writer = LogWriter(logdir=vdl_writer_path)
......
...@@ -50,6 +50,12 @@ def main(config, device, logger, vdl_writer): ...@@ -50,6 +50,12 @@ def main(config, device, logger, vdl_writer):
# build dataloader # build dataloader
train_dataloader = build_dataloader(config, 'Train', device, logger) train_dataloader = build_dataloader(config, 'Train', device, logger)
if len(train_dataloader) == 0:
logger.error(
'No Images in train dataset, please check annotation file and path in the configuration file'
)
return
if config['Eval']: if config['Eval']:
valid_dataloader = build_dataloader(config, 'Eval', device, logger) valid_dataloader = build_dataloader(config, 'Eval', device, logger)
else: else:
......
# recommended paddle.__version__ == 2.0.0 # recommended paddle.__version__ == 2.0.0
python3 -m paddle.distributed.launch --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml python3 -m paddle.distributed.launch --log_dir=./debug/ --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment