interactive.py 6.4 KB
Newer Older
Myle Ott's avatar
Myle Ott committed
1
#!/usr/bin/env python3 -u
Louis Martin's avatar
Louis Martin committed
2
3
4
5
6
7
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
Myle Ott's avatar
Myle Ott committed
8
9
10
"""
Translate raw text with a trained model. Batches data on-the-fly.
"""
Myle Ott's avatar
Myle Ott committed
11

Myle Ott's avatar
Myle Ott committed
12
from collections import namedtuple
13
import fileinput
Myle Ott's avatar
Myle Ott committed
14

Louis Martin's avatar
Louis Martin committed
15
16
import torch

Myle Ott's avatar
Myle Ott committed
17
from fairseq import checkpoint_utils, options, tasks, utils
18
from fairseq.data import transforms
Myle Ott's avatar
Myle Ott committed
19

Myle Ott's avatar
Myle Ott committed
20

Myle Ott's avatar
Myle Ott committed
21
Batch = namedtuple('Batch', 'ids src_tokens src_lengths')
22
Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments')
23
24


25
def buffered_read(input, buffer_size):
26
    buffer = []
Myle Ott's avatar
Myle Ott committed
27
28
29
30
31
32
    with fileinput.input(files=[input], openhook=fileinput.hook_encoded("utf-8")) as h:
        for src_str in h:
            buffer.append(src_str.strip())
            if len(buffer) >= buffer_size:
                yield buffer
                buffer = []
33
34
35
36
37

    if len(buffer) > 0:
        yield buffer


38
def make_batches(lines, args, task, max_positions, encode_fn):
Myle Ott's avatar
Myle Ott committed
39
    tokens = [
40
41
42
        task.source_dictionary.encode_line(
            encode_fn(src_str), add_if_not_exist=False
        ).long()
Myle Ott's avatar
Myle Ott committed
43
44
        for src_str in lines
    ]
Myle Ott's avatar
Myle Ott committed
45
    lengths = torch.LongTensor([t.numel() for t in tokens])
46
    itr = task.get_batch_iterator(
Myle Ott's avatar
Myle Ott committed
47
        dataset=task.build_dataset_for_inference(tokens, lengths),
Myle Ott's avatar
Myle Ott committed
48
49
50
51
52
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=max_positions,
    ).next_epoch_itr(shuffle=False)
    for batch in itr:
53
        yield Batch(
Myle Ott's avatar
Myle Ott committed
54
55
56
            ids=batch['id'],
            src_tokens=batch['net_input']['src_tokens'], src_lengths=batch['net_input']['src_lengths'],
        )
57

Louis Martin's avatar
Louis Martin committed
58

Myle Ott's avatar
Myle Ott committed
59
def main(args):
Myle Ott's avatar
Myle Ott committed
60
    utils.import_user_module(args)
61

Myle Ott's avatar
Myle Ott committed
62
63
64
65
66
    if args.buffer_size < 1:
        args.buffer_size = 1
    if args.max_tokens is None and args.max_sentences is None:
        args.max_sentences = 1

Myle Ott's avatar
Myle Ott committed
67
68
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
69
70
71
    assert not args.max_sentences or args.max_sentences <= args.buffer_size, \
        '--max-sentences/--batch-size cannot be larger than --buffer-size'

Myle Ott's avatar
Myle Ott committed
72
    print(args)
Louis Martin's avatar
Louis Martin committed
73
74
75

    use_cuda = torch.cuda.is_available() and not args.cpu

Myle Ott's avatar
Myle Ott committed
76
77
78
    # Setup task, e.g., translation
    task = tasks.setup_task(args)

Louis Martin's avatar
Louis Martin committed
79
    # Load ensemble
80
    print('| loading model(s) from {}'.format(args.path))
Myle Ott's avatar
Myle Ott committed
81
82
83
84
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(':'),
        arg_overrides=eval(args.model_overrides),
        task=task,
Myle Ott's avatar
Myle Ott committed
85
    )
Louis Martin's avatar
Louis Martin committed
86

Myle Ott's avatar
Myle Ott committed
87
    # Set dictionaries
Myle Ott's avatar
Myle Ott committed
88
    src_dict = task.source_dictionary
Myle Ott's avatar
Myle Ott committed
89
    tgt_dict = task.target_dictionary
Louis Martin's avatar
Louis Martin committed
90
91
92

    # Optimize ensemble for generation
    for model in models:
Myle Ott's avatar
Myle Ott committed
93
94
95
96
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
Myle Ott's avatar
Myle Ott committed
97
98
        if args.fp16:
            model.half()
Myle Ott's avatar
Myle Ott committed
99
100
        if use_cuda:
            model.cuda()
Louis Martin's avatar
Louis Martin committed
101
102

    # Initialize generator
Myle Ott's avatar
Myle Ott committed
103
    generator = task.build_generator(args)
Louis Martin's avatar
Louis Martin committed
104

105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
    # Handle tokenization and BPE
    tokenizer = transforms.build_tokenizer(args)
    bpe = transforms.build_bpe(args)

    def encode_fn(x):
        if tokenizer is not None:
            x = tokenizer.encode(x)
        if bpe is not None:
            x = bpe.encode(x)
        return x

    def decode_fn(x):
        if bpe is not None:
            x = bpe.decode(x)
        if tokenizer is not None:
            x = tokenizer.decode(x)
        return x
122

Louis Martin's avatar
Louis Martin committed
123
124
125
126
    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

127
128
129
130
131
    max_positions = utils.resolve_max_positions(
        task.max_positions(),
        *[model.max_positions() for model in models]
    )

132
133
134
    if args.buffer_size > 1:
        print('| Sentence buffer size:', args.buffer_size)
    print('| Type the input sentence and press return:')
Myle Ott's avatar
Myle Ott committed
135
    start_id = 0
136
    for inputs in buffered_read(args.input, args.buffer_size):
137
        results = []
138
        for batch in make_batches(inputs, args, task, max_positions, encode_fn):
Myle Ott's avatar
Myle Ott committed
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
            src_tokens = batch.src_tokens
            src_lengths = batch.src_lengths
            if use_cuda:
                src_tokens = src_tokens.cuda()
                src_lengths = src_lengths.cuda()

            sample = {
                'net_input': {
                    'src_tokens': src_tokens,
                    'src_lengths': src_lengths,
                },
            }
            translations = task.inference_step(generator, models, sample)
            for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)):
                src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad())
                results.append((start_id + id, src_tokens_i, hypos))

        # sort output to match input order
        for id, src_tokens, hypos in sorted(results, key=lambda x: x[0]):
158
159
160
            if src_dict is not None:
                src_str = src_dict.string(src_tokens, args.remove_bpe)
                print('S-{}\t{}'.format(id, src_str))
Myle Ott's avatar
Myle Ott committed
161
162
163
164
165
166
167
168
169
170
171

            # Process top predictions
            for hypo in hypos[:min(len(hypos), args.nbest)]:
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
                    align_dict=align_dict,
                    tgt_dict=tgt_dict,
                    remove_bpe=args.remove_bpe,
                )
172
                hypo_str = decode_fn(hypo_str)
Myle Ott's avatar
Myle Ott committed
173
174
175
176
177
178
179
180
181
182
183
184
                print('H-{}\t{}\t{}'.format(id, hypo['score'], hypo_str))
                print('P-{}\t{}'.format(
                    id,
                    ' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist()))
                ))
                if args.print_alignment:
                    print('A-{}\t{}'.format(
                        id,
                        ' '.join(map(lambda x: str(utils.item(x)), alignment))
                    ))

        # update running id counter
185
        start_id += len(inputs)
Louis Martin's avatar
Louis Martin committed
186

Myle Ott's avatar
Myle Ott committed
187

Myle Ott's avatar
Myle Ott committed
188
def cli_main():
189
    parser = options.get_generation_parser(interactive=True)
190
    args = options.parse_args_and_arch(parser)
Myle Ott's avatar
Myle Ott committed
191
    main(args)
Myle Ott's avatar
Myle Ott committed
192
193
194
195


if __name__ == '__main__':
    cli_main()