interactive.py 6.31 KB
Newer Older
Myle Ott's avatar
Myle Ott committed
1
#!/usr/bin/env python3 -u
Louis Martin's avatar
Louis Martin committed
2
3
4
5
6
7
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
Myle Ott's avatar
Myle Ott committed
8
9
10
"""
Translate raw text with a trained model. Batches data on-the-fly.
"""
Myle Ott's avatar
Myle Ott committed
11

Myle Ott's avatar
Myle Ott committed
12
from collections import namedtuple
13
import numpy as np
Louis Martin's avatar
Louis Martin committed
14
import sys
Myle Ott's avatar
Myle Ott committed
15

Louis Martin's avatar
Louis Martin committed
16
17
import torch

Myle Ott's avatar
Myle Ott committed
18
from fairseq import data, options, tasks, tokenizer, utils
Louis Martin's avatar
Louis Martin committed
19
20
from fairseq.sequence_generator import SequenceGenerator

Myle Ott's avatar
Myle Ott committed
21

22
Batch = namedtuple('Batch', 'srcs tokens lengths')
23
Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments')
24
25
26
27
28
29
30
31
32
33
34
35
36
37


def buffered_read(buffer_size):
    buffer = []
    for src_str in sys.stdin:
        buffer.append(src_str.strip())
        if len(buffer) >= buffer_size:
            yield buffer
            buffer = []

    if len(buffer) > 0:
        yield buffer


38
def make_batches(lines, args, task, max_positions):
Myle Ott's avatar
Myle Ott committed
39
    tokens = [
40
        tokenizer.Tokenizer.tokenize(src_str, task.source_dictionary, add_if_not_exist=False).long()
Myle Ott's avatar
Myle Ott committed
41
42
43
        for src_str in lines
    ]
    lengths = np.array([t.numel() for t in tokens])
44
45
    itr = task.get_batch_iterator(
        dataset=data.LanguagePairDataset(tokens, lengths, task.source_dictionary),
Myle Ott's avatar
Myle Ott committed
46
47
48
49
50
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=max_positions,
    ).next_epoch_itr(shuffle=False)
    for batch in itr:
51
        yield Batch(
Myle Ott's avatar
Myle Ott committed
52
53
54
55
            srcs=[lines[i] for i in batch['id']],
            tokens=batch['net_input']['src_tokens'],
            lengths=batch['net_input']['src_lengths'],
        ), batch['id']
56

Louis Martin's avatar
Louis Martin committed
57

Myle Ott's avatar
Myle Ott committed
58
def main(args):
Myle Ott's avatar
Myle Ott committed
59
60
61
62
63
    if args.buffer_size < 1:
        args.buffer_size = 1
    if args.max_tokens is None and args.max_sentences is None:
        args.max_sentences = 1

Myle Ott's avatar
Myle Ott committed
64
65
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
66
67
68
    assert not args.max_sentences or args.max_sentences <= args.buffer_size, \
        '--max-sentences/--batch-size cannot be larger than --buffer-size'

Myle Ott's avatar
Myle Ott committed
69
    print(args)
Louis Martin's avatar
Louis Martin committed
70
71
72

    use_cuda = torch.cuda.is_available() and not args.cpu

Myle Ott's avatar
Myle Ott committed
73
74
75
    # Setup task, e.g., translation
    task = tasks.setup_task(args)

Louis Martin's avatar
Louis Martin committed
76
    # Load ensemble
77
    print('| loading model(s) from {}'.format(args.path))
Myle Ott's avatar
Myle Ott committed
78
79
80
    models, _model_args = utils.load_ensemble_for_inference(
        args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides),
    )
Louis Martin's avatar
Louis Martin committed
81

Myle Ott's avatar
Myle Ott committed
82
83
    # Set dictionaries
    tgt_dict = task.target_dictionary
Louis Martin's avatar
Louis Martin committed
84
85
86

    # Optimize ensemble for generation
    for model in models:
Myle Ott's avatar
Myle Ott committed
87
88
89
90
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
Myle Ott's avatar
Myle Ott committed
91
92
        if args.fp16:
            model.half()
Louis Martin's avatar
Louis Martin committed
93
94
95

    # Initialize generator
    translator = SequenceGenerator(
Myle Ott's avatar
Myle Ott committed
96
97
98
99
100
        models, tgt_dict, beam_size=args.beam, minlen=args.min_len,
        stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized),
        len_penalty=args.lenpen, unk_penalty=args.unkpen,
        sampling=args.sampling, sampling_topk=args.sampling_topk, sampling_temperature=args.sampling_temperature,
        diverse_beam_groups=args.diverse_beam_groups, diverse_beam_strength=args.diverse_beam_strength,
Myle Ott's avatar
Myle Ott committed
101
        match_source_len=args.match_source_len, no_repeat_ngram_size=args.no_repeat_ngram_size,
Myle Ott's avatar
Myle Ott committed
102
    )
103

Louis Martin's avatar
Louis Martin committed
104
105
106
107
108
109
110
    if use_cuda:
        translator.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

111
112
113
114
    def make_result(src_str, hypos):
        result = Translation(
            src_str='O\t{}'.format(src_str),
            hypos=[],
115
            pos_scores=[],
116
            alignments=[],
Myle Ott's avatar
Myle Ott committed
117
        )
Louis Martin's avatar
Louis Martin committed
118
119
120
121
122
123

        # Process top predictions
        for hypo in hypos[:min(len(hypos), args.nbest)]:
            hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                hypo_tokens=hypo['tokens'].int().cpu(),
                src_str=src_str,
Myle Ott's avatar
Myle Ott committed
124
                alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
Louis Martin's avatar
Louis Martin committed
125
                align_dict=align_dict,
Myle Ott's avatar
Myle Ott committed
126
                tgt_dict=tgt_dict,
Myle Ott's avatar
Myle Ott committed
127
128
                remove_bpe=args.remove_bpe,
            )
129
            result.hypos.append('H\t{}\t{}'.format(hypo['score'], hypo_str))
130
131
132
133
134
135
            result.pos_scores.append('P\t{}'.format(
                ' '.join(map(
                    lambda x: '{:.4f}'.format(x),
                    hypo['positional_scores'].tolist(),
                ))
            ))
Myle Ott's avatar
Myle Ott committed
136
137
138
139
            result.alignments.append(
                'A\t{}'.format(' '.join(map(lambda x: str(utils.item(x)), alignment)))
                if args.print_alignment else None
            )
140
141
142
143
144
145
146
147
148
149
        return result

    def process_batch(batch):
        tokens = batch.tokens
        lengths = batch.lengths

        if use_cuda:
            tokens = tokens.cuda()
            lengths = lengths.cuda()

150
        encoder_input = {'src_tokens': tokens, 'src_lengths': lengths}
151
        translations = translator.generate(
152
            encoder_input,
153
154
155
156
157
            maxlen=int(args.max_len_a * tokens.size(1) + args.max_len_b),
        )

        return [make_result(batch.srcs[i], t) for i, t in enumerate(translations)]

158
159
160
161
162
    max_positions = utils.resolve_max_positions(
        task.max_positions(),
        *[model.max_positions() for model in models]
    )

163
164
165
166
167
168
    if args.buffer_size > 1:
        print('| Sentence buffer size:', args.buffer_size)
    print('| Type the input sentence and press return:')
    for inputs in buffered_read(args.buffer_size):
        indices = []
        results = []
169
        for batch, batch_indices in make_batches(inputs, args, task, max_positions):
170
171
172
173
174
175
            indices.extend(batch_indices)
            results += process_batch(batch)

        for i in np.argsort(indices):
            result = results[i]
            print(result.src_str)
176
            for hypo, pos_scores, align in zip(result.hypos, result.pos_scores, result.alignments):
177
                print(hypo)
178
                print(pos_scores)
Myle Ott's avatar
Myle Ott committed
179
180
                if align is not None:
                    print(align)
Louis Martin's avatar
Louis Martin committed
181

Myle Ott's avatar
Myle Ott committed
182

Louis Martin's avatar
Louis Martin committed
183
if __name__ == '__main__':
184
    parser = options.get_generation_parser(interactive=True)
185
    args = options.parse_args_and_arch(parser)
Myle Ott's avatar
Myle Ott committed
186
    main(args)