Commit 9dc9a486 authored by yilinyang7's avatar yilinyang7 Committed by Facebook Github Bot
Browse files

when given prefix_tokens, sequence generator would generate (exactly) same...

when given prefix_tokens, sequence generator would generate (exactly) same finished candidates (#713)

Summary:
https://github.com/pytorch/fairseq/issues/712
Pull Request resolved: https://github.com/pytorch/fairseq/pull/713

Differential Revision: D15242432

Pulled By: myleott

fbshipit-source-id: a230ee48f4bf891c805609c428d7233a0ad21179
parent ee8bcb17
...@@ -158,6 +158,8 @@ class SequenceGenerator(object): ...@@ -158,6 +158,8 @@ class SequenceGenerator(object):
tokens[:, 0] = bos_token or self.eos tokens[:, 0] = bos_token or self.eos
attn, attn_buf = None, None attn, attn_buf = None, None
nonpad_idxs = None nonpad_idxs = None
if prefix_tokens is not None:
partial_prefix_mask_buf = torch.zeros_like(src_lengths).byte()
# list of completed sentences # list of completed sentences
finalized = [[] for i in range(bsz)] finalized = [[] for i in range(bsz)]
...@@ -353,6 +355,8 @@ class SequenceGenerator(object): ...@@ -353,6 +355,8 @@ class SequenceGenerator(object):
lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = -math.inf lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = -math.inf
if prefix_tokens is not None and step < prefix_tokens.size(1): if prefix_tokens is not None and step < prefix_tokens.size(1):
assert isinstance(self.search, search.BeamSearch), \
"currently only BeamSearch supports decoding with prefix_tokens"
probs_slice = lprobs.view(bsz, -1, lprobs.size(-1))[:, 0, :] probs_slice = lprobs.view(bsz, -1, lprobs.size(-1))[:, 0, :]
cand_scores = torch.gather( cand_scores = torch.gather(
probs_slice, dim=1, probs_slice, dim=1,
...@@ -364,9 +368,18 @@ class SequenceGenerator(object): ...@@ -364,9 +368,18 @@ class SequenceGenerator(object):
cand_indices = prefix_tokens[:, step].view(-1, 1).repeat(1, cand_size) cand_indices = prefix_tokens[:, step].view(-1, 1).repeat(1, cand_size)
cand_beams = torch.zeros_like(cand_indices) cand_beams = torch.zeros_like(cand_indices)
# handle prefixes of different lengths # handle prefixes of different lengths
partial_prefix_mask = prefix_tokens[:, step].eq(self.pad) # when step == prefix_tokens.size(1), we'll have new free-decoding batches
if prefix_tokens is not None and step <= prefix_tokens.size(1):
if step < prefix_tokens.size(1):
partial_prefix_mask = prefix_tokens[:, step].eq(self.pad)
else: # all prefixes finished force-decoding
partial_prefix_mask = torch.ones(bsz).to(prefix_tokens).byte()
if partial_prefix_mask.any(): if partial_prefix_mask.any():
# track new free-decoding batches, at whose very first step
# only use the first beam to eliminate repeats
prefix_step0_mask = partial_prefix_mask ^ partial_prefix_mask_buf
lprobs.view(bsz, beam_size, -1)[prefix_step0_mask, 1:] = -math.inf
partial_scores, partial_indices, partial_beams = self.search.step( partial_scores, partial_indices, partial_beams = self.search.step(
step, step,
lprobs.view(bsz, -1, self.vocab_size), lprobs.view(bsz, -1, self.vocab_size),
...@@ -375,6 +388,8 @@ class SequenceGenerator(object): ...@@ -375,6 +388,8 @@ class SequenceGenerator(object):
cand_scores[partial_prefix_mask] = partial_scores[partial_prefix_mask] cand_scores[partial_prefix_mask] = partial_scores[partial_prefix_mask]
cand_indices[partial_prefix_mask] = partial_indices[partial_prefix_mask] cand_indices[partial_prefix_mask] = partial_indices[partial_prefix_mask]
cand_beams[partial_prefix_mask] = partial_beams[partial_prefix_mask] cand_beams[partial_prefix_mask] = partial_beams[partial_prefix_mask]
partial_prefix_mask_buf = partial_prefix_mask
else: else:
cand_scores, cand_indices, cand_beams = self.search.step( cand_scores, cand_indices, cand_beams = self.search.step(
step, step,
...@@ -442,6 +457,7 @@ class SequenceGenerator(object): ...@@ -442,6 +457,7 @@ class SequenceGenerator(object):
cand_indices = cand_indices[batch_idxs] cand_indices = cand_indices[batch_idxs]
if prefix_tokens is not None: if prefix_tokens is not None:
prefix_tokens = prefix_tokens[batch_idxs] prefix_tokens = prefix_tokens[batch_idxs]
partial_prefix_mask_buf = partial_prefix_mask_buf[batch_idxs]
src_lengths = src_lengths[batch_idxs] src_lengths = src_lengths[batch_idxs]
scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
......
...@@ -225,7 +225,7 @@ class LanguageModelingTask(FairseqTask): ...@@ -225,7 +225,7 @@ class LanguageModelingTask(FairseqTask):
def inference_step(self, generator, models, sample, prefix_tokens=None): def inference_step(self, generator, models, sample, prefix_tokens=None):
with torch.no_grad(): with torch.no_grad():
if prefix_tokens is None: if prefix_tokens is None and sample['net_input']['src_tokens'].nelement():
# note: EOS has already been removed in build_dataset_for_inference # note: EOS has already been removed in build_dataset_for_inference
prefix_tokens = sample['net_input']['src_tokens'] prefix_tokens = sample['net_input']['src_tokens']
return generator.generate(models, sample, prefix_tokens=prefix_tokens) return generator.generate(models, sample, prefix_tokens=prefix_tokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment