Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
d8719969
Unverified
Commit
d8719969
authored
Feb 22, 2021
by
littletomatodonkey
Committed by
GitHub
Feb 22, 2021
Browse files
improve style text infer process (#2055)
* improve style text * fix dead loop
parent
6a42745f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
81 additions
and
25 deletions
+81
-25
StyleText/engine/predictors.py
StyleText/engine/predictors.py
+24
-1
StyleText/engine/synthesisers.py
StyleText/engine/synthesisers.py
+13
-7
StyleText/engine/text_drawers.py
StyleText/engine/text_drawers.py
+44
-17
No files found.
StyleText/engine/predictors.py
View file @
d8719969
...
@@ -38,7 +38,15 @@ class StyleTextRecPredictor(object):
...
@@ -38,7 +38,15 @@ class StyleTextRecPredictor(object):
self
.
std
=
config
[
"Predictor"
][
"std"
]
self
.
std
=
config
[
"Predictor"
][
"std"
]
self
.
expand_result
=
config
[
"Predictor"
][
"expand_result"
]
self
.
expand_result
=
config
[
"Predictor"
][
"expand_result"
]
def
predict
(
self
,
style_input
,
text_input
):
def
reshape_to_same_height
(
self
,
img_list
):
h
=
img_list
[
0
].
shape
[
0
]
for
idx
in
range
(
1
,
len
(
img_list
)):
new_w
=
round
(
1.0
*
img_list
[
idx
].
shape
[
1
]
/
img_list
[
idx
].
shape
[
0
]
*
h
)
img_list
[
idx
]
=
cv2
.
resize
(
img_list
[
idx
],
(
new_w
,
h
))
return
img_list
def
predict_single_image
(
self
,
style_input
,
text_input
):
style_input
=
self
.
rep_style_input
(
style_input
,
text_input
)
style_input
=
self
.
rep_style_input
(
style_input
,
text_input
)
tensor_style_input
=
self
.
preprocess
(
style_input
)
tensor_style_input
=
self
.
preprocess
(
style_input
)
tensor_text_input
=
self
.
preprocess
(
text_input
)
tensor_text_input
=
self
.
preprocess
(
text_input
)
...
@@ -64,6 +72,21 @@ class StyleTextRecPredictor(object):
...
@@ -64,6 +72,21 @@ class StyleTextRecPredictor(object):
"fake_bg"
:
fake_bg
,
"fake_bg"
:
fake_bg
,
}
}
def
predict
(
self
,
style_input
,
text_input_list
):
if
not
isinstance
(
text_input_list
,
(
tuple
,
list
)):
return
self
.
predict_single_image
(
style_input
,
text_input_list
)
synth_result_list
=
[]
for
text_input
in
text_input_list
:
synth_result
=
self
.
predict_single_image
(
style_input
,
text_input
)
synth_result_list
.
append
(
synth_result
)
for
key
in
synth_result
:
res
=
[
r
[
key
]
for
r
in
synth_result_list
]
res
=
self
.
reshape_to_same_height
(
res
)
synth_result
[
key
]
=
np
.
concatenate
(
res
,
axis
=
1
)
return
synth_result
def
preprocess
(
self
,
img
):
def
preprocess
(
self
,
img
):
img
=
(
img
.
astype
(
'float32'
)
*
self
.
scale
-
self
.
mean
)
/
self
.
std
img
=
(
img
.
astype
(
'float32'
)
*
self
.
scale
-
self
.
mean
)
/
self
.
std
img_height
,
img_width
,
channel
=
img
.
shape
img_height
,
img_width
,
channel
=
img
.
shape
...
...
StyleText/engine/synthesisers.py
View file @
d8719969
...
@@ -12,6 +12,8 @@
...
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
os
import
os
import
numpy
as
np
import
cv2
from
utils.config
import
ArgsParser
,
load_config
,
override_config
from
utils.config
import
ArgsParser
,
load_config
,
override_config
from
utils.logging
import
get_logger
from
utils.logging
import
get_logger
...
@@ -36,8 +38,9 @@ class ImageSynthesiser(object):
...
@@ -36,8 +38,9 @@ class ImageSynthesiser(object):
self
.
predictor
=
getattr
(
predictors
,
predictor_method
)(
self
.
config
)
self
.
predictor
=
getattr
(
predictors
,
predictor_method
)(
self
.
config
)
def
synth_image
(
self
,
corpus
,
style_input
,
language
=
"en"
):
def
synth_image
(
self
,
corpus
,
style_input
,
language
=
"en"
):
corpus
,
text_input
=
self
.
text_drawer
.
draw_text
(
corpus
,
language
)
corpus_list
,
text_input_list
=
self
.
text_drawer
.
draw_text
(
synth_result
=
self
.
predictor
.
predict
(
style_input
,
text_input
)
corpus
,
language
,
style_input_width
=
style_input
.
shape
[
1
])
synth_result
=
self
.
predictor
.
predict
(
style_input
,
text_input_list
)
return
synth_result
return
synth_result
...
@@ -59,12 +62,15 @@ class DatasetSynthesiser(ImageSynthesiser):
...
@@ -59,12 +62,15 @@ class DatasetSynthesiser(ImageSynthesiser):
for
i
in
range
(
self
.
output_num
):
for
i
in
range
(
self
.
output_num
):
style_data
=
self
.
style_sampler
.
sample
()
style_data
=
self
.
style_sampler
.
sample
()
style_input
=
style_data
[
"image"
]
style_input
=
style_data
[
"image"
]
corpus_language
,
text_input_label
=
self
.
corpus_generator
.
generate
(
corpus_language
,
text_input_label
=
self
.
corpus_generator
.
generate
()
)
text_input_label_list
,
text_input_list
=
self
.
text_drawer
.
draw_text
(
text_input_label
,
text_input
=
self
.
text_drawer
.
draw_text
(
text_input_label
,
text_input_label
,
corpus_language
)
corpus_language
,
style_input_width
=
style_input
.
shape
[
1
])
synth_result
=
self
.
predictor
.
predict
(
style_input
,
text_input
)
text_input_label
=
""
.
join
(
text_input_label_list
)
synth_result
=
self
.
predictor
.
predict
(
style_input
,
text_input_list
)
fake_fusion
=
synth_result
[
"fake_fusion"
]
fake_fusion
=
synth_result
[
"fake_fusion"
]
self
.
writer
.
save_image
(
fake_fusion
,
text_input_label
)
self
.
writer
.
save_image
(
fake_fusion
,
text_input_label
)
self
.
writer
.
save_label
()
self
.
writer
.
save_label
()
...
...
StyleText/engine/text_drawers.py
View file @
d8719969
from
PIL
import
Image
,
ImageDraw
,
ImageFont
from
PIL
import
Image
,
ImageDraw
,
ImageFont
import
numpy
as
np
import
numpy
as
np
import
cv2
from
utils.logging
import
get_logger
from
utils.logging
import
get_logger
...
@@ -28,7 +29,11 @@ class StdTextDrawer(object):
...
@@ -28,7 +29,11 @@ class StdTextDrawer(object):
else
:
else
:
return
int
((
self
.
height
-
4
)
**
2
/
font_height
)
return
int
((
self
.
height
-
4
)
**
2
/
font_height
)
def
draw_text
(
self
,
corpus
,
language
=
"en"
,
crop
=
True
):
def
draw_text
(
self
,
corpus
,
language
=
"en"
,
crop
=
True
,
style_input_width
=
None
):
if
language
not
in
self
.
support_languages
:
if
language
not
in
self
.
support_languages
:
self
.
logger
.
warning
(
self
.
logger
.
warning
(
"language {} not supported, use en instead."
.
format
(
language
))
"language {} not supported, use en instead."
.
format
(
language
))
...
@@ -37,21 +42,43 @@ class StdTextDrawer(object):
...
@@ -37,21 +42,43 @@ class StdTextDrawer(object):
width
=
min
(
self
.
max_width
,
len
(
corpus
)
*
self
.
height
)
+
4
width
=
min
(
self
.
max_width
,
len
(
corpus
)
*
self
.
height
)
+
4
else
:
else
:
width
=
len
(
corpus
)
*
self
.
height
+
4
width
=
len
(
corpus
)
*
self
.
height
+
4
bg
=
Image
.
new
(
"RGB"
,
(
width
,
self
.
height
),
color
=
(
127
,
127
,
127
))
draw
=
ImageDraw
.
Draw
(
bg
)
if
style_input_width
is
not
None
:
width
=
min
(
width
,
style_input_width
)
char_x
=
2
font
=
self
.
font_dict
[
language
]
corpus_list
=
[]
for
i
,
char_i
in
enumerate
(
corpus
):
text_input_list
=
[]
char_size
=
font
.
getsize
(
char_i
)[
0
]
draw
.
text
((
char_x
,
2
),
char_i
,
fill
=
(
0
,
0
,
0
),
font
=
font
)
while
len
(
corpus
)
!=
0
:
char_x
+=
char_size
bg
=
Image
.
new
(
"RGB"
,
(
width
,
self
.
height
),
color
=
(
127
,
127
,
127
))
if
char_x
>=
width
:
draw
=
ImageDraw
.
Draw
(
bg
)
corpus
=
corpus
[
0
:
i
+
1
]
char_x
=
2
self
.
logger
.
warning
(
"corpus length exceed limit: {}"
.
format
(
font
=
self
.
font_dict
[
language
]
corpus
))
i
=
0
while
i
<
len
(
corpus
):
char_i
=
corpus
[
i
]
char_size
=
font
.
getsize
(
char_i
)[
0
]
# split when char_x exceeds char size and index is not 0 (at least 1 char should be wroten on the image)
if
char_x
+
char_size
>=
width
and
i
!=
0
:
text_input
=
np
.
array
(
bg
).
astype
(
np
.
uint8
)
text_input
=
text_input
[:,
0
:
char_x
,
:]
corpus_list
.
append
(
corpus
[
0
:
i
])
text_input_list
.
append
(
text_input
)
corpus
=
corpus
[
i
:]
break
draw
.
text
((
char_x
,
2
),
char_i
,
fill
=
(
0
,
0
,
0
),
font
=
font
)
char_x
+=
char_size
i
+=
1
# the whole text is shorter than style input
if
i
==
len
(
corpus
):
text_input
=
np
.
array
(
bg
).
astype
(
np
.
uint8
)
text_input
=
text_input
[:,
0
:
char_x
,
:]
corpus_list
.
append
(
corpus
[
0
:
i
])
text_input_list
.
append
(
text_input
)
corpus
=
corpus
[
i
:]
break
break
text_input
=
np
.
array
(
bg
).
astype
(
np
.
uint8
)
return
corpus_list
,
text_input_list
text_input
=
text_input
[:,
0
:
char_x
,
:]
return
corpus
,
text_input
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment