Commit 866b27d5 authored by Dario Pavllo's avatar Dario Pavllo Committed by Sergey Edunov
Browse files

Add support to prefixes (#221)

* Add prefix

* Fixes

* Keep original scores with prefix

* Improve prefix code

* Replace 'repeat' with 'expand'
parent 0d90e35f
...@@ -226,6 +226,8 @@ def add_generation_args(parser): ...@@ -226,6 +226,8 @@ def add_generation_args(parser):
help='only print final scores') help='only print final scores')
group.add_argument('--score-reference', action='store_true', group.add_argument('--score-reference', action='store_true',
help='just score the reference translation') help='just score the reference translation')
group.add_argument('--prefix-size', default=0, type=int, metavar='PS',
help=('initialize generation by target prefix of given length'))
return group return group
......
...@@ -51,7 +51,7 @@ class SequenceGenerator(object): ...@@ -51,7 +51,7 @@ class SequenceGenerator(object):
return self return self
def generate_batched_itr(self, data_itr, beam_size=None, maxlen_a=0.0, maxlen_b=None, def generate_batched_itr(self, data_itr, beam_size=None, maxlen_a=0.0, maxlen_b=None,
cuda=False, timer=None): cuda=False, timer=None, prefix_size=0):
"""Iterate over a batched dataset and yield individual translations. """Iterate over a batched dataset and yield individual translations.
Args: Args:
...@@ -75,6 +75,7 @@ class SequenceGenerator(object): ...@@ -75,6 +75,7 @@ class SequenceGenerator(object):
input['src_lengths'], input['src_lengths'],
beam_size=beam_size, beam_size=beam_size,
maxlen=int(maxlen_a*srclen + maxlen_b), maxlen=int(maxlen_a*srclen + maxlen_b),
prefix_tokens=s['target'][:, :prefix_size] if prefix_size > 0 else None,
) )
if timer is not None: if timer is not None:
timer.stop(s['ntokens']) timer.stop(s['ntokens'])
...@@ -84,16 +85,16 @@ class SequenceGenerator(object): ...@@ -84,16 +85,16 @@ class SequenceGenerator(object):
ref = utils.strip_pad(s['target'].data[i, :], self.pad) ref = utils.strip_pad(s['target'].data[i, :], self.pad)
yield id, src, ref, hypos[i] yield id, src, ref, hypos[i]
def generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None): def generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None):
"""Generate a batch of translations.""" """Generate a batch of translations."""
with ExitStack() as stack: with ExitStack() as stack:
for model in self.models: for model in self.models:
if isinstance(model.decoder, FairseqIncrementalDecoder): if isinstance(model.decoder, FairseqIncrementalDecoder):
stack.enter_context(model.decoder.incremental_inference()) stack.enter_context(model.decoder.incremental_inference())
with utils.maybe_no_grad(): with utils.maybe_no_grad():
return self._generate(src_tokens, src_lengths, beam_size, maxlen) return self._generate(src_tokens, src_lengths, beam_size, maxlen, prefix_tokens)
def _generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None): def _generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None):
bsz, srclen = src_tokens.size() bsz, srclen = src_tokens.size()
maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen
...@@ -268,15 +269,25 @@ class SequenceGenerator(object): ...@@ -268,15 +269,25 @@ class SequenceGenerator(object):
eos_bbsz_idx = buffer('eos_bbsz_idx') eos_bbsz_idx = buffer('eos_bbsz_idx')
eos_scores = buffer('eos_scores', type_of=scores) eos_scores = buffer('eos_scores', type_of=scores)
if step < maxlen: if step < maxlen:
# take the best 2 x beam_size predictions. We'll choose the first if prefix_tokens is not None and step < prefix_tokens.size(1):
# beam_size of these which don't predict eos to continue with. probs_slice = probs.view(bsz, -1, probs.size(-1))[:, 0, :]
torch.topk( cand_scores = probs_slice.gather(
probs.view(bsz, -1), dim=1,
k=min(cand_size, probs.view(bsz, -1).size(1) - 1), # -1 so we never select pad index=prefix_tokens[:, step].view(-1, 1).data
out=(cand_scores, cand_indices), ).expand(-1, cand_size)
) cand_indices = prefix_tokens[:, step].view(-1, 1).expand(bsz, cand_size).data
torch.div(cand_indices, self.vocab_size, out=cand_beams) cand_beams.resize_as_(cand_indices).fill_(0)
cand_indices.fmod_(self.vocab_size) else:
# take the best 2 x beam_size predictions. We'll choose the first
# beam_size of these which don't predict eos to continue with.
torch.topk(
probs.view(bsz, -1),
k=min(cand_size, probs.view(bsz, -1).size(1) - 1), # -1 so we never select pad
out=(cand_scores, cand_indices),
)
torch.div(cand_indices, self.vocab_size, out=cand_beams)
cand_indices.fmod_(self.vocab_size)
else: else:
# finalize all active hypotheses once we hit maxlen # finalize all active hypotheses once we hit maxlen
# pick the hypothesis with the highest prob of EOS right now # pick the hypothesis with the highest prob of EOS right now
......
...@@ -90,7 +90,7 @@ def main(args): ...@@ -90,7 +90,7 @@ def main(args):
else: else:
translations = translator.generate_batched_itr( translations = translator.generate_batched_itr(
t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b, t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
cuda=use_cuda, timer=gen_timer) cuda=use_cuda, timer=gen_timer, prefix_size=args.prefix_size)
wps_meter = TimeMeter() wps_meter = TimeMeter()
for sample_id, src_tokens, target_tokens, hypos in translations: for sample_id, src_tokens, target_tokens, hypos in translations:
# Process input and ground truth # Process input and ground truth
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment