Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
MinerU
Commits
7230bfe3
Commit
7230bfe3
authored
Jun 11, 2025
by
myhloli
Browse files
refactor: add DonutSwin model implementation and enhance character decoding logic
parent
8f0cc148
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
1407 additions
and
26 deletions
+1407
-26
mineru/model/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/__init__.py
...ddleocr2pytorch/pytorchocr/modeling/backbones/__init__.py
+2
-0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_donut_swin.py
...r2pytorch/pytorchocr/modeling/backbones/rec_donut_swin.py
+1277
-0
mineru/model/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/rnn.py
...el/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/rnn.py
+18
-18
mineru/model/ocr/paddleocr2pytorch/pytorchocr/postprocess/db_postprocess.py
...addleocr2pytorch/pytorchocr/postprocess/db_postprocess.py
+4
-4
mineru/model/ocr/paddleocr2pytorch/pytorchocr/postprocess/rec_postprocess.py
...ddleocr2pytorch/pytorchocr/postprocess/rec_postprocess.py
+106
-4
No files found.
mineru/model/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/__init__.py
View file @
7230bfe3
...
@@ -20,6 +20,7 @@ def build_backbone(config, model_type):
...
@@ -20,6 +20,7 @@ def build_backbone(config, model_type):
from
.det_mobilenet_v3
import
MobileNetV3
from
.det_mobilenet_v3
import
MobileNetV3
from
.rec_hgnet
import
PPHGNet_small
from
.rec_hgnet
import
PPHGNet_small
from
.rec_lcnetv3
import
PPLCNetV3
from
.rec_lcnetv3
import
PPLCNetV3
from
.rec_pphgnetv2
import
PPHGNetV2_B4
support_dict
=
[
support_dict
=
[
"MobileNetV3"
,
"MobileNetV3"
,
...
@@ -28,6 +29,7 @@ def build_backbone(config, model_type):
...
@@ -28,6 +29,7 @@ def build_backbone(config, model_type):
"ResNet_SAST"
,
"ResNet_SAST"
,
"PPLCNetV3"
,
"PPLCNetV3"
,
"PPHGNet_small"
,
"PPHGNet_small"
,
'PPHGNetV2_B4'
,
]
]
elif
model_type
==
"rec"
or
model_type
==
"cls"
:
elif
model_type
==
"rec"
or
model_type
==
"cls"
:
from
.rec_hgnet
import
PPHGNet_small
from
.rec_hgnet
import
PPHGNet_small
...
...
mineru/model/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_donut_swin.py
0 → 100644
View file @
7230bfe3
import
collections.abc
from
collections
import
OrderedDict
import
math
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Tuple
,
Union
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
class
DonutSwinConfig
(
object
):
model_type
=
"donut-swin"
attribute_map
=
{
"num_attention_heads"
:
"num_heads"
,
"num_hidden_layers"
:
"num_layers"
,
}
def
__init__
(
self
,
image_size
=
224
,
patch_size
=
4
,
num_channels
=
3
,
embed_dim
=
96
,
depths
=
[
2
,
2
,
6
,
2
],
num_heads
=
[
3
,
6
,
12
,
24
],
window_size
=
7
,
mlp_ratio
=
4.0
,
qkv_bias
=
True
,
hidden_dropout_prob
=
0.0
,
attention_probs_dropout_prob
=
0.0
,
drop_path_rate
=
0.1
,
hidden_act
=
"gelu"
,
use_absolute_embeddings
=
False
,
initializer_range
=
0.02
,
layer_norm_eps
=
1e-5
,
**
kwargs
,
):
super
().
__init__
()
self
.
image_size
=
image_size
self
.
patch_size
=
patch_size
self
.
num_channels
=
num_channels
self
.
embed_dim
=
embed_dim
self
.
depths
=
depths
self
.
num_layers
=
len
(
depths
)
self
.
num_heads
=
num_heads
self
.
window_size
=
window_size
self
.
mlp_ratio
=
mlp_ratio
self
.
qkv_bias
=
qkv_bias
self
.
hidden_dropout_prob
=
hidden_dropout_prob
self
.
attention_probs_dropout_prob
=
attention_probs_dropout_prob
self
.
drop_path_rate
=
drop_path_rate
self
.
hidden_act
=
hidden_act
self
.
use_absolute_embeddings
=
use_absolute_embeddings
self
.
layer_norm_eps
=
layer_norm_eps
self
.
initializer_range
=
initializer_range
self
.
hidden_size
=
int
(
embed_dim
*
2
**
(
len
(
depths
)
-
1
))
for
key
,
value
in
kwargs
.
items
():
try
:
setattr
(
self
,
key
,
value
)
except
AttributeError
as
err
:
print
(
f
"Can't set
{
key
}
with value
{
value
}
for
{
self
}
"
)
raise
err
@
dataclass
# Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->DonutSwin
class
DonutSwinEncoderOutput
(
OrderedDict
):
last_hidden_state
=
None
hidden_states
=
None
attentions
=
None
reshaped_hidden_states
=
None
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
def
__getitem__
(
self
,
k
):
if
isinstance
(
k
,
str
):
inner_dict
=
dict
(
self
.
items
())
return
inner_dict
[
k
]
else
:
return
self
.
to_tuple
()[
k
]
def
__setattr__
(
self
,
name
,
value
):
if
name
in
self
.
keys
()
and
value
is
not
None
:
super
().
__setitem__
(
name
,
value
)
super
().
__setattr__
(
name
,
value
)
def
__setitem__
(
self
,
key
,
value
):
super
().
__setitem__
(
key
,
value
)
super
().
__setattr__
(
key
,
value
)
def
to_tuple
(
self
):
"""
Convert self to a tuple containing all the attributes/keys that are not `None`.
"""
return
tuple
(
self
[
k
]
for
k
in
self
.
keys
())
@
dataclass
# Copied from transformers.models.swin.modeling_swin.SwinModelOutput with Swin->DonutSwin
class
DonutSwinModelOutput
(
OrderedDict
):
last_hidden_state
=
None
pooler_output
=
None
hidden_states
=
None
attentions
=
None
reshaped_hidden_states
=
None
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
def
__getitem__
(
self
,
k
):
if
isinstance
(
k
,
str
):
inner_dict
=
dict
(
self
.
items
())
return
inner_dict
[
k
]
else
:
return
self
.
to_tuple
()[
k
]
def
__setattr__
(
self
,
name
,
value
):
if
name
in
self
.
keys
()
and
value
is
not
None
:
super
().
__setitem__
(
name
,
value
)
super
().
__setattr__
(
name
,
value
)
def
__setitem__
(
self
,
key
,
value
):
super
().
__setitem__
(
key
,
value
)
super
().
__setattr__
(
key
,
value
)
def
to_tuple
(
self
):
"""
Convert self to a tuple containing all the attributes/keys that are not `None`.
"""
return
tuple
(
self
[
k
]
for
k
in
self
.
keys
())
# Copied from transformers.models.swin.modeling_swin.window_partition
def
window_partition
(
input_feature
,
window_size
):
"""
Partitions the given input into windows.
"""
batch_size
,
height
,
width
,
num_channels
=
input_feature
.
shape
input_feature
=
input_feature
.
reshape
(
[
batch_size
,
height
//
window_size
,
window_size
,
width
//
window_size
,
window_size
,
num_channels
,
]
)
windows
=
input_feature
.
transpose
([
0
,
1
,
3
,
2
,
4
,
5
]).
reshape
(
[
-
1
,
window_size
,
window_size
,
num_channels
]
)
return
windows
# Copied from transformers.models.swin.modeling_swin.window_reverse
def
window_reverse
(
windows
,
window_size
,
height
,
width
):
"""
Merges windows to produce higher resolution features.
"""
num_channels
=
windows
.
shape
[
-
1
]
windows
=
windows
.
reshape
(
[
-
1
,
height
//
window_size
,
width
//
window_size
,
window_size
,
window_size
,
num_channels
,
]
)
windows
=
windows
.
transpose
([
0
,
1
,
3
,
2
,
4
,
5
]).
reshape
(
[
-
1
,
height
,
width
,
num_channels
]
)
return
windows
# Copied from transformers.models.swin.modeling_swin.SwinEmbeddings with Swin->DonutSwin
class
DonutSwinEmbeddings
(
nn
.
Module
):
"""
Construct the patch and position embeddings. Optionally, also the mask token.
"""
def
__init__
(
self
,
config
,
use_mask_token
=
False
):
super
().
__init__
()
self
.
patch_embeddings
=
DonutSwinPatchEmbeddings
(
config
)
num_patches
=
self
.
patch_embeddings
.
num_patches
self
.
patch_grid
=
self
.
patch_embeddings
.
grid_size
if
use_mask_token
:
# self.mask_token = paddle.create_parameter(
# [1, 1, config.embed_dim], dtype="float32"
# )
self
.
mask_token
=
nn
.
Parameter
(
nn
.
init
.
xavier_uniform_
(
torch
.
zeros
(
1
,
1
,
config
.
embed_dim
).
to
(
torch
.
float32
))
)
nn
.
init
.
zeros_
(
self
.
mask_token
)
else
:
self
.
mask_token
=
None
if
config
.
use_absolute_embeddings
:
# self.position_embeddings = paddle.create_parameter(
# [1, num_patches + 1, config.embed_dim], dtype="float32"
# )
self
.
position_embeddings
=
nn
.
Parameter
(
nn
.
init
.
xavier_uniform_
(
torch
.
zeros
(
1
,
num_patches
+
1
,
config
.
embed_dim
).
to
(
torch
.
float32
))
)
nn
.
init
.
zeros_
(
self
.
position_embedding
)
else
:
self
.
position_embeddings
=
None
self
.
norm
=
nn
.
LayerNorm
(
config
.
embed_dim
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
def
forward
(
self
,
pixel_values
,
bool_masked_pos
=
None
):
embeddings
,
output_dimensions
=
self
.
patch_embeddings
(
pixel_values
)
embeddings
=
self
.
norm
(
embeddings
)
batch_size
,
seq_len
,
_
=
embeddings
.
shape
if
bool_masked_pos
is
not
None
:
mask_tokens
=
self
.
mask_token
.
expand
(
batch_size
,
seq_len
,
-
1
)
mask
=
bool_masked_pos
.
unsqueeze
(
-
1
).
type_as
(
mask_tokens
)
embeddings
=
embeddings
*
(
1.0
-
mask
)
+
mask_tokens
*
mask
if
self
.
position_embeddings
is
not
None
:
embeddings
=
embeddings
+
self
.
position_embeddings
embeddings
=
self
.
dropout
(
embeddings
)
return
embeddings
,
output_dimensions
class
MyConv2d
(
nn
.
Conv2d
):
def
__init__
(
self
,
in_channel
,
out_channels
,
kernel_size
,
stride
=
1
,
padding
=
"SAME"
,
dilation
=
1
,
groups
=
1
,
bias_attr
=
False
,
eps
=
1e-6
,
):
super
().
__init__
(
in_channel
,
out_channels
,
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
groups
=
groups
,
bias_attr
=
bias_attr
,
)
# self.weight = paddle.create_parameter(
# [out_channels, in_channel, kernel_size[0], kernel_size[1]], dtype="float32"
# )
self
.
weight
=
torch
.
Parameter
(
nn
.
init
.
xavier_uniform_
(
torch
.
zeros
(
out_channels
,
in_channel
,
kernel_size
[
0
],
kernel_size
[
1
]).
to
(
torch
.
float32
)
)
)
# self.bias = paddle.create_parameter([out_channels], dtype="float32")
self
.
bias
=
torch
.
Parameter
(
nn
.
init
.
xavier_uniform_
(
torch
.
zeros
(
out_channels
).
to
(
torch
.
float32
)
)
)
nn
.
init
.
ones_
(
self
.
weight
)
nn
.
init
.
zeros_
(
self
.
bias
)
def
forward
(
self
,
x
):
x
=
F
.
conv2d
(
x
,
self
.
weight
,
self
.
bias
,
self
.
_stride
,
self
.
_padding
,
self
.
_dilation
,
self
.
_groups
,
)
return
x
# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings
class
DonutSwinPatchEmbeddings
(
nn
.
Module
):
"""
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
"""
def
__init__
(
self
,
config
):
super
().
__init__
()
image_size
,
patch_size
=
config
.
image_size
,
config
.
patch_size
num_channels
,
hidden_size
=
config
.
num_channels
,
config
.
embed_dim
image_size
=
(
image_size
if
isinstance
(
image_size
,
collections
.
abc
.
Iterable
)
else
(
image_size
,
image_size
)
)
patch_size
=
(
patch_size
if
isinstance
(
patch_size
,
collections
.
abc
.
Iterable
)
else
(
patch_size
,
patch_size
)
)
num_patches
=
(
image_size
[
1
]
//
patch_size
[
1
])
*
(
image_size
[
0
]
//
patch_size
[
0
]
)
self
.
image_size
=
image_size
self
.
patch_size
=
patch_size
self
.
num_channels
=
num_channels
self
.
num_patches
=
num_patches
self
.
is_export
=
config
.
is_export
self
.
grid_size
=
(
image_size
[
0
]
//
patch_size
[
0
],
image_size
[
1
]
//
patch_size
[
1
],
)
self
.
projection
=
nn
.
Conv2D
(
num_channels
,
hidden_size
,
kernel_size
=
patch_size
,
stride
=
patch_size
)
def
maybe_pad
(
self
,
pixel_values
,
height
,
width
):
if
width
%
self
.
patch_size
[
1
]
!=
0
:
pad_values
=
(
0
,
self
.
patch_size
[
1
]
-
width
%
self
.
patch_size
[
1
])
if
self
.
is_export
:
pad_values
=
torch
.
tensor
(
pad_values
,
dtype
=
torch
.
int32
)
pixel_values
=
nn
.
functional
.
pad
(
pixel_values
,
pad_values
)
if
height
%
self
.
patch_size
[
0
]
!=
0
:
pad_values
=
(
0
,
0
,
0
,
self
.
patch_size
[
0
]
-
height
%
self
.
patch_size
[
0
])
if
self
.
is_export
:
pad_values
=
torch
.
tensor
(
pad_values
,
dtype
=
torch
.
int32
)
pixel_values
=
nn
.
functional
.
pad
(
pixel_values
,
pad_values
)
return
pixel_values
def
forward
(
self
,
pixel_values
)
->
Tuple
[
torch
.
Tensor
,
Tuple
[
int
]]:
_
,
num_channels
,
height
,
width
=
pixel_values
.
shape
if
num_channels
!=
self
.
num_channels
:
raise
ValueError
(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
pixel_values
=
self
.
maybe_pad
(
pixel_values
,
height
,
width
)
embeddings
=
self
.
projection
(
pixel_values
)
_
,
_
,
height
,
width
=
embeddings
.
shape
output_dimensions
=
(
height
,
width
)
embeddings
=
embeddings
.
flatten
(
2
).
transpose
([
0
,
2
,
1
])
return
embeddings
,
output_dimensions
# Copied from transformers.models.swin.modeling_swin.SwinPatchMerging
class
DonutSwinPatchMerging
(
nn
.
Module
):
"""
Patch Merging Layer.
Args:
input_resolution (`Tuple[int]`):
Resolution of input feature.
dim (`int`):
Number of input channels.
norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
Normalization layer class.
"""
def
__init__
(
self
,
input_resolution
:
Tuple
[
int
],
dim
:
int
,
norm_layer
:
nn
.
Module
=
nn
.
LayerNorm
,
is_export
=
False
,
):
super
().
__init__
()
self
.
input_resolution
=
input_resolution
self
.
dim
=
dim
self
.
reduction
=
nn
.
Linear
(
4
*
dim
,
2
*
dim
,
bias_attr
=
False
)
self
.
norm
=
norm_layer
(
4
*
dim
)
self
.
is_export
=
is_export
def
maybe_pad
(
self
,
input_feature
,
height
,
width
):
should_pad
=
(
height
%
2
==
1
)
or
(
width
%
2
==
1
)
if
should_pad
:
pad_values
=
(
0
,
0
,
0
,
width
%
2
,
0
,
height
%
2
)
if
self
.
is_export
:
pad_values
=
torch
.
tensor
(
pad_values
,
dtype
=
torch
.
int32
)
input_feature
=
nn
.
functional
.
pad
(
input_feature
,
pad_values
)
return
input_feature
def
forward
(
self
,
input_feature
:
torch
.
Tensor
,
input_dimensions
:
Tuple
[
int
,
int
]
)
->
torch
.
Tensor
:
height
,
width
=
input_dimensions
batch_size
,
dim
,
num_channels
=
input_feature
.
shape
input_feature
=
input_feature
.
reshape
([
batch_size
,
height
,
width
,
num_channels
])
input_feature
=
self
.
maybe_pad
(
input_feature
,
height
,
width
)
input_feature_0
=
input_feature
[:,
0
::
2
,
0
::
2
,
:]
input_feature_1
=
input_feature
[:,
1
::
2
,
0
::
2
,
:]
input_feature_2
=
input_feature
[:,
0
::
2
,
1
::
2
,
:]
input_feature_3
=
input_feature
[:,
1
::
2
,
1
::
2
,
:]
input_feature
=
torch
.
cat
(
[
input_feature_0
,
input_feature_1
,
input_feature_2
,
input_feature_3
],
-
1
)
input_feature
=
input_feature
.
reshape
(
[
batch_size
,
-
1
,
4
*
num_channels
]
)
# batch_size height/2*width/2 4*C
input_feature
=
self
.
norm
(
input_feature
)
input_feature
=
self
.
reduction
(
input_feature
)
return
input_feature
# Copied from transformers.models.beit.modeling_beit.drop_path
def
drop_path
(
input
:
torch
.
Tensor
,
drop_prob
:
float
=
0.0
,
training
:
bool
=
False
)
->
torch
.
Tensor
:
if
drop_prob
==
0.0
or
not
training
:
return
input
keep_prob
=
1
-
drop_prob
shape
=
(
input
.
shape
[
0
],)
+
(
1
,)
*
(
input
.
ndim
-
1
)
# work with diff dim tensors, not just 2D ConvNets
random_tensor
=
keep_prob
+
torch
.
rand
(
shape
,
dtype
=
input
.
dtype
,
)
random_tensor
.
floor_
()
# binarize
output
=
input
/
keep_prob
*
random_tensor
return
output
# Copied from transformers.models.swin.modeling_swin.SwinDropPath
class
DonutSwinDropPath
(
nn
.
Module
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def
__init__
(
self
,
drop_prob
:
Optional
[
float
]
=
None
)
->
None
:
super
().
__init__
()
self
.
drop_prob
=
drop_prob
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
drop_path
(
hidden_states
,
self
.
drop_prob
,
self
.
training
)
def
extra_repr
(
self
)
->
str
:
return
"p={}"
.
format
(
self
.
drop_prob
)
class
DonutSwinSelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
,
dim
,
num_heads
,
window_size
):
super
().
__init__
()
if
dim
%
num_heads
!=
0
:
raise
ValueError
(
f
"The hidden size (
{
dim
}
) is not a multiple of the number of attention heads (
{
num_heads
}
)"
)
self
.
num_attention_heads
=
num_heads
self
.
attention_head_size
=
int
(
dim
/
num_heads
)
self
.
all_head_size
=
self
.
num_attention_heads
*
self
.
attention_head_size
self
.
window_size
=
(
window_size
if
isinstance
(
window_size
,
collections
.
abc
.
Iterable
)
else
(
window_size
,
window_size
)
)
# self.relative_position_bias_table = paddle.create_parameter(
# [(2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads],
# dtype="float32",
# )
self
.
relative_position_bias_table
=
torch
.
Parameter
(
nn
.
init
.
xavier_normal_
(
torch
.
zeros
((
2
*
self
.
window_size
[
0
]
-
1
)
*
(
2
*
self
.
window_size
[
1
]
-
1
),
num_heads
).
to
(
torch
.
float32
)
)
)
nn
.
init
.
zeros_
(
self
.
relative_position_bias_table
)
# get pair-wise relative position index for each token inside the window
coords_h
=
torch
.
arange
(
self
.
window_size
[
0
])
coords_w
=
torch
.
arange
(
self
.
window_size
[
1
])
coords
=
torch
.
stack
(
torch
.
meshgrid
(
coords_h
,
coords_w
,
indexing
=
"ij"
))
coords_flatten
=
torch
.
flatten
(
coords
,
1
)
relative_coords
=
coords_flatten
[:,
:,
None
]
-
coords_flatten
[:,
None
,
:]
relative_coords
=
relative_coords
.
transpose
([
1
,
2
,
0
])
relative_coords
[:,
:,
0
]
+=
self
.
window_size
[
0
]
-
1
relative_coords
[:,
:,
1
]
+=
self
.
window_size
[
1
]
-
1
relative_coords
[:,
:,
0
]
*=
2
*
self
.
window_size
[
1
]
-
1
relative_position_index
=
relative_coords
.
sum
(
-
1
)
self
.
register_buffer
(
"relative_position_index"
,
relative_position_index
)
self
.
query
=
nn
.
Linear
(
self
.
all_head_size
,
self
.
all_head_size
,
bias_attr
=
config
.
qkv_bias
)
self
.
key
=
nn
.
Linear
(
self
.
all_head_size
,
self
.
all_head_size
,
bias_attr
=
config
.
qkv_bias
)
self
.
value
=
nn
.
Linear
(
self
.
all_head_size
,
self
.
all_head_size
,
bias_attr
=
config
.
qkv_bias
)
self
.
dropout
=
nn
.
Dropout
(
config
.
attention_probs_dropout_prob
)
def
transpose_for_scores
(
self
,
x
):
new_x_shape
=
x
.
shape
[:
-
1
]
+
[
self
.
num_attention_heads
,
self
.
attention_head_size
,
]
x
=
x
.
reshape
(
new_x_shape
)
return
x
.
transpose
([
0
,
2
,
1
,
3
])
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
=
None
,
head_mask
=
None
,
output_attentions
=
False
,
)
->
Tuple
[
torch
.
Tensor
]:
batch_size
,
dim
,
num_channels
=
hidden_states
.
shape
mixed_query_layer
=
self
.
query
(
hidden_states
)
key_layer
=
self
.
transpose_for_scores
(
self
.
key
(
hidden_states
))
value_layer
=
self
.
transpose_for_scores
(
self
.
value
(
hidden_states
))
query_layer
=
self
.
transpose_for_scores
(
mixed_query_layer
)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores
=
torch
.
matmul
(
query_layer
,
key_layer
.
transpose
([
0
,
1
,
3
,
2
]))
attention_scores
=
attention_scores
/
math
.
sqrt
(
self
.
attention_head_size
)
relative_position_bias
=
self
.
relative_position_bias_table
[
self
.
relative_position_index
.
reshape
([
-
1
])
]
relative_position_bias
=
relative_position_bias
.
reshape
(
[
self
.
window_size
[
0
]
*
self
.
window_size
[
1
],
self
.
window_size
[
0
]
*
self
.
window_size
[
1
],
-
1
,
]
)
relative_position_bias
=
relative_position_bias
.
transpose
([
2
,
0
,
1
])
attention_scores
=
attention_scores
+
relative_position_bias
.
unsqueeze
(
0
)
if
attention_mask
is
not
None
:
# Apply the attention mask is (precomputed for all layers in DonutSwinModel forward() function)
mask_shape
=
attention_mask
.
shape
[
0
]
attention_scores
=
attention_scores
.
reshape
(
[
batch_size
//
mask_shape
,
mask_shape
,
self
.
num_attention_heads
,
dim
,
dim
,
]
)
attention_scores
=
attention_scores
+
attention_mask
.
unsqueeze
(
1
).
unsqueeze
(
0
)
attention_scores
=
attention_scores
.
reshape
(
[
-
1
,
self
.
num_attention_heads
,
dim
,
dim
]
)
# Normalize the attention scores to probabilities.
attention_probs
=
nn
.
functional
.
softmax
(
attention_scores
,
axis
=-
1
)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs
=
self
.
dropout
(
attention_probs
)
# Mask heads if we want to
if
head_mask
is
not
None
:
attention_probs
=
attention_probs
*
head_mask
context_layer
=
torch
.
matmul
(
attention_probs
,
value_layer
)
context_layer
=
context_layer
.
transpose
([
0
,
2
,
1
,
3
])
new_context_layer_shape
=
tuple
(
context_layer
.
shape
[:
-
2
])
+
(
self
.
all_head_size
,
)
context_layer
=
context_layer
.
reshape
(
new_context_layer_shape
)
outputs
=
(
(
context_layer
,
attention_probs
)
if
output_attentions
else
(
context_layer
,)
)
return
outputs
# Copied from transformers.models.swin.modeling_swin.SwinSelfOutput
class
DonutSwinSelfOutput
(
nn
.
Module
):
def
__init__
(
self
,
config
,
dim
):
super
().
__init__
()
self
.
dense
=
nn
.
Linear
(
dim
,
dim
)
self
.
dropout
=
nn
.
Dropout
(
config
.
attention_probs_dropout_prob
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
input_tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
=
self
.
dense
(
hidden_states
)
hidden_states
=
self
.
dropout
(
hidden_states
)
return
hidden_states
# Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->DonutSwin
class
DonutSwinAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
,
dim
,
num_heads
,
window_size
):
super
().
__init__
()
self
.
self
=
DonutSwinSelfAttention
(
config
,
dim
,
num_heads
,
window_size
)
self
.
output
=
DonutSwinSelfOutput
(
config
,
dim
)
self
.
pruned_heads
=
set
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
=
None
,
head_mask
=
None
,
output_attentions
=
False
,
)
->
Tuple
[
torch
.
Tensor
]:
self_outputs
=
self
.
self
(
hidden_states
,
attention_mask
,
head_mask
,
output_attentions
)
attention_output
=
self
.
output
(
self_outputs
[
0
],
hidden_states
)
outputs
=
(
attention_output
,)
+
self_outputs
[
1
:
]
# add attentions if we output them
return
outputs
# Copied from transformers.models.swin.modeling_swin.SwinIntermediate
class
DonutSwinIntermediate
(
nn
.
Module
):
def
__init__
(
self
,
config
,
dim
):
super
().
__init__
()
self
.
dense
=
nn
.
Linear
(
dim
,
int
(
config
.
mlp_ratio
*
dim
))
self
.
intermediate_act_fn
=
F
.
gelu
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
=
self
.
dense
(
hidden_states
)
hidden_states
=
self
.
intermediate_act_fn
(
hidden_states
)
return
hidden_states
# Copied from transformers.models.swin.modeling_swin.SwinOutput
class
DonutSwinOutput
(
nn
.
Module
):
def
__init__
(
self
,
config
,
dim
):
super
().
__init__
()
self
.
dense
=
nn
.
Linear
(
int
(
config
.
mlp_ratio
*
dim
),
dim
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
=
self
.
dense
(
hidden_states
)
hidden_states
=
self
.
dropout
(
hidden_states
)
return
hidden_states
# Copied from transformers.models.swin.modeling_swin.SwinLayer with Swin->DonutSwin
class
DonutSwinLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
,
dim
,
input_resolution
,
num_heads
,
shift_size
=
0
):
super
().
__init__
()
self
.
chunk_size_feed_forward
=
config
.
chunk_size_feed_forward
self
.
shift_size
=
shift_size
self
.
window_size
=
config
.
window_size
self
.
input_resolution
=
input_resolution
self
.
layernorm_before
=
nn
.
LayerNorm
(
dim
,
eps
=
config
.
layer_norm_eps
)
self
.
attention
=
DonutSwinAttention
(
config
,
dim
,
num_heads
,
window_size
=
self
.
window_size
)
self
.
drop_path
=
(
DonutSwinDropPath
(
config
.
drop_path_rate
)
if
config
.
drop_path_rate
>
0.0
else
nn
.
Identity
()
)
self
.
layernorm_after
=
nn
.
LayerNorm
(
dim
,
eps
=
config
.
layer_norm_eps
)
self
.
intermediate
=
DonutSwinIntermediate
(
config
,
dim
)
self
.
output
=
DonutSwinOutput
(
config
,
dim
)
self
.
is_export
=
config
.
is_export
def
set_shift_and_window_size
(
self
,
input_resolution
):
if
min
(
input_resolution
)
<=
self
.
window_size
:
# if window size is larger than input resolution, we don't partition windows
self
.
shift_size
=
0
self
.
window_size
=
min
(
input_resolution
)
def
get_attn_mask_export
(
self
,
height
,
width
,
dtype
):
attn_mask
=
None
height_slices
=
(
slice
(
0
,
-
self
.
window_size
),
slice
(
-
self
.
window_size
,
-
self
.
shift_size
),
slice
(
-
self
.
shift_size
,
None
),
)
width_slices
=
(
slice
(
0
,
-
self
.
window_size
),
slice
(
-
self
.
window_size
,
-
self
.
shift_size
),
slice
(
-
self
.
shift_size
,
None
),
)
img_mask
=
torch
.
zeros
((
1
,
height
,
width
,
1
),
dtype
=
dtype
)
count
=
0
for
height_slice
in
height_slices
:
for
width_slice
in
width_slices
:
if
self
.
shift_size
>
0
:
img_mask
[:,
height_slice
,
width_slice
,
:]
=
count
count
+=
1
if
torch
.
Tensor
(
self
.
shift_size
>
0
).
to
(
torch
.
bool
):
# calculate attention mask for SW-MSA
mask_windows
=
window_partition
(
img_mask
,
self
.
window_size
)
mask_windows
=
mask_windows
.
reshape
(
[
-
1
,
self
.
window_size
*
self
.
window_size
]
)
attn_mask
=
mask_windows
.
unsqueeze
(
1
)
-
mask_windows
.
unsqueeze
(
2
)
attn_mask
=
attn_mask
.
masked_fill
(
attn_mask
!=
0
,
float
(
-
100.0
)
).
masked_fill
(
attn_mask
==
0
,
float
(
0.0
))
return
attn_mask
def
get_attn_mask
(
self
,
height
,
width
,
dtype
):
if
self
.
shift_size
>
0
:
# calculate attention mask for SW-MSA
img_mask
=
torch
.
zeros
((
1
,
height
,
width
,
1
),
dtype
=
dtype
)
height_slices
=
(
slice
(
0
,
-
self
.
window_size
),
slice
(
-
self
.
window_size
,
-
self
.
shift_size
),
slice
(
-
self
.
shift_size
,
None
),
)
width_slices
=
(
slice
(
0
,
-
self
.
window_size
),
slice
(
-
self
.
window_size
,
-
self
.
shift_size
),
slice
(
-
self
.
shift_size
,
None
),
)
count
=
0
for
height_slice
in
height_slices
:
for
width_slice
in
width_slices
:
img_mask
[:,
height_slice
,
width_slice
,
:]
=
count
count
+=
1
mask_windows
=
window_partition
(
img_mask
,
self
.
window_size
)
mask_windows
=
mask_windows
.
reshape
(
[
-
1
,
self
.
window_size
*
self
.
window_size
]
)
attn_mask
=
mask_windows
.
unsqueeze
(
1
)
-
mask_windows
.
unsqueeze
(
2
)
attn_mask
=
attn_mask
.
masked_fill
(
attn_mask
!=
0
,
float
(
-
100.0
)
).
masked_fill
(
attn_mask
==
0
,
float
(
0.0
))
else
:
attn_mask
=
None
return
attn_mask
def
maybe_pad
(
self
,
hidden_states
,
height
,
width
):
pad_right
=
(
self
.
window_size
-
width
%
self
.
window_size
)
%
self
.
window_size
pad_bottom
=
(
self
.
window_size
-
height
%
self
.
window_size
)
%
self
.
window_size
pad_values
=
(
0
,
0
,
0
,
pad_bottom
,
0
,
pad_right
,
0
,
0
)
hidden_states
=
nn
.
functional
.
pad
(
hidden_states
,
pad_values
)
return
hidden_states
,
pad_values
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
input_dimensions
:
Tuple
[
int
,
int
],
head_mask
=
None
,
output_attentions
=
False
,
always_partition
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
not
always_partition
:
self
.
set_shift_and_window_size
(
input_dimensions
)
else
:
pass
height
,
width
=
input_dimensions
batch_size
,
_
,
channels
=
hidden_states
.
shape
shortcut
=
hidden_states
hidden_states
=
self
.
layernorm_before
(
hidden_states
)
hidden_states
=
hidden_states
.
reshape
([
batch_size
,
height
,
width
,
channels
])
# pad hidden_states to multiples of window size
hidden_states
,
pad_values
=
self
.
maybe_pad
(
hidden_states
,
height
,
width
)
_
,
height_pad
,
width_pad
,
_
=
hidden_states
.
shape
# cyclic shift
if
self
.
shift_size
>
0
:
shift_value
=
(
-
self
.
shift_size
,
-
self
.
shift_size
)
if
self
.
is_export
:
shift_value
=
torch
.
tensor
(
shift_value
,
dtype
=
torch
.
int32
)
shifted_hidden_states
=
torch
.
roll
(
hidden_states
,
shifts
=
shift_value
,
dims
=
(
1
,
2
)
)
else
:
shifted_hidden_states
=
hidden_states
# partition windows
hidden_states_windows
=
window_partition
(
shifted_hidden_states
,
self
.
window_size
)
hidden_states_windows
=
hidden_states_windows
.
reshape
(
[
-
1
,
self
.
window_size
*
self
.
window_size
,
channels
]
)
attn_mask
=
self
.
get_attn_mask
(
height_pad
,
width_pad
,
dtype
=
hidden_states
.
dtype
)
attention_outputs
=
self
.
attention
(
hidden_states_windows
,
attn_mask
,
head_mask
,
output_attentions
=
output_attentions
,
)
attention_output
=
attention_outputs
[
0
]
attention_windows
=
attention_output
.
reshape
(
[
-
1
,
self
.
window_size
,
self
.
window_size
,
channels
]
)
shifted_windows
=
window_reverse
(
attention_windows
,
self
.
window_size
,
height_pad
,
width_pad
)
# reverse cyclic shift
if
self
.
shift_size
>
0
:
shift_value
=
(
self
.
shift_size
,
self
.
shift_size
)
if
self
.
is_export
:
shift_value
=
torch
.
tensor
(
shift_value
,
dtype
=
torch
.
int32
)
attention_windows
=
torch
.
roll
(
shifted_windows
,
shifts
=
shift_value
,
dims
=
(
1
,
2
)
)
else
:
attention_windows
=
shifted_windows
was_padded
=
pad_values
[
3
]
>
0
or
pad_values
[
5
]
>
0
if
was_padded
:
attention_windows
=
attention_windows
[:,
:
height
,
:
width
,
:].
contiguous
()
attention_windows
=
attention_windows
.
reshape
(
[
batch_size
,
height
*
width
,
channels
]
)
hidden_states
=
shortcut
+
self
.
drop_path
(
attention_windows
)
layer_output
=
self
.
layernorm_after
(
hidden_states
)
layer_output
=
self
.
intermediate
(
layer_output
)
layer_output
=
hidden_states
+
self
.
output
(
layer_output
)
layer_outputs
=
(
(
layer_output
,
attention_outputs
[
1
])
if
output_attentions
else
(
layer_output
,)
)
return
layer_outputs
# Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->DonutSwin
class
DonutSwinStage
(
nn
.
Module
):
def
__init__
(
self
,
config
,
dim
,
input_resolution
,
depth
,
num_heads
,
drop_path
,
downsample
):
super
().
__init__
()
self
.
config
=
config
self
.
dim
=
dim
self
.
blocks
=
nn
.
ModuleList
(
[
DonutSwinLayer
(
config
=
config
,
dim
=
dim
,
input_resolution
=
input_resolution
,
num_heads
=
num_heads
,
shift_size
=
0
if
(
i
%
2
==
0
)
else
config
.
window_size
//
2
,
)
for
i
in
range
(
depth
)
]
)
self
.
is_export
=
config
.
is_export
# patch merging layer
if
downsample
is
not
None
:
self
.
downsample
=
downsample
(
input_resolution
,
dim
=
dim
,
norm_layer
=
nn
.
LayerNorm
,
is_export
=
self
.
is_export
,
)
else
:
self
.
downsample
=
None
self
.
pointing
=
False
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
input_dimensions
:
Tuple
[
int
,
int
],
head_mask
=
None
,
output_attentions
=
False
,
always_partition
=
False
,
)
->
Tuple
[
torch
.
Tensor
]:
height
,
width
=
input_dimensions
for
i
,
layer_module
in
enumerate
(
self
.
blocks
):
layer_head_mask
=
head_mask
[
i
]
if
head_mask
is
not
None
else
None
layer_outputs
=
layer_module
(
hidden_states
,
input_dimensions
,
layer_head_mask
,
output_attentions
,
always_partition
,
)
hidden_states
=
layer_outputs
[
0
]
hidden_states_before_downsampling
=
hidden_states
if
self
.
downsample
is
not
None
:
height_downsampled
,
width_downsampled
=
(
height
+
1
)
//
2
,
(
width
+
1
)
//
2
output_dimensions
=
(
height
,
width
,
height_downsampled
,
width_downsampled
)
hidden_states
=
self
.
downsample
(
hidden_states_before_downsampling
,
input_dimensions
)
else
:
output_dimensions
=
(
height
,
width
,
height
,
width
)
stage_outputs
=
(
hidden_states
,
hidden_states_before_downsampling
,
output_dimensions
,
)
if
output_attentions
:
stage_outputs
+=
layer_outputs
[
1
:]
return
stage_outputs
# Copied from transformers.models.swin.modeling_swin.SwinEncoder with Swin->DonutSwin
class
DonutSwinEncoder
(
nn
.
Module
):
def
__init__
(
self
,
config
,
grid_size
):
super
().
__init__
()
self
.
num_layers
=
len
(
config
.
depths
)
self
.
config
=
config
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
config
.
drop_path_rate
,
sum
(
config
.
depths
))
]
self
.
layers
=
nn
.
ModuleList
(
[
DonutSwinStage
(
config
=
config
,
dim
=
int
(
config
.
embed_dim
*
2
**
i_layer
),
input_resolution
=
(
grid_size
[
0
]
//
(
2
**
i_layer
),
grid_size
[
1
]
//
(
2
**
i_layer
),
),
depth
=
config
.
depths
[
i_layer
],
num_heads
=
config
.
num_heads
[
i_layer
],
drop_path
=
dpr
[
sum
(
config
.
depths
[:
i_layer
])
:
sum
(
config
.
depths
[:
i_layer
+
1
])
],
downsample
=
(
DonutSwinPatchMerging
if
(
i_layer
<
self
.
num_layers
-
1
)
else
None
),
)
for
i_layer
in
range
(
self
.
num_layers
)
]
)
self
.
gradient_checkpointing
=
False
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
input_dimensions
:
Tuple
[
int
,
int
],
head_mask
=
None
,
output_attentions
=
False
,
output_hidden_states
=
False
,
output_hidden_states_before_downsampling
=
False
,
always_partition
=
False
,
return_dict
=
True
,
):
all_hidden_states
=
()
if
output_hidden_states
else
None
all_reshaped_hidden_states
=
()
if
output_hidden_states
else
None
all_self_attentions
=
()
if
output_attentions
else
None
if
output_hidden_states
:
batch_size
,
_
,
hidden_size
=
hidden_states
.
shape
reshaped_hidden_state
=
hidden_states
.
view
(
batch_size
,
*
input_dimensions
,
hidden_size
)
reshaped_hidden_state
=
reshaped_hidden_state
.
permute
(
0
,
3
,
1
,
2
)
all_hidden_states
+=
(
hidden_states
,)
all_reshaped_hidden_states
+=
(
reshaped_hidden_state
,)
for
i
,
layer_module
in
enumerate
(
self
.
layers
):
layer_head_mask
=
head_mask
[
i
]
if
head_mask
is
not
None
else
None
if
self
.
gradient_checkpointing
and
self
.
training
:
layer_outputs
=
self
.
_gradient_checkpointing_func
(
layer_module
.
__call__
,
hidden_states
,
input_dimensions
,
layer_head_mask
,
output_attentions
,
always_partition
,
)
else
:
layer_outputs
=
layer_module
(
hidden_states
,
input_dimensions
,
layer_head_mask
,
output_attentions
,
always_partition
,
)
hidden_states
=
layer_outputs
[
0
]
hidden_states_before_downsampling
=
layer_outputs
[
1
]
output_dimensions
=
layer_outputs
[
2
]
input_dimensions
=
(
output_dimensions
[
-
2
],
output_dimensions
[
-
1
])
if
output_hidden_states
and
output_hidden_states_before_downsampling
:
batch_size
,
_
,
hidden_size
=
hidden_states_before_downsampling
.
shape
reshaped_hidden_state
=
hidden_states_before_downsampling
.
reshape
(
[
batch_size
,
*
(
output_dimensions
[
0
],
output_dimensions
[
1
]),
hidden_size
,
]
)
reshaped_hidden_state
=
reshaped_hidden_state
.
transpose
([
0
,
3
,
1
,
2
])
all_hidden_states
+=
(
hidden_states_before_downsampling
,)
all_reshaped_hidden_states
+=
(
reshaped_hidden_state
,)
elif
output_hidden_states
and
not
output_hidden_states_before_downsampling
:
batch_size
,
_
,
hidden_size
=
hidden_states
.
shape
reshaped_hidden_state
=
hidden_states
.
reshape
(
[
batch_size
,
*
input_dimensions
,
hidden_size
]
)
reshaped_hidden_state
=
reshaped_hidden_state
.
transpose
([
0
,
3
,
1
,
2
])
all_hidden_states
+=
(
hidden_states
,)
all_reshaped_hidden_states
+=
(
reshaped_hidden_state
,)
if
output_attentions
:
all_self_attentions
+=
layer_outputs
[
3
:]
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
hidden_states
,
all_hidden_states
,
all_self_attentions
]
if
v
is
not
None
)
return
DonutSwinEncoderOutput
(
last_hidden_state
=
hidden_states
,
hidden_states
=
all_hidden_states
,
attentions
=
all_self_attentions
,
reshaped_hidden_states
=
all_reshaped_hidden_states
,
)
class
DonutSwinPreTrainedModel
(
nn
.
Module
):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class
=
DonutSwinConfig
base_model_prefix
=
"swin"
main_input_name
=
"pixel_values"
supports_gradient_checkpointing
=
True
def
_init_weights
(
self
,
module
):
"""Initialize the weights"""
if
isinstance
(
module
,
(
nn
.
Linear
,
nn
.
Conv2D
)):
# normal_ = Normal(mean=0.0, std=self.config.initializer_range)
nn
.
init
.
normal_
(
module
.
weight
,
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
)
if
module
.
bias
is
not
None
:
nn
.
init
.
zeros_
(
module
.
bias
)
elif
isinstance
(
module
,
nn
.
LayerNorm
):
nn
.
init
.
zeros_
(
module
.
bias
)
nn
.
init
.
ones_
(
module
.
weight
)
def
_initialize_weights
(
self
,
module
):
"""
Initialize the weights if they are not already initialized.
"""
if
getattr
(
module
,
"_is_hf_initialized"
,
False
):
return
self
.
_init_weights
(
module
)
def
post_init
(
self
):
self
.
apply
(
self
.
_initialize_weights
)
def
get_head_mask
(
self
,
head_mask
,
num_hidden_layers
,
is_attention_chunked
=
False
):
if
head_mask
is
not
None
:
head_mask
=
self
.
_convert_head_mask_to_5d
(
head_mask
,
num_hidden_layers
)
if
is_attention_chunked
is
True
:
head_mask
=
head_mask
.
unsqueeze
(
-
1
)
else
:
head_mask
=
[
None
]
*
num_hidden_layers
return
head_mask
class
DonutSwinModel
(
DonutSwinPreTrainedModel
):
def
__init__
(
self
,
in_channels
=
3
,
hidden_size
=
1024
,
num_layers
=
4
,
num_heads
=
[
4
,
8
,
16
,
32
],
add_pooling_layer
=
True
,
use_mask_token
=
False
,
is_export
=
False
,
):
super
().
__init__
()
donut_swin_config
=
{
"return_dict"
:
True
,
"output_hidden_states"
:
False
,
"output_attentions"
:
False
,
"use_bfloat16"
:
False
,
"tf_legacy_loss"
:
False
,
"pruned_heads"
:
{},
"tie_word_embeddings"
:
True
,
"chunk_size_feed_forward"
:
0
,
"is_encoder_decoder"
:
False
,
"is_decoder"
:
False
,
"cross_attention_hidden_size"
:
None
,
"add_cross_attention"
:
False
,
"tie_encoder_decoder"
:
False
,
"max_length"
:
20
,
"min_length"
:
0
,
"do_sample"
:
False
,
"early_stopping"
:
False
,
"num_beams"
:
1
,
"num_beam_groups"
:
1
,
"diversity_penalty"
:
0.0
,
"temperature"
:
1.0
,
"top_k"
:
50
,
"top_p"
:
1.0
,
"typical_p"
:
1.0
,
"repetition_penalty"
:
1.0
,
"length_penalty"
:
1.0
,
"no_repeat_ngram_size"
:
0
,
"encoder_no_repeat_ngram_size"
:
0
,
"bad_words_ids"
:
None
,
"num_return_sequences"
:
1
,
"output_scores"
:
False
,
"return_dict_in_generate"
:
False
,
"forced_bos_token_id"
:
None
,
"forced_eos_token_id"
:
None
,
"remove_invalid_values"
:
False
,
"exponential_decay_length_penalty"
:
None
,
"suppress_tokens"
:
None
,
"begin_suppress_tokens"
:
None
,
"architectures"
:
None
,
"finetuning_task"
:
None
,
"id2label"
:
{
0
:
"LABEL_0"
,
1
:
"LABEL_1"
},
"label2id"
:
{
"LABEL_0"
:
0
,
"LABEL_1"
:
1
},
"tokenizer_class"
:
None
,
"prefix"
:
None
,
"bos_token_id"
:
None
,
"pad_token_id"
:
None
,
"eos_token_id"
:
None
,
"sep_token_id"
:
None
,
"decoder_start_token_id"
:
None
,
"task_specific_params"
:
None
,
"problem_type"
:
None
,
"_name_or_path"
:
""
,
"_commit_hash"
:
None
,
"_attn_implementation_internal"
:
None
,
"transformers_version"
:
None
,
"hidden_size"
:
hidden_size
,
"num_layers"
:
num_layers
,
"path_norm"
:
True
,
"use_2d_embeddings"
:
False
,
"image_size"
:
[
420
,
420
],
"patch_size"
:
4
,
"num_channels"
:
in_channels
,
"embed_dim"
:
128
,
"depths"
:
[
2
,
2
,
14
,
2
],
"num_heads"
:
num_heads
,
"window_size"
:
5
,
"mlp_ratio"
:
4.0
,
"qkv_bias"
:
True
,
"hidden_dropout_prob"
:
0.0
,
"attention_probs_dropout_prob"
:
0.0
,
"drop_path_rate"
:
0.1
,
"hidden_act"
:
"gelu"
,
"use_absolute_embeddings"
:
False
,
"layer_norm_eps"
:
1e-05
,
"initializer_range"
:
0.02
,
"is_export"
:
is_export
,
}
config
=
DonutSwinConfig
(
**
donut_swin_config
)
self
.
config
=
config
self
.
num_layers
=
len
(
config
.
depths
)
self
.
num_features
=
int
(
config
.
embed_dim
*
2
**
(
self
.
num_layers
-
1
))
self
.
embeddings
=
DonutSwinEmbeddings
(
config
,
use_mask_token
=
use_mask_token
)
self
.
encoder
=
DonutSwinEncoder
(
config
,
self
.
embeddings
.
patch_grid
)
self
.
pooler
=
nn
.
AdaptiveAvgPool1D
(
1
)
if
add_pooling_layer
else
None
self
.
out_channels
=
hidden_size
self
.
post_init
()
def
get_input_embeddings
(
self
):
return
self
.
embeddings
.
patch_embeddings
def
forward
(
self
,
input_data
=
None
,
bool_masked_pos
=
None
,
head_mask
=
None
,
output_attentions
=
None
,
output_hidden_states
=
None
,
return_dict
=
None
,
)
->
Union
[
Tuple
,
DonutSwinModelOutput
]:
r
"""
bool_masked_pos (`paddle.BoolTensor` of shape `(batch_size, num_patches)`):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
"""
if
self
.
training
:
pixel_values
,
label
,
attention_mask
=
input_data
else
:
if
isinstance
(
input_data
,
list
):
pixel_values
=
input_data
[
0
]
else
:
pixel_values
=
input_data
output_attentions
=
(
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
)
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
return_dict
=
(
return_dict
if
return_dict
is
not
None
else
self
.
config
.
return_dict
)
if
pixel_values
is
None
:
raise
ValueError
(
"You have to specify pixel_values"
)
num_channels
=
pixel_values
.
shape
[
1
]
if
num_channels
==
1
:
pixel_values
=
torch
.
repeat_interleave
(
pixel_values
,
repeats
=
3
,
dim
=
1
)
head_mask
=
self
.
get_head_mask
(
head_mask
,
len
(
self
.
config
.
depths
))
embedding_output
,
input_dimensions
=
self
.
embeddings
(
pixel_values
,
bool_masked_pos
=
bool_masked_pos
)
encoder_outputs
=
self
.
encoder
(
embedding_output
,
input_dimensions
,
head_mask
=
head_mask
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
)
sequence_output
=
encoder_outputs
[
0
]
pooled_output
=
None
if
self
.
pooler
is
not
None
:
pooled_output
=
self
.
pooler
(
sequence_output
.
transpose
([
0
,
2
,
1
]))
pooled_output
=
torch
.
flatten
(
pooled_output
,
1
)
if
not
return_dict
:
output
=
(
sequence_output
,
pooled_output
)
+
encoder_outputs
[
1
:]
return
output
donut_swin_output
=
DonutSwinModelOutput
(
last_hidden_state
=
sequence_output
,
pooler_output
=
pooled_output
,
hidden_states
=
encoder_outputs
.
hidden_states
,
attentions
=
encoder_outputs
.
attentions
,
reshaped_hidden_states
=
encoder_outputs
.
reshaped_hidden_states
,
)
if
self
.
training
:
return
donut_swin_output
,
label
,
attention_mask
else
:
return
donut_swin_output
\ No newline at end of file
mineru/model/ocr/paddleocr2pytorch/pytorchocr/modeling/necks/rnn.py
View file @
7230bfe3
...
@@ -9,28 +9,28 @@ class Im2Seq(nn.Module):
...
@@ -9,28 +9,28 @@ class Im2Seq(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
out_channels
=
in_channels
self
.
out_channels
=
in_channels
# def forward(self, x):
# B, C, H, W = x.shape
# # assert H == 1
# x = x.squeeze(dim=2)
# # x = x.transpose([0, 2, 1]) # paddle (NTC)(batch, width, channels)
# x = x.permute(0, 2, 1)
# return x
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
B
,
C
,
H
,
W
=
x
.
shape
# 处理四维张量,将空间维度展平为序列
# assert H == 1
if
H
==
1
:
x
=
x
.
squeeze
(
dim
=
2
)
# 原来的处理逻辑,适用于H=1的情况
# x = x.transpose([0, 2, 1]) # paddle (NTC)(batch, width, channels)
x
=
x
.
squeeze
(
dim
=
2
)
x
=
x
.
permute
(
0
,
2
,
1
)
x
=
x
.
permute
(
0
,
2
,
1
)
# (B, W, C)
else
:
# 处理H不为1的情况
x
=
x
.
permute
(
0
,
2
,
3
,
1
)
# (B, H, W, C)
x
=
x
.
reshape
(
B
,
H
*
W
,
C
)
# (B, H*W, C)
return
x
return
x
# def forward(self, x):
# B, C, H, W = x.shape
# # 处理四维张量,将空间维度展平为序列
# if H == 1:
# # 原来的处理逻辑,适用于H=1的情况
# x = x.squeeze(dim=2)
# x = x.permute(0, 2, 1) # (B, W, C)
# else:
# # 处理H不为1的情况
# x = x.permute(0, 2, 3, 1) # (B, H, W, C)
# x = x.reshape(B, H * W, C) # (B, H*W, C)
#
# return x
class
EncoderWithRNN_
(
nn
.
Module
):
class
EncoderWithRNN_
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
hidden_size
):
def
__init__
(
self
,
in_channels
,
hidden_size
):
super
(
EncoderWithRNN_
,
self
).
__init__
()
super
(
EncoderWithRNN_
,
self
).
__init__
()
...
...
mineru/model/ocr/paddleocr2pytorch/pytorchocr/postprocess/db_postprocess.py
View file @
7230bfe3
...
@@ -124,10 +124,10 @@ class DBPostProcess(object):
...
@@ -124,10 +124,10 @@ class DBPostProcess(object):
'''
'''
h
,
w
=
bitmap
.
shape
[:
2
]
h
,
w
=
bitmap
.
shape
[:
2
]
box
=
_box
.
copy
()
box
=
_box
.
copy
()
xmin
=
np
.
clip
(
np
.
floor
(
box
[:,
0
].
min
()).
astype
(
np
.
int
64
),
0
,
w
-
1
)
xmin
=
np
.
clip
(
np
.
floor
(
box
[:,
0
].
min
()).
astype
(
np
.
int
if
'int'
in
np
.
__dict__
else
np
.
int32
),
0
,
w
-
1
)
xmax
=
np
.
clip
(
np
.
ceil
(
box
[:,
0
].
max
()).
astype
(
np
.
int
64
),
0
,
w
-
1
)
xmax
=
np
.
clip
(
np
.
ceil
(
box
[:,
0
].
max
()).
astype
(
np
.
int
if
'int'
in
np
.
__dict__
else
np
.
int32
),
0
,
w
-
1
)
ymin
=
np
.
clip
(
np
.
floor
(
box
[:,
1
].
min
()).
astype
(
np
.
int
64
),
0
,
h
-
1
)
ymin
=
np
.
clip
(
np
.
floor
(
box
[:,
1
].
min
()).
astype
(
np
.
int
if
'int'
in
np
.
__dict__
else
np
.
int32
),
0
,
h
-
1
)
ymax
=
np
.
clip
(
np
.
ceil
(
box
[:,
1
].
max
()).
astype
(
np
.
int
64
),
0
,
h
-
1
)
ymax
=
np
.
clip
(
np
.
ceil
(
box
[:,
1
].
max
()).
astype
(
np
.
int
if
'int'
in
np
.
__dict__
else
np
.
int32
),
0
,
h
-
1
)
mask
=
np
.
zeros
((
ymax
-
ymin
+
1
,
xmax
-
xmin
+
1
),
dtype
=
np
.
uint8
)
mask
=
np
.
zeros
((
ymax
-
ymin
+
1
,
xmax
-
xmin
+
1
),
dtype
=
np
.
uint8
)
box
[:,
0
]
=
box
[:,
0
]
-
xmin
box
[:,
0
]
=
box
[:,
0
]
-
xmin
...
...
mineru/model/ocr/paddleocr2pytorch/pytorchocr/postprocess/rec_postprocess.py
View file @
7230bfe3
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
re
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -24,8 +25,9 @@ class BaseRecLabelDecode(object):
...
@@ -24,8 +25,9 @@ class BaseRecLabelDecode(object):
self
.
beg_str
=
"sos"
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
self
.
end_str
=
"eos"
self
.
reverse
=
False
self
.
character_str
=
[]
self
.
character_str
=
[]
if
character_dict_path
is
None
:
if
character_dict_path
is
None
:
self
.
character_str
=
"0123456789abcdefghijklmnopqrstuvwxyz"
self
.
character_str
=
"0123456789abcdefghijklmnopqrstuvwxyz"
dict_character
=
list
(
self
.
character_str
)
dict_character
=
list
(
self
.
character_str
)
...
@@ -38,6 +40,8 @@ class BaseRecLabelDecode(object):
...
@@ -38,6 +40,8 @@ class BaseRecLabelDecode(object):
if
use_space_char
:
if
use_space_char
:
self
.
character_str
.
append
(
" "
)
self
.
character_str
.
append
(
" "
)
dict_character
=
list
(
self
.
character_str
)
dict_character
=
list
(
self
.
character_str
)
if
"arabic"
in
character_dict_path
:
self
.
reverse
=
True
dict_character
=
self
.
add_special_char
(
dict_character
)
dict_character
=
self
.
add_special_char
(
dict_character
)
self
.
dict
=
{}
self
.
dict
=
{}
...
@@ -45,10 +49,98 @@ class BaseRecLabelDecode(object):
...
@@ -45,10 +49,98 @@ class BaseRecLabelDecode(object):
self
.
dict
[
char
]
=
i
self
.
dict
[
char
]
=
i
self
.
character
=
dict_character
self
.
character
=
dict_character
def
pred_reverse
(
self
,
pred
):
pred_re
=
[]
c_current
=
""
for
c
in
pred
:
if
not
bool
(
re
.
search
(
"[a-zA-Z0-9 :*./%+-]"
,
c
)):
if
c_current
!=
""
:
pred_re
.
append
(
c_current
)
pred_re
.
append
(
c
)
c_current
=
""
else
:
c_current
+=
c
if
c_current
!=
""
:
pred_re
.
append
(
c_current
)
return
""
.
join
(
pred_re
[::
-
1
])
def
add_special_char
(
self
,
dict_character
):
def
add_special_char
(
self
,
dict_character
):
return
dict_character
return
dict_character
def
decode
(
self
,
text_index
,
text_prob
=
None
,
is_remove_duplicate
=
False
):
def
get_word_info
(
self
,
text
,
selection
):
"""
Group the decoded characters and record the corresponding decoded positions.
Args:
text: the decoded text
selection: the bool array that identifies which columns of features are decoded as non-separated characters
Returns:
word_list: list of the grouped words
word_col_list: list of decoding positions corresponding to each character in the grouped word
state_list: list of marker to identify the type of grouping words, including two types of grouping words:
- 'cn': continuous chinese characters (e.g., 你好啊)
- 'en&num': continuous english characters (e.g., hello), number (e.g., 123, 1.123), or mixed of them connected by '-' (e.g., VGG-16)
The remaining characters in text are treated as separators between groups (e.g., space, '(', ')', etc.).
"""
state
=
None
word_content
=
[]
word_col_content
=
[]
word_list
=
[]
word_col_list
=
[]
state_list
=
[]
valid_col
=
np
.
where
(
selection
==
True
)[
0
]
for
c_i
,
char
in
enumerate
(
text
):
if
"
\u4e00
"
<=
char
<=
"
\u9fff
"
:
c_state
=
"cn"
elif
bool
(
re
.
search
(
"[a-zA-Z0-9]"
,
char
)):
c_state
=
"en&num"
else
:
c_state
=
"splitter"
if
(
char
==
"."
and
state
==
"en&num"
and
c_i
+
1
<
len
(
text
)
and
bool
(
re
.
search
(
"[0-9]"
,
text
[
c_i
+
1
]))
):
# grouping floating number
c_state
=
"en&num"
if
(
char
==
"-"
and
state
==
"en&num"
):
# grouping word with '-', such as 'state-of-the-art'
c_state
=
"en&num"
if
state
==
None
:
state
=
c_state
if
state
!=
c_state
:
if
len
(
word_content
)
!=
0
:
word_list
.
append
(
word_content
)
word_col_list
.
append
(
word_col_content
)
state_list
.
append
(
state
)
word_content
=
[]
word_col_content
=
[]
state
=
c_state
if
state
!=
"splitter"
:
word_content
.
append
(
char
)
word_col_content
.
append
(
valid_col
[
c_i
])
if
len
(
word_content
)
!=
0
:
word_list
.
append
(
word_content
)
word_col_list
.
append
(
word_col_content
)
state_list
.
append
(
state
)
return
word_list
,
word_col_list
,
state_list
def
decode
(
self
,
text_index
,
text_prob
=
None
,
is_remove_duplicate
=
False
,
return_word_box
=
False
,
):
""" convert text-index into text-label. """
""" convert text-index into text-label. """
result_list
=
[]
result_list
=
[]
ignored_tokens
=
self
.
get_ignored_tokens
()
ignored_tokens
=
self
.
get_ignored_tokens
()
...
@@ -88,12 +180,22 @@ class CTCLabelDecode(BaseRecLabelDecode):
...
@@ -88,12 +180,22 @@ class CTCLabelDecode(BaseRecLabelDecode):
super
(
CTCLabelDecode
,
self
).
__init__
(
character_dict_path
,
super
(
CTCLabelDecode
,
self
).
__init__
(
character_dict_path
,
use_space_char
)
use_space_char
)
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
def
__call__
(
self
,
preds
,
label
=
None
,
return_word_box
=
False
,
*
args
,
**
kwargs
):
if
isinstance
(
preds
,
torch
.
Tensor
):
if
isinstance
(
preds
,
torch
.
Tensor
):
preds
=
preds
.
numpy
()
preds
=
preds
.
numpy
()
preds_idx
=
preds
.
argmax
(
axis
=
2
)
preds_idx
=
preds
.
argmax
(
axis
=
2
)
preds_prob
=
preds
.
max
(
axis
=
2
)
preds_prob
=
preds
.
max
(
axis
=
2
)
text
=
self
.
decode
(
preds_idx
,
preds_prob
,
is_remove_duplicate
=
True
)
text
=
self
.
decode
(
preds_idx
,
preds_prob
,
is_remove_duplicate
=
True
,
return_word_box
=
return_word_box
,
)
if
return_word_box
:
for
rec_idx
,
rec
in
enumerate
(
text
):
wh_ratio
=
kwargs
[
"wh_ratio_list"
][
rec_idx
]
max_wh_ratio
=
kwargs
[
"max_wh_ratio"
]
rec
[
2
][
0
]
=
rec
[
2
][
0
]
*
(
wh_ratio
/
max_wh_ratio
)
if
label
is
None
:
if
label
is
None
:
return
text
return
text
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment