generate.py 7.02 KB
Newer Older
Louis Martin's avatar
Louis Martin committed
1
#!/usr/bin/env python3
Sergey Edunov's avatar
Sergey Edunov committed
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#

import sys
import torch
from torch.autograd import Variable

14
from fairseq import bleu, data, options, tokenizer, utils
Sergey Edunov's avatar
Sergey Edunov committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.progress_bar import progress_bar
from fairseq.sequence_generator import SequenceGenerator


def main():
    parser = options.get_parser('Generation')
    parser.add_argument('--path', metavar='FILE', required=True, action='append',
                        help='path(s) to model file(s)')
    dataset_args = options.add_dataset_args(parser)
    dataset_args.add_argument('-i', '--interactive', action='store_true',
                              help='generate translations in interactive mode')
    dataset_args.add_argument('--batch-size', default=32, type=int, metavar='N',
                              help='batch size')
    dataset_args.add_argument('--gen-subset', default='test', metavar='SPLIT',
                              help='data subset to generate (train, valid, test)')
    options.add_generation_args(parser)

    args = parser.parse_args()
    print(args)

    if args.no_progress_bar:
        progress_bar.enabled = False
    use_cuda = torch.cuda.is_available() and not args.cpu

40
41
42
43
44
45
46
    # Load dataset
    dataset = data.load_with_check(args.data, [args.gen_subset], args.source_lang, args.target_lang)
    if args.source_lang is None or args.target_lang is None:
        # record inferred languages in args
        args.source_lang, args.target_lang = dataset.src, dataset.dst

    # Load ensemble
Sergey Edunov's avatar
Sergey Edunov committed
47
    print('| loading model(s) from {}'.format(', '.join(args.path)))
48
    models = utils.load_ensemble_for_inference(args.path, dataset.src_dict, dataset.dst_dict)
Sergey Edunov's avatar
Sergey Edunov committed
49
50
51
52
53
54

    print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
    if not args.interactive:
        print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset.splits[args.gen_subset])))

55
    # Optimize ensemble for generation
Sergey Edunov's avatar
Sergey Edunov committed
56
    for model in models:
Myle Ott's avatar
Myle Ott committed
57
58
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam)
Sergey Edunov's avatar
Sergey Edunov committed
59
60

    # Initialize generator
61
    translator = SequenceGenerator(
62
        models, beam_size=args.beam, stop_early=(not args.no_early_stop),
63
64
        normalize_scores=(not args.unnormalized), len_penalty=args.lenpen,
        unk_penalty=args.unkpen)
65
66
67
68
    if use_cuda:
        translator.cuda()

    # Load alignment dictionary for unknown word replacement
Sergey Edunov's avatar
Sergey Edunov committed
69
70
    align_dict = {}
    if args.unk_replace_dict != '':
71
72
        assert args.interactive, \
            'Unknown word replacement requires access to original source and is only supported in interactive mode'
Sergey Edunov's avatar
Sergey Edunov committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
        with open(args.unk_replace_dict, 'r') as f:
            for line in f:
                l = line.split()
                align_dict[l[0]] = l[1]

    def replace_unk(hypo_str, align_str, src, unk):
        hypo_tokens = hypo_str.split()
        src_tokens = tokenizer.tokenize_line(src)
        align_idx = [int(i) for i in align_str.split()]
        for i, ht in enumerate(hypo_tokens):
            if ht == unk:
                src_token = src_tokens[align_idx[i]]
                if src_token in align_dict:
                    hypo_tokens[i] = align_dict[src_token]
                else:
                    hypo_tokens[i] = src_token
        return ' '.join(hypo_tokens)

    def display_hypotheses(id, src, orig, ref, hypos):
92
93
        if args.quiet:
            return
Sergey Edunov's avatar
Sergey Edunov committed
94
        id_str = '' if id is None else '-{}'.format(id)
Myle Ott's avatar
Myle Ott committed
95
        src_str = dataset.src_dict.string(src, args.remove_bpe)
Sergey Edunov's avatar
Sergey Edunov committed
96
97
98
99
        print('S{}\t{}'.format(id_str, src_str))
        if orig is not None:
            print('O{}\t{}'.format(id_str, orig.strip()))
        if ref is not None:
Myle Ott's avatar
Myle Ott committed
100
            print('T{}\t{}'.format(id_str, dataset.dst_dict.string(ref, args.remove_bpe, escape_unk=True)))
Sergey Edunov's avatar
Sergey Edunov committed
101
        for hypo in hypos:
Myle Ott's avatar
Myle Ott committed
102
            hypo_str = dataset.dst_dict.string(hypo['tokens'], args.remove_bpe)
Sergey Edunov's avatar
Sergey Edunov committed
103
104
            align_str = ' '.join(map(str, hypo['alignment']))
            if args.unk_replace_dict != '':
105
106
                hypo_str = replace_unk(hypo_str, align_str, orig, dataset.dst_dict.unk_string())
            print('H{}\t{}\t{}'.format(id_str, hypo['score'], hypo_str))
Sergey Edunov's avatar
Sergey Edunov committed
107
108
109
110
111
112
113
            print('A{}\t{}'.format(id_str, align_str))

    if args.interactive:
        for line in sys.stdin:
            tokens = tokenizer.Tokenizer.tokenize(line, dataset.src_dict, add_if_not_exist=False).long()
            if use_cuda:
                tokens = tokens.cuda()
Myle Ott's avatar
Myle Ott committed
114
            translations = translator.generate(Variable(tokens.view(1, -1)))
Sergey Edunov's avatar
Sergey Edunov committed
115
116
117
118
            hypos = translations[0]
            display_hypotheses(None, tokens, line, None, hypos[:min(len(hypos), args.nbest)])

    else:
119
        def maybe_remove_bpe(tokens, escape_unk=False):
Sergey Edunov's avatar
Sergey Edunov committed
120
            """Helper for removing BPE symbols from a hypothesis."""
Myle Ott's avatar
Myle Ott committed
121
            if args.remove_bpe is None:
Sergey Edunov's avatar
Sergey Edunov committed
122
123
                return tokens
            assert (tokens == dataset.dst_dict.pad()).sum() == 0
124
            hypo_minus_bpe = dataset.dst_dict.string(tokens, args.remove_bpe, escape_unk)
Sergey Edunov's avatar
Sergey Edunov committed
125
            return tokenizer.Tokenizer.tokenize(hypo_minus_bpe, dataset.dst_dict, add_if_not_exist=True)
Sergey Edunov's avatar
Sergey Edunov committed
126
127

        # Generate and compute BLEU score
Sergey Edunov's avatar
Sergey Edunov committed
128
        scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk())
Myle Ott's avatar
Myle Ott committed
129
        max_positions = min(model.max_encoder_positions() for model in models)
130
        itr = dataset.dataloader(args.gen_subset, batch_size=args.batch_size,
Myle Ott's avatar
Myle Ott committed
131
                                 max_positions=max_positions,
132
                                 skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test)
Sergey Edunov's avatar
Sergey Edunov committed
133
134
135
136
137
138
139
140
141
142
        num_sentences = 0
        with progress_bar(itr, smoothing=0, leave=False) as t:
            wps_meter = TimeMeter()
            gen_timer = StopwatchMeter()
            translations = translator.generate_batched_itr(
                t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
                cuda_device=0 if use_cuda else None, timer=gen_timer)
            for id, src, ref, hypos in translations:
                ref = ref.int().cpu()
                top_hypo = hypos[0]['tokens'].int().cpu()
143
                scorer.add(maybe_remove_bpe(ref, escape_unk=True), maybe_remove_bpe(top_hypo))
Sergey Edunov's avatar
Sergey Edunov committed
144
145
146
                display_hypotheses(id, src, None, ref, hypos[:min(len(hypos), args.nbest)])

                wps_meter.update(src.size(0))
147
                t.set_postfix(wps='{:5d}'.format(round(wps_meter.avg)), refresh=False)
Sergey Edunov's avatar
Sergey Edunov committed
148
149
150
151
152
153
154
155
156
                num_sentences += 1

        print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} tokens/s)'.format(
            num_sentences, gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
        print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))


if __name__ == '__main__':
    main()