interactive.py 5.29 KB
Newer Older
Myle Ott's avatar
Myle Ott committed
1
#!/usr/bin/env python3 -u
Louis Martin's avatar
Louis Martin committed
2
3
4
5
6
7
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
Myle Ott's avatar
Myle Ott committed
8

Myle Ott's avatar
Myle Ott committed
9
from collections import namedtuple
10
import numpy as np
Louis Martin's avatar
Louis Martin committed
11
import sys
Myle Ott's avatar
Myle Ott committed
12

Louis Martin's avatar
Louis Martin committed
13
14
15
import torch
from torch.autograd import Variable

Myle Ott's avatar
Myle Ott committed
16
from fairseq import data, options, tasks, tokenizer, utils
Louis Martin's avatar
Louis Martin committed
17
18
from fairseq.sequence_generator import SequenceGenerator

Myle Ott's avatar
Myle Ott committed
19

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
Batch = namedtuple('Batch', 'srcs tokens lengths')
Translation = namedtuple('Translation', 'src_str hypos alignments')


def buffered_read(buffer_size):
    buffer = []
    for src_str in sys.stdin:
        buffer.append(src_str.strip())
        if len(buffer) >= buffer_size:
            yield buffer
            buffer = []

    if len(buffer) > 0:
        yield buffer


Myle Ott's avatar
Myle Ott committed
36
37
38
39
40
41
42
43
44
45
46
47
48
def make_batches(lines, args, src_dict, max_positions):
    tokens = [
        tokenizer.Tokenizer.tokenize(src_str, src_dict, add_if_not_exist=False).long()
        for src_str in lines
    ]
    lengths = np.array([t.numel() for t in tokens])
    itr = data.EpochBatchIterator(
        dataset=data.LanguagePairDataset(tokens, lengths, src_dict),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=max_positions,
    ).next_epoch_itr(shuffle=False)
    for batch in itr:
49
        yield Batch(
Myle Ott's avatar
Myle Ott committed
50
51
52
53
            srcs=[lines[i] for i in batch['id']],
            tokens=batch['net_input']['src_tokens'],
            lengths=batch['net_input']['src_lengths'],
        ), batch['id']
54

Louis Martin's avatar
Louis Martin committed
55

Myle Ott's avatar
Myle Ott committed
56
def main(args):
Myle Ott's avatar
Myle Ott committed
57
58
59
60
61
    if args.buffer_size < 1:
        args.buffer_size = 1
    if args.max_tokens is None and args.max_sentences is None:
        args.max_sentences = 1

Myle Ott's avatar
Myle Ott committed
62
63
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
64
65
66
    assert not args.max_sentences or args.max_sentences <= args.buffer_size, \
        '--max-sentences/--batch-size cannot be larger than --buffer-size'

Myle Ott's avatar
Myle Ott committed
67
    print(args)
Louis Martin's avatar
Louis Martin committed
68
69
70

    use_cuda = torch.cuda.is_available() and not args.cpu

Myle Ott's avatar
Myle Ott committed
71
72
73
    # Setup task, e.g., translation
    task = tasks.setup_task(args)

Louis Martin's avatar
Louis Martin committed
74
    # Load ensemble
75
    print('| loading model(s) from {}'.format(args.path))
76
    model_paths = args.path.split(':')
Myle Ott's avatar
Myle Ott committed
77
    models, model_args = utils.load_ensemble_for_inference(model_paths, task)
Louis Martin's avatar
Louis Martin committed
78

Myle Ott's avatar
Myle Ott committed
79
80
81
    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary
Louis Martin's avatar
Louis Martin committed
82
83
84
85

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
Myle Ott's avatar
Myle Ott committed
86
87
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
        )
Louis Martin's avatar
Louis Martin committed
88
89
90

    # Initialize generator
    translator = SequenceGenerator(
Myle Ott's avatar
Myle Ott committed
91
        models, tgt_dict, beam_size=args.beam, stop_early=(not args.no_early_stop),
Louis Martin's avatar
Louis Martin committed
92
        normalize_scores=(not args.unnormalized), len_penalty=args.lenpen,
93
        unk_penalty=args.unkpen, sampling=args.sampling, sampling_topk=args.sampling_topk,
Myle Ott's avatar
Myle Ott committed
94
95
        minlen=args.min_len,
    )
96

Louis Martin's avatar
Louis Martin committed
97
98
99
100
101
102
103
    if use_cuda:
        translator.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

104
105
106
107
108
    def make_result(src_str, hypos):
        result = Translation(
            src_str='O\t{}'.format(src_str),
            hypos=[],
            alignments=[],
Myle Ott's avatar
Myle Ott committed
109
        )
Louis Martin's avatar
Louis Martin committed
110
111
112
113
114
115
116
117

        # Process top predictions
        for hypo in hypos[:min(len(hypos), args.nbest)]:
            hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                hypo_tokens=hypo['tokens'].int().cpu(),
                src_str=src_str,
                alignment=hypo['alignment'].int().cpu(),
                align_dict=align_dict,
Myle Ott's avatar
Myle Ott committed
118
                tgt_dict=tgt_dict,
Myle Ott's avatar
Myle Ott committed
119
120
                remove_bpe=args.remove_bpe,
            )
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
            result.hypos.append('H\t{}\t{}'.format(hypo['score'], hypo_str))
            result.alignments.append('A\t{}'.format(' '.join(map(lambda x: str(utils.item(x)), alignment))))
        return result

    def process_batch(batch):
        tokens = batch.tokens
        lengths = batch.lengths

        if use_cuda:
            tokens = tokens.cuda()
            lengths = lengths.cuda()

        translations = translator.generate(
            Variable(tokens),
            Variable(lengths),
            maxlen=int(args.max_len_a * tokens.size(1) + args.max_len_b),
        )

        return [make_result(batch.srcs[i], t) for i, t in enumerate(translations)]

    if args.buffer_size > 1:
        print('| Sentence buffer size:', args.buffer_size)
    print('| Type the input sentence and press return:')
    for inputs in buffered_read(args.buffer_size):
        indices = []
        results = []
Myle Ott's avatar
Myle Ott committed
147
        for batch, batch_indices in make_batches(inputs, args, src_dict, models[0].max_positions()):
148
149
150
151
152
153
154
155
156
            indices.extend(batch_indices)
            results += process_batch(batch)

        for i in np.argsort(indices):
            result = results[i]
            print(result.src_str)
            for hypo, align in zip(result.hypos, result.alignments):
                print(hypo)
                print(align)
Louis Martin's avatar
Louis Martin committed
157

Myle Ott's avatar
Myle Ott committed
158

Louis Martin's avatar
Louis Martin committed
159
if __name__ == '__main__':
160
    parser = options.get_generation_parser(interactive=True)
161
    args = options.parse_args_and_arch(parser)
Myle Ott's avatar
Myle Ott committed
162
    main(args)