Commit 5fe8ea46 authored by Michael Auli's avatar Michael Auli Committed by Myle Ott
Browse files

Added -unkpen flag to generate.py following logic of Lua/Torch version

parent 6e4b7e22
......@@ -112,6 +112,8 @@ def add_generation_args(parser):
help='don\'t use BeamableMM in attention layers')
group.add_argument('--lenpen', default=1, type=float,
help='length penalty: <1.0 favors shorter, >1.0 favors longer sentences')
group.add_argument('--unkpen', default=0, type=float,
help='unknown word penalty: <0 produces more unks, >0 produces fewer')
group.add_argument('--unk-replace-dict', default='', type=str,
help='performs unk word replacement')
group.add_argument('--quiet', action='store_true',
......
......@@ -18,7 +18,8 @@ from fairseq.models import FairseqIncrementalDecoder
class SequenceGenerator(object):
def __init__(self, models, beam_size=1, minlen=1, maxlen=200,
stop_early=True, normalize_scores=True, len_penalty=1):
stop_early=True, normalize_scores=True, len_penalty=1,
unk_penalty=0):
"""Generates translations of a given source sentence.
Args:
......@@ -31,8 +32,10 @@ class SequenceGenerator(object):
"""
self.models = models
self.pad = models[0].dst_dict.pad()
self.unk = models[0].dst_dict.unk()
self.eos = models[0].dst_dict.eos()
assert all(m.dst_dict.pad() == self.pad for m in self.models[1:])
assert all(m.dst_dict.unk() == self.unk for m in self.models[1:])
assert all(m.dst_dict.eos() == self.eos for m in self.models[1:])
self.vocab_size = len(models[0].dst_dict)
self.beam_size = beam_size
......@@ -41,6 +44,7 @@ class SequenceGenerator(object):
self.stop_early = stop_early
self.normalize_scores = normalize_scores
self.len_penalty = len_penalty
self.unk_penalty = unk_penalty
def cuda(self):
for model in self.models:
......@@ -230,6 +234,7 @@ class SequenceGenerator(object):
# make probs contain cumulative scores for each hypothesis
probs.add_(scores.view(-1, 1))
probs[:, self.pad] = -math.inf # never select pad
probs[:, self.unk] -= self.unk_penalty # apply unk penalty
# Record attention scores
attn[:, :, step+1].copy_(avg_attn_scores)
......
......@@ -60,7 +60,8 @@ def main():
# Initialize generator
translator = SequenceGenerator(
models, beam_size=args.beam, stop_early=(not args.no_early_stop),
normalize_scores=(not args.unnormalized), len_penalty=args.lenpen)
normalize_scores=(not args.unnormalized), len_penalty=args.lenpen,
unk_penalty=args.unkpen)
if use_cuda:
translator.cuda()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment