Commit 0ef2856c authored by Myle Ott's avatar Myle Ott
Browse files

Don't compute unnecessary attention averages during training

parent c37fc8fd
...@@ -468,7 +468,7 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -468,7 +468,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
x, attn_scores = attention(x, target_embedding, (encoder_a, encoder_b), encoder_padding_mask) x, attn_scores = attention(x, target_embedding, (encoder_a, encoder_b), encoder_padding_mask)
if self.need_attn: if not self.training and self.need_attn:
attn_scores = attn_scores / num_attn_layers attn_scores = attn_scores / num_attn_layers
if avg_attn_scores is None: if avg_attn_scores is None:
avg_attn_scores = attn_scores avg_attn_scores = attn_scores
......
...@@ -389,7 +389,7 @@ class FConvDecoder(FairseqDecoder): ...@@ -389,7 +389,7 @@ class FConvDecoder(FairseqDecoder):
r = x r = x
x, attn_scores = attention(attproj(x) + target_embedding, encoder_a, encoder_b) x, attn_scores = attention(attproj(x) + target_embedding, encoder_a, encoder_b)
x = x + r x = x + r
if self.need_attn: if not self.training and self.need_attn:
if avg_attn_scores is None: if avg_attn_scores is None:
avg_attn_scores = attn_scores avg_attn_scores = attn_scores
else: else:
......
...@@ -396,7 +396,10 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -396,7 +396,10 @@ class LSTMDecoder(FairseqIncrementalDecoder):
x = x.transpose(1, 0) x = x.transpose(1, 0)
# srclen x tgtlen x bsz -> bsz x tgtlen x srclen # srclen x tgtlen x bsz -> bsz x tgtlen x srclen
attn_scores = attn_scores.transpose(0, 2) if self.need_attn else None if not self.training and self.need_attn:
attn_scores = attn_scores.transpose(0, 2)
else:
attn_scores = None
# project back to size of vocabulary # project back to size of vocabulary
if hasattr(self, 'additional_fc'): if hasattr(self, 'additional_fc'):
......
...@@ -193,7 +193,6 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -193,7 +193,6 @@ class TransformerDecoder(FairseqIncrementalDecoder):
super().__init__(dictionary) super().__init__(dictionary)
self.dropout = args.dropout self.dropout = args.dropout
self.share_input_output_embed = args.share_decoder_input_output_embed self.share_input_output_embed = args.share_decoder_input_output_embed
self.need_attn = True
embed_dim = embed_tokens.embedding_dim embed_dim = embed_tokens.embedding_dim
padding_idx = embed_tokens.padding_idx padding_idx = embed_tokens.padding_idx
...@@ -267,9 +266,6 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -267,9 +266,6 @@ class TransformerDecoder(FairseqIncrementalDecoder):
state_dict['decoder.embed_positions._float_tensor'] = torch.FloatTensor() state_dict['decoder.embed_positions._float_tensor'] = torch.FloatTensor()
return state_dict return state_dict
def make_generation_fast_(self, need_attn=False, **kwargs):
self.need_attn = need_attn
class TransformerEncoderLayer(nn.Module): class TransformerEncoderLayer(nn.Module):
"""Encoder layer block. """Encoder layer block.
...@@ -369,7 +365,7 @@ class TransformerDecoderLayer(nn.Module): ...@@ -369,7 +365,7 @@ class TransformerDecoderLayer(nn.Module):
key_padding_mask=encoder_padding_mask, key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state, incremental_state=incremental_state,
static_kv=True, static_kv=True,
need_weights=self.need_attn, need_weights=(not self.training and self.need_attn),
) )
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x x = residual + x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment