Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
ca9ea622
"tests/git@developer.sourcefind.cn:chenpangpang/diffusers.git" did not exist on "35db2fdea91dad8842d6d083d68e396c81b9e771"
Commit
ca9ea622
authored
Oct 20, 2020
by
WenmuZhou
Browse files
添加im2seq实现
parent
bdad0cef
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
56 deletions
+4
-56
ppocr/modeling/necks/rnn.py
ppocr/modeling/necks/rnn.py
+4
-56
No files found.
ppocr/modeling/necks/rnn.py
View file @
ca9ea622
...
@@ -21,18 +21,6 @@ from paddle import nn
...
@@ -21,18 +21,6 @@ from paddle import nn
from
ppocr.modeling.heads.rec_ctc_head
import
get_para_bias_attr
from
ppocr.modeling.heads.rec_ctc_head
import
get_para_bias_attr
class
EncoderWithReshape
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
**
kwargs
):
super
().
__init__
()
self
.
out_channels
=
in_channels
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
x
=
x
.
reshape
((
B
,
C
,
-
1
))
x
=
x
.
transpose
([
0
,
2
,
1
])
# (NTC)(batch, width, channels)
return
x
class
Im2Seq
(
nn
.
Layer
):
class
Im2Seq
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
**
kwargs
):
def
__init__
(
self
,
in_channels
,
**
kwargs
):
super
().
__init__
()
super
().
__init__
()
...
@@ -40,9 +28,8 @@ class Im2Seq(nn.Layer):
...
@@ -40,9 +28,8 @@ class Im2Seq(nn.Layer):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
B
,
C
,
H
,
W
=
x
.
shape
assert
H
==
1
x
=
x
.
reshape
((
B
,
-
1
,
W
))
x
=
x
.
transpose
((
0
,
2
,
3
,
1
))
x
=
x
.
transpose
((
0
,
2
,
1
))
# (NTC)(batch, width, channels)
x
=
x
.
reshape
((
-
1
,
C
))
return
x
return
x
...
@@ -50,49 +37,10 @@ class EncoderWithRNN(nn.Layer):
...
@@ -50,49 +37,10 @@ class EncoderWithRNN(nn.Layer):
def
__init__
(
self
,
in_channels
,
hidden_size
):
def
__init__
(
self
,
in_channels
,
hidden_size
):
super
(
EncoderWithRNN
,
self
).
__init__
()
super
(
EncoderWithRNN
,
self
).
__init__
()
self
.
out_channels
=
hidden_size
*
2
self
.
out_channels
=
hidden_size
*
2
# self.lstm1_fw = nn.LSTMCell(
# in_channels,
# hidden_size,
# weight_ih_attr=ParamAttr(name='lstm_st1_fc1_w'),
# bias_ih_attr=ParamAttr(name='lstm_st1_fc1_b'),
# weight_hh_attr=ParamAttr(name='lstm_st1_out1_w'),
# bias_hh_attr=ParamAttr(name='lstm_st1_out1_b'),
# )
# self.lstm1_bw = nn.LSTMCell(
# in_channels,
# hidden_size,
# weight_ih_attr=ParamAttr(name='lstm_st1_fc2_w'),
# bias_ih_attr=ParamAttr(name='lstm_st1_fc2_b'),
# weight_hh_attr=ParamAttr(name='lstm_st1_out2_w'),
# bias_hh_attr=ParamAttr(name='lstm_st1_out2_b'),
# )
# self.lstm2_fw = nn.LSTMCell(
# hidden_size,
# hidden_size,
# weight_ih_attr=ParamAttr(name='lstm_st2_fc1_w'),
# bias_ih_attr=ParamAttr(name='lstm_st2_fc1_b'),
# weight_hh_attr=ParamAttr(name='lstm_st2_out1_w'),
# bias_hh_attr=ParamAttr(name='lstm_st2_out1_b'),
# )
# self.lstm2_bw = nn.LSTMCell(
# hidden_size,
# hidden_size,
# weight_ih_attr=ParamAttr(name='lstm_st2_fc2_w'),
# bias_ih_attr=ParamAttr(name='lstm_st2_fc2_b'),
# weight_hh_attr=ParamAttr(name='lstm_st2_out2_w'),
# bias_hh_attr=ParamAttr(name='lstm_st2_out2_b'),
# )
self
.
lstm
=
nn
.
LSTM
(
self
.
lstm
=
nn
.
LSTM
(
in_channels
,
hidden_size
,
direction
=
'bidirectional'
,
num_layers
=
2
)
in_channels
,
hidden_size
,
direction
=
'bidirectional'
,
num_layers
=
2
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
# fw_x, _ = self.lstm1_fw(x)
# fw_x, _ = self.lstm2_fw(fw_x)
#
# # bw
# bw_x, _ = self.lstm1_bw(x)
# bw_x, _ = self.lstm2_bw(bw_x)
# x = paddle.concat([fw_x, bw_x], axis=2)
x
,
_
=
self
.
lstm
(
x
)
x
,
_
=
self
.
lstm
(
x
)
return
x
return
x
...
@@ -118,13 +66,13 @@ class EncoderWithFC(nn.Layer):
...
@@ -118,13 +66,13 @@ class EncoderWithFC(nn.Layer):
class
SequenceEncoder
(
nn
.
Layer
):
class
SequenceEncoder
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
encoder_type
,
hidden_size
=
48
,
**
kwargs
):
def
__init__
(
self
,
in_channels
,
encoder_type
,
hidden_size
=
48
,
**
kwargs
):
super
(
SequenceEncoder
,
self
).
__init__
()
super
(
SequenceEncoder
,
self
).
__init__
()
self
.
encoder_reshape
=
EncoderWithReshape
(
in_channels
)
self
.
encoder_reshape
=
Im2Seq
(
in_channels
)
self
.
out_channels
=
self
.
encoder_reshape
.
out_channels
self
.
out_channels
=
self
.
encoder_reshape
.
out_channels
if
encoder_type
==
'reshape'
:
if
encoder_type
==
'reshape'
:
self
.
only_reshape
=
True
self
.
only_reshape
=
True
else
:
else
:
support_encoder_dict
=
{
support_encoder_dict
=
{
'reshape'
:
EncoderWithReshape
,
'reshape'
:
Im2Seq
,
'fc'
:
EncoderWithFC
,
'fc'
:
EncoderWithFC
,
'rnn'
:
EncoderWithRNN
'rnn'
:
EncoderWithRNN
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment