generate.py 6.97 KB
Newer Older
Sergey Edunov's avatar
Sergey Edunov committed
1
2
3
4
5
6
7
8
9
10
11
12
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#

import sys
import torch
from torch.autograd import Variable

13
from fairseq import bleu, options, tokenizer, utils
Sergey Edunov's avatar
Sergey Edunov committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.progress_bar import progress_bar
from fairseq.sequence_generator import SequenceGenerator


def main():
    parser = options.get_parser('Generation')
    parser.add_argument('--path', metavar='FILE', required=True, action='append',
                        help='path(s) to model file(s)')
    dataset_args = options.add_dataset_args(parser)
    dataset_args.add_argument('-i', '--interactive', action='store_true',
                              help='generate translations in interactive mode')
    dataset_args.add_argument('--batch-size', default=32, type=int, metavar='N',
                              help='batch size')
    dataset_args.add_argument('--gen-subset', default='test', metavar='SPLIT',
                              help='data subset to generate (train, valid, test)')
    options.add_generation_args(parser)

    args = parser.parse_args()
    print(args)

    if args.no_progress_bar:
        progress_bar.enabled = False
    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load model and dataset
    print('| loading model(s) from {}'.format(', '.join(args.path)))
41
    models, dataset = utils.load_ensemble_for_inference(args.path, args.data, args.gen_subset)
Sergey Edunov's avatar
Sergey Edunov committed
42
43
44
45
46
47

    print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
    if not args.interactive:
        print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset.splits[args.gen_subset])))

48
49
50
51
    # Max positions is the model property but it is needed in data reader to be able to
    # ignore too long sentences
    args.max_positions = min(args.max_positions, *(m.decoder.max_positions() for m in models))

Sergey Edunov's avatar
Sergey Edunov committed
52
53
    # Optimize model for generation
    for model in models:
54
        model.make_generation_fast_(not args.no_beamable_mm)
Sergey Edunov's avatar
Sergey Edunov committed
55
56

    # Initialize generator
57
58
59
60
61
62
63
64
    translator = SequenceGenerator(
        models, dataset.dst_dict, beam_size=args.beam, stop_early=(not args.no_early_stop),
        normalize_scores=(not args.unnormalized), len_penalty=args.lenpen
    )
    if use_cuda:
        translator.cuda()

    # Load alignment dictionary for unknown word replacement
Sergey Edunov's avatar
Sergey Edunov committed
65
66
    align_dict = {}
    if args.unk_replace_dict != '':
67
68
        assert args.interactive, \
            'Unknown word replacement requires access to original source and is only supported in interactive mode'
Sergey Edunov's avatar
Sergey Edunov committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        with open(args.unk_replace_dict, 'r') as f:
            for line in f:
                l = line.split()
                align_dict[l[0]] = l[1]

    def replace_unk(hypo_str, align_str, src, unk):
        hypo_tokens = hypo_str.split()
        src_tokens = tokenizer.tokenize_line(src)
        align_idx = [int(i) for i in align_str.split()]
        for i, ht in enumerate(hypo_tokens):
            if ht == unk:
                src_token = src_tokens[align_idx[i]]
                if src_token in align_dict:
                    hypo_tokens[i] = align_dict[src_token]
                else:
                    hypo_tokens[i] = src_token
        return ' '.join(hypo_tokens)

    bpe_symbol = '@@ ' if args.remove_bpe else None
    def display_hypotheses(id, src, orig, ref, hypos):
89
90
        if args.quiet:
            return
Sergey Edunov's avatar
Sergey Edunov committed
91
        id_str = '' if id is None else '-{}'.format(id)
92
        src_str = dataset.src_dict.string(src, bpe_symbol)
Sergey Edunov's avatar
Sergey Edunov committed
93
94
95
96
        print('S{}\t{}'.format(id_str, src_str))
        if orig is not None:
            print('O{}\t{}'.format(id_str, orig.strip()))
        if ref is not None:
97
            print('T{}\t{}'.format(id_str, dataset.dst_dict.string(ref, bpe_symbol, escape_unk=True)))
Sergey Edunov's avatar
Sergey Edunov committed
98
        for hypo in hypos:
99
            hypo_str = dataset.dst_dict.string(hypo['tokens'], bpe_symbol)
Sergey Edunov's avatar
Sergey Edunov committed
100
101
            align_str = ' '.join(map(str, hypo['alignment']))
            if args.unk_replace_dict != '':
102
103
                hypo_str = replace_unk(hypo_str, align_str, orig, dataset.dst_dict.unk_string())
            print('H{}\t{}\t{}'.format(id_str, hypo['score'], hypo_str))
Sergey Edunov's avatar
Sergey Edunov committed
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
            print('A{}\t{}'.format(id_str, align_str))

    if args.interactive:
        for line in sys.stdin:
            tokens = tokenizer.Tokenizer.tokenize(line, dataset.src_dict, add_if_not_exist=False).long()
            start = dataset.src_dict.pad() + 1
            positions = torch.arange(start, start + len(tokens)).type_as(tokens)
            if use_cuda:
                positions = positions.cuda()
                tokens = tokens.cuda()
            translations = translator.generate(Variable(tokens.view(1, -1)), Variable(positions.view(1, -1)))
            hypos = translations[0]
            display_hypotheses(None, tokens, line, None, hypos[:min(len(hypos), args.nbest)])

    else:
Sergey Edunov's avatar
Sergey Edunov committed
119
120
        def maybe_remove_bpe(tokens):
            """Helper for removing BPE symbols from a hypothesis."""
Sergey Edunov's avatar
Sergey Edunov committed
121
122
123
            if not args.remove_bpe:
                return tokens
            assert (tokens == dataset.dst_dict.pad()).sum() == 0
124
            hypo_minus_bpe = dataset.dst_dict.string(tokens, bpe_symbol)
Sergey Edunov's avatar
Sergey Edunov committed
125
            return tokenizer.Tokenizer.tokenize(hypo_minus_bpe, dataset.dst_dict, add_if_not_exist=True)
Sergey Edunov's avatar
Sergey Edunov committed
126
127

        # Generate and compute BLEU score
Sergey Edunov's avatar
Sergey Edunov committed
128
        scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk())
129
130
131
        itr = dataset.dataloader(args.gen_subset, batch_size=args.batch_size,
                                 max_positions=args.max_positions,
                                 skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test)
Sergey Edunov's avatar
Sergey Edunov committed
132
133
134
135
136
137
138
139
140
141
        num_sentences = 0
        with progress_bar(itr, smoothing=0, leave=False) as t:
            wps_meter = TimeMeter()
            gen_timer = StopwatchMeter()
            translations = translator.generate_batched_itr(
                t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
                cuda_device=0 if use_cuda else None, timer=gen_timer)
            for id, src, ref, hypos in translations:
                ref = ref.int().cpu()
                top_hypo = hypos[0]['tokens'].int().cpu()
Sergey Edunov's avatar
Sergey Edunov committed
142
                scorer.add(maybe_remove_bpe(ref), maybe_remove_bpe(top_hypo))
Sergey Edunov's avatar
Sergey Edunov committed
143
144
145
                display_hypotheses(id, src, None, ref, hypos[:min(len(hypos), args.nbest)])

                wps_meter.update(src.size(0))
146
                t.set_postfix(wps='{:5d}'.format(round(wps_meter.avg)), refresh=False)
Sergey Edunov's avatar
Sergey Edunov committed
147
148
149
150
151
152
153
154
155
                num_sentences += 1

        print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} tokens/s)'.format(
            num_sentences, gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
        print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))


if __name__ == '__main__':
    main()