Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
fc830685
Commit
fc830685
authored
Apr 05, 2018
by
alexeib
Committed by
Myle Ott
Jun 15, 2018
Browse files
smarter way to avoid applying encoder key mask
parent
b2374e52
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
8 deletions
+9
-8
fairseq/models/transformer.py
fairseq/models/transformer.py
+2
-0
fairseq/modules/multihead_attention.py
fairseq/modules/multihead_attention.py
+7
-8
No files found.
fairseq/models/transformer.py
View file @
fc830685
...
@@ -137,6 +137,8 @@ class TransformerEncoder(FairseqEncoder):
...
@@ -137,6 +137,8 @@ class TransformerEncoder(FairseqEncoder):
# compute padding mask
# compute padding mask
encoder_padding_mask
=
src_tokens
.
eq
(
self
.
padding_idx
)
encoder_padding_mask
=
src_tokens
.
eq
(
self
.
padding_idx
)
if
not
encoder_padding_mask
.
any
():
encoder_padding_mask
=
None
# encoder layers
# encoder layers
for
layer
in
self
.
layers
:
for
layer
in
self
.
layers
:
...
...
fairseq/modules/multihead_attention.py
View file @
fc830685
...
@@ -122,9 +122,8 @@ class MultiheadAttention(nn.Module):
...
@@ -122,9 +122,8 @@ class MultiheadAttention(nn.Module):
assert
query
.
size
()
==
key
.
size
(),
\
assert
query
.
size
()
==
key
.
size
(),
\
'mask_future_timesteps only applies to self-attention'
'mask_future_timesteps only applies to self-attention'
attn_weights
+=
self
.
buffered_mask
(
attn_weights
).
unsqueeze
(
0
)
attn_weights
+=
self
.
buffered_mask
(
attn_weights
).
unsqueeze
(
0
)
if
key_padding_mask
is
not
None
and
incremental_state
is
None
:
if
key_padding_mask
is
not
None
:
# don't attend to padding symbols
# don't attend to padding symbols
if
utils
.
item
(
key_padding_mask
.
max
())
>
0
:
attn_weights
=
attn_weights
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
=
attn_weights
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
=
attn_weights
.
masked_fill
(
attn_weights
=
attn_weights
.
masked_fill
(
key_padding_mask
.
unsqueeze
(
1
).
unsqueeze
(
2
),
key_padding_mask
.
unsqueeze
(
1
).
unsqueeze
(
2
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment