Commit 12258e57 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Fix generating with a fixed prefix

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/801

Differential Revision: D16628318

Pulled By: myleott

fbshipit-source-id: 50e93bb9108afd2ba90f1edd4f34306a7c9964a4
parent 9012e87d
...@@ -110,6 +110,7 @@ class GeneratorHubInterface(nn.Module): ...@@ -110,6 +110,7 @@ class GeneratorHubInterface(nn.Module):
# build generator using current args as well as any kwargs # build generator using current args as well as any kwargs
gen_args = copy.copy(self.args) gen_args = copy.copy(self.args)
gen_args.beam = beam
for k, v in kwargs.items(): for k, v in kwargs.items():
setattr(gen_args, k, v) setattr(gen_args, k, v)
generator = self.task.build_generator(gen_args) generator = self.task.build_generator(gen_args)
......
...@@ -156,8 +156,6 @@ class SequenceGenerator(object): ...@@ -156,8 +156,6 @@ class SequenceGenerator(object):
tokens[:, 0] = self.eos if bos_token is None else bos_token tokens[:, 0] = self.eos if bos_token is None else bos_token
attn, attn_buf = None, None attn, attn_buf = None, None
nonpad_idxs = None nonpad_idxs = None
if prefix_tokens is not None:
partial_prefix_mask_buf = torch.zeros_like(src_lengths).byte()
# The blacklist indicates candidates that should be ignored. # The blacklist indicates candidates that should be ignored.
# For example, suppose we're sampling and have already finalized 2/5 # For example, suppose we're sampling and have already finalized 2/5
...@@ -304,6 +302,35 @@ class SequenceGenerator(object): ...@@ -304,6 +302,35 @@ class SequenceGenerator(object):
elif step < self.min_len: elif step < self.min_len:
lprobs[:, self.eos] = -math.inf lprobs[:, self.eos] = -math.inf
# handle prefix tokens (possibly with different lengths)
if prefix_tokens is not None and step < prefix_tokens.size(1):
prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1)
prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1))
prefix_mask = prefix_toks.ne(self.pad)
lprobs[prefix_mask] = -math.inf
lprobs[prefix_mask] = lprobs[prefix_mask].scatter_(
-1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs
)
# if prefix includes eos, then we should make sure tokens and
# scores are the same across all beams
eos_mask = prefix_toks.eq(self.eos)
if eos_mask.any():
# validate that the first beam matches the prefix
first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[:, 0, 1:step + 1]
eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0]
target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step]
assert (first_beam == target_prefix).all()
def replicate_first_beam(tensor, mask):
tensor = tensor.view(-1, beam_size, tensor.size(-1))
tensor[mask] = tensor[mask][:, :1, :]
return tensor.view(-1, tensor.size(-1))
# copy tokens, scores and lprobs from the first beam to all beams
tokens = replicate_first_beam(tokens, eos_mask_batch_dim)
scores = replicate_first_beam(scores, eos_mask_batch_dim)
lprobs = replicate_first_beam(lprobs, eos_mask_batch_dim)
if self.no_repeat_ngram_size > 0: if self.no_repeat_ngram_size > 0:
# for each beam and batch sentence, generate a list of previous ngrams # for each beam and batch sentence, generate a list of previous ngrams
gen_ngrams = [{} for bbsz_idx in range(bsz * beam_size)] gen_ngrams = [{} for bbsz_idx in range(bsz * beam_size)]
...@@ -343,43 +370,6 @@ class SequenceGenerator(object): ...@@ -343,43 +370,6 @@ class SequenceGenerator(object):
for bbsz_idx in range(bsz * beam_size): for bbsz_idx in range(bsz * beam_size):
lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = -math.inf lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = -math.inf
if prefix_tokens is not None and step < prefix_tokens.size(1):
assert isinstance(self.search, search.BeamSearch) or bsz == 1, \
"currently only BeamSearch supports decoding with prefix_tokens"
probs_slice = lprobs.view(bsz, -1, lprobs.size(-1))[:, 0, :]
cand_scores = torch.gather(
probs_slice, dim=1,
index=prefix_tokens[:, step].view(-1, 1)
).view(-1, 1).repeat(1, cand_size)
if step > 0:
# save cumulative scores for each hypothesis
cand_scores.add_(scores[:, step - 1].view(bsz, beam_size).repeat(1, 2))
cand_indices = prefix_tokens[:, step].view(-1, 1).repeat(1, cand_size)
cand_beams = torch.zeros_like(cand_indices)
# handle prefixes of different lengths
# when step == prefix_tokens.size(1), we'll have new free-decoding batches
if prefix_tokens is not None and step <= prefix_tokens.size(1):
if step < prefix_tokens.size(1):
partial_prefix_mask = prefix_tokens[:, step].eq(self.pad)
else: # all prefixes finished force-decoding
partial_prefix_mask = torch.ones(bsz).to(prefix_tokens).byte()
if partial_prefix_mask.any():
# track new free-decoding batches, at whose very first step
# only use the first beam to eliminate repeats
prefix_step0_mask = partial_prefix_mask ^ partial_prefix_mask_buf
lprobs.view(bsz, beam_size, -1)[prefix_step0_mask, 1:] = -math.inf
partial_scores, partial_indices, partial_beams = self.search.step(
step,
lprobs.view(bsz, -1, self.vocab_size),
scores.view(bsz, beam_size, -1)[:, :, :step],
)
cand_scores[partial_prefix_mask] = partial_scores[partial_prefix_mask]
cand_indices[partial_prefix_mask] = partial_indices[partial_prefix_mask]
cand_beams[partial_prefix_mask] = partial_beams[partial_prefix_mask]
partial_prefix_mask_buf = partial_prefix_mask
else:
cand_scores, cand_indices, cand_beams = self.search.step( cand_scores, cand_indices, cand_beams = self.search.step(
step, step,
lprobs.view(bsz, -1, self.vocab_size), lprobs.view(bsz, -1, self.vocab_size),
...@@ -433,7 +423,6 @@ class SequenceGenerator(object): ...@@ -433,7 +423,6 @@ class SequenceGenerator(object):
cand_indices = cand_indices[batch_idxs] cand_indices = cand_indices[batch_idxs]
if prefix_tokens is not None: if prefix_tokens is not None:
prefix_tokens = prefix_tokens[batch_idxs] prefix_tokens = prefix_tokens[batch_idxs]
partial_prefix_mask_buf = partial_prefix_mask_buf[batch_idxs]
src_lengths = src_lengths[batch_idxs] src_lengths = src_lengths[batch_idxs]
blacklist = blacklist[batch_idxs] blacklist = blacklist[batch_idxs]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment