Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
df4a2f6a
Commit
df4a2f6a
authored
Sep 07, 2021
by
andyjpaddle
Browse files
update rec_sar_head
parent
073fad37
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
10 deletions
+10
-10
ppocr/losses/rec_sar_loss.py
ppocr/losses/rec_sar_loss.py
+1
-1
ppocr/modeling/heads/rec_sar_head.py
ppocr/modeling/heads/rec_sar_head.py
+9
-9
No files found.
ppocr/losses/rec_sar_loss.py
View file @
df4a2f6a
...
...
@@ -9,7 +9,7 @@ from paddle import nn
class
SARLoss
(
nn
.
Layer
):
def
__init__
(
self
,
**
kwargs
):
super
(
SARLoss
,
self
).
__init__
()
self
.
loss_func
=
paddle
.
nn
.
loss
.
CrossEntropyLoss
(
reduction
=
"mean"
,
ignore_index
=
9
2
)
self
.
loss_func
=
paddle
.
nn
.
loss
.
CrossEntropyLoss
(
reduction
=
"mean"
,
ignore_index
=
9
6
)
def
forward
(
self
,
predicts
,
batch
):
predict
=
predicts
[:,
:
-
1
,
:]
# ignore last index of outputs to be in same seq_len with targets
...
...
ppocr/modeling/heads/rec_sar_head.py
View file @
df4a2f6a
...
...
@@ -118,8 +118,7 @@ class BaseDecoder(nn.Layer):
class
ParallelSARDecoder
(
BaseDecoder
):
"""
Args:
num_classes (int): Output class number.
channels (list[int]): Network layer channels.
out_channels (int): Output class number.
enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
dec_bi_rnn (bool): If True, use bidirectional RNN in decoder.
dec_drop_rnn (float): Dropout of RNN layer in decoder.
...
...
@@ -137,7 +136,7 @@ class ParallelSARDecoder(BaseDecoder):
"""
def
__init__
(
self
,
num_classes
=
93
,
# 90 + unknown + start + padding
out_channels
,
# 90 + unknown + start + padding
enc_bi_rnn
=
False
,
dec_bi_rnn
=
False
,
dec_drop_rnn
=
0.0
,
...
...
@@ -148,8 +147,6 @@ class ParallelSARDecoder(BaseDecoder):
pred_dropout
=
0.1
,
max_text_length
=
30
,
mask
=
True
,
start_idx
=
91
,
padding_idx
=
92
,
# 92
pred_concat
=
True
,
**
kwargs
):
super
().
__init__
()
...
...
@@ -157,7 +154,8 @@ class ParallelSARDecoder(BaseDecoder):
self
.
num_classes
=
num_classes
self
.
enc_bi_rnn
=
enc_bi_rnn
self
.
d_k
=
d_k
self
.
start_idx
=
start_idx
self
.
start_idx
=
out_channels
-
2
self
.
padding_idx
=
out_channels
-
1
self
.
max_seq_len
=
max_text_length
self
.
mask
=
mask
self
.
pred_concat
=
pred_concat
...
...
@@ -191,7 +189,7 @@ class ParallelSARDecoder(BaseDecoder):
# Decoder input embedding
self
.
embedding
=
nn
.
Embedding
(
self
.
num_classes
,
encoder_rnn_out_size
,
padding_idx
=
padding_idx
)
self
.
num_classes
,
encoder_rnn_out_size
,
padding_idx
=
self
.
padding_idx
)
# Prediction layer
self
.
pred_dropout
=
nn
.
Dropout
(
pred_dropout
)
...
...
@@ -330,6 +328,7 @@ class ParallelSARDecoder(BaseDecoder):
class
SARHead
(
nn
.
Layer
):
def
__init__
(
self
,
out_channels
,
enc_bi_rnn
=
False
,
enc_drop_rnn
=
0.1
,
enc_gru
=
False
,
...
...
@@ -351,7 +350,8 @@ class SARHead(nn.Layer):
# decoder module
self
.
decoder
=
ParallelSARDecoder
(
enc_bi_rnn
=
enc_bi_rnn
,
out_channels
=
out_channels
,
enc_bi_rnn
=
enc_bi_rnn
,
dec_bi_rnn
=
dec_bi_rnn
,
dec_drop_rnn
=
dec_drop_rnn
,
dec_gru
=
dec_gru
,
...
...
@@ -375,4 +375,4 @@ class SARHead(nn.Layer):
# (bsz, seq_len, num_classes)
return
final_out
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment