Commit 8ce6499d authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Merge internal changes (#422)

Summary:
- 04cc608: Add `--match-source-len` option to generate.py to for sequence-tagging tasks
- 19f1a40: Add `--no-repeat-ngram-size` option to generate.py for ngram blocking
Pull Request resolved: https://github.com/pytorch/fairseq/pull/422

Differential Revision: D13548445

Pulled By: myleott

fbshipit-source-id: 26d1ae83993e428fcb020dac5ae358b0e36233d9
parent 19c17b74
...@@ -315,6 +315,8 @@ def add_generation_args(parser): ...@@ -315,6 +315,8 @@ def add_generation_args(parser):
'where x is the source length')) 'where x is the source length'))
group.add_argument('--min-len', default=1, type=float, metavar='N', group.add_argument('--min-len', default=1, type=float, metavar='N',
help=('minimum generation length')) help=('minimum generation length'))
group.add_argument('--match-source-len', default=False, action='store_true',
help=('generations should match the source length'))
group.add_argument('--no-early-stop', action='store_true', group.add_argument('--no-early-stop', action='store_true',
help=('continue searching even after finalizing k=beam ' help=('continue searching even after finalizing k=beam '
'hypotheses; this is more correct, but increases ' 'hypotheses; this is more correct, but increases '
...@@ -333,13 +335,15 @@ def add_generation_args(parser): ...@@ -333,13 +335,15 @@ def add_generation_args(parser):
help='just score the reference translation') help='just score the reference translation')
group.add_argument('--prefix-size', default=0, type=int, metavar='PS', group.add_argument('--prefix-size', default=0, type=int, metavar='PS',
help='initialize generation by target prefix of given length') help='initialize generation by target prefix of given length')
group.add_argument('--no-repeat-ngram-size', default=0, type=int, metavar='N',
help='ngram blocking such that this size ngram cannot be repeated in the generation')
group.add_argument('--sampling', action='store_true', group.add_argument('--sampling', action='store_true',
help='sample hypotheses instead of using beam search') help='sample hypotheses instead of using beam search')
group.add_argument('--sampling-topk', default=-1, type=int, metavar='PS', group.add_argument('--sampling-topk', default=-1, type=int, metavar='PS',
help='sample from top K likely next words instead of all words') help='sample from top K likely next words instead of all words')
group.add_argument('--sampling-temperature', default=1, type=float, metavar='N', group.add_argument('--sampling-temperature', default=1, type=float, metavar='N',
help='temperature for random sampling') help='temperature for random sampling')
group.add_argument('--diverse-beam-groups', default=1, type=int, metavar='N', group.add_argument('--diverse-beam-groups', default=-1, type=int, metavar='N',
help='number of groups for Diverse Beam Search') help='number of groups for Diverse Beam Search')
group.add_argument('--diverse-beam-strength', default=0.5, type=float, metavar='N', group.add_argument('--diverse-beam-strength', default=0.5, type=float, metavar='N',
help='strength of diversity penalty for Diverse Beam Search') help='strength of diversity penalty for Diverse Beam Search')
......
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import math
import torch import torch
...@@ -47,6 +49,9 @@ class Search(object): ...@@ -47,6 +49,9 @@ class Search(object):
""" """
raise NotImplementedError raise NotImplementedError
def set_src_lengths(self, src_lengths):
self.src_lengths = src_lengths
class BeamSearch(Search): class BeamSearch(Search):
...@@ -80,6 +85,25 @@ class BeamSearch(Search): ...@@ -80,6 +85,25 @@ class BeamSearch(Search):
return self.scores_buf, self.indices_buf, self.beams_buf return self.scores_buf, self.indices_buf, self.beams_buf
class LengthConstrainedBeamSearch(Search):
def __init__(self, tgt_dict, min_len_a, min_len_b, max_len_a, max_len_b):
super().__init__(tgt_dict)
self.min_len_a = min_len_a
self.min_len_b = min_len_b
self.max_len_a = max_len_a
self.max_len_b = max_len_b
self.beam = BeamSearch(tgt_dict)
def step(self, step, lprobs, scores):
min_lens = self.min_len_a * self.src_lengths + self.min_len_b
max_lens = self.max_len_a * self.src_lengths + self.max_len_b
lprobs[step < min_lens, :, self.eos] = -math.inf
lprobs[step == max_lens, :, self.eos] = 0
lprobs[step > max_lens, :, self.eos] = -math.inf
return self.beam.step(step, lprobs, scores)
class DiverseBeamSearch(Search): class DiverseBeamSearch(Search):
"""Diverse Beam Search. """Diverse Beam Search.
...@@ -104,7 +128,6 @@ class DiverseBeamSearch(Search): ...@@ -104,7 +128,6 @@ class DiverseBeamSearch(Search):
raise ValueError( raise ValueError(
'DiverseBeamSearch requires --beam to be divisible by the number of groups' 'DiverseBeamSearch requires --beam to be divisible by the number of groups'
) )
group_size = beam_size // self.num_groups
# initialize diversity penalty # initialize diversity penalty
if self.diversity_buf is None: if self.diversity_buf is None:
......
...@@ -16,18 +16,38 @@ from fairseq.models import FairseqIncrementalDecoder ...@@ -16,18 +16,38 @@ from fairseq.models import FairseqIncrementalDecoder
class SequenceGenerator(object): class SequenceGenerator(object):
def __init__( def __init__(
self, models, tgt_dict, beam_size=1, minlen=1, maxlen=None, stop_early=True, self, models, tgt_dict, beam_size=1, minlen=1, maxlen=None, stop_early=True,
normalize_scores=True, len_penalty=1, unk_penalty=0, retain_dropout=False, normalize_scores=True, len_penalty=1., unk_penalty=0., retain_dropout=False,
sampling=False, sampling_topk=-1, sampling_temperature=1, sampling=False, sampling_topk=-1, sampling_temperature=1.,
diverse_beam_groups=-1, diverse_beam_strength=0.5, diverse_beam_groups=-1, diverse_beam_strength=0.5,
match_source_len=False, no_repeat_ngram_size=0
): ):
"""Generates translations of a given source sentence. """Generates translations of a given source sentence.
Args: Args:
min/maxlen: The length of the generated output will be bounded by beam_size (int, optional): beam width (default: 1)
minlen and maxlen (not including the end-of-sentence marker). min/maxlen (int, optional): the length of the generated output will
stop_early: Stop generation immediately after we finalize beam_size be bounded by minlen and maxlen (not including end-of-sentence)
hypotheses, even though longer hypotheses might have better stop_early (bool, optional): stop generation immediately after we
normalized scores. finalize beam_size hypotheses, even though longer hypotheses
normalize_scores: Normalize scores by the length of the output. might have better normalized scores (default: True)
normalize_scores (bool, optional): normalize scores by the length
of the output (default: True)
len_penalty (float, optional): length penalty, where <1.0 favors
shorter, >1.0 favors longer sentences (default: 1.0)
unk_penalty (float, optional): unknown word penalty, where <0
produces more unks, >0 produces fewer (default: 0.0)
retain_dropout (bool, optional): use dropout when generating
(default: False)
sampling (bool, optional): sample outputs instead of beam search
(default: False)
sampling_topk (int, optional): only sample among the top-k choices
at each step (default: -1)
sampling_temperature (float, optional): temperature for sampling,
where values >1.0 produces more uniform sampling and values
<1.0 produces sharper sampling (default: 1.0)
diverse_beam_groups/strength (float, optional): parameters for
Diverse Beam Search sampling
match_source_len (bool, optional): outputs should match the source
length (default: False)
""" """
self.models = models self.models = models
self.pad = tgt_dict.pad() self.pad = tgt_dict.pad()
...@@ -44,6 +64,8 @@ class SequenceGenerator(object): ...@@ -44,6 +64,8 @@ class SequenceGenerator(object):
self.len_penalty = len_penalty self.len_penalty = len_penalty
self.unk_penalty = unk_penalty self.unk_penalty = unk_penalty
self.retain_dropout = retain_dropout self.retain_dropout = retain_dropout
self.match_source_len = match_source_len
self.no_repeat_ngram_size = no_repeat_ngram_size
assert sampling_topk < 0 or sampling, '--sampling-topk requires --sampling' assert sampling_topk < 0 or sampling, '--sampling-topk requires --sampling'
...@@ -51,6 +73,10 @@ class SequenceGenerator(object): ...@@ -51,6 +73,10 @@ class SequenceGenerator(object):
self.search = search.Sampling(tgt_dict, sampling_topk, sampling_temperature) self.search = search.Sampling(tgt_dict, sampling_topk, sampling_temperature)
elif diverse_beam_groups > 0: elif diverse_beam_groups > 0:
self.search = search.DiverseBeamSearch(tgt_dict, diverse_beam_groups, diverse_beam_strength) self.search = search.DiverseBeamSearch(tgt_dict, diverse_beam_groups, diverse_beam_strength)
elif match_source_len:
self.search = search.LengthConstrainedBeamSearch(
tgt_dict, min_len_a=1, min_len_b=0, max_len_a=1, max_len_b=0,
)
else: else:
self.search = search.BeamSearch(tgt_dict) self.search = search.BeamSearch(tgt_dict)
...@@ -119,8 +145,11 @@ class SequenceGenerator(object): ...@@ -119,8 +145,11 @@ class SequenceGenerator(object):
def _generate(self, encoder_input, beam_size=None, maxlen=None, prefix_tokens=None): def _generate(self, encoder_input, beam_size=None, maxlen=None, prefix_tokens=None):
"""See generate""" """See generate"""
src_tokens = encoder_input['src_tokens'] src_tokens = encoder_input['src_tokens']
src_lengths = (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1)
bsz, srclen = src_tokens.size() bsz, srclen = src_tokens.size()
maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen
if self.match_source_len:
maxlen = src_lengths.max().item()
# the max beam size is the dictionary size - 1, since we never select pad # the max beam size is the dictionary size - 1, since we never select pad
beam_size = beam_size if beam_size is not None else self.beam_size beam_size = beam_size if beam_size is not None else self.beam_size
...@@ -196,9 +225,11 @@ class SequenceGenerator(object): ...@@ -196,9 +225,11 @@ class SequenceGenerator(object):
""" """
Finalize the given hypotheses at this step, while keeping the total Finalize the given hypotheses at this step, while keeping the total
number of finalized hypotheses per sentence <= beam_size. number of finalized hypotheses per sentence <= beam_size.
Note: the input must be in the desired finalization order, so that Note: the input must be in the desired finalization order, so that
hypotheses that appear earlier in the input are preferred to those hypotheses that appear earlier in the input are preferred to those
that appear later. that appear later.
Args: Args:
step: current time step step: current time step
bbsz_idx: A vector of indices in the range [0, bsz*beam_size), bbsz_idx: A vector of indices in the range [0, bsz*beam_size),
...@@ -241,6 +272,9 @@ class SequenceGenerator(object): ...@@ -241,6 +272,9 @@ class SequenceGenerator(object):
sents_seen.add((sent, unfin_idx)) sents_seen.add((sent, unfin_idx))
if self.match_source_len and step > src_lengths[unfin_idx]:
score = -math.inf
def get_hypo(): def get_hypo():
if attn_clone is not None: if attn_clone is not None:
...@@ -301,6 +335,15 @@ class SequenceGenerator(object): ...@@ -301,6 +335,15 @@ class SequenceGenerator(object):
lprobs[:, self.pad] = -math.inf # never select pad lprobs[:, self.pad] = -math.inf # never select pad
lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty
if self.no_repeat_ngram_size > 0:
# for each beam and batch sentence, generate a list of previous ngrams
gen_ngrams = [{} for bbsz_idx in range(bsz * beam_size)]
for bbsz_idx in range(bsz * beam_size):
gen_tokens = tokens[bbsz_idx].tolist()
for ngram in zip(*[gen_tokens[i:] for i in range(self.no_repeat_ngram_size)]):
gen_ngrams[bbsz_idx][tuple(ngram[:-1])] = \
gen_ngrams[bbsz_idx].get(tuple(ngram[:-1]), []) + [ngram[-1]]
# Record attention scores # Record attention scores
if avg_attn_scores is not None: if avg_attn_scores is not None:
if attn is None: if attn is None:
...@@ -314,6 +357,23 @@ class SequenceGenerator(object): ...@@ -314,6 +357,23 @@ class SequenceGenerator(object):
eos_bbsz_idx = buffer('eos_bbsz_idx') eos_bbsz_idx = buffer('eos_bbsz_idx')
eos_scores = buffer('eos_scores', type_of=scores) eos_scores = buffer('eos_scores', type_of=scores)
if step < maxlen: if step < maxlen:
self.search.set_src_lengths(src_lengths)
if self.no_repeat_ngram_size > 0:
def calculate_banned_tokens(bbsz_idx):
# before decoding the next token, prevent decoding of ngrams that have already appeared
ngram_index = tuple(tokens[bbsz_idx, step + 2 - self.no_repeat_ngram_size:step + 1].tolist())
return gen_ngrams[bbsz_idx].get(ngram_index, [])
if step + 2 - self.no_repeat_ngram_size >= 0:
# no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
banned_tokens = [calculate_banned_tokens(bbsz_idx) for bbsz_idx in range(bsz * beam_size)]
else:
banned_tokens = [[] for bbsz_idx in range(bsz * beam_size)]
for bbsz_idx in range(bsz * beam_size):
lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = float('-Inf')
if prefix_tokens is not None and step < prefix_tokens.size(1): if prefix_tokens is not None and step < prefix_tokens.size(1):
probs_slice = lprobs.view(bsz, -1, lprobs.size(-1))[:, 0, :] probs_slice = lprobs.view(bsz, -1, lprobs.size(-1))[:, 0, :]
cand_scores = torch.gather( cand_scores = torch.gather(
...@@ -339,8 +399,7 @@ class SequenceGenerator(object): ...@@ -339,8 +399,7 @@ class SequenceGenerator(object):
descending=True, descending=True,
out=(eos_scores, eos_bbsz_idx), out=(eos_scores, eos_bbsz_idx),
) )
num_remaining_sent -= len(finalize_hypos( num_remaining_sent -= len(finalize_hypos(step, eos_bbsz_idx, eos_scores))
step, eos_bbsz_idx, eos_scores))
assert num_remaining_sent == 0 assert num_remaining_sent == 0
break break
...@@ -366,8 +425,7 @@ class SequenceGenerator(object): ...@@ -366,8 +425,7 @@ class SequenceGenerator(object):
mask=eos_mask[:, :beam_size], mask=eos_mask[:, :beam_size],
out=eos_scores, out=eos_scores,
) )
finalized_sents = finalize_hypos( finalized_sents = finalize_hypos(step, eos_bbsz_idx, eos_scores, cand_scores)
step, eos_bbsz_idx, eos_scores, cand_scores)
num_remaining_sent -= len(finalized_sents) num_remaining_sent -= len(finalized_sents)
assert num_remaining_sent >= 0 assert num_remaining_sent >= 0
...@@ -387,11 +445,11 @@ class SequenceGenerator(object): ...@@ -387,11 +445,11 @@ class SequenceGenerator(object):
cand_beams = cand_beams[batch_idxs] cand_beams = cand_beams[batch_idxs]
bbsz_offsets.resize_(new_bsz, 1) bbsz_offsets.resize_(new_bsz, 1)
cand_bbsz_idx = cand_beams.add(bbsz_offsets) cand_bbsz_idx = cand_beams.add(bbsz_offsets)
cand_scores = cand_scores[batch_idxs] cand_scores = cand_scores[batch_idxs]
cand_indices = cand_indices[batch_idxs] cand_indices = cand_indices[batch_idxs]
if prefix_tokens is not None: if prefix_tokens is not None:
prefix_tokens = prefix_tokens[batch_idxs] prefix_tokens = prefix_tokens[batch_idxs]
src_lengths = src_lengths[batch_idxs]
scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
scores_buf.resize_as_(scores) scores_buf.resize_as_(scores)
......
...@@ -82,6 +82,7 @@ def main(args): ...@@ -82,6 +82,7 @@ def main(args):
len_penalty=args.lenpen, unk_penalty=args.unkpen, len_penalty=args.lenpen, unk_penalty=args.unkpen,
sampling=args.sampling, sampling_topk=args.sampling_topk, sampling_temperature=args.sampling_temperature, sampling=args.sampling, sampling_topk=args.sampling_topk, sampling_temperature=args.sampling_temperature,
diverse_beam_groups=args.diverse_beam_groups, diverse_beam_strength=args.diverse_beam_strength, diverse_beam_groups=args.diverse_beam_groups, diverse_beam_strength=args.diverse_beam_strength,
match_source_len=args.match_source_len, no_repeat_ngram_size=args.no_repeat_ngram_size,
) )
if use_cuda: if use_cuda:
......
...@@ -97,6 +97,7 @@ def main(args): ...@@ -97,6 +97,7 @@ def main(args):
len_penalty=args.lenpen, unk_penalty=args.unkpen, len_penalty=args.lenpen, unk_penalty=args.unkpen,
sampling=args.sampling, sampling_topk=args.sampling_topk, sampling_temperature=args.sampling_temperature, sampling=args.sampling, sampling_topk=args.sampling_topk, sampling_temperature=args.sampling_temperature,
diverse_beam_groups=args.diverse_beam_groups, diverse_beam_strength=args.diverse_beam_strength, diverse_beam_groups=args.diverse_beam_groups, diverse_beam_strength=args.diverse_beam_strength,
match_source_len=args.match_source_len, no_repeat_ngram_size=args.no_repeat_ngram_size,
) )
if use_cuda: if use_cuda:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment