"vscode:/vscode.git/clone" did not exist on "8de78001df95a641bf6ef942bee9553921d44490"
Commit 866b27d5 authored by Dario Pavllo's avatar Dario Pavllo Committed by Sergey Edunov
Browse files

Add support to prefixes (#221)

* Add prefix

* Fixes

* Keep original scores with prefix

* Improve prefix code

* Replace 'repeat' with 'expand'
parent 0d90e35f
......@@ -226,6 +226,8 @@ def add_generation_args(parser):
help='only print final scores')
group.add_argument('--score-reference', action='store_true',
help='just score the reference translation')
group.add_argument('--prefix-size', default=0, type=int, metavar='PS',
help=('initialize generation by target prefix of given length'))
return group
......
......@@ -51,7 +51,7 @@ class SequenceGenerator(object):
return self
def generate_batched_itr(self, data_itr, beam_size=None, maxlen_a=0.0, maxlen_b=None,
cuda=False, timer=None):
cuda=False, timer=None, prefix_size=0):
"""Iterate over a batched dataset and yield individual translations.
Args:
......@@ -75,6 +75,7 @@ class SequenceGenerator(object):
input['src_lengths'],
beam_size=beam_size,
maxlen=int(maxlen_a*srclen + maxlen_b),
prefix_tokens=s['target'][:, :prefix_size] if prefix_size > 0 else None,
)
if timer is not None:
timer.stop(s['ntokens'])
......@@ -84,16 +85,16 @@ class SequenceGenerator(object):
ref = utils.strip_pad(s['target'].data[i, :], self.pad)
yield id, src, ref, hypos[i]
def generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None):
def generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None):
"""Generate a batch of translations."""
with ExitStack() as stack:
for model in self.models:
if isinstance(model.decoder, FairseqIncrementalDecoder):
stack.enter_context(model.decoder.incremental_inference())
with utils.maybe_no_grad():
return self._generate(src_tokens, src_lengths, beam_size, maxlen)
return self._generate(src_tokens, src_lengths, beam_size, maxlen, prefix_tokens)
def _generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None):
def _generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None):
bsz, srclen = src_tokens.size()
maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen
......@@ -268,15 +269,25 @@ class SequenceGenerator(object):
eos_bbsz_idx = buffer('eos_bbsz_idx')
eos_scores = buffer('eos_scores', type_of=scores)
if step < maxlen:
# take the best 2 x beam_size predictions. We'll choose the first
# beam_size of these which don't predict eos to continue with.
torch.topk(
probs.view(bsz, -1),
k=min(cand_size, probs.view(bsz, -1).size(1) - 1), # -1 so we never select pad
out=(cand_scores, cand_indices),
)
torch.div(cand_indices, self.vocab_size, out=cand_beams)
cand_indices.fmod_(self.vocab_size)
if prefix_tokens is not None and step < prefix_tokens.size(1):
probs_slice = probs.view(bsz, -1, probs.size(-1))[:, 0, :]
cand_scores = probs_slice.gather(
dim=1,
index=prefix_tokens[:, step].view(-1, 1).data
).expand(-1, cand_size)
cand_indices = prefix_tokens[:, step].view(-1, 1).expand(bsz, cand_size).data
cand_beams.resize_as_(cand_indices).fill_(0)
else:
# take the best 2 x beam_size predictions. We'll choose the first
# beam_size of these which don't predict eos to continue with.
torch.topk(
probs.view(bsz, -1),
k=min(cand_size, probs.view(bsz, -1).size(1) - 1), # -1 so we never select pad
out=(cand_scores, cand_indices),
)
torch.div(cand_indices, self.vocab_size, out=cand_beams)
cand_indices.fmod_(self.vocab_size)
else:
# finalize all active hypotheses once we hit maxlen
# pick the hypothesis with the highest prob of EOS right now
......
......@@ -90,7 +90,7 @@ def main(args):
else:
translations = translator.generate_batched_itr(
t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
cuda=use_cuda, timer=gen_timer)
cuda=use_cuda, timer=gen_timer, prefix_size=args.prefix_size)
wps_meter = TimeMeter()
for sample_id, src_tokens, target_tokens, hypos in translations:
# Process input and ground truth
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment