Commit fc830685 authored by alexeib's avatar alexeib Committed by Myle Ott
Browse files

smarter way to avoid applying encoder key mask

parent b2374e52
......@@ -137,6 +137,8 @@ class TransformerEncoder(FairseqEncoder):
# compute padding mask
encoder_padding_mask = src_tokens.eq(self.padding_idx)
if not encoder_padding_mask.any():
encoder_padding_mask = None
# encoder layers
for layer in self.layers:
......
......@@ -122,9 +122,8 @@ class MultiheadAttention(nn.Module):
assert query.size() == key.size(), \
'mask_future_timesteps only applies to self-attention'
attn_weights += self.buffered_mask(attn_weights).unsqueeze(0)
if key_padding_mask is not None and incremental_state is None:
if key_padding_mask is not None:
# don't attend to padding symbols
if utils.item(key_padding_mask.max()) > 0:
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment