Commit 84754894 authored by Louis Martin's avatar Louis Martin Committed by Myle Ott
Browse files

Add attention matrix to output of SequenceGenerator

parent 376c265f
...@@ -108,8 +108,8 @@ class SequenceGenerator(object): ...@@ -108,8 +108,8 @@ class SequenceGenerator(object):
tokens = src_tokens.data.new(bsz * beam_size, maxlen + 2).fill_(self.pad) tokens = src_tokens.data.new(bsz * beam_size, maxlen + 2).fill_(self.pad)
tokens_buf = tokens.clone() tokens_buf = tokens.clone()
tokens[:, 0] = self.eos tokens[:, 0] = self.eos
align = src_tokens.data.new(bsz * beam_size, maxlen + 2).fill_(-1) attn = scores.new(bsz * beam_size, src_tokens.size(1), maxlen + 2)
align_buf = align.clone() attn_buf = attn.clone()
# list of completed sentences # list of completed sentences
finalized = [[] for i in range(bsz)] finalized = [[] for i in range(bsz)]
...@@ -177,10 +177,12 @@ class SequenceGenerator(object): ...@@ -177,10 +177,12 @@ class SequenceGenerator(object):
def get_hypo(): def get_hypo():
hypo = tokens[idx, 1:step+2].clone() # skip the first index, which is EOS hypo = tokens[idx, 1:step+2].clone() # skip the first index, which is EOS
hypo[step] = self.eos hypo[step] = self.eos
alignment = align[idx, 1:step+2].clone() attention = attn[idx, :, 1:step+2].clone()
_, alignment = attention.max(dim=0)
return { return {
'tokens': hypo, 'tokens': hypo,
'score': score, 'score': score,
'attention': attention,
'alignment': alignment, 'alignment': alignment,
} }
...@@ -224,9 +226,8 @@ class SequenceGenerator(object): ...@@ -224,9 +226,8 @@ class SequenceGenerator(object):
probs.add_(scores.view(-1, 1)) probs.add_(scores.view(-1, 1))
probs[:, self.pad] = -math.inf # never select pad probs[:, self.pad] = -math.inf # never select pad
# record alignment to source tokens, based on attention # Record attention scores
_ignore_scores = buffer('_ignore_scores', type_of=scores) attn[:, :, step+1].copy_(avg_attn_scores)
avg_attn_scores.topk(1, out=(_ignore_scores, align[:, step+1].unsqueeze(1)))
# take the best 2 x beam_size predictions. We'll choose the first # take the best 2 x beam_size predictions. We'll choose the first
# beam_size of these which don't predict eos to continue with. # beam_size of these which don't predict eos to continue with.
...@@ -290,17 +291,17 @@ class SequenceGenerator(object): ...@@ -290,17 +291,17 @@ class SequenceGenerator(object):
cand_indices.gather(1, active_hypos, cand_indices.gather(1, active_hypos,
out=tokens_buf.view(bsz, beam_size, -1)[:, :, step+1]) out=tokens_buf.view(bsz, beam_size, -1)[:, :, step+1])
# copy attention/alignment for active hypotheses # copy attention for active hypotheses
torch.index_select(align[:, :step+2], dim=0, index=active_bbsz_idx, torch.index_select(attn[:, :, :step+2], dim=0, index=active_bbsz_idx,
out=align_buf[:, :step+2]) out=attn_buf[:, :, :step+2])
# swap buffers # swap buffers
old_tokens = tokens old_tokens = tokens
tokens = tokens_buf tokens = tokens_buf
tokens_buf = old_tokens tokens_buf = old_tokens
old_align = align old_attn = attn
align = align_buf attn = attn_buf
align_buf = old_align attn_buf = old_attn
# reorder incremental state in decoder # reorder incremental state in decoder
reorder_state = active_bbsz_idx reorder_state = active_bbsz_idx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment