Commit 8ce6499d authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Merge internal changes (#422)

Summary:
- 04cc608: Add `--match-source-len` option to generate.py to for sequence-tagging tasks
- 19f1a40: Add `--no-repeat-ngram-size` option to generate.py for ngram blocking
Pull Request resolved: https://github.com/pytorch/fairseq/pull/422

Differential Revision: D13548445

Pulled By: myleott

fbshipit-source-id: 26d1ae83993e428fcb020dac5ae358b0e36233d9
parent 19c17b74
......@@ -315,6 +315,8 @@ def add_generation_args(parser):
'where x is the source length'))
group.add_argument('--min-len', default=1, type=float, metavar='N',
help=('minimum generation length'))
group.add_argument('--match-source-len', default=False, action='store_true',
help=('generations should match the source length'))
group.add_argument('--no-early-stop', action='store_true',
help=('continue searching even after finalizing k=beam '
'hypotheses; this is more correct, but increases '
......@@ -333,13 +335,15 @@ def add_generation_args(parser):
help='just score the reference translation')
group.add_argument('--prefix-size', default=0, type=int, metavar='PS',
help='initialize generation by target prefix of given length')
group.add_argument('--no-repeat-ngram-size', default=0, type=int, metavar='N',
help='ngram blocking such that this size ngram cannot be repeated in the generation')
group.add_argument('--sampling', action='store_true',
help='sample hypotheses instead of using beam search')
group.add_argument('--sampling-topk', default=-1, type=int, metavar='PS',
help='sample from top K likely next words instead of all words')
group.add_argument('--sampling-temperature', default=1, type=float, metavar='N',
help='temperature for random sampling')
group.add_argument('--diverse-beam-groups', default=1, type=int, metavar='N',
group.add_argument('--diverse-beam-groups', default=-1, type=int, metavar='N',
help='number of groups for Diverse Beam Search')
group.add_argument('--diverse-beam-strength', default=0.5, type=float, metavar='N',
help='strength of diversity penalty for Diverse Beam Search')
......
......@@ -5,6 +5,8 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
import torch
......@@ -47,6 +49,9 @@ class Search(object):
"""
raise NotImplementedError
def set_src_lengths(self, src_lengths):
self.src_lengths = src_lengths
class BeamSearch(Search):
......@@ -80,6 +85,25 @@ class BeamSearch(Search):
return self.scores_buf, self.indices_buf, self.beams_buf
class LengthConstrainedBeamSearch(Search):
def __init__(self, tgt_dict, min_len_a, min_len_b, max_len_a, max_len_b):
super().__init__(tgt_dict)
self.min_len_a = min_len_a
self.min_len_b = min_len_b
self.max_len_a = max_len_a
self.max_len_b = max_len_b
self.beam = BeamSearch(tgt_dict)
def step(self, step, lprobs, scores):
min_lens = self.min_len_a * self.src_lengths + self.min_len_b
max_lens = self.max_len_a * self.src_lengths + self.max_len_b
lprobs[step < min_lens, :, self.eos] = -math.inf
lprobs[step == max_lens, :, self.eos] = 0
lprobs[step > max_lens, :, self.eos] = -math.inf
return self.beam.step(step, lprobs, scores)
class DiverseBeamSearch(Search):
"""Diverse Beam Search.
......@@ -104,7 +128,6 @@ class DiverseBeamSearch(Search):
raise ValueError(
'DiverseBeamSearch requires --beam to be divisible by the number of groups'
)
group_size = beam_size // self.num_groups
# initialize diversity penalty
if self.diversity_buf is None:
......
......@@ -16,18 +16,38 @@ from fairseq.models import FairseqIncrementalDecoder
class SequenceGenerator(object):
def __init__(
self, models, tgt_dict, beam_size=1, minlen=1, maxlen=None, stop_early=True,
normalize_scores=True, len_penalty=1, unk_penalty=0, retain_dropout=False,
sampling=False, sampling_topk=-1, sampling_temperature=1,
normalize_scores=True, len_penalty=1., unk_penalty=0., retain_dropout=False,
sampling=False, sampling_topk=-1, sampling_temperature=1.,
diverse_beam_groups=-1, diverse_beam_strength=0.5,
match_source_len=False, no_repeat_ngram_size=0
):
"""Generates translations of a given source sentence.
Args:
min/maxlen: The length of the generated output will be bounded by
minlen and maxlen (not including the end-of-sentence marker).
stop_early: Stop generation immediately after we finalize beam_size
hypotheses, even though longer hypotheses might have better
normalized scores.
normalize_scores: Normalize scores by the length of the output.
beam_size (int, optional): beam width (default: 1)
min/maxlen (int, optional): the length of the generated output will
be bounded by minlen and maxlen (not including end-of-sentence)
stop_early (bool, optional): stop generation immediately after we
finalize beam_size hypotheses, even though longer hypotheses
might have better normalized scores (default: True)
normalize_scores (bool, optional): normalize scores by the length
of the output (default: True)
len_penalty (float, optional): length penalty, where <1.0 favors
shorter, >1.0 favors longer sentences (default: 1.0)
unk_penalty (float, optional): unknown word penalty, where <0
produces more unks, >0 produces fewer (default: 0.0)
retain_dropout (bool, optional): use dropout when generating
(default: False)
sampling (bool, optional): sample outputs instead of beam search
(default: False)
sampling_topk (int, optional): only sample among the top-k choices
at each step (default: -1)
sampling_temperature (float, optional): temperature for sampling,
where values >1.0 produces more uniform sampling and values
<1.0 produces sharper sampling (default: 1.0)
diverse_beam_groups/strength (float, optional): parameters for
Diverse Beam Search sampling
match_source_len (bool, optional): outputs should match the source
length (default: False)
"""
self.models = models
self.pad = tgt_dict.pad()
......@@ -44,6 +64,8 @@ class SequenceGenerator(object):
self.len_penalty = len_penalty
self.unk_penalty = unk_penalty
self.retain_dropout = retain_dropout
self.match_source_len = match_source_len
self.no_repeat_ngram_size = no_repeat_ngram_size
assert sampling_topk < 0 or sampling, '--sampling-topk requires --sampling'
......@@ -51,6 +73,10 @@ class SequenceGenerator(object):
self.search = search.Sampling(tgt_dict, sampling_topk, sampling_temperature)
elif diverse_beam_groups > 0:
self.search = search.DiverseBeamSearch(tgt_dict, diverse_beam_groups, diverse_beam_strength)
elif match_source_len:
self.search = search.LengthConstrainedBeamSearch(
tgt_dict, min_len_a=1, min_len_b=0, max_len_a=1, max_len_b=0,
)
else:
self.search = search.BeamSearch(tgt_dict)
......@@ -119,8 +145,11 @@ class SequenceGenerator(object):
def _generate(self, encoder_input, beam_size=None, maxlen=None, prefix_tokens=None):
"""See generate"""
src_tokens = encoder_input['src_tokens']
src_lengths = (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1)
bsz, srclen = src_tokens.size()
maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen
if self.match_source_len:
maxlen = src_lengths.max().item()
# the max beam size is the dictionary size - 1, since we never select pad
beam_size = beam_size if beam_size is not None else self.beam_size
......@@ -196,9 +225,11 @@ class SequenceGenerator(object):
"""
Finalize the given hypotheses at this step, while keeping the total
number of finalized hypotheses per sentence <= beam_size.
Note: the input must be in the desired finalization order, so that
hypotheses that appear earlier in the input are preferred to those
that appear later.
Args:
step: current time step
bbsz_idx: A vector of indices in the range [0, bsz*beam_size),
......@@ -241,6 +272,9 @@ class SequenceGenerator(object):
sents_seen.add((sent, unfin_idx))
if self.match_source_len and step > src_lengths[unfin_idx]:
score = -math.inf
def get_hypo():
if attn_clone is not None:
......@@ -301,6 +335,15 @@ class SequenceGenerator(object):
lprobs[:, self.pad] = -math.inf # never select pad
lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty
if self.no_repeat_ngram_size > 0:
# for each beam and batch sentence, generate a list of previous ngrams
gen_ngrams = [{} for bbsz_idx in range(bsz * beam_size)]
for bbsz_idx in range(bsz * beam_size):
gen_tokens = tokens[bbsz_idx].tolist()
for ngram in zip(*[gen_tokens[i:] for i in range(self.no_repeat_ngram_size)]):
gen_ngrams[bbsz_idx][tuple(ngram[:-1])] = \
gen_ngrams[bbsz_idx].get(tuple(ngram[:-1]), []) + [ngram[-1]]
# Record attention scores
if avg_attn_scores is not None:
if attn is None:
......@@ -314,6 +357,23 @@ class SequenceGenerator(object):
eos_bbsz_idx = buffer('eos_bbsz_idx')
eos_scores = buffer('eos_scores', type_of=scores)
if step < maxlen:
self.search.set_src_lengths(src_lengths)
if self.no_repeat_ngram_size > 0:
def calculate_banned_tokens(bbsz_idx):
# before decoding the next token, prevent decoding of ngrams that have already appeared
ngram_index = tuple(tokens[bbsz_idx, step + 2 - self.no_repeat_ngram_size:step + 1].tolist())
return gen_ngrams[bbsz_idx].get(ngram_index, [])
if step + 2 - self.no_repeat_ngram_size >= 0:
# no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
banned_tokens = [calculate_banned_tokens(bbsz_idx) for bbsz_idx in range(bsz * beam_size)]
else:
banned_tokens = [[] for bbsz_idx in range(bsz * beam_size)]
for bbsz_idx in range(bsz * beam_size):
lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = float('-Inf')
if prefix_tokens is not None and step < prefix_tokens.size(1):
probs_slice = lprobs.view(bsz, -1, lprobs.size(-1))[:, 0, :]
cand_scores = torch.gather(
......@@ -339,8 +399,7 @@ class SequenceGenerator(object):
descending=True,
out=(eos_scores, eos_bbsz_idx),
)
num_remaining_sent -= len(finalize_hypos(
step, eos_bbsz_idx, eos_scores))
num_remaining_sent -= len(finalize_hypos(step, eos_bbsz_idx, eos_scores))
assert num_remaining_sent == 0
break
......@@ -366,8 +425,7 @@ class SequenceGenerator(object):
mask=eos_mask[:, :beam_size],
out=eos_scores,
)
finalized_sents = finalize_hypos(
step, eos_bbsz_idx, eos_scores, cand_scores)
finalized_sents = finalize_hypos(step, eos_bbsz_idx, eos_scores, cand_scores)
num_remaining_sent -= len(finalized_sents)
assert num_remaining_sent >= 0
......@@ -387,11 +445,11 @@ class SequenceGenerator(object):
cand_beams = cand_beams[batch_idxs]
bbsz_offsets.resize_(new_bsz, 1)
cand_bbsz_idx = cand_beams.add(bbsz_offsets)
cand_scores = cand_scores[batch_idxs]
cand_indices = cand_indices[batch_idxs]
if prefix_tokens is not None:
prefix_tokens = prefix_tokens[batch_idxs]
src_lengths = src_lengths[batch_idxs]
scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
scores_buf.resize_as_(scores)
......
......@@ -82,6 +82,7 @@ def main(args):
len_penalty=args.lenpen, unk_penalty=args.unkpen,
sampling=args.sampling, sampling_topk=args.sampling_topk, sampling_temperature=args.sampling_temperature,
diverse_beam_groups=args.diverse_beam_groups, diverse_beam_strength=args.diverse_beam_strength,
match_source_len=args.match_source_len, no_repeat_ngram_size=args.no_repeat_ngram_size,
)
if use_cuda:
......
......@@ -97,6 +97,7 @@ def main(args):
len_penalty=args.lenpen, unk_penalty=args.unkpen,
sampling=args.sampling, sampling_topk=args.sampling_topk, sampling_temperature=args.sampling_temperature,
diverse_beam_groups=args.diverse_beam_groups, diverse_beam_strength=args.diverse_beam_strength,
match_source_len=args.match_source_len, no_repeat_ngram_size=args.no_repeat_ngram_size,
)
if use_cuda:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment