generate.py 7.31 KB
Newer Older
Myle Ott's avatar
Myle Ott committed
1
#!/usr/bin/env python3 -u
Sergey Edunov's avatar
Sergey Edunov committed
2
3
4
5
6
7
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
Myle Ott's avatar
Myle Ott committed
8
9
10
"""
Translate pre-processed data with a trained model.
"""
Sergey Edunov's avatar
Sergey Edunov committed
11
12
13

import torch

Myle Ott's avatar
Myle Ott committed
14
from fairseq import bleu, options, progress_bar, tasks, tokenizer, utils
Sergey Edunov's avatar
Sergey Edunov committed
15
16
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_generator import SequenceGenerator
Myle Ott's avatar
Myle Ott committed
17
from fairseq.sequence_scorer import SequenceScorer
18
from fairseq.utils import import_user_module
Sergey Edunov's avatar
Sergey Edunov committed
19
20


Myle Ott's avatar
Myle Ott committed
21
def main(args):
Myle Ott's avatar
Myle Ott committed
22
    assert args.path is not None, '--path required for generation!'
Myle Ott's avatar
Myle Ott committed
23
24
25
26
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'
27

28
29
    import_user_module(args)

30
31
    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
Sergey Edunov's avatar
Sergey Edunov committed
32
33
34
35
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

Myle Ott's avatar
Myle Ott committed
36
37
38
39
40
41
42
43
    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))

    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary
44
45

    # Load ensemble
46
    print('| loading model(s) from {}'.format(args.path))
Myle Ott's avatar
Myle Ott committed
47
48
49
    models, _model_args = utils.load_ensemble_for_inference(
        args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides),
    )
Sergey Edunov's avatar
Sergey Edunov committed
50

51
    # Optimize ensemble for generation
Sergey Edunov's avatar
Sergey Edunov committed
52
    for model in models:
Myle Ott's avatar
Myle Ott committed
53
54
55
56
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
Myle Ott's avatar
Myle Ott committed
57
58
        if args.fp16:
            model.half()
59
60

    # Load alignment dictionary for unknown word replacement
Louis Martin's avatar
Louis Martin committed
61
62
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)
63

Myle Ott's avatar
Myle Ott committed
64
    # Load dataset (possibly sharded)
65
    itr = task.get_batch_iterator(
Myle Ott's avatar
Myle Ott committed
66
        dataset=task.dataset(args.gen_subset),
67
68
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
69
70
71
72
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]
        ),
Myle Ott's avatar
Myle Ott committed
73
74
75
76
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
Myle Ott's avatar
Myle Ott committed
77
        num_workers=args.num_workers,
Myle Ott's avatar
Myle Ott committed
78
    ).next_epoch_itr(shuffle=False)
Myle Ott's avatar
Myle Ott committed
79
80
81
82

    # Initialize generator
    gen_timer = StopwatchMeter()
    if args.score_reference:
Myle Ott's avatar
Myle Ott committed
83
        translator = SequenceScorer(models, task.target_dictionary)
Myle Ott's avatar
Myle Ott committed
84
85
    else:
        translator = SequenceGenerator(
Myle Ott's avatar
Myle Ott committed
86
            models, task.target_dictionary, beam_size=args.beam, minlen=args.min_len,
Myle Ott's avatar
Myle Ott committed
87
88
            stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized),
            len_penalty=args.lenpen, unk_penalty=args.unkpen,
Myle Ott's avatar
Myle Ott committed
89
90
            sampling=args.sampling, sampling_topk=args.sampling_topk, sampling_temperature=args.sampling_temperature,
            diverse_beam_groups=args.diverse_beam_groups, diverse_beam_strength=args.diverse_beam_strength,
Myle Ott's avatar
Myle Ott committed
91
            match_source_len=args.match_source_len, no_repeat_ngram_size=args.no_repeat_ngram_size,
Myle Ott's avatar
Myle Ott committed
92
        )
93

Myle Ott's avatar
Myle Ott committed
94
95
96
97
    if use_cuda:
        translator.cuda()

    # Generate and compute BLEU score
Myle Ott's avatar
Myle Ott committed
98
    scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
Louis Martin's avatar
Louis Martin committed
99
    num_sentences = 0
100
    has_target = True
Myle Ott's avatar
Myle Ott committed
101
102
103
104
105
106
    with progress_bar.build_progress_bar(args, itr) as t:
        if args.score_reference:
            translations = translator.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
        else:
            translations = translator.generate_batched_itr(
                t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
Myle Ott's avatar
Myle Ott committed
107
108
109
                cuda=use_cuda, timer=gen_timer, prefix_size=args.prefix_size,
            )

Louis Martin's avatar
Louis Martin committed
110
111
112
        wps_meter = TimeMeter()
        for sample_id, src_tokens, target_tokens, hypos in translations:
            # Process input and ground truth
113
114
            has_target = target_tokens is not None
            target_tokens = target_tokens.int().cpu() if has_target else None
Myle Ott's avatar
Nits  
Myle Ott committed
115

116
117
            # Either retrieve the original sentences or regenerate them from tokens.
            if align_dict is not None:
Myle Ott's avatar
Myle Ott committed
118
119
                src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
                target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id)
120
            else:
Myle Ott's avatar
Myle Ott committed
121
                src_str = src_dict.string(src_tokens, args.remove_bpe)
Myle Ott's avatar
Nits  
Myle Ott committed
122
                if has_target:
Myle Ott's avatar
Myle Ott committed
123
                    target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True)
124

Louis Martin's avatar
Louis Martin committed
125
126
            if not args.quiet:
                print('S-{}\t{}'.format(sample_id, src_str))
127
128
                if has_target:
                    print('T-{}\t{}'.format(sample_id, target_str))
Louis Martin's avatar
Louis Martin committed
129
130
131
132
133
134

            # Process top predictions
            for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
135
                    alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
Louis Martin's avatar
Louis Martin committed
136
                    align_dict=align_dict,
Myle Ott's avatar
Myle Ott committed
137
                    tgt_dict=tgt_dict,
Myle Ott's avatar
Myle Ott committed
138
139
                    remove_bpe=args.remove_bpe,
                )
Louis Martin's avatar
Louis Martin committed
140
141
142

                if not args.quiet:
                    print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str))
Myle Ott's avatar
Myle Ott committed
143
144
145
146
147
148
149
                    print('P-{}\t{}'.format(
                        sample_id,
                        ' '.join(map(
                            lambda x: '{:.4f}'.format(x),
                            hypo['positional_scores'].tolist(),
                        ))
                    ))
150
151
152
153
154
155

                    if args.print_alignment:
                        print('A-{}\t{}'.format(
                            sample_id,
                            ' '.join(map(lambda x: str(utils.item(x)), alignment))
                        ))
Louis Martin's avatar
Louis Martin committed
156
157

                # Score only the top hypothesis
158
                if has_target and i == 0:
159
160
                    if align_dict is not None or args.remove_bpe is not None:
                        # Convert back to tokens for evaluation with unk replacement and/or without BPE
Myle Ott's avatar
Myle Ott committed
161
                        target_tokens = tokenizer.Tokenizer.tokenize(
Myle Ott's avatar
Myle Ott committed
162
                            target_str, tgt_dict, add_if_not_exist=True)
Louis Martin's avatar
Louis Martin committed
163
164
165
                    scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(src_tokens.size(0))
166
            t.log({'wps': round(wps_meter.avg)})
Louis Martin's avatar
Louis Martin committed
167
168
            num_sentences += 1

169
    print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format(
Myle Ott's avatar
Nits  
Myle Ott committed
170
        num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg))
171
172
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
Sergey Edunov's avatar
Sergey Edunov committed
173
174
175


if __name__ == '__main__':
Myle Ott's avatar
Myle Ott committed
176
    parser = options.get_generation_parser()
177
    args = options.parse_args_and_arch(parser)
Myle Ott's avatar
Myle Ott committed
178
    main(args)