generate.py 7.46 KB
Newer Older
Myle Ott's avatar
Myle Ott committed
1
#!/usr/bin/env python3 -u
2
# Copyright (c) Facebook, Inc. and its affiliates.
Sergey Edunov's avatar
Sergey Edunov committed
3
#
4
5
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
Myle Ott's avatar
Myle Ott committed
6
7
8
"""
Translate pre-processed data with a trained model.
"""
Sergey Edunov's avatar
Sergey Edunov committed
9
10
11

import torch

Myle Ott's avatar
Myle Ott committed
12
from fairseq import bleu, checkpoint_utils, options, progress_bar, tasks, utils
Sergey Edunov's avatar
Sergey Edunov committed
13
14
15
from fairseq.meters import StopwatchMeter, TimeMeter


Myle Ott's avatar
Myle Ott committed
16
def main(args):
Myle Ott's avatar
Myle Ott committed
17
    assert args.path is not None, '--path required for generation!'
Myle Ott's avatar
Myle Ott committed
18
19
20
21
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'
22

Myle Ott's avatar
Myle Ott committed
23
    utils.import_user_module(args)
24

25
26
    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
Sergey Edunov's avatar
Sergey Edunov committed
27
28
29
30
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

Myle Ott's avatar
Myle Ott committed
31
32
33
34
35
    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Set dictionaries
Myle Ott's avatar
Myle Ott committed
36
37
38
39
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
Myle Ott's avatar
Myle Ott committed
40
    tgt_dict = task.target_dictionary
41
42

    # Load ensemble
43
    print('| loading model(s) from {}'.format(args.path))
Myle Ott's avatar
Myle Ott committed
44
45
46
47
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(':'),
        arg_overrides=eval(args.model_overrides),
        task=task,
Myle Ott's avatar
Myle Ott committed
48
    )
Sergey Edunov's avatar
Sergey Edunov committed
49

50
    # Optimize ensemble for generation
Sergey Edunov's avatar
Sergey Edunov committed
51
    for model in models:
Myle Ott's avatar
Myle Ott committed
52
53
54
55
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
Myle Ott's avatar
Myle Ott committed
56
57
        if args.fp16:
            model.half()
Myle Ott's avatar
Myle Ott committed
58
59
        if use_cuda:
            model.cuda()
60
61

    # Load alignment dictionary for unknown word replacement
Louis Martin's avatar
Louis Martin committed
62
63
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)
64

Myle Ott's avatar
Myle Ott committed
65
    # Load dataset (possibly sharded)
66
    itr = task.get_batch_iterator(
Myle Ott's avatar
Myle Ott committed
67
        dataset=task.dataset(args.gen_subset),
68
69
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
70
71
72
73
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]
        ),
Myle Ott's avatar
Myle Ott committed
74
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
75
        required_batch_size_multiple=args.required_batch_size_multiple,
Myle Ott's avatar
Myle Ott committed
76
77
        num_shards=args.num_shards,
        shard_id=args.shard_id,
Myle Ott's avatar
Myle Ott committed
78
        num_workers=args.num_workers,
Myle Ott's avatar
Myle Ott committed
79
    ).next_epoch_itr(shuffle=False)
Myle Ott's avatar
Myle Ott committed
80
81
82

    # Initialize generator
    gen_timer = StopwatchMeter()
Myle Ott's avatar
Myle Ott committed
83
    generator = task.build_generator(args)
Myle Ott's avatar
Myle Ott committed
84
85

    # Generate and compute BLEU score
Myle Ott's avatar
Myle Ott committed
86
87
88
89
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
Louis Martin's avatar
Louis Martin committed
90
    num_sentences = 0
91
    has_target = True
Myle Ott's avatar
Myle Ott committed
92
    with progress_bar.build_progress_bar(args, itr) as t:
Louis Martin's avatar
Louis Martin committed
93
        wps_meter = TimeMeter()
Myle Ott's avatar
Myle Ott committed
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if 'net_input' not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample['target'][:, :args.prefix_size]

            gen_timer.start()
            hypos = task.inference_step(generator, models, sample, prefix_tokens)
            num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            for i, sample_id in enumerate(sample['id'].tolist()):
                has_target = sample['target'] is not None

                # Remove padding
                src_tokens = utils.strip_pad(sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
                target_tokens = None
Myle Ott's avatar
Nits  
Myle Ott committed
114
                if has_target:
Myle Ott's avatar
Myle Ott committed
115
                    target_tokens = utils.strip_pad(sample['target'][i, :], tgt_dict.pad()).int().cpu()
116

Myle Ott's avatar
Myle Ott committed
117
118
119
120
121
122
123
124
125
126
127
                # Either retrieve the original sentences or regenerate them from tokens.
                if align_dict is not None:
                    src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
                    target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id)
                else:
                    if src_dict is not None:
                        src_str = src_dict.string(src_tokens, args.remove_bpe)
                    else:
                        src_str = ""
                    if has_target:
                        target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True)
Louis Martin's avatar
Louis Martin committed
128
129

                if not args.quiet:
Myle Ott's avatar
Myle Ott committed
130
131
132
133
134
135
                    if src_dict is not None:
                        print('S-{}\t{}'.format(sample_id, src_str))
                    if has_target:
                        print('T-{}\t{}'.format(sample_id, target_str))

                # Process top predictions
136
                for j, hypo in enumerate(hypos[i][:args.nbest]):
Myle Ott's avatar
Myle Ott committed
137
138
139
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo['tokens'].int().cpu(),
                        src_str=src_str,
140
                        alignment=hypo['alignment'],
Myle Ott's avatar
Myle Ott committed
141
142
143
144
145
146
147
148
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=args.remove_bpe,
                    )

                    if not args.quiet:
                        print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str))
                        print('P-{}\t{}'.format(
149
                            sample_id,
Myle Ott's avatar
Myle Ott committed
150
151
152
153
                            ' '.join(map(
                                lambda x: '{:.4f}'.format(x),
                                hypo['positional_scores'].tolist(),
                            ))
154
                        ))
Louis Martin's avatar
Louis Martin committed
155

Myle Ott's avatar
Myle Ott committed
156
157
158
                        if args.print_alignment:
                            print('A-{}\t{}'.format(
                                sample_id,
159
                                ' '.join(['{}-{}'.format(src_idx, tgt_idx) for src_idx, tgt_idx in alignment])
Myle Ott's avatar
Myle Ott committed
160
161
                            ))

162
163
164
                        if args.print_step:
                            print('I-{}\t{}'.format(sample_id, hypo['steps']))

Myle Ott's avatar
Myle Ott committed
165
                    # Score only the top hypothesis
Bao-Yu's avatar
Bao-Yu committed
166
                    if has_target and j == 0:
Myle Ott's avatar
Myle Ott committed
167
168
                        if align_dict is not None or args.remove_bpe is not None:
                            # Convert back to tokens for evaluation with unk replacement and/or without BPE
169
                            target_tokens = tgt_dict.encode_line(target_str, add_if_not_exist=True)
Myle Ott's avatar
Myle Ott committed
170
171
172
173
174
175
                        if hasattr(scorer, 'add_string'):
                            scorer.add_string(target_str, hypo_str)
                        else:
                            scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(num_generated_tokens)
176
            t.log({'wps': round(wps_meter.avg)})
Myle Ott's avatar
Myle Ott committed
177
            num_sentences += sample['nsentences']
Louis Martin's avatar
Louis Martin committed
178

179
    print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format(
Myle Ott's avatar
Nits  
Myle Ott committed
180
        num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg))
181
182
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
183

Matt Le's avatar
Matt Le committed
184
    return scorer
Sergey Edunov's avatar
Sergey Edunov committed
185
186


Myle Ott's avatar
Myle Ott committed
187
def cli_main():
Myle Ott's avatar
Myle Ott committed
188
    parser = options.get_generation_parser()
189
    args = options.parse_args_and_arch(parser)
Myle Ott's avatar
Myle Ott committed
190
    main(args)
Myle Ott's avatar
Myle Ott committed
191
192
193
194


if __name__ == '__main__':
    cli_main()