Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
f6532a0e
Unverified
Commit
f6532a0e
authored
Apr 26, 2022
by
andyjpaddle
Committed by
GitHub
Apr 26, 2022
Browse files
add ppocrv3 rec (#6033)
* add ppocrv3 rec
parent
6902d160
Changes
30
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
297 additions
and
27 deletions
+297
-27
ppocr/modeling/heads/rec_sar_head.py
ppocr/modeling/heads/rec_sar_head.py
+11
-3
ppocr/modeling/necks/rnn.py
ppocr/modeling/necks/rnn.py
+106
-7
ppocr/postprocess/__init__.py
ppocr/postprocess/__init__.py
+2
-1
ppocr/postprocess/rec_postprocess.py
ppocr/postprocess/rec_postprocess.py
+38
-0
tools/eval.py
tools/eval.py
+28
-4
tools/export_model.py
tools/export_model.py
+32
-2
tools/infer/predict_rec.py
tools/infer/predict_rec.py
+4
-2
tools/infer_rec.py
tools/infer_rec.py
+22
-2
tools/program.py
tools/program.py
+14
-4
tools/train.py
tools/train.py
+40
-2
No files found.
ppocr/modeling/heads/rec_sar_head.py
View file @
f6532a0e
...
@@ -349,7 +349,10 @@ class ParallelSARDecoder(BaseDecoder):
...
@@ -349,7 +349,10 @@ class ParallelSARDecoder(BaseDecoder):
class
SARHead
(
nn
.
Layer
):
class
SARHead
(
nn
.
Layer
):
def
__init__
(
self
,
def
__init__
(
self
,
in_channels
,
out_channels
,
out_channels
,
enc_dim
=
512
,
max_text_length
=
30
,
enc_bi_rnn
=
False
,
enc_bi_rnn
=
False
,
enc_drop_rnn
=
0.1
,
enc_drop_rnn
=
0.1
,
enc_gru
=
False
,
enc_gru
=
False
,
...
@@ -358,14 +361,17 @@ class SARHead(nn.Layer):
...
@@ -358,14 +361,17 @@ class SARHead(nn.Layer):
dec_gru
=
False
,
dec_gru
=
False
,
d_k
=
512
,
d_k
=
512
,
pred_dropout
=
0.1
,
pred_dropout
=
0.1
,
max_text_length
=
30
,
pred_concat
=
True
,
pred_concat
=
True
,
**
kwargs
):
**
kwargs
):
super
(
SARHead
,
self
).
__init__
()
super
(
SARHead
,
self
).
__init__
()
# encoder module
# encoder module
self
.
encoder
=
SAREncoder
(
self
.
encoder
=
SAREncoder
(
enc_bi_rnn
=
enc_bi_rnn
,
enc_drop_rnn
=
enc_drop_rnn
,
enc_gru
=
enc_gru
)
enc_bi_rnn
=
enc_bi_rnn
,
enc_drop_rnn
=
enc_drop_rnn
,
enc_gru
=
enc_gru
,
d_model
=
in_channels
,
d_enc
=
enc_dim
)
# decoder module
# decoder module
self
.
decoder
=
ParallelSARDecoder
(
self
.
decoder
=
ParallelSARDecoder
(
...
@@ -374,6 +380,8 @@ class SARHead(nn.Layer):
...
@@ -374,6 +380,8 @@ class SARHead(nn.Layer):
dec_bi_rnn
=
dec_bi_rnn
,
dec_bi_rnn
=
dec_bi_rnn
,
dec_drop_rnn
=
dec_drop_rnn
,
dec_drop_rnn
=
dec_drop_rnn
,
dec_gru
=
dec_gru
,
dec_gru
=
dec_gru
,
d_model
=
in_channels
,
d_enc
=
enc_dim
,
d_k
=
d_k
,
d_k
=
d_k
,
pred_dropout
=
pred_dropout
,
pred_dropout
=
pred_dropout
,
max_text_length
=
max_text_length
,
max_text_length
=
max_text_length
,
...
@@ -390,7 +398,7 @@ class SARHead(nn.Layer):
...
@@ -390,7 +398,7 @@ class SARHead(nn.Layer):
label
=
paddle
.
to_tensor
(
label
,
dtype
=
'int64'
)
label
=
paddle
.
to_tensor
(
label
,
dtype
=
'int64'
)
final_out
=
self
.
decoder
(
final_out
=
self
.
decoder
(
feat
,
holistic_feat
,
label
,
img_metas
=
targets
)
feat
,
holistic_feat
,
label
,
img_metas
=
targets
)
if
not
self
.
training
:
else
:
final_out
=
self
.
decoder
(
final_out
=
self
.
decoder
(
feat
,
feat
,
holistic_feat
,
holistic_feat
,
...
...
ppocr/modeling/necks/rnn.py
View file @
f6532a0e
...
@@ -16,9 +16,11 @@ from __future__ import absolute_import
...
@@ -16,9 +16,11 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
paddle
from
paddle
import
nn
from
paddle
import
nn
from
ppocr.modeling.heads.rec_ctc_head
import
get_para_bias_attr
from
ppocr.modeling.heads.rec_ctc_head
import
get_para_bias_attr
from
ppocr.modeling.backbones.rec_svtrnet
import
Block
,
ConvBNLayer
,
trunc_normal_
,
zeros_
,
ones_
class
Im2Seq
(
nn
.
Layer
):
class
Im2Seq
(
nn
.
Layer
):
...
@@ -64,29 +66,126 @@ class EncoderWithFC(nn.Layer):
...
@@ -64,29 +66,126 @@ class EncoderWithFC(nn.Layer):
return
x
return
x
class
EncoderWithSVTR
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
dims
=
64
,
# XS
depth
=
2
,
hidden_dims
=
120
,
use_guide
=
False
,
num_heads
=
8
,
qkv_bias
=
True
,
mlp_ratio
=
2.0
,
drop_rate
=
0.1
,
attn_drop_rate
=
0.1
,
drop_path
=
0.
,
qk_scale
=
None
):
super
(
EncoderWithSVTR
,
self
).
__init__
()
self
.
depth
=
depth
self
.
use_guide
=
use_guide
self
.
conv1
=
ConvBNLayer
(
in_channels
,
in_channels
//
8
,
padding
=
1
,
act
=
nn
.
Swish
)
self
.
conv2
=
ConvBNLayer
(
in_channels
//
8
,
hidden_dims
,
kernel_size
=
1
,
act
=
nn
.
Swish
)
self
.
svtr_block
=
nn
.
LayerList
([
Block
(
dim
=
hidden_dims
,
num_heads
=
num_heads
,
mixer
=
'Global'
,
HW
=
None
,
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
act_layer
=
nn
.
Swish
,
attn_drop
=
attn_drop_rate
,
drop_path
=
drop_path
,
norm_layer
=
'nn.LayerNorm'
,
epsilon
=
1e-05
,
prenorm
=
False
)
for
i
in
range
(
depth
)
])
self
.
norm
=
nn
.
LayerNorm
(
hidden_dims
,
epsilon
=
1e-6
)
self
.
conv3
=
ConvBNLayer
(
hidden_dims
,
in_channels
,
kernel_size
=
1
,
act
=
nn
.
Swish
)
# last conv-nxn, the input is concat of input tensor and conv3 output tensor
self
.
conv4
=
ConvBNLayer
(
2
*
in_channels
,
in_channels
//
8
,
padding
=
1
,
act
=
nn
.
Swish
)
self
.
conv1x1
=
ConvBNLayer
(
in_channels
//
8
,
dims
,
kernel_size
=
1
,
act
=
nn
.
Swish
)
self
.
out_channels
=
dims
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
zeros_
(
m
.
bias
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
zeros_
(
m
.
bias
)
ones_
(
m
.
weight
)
def
forward
(
self
,
x
):
# for use guide
if
self
.
use_guide
:
z
=
x
.
clone
()
z
.
stop_gradient
=
True
else
:
z
=
x
# for short cut
h
=
z
# reduce dim
z
=
self
.
conv1
(
z
)
z
=
self
.
conv2
(
z
)
# SVTR global block
B
,
C
,
H
,
W
=
z
.
shape
z
=
z
.
flatten
(
2
).
transpose
([
0
,
2
,
1
])
for
blk
in
self
.
svtr_block
:
z
=
blk
(
z
)
z
=
self
.
norm
(
z
)
# last stage
z
=
z
.
reshape
([
0
,
H
,
W
,
C
]).
transpose
([
0
,
3
,
1
,
2
])
z
=
self
.
conv3
(
z
)
z
=
paddle
.
concat
((
h
,
z
),
axis
=
1
)
z
=
self
.
conv1x1
(
self
.
conv4
(
z
))
return
z
class
SequenceEncoder
(
nn
.
Layer
):
class
SequenceEncoder
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
encoder_type
,
hidden_size
=
48
,
**
kwargs
):
def
__init__
(
self
,
in_channels
,
encoder_type
,
hidden_size
=
48
,
**
kwargs
):
super
(
SequenceEncoder
,
self
).
__init__
()
super
(
SequenceEncoder
,
self
).
__init__
()
self
.
encoder_reshape
=
Im2Seq
(
in_channels
)
self
.
encoder_reshape
=
Im2Seq
(
in_channels
)
self
.
out_channels
=
self
.
encoder_reshape
.
out_channels
self
.
out_channels
=
self
.
encoder_reshape
.
out_channels
self
.
encoder_type
=
encoder_type
if
encoder_type
==
'reshape'
:
if
encoder_type
==
'reshape'
:
self
.
only_reshape
=
True
self
.
only_reshape
=
True
else
:
else
:
support_encoder_dict
=
{
support_encoder_dict
=
{
'reshape'
:
Im2Seq
,
'reshape'
:
Im2Seq
,
'fc'
:
EncoderWithFC
,
'fc'
:
EncoderWithFC
,
'rnn'
:
EncoderWithRNN
'rnn'
:
EncoderWithRNN
,
'svtr'
:
EncoderWithSVTR
}
}
assert
encoder_type
in
support_encoder_dict
,
'{} must in {}'
.
format
(
assert
encoder_type
in
support_encoder_dict
,
'{} must in {}'
.
format
(
encoder_type
,
support_encoder_dict
.
keys
())
encoder_type
,
support_encoder_dict
.
keys
())
if
encoder_type
==
"svtr"
:
self
.
encoder
=
support_encoder_dict
[
encoder_type
](
self
.
encoder
=
support_encoder_dict
[
encoder_type
](
self
.
encoder_reshape
.
out_channels
,
hidden_size
)
self
.
encoder_reshape
.
out_channels
,
**
kwargs
)
else
:
self
.
encoder
=
support_encoder_dict
[
encoder_type
](
self
.
encoder_reshape
.
out_channels
,
hidden_size
)
self
.
out_channels
=
self
.
encoder
.
out_channels
self
.
out_channels
=
self
.
encoder
.
out_channels
self
.
only_reshape
=
False
self
.
only_reshape
=
False
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
=
self
.
encoder_reshape
(
x
)
if
self
.
encoder_type
!=
'svtr'
:
if
not
self
.
only_reshape
:
x
=
self
.
encoder_reshape
(
x
)
if
not
self
.
only_reshape
:
x
=
self
.
encoder
(
x
)
return
x
else
:
x
=
self
.
encoder
(
x
)
x
=
self
.
encoder
(
x
)
return
x
x
=
self
.
encoder_reshape
(
x
)
return
x
ppocr/postprocess/__init__.py
View file @
f6532a0e
...
@@ -41,7 +41,8 @@ def build_post_process(config, global_config=None):
...
@@ -41,7 +41,8 @@ def build_post_process(config, global_config=None):
'PGPostProcess'
,
'DistillationCTCLabelDecode'
,
'TableLabelDecode'
,
'PGPostProcess'
,
'DistillationCTCLabelDecode'
,
'TableLabelDecode'
,
'DistillationDBPostProcess'
,
'NRTRLabelDecode'
,
'SARLabelDecode'
,
'DistillationDBPostProcess'
,
'NRTRLabelDecode'
,
'SARLabelDecode'
,
'SEEDLabelDecode'
,
'VQASerTokenLayoutLMPostProcess'
,
'SEEDLabelDecode'
,
'VQASerTokenLayoutLMPostProcess'
,
'VQAReTokenLayoutLMPostProcess'
,
'PRENLabelDecode'
'VQAReTokenLayoutLMPostProcess'
,
'PRENLabelDecode'
,
'DistillationSARLabelDecode'
]
]
if
config
[
'name'
]
==
'PSEPostProcess'
:
if
config
[
'name'
]
==
'PSEPostProcess'
:
...
...
ppocr/postprocess/rec_postprocess.py
View file @
f6532a0e
...
@@ -117,6 +117,7 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
...
@@ -117,6 +117,7 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
use_space_char
=
False
,
use_space_char
=
False
,
model_name
=
[
"student"
],
model_name
=
[
"student"
],
key
=
None
,
key
=
None
,
multi_head
=
False
,
**
kwargs
):
**
kwargs
):
super
(
DistillationCTCLabelDecode
,
self
).
__init__
(
character_dict_path
,
super
(
DistillationCTCLabelDecode
,
self
).
__init__
(
character_dict_path
,
use_space_char
)
use_space_char
)
...
@@ -125,6 +126,7 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
...
@@ -125,6 +126,7 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
self
.
model_name
=
model_name
self
.
model_name
=
model_name
self
.
key
=
key
self
.
key
=
key
self
.
multi_head
=
multi_head
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
output
=
dict
()
output
=
dict
()
...
@@ -132,6 +134,8 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
...
@@ -132,6 +134,8 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
pred
=
preds
[
name
]
pred
=
preds
[
name
]
if
self
.
key
is
not
None
:
if
self
.
key
is
not
None
:
pred
=
pred
[
self
.
key
]
pred
=
pred
[
self
.
key
]
if
self
.
multi_head
and
isinstance
(
pred
,
dict
):
pred
=
pred
[
'ctc'
]
output
[
name
]
=
super
().
__call__
(
pred
,
label
=
label
,
*
args
,
**
kwargs
)
output
[
name
]
=
super
().
__call__
(
pred
,
label
=
label
,
*
args
,
**
kwargs
)
return
output
return
output
...
@@ -656,6 +660,40 @@ class SARLabelDecode(BaseRecLabelDecode):
...
@@ -656,6 +660,40 @@ class SARLabelDecode(BaseRecLabelDecode):
return
[
self
.
padding_idx
]
return
[
self
.
padding_idx
]
class
DistillationSARLabelDecode
(
SARLabelDecode
):
"""
Convert
Convert between text-label and text-index
"""
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
False
,
model_name
=
[
"student"
],
key
=
None
,
multi_head
=
False
,
**
kwargs
):
super
(
DistillationSARLabelDecode
,
self
).
__init__
(
character_dict_path
,
use_space_char
)
if
not
isinstance
(
model_name
,
list
):
model_name
=
[
model_name
]
self
.
model_name
=
model_name
self
.
key
=
key
self
.
multi_head
=
multi_head
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
output
=
dict
()
for
name
in
self
.
model_name
:
pred
=
preds
[
name
]
if
self
.
key
is
not
None
:
pred
=
pred
[
self
.
key
]
if
self
.
multi_head
and
isinstance
(
pred
,
dict
):
pred
=
pred
[
'sar'
]
output
[
name
]
=
super
().
__call__
(
pred
,
label
=
label
,
*
args
,
**
kwargs
)
return
output
class
PRENLabelDecode
(
BaseRecLabelDecode
):
class
PRENLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
""" Convert between text-label and text-index """
...
...
tools/eval.py
View file @
f6532a0e
...
@@ -47,14 +47,38 @@ def main():
...
@@ -47,14 +47,38 @@ def main():
if
config
[
'Architecture'
][
"algorithm"
]
in
[
"Distillation"
,
if
config
[
'Architecture'
][
"algorithm"
]
in
[
"Distillation"
,
]:
# distillation model
]:
# distillation model
for
key
in
config
[
'Architecture'
][
"Models"
]:
for
key
in
config
[
'Architecture'
][
"Models"
]:
config
[
'Architecture'
][
"Models"
][
key
][
"Head"
][
if
config
[
'Architecture'
][
'Models'
][
key
][
'Head'
][
'out_channels'
]
=
char_num
'name'
]
==
'MultiHead'
:
# for multi head
out_channels_list
=
{}
if
config
[
'PostProcess'
][
'name'
]
==
'DistillationSARLabelDecode'
:
char_num
=
char_num
-
2
out_channels_list
[
'CTCLabelDecode'
]
=
char_num
out_channels_list
[
'SARLabelDecode'
]
=
char_num
+
2
config
[
'Architecture'
][
'Models'
][
key
][
'Head'
][
'out_channels_list'
]
=
out_channels_list
else
:
config
[
'Architecture'
][
"Models"
][
key
][
"Head"
][
'out_channels'
]
=
char_num
elif
config
[
'Architecture'
][
'Head'
][
'name'
]
==
'MultiHead'
:
# for multi head
out_channels_list
=
{}
if
config
[
'PostProcess'
][
'name'
]
==
'SARLabelDecode'
:
char_num
=
char_num
-
2
out_channels_list
[
'CTCLabelDecode'
]
=
char_num
out_channels_list
[
'SARLabelDecode'
]
=
char_num
+
2
config
[
'Architecture'
][
'Head'
][
'out_channels_list'
]
=
out_channels_list
else
:
# base rec model
else
:
# base rec model
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
char_num
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
char_num
model
=
build_model
(
config
[
'Architecture'
])
model
=
build_model
(
config
[
'Architecture'
])
extra_input
=
config
[
'Architecture'
][
extra_input_models
=
[
"SRN"
,
"NRTR"
,
"SAR"
,
"SEED"
,
"SVTR"
]
'algorithm'
]
in
[
"SRN"
,
"NRTR"
,
"SAR"
,
"SEED"
]
if
config
[
'Architecture'
][
'algorithm'
]
==
'Distillation'
:
extra_input
=
config
[
'Architecture'
][
'Models'
][
'Teacher'
][
'algorithm'
]
in
extra_input_models
else
:
extra_input
=
config
[
'Architecture'
][
'algorithm'
]
in
extra_input_models
if
"model_type"
in
config
[
'Architecture'
].
keys
():
if
"model_type"
in
config
[
'Architecture'
].
keys
():
model_type
=
config
[
'Architecture'
][
'model_type'
]
model_type
=
config
[
'Architecture'
][
'model_type'
]
else
:
else
:
...
...
tools/export_model.py
View file @
f6532a0e
...
@@ -55,6 +55,13 @@ def export_single_model(model, arch_config, save_path, logger):
...
@@ -55,6 +55,13 @@ def export_single_model(model, arch_config, save_path, logger):
shape
=
[
None
,
3
,
48
,
160
],
dtype
=
"float32"
),
shape
=
[
None
,
3
,
48
,
160
],
dtype
=
"float32"
),
]
]
model
=
to_static
(
model
,
input_spec
=
other_shape
)
model
=
to_static
(
model
,
input_spec
=
other_shape
)
elif
arch_config
[
"algorithm"
]
==
"SVTR"
:
if
arch_config
[
"Head"
][
"name"
]
==
'MultiHead'
:
other_shape
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
3
,
48
,
-
1
],
dtype
=
"float32"
),
]
model
=
to_static
(
model
,
input_spec
=
other_shape
)
elif
arch_config
[
"algorithm"
]
==
"PREN"
:
elif
arch_config
[
"algorithm"
]
==
"PREN"
:
other_shape
=
[
other_shape
=
[
paddle
.
static
.
InputSpec
(
paddle
.
static
.
InputSpec
(
...
@@ -105,13 +112,36 @@ def main():
...
@@ -105,13 +112,36 @@ def main():
if
config
[
"Architecture"
][
"algorithm"
]
in
[
"Distillation"
,
if
config
[
"Architecture"
][
"algorithm"
]
in
[
"Distillation"
,
]:
# distillation model
]:
# distillation model
for
key
in
config
[
"Architecture"
][
"Models"
]:
for
key
in
config
[
"Architecture"
][
"Models"
]:
config
[
"Architecture"
][
"Models"
][
key
][
"Head"
][
if
config
[
"Architecture"
][
"Models"
][
key
][
"Head"
][
"out_channels"
]
=
char_num
"name"
]
==
'MultiHead'
:
# multi head
out_channels_list
=
{}
if
config
[
'PostProcess'
][
'name'
]
==
'DistillationSARLabelDecode'
:
char_num
=
char_num
-
2
out_channels_list
[
'CTCLabelDecode'
]
=
char_num
out_channels_list
[
'SARLabelDecode'
]
=
char_num
+
2
loss_list
=
config
[
'Loss'
][
'loss_config_list'
]
config
[
'Architecture'
][
'Models'
][
key
][
'Head'
][
'out_channels_list'
]
=
out_channels_list
else
:
config
[
"Architecture"
][
"Models"
][
key
][
"Head"
][
"out_channels"
]
=
char_num
# just one final tensor needs to to exported for inference
# just one final tensor needs to to exported for inference
config
[
"Architecture"
][
"Models"
][
key
][
config
[
"Architecture"
][
"Models"
][
key
][
"return_all_feats"
]
=
False
"return_all_feats"
]
=
False
elif
config
[
'Architecture'
][
'Head'
][
'name'
]
==
'MultiHead'
:
# multi head
out_channels_list
=
{}
char_num
=
len
(
getattr
(
post_process_class
,
'character'
))
if
config
[
'PostProcess'
][
'name'
]
==
'SARLabelDecode'
:
char_num
=
char_num
-
2
out_channels_list
[
'CTCLabelDecode'
]
=
char_num
out_channels_list
[
'SARLabelDecode'
]
=
char_num
+
2
config
[
'Architecture'
][
'Head'
][
'out_channels_list'
]
=
out_channels_list
else
:
# base rec model
else
:
# base rec model
config
[
"Architecture"
][
"Head"
][
"out_channels"
]
=
char_num
config
[
"Architecture"
][
"Head"
][
"out_channels"
]
=
char_num
model
=
build_model
(
config
[
"Architecture"
])
model
=
build_model
(
config
[
"Architecture"
])
load_model
(
config
,
model
)
load_model
(
config
,
model
)
model
.
eval
()
model
.
eval
()
...
...
tools/infer/predict_rec.py
View file @
f6532a0e
...
@@ -107,7 +107,7 @@ class TextRecognizer(object):
...
@@ -107,7 +107,7 @@ class TextRecognizer(object):
return
norm_img
.
astype
(
np
.
float32
)
/
128.
-
1.
return
norm_img
.
astype
(
np
.
float32
)
/
128.
-
1.
assert
imgC
==
img
.
shape
[
2
]
assert
imgC
==
img
.
shape
[
2
]
imgW
=
int
((
32
*
max_wh_ratio
))
imgW
=
int
((
imgH
*
max_wh_ratio
))
if
self
.
use_onnx
:
if
self
.
use_onnx
:
w
=
self
.
input_tensor
.
shape
[
3
:][
0
]
w
=
self
.
input_tensor
.
shape
[
3
:][
0
]
if
w
is
not
None
and
w
>
0
:
if
w
is
not
None
and
w
>
0
:
...
@@ -255,7 +255,9 @@ class TextRecognizer(object):
...
@@ -255,7 +255,9 @@ class TextRecognizer(object):
for
beg_img_no
in
range
(
0
,
img_num
,
batch_num
):
for
beg_img_no
in
range
(
0
,
img_num
,
batch_num
):
end_img_no
=
min
(
img_num
,
beg_img_no
+
batch_num
)
end_img_no
=
min
(
img_num
,
beg_img_no
+
batch_num
)
norm_img_batch
=
[]
norm_img_batch
=
[]
max_wh_ratio
=
0
imgC
,
imgH
,
imgW
=
self
.
rec_image_shape
max_wh_ratio
=
imgW
/
imgH
# max_wh_ratio = 0
for
ino
in
range
(
beg_img_no
,
end_img_no
):
for
ino
in
range
(
beg_img_no
,
end_img_no
):
h
,
w
=
img_list
[
indices
[
ino
]].
shape
[
0
:
2
]
h
,
w
=
img_list
[
indices
[
ino
]].
shape
[
0
:
2
]
wh_ratio
=
w
*
1.0
/
h
wh_ratio
=
w
*
1.0
/
h
...
...
tools/infer_rec.py
View file @
f6532a0e
...
@@ -51,8 +51,28 @@ def main():
...
@@ -51,8 +51,28 @@ def main():
if
config
[
'Architecture'
][
"algorithm"
]
in
[
"Distillation"
,
if
config
[
'Architecture'
][
"algorithm"
]
in
[
"Distillation"
,
]:
# distillation model
]:
# distillation model
for
key
in
config
[
'Architecture'
][
"Models"
]:
for
key
in
config
[
'Architecture'
][
"Models"
]:
config
[
'Architecture'
][
"Models"
][
key
][
"Head"
][
if
config
[
'Architecture'
][
'Models'
][
key
][
'Head'
][
'out_channels'
]
=
char_num
'name'
]
==
'MultiHead'
:
# for multi head
out_channels_list
=
{}
if
config
[
'PostProcess'
][
'name'
]
==
'DistillationSARLabelDecode'
:
char_num
=
char_num
-
2
out_channels_list
[
'CTCLabelDecode'
]
=
char_num
out_channels_list
[
'SARLabelDecode'
]
=
char_num
+
2
config
[
'Architecture'
][
'Models'
][
key
][
'Head'
][
'out_channels_list'
]
=
out_channels_list
else
:
config
[
'Architecture'
][
"Models"
][
key
][
"Head"
][
'out_channels'
]
=
char_num
elif
config
[
'Architecture'
][
'Head'
][
'name'
]
==
'MultiHead'
:
# for multi head loss
out_channels_list
=
{}
if
config
[
'PostProcess'
][
'name'
]
==
'SARLabelDecode'
:
char_num
=
char_num
-
2
out_channels_list
[
'CTCLabelDecode'
]
=
char_num
out_channels_list
[
'SARLabelDecode'
]
=
char_num
+
2
config
[
'Architecture'
][
'Head'
][
'out_channels_list'
]
=
out_channels_list
else
:
# base rec model
else
:
# base rec model
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
char_num
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
char_num
...
...
tools/program.py
View file @
f6532a0e
...
@@ -201,12 +201,17 @@ def train(config,
...
@@ -201,12 +201,17 @@ def train(config,
model
.
train
()
model
.
train
()
use_srn
=
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
use_srn
=
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
extra_input
=
config
[
'Architecture'
][
extra_input_models
=
[
"SRN"
,
"NRTR"
,
"SAR"
,
"SEED"
,
"SVTR"
]
'algorithm'
]
in
[
"SRN"
,
"NRTR"
,
"SAR"
,
"SEED"
]
if
config
[
'Architecture'
][
'algorithm'
]
==
'Distillation'
:
extra_input
=
config
[
'Architecture'
][
'Models'
][
'Teacher'
][
'algorithm'
]
in
extra_input_models
else
:
extra_input
=
config
[
'Architecture'
][
'algorithm'
]
in
extra_input_models
try
:
try
:
model_type
=
config
[
'Architecture'
][
'model_type'
]
model_type
=
config
[
'Architecture'
][
'model_type'
]
except
:
except
:
model_type
=
None
model_type
=
None
algorithm
=
config
[
'Architecture'
][
'algorithm'
]
algorithm
=
config
[
'Architecture'
][
'algorithm'
]
start_epoch
=
best_model_dict
[
start_epoch
=
best_model_dict
[
...
@@ -269,7 +274,12 @@ def train(config,
...
@@ -269,7 +274,12 @@ def train(config,
if
model_type
in
[
'table'
,
'kie'
]:
if
model_type
in
[
'table'
,
'kie'
]:
eval_class
(
preds
,
batch
)
eval_class
(
preds
,
batch
)
else
:
else
:
post_result
=
post_process_class
(
preds
,
batch
[
1
])
if
config
[
'Loss'
][
'name'
]
in
[
'MultiLoss'
,
'MultiLoss_v2'
]:
# for multi head loss
post_result
=
post_process_class
(
preds
[
'ctc'
],
batch
[
1
])
# for CTC head out
else
:
post_result
=
post_process_class
(
preds
,
batch
[
1
])
eval_class
(
post_result
,
batch
)
eval_class
(
post_result
,
batch
)
metric
=
eval_class
.
get_metric
()
metric
=
eval_class
.
get_metric
()
train_stats
.
update
(
metric
)
train_stats
.
update
(
metric
)
...
@@ -541,7 +551,7 @@ def preprocess(is_train=False):
...
@@ -541,7 +551,7 @@ def preprocess(is_train=False):
assert
alg
in
[
assert
alg
in
[
'EAST'
,
'DB'
,
'SAST'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
,
'EAST'
,
'DB'
,
'SAST'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
,
'CLS'
,
'PGNet'
,
'Distillation'
,
'NRTR'
,
'TableAttn'
,
'SAR'
,
'PSE'
,
'CLS'
,
'PGNet'
,
'Distillation'
,
'NRTR'
,
'TableAttn'
,
'SAR'
,
'PSE'
,
'SEED'
,
'SDMGR'
,
'LayoutXLM'
,
'LayoutLM'
,
'PREN'
,
'FCE'
'SEED'
,
'SDMGR'
,
'LayoutXLM'
,
'LayoutLM'
,
'PREN'
,
'FCE'
,
'SVTR'
]
]
device
=
'cpu'
device
=
'cpu'
...
...
tools/train.py
View file @
f6532a0e
...
@@ -74,11 +74,49 @@ def main(config, device, logger, vdl_writer):
...
@@ -74,11 +74,49 @@ def main(config, device, logger, vdl_writer):
if
config
[
'Architecture'
][
"algorithm"
]
in
[
"Distillation"
,
if
config
[
'Architecture'
][
"algorithm"
]
in
[
"Distillation"
,
]:
# distillation model
]:
# distillation model
for
key
in
config
[
'Architecture'
][
"Models"
]:
for
key
in
config
[
'Architecture'
][
"Models"
]:
config
[
'Architecture'
][
"Models"
][
key
][
"Head"
][
if
config
[
'Architecture'
][
'Models'
][
key
][
'Head'
][
'out_channels'
]
=
char_num
'name'
]
==
'MultiHead'
:
# for multi head
if
config
[
'PostProcess'
][
'name'
]
==
'DistillationSARLabelDecode'
:
char_num
=
char_num
-
2
# update SARLoss params
assert
list
(
config
[
'Loss'
][
'loss_config_list'
][
-
1
].
keys
())[
0
]
==
'DistillationSARLoss'
config
[
'Loss'
][
'loss_config_list'
][
-
1
][
'DistillationSARLoss'
][
'ignore_index'
]
=
char_num
+
1
out_channels_list
=
{}
out_channels_list
[
'CTCLabelDecode'
]
=
char_num
out_channels_list
[
'SARLabelDecode'
]
=
char_num
+
2
config
[
'Architecture'
][
'Models'
][
key
][
'Head'
][
'out_channels_list'
]
=
out_channels_list
else
:
config
[
'Architecture'
][
"Models"
][
key
][
"Head"
][
'out_channels'
]
=
char_num
elif
config
[
'Architecture'
][
'Head'
][
'name'
]
==
'MultiHead'
:
# for multi head
if
config
[
'PostProcess'
][
'name'
]
==
'SARLabelDecode'
:
char_num
=
char_num
-
2
# update SARLoss params
assert
list
(
config
[
'Loss'
][
'loss_config_list'
][
1
].
keys
())[
0
]
==
'SARLoss'
if
config
[
'Loss'
][
'loss_config_list'
][
1
][
'SARLoss'
]
is
None
:
config
[
'Loss'
][
'loss_config_list'
][
1
][
'SARLoss'
]
=
{
'ignore_index'
:
char_num
+
1
}
else
:
config
[
'Loss'
][
'loss_config_list'
][
1
][
'SARLoss'
][
'ignore_index'
]
=
char_num
+
1
out_channels_list
=
{}
out_channels_list
[
'CTCLabelDecode'
]
=
char_num
out_channels_list
[
'SARLabelDecode'
]
=
char_num
+
2
config
[
'Architecture'
][
'Head'
][
'out_channels_list'
]
=
out_channels_list
else
:
# base rec model
else
:
# base rec model
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
char_num
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
char_num
if
config
[
'PostProcess'
][
'name'
]
==
'SARLabelDecode'
:
# for SAR model
config
[
'Loss'
][
'ignore_index'
]
=
char_num
-
1
model
=
build_model
(
config
[
'Architecture'
])
model
=
build_model
(
config
[
'Architecture'
])
if
config
[
'Global'
][
'distributed'
]:
if
config
[
'Global'
][
'distributed'
]:
model
=
paddle
.
DataParallel
(
model
)
model
=
paddle
.
DataParallel
(
model
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment