Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
bde8cad0
"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "b7fcf14f6eefdefeebb7616641767c0004c0d640"
Unverified
Commit
bde8cad0
authored
Aug 08, 2022
by
topduke
Committed by
GitHub
Aug 08, 2022
Browse files
add svtr ch model (#7134)
parent
05b6d296
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
191 additions
and
9 deletions
+191
-9
configs/rec/rec_svtrnet.yml
configs/rec/rec_svtrnet.yml
+3
-5
configs/rec/rec_svtrnet_ch.yml
configs/rec/rec_svtrnet_ch.yml
+155
-0
ppocr/data/imaug/__init__.py
ppocr/data/imaug/__init__.py
+2
-1
ppocr/data/imaug/rec_img_aug.py
ppocr/data/imaug/rec_img_aug.py
+15
-0
tools/export_model.py
tools/export_model.py
+16
-3
No files found.
configs/rec/rec_svtrnet.yml
View file @
bde8cad0
...
@@ -83,8 +83,7 @@ Train:
...
@@ -83,8 +83,7 @@ Train:
img_mode
:
BGR
img_mode
:
BGR
channel_first
:
False
channel_first
:
False
-
CTCLabelEncode
:
# Class handling label
-
CTCLabelEncode
:
# Class handling label
-
RecResizeImg
:
-
SVTRRecResizeImg
:
character_dict_path
:
image_shape
:
[
3
,
64
,
256
]
image_shape
:
[
3
,
64
,
256
]
padding
:
False
padding
:
False
-
KeepKeys
:
-
KeepKeys
:
...
@@ -98,14 +97,13 @@ Train:
...
@@ -98,14 +97,13 @@ Train:
Eval
:
Eval
:
dataset
:
dataset
:
name
:
LMDBDataSet
name
:
LMDBDataSet
data_dir
:
./train_data/data_lmdb_release/val
id
ation/
data_dir
:
./train_data/data_lmdb_release/
e
val
u
ation/
transforms
:
transforms
:
-
DecodeImage
:
# load image
-
DecodeImage
:
# load image
img_mode
:
BGR
img_mode
:
BGR
channel_first
:
False
channel_first
:
False
-
CTCLabelEncode
:
# Class handling label
-
CTCLabelEncode
:
# Class handling label
-
RecResizeImg
:
-
SVTRRecResizeImg
:
character_dict_path
:
image_shape
:
[
3
,
64
,
256
]
image_shape
:
[
3
,
64
,
256
]
padding
:
False
padding
:
False
-
KeepKeys
:
-
KeepKeys
:
...
...
configs/rec/rec_svtrnet_ch.yml
0 → 100644
View file @
bde8cad0
Global
:
use_gpu
:
true
epoch_num
:
100
log_smooth_window
:
20
print_batch_step
:
10
save_model_dir
:
./output/rec/svtr_ch_all/
save_epoch_step
:
10
eval_batch_step
:
-
0
-
2000
cal_metric_during_train
:
true
pretrained_model
:
null
checkpoints
:
null
save_inference_dir
:
null
use_visualdl
:
false
infer_img
:
doc/imgs_words/ch/word_1.jpg
character_dict_path
:
ppocr/utils/ppocr_keys_v1.txt
max_text_length
:
25
infer_mode
:
false
use_space_char
:
true
save_res_path
:
./output/rec/predicts_svtr_tiny_ch_all.txt
Optimizer
:
name
:
AdamW
beta1
:
0.9
beta2
:
0.99
epsilon
:
8.0e-08
weight_decay
:
0.05
no_weight_decay_name
:
norm pos_embed
one_dim_param_no_weight_decay
:
true
lr
:
name
:
Cosine
learning_rate
:
0.0005
warmup_epoch
:
2
Architecture
:
model_type
:
rec
algorithm
:
SVTR
Transform
:
null
Backbone
:
name
:
SVTRNet
img_size
:
-
32
-
320
out_char_num
:
40
out_channels
:
96
patch_merging
:
Conv
embed_dim
:
-
64
-
128
-
256
depth
:
-
3
-
6
-
3
num_heads
:
-
2
-
4
-
8
mixer
:
-
Local
-
Local
-
Local
-
Local
-
Local
-
Local
-
Global
-
Global
-
Global
-
Global
-
Global
-
Global
local_mixer
:
-
-
7
-
11
-
-
7
-
11
-
-
7
-
11
last_stage
:
true
prenorm
:
false
Neck
:
name
:
SequenceEncoder
encoder_type
:
reshape
Head
:
name
:
CTCHead
Loss
:
name
:
CTCLoss
PostProcess
:
name
:
CTCLabelDecode
Metric
:
name
:
RecMetric
main_indicator
:
acc
Train
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data
label_file_list
:
-
./train_data/train_list.txt
ext_op_transform_idx
:
1
transforms
:
-
DecodeImage
:
img_mode
:
BGR
channel_first
:
false
-
RecConAug
:
prob
:
0.5
ext_data_num
:
2
image_shape
:
-
32
-
320
-
3
-
RecAug
:
null
-
CTCLabelEncode
:
null
-
SVTRRecResizeImg
:
image_shape
:
-
3
-
32
-
320
padding
:
true
-
KeepKeys
:
keep_keys
:
-
image
-
label
-
length
loader
:
shuffle
:
true
batch_size_per_card
:
256
drop_last
:
true
num_workers
:
8
Eval
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data
label_file_list
:
-
./train_data/val_list.txt
transforms
:
-
DecodeImage
:
img_mode
:
BGR
channel_first
:
false
-
CTCLabelEncode
:
null
-
SVTRRecResizeImg
:
image_shape
:
-
3
-
32
-
320
padding
:
true
-
KeepKeys
:
keep_keys
:
-
image
-
label
-
length
loader
:
shuffle
:
false
drop_last
:
false
batch_size_per_card
:
256
num_workers
:
2
profiler_options
:
null
ppocr/data/imaug/__init__.py
View file @
bde8cad0
...
@@ -23,7 +23,8 @@ from .random_crop_data import EastRandomCropData, RandomCropImgMask
...
@@ -23,7 +23,8 @@ from .random_crop_data import EastRandomCropData, RandomCropImgMask
from
.make_pse_gt
import
MakePseGt
from
.make_pse_gt
import
MakePseGt
from
.rec_img_aug
import
RecAug
,
RecConAug
,
RecResizeImg
,
ClsResizeImg
,
\
from
.rec_img_aug
import
RecAug
,
RecConAug
,
RecResizeImg
,
ClsResizeImg
,
\
SRNRecResizeImg
,
NRTRRecResizeImg
,
SARRecResizeImg
,
PRENResizeImg
SRNRecResizeImg
,
NRTRRecResizeImg
,
SARRecResizeImg
,
PRENResizeImg
,
\
SVTRRecResizeImg
from
.ssl_img_aug
import
SSLRotateResize
from
.ssl_img_aug
import
SSLRotateResize
from
.randaugment
import
RandAugment
from
.randaugment
import
RandAugment
from
.copy_paste
import
CopyPaste
from
.copy_paste
import
CopyPaste
...
...
ppocr/data/imaug/rec_img_aug.py
View file @
bde8cad0
...
@@ -207,6 +207,21 @@ class PRENResizeImg(object):
...
@@ -207,6 +207,21 @@ class PRENResizeImg(object):
return
data
return
data
class
SVTRRecResizeImg
(
object
):
def
__init__
(
self
,
image_shape
,
padding
=
True
,
**
kwargs
):
self
.
image_shape
=
image_shape
self
.
padding
=
padding
def
__call__
(
self
,
data
):
img
=
data
[
'image'
]
norm_img
,
valid_ratio
=
resize_norm_img
(
img
,
self
.
image_shape
,
self
.
padding
)
data
[
'image'
]
=
norm_img
data
[
'valid_ratio'
]
=
valid_ratio
return
data
def
resize_norm_img_sar
(
img
,
image_shape
,
width_downsample_ratio
=
0.25
):
def
resize_norm_img_sar
(
img
,
image_shape
,
width_downsample_ratio
=
0.25
):
imgC
,
imgH
,
imgW_min
,
imgW_max
=
image_shape
imgC
,
imgH
,
imgW_min
,
imgW_max
=
image_shape
h
=
img
.
shape
[
0
]
h
=
img
.
shape
[
0
]
...
...
tools/export_model.py
View file @
bde8cad0
...
@@ -31,7 +31,12 @@ from ppocr.utils.logging import get_logger
...
@@ -31,7 +31,12 @@ from ppocr.utils.logging import get_logger
from
tools.program
import
load_config
,
merge_config
,
ArgsParser
from
tools.program
import
load_config
,
merge_config
,
ArgsParser
def
export_single_model
(
model
,
arch_config
,
save_path
,
logger
,
quanter
=
None
):
def
export_single_model
(
model
,
arch_config
,
save_path
,
logger
,
input_shape
=
None
,
quanter
=
None
):
if
arch_config
[
"algorithm"
]
==
"SRN"
:
if
arch_config
[
"algorithm"
]
==
"SRN"
:
max_text_length
=
arch_config
[
"Head"
][
"max_text_length"
]
max_text_length
=
arch_config
[
"Head"
][
"max_text_length"
]
other_shape
=
[
other_shape
=
[
...
@@ -64,7 +69,7 @@ def export_single_model(model, arch_config, save_path, logger, quanter=None):
...
@@ -64,7 +69,7 @@ def export_single_model(model, arch_config, save_path, logger, quanter=None):
else
:
else
:
other_shape
=
[
other_shape
=
[
paddle
.
static
.
InputSpec
(
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
3
,
64
,
256
]
,
dtype
=
"float32"
),
shape
=
[
None
]
+
input_shape
,
dtype
=
"float32"
),
]
]
model
=
to_static
(
model
,
input_spec
=
other_shape
)
model
=
to_static
(
model
,
input_spec
=
other_shape
)
elif
arch_config
[
"algorithm"
]
==
"PREN"
:
elif
arch_config
[
"algorithm"
]
==
"PREN"
:
...
@@ -157,6 +162,13 @@ def main():
...
@@ -157,6 +162,13 @@ def main():
arch_config
=
config
[
"Architecture"
]
arch_config
=
config
[
"Architecture"
]
if
arch_config
[
"algorithm"
]
==
"SVTR"
and
arch_config
[
"Head"
][
"name"
]
!=
'MultiHead'
:
input_shape
=
config
[
"Eval"
][
"dataset"
][
"transforms"
][
-
2
][
'SVTRRecResizeImg'
][
'image_shape'
]
else
:
input_shape
=
None
if
arch_config
[
"algorithm"
]
in
[
"Distillation"
,
]:
# distillation model
if
arch_config
[
"algorithm"
]
in
[
"Distillation"
,
]:
# distillation model
archs
=
list
(
arch_config
[
"Models"
].
values
())
archs
=
list
(
arch_config
[
"Models"
].
values
())
for
idx
,
name
in
enumerate
(
model
.
model_name_list
):
for
idx
,
name
in
enumerate
(
model
.
model_name_list
):
...
@@ -165,7 +177,8 @@ def main():
...
@@ -165,7 +177,8 @@ def main():
sub_model_save_path
,
logger
)
sub_model_save_path
,
logger
)
else
:
else
:
save_path
=
os
.
path
.
join
(
save_path
,
"inference"
)
save_path
=
os
.
path
.
join
(
save_path
,
"inference"
)
export_single_model
(
model
,
arch_config
,
save_path
,
logger
)
export_single_model
(
model
,
arch_config
,
save_path
,
logger
,
input_shape
=
input_shape
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment