Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
f6532a0e
Unverified
Commit
f6532a0e
authored
Apr 26, 2022
by
andyjpaddle
Committed by
GitHub
Apr 26, 2022
Browse files
add ppocrv3 rec (#6033)
* add ppocrv3 rec
parent
6902d160
Changes
30
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1251 additions
and
28 deletions
+1251
-28
configs/rec/ch_PP-OCRv3/ch_PP-OCRv3_rec.yml
configs/rec/ch_PP-OCRv3/ch_PP-OCRv3_rec.yml
+131
-0
configs/rec/ch_PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml
configs/rec/ch_PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml
+205
-0
ppocr/data/imaug/__init__.py
ppocr/data/imaug/__init__.py
+1
-1
ppocr/data/imaug/label_ops.py
ppocr/data/imaug/label_ops.py
+32
-0
ppocr/data/imaug/rec_img_aug.py
ppocr/data/imaug/rec_img_aug.py
+53
-5
ppocr/data/simple_dataset.py
ppocr/data/simple_dataset.py
+7
-3
ppocr/losses/__init__.py
ppocr/losses/__init__.py
+2
-1
ppocr/losses/basic_loss.py
ppocr/losses/basic_loss.py
+2
-2
ppocr/losses/combined_loss.py
ppocr/losses/combined_loss.py
+2
-0
ppocr/losses/distillation_loss.py
ppocr/losses/distillation_loss.py
+55
-3
ppocr/losses/rec_multi_loss.py
ppocr/losses/rec_multi_loss.py
+58
-0
ppocr/losses/rec_sar_loss.py
ppocr/losses/rec_sar_loss.py
+2
-1
ppocr/metrics/rec_metric.py
ppocr/metrics/rec_metric.py
+9
-3
ppocr/modeling/architectures/base_model.py
ppocr/modeling/architectures/base_model.py
+5
-1
ppocr/modeling/architectures/distillation_model.py
ppocr/modeling/architectures/distillation_model.py
+2
-2
ppocr/modeling/backbones/__init__.py
ppocr/modeling/backbones/__init__.py
+3
-1
ppocr/modeling/backbones/rec_mv1_enhance.py
ppocr/modeling/backbones/rec_mv1_enhance.py
+11
-4
ppocr/modeling/backbones/rec_svtrnet.py
ppocr/modeling/backbones/rec_svtrnet.py
+595
-0
ppocr/modeling/heads/__init__.py
ppocr/modeling/heads/__init__.py
+3
-1
ppocr/modeling/heads/rec_multi_head.py
ppocr/modeling/heads/rec_multi_head.py
+73
-0
No files found.
configs/rec/ch_PP-OCRv3/ch_PP-OCRv3_rec.yml
0 → 100644
View file @
f6532a0e
Global
:
debug
:
false
use_gpu
:
true
epoch_num
:
500
log_smooth_window
:
20
print_batch_step
:
10
save_model_dir
:
./output/rec_ppocr_v3
save_epoch_step
:
3
eval_batch_step
:
[
0
,
2000
]
cal_metric_during_train
:
true
pretrained_model
:
checkpoints
:
save_inference_dir
:
use_visualdl
:
false
infer_img
:
doc/imgs_words/ch/word_1.jpg
character_dict_path
:
ppocr/utils/ppocr_keys_v1.txt
max_text_length
:
&max_text_length
25
infer_mode
:
false
use_space_char
:
true
distributed
:
true
save_res_path
:
./output/rec/predicts_ppocrv3.txt
Optimizer
:
name
:
Adam
beta1
:
0.9
beta2
:
0.999
lr
:
name
:
Cosine
learning_rate
:
0.001
warmup_epoch
:
5
regularizer
:
name
:
L2
factor
:
3.0e-05
Architecture
:
model_type
:
rec
algorithm
:
SVTR
Transform
:
Backbone
:
name
:
MobileNetV1Enhance
scale
:
0.5
last_conv_stride
:
[
1
,
2
]
last_pool_type
:
avg
Head
:
name
:
MultiHead
head_list
:
-
CTCHead
:
Neck
:
name
:
svtr
dims
:
64
depth
:
2
hidden_dims
:
120
use_guide
:
True
Head
:
fc_decay
:
0.00001
-
SARHead
:
enc_dim
:
512
max_text_length
:
*max_text_length
Loss
:
name
:
MultiLoss
loss_config_list
:
-
CTCLoss
:
-
SARLoss
:
PostProcess
:
name
:
CTCLabelDecode
Metric
:
name
:
RecMetric
main_indicator
:
acc
ignore_space
:
True
Train
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data/
ext_op_transform_idx
:
1
label_file_list
:
-
./train_data/train_list.txt
transforms
:
-
DecodeImage
:
img_mode
:
BGR
channel_first
:
false
-
RecConAug
:
prob
:
0.5
ext_data_num
:
2
image_shape
:
[
48
,
320
,
3
]
-
RecAug
:
-
MultiLabelEncode
:
-
RecResizeImg
:
image_shape
:
[
3
,
48
,
320
]
-
KeepKeys
:
keep_keys
:
-
image
-
label_ctc
-
label_sar
-
length
-
valid_ratio
loader
:
shuffle
:
true
batch_size_per_card
:
128
drop_last
:
true
num_workers
:
4
Eval
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data
label_file_list
:
-
./train_data/val_list.txt
transforms
:
-
DecodeImage
:
img_mode
:
BGR
channel_first
:
false
-
MultiLabelEncode
:
-
RecResizeImg
:
image_shape
:
[
3
,
48
,
320
]
-
KeepKeys
:
keep_keys
:
-
image
-
label_ctc
-
label_sar
-
length
-
valid_ratio
loader
:
shuffle
:
false
drop_last
:
false
batch_size_per_card
:
128
num_workers
:
4
configs/rec/ch_PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml
0 → 100644
View file @
f6532a0e
Global
:
debug
:
false
use_gpu
:
true
epoch_num
:
800
log_smooth_window
:
20
print_batch_step
:
10
save_model_dir
:
./output/rec_ppocr_v3_distillation
save_epoch_step
:
3
eval_batch_step
:
[
0
,
2000
]
cal_metric_during_train
:
true
pretrained_model
:
checkpoints
:
save_inference_dir
:
use_visualdl
:
false
infer_img
:
doc/imgs_words/ch/word_1.jpg
character_dict_path
:
ppocr/utils/ppocr_keys_v1.txt
max_text_length
:
&max_text_length
25
infer_mode
:
false
use_space_char
:
true
distributed
:
true
save_res_path
:
./output/rec/predicts_ppocrv3_distillation.txt
Optimizer
:
name
:
Adam
beta1
:
0.9
beta2
:
0.999
lr
:
name
:
Piecewise
decay_epochs
:
[
700
,
800
]
values
:
[
0.0005
,
0.00005
]
warmup_epoch
:
5
regularizer
:
name
:
L2
factor
:
3.0e-05
Architecture
:
model_type
:
&model_type
"
rec"
name
:
DistillationModel
algorithm
:
Distillation
Models
:
Teacher
:
pretrained
:
freeze_params
:
false
return_all_feats
:
true
model_type
:
*model_type
algorithm
:
SVTR
Transform
:
Backbone
:
name
:
MobileNetV1Enhance
scale
:
0.5
last_conv_stride
:
[
1
,
2
]
last_pool_type
:
avg
Head
:
name
:
MultiHead
head_list
:
-
CTCHead
:
Neck
:
name
:
svtr
dims
:
64
depth
:
2
hidden_dims
:
120
use_guide
:
True
Head
:
fc_decay
:
0.00001
-
SARHead
:
enc_dim
:
512
max_text_length
:
*max_text_length
Student
:
pretrained
:
freeze_params
:
false
return_all_feats
:
true
model_type
:
*model_type
algorithm
:
SVTR
Transform
:
Backbone
:
name
:
MobileNetV1Enhance
scale
:
0.5
last_conv_stride
:
[
1
,
2
]
last_pool_type
:
avg
Head
:
name
:
MultiHead
head_list
:
-
CTCHead
:
Neck
:
name
:
svtr
dims
:
64
depth
:
2
hidden_dims
:
120
use_guide
:
True
Head
:
fc_decay
:
0.00001
-
SARHead
:
enc_dim
:
512
max_text_length
:
*max_text_length
Loss
:
name
:
CombinedLoss
loss_config_list
:
-
DistillationDMLLoss
:
weight
:
1.0
act
:
"
softmax"
use_log
:
true
model_name_pairs
:
-
[
"
Student"
,
"
Teacher"
]
key
:
head_out
multi_head
:
True
dis_head
:
ctc
name
:
dml_ctc
-
DistillationDMLLoss
:
weight
:
0.5
act
:
"
softmax"
use_log
:
true
model_name_pairs
:
-
[
"
Student"
,
"
Teacher"
]
key
:
head_out
multi_head
:
True
dis_head
:
sar
name
:
dml_sar
-
DistillationDistanceLoss
:
weight
:
1.0
mode
:
"
l2"
model_name_pairs
:
-
[
"
Student"
,
"
Teacher"
]
key
:
backbone_out
-
DistillationCTCLoss
:
weight
:
1.0
model_name_list
:
[
"
Student"
,
"
Teacher"
]
key
:
head_out
multi_head
:
True
-
DistillationSARLoss
:
weight
:
1.0
model_name_list
:
[
"
Student"
,
"
Teacher"
]
key
:
head_out
multi_head
:
True
PostProcess
:
name
:
DistillationCTCLabelDecode
model_name
:
[
"
Student"
,
"
Teacher"
]
key
:
head_out
multi_head
:
True
Metric
:
name
:
DistillationMetric
base_metric_name
:
RecMetric
main_indicator
:
acc
key
:
"
Student"
ignore_space
:
True
Train
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data/
ext_op_transform_idx
:
1
label_file_list
:
-
./train_data/train_list.txt
transforms
:
-
DecodeImage
:
img_mode
:
BGR
channel_first
:
false
-
RecConAug
:
prob
:
0.5
ext_data_num
:
2
image_shape
:
[
48
,
320
,
3
]
-
RecAug
:
-
MultiLabelEncode
:
-
RecResizeImg
:
image_shape
:
[
3
,
48
,
320
]
-
KeepKeys
:
keep_keys
:
-
image
-
label_ctc
-
label_sar
-
length
-
valid_ratio
loader
:
shuffle
:
true
batch_size_per_card
:
128
drop_last
:
true
num_workers
:
4
Eval
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data
label_file_list
:
-
./train_data/val_list.txt
transforms
:
-
DecodeImage
:
img_mode
:
BGR
channel_first
:
false
-
MultiLabelEncode
:
-
RecResizeImg
:
image_shape
:
[
3
,
48
,
320
]
-
KeepKeys
:
keep_keys
:
-
image
-
label_ctc
-
label_sar
-
length
-
valid_ratio
loader
:
shuffle
:
false
drop_last
:
false
batch_size_per_card
:
128
num_workers
:
4
ppocr/data/imaug/__init__.py
View file @
f6532a0e
...
@@ -22,7 +22,7 @@ from .make_shrink_map import MakeShrinkMap
...
@@ -22,7 +22,7 @@ from .make_shrink_map import MakeShrinkMap
from
.random_crop_data
import
EastRandomCropData
,
RandomCropImgMask
from
.random_crop_data
import
EastRandomCropData
,
RandomCropImgMask
from
.make_pse_gt
import
MakePseGt
from
.make_pse_gt
import
MakePseGt
from
.rec_img_aug
import
RecAug
,
RecResizeImg
,
ClsResizeImg
,
\
from
.rec_img_aug
import
RecAug
,
RecConAug
,
RecResizeImg
,
ClsResizeImg
,
\
SRNRecResizeImg
,
NRTRRecResizeImg
,
SARRecResizeImg
,
PRENResizeImg
SRNRecResizeImg
,
NRTRRecResizeImg
,
SARRecResizeImg
,
PRENResizeImg
from
.randaugment
import
RandAugment
from
.randaugment
import
RandAugment
from
.copy_paste
import
CopyPaste
from
.copy_paste
import
CopyPaste
...
...
ppocr/data/imaug/label_ops.py
View file @
f6532a0e
...
@@ -22,6 +22,7 @@ import numpy as np
...
@@ -22,6 +22,7 @@ import numpy as np
import
string
import
string
from
shapely.geometry
import
LineString
,
Point
,
Polygon
from
shapely.geometry
import
LineString
,
Point
,
Polygon
import
json
import
json
import
copy
from
ppocr.utils.logging
import
get_logger
from
ppocr.utils.logging
import
get_logger
...
@@ -1007,3 +1008,34 @@ class VQATokenLabelEncode(object):
...
@@ -1007,3 +1008,34 @@ class VQATokenLabelEncode(object):
gt_label
.
extend
([
self
.
label2id_map
[(
"i-"
+
label
).
upper
()]]
*
gt_label
.
extend
([
self
.
label2id_map
[(
"i-"
+
label
).
upper
()]]
*
(
len
(
encode_res
[
"input_ids"
])
-
1
))
(
len
(
encode_res
[
"input_ids"
])
-
1
))
return
gt_label
return
gt_label
class
MultiLabelEncode
(
BaseRecLabelEncode
):
def
__init__
(
self
,
max_text_length
,
character_dict_path
=
None
,
use_space_char
=
False
,
**
kwargs
):
super
(
MultiLabelEncode
,
self
).
__init__
(
max_text_length
,
character_dict_path
,
use_space_char
)
self
.
ctc_encode
=
CTCLabelEncode
(
max_text_length
,
character_dict_path
,
use_space_char
,
**
kwargs
)
self
.
sar_encode
=
SARLabelEncode
(
max_text_length
,
character_dict_path
,
use_space_char
,
**
kwargs
)
def
__call__
(
self
,
data
):
data_ctc
=
copy
.
deepcopy
(
data
)
data_sar
=
copy
.
deepcopy
(
data
)
data_out
=
dict
()
data_out
[
'img_path'
]
=
data
.
get
(
'img_path'
,
None
)
data_out
[
'image'
]
=
data
[
'image'
]
ctc
=
self
.
ctc_encode
.
__call__
(
data_ctc
)
sar
=
self
.
sar_encode
.
__call__
(
data_sar
)
if
ctc
is
None
or
sar
is
None
:
return
None
data_out
[
'label_ctc'
]
=
ctc
[
'label'
]
data_out
[
'label_sar'
]
=
sar
[
'label'
]
data_out
[
'length'
]
=
ctc
[
'length'
]
return
data_out
ppocr/data/imaug/rec_img_aug.py
View file @
f6532a0e
...
@@ -32,6 +32,49 @@ class RecAug(object):
...
@@ -32,6 +32,49 @@ class RecAug(object):
return
data
return
data
class
RecConAug
(
object
):
def
__init__
(
self
,
prob
=
0.5
,
image_shape
=
(
32
,
320
,
3
),
max_text_length
=
25
,
ext_data_num
=
1
,
**
kwargs
):
self
.
ext_data_num
=
ext_data_num
self
.
prob
=
prob
self
.
max_text_length
=
max_text_length
self
.
image_shape
=
image_shape
self
.
max_wh_ratio
=
self
.
image_shape
[
1
]
/
self
.
image_shape
[
0
]
def
merge_ext_data
(
self
,
data
,
ext_data
):
ori_w
=
round
(
data
[
'image'
].
shape
[
1
]
/
data
[
'image'
].
shape
[
0
]
*
self
.
image_shape
[
0
])
ext_w
=
round
(
ext_data
[
'image'
].
shape
[
1
]
/
ext_data
[
'image'
].
shape
[
0
]
*
self
.
image_shape
[
0
])
data
[
'image'
]
=
cv2
.
resize
(
data
[
'image'
],
(
ori_w
,
self
.
image_shape
[
0
]))
ext_data
[
'image'
]
=
cv2
.
resize
(
ext_data
[
'image'
],
(
ext_w
,
self
.
image_shape
[
0
]))
data
[
'image'
]
=
np
.
concatenate
(
[
data
[
'image'
],
ext_data
[
'image'
]],
axis
=
1
)
data
[
"label"
]
+=
ext_data
[
"label"
]
return
data
def
__call__
(
self
,
data
):
rnd_num
=
random
.
random
()
if
rnd_num
>
self
.
prob
:
return
data
for
idx
,
ext_data
in
enumerate
(
data
[
"ext_data"
]):
if
len
(
data
[
"label"
])
+
len
(
ext_data
[
"label"
])
>
self
.
max_text_length
:
break
concat_ratio
=
data
[
'image'
].
shape
[
1
]
/
data
[
'image'
].
shape
[
0
]
+
ext_data
[
'image'
].
shape
[
1
]
/
ext_data
[
'image'
].
shape
[
0
]
if
concat_ratio
>
self
.
max_wh_ratio
:
break
data
=
self
.
merge_ext_data
(
data
,
ext_data
)
data
.
pop
(
"ext_data"
)
return
data
class
ClsResizeImg
(
object
):
class
ClsResizeImg
(
object
):
def
__init__
(
self
,
image_shape
,
**
kwargs
):
def
__init__
(
self
,
image_shape
,
**
kwargs
):
self
.
image_shape
=
image_shape
self
.
image_shape
=
image_shape
...
@@ -98,10 +141,13 @@ class RecResizeImg(object):
...
@@ -98,10 +141,13 @@ class RecResizeImg(object):
def
__call__
(
self
,
data
):
def
__call__
(
self
,
data
):
img
=
data
[
'image'
]
img
=
data
[
'image'
]
if
self
.
infer_mode
and
self
.
character_dict_path
is
not
None
:
if
self
.
infer_mode
and
self
.
character_dict_path
is
not
None
:
norm_img
=
resize_norm_img_chinese
(
img
,
self
.
image_shape
)
norm_img
,
valid_ratio
=
resize_norm_img_chinese
(
img
,
self
.
image_shape
)
else
:
else
:
norm_img
=
resize_norm_img
(
img
,
self
.
image_shape
,
self
.
padding
)
norm_img
,
valid_ratio
=
resize_norm_img
(
img
,
self
.
image_shape
,
self
.
padding
)
data
[
'image'
]
=
norm_img
data
[
'image'
]
=
norm_img
data
[
'valid_ratio'
]
=
valid_ratio
return
data
return
data
...
@@ -220,7 +266,8 @@ def resize_norm_img(img, image_shape, padding=True):
...
@@ -220,7 +266,8 @@ def resize_norm_img(img, image_shape, padding=True):
resized_image
/=
0.5
resized_image
/=
0.5
padding_im
=
np
.
zeros
((
imgC
,
imgH
,
imgW
),
dtype
=
np
.
float32
)
padding_im
=
np
.
zeros
((
imgC
,
imgH
,
imgW
),
dtype
=
np
.
float32
)
padding_im
[:,
:,
0
:
resized_w
]
=
resized_image
padding_im
[:,
:,
0
:
resized_w
]
=
resized_image
return
padding_im
valid_ratio
=
min
(
1.0
,
float
(
resized_w
/
imgW
))
return
padding_im
,
valid_ratio
def
resize_norm_img_chinese
(
img
,
image_shape
):
def
resize_norm_img_chinese
(
img
,
image_shape
):
...
@@ -230,7 +277,7 @@ def resize_norm_img_chinese(img, image_shape):
...
@@ -230,7 +277,7 @@ def resize_norm_img_chinese(img, image_shape):
h
,
w
=
img
.
shape
[
0
],
img
.
shape
[
1
]
h
,
w
=
img
.
shape
[
0
],
img
.
shape
[
1
]
ratio
=
w
*
1.0
/
h
ratio
=
w
*
1.0
/
h
max_wh_ratio
=
max
(
max_wh_ratio
,
ratio
)
max_wh_ratio
=
max
(
max_wh_ratio
,
ratio
)
imgW
=
int
(
32
*
max_wh_ratio
)
imgW
=
int
(
imgH
*
max_wh_ratio
)
if
math
.
ceil
(
imgH
*
ratio
)
>
imgW
:
if
math
.
ceil
(
imgH
*
ratio
)
>
imgW
:
resized_w
=
imgW
resized_w
=
imgW
else
:
else
:
...
@@ -246,7 +293,8 @@ def resize_norm_img_chinese(img, image_shape):
...
@@ -246,7 +293,8 @@ def resize_norm_img_chinese(img, image_shape):
resized_image
/=
0.5
resized_image
/=
0.5
padding_im
=
np
.
zeros
((
imgC
,
imgH
,
imgW
),
dtype
=
np
.
float32
)
padding_im
=
np
.
zeros
((
imgC
,
imgH
,
imgW
),
dtype
=
np
.
float32
)
padding_im
[:,
:,
0
:
resized_w
]
=
resized_image
padding_im
[:,
:,
0
:
resized_w
]
=
resized_image
return
padding_im
valid_ratio
=
min
(
1.0
,
float
(
resized_w
/
imgW
))
return
padding_im
,
valid_ratio
def
resize_norm_img_srn
(
img
,
image_shape
):
def
resize_norm_img_srn
(
img
,
image_shape
):
...
...
ppocr/data/simple_dataset.py
View file @
f6532a0e
...
@@ -49,7 +49,8 @@ class SimpleDataSet(Dataset):
...
@@ -49,7 +49,8 @@ class SimpleDataSet(Dataset):
if
self
.
mode
==
"train"
and
self
.
do_shuffle
:
if
self
.
mode
==
"train"
and
self
.
do_shuffle
:
self
.
shuffle_data_random
()
self
.
shuffle_data_random
()
self
.
ops
=
create_operators
(
dataset_config
[
'transforms'
],
global_config
)
self
.
ops
=
create_operators
(
dataset_config
[
'transforms'
],
global_config
)
self
.
ext_op_transform_idx
=
dataset_config
.
get
(
"ext_op_transform_idx"
,
2
)
self
.
need_reset
=
True
in
[
x
<
1
for
x
in
ratio_list
]
self
.
need_reset
=
True
in
[
x
<
1
for
x
in
ratio_list
]
def
get_image_info_list
(
self
,
file_list
,
ratio_list
):
def
get_image_info_list
(
self
,
file_list
,
ratio_list
):
...
@@ -87,7 +88,7 @@ class SimpleDataSet(Dataset):
...
@@ -87,7 +88,7 @@ class SimpleDataSet(Dataset):
if
hasattr
(
op
,
'ext_data_num'
):
if
hasattr
(
op
,
'ext_data_num'
):
ext_data_num
=
getattr
(
op
,
'ext_data_num'
)
ext_data_num
=
getattr
(
op
,
'ext_data_num'
)
break
break
load_data_ops
=
self
.
ops
[:
2
]
load_data_ops
=
self
.
ops
[:
self
.
ext_op_transform_idx
]
ext_data
=
[]
ext_data
=
[]
while
len
(
ext_data
)
<
ext_data_num
:
while
len
(
ext_data
)
<
ext_data_num
:
...
@@ -108,7 +109,10 @@ class SimpleDataSet(Dataset):
...
@@ -108,7 +109,10 @@ class SimpleDataSet(Dataset):
data
[
'image'
]
=
img
data
[
'image'
]
=
img
data
=
transform
(
data
,
load_data_ops
)
data
=
transform
(
data
,
load_data_ops
)
if
data
is
None
or
data
[
'polys'
].
shape
[
1
]
!=
4
:
if
data
is
None
:
continue
if
'polys'
in
data
.
keys
():
if
data
[
'polys'
].
shape
[
1
]
!=
4
:
continue
continue
ext_data
.
append
(
data
)
ext_data
.
append
(
data
)
return
ext_data
return
ext_data
...
...
ppocr/losses/__init__.py
View file @
f6532a0e
...
@@ -34,6 +34,7 @@ from .rec_nrtr_loss import NRTRLoss
...
@@ -34,6 +34,7 @@ from .rec_nrtr_loss import NRTRLoss
from
.rec_sar_loss
import
SARLoss
from
.rec_sar_loss
import
SARLoss
from
.rec_aster_loss
import
AsterLoss
from
.rec_aster_loss
import
AsterLoss
from
.rec_pren_loss
import
PRENLoss
from
.rec_pren_loss
import
PRENLoss
from
.rec_multi_loss
import
MultiLoss
# cls loss
# cls loss
from
.cls_loss
import
ClsLoss
from
.cls_loss
import
ClsLoss
...
@@ -60,7 +61,7 @@ def build_loss(config):
...
@@ -60,7 +61,7 @@ def build_loss(config):
'DBLoss'
,
'PSELoss'
,
'EASTLoss'
,
'SASTLoss'
,
'FCELoss'
,
'CTCLoss'
,
'DBLoss'
,
'PSELoss'
,
'EASTLoss'
,
'SASTLoss'
,
'FCELoss'
,
'CTCLoss'
,
'ClsLoss'
,
'AttentionLoss'
,
'SRNLoss'
,
'PGLoss'
,
'CombinedLoss'
,
'ClsLoss'
,
'AttentionLoss'
,
'SRNLoss'
,
'PGLoss'
,
'CombinedLoss'
,
'NRTRLoss'
,
'TableAttentionLoss'
,
'SARLoss'
,
'AsterLoss'
,
'SDMGRLoss'
,
'NRTRLoss'
,
'TableAttentionLoss'
,
'SARLoss'
,
'AsterLoss'
,
'SDMGRLoss'
,
'VQASerTokenLayoutLMLoss'
,
'LossFromOutput'
,
'PRENLoss'
'VQASerTokenLayoutLMLoss'
,
'LossFromOutput'
,
'PRENLoss'
,
'MultiLoss'
]
]
config
=
copy
.
deepcopy
(
config
)
config
=
copy
.
deepcopy
(
config
)
module_name
=
config
.
pop
(
'name'
)
module_name
=
config
.
pop
(
'name'
)
...
...
ppocr/losses/basic_loss.py
View file @
f6532a0e
...
@@ -106,8 +106,8 @@ class DMLLoss(nn.Layer):
...
@@ -106,8 +106,8 @@ class DMLLoss(nn.Layer):
def
forward
(
self
,
out1
,
out2
):
def
forward
(
self
,
out1
,
out2
):
if
self
.
act
is
not
None
:
if
self
.
act
is
not
None
:
out1
=
self
.
act
(
out1
)
out1
=
self
.
act
(
out1
)
+
1e-10
out2
=
self
.
act
(
out2
)
out2
=
self
.
act
(
out2
)
+
1e-10
if
self
.
use_log
:
if
self
.
use_log
:
# for recognition distillation, log is needed for feature map
# for recognition distillation, log is needed for feature map
log_out1
=
paddle
.
log
(
out1
)
log_out1
=
paddle
.
log
(
out1
)
...
...
ppocr/losses/combined_loss.py
View file @
f6532a0e
...
@@ -18,8 +18,10 @@ import paddle.nn as nn
...
@@ -18,8 +18,10 @@ import paddle.nn as nn
from
.rec_ctc_loss
import
CTCLoss
from
.rec_ctc_loss
import
CTCLoss
from
.center_loss
import
CenterLoss
from
.center_loss
import
CenterLoss
from
.ace_loss
import
ACELoss
from
.ace_loss
import
ACELoss
from
.rec_sar_loss
import
SARLoss
from
.distillation_loss
import
DistillationCTCLoss
from
.distillation_loss
import
DistillationCTCLoss
from
.distillation_loss
import
DistillationSARLoss
from
.distillation_loss
import
DistillationDMLLoss
from
.distillation_loss
import
DistillationDMLLoss
from
.distillation_loss
import
DistillationDistanceLoss
,
DistillationDBLoss
,
DistillationDilaDBLoss
from
.distillation_loss
import
DistillationDistanceLoss
,
DistillationDBLoss
,
DistillationDilaDBLoss
...
...
ppocr/losses/distillation_loss.py
View file @
f6532a0e
...
@@ -18,6 +18,7 @@ import numpy as np
...
@@ -18,6 +18,7 @@ import numpy as np
import
cv2
import
cv2
from
.rec_ctc_loss
import
CTCLoss
from
.rec_ctc_loss
import
CTCLoss
from
.rec_sar_loss
import
SARLoss
from
.basic_loss
import
DMLLoss
from
.basic_loss
import
DMLLoss
from
.basic_loss
import
DistanceLoss
from
.basic_loss
import
DistanceLoss
from
.det_db_loss
import
DBLoss
from
.det_db_loss
import
DBLoss
...
@@ -46,11 +47,15 @@ class DistillationDMLLoss(DMLLoss):
...
@@ -46,11 +47,15 @@ class DistillationDMLLoss(DMLLoss):
act
=
None
,
act
=
None
,
use_log
=
False
,
use_log
=
False
,
key
=
None
,
key
=
None
,
multi_head
=
False
,
dis_head
=
'ctc'
,
maps_name
=
None
,
maps_name
=
None
,
name
=
"dml"
):
name
=
"dml"
):
super
().
__init__
(
act
=
act
,
use_log
=
use_log
)
super
().
__init__
(
act
=
act
,
use_log
=
use_log
)
assert
isinstance
(
model_name_pairs
,
list
)
assert
isinstance
(
model_name_pairs
,
list
)
self
.
key
=
key
self
.
key
=
key
self
.
multi_head
=
multi_head
self
.
dis_head
=
dis_head
self
.
model_name_pairs
=
self
.
_check_model_name_pairs
(
model_name_pairs
)
self
.
model_name_pairs
=
self
.
_check_model_name_pairs
(
model_name_pairs
)
self
.
name
=
name
self
.
name
=
name
self
.
maps_name
=
self
.
_check_maps_name
(
maps_name
)
self
.
maps_name
=
self
.
_check_maps_name
(
maps_name
)
...
@@ -97,6 +102,10 @@ class DistillationDMLLoss(DMLLoss):
...
@@ -97,6 +102,10 @@ class DistillationDMLLoss(DMLLoss):
out2
=
out2
[
self
.
key
]
out2
=
out2
[
self
.
key
]
if
self
.
maps_name
is
None
:
if
self
.
maps_name
is
None
:
if
self
.
multi_head
:
loss
=
super
().
forward
(
out1
[
self
.
dis_head
],
out2
[
self
.
dis_head
])
else
:
loss
=
super
().
forward
(
out1
,
out2
)
loss
=
super
().
forward
(
out1
,
out2
)
if
isinstance
(
loss
,
dict
):
if
isinstance
(
loss
,
dict
):
for
key
in
loss
:
for
key
in
loss
:
...
@@ -123,11 +132,50 @@ class DistillationDMLLoss(DMLLoss):
...
@@ -123,11 +132,50 @@ class DistillationDMLLoss(DMLLoss):
class
DistillationCTCLoss
(
CTCLoss
):
class
DistillationCTCLoss
(
CTCLoss
):
def
__init__
(
self
,
model_name_list
=
[],
key
=
None
,
name
=
"loss_ctc"
):
def
__init__
(
self
,
model_name_list
=
[],
key
=
None
,
multi_head
=
False
,
name
=
"loss_ctc"
):
super
().
__init__
()
super
().
__init__
()
self
.
model_name_list
=
model_name_list
self
.
model_name_list
=
model_name_list
self
.
key
=
key
self
.
key
=
key
self
.
name
=
name
self
.
name
=
name
self
.
multi_head
=
multi_head
def
forward
(
self
,
predicts
,
batch
):
loss_dict
=
dict
()
for
idx
,
model_name
in
enumerate
(
self
.
model_name_list
):
out
=
predicts
[
model_name
]
if
self
.
key
is
not
None
:
out
=
out
[
self
.
key
]
if
self
.
multi_head
:
assert
'ctc'
in
out
,
'multi head has multi out'
loss
=
super
().
forward
(
out
[
'ctc'
],
batch
[:
2
]
+
batch
[
3
:])
else
:
loss
=
super
().
forward
(
out
,
batch
)
if
isinstance
(
loss
,
dict
):
for
key
in
loss
:
loss_dict
[
"{}_{}_{}"
.
format
(
self
.
name
,
model_name
,
idx
)]
=
loss
[
key
]
else
:
loss_dict
[
"{}_{}"
.
format
(
self
.
name
,
model_name
)]
=
loss
return
loss_dict
class
DistillationSARLoss
(
SARLoss
):
def
__init__
(
self
,
model_name_list
=
[],
key
=
None
,
multi_head
=
False
,
name
=
"loss_sar"
,
**
kwargs
):
ignore_index
=
kwargs
.
get
(
'ignore_index'
,
92
)
super
().
__init__
(
ignore_index
=
ignore_index
)
self
.
model_name_list
=
model_name_list
self
.
key
=
key
self
.
name
=
name
self
.
multi_head
=
multi_head
def
forward
(
self
,
predicts
,
batch
):
def
forward
(
self
,
predicts
,
batch
):
loss_dict
=
dict
()
loss_dict
=
dict
()
...
@@ -135,6 +183,10 @@ class DistillationCTCLoss(CTCLoss):
...
@@ -135,6 +183,10 @@ class DistillationCTCLoss(CTCLoss):
out
=
predicts
[
model_name
]
out
=
predicts
[
model_name
]
if
self
.
key
is
not
None
:
if
self
.
key
is
not
None
:
out
=
out
[
self
.
key
]
out
=
out
[
self
.
key
]
if
self
.
multi_head
:
assert
'sar'
in
out
,
'multi head has multi out'
loss
=
super
().
forward
(
out
[
'sar'
],
batch
[:
1
]
+
batch
[
2
:])
else
:
loss
=
super
().
forward
(
out
,
batch
)
loss
=
super
().
forward
(
out
,
batch
)
if
isinstance
(
loss
,
dict
):
if
isinstance
(
loss
,
dict
):
for
key
in
loss
:
for
key
in
loss
:
...
...
ppocr/losses/rec_multi_loss.py
0 → 100644
View file @
f6532a0e
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle
from
paddle
import
nn
from
.rec_ctc_loss
import
CTCLoss
from
.rec_sar_loss
import
SARLoss
class
MultiLoss
(
nn
.
Layer
):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
()
self
.
loss_funcs
=
{}
self
.
loss_list
=
kwargs
.
pop
(
'loss_config_list'
)
self
.
weight_1
=
kwargs
.
get
(
'weight_1'
,
1.0
)
self
.
weight_2
=
kwargs
.
get
(
'weight_2'
,
1.0
)
self
.
gtc_loss
=
kwargs
.
get
(
'gtc_loss'
,
'sar'
)
for
loss_info
in
self
.
loss_list
:
for
name
,
param
in
loss_info
.
items
():
if
param
is
not
None
:
kwargs
.
update
(
param
)
loss
=
eval
(
name
)(
**
kwargs
)
self
.
loss_funcs
[
name
]
=
loss
def
forward
(
self
,
predicts
,
batch
):
self
.
total_loss
=
{}
total_loss
=
0.0
# batch [image, label_ctc, label_sar, length, valid_ratio]
for
name
,
loss_func
in
self
.
loss_funcs
.
items
():
if
name
==
'CTCLoss'
:
loss
=
loss_func
(
predicts
[
'ctc'
],
batch
[:
2
]
+
batch
[
3
:])[
'loss'
]
*
self
.
weight_1
elif
name
==
'SARLoss'
:
loss
=
loss_func
(
predicts
[
'sar'
],
batch
[:
1
]
+
batch
[
2
:])[
'loss'
]
*
self
.
weight_2
else
:
raise
NotImplementedError
(
'{} is not supported in MultiLoss yet'
.
format
(
name
))
self
.
total_loss
[
name
]
=
loss
total_loss
+=
loss
self
.
total_loss
[
'loss'
]
=
total_loss
return
self
.
total_loss
ppocr/losses/rec_sar_loss.py
View file @
f6532a0e
...
@@ -9,8 +9,9 @@ from paddle import nn
...
@@ -9,8 +9,9 @@ from paddle import nn
class
SARLoss
(
nn
.
Layer
):
class
SARLoss
(
nn
.
Layer
):
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
super
(
SARLoss
,
self
).
__init__
()
super
(
SARLoss
,
self
).
__init__
()
ignore_index
=
kwargs
.
get
(
'ignore_index'
,
92
)
# 6626
self
.
loss_func
=
paddle
.
nn
.
loss
.
CrossEntropyLoss
(
self
.
loss_func
=
paddle
.
nn
.
loss
.
CrossEntropyLoss
(
reduction
=
"mean"
,
ignore_index
=
92
)
reduction
=
"mean"
,
ignore_index
=
ignore_index
)
def
forward
(
self
,
predicts
,
batch
):
def
forward
(
self
,
predicts
,
batch
):
predict
=
predicts
[:,
:
predict
=
predicts
[:,
:
...
...
ppocr/metrics/rec_metric.py
View file @
f6532a0e
...
@@ -17,9 +17,14 @@ import string
...
@@ -17,9 +17,14 @@ import string
class
RecMetric
(
object
):
class
RecMetric
(
object
):
def
__init__
(
self
,
main_indicator
=
'acc'
,
is_filter
=
False
,
**
kwargs
):
def
__init__
(
self
,
main_indicator
=
'acc'
,
is_filter
=
False
,
ignore_space
=
True
,
**
kwargs
):
self
.
main_indicator
=
main_indicator
self
.
main_indicator
=
main_indicator
self
.
is_filter
=
is_filter
self
.
is_filter
=
is_filter
self
.
ignore_space
=
ignore_space
self
.
eps
=
1e-5
self
.
eps
=
1e-5
self
.
reset
()
self
.
reset
()
...
@@ -34,6 +39,7 @@ class RecMetric(object):
...
@@ -34,6 +39,7 @@ class RecMetric(object):
all_num
=
0
all_num
=
0
norm_edit_dis
=
0.0
norm_edit_dis
=
0.0
for
(
pred
,
pred_conf
),
(
target
,
_
)
in
zip
(
preds
,
labels
):
for
(
pred
,
pred_conf
),
(
target
,
_
)
in
zip
(
preds
,
labels
):
if
self
.
ignore_space
:
pred
=
pred
.
replace
(
" "
,
""
)
pred
=
pred
.
replace
(
" "
,
""
)
target
=
target
.
replace
(
" "
,
""
)
target
=
target
.
replace
(
" "
,
""
)
if
self
.
is_filter
:
if
self
.
is_filter
:
...
...
ppocr/modeling/architectures/base_model.py
View file @
f6532a0e
...
@@ -83,7 +83,11 @@ class BaseModel(nn.Layer):
...
@@ -83,7 +83,11 @@ class BaseModel(nn.Layer):
y
[
"neck_out"
]
=
x
y
[
"neck_out"
]
=
x
if
self
.
use_head
:
if
self
.
use_head
:
x
=
self
.
head
(
x
,
targets
=
data
)
x
=
self
.
head
(
x
,
targets
=
data
)
if
isinstance
(
x
,
dict
):
# for multi head, save ctc neck out for udml
if
isinstance
(
x
,
dict
)
and
'ctc_neck'
in
x
.
keys
():
y
[
"neck_out"
]
=
x
[
"ctc_neck"
]
y
[
"head_out"
]
=
x
elif
isinstance
(
x
,
dict
):
y
.
update
(
x
)
y
.
update
(
x
)
else
:
else
:
y
[
"head_out"
]
=
x
y
[
"head_out"
]
=
x
...
...
ppocr/modeling/architectures/distillation_model.py
View file @
f6532a0e
...
@@ -53,8 +53,8 @@ class DistillationModel(nn.Layer):
...
@@ -53,8 +53,8 @@ class DistillationModel(nn.Layer):
self
.
model_list
.
append
(
self
.
add_sublayer
(
key
,
model
))
self
.
model_list
.
append
(
self
.
add_sublayer
(
key
,
model
))
self
.
model_name_list
.
append
(
key
)
self
.
model_name_list
.
append
(
key
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
,
data
=
None
):
result_dict
=
dict
()
result_dict
=
dict
()
for
idx
,
model_name
in
enumerate
(
self
.
model_name_list
):
for
idx
,
model_name
in
enumerate
(
self
.
model_name_list
):
result_dict
[
model_name
]
=
self
.
model_list
[
idx
](
x
)
result_dict
[
model_name
]
=
self
.
model_list
[
idx
](
x
,
data
)
return
result_dict
return
result_dict
ppocr/modeling/backbones/__init__.py
View file @
f6532a0e
...
@@ -31,9 +31,11 @@ def build_backbone(config, model_type):
...
@@ -31,9 +31,11 @@ def build_backbone(config, model_type):
from
.rec_resnet_aster
import
ResNet_ASTER
from
.rec_resnet_aster
import
ResNet_ASTER
from
.rec_micronet
import
MicroNet
from
.rec_micronet
import
MicroNet
from
.rec_efficientb3_pren
import
EfficientNetb3_PREN
from
.rec_efficientb3_pren
import
EfficientNetb3_PREN
from
.rec_svtrnet
import
SVTRNet
support_dict
=
[
support_dict
=
[
'MobileNetV1Enhance'
,
'MobileNetV3'
,
'ResNet'
,
'ResNetFPN'
,
'MTB'
,
'MobileNetV1Enhance'
,
'MobileNetV3'
,
'ResNet'
,
'ResNetFPN'
,
'MTB'
,
"ResNet31"
,
"ResNet_ASTER"
,
'MicroNet'
,
'EfficientNetb3_PREN'
"ResNet31"
,
"ResNet_ASTER"
,
'MicroNet'
,
'EfficientNetb3_PREN'
,
'SVTRNet'
]
]
elif
model_type
==
"e2e"
:
elif
model_type
==
"e2e"
:
from
.e2e_resnet_vd_pg
import
ResNet
from
.e2e_resnet_vd_pg
import
ResNet
...
...
ppocr/modeling/backbones/rec_mv1_enhance.py
View file @
f6532a0e
...
@@ -103,7 +103,12 @@ class DepthwiseSeparable(nn.Layer):
...
@@ -103,7 +103,12 @@ class DepthwiseSeparable(nn.Layer):
class
MobileNetV1Enhance
(
nn
.
Layer
):
class
MobileNetV1Enhance
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
=
3
,
scale
=
0.5
,
**
kwargs
):
def
__init__
(
self
,
in_channels
=
3
,
scale
=
0.5
,
last_conv_stride
=
1
,
last_pool_type
=
'max'
,
**
kwargs
):
super
().
__init__
()
super
().
__init__
()
self
.
scale
=
scale
self
.
scale
=
scale
self
.
block_list
=
[]
self
.
block_list
=
[]
...
@@ -200,7 +205,7 @@ class MobileNetV1Enhance(nn.Layer):
...
@@ -200,7 +205,7 @@ class MobileNetV1Enhance(nn.Layer):
num_filters1
=
1024
,
num_filters1
=
1024
,
num_filters2
=
1024
,
num_filters2
=
1024
,
num_groups
=
1024
,
num_groups
=
1024
,
stride
=
1
,
stride
=
last_conv_stride
,
dw_size
=
5
,
dw_size
=
5
,
padding
=
2
,
padding
=
2
,
use_se
=
True
,
use_se
=
True
,
...
@@ -208,7 +213,9 @@ class MobileNetV1Enhance(nn.Layer):
...
@@ -208,7 +213,9 @@ class MobileNetV1Enhance(nn.Layer):
self
.
block_list
.
append
(
conv6
)
self
.
block_list
.
append
(
conv6
)
self
.
block_list
=
nn
.
Sequential
(
*
self
.
block_list
)
self
.
block_list
=
nn
.
Sequential
(
*
self
.
block_list
)
if
last_pool_type
==
'avg'
:
self
.
pool
=
nn
.
AvgPool2D
(
kernel_size
=
2
,
stride
=
2
,
padding
=
0
)
else
:
self
.
pool
=
nn
.
MaxPool2D
(
kernel_size
=
2
,
stride
=
2
,
padding
=
0
)
self
.
pool
=
nn
.
MaxPool2D
(
kernel_size
=
2
,
stride
=
2
,
padding
=
0
)
self
.
out_channels
=
int
(
1024
*
scale
)
self
.
out_channels
=
int
(
1024
*
scale
)
...
...
ppocr/modeling/backbones/rec_svtrnet.py
0 → 100644
View file @
f6532a0e
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections
import
Callable
from
paddle
import
ParamAttr
from
paddle.nn.initializer
import
KaimingNormal
import
numpy
as
np
import
paddle
import
paddle.nn
as
nn
from
paddle.nn.initializer
import
TruncatedNormal
,
Constant
,
Normal
trunc_normal_
=
TruncatedNormal
(
std
=
.
02
)
normal_
=
Normal
zeros_
=
Constant
(
value
=
0.
)
ones_
=
Constant
(
value
=
1.
)
def
drop_path
(
x
,
drop_prob
=
0.
,
training
=
False
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
"""
if
drop_prob
==
0.
or
not
training
:
return
x
keep_prob
=
paddle
.
to_tensor
(
1
-
drop_prob
)
shape
=
(
paddle
.
shape
(
x
)[
0
],
)
+
(
1
,
)
*
(
x
.
ndim
-
1
)
random_tensor
=
keep_prob
+
paddle
.
rand
(
shape
,
dtype
=
x
.
dtype
)
random_tensor
=
paddle
.
floor
(
random_tensor
)
# binarize
output
=
x
.
divide
(
keep_prob
)
*
random_tensor
return
output
class
ConvBNLayer
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
0
,
bias_attr
=
False
,
groups
=
1
,
act
=
nn
.
GELU
):
super
().
__init__
()
self
.
conv
=
nn
.
Conv2D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
groups
=
groups
,
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
KaimingUniform
()),
bias_attr
=
bias_attr
)
self
.
norm
=
nn
.
BatchNorm2D
(
out_channels
)
self
.
act
=
act
()
def
forward
(
self
,
inputs
):
out
=
self
.
conv
(
inputs
)
out
=
self
.
norm
(
out
)
out
=
self
.
act
(
out
)
return
out
class
DropPath
(
nn
.
Layer
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def
__init__
(
self
,
drop_prob
=
None
):
super
(
DropPath
,
self
).
__init__
()
self
.
drop_prob
=
drop_prob
def
forward
(
self
,
x
):
return
drop_path
(
x
,
self
.
drop_prob
,
self
.
training
)
class
Identity
(
nn
.
Layer
):
def
__init__
(
self
):
super
(
Identity
,
self
).
__init__
()
def
forward
(
self
,
input
):
return
input
class
Mlp
(
nn
.
Layer
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_layer
=
nn
.
GELU
,
drop
=
0.
):
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
)
self
.
act
=
act_layer
()
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
)
self
.
drop
=
nn
.
Dropout
(
drop
)
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
drop
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
drop
(
x
)
return
x
class
ConvMixer
(
nn
.
Layer
):
def
__init__
(
self
,
dim
,
num_heads
=
8
,
HW
=
[
8
,
25
],
local_k
=
[
3
,
3
],
):
super
().
__init__
()
self
.
HW
=
HW
self
.
dim
=
dim
self
.
local_mixer
=
nn
.
Conv2D
(
dim
,
dim
,
local_k
,
1
,
[
local_k
[
0
]
//
2
,
local_k
[
1
]
//
2
],
groups
=
num_heads
,
weight_attr
=
ParamAttr
(
initializer
=
KaimingNormal
()))
def
forward
(
self
,
x
):
h
=
self
.
HW
[
0
]
w
=
self
.
HW
[
1
]
x
=
x
.
transpose
([
0
,
2
,
1
]).
reshape
([
0
,
self
.
dim
,
h
,
w
])
x
=
self
.
local_mixer
(
x
)
x
=
x
.
flatten
(
2
).
transpose
([
0
,
2
,
1
])
return
x
class
Attention
(
nn
.
Layer
):
def
__init__
(
self
,
dim
,
num_heads
=
8
,
mixer
=
'Global'
,
HW
=
[
8
,
25
],
local_k
=
[
7
,
11
],
qkv_bias
=
False
,
qk_scale
=
None
,
attn_drop
=
0.
,
proj_drop
=
0.
):
super
().
__init__
()
self
.
num_heads
=
num_heads
head_dim
=
dim
//
num_heads
self
.
scale
=
qk_scale
or
head_dim
**-
0.5
self
.
qkv
=
nn
.
Linear
(
dim
,
dim
*
3
,
bias_attr
=
qkv_bias
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
self
.
HW
=
HW
if
HW
is
not
None
:
H
=
HW
[
0
]
W
=
HW
[
1
]
self
.
N
=
H
*
W
self
.
C
=
dim
if
mixer
==
'Local'
and
HW
is
not
None
:
hk
=
local_k
[
0
]
wk
=
local_k
[
1
]
mask
=
np
.
ones
([
H
*
W
,
H
*
W
])
for
h
in
range
(
H
):
for
w
in
range
(
W
):
for
kh
in
range
(
-
(
hk
//
2
),
(
hk
//
2
)
+
1
):
for
kw
in
range
(
-
(
wk
//
2
),
(
wk
//
2
)
+
1
):
if
H
>
(
h
+
kh
)
>=
0
and
W
>
(
w
+
kw
)
>=
0
:
mask
[
h
*
W
+
w
][(
h
+
kh
)
*
W
+
(
w
+
kw
)]
=
0
mask_paddle
=
paddle
.
to_tensor
(
mask
,
dtype
=
'float32'
)
mask_inf
=
paddle
.
full
([
H
*
W
,
H
*
W
],
'-inf'
,
dtype
=
'float32'
)
mask
=
paddle
.
where
(
mask_paddle
<
1
,
mask_paddle
,
mask_inf
)
self
.
mask
=
mask
.
unsqueeze
([
0
,
1
])
self
.
mixer
=
mixer
def
forward
(
self
,
x
):
if
self
.
HW
is
not
None
:
N
=
self
.
N
C
=
self
.
C
else
:
_
,
N
,
C
=
x
.
shape
qkv
=
self
.
qkv
(
x
).
reshape
((
0
,
N
,
3
,
self
.
num_heads
,
C
//
self
.
num_heads
)).
transpose
((
2
,
0
,
3
,
1
,
4
))
q
,
k
,
v
=
qkv
[
0
]
*
self
.
scale
,
qkv
[
1
],
qkv
[
2
]
attn
=
(
q
.
matmul
(
k
.
transpose
((
0
,
1
,
3
,
2
))))
if
self
.
mixer
==
'Local'
:
attn
+=
self
.
mask
attn
=
nn
.
functional
.
softmax
(
attn
,
axis
=-
1
)
attn
=
self
.
attn_drop
(
attn
)
x
=
(
attn
.
matmul
(
v
)).
transpose
((
0
,
2
,
1
,
3
)).
reshape
((
0
,
N
,
C
))
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
class
Block
(
nn
.
Layer
):
def
__init__
(
self
,
dim
,
num_heads
,
mixer
=
'Global'
,
local_mixer
=
[
7
,
11
],
HW
=
[
8
,
25
],
mlp_ratio
=
4.
,
qkv_bias
=
False
,
qk_scale
=
None
,
drop
=
0.
,
attn_drop
=
0.
,
drop_path
=
0.
,
act_layer
=
nn
.
GELU
,
norm_layer
=
'nn.LayerNorm'
,
epsilon
=
1e-6
,
prenorm
=
True
):
super
().
__init__
()
if
isinstance
(
norm_layer
,
str
):
self
.
norm1
=
eval
(
norm_layer
)(
dim
,
epsilon
=
epsilon
)
elif
isinstance
(
norm_layer
,
Callable
):
self
.
norm1
=
norm_layer
(
dim
)
else
:
raise
TypeError
(
"The norm_layer must be str or paddle.nn.layer.Layer class"
)
if
mixer
==
'Global'
or
mixer
==
'Local'
:
self
.
mixer
=
Attention
(
dim
,
num_heads
=
num_heads
,
mixer
=
mixer
,
HW
=
HW
,
local_k
=
local_mixer
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
attn_drop
=
attn_drop
,
proj_drop
=
drop
)
elif
mixer
==
'Conv'
:
self
.
mixer
=
ConvMixer
(
dim
,
num_heads
=
num_heads
,
HW
=
HW
,
local_k
=
local_mixer
)
else
:
raise
TypeError
(
"The mixer must be one of [Global, Local, Conv]"
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.
else
Identity
()
if
isinstance
(
norm_layer
,
str
):
self
.
norm2
=
eval
(
norm_layer
)(
dim
,
epsilon
=
epsilon
)
elif
isinstance
(
norm_layer
,
Callable
):
self
.
norm2
=
norm_layer
(
dim
)
else
:
raise
TypeError
(
"The norm_layer must be str or paddle.nn.layer.Layer class"
)
mlp_hidden_dim
=
int
(
dim
*
mlp_ratio
)
self
.
mlp_ratio
=
mlp_ratio
self
.
mlp
=
Mlp
(
in_features
=
dim
,
hidden_features
=
mlp_hidden_dim
,
act_layer
=
act_layer
,
drop
=
drop
)
self
.
prenorm
=
prenorm
def
forward
(
self
,
x
):
if
self
.
prenorm
:
x
=
self
.
norm1
(
x
+
self
.
drop_path
(
self
.
mixer
(
x
)))
x
=
self
.
norm2
(
x
+
self
.
drop_path
(
self
.
mlp
(
x
)))
else
:
x
=
x
+
self
.
drop_path
(
self
.
mixer
(
self
.
norm1
(
x
)))
x
=
x
+
self
.
drop_path
(
self
.
mlp
(
self
.
norm2
(
x
)))
return
x
class
PatchEmbed
(
nn
.
Layer
):
""" Image to Patch Embedding
"""
def
__init__
(
self
,
img_size
=
[
32
,
100
],
in_channels
=
3
,
embed_dim
=
768
,
sub_num
=
2
):
super
().
__init__
()
num_patches
=
(
img_size
[
1
]
//
(
2
**
sub_num
))
*
\
(
img_size
[
0
]
//
(
2
**
sub_num
))
self
.
img_size
=
img_size
self
.
num_patches
=
num_patches
self
.
embed_dim
=
embed_dim
self
.
norm
=
None
if
sub_num
==
2
:
self
.
proj
=
nn
.
Sequential
(
ConvBNLayer
(
in_channels
,
embed_dim
//
2
,
3
,
2
,
1
,
act
=
nn
.
GELU
,
bias_attr
=
None
),
ConvBNLayer
(
embed_dim
//
2
,
embed_dim
,
3
,
2
,
1
,
act
=
nn
.
GELU
,
bias_attr
=
None
))
if
sub_num
==
3
:
self
.
proj
=
nn
.
Sequential
(
ConvBNLayer
(
in_channels
,
embed_dim
//
4
,
3
,
2
,
1
,
act
=
nn
.
GELU
,
bias_attr
=
None
),
ConvBNLayer
(
embed_dim
//
4
,
embed_dim
//
2
,
3
,
2
,
1
,
act
=
nn
.
GELU
,
bias_attr
=
None
),
ConvBNLayer
(
embed_dim
//
2
,
embed_dim
,
3
,
2
,
1
,
act
=
nn
.
GELU
,
bias_attr
=
None
),
)
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
assert
H
==
self
.
img_size
[
0
]
and
W
==
self
.
img_size
[
1
],
\
f
"Input image size (
{
H
}
*
{
W
}
) doesn't match model (
{
self
.
img_size
[
0
]
}
*
{
self
.
img_size
[
1
]
}
)."
x
=
self
.
proj
(
x
).
flatten
(
2
).
transpose
((
0
,
2
,
1
))
return
x
class
SubSample
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
types
=
'Pool'
,
stride
=
[
2
,
1
],
sub_norm
=
'nn.LayerNorm'
,
act
=
None
):
super
().
__init__
()
self
.
types
=
types
if
types
==
'Pool'
:
self
.
avgpool
=
nn
.
AvgPool2D
(
kernel_size
=
[
3
,
5
],
stride
=
stride
,
padding
=
[
1
,
2
])
self
.
maxpool
=
nn
.
MaxPool2D
(
kernel_size
=
[
3
,
5
],
stride
=
stride
,
padding
=
[
1
,
2
])
self
.
proj
=
nn
.
Linear
(
in_channels
,
out_channels
)
else
:
self
.
conv
=
nn
.
Conv2D
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
1
,
weight_attr
=
ParamAttr
(
initializer
=
KaimingNormal
()))
self
.
norm
=
eval
(
sub_norm
)(
out_channels
)
if
act
is
not
None
:
self
.
act
=
act
()
else
:
self
.
act
=
None
def
forward
(
self
,
x
):
if
self
.
types
==
'Pool'
:
x1
=
self
.
avgpool
(
x
)
x2
=
self
.
maxpool
(
x
)
x
=
(
x1
+
x2
)
*
0.5
out
=
self
.
proj
(
x
.
flatten
(
2
).
transpose
((
0
,
2
,
1
)))
else
:
x
=
self
.
conv
(
x
)
out
=
x
.
flatten
(
2
).
transpose
((
0
,
2
,
1
))
out
=
self
.
norm
(
out
)
if
self
.
act
is
not
None
:
out
=
self
.
act
(
out
)
return
out
class
SVTRNet
(
nn
.
Layer
):
def
__init__
(
self
,
img_size
=
[
32
,
100
],
in_channels
=
3
,
embed_dim
=
[
64
,
128
,
256
],
depth
=
[
3
,
6
,
3
],
num_heads
=
[
2
,
4
,
8
],
mixer
=
[
'Local'
]
*
6
+
[
'Global'
]
*
6
,
# Local atten, Global atten, Conv
local_mixer
=
[[
7
,
11
],
[
7
,
11
],
[
7
,
11
]],
patch_merging
=
'Conv'
,
# Conv, Pool, None
mlp_ratio
=
4
,
qkv_bias
=
True
,
qk_scale
=
None
,
drop_rate
=
0.
,
last_drop
=
0.1
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.1
,
norm_layer
=
'nn.LayerNorm'
,
sub_norm
=
'nn.LayerNorm'
,
epsilon
=
1e-6
,
out_channels
=
192
,
out_char_num
=
25
,
block_unit
=
'Block'
,
act
=
'nn.GELU'
,
last_stage
=
True
,
sub_num
=
2
,
prenorm
=
True
,
use_lenhead
=
False
,
**
kwargs
):
super
().
__init__
()
self
.
img_size
=
img_size
self
.
embed_dim
=
embed_dim
self
.
out_channels
=
out_channels
self
.
prenorm
=
prenorm
patch_merging
=
None
if
patch_merging
!=
'Conv'
and
patch_merging
!=
'Pool'
else
patch_merging
self
.
patch_embed
=
PatchEmbed
(
img_size
=
img_size
,
in_channels
=
in_channels
,
embed_dim
=
embed_dim
[
0
],
sub_num
=
sub_num
)
num_patches
=
self
.
patch_embed
.
num_patches
self
.
HW
=
[
img_size
[
0
]
//
(
2
**
sub_num
),
img_size
[
1
]
//
(
2
**
sub_num
)]
self
.
pos_embed
=
self
.
create_parameter
(
shape
=
[
1
,
num_patches
,
embed_dim
[
0
]],
default_initializer
=
zeros_
)
self
.
add_parameter
(
"pos_embed"
,
self
.
pos_embed
)
self
.
pos_drop
=
nn
.
Dropout
(
p
=
drop_rate
)
Block_unit
=
eval
(
block_unit
)
dpr
=
np
.
linspace
(
0
,
drop_path_rate
,
sum
(
depth
))
self
.
blocks1
=
nn
.
LayerList
([
Block_unit
(
dim
=
embed_dim
[
0
],
num_heads
=
num_heads
[
0
],
mixer
=
mixer
[
0
:
depth
[
0
]][
i
],
HW
=
self
.
HW
,
local_mixer
=
local_mixer
[
0
],
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
act_layer
=
nn
.
Swish
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
0
:
depth
[
0
]][
i
],
norm_layer
=
norm_layer
,
epsilon
=
epsilon
,
prenorm
=
prenorm
)
for
i
in
range
(
depth
[
0
])
])
if
patch_merging
is
not
None
:
self
.
sub_sample1
=
SubSample
(
embed_dim
[
0
],
embed_dim
[
1
],
sub_norm
=
sub_norm
,
stride
=
[
2
,
1
],
types
=
patch_merging
)
HW
=
[
self
.
HW
[
0
]
//
2
,
self
.
HW
[
1
]]
else
:
HW
=
self
.
HW
self
.
patch_merging
=
patch_merging
self
.
blocks2
=
nn
.
LayerList
([
Block_unit
(
dim
=
embed_dim
[
1
],
num_heads
=
num_heads
[
1
],
mixer
=
mixer
[
depth
[
0
]:
depth
[
0
]
+
depth
[
1
]][
i
],
HW
=
HW
,
local_mixer
=
local_mixer
[
1
],
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
act_layer
=
eval
(
act
),
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
depth
[
0
]:
depth
[
0
]
+
depth
[
1
]][
i
],
norm_layer
=
norm_layer
,
epsilon
=
epsilon
,
prenorm
=
prenorm
)
for
i
in
range
(
depth
[
1
])
])
if
patch_merging
is
not
None
:
self
.
sub_sample2
=
SubSample
(
embed_dim
[
1
],
embed_dim
[
2
],
sub_norm
=
sub_norm
,
stride
=
[
2
,
1
],
types
=
patch_merging
)
HW
=
[
self
.
HW
[
0
]
//
4
,
self
.
HW
[
1
]]
else
:
HW
=
self
.
HW
self
.
blocks3
=
nn
.
LayerList
([
Block_unit
(
dim
=
embed_dim
[
2
],
num_heads
=
num_heads
[
2
],
mixer
=
mixer
[
depth
[
0
]
+
depth
[
1
]:][
i
],
HW
=
HW
,
local_mixer
=
local_mixer
[
2
],
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
act_layer
=
eval
(
act
),
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
depth
[
0
]
+
depth
[
1
]:][
i
],
norm_layer
=
norm_layer
,
epsilon
=
epsilon
,
prenorm
=
prenorm
)
for
i
in
range
(
depth
[
2
])
])
self
.
last_stage
=
last_stage
if
last_stage
:
self
.
avg_pool
=
nn
.
AdaptiveAvgPool2D
([
1
,
out_char_num
])
self
.
last_conv
=
nn
.
Conv2D
(
in_channels
=
embed_dim
[
2
],
out_channels
=
self
.
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias_attr
=
False
)
self
.
hardswish
=
nn
.
Hardswish
()
self
.
dropout
=
nn
.
Dropout
(
p
=
last_drop
,
mode
=
"downscale_in_infer"
)
if
not
prenorm
:
self
.
norm
=
eval
(
norm_layer
)(
embed_dim
[
-
1
],
epsilon
=
epsilon
)
self
.
use_lenhead
=
use_lenhead
if
use_lenhead
:
self
.
len_conv
=
nn
.
Linear
(
embed_dim
[
2
],
self
.
out_channels
)
self
.
hardswish_len
=
nn
.
Hardswish
()
self
.
dropout_len
=
nn
.
Dropout
(
p
=
last_drop
,
mode
=
"downscale_in_infer"
)
trunc_normal_
(
self
.
pos_embed
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
zeros_
(
m
.
bias
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
zeros_
(
m
.
bias
)
ones_
(
m
.
weight
)
def
forward_features
(
self
,
x
):
x
=
self
.
patch_embed
(
x
)
x
=
x
+
self
.
pos_embed
x
=
self
.
pos_drop
(
x
)
for
blk
in
self
.
blocks1
:
x
=
blk
(
x
)
if
self
.
patch_merging
is
not
None
:
x
=
self
.
sub_sample1
(
x
.
transpose
([
0
,
2
,
1
]).
reshape
(
[
0
,
self
.
embed_dim
[
0
],
self
.
HW
[
0
],
self
.
HW
[
1
]]))
for
blk
in
self
.
blocks2
:
x
=
blk
(
x
)
if
self
.
patch_merging
is
not
None
:
x
=
self
.
sub_sample2
(
x
.
transpose
([
0
,
2
,
1
]).
reshape
(
[
0
,
self
.
embed_dim
[
1
],
self
.
HW
[
0
]
//
2
,
self
.
HW
[
1
]]))
for
blk
in
self
.
blocks3
:
x
=
blk
(
x
)
if
not
self
.
prenorm
:
x
=
self
.
norm
(
x
)
return
x
def
forward
(
self
,
x
):
x
=
self
.
forward_features
(
x
)
if
self
.
use_lenhead
:
len_x
=
self
.
len_conv
(
x
.
mean
(
1
))
len_x
=
self
.
dropout_len
(
self
.
hardswish_len
(
len_x
))
if
self
.
last_stage
:
if
self
.
patch_merging
is
not
None
:
h
=
self
.
HW
[
0
]
//
4
else
:
h
=
self
.
HW
[
0
]
x
=
self
.
avg_pool
(
x
.
transpose
([
0
,
2
,
1
]).
reshape
(
[
0
,
self
.
embed_dim
[
2
],
h
,
self
.
HW
[
1
]]))
x
=
self
.
last_conv
(
x
)
x
=
self
.
hardswish
(
x
)
x
=
self
.
dropout
(
x
)
if
self
.
use_lenhead
:
return
x
,
len_x
return
x
ppocr/modeling/heads/__init__.py
View file @
f6532a0e
...
@@ -32,6 +32,7 @@ def build_head(config):
...
@@ -32,6 +32,7 @@ def build_head(config):
from
.rec_sar_head
import
SARHead
from
.rec_sar_head
import
SARHead
from
.rec_aster_head
import
AsterHead
from
.rec_aster_head
import
AsterHead
from
.rec_pren_head
import
PRENHead
from
.rec_pren_head
import
PRENHead
from
.rec_multi_head
import
MultiHead
# cls head
# cls head
from
.cls_head
import
ClsHead
from
.cls_head
import
ClsHead
...
@@ -44,7 +45,8 @@ def build_head(config):
...
@@ -44,7 +45,8 @@ def build_head(config):
support_dict
=
[
support_dict
=
[
'DBHead'
,
'PSEHead'
,
'FCEHead'
,
'EASTHead'
,
'SASTHead'
,
'CTCHead'
,
'DBHead'
,
'PSEHead'
,
'FCEHead'
,
'EASTHead'
,
'SASTHead'
,
'CTCHead'
,
'ClsHead'
,
'AttentionHead'
,
'SRNHead'
,
'PGHead'
,
'Transformer'
,
'ClsHead'
,
'AttentionHead'
,
'SRNHead'
,
'PGHead'
,
'Transformer'
,
'TableAttentionHead'
,
'SARHead'
,
'AsterHead'
,
'SDMGRHead'
,
'PRENHead'
'TableAttentionHead'
,
'SARHead'
,
'AsterHead'
,
'SDMGRHead'
,
'PRENHead'
,
'MultiHead'
]
]
#table head
#table head
...
...
ppocr/modeling/heads/rec_multi_head.py
0 → 100644
View file @
f6532a0e
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
math
import
paddle
from
paddle
import
ParamAttr
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
ppocr.modeling.necks.rnn
import
Im2Seq
,
EncoderWithRNN
,
EncoderWithFC
,
SequenceEncoder
,
EncoderWithSVTR
from
.rec_ctc_head
import
CTCHead
from
.rec_sar_head
import
SARHead
class
MultiHead
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels_list
,
**
kwargs
):
super
().
__init__
()
self
.
head_list
=
kwargs
.
pop
(
'head_list'
)
self
.
gtc_head
=
'sar'
assert
len
(
self
.
head_list
)
>=
2
for
idx
,
head_name
in
enumerate
(
self
.
head_list
):
name
=
list
(
head_name
)[
0
]
if
name
==
'SARHead'
:
# sar head
sar_args
=
self
.
head_list
[
idx
][
name
]
self
.
sar_head
=
eval
(
name
)(
in_channels
=
in_channels
,
\
out_channels
=
out_channels_list
[
'SARLabelDecode'
],
**
sar_args
)
elif
name
==
'CTCHead'
:
# ctc neck
self
.
encoder_reshape
=
Im2Seq
(
in_channels
)
neck_args
=
self
.
head_list
[
idx
][
name
][
'Neck'
]
encoder_type
=
neck_args
.
pop
(
'name'
)
self
.
encoder
=
encoder_type
self
.
ctc_encoder
=
SequenceEncoder
(
in_channels
=
in_channels
,
\
encoder_type
=
encoder_type
,
**
neck_args
)
# ctc head
head_args
=
self
.
head_list
[
idx
][
name
][
'Head'
]
self
.
ctc_head
=
eval
(
name
)(
in_channels
=
self
.
ctc_encoder
.
out_channels
,
\
out_channels
=
out_channels_list
[
'CTCLabelDecode'
],
**
head_args
)
else
:
raise
NotImplementedError
(
'{} is not supported in MultiHead yet'
.
format
(
name
))
def
forward
(
self
,
x
,
targets
=
None
):
ctc_encoder
=
self
.
ctc_encoder
(
x
)
ctc_out
=
self
.
ctc_head
(
ctc_encoder
,
targets
)
head_out
=
dict
()
head_out
[
'ctc'
]
=
ctc_out
head_out
[
'ctc_neck'
]
=
ctc_encoder
# eval mode
if
not
self
.
training
:
return
ctc_out
if
self
.
gtc_head
==
'sar'
:
sar_out
=
self
.
sar_head
(
x
,
targets
[
1
:])
head_out
[
'sar'
]
=
sar_out
return
head_out
else
:
return
head_out
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment