Commit 0ef2856c authored by Myle Ott's avatar Myle Ott
Browse files

Don't compute unnecessary attention averages during training

parent c37fc8fd
......@@ -468,7 +468,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
x, attn_scores = attention(x, target_embedding, (encoder_a, encoder_b), encoder_padding_mask)
if self.need_attn:
if not self.training and self.need_attn:
attn_scores = attn_scores / num_attn_layers
if avg_attn_scores is None:
avg_attn_scores = attn_scores
......
......@@ -389,7 +389,7 @@ class FConvDecoder(FairseqDecoder):
r = x
x, attn_scores = attention(attproj(x) + target_embedding, encoder_a, encoder_b)
x = x + r
if self.need_attn:
if not self.training and self.need_attn:
if avg_attn_scores is None:
avg_attn_scores = attn_scores
else:
......
......@@ -396,7 +396,10 @@ class LSTMDecoder(FairseqIncrementalDecoder):
x = x.transpose(1, 0)
# srclen x tgtlen x bsz -> bsz x tgtlen x srclen
attn_scores = attn_scores.transpose(0, 2) if self.need_attn else None
if not self.training and self.need_attn:
attn_scores = attn_scores.transpose(0, 2)
else:
attn_scores = None
# project back to size of vocabulary
if hasattr(self, 'additional_fc'):
......
......@@ -193,7 +193,6 @@ class TransformerDecoder(FairseqIncrementalDecoder):
super().__init__(dictionary)
self.dropout = args.dropout
self.share_input_output_embed = args.share_decoder_input_output_embed
self.need_attn = True
embed_dim = embed_tokens.embedding_dim
padding_idx = embed_tokens.padding_idx
......@@ -267,9 +266,6 @@ class TransformerDecoder(FairseqIncrementalDecoder):
state_dict['decoder.embed_positions._float_tensor'] = torch.FloatTensor()
return state_dict
def make_generation_fast_(self, need_attn=False, **kwargs):
self.need_attn = need_attn
class TransformerEncoderLayer(nn.Module):
"""Encoder layer block.
......@@ -369,7 +365,7 @@ class TransformerDecoderLayer(nn.Module):
key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
static_kv=True,
need_weights=self.need_attn,
need_weights=(not self.training and self.need_attn),
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment