"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "cd2b7f071cbfb207400e70deaf4b1b2c46c3e47c"
Commit 7f538f54 authored by Myle Ott's avatar Myle Ott
Browse files

Fix --prefix-size

parent 2a681d99
...@@ -236,7 +236,7 @@ def add_generation_args(parser): ...@@ -236,7 +236,7 @@ def add_generation_args(parser):
group.add_argument('--score-reference', action='store_true', group.add_argument('--score-reference', action='store_true',
help='just score the reference translation') help='just score the reference translation')
group.add_argument('--prefix-size', default=0, type=int, metavar='PS', group.add_argument('--prefix-size', default=0, type=int, metavar='PS',
help=('initialize generation by target prefix of given length')) help='initialize generation by target prefix of given length')
group.add_argument('--sampling', action='store_true', group.add_argument('--sampling', action='store_true',
help='sample hypotheses instead of using beam search') help='sample hypotheses instead of using beam search')
return group return group
......
...@@ -322,7 +322,6 @@ class SequenceGenerator(object): ...@@ -322,7 +322,6 @@ class SequenceGenerator(object):
else: else:
# take the best 2 x beam_size predictions. We'll choose the first # take the best 2 x beam_size predictions. We'll choose the first
# beam_size of these which don't predict eos to continue with. # beam_size of these which don't predict eos to continue with.
torch.topk( torch.topk(
probs.view(bsz, -1), probs.view(bsz, -1),
k=min(cand_size, probs.view(bsz, -1).size(1) - 1), # -1 so we never select pad k=min(cand_size, probs.view(bsz, -1).size(1) - 1), # -1 so we never select pad
...@@ -375,10 +374,9 @@ class SequenceGenerator(object): ...@@ -375,10 +374,9 @@ class SequenceGenerator(object):
assert step < maxlen assert step < maxlen
if len(finalized_sents) > 0: if len(finalized_sents) > 0:
# construct batch_idxs which holds indices of batches to keep for the next pass
new_bsz = bsz - len(finalized_sents) new_bsz = bsz - len(finalized_sents)
# construct batch_idxs which holds indices of batches to keep for the next pass
batch_mask = torch.ones(bsz).type_as(cand_indices) batch_mask = torch.ones(bsz).type_as(cand_indices)
batch_mask[cand_indices.new(finalized_sents)] = 0 batch_mask[cand_indices.new(finalized_sents)] = 0
batch_idxs = batch_mask.nonzero().squeeze(-1) batch_idxs = batch_mask.nonzero().squeeze(-1)
...@@ -390,6 +388,8 @@ class SequenceGenerator(object): ...@@ -390,6 +388,8 @@ class SequenceGenerator(object):
cand_scores = cand_scores[batch_idxs] cand_scores = cand_scores[batch_idxs]
cand_indices = cand_indices[batch_idxs] cand_indices = cand_indices[batch_idxs]
if prefix_tokens is not None:
prefix_tokens = prefix_tokens[batch_idxs]
scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
scores_buf.resize_as_(scores) scores_buf.resize_as_(scores)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment