Commit 7ae79c12 authored by Louis Martin's avatar Louis Martin Committed by Myle Ott
Browse files

Refactor generation

* Split generate.py to generate.py and interactive.py and refactor code

The main motivation behind these changes is to try to decorrelate use
cases in order to implement future improvements such as unk replacement
with original string during evaluation on test and writing predictions
to output file.
The previous implementation worked well but I found it difficult to
integrate these future improvements.

* Add --replace-unk arg to be used without align dict

Replacing <unk> tokens can be beneficial even without an alignment
dictionary.
parent 8df95dcc
......@@ -116,8 +116,8 @@ def add_generation_args(parser):
help='length penalty: <1.0 favors shorter, >1.0 favors longer sentences')
group.add_argument('--unkpen', default=0, type=float,
help='unknown word penalty: <0 produces more unks, >0 produces fewer')
group.add_argument('--unk-replace-dict', default='', type=str,
help='performs unk word replacement')
group.add_argument('--replace-unk', nargs='?', const=True, default=None,
help='perform unknown replacement (optionally with alignment dictionary)')
group.add_argument('--quiet', action='store_true',
help='Only print final scores')
......
......@@ -14,7 +14,7 @@ import traceback
from torch.autograd import Variable
from torch.serialization import default_restore_location
from fairseq import criterions, models
from fairseq import criterions, models, tokenizer
def parse_args_and_arch(parser):
......@@ -162,3 +162,43 @@ def prepare_sample(sample, volatile=False, cuda_device=None):
for key in ['src_tokens', 'input_tokens']
},
}
def load_align_dict(replace_unk):
if replace_unk is None:
align_dict = None
elif isinstance(replace_unk, str):
# Load alignment dictionary for unknown word replacement if it was passed as an argument.
align_dict = {}
with open(replace_unk, 'r') as f:
for line in f:
l = line.split()
align_dict[l[0]] = l[1]
else:
# No alignment dictionary provided but we still want to perform unknown word replacement by copying the
# original source word.
align_dict = {}
return align_dict
def replace_unk(hypo_str, src_str, alignment, align_dict, unk):
# Tokens are strings here
hypo_tokens = tokenizer.tokenize_line(hypo_str)
src_tokens = tokenizer.tokenize_line(src_str)
for i, ht in enumerate(hypo_tokens):
if ht == unk:
src_token = src_tokens[alignment[i]]
# Either take the corresponding value in the aligned dictionary or just copy the original value.
hypo_tokens[i] = align_dict.get(src_token, src_token)
return ' '.join(hypo_tokens)
def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, dst_dict, remove_bpe):
hypo_str = dst_dict.string(hypo_tokens, remove_bpe)
if align_dict is not None:
hypo_str = replace_unk(hypo_str, src_str, alignment, align_dict, dst_dict.unk_string())
if align_dict is not None or remove_bpe is not None:
# Convert back to tokens for evaluating with unk replacement or without BPE
# Note that the dictionary can be modified inside the method.
hypo_tokens = tokenizer.Tokenizer.tokenize(hypo_str, dst_dict, add_if_not_exist=True)
return hypo_tokens, hypo_str, alignment
......@@ -7,9 +7,7 @@
# can be found in the PATENTS file in the same directory.
#
import sys
import torch
from torch.autograd import Variable
from fairseq import bleu, data, options, tokenizer, utils
from fairseq.meters import StopwatchMeter, TimeMeter
......@@ -22,8 +20,6 @@ def main():
parser.add_argument('--path', metavar='FILE', required=True, action='append',
help='path(s) to model file(s)')
dataset_args = options.add_dataset_args(parser)
dataset_args.add_argument('-i', '--interactive', action='store_true',
help='generate translations in interactive mode')
dataset_args.add_argument('--batch-size', default=32, type=int, metavar='N',
help='batch size')
dataset_args.add_argument('--gen-subset', default='test', metavar='SPLIT',
......@@ -49,7 +45,6 @@ def main():
print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
if not args.interactive:
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset.splits[args.gen_subset])))
# Optimize ensemble for generation
......@@ -66,63 +61,8 @@ def main():
translator.cuda()
# Load alignment dictionary for unknown word replacement
align_dict = {}
if args.unk_replace_dict != '':
assert args.interactive, \
'Unknown word replacement requires access to original source and is only supported in interactive mode'
with open(args.unk_replace_dict, 'r') as f:
for line in f:
l = line.split()
align_dict[l[0]] = l[1]
def replace_unk(hypo_str, align_str, src, unk):
hypo_tokens = hypo_str.split()
src_tokens = tokenizer.tokenize_line(src)
align_idx = [int(i) for i in align_str.split()]
for i, ht in enumerate(hypo_tokens):
if ht == unk:
src_token = src_tokens[align_idx[i]]
if src_token in align_dict:
hypo_tokens[i] = align_dict[src_token]
else:
hypo_tokens[i] = src_token
return ' '.join(hypo_tokens)
def display_hypotheses(id, src, orig, ref, hypos):
if args.quiet:
return
id_str = '' if id is None else '-{}'.format(id)
src_str = dataset.src_dict.string(src, args.remove_bpe)
print('S{}\t{}'.format(id_str, src_str))
if orig is not None:
print('O{}\t{}'.format(id_str, orig.strip()))
if ref is not None:
print('T{}\t{}'.format(id_str, dataset.dst_dict.string(ref, args.remove_bpe, escape_unk=True)))
for hypo in hypos:
hypo_str = dataset.dst_dict.string(hypo['tokens'], args.remove_bpe)
align_str = ' '.join(map(str, hypo['alignment']))
if args.unk_replace_dict != '':
hypo_str = replace_unk(hypo_str, align_str, orig, dataset.dst_dict.unk_string())
print('H{}\t{}\t{}'.format(id_str, hypo['score'], hypo_str))
print('A{}\t{}'.format(id_str, align_str))
if args.interactive:
for line in sys.stdin:
tokens = tokenizer.Tokenizer.tokenize(line, dataset.src_dict, add_if_not_exist=False).long()
if use_cuda:
tokens = tokens.cuda()
translations = translator.generate(Variable(tokens.view(1, -1)))
hypos = translations[0]
display_hypotheses(None, tokens, line, None, hypos[:min(len(hypos), args.nbest)])
else:
def maybe_remove_bpe(tokens, escape_unk=False):
"""Helper for removing BPE symbols from a hypothesis."""
if args.remove_bpe is None:
return tokens
assert (tokens == dataset.dst_dict.pad()).sum() == 0
hypo_minus_bpe = dataset.dst_dict.string(tokens, args.remove_bpe, escape_unk)
return tokenizer.Tokenizer.tokenize(hypo_minus_bpe, dataset.dst_dict, add_if_not_exist=True)
# (None if no unknown word replacement, empty if no path to align dictionary)
align_dict = utils.load_align_dict(args.replace_unk)
# Generate and compute BLEU score
scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk())
......@@ -137,13 +77,39 @@ def main():
translations = translator.generate_batched_itr(
t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
cuda_device=0 if use_cuda else None, timer=gen_timer)
for id, src, ref, hypos in translations:
ref = ref.int().cpu()
top_hypo = hypos[0]['tokens'].int().cpu()
scorer.add(maybe_remove_bpe(ref, escape_unk=True), maybe_remove_bpe(top_hypo))
display_hypotheses(id, src, None, ref, hypos[:min(len(hypos), args.nbest)])
wps_meter.update(src.size(0))
for sample_id, src_tokens, target_tokens, hypos in translations:
# Process input and ground truth
target_tokens = target_tokens.int().cpu()
src_str = dataset.src_dict.string(src_tokens, args.remove_bpe)
target_str = dataset.dst_dict.string(target_tokens, args.remove_bpe, escape_unk=True)
if not args.quiet:
print('S-{}\t{}'.format(sample_id, src_str))
print('T-{}\t{}'.format(sample_id, target_str))
# Process top predictions
for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo['tokens'].int().cpu(),
src_str=src_str,
alignment=hypo['alignment'].int().cpu(),
align_dict=align_dict,
dst_dict=dataset.dst_dict,
remove_bpe=args.remove_bpe)
if not args.quiet:
print('A-{}\t{}'.format(sample_id, ' '.join(map(str, alignment))))
print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str))
# Score only the top hypothesis
if i == 0:
if args.remove_bpe is not None:
# Convert the string without BPE back to tokens for evaluation
target_tokens = tokenizer.Tokenizer.tokenize(target_str,
dataset.dst_dict,
add_if_not_exist=True)
scorer.add(target_tokens, hypo_tokens)
wps_meter.update(src_tokens.size(0))
t.set_postfix(wps='{:5d}'.format(round(wps_meter.avg)), refresh=False)
num_sentences += 1
......
#!/usr/bin/env python3
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import sys
import torch
from torch.autograd import Variable
from fairseq import data, options, tokenizer, utils
from fairseq.sequence_generator import SequenceGenerator
def main():
parser = options.get_parser('Generation')
parser.add_argument('--path', metavar='FILE', required=True, action='append',
help='path(s) to model file(s)')
options.add_dataset_args(parser)
options.add_generation_args(parser)
args = parser.parse_args()
print(args)
use_cuda = torch.cuda.is_available() and not args.cpu
# Load dataset
# TODO: load only dictionaries
dataset = data.load_with_check(args.data, ['test'], args.source_lang, args.target_lang)
if args.source_lang is None or args.target_lang is None:
# record inferred languages in args
args.source_lang, args.target_lang = dataset.src, dataset.dst
# Load ensemble
print('| loading model(s) from {}'.format(', '.join(args.path)))
models = utils.load_ensemble_for_inference(args.path, dataset.src_dict, dataset.dst_dict)
print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
# Optimize ensemble for generation
for model in models:
model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam)
# Initialize generator
translator = SequenceGenerator(
models, beam_size=args.beam, stop_early=(not args.no_early_stop),
normalize_scores=(not args.unnormalized), len_penalty=args.lenpen,
unk_penalty=args.unkpen)
if use_cuda:
translator.cuda()
# Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary)
align_dict = utils.load_align_dict(args.replace_unk)
print('Type the input sentence and press return:')
for src_str in sys.stdin:
src_str = src_str.strip()
src_tokens = tokenizer.Tokenizer.tokenize(src_str, dataset.src_dict, add_if_not_exist=False).long()
if use_cuda:
src_tokens = src_tokens.cuda()
translations = translator.generate(Variable(src_tokens.view(1, -1)))
hypos = translations[0]
print('O\t{}'.format(src_str))
# Process top predictions
for hypo in hypos[:min(len(hypos), args.nbest)]:
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo['tokens'].int().cpu(),
src_str=src_str,
alignment=hypo['alignment'].int().cpu(),
align_dict=align_dict,
dst_dict=dataset.dst_dict,
remove_bpe=args.remove_bpe)
print('A\t{}'.format(' '.join(map(str, alignment))))
print('H\t{}\t{}'.format(hypo['score'], hypo_str))
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment