Commit 1b937bb2 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Add exception for bsz=1 with prefix generation (#796)

Summary:
This is a temporary workaround to support sampling after https://github.com/pytorch/fairseq/issues/713. We'll need to revisit this to support sampling and beam more generally.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/796

Differential Revision: D15760808

Pulled By: myleott

fbshipit-source-id: ecaf4f161b0c30de037f32007e4610a559a49230
parent a8f28ecb
......@@ -355,7 +355,7 @@ class SequenceGenerator(object):
lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = -math.inf
if prefix_tokens is not None and step < prefix_tokens.size(1):
assert isinstance(self.search, search.BeamSearch), \
assert isinstance(self.search, search.BeamSearch) or bsz == 1, \
"currently only BeamSearch supports decoding with prefix_tokens"
probs_slice = lprobs.view(bsz, -1, lprobs.size(-1))[:, 0, :]
cand_scores = torch.gather(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment