Commit 9f1b37dd authored by Alexei Baevski's avatar Alexei Baevski Committed by Myle Ott
Browse files

fix alignment when using uneven batches and left pad

parent 663fd806
......@@ -9,6 +9,7 @@ import math
import torch
from fairseq import utils
from fairseq.data import LanguagePairDataset
from fairseq.models import FairseqIncrementalDecoder
......@@ -135,11 +136,12 @@ class SequenceGenerator(object):
cand_size = 2 * beam_size # 2 x beam size in case half are EOS
# offset arrays for converting between different indexing schemes
bbsz_offsets = (torch.arange(0, bsz)*beam_size).unsqueeze(1).type_as(tokens)
bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens)
cand_offsets = torch.arange(0, cand_size).type_as(tokens)
# helper function for allocating buffers on the fly
buffers = {}
def buffer(name, type_of=tokens): # noqa
if name not in buffers:
buffers[name] = type_of.new()
......@@ -186,7 +188,7 @@ class SequenceGenerator(object):
# clone relevant token and attention tensors
tokens_clone = tokens.index_select(0, bbsz_idx)
tokens_clone = tokens_clone[:, 1:step+2] # skip the first index, which is EOS
tokens_clone = tokens_clone[:, 1:step + 2] # skip the first index, which is EOS
tokens_clone[:, step] = self.eos
attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1:step+2]
......@@ -198,7 +200,7 @@ class SequenceGenerator(object):
# normalize sentence-level scores
if self.normalize_scores:
eos_scores /= (step+1)**self.len_penalty
eos_scores /= (step + 1) ** self.len_penalty
cum_unfin = []
prev = 0
......@@ -216,11 +218,17 @@ class SequenceGenerator(object):
sents_seen.add((sent, unfin_idx))
def get_hypo():
_, alignment = attn_clone[i].max(dim=0)
# remove padding tokens from attn scores
nonpad_idxs = src_tokens[sent].ne(self.pad)
hypo_attn = attn_clone[i][nonpad_idxs]
_, alignment = hypo_attn.max(dim=0)
return {
'tokens': tokens_clone[i],
'score': score,
'attention': attn_clone[i], # src_len x tgt_len
'attention': hypo_attn, # src_len x tgt_len
'alignment': alignment,
'positional_scores': pos_scores[i],
}
......@@ -263,7 +271,7 @@ class SequenceGenerator(object):
encoder_outs[i] = model.decoder.reorder_encoder_out(encoder_outs[i], reorder_state)
probs, avg_attn_scores = self._decode(
tokens[:, :step+1], encoder_outs, incremental_states)
tokens[:, :step + 1], encoder_outs, incremental_states)
if step == 0:
# at the first step all hypotheses are equally likely, so use
# only the first beam
......@@ -272,13 +280,13 @@ class SequenceGenerator(object):
scores_buf = scores_buf.type_as(probs)
elif not self.sampling:
# make probs contain cumulative scores for each hypothesis
probs.add_(scores[:, step-1].view(-1, 1))
probs.add_(scores[:, step - 1].view(-1, 1))
probs[:, self.pad] = -math.inf # never select pad
probs[:, self.unk] -= self.unk_penalty # apply unk penalty
# Record attention scores
attn[:, :, step+1].copy_(avg_attn_scores)
attn[:, :, step + 1].copy_(avg_attn_scores)
cand_scores = buffer('cand_scores', type_of=scores)
cand_indices = buffer('cand_indices')
......@@ -315,7 +323,7 @@ class SequenceGenerator(object):
# make scores cumulative
cand_scores.add_(
torch.gather(
scores[:, step-1].view(bsz, beam_size), dim=1,
scores[:, step - 1].view(bsz, beam_size), dim=1,
index=cand_beams,
)
)
......@@ -406,7 +414,7 @@ class SequenceGenerator(object):
# After, the min values per row are the top candidate active hypos
active_mask = buffer('active_mask')
torch.add(
eos_mask.type_as(cand_offsets)*cand_size,
eos_mask.type_as(cand_offsets) * cand_size,
cand_offsets[:eos_mask.size(1)],
out=active_mask,
)
......@@ -433,12 +441,12 @@ class SequenceGenerator(object):
# copy tokens and scores for active hypotheses
torch.index_select(
tokens[:, :step+1], dim=0, index=active_bbsz_idx,
out=tokens_buf[:, :step+1],
tokens[:, :step + 1], dim=0, index=active_bbsz_idx,
out=tokens_buf[:, :step + 1],
)
torch.gather(
cand_indices, dim=1, index=active_hypos,
out=tokens_buf.view(bsz, beam_size, -1)[:, :, step+1],
out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1],
)
if step > 0:
torch.index_select(
......@@ -452,8 +460,8 @@ class SequenceGenerator(object):
# copy attention for active hypotheses
torch.index_select(
attn[:, :, :step+2], dim=0, index=active_bbsz_idx,
out=attn_buf[:, :, :step+2],
attn[:, :, :step + 2], dim=0, index=active_bbsz_idx,
out=attn_buf[:, :, :step + 2],
)
# swap buffers
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment