generate.py 6.94 KB
Newer Older
Myle Ott's avatar
Myle Ott committed
1
#!/usr/bin/env python3 -u
Sergey Edunov's avatar
Sergey Edunov committed
2
3
4
5
6
7
8
9
10
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import torch

Myle Ott's avatar
Myle Ott committed
11
from fairseq import bleu, data, options, progress_bar, tokenizer, utils
Sergey Edunov's avatar
Sergey Edunov committed
12
13
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_generator import SequenceGenerator
Myle Ott's avatar
Myle Ott committed
14
from fairseq.sequence_scorer import SequenceScorer
Sergey Edunov's avatar
Sergey Edunov committed
15
16


Myle Ott's avatar
Myle Ott committed
17
def main(args):
Myle Ott's avatar
Myle Ott committed
18
    assert args.path is not None, '--path required for generation!'
Sergey Edunov's avatar
Sergey Edunov committed
19
    print(args)
Myle Ott's avatar
Myle Ott committed
20
21
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
Sergey Edunov's avatar
Sergey Edunov committed
22
23
24

    use_cuda = torch.cuda.is_available() and not args.cpu

25
    # Load dataset
26
    if args.replace_unk is None:
Myle Ott's avatar
Myle Ott committed
27
28
29
30
31
32
        dataset = data.load_dataset(
            args.data,
            [args.gen_subset],
            args.source_lang,
            args.target_lang,
        )
33
    else:
Myle Ott's avatar
Myle Ott committed
34
35
36
37
38
39
        dataset = data.load_raw_text_dataset(
            args.data,
            [args.gen_subset],
            args.source_lang,
            args.target_lang,
        )
40
41
42
43
44
    if args.source_lang is None or args.target_lang is None:
        # record inferred languages in args
        args.source_lang, args.target_lang = dataset.src, dataset.dst

    # Load ensemble
Sergey Edunov's avatar
Sergey Edunov committed
45
    print('| loading model(s) from {}'.format(', '.join(args.path)))
Myle Ott's avatar
Myle Ott committed
46
    models, _ = utils.load_ensemble_for_inference(args.path, dataset.src_dict, dataset.dst_dict)
Sergey Edunov's avatar
Sergey Edunov committed
47
48
49

    print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
Louis Martin's avatar
Louis Martin committed
50
    print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset.splits[args.gen_subset])))
Sergey Edunov's avatar
Sergey Edunov committed
51

52
    # Optimize ensemble for generation
Sergey Edunov's avatar
Sergey Edunov committed
53
    for model in models:
Myle Ott's avatar
Myle Ott committed
54
        model.make_generation_fast_(
Myle Ott's avatar
Myle Ott committed
55
56
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
        )
57
58

    # Load alignment dictionary for unknown word replacement
Louis Martin's avatar
Louis Martin committed
59
60
61
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

Myle Ott's avatar
Myle Ott committed
62
    # Load dataset (possibly sharded)
Louis Martin's avatar
Louis Martin committed
63
    max_positions = min(model.max_encoder_positions() for model in models)
Myle Ott's avatar
Myle Ott committed
64
    itr = dataset.eval_dataloader(
Myle Ott's avatar
Myle Ott committed
65
66
67
68
69
        args.gen_subset,
        max_sentences=args.max_sentences,
        max_positions=max_positions,
        skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test,
    )
Myle Ott's avatar
Myle Ott committed
70
71
72
73
    if args.num_shards > 1:
        if args.shard_id < 0 or args.shard_id >= args.num_shards:
            raise ValueError('--shard-id must be between 0 and num_shards')
        itr = data.sharded_iterator(itr, args.num_shards, args.shard_id)
Myle Ott's avatar
Myle Ott committed
74
75
76
77
78
79
80
81
82

    # Initialize generator
    gen_timer = StopwatchMeter()
    if args.score_reference:
        translator = SequenceScorer(models)
    else:
        translator = SequenceGenerator(
            models, beam_size=args.beam, stop_early=(not args.no_early_stop),
            normalize_scores=(not args.unnormalized), len_penalty=args.lenpen,
Myle Ott's avatar
Myle Ott committed
83
            unk_penalty=args.unkpen, sampling=args.sampling)
Myle Ott's avatar
Myle Ott committed
84
85
86
87
88
    if use_cuda:
        translator.cuda()

    # Generate and compute BLEU score
    scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk())
Louis Martin's avatar
Louis Martin committed
89
    num_sentences = 0
90
    has_target = True
Myle Ott's avatar
Myle Ott committed
91
92
93
94
95
96
    with progress_bar.build_progress_bar(args, itr) as t:
        if args.score_reference:
            translations = translator.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
        else:
            translations = translator.generate_batched_itr(
                t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
Dario Pavllo's avatar
Dario Pavllo committed
97
                cuda=use_cuda, timer=gen_timer, prefix_size=args.prefix_size)
Louis Martin's avatar
Louis Martin committed
98
99
100
        wps_meter = TimeMeter()
        for sample_id, src_tokens, target_tokens, hypos in translations:
            # Process input and ground truth
101
102
            has_target = target_tokens is not None
            target_tokens = target_tokens.int().cpu() if has_target else None
103
104
105
106
107
108
            # Either retrieve the original sentences or regenerate them from tokens.
            if align_dict is not None:
                src_str = dataset.splits[args.gen_subset].src.get_original_text(sample_id)
                target_str = dataset.splits[args.gen_subset].dst.get_original_text(sample_id)
            else:
                src_str = dataset.src_dict.string(src_tokens, args.remove_bpe)
109
110
111
                target_str = dataset.dst_dict.string(target_tokens,
                                                     args.remove_bpe,
                                                     escape_unk=True) if has_target else ''
112

Louis Martin's avatar
Louis Martin committed
113
114
            if not args.quiet:
                print('S-{}\t{}'.format(sample_id, src_str))
115
116
                if has_target:
                    print('T-{}\t{}'.format(sample_id, target_str))
Louis Martin's avatar
Louis Martin committed
117
118
119
120
121
122
123
124
125

            # Process top predictions
            for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'].int().cpu(),
                    align_dict=align_dict,
                    dst_dict=dataset.dst_dict,
Myle Ott's avatar
Myle Ott committed
126
127
                    remove_bpe=args.remove_bpe,
                )
Louis Martin's avatar
Louis Martin committed
128
129
130

                if not args.quiet:
                    print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str))
Myle Ott's avatar
Myle Ott committed
131
132
133
134
135
136
137
                    print('P-{}\t{}'.format(
                        sample_id,
                        ' '.join(map(
                            lambda x: '{:.4f}'.format(x),
                            hypo['positional_scores'].tolist(),
                        ))
                    ))
138
139
140
141
                    print('A-{}\t{}'.format(
                        sample_id,
                        ' '.join(map(lambda x: str(utils.item(x)), alignment))
                    ))
Louis Martin's avatar
Louis Martin committed
142
143

                # Score only the top hypothesis
144
                if has_target and i == 0:
145
146
                    if align_dict is not None or args.remove_bpe is not None:
                        # Convert back to tokens for evaluation with unk replacement and/or without BPE
Myle Ott's avatar
Myle Ott committed
147
148
                        target_tokens = tokenizer.Tokenizer.tokenize(
                            target_str, dataset.dst_dict, add_if_not_exist=True)
Louis Martin's avatar
Louis Martin committed
149
150
151
                    scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(src_tokens.size(0))
152
            t.log({'wps': round(wps_meter.avg)})
Louis Martin's avatar
Louis Martin committed
153
154
155
156
            num_sentences += 1

    print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} tokens/s)'.format(
        num_sentences, gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
157
158
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
Sergey Edunov's avatar
Sergey Edunov committed
159
160
161


if __name__ == '__main__':
Myle Ott's avatar
Myle Ott committed
162
163
164
    parser = options.get_generation_parser()
    args = parser.parse_args()
    main(args)