Commit fc830685 authored by alexeib's avatar alexeib Committed by Myle Ott
Browse files

smarter way to avoid applying encoder key mask

parent b2374e52
...@@ -137,6 +137,8 @@ class TransformerEncoder(FairseqEncoder): ...@@ -137,6 +137,8 @@ class TransformerEncoder(FairseqEncoder):
# compute padding mask # compute padding mask
encoder_padding_mask = src_tokens.eq(self.padding_idx) encoder_padding_mask = src_tokens.eq(self.padding_idx)
if not encoder_padding_mask.any():
encoder_padding_mask = None
# encoder layers # encoder layers
for layer in self.layers: for layer in self.layers:
......
...@@ -122,15 +122,14 @@ class MultiheadAttention(nn.Module): ...@@ -122,15 +122,14 @@ class MultiheadAttention(nn.Module):
assert query.size() == key.size(), \ assert query.size() == key.size(), \
'mask_future_timesteps only applies to self-attention' 'mask_future_timesteps only applies to self-attention'
attn_weights += self.buffered_mask(attn_weights).unsqueeze(0) attn_weights += self.buffered_mask(attn_weights).unsqueeze(0)
if key_padding_mask is not None and incremental_state is None: if key_padding_mask is not None:
# don't attend to padding symbols # don't attend to padding symbols
if utils.item(key_padding_mask.max()) > 0: attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.masked_fill(
attn_weights = attn_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2),
key_padding_mask.unsqueeze(1).unsqueeze(2), -math.inf,
-math.inf, )
) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1) attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment