generator.py 6.66 KB
Newer Older
Nathan Ng's avatar
Nathan Ng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

from collections import namedtuple

import torch
import html
import os
from sacremoses import MosesTokenizer, MosesDetokenizer
from subword_nmt import apply_bpe

from fairseq import checkpoint_utils, options, tasks, utils, file_utils

Batch = namedtuple('Batch', 'ids src_tokens src_lengths')


class Generator(object):

    def __init__(self, task, models, args, src_bpe=None, bpe_symbol='@@ '):
        self.task = task
        self.models = models
        self.src_dict = task.source_dictionary
        self.tgt_dict = task.target_dictionary
        self.src_bpe = src_bpe
        self.use_cuda = torch.cuda.is_available() and not args.cpu
        self.args = args

        self.args.remove_bpe = bpe_symbol

        # optimize model for generation
        for model in self.models:
            model.make_generation_fast_(
                beamable_mm_beam_size=None if self.args.no_beamable_mm else self.args.beam,
                need_attn=args.print_alignment,
            )
            if args.fp16:
                model.half()
            if self.use_cuda:
                model.cuda()

        self.generator = self.task.build_generator(args)

        # Load alignment dictionary for unknown word replacement
        # (None if no unknown word replacement, empty if no path to align dictionary)
        self.align_dict = utils.load_align_dict(args.replace_unk)

        self.max_positions = utils.resolve_max_positions(
            self.task.max_positions(),
            *[model.max_positions() for model in models]
        )

        if hasattr(args, 'source_lang'):
            self.tokenizer = MosesTokenizer(lang=args.source_lang)
        else:
            self.tokenizer = MosesTokenizer()

        if src_bpe is not None:
            bpe_parser = apply_bpe.create_parser()
            bpe_args = bpe_parser.parse_args(['--codes', self.src_bpe])
            self.bpe = apply_bpe.BPE(bpe_args.codes, bpe_args.merges, bpe_args.separator, None, bpe_args.glossaries)
        else:
            self.bpe = None

    def generate(self, src_str, verbose=False):

        src_str = self.tokenizer.tokenize(src_str, return_str=True)
        if self.bpe:
            src_str = self.bpe.process_line(src_str)

        for batch in self.make_batches([src_str], self.args, self.task, self.max_positions):
            src_tokens = batch.src_tokens
            src_lengths = batch.src_lengths
            if self.use_cuda:
                src_tokens = src_tokens.cuda()
                src_lengths = src_lengths.cuda()

            sample = {
                'net_input': {
                    'src_tokens': src_tokens,
                    'src_lengths': src_lengths,
                },
            }
            translations = self.task.inference_step(self.generator, self.models, sample)
            src_tokens = utils.strip_pad(src_tokens, self.tgt_dict.pad())

        if self.src_dict is not None:
            src_str = self.src_dict.string(src_tokens, self.args.remove_bpe)
            if verbose:
                print('S\t{}'.format(src_str))

        # Process top predictions
        for hypo in translations[0][:min(len(translations), self.args.nbest)]:
            hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                hypo_tokens=hypo['tokens'].int().cpu(),
                src_str=src_str,
                alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
                align_dict=self.align_dict,
                tgt_dict=self.tgt_dict,
                remove_bpe=self.args.remove_bpe,
            )
            if verbose:
                print('H\t{}\t{}'.format(hypo['score'], hypo_str))
                print('P\t{}'.format(
                    ' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist()))
                ))
                if self.args.print_alignment:
                    print('A\t{}'.format(
                        ' '.join(map(lambda x: str(utils.item(x)), alignment))
                    ))

        return html.unescape(hypo_str)

    @classmethod
    def from_pretrained(cls, parser, *args, model_name_or_path, data_name_or_path, **kwargs):
        model_path = file_utils.load_archive_file(model_name_or_path)
        data_path = file_utils.load_archive_file(data_name_or_path)
        checkpoint_path = os.path.join(model_path, 'model.pt')

        task_name = kwargs.get('task', 'translation')

        # set data and parse
        model_args = options.parse_args_and_arch(parser, input_args=[data_path, '--task', task_name])

        # override any kwargs passed in
        if kwargs is not None:
            for arg_name, arg_val in kwargs.items():
                setattr(model_args, arg_name, arg_val)

        utils.import_user_module(args)

        if model_args.buffer_size < 1:
            model_args.buffer_size = 1
        if model_args.max_tokens is None and model_args.max_sentences is None:
            model_args.max_sentences = 1

        assert not model_args.sampling or model_args.nbest == model_args.beam, \
            '--sampling requires --nbest to be equal to --beam'
        assert not model_args.max_sentences or model_args.max_sentences <= model_args.buffer_size, \
            '--max-sentences/--batch-size cannot be larger than --buffer-size'

        print(model_args)

        task = tasks.setup_task(model_args)
        print("loading model checkpoint from {}".format(checkpoint_path))

        model, _model_args = checkpoint_utils.load_model_ensemble([checkpoint_path], task=task)
        src_bpe = os.path.join(model_path, 'bpecodes')
        if not os.path.exists(src_bpe):
            src_bpe = None

        return cls(task, model, model_args, src_bpe, kwargs.get('remove_bpe', '@@ '))

    def make_batches(self, lines, args, task, max_positions):
        tokens = [
            task.source_dictionary.encode_line(src_str, add_if_not_exist=False).long()
            for src_str in lines
        ]
        lengths = torch.LongTensor([t.numel() for t in tokens])
        itr = task.get_batch_iterator(
            dataset=task.build_dataset_for_inference(tokens, lengths),
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences,
            max_positions=max_positions,
        ).next_epoch_itr(shuffle=False)
        for batch in itr:
            yield Batch(
                ids=batch['id'],
                src_tokens=batch['net_input']['src_tokens'], src_lengths=batch['net_input']['src_lengths'],
            )