"...git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "e2b84da86655eec1b03f320c5a75174864355eba"
Commit 38fc1fae authored by WenmuZhou's avatar WenmuZhou
Browse files

add max_text_length to SRNLabelDecode

parent 88c6ad8a
...@@ -218,6 +218,7 @@ class SRNLabelDecode(BaseRecLabelDecode): ...@@ -218,6 +218,7 @@ class SRNLabelDecode(BaseRecLabelDecode):
**kwargs): **kwargs):
super(SRNLabelDecode, self).__init__(character_dict_path, super(SRNLabelDecode, self).__init__(character_dict_path,
character_type, use_space_char) character_type, use_space_char)
self.max_text_length = kwargs.get('max_text_length', 25)
def __call__(self, preds, label=None, *args, **kwargs): def __call__(self, preds, label=None, *args, **kwargs):
pred = preds['predict'] pred = preds['predict']
...@@ -229,9 +230,9 @@ class SRNLabelDecode(BaseRecLabelDecode): ...@@ -229,9 +230,9 @@ class SRNLabelDecode(BaseRecLabelDecode):
preds_idx = np.argmax(pred, axis=1) preds_idx = np.argmax(pred, axis=1)
preds_prob = np.max(pred, axis=1) preds_prob = np.max(pred, axis=1)
preds_idx = np.reshape(preds_idx, [-1, 25]) preds_idx = np.reshape(preds_idx, [-1, self.max_text_length])
preds_prob = np.reshape(preds_prob, [-1, 25]) preds_prob = np.reshape(preds_prob, [-1, self.max_text_length])
text = self.decode(preds_idx, preds_prob) text = self.decode(preds_idx, preds_prob)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment