interactive.py 6.1 KB
Newer Older
Myle Ott's avatar
Myle Ott committed
1
#!/usr/bin/env python3 -u
Louis Martin's avatar
Louis Martin committed
2
3
4
5
6
7
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
Myle Ott's avatar
Myle Ott committed
8

Myle Ott's avatar
Myle Ott committed
9
from collections import namedtuple
10
import numpy as np
Louis Martin's avatar
Louis Martin committed
11
import sys
Myle Ott's avatar
Myle Ott committed
12

Louis Martin's avatar
Louis Martin committed
13
14
import torch

Myle Ott's avatar
Myle Ott committed
15
from fairseq import data, options, tasks, tokenizer, utils
Louis Martin's avatar
Louis Martin committed
16
17
from fairseq.sequence_generator import SequenceGenerator

Myle Ott's avatar
Myle Ott committed
18

19
Batch = namedtuple('Batch', 'srcs tokens lengths')
20
Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments')
21
22
23
24
25
26
27
28
29
30
31
32
33
34


def buffered_read(buffer_size):
    buffer = []
    for src_str in sys.stdin:
        buffer.append(src_str.strip())
        if len(buffer) >= buffer_size:
            yield buffer
            buffer = []

    if len(buffer) > 0:
        yield buffer


35
def make_batches(lines, args, task, max_positions):
Myle Ott's avatar
Myle Ott committed
36
    tokens = [
37
        tokenizer.Tokenizer.tokenize(src_str, task.source_dictionary, add_if_not_exist=False).long()
Myle Ott's avatar
Myle Ott committed
38
39
40
        for src_str in lines
    ]
    lengths = np.array([t.numel() for t in tokens])
41
42
    itr = task.get_batch_iterator(
        dataset=data.LanguagePairDataset(tokens, lengths, task.source_dictionary),
Myle Ott's avatar
Myle Ott committed
43
44
45
46
47
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=max_positions,
    ).next_epoch_itr(shuffle=False)
    for batch in itr:
48
        yield Batch(
Myle Ott's avatar
Myle Ott committed
49
50
51
52
            srcs=[lines[i] for i in batch['id']],
            tokens=batch['net_input']['src_tokens'],
            lengths=batch['net_input']['src_lengths'],
        ), batch['id']
53

Louis Martin's avatar
Louis Martin committed
54

Myle Ott's avatar
Myle Ott committed
55
def main(args):
Myle Ott's avatar
Myle Ott committed
56
57
58
59
60
    if args.buffer_size < 1:
        args.buffer_size = 1
    if args.max_tokens is None and args.max_sentences is None:
        args.max_sentences = 1

Myle Ott's avatar
Myle Ott committed
61
62
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
63
64
65
    assert not args.max_sentences or args.max_sentences <= args.buffer_size, \
        '--max-sentences/--batch-size cannot be larger than --buffer-size'

Myle Ott's avatar
Myle Ott committed
66
    print(args)
Louis Martin's avatar
Louis Martin committed
67
68
69

    use_cuda = torch.cuda.is_available() and not args.cpu

Myle Ott's avatar
Myle Ott committed
70
71
72
    # Setup task, e.g., translation
    task = tasks.setup_task(args)

Louis Martin's avatar
Louis Martin committed
73
    # Load ensemble
74
    print('| loading model(s) from {}'.format(args.path))
75
    model_paths = args.path.split(':')
76
    models, model_args = utils.load_ensemble_for_inference(model_paths, task, model_arg_overrides=eval(args.model_overrides))
Louis Martin's avatar
Louis Martin committed
77

Myle Ott's avatar
Myle Ott committed
78
79
    # Set dictionaries
    tgt_dict = task.target_dictionary
Louis Martin's avatar
Louis Martin committed
80
81
82

    # Optimize ensemble for generation
    for model in models:
Myle Ott's avatar
Myle Ott committed
83
84
85
86
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
Myle Ott's avatar
Myle Ott committed
87
88
        if args.fp16:
            model.half()
Louis Martin's avatar
Louis Martin committed
89
90
91

    # Initialize generator
    translator = SequenceGenerator(
Myle Ott's avatar
Myle Ott committed
92
93
94
95
96
        models, tgt_dict, beam_size=args.beam, minlen=args.min_len,
        stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized),
        len_penalty=args.lenpen, unk_penalty=args.unkpen,
        sampling=args.sampling, sampling_topk=args.sampling_topk, sampling_temperature=args.sampling_temperature,
        diverse_beam_groups=args.diverse_beam_groups, diverse_beam_strength=args.diverse_beam_strength,
Myle Ott's avatar
Myle Ott committed
97
    )
98

Louis Martin's avatar
Louis Martin committed
99
100
101
102
103
104
105
    if use_cuda:
        translator.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

106
107
108
109
    def make_result(src_str, hypos):
        result = Translation(
            src_str='O\t{}'.format(src_str),
            hypos=[],
110
            pos_scores=[],
111
            alignments=[],
Myle Ott's avatar
Myle Ott committed
112
        )
Louis Martin's avatar
Louis Martin committed
113
114
115
116
117
118

        # Process top predictions
        for hypo in hypos[:min(len(hypos), args.nbest)]:
            hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                hypo_tokens=hypo['tokens'].int().cpu(),
                src_str=src_str,
Myle Ott's avatar
Myle Ott committed
119
                alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
Louis Martin's avatar
Louis Martin committed
120
                align_dict=align_dict,
Myle Ott's avatar
Myle Ott committed
121
                tgt_dict=tgt_dict,
Myle Ott's avatar
Myle Ott committed
122
123
                remove_bpe=args.remove_bpe,
            )
124
            result.hypos.append('H\t{}\t{}'.format(hypo['score'], hypo_str))
125
126
127
128
129
130
            result.pos_scores.append('P\t{}'.format(
                ' '.join(map(
                    lambda x: '{:.4f}'.format(x),
                    hypo['positional_scores'].tolist(),
                ))
            ))
Myle Ott's avatar
Myle Ott committed
131
132
133
134
            result.alignments.append(
                'A\t{}'.format(' '.join(map(lambda x: str(utils.item(x)), alignment)))
                if args.print_alignment else None
            )
135
136
137
138
139
140
141
142
143
144
145
        return result

    def process_batch(batch):
        tokens = batch.tokens
        lengths = batch.lengths

        if use_cuda:
            tokens = tokens.cuda()
            lengths = lengths.cuda()

        translations = translator.generate(
Myle Ott's avatar
Myle Ott committed
146
147
            tokens,
            lengths,
148
149
150
151
152
            maxlen=int(args.max_len_a * tokens.size(1) + args.max_len_b),
        )

        return [make_result(batch.srcs[i], t) for i, t in enumerate(translations)]

153
154
155
156
157
    max_positions = utils.resolve_max_positions(
        task.max_positions(),
        *[model.max_positions() for model in models]
    )

158
159
160
161
162
163
    if args.buffer_size > 1:
        print('| Sentence buffer size:', args.buffer_size)
    print('| Type the input sentence and press return:')
    for inputs in buffered_read(args.buffer_size):
        indices = []
        results = []
164
        for batch, batch_indices in make_batches(inputs, args, task, max_positions):
165
166
167
168
169
170
            indices.extend(batch_indices)
            results += process_batch(batch)

        for i in np.argsort(indices):
            result = results[i]
            print(result.src_str)
171
            for hypo, pos_scores, align in zip(result.hypos, result.pos_scores, result.alignments):
172
                print(hypo)
173
                print(pos_scores)
Myle Ott's avatar
Myle Ott committed
174
175
                if align is not None:
                    print(align)
Louis Martin's avatar
Louis Martin committed
176

Myle Ott's avatar
Myle Ott committed
177

Louis Martin's avatar
Louis Martin committed
178
if __name__ == '__main__':
179
    parser = options.get_generation_parser(interactive=True)
180
    args = options.parse_args_and_arch(parser)
Myle Ott's avatar
Myle Ott committed
181
    main(args)