Commit 89e19d42 authored by Alexei Baevski's avatar Alexei Baevski Committed by Myle Ott
Browse files

disable printing alignment by default (for perf) and add a flag to enable it

parent f472d141
......@@ -14,7 +14,7 @@ class FairseqIncrementalDecoder(FairseqDecoder):
def __init__(self, dictionary):
super().__init__(dictionary)
def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
def forward(self, prev_output_tokens, encoder_out, incremental_state=None, need_attn=False):
raise NotImplementedError
def reorder_incremental_state(self, incremental_state, new_order):
......
......@@ -104,9 +104,9 @@ class FairseqModel(BaseFairseqModel):
assert isinstance(self.encoder, FairseqEncoder)
assert isinstance(self.decoder, FairseqDecoder)
def forward(self, src_tokens, src_lengths, prev_output_tokens):
def forward(self, src_tokens, src_lengths, prev_output_tokens, need_attn):
encoder_out = self.encoder(src_tokens, src_lengths)
decoder_out = self.decoder(prev_output_tokens, encoder_out)
decoder_out = self.decoder(prev_output_tokens, encoder_out, need_attn=need_attn)
return decoder_out
def max_positions(self):
......
......@@ -417,7 +417,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
else:
self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout)
def forward(self, prev_output_tokens, encoder_out_dict=None, incremental_state=None):
def forward(self, prev_output_tokens, encoder_out_dict=None, incremental_state=None, need_attn=False):
if encoder_out_dict is not None:
encoder_out = encoder_out_dict['encoder_out']
encoder_padding_mask = encoder_out_dict['encoder_padding_mask']
......@@ -466,6 +466,8 @@ class FConvDecoder(FairseqIncrementalDecoder):
x = self._transpose_if_training(x, incremental_state)
x, attn_scores = attention(x, target_embedding, (encoder_a, encoder_b), encoder_padding_mask)
if need_attn:
attn_scores = attn_scores / num_attn_layers
if avg_attn_scores is None:
avg_attn_scores = attn_scores
......
......@@ -352,7 +352,7 @@ class FConvDecoder(FairseqDecoder):
self.pretrained_decoder.fc2.register_forward_hook(save_output())
def forward(self, prev_output_tokens, encoder_out_dict):
def forward(self, prev_output_tokens, encoder_out_dict, need_attn=False):
encoder_out = encoder_out_dict['encoder']['encoder_out']
trained_encoder_out = encoder_out_dict['pretrained'] if self.pretrained else None
......@@ -388,6 +388,7 @@ class FConvDecoder(FairseqDecoder):
r = x
x, attn_scores = attention(attproj(x) + target_embedding, encoder_a, encoder_b)
x = x + r
if need_attn:
if avg_attn_scores is None:
avg_attn_scores = attn_scores
else:
......
......@@ -320,7 +320,7 @@ class LSTMDecoder(FairseqIncrementalDecoder):
if not self.share_input_output_embed:
self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out)
def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None):
def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None, need_attn=False):
encoder_out = encoder_out_dict['encoder_out']
encoder_padding_mask = encoder_out_dict['encoder_padding_mask']
......@@ -391,7 +391,7 @@ class LSTMDecoder(FairseqIncrementalDecoder):
x = x.transpose(1, 0)
# srclen x tgtlen x bsz -> bsz x tgtlen x srclen
attn_scores = attn_scores.transpose(0, 2)
attn_scores = attn_scores.transpose(0, 2) if need_attn else None
# project back to size of vocabulary
if hasattr(self, 'additional_fc'):
......
......@@ -215,7 +215,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), embed_dim))
nn.init.normal_(self.embed_out, mean=0, std=embed_dim ** -0.5)
def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
def forward(self, prev_output_tokens, encoder_out, incremental_state=None, need_attn=False):
# embed positions
positions = self.embed_positions(
prev_output_tokens,
......@@ -340,7 +340,7 @@ class TransformerDecoderLayer(nn.Module):
self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)
self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for i in range(3)])
def forward(self, x, encoder_out, encoder_padding_mask, incremental_state):
def forward(self, x, encoder_out, encoder_padding_mask, incremental_state, need_attn=False):
residual = x
x = self.maybe_layer_norm(0, x, before=True)
x, _ = self.self_attn(
......@@ -364,6 +364,7 @@ class TransformerDecoderLayer(nn.Module):
key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
static_kv=True,
need_weights=need_attn,
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
......
......@@ -290,6 +290,8 @@ def add_generation_args(parser):
help='sample from top K likely next words instead of all words')
group.add_argument('--sampling-temperature', default=1, type=float, metavar='N',
help='temperature for random sampling')
group.add_argument('--print-alignment', action='store_true',
help='if set, uses attention feedback to compute and print alignment to source tokens')
group.add_argument('--model-overrides', default="{}", type=str, metavar='DICT',
help='a dictionary used to override model args at generation that were used during model training')
return group
......
......@@ -54,7 +54,7 @@ class SequenceGenerator(object):
def generate_batched_itr(
self, data_itr, beam_size=None, maxlen_a=0.0, maxlen_b=None,
cuda=False, timer=None, prefix_size=0,
cuda=False, timer=None, prefix_size=0, with_attention=False,
):
"""Iterate over a batched dataset and yield individual translations.
Args:
......@@ -81,6 +81,7 @@ class SequenceGenerator(object):
beam_size=beam_size,
maxlen=int(maxlen_a*srclen + maxlen_b),
prefix_tokens=s['target'][:, :prefix_size] if prefix_size > 0 else None,
with_attention=with_attention,
)
if timer is not None:
timer.stop(sum(len(h[0]['tokens']) for h in hypos))
......@@ -90,12 +91,12 @@ class SequenceGenerator(object):
ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None
yield id, src, ref, hypos[i]
def generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None):
def generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None, with_attention=False):
"""Generate a batch of translations."""
with torch.no_grad():
return self._generate(src_tokens, src_lengths, beam_size, maxlen, prefix_tokens)
return self._generate(src_tokens, src_lengths, beam_size, maxlen, prefix_tokens, with_attention)
def _generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None):
def _generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None, with_attention=False):
bsz, srclen = src_tokens.size()
maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen
......@@ -128,6 +129,7 @@ class SequenceGenerator(object):
tokens[:, 0] = self.eos
attn = scores.new(bsz * beam_size, src_tokens.size(1), maxlen + 2)
attn_buf = attn.clone()
nonpad_idxs = src_tokens.ne(self.pad) if with_attention else None
# list of completed sentences
finalized = [[] for i in range(bsz)]
......@@ -220,10 +222,13 @@ class SequenceGenerator(object):
def get_hypo():
if with_attention:
# remove padding tokens from attn scores
nonpad_idxs = src_tokens[sent].ne(self.pad)
hypo_attn = attn_clone[i][nonpad_idxs]
hypo_attn = attn_clone[i][nonpad_idxs[sent]]
_, alignment = hypo_attn.max(dim=0)
else:
hypo_attn = None
alignment = None
return {
'tokens': tokens_clone[i],
......@@ -271,7 +276,7 @@ class SequenceGenerator(object):
encoder_outs[i] = model.encoder.reorder_encoder_out(encoder_outs[i], reorder_state)
probs, avg_attn_scores = self._decode(
tokens[:, :step + 1], encoder_outs, incremental_states)
tokens[:, :step + 1], encoder_outs, incremental_states, with_attention)
if step == 0:
# at the first step all hypotheses are equally likely, so use
# only the first beam
......@@ -286,6 +291,7 @@ class SequenceGenerator(object):
probs[:, self.unk] -= self.unk_penalty # apply unk penalty
# Record attention scores
if avg_attn_scores is not None:
attn[:, :, step + 1].copy_(avg_attn_scores)
cand_scores = buffer('cand_scores', type_of=scores)
......@@ -492,14 +498,16 @@ class SequenceGenerator(object):
return finalized
def _decode(self, tokens, encoder_outs, incremental_states):
def _decode(self, tokens, encoder_outs, incremental_states, with_attention):
if len(self.models) == 1:
return self._decode_one(tokens, self.models[0], encoder_outs[0], incremental_states, log_probs=True)
return self._decode_one(tokens, self.models[0], encoder_outs[0], incremental_states, log_probs=True,
with_attention=with_attention, )
avg_probs = None
avg_attn = None
for model, encoder_out in zip(self.models, encoder_outs):
probs, attn = self._decode_one(tokens, model, encoder_out, incremental_states, log_probs=False)
probs, attn = self._decode_one(tokens, model, encoder_out, incremental_states, log_probs=False,
with_attention=with_attention, )
if avg_probs is None:
avg_probs = probs
else:
......@@ -515,12 +523,13 @@ class SequenceGenerator(object):
avg_attn.div_(len(self.models))
return avg_probs, avg_attn
def _decode_one(self, tokens, model, encoder_out, incremental_states, log_probs):
def _decode_one(self, tokens, model, encoder_out, incremental_states, log_probs, with_attention):
with torch.no_grad():
if incremental_states[model] is not None:
decoder_out = list(model.decoder(tokens, encoder_out, incremental_states[model]))
decoder_out = list(model.decoder(tokens, encoder_out, incremental_state=incremental_states[model],
need_attn=with_attention))
else:
decoder_out = list(model.decoder(tokens, encoder_out))
decoder_out = list(model.decoder(tokens, encoder_out, incremental_state=None, need_attn=with_attention))
decoder_out[0] = decoder_out[0][:, -1, :]
attn = decoder_out[1]
if attn is not None:
......
......@@ -88,6 +88,7 @@ def main(args):
translations = translator.generate_batched_itr(
t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
cuda=use_cuda, timer=gen_timer, prefix_size=args.prefix_size,
with_attention=args.print_alignment,
)
wps_meter = TimeMeter()
......@@ -115,7 +116,7 @@ def main(args):
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo['tokens'].int().cpu(),
src_str=src_str,
alignment=hypo['alignment'].int().cpu(),
alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
align_dict=align_dict,
tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe,
......@@ -130,6 +131,8 @@ def main(args):
hypo['positional_scores'].tolist(),
))
))
if args.print_alignment:
print('A-{}\t{}'.format(
sample_id,
' '.join(map(lambda x: str(utils.item(x)), alignment))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment