interactive.py 6.38 KB
Newer Older
Myle Ott's avatar
Myle Ott committed
1
#!/usr/bin/env python3 -u
Louis Martin's avatar
Louis Martin committed
2
3
4
5
6
7
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
Myle Ott's avatar
Myle Ott committed
8
9
10
"""
Translate raw text with a trained model. Batches data on-the-fly.
"""
Myle Ott's avatar
Myle Ott committed
11

Myle Ott's avatar
Myle Ott committed
12
from collections import namedtuple
13
import numpy as np
Louis Martin's avatar
Louis Martin committed
14
import sys
Myle Ott's avatar
Myle Ott committed
15

Louis Martin's avatar
Louis Martin committed
16
17
import torch

Myle Ott's avatar
Myle Ott committed
18
from fairseq import data, options, tasks, tokenizer, utils
Louis Martin's avatar
Louis Martin committed
19
from fairseq.sequence_generator import SequenceGenerator
20
from fairseq.utils import import_user_module
Myle Ott's avatar
Myle Ott committed
21

22
Batch = namedtuple('Batch', 'srcs tokens lengths')
23
Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments')
24
25
26
27
28
29
30
31
32
33
34
35
36
37


def buffered_read(buffer_size):
    buffer = []
    for src_str in sys.stdin:
        buffer.append(src_str.strip())
        if len(buffer) >= buffer_size:
            yield buffer
            buffer = []

    if len(buffer) > 0:
        yield buffer


38
def make_batches(lines, args, task, max_positions):
Myle Ott's avatar
Myle Ott committed
39
    tokens = [
40
        tokenizer.Tokenizer.tokenize(src_str, task.source_dictionary, add_if_not_exist=False).long()
Myle Ott's avatar
Myle Ott committed
41
42
43
        for src_str in lines
    ]
    lengths = np.array([t.numel() for t in tokens])
44
45
    itr = task.get_batch_iterator(
        dataset=data.LanguagePairDataset(tokens, lengths, task.source_dictionary),
Myle Ott's avatar
Myle Ott committed
46
47
48
49
50
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=max_positions,
    ).next_epoch_itr(shuffle=False)
    for batch in itr:
51
        yield Batch(
Myle Ott's avatar
Myle Ott committed
52
53
54
55
            srcs=[lines[i] for i in batch['id']],
            tokens=batch['net_input']['src_tokens'],
            lengths=batch['net_input']['src_lengths'],
        ), batch['id']
56

Louis Martin's avatar
Louis Martin committed
57

Myle Ott's avatar
Myle Ott committed
58
def main(args):
59
60
    import_user_module(args)

Myle Ott's avatar
Myle Ott committed
61
62
63
64
65
    if args.buffer_size < 1:
        args.buffer_size = 1
    if args.max_tokens is None and args.max_sentences is None:
        args.max_sentences = 1

Myle Ott's avatar
Myle Ott committed
66
67
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
68
69
70
    assert not args.max_sentences or args.max_sentences <= args.buffer_size, \
        '--max-sentences/--batch-size cannot be larger than --buffer-size'

Myle Ott's avatar
Myle Ott committed
71
    print(args)
Louis Martin's avatar
Louis Martin committed
72
73
74

    use_cuda = torch.cuda.is_available() and not args.cpu

Myle Ott's avatar
Myle Ott committed
75
76
77
    # Setup task, e.g., translation
    task = tasks.setup_task(args)

Louis Martin's avatar
Louis Martin committed
78
    # Load ensemble
79
    print('| loading model(s) from {}'.format(args.path))
Myle Ott's avatar
Myle Ott committed
80
81
82
    models, _model_args = utils.load_ensemble_for_inference(
        args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides),
    )
Louis Martin's avatar
Louis Martin committed
83

Myle Ott's avatar
Myle Ott committed
84
85
    # Set dictionaries
    tgt_dict = task.target_dictionary
Louis Martin's avatar
Louis Martin committed
86
87
88

    # Optimize ensemble for generation
    for model in models:
Myle Ott's avatar
Myle Ott committed
89
90
91
92
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
Myle Ott's avatar
Myle Ott committed
93
94
        if args.fp16:
            model.half()
Louis Martin's avatar
Louis Martin committed
95
96
97

    # Initialize generator
    translator = SequenceGenerator(
Myle Ott's avatar
Myle Ott committed
98
99
100
101
102
        models, tgt_dict, beam_size=args.beam, minlen=args.min_len,
        stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized),
        len_penalty=args.lenpen, unk_penalty=args.unkpen,
        sampling=args.sampling, sampling_topk=args.sampling_topk, sampling_temperature=args.sampling_temperature,
        diverse_beam_groups=args.diverse_beam_groups, diverse_beam_strength=args.diverse_beam_strength,
Myle Ott's avatar
Myle Ott committed
103
        match_source_len=args.match_source_len, no_repeat_ngram_size=args.no_repeat_ngram_size,
Myle Ott's avatar
Myle Ott committed
104
    )
105

Louis Martin's avatar
Louis Martin committed
106
107
108
109
110
111
112
    if use_cuda:
        translator.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

113
114
115
116
    def make_result(src_str, hypos):
        result = Translation(
            src_str='O\t{}'.format(src_str),
            hypos=[],
117
            pos_scores=[],
118
            alignments=[],
Myle Ott's avatar
Myle Ott committed
119
        )
Louis Martin's avatar
Louis Martin committed
120
121
122
123
124
125

        # Process top predictions
        for hypo in hypos[:min(len(hypos), args.nbest)]:
            hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                hypo_tokens=hypo['tokens'].int().cpu(),
                src_str=src_str,
Myle Ott's avatar
Myle Ott committed
126
                alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
Louis Martin's avatar
Louis Martin committed
127
                align_dict=align_dict,
Myle Ott's avatar
Myle Ott committed
128
                tgt_dict=tgt_dict,
Myle Ott's avatar
Myle Ott committed
129
130
                remove_bpe=args.remove_bpe,
            )
131
            result.hypos.append('H\t{}\t{}'.format(hypo['score'], hypo_str))
132
133
134
135
136
137
            result.pos_scores.append('P\t{}'.format(
                ' '.join(map(
                    lambda x: '{:.4f}'.format(x),
                    hypo['positional_scores'].tolist(),
                ))
            ))
Myle Ott's avatar
Myle Ott committed
138
139
140
141
            result.alignments.append(
                'A\t{}'.format(' '.join(map(lambda x: str(utils.item(x)), alignment)))
                if args.print_alignment else None
            )
142
143
144
145
146
147
148
149
150
151
        return result

    def process_batch(batch):
        tokens = batch.tokens
        lengths = batch.lengths

        if use_cuda:
            tokens = tokens.cuda()
            lengths = lengths.cuda()

152
        encoder_input = {'src_tokens': tokens, 'src_lengths': lengths}
153
        translations = translator.generate(
154
            encoder_input,
155
156
157
158
159
            maxlen=int(args.max_len_a * tokens.size(1) + args.max_len_b),
        )

        return [make_result(batch.srcs[i], t) for i, t in enumerate(translations)]

160
161
162
163
164
    max_positions = utils.resolve_max_positions(
        task.max_positions(),
        *[model.max_positions() for model in models]
    )

165
166
167
168
169
170
    if args.buffer_size > 1:
        print('| Sentence buffer size:', args.buffer_size)
    print('| Type the input sentence and press return:')
    for inputs in buffered_read(args.buffer_size):
        indices = []
        results = []
171
        for batch, batch_indices in make_batches(inputs, args, task, max_positions):
172
173
174
175
176
177
            indices.extend(batch_indices)
            results += process_batch(batch)

        for i in np.argsort(indices):
            result = results[i]
            print(result.src_str)
178
            for hypo, pos_scores, align in zip(result.hypos, result.pos_scores, result.alignments):
179
                print(hypo)
180
                print(pos_scores)
Myle Ott's avatar
Myle Ott committed
181
182
                if align is not None:
                    print(align)
Louis Martin's avatar
Louis Martin committed
183

Myle Ott's avatar
Myle Ott committed
184

Louis Martin's avatar
Louis Martin committed
185
if __name__ == '__main__':
186
    parser = options.get_generation_parser(interactive=True)
187
    args = options.parse_args_and_arch(parser)
Myle Ott's avatar
Myle Ott committed
188
    main(args)