interactive.py 6.29 KB
Newer Older
Myle Ott's avatar
Myle Ott committed
1
#!/usr/bin/env python3 -u
2
# Copyright (c) Facebook, Inc. and its affiliates.
Louis Martin's avatar
Louis Martin committed
3
#
4
5
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
Myle Ott's avatar
Myle Ott committed
6
7
8
"""
Translate raw text with a trained model. Batches data on-the-fly.
"""
Myle Ott's avatar
Myle Ott committed
9

Myle Ott's avatar
Myle Ott committed
10
from collections import namedtuple
11
import fileinput
Myle Ott's avatar
Myle Ott committed
12

Louis Martin's avatar
Louis Martin committed
13
14
import torch

Myle Ott's avatar
Myle Ott committed
15
from fairseq import checkpoint_utils, options, tasks, utils
16
from fairseq.data import encoders
Myle Ott's avatar
Myle Ott committed
17

Myle Ott's avatar
Myle Ott committed
18

Myle Ott's avatar
Myle Ott committed
19
Batch = namedtuple('Batch', 'ids src_tokens src_lengths')
20
Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments')
21
22


23
def buffered_read(input, buffer_size):
24
    buffer = []
Myle Ott's avatar
Myle Ott committed
25
26
27
28
29
30
    with fileinput.input(files=[input], openhook=fileinput.hook_encoded("utf-8")) as h:
        for src_str in h:
            buffer.append(src_str.strip())
            if len(buffer) >= buffer_size:
                yield buffer
                buffer = []
31
32
33
34
35

    if len(buffer) > 0:
        yield buffer


36
def make_batches(lines, args, task, max_positions, encode_fn):
Myle Ott's avatar
Myle Ott committed
37
    tokens = [
38
39
40
        task.source_dictionary.encode_line(
            encode_fn(src_str), add_if_not_exist=False
        ).long()
Myle Ott's avatar
Myle Ott committed
41
42
        for src_str in lines
    ]
Myle Ott's avatar
Myle Ott committed
43
    lengths = torch.LongTensor([t.numel() for t in tokens])
44
    itr = task.get_batch_iterator(
Myle Ott's avatar
Myle Ott committed
45
        dataset=task.build_dataset_for_inference(tokens, lengths),
Myle Ott's avatar
Myle Ott committed
46
47
48
49
50
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=max_positions,
    ).next_epoch_itr(shuffle=False)
    for batch in itr:
51
        yield Batch(
Myle Ott's avatar
Myle Ott committed
52
53
54
            ids=batch['id'],
            src_tokens=batch['net_input']['src_tokens'], src_lengths=batch['net_input']['src_lengths'],
        )
55

Louis Martin's avatar
Louis Martin committed
56

Myle Ott's avatar
Myle Ott committed
57
def main(args):
Myle Ott's avatar
Myle Ott committed
58
    utils.import_user_module(args)
59

Myle Ott's avatar
Myle Ott committed
60
61
62
63
64
    if args.buffer_size < 1:
        args.buffer_size = 1
    if args.max_tokens is None and args.max_sentences is None:
        args.max_sentences = 1

Myle Ott's avatar
Myle Ott committed
65
66
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
67
68
69
    assert not args.max_sentences or args.max_sentences <= args.buffer_size, \
        '--max-sentences/--batch-size cannot be larger than --buffer-size'

Myle Ott's avatar
Myle Ott committed
70
    print(args)
Louis Martin's avatar
Louis Martin committed
71
72
73

    use_cuda = torch.cuda.is_available() and not args.cpu

Myle Ott's avatar
Myle Ott committed
74
75
76
    # Setup task, e.g., translation
    task = tasks.setup_task(args)

Louis Martin's avatar
Louis Martin committed
77
    # Load ensemble
78
    print('| loading model(s) from {}'.format(args.path))
Myle Ott's avatar
Myle Ott committed
79
80
81
82
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(':'),
        arg_overrides=eval(args.model_overrides),
        task=task,
Myle Ott's avatar
Myle Ott committed
83
    )
Louis Martin's avatar
Louis Martin committed
84

Myle Ott's avatar
Myle Ott committed
85
    # Set dictionaries
Myle Ott's avatar
Myle Ott committed
86
    src_dict = task.source_dictionary
Myle Ott's avatar
Myle Ott committed
87
    tgt_dict = task.target_dictionary
Louis Martin's avatar
Louis Martin committed
88
89
90

    # Optimize ensemble for generation
    for model in models:
Myle Ott's avatar
Myle Ott committed
91
92
93
94
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
Myle Ott's avatar
Myle Ott committed
95
96
        if args.fp16:
            model.half()
Myle Ott's avatar
Myle Ott committed
97
98
        if use_cuda:
            model.cuda()
Louis Martin's avatar
Louis Martin committed
99
100

    # Initialize generator
Myle Ott's avatar
Myle Ott committed
101
    generator = task.build_generator(args)
Louis Martin's avatar
Louis Martin committed
102

103
    # Handle tokenization and BPE
104
105
    tokenizer = encoders.build_tokenizer(args)
    bpe = encoders.build_bpe(args)
106
107
108
109
110
111
112
113
114
115
116
117
118
119

    def encode_fn(x):
        if tokenizer is not None:
            x = tokenizer.encode(x)
        if bpe is not None:
            x = bpe.encode(x)
        return x

    def decode_fn(x):
        if bpe is not None:
            x = bpe.decode(x)
        if tokenizer is not None:
            x = tokenizer.decode(x)
        return x
120

Louis Martin's avatar
Louis Martin committed
121
122
123
124
    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

125
126
127
128
129
    max_positions = utils.resolve_max_positions(
        task.max_positions(),
        *[model.max_positions() for model in models]
    )

130
131
132
    if args.buffer_size > 1:
        print('| Sentence buffer size:', args.buffer_size)
    print('| Type the input sentence and press return:')
Myle Ott's avatar
Myle Ott committed
133
    start_id = 0
134
    for inputs in buffered_read(args.input, args.buffer_size):
135
        results = []
136
        for batch in make_batches(inputs, args, task, max_positions, encode_fn):
Myle Ott's avatar
Myle Ott committed
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
            src_tokens = batch.src_tokens
            src_lengths = batch.src_lengths
            if use_cuda:
                src_tokens = src_tokens.cuda()
                src_lengths = src_lengths.cuda()

            sample = {
                'net_input': {
                    'src_tokens': src_tokens,
                    'src_lengths': src_lengths,
                },
            }
            translations = task.inference_step(generator, models, sample)
            for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)):
                src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad())
                results.append((start_id + id, src_tokens_i, hypos))

        # sort output to match input order
        for id, src_tokens, hypos in sorted(results, key=lambda x: x[0]):
156
157
158
            if src_dict is not None:
                src_str = src_dict.string(src_tokens, args.remove_bpe)
                print('S-{}\t{}'.format(id, src_str))
Myle Ott's avatar
Myle Ott committed
159
160
161
162
163
164
165
166
167
168
169

            # Process top predictions
            for hypo in hypos[:min(len(hypos), args.nbest)]:
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
                    align_dict=align_dict,
                    tgt_dict=tgt_dict,
                    remove_bpe=args.remove_bpe,
                )
170
                hypo_str = decode_fn(hypo_str)
Myle Ott's avatar
Myle Ott committed
171
172
173
174
175
176
177
178
179
180
181
182
                print('H-{}\t{}\t{}'.format(id, hypo['score'], hypo_str))
                print('P-{}\t{}'.format(
                    id,
                    ' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist()))
                ))
                if args.print_alignment:
                    print('A-{}\t{}'.format(
                        id,
                        ' '.join(map(lambda x: str(utils.item(x)), alignment))
                    ))

        # update running id counter
183
        start_id += len(inputs)
Louis Martin's avatar
Louis Martin committed
184

Myle Ott's avatar
Myle Ott committed
185

Myle Ott's avatar
Myle Ott committed
186
def cli_main():
187
    parser = options.get_generation_parser(interactive=True)
188
    args = options.parse_args_and_arch(parser)
Myle Ott's avatar
Myle Ott committed
189
    main(args)
Myle Ott's avatar
Myle Ott committed
190
191
192
193


if __name__ == '__main__':
    cli_main()