generate.py 7.6 KB
Newer Older
Myle Ott's avatar
Myle Ott committed
1
#!/usr/bin/env python3 -u
Sergey Edunov's avatar
Sergey Edunov committed
2
3
4
5
6
7
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
Myle Ott's avatar
Myle Ott committed
8
9
10
"""
Translate pre-processed data with a trained model.
"""
Sergey Edunov's avatar
Sergey Edunov committed
11
12
13

import torch

Myle Ott's avatar
Myle Ott committed
14
from fairseq import bleu, options, progress_bar, tasks, tokenizer, utils
Sergey Edunov's avatar
Sergey Edunov committed
15
from fairseq.meters import StopwatchMeter, TimeMeter
16
from fairseq.utils import import_user_module
Sergey Edunov's avatar
Sergey Edunov committed
17
18


Myle Ott's avatar
Myle Ott committed
19
def main(args):
Myle Ott's avatar
Myle Ott committed
20
    assert args.path is not None, '--path required for generation!'
Myle Ott's avatar
Myle Ott committed
21
22
23
24
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'
25

26
27
    import_user_module(args)

28
29
    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
Sergey Edunov's avatar
Sergey Edunov committed
30
31
32
33
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

Myle Ott's avatar
Myle Ott committed
34
35
36
37
38
39
    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))

    # Set dictionaries
Myle Ott's avatar
Myle Ott committed
40
41
42
43
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
Myle Ott's avatar
Myle Ott committed
44
    tgt_dict = task.target_dictionary
45
46

    # Load ensemble
47
    print('| loading model(s) from {}'.format(args.path))
Myle Ott's avatar
Myle Ott committed
48
49
50
    models, _model_args = utils.load_ensemble_for_inference(
        args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides),
    )
Sergey Edunov's avatar
Sergey Edunov committed
51

52
    # Optimize ensemble for generation
Sergey Edunov's avatar
Sergey Edunov committed
53
    for model in models:
Myle Ott's avatar
Myle Ott committed
54
55
56
57
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
Myle Ott's avatar
Myle Ott committed
58
59
        if args.fp16:
            model.half()
Myle Ott's avatar
Myle Ott committed
60
61
        if use_cuda:
            model.cuda()
62
63

    # Load alignment dictionary for unknown word replacement
Louis Martin's avatar
Louis Martin committed
64
65
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)
66

Myle Ott's avatar
Myle Ott committed
67
    # Load dataset (possibly sharded)
68
    itr = task.get_batch_iterator(
Myle Ott's avatar
Myle Ott committed
69
        dataset=task.dataset(args.gen_subset),
70
71
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
72
73
74
75
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]
        ),
Myle Ott's avatar
Myle Ott committed
76
77
78
79
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
Myle Ott's avatar
Myle Ott committed
80
        num_workers=args.num_workers,
Myle Ott's avatar
Myle Ott committed
81
    ).next_epoch_itr(shuffle=False)
Myle Ott's avatar
Myle Ott committed
82
83
84

    # Initialize generator
    gen_timer = StopwatchMeter()
Myle Ott's avatar
Myle Ott committed
85
    generator = task.build_generator(args)
Myle Ott's avatar
Myle Ott committed
86
87

    # Generate and compute BLEU score
Myle Ott's avatar
Myle Ott committed
88
89
90
91
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
Louis Martin's avatar
Louis Martin committed
92
    num_sentences = 0
93
    has_target = True
Myle Ott's avatar
Myle Ott committed
94
    with progress_bar.build_progress_bar(args, itr) as t:
Louis Martin's avatar
Louis Martin committed
95
        wps_meter = TimeMeter()
Myle Ott's avatar
Myle Ott committed
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if 'net_input' not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample['target'][:, :args.prefix_size]

            gen_timer.start()
            hypos = task.inference_step(generator, models, sample, prefix_tokens)
            num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            for i, sample_id in enumerate(sample['id'].tolist()):
                has_target = sample['target'] is not None

                # Remove padding
                src_tokens = utils.strip_pad(sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
                target_tokens = None
Myle Ott's avatar
Nits  
Myle Ott committed
116
                if has_target:
Myle Ott's avatar
Myle Ott committed
117
                    target_tokens = utils.strip_pad(sample['target'][i, :], tgt_dict.pad()).int().cpu()
118

Myle Ott's avatar
Myle Ott committed
119
120
121
122
123
124
125
126
127
128
129
                # Either retrieve the original sentences or regenerate them from tokens.
                if align_dict is not None:
                    src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
                    target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id)
                else:
                    if src_dict is not None:
                        src_str = src_dict.string(src_tokens, args.remove_bpe)
                    else:
                        src_str = ""
                    if has_target:
                        target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True)
Louis Martin's avatar
Louis Martin committed
130
131

                if not args.quiet:
Myle Ott's avatar
Myle Ott committed
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
                    if src_dict is not None:
                        print('S-{}\t{}'.format(sample_id, src_str))
                    if has_target:
                        print('T-{}\t{}'.format(sample_id, target_str))

                # Process top predictions
                for i, hypo in enumerate(hypos[i][:min(len(hypos), args.nbest)]):
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo['tokens'].int().cpu(),
                        src_str=src_str,
                        alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=args.remove_bpe,
                    )

                    if not args.quiet:
                        print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str))
                        print('P-{}\t{}'.format(
151
                            sample_id,
Myle Ott's avatar
Myle Ott committed
152
153
154
155
                            ' '.join(map(
                                lambda x: '{:.4f}'.format(x),
                                hypo['positional_scores'].tolist(),
                            ))
156
                        ))
Louis Martin's avatar
Louis Martin committed
157

Myle Ott's avatar
Myle Ott committed
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
                        if args.print_alignment:
                            print('A-{}\t{}'.format(
                                sample_id,
                                ' '.join(map(lambda x: str(utils.item(x)), alignment))
                            ))

                    # Score only the top hypothesis
                    if has_target and i == 0:
                        if align_dict is not None or args.remove_bpe is not None:
                            # Convert back to tokens for evaluation with unk replacement and/or without BPE
                            target_tokens = tokenizer.Tokenizer.tokenize(
                                target_str, tgt_dict, add_if_not_exist=True)
                        if hasattr(scorer, 'add_string'):
                            scorer.add_string(target_str, hypo_str)
                        else:
                            scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(num_generated_tokens)
176
            t.log({'wps': round(wps_meter.avg)})
Myle Ott's avatar
Myle Ott committed
177
            num_sentences += sample['nsentences']
Louis Martin's avatar
Louis Martin committed
178

179
    print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format(
Myle Ott's avatar
Nits  
Myle Ott committed
180
        num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg))
181
182
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
Sergey Edunov's avatar
Sergey Edunov committed
183
184


Myle Ott's avatar
Myle Ott committed
185
def cli_main():
Myle Ott's avatar
Myle Ott committed
186
    parser = options.get_generation_parser()
187
    args = options.parse_args_and_arch(parser)
Myle Ott's avatar
Myle Ott committed
188
    main(args)
Myle Ott's avatar
Myle Ott committed
189
190
191
192


if __name__ == '__main__':
    cli_main()