Commit 7ae79c12 authored by Louis Martin's avatar Louis Martin Committed by Myle Ott
Browse files

Refactor generation

* Split generate.py to generate.py and interactive.py and refactor code

The main motivation behind these changes is to try to decorrelate use
cases in order to implement future improvements such as unk replacement
with original string during evaluation on test and writing predictions
to output file.
The previous implementation worked well but I found it difficult to
integrate these future improvements.

* Add --replace-unk arg to be used without align dict

Replacing <unk> tokens can be beneficial even without an alignment
dictionary.
parent 8df95dcc
...@@ -116,8 +116,8 @@ def add_generation_args(parser): ...@@ -116,8 +116,8 @@ def add_generation_args(parser):
help='length penalty: <1.0 favors shorter, >1.0 favors longer sentences') help='length penalty: <1.0 favors shorter, >1.0 favors longer sentences')
group.add_argument('--unkpen', default=0, type=float, group.add_argument('--unkpen', default=0, type=float,
help='unknown word penalty: <0 produces more unks, >0 produces fewer') help='unknown word penalty: <0 produces more unks, >0 produces fewer')
group.add_argument('--unk-replace-dict', default='', type=str, group.add_argument('--replace-unk', nargs='?', const=True, default=None,
help='performs unk word replacement') help='perform unknown replacement (optionally with alignment dictionary)')
group.add_argument('--quiet', action='store_true', group.add_argument('--quiet', action='store_true',
help='Only print final scores') help='Only print final scores')
......
...@@ -14,7 +14,7 @@ import traceback ...@@ -14,7 +14,7 @@ import traceback
from torch.autograd import Variable from torch.autograd import Variable
from torch.serialization import default_restore_location from torch.serialization import default_restore_location
from fairseq import criterions, models from fairseq import criterions, models, tokenizer
def parse_args_and_arch(parser): def parse_args_and_arch(parser):
...@@ -162,3 +162,43 @@ def prepare_sample(sample, volatile=False, cuda_device=None): ...@@ -162,3 +162,43 @@ def prepare_sample(sample, volatile=False, cuda_device=None):
for key in ['src_tokens', 'input_tokens'] for key in ['src_tokens', 'input_tokens']
}, },
} }
def load_align_dict(replace_unk):
if replace_unk is None:
align_dict = None
elif isinstance(replace_unk, str):
# Load alignment dictionary for unknown word replacement if it was passed as an argument.
align_dict = {}
with open(replace_unk, 'r') as f:
for line in f:
l = line.split()
align_dict[l[0]] = l[1]
else:
# No alignment dictionary provided but we still want to perform unknown word replacement by copying the
# original source word.
align_dict = {}
return align_dict
def replace_unk(hypo_str, src_str, alignment, align_dict, unk):
# Tokens are strings here
hypo_tokens = tokenizer.tokenize_line(hypo_str)
src_tokens = tokenizer.tokenize_line(src_str)
for i, ht in enumerate(hypo_tokens):
if ht == unk:
src_token = src_tokens[alignment[i]]
# Either take the corresponding value in the aligned dictionary or just copy the original value.
hypo_tokens[i] = align_dict.get(src_token, src_token)
return ' '.join(hypo_tokens)
def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, dst_dict, remove_bpe):
hypo_str = dst_dict.string(hypo_tokens, remove_bpe)
if align_dict is not None:
hypo_str = replace_unk(hypo_str, src_str, alignment, align_dict, dst_dict.unk_string())
if align_dict is not None or remove_bpe is not None:
# Convert back to tokens for evaluating with unk replacement or without BPE
# Note that the dictionary can be modified inside the method.
hypo_tokens = tokenizer.Tokenizer.tokenize(hypo_str, dst_dict, add_if_not_exist=True)
return hypo_tokens, hypo_str, alignment
...@@ -7,9 +7,7 @@ ...@@ -7,9 +7,7 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
# #
import sys
import torch import torch
from torch.autograd import Variable
from fairseq import bleu, data, options, tokenizer, utils from fairseq import bleu, data, options, tokenizer, utils
from fairseq.meters import StopwatchMeter, TimeMeter from fairseq.meters import StopwatchMeter, TimeMeter
...@@ -22,8 +20,6 @@ def main(): ...@@ -22,8 +20,6 @@ def main():
parser.add_argument('--path', metavar='FILE', required=True, action='append', parser.add_argument('--path', metavar='FILE', required=True, action='append',
help='path(s) to model file(s)') help='path(s) to model file(s)')
dataset_args = options.add_dataset_args(parser) dataset_args = options.add_dataset_args(parser)
dataset_args.add_argument('-i', '--interactive', action='store_true',
help='generate translations in interactive mode')
dataset_args.add_argument('--batch-size', default=32, type=int, metavar='N', dataset_args.add_argument('--batch-size', default=32, type=int, metavar='N',
help='batch size') help='batch size')
dataset_args.add_argument('--gen-subset', default='test', metavar='SPLIT', dataset_args.add_argument('--gen-subset', default='test', metavar='SPLIT',
...@@ -49,8 +45,7 @@ def main(): ...@@ -49,8 +45,7 @@ def main():
print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict))) print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict))) print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
if not args.interactive: print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset.splits[args.gen_subset])))
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset.splits[args.gen_subset])))
# Optimize ensemble for generation # Optimize ensemble for generation
for model in models: for model in models:
...@@ -66,90 +61,61 @@ def main(): ...@@ -66,90 +61,61 @@ def main():
translator.cuda() translator.cuda()
# Load alignment dictionary for unknown word replacement # Load alignment dictionary for unknown word replacement
align_dict = {} # (None if no unknown word replacement, empty if no path to align dictionary)
if args.unk_replace_dict != '': align_dict = utils.load_align_dict(args.replace_unk)
assert args.interactive, \
'Unknown word replacement requires access to original source and is only supported in interactive mode' # Generate and compute BLEU score
with open(args.unk_replace_dict, 'r') as f: scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk())
for line in f: max_positions = min(model.max_encoder_positions() for model in models)
l = line.split() itr = dataset.dataloader(args.gen_subset, batch_size=args.batch_size,
align_dict[l[0]] = l[1] max_positions=max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test)
def replace_unk(hypo_str, align_str, src, unk): num_sentences = 0
hypo_tokens = hypo_str.split() with progress_bar(itr, smoothing=0, leave=False) as t:
src_tokens = tokenizer.tokenize_line(src) wps_meter = TimeMeter()
align_idx = [int(i) for i in align_str.split()] gen_timer = StopwatchMeter()
for i, ht in enumerate(hypo_tokens): translations = translator.generate_batched_itr(
if ht == unk: t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
src_token = src_tokens[align_idx[i]] cuda_device=0 if use_cuda else None, timer=gen_timer)
if src_token in align_dict: for sample_id, src_tokens, target_tokens, hypos in translations:
hypo_tokens[i] = align_dict[src_token] # Process input and ground truth
else: target_tokens = target_tokens.int().cpu()
hypo_tokens[i] = src_token src_str = dataset.src_dict.string(src_tokens, args.remove_bpe)
return ' '.join(hypo_tokens) target_str = dataset.dst_dict.string(target_tokens, args.remove_bpe, escape_unk=True)
if not args.quiet:
def display_hypotheses(id, src, orig, ref, hypos): print('S-{}\t{}'.format(sample_id, src_str))
if args.quiet: print('T-{}\t{}'.format(sample_id, target_str))
return
id_str = '' if id is None else '-{}'.format(id) # Process top predictions
src_str = dataset.src_dict.string(src, args.remove_bpe) for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
print('S{}\t{}'.format(id_str, src_str)) hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
if orig is not None: hypo_tokens=hypo['tokens'].int().cpu(),
print('O{}\t{}'.format(id_str, orig.strip())) src_str=src_str,
if ref is not None: alignment=hypo['alignment'].int().cpu(),
print('T{}\t{}'.format(id_str, dataset.dst_dict.string(ref, args.remove_bpe, escape_unk=True))) align_dict=align_dict,
for hypo in hypos: dst_dict=dataset.dst_dict,
hypo_str = dataset.dst_dict.string(hypo['tokens'], args.remove_bpe) remove_bpe=args.remove_bpe)
align_str = ' '.join(map(str, hypo['alignment']))
if args.unk_replace_dict != '': if not args.quiet:
hypo_str = replace_unk(hypo_str, align_str, orig, dataset.dst_dict.unk_string()) print('A-{}\t{}'.format(sample_id, ' '.join(map(str, alignment))))
print('H{}\t{}\t{}'.format(id_str, hypo['score'], hypo_str)) print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str))
print('A{}\t{}'.format(id_str, align_str))
# Score only the top hypothesis
if args.interactive: if i == 0:
for line in sys.stdin: if args.remove_bpe is not None:
tokens = tokenizer.Tokenizer.tokenize(line, dataset.src_dict, add_if_not_exist=False).long() # Convert the string without BPE back to tokens for evaluation
if use_cuda: target_tokens = tokenizer.Tokenizer.tokenize(target_str,
tokens = tokens.cuda() dataset.dst_dict,
translations = translator.generate(Variable(tokens.view(1, -1))) add_if_not_exist=True)
hypos = translations[0] scorer.add(target_tokens, hypo_tokens)
display_hypotheses(None, tokens, line, None, hypos[:min(len(hypos), args.nbest)])
wps_meter.update(src_tokens.size(0))
else: t.set_postfix(wps='{:5d}'.format(round(wps_meter.avg)), refresh=False)
def maybe_remove_bpe(tokens, escape_unk=False): num_sentences += 1
"""Helper for removing BPE symbols from a hypothesis."""
if args.remove_bpe is None: print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} tokens/s)'.format(
return tokens num_sentences, gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
assert (tokens == dataset.dst_dict.pad()).sum() == 0 print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
hypo_minus_bpe = dataset.dst_dict.string(tokens, args.remove_bpe, escape_unk)
return tokenizer.Tokenizer.tokenize(hypo_minus_bpe, dataset.dst_dict, add_if_not_exist=True)
# Generate and compute BLEU score
scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk())
max_positions = min(model.max_encoder_positions() for model in models)
itr = dataset.dataloader(args.gen_subset, batch_size=args.batch_size,
max_positions=max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test)
num_sentences = 0
with progress_bar(itr, smoothing=0, leave=False) as t:
wps_meter = TimeMeter()
gen_timer = StopwatchMeter()
translations = translator.generate_batched_itr(
t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
cuda_device=0 if use_cuda else None, timer=gen_timer)
for id, src, ref, hypos in translations:
ref = ref.int().cpu()
top_hypo = hypos[0]['tokens'].int().cpu()
scorer.add(maybe_remove_bpe(ref, escape_unk=True), maybe_remove_bpe(top_hypo))
display_hypotheses(id, src, None, ref, hypos[:min(len(hypos), args.nbest)])
wps_meter.update(src.size(0))
t.set_postfix(wps='{:5d}'.format(round(wps_meter.avg)), refresh=False)
num_sentences += 1
print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} tokens/s)'.format(
num_sentences, gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
if __name__ == '__main__': if __name__ == '__main__':
......
#!/usr/bin/env python3
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import sys
import torch
from torch.autograd import Variable
from fairseq import data, options, tokenizer, utils
from fairseq.sequence_generator import SequenceGenerator
def main():
parser = options.get_parser('Generation')
parser.add_argument('--path', metavar='FILE', required=True, action='append',
help='path(s) to model file(s)')
options.add_dataset_args(parser)
options.add_generation_args(parser)
args = parser.parse_args()
print(args)
use_cuda = torch.cuda.is_available() and not args.cpu
# Load dataset
# TODO: load only dictionaries
dataset = data.load_with_check(args.data, ['test'], args.source_lang, args.target_lang)
if args.source_lang is None or args.target_lang is None:
# record inferred languages in args
args.source_lang, args.target_lang = dataset.src, dataset.dst
# Load ensemble
print('| loading model(s) from {}'.format(', '.join(args.path)))
models = utils.load_ensemble_for_inference(args.path, dataset.src_dict, dataset.dst_dict)
print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
# Optimize ensemble for generation
for model in models:
model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam)
# Initialize generator
translator = SequenceGenerator(
models, beam_size=args.beam, stop_early=(not args.no_early_stop),
normalize_scores=(not args.unnormalized), len_penalty=args.lenpen,
unk_penalty=args.unkpen)
if use_cuda:
translator.cuda()
# Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary)
align_dict = utils.load_align_dict(args.replace_unk)
print('Type the input sentence and press return:')
for src_str in sys.stdin:
src_str = src_str.strip()
src_tokens = tokenizer.Tokenizer.tokenize(src_str, dataset.src_dict, add_if_not_exist=False).long()
if use_cuda:
src_tokens = src_tokens.cuda()
translations = translator.generate(Variable(src_tokens.view(1, -1)))
hypos = translations[0]
print('O\t{}'.format(src_str))
# Process top predictions
for hypo in hypos[:min(len(hypos), args.nbest)]:
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo['tokens'].int().cpu(),
src_str=src_str,
alignment=hypo['alignment'].int().cpu(),
align_dict=align_dict,
dst_dict=dataset.dst_dict,
remove_bpe=args.remove_bpe)
print('A\t{}'.format(' '.join(map(str, alignment))))
print('H\t{}\t{}'.format(hypo['score'], hypo_str))
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment