Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
f1048e29
Unverified
Commit
f1048e29
authored
Oct 20, 2020
by
dyning
Committed by
GitHub
Oct 20, 2020
Browse files
Merge pull request #970 from WenmuZhou/dygraph
解决crnn训练时对labels进行合并的bug
parents
52b40f36
a88ce7a5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
57 deletions
+6
-57
ppocr/modeling/necks/rnn.py
ppocr/modeling/necks/rnn.py
+4
-56
ppocr/postprocess/rec_postprocess.py
ppocr/postprocess/rec_postprocess.py
+2
-1
No files found.
ppocr/modeling/necks/rnn.py
View file @
f1048e29
...
@@ -21,18 +21,6 @@ from paddle import nn
...
@@ -21,18 +21,6 @@ from paddle import nn
from
ppocr.modeling.heads.rec_ctc_head
import
get_para_bias_attr
from
ppocr.modeling.heads.rec_ctc_head
import
get_para_bias_attr
class
EncoderWithReshape
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
**
kwargs
):
super
().
__init__
()
self
.
out_channels
=
in_channels
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
x
=
x
.
reshape
((
B
,
C
,
-
1
))
x
=
x
.
transpose
([
0
,
2
,
1
])
# (NTC)(batch, width, channels)
return
x
class
Im2Seq
(
nn
.
Layer
):
class
Im2Seq
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
**
kwargs
):
def
__init__
(
self
,
in_channels
,
**
kwargs
):
super
().
__init__
()
super
().
__init__
()
...
@@ -40,9 +28,8 @@ class Im2Seq(nn.Layer):
...
@@ -40,9 +28,8 @@ class Im2Seq(nn.Layer):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
B
,
C
,
H
,
W
=
x
.
shape
assert
H
==
1
x
=
x
.
reshape
((
B
,
-
1
,
W
))
x
=
x
.
transpose
((
0
,
2
,
3
,
1
))
x
=
x
.
transpose
((
0
,
2
,
1
))
# (NTC)(batch, width, channels)
x
=
x
.
reshape
((
-
1
,
C
))
return
x
return
x
...
@@ -50,49 +37,10 @@ class EncoderWithRNN(nn.Layer):
...
@@ -50,49 +37,10 @@ class EncoderWithRNN(nn.Layer):
def
__init__
(
self
,
in_channels
,
hidden_size
):
def
__init__
(
self
,
in_channels
,
hidden_size
):
super
(
EncoderWithRNN
,
self
).
__init__
()
super
(
EncoderWithRNN
,
self
).
__init__
()
self
.
out_channels
=
hidden_size
*
2
self
.
out_channels
=
hidden_size
*
2
# self.lstm1_fw = nn.LSTMCell(
# in_channels,
# hidden_size,
# weight_ih_attr=ParamAttr(name='lstm_st1_fc1_w'),
# bias_ih_attr=ParamAttr(name='lstm_st1_fc1_b'),
# weight_hh_attr=ParamAttr(name='lstm_st1_out1_w'),
# bias_hh_attr=ParamAttr(name='lstm_st1_out1_b'),
# )
# self.lstm1_bw = nn.LSTMCell(
# in_channels,
# hidden_size,
# weight_ih_attr=ParamAttr(name='lstm_st1_fc2_w'),
# bias_ih_attr=ParamAttr(name='lstm_st1_fc2_b'),
# weight_hh_attr=ParamAttr(name='lstm_st1_out2_w'),
# bias_hh_attr=ParamAttr(name='lstm_st1_out2_b'),
# )
# self.lstm2_fw = nn.LSTMCell(
# hidden_size,
# hidden_size,
# weight_ih_attr=ParamAttr(name='lstm_st2_fc1_w'),
# bias_ih_attr=ParamAttr(name='lstm_st2_fc1_b'),
# weight_hh_attr=ParamAttr(name='lstm_st2_out1_w'),
# bias_hh_attr=ParamAttr(name='lstm_st2_out1_b'),
# )
# self.lstm2_bw = nn.LSTMCell(
# hidden_size,
# hidden_size,
# weight_ih_attr=ParamAttr(name='lstm_st2_fc2_w'),
# bias_ih_attr=ParamAttr(name='lstm_st2_fc2_b'),
# weight_hh_attr=ParamAttr(name='lstm_st2_out2_w'),
# bias_hh_attr=ParamAttr(name='lstm_st2_out2_b'),
# )
self
.
lstm
=
nn
.
LSTM
(
self
.
lstm
=
nn
.
LSTM
(
in_channels
,
hidden_size
,
direction
=
'bidirectional'
,
num_layers
=
2
)
in_channels
,
hidden_size
,
direction
=
'bidirectional'
,
num_layers
=
2
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
# fw_x, _ = self.lstm1_fw(x)
# fw_x, _ = self.lstm2_fw(fw_x)
#
# # bw
# bw_x, _ = self.lstm1_bw(x)
# bw_x, _ = self.lstm2_bw(bw_x)
# x = paddle.concat([fw_x, bw_x], axis=2)
x
,
_
=
self
.
lstm
(
x
)
x
,
_
=
self
.
lstm
(
x
)
return
x
return
x
...
@@ -118,13 +66,13 @@ class EncoderWithFC(nn.Layer):
...
@@ -118,13 +66,13 @@ class EncoderWithFC(nn.Layer):
class
SequenceEncoder
(
nn
.
Layer
):
class
SequenceEncoder
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
encoder_type
,
hidden_size
=
48
,
**
kwargs
):
def
__init__
(
self
,
in_channels
,
encoder_type
,
hidden_size
=
48
,
**
kwargs
):
super
(
SequenceEncoder
,
self
).
__init__
()
super
(
SequenceEncoder
,
self
).
__init__
()
self
.
encoder_reshape
=
EncoderWithReshape
(
in_channels
)
self
.
encoder_reshape
=
Im2Seq
(
in_channels
)
self
.
out_channels
=
self
.
encoder_reshape
.
out_channels
self
.
out_channels
=
self
.
encoder_reshape
.
out_channels
if
encoder_type
==
'reshape'
:
if
encoder_type
==
'reshape'
:
self
.
only_reshape
=
True
self
.
only_reshape
=
True
else
:
else
:
support_encoder_dict
=
{
support_encoder_dict
=
{
'reshape'
:
EncoderWithReshape
,
'reshape'
:
Im2Seq
,
'fc'
:
EncoderWithFC
,
'fc'
:
EncoderWithFC
,
'rnn'
:
EncoderWithRNN
'rnn'
:
EncoderWithRNN
}
}
...
...
ppocr/postprocess/rec_postprocess.py
View file @
f1048e29
...
@@ -70,6 +70,7 @@ class BaseRecLabelDecode(object):
...
@@ -70,6 +70,7 @@ class BaseRecLabelDecode(object):
if
text_index
[
batch_idx
][
idx
]
in
ignored_tokens
:
if
text_index
[
batch_idx
][
idx
]
in
ignored_tokens
:
continue
continue
if
is_remove_duplicate
:
if
is_remove_duplicate
:
# only for predict
if
idx
>
0
and
text_index
[
batch_idx
][
idx
-
1
]
==
text_index
[
if
idx
>
0
and
text_index
[
batch_idx
][
idx
-
1
]
==
text_index
[
batch_idx
][
idx
]:
batch_idx
][
idx
]:
continue
continue
...
@@ -107,7 +108,7 @@ class CTCLabelDecode(BaseRecLabelDecode):
...
@@ -107,7 +108,7 @@ class CTCLabelDecode(BaseRecLabelDecode):
text
=
self
.
decode
(
preds_idx
,
preds_prob
)
text
=
self
.
decode
(
preds_idx
,
preds_prob
)
if
label
is
None
:
if
label
is
None
:
return
text
return
text
label
=
self
.
decode
(
label
)
label
=
self
.
decode
(
label
,
is_remove_duplicate
=
False
)
return
text
,
label
return
text
,
label
def
add_special_char
(
self
,
dict_character
):
def
add_special_char
(
self
,
dict_character
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment