interactive.py 6.52 KB
Newer Older
Myle Ott's avatar
Myle Ott committed
1
#!/usr/bin/env python3 -u
Louis Martin's avatar
Louis Martin committed
2
3
4
5
6
7
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
Myle Ott's avatar
Myle Ott committed
8
9
10
"""
Translate raw text with a trained model. Batches data on-the-fly.
"""
Myle Ott's avatar
Myle Ott committed
11

Myle Ott's avatar
Myle Ott committed
12
from collections import namedtuple
13
import fileinput
Louis Martin's avatar
Louis Martin committed
14
import sys
Myle Ott's avatar
Myle Ott committed
15

16
import numpy as np
Louis Martin's avatar
Louis Martin committed
17
18
import torch

Myle Ott's avatar
Myle Ott committed
19
from fairseq import data, options, tasks, tokenizer, utils
Louis Martin's avatar
Louis Martin committed
20
from fairseq.sequence_generator import SequenceGenerator
21
from fairseq.utils import import_user_module
Myle Ott's avatar
Myle Ott committed
22

23
Batch = namedtuple('Batch', 'srcs tokens lengths')
24
Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments')
25
26


27
def buffered_read(input, buffer_size):
28
    buffer = []
29
    for src_str in fileinput.input(files=[input], openhook=fileinput.hook_encoded("utf-8")):
30
31
32
33
34
35
36
37
38
        buffer.append(src_str.strip())
        if len(buffer) >= buffer_size:
            yield buffer
            buffer = []

    if len(buffer) > 0:
        yield buffer


39
def make_batches(lines, args, task, max_positions):
Myle Ott's avatar
Myle Ott committed
40
    tokens = [
41
        tokenizer.Tokenizer.tokenize(src_str, task.source_dictionary, add_if_not_exist=False).long()
Myle Ott's avatar
Myle Ott committed
42
43
44
        for src_str in lines
    ]
    lengths = np.array([t.numel() for t in tokens])
45
46
    itr = task.get_batch_iterator(
        dataset=data.LanguagePairDataset(tokens, lengths, task.source_dictionary),
Myle Ott's avatar
Myle Ott committed
47
48
49
50
51
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=max_positions,
    ).next_epoch_itr(shuffle=False)
    for batch in itr:
52
        yield Batch(
Myle Ott's avatar
Myle Ott committed
53
54
55
56
            srcs=[lines[i] for i in batch['id']],
            tokens=batch['net_input']['src_tokens'],
            lengths=batch['net_input']['src_lengths'],
        ), batch['id']
57

Louis Martin's avatar
Louis Martin committed
58

Myle Ott's avatar
Myle Ott committed
59
def main(args):
60
61
    import_user_module(args)

Myle Ott's avatar
Myle Ott committed
62
63
64
65
66
    if args.buffer_size < 1:
        args.buffer_size = 1
    if args.max_tokens is None and args.max_sentences is None:
        args.max_sentences = 1

Myle Ott's avatar
Myle Ott committed
67
68
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
69
70
71
    assert not args.max_sentences or args.max_sentences <= args.buffer_size, \
        '--max-sentences/--batch-size cannot be larger than --buffer-size'

Myle Ott's avatar
Myle Ott committed
72
    print(args)
Louis Martin's avatar
Louis Martin committed
73
74
75

    use_cuda = torch.cuda.is_available() and not args.cpu

Myle Ott's avatar
Myle Ott committed
76
77
78
    # Setup task, e.g., translation
    task = tasks.setup_task(args)

Louis Martin's avatar
Louis Martin committed
79
    # Load ensemble
80
    print('| loading model(s) from {}'.format(args.path))
Myle Ott's avatar
Myle Ott committed
81
82
83
    models, _model_args = utils.load_ensemble_for_inference(
        args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides),
    )
Louis Martin's avatar
Louis Martin committed
84

Myle Ott's avatar
Myle Ott committed
85
86
    # Set dictionaries
    tgt_dict = task.target_dictionary
Louis Martin's avatar
Louis Martin committed
87
88
89

    # Optimize ensemble for generation
    for model in models:
Myle Ott's avatar
Myle Ott committed
90
91
92
93
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
Myle Ott's avatar
Myle Ott committed
94
95
        if args.fp16:
            model.half()
Louis Martin's avatar
Louis Martin committed
96
97
98

    # Initialize generator
    translator = SequenceGenerator(
Myle Ott's avatar
Myle Ott committed
99
100
101
102
103
        models, tgt_dict, beam_size=args.beam, minlen=args.min_len,
        stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized),
        len_penalty=args.lenpen, unk_penalty=args.unkpen,
        sampling=args.sampling, sampling_topk=args.sampling_topk, sampling_temperature=args.sampling_temperature,
        diverse_beam_groups=args.diverse_beam_groups, diverse_beam_strength=args.diverse_beam_strength,
Myle Ott's avatar
Myle Ott committed
104
        match_source_len=args.match_source_len, no_repeat_ngram_size=args.no_repeat_ngram_size,
Myle Ott's avatar
Myle Ott committed
105
    )
106

Louis Martin's avatar
Louis Martin committed
107
108
109
110
111
112
113
    if use_cuda:
        translator.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

114
115
116
117
    def make_result(src_str, hypos):
        result = Translation(
            src_str='O\t{}'.format(src_str),
            hypos=[],
118
            pos_scores=[],
119
            alignments=[],
Myle Ott's avatar
Myle Ott committed
120
        )
Louis Martin's avatar
Louis Martin committed
121
122
123
124
125
126

        # Process top predictions
        for hypo in hypos[:min(len(hypos), args.nbest)]:
            hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                hypo_tokens=hypo['tokens'].int().cpu(),
                src_str=src_str,
Myle Ott's avatar
Myle Ott committed
127
                alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
Louis Martin's avatar
Louis Martin committed
128
                align_dict=align_dict,
Myle Ott's avatar
Myle Ott committed
129
                tgt_dict=tgt_dict,
Myle Ott's avatar
Myle Ott committed
130
131
                remove_bpe=args.remove_bpe,
            )
132
            result.hypos.append('H\t{}\t{}'.format(hypo['score'], hypo_str))
133
134
135
136
137
138
            result.pos_scores.append('P\t{}'.format(
                ' '.join(map(
                    lambda x: '{:.4f}'.format(x),
                    hypo['positional_scores'].tolist(),
                ))
            ))
Myle Ott's avatar
Myle Ott committed
139
140
141
142
            result.alignments.append(
                'A\t{}'.format(' '.join(map(lambda x: str(utils.item(x)), alignment)))
                if args.print_alignment else None
            )
143
144
145
146
147
148
149
150
151
152
        return result

    def process_batch(batch):
        tokens = batch.tokens
        lengths = batch.lengths

        if use_cuda:
            tokens = tokens.cuda()
            lengths = lengths.cuda()

153
        encoder_input = {'src_tokens': tokens, 'src_lengths': lengths}
154
        translations = translator.generate(
155
            encoder_input,
156
157
158
159
160
            maxlen=int(args.max_len_a * tokens.size(1) + args.max_len_b),
        )

        return [make_result(batch.srcs[i], t) for i, t in enumerate(translations)]

161
162
163
164
165
    max_positions = utils.resolve_max_positions(
        task.max_positions(),
        *[model.max_positions() for model in models]
    )

166
167
168
    if args.buffer_size > 1:
        print('| Sentence buffer size:', args.buffer_size)
    print('| Type the input sentence and press return:')
169
    for inputs in buffered_read(args.input, args.buffer_size):
170
171
        indices = []
        results = []
172
        for batch, batch_indices in make_batches(inputs, args, task, max_positions):
173
            indices.extend(batch_indices)
174
            results.extend(process_batch(batch))
175
176
177
178

        for i in np.argsort(indices):
            result = results[i]
            print(result.src_str)
179
            for hypo, pos_scores, align in zip(result.hypos, result.pos_scores, result.alignments):
180
                print(hypo)
181
                print(pos_scores)
Myle Ott's avatar
Myle Ott committed
182
183
                if align is not None:
                    print(align)
Louis Martin's avatar
Louis Martin committed
184

Myle Ott's avatar
Myle Ott committed
185

Myle Ott's avatar
Myle Ott committed
186
def cli_main():
187
    parser = options.get_generation_parser(interactive=True)
188
    args = options.parse_args_and_arch(parser)
Myle Ott's avatar
Myle Ott committed
189
    main(args)
Myle Ott's avatar
Myle Ott committed
190
191
192
193


if __name__ == '__main__':
    cli_main()