Commit 7230bfe3 authored by myhloli's avatar myhloli
Browse files

refactor: add DonutSwin model implementation and enhance character decoding logic

parent 8f0cc148
...@@ -20,6 +20,7 @@ def build_backbone(config, model_type): ...@@ -20,6 +20,7 @@ def build_backbone(config, model_type):
from .det_mobilenet_v3 import MobileNetV3 from .det_mobilenet_v3 import MobileNetV3
from .rec_hgnet import PPHGNet_small from .rec_hgnet import PPHGNet_small
from .rec_lcnetv3 import PPLCNetV3 from .rec_lcnetv3 import PPLCNetV3
from .rec_pphgnetv2 import PPHGNetV2_B4
support_dict = [ support_dict = [
"MobileNetV3", "MobileNetV3",
...@@ -28,6 +29,7 @@ def build_backbone(config, model_type): ...@@ -28,6 +29,7 @@ def build_backbone(config, model_type):
"ResNet_SAST", "ResNet_SAST",
"PPLCNetV3", "PPLCNetV3",
"PPHGNet_small", "PPHGNet_small",
'PPHGNetV2_B4',
] ]
elif model_type == "rec" or model_type == "cls": elif model_type == "rec" or model_type == "cls":
from .rec_hgnet import PPHGNet_small from .rec_hgnet import PPHGNet_small
......
...@@ -9,28 +9,28 @@ class Im2Seq(nn.Module): ...@@ -9,28 +9,28 @@ class Im2Seq(nn.Module):
super().__init__() super().__init__()
self.out_channels = in_channels self.out_channels = in_channels
# def forward(self, x):
# B, C, H, W = x.shape
# # assert H == 1
# x = x.squeeze(dim=2)
# # x = x.transpose([0, 2, 1]) # paddle (NTC)(batch, width, channels)
# x = x.permute(0, 2, 1)
# return x
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape B, C, H, W = x.shape
# 处理四维张量,将空间维度展平为序列 # assert H == 1
if H == 1: x = x.squeeze(dim=2)
# 原来的处理逻辑,适用于H=1的情况 # x = x.transpose([0, 2, 1]) # paddle (NTC)(batch, width, channels)
x = x.squeeze(dim=2) x = x.permute(0, 2, 1)
x = x.permute(0, 2, 1) # (B, W, C)
else:
# 处理H不为1的情况
x = x.permute(0, 2, 3, 1) # (B, H, W, C)
x = x.reshape(B, H * W, C) # (B, H*W, C)
return x return x
# def forward(self, x):
# B, C, H, W = x.shape
# # 处理四维张量,将空间维度展平为序列
# if H == 1:
# # 原来的处理逻辑,适用于H=1的情况
# x = x.squeeze(dim=2)
# x = x.permute(0, 2, 1) # (B, W, C)
# else:
# # 处理H不为1的情况
# x = x.permute(0, 2, 3, 1) # (B, H, W, C)
# x = x.reshape(B, H * W, C) # (B, H*W, C)
#
# return x
class EncoderWithRNN_(nn.Module): class EncoderWithRNN_(nn.Module):
def __init__(self, in_channels, hidden_size): def __init__(self, in_channels, hidden_size):
super(EncoderWithRNN_, self).__init__() super(EncoderWithRNN_, self).__init__()
......
...@@ -124,10 +124,10 @@ class DBPostProcess(object): ...@@ -124,10 +124,10 @@ class DBPostProcess(object):
''' '''
h, w = bitmap.shape[:2] h, w = bitmap.shape[:2]
box = _box.copy() box = _box.copy()
xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int64), 0, w - 1) xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int if 'int' in np.__dict__ else np.int32), 0, w - 1)
xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int64), 0, w - 1) xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int if 'int' in np.__dict__ else np.int32), 0, w - 1)
ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int64), 0, h - 1) ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int if 'int' in np.__dict__ else np.int32), 0, h - 1)
ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int64), 0, h - 1) ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int if 'int' in np.__dict__ else np.int32), 0, h - 1)
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
box[:, 0] = box[:, 0] - xmin box[:, 0] = box[:, 0] - xmin
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import re
import numpy as np import numpy as np
import torch import torch
...@@ -24,8 +25,9 @@ class BaseRecLabelDecode(object): ...@@ -24,8 +25,9 @@ class BaseRecLabelDecode(object):
self.beg_str = "sos" self.beg_str = "sos"
self.end_str = "eos" self.end_str = "eos"
self.reverse = False
self.character_str = [] self.character_str = []
if character_dict_path is None: if character_dict_path is None:
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str) dict_character = list(self.character_str)
...@@ -38,6 +40,8 @@ class BaseRecLabelDecode(object): ...@@ -38,6 +40,8 @@ class BaseRecLabelDecode(object):
if use_space_char: if use_space_char:
self.character_str.append(" ") self.character_str.append(" ")
dict_character = list(self.character_str) dict_character = list(self.character_str)
if "arabic" in character_dict_path:
self.reverse = True
dict_character = self.add_special_char(dict_character) dict_character = self.add_special_char(dict_character)
self.dict = {} self.dict = {}
...@@ -45,10 +49,98 @@ class BaseRecLabelDecode(object): ...@@ -45,10 +49,98 @@ class BaseRecLabelDecode(object):
self.dict[char] = i self.dict[char] = i
self.character = dict_character self.character = dict_character
def pred_reverse(self, pred):
pred_re = []
c_current = ""
for c in pred:
if not bool(re.search("[a-zA-Z0-9 :*./%+-]", c)):
if c_current != "":
pred_re.append(c_current)
pred_re.append(c)
c_current = ""
else:
c_current += c
if c_current != "":
pred_re.append(c_current)
return "".join(pred_re[::-1])
def add_special_char(self, dict_character): def add_special_char(self, dict_character):
return dict_character return dict_character
def decode(self, text_index, text_prob=None, is_remove_duplicate=False): def get_word_info(self, text, selection):
"""
Group the decoded characters and record the corresponding decoded positions.
Args:
text: the decoded text
selection: the bool array that identifies which columns of features are decoded as non-separated characters
Returns:
word_list: list of the grouped words
word_col_list: list of decoding positions corresponding to each character in the grouped word
state_list: list of marker to identify the type of grouping words, including two types of grouping words:
- 'cn': continuous chinese characters (e.g., 你好啊)
- 'en&num': continuous english characters (e.g., hello), number (e.g., 123, 1.123), or mixed of them connected by '-' (e.g., VGG-16)
The remaining characters in text are treated as separators between groups (e.g., space, '(', ')', etc.).
"""
state = None
word_content = []
word_col_content = []
word_list = []
word_col_list = []
state_list = []
valid_col = np.where(selection == True)[0]
for c_i, char in enumerate(text):
if "\u4e00" <= char <= "\u9fff":
c_state = "cn"
elif bool(re.search("[a-zA-Z0-9]", char)):
c_state = "en&num"
else:
c_state = "splitter"
if (
char == "."
and state == "en&num"
and c_i + 1 < len(text)
and bool(re.search("[0-9]", text[c_i + 1]))
): # grouping floating number
c_state = "en&num"
if (
char == "-" and state == "en&num"
): # grouping word with '-', such as 'state-of-the-art'
c_state = "en&num"
if state == None:
state = c_state
if state != c_state:
if len(word_content) != 0:
word_list.append(word_content)
word_col_list.append(word_col_content)
state_list.append(state)
word_content = []
word_col_content = []
state = c_state
if state != "splitter":
word_content.append(char)
word_col_content.append(valid_col[c_i])
if len(word_content) != 0:
word_list.append(word_content)
word_col_list.append(word_col_content)
state_list.append(state)
return word_list, word_col_list, state_list
def decode(
self,
text_index,
text_prob=None,
is_remove_duplicate=False,
return_word_box=False,
):
""" convert text-index into text-label. """ """ convert text-index into text-label. """
result_list = [] result_list = []
ignored_tokens = self.get_ignored_tokens() ignored_tokens = self.get_ignored_tokens()
...@@ -88,12 +180,22 @@ class CTCLabelDecode(BaseRecLabelDecode): ...@@ -88,12 +180,22 @@ class CTCLabelDecode(BaseRecLabelDecode):
super(CTCLabelDecode, self).__init__(character_dict_path, super(CTCLabelDecode, self).__init__(character_dict_path,
use_space_char) use_space_char)
def __call__(self, preds, label=None, *args, **kwargs): def __call__(self, preds, label=None, return_word_box=False, *args, **kwargs):
if isinstance(preds, torch.Tensor): if isinstance(preds, torch.Tensor):
preds = preds.numpy() preds = preds.numpy()
preds_idx = preds.argmax(axis=2) preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2) preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True) text = self.decode(
preds_idx,
preds_prob,
is_remove_duplicate=True,
return_word_box=return_word_box,
)
if return_word_box:
for rec_idx, rec in enumerate(text):
wh_ratio = kwargs["wh_ratio_list"][rec_idx]
max_wh_ratio = kwargs["max_wh_ratio"]
rec[2][0] = rec[2][0] * (wh_ratio / max_wh_ratio)
if label is None: if label is None:
return text return text
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment