Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
MinerU
Commits
7230bfe3
Commit
7230bfe3
authored
Jun 11, 2025
by
myhloli
Browse files
refactor: add DonutSwin model implementation and enhance character decoding logic
parent
8f0cc148
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
1407 additions
and
26 deletions
+1407
-26
mineru/model/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/__init__.py
...ddleocr2pytorch/pytorchocr/modeling/backbones/__init__.py
+2
-0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_donut_swin.py
...r2pytorch/pytorchocr/modeling/backbones/rec_donut_swin.py
+1277
-0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/rnn.py
...el/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/rnn.py
+18
-18
mineru/model/ocr/paddleocr2pytorch/pytorchocr/postprocess/db_postprocess.py
...addleocr2pytorch/pytorchocr/postprocess/db_postprocess.py
+4
-4
mineru/model/ocr/paddleocr2pytorch/pytorchocr/postprocess/rec_postprocess.py
...ddleocr2pytorch/pytorchocr/postprocess/rec_postprocess.py
+106
-4
No files found.
mineru/model/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/__init__.py
View file @
7230bfe3
...
@@ -20,6 +20,7 @@ def build_backbone(config, model_type):
...
@@ -20,6 +20,7 @@ def build_backbone(config, model_type):
from
.det_mobilenet_v3
import
MobileNetV3
from
.det_mobilenet_v3
import
MobileNetV3
from
.rec_hgnet
import
PPHGNet_small
from
.rec_hgnet
import
PPHGNet_small
from
.rec_lcnetv3
import
PPLCNetV3
from
.rec_lcnetv3
import
PPLCNetV3
from
.rec_pphgnetv2
import
PPHGNetV2_B4
support_dict
=
[
support_dict
=
[
"MobileNetV3"
,
"MobileNetV3"
,
...
@@ -28,6 +29,7 @@ def build_backbone(config, model_type):
...
@@ -28,6 +29,7 @@ def build_backbone(config, model_type):
"ResNet_SAST"
,
"ResNet_SAST"
,
"PPLCNetV3"
,
"PPLCNetV3"
,
"PPHGNet_small"
,
"PPHGNet_small"
,
'PPHGNetV2_B4'
,
]
]
elif
model_type
==
"rec"
or
model_type
==
"cls"
:
elif
model_type
==
"rec"
or
model_type
==
"cls"
:
from
.rec_hgnet
import
PPHGNet_small
from
.rec_hgnet
import
PPHGNet_small
...
...
mineru/model/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_donut_swin.py
0 → 100644
View file @
7230bfe3
This diff is collapsed.
Click to expand it.
mineru/model/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/rnn.py
View file @
7230bfe3
...
@@ -9,28 +9,28 @@ class Im2Seq(nn.Module):
...
@@ -9,28 +9,28 @@ class Im2Seq(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
out_channels
=
in_channels
self
.
out_channels
=
in_channels
# def forward(self, x):
# B, C, H, W = x.shape
# # assert H == 1
# x = x.squeeze(dim=2)
# # x = x.transpose([0, 2, 1]) # paddle (NTC)(batch, width, channels)
# x = x.permute(0, 2, 1)
# return x
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
B
,
C
,
H
,
W
=
x
.
shape
# 处理四维张量,将空间维度展平为序列
# assert H == 1
if
H
==
1
:
x
=
x
.
squeeze
(
dim
=
2
)
# 原来的处理逻辑,适用于H=1的情况
# x = x.transpose([0, 2, 1]) # paddle (NTC)(batch, width, channels)
x
=
x
.
squeeze
(
dim
=
2
)
x
=
x
.
permute
(
0
,
2
,
1
)
x
=
x
.
permute
(
0
,
2
,
1
)
# (B, W, C)
else
:
# 处理H不为1的情况
x
=
x
.
permute
(
0
,
2
,
3
,
1
)
# (B, H, W, C)
x
=
x
.
reshape
(
B
,
H
*
W
,
C
)
# (B, H*W, C)
return
x
return
x
# def forward(self, x):
# B, C, H, W = x.shape
# # 处理四维张量,将空间维度展平为序列
# if H == 1:
# # 原来的处理逻辑,适用于H=1的情况
# x = x.squeeze(dim=2)
# x = x.permute(0, 2, 1) # (B, W, C)
# else:
# # 处理H不为1的情况
# x = x.permute(0, 2, 3, 1) # (B, H, W, C)
# x = x.reshape(B, H * W, C) # (B, H*W, C)
#
# return x
class
EncoderWithRNN_
(
nn
.
Module
):
class
EncoderWithRNN_
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
hidden_size
):
def
__init__
(
self
,
in_channels
,
hidden_size
):
super
(
EncoderWithRNN_
,
self
).
__init__
()
super
(
EncoderWithRNN_
,
self
).
__init__
()
...
...
mineru/model/ocr/paddleocr2pytorch/pytorchocr/postprocess/db_postprocess.py
View file @
7230bfe3
...
@@ -124,10 +124,10 @@ class DBPostProcess(object):
...
@@ -124,10 +124,10 @@ class DBPostProcess(object):
'''
'''
h
,
w
=
bitmap
.
shape
[:
2
]
h
,
w
=
bitmap
.
shape
[:
2
]
box
=
_box
.
copy
()
box
=
_box
.
copy
()
xmin
=
np
.
clip
(
np
.
floor
(
box
[:,
0
].
min
()).
astype
(
np
.
int
64
),
0
,
w
-
1
)
xmin
=
np
.
clip
(
np
.
floor
(
box
[:,
0
].
min
()).
astype
(
np
.
int
if
'int'
in
np
.
__dict__
else
np
.
int32
),
0
,
w
-
1
)
xmax
=
np
.
clip
(
np
.
ceil
(
box
[:,
0
].
max
()).
astype
(
np
.
int
64
),
0
,
w
-
1
)
xmax
=
np
.
clip
(
np
.
ceil
(
box
[:,
0
].
max
()).
astype
(
np
.
int
if
'int'
in
np
.
__dict__
else
np
.
int32
),
0
,
w
-
1
)
ymin
=
np
.
clip
(
np
.
floor
(
box
[:,
1
].
min
()).
astype
(
np
.
int
64
),
0
,
h
-
1
)
ymin
=
np
.
clip
(
np
.
floor
(
box
[:,
1
].
min
()).
astype
(
np
.
int
if
'int'
in
np
.
__dict__
else
np
.
int32
),
0
,
h
-
1
)
ymax
=
np
.
clip
(
np
.
ceil
(
box
[:,
1
].
max
()).
astype
(
np
.
int
64
),
0
,
h
-
1
)
ymax
=
np
.
clip
(
np
.
ceil
(
box
[:,
1
].
max
()).
astype
(
np
.
int
if
'int'
in
np
.
__dict__
else
np
.
int32
),
0
,
h
-
1
)
mask
=
np
.
zeros
((
ymax
-
ymin
+
1
,
xmax
-
xmin
+
1
),
dtype
=
np
.
uint8
)
mask
=
np
.
zeros
((
ymax
-
ymin
+
1
,
xmax
-
xmin
+
1
),
dtype
=
np
.
uint8
)
box
[:,
0
]
=
box
[:,
0
]
-
xmin
box
[:,
0
]
=
box
[:,
0
]
-
xmin
...
...
mineru/model/ocr/paddleocr2pytorch/pytorchocr/postprocess/rec_postprocess.py
View file @
7230bfe3
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
re
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -24,8 +25,9 @@ class BaseRecLabelDecode(object):
...
@@ -24,8 +25,9 @@ class BaseRecLabelDecode(object):
self
.
beg_str
=
"sos"
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
self
.
end_str
=
"eos"
self
.
reverse
=
False
self
.
character_str
=
[]
self
.
character_str
=
[]
if
character_dict_path
is
None
:
if
character_dict_path
is
None
:
self
.
character_str
=
"0123456789abcdefghijklmnopqrstuvwxyz"
self
.
character_str
=
"0123456789abcdefghijklmnopqrstuvwxyz"
dict_character
=
list
(
self
.
character_str
)
dict_character
=
list
(
self
.
character_str
)
...
@@ -38,6 +40,8 @@ class BaseRecLabelDecode(object):
...
@@ -38,6 +40,8 @@ class BaseRecLabelDecode(object):
if
use_space_char
:
if
use_space_char
:
self
.
character_str
.
append
(
" "
)
self
.
character_str
.
append
(
" "
)
dict_character
=
list
(
self
.
character_str
)
dict_character
=
list
(
self
.
character_str
)
if
"arabic"
in
character_dict_path
:
self
.
reverse
=
True
dict_character
=
self
.
add_special_char
(
dict_character
)
dict_character
=
self
.
add_special_char
(
dict_character
)
self
.
dict
=
{}
self
.
dict
=
{}
...
@@ -45,10 +49,98 @@ class BaseRecLabelDecode(object):
...
@@ -45,10 +49,98 @@ class BaseRecLabelDecode(object):
self
.
dict
[
char
]
=
i
self
.
dict
[
char
]
=
i
self
.
character
=
dict_character
self
.
character
=
dict_character
def
pred_reverse
(
self
,
pred
):
pred_re
=
[]
c_current
=
""
for
c
in
pred
:
if
not
bool
(
re
.
search
(
"[a-zA-Z0-9 :*./%+-]"
,
c
)):
if
c_current
!=
""
:
pred_re
.
append
(
c_current
)
pred_re
.
append
(
c
)
c_current
=
""
else
:
c_current
+=
c
if
c_current
!=
""
:
pred_re
.
append
(
c_current
)
return
""
.
join
(
pred_re
[::
-
1
])
def
add_special_char
(
self
,
dict_character
):
def
add_special_char
(
self
,
dict_character
):
return
dict_character
return
dict_character
def
decode
(
self
,
text_index
,
text_prob
=
None
,
is_remove_duplicate
=
False
):
def
get_word_info
(
self
,
text
,
selection
):
"""
Group the decoded characters and record the corresponding decoded positions.
Args:
text: the decoded text
selection: the bool array that identifies which columns of features are decoded as non-separated characters
Returns:
word_list: list of the grouped words
word_col_list: list of decoding positions corresponding to each character in the grouped word
state_list: list of marker to identify the type of grouping words, including two types of grouping words:
- 'cn': continuous chinese characters (e.g., 你好啊)
- 'en&num': continuous english characters (e.g., hello), number (e.g., 123, 1.123), or mixed of them connected by '-' (e.g., VGG-16)
The remaining characters in text are treated as separators between groups (e.g., space, '(', ')', etc.).
"""
state
=
None
word_content
=
[]
word_col_content
=
[]
word_list
=
[]
word_col_list
=
[]
state_list
=
[]
valid_col
=
np
.
where
(
selection
==
True
)[
0
]
for
c_i
,
char
in
enumerate
(
text
):
if
"
\u4e00
"
<=
char
<=
"
\u9fff
"
:
c_state
=
"cn"
elif
bool
(
re
.
search
(
"[a-zA-Z0-9]"
,
char
)):
c_state
=
"en&num"
else
:
c_state
=
"splitter"
if
(
char
==
"."
and
state
==
"en&num"
and
c_i
+
1
<
len
(
text
)
and
bool
(
re
.
search
(
"[0-9]"
,
text
[
c_i
+
1
]))
):
# grouping floating number
c_state
=
"en&num"
if
(
char
==
"-"
and
state
==
"en&num"
):
# grouping word with '-', such as 'state-of-the-art'
c_state
=
"en&num"
if
state
==
None
:
state
=
c_state
if
state
!=
c_state
:
if
len
(
word_content
)
!=
0
:
word_list
.
append
(
word_content
)
word_col_list
.
append
(
word_col_content
)
state_list
.
append
(
state
)
word_content
=
[]
word_col_content
=
[]
state
=
c_state
if
state
!=
"splitter"
:
word_content
.
append
(
char
)
word_col_content
.
append
(
valid_col
[
c_i
])
if
len
(
word_content
)
!=
0
:
word_list
.
append
(
word_content
)
word_col_list
.
append
(
word_col_content
)
state_list
.
append
(
state
)
return
word_list
,
word_col_list
,
state_list
def
decode
(
self
,
text_index
,
text_prob
=
None
,
is_remove_duplicate
=
False
,
return_word_box
=
False
,
):
""" convert text-index into text-label. """
""" convert text-index into text-label. """
result_list
=
[]
result_list
=
[]
ignored_tokens
=
self
.
get_ignored_tokens
()
ignored_tokens
=
self
.
get_ignored_tokens
()
...
@@ -88,12 +180,22 @@ class CTCLabelDecode(BaseRecLabelDecode):
...
@@ -88,12 +180,22 @@ class CTCLabelDecode(BaseRecLabelDecode):
super
(
CTCLabelDecode
,
self
).
__init__
(
character_dict_path
,
super
(
CTCLabelDecode
,
self
).
__init__
(
character_dict_path
,
use_space_char
)
use_space_char
)
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
def
__call__
(
self
,
preds
,
label
=
None
,
return_word_box
=
False
,
*
args
,
**
kwargs
):
if
isinstance
(
preds
,
torch
.
Tensor
):
if
isinstance
(
preds
,
torch
.
Tensor
):
preds
=
preds
.
numpy
()
preds
=
preds
.
numpy
()
preds_idx
=
preds
.
argmax
(
axis
=
2
)
preds_idx
=
preds
.
argmax
(
axis
=
2
)
preds_prob
=
preds
.
max
(
axis
=
2
)
preds_prob
=
preds
.
max
(
axis
=
2
)
text
=
self
.
decode
(
preds_idx
,
preds_prob
,
is_remove_duplicate
=
True
)
text
=
self
.
decode
(
preds_idx
,
preds_prob
,
is_remove_duplicate
=
True
,
return_word_box
=
return_word_box
,
)
if
return_word_box
:
for
rec_idx
,
rec
in
enumerate
(
text
):
wh_ratio
=
kwargs
[
"wh_ratio_list"
][
rec_idx
]
max_wh_ratio
=
kwargs
[
"max_wh_ratio"
]
rec
[
2
][
0
]
=
rec
[
2
][
0
]
*
(
wh_ratio
/
max_wh_ratio
)
if
label
is
None
:
if
label
is
None
:
return
text
return
text
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment