Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
f4b62551
Commit
f4b62551
authored
May 18, 2022
by
littletomatodonkey
Browse files
add support for svtr static training (#6328)
parent
1bb03b4d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
32 additions
and
14 deletions
+32
-14
ppocr/modeling/architectures/__init__.py
ppocr/modeling/architectures/__init__.py
+22
-4
ppocr/modeling/heads/rec_sar_head.py
ppocr/modeling/heads/rec_sar_head.py
+10
-10
No files found.
ppocr/modeling/architectures/__init__.py
View file @
f4b62551
...
@@ -40,11 +40,29 @@ def apply_to_static(model, config, logger):
...
@@ -40,11 +40,29 @@ def apply_to_static(model, config, logger):
return
model
return
model
assert
"image_shape"
in
config
[
assert
"image_shape"
in
config
[
"Global"
],
"image_shape must be assigned for static training mode..."
"Global"
],
"image_shape must be assigned for static training mode..."
supported_list
=
[
"DB"
]
supported_list
=
[
"DB"
,
"SVTR"
]
assert
config
[
"Architecture"
][
if
config
[
"Architecture"
][
"algorithm"
]
in
[
"Distillation"
]:
"algorithm"
]
in
supported_list
,
f
"algorithms that supports static training must in in
{
supported_list
}
but got
{
config
[
'Architecture'
][
'algorithm'
]
}
"
algo
=
list
(
config
[
"Architecture"
][
"Models"
].
values
())[
0
][
"algorithm"
]
else
:
algo
=
config
[
"Architecture"
][
"algorithm"
]
assert
algo
in
supported_list
,
f
"algorithms that supports static training must in in
{
supported_list
}
but got
{
algo
}
"
specs
=
[
InputSpec
(
[
None
]
+
config
[
"Global"
][
"image_shape"
],
dtype
=
'float32'
)
]
if
algo
==
"SVTR"
:
specs
.
append
([
InputSpec
(
[
None
,
config
[
"Global"
][
"max_text_length"
]],
dtype
=
'int64'
),
InputSpec
(
[
None
,
config
[
"Global"
][
"max_text_length"
]],
dtype
=
'int64'
),
InputSpec
(
[
None
],
dtype
=
'int64'
),
InputSpec
(
[
None
],
dtype
=
'float64'
)
])
specs
=
[
InputSpec
([
None
]
+
config
[
"Global"
][
"image_shape"
])]
model
=
to_static
(
model
,
input_spec
=
specs
)
model
=
to_static
(
model
,
input_spec
=
specs
)
logger
.
info
(
"Successfully to apply @to_static with specs: {}"
.
format
(
specs
))
logger
.
info
(
"Successfully to apply @to_static with specs: {}"
.
format
(
specs
))
return
model
return
model
ppocr/modeling/heads/rec_sar_head.py
View file @
f4b62551
...
@@ -83,7 +83,7 @@ class SAREncoder(nn.Layer):
...
@@ -83,7 +83,7 @@ class SAREncoder(nn.Layer):
def
forward
(
self
,
feat
,
img_metas
=
None
):
def
forward
(
self
,
feat
,
img_metas
=
None
):
if
img_metas
is
not
None
:
if
img_metas
is
not
None
:
assert
len
(
img_metas
[
0
])
==
feat
.
shape
[
0
]
assert
len
(
img_metas
[
0
])
==
paddle
.
shape
(
feat
)
[
0
]
valid_ratios
=
None
valid_ratios
=
None
if
img_metas
is
not
None
and
self
.
mask
:
if
img_metas
is
not
None
and
self
.
mask
:
...
@@ -98,9 +98,10 @@ class SAREncoder(nn.Layer):
...
@@ -98,9 +98,10 @@ class SAREncoder(nn.Layer):
if
valid_ratios
is
not
None
:
if
valid_ratios
is
not
None
:
valid_hf
=
[]
valid_hf
=
[]
T
=
holistic_feat
.
shape
[
1
]
T
=
paddle
.
shape
(
holistic_feat
)[
1
]
for
i
in
range
(
len
(
valid_ratios
)):
for
i
in
range
(
paddle
.
shape
(
valid_ratios
)[
0
]):
valid_step
=
min
(
T
,
math
.
ceil
(
T
*
valid_ratios
[
i
]))
-
1
valid_step
=
paddle
.
minimum
(
T
,
paddle
.
ceil
(
valid_ratios
[
i
]
*
T
).
astype
(
'int32'
))
-
1
valid_hf
.
append
(
holistic_feat
[
i
,
valid_step
,
:])
valid_hf
.
append
(
holistic_feat
[
i
,
valid_step
,
:])
valid_hf
=
paddle
.
stack
(
valid_hf
,
axis
=
0
)
valid_hf
=
paddle
.
stack
(
valid_hf
,
axis
=
0
)
else
:
else
:
...
@@ -247,13 +248,14 @@ class ParallelSARDecoder(BaseDecoder):
...
@@ -247,13 +248,14 @@ class ParallelSARDecoder(BaseDecoder):
# bsz * (seq_len + 1) * h * w * attn_size
# bsz * (seq_len + 1) * h * w * attn_size
attn_weight
=
self
.
conv1x1_2
(
attn_weight
)
attn_weight
=
self
.
conv1x1_2
(
attn_weight
)
# bsz * (seq_len + 1) * h * w * 1
# bsz * (seq_len + 1) * h * w * 1
bsz
,
T
,
h
,
w
,
c
=
attn_weight
.
shape
bsz
,
T
,
h
,
w
,
c
=
paddle
.
shape
(
attn_weight
)
assert
c
==
1
assert
c
==
1
if
valid_ratios
is
not
None
:
if
valid_ratios
is
not
None
:
# cal mask of attention weight
# cal mask of attention weight
for
i
in
range
(
len
(
valid_ratios
)):
for
i
in
range
(
paddle
.
shape
(
valid_ratios
)[
0
]):
valid_width
=
min
(
w
,
math
.
ceil
(
w
*
valid_ratios
[
i
]))
valid_width
=
paddle
.
minimum
(
w
,
paddle
.
ceil
(
valid_ratios
[
i
]
*
w
).
astype
(
"int32"
))
if
valid_width
<
w
:
if
valid_width
<
w
:
attn_weight
[
i
,
:,
:,
valid_width
:,
:]
=
float
(
'-inf'
)
attn_weight
[
i
,
:,
:,
valid_width
:,
:]
=
float
(
'-inf'
)
...
@@ -288,7 +290,7 @@ class ParallelSARDecoder(BaseDecoder):
...
@@ -288,7 +290,7 @@ class ParallelSARDecoder(BaseDecoder):
img_metas: [label, valid_ratio]
img_metas: [label, valid_ratio]
'''
'''
if
img_metas
is
not
None
:
if
img_metas
is
not
None
:
assert
len
(
img_metas
[
0
])
==
feat
.
shape
[
0
]
assert
paddle
.
shape
(
img_metas
[
0
])
[
0
]
==
paddle
.
shape
(
feat
)
[
0
]
valid_ratios
=
None
valid_ratios
=
None
if
img_metas
is
not
None
and
self
.
mask
:
if
img_metas
is
not
None
and
self
.
mask
:
...
@@ -302,7 +304,6 @@ class ParallelSARDecoder(BaseDecoder):
...
@@ -302,7 +304,6 @@ class ParallelSARDecoder(BaseDecoder):
# bsz * (seq_len + 1) * C
# bsz * (seq_len + 1) * C
out_dec
=
self
.
_2d_attention
(
out_dec
=
self
.
_2d_attention
(
in_dec
,
feat
,
out_enc
,
valid_ratios
=
valid_ratios
)
in_dec
,
feat
,
out_enc
,
valid_ratios
=
valid_ratios
)
# bsz * (seq_len + 1) * num_classes
return
out_dec
[:,
1
:,
:]
# bsz * seq_len * num_classes
return
out_dec
[:,
1
:,
:]
# bsz * seq_len * num_classes
...
@@ -395,7 +396,6 @@ class SARHead(nn.Layer):
...
@@ -395,7 +396,6 @@ class SARHead(nn.Layer):
if
self
.
training
:
if
self
.
training
:
label
=
targets
[
0
]
# label
label
=
targets
[
0
]
# label
label
=
paddle
.
to_tensor
(
label
,
dtype
=
'int64'
)
final_out
=
self
.
decoder
(
final_out
=
self
.
decoder
(
feat
,
holistic_feat
,
label
,
img_metas
=
targets
)
feat
,
holistic_feat
,
label
,
img_metas
=
targets
)
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment