Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
4950c845
Commit
4950c845
authored
Nov 17, 2020
by
WenmuZhou
Browse files
添加方向分类器
parent
931d138b
Changes
30
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
399 additions
and
27 deletions
+399
-27
ppocr/metrics/cls_metric.py
ppocr/metrics/cls_metric.py
+46
-0
ppocr/modeling/backbones/__init__.py
ppocr/modeling/backbones/__init__.py
+1
-1
ppocr/modeling/backbones/rec_mobilenet_v3.py
ppocr/modeling/backbones/rec_mobilenet_v3.py
+0
-10
ppocr/modeling/heads/__init__.py
ppocr/modeling/heads/__init__.py
+4
-1
ppocr/modeling/heads/cls_head.py
ppocr/modeling/heads/cls_head.py
+52
-0
ppocr/postprocess/__init__.py
ppocr/postprocess/__init__.py
+5
-2
ppocr/postprocess/cls_postprocess.py
ppocr/postprocess/cls_postprocess.py
+33
-0
tools/infer/predict_cls.py
tools/infer/predict_cls.py
+151
-0
tools/infer/utility.py
tools/infer/utility.py
+26
-13
tools/infer_cls.py
tools/infer_cls.py
+81
-0
No files found.
ppocr/metrics/cls_metric.py
0 → 100644
View file @
4950c845
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class
ClsMetric
(
object
):
def
__init__
(
self
,
main_indicator
=
'acc'
,
**
kwargs
):
self
.
main_indicator
=
main_indicator
self
.
reset
()
def
__call__
(
self
,
pred_label
,
*
args
,
**
kwargs
):
preds
,
labels
=
pred_label
correct_num
=
0
all_num
=
0
for
(
pred
,
pred_conf
),
(
target
,
_
)
in
zip
(
preds
,
labels
):
if
pred
==
target
:
correct_num
+=
1
all_num
+=
1
self
.
correct_num
+=
correct_num
self
.
all_num
+=
all_num
return
{
'acc'
:
correct_num
/
all_num
,
}
def
get_metric
(
self
):
"""
return metircs {
'acc': 0,
'norm_edit_dis': 0,
}
"""
acc
=
self
.
correct_num
/
self
.
all_num
self
.
reset
()
return
{
'acc'
:
acc
}
def
reset
(
self
):
self
.
correct_num
=
0
self
.
all_num
=
0
ppocr/modeling/backbones/__init__.py
View file @
4950c845
...
...
@@ -20,7 +20,7 @@ def build_backbone(config, model_type):
from
.det_mobilenet_v3
import
MobileNetV3
from
.det_resnet_vd
import
ResNet
support_dict
=
[
'MobileNetV3'
,
'ResNet'
,
'ResNet_SAST'
]
elif
model_type
==
'rec'
:
elif
model_type
==
'rec'
or
model_type
==
'cls'
:
from
.rec_mobilenet_v3
import
MobileNetV3
from
.rec_resnet_vd
import
ResNet
support_dict
=
[
'MobileNetV3'
,
'ResNet'
,
'ResNet_FPN'
]
...
...
ppocr/modeling/backbones/rec_mobilenet_v3.py
View file @
4950c845
...
...
@@ -136,13 +136,3 @@ class MobileNetV3(nn.Layer):
x
=
self
.
conv2
(
x
)
x
=
self
.
pool
(
x
)
return
x
if
__name__
==
'__main__'
:
import
paddle
paddle
.
disable_static
()
x
=
paddle
.
zeros
((
1
,
3
,
32
,
320
))
x
=
paddle
.
to_variable
(
x
)
net
=
MobileNetV3
(
model_name
=
'small'
,
small_stride
=
[
1
,
2
,
2
,
2
])
y
=
net
(
x
)
print
(
y
.
shape
)
ppocr/modeling/heads/__init__.py
View file @
4950c845
...
...
@@ -21,7 +21,10 @@ def build_head(config):
# rec head
from
.rec_ctc_head
import
CTCHead
support_dict
=
[
'DBHead'
,
'CTCHead'
]
# cls head
from
.cls_head
import
ClsHead
support_dict
=
[
'DBHead'
,
'CTCHead'
,
'ClsHead'
]
module_name
=
config
.
pop
(
'name'
)
assert
module_name
in
support_dict
,
Exception
(
'head only support {}'
.
format
(
...
...
ppocr/modeling/heads/cls_head.py
0 → 100644
View file @
4950c845
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
math
import
paddle
from
paddle
import
nn
,
ParamAttr
import
paddle.nn.functional
as
F
class
ClsHead
(
nn
.
Layer
):
"""
Class orientation
Args:
params(dict): super parameters for build Class network
"""
def
__init__
(
self
,
in_channels
,
class_dim
,
**
kwargs
):
super
(
ClsHead
,
self
).
__init__
()
self
.
pool
=
nn
.
AdaptiveAvgPool2D
(
1
)
stdv
=
1.0
/
math
.
sqrt
(
in_channels
*
1.0
)
self
.
fc
=
nn
.
Linear
(
in_channels
,
class_dim
,
weight_attr
=
ParamAttr
(
name
=
"fc_0.w_0"
,
initializer
=
nn
.
initializer
.
Uniform
(
-
stdv
,
stdv
)),
bias_attr
=
ParamAttr
(
name
=
"fc_0.b_0"
),
)
def
forward
(
self
,
x
):
x
=
self
.
pool
(
x
)
x
=
paddle
.
reshape
(
x
,
shape
=
[
x
.
shape
[
0
],
x
.
shape
[
1
]])
x
=
self
.
fc
(
x
)
if
not
self
.
training
:
x
=
F
.
softmax
(
x
,
axis
=
1
)
return
x
ppocr/postprocess/__init__.py
View file @
4950c845
...
...
@@ -25,8 +25,11 @@ __all__ = ['build_post_process']
def
build_post_process
(
config
,
global_config
=
None
):
from
.db_postprocess
import
DBPostProcess
from
.rec_postprocess
import
CTCLabelDecode
,
AttnLabelDecode
support_dict
=
[
'DBPostProcess'
,
'CTCLabelDecode'
,
'AttnLabelDecode'
]
from
.cls_postprocess
import
ClsPostProcess
support_dict
=
[
'DBPostProcess'
,
'CTCLabelDecode'
,
'AttnLabelDecode'
,
'ClsPostProcess'
]
config
=
copy
.
deepcopy
(
config
)
module_name
=
config
.
pop
(
'name'
)
...
...
ppocr/postprocess/cls_postprocess.py
0 → 100644
View file @
4950c845
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
class
ClsPostProcess
(
object
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
label_list
,
**
kwargs
):
super
(
ClsPostProcess
,
self
).
__init__
()
self
.
label_list
=
label_list
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
if
isinstance
(
preds
,
paddle
.
Tensor
):
preds
=
preds
.
numpy
()
pred_idxs
=
preds
.
argmax
(
axis
=
1
)
decode_out
=
[(
self
.
label_list
[
idx
],
preds
[
i
,
idx
])
for
i
,
idx
in
enumerate
(
pred_idxs
)]
if
label
is
None
:
return
decode_out
label
=
[(
self
.
label_list
[
idx
],
1.0
)
for
idx
in
label
]
return
decode_out
,
label
tools/infer/predict_cls.py
0 → 100755
View file @
4950c845
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../..'
)))
import
cv2
import
copy
import
numpy
as
np
import
math
import
time
import
paddle.fluid
as
fluid
import
tools.infer.utility
as
utility
from
ppocr.postprocess
import
build_post_process
from
ppocr.utils.logging
import
get_logger
from
ppocr.utils.utility
import
get_image_file_list
,
check_and_read_gif
class
TextClassifier
(
object
):
def
__init__
(
self
,
args
):
self
.
cls_image_shape
=
[
int
(
v
)
for
v
in
args
.
cls_image_shape
.
split
(
","
)]
self
.
cls_batch_num
=
args
.
rec_batch_num
self
.
cls_thresh
=
args
.
cls_thresh
self
.
use_zero_copy_run
=
args
.
use_zero_copy_run
postprocess_params
=
{
'name'
:
'ClsPostProcess'
,
"label_list"
:
args
.
label_list
,
}
self
.
postprocess_op
=
build_post_process
(
postprocess_params
)
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
=
\
utility
.
create_predictor
(
args
,
'cls'
,
logger
)
def
resize_norm_img
(
self
,
img
):
imgC
,
imgH
,
imgW
=
self
.
cls_image_shape
h
=
img
.
shape
[
0
]
w
=
img
.
shape
[
1
]
ratio
=
w
/
float
(
h
)
if
math
.
ceil
(
imgH
*
ratio
)
>
imgW
:
resized_w
=
imgW
else
:
resized_w
=
int
(
math
.
ceil
(
imgH
*
ratio
))
resized_image
=
cv2
.
resize
(
img
,
(
resized_w
,
imgH
))
resized_image
=
resized_image
.
astype
(
'float32'
)
if
self
.
cls_image_shape
[
0
]
==
1
:
resized_image
=
resized_image
/
255
resized_image
=
resized_image
[
np
.
newaxis
,
:]
else
:
resized_image
=
resized_image
.
transpose
((
2
,
0
,
1
))
/
255
resized_image
-=
0.5
resized_image
/=
0.5
padding_im
=
np
.
zeros
((
imgC
,
imgH
,
imgW
),
dtype
=
np
.
float32
)
padding_im
[:,
:,
0
:
resized_w
]
=
resized_image
return
padding_im
def
__call__
(
self
,
img_list
):
img_list
=
copy
.
deepcopy
(
img_list
)
img_num
=
len
(
img_list
)
# Calculate the aspect ratio of all text bars
width_list
=
[]
for
img
in
img_list
:
width_list
.
append
(
img
.
shape
[
1
]
/
float
(
img
.
shape
[
0
]))
# Sorting can speed up the cls process
indices
=
np
.
argsort
(
np
.
array
(
width_list
))
cls_res
=
[[
''
,
0.0
]]
*
img_num
batch_num
=
self
.
cls_batch_num
predict_time
=
0
for
beg_img_no
in
range
(
0
,
img_num
,
batch_num
):
end_img_no
=
min
(
img_num
,
beg_img_no
+
batch_num
)
norm_img_batch
=
[]
max_wh_ratio
=
0
for
ino
in
range
(
beg_img_no
,
end_img_no
):
h
,
w
=
img_list
[
indices
[
ino
]].
shape
[
0
:
2
]
wh_ratio
=
w
*
1.0
/
h
max_wh_ratio
=
max
(
max_wh_ratio
,
wh_ratio
)
for
ino
in
range
(
beg_img_no
,
end_img_no
):
norm_img
=
self
.
resize_norm_img
(
img_list
[
indices
[
ino
]])
norm_img
=
norm_img
[
np
.
newaxis
,
:]
norm_img_batch
.
append
(
norm_img
)
norm_img_batch
=
np
.
concatenate
(
norm_img_batch
)
norm_img_batch
=
norm_img_batch
.
copy
()
starttime
=
time
.
time
()
if
self
.
use_zero_copy_run
:
self
.
input_tensor
.
copy_from_cpu
(
norm_img_batch
)
self
.
predictor
.
zero_copy_run
()
else
:
norm_img_batch
=
fluid
.
core
.
PaddleTensor
(
norm_img_batch
)
self
.
predictor
.
run
([
norm_img_batch
])
prob_out
=
self
.
output_tensors
[
0
].
copy_to_cpu
()
cls_res
=
self
.
postprocess_op
(
prob_out
)
elapse
=
time
.
time
()
-
starttime
for
rno
in
range
(
len
(
cls_res
)):
label
,
score
=
cls_res
[
rno
]
cls_res
[
indices
[
beg_img_no
+
rno
]]
=
[
label
,
score
]
if
'180'
in
label
and
score
>
self
.
cls_thresh
:
img_list
[
indices
[
beg_img_no
+
rno
]]
=
cv2
.
rotate
(
img_list
[
indices
[
beg_img_no
+
rno
]],
1
)
return
img_list
,
cls_res
,
predict_time
def
main
(
args
):
image_file_list
=
get_image_file_list
(
args
.
image_dir
)
text_classifier
=
TextClassifier
(
args
)
valid_image_file_list
=
[]
img_list
=
[]
for
image_file
in
image_file_list
:
img
,
flag
=
check_and_read_gif
(
image_file
)
if
not
flag
:
img
=
cv2
.
imread
(
image_file
)
if
img
is
None
:
logger
.
info
(
"error in loading image:{}"
.
format
(
image_file
))
continue
valid_image_file_list
.
append
(
image_file
)
img_list
.
append
(
img
)
try
:
img_list
,
cls_res
,
predict_time
=
text_classifier
(
img_list
)
except
Exception
as
e
:
print
(
e
)
logger
.
info
(
"ERROR!!!!
\n
"
"Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq
\n
"
"If your model has tps module: "
"TPS does not support variable shape.
\n
"
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' "
)
exit
()
for
ino
in
range
(
len
(
img_list
)):
print
(
"Predicts of %s:%s"
%
(
valid_image_file_list
[
ino
],
cls_res
[
ino
]))
print
(
"Total predict time for %d images, cost: %.3f"
%
(
len
(
img_list
),
predict_time
))
if
__name__
==
"__main__"
:
logger
=
get_logger
()
main
(
utility
.
parse_args
())
tools/infer/utility.py
View file @
4950c845
...
...
@@ -29,48 +29,61 @@ def parse_args():
return
v
.
lower
()
in
(
"true"
,
"t"
,
"1"
)
parser
=
argparse
.
ArgumentParser
()
#params for prediction engine
#
params for prediction engine
parser
.
add_argument
(
"--use_gpu"
,
type
=
str2bool
,
default
=
True
)
parser
.
add_argument
(
"--ir_optim"
,
type
=
str2bool
,
default
=
True
)
parser
.
add_argument
(
"--use_tensorrt"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--gpu_mem"
,
type
=
int
,
default
=
8000
)
#params for text detector
#
params for text detector
parser
.
add_argument
(
"--image_dir"
,
type
=
str
)
parser
.
add_argument
(
"--det_algorithm"
,
type
=
str
,
default
=
'DB'
)
parser
.
add_argument
(
"--det_model_dir"
,
type
=
str
)
parser
.
add_argument
(
"--det_limit_side_len"
,
type
=
float
,
default
=
960
)
parser
.
add_argument
(
"--det_limit_type"
,
type
=
str
,
default
=
'max'
)
parser
.
add_argument
(
"--det_max_side_len"
,
type
=
float
,
default
=
960
)
#DB parmas
#
DB parmas
parser
.
add_argument
(
"--det_db_thresh"
,
type
=
float
,
default
=
0.3
)
parser
.
add_argument
(
"--det_db_box_thresh"
,
type
=
float
,
default
=
0.5
)
parser
.
add_argument
(
"--det_db_unclip_ratio"
,
type
=
float
,
default
=
2.0
)
parser
.
add_argument
(
"--det_db_unclip_ratio"
,
type
=
float
,
default
=
1.6
)
#EAST parmas
#
EAST parmas
parser
.
add_argument
(
"--det_east_score_thresh"
,
type
=
float
,
default
=
0.8
)
parser
.
add_argument
(
"--det_east_cover_thresh"
,
type
=
float
,
default
=
0.1
)
parser
.
add_argument
(
"--det_east_nms_thresh"
,
type
=
float
,
default
=
0.2
)
#SAST parmas
#
SAST parmas
parser
.
add_argument
(
"--det_sast_score_thresh"
,
type
=
float
,
default
=
0.5
)
parser
.
add_argument
(
"--det_sast_nms_thresh"
,
type
=
float
,
default
=
0.2
)
parser
.
add_argument
(
"--det_sast_polygon"
,
type
=
bool
,
default
=
False
)
#params for text recognizer
#
params for text recognizer
parser
.
add_argument
(
"--rec_algorithm"
,
type
=
str
,
default
=
'CRNN'
)
parser
.
add_argument
(
"--rec_model_dir"
,
type
=
str
)
parser
.
add_argument
(
"--rec_image_shape"
,
type
=
str
,
default
=
"3, 32, 320"
)
parser
.
add_argument
(
"--rec_char_type"
,
type
=
str
,
default
=
'ch'
)
parser
.
add_argument
(
"--rec_batch_num"
,
type
=
int
,
default
=
30
)
parser
.
add_argument
(
"--rec_batch_num"
,
type
=
int
,
default
=
6
)
parser
.
add_argument
(
"--max_text_length"
,
type
=
int
,
default
=
25
)
parser
.
add_argument
(
"--rec_char_dict_path"
,
type
=
str
,
default
=
"./ppocr/utils/ppocr_keys_v1.txt"
)
parser
.
add_argument
(
"--use_space_char"
,
type
=
bool
,
default
=
True
)
parser
.
add_argument
(
"--enable_mkldnn"
,
type
=
bool
,
default
=
False
)
parser
.
add_argument
(
"--use_zero_copy_run"
,
type
=
bool
,
default
=
False
)
parser
.
add_argument
(
"--use_space_char"
,
type
=
str2bool
,
default
=
True
)
parser
.
add_argument
(
"--vis_font_path"
,
type
=
str
,
default
=
"./doc/simfang.ttf"
)
# params for text classifier
parser
.
add_argument
(
"--use_angle_cls"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--cls_model_dir"
,
type
=
str
)
parser
.
add_argument
(
"--cls_image_shape"
,
type
=
str
,
default
=
"3, 48, 192"
)
parser
.
add_argument
(
"--label_list"
,
type
=
list
,
default
=
[
'0'
,
'180'
])
parser
.
add_argument
(
"--cls_batch_num"
,
type
=
int
,
default
=
30
)
parser
.
add_argument
(
"--cls_thresh"
,
type
=
float
,
default
=
0.9
)
parser
.
add_argument
(
"--enable_mkldnn"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--use_zero_copy_run"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--use_pdserving"
,
type
=
str2bool
,
default
=
False
)
return
parser
.
parse_args
()
...
...
tools/infer_cls.py
0 → 100755
View file @
4950c845
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'..'
)))
import
paddle
from
ppocr.data
import
create_operators
,
transform
from
ppocr.modeling.architectures
import
build_model
from
ppocr.postprocess
import
build_post_process
from
ppocr.utils.save_load
import
init_model
from
ppocr.utils.utility
import
get_image_file_list
import
tools.program
as
program
def
main
():
global_config
=
config
[
'Global'
]
# build post process
post_process_class
=
build_post_process
(
config
[
'PostProcess'
],
global_config
)
# build model
model
=
build_model
(
config
[
'Architecture'
])
init_model
(
config
,
model
,
logger
)
# create data ops
transforms
=
[]
for
op
in
config
[
'Eval'
][
'dataset'
][
'transforms'
]:
op_name
=
list
(
op
)[
0
]
if
'Label'
in
op_name
:
continue
elif
op_name
==
'KeepKeys'
:
op
[
op_name
][
'keep_keys'
]
=
[
'image'
]
transforms
.
append
(
op
)
global_config
[
'infer_mode'
]
=
True
ops
=
create_operators
(
transforms
,
global_config
)
model
.
eval
()
for
file
in
get_image_file_list
(
config
[
'Global'
][
'infer_img'
]):
logger
.
info
(
"infer_img: {}"
.
format
(
file
))
with
open
(
file
,
'rb'
)
as
f
:
img
=
f
.
read
()
data
=
{
'image'
:
img
}
batch
=
transform
(
data
,
ops
)
images
=
np
.
expand_dims
(
batch
[
0
],
axis
=
0
)
images
=
paddle
.
to_tensor
(
images
)
preds
=
model
(
images
)
post_result
=
post_process_class
(
preds
)
for
rec_reuslt
in
post_result
:
logger
.
info
(
'
\t
result: {}'
.
format
(
rec_reuslt
))
logger
.
info
(
"success!"
)
if
__name__
==
'__main__'
:
config
,
device
,
logger
,
vdl_writer
=
program
.
preprocess
()
main
()
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment