sequence_generator.py 15.2 KB
Newer Older
Sergey Edunov's avatar
Sergey Edunov committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#

from contextlib import ExitStack
import math
import torch
import torch.nn.functional as F
from torch.autograd import Variable

from fairseq import utils


class SequenceGenerator(object):
19
    def __init__(self, models, beam_size=1, minlen=1, maxlen=200,
Sergey Edunov's avatar
Sergey Edunov committed
20
21
22
23
24
25
26
27
28
29
30
31
                 stop_early=True, normalize_scores=True, len_penalty=1):
        """Generates translations of a given source sentence.

        Args:
            min/maxlen: The length of the generated output will be bounded by
                minlen and maxlen (not including the end-of-sentence marker).
            stop_early: Stop generation immediately after we finalize beam_size
                hypotheses, even though longer hypotheses might have better
                normalized scores.
            normalize_scores: Normalize scores by the length of the output.
        """
        self.models = models
32
33
34
35
36
        self.pad = models[0].dst_dict.pad()
        self.eos = models[0].dst_dict.eos()
        assert all(m.dst_dict.pad() == self.pad for m in self.models[1:])
        assert all(m.dst_dict.eos() == self.eos for m in self.models[1:])
        self.vocab_size = len(models[0].dst_dict)
Sergey Edunov's avatar
Sergey Edunov committed
37
38
        self.beam_size = beam_size
        self.minlen = minlen
39
        self.maxlen = min(maxlen, *[m.decoder.max_positions() - self.pad - 2 for m in self.models])
40
        self.positions = torch.LongTensor(range(self.pad + 1, self.pad + self.maxlen + 2))
Sergey Edunov's avatar
Sergey Edunov committed
41
42
43
44
45
46
47
48
49
50
51
        self.decoder_context = models[0].decoder.context_size()
        self.stop_early = stop_early
        self.normalize_scores = normalize_scores
        self.len_penalty = len_penalty

    def cuda(self):
        for model in self.models:
            model.cuda()
        self.positions = self.positions.cuda()
        return self

Myle Ott's avatar
Myle Ott committed
52
    def generate_batched_itr(self, data_itr, maxlen_a=0.0, maxlen_b=200,
Sergey Edunov's avatar
Sergey Edunov committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
                             cuda_device=None, timer=None):
        """Iterate over a batched dataset and yield individual translations.

        Args:
            maxlen_a/b: generate sequences of maximum length ax + b,
                where x is the source sentence length.
            cuda_device: GPU on which to do generation.
            timer: StopwatchMeter for timing generations.
        """

        def lstrip_pad(tensor):
            return tensor[tensor.eq(self.pad).sum():]

        for sample in data_itr:
            s = utils.prepare_sample(sample, volatile=True, cuda_device=cuda_device)
            input = s['net_input']
            srclen = input['src_tokens'].size(1)
            if timer is not None:
                timer.start()
            hypos = self.generate(input['src_tokens'], input['src_positions'],
Myle Ott's avatar
Myle Ott committed
73
                                  maxlen=int(maxlen_a*srclen + maxlen_b))
Sergey Edunov's avatar
Sergey Edunov committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
            if timer is not None:
                timer.stop(s['ntokens'])
            for i, id in enumerate(s['id']):
                src = input['src_tokens'].data[i, :]
                # remove padding from ref, which appears at the beginning
                ref = lstrip_pad(s['target'].data[i, :])
                yield id, src, ref, hypos[i]

    def generate(self, src_tokens, src_positions, beam_size=None, maxlen=None):
        """Generate a batch of translations."""
        with ExitStack() as stack:
            for model in self.models:
                stack.enter_context(model.decoder.incremental_inference())
            return self._generate(src_tokens, src_positions, beam_size, maxlen)

    def _generate(self, src_tokens, src_positions, beam_size=None, maxlen=None):
        bsz = src_tokens.size(0)
        maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen

93
94
        # the max beam size is the dictionary size - 1, since we never select pad
        beam_size = beam_size if beam_size is not None else self.beam_size
95
        beam_size = min(beam_size, self.vocab_size - 1)
96

Sergey Edunov's avatar
Sergey Edunov committed
97
98
99
        encoder_outs = []
        for model in self.models:
            model.eval()
100
            model.decoder.start_fresh_sequence(beam_size)  # start a fresh sequence
Sergey Edunov's avatar
Sergey Edunov committed
101
102
103
104
105
106
107
108
109
110
111

            # compute the encoder output and expand to beam size
            encoder_out = model.encoder(src_tokens, src_positions)
            encoder_out = self._expand_encoder_out(encoder_out, beam_size)
            encoder_outs.append(encoder_out)

        # initialize buffers
        scores = encoder_outs[0][0].data.new(bsz * beam_size).fill_(0)
        tokens = src_tokens.data.new(bsz * beam_size, maxlen + 2).fill_(self.pad)
        tokens_buf = tokens.clone()
        tokens[:, 0] = self.eos
112
113
        attn = scores.new(bsz * beam_size, src_tokens.size(1), maxlen + 2)
        attn_buf = attn.clone()
Sergey Edunov's avatar
Sergey Edunov committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129

        # list of completed sentences
        finalized = [[] for i in range(bsz)]
        finished = [False for i in range(bsz)]
        worst_finalized = [{'idx': None, 'score': float('Inf')} for i in range(bsz)]
        num_remaining_sent = bsz

        # number of candidate hypos per step
        cand_size = 2 * beam_size  # 2 x beam size in case half are EOS

        # offset arrays for converting between different indexing schemes
        bbsz_offsets = (torch.arange(0, bsz)*beam_size).unsqueeze(1).type_as(tokens)
        cand_offsets = torch.arange(0, cand_size).type_as(tokens)

        # helper function for allocating buffers on the fly
        buffers = {}
Louis Martin's avatar
Louis Martin committed
130
        def buffer(name, type_of=tokens):  # noqa
Sergey Edunov's avatar
Sergey Edunov committed
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
            if name not in buffers:
                buffers[name] = type_of.new()
            return buffers[name]

        def is_finished(sent):
            """
            Check whether we've finished generation for a given sentence, by
            comparing the worst score among finalized hypotheses to the best
            possible score among unfinalized hypotheses.
            """
            assert len(finalized[sent]) <= beam_size
            if len(finalized[sent]) == beam_size:
                if self.stop_early:
                    return True
                # stop if the best unfinalized score is worse than the worst
                # finalized one
                bbsz = sent*beam_size
                best_unfinalized_score = scores[bbsz:bbsz+beam_size].max()
                if self.normalize_scores:
                    best_unfinalized_score /= maxlen
                if worst_finalized[sent]['score'] >= best_unfinalized_score:
                    return True
            return False

        def finalize_hypos(step, bbsz_idx, scores):
            """
            Finalize the given hypotheses at this step, while keeping the total
            number of finalized hypotheses per sentence <= beam_size.

            Note: the input must be in the desired finalization order, so that
            hypotheses that appear earlier in the input are preferred to those
            that appear later.

            Args:
                step: current time step
                bbsz_idx: A vector of indices in the range [0, bsz*beam_size),
                    indicating which hypotheses to finalize
                scores: A vector of the same size as bbsz_idx containing scores
                    for each hypothesis
            """
            assert bbsz_idx.numel() == scores.numel()
            norm_scores = scores/math.pow(step+1, self.len_penalty) if self.normalize_scores else scores
            sents_seen = set()
            for idx, score in zip(bbsz_idx.cpu(), norm_scores.cpu()):
                sent = idx // beam_size
                sents_seen.add(sent)

                def get_hypo():
179
                    hypo = tokens[idx, 1:step+2].clone()  # skip the first index, which is EOS
Sergey Edunov's avatar
Sergey Edunov committed
180
                    hypo[step] = self.eos
181
182
                    attention = attn[idx, :, 1:step+2].clone()
                    _, alignment = attention.max(dim=0)
Sergey Edunov's avatar
Sergey Edunov committed
183
184
185
                    return {
                        'tokens': hypo,
                        'score': score,
186
                        'attention': attention,
Sergey Edunov's avatar
Sergey Edunov committed
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
                        'alignment': alignment,
                    }

                if len(finalized[sent]) < beam_size:
                    finalized[sent].append(get_hypo())
                elif score > worst_finalized[sent]['score']:
                    # replace worst hypo for this sentence with new/better one
                    worst_idx = worst_finalized[sent]['idx']
                    finalized[sent][worst_idx] = get_hypo()

                    # find new worst finalized hypo for this sentence
                    idx, s = min(enumerate(finalized[sent]), key=lambda r: r[1]['score'])
                    worst_finalized[sent] = {
                        'score': s['score'],
                        'idx': idx,
                    }

            # return number of hypotheses finished this step
            num_finished = 0
            for sent in sents_seen:
                # check termination conditions for this sentence
                if not finished[sent] and is_finished(sent):
                    finished[sent] = True
                    num_finished += 1
            return num_finished

        reorder_state = None
        for step in range(maxlen + 1):  # one extra step for EOS marker
            # reorder decoder internal states based on the prev choice of beams
            if reorder_state is not None:
                for model in self.models:
                    model.decoder.reorder_incremental_state(reorder_state)

            probs, avg_attn_scores = self._decode(tokens[:, :step+1], encoder_outs)
            if step == 0:
                # at the first step all hypotheses are equally likely, so use
                # only the first beam
                probs = probs.unfold(0, 1, beam_size).squeeze(2).contiguous()
            else:
                # make probs contain cumulative scores for each hypothesis
                probs.add_(scores.view(-1, 1))
228
            probs[:, self.pad] = -math.inf  # never select pad
Sergey Edunov's avatar
Sergey Edunov committed
229

230
231
            # Record attention scores
            attn[:, :, step+1].copy_(avg_attn_scores)
Sergey Edunov's avatar
Sergey Edunov committed
232
233
234
235
236
237

            # take the best 2 x beam_size predictions. We'll choose the first
            # beam_size of these which don't predict eos to continue with.
            cand_scores = buffer('cand_scores', type_of=scores)
            cand_indices = buffer('cand_indices')
            cand_beams = buffer('cand_beams')
238
239
240
            probs.view(bsz, -1).topk(
                min(cand_size, probs.view(bsz, -1).size(1) - 1),  # -1 so we never select pad
                out=(cand_scores, cand_indices))
Sergey Edunov's avatar
Sergey Edunov committed
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
            torch.div(cand_indices, self.vocab_size, out=cand_beams)
            cand_indices.fmod_(self.vocab_size)

            # cand_bbsz_idx contains beam indices for the top candidate
            # hypotheses, with a range of values: [0, bsz*beam_size),
            # and dimensions: [bsz, cand_size]
            cand_bbsz_idx = cand_beams.add_(bbsz_offsets)

            # finalize hypotheses that end in eos
            eos_mask = cand_indices.eq(self.eos)
            if step >= self.minlen:
                eos_bbsz_idx = buffer('eos_bbsz_idx')
                cand_bbsz_idx.masked_select(eos_mask, out=eos_bbsz_idx)
                if eos_bbsz_idx.numel() > 0:
                    eos_scores = buffer('eos_scores', type_of=scores)
                    cand_scores.masked_select(eos_mask, out=eos_scores)
                    num_remaining_sent -= finalize_hypos(step, eos_bbsz_idx, eos_scores)

            assert num_remaining_sent >= 0
            if num_remaining_sent == 0:
                break

            # set active_mask so that values > cand_size indicate eos hypos
            # and values < cand_size indicate candidate active hypos.
            # After, the min values per row are the top candidate active hypos
            active_mask = buffer('active_mask')
267
            torch.add((eos_mask*cand_size).type_as(cand_offsets), cand_offsets[:eos_mask.size(1)],
Sergey Edunov's avatar
Sergey Edunov committed
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
                      out=active_mask)

            # get the top beam_size active hypotheses, which are just the hypos
            # with the smallest values in active_mask
            active_hypos, _ignore = buffer('active_hypos'), buffer('_ignore')
            active_mask.topk(beam_size, 1, largest=False, out=(_ignore, active_hypos))
            active_bbsz_idx = buffer('active_bbsz_idx')
            cand_bbsz_idx.gather(1, active_hypos, out=active_bbsz_idx)
            active_scores = cand_scores.gather(1, active_hypos,
                                               out=scores.view(bsz, beam_size))

            active_bbsz_idx = active_bbsz_idx.view(-1)
            active_scores = active_scores.view(-1)

            # finalize all active hypotheses once we hit maxlen
            # finalize_hypos will take care of adding the EOS markers
            if step == maxlen:
                num_remaining_sent -= finalize_hypos(step, active_bbsz_idx, active_scores)
                assert num_remaining_sent == 0
                break

            # copy tokens for active hypotheses
            torch.index_select(tokens[:, :step+1], dim=0, index=active_bbsz_idx,
                               out=tokens_buf[:, :step+1])
            cand_indices.gather(1, active_hypos,
                                out=tokens_buf.view(bsz, beam_size, -1)[:, :, step+1])

295
296
297
            # copy attention for active hypotheses
            torch.index_select(attn[:, :, :step+2], dim=0, index=active_bbsz_idx,
                               out=attn_buf[:, :, :step+2])
Sergey Edunov's avatar
Sergey Edunov committed
298
299
300
301
302

            # swap buffers
            old_tokens = tokens
            tokens = tokens_buf
            tokens_buf = old_tokens
303
304
305
            old_attn = attn
            attn = attn_buf
            attn_buf = old_attn
Sergey Edunov's avatar
Sergey Edunov committed
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353

            # reorder incremental state in decoder
            reorder_state = active_bbsz_idx

        # sort by score descending
        for sent in range(bsz):
            finalized[sent] = sorted(finalized[sent], key=lambda r: r['score'], reverse=True)

        return finalized

    def _decode(self, tokens, encoder_outs):
        length = tokens.size(1)

        # repeat the first length positions to fill batch
        positions = self.positions[:length].view(1, length)

        # wrap in Variables
        tokens = Variable(tokens, volatile=True)
        positions = Variable(positions, volatile=True)

        avg_probs = None
        avg_attn = None
        for model, encoder_out in zip(self.models, encoder_outs):
            decoder_out, attn = model.decoder(tokens, positions, encoder_out)
            probs = F.softmax(decoder_out[:, -1, :]).data
            attn = attn[:, -1, :].data
            if avg_probs is None or avg_attn is None:
                avg_probs = probs
                avg_attn = attn
            else:
                avg_probs.add_(probs)
                avg_attn.add_(attn)
        avg_probs.div_(len(self.models))
        avg_probs.log_()
        avg_attn.div_(len(self.models))

        return avg_probs, avg_attn

    def _expand_encoder_out(self, encoder_out, beam_size):
        res = []
        for tensor in encoder_out:
            res.append(
                # repeat beam_size times along second dimension
                tensor.repeat(1, beam_size, *[1 for i in range(tensor.dim()-2)]) \
                # then collapse into [bsz*beam, ...original dims...]
                .view(-1, *tensor.size()[1:])
            )
        return tuple(res)