Commit c9e1077d authored by tink2123's avatar tink2123
Browse files

polish code

parent 59cc4efd
...@@ -170,10 +170,8 @@ class AttnLabelDecode(BaseRecLabelDecode): ...@@ -170,10 +170,8 @@ class AttnLabelDecode(BaseRecLabelDecode):
def add_special_char(self, dict_character): def add_special_char(self, dict_character):
self.beg_str = "sos" self.beg_str = "sos"
self.end_str = "eos" self.end_str = "eos"
self.unkonwn = "UNKNOWN"
dict_character = dict_character dict_character = dict_character
dict_character = [self.beg_str] + dict_character + [self.end_str dict_character = [self.beg_str] + dict_character + [self.end_str]
] + [self.unkonwn]
return dict_character return dict_character
def decode(self, text_index, text_prob=None, is_remove_duplicate=False): def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
...@@ -214,7 +212,6 @@ class AttnLabelDecode(BaseRecLabelDecode): ...@@ -214,7 +212,6 @@ class AttnLabelDecode(BaseRecLabelDecode):
label = self.decode(label, is_remove_duplicate=False) label = self.decode(label, is_remove_duplicate=False)
return text, label return text, label
""" """
preds = preds["rec_pred"]
if isinstance(preds, paddle.Tensor): if isinstance(preds, paddle.Tensor):
preds = preds.numpy() preds = preds.numpy()
...@@ -242,6 +239,88 @@ class AttnLabelDecode(BaseRecLabelDecode): ...@@ -242,6 +239,88 @@ class AttnLabelDecode(BaseRecLabelDecode):
return idx return idx
class SEEDLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self,
character_dict_path=None,
character_type='ch',
use_space_char=False,
**kwargs):
super(SEEDLabelDecode, self).__init__(character_dict_path,
character_type, use_space_char)
def add_special_char(self, dict_character):
self.beg_str = "sos"
self.end_str = "eos"
dict_character = dict_character
dict_character = dict_character + [self.end_str]
return dict_character
def get_ignored_tokens(self):
end_idx = self.get_beg_end_flag_idx("eos")
return [end_idx]
def get_beg_end_flag_idx(self, beg_or_end):
if beg_or_end == "sos":
idx = np.array(self.dict[self.beg_str])
elif beg_or_end == "eos":
idx = np.array(self.dict[self.end_str])
else:
assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end
return idx
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
[end_idx] = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
conf_list = []
for idx in range(len(text_index[batch_idx])):
if int(text_index[batch_idx][idx]) == int(end_idx):
break
if is_remove_duplicate:
# only for predict
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
batch_idx][idx]:
continue
char_list.append(self.character[int(text_index[batch_idx][
idx])])
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
text = ''.join(char_list)
result_list.append((text, np.mean(conf_list)))
return result_list
def __call__(self, preds, label=None, *args, **kwargs):
"""
text = self.decode(text)
if label is None:
return text
else:
label = self.decode(label, is_remove_duplicate=False)
return text, label
"""
preds_idx = preds["rec_pred"]
if isinstance(preds_idx, paddle.Tensor):
preds_idx = preds_idx.numpy()
if "rec_pred_scores" in preds:
preds_idx = preds["rec_pred"]
preds_prob = preds["rec_pred_scores"]
else:
preds_idx = preds["rec_pred"].argmax(axis=2)
preds_prob = preds["rec_pred"].max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
if label is None:
return text
label = self.decode(label, is_remove_duplicate=False)
return text, label
class SRNLabelDecode(BaseRecLabelDecode): class SRNLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """ """ Convert between text-label and text-index """
......
...@@ -105,15 +105,12 @@ def load_dygraph_params(config, model, logger, optimizer): ...@@ -105,15 +105,12 @@ def load_dygraph_params(config, model, logger, optimizer):
params = paddle.load(pm) params = paddle.load(pm)
state_dict = model.state_dict() state_dict = model.state_dict()
new_state_dict = {} new_state_dict = {}
# for k1, k2 in zip(state_dict.keys(), params.keys()): for k1, k2 in zip(state_dict.keys(), params.keys()):
for k1 in state_dict.keys(): if list(state_dict[k1].shape) == list(params[k2].shape):
if k1 not in params: new_state_dict[k1] = params[k2]
continue
if list(state_dict[k1].shape) == list(params[k1].shape):
new_state_dict[k1] = params[k1]
else: else:
logger.info( logger.info(
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k1} {params[k1].shape} !" f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
) )
model.set_state_dict(new_state_dict) model.set_state_dict(new_state_dict)
logger.info(f"loaded pretrained_model successful from {pm}") logger.info(f"loaded pretrained_model successful from {pm}")
......
...@@ -211,11 +211,10 @@ def train(config, ...@@ -211,11 +211,10 @@ def train(config,
images = batch[0] images = batch[0]
if use_srn: if use_srn:
model_average = True model_average = True
# if use_srn or model_type == 'table' or algorithm == "ASTER": if use_srn or model_type == 'table' or model_type == "seed":
# preds = model(images, data=batch[1:])
# else:
# preds = model(images)
preds = model(images, data=batch[1:]) preds = model(images, data=batch[1:])
else:
preds = model(images)
state_dict = model.state_dict() state_dict = model.state_dict()
# for key in state_dict: # for key in state_dict:
# print(key) # print(key)
...@@ -415,6 +414,7 @@ def preprocess(is_train=False): ...@@ -415,6 +414,7 @@ def preprocess(is_train=False):
yaml.dump( yaml.dump(
dict(config), f, default_flow_style=False, sort_keys=False) dict(config), f, default_flow_style=False, sort_keys=False)
log_file = '{}/train.log'.format(save_model_dir) log_file = '{}/train.log'.format(save_model_dir)
print("log has save in {}/train.log".format(save_model_dir))
else: else:
log_file = None log_file = None
logger = get_logger(name='root', log_file=log_file) logger = get_logger(name='root', log_file=log_file)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment