Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
4cddec73
Unverified
Commit
4cddec73
authored
Apr 27, 2022
by
littletomatodonkey
Committed by
GitHub
Apr 27, 2022
Browse files
add rotnet code (#6065)
* add rotnet code * add config * fix infer for ssl * rm unused code
parent
3c9200c6
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
283 additions
and
29 deletions
+283
-29
configs/cls/ch_PP-OCRv3/ch_PP-OCRv3_rotnet.yml
configs/cls/ch_PP-OCRv3/ch_PP-OCRv3_rotnet.yml
+99
-0
deploy/slim/quantization/export_model.py
deploy/slim/quantization/export_model.py
+48
-20
deploy/slim/quantization/quant.py
deploy/slim/quantization/quant.py
+40
-2
doc/doc_ch/layout_datasets.md
doc/doc_ch/layout_datasets.md
+1
-1
ppocr/data/__init__.py
ppocr/data/__init__.py
+1
-0
ppocr/data/collate_fn.py
ppocr/data/collate_fn.py
+14
-0
ppocr/data/imaug/__init__.py
ppocr/data/imaug/__init__.py
+1
-0
ppocr/data/imaug/ssl_img_aug.py
ppocr/data/imaug/ssl_img_aug.py
+60
-0
ppocr/postprocess/cls_postprocess.py
ppocr/postprocess/cls_postprocess.py
+12
-3
tools/export_model.py
tools/export_model.py
+5
-3
tools/infer_cls.py
tools/infer_cls.py
+2
-0
No files found.
configs/cls/ch_PP-OCRv3/ch_PP-OCRv3_rotnet.yml
0 → 100644
View file @
4cddec73
Global
:
debug
:
false
use_gpu
:
true
epoch_num
:
100
log_smooth_window
:
20
print_batch_step
:
10
save_model_dir
:
./output/rec_ppocr_v3_rotnet
save_epoch_step
:
3
eval_batch_step
:
[
0
,
2000
]
cal_metric_during_train
:
true
pretrained_model
:
null
checkpoints
:
null
save_inference_dir
:
null
use_visualdl
:
false
infer_img
:
doc/imgs_words/ch/word_1.jpg
character_dict_path
:
ppocr/utils/ppocr_keys_v1.txt
max_text_length
:
25
infer_mode
:
false
use_space_char
:
true
save_res_path
:
./output/rec/predicts_chinese_lite_v2.0.txt
Optimizer
:
name
:
Adam
beta1
:
0.9
beta2
:
0.999
lr
:
name
:
Cosine
learning_rate
:
0.001
regularizer
:
name
:
L2
factor
:
1.0e-05
Architecture
:
model_type
:
cls
algorithm
:
CLS
Transform
:
null
Backbone
:
name
:
MobileNetV1Enhance
scale
:
0.5
last_conv_stride
:
[
1
,
2
]
last_pool_type
:
avg
Neck
:
Head
:
name
:
ClsHead
class_dim
:
4
Loss
:
name
:
ClsLoss
main_indicator
:
acc
PostProcess
:
name
:
ClsPostProcess
Metric
:
name
:
ClsMetric
main_indicator
:
acc
Train
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data
label_file_list
:
-
./train_data/train_list.txt
transforms
:
-
DecodeImage
:
img_mode
:
BGR
channel_first
:
false
-
RecAug
:
use_tia
:
False
-
RandAugment
:
-
SSLRotateResize
:
image_shape
:
[
3
,
48
,
320
]
-
KeepKeys
:
keep_keys
:
[
"
image"
,
"
label"
]
loader
:
collate_fn
:
"
SSLRotateCollate"
shuffle
:
true
batch_size_per_card
:
32
drop_last
:
true
num_workers
:
8
Eval
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data
label_file_list
:
-
./train_data/val_list.txt
transforms
:
-
DecodeImage
:
img_mode
:
BGR
channel_first
:
false
-
SSLRotateResize
:
image_shape
:
[
3
,
48
,
320
]
-
KeepKeys
:
keep_keys
:
[
"
image"
,
"
label"
]
loader
:
collate_fn
:
"
SSLRotateCollate"
shuffle
:
false
drop_last
:
false
batch_size_per_card
:
64
num_workers
:
8
profiler_options
:
null
deploy/slim/quantization/export_model.py
View file @
4cddec73
...
@@ -35,17 +35,7 @@ from ppocr.metrics import build_metric
...
@@ -35,17 +35,7 @@ from ppocr.metrics import build_metric
import
tools.program
as
program
import
tools.program
as
program
from
paddleslim.dygraph.quant
import
QAT
from
paddleslim.dygraph.quant
import
QAT
from
ppocr.data
import
build_dataloader
from
ppocr.data
import
build_dataloader
from
tools.export_model
import
export_single_model
def
export_single_model
(
quanter
,
model
,
infer_shape
,
save_path
,
logger
):
quanter
.
save_quantized_model
(
model
,
save_path
,
input_spec
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
None
]
+
infer_shape
,
dtype
=
'float32'
)
])
logger
.
info
(
'inference QAT model is saved to {}'
.
format
(
save_path
))
def
main
():
def
main
():
...
@@ -84,17 +74,54 @@ def main():
...
@@ -84,17 +74,54 @@ def main():
config
[
'Global'
])
config
[
'Global'
])
# build model
# build model
# for rec algorithm
if
hasattr
(
post_process_class
,
'character'
):
if
hasattr
(
post_process_class
,
'character'
):
char_num
=
len
(
getattr
(
post_process_class
,
'character'
))
char_num
=
len
(
getattr
(
post_process_class
,
'character'
))
if
config
[
'Architecture'
][
"algorithm"
]
in
[
"Distillation"
,
if
config
[
'Architecture'
][
"algorithm"
]
in
[
"Distillation"
,
]:
# distillation model
]:
# distillation model
for
key
in
config
[
'Architecture'
][
"Models"
]:
for
key
in
config
[
'Architecture'
][
"Models"
]:
config
[
'Architecture'
][
"Models"
][
key
][
"Head"
][
if
config
[
'Architecture'
][
'Models'
][
key
][
'Head'
][
'out_channels'
]
=
char_num
'name'
]
==
'MultiHead'
:
# for multi head
if
config
[
'PostProcess'
][
'name'
]
==
'DistillationSARLabelDecode'
:
char_num
=
char_num
-
2
# update SARLoss params
assert
list
(
config
[
'Loss'
][
'loss_config_list'
][
-
1
].
keys
())[
0
]
==
'DistillationSARLoss'
config
[
'Loss'
][
'loss_config_list'
][
-
1
][
'DistillationSARLoss'
][
'ignore_index'
]
=
char_num
+
1
out_channels_list
=
{}
out_channels_list
[
'CTCLabelDecode'
]
=
char_num
out_channels_list
[
'SARLabelDecode'
]
=
char_num
+
2
config
[
'Architecture'
][
'Models'
][
key
][
'Head'
][
'out_channels_list'
]
=
out_channels_list
else
:
config
[
'Architecture'
][
"Models"
][
key
][
"Head"
][
'out_channels'
]
=
char_num
elif
config
[
'Architecture'
][
'Head'
][
'name'
]
==
'MultiHead'
:
# for multi head
if
config
[
'PostProcess'
][
'name'
]
==
'SARLabelDecode'
:
char_num
=
char_num
-
2
# update SARLoss params
assert
list
(
config
[
'Loss'
][
'loss_config_list'
][
1
].
keys
())[
0
]
==
'SARLoss'
if
config
[
'Loss'
][
'loss_config_list'
][
1
][
'SARLoss'
]
is
None
:
config
[
'Loss'
][
'loss_config_list'
][
1
][
'SARLoss'
]
=
{
'ignore_index'
:
char_num
+
1
}
else
:
config
[
'Loss'
][
'loss_config_list'
][
1
][
'SARLoss'
][
'ignore_index'
]
=
char_num
+
1
out_channels_list
=
{}
out_channels_list
[
'CTCLabelDecode'
]
=
char_num
out_channels_list
[
'SARLabelDecode'
]
=
char_num
+
2
config
[
'Architecture'
][
'Head'
][
'out_channels_list'
]
=
out_channels_list
else
:
# base rec model
else
:
# base rec model
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
char_num
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
char_num
if
config
[
'PostProcess'
][
'name'
]
==
'SARLabelDecode'
:
# for SAR model
config
[
'Loss'
][
'ignore_index'
]
=
char_num
-
1
model
=
build_model
(
config
[
'Architecture'
])
model
=
build_model
(
config
[
'Architecture'
])
# get QAT model
# get QAT model
...
@@ -120,21 +147,22 @@ def main():
...
@@ -120,21 +147,22 @@ def main():
for
k
,
v
in
metric
.
items
():
for
k
,
v
in
metric
.
items
():
logger
.
info
(
'{}:{}'
.
format
(
k
,
v
))
logger
.
info
(
'{}:{}'
.
format
(
k
,
v
))
infer_shape
=
[
3
,
32
,
100
]
if
model_type
==
"rec"
else
[
3
,
640
,
640
]
save_path
=
config
[
"Global"
][
"save_inference_dir"
]
save_path
=
config
[
"Global"
][
"save_inference_dir"
]
arch_config
=
config
[
"Architecture"
]
arch_config
=
config
[
"Architecture"
]
arch_config
=
config
[
"Architecture"
]
if
arch_config
[
"algorithm"
]
in
[
"Distillation"
,
]:
# distillation model
if
arch_config
[
"algorithm"
]
in
[
"Distillation"
,
]:
# distillation model
archs
=
list
(
arch_config
[
"Models"
].
values
())
for
idx
,
name
in
enumerate
(
model
.
model_name_list
):
for
idx
,
name
in
enumerate
(
model
.
model_name_list
):
model
.
model_list
[
idx
].
eval
()
model
.
model_list
[
idx
].
eval
()
sub_model_save_path
=
os
.
path
.
join
(
save_path
,
name
,
"inference"
)
sub_model_save_path
=
os
.
path
.
join
(
save_path
,
name
,
"inference"
)
export_single_model
(
quanter
,
model
.
model_list
[
idx
],
infer_shape
,
export_single_model
(
model
.
model_list
[
idx
],
archs
[
idx
]
,
sub_model_save_path
,
logger
)
sub_model_save_path
,
logger
,
quanter
)
else
:
else
:
save_path
=
os
.
path
.
join
(
save_path
,
"inference"
)
save_path
=
os
.
path
.
join
(
save_path
,
"inference"
)
model
.
eval
()
export_single_model
(
model
,
arch_config
,
save_path
,
logger
,
quanter
)
export_single_model
(
quanter
,
model
,
infer_shape
,
save_path
,
logger
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
deploy/slim/quantization/quant.py
View file @
4cddec73
...
@@ -112,10 +112,48 @@ def main(config, device, logger, vdl_writer):
...
@@ -112,10 +112,48 @@ def main(config, device, logger, vdl_writer):
if
config
[
'Architecture'
][
"algorithm"
]
in
[
"Distillation"
,
if
config
[
'Architecture'
][
"algorithm"
]
in
[
"Distillation"
,
]:
# distillation model
]:
# distillation model
for
key
in
config
[
'Architecture'
][
"Models"
]:
for
key
in
config
[
'Architecture'
][
"Models"
]:
config
[
'Architecture'
][
"Models"
][
key
][
"Head"
][
if
config
[
'Architecture'
][
'Models'
][
key
][
'Head'
][
'out_channels'
]
=
char_num
'name'
]
==
'MultiHead'
:
# for multi head
if
config
[
'PostProcess'
][
'name'
]
==
'DistillationSARLabelDecode'
:
char_num
=
char_num
-
2
# update SARLoss params
assert
list
(
config
[
'Loss'
][
'loss_config_list'
][
-
1
].
keys
())[
0
]
==
'DistillationSARLoss'
config
[
'Loss'
][
'loss_config_list'
][
-
1
][
'DistillationSARLoss'
][
'ignore_index'
]
=
char_num
+
1
out_channels_list
=
{}
out_channels_list
[
'CTCLabelDecode'
]
=
char_num
out_channels_list
[
'SARLabelDecode'
]
=
char_num
+
2
config
[
'Architecture'
][
'Models'
][
key
][
'Head'
][
'out_channels_list'
]
=
out_channels_list
else
:
config
[
'Architecture'
][
"Models"
][
key
][
"Head"
][
'out_channels'
]
=
char_num
elif
config
[
'Architecture'
][
'Head'
][
'name'
]
==
'MultiHead'
:
# for multi head
if
config
[
'PostProcess'
][
'name'
]
==
'SARLabelDecode'
:
char_num
=
char_num
-
2
# update SARLoss params
assert
list
(
config
[
'Loss'
][
'loss_config_list'
][
1
].
keys
())[
0
]
==
'SARLoss'
if
config
[
'Loss'
][
'loss_config_list'
][
1
][
'SARLoss'
]
is
None
:
config
[
'Loss'
][
'loss_config_list'
][
1
][
'SARLoss'
]
=
{
'ignore_index'
:
char_num
+
1
}
else
:
config
[
'Loss'
][
'loss_config_list'
][
1
][
'SARLoss'
][
'ignore_index'
]
=
char_num
+
1
out_channels_list
=
{}
out_channels_list
[
'CTCLabelDecode'
]
=
char_num
out_channels_list
[
'SARLabelDecode'
]
=
char_num
+
2
config
[
'Architecture'
][
'Head'
][
'out_channels_list'
]
=
out_channels_list
else
:
# base rec model
else
:
# base rec model
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
char_num
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
char_num
if
config
[
'PostProcess'
][
'name'
]
==
'SARLabelDecode'
:
# for SAR model
config
[
'Loss'
][
'ignore_index'
]
=
char_num
-
1
model
=
build_model
(
config
[
'Architecture'
])
model
=
build_model
(
config
[
'Architecture'
])
pre_best_model_dict
=
dict
()
pre_best_model_dict
=
dict
()
...
...
doc/doc_ch/layout_datasets.md
View file @
4cddec73
...
@@ -27,7 +27,7 @@
...
@@ -27,7 +27,7 @@
#### 2、CDLA数据集
#### 2、CDLA数据集
-
**数据来源**
:https://github.com/buptlihang/CDLA
-
**数据来源**
:https://github.com/buptlihang/CDLA
-
**数据简介**
:
publaynet数
据集的训练集合中包含5000张图像,验证集合中包含1000张图像。总共包含10个类别,分别是:
`Text, Title, Figure, Figure caption, Table, Table caption, Header, Footer, Reference, Equation`
。部分图像以及标注框可视化如下所示。
-
**数据简介**
:
CDLA
据集的训练集合中包含5000张图像,验证集合中包含1000张图像。总共包含10个类别,分别是:
`Text, Title, Figure, Figure caption, Table, Table caption, Header, Footer, Reference, Equation`
。部分图像以及标注框可视化如下所示。
<div
align=
"center"
>
<div
align=
"center"
>
<img
src=
"../datasets/CDLA_demo/val_0633.jpg"
width=
"500"
>
<img
src=
"../datasets/CDLA_demo/val_0633.jpg"
width=
"500"
>
...
...
ppocr/data/__init__.py
View file @
4cddec73
...
@@ -72,6 +72,7 @@ def build_dataloader(config, mode, device, logger, seed=None):
...
@@ -72,6 +72,7 @@ def build_dataloader(config, mode, device, logger, seed=None):
use_shared_memory
=
loader_config
[
'use_shared_memory'
]
use_shared_memory
=
loader_config
[
'use_shared_memory'
]
else
:
else
:
use_shared_memory
=
True
use_shared_memory
=
True
if
mode
==
"Train"
:
if
mode
==
"Train"
:
# Distribute data to multiple cards
# Distribute data to multiple cards
batch_sampler
=
DistributedBatchSampler
(
batch_sampler
=
DistributedBatchSampler
(
...
...
ppocr/data/collate_fn.py
View file @
4cddec73
...
@@ -56,3 +56,17 @@ class ListCollator(object):
...
@@ -56,3 +56,17 @@ class ListCollator(object):
for
idx
in
to_tensor_idxs
:
for
idx
in
to_tensor_idxs
:
data_dict
[
idx
]
=
paddle
.
to_tensor
(
data_dict
[
idx
])
data_dict
[
idx
]
=
paddle
.
to_tensor
(
data_dict
[
idx
])
return
list
(
data_dict
.
values
())
return
list
(
data_dict
.
values
())
class
SSLRotateCollate
(
object
):
"""
bach: [
[(4*3xH*W), (4,)]
[(4*3xH*W), (4,)]
...
]
"""
def
__call__
(
self
,
batch
):
output
=
[
np
.
concatenate
(
d
,
axis
=
0
)
for
d
in
zip
(
*
batch
)]
return
output
ppocr/data/imaug/__init__.py
View file @
4cddec73
...
@@ -24,6 +24,7 @@ from .make_pse_gt import MakePseGt
...
@@ -24,6 +24,7 @@ from .make_pse_gt import MakePseGt
from
.rec_img_aug
import
RecAug
,
RecConAug
,
RecResizeImg
,
ClsResizeImg
,
\
from
.rec_img_aug
import
RecAug
,
RecConAug
,
RecResizeImg
,
ClsResizeImg
,
\
SRNRecResizeImg
,
NRTRRecResizeImg
,
SARRecResizeImg
,
PRENResizeImg
,
SVTRRecResizeImg
SRNRecResizeImg
,
NRTRRecResizeImg
,
SARRecResizeImg
,
PRENResizeImg
,
SVTRRecResizeImg
from
.ssl_img_aug
import
SSLRotateResize
from
.randaugment
import
RandAugment
from
.randaugment
import
RandAugment
from
.copy_paste
import
CopyPaste
from
.copy_paste
import
CopyPaste
from
.ColorJitter
import
ColorJitter
from
.ColorJitter
import
ColorJitter
...
...
ppocr/data/imaug/ssl_img_aug.py
0 → 100644
View file @
4cddec73
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
cv2
import
numpy
as
np
import
random
from
PIL
import
Image
from
.rec_img_aug
import
resize_norm_img
class
SSLRotateResize
(
object
):
def
__init__
(
self
,
image_shape
,
padding
=
False
,
select_all
=
True
,
mode
=
"train"
,
**
kwargs
):
self
.
image_shape
=
image_shape
self
.
padding
=
padding
self
.
select_all
=
select_all
self
.
mode
=
mode
def
__call__
(
self
,
data
):
img
=
data
[
"image"
]
data
[
"image_r90"
]
=
cv2
.
rotate
(
img
,
cv2
.
ROTATE_90_CLOCKWISE
)
data
[
"image_r180"
]
=
cv2
.
rotate
(
data
[
"image_r90"
],
cv2
.
ROTATE_90_CLOCKWISE
)
data
[
"image_r270"
]
=
cv2
.
rotate
(
data
[
"image_r180"
],
cv2
.
ROTATE_90_CLOCKWISE
)
images
=
[]
for
key
in
[
"image"
,
"image_r90"
,
"image_r180"
,
"image_r270"
]:
images
.
append
(
resize_norm_img
(
data
.
pop
(
key
),
image_shape
=
self
.
image_shape
,
padding
=
self
.
padding
)[
0
])
data
[
"image"
]
=
np
.
stack
(
images
,
axis
=
0
)
data
[
"label"
]
=
np
.
array
(
list
(
range
(
4
)))
if
not
self
.
select_all
:
data
[
"image"
]
=
data
[
"image"
][
0
::
2
]
# just choose 0 and 180
data
[
"label"
]
=
data
[
"label"
][
0
:
2
]
# label needs to be continuous
if
self
.
mode
==
"test"
:
data
[
"image"
]
=
data
[
"image"
][
0
]
data
[
"label"
]
=
data
[
"label"
][
0
]
return
data
ppocr/postprocess/cls_postprocess.py
View file @
4cddec73
...
@@ -17,17 +17,26 @@ import paddle
...
@@ -17,17 +17,26 @@ import paddle
class
ClsPostProcess
(
object
):
class
ClsPostProcess
(
object
):
""" Convert between text-label and text-index """
""" Convert between text-label and text-index """
def
__init__
(
self
,
label_list
,
**
kwargs
):
def
__init__
(
self
,
label_list
=
None
,
key
=
None
,
**
kwargs
):
super
(
ClsPostProcess
,
self
).
__init__
()
super
(
ClsPostProcess
,
self
).
__init__
()
self
.
label_list
=
label_list
self
.
label_list
=
label_list
self
.
key
=
key
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
if
self
.
key
is
not
None
:
preds
=
preds
[
self
.
key
]
label_list
=
self
.
label_list
if
label_list
is
None
:
label_list
=
{
idx
:
idx
for
idx
in
range
(
preds
.
shape
[
-
1
])}
if
isinstance
(
preds
,
paddle
.
Tensor
):
if
isinstance
(
preds
,
paddle
.
Tensor
):
preds
=
preds
.
numpy
()
preds
=
preds
.
numpy
()
pred_idxs
=
preds
.
argmax
(
axis
=
1
)
pred_idxs
=
preds
.
argmax
(
axis
=
1
)
decode_out
=
[(
self
.
label_list
[
idx
],
preds
[
i
,
idx
])
decode_out
=
[(
label_list
[
idx
],
preds
[
i
,
idx
])
for
i
,
idx
in
enumerate
(
pred_idxs
)]
for
i
,
idx
in
enumerate
(
pred_idxs
)]
if
label
is
None
:
if
label
is
None
:
return
decode_out
return
decode_out
label
=
[(
self
.
label_list
[
idx
],
1.0
)
for
idx
in
label
]
label
=
[(
label_list
[
idx
],
1.0
)
for
idx
in
label
]
return
decode_out
,
label
return
decode_out
,
label
tools/export_model.py
View file @
4cddec73
...
@@ -31,7 +31,7 @@ from ppocr.utils.logging import get_logger
...
@@ -31,7 +31,7 @@ from ppocr.utils.logging import get_logger
from
tools.program
import
load_config
,
merge_config
,
ArgsParser
from
tools.program
import
load_config
,
merge_config
,
ArgsParser
def
export_single_model
(
model
,
arch_config
,
save_path
,
logger
):
def
export_single_model
(
model
,
arch_config
,
save_path
,
logger
,
quanter
=
None
):
if
arch_config
[
"algorithm"
]
==
"SRN"
:
if
arch_config
[
"algorithm"
]
==
"SRN"
:
max_text_length
=
arch_config
[
"Head"
][
"max_text_length"
]
max_text_length
=
arch_config
[
"Head"
][
"max_text_length"
]
other_shape
=
[
other_shape
=
[
...
@@ -95,7 +95,10 @@ def export_single_model(model, arch_config, save_path, logger):
...
@@ -95,7 +95,10 @@ def export_single_model(model, arch_config, save_path, logger):
shape
=
[
None
]
+
infer_shape
,
dtype
=
"float32"
)
shape
=
[
None
]
+
infer_shape
,
dtype
=
"float32"
)
])
])
paddle
.
jit
.
save
(
model
,
save_path
)
if
quanter
is
None
:
paddle
.
jit
.
save
(
model
,
save_path
)
else
:
quanter
.
save_quantized_model
(
model
,
save_path
)
logger
.
info
(
"inference model is saved to {}"
.
format
(
save_path
))
logger
.
info
(
"inference model is saved to {}"
.
format
(
save_path
))
return
return
...
@@ -125,7 +128,6 @@ def main():
...
@@ -125,7 +128,6 @@ def main():
char_num
=
char_num
-
2
char_num
=
char_num
-
2
out_channels_list
[
'CTCLabelDecode'
]
=
char_num
out_channels_list
[
'CTCLabelDecode'
]
=
char_num
out_channels_list
[
'SARLabelDecode'
]
=
char_num
+
2
out_channels_list
[
'SARLabelDecode'
]
=
char_num
+
2
loss_list
=
config
[
'Loss'
][
'loss_config_list'
]
config
[
'Architecture'
][
'Models'
][
key
][
'Head'
][
config
[
'Architecture'
][
'Models'
][
key
][
'Head'
][
'out_channels_list'
]
=
out_channels_list
'out_channels_list'
]
=
out_channels_list
else
:
else
:
...
...
tools/infer_cls.py
View file @
4cddec73
...
@@ -57,6 +57,8 @@ def main():
...
@@ -57,6 +57,8 @@ def main():
continue
continue
elif
op_name
==
'KeepKeys'
:
elif
op_name
==
'KeepKeys'
:
op
[
op_name
][
'keep_keys'
]
=
[
'image'
]
op
[
op_name
][
'keep_keys'
]
=
[
'image'
]
elif
op_name
==
"SSLRotateResize"
:
op
[
op_name
][
"mode"
]
=
"test"
transforms
.
append
(
op
)
transforms
.
append
(
op
)
global_config
[
'infer_mode'
]
=
True
global_config
[
'infer_mode'
]
=
True
ops
=
create_operators
(
transforms
,
global_config
)
ops
=
create_operators
(
transforms
,
global_config
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment