interactive.py 6.38 KB
Newer Older
Myle Ott's avatar
Myle Ott committed
1
#!/usr/bin/env python3 -u
Louis Martin's avatar
Louis Martin committed
2
3
4
5
6
7
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
Myle Ott's avatar
Myle Ott committed
8
9
10
"""
Translate raw text with a trained model. Batches data on-the-fly.
"""
Myle Ott's avatar
Myle Ott committed
11

Myle Ott's avatar
Myle Ott committed
12
from collections import namedtuple
13
import fileinput
Myle Ott's avatar
Myle Ott committed
14

Louis Martin's avatar
Louis Martin committed
15
16
import torch

Myle Ott's avatar
Myle Ott committed
17
from fairseq import checkpoint_utils, options, tasks, utils
Myle Ott's avatar
Myle Ott committed
18

Myle Ott's avatar
Myle Ott committed
19

Myle Ott's avatar
Myle Ott committed
20
Batch = namedtuple('Batch', 'ids src_tokens src_lengths')
21
Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments')
22
23


24
def buffered_read(input, buffer_size):
25
    buffer = []
Myle Ott's avatar
Myle Ott committed
26
27
28
29
30
31
    with fileinput.input(files=[input], openhook=fileinput.hook_encoded("utf-8")) as h:
        for src_str in h:
            buffer.append(src_str.strip())
            if len(buffer) >= buffer_size:
                yield buffer
                buffer = []
32
33
34
35
36

    if len(buffer) > 0:
        yield buffer


37
def make_batches(lines, args, task, max_positions, encode_fn):
Myle Ott's avatar
Myle Ott committed
38
    tokens = [
39
40
41
        task.source_dictionary.encode_line(
            encode_fn(src_str), add_if_not_exist=False
        ).long()
Myle Ott's avatar
Myle Ott committed
42
43
        for src_str in lines
    ]
Myle Ott's avatar
Myle Ott committed
44
    lengths = torch.LongTensor([t.numel() for t in tokens])
45
    itr = task.get_batch_iterator(
Myle Ott's avatar
Myle Ott committed
46
        dataset=task.build_dataset_for_inference(tokens, lengths),
Myle Ott's avatar
Myle Ott committed
47
48
49
50
51
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=max_positions,
    ).next_epoch_itr(shuffle=False)
    for batch in itr:
52
        yield Batch(
Myle Ott's avatar
Myle Ott committed
53
54
55
            ids=batch['id'],
            src_tokens=batch['net_input']['src_tokens'], src_lengths=batch['net_input']['src_lengths'],
        )
56

Louis Martin's avatar
Louis Martin committed
57

Myle Ott's avatar
Myle Ott committed
58
def main(args):
Myle Ott's avatar
Myle Ott committed
59
    utils.import_user_module(args)
60

Myle Ott's avatar
Myle Ott committed
61
62
63
64
65
    if args.buffer_size < 1:
        args.buffer_size = 1
    if args.max_tokens is None and args.max_sentences is None:
        args.max_sentences = 1

Myle Ott's avatar
Myle Ott committed
66
67
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
68
69
70
    assert not args.max_sentences or args.max_sentences <= args.buffer_size, \
        '--max-sentences/--batch-size cannot be larger than --buffer-size'

Myle Ott's avatar
Myle Ott committed
71
    print(args)
Louis Martin's avatar
Louis Martin committed
72
73
74

    use_cuda = torch.cuda.is_available() and not args.cpu

Myle Ott's avatar
Myle Ott committed
75
76
77
    # Setup task, e.g., translation
    task = tasks.setup_task(args)

Louis Martin's avatar
Louis Martin committed
78
    # Load ensemble
79
    print('| loading model(s) from {}'.format(args.path))
Myle Ott's avatar
Myle Ott committed
80
81
82
83
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(':'),
        arg_overrides=eval(args.model_overrides),
        task=task,
Myle Ott's avatar
Myle Ott committed
84
    )
Louis Martin's avatar
Louis Martin committed
85

Myle Ott's avatar
Myle Ott committed
86
    # Set dictionaries
Myle Ott's avatar
Myle Ott committed
87
    src_dict = task.source_dictionary
Myle Ott's avatar
Myle Ott committed
88
    tgt_dict = task.target_dictionary
Louis Martin's avatar
Louis Martin committed
89
90
91

    # Optimize ensemble for generation
    for model in models:
Myle Ott's avatar
Myle Ott committed
92
93
94
95
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
Myle Ott's avatar
Myle Ott committed
96
97
        if args.fp16:
            model.half()
Myle Ott's avatar
Myle Ott committed
98
99
        if use_cuda:
            model.cuda()
Louis Martin's avatar
Louis Martin committed
100
101

    # Initialize generator
Myle Ott's avatar
Myle Ott committed
102
    generator = task.build_generator(args)
Louis Martin's avatar
Louis Martin committed
103

104
105
106
107
108
109
110
111
112
113
114
115
    # Hack to support GPT-2 BPE
    if args.remove_bpe == 'gpt2':
        from fairseq.gpt2_bpe.gpt2_encoding import get_encoder
        decoder = get_encoder(
            'fairseq/gpt2_bpe/encoder.json',
            'fairseq/gpt2_bpe/vocab.bpe',
        )
        encode_fn = lambda x: ' '.join(map(str, decoder.encode(x)))
    else:
        decoder = None
        encode_fn = lambda x: x

Louis Martin's avatar
Louis Martin committed
116
117
118
119
    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

120
121
122
123
124
    max_positions = utils.resolve_max_positions(
        task.max_positions(),
        *[model.max_positions() for model in models]
    )

125
126
127
    if args.buffer_size > 1:
        print('| Sentence buffer size:', args.buffer_size)
    print('| Type the input sentence and press return:')
Myle Ott's avatar
Myle Ott committed
128
    start_id = 0
129
    for inputs in buffered_read(args.input, args.buffer_size):
130
        results = []
131
        for batch in make_batches(inputs, args, task, max_positions, encode_fn):
Myle Ott's avatar
Myle Ott committed
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
            src_tokens = batch.src_tokens
            src_lengths = batch.src_lengths
            if use_cuda:
                src_tokens = src_tokens.cuda()
                src_lengths = src_lengths.cuda()

            sample = {
                'net_input': {
                    'src_tokens': src_tokens,
                    'src_lengths': src_lengths,
                },
            }
            translations = task.inference_step(generator, models, sample)
            for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)):
                src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad())
                results.append((start_id + id, src_tokens_i, hypos))

        # sort output to match input order
        for id, src_tokens, hypos in sorted(results, key=lambda x: x[0]):
151
152
153
            if src_dict is not None:
                src_str = src_dict.string(src_tokens, args.remove_bpe)
                print('S-{}\t{}'.format(id, src_str))
Myle Ott's avatar
Myle Ott committed
154
155
156
157
158
159
160
161
162
163
164

            # Process top predictions
            for hypo in hypos[:min(len(hypos), args.nbest)]:
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
                    align_dict=align_dict,
                    tgt_dict=tgt_dict,
                    remove_bpe=args.remove_bpe,
                )
165
166
                if decoder is not None:
                    hypo_str = decoder.decode(map(int, hypo_str.strip().split()))
Myle Ott's avatar
Myle Ott committed
167
168
169
170
171
172
173
174
175
176
177
178
                print('H-{}\t{}\t{}'.format(id, hypo['score'], hypo_str))
                print('P-{}\t{}'.format(
                    id,
                    ' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist()))
                ))
                if args.print_alignment:
                    print('A-{}\t{}'.format(
                        id,
                        ' '.join(map(lambda x: str(utils.item(x)), alignment))
                    ))

        # update running id counter
179
        start_id += len(inputs)
Louis Martin's avatar
Louis Martin committed
180

Myle Ott's avatar
Myle Ott committed
181

Myle Ott's avatar
Myle Ott committed
182
def cli_main():
183
    parser = options.get_generation_parser(interactive=True)
184
    args = options.parse_args_and_arch(parser)
Myle Ott's avatar
Myle Ott committed
185
    main(args)
Myle Ott's avatar
Myle Ott committed
186
187
188
189


if __name__ == '__main__':
    cli_main()