Commit bb5f15d1 authored by Myle Ott's avatar Myle Ott
Browse files

Iterate on need_attn and fix tests

parent 498a186d
......@@ -14,7 +14,7 @@ class FairseqIncrementalDecoder(FairseqDecoder):
def __init__(self, dictionary):
super().__init__(dictionary)
def forward(self, prev_output_tokens, encoder_out, incremental_state=None, need_attn=False):
def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
raise NotImplementedError
def reorder_incremental_state(self, incremental_state, new_order):
......
......@@ -104,9 +104,9 @@ class FairseqModel(BaseFairseqModel):
assert isinstance(self.encoder, FairseqEncoder)
assert isinstance(self.decoder, FairseqDecoder)
def forward(self, src_tokens, src_lengths, prev_output_tokens, need_attn=False):
def forward(self, src_tokens, src_lengths, prev_output_tokens):
encoder_out = self.encoder(src_tokens, src_lengths)
decoder_out = self.decoder(prev_output_tokens, encoder_out, need_attn=need_attn)
decoder_out = self.decoder(prev_output_tokens, encoder_out)
return decoder_out
def max_positions(self):
......
......@@ -352,6 +352,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
self.dropout = dropout
self.normalization_constant = normalization_constant
self.left_pad = left_pad
self.need_attn = True
convolutions = extend_conv_spec(convolutions)
in_channels = convolutions[0][0]
......@@ -417,7 +418,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
else:
self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout)
def forward(self, prev_output_tokens, encoder_out_dict=None, incremental_state=None, need_attn=False):
def forward(self, prev_output_tokens, encoder_out_dict=None, incremental_state=None):
if encoder_out_dict is not None:
encoder_out = encoder_out_dict['encoder_out']
encoder_padding_mask = encoder_out_dict['encoder_padding_mask']
......@@ -467,7 +468,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
x, attn_scores = attention(x, target_embedding, (encoder_a, encoder_b), encoder_padding_mask)
if need_attn:
if self.need_attn:
attn_scores = attn_scores / num_attn_layers
if avg_attn_scores is None:
avg_attn_scores = attn_scores
......@@ -523,6 +524,9 @@ class FConvDecoder(FairseqIncrementalDecoder):
state_dict['decoder.version'] = torch.Tensor([1])
return state_dict
def make_generation_fast_(self, need_attn=False, **kwargs):
self.need_attn = need_attn
def _embed_tokens(self, tokens, incremental_state):
if incremental_state is not None:
# keep only the last token for incremental forward pass
......
......@@ -259,6 +259,7 @@ class FConvDecoder(FairseqDecoder):
self.pretrained_decoder = trained_decoder
self.dropout = dropout
self.left_pad = left_pad
self.need_attn = True
in_channels = convolutions[0][0]
def expand_bool_array(val):
......@@ -352,7 +353,7 @@ class FConvDecoder(FairseqDecoder):
self.pretrained_decoder.fc2.register_forward_hook(save_output())
def forward(self, prev_output_tokens, encoder_out_dict, need_attn=False):
def forward(self, prev_output_tokens, encoder_out_dict):
encoder_out = encoder_out_dict['encoder']['encoder_out']
trained_encoder_out = encoder_out_dict['pretrained'] if self.pretrained else None
......@@ -388,7 +389,7 @@ class FConvDecoder(FairseqDecoder):
r = x
x, attn_scores = attention(attproj(x) + target_embedding, encoder_a, encoder_b)
x = x + r
if need_attn:
if self.need_attn:
if avg_attn_scores is None:
avg_attn_scores = attn_scores
else:
......@@ -427,6 +428,9 @@ class FConvDecoder(FairseqDecoder):
"""Maximum output length supported by the decoder."""
return self.embed_positions.max_positions()
def make_generation_fast_(self, need_attn=False, **kwargs):
self.need_attn = need_attn
def _split_encoder_out(self, encoder_out):
"""Split and transpose encoder outputs."""
# transpose only once to speed up attention layers
......
......@@ -298,6 +298,7 @@ class LSTMDecoder(FairseqIncrementalDecoder):
self.dropout_out = dropout_out
self.hidden_size = hidden_size
self.share_input_output_embed = share_input_output_embed
self.need_attn = True
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
......@@ -324,7 +325,7 @@ class LSTMDecoder(FairseqIncrementalDecoder):
if not self.share_input_output_embed:
self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out)
def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None, need_attn=False):
def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None):
encoder_out = encoder_out_dict['encoder_out']
encoder_padding_mask = encoder_out_dict['encoder_padding_mask']
......@@ -395,7 +396,7 @@ class LSTMDecoder(FairseqIncrementalDecoder):
x = x.transpose(1, 0)
# srclen x tgtlen x bsz -> bsz x tgtlen x srclen
attn_scores = attn_scores.transpose(0, 2) if need_attn else None
attn_scores = attn_scores.transpose(0, 2) if self.need_attn else None
# project back to size of vocabulary
if hasattr(self, 'additional_fc'):
......@@ -426,6 +427,9 @@ class LSTMDecoder(FairseqIncrementalDecoder):
"""Maximum output length supported by the decoder."""
return int(1e5) # an arbitrary large number
def make_generation_fast_(self, need_attn=False, **kwargs):
self.need_attn = need_attn
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
......
......@@ -193,6 +193,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
super().__init__(dictionary)
self.dropout = args.dropout
self.share_input_output_embed = args.share_decoder_input_output_embed
self.need_attn = True
embed_dim = embed_tokens.embedding_dim
padding_idx = embed_tokens.padding_idx
......@@ -215,7 +216,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), embed_dim))
nn.init.normal_(self.embed_out, mean=0, std=embed_dim ** -0.5)
def forward(self, prev_output_tokens, encoder_out, incremental_state=None, need_attn=False):
def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
# embed positions
positions = self.embed_positions(
prev_output_tokens,
......@@ -266,6 +267,9 @@ class TransformerDecoder(FairseqIncrementalDecoder):
state_dict['decoder.embed_positions._float_tensor'] = torch.FloatTensor()
return state_dict
def make_generation_fast_(self, need_attn=False, **kwargs):
self.need_attn = need_attn
class TransformerEncoderLayer(nn.Module):
"""Encoder layer block.
......@@ -339,8 +343,9 @@ class TransformerDecoderLayer(nn.Module):
self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)
self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for i in range(3)])
self.need_attn = True
def forward(self, x, encoder_out, encoder_padding_mask, incremental_state, need_attn=False):
def forward(self, x, encoder_out, encoder_padding_mask, incremental_state):
residual = x
x = self.maybe_layer_norm(0, x, before=True)
x, _ = self.self_attn(
......@@ -364,7 +369,7 @@ class TransformerDecoderLayer(nn.Module):
key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
static_kv=True,
need_weights=need_attn,
need_weights=self.need_attn,
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
......@@ -387,6 +392,9 @@ class TransformerDecoderLayer(nn.Module):
else:
return x
def make_generation_fast_(self, need_attn=False, **kwargs):
self.need_attn = need_attn
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
......
......@@ -54,7 +54,7 @@ class SequenceGenerator(object):
def generate_batched_itr(
self, data_itr, beam_size=None, maxlen_a=0.0, maxlen_b=None,
cuda=False, timer=None, prefix_size=0, with_attention=False,
cuda=False, timer=None, prefix_size=0,
):
"""Iterate over a batched dataset and yield individual translations.
Args:
......@@ -81,7 +81,6 @@ class SequenceGenerator(object):
beam_size=beam_size,
maxlen=int(maxlen_a*srclen + maxlen_b),
prefix_tokens=s['target'][:, :prefix_size] if prefix_size > 0 else None,
with_attention=with_attention,
)
if timer is not None:
timer.stop(sum(len(h[0]['tokens']) for h in hypos))
......@@ -91,12 +90,12 @@ class SequenceGenerator(object):
ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None
yield id, src, ref, hypos[i]
def generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None, with_attention=False):
def generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None):
"""Generate a batch of translations."""
with torch.no_grad():
return self._generate(src_tokens, src_lengths, beam_size, maxlen, prefix_tokens, with_attention)
return self._generate(src_tokens, src_lengths, beam_size, maxlen, prefix_tokens)
def _generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None, with_attention=False):
def _generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None):
bsz, srclen = src_tokens.size()
maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen
......@@ -127,9 +126,8 @@ class SequenceGenerator(object):
tokens = src_tokens.data.new(bsz * beam_size, maxlen + 2).fill_(self.pad)
tokens_buf = tokens.clone()
tokens[:, 0] = self.eos
attn = scores.new(bsz * beam_size, src_tokens.size(1), maxlen + 2)
attn_buf = attn.clone()
nonpad_idxs = src_tokens.ne(self.pad) if with_attention else None
attn, attn_buf = None, None
nonpad_idxs = None
# list of completed sentences
finalized = [[] for i in range(bsz)]
......@@ -193,7 +191,7 @@ class SequenceGenerator(object):
tokens_clone = tokens.index_select(0, bbsz_idx)
tokens_clone = tokens_clone[:, 1:step + 2] # skip the first index, which is EOS
tokens_clone[:, step] = self.eos
attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1:step+2]
attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1:step+2] if attn is not None else None
# compute scores per token position
pos_scores = scores.index_select(0, bbsz_idx)[:, :step+1]
......@@ -222,7 +220,7 @@ class SequenceGenerator(object):
def get_hypo():
if with_attention:
if attn_clone is not None:
# remove padding tokens from attn scores
hypo_attn = attn_clone[i][nonpad_idxs[sent]]
_, alignment = hypo_attn.max(dim=0)
......@@ -275,8 +273,7 @@ class SequenceGenerator(object):
model.decoder.reorder_incremental_state(incremental_states[model], reorder_state)
encoder_outs[i] = model.encoder.reorder_encoder_out(encoder_outs[i], reorder_state)
probs, avg_attn_scores = self._decode(
tokens[:, :step + 1], encoder_outs, incremental_states, with_attention)
probs, avg_attn_scores = self._decode(tokens[:, :step + 1], encoder_outs, incremental_states)
if step == 0:
# at the first step all hypotheses are equally likely, so use
# only the first beam
......@@ -292,6 +289,10 @@ class SequenceGenerator(object):
# Record attention scores
if avg_attn_scores is not None:
if attn is None:
attn = scores.new(bsz * beam_size, src_tokens.size(1), maxlen + 2)
attn_buf = attn.clone()
nonpad_idxs = src_tokens.ne(self.pad)
attn[:, :, step + 1].copy_(avg_attn_scores)
cand_scores = buffer('cand_scores', type_of=scores)
......@@ -423,6 +424,7 @@ class SequenceGenerator(object):
scores_buf.resize_as_(scores)
tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
tokens_buf.resize_as_(tokens)
if attn is not None:
attn = attn.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, attn.size(1), -1)
attn_buf.resize_as_(attn)
bsz = new_bsz
......@@ -479,6 +481,7 @@ class SequenceGenerator(object):
)
# copy attention for active hypotheses
if attn is not None:
torch.index_select(
attn[:, :, :step + 2], dim=0, index=active_bbsz_idx,
out=attn_buf[:, :, :step + 2],
......@@ -487,6 +490,7 @@ class SequenceGenerator(object):
# swap buffers
tokens, tokens_buf = tokens_buf, tokens
scores, scores_buf = scores_buf, scores
if attn is not None:
attn, attn_buf = attn_buf, attn
# reorder incremental state in decoder
......@@ -498,16 +502,14 @@ class SequenceGenerator(object):
return finalized
def _decode(self, tokens, encoder_outs, incremental_states, with_attention):
def _decode(self, tokens, encoder_outs, incremental_states):
if len(self.models) == 1:
return self._decode_one(tokens, self.models[0], encoder_outs[0], incremental_states, log_probs=True,
with_attention=with_attention, )
return self._decode_one(tokens, self.models[0], encoder_outs[0], incremental_states, log_probs=True)
avg_probs = None
avg_attn = None
for model, encoder_out in zip(self.models, encoder_outs):
probs, attn = self._decode_one(tokens, model, encoder_out, incremental_states, log_probs=False,
with_attention=with_attention, )
probs, attn = self._decode_one(tokens, model, encoder_out, incremental_states, log_probs=False)
if avg_probs is None:
avg_probs = probs
else:
......@@ -523,13 +525,12 @@ class SequenceGenerator(object):
avg_attn.div_(len(self.models))
return avg_probs, avg_attn
def _decode_one(self, tokens, model, encoder_out, incremental_states, log_probs, with_attention):
def _decode_one(self, tokens, model, encoder_out, incremental_states, log_probs):
with torch.no_grad():
if incremental_states[model] is not None:
decoder_out = list(model.decoder(tokens, encoder_out, incremental_state=incremental_states[model],
need_attn=with_attention))
decoder_out = list(model.decoder(tokens, encoder_out, incremental_state=incremental_states[model]))
else:
decoder_out = list(model.decoder(tokens, encoder_out, incremental_state=None, need_attn=with_attention))
decoder_out = list(model.decoder(tokens, encoder_out))
decoder_out[0] = decoder_out[0][:, -1, :]
attn = decoder_out[1]
if attn is not None:
......
......@@ -42,7 +42,10 @@ def main(args):
# Optimize ensemble for generation
for model in models:
model.make_generation_fast_(beamable_mm_beam_size=None if args.no_beamable_mm else args.beam)
model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
need_attn=args.print_alignment,
)
if args.fp16:
model.half()
......@@ -88,7 +91,6 @@ def main(args):
translations = translator.generate_batched_itr(
t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
cuda=use_cuda, timer=gen_timer, prefix_size=args.prefix_size,
with_attention=args.print_alignment,
)
wps_meter = TimeMeter()
......
......@@ -81,7 +81,10 @@ def main(args):
# Optimize ensemble for generation
for model in models:
model.make_generation_fast_(beamable_mm_beam_size=None if args.no_beamable_mm else args.beam)
model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
need_attn=args.print_alignment,
)
if args.fp16:
model.half()
......@@ -112,13 +115,16 @@ def main(args):
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo['tokens'].int().cpu(),
src_str=src_str,
alignment=hypo['alignment'].int().cpu(),
alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
align_dict=align_dict,
tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe,
)
result.hypos.append('H\t{}\t{}'.format(hypo['score'], hypo_str))
result.alignments.append('A\t{}'.format(' '.join(map(lambda x: str(utils.item(x)), alignment))))
result.alignments.append(
'A\t{}'.format(' '.join(map(lambda x: str(utils.item(x)), alignment)))
if args.print_alignment else None
)
return result
def process_batch(batch):
......@@ -152,6 +158,7 @@ def main(args):
print(result.src_str)
for hypo, align in zip(result.hypos, result.alignments):
print(hypo)
if align is not None:
print(align)
......
......@@ -203,6 +203,7 @@ def generate_main(data_dir, extra_flags=None):
'--max-len-b', '5',
'--gen-subset', 'valid',
'--no-progress-bar',
'--print-alignment',
] + (extra_flags or []),
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment