Commit 9f1b37dd authored by Alexei Baevski's avatar Alexei Baevski Committed by Myle Ott
Browse files

fix alignment when using uneven batches and left pad

parent 663fd806
...@@ -9,6 +9,7 @@ import math ...@@ -9,6 +9,7 @@ import math
import torch import torch
from fairseq import utils from fairseq import utils
from fairseq.data import LanguagePairDataset
from fairseq.models import FairseqIncrementalDecoder from fairseq.models import FairseqIncrementalDecoder
...@@ -135,11 +136,12 @@ class SequenceGenerator(object): ...@@ -135,11 +136,12 @@ class SequenceGenerator(object):
cand_size = 2 * beam_size # 2 x beam size in case half are EOS cand_size = 2 * beam_size # 2 x beam size in case half are EOS
# offset arrays for converting between different indexing schemes # offset arrays for converting between different indexing schemes
bbsz_offsets = (torch.arange(0, bsz)*beam_size).unsqueeze(1).type_as(tokens) bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens)
cand_offsets = torch.arange(0, cand_size).type_as(tokens) cand_offsets = torch.arange(0, cand_size).type_as(tokens)
# helper function for allocating buffers on the fly # helper function for allocating buffers on the fly
buffers = {} buffers = {}
def buffer(name, type_of=tokens): # noqa def buffer(name, type_of=tokens): # noqa
if name not in buffers: if name not in buffers:
buffers[name] = type_of.new() buffers[name] = type_of.new()
...@@ -186,7 +188,7 @@ class SequenceGenerator(object): ...@@ -186,7 +188,7 @@ class SequenceGenerator(object):
# clone relevant token and attention tensors # clone relevant token and attention tensors
tokens_clone = tokens.index_select(0, bbsz_idx) tokens_clone = tokens.index_select(0, bbsz_idx)
tokens_clone = tokens_clone[:, 1:step+2] # skip the first index, which is EOS tokens_clone = tokens_clone[:, 1:step + 2] # skip the first index, which is EOS
tokens_clone[:, step] = self.eos tokens_clone[:, step] = self.eos
attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1:step+2] attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1:step+2]
...@@ -198,7 +200,7 @@ class SequenceGenerator(object): ...@@ -198,7 +200,7 @@ class SequenceGenerator(object):
# normalize sentence-level scores # normalize sentence-level scores
if self.normalize_scores: if self.normalize_scores:
eos_scores /= (step+1)**self.len_penalty eos_scores /= (step + 1) ** self.len_penalty
cum_unfin = [] cum_unfin = []
prev = 0 prev = 0
...@@ -216,11 +218,17 @@ class SequenceGenerator(object): ...@@ -216,11 +218,17 @@ class SequenceGenerator(object):
sents_seen.add((sent, unfin_idx)) sents_seen.add((sent, unfin_idx))
def get_hypo(): def get_hypo():
_, alignment = attn_clone[i].max(dim=0)
# remove padding tokens from attn scores
nonpad_idxs = src_tokens[sent].ne(self.pad)
hypo_attn = attn_clone[i][nonpad_idxs]
_, alignment = hypo_attn.max(dim=0)
return { return {
'tokens': tokens_clone[i], 'tokens': tokens_clone[i],
'score': score, 'score': score,
'attention': attn_clone[i], # src_len x tgt_len 'attention': hypo_attn, # src_len x tgt_len
'alignment': alignment, 'alignment': alignment,
'positional_scores': pos_scores[i], 'positional_scores': pos_scores[i],
} }
...@@ -263,7 +271,7 @@ class SequenceGenerator(object): ...@@ -263,7 +271,7 @@ class SequenceGenerator(object):
encoder_outs[i] = model.decoder.reorder_encoder_out(encoder_outs[i], reorder_state) encoder_outs[i] = model.decoder.reorder_encoder_out(encoder_outs[i], reorder_state)
probs, avg_attn_scores = self._decode( probs, avg_attn_scores = self._decode(
tokens[:, :step+1], encoder_outs, incremental_states) tokens[:, :step + 1], encoder_outs, incremental_states)
if step == 0: if step == 0:
# at the first step all hypotheses are equally likely, so use # at the first step all hypotheses are equally likely, so use
# only the first beam # only the first beam
...@@ -272,13 +280,13 @@ class SequenceGenerator(object): ...@@ -272,13 +280,13 @@ class SequenceGenerator(object):
scores_buf = scores_buf.type_as(probs) scores_buf = scores_buf.type_as(probs)
elif not self.sampling: elif not self.sampling:
# make probs contain cumulative scores for each hypothesis # make probs contain cumulative scores for each hypothesis
probs.add_(scores[:, step-1].view(-1, 1)) probs.add_(scores[:, step - 1].view(-1, 1))
probs[:, self.pad] = -math.inf # never select pad probs[:, self.pad] = -math.inf # never select pad
probs[:, self.unk] -= self.unk_penalty # apply unk penalty probs[:, self.unk] -= self.unk_penalty # apply unk penalty
# Record attention scores # Record attention scores
attn[:, :, step+1].copy_(avg_attn_scores) attn[:, :, step + 1].copy_(avg_attn_scores)
cand_scores = buffer('cand_scores', type_of=scores) cand_scores = buffer('cand_scores', type_of=scores)
cand_indices = buffer('cand_indices') cand_indices = buffer('cand_indices')
...@@ -315,7 +323,7 @@ class SequenceGenerator(object): ...@@ -315,7 +323,7 @@ class SequenceGenerator(object):
# make scores cumulative # make scores cumulative
cand_scores.add_( cand_scores.add_(
torch.gather( torch.gather(
scores[:, step-1].view(bsz, beam_size), dim=1, scores[:, step - 1].view(bsz, beam_size), dim=1,
index=cand_beams, index=cand_beams,
) )
) )
...@@ -406,7 +414,7 @@ class SequenceGenerator(object): ...@@ -406,7 +414,7 @@ class SequenceGenerator(object):
# After, the min values per row are the top candidate active hypos # After, the min values per row are the top candidate active hypos
active_mask = buffer('active_mask') active_mask = buffer('active_mask')
torch.add( torch.add(
eos_mask.type_as(cand_offsets)*cand_size, eos_mask.type_as(cand_offsets) * cand_size,
cand_offsets[:eos_mask.size(1)], cand_offsets[:eos_mask.size(1)],
out=active_mask, out=active_mask,
) )
...@@ -433,12 +441,12 @@ class SequenceGenerator(object): ...@@ -433,12 +441,12 @@ class SequenceGenerator(object):
# copy tokens and scores for active hypotheses # copy tokens and scores for active hypotheses
torch.index_select( torch.index_select(
tokens[:, :step+1], dim=0, index=active_bbsz_idx, tokens[:, :step + 1], dim=0, index=active_bbsz_idx,
out=tokens_buf[:, :step+1], out=tokens_buf[:, :step + 1],
) )
torch.gather( torch.gather(
cand_indices, dim=1, index=active_hypos, cand_indices, dim=1, index=active_hypos,
out=tokens_buf.view(bsz, beam_size, -1)[:, :, step+1], out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1],
) )
if step > 0: if step > 0:
torch.index_select( torch.index_select(
...@@ -452,8 +460,8 @@ class SequenceGenerator(object): ...@@ -452,8 +460,8 @@ class SequenceGenerator(object):
# copy attention for active hypotheses # copy attention for active hypotheses
torch.index_select( torch.index_select(
attn[:, :, :step+2], dim=0, index=active_bbsz_idx, attn[:, :, :step + 2], dim=0, index=active_bbsz_idx,
out=attn_buf[:, :, :step+2], out=attn_buf[:, :, :step + 2],
) )
# swap buffers # swap buffers
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment