Commit 7f538f54 authored by Myle Ott's avatar Myle Ott
Browse files

Fix --prefix-size

parent 2a681d99
......@@ -236,7 +236,7 @@ def add_generation_args(parser):
group.add_argument('--score-reference', action='store_true',
help='just score the reference translation')
group.add_argument('--prefix-size', default=0, type=int, metavar='PS',
help=('initialize generation by target prefix of given length'))
help='initialize generation by target prefix of given length')
group.add_argument('--sampling', action='store_true',
help='sample hypotheses instead of using beam search')
return group
......
......@@ -322,7 +322,6 @@ class SequenceGenerator(object):
else:
# take the best 2 x beam_size predictions. We'll choose the first
# beam_size of these which don't predict eos to continue with.
torch.topk(
probs.view(bsz, -1),
k=min(cand_size, probs.view(bsz, -1).size(1) - 1), # -1 so we never select pad
......@@ -375,10 +374,9 @@ class SequenceGenerator(object):
assert step < maxlen
if len(finalized_sents) > 0:
# construct batch_idxs which holds indices of batches to keep for the next pass
new_bsz = bsz - len(finalized_sents)
# construct batch_idxs which holds indices of batches to keep for the next pass
batch_mask = torch.ones(bsz).type_as(cand_indices)
batch_mask[cand_indices.new(finalized_sents)] = 0
batch_idxs = batch_mask.nonzero().squeeze(-1)
......@@ -390,6 +388,8 @@ class SequenceGenerator(object):
cand_scores = cand_scores[batch_idxs]
cand_indices = cand_indices[batch_idxs]
if prefix_tokens is not None:
prefix_tokens = prefix_tokens[batch_idxs]
scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
scores_buf.resize_as_(scores)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment