Commit 4abadbdf authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Fix sampling with beam>1

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/792

Differential Revision: D16591987

Pulled By: myleott

fbshipit-source-id: d27c490ae75f80ded19226b8384f4776485dd694
parent 5b2be870
......@@ -448,9 +448,7 @@ def add_generation_args(parser):
group.add_argument('--match-source-len', default=False, action='store_true',
help=('generations should match the source length'))
group.add_argument('--no-early-stop', action='store_true',
help=('continue searching even after finalizing k=beam '
'hypotheses; this is more correct, but increases '
'generation time by 50%%'))
help='deprecated')
group.add_argument('--unnormalized', action='store_true',
help='compare unnormalized hypothesis scores')
group.add_argument('--no-beamable-mm', action='store_true',
......
......@@ -25,7 +25,7 @@ class Search(object):
self.indices_buf = torch.LongTensor().to(device=t.device)
self.beams_buf = torch.LongTensor().to(device=t.device)
def step(self, step, lprobs, scores, beam_size):
def step(self, step, lprobs, scores):
"""Take a single search step.
Args:
......
......@@ -19,7 +19,6 @@ class SequenceGenerator(object):
max_len_a=0,
max_len_b=200,
min_len=1,
stop_early=True,
normalize_scores=True,
len_penalty=1.,
unk_penalty=0.,
......@@ -42,9 +41,6 @@ class SequenceGenerator(object):
ax + b, where x is the source length
min_len (int, optional): the minimum length of the generated output
(not including end-of-sentence)
stop_early (bool, optional): stop generation immediately after we
finalize beam_size hypotheses, even though longer hypotheses
might have better normalized scores (default: True)
normalize_scores (bool, optional): normalize scores by the length
of the output (default: True)
len_penalty (float, optional): length penalty, where <1.0 favors
......@@ -78,7 +74,6 @@ class SequenceGenerator(object):
self.max_len_a = max_len_a
self.max_len_b = max_len_b
self.min_len = min_len
self.stop_early = stop_early
self.normalize_scores = normalize_scores
self.len_penalty = len_penalty
self.unk_penalty = unk_penalty
......@@ -156,7 +151,7 @@ class SequenceGenerator(object):
# initialize buffers
scores = src_tokens.new(bsz * beam_size, max_len + 1).float().fill_(0)
scores_buf = scores.clone()
tokens = src_tokens.data.new(bsz * beam_size, max_len + 2).long().fill_(self.pad)
tokens = src_tokens.new(bsz * beam_size, max_len + 2).long().fill_(self.pad)
tokens_buf = tokens.clone()
tokens[:, 0] = self.eos if bos_token is None else bos_token
attn, attn_buf = None, None
......@@ -164,10 +159,15 @@ class SequenceGenerator(object):
if prefix_tokens is not None:
partial_prefix_mask_buf = torch.zeros_like(src_lengths).byte()
# The blacklist indicates candidates that should be ignored.
# For example, suppose we're sampling and have already finalized 2/5
# samples. Then the blacklist would mark 2 positions as being ignored,
# so that we only finalize the remaining 3 samples.
blacklist = src_tokens.new(bsz, beam_size).byte().fill_(0)
# list of completed sentences
finalized = [[] for i in range(bsz)]
finished = [False for i in range(bsz)]
worst_finalized = [{'idx': None, 'score': -math.inf} for i in range(bsz)]
num_remaining_sent = bsz
# number of candidate hypos per step
......@@ -185,7 +185,7 @@ class SequenceGenerator(object):
buffers[name] = type_of.new()
return buffers[name]
def is_finished(sent, step, unfin_idx, unfinalized_scores=None):
def is_finished(sent, step, unfin_idx):
"""
Check whether we've finished generation for a given sentence, by
comparing the worst score among finalized hypotheses to the best
......@@ -193,18 +193,10 @@ class SequenceGenerator(object):
"""
assert len(finalized[sent]) <= beam_size
if len(finalized[sent]) == beam_size:
if self.stop_early or step == max_len or unfinalized_scores is None:
return True
# stop if the best unfinalized score is worse than the worst
# finalized one
best_unfinalized_score = unfinalized_scores[unfin_idx].max()
if self.normalize_scores:
best_unfinalized_score /= max_len ** self.len_penalty
if worst_finalized[sent]['score'] >= best_unfinalized_score:
return True
return True
return False
def finalize_hypos(step, bbsz_idx, eos_scores, unfinalized_scores=None):
def finalize_hypos(step, bbsz_idx, eos_scores):
"""
Finalize the given hypotheses at this step, while keeping the total
number of finalized hypotheses per sentence <= beam_size.
......@@ -219,14 +211,13 @@ class SequenceGenerator(object):
indicating which hypotheses to finalize
eos_scores: A vector of the same size as bbsz_idx containing
scores for each hypothesis
unfinalized_scores: A vector containing scores for all
unfinalized hypotheses
"""
assert bbsz_idx.numel() == eos_scores.numel()
# clone relevant token and attention tensors
tokens_clone = tokens.index_select(0, bbsz_idx)
tokens_clone = tokens_clone[:, 1:step + 2] # skip the first index, which is EOS
assert not tokens_clone.eq(self.eos).any()
tokens_clone[:, step] = self.eos
attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1:step+2] if attn is not None else None
......@@ -278,23 +269,11 @@ class SequenceGenerator(object):
if len(finalized[sent]) < beam_size:
finalized[sent].append(get_hypo())
elif not self.stop_early and score > worst_finalized[sent]['score']:
# replace worst hypo for this sentence with new/better one
worst_idx = worst_finalized[sent]['idx']
if worst_idx is not None:
finalized[sent][worst_idx] = get_hypo()
# find new worst finalized hypo for this sentence
idx, s = min(enumerate(finalized[sent]), key=lambda r: r[1]['score'])
worst_finalized[sent] = {
'score': s['score'],
'idx': idx,
}
newly_finished = []
for sent, unfin_idx in sents_seen:
# check termination conditions for this sentence
if not finished[sent] and is_finished(sent, step, unfin_idx, unfinalized_scores):
if not finished[sent] and is_finished(sent, step, unfin_idx):
finished[sent] = True
newly_finished.append(unfin_idx)
return newly_finished
......@@ -318,6 +297,13 @@ class SequenceGenerator(object):
lprobs[:, self.pad] = -math.inf # never select pad
lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty
# handle min and max length constraints
if step >= max_len:
lprobs[:, :self.eos] = -math.inf
lprobs[:, self.eos + 1:] = -math.inf
elif step < self.min_len:
lprobs[:, self.eos] = -math.inf
if self.no_repeat_ngram_size > 0:
# for each beam and batch sentence, generate a list of previous ngrams
gen_ngrams = [{} for bbsz_idx in range(bsz * beam_size)]
......@@ -339,105 +325,92 @@ class SequenceGenerator(object):
scores_buf = scores_buf.type_as(lprobs)
eos_bbsz_idx = buffer('eos_bbsz_idx')
eos_scores = buffer('eos_scores', type_of=scores)
if step < max_len:
self.search.set_src_lengths(src_lengths)
if self.no_repeat_ngram_size > 0:
def calculate_banned_tokens(bbsz_idx):
# before decoding the next token, prevent decoding of ngrams that have already appeared
ngram_index = tuple(tokens[bbsz_idx, step + 2 - self.no_repeat_ngram_size:step + 1].tolist())
return gen_ngrams[bbsz_idx].get(ngram_index, [])
if step + 2 - self.no_repeat_ngram_size >= 0:
# no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
banned_tokens = [calculate_banned_tokens(bbsz_idx) for bbsz_idx in range(bsz * beam_size)]
else:
banned_tokens = [[] for bbsz_idx in range(bsz * beam_size)]
for bbsz_idx in range(bsz * beam_size):
lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = -math.inf
if prefix_tokens is not None and step < prefix_tokens.size(1):
assert isinstance(self.search, search.BeamSearch) or bsz == 1, \
"currently only BeamSearch supports decoding with prefix_tokens"
probs_slice = lprobs.view(bsz, -1, lprobs.size(-1))[:, 0, :]
cand_scores = torch.gather(
probs_slice, dim=1,
index=prefix_tokens[:, step].view(-1, 1)
).view(-1, 1).repeat(1, cand_size)
if step > 0:
# save cumulative scores for each hypothesis
cand_scores.add_(scores[:, step - 1].view(bsz, beam_size).repeat(1, 2))
cand_indices = prefix_tokens[:, step].view(-1, 1).repeat(1, cand_size)
cand_beams = torch.zeros_like(cand_indices)
# handle prefixes of different lengths
# when step == prefix_tokens.size(1), we'll have new free-decoding batches
if prefix_tokens is not None and step <= prefix_tokens.size(1):
if step < prefix_tokens.size(1):
partial_prefix_mask = prefix_tokens[:, step].eq(self.pad)
else: # all prefixes finished force-decoding
partial_prefix_mask = torch.ones(bsz).to(prefix_tokens).byte()
if partial_prefix_mask.any():
# track new free-decoding batches, at whose very first step
# only use the first beam to eliminate repeats
prefix_step0_mask = partial_prefix_mask ^ partial_prefix_mask_buf
lprobs.view(bsz, beam_size, -1)[prefix_step0_mask, 1:] = -math.inf
partial_scores, partial_indices, partial_beams = self.search.step(
step,
lprobs.view(bsz, -1, self.vocab_size),
scores.view(bsz, beam_size, -1)[:, :, :step],
)
cand_scores[partial_prefix_mask] = partial_scores[partial_prefix_mask]
cand_indices[partial_prefix_mask] = partial_indices[partial_prefix_mask]
cand_beams[partial_prefix_mask] = partial_beams[partial_prefix_mask]
partial_prefix_mask_buf = partial_prefix_mask
self.search.set_src_lengths(src_lengths)
if self.no_repeat_ngram_size > 0:
def calculate_banned_tokens(bbsz_idx):
# before decoding the next token, prevent decoding of ngrams that have already appeared
ngram_index = tuple(tokens[bbsz_idx, step + 2 - self.no_repeat_ngram_size:step + 1].tolist())
return gen_ngrams[bbsz_idx].get(ngram_index, [])
if step + 2 - self.no_repeat_ngram_size >= 0:
# no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
banned_tokens = [calculate_banned_tokens(bbsz_idx) for bbsz_idx in range(bsz * beam_size)]
else:
cand_scores, cand_indices, cand_beams = self.search.step(
banned_tokens = [[] for bbsz_idx in range(bsz * beam_size)]
for bbsz_idx in range(bsz * beam_size):
lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = -math.inf
if prefix_tokens is not None and step < prefix_tokens.size(1):
assert isinstance(self.search, search.BeamSearch) or bsz == 1, \
"currently only BeamSearch supports decoding with prefix_tokens"
probs_slice = lprobs.view(bsz, -1, lprobs.size(-1))[:, 0, :]
cand_scores = torch.gather(
probs_slice, dim=1,
index=prefix_tokens[:, step].view(-1, 1)
).view(-1, 1).repeat(1, cand_size)
if step > 0:
# save cumulative scores for each hypothesis
cand_scores.add_(scores[:, step - 1].view(bsz, beam_size).repeat(1, 2))
cand_indices = prefix_tokens[:, step].view(-1, 1).repeat(1, cand_size)
cand_beams = torch.zeros_like(cand_indices)
# handle prefixes of different lengths
# when step == prefix_tokens.size(1), we'll have new free-decoding batches
if prefix_tokens is not None and step <= prefix_tokens.size(1):
if step < prefix_tokens.size(1):
partial_prefix_mask = prefix_tokens[:, step].eq(self.pad)
else: # all prefixes finished force-decoding
partial_prefix_mask = torch.ones(bsz).to(prefix_tokens).byte()
if partial_prefix_mask.any():
# track new free-decoding batches, at whose very first step
# only use the first beam to eliminate repeats
prefix_step0_mask = partial_prefix_mask ^ partial_prefix_mask_buf
lprobs.view(bsz, beam_size, -1)[prefix_step0_mask, 1:] = -math.inf
partial_scores, partial_indices, partial_beams = self.search.step(
step,
lprobs.view(bsz, -1, self.vocab_size),
scores.view(bsz, beam_size, -1)[:, :, :step],
)
cand_scores[partial_prefix_mask] = partial_scores[partial_prefix_mask]
cand_indices[partial_prefix_mask] = partial_indices[partial_prefix_mask]
cand_beams[partial_prefix_mask] = partial_beams[partial_prefix_mask]
partial_prefix_mask_buf = partial_prefix_mask
else:
# make probs contain cumulative scores for each hypothesis
lprobs.add_(scores[:, step - 1].unsqueeze(-1))
# finalize all active hypotheses once we hit max_len
# pick the hypothesis with the highest prob of EOS right now
torch.sort(
lprobs[:, self.eos],
descending=True,
out=(eos_scores, eos_bbsz_idx),
cand_scores, cand_indices, cand_beams = self.search.step(
step,
lprobs.view(bsz, -1, self.vocab_size),
scores.view(bsz, beam_size, -1)[:, :, :step],
)
num_remaining_sent -= len(finalize_hypos(step, eos_bbsz_idx, eos_scores))
assert num_remaining_sent == 0
break
# cand_bbsz_idx contains beam indices for the top candidate
# hypotheses, with a range of values: [0, bsz*beam_size),
# and dimensions: [bsz, cand_size]
cand_bbsz_idx = cand_beams.add(bbsz_offsets)
# finalize hypotheses that end in eos
# finalize hypotheses that end in eos (except for blacklisted ones)
eos_mask = cand_indices.eq(self.eos)
eos_mask[:, :beam_size][blacklist] = 0
# only consider eos when it's among the top beam_size indices
torch.masked_select(
cand_bbsz_idx[:, :beam_size],
mask=eos_mask[:, :beam_size],
out=eos_bbsz_idx,
)
finalized_sents = set()
if step >= self.min_len:
# only consider eos when it's among the top beam_size indices
if eos_bbsz_idx.numel() > 0:
torch.masked_select(
cand_bbsz_idx[:, :beam_size],
cand_scores[:, :beam_size],
mask=eos_mask[:, :beam_size],
out=eos_bbsz_idx,
out=eos_scores,
)
if eos_bbsz_idx.numel() > 0:
torch.masked_select(
cand_scores[:, :beam_size],
mask=eos_mask[:, :beam_size],
out=eos_scores,
)
finalized_sents = finalize_hypos(step, eos_bbsz_idx, eos_scores, cand_scores)
num_remaining_sent -= len(finalized_sents)
finalized_sents = finalize_hypos(step, eos_bbsz_idx, eos_scores)
num_remaining_sent -= len(finalized_sents)
assert num_remaining_sent >= 0
if num_remaining_sent == 0:
......@@ -462,6 +435,7 @@ class SequenceGenerator(object):
prefix_tokens = prefix_tokens[batch_idxs]
partial_prefix_mask_buf = partial_prefix_mask_buf[batch_idxs]
src_lengths = src_lengths[batch_idxs]
blacklist = blacklist[batch_idxs]
scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
scores_buf.resize_as_(scores)
......@@ -474,10 +448,12 @@ class SequenceGenerator(object):
else:
batch_idxs = None
# set active_mask so that values > cand_size indicate eos hypos
# and values < cand_size indicate candidate active hypos.
# After, the min values per row are the top candidate active hypos
# Set active_mask so that values > cand_size indicate eos or
# blacklisted hypos and values < cand_size indicate candidate
# active hypos. After this, the min values per row are the top
# candidate active hypos.
active_mask = buffer('active_mask')
eos_mask[:, :beam_size] |= blacklist
torch.add(
eos_mask.type_as(cand_offsets) * cand_size,
cand_offsets[:eos_mask.size(1)],
......@@ -486,12 +462,16 @@ class SequenceGenerator(object):
# get the top beam_size active hypotheses, which are just the hypos
# with the smallest values in active_mask
active_hypos, _ignore = buffer('active_hypos'), buffer('_ignore')
active_hypos, new_blacklist = buffer('active_hypos'), buffer('new_blacklist')
torch.topk(
active_mask, k=beam_size, dim=1, largest=False,
out=(_ignore, active_hypos)
out=(new_blacklist, active_hypos)
)
# update blacklist to ignore any finalized hypos
blacklist = new_blacklist.ge(cand_size)[:, :beam_size]
assert (~blacklist).any(dim=1).all()
active_bbsz_idx = buffer('active_bbsz_idx')
torch.gather(
cand_bbsz_idx, dim=1, index=active_hypos,
......
......@@ -193,7 +193,6 @@ class FairseqTask(object):
max_len_a=getattr(args, 'max_len_a', 0),
max_len_b=getattr(args, 'max_len_b', 200),
min_len=getattr(args, 'min_len', 1),
stop_early=(not getattr(args, 'no_early_stop', False)),
normalize_scores=(not getattr(args, 'unnormalized', False)),
len_penalty=getattr(args, 'lenpen', 1),
unk_penalty=getattr(args, 'unkpen', 0),
......
......@@ -137,23 +137,6 @@ class TestSequenceGenerator(TestSequenceGeneratorBase):
self.assertHypoTokens(hypos[1][1], [w2, w2, eos])
self.assertHypoScore(hypos[1][1], [0.3, 0.9, 0.01])
def test_no_stop_early(self):
generator = SequenceGenerator(self.tgt_dict, stop_early=False, beam_size=2)
hypos = generator.generate([self.model], self.sample)
eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
# sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos])
self.assertHypoScore(hypos[0][0], [0.9, 1.0])
# sentence 1, beam 2
self.assertHypoTokens(hypos[0][1], [w2, w1, w2, eos])
self.assertHypoScore(hypos[0][1], [0.1, 0.9, 0.9, 1.0])
# sentence 2, beam 1
self.assertHypoTokens(hypos[1][0], [w2, w2, w2, w2, eos])
self.assertHypoScore(hypos[1][0], [0.3, 0.9, 0.99, 0.4, 1.0])
# sentence 2, beam 2
self.assertHypoTokens(hypos[1][1], [w1, w2, w1, eos])
self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.4, 1.0])
class TestDiverseBeamSearch(TestSequenceGeneratorBase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment