Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
a5f75115
Commit
a5f75115
authored
Jun 05, 2021
by
WenmuZhou
Browse files
mv download func to ppocr/utils/network.py
parent
20466055
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
106 additions
and
86 deletions
+106
-86
paddleocr.py
paddleocr.py
+40
-86
ppocr/utils/network.py
ppocr/utils/network.py
+66
-0
No files found.
paddleocr.py
View file @
a5f75115
...
...
@@ -21,15 +21,13 @@ sys.path.append(os.path.join(__dir__, ''))
import
cv2
import
numpy
as
np
from
pathlib
import
Path
import
tarfile
import
requests
from
tqdm
import
tqdm
from
tools.infer
import
predict_system
from
ppocr.utils.logging
import
get_logger
logger
=
get_logger
()
from
ppocr.utils.utility
import
check_and_read_gif
,
get_image_file_list
from
ppocr.utils.network
import
maybe_download
,
download_with_progressbar
from
tools.infer.utility
import
draw_ocr
,
init_args
,
str2bool
__all__
=
[
'PaddleOCR'
]
...
...
@@ -123,50 +121,6 @@ SUPPORT_REC_MODEL = ['CRNN']
BASE_DIR
=
os
.
path
.
expanduser
(
"~/.paddleocr/"
)
def
download_with_progressbar
(
url
,
save_path
):
response
=
requests
.
get
(
url
,
stream
=
True
)
total_size_in_bytes
=
int
(
response
.
headers
.
get
(
'content-length'
,
0
))
block_size
=
1024
# 1 Kibibyte
progress_bar
=
tqdm
(
total
=
total_size_in_bytes
,
unit
=
'iB'
,
unit_scale
=
True
)
with
open
(
save_path
,
'wb'
)
as
file
:
for
data
in
response
.
iter_content
(
block_size
):
progress_bar
.
update
(
len
(
data
))
file
.
write
(
data
)
progress_bar
.
close
()
if
total_size_in_bytes
==
0
or
progress_bar
.
n
!=
total_size_in_bytes
:
logger
.
error
(
"Something went wrong while downloading models"
)
sys
.
exit
(
0
)
def
maybe_download
(
model_storage_directory
,
url
):
# using custom model
tar_file_name_list
=
[
'inference.pdiparams'
,
'inference.pdiparams.info'
,
'inference.pdmodel'
]
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
model_storage_directory
,
'inference.pdiparams'
)
)
or
not
os
.
path
.
exists
(
os
.
path
.
join
(
model_storage_directory
,
'inference.pdmodel'
)):
tmp_path
=
os
.
path
.
join
(
model_storage_directory
,
url
.
split
(
'/'
)[
-
1
])
print
(
'download {} to {}'
.
format
(
url
,
tmp_path
))
os
.
makedirs
(
model_storage_directory
,
exist_ok
=
True
)
download_with_progressbar
(
url
,
tmp_path
)
with
tarfile
.
open
(
tmp_path
,
'r'
)
as
tarObj
:
for
member
in
tarObj
.
getmembers
():
filename
=
None
for
tar_file_name
in
tar_file_name_list
:
if
tar_file_name
in
member
.
name
:
filename
=
tar_file_name
if
filename
is
None
:
continue
file
=
tarObj
.
extractfile
(
member
)
with
open
(
os
.
path
.
join
(
model_storage_directory
,
filename
),
'wb'
)
as
f
:
f
.
write
(
file
.
read
())
os
.
remove
(
tmp_path
)
def
parse_args
(
mMain
=
True
):
import
argparse
parser
=
init_args
()
...
...
@@ -194,10 +148,10 @@ class PaddleOCR(predict_system.TextSystem):
args:
**kwargs: other params show in paddleocr --help
"""
postprocess_
params
=
parse_args
(
mMain
=
False
)
postprocess_
params
.
__dict__
.
update
(
**
kwargs
)
self
.
use_angle_cls
=
postprocess_
params
.
use_angle_cls
lang
=
postprocess_
params
.
lang
params
=
parse_args
(
mMain
=
False
)
params
.
__dict__
.
update
(
**
kwargs
)
self
.
use_angle_cls
=
params
.
use_angle_cls
lang
=
params
.
lang
latin_lang
=
[
'af'
,
'az'
,
'bs'
,
'cs'
,
'cy'
,
'da'
,
'de'
,
'es'
,
'et'
,
'fr'
,
'ga'
,
'hr'
,
'hu'
,
'id'
,
'is'
,
'it'
,
'ku'
,
'la'
,
'lt'
,
'lv'
,
'mi'
,
'ms'
,
...
...
@@ -229,40 +183,40 @@ class PaddleOCR(predict_system.TextSystem):
else
:
det_lang
=
"en"
use_inner_dict
=
False
if
postprocess_
params
.
rec_char_dict_path
is
None
:
if
params
.
rec_char_dict_path
is
None
:
use_inner_dict
=
True
postprocess_
params
.
rec_char_dict_path
=
model_urls
[
'rec'
][
lang
][
params
.
rec_char_dict_path
=
model_urls
[
'rec'
][
lang
][
'dict_path'
]
# init model dir
if
postprocess_
params
.
det_model_dir
is
None
:
postprocess_
params
.
det_model_dir
=
os
.
path
.
join
(
BASE_DIR
,
VERSION
,
if
params
.
det_model_dir
is
None
:
params
.
det_model_dir
=
os
.
path
.
join
(
BASE_DIR
,
VERSION
,
'det'
,
det_lang
)
if
postprocess_
params
.
rec_model_dir
is
None
:
postprocess_
params
.
rec_model_dir
=
os
.
path
.
join
(
BASE_DIR
,
VERSION
,
if
params
.
rec_model_dir
is
None
:
params
.
rec_model_dir
=
os
.
path
.
join
(
BASE_DIR
,
VERSION
,
'rec'
,
lang
)
if
postprocess_params
.
cls_model_dir
is
None
:
postprocess_params
.
cls_model_dir
=
os
.
path
.
join
(
BASE_DIR
,
'cls'
)
print
(
postprocess_params
)
if
params
.
cls_model_dir
is
None
:
params
.
cls_model_dir
=
os
.
path
.
join
(
BASE_DIR
,
'cls'
)
# download model
maybe_download
(
postprocess_
params
.
det_model_dir
,
maybe_download
(
params
.
det_model_dir
,
model_urls
[
'det'
][
det_lang
])
maybe_download
(
postprocess_
params
.
rec_model_dir
,
maybe_download
(
params
.
rec_model_dir
,
model_urls
[
'rec'
][
lang
][
'url'
])
maybe_download
(
postprocess_
params
.
cls_model_dir
,
model_urls
[
'cls'
])
maybe_download
(
params
.
cls_model_dir
,
model_urls
[
'cls'
])
if
postprocess_
params
.
det_algorithm
not
in
SUPPORT_DET_MODEL
:
if
params
.
det_algorithm
not
in
SUPPORT_DET_MODEL
:
logger
.
error
(
'det_algorithm must in {}'
.
format
(
SUPPORT_DET_MODEL
))
sys
.
exit
(
0
)
if
postprocess_
params
.
rec_algorithm
not
in
SUPPORT_REC_MODEL
:
if
params
.
rec_algorithm
not
in
SUPPORT_REC_MODEL
:
logger
.
error
(
'rec_algorithm must in {}'
.
format
(
SUPPORT_REC_MODEL
))
sys
.
exit
(
0
)
if
use_inner_dict
:
postprocess_
params
.
rec_char_dict_path
=
str
(
Path
(
__file__
).
parent
/
postprocess_
params
.
rec_char_dict_path
)
params
.
rec_char_dict_path
=
str
(
Path
(
__file__
).
parent
/
params
.
rec_char_dict_path
)
print
(
params
)
# init det_model and rec_model
super
().
__init__
(
postprocess_
params
)
super
().
__init__
(
params
)
def
ocr
(
self
,
img
,
det
=
True
,
rec
=
True
,
cls
=
True
):
"""
...
...
ppocr/utils/network.py
0 → 100644
View file @
a5f75115
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
import
tarfile
import
requests
from
tqdm
import
tqdm
from
ppocr.utils.logging
import
get_logger
def
download_with_progressbar
(
url
,
save_path
):
logger
=
get_logger
()
response
=
requests
.
get
(
url
,
stream
=
True
)
total_size_in_bytes
=
int
(
response
.
headers
.
get
(
'content-length'
,
0
))
block_size
=
1024
# 1 Kibibyte
progress_bar
=
tqdm
(
total
=
total_size_in_bytes
,
unit
=
'iB'
,
unit_scale
=
True
)
with
open
(
save_path
,
'wb'
)
as
file
:
for
data
in
response
.
iter_content
(
block_size
):
progress_bar
.
update
(
len
(
data
))
file
.
write
(
data
)
progress_bar
.
close
()
if
total_size_in_bytes
==
0
or
progress_bar
.
n
!=
total_size_in_bytes
:
logger
.
error
(
"Something went wrong while downloading models"
)
sys
.
exit
(
0
)
def
maybe_download
(
model_storage_directory
,
url
):
# using custom model
tar_file_name_list
=
[
'inference.pdiparams'
,
'inference.pdiparams.info'
,
'inference.pdmodel'
]
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
model_storage_directory
,
'inference.pdiparams'
)
)
or
not
os
.
path
.
exists
(
os
.
path
.
join
(
model_storage_directory
,
'inference.pdmodel'
)):
tmp_path
=
os
.
path
.
join
(
model_storage_directory
,
url
.
split
(
'/'
)[
-
1
])
print
(
'download {} to {}'
.
format
(
url
,
tmp_path
))
os
.
makedirs
(
model_storage_directory
,
exist_ok
=
True
)
download_with_progressbar
(
url
,
tmp_path
)
with
tarfile
.
open
(
tmp_path
,
'r'
)
as
tarObj
:
for
member
in
tarObj
.
getmembers
():
filename
=
None
for
tar_file_name
in
tar_file_name_list
:
if
tar_file_name
in
member
.
name
:
filename
=
tar_file_name
if
filename
is
None
:
continue
file
=
tarObj
.
extractfile
(
member
)
with
open
(
os
.
path
.
join
(
model_storage_directory
,
filename
),
'wb'
)
as
f
:
f
.
write
(
file
.
read
())
os
.
remove
(
tmp_path
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment