Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
fe46b77e
Commit
fe46b77e
authored
Dec 15, 2020
by
MissPenguin
Browse files
fix conflicts
parents
c9a8cd83
9ad5c6b2
Changes
55
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
798 additions
and
0 deletions
+798
-0
StyleTextRec/engine/style_samplers.py
StyleTextRec/engine/style_samplers.py
+62
-0
StyleTextRec/engine/synthesisers.py
StyleTextRec/engine/synthesisers.py
+71
-0
StyleTextRec/engine/text_drawers.py
StyleTextRec/engine/text_drawers.py
+57
-0
StyleTextRec/engine/writers.py
StyleTextRec/engine/writers.py
+71
-0
StyleTextRec/examples/corpus/example.txt
StyleTextRec/examples/corpus/example.txt
+2
-0
StyleTextRec/examples/image_list.txt
StyleTextRec/examples/image_list.txt
+2
-0
StyleTextRec/examples/style_images/1.jpg
StyleTextRec/examples/style_images/1.jpg
+0
-0
StyleTextRec/examples/style_images/2.jpg
StyleTextRec/examples/style_images/2.jpg
+0
-0
StyleTextRec/fonts/ch_standard.ttf
StyleTextRec/fonts/ch_standard.ttf
+0
-0
StyleTextRec/fonts/en_standard.ttf
StyleTextRec/fonts/en_standard.ttf
+0
-0
StyleTextRec/fonts/ko_standard.ttf
StyleTextRec/fonts/ko_standard.ttf
+0
-0
StyleTextRec/tools/__init__.py
StyleTextRec/tools/__init__.py
+0
-0
StyleTextRec/tools/synth_dataset.py
StyleTextRec/tools/synth_dataset.py
+23
-0
StyleTextRec/tools/synth_image.py
StyleTextRec/tools/synth_image.py
+82
-0
StyleTextRec/utils/__init__.py
StyleTextRec/utils/__init__.py
+0
-0
StyleTextRec/utils/config.py
StyleTextRec/utils/config.py
+224
-0
StyleTextRec/utils/load_params.py
StyleTextRec/utils/load_params.py
+27
-0
StyleTextRec/utils/logging.py
StyleTextRec/utils/logging.py
+65
-0
StyleTextRec/utils/math_functions.py
StyleTextRec/utils/math_functions.py
+45
-0
StyleTextRec/utils/sys_funcs.py
StyleTextRec/utils/sys_funcs.py
+67
-0
No files found.
StyleTextRec/engine/style_samplers.py
0 → 100644
View file @
fe46b77e
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
random
import
cv2
class
DatasetSampler
(
object
):
def
__init__
(
self
,
config
):
self
.
image_home
=
config
[
"StyleSampler"
][
"image_home"
]
label_file
=
config
[
"StyleSampler"
][
"label_file"
]
self
.
dataset_with_label
=
config
[
"StyleSampler"
][
"with_label"
]
self
.
height
=
config
[
"Global"
][
"image_height"
]
self
.
index
=
0
with
open
(
label_file
,
"r"
)
as
f
:
label_raw
=
f
.
read
()
self
.
path_label_list
=
label_raw
.
split
(
"
\n
"
)[:
-
1
]
assert
len
(
self
.
path_label_list
)
>
0
random
.
shuffle
(
self
.
path_label_list
)
def
sample
(
self
):
if
self
.
index
>=
len
(
self
.
path_label_list
):
random
.
shuffle
(
self
.
path_label_list
)
self
.
index
=
0
if
self
.
dataset_with_label
:
path_label
=
self
.
path_label_list
[
self
.
index
]
rel_image_path
,
label
=
path_label
.
split
(
'
\t
'
)
else
:
rel_image_path
=
self
.
path_label_list
[
self
.
index
]
label
=
None
img_path
=
"{}/{}"
.
format
(
self
.
image_home
,
rel_image_path
)
image
=
cv2
.
imread
(
img_path
)
origin_height
=
image
.
shape
[
0
]
ratio
=
self
.
height
/
origin_height
width
=
int
(
image
.
shape
[
1
]
*
ratio
)
height
=
int
(
image
.
shape
[
0
]
*
ratio
)
image
=
cv2
.
resize
(
image
,
(
width
,
height
))
self
.
index
+=
1
if
label
:
return
{
"image"
:
image
,
"label"
:
label
}
else
:
return
{
"image"
:
image
}
def
duplicate_image
(
image
,
width
):
image_width
=
image
.
shape
[
1
]
dup_num
=
width
//
image_width
+
1
image
=
np
.
tile
(
image
,
reps
=
[
1
,
dup_num
,
1
])
cropped_image
=
image
[:,
:
width
,
:]
return
cropped_image
StyleTextRec/engine/synthesisers.py
0 → 100644
View file @
fe46b77e
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
utils.config
import
ArgsParser
,
load_config
,
override_config
from
utils.logging
import
get_logger
from
engine
import
style_samplers
,
corpus_generators
,
text_drawers
,
predictors
,
writers
class
ImageSynthesiser
(
object
):
def
__init__
(
self
):
self
.
FLAGS
=
ArgsParser
().
parse_args
()
self
.
config
=
load_config
(
self
.
FLAGS
.
config
)
self
.
config
=
override_config
(
self
.
config
,
options
=
self
.
FLAGS
.
override
)
self
.
output_dir
=
self
.
config
[
"Global"
][
"output_dir"
]
if
not
os
.
path
.
exists
(
self
.
output_dir
):
os
.
mkdir
(
self
.
output_dir
)
self
.
logger
=
get_logger
(
log_file
=
'{}/predict.log'
.
format
(
self
.
output_dir
))
self
.
text_drawer
=
text_drawers
.
StdTextDrawer
(
self
.
config
)
predictor_method
=
self
.
config
[
"Predictor"
][
"method"
]
assert
predictor_method
is
not
None
self
.
predictor
=
getattr
(
predictors
,
predictor_method
)(
self
.
config
)
def
synth_image
(
self
,
corpus
,
style_input
,
language
=
"en"
):
corpus
,
text_input
=
self
.
text_drawer
.
draw_text
(
corpus
,
language
)
synth_result
=
self
.
predictor
.
predict
(
style_input
,
text_input
)
return
synth_result
class
DatasetSynthesiser
(
ImageSynthesiser
):
def
__init__
(
self
):
super
(
DatasetSynthesiser
,
self
).
__init__
()
self
.
tag
=
self
.
FLAGS
.
tag
self
.
output_num
=
self
.
config
[
"Global"
][
"output_num"
]
corpus_generator_method
=
self
.
config
[
"CorpusGenerator"
][
"method"
]
self
.
corpus_generator
=
getattr
(
corpus_generators
,
corpus_generator_method
)(
self
.
config
)
style_sampler_method
=
self
.
config
[
"StyleSampler"
][
"method"
]
assert
style_sampler_method
is
not
None
self
.
style_sampler
=
style_samplers
.
DatasetSampler
(
self
.
config
)
self
.
writer
=
writers
.
SimpleWriter
(
self
.
config
,
self
.
tag
)
def
synth_dataset
(
self
):
for
i
in
range
(
self
.
output_num
):
style_data
=
self
.
style_sampler
.
sample
()
style_input
=
style_data
[
"image"
]
corpus_language
,
text_input_label
=
self
.
corpus_generator
.
generate
(
)
text_input_label
,
text_input
=
self
.
text_drawer
.
draw_text
(
text_input_label
,
corpus_language
)
synth_result
=
self
.
predictor
.
predict
(
style_input
,
text_input
)
fake_fusion
=
synth_result
[
"fake_fusion"
]
self
.
writer
.
save_image
(
fake_fusion
,
text_input_label
)
self
.
writer
.
save_label
()
self
.
writer
.
merge_label
()
StyleTextRec/engine/text_drawers.py
0 → 100644
View file @
fe46b77e
from
PIL
import
Image
,
ImageDraw
,
ImageFont
import
numpy
as
np
from
utils.logging
import
get_logger
class
StdTextDrawer
(
object
):
def
__init__
(
self
,
config
):
self
.
logger
=
get_logger
()
self
.
max_width
=
config
[
"Global"
][
"image_width"
]
self
.
char_list
=
" 0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
self
.
height
=
config
[
"Global"
][
"image_height"
]
self
.
font_dict
=
{}
self
.
load_fonts
(
config
[
"TextDrawer"
][
"fonts"
])
self
.
support_languages
=
list
(
self
.
font_dict
)
def
load_fonts
(
self
,
fonts_config
):
for
language
in
fonts_config
:
font_path
=
fonts_config
[
language
]
font_height
=
self
.
get_valid_height
(
font_path
)
font
=
ImageFont
.
truetype
(
font_path
,
font_height
)
self
.
font_dict
[
language
]
=
font
def
get_valid_height
(
self
,
font_path
):
font
=
ImageFont
.
truetype
(
font_path
,
self
.
height
-
4
)
_
,
font_height
=
font
.
getsize
(
self
.
char_list
)
if
font_height
<=
self
.
height
-
4
:
return
self
.
height
-
4
else
:
return
int
((
self
.
height
-
4
)
**
2
/
font_height
)
def
draw_text
(
self
,
corpus
,
language
=
"en"
,
crop
=
True
):
if
language
not
in
self
.
support_languages
:
self
.
logger
.
warning
(
"language {} not supported, use en instead."
.
format
(
language
))
language
=
"en"
if
crop
:
width
=
min
(
self
.
max_width
,
len
(
corpus
)
*
self
.
height
)
+
4
else
:
width
=
len
(
corpus
)
*
self
.
height
+
4
bg
=
Image
.
new
(
"RGB"
,
(
width
,
self
.
height
),
color
=
(
127
,
127
,
127
))
draw
=
ImageDraw
.
Draw
(
bg
)
char_x
=
2
font
=
self
.
font_dict
[
language
]
for
i
,
char_i
in
enumerate
(
corpus
):
char_size
=
font
.
getsize
(
char_i
)[
0
]
draw
.
text
((
char_x
,
2
),
char_i
,
fill
=
(
0
,
0
,
0
),
font
=
font
)
char_x
+=
char_size
if
char_x
>=
width
:
corpus
=
corpus
[
0
:
i
+
1
]
self
.
logger
.
warning
(
"corpus length exceed limit: {}"
.
format
(
corpus
))
break
text_input
=
np
.
array
(
bg
).
astype
(
np
.
uint8
)
text_input
=
text_input
[:,
0
:
char_x
,
:]
return
corpus
,
text_input
StyleTextRec/engine/writers.py
0 → 100644
View file @
fe46b77e
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
cv2
import
glob
from
utils.logging
import
get_logger
class
SimpleWriter
(
object
):
def
__init__
(
self
,
config
,
tag
):
self
.
logger
=
get_logger
()
self
.
output_dir
=
config
[
"Global"
][
"output_dir"
]
self
.
counter
=
0
self
.
label_dict
=
{}
self
.
tag
=
tag
self
.
label_file_index
=
0
def
save_image
(
self
,
image
,
text_input_label
):
image_home
=
os
.
path
.
join
(
self
.
output_dir
,
"images"
,
self
.
tag
)
if
not
os
.
path
.
exists
(
image_home
):
os
.
makedirs
(
image_home
)
image_path
=
os
.
path
.
join
(
image_home
,
"{}.png"
.
format
(
self
.
counter
))
# todo support continue synth
cv2
.
imwrite
(
image_path
,
image
)
self
.
logger
.
info
(
"generate image: {}"
.
format
(
image_path
))
image_name
=
os
.
path
.
join
(
self
.
tag
,
"{}.png"
.
format
(
self
.
counter
))
self
.
label_dict
[
image_name
]
=
text_input_label
self
.
counter
+=
1
if
not
self
.
counter
%
100
:
self
.
save_label
()
def
save_label
(
self
):
label_raw
=
""
label_home
=
os
.
path
.
join
(
self
.
output_dir
,
"label"
)
if
not
os
.
path
.
exists
(
label_home
):
os
.
mkdir
(
label_home
)
for
image_path
in
self
.
label_dict
:
label
=
self
.
label_dict
[
image_path
]
label_raw
+=
"{}
\t
{}
\n
"
.
format
(
image_path
,
label
)
label_file_path
=
os
.
path
.
join
(
label_home
,
"{}_label.txt"
.
format
(
self
.
tag
))
with
open
(
label_file_path
,
"w"
)
as
f
:
f
.
write
(
label_raw
)
self
.
label_file_index
+=
1
def
merge_label
(
self
):
label_raw
=
""
label_file_regex
=
os
.
path
.
join
(
self
.
output_dir
,
"label"
,
"*_label.txt"
)
label_file_list
=
glob
.
glob
(
label_file_regex
)
for
label_file_i
in
label_file_list
:
with
open
(
label_file_i
,
"r"
)
as
f
:
label_raw
+=
f
.
read
()
label_file_path
=
os
.
path
.
join
(
self
.
output_dir
,
"label.txt"
)
with
open
(
label_file_path
,
"w"
)
as
f
:
f
.
write
(
label_raw
)
StyleTextRec/examples/corpus/example.txt
0 → 100644
View file @
fe46b77e
PaddleOCR
飞桨文字识别
StyleTextRec/examples/image_list.txt
0 → 100644
View file @
fe46b77e
style_images/1.jpg NEATNESS
style_images/2.jpg 锁店君和宾馆
StyleTextRec/examples/style_images/1.jpg
0 → 100644
View file @
fe46b77e
2.55 KB
StyleTextRec/examples/style_images/2.jpg
0 → 100644
View file @
fe46b77e
3.83 KB
StyleTextRec/fonts/ch_standard.ttf
0 → 100755
View file @
fe46b77e
File added
StyleTextRec/fonts/en_standard.ttf
0 → 100755
View file @
fe46b77e
File added
StyleTextRec/fonts/ko_standard.ttf
0 → 100755
View file @
fe46b77e
File added
StyleTextRec/tools/__init__.py
0 → 100644
View file @
fe46b77e
StyleTextRec/tools/synth_dataset.py
0 → 100644
View file @
fe46b77e
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
engine.synthesisers
import
DatasetSynthesiser
def
synth_dataset
():
dataset_synthesiser
=
DatasetSynthesiser
()
dataset_synthesiser
.
synth_dataset
()
if
__name__
==
'__main__'
:
synth_dataset
()
StyleTextRec/tools/synth_image.py
0 → 100644
View file @
fe46b77e
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
cv2
import
sys
import
glob
from
utils.config
import
ArgsParser
from
engine.synthesisers
import
ImageSynthesiser
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'..'
)))
def
synth_image
():
args
=
ArgsParser
().
parse_args
()
image_synthesiser
=
ImageSynthesiser
()
style_image_path
=
args
.
style_image
img
=
cv2
.
imread
(
style_image_path
)
text_corpus
=
args
.
text_corpus
language
=
args
.
language
synth_result
=
image_synthesiser
.
synth_image
(
text_corpus
,
img
,
language
)
fake_fusion
=
synth_result
[
"fake_fusion"
]
fake_text
=
synth_result
[
"fake_text"
]
fake_bg
=
synth_result
[
"fake_bg"
]
cv2
.
imwrite
(
"fake_fusion.jpg"
,
fake_fusion
)
cv2
.
imwrite
(
"fake_text.jpg"
,
fake_text
)
cv2
.
imwrite
(
"fake_bg.jpg"
,
fake_bg
)
def
batch_synth_images
():
image_synthesiser
=
ImageSynthesiser
()
corpus_file
=
"../StyleTextRec_data/test_20201208/test_text_list.txt"
style_data_dir
=
"../StyleTextRec_data/test_20201208/style_images/"
save_path
=
"./output_data/"
corpus_list
=
[]
with
open
(
corpus_file
,
"rb"
)
as
fin
:
lines
=
fin
.
readlines
()
for
line
in
lines
:
substr
=
line
.
decode
(
"utf-8"
).
strip
(
"
\n
"
).
split
(
"
\t
"
)
corpus_list
.
append
(
substr
)
style_img_list
=
glob
.
glob
(
"{}/*.jpg"
.
format
(
style_data_dir
))
corpus_num
=
len
(
corpus_list
)
style_img_num
=
len
(
style_img_list
)
for
cno
in
range
(
corpus_num
):
for
sno
in
range
(
style_img_num
):
corpus
,
lang
=
corpus_list
[
cno
]
style_img_path
=
style_img_list
[
sno
]
img
=
cv2
.
imread
(
style_img_path
)
synth_result
=
image_synthesiser
.
synth_image
(
corpus
,
img
,
lang
)
fake_fusion
=
synth_result
[
"fake_fusion"
]
fake_text
=
synth_result
[
"fake_text"
]
fake_bg
=
synth_result
[
"fake_bg"
]
for
tp
in
range
(
2
):
if
tp
==
0
:
prefix
=
"%s/c%d_s%d_"
%
(
save_path
,
cno
,
sno
)
else
:
prefix
=
"%s/s%d_c%d_"
%
(
save_path
,
sno
,
cno
)
cv2
.
imwrite
(
"%s_fake_fusion.jpg"
%
prefix
,
fake_fusion
)
cv2
.
imwrite
(
"%s_fake_text.jpg"
%
prefix
,
fake_text
)
cv2
.
imwrite
(
"%s_fake_bg.jpg"
%
prefix
,
fake_bg
)
cv2
.
imwrite
(
"%s_input_style.jpg"
%
prefix
,
img
)
print
(
cno
,
corpus_num
,
sno
,
style_img_num
)
if
__name__
==
'__main__'
:
# batch_synth_images()
synth_image
()
StyleTextRec/utils/__init__.py
0 → 100644
View file @
fe46b77e
StyleTextRec/utils/config.py
0 → 100644
View file @
fe46b77e
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
yaml
import
os
from
argparse
import
ArgumentParser
,
RawDescriptionHelpFormatter
def
override
(
dl
,
ks
,
v
):
"""
Recursively replace dict of list
Args:
dl(dict or list): dict or list to be replaced
ks(list): list of keys
v(str): value to be replaced
"""
def
str2num
(
v
):
try
:
return
eval
(
v
)
except
Exception
:
return
v
assert
isinstance
(
dl
,
(
list
,
dict
)),
(
"{} should be a list or a dict"
)
assert
len
(
ks
)
>
0
,
(
'lenght of keys should larger than 0'
)
if
isinstance
(
dl
,
list
):
k
=
str2num
(
ks
[
0
])
if
len
(
ks
)
==
1
:
assert
k
<
len
(
dl
),
(
'index({}) out of range({})'
.
format
(
k
,
dl
))
dl
[
k
]
=
str2num
(
v
)
else
:
override
(
dl
[
k
],
ks
[
1
:],
v
)
else
:
if
len
(
ks
)
==
1
:
#assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
if
not
ks
[
0
]
in
dl
:
logger
.
warning
(
'A new filed ({}) detected!'
.
format
(
ks
[
0
],
dl
))
dl
[
ks
[
0
]]
=
str2num
(
v
)
else
:
assert
ks
[
0
]
in
dl
,
(
'({}) doesn
\'
t exist in {}, a new dict field is invalid'
.
format
(
ks
[
0
],
dl
))
override
(
dl
[
ks
[
0
]],
ks
[
1
:],
v
)
def
override_config
(
config
,
options
=
None
):
"""
Recursively override the config
Args:
config(dict): dict to be replaced
options(list): list of pairs(key0.key1.idx.key2=value)
such as: [
'topk=2',
'VALID.transforms.1.ResizeImage.resize_short=300'
]
Returns:
config(dict): replaced config
"""
if
options
is
not
None
:
for
opt
in
options
:
assert
isinstance
(
opt
,
str
),
(
"option({}) should be a str"
.
format
(
opt
))
assert
"="
in
opt
,
(
"option({}) should contain a ="
"to distinguish between key and value"
.
format
(
opt
))
pair
=
opt
.
split
(
'='
)
assert
len
(
pair
)
==
2
,
(
"there can be only a = in the option"
)
key
,
value
=
pair
keys
=
key
.
split
(
'.'
)
override
(
config
,
keys
,
value
)
return
config
class
ArgsParser
(
ArgumentParser
):
def
__init__
(
self
):
super
(
ArgsParser
,
self
).
__init__
(
formatter_class
=
RawDescriptionHelpFormatter
)
self
.
add_argument
(
"-c"
,
"--config"
,
help
=
"configuration file to use"
)
self
.
add_argument
(
"-t"
,
"--tag"
,
default
=
"0"
,
help
=
"tag for marking worker"
)
self
.
add_argument
(
'-o'
,
'--override'
,
action
=
'append'
,
default
=
[],
help
=
'config options to be overridden'
)
self
.
add_argument
(
"--style_image"
,
default
=
"examples/style_images/1.jpg"
,
help
=
"tag for marking worker"
)
self
.
add_argument
(
"--text_corpus"
,
default
=
"PaddleOCR"
,
help
=
"tag for marking worker"
)
self
.
add_argument
(
"--language"
,
default
=
"en"
,
help
=
"tag for marking worker"
)
def
parse_args
(
self
,
argv
=
None
):
args
=
super
(
ArgsParser
,
self
).
parse_args
(
argv
)
assert
args
.
config
is
not
None
,
\
"Please specify --config=configure_file_path."
return
args
def
load_config
(
file_path
):
"""
Load config from yml/yaml file.
Args:
file_path (str): Path of the config file to be loaded.
Returns: config
"""
ext
=
os
.
path
.
splitext
(
file_path
)[
1
]
assert
ext
in
[
'.yml'
,
'.yaml'
],
"only support yaml files for now"
with
open
(
file_path
,
'rb'
)
as
f
:
config
=
yaml
.
load
(
f
,
Loader
=
yaml
.
Loader
)
return
config
def
gen_config
():
base_config
=
{
"Global"
:
{
"algorithm"
:
"SRNet"
,
"use_gpu"
:
True
,
"start_epoch"
:
1
,
"stage1_epoch_num"
:
100
,
"stage2_epoch_num"
:
100
,
"log_smooth_window"
:
20
,
"print_batch_step"
:
2
,
"save_model_dir"
:
"./output/SRNet"
,
"use_visualdl"
:
False
,
"save_epoch_step"
:
10
,
"vgg_pretrain"
:
"./pretrained/VGG19_pretrained"
,
"vgg_load_static_pretrain"
:
True
},
"Architecture"
:
{
"model_type"
:
"data_aug"
,
"algorithm"
:
"SRNet"
,
"net_g"
:
{
"name"
:
"srnet_net_g"
,
"encode_dim"
:
64
,
"norm"
:
"batch"
,
"use_dropout"
:
False
,
"init_type"
:
"xavier"
,
"init_gain"
:
0.02
,
"use_dilation"
:
1
},
# input_nc, ndf, netD,
# n_layers_D=3, norm='instance', use_sigmoid=False, init_type='normal', init_gain=0.02, gpu_id='cuda:0'
"bg_discriminator"
:
{
"name"
:
"srnet_bg_discriminator"
,
"input_nc"
:
6
,
"ndf"
:
64
,
"netD"
:
"basic"
,
"norm"
:
"none"
,
"init_type"
:
"xavier"
,
},
"fusion_discriminator"
:
{
"name"
:
"srnet_fusion_discriminator"
,
"input_nc"
:
6
,
"ndf"
:
64
,
"netD"
:
"basic"
,
"norm"
:
"none"
,
"init_type"
:
"xavier"
,
}
},
"Loss"
:
{
"lamb"
:
10
,
"perceptual_lamb"
:
1
,
"muvar_lamb"
:
50
,
"style_lamb"
:
500
},
"Optimizer"
:
{
"name"
:
"Adam"
,
"learning_rate"
:
{
"name"
:
"lambda"
,
"lr"
:
0.0002
,
"lr_decay_iters"
:
50
},
"beta1"
:
0.5
,
"beta2"
:
0.999
,
},
"Train"
:
{
"batch_size_per_card"
:
8
,
"num_workers_per_card"
:
4
,
"dataset"
:
{
"delimiter"
:
"
\t
"
,
"data_dir"
:
"/"
,
"label_file"
:
"tmp/label.txt"
,
"transforms"
:
[{
"DecodeImage"
:
{
"to_rgb"
:
True
,
"to_np"
:
False
,
"channel_first"
:
False
}
},
{
"NormalizeImage"
:
{
"scale"
:
1.
/
255.
,
"mean"
:
[
0.485
,
0.456
,
0.406
],
"std"
:
[
0.229
,
0.224
,
0.225
],
"order"
:
None
}
},
{
"ToCHWImage"
:
None
}]
}
}
}
with
open
(
"config.yml"
,
"w"
)
as
f
:
yaml
.
dump
(
base_config
,
f
)
if
__name__
==
'__main__'
:
gen_config
()
StyleTextRec/utils/load_params.py
0 → 100644
View file @
fe46b77e
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
paddle
__all__
=
[
'load_dygraph_pretrain'
]
def
load_dygraph_pretrain
(
model
,
logger
,
path
=
None
,
load_static_weights
=
False
):
if
not
os
.
path
.
exists
(
path
+
'.pdparams'
):
raise
ValueError
(
"Model pretrain path {} does not "
"exists."
.
format
(
path
))
param_state_dict
=
paddle
.
load
(
path
+
'.pdparams'
)
model
.
set_state_dict
(
param_state_dict
)
logger
.
info
(
"load pretrained model from {}"
.
format
(
path
))
return
StyleTextRec/utils/logging.py
0 → 100644
View file @
fe46b77e
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
import
logging
import
functools
import
paddle.distributed
as
dist
logger_initialized
=
{}
@
functools
.
lru_cache
()
def
get_logger
(
name
=
'srnet'
,
log_file
=
None
,
log_level
=
logging
.
INFO
):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
added. If `log_file` is specified a FileHandler will also be added.
Args:
name (str): Logger name.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level. Note that only the process of
rank 0 is affected, and other processes will set the level to
"Error" thus be silent most of the time.
Returns:
logging.Logger: The expected logger.
"""
logger
=
logging
.
getLogger
(
name
)
if
name
in
logger_initialized
:
return
logger
for
logger_name
in
logger_initialized
:
if
name
.
startswith
(
logger_name
):
return
logger
formatter
=
logging
.
Formatter
(
'[%(asctime)s] %(name)s %(levelname)s: %(message)s'
,
datefmt
=
"%Y/%m/%d %H:%M:%S"
)
stream_handler
=
logging
.
StreamHandler
(
stream
=
sys
.
stdout
)
stream_handler
.
setFormatter
(
formatter
)
logger
.
addHandler
(
stream_handler
)
if
log_file
is
not
None
and
dist
.
get_rank
()
==
0
:
log_file_folder
=
os
.
path
.
split
(
log_file
)[
0
]
os
.
makedirs
(
log_file_folder
,
exist_ok
=
True
)
file_handler
=
logging
.
FileHandler
(
log_file
,
'a'
)
file_handler
.
setFormatter
(
formatter
)
logger
.
addHandler
(
file_handler
)
if
dist
.
get_rank
()
==
0
:
logger
.
setLevel
(
log_level
)
else
:
logger
.
setLevel
(
logging
.
ERROR
)
logger_initialized
[
name
]
=
True
return
logger
StyleTextRec/utils/math_functions.py
0 → 100644
View file @
fe46b77e
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
def
compute_mean_covariance
(
img
):
batch_size
=
img
.
shape
[
0
]
channel_num
=
img
.
shape
[
1
]
height
=
img
.
shape
[
2
]
width
=
img
.
shape
[
3
]
num_pixels
=
height
*
width
# batch_size * channel_num * 1 * 1
mu
=
img
.
mean
(
2
,
keepdim
=
True
).
mean
(
3
,
keepdim
=
True
)
# batch_size * channel_num * num_pixels
img_hat
=
img
-
mu
.
expand_as
(
img
)
img_hat
=
img_hat
.
reshape
([
batch_size
,
channel_num
,
num_pixels
])
# batch_size * num_pixels * channel_num
img_hat_transpose
=
img_hat
.
transpose
([
0
,
2
,
1
])
# batch_size * channel_num * channel_num
covariance
=
paddle
.
bmm
(
img_hat
,
img_hat_transpose
)
covariance
=
covariance
/
num_pixels
return
mu
,
covariance
def
dice_coefficient
(
y_true_cls
,
y_pred_cls
,
training_mask
):
eps
=
1e-5
intersection
=
paddle
.
sum
(
y_true_cls
*
y_pred_cls
*
training_mask
)
union
=
paddle
.
sum
(
y_true_cls
*
training_mask
)
+
paddle
.
sum
(
y_pred_cls
*
training_mask
)
+
eps
loss
=
1.
-
(
2
*
intersection
/
union
)
return
loss
StyleTextRec/utils/sys_funcs.py
0 → 100644
View file @
fe46b77e
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
import
os
import
errno
import
paddle
def
get_check_global_params
(
mode
):
check_params
=
[
'use_gpu'
,
'max_text_length'
,
'image_shape'
,
'image_shape'
,
'character_type'
,
'loss_type'
]
if
mode
==
"train_eval"
:
check_params
=
check_params
+
[
'train_batch_size_per_card'
,
'test_batch_size_per_card'
]
elif
mode
==
"test"
:
check_params
=
check_params
+
[
'test_batch_size_per_card'
]
return
check_params
def
check_gpu
(
use_gpu
):
"""
Log error and exit when set use_gpu=true in paddlepaddle
cpu version.
"""
err
=
"Config use_gpu cannot be set as true while you are "
\
"using paddlepaddle cpu version !
\n
Please try:
\n
"
\
"
\t
1. Install paddlepaddle-gpu to run model on GPU
\n
"
\
"
\t
2. Set use_gpu as false in config file to run "
\
"model on CPU"
if
use_gpu
:
try
:
if
not
paddle
.
is_compiled_with_cuda
():
print
(
err
)
sys
.
exit
(
1
)
except
:
print
(
"Fail to check gpu state."
)
sys
.
exit
(
1
)
def
_mkdir_if_not_exist
(
path
,
logger
):
"""
mkdir if not exists, ignore the exception when multiprocess mkdir together
"""
if
not
os
.
path
.
exists
(
path
):
try
:
os
.
makedirs
(
path
)
except
OSError
as
e
:
if
e
.
errno
==
errno
.
EEXIST
and
os
.
path
.
isdir
(
path
):
logger
.
warning
(
'be happy if some process has already created {}'
.
format
(
path
))
else
:
raise
OSError
(
'Failed to mkdir {}'
.
format
(
path
))
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment