Commit c9e1077d authored by tink2123's avatar tink2123
Browse files

polish code

parent 59cc4efd
......@@ -170,10 +170,8 @@ class AttnLabelDecode(BaseRecLabelDecode):
def add_special_char(self, dict_character):
self.beg_str = "sos"
self.end_str = "eos"
self.unkonwn = "UNKNOWN"
dict_character = dict_character
dict_character = [self.beg_str] + dict_character + [self.end_str
] + [self.unkonwn]
dict_character = [self.beg_str] + dict_character + [self.end_str]
return dict_character
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
......@@ -214,7 +212,6 @@ class AttnLabelDecode(BaseRecLabelDecode):
label = self.decode(label, is_remove_duplicate=False)
return text, label
"""
preds = preds["rec_pred"]
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
......@@ -242,6 +239,88 @@ class AttnLabelDecode(BaseRecLabelDecode):
return idx
class SEEDLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self,
character_dict_path=None,
character_type='ch',
use_space_char=False,
**kwargs):
super(SEEDLabelDecode, self).__init__(character_dict_path,
character_type, use_space_char)
def add_special_char(self, dict_character):
self.beg_str = "sos"
self.end_str = "eos"
dict_character = dict_character
dict_character = dict_character + [self.end_str]
return dict_character
def get_ignored_tokens(self):
end_idx = self.get_beg_end_flag_idx("eos")
return [end_idx]
def get_beg_end_flag_idx(self, beg_or_end):
if beg_or_end == "sos":
idx = np.array(self.dict[self.beg_str])
elif beg_or_end == "eos":
idx = np.array(self.dict[self.end_str])
else:
assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end
return idx
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
[end_idx] = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
conf_list = []
for idx in range(len(text_index[batch_idx])):
if int(text_index[batch_idx][idx]) == int(end_idx):
break
if is_remove_duplicate:
# only for predict
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
batch_idx][idx]:
continue
char_list.append(self.character[int(text_index[batch_idx][
idx])])
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
text = ''.join(char_list)
result_list.append((text, np.mean(conf_list)))
return result_list
def __call__(self, preds, label=None, *args, **kwargs):
"""
text = self.decode(text)
if label is None:
return text
else:
label = self.decode(label, is_remove_duplicate=False)
return text, label
"""
preds_idx = preds["rec_pred"]
if isinstance(preds_idx, paddle.Tensor):
preds_idx = preds_idx.numpy()
if "rec_pred_scores" in preds:
preds_idx = preds["rec_pred"]
preds_prob = preds["rec_pred_scores"]
else:
preds_idx = preds["rec_pred"].argmax(axis=2)
preds_prob = preds["rec_pred"].max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
if label is None:
return text
label = self.decode(label, is_remove_duplicate=False)
return text, label
class SRNLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
......
......@@ -105,16 +105,13 @@ def load_dygraph_params(config, model, logger, optimizer):
params = paddle.load(pm)
state_dict = model.state_dict()
new_state_dict = {}
# for k1, k2 in zip(state_dict.keys(), params.keys()):
for k1 in state_dict.keys():
if k1 not in params:
continue
if list(state_dict[k1].shape) == list(params[k1].shape):
new_state_dict[k1] = params[k1]
else:
logger.info(
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k1} {params[k1].shape} !"
)
for k1, k2 in zip(state_dict.keys(), params.keys()):
if list(state_dict[k1].shape) == list(params[k2].shape):
new_state_dict[k1] = params[k2]
else:
logger.info(
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
)
model.set_state_dict(new_state_dict)
logger.info(f"loaded pretrained_model successful from {pm}")
return {}
......
......@@ -211,11 +211,10 @@ def train(config,
images = batch[0]
if use_srn:
model_average = True
# if use_srn or model_type == 'table' or algorithm == "ASTER":
# preds = model(images, data=batch[1:])
# else:
# preds = model(images)
preds = model(images, data=batch[1:])
if use_srn or model_type == 'table' or model_type == "seed":
preds = model(images, data=batch[1:])
else:
preds = model(images)
state_dict = model.state_dict()
# for key in state_dict:
# print(key)
......@@ -415,6 +414,7 @@ def preprocess(is_train=False):
yaml.dump(
dict(config), f, default_flow_style=False, sort_keys=False)
log_file = '{}/train.log'.format(save_model_dir)
print("log has save in {}/train.log".format(save_model_dir))
else:
log_file = None
logger = get_logger(name='root', log_file=log_file)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment