Commit 89e19d42 authored by Alexei Baevski's avatar Alexei Baevski Committed by Myle Ott
Browse files

disable printing alignment by default (for perf) and add a flag to enable it

parent f472d141
...@@ -14,7 +14,7 @@ class FairseqIncrementalDecoder(FairseqDecoder): ...@@ -14,7 +14,7 @@ class FairseqIncrementalDecoder(FairseqDecoder):
def __init__(self, dictionary): def __init__(self, dictionary):
super().__init__(dictionary) super().__init__(dictionary)
def forward(self, prev_output_tokens, encoder_out, incremental_state=None): def forward(self, prev_output_tokens, encoder_out, incremental_state=None, need_attn=False):
raise NotImplementedError raise NotImplementedError
def reorder_incremental_state(self, incremental_state, new_order): def reorder_incremental_state(self, incremental_state, new_order):
......
...@@ -104,9 +104,9 @@ class FairseqModel(BaseFairseqModel): ...@@ -104,9 +104,9 @@ class FairseqModel(BaseFairseqModel):
assert isinstance(self.encoder, FairseqEncoder) assert isinstance(self.encoder, FairseqEncoder)
assert isinstance(self.decoder, FairseqDecoder) assert isinstance(self.decoder, FairseqDecoder)
def forward(self, src_tokens, src_lengths, prev_output_tokens): def forward(self, src_tokens, src_lengths, prev_output_tokens, need_attn):
encoder_out = self.encoder(src_tokens, src_lengths) encoder_out = self.encoder(src_tokens, src_lengths)
decoder_out = self.decoder(prev_output_tokens, encoder_out) decoder_out = self.decoder(prev_output_tokens, encoder_out, need_attn=need_attn)
return decoder_out return decoder_out
def max_positions(self): def max_positions(self):
......
...@@ -417,7 +417,7 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -417,7 +417,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
else: else:
self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout) self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout)
def forward(self, prev_output_tokens, encoder_out_dict=None, incremental_state=None): def forward(self, prev_output_tokens, encoder_out_dict=None, incremental_state=None, need_attn=False):
if encoder_out_dict is not None: if encoder_out_dict is not None:
encoder_out = encoder_out_dict['encoder_out'] encoder_out = encoder_out_dict['encoder_out']
encoder_padding_mask = encoder_out_dict['encoder_padding_mask'] encoder_padding_mask = encoder_out_dict['encoder_padding_mask']
...@@ -466,11 +466,13 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -466,11 +466,13 @@ class FConvDecoder(FairseqIncrementalDecoder):
x = self._transpose_if_training(x, incremental_state) x = self._transpose_if_training(x, incremental_state)
x, attn_scores = attention(x, target_embedding, (encoder_a, encoder_b), encoder_padding_mask) x, attn_scores = attention(x, target_embedding, (encoder_a, encoder_b), encoder_padding_mask)
attn_scores = attn_scores / num_attn_layers
if avg_attn_scores is None: if need_attn:
avg_attn_scores = attn_scores attn_scores = attn_scores / num_attn_layers
else: if avg_attn_scores is None:
avg_attn_scores.add_(attn_scores) avg_attn_scores = attn_scores
else:
avg_attn_scores.add_(attn_scores)
x = self._transpose_if_training(x, incremental_state) x = self._transpose_if_training(x, incremental_state)
......
...@@ -352,7 +352,7 @@ class FConvDecoder(FairseqDecoder): ...@@ -352,7 +352,7 @@ class FConvDecoder(FairseqDecoder):
self.pretrained_decoder.fc2.register_forward_hook(save_output()) self.pretrained_decoder.fc2.register_forward_hook(save_output())
def forward(self, prev_output_tokens, encoder_out_dict): def forward(self, prev_output_tokens, encoder_out_dict, need_attn=False):
encoder_out = encoder_out_dict['encoder']['encoder_out'] encoder_out = encoder_out_dict['encoder']['encoder_out']
trained_encoder_out = encoder_out_dict['pretrained'] if self.pretrained else None trained_encoder_out = encoder_out_dict['pretrained'] if self.pretrained else None
...@@ -388,10 +388,11 @@ class FConvDecoder(FairseqDecoder): ...@@ -388,10 +388,11 @@ class FConvDecoder(FairseqDecoder):
r = x r = x
x, attn_scores = attention(attproj(x) + target_embedding, encoder_a, encoder_b) x, attn_scores = attention(attproj(x) + target_embedding, encoder_a, encoder_b)
x = x + r x = x + r
if avg_attn_scores is None: if need_attn:
avg_attn_scores = attn_scores if avg_attn_scores is None:
else: avg_attn_scores = attn_scores
avg_attn_scores.add_(attn_scores) else:
avg_attn_scores.add_(attn_scores)
if selfattention is not None: if selfattention is not None:
x = selfattention(x) x = selfattention(x)
......
...@@ -320,7 +320,7 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -320,7 +320,7 @@ class LSTMDecoder(FairseqIncrementalDecoder):
if not self.share_input_output_embed: if not self.share_input_output_embed:
self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out) self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out)
def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None): def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None, need_attn=False):
encoder_out = encoder_out_dict['encoder_out'] encoder_out = encoder_out_dict['encoder_out']
encoder_padding_mask = encoder_out_dict['encoder_padding_mask'] encoder_padding_mask = encoder_out_dict['encoder_padding_mask']
...@@ -391,7 +391,7 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -391,7 +391,7 @@ class LSTMDecoder(FairseqIncrementalDecoder):
x = x.transpose(1, 0) x = x.transpose(1, 0)
# srclen x tgtlen x bsz -> bsz x tgtlen x srclen # srclen x tgtlen x bsz -> bsz x tgtlen x srclen
attn_scores = attn_scores.transpose(0, 2) attn_scores = attn_scores.transpose(0, 2) if need_attn else None
# project back to size of vocabulary # project back to size of vocabulary
if hasattr(self, 'additional_fc'): if hasattr(self, 'additional_fc'):
......
...@@ -215,7 +215,7 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -215,7 +215,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), embed_dim)) self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), embed_dim))
nn.init.normal_(self.embed_out, mean=0, std=embed_dim ** -0.5) nn.init.normal_(self.embed_out, mean=0, std=embed_dim ** -0.5)
def forward(self, prev_output_tokens, encoder_out, incremental_state=None): def forward(self, prev_output_tokens, encoder_out, incremental_state=None, need_attn=False):
# embed positions # embed positions
positions = self.embed_positions( positions = self.embed_positions(
prev_output_tokens, prev_output_tokens,
...@@ -340,7 +340,7 @@ class TransformerDecoderLayer(nn.Module): ...@@ -340,7 +340,7 @@ class TransformerDecoderLayer(nn.Module):
self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim) self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)
self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for i in range(3)]) self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for i in range(3)])
def forward(self, x, encoder_out, encoder_padding_mask, incremental_state): def forward(self, x, encoder_out, encoder_padding_mask, incremental_state, need_attn=False):
residual = x residual = x
x = self.maybe_layer_norm(0, x, before=True) x = self.maybe_layer_norm(0, x, before=True)
x, _ = self.self_attn( x, _ = self.self_attn(
...@@ -364,6 +364,7 @@ class TransformerDecoderLayer(nn.Module): ...@@ -364,6 +364,7 @@ class TransformerDecoderLayer(nn.Module):
key_padding_mask=encoder_padding_mask, key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state, incremental_state=incremental_state,
static_kv=True, static_kv=True,
need_weights=need_attn,
) )
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x x = residual + x
......
...@@ -290,6 +290,8 @@ def add_generation_args(parser): ...@@ -290,6 +290,8 @@ def add_generation_args(parser):
help='sample from top K likely next words instead of all words') help='sample from top K likely next words instead of all words')
group.add_argument('--sampling-temperature', default=1, type=float, metavar='N', group.add_argument('--sampling-temperature', default=1, type=float, metavar='N',
help='temperature for random sampling') help='temperature for random sampling')
group.add_argument('--print-alignment', action='store_true',
help='if set, uses attention feedback to compute and print alignment to source tokens')
group.add_argument('--model-overrides', default="{}", type=str, metavar='DICT', group.add_argument('--model-overrides', default="{}", type=str, metavar='DICT',
help='a dictionary used to override model args at generation that were used during model training') help='a dictionary used to override model args at generation that were used during model training')
return group return group
......
...@@ -54,7 +54,7 @@ class SequenceGenerator(object): ...@@ -54,7 +54,7 @@ class SequenceGenerator(object):
def generate_batched_itr( def generate_batched_itr(
self, data_itr, beam_size=None, maxlen_a=0.0, maxlen_b=None, self, data_itr, beam_size=None, maxlen_a=0.0, maxlen_b=None,
cuda=False, timer=None, prefix_size=0, cuda=False, timer=None, prefix_size=0, with_attention=False,
): ):
"""Iterate over a batched dataset and yield individual translations. """Iterate over a batched dataset and yield individual translations.
Args: Args:
...@@ -81,6 +81,7 @@ class SequenceGenerator(object): ...@@ -81,6 +81,7 @@ class SequenceGenerator(object):
beam_size=beam_size, beam_size=beam_size,
maxlen=int(maxlen_a*srclen + maxlen_b), maxlen=int(maxlen_a*srclen + maxlen_b),
prefix_tokens=s['target'][:, :prefix_size] if prefix_size > 0 else None, prefix_tokens=s['target'][:, :prefix_size] if prefix_size > 0 else None,
with_attention=with_attention,
) )
if timer is not None: if timer is not None:
timer.stop(sum(len(h[0]['tokens']) for h in hypos)) timer.stop(sum(len(h[0]['tokens']) for h in hypos))
...@@ -90,12 +91,12 @@ class SequenceGenerator(object): ...@@ -90,12 +91,12 @@ class SequenceGenerator(object):
ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None
yield id, src, ref, hypos[i] yield id, src, ref, hypos[i]
def generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None): def generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None, with_attention=False):
"""Generate a batch of translations.""" """Generate a batch of translations."""
with torch.no_grad(): with torch.no_grad():
return self._generate(src_tokens, src_lengths, beam_size, maxlen, prefix_tokens) return self._generate(src_tokens, src_lengths, beam_size, maxlen, prefix_tokens, with_attention)
def _generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None): def _generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None, with_attention=False):
bsz, srclen = src_tokens.size() bsz, srclen = src_tokens.size()
maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen
...@@ -128,6 +129,7 @@ class SequenceGenerator(object): ...@@ -128,6 +129,7 @@ class SequenceGenerator(object):
tokens[:, 0] = self.eos tokens[:, 0] = self.eos
attn = scores.new(bsz * beam_size, src_tokens.size(1), maxlen + 2) attn = scores.new(bsz * beam_size, src_tokens.size(1), maxlen + 2)
attn_buf = attn.clone() attn_buf = attn.clone()
nonpad_idxs = src_tokens.ne(self.pad) if with_attention else None
# list of completed sentences # list of completed sentences
finalized = [[] for i in range(bsz)] finalized = [[] for i in range(bsz)]
...@@ -220,10 +222,13 @@ class SequenceGenerator(object): ...@@ -220,10 +222,13 @@ class SequenceGenerator(object):
def get_hypo(): def get_hypo():
# remove padding tokens from attn scores if with_attention:
nonpad_idxs = src_tokens[sent].ne(self.pad) # remove padding tokens from attn scores
hypo_attn = attn_clone[i][nonpad_idxs] hypo_attn = attn_clone[i][nonpad_idxs[sent]]
_, alignment = hypo_attn.max(dim=0) _, alignment = hypo_attn.max(dim=0)
else:
hypo_attn = None
alignment = None
return { return {
'tokens': tokens_clone[i], 'tokens': tokens_clone[i],
...@@ -271,7 +276,7 @@ class SequenceGenerator(object): ...@@ -271,7 +276,7 @@ class SequenceGenerator(object):
encoder_outs[i] = model.encoder.reorder_encoder_out(encoder_outs[i], reorder_state) encoder_outs[i] = model.encoder.reorder_encoder_out(encoder_outs[i], reorder_state)
probs, avg_attn_scores = self._decode( probs, avg_attn_scores = self._decode(
tokens[:, :step + 1], encoder_outs, incremental_states) tokens[:, :step + 1], encoder_outs, incremental_states, with_attention)
if step == 0: if step == 0:
# at the first step all hypotheses are equally likely, so use # at the first step all hypotheses are equally likely, so use
# only the first beam # only the first beam
...@@ -286,7 +291,8 @@ class SequenceGenerator(object): ...@@ -286,7 +291,8 @@ class SequenceGenerator(object):
probs[:, self.unk] -= self.unk_penalty # apply unk penalty probs[:, self.unk] -= self.unk_penalty # apply unk penalty
# Record attention scores # Record attention scores
attn[:, :, step + 1].copy_(avg_attn_scores) if avg_attn_scores is not None:
attn[:, :, step + 1].copy_(avg_attn_scores)
cand_scores = buffer('cand_scores', type_of=scores) cand_scores = buffer('cand_scores', type_of=scores)
cand_indices = buffer('cand_indices') cand_indices = buffer('cand_indices')
...@@ -492,14 +498,16 @@ class SequenceGenerator(object): ...@@ -492,14 +498,16 @@ class SequenceGenerator(object):
return finalized return finalized
def _decode(self, tokens, encoder_outs, incremental_states): def _decode(self, tokens, encoder_outs, incremental_states, with_attention):
if len(self.models) == 1: if len(self.models) == 1:
return self._decode_one(tokens, self.models[0], encoder_outs[0], incremental_states, log_probs=True) return self._decode_one(tokens, self.models[0], encoder_outs[0], incremental_states, log_probs=True,
with_attention=with_attention, )
avg_probs = None avg_probs = None
avg_attn = None avg_attn = None
for model, encoder_out in zip(self.models, encoder_outs): for model, encoder_out in zip(self.models, encoder_outs):
probs, attn = self._decode_one(tokens, model, encoder_out, incremental_states, log_probs=False) probs, attn = self._decode_one(tokens, model, encoder_out, incremental_states, log_probs=False,
with_attention=with_attention, )
if avg_probs is None: if avg_probs is None:
avg_probs = probs avg_probs = probs
else: else:
...@@ -515,12 +523,13 @@ class SequenceGenerator(object): ...@@ -515,12 +523,13 @@ class SequenceGenerator(object):
avg_attn.div_(len(self.models)) avg_attn.div_(len(self.models))
return avg_probs, avg_attn return avg_probs, avg_attn
def _decode_one(self, tokens, model, encoder_out, incremental_states, log_probs): def _decode_one(self, tokens, model, encoder_out, incremental_states, log_probs, with_attention):
with torch.no_grad(): with torch.no_grad():
if incremental_states[model] is not None: if incremental_states[model] is not None:
decoder_out = list(model.decoder(tokens, encoder_out, incremental_states[model])) decoder_out = list(model.decoder(tokens, encoder_out, incremental_state=incremental_states[model],
need_attn=with_attention))
else: else:
decoder_out = list(model.decoder(tokens, encoder_out)) decoder_out = list(model.decoder(tokens, encoder_out, incremental_state=None, need_attn=with_attention))
decoder_out[0] = decoder_out[0][:, -1, :] decoder_out[0] = decoder_out[0][:, -1, :]
attn = decoder_out[1] attn = decoder_out[1]
if attn is not None: if attn is not None:
......
...@@ -88,6 +88,7 @@ def main(args): ...@@ -88,6 +88,7 @@ def main(args):
translations = translator.generate_batched_itr( translations = translator.generate_batched_itr(
t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b, t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
cuda=use_cuda, timer=gen_timer, prefix_size=args.prefix_size, cuda=use_cuda, timer=gen_timer, prefix_size=args.prefix_size,
with_attention=args.print_alignment,
) )
wps_meter = TimeMeter() wps_meter = TimeMeter()
...@@ -115,7 +116,7 @@ def main(args): ...@@ -115,7 +116,7 @@ def main(args):
hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo['tokens'].int().cpu(), hypo_tokens=hypo['tokens'].int().cpu(),
src_str=src_str, src_str=src_str,
alignment=hypo['alignment'].int().cpu(), alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
align_dict=align_dict, align_dict=align_dict,
tgt_dict=tgt_dict, tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe, remove_bpe=args.remove_bpe,
...@@ -130,10 +131,12 @@ def main(args): ...@@ -130,10 +131,12 @@ def main(args):
hypo['positional_scores'].tolist(), hypo['positional_scores'].tolist(),
)) ))
)) ))
print('A-{}\t{}'.format(
sample_id, if args.print_alignment:
' '.join(map(lambda x: str(utils.item(x)), alignment)) print('A-{}\t{}'.format(
)) sample_id,
' '.join(map(lambda x: str(utils.item(x)), alignment))
))
# Score only the top hypothesis # Score only the top hypothesis
if has_target and i == 0: if has_target and i == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment