Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
0ef2856c
Commit
0ef2856c
authored
Jul 12, 2018
by
Myle Ott
Browse files
Don't compute unnecessary attention averages during training
parent
c37fc8fd
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
7 additions
and
8 deletions
+7
-8
fairseq/models/fconv.py
fairseq/models/fconv.py
+1
-1
fairseq/models/fconv_self_att.py
fairseq/models/fconv_self_att.py
+1
-1
fairseq/models/lstm.py
fairseq/models/lstm.py
+4
-1
fairseq/models/transformer.py
fairseq/models/transformer.py
+1
-5
No files found.
fairseq/models/fconv.py
View file @
0ef2856c
...
@@ -468,7 +468,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
...
@@ -468,7 +468,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
x
,
attn_scores
=
attention
(
x
,
target_embedding
,
(
encoder_a
,
encoder_b
),
encoder_padding_mask
)
x
,
attn_scores
=
attention
(
x
,
target_embedding
,
(
encoder_a
,
encoder_b
),
encoder_padding_mask
)
if
self
.
need_attn
:
if
not
self
.
training
and
self
.
need_attn
:
attn_scores
=
attn_scores
/
num_attn_layers
attn_scores
=
attn_scores
/
num_attn_layers
if
avg_attn_scores
is
None
:
if
avg_attn_scores
is
None
:
avg_attn_scores
=
attn_scores
avg_attn_scores
=
attn_scores
...
...
fairseq/models/fconv_self_att.py
View file @
0ef2856c
...
@@ -389,7 +389,7 @@ class FConvDecoder(FairseqDecoder):
...
@@ -389,7 +389,7 @@ class FConvDecoder(FairseqDecoder):
r
=
x
r
=
x
x
,
attn_scores
=
attention
(
attproj
(
x
)
+
target_embedding
,
encoder_a
,
encoder_b
)
x
,
attn_scores
=
attention
(
attproj
(
x
)
+
target_embedding
,
encoder_a
,
encoder_b
)
x
=
x
+
r
x
=
x
+
r
if
self
.
need_attn
:
if
not
self
.
training
and
self
.
need_attn
:
if
avg_attn_scores
is
None
:
if
avg_attn_scores
is
None
:
avg_attn_scores
=
attn_scores
avg_attn_scores
=
attn_scores
else
:
else
:
...
...
fairseq/models/lstm.py
View file @
0ef2856c
...
@@ -396,7 +396,10 @@ class LSTMDecoder(FairseqIncrementalDecoder):
...
@@ -396,7 +396,10 @@ class LSTMDecoder(FairseqIncrementalDecoder):
x
=
x
.
transpose
(
1
,
0
)
x
=
x
.
transpose
(
1
,
0
)
# srclen x tgtlen x bsz -> bsz x tgtlen x srclen
# srclen x tgtlen x bsz -> bsz x tgtlen x srclen
attn_scores
=
attn_scores
.
transpose
(
0
,
2
)
if
self
.
need_attn
else
None
if
not
self
.
training
and
self
.
need_attn
:
attn_scores
=
attn_scores
.
transpose
(
0
,
2
)
else
:
attn_scores
=
None
# project back to size of vocabulary
# project back to size of vocabulary
if
hasattr
(
self
,
'additional_fc'
):
if
hasattr
(
self
,
'additional_fc'
):
...
...
fairseq/models/transformer.py
View file @
0ef2856c
...
@@ -193,7 +193,6 @@ class TransformerDecoder(FairseqIncrementalDecoder):
...
@@ -193,7 +193,6 @@ class TransformerDecoder(FairseqIncrementalDecoder):
super
().
__init__
(
dictionary
)
super
().
__init__
(
dictionary
)
self
.
dropout
=
args
.
dropout
self
.
dropout
=
args
.
dropout
self
.
share_input_output_embed
=
args
.
share_decoder_input_output_embed
self
.
share_input_output_embed
=
args
.
share_decoder_input_output_embed
self
.
need_attn
=
True
embed_dim
=
embed_tokens
.
embedding_dim
embed_dim
=
embed_tokens
.
embedding_dim
padding_idx
=
embed_tokens
.
padding_idx
padding_idx
=
embed_tokens
.
padding_idx
...
@@ -267,9 +266,6 @@ class TransformerDecoder(FairseqIncrementalDecoder):
...
@@ -267,9 +266,6 @@ class TransformerDecoder(FairseqIncrementalDecoder):
state_dict
[
'decoder.embed_positions._float_tensor'
]
=
torch
.
FloatTensor
()
state_dict
[
'decoder.embed_positions._float_tensor'
]
=
torch
.
FloatTensor
()
return
state_dict
return
state_dict
def
make_generation_fast_
(
self
,
need_attn
=
False
,
**
kwargs
):
self
.
need_attn
=
need_attn
class
TransformerEncoderLayer
(
nn
.
Module
):
class
TransformerEncoderLayer
(
nn
.
Module
):
"""Encoder layer block.
"""Encoder layer block.
...
@@ -369,7 +365,7 @@ class TransformerDecoderLayer(nn.Module):
...
@@ -369,7 +365,7 @@ class TransformerDecoderLayer(nn.Module):
key_padding_mask
=
encoder_padding_mask
,
key_padding_mask
=
encoder_padding_mask
,
incremental_state
=
incremental_state
,
incremental_state
=
incremental_state
,
static_kv
=
True
,
static_kv
=
True
,
need_weights
=
self
.
need_attn
,
need_weights
=
(
not
self
.
training
and
self
.
need_attn
)
,
)
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
residual
+
x
x
=
residual
+
x
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment