sequence_generator.py 23.4 KB
Newer Older
Sergey Edunov's avatar
Sergey Edunov committed
1
2
3
4
5
6
7
8
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import math
Myle Ott's avatar
Myle Ott committed
9

Sergey Edunov's avatar
Sergey Edunov committed
10
11
12
import torch

from fairseq import utils
Myle Ott's avatar
Myle Ott committed
13
from fairseq.models import FairseqIncrementalDecoder
Sergey Edunov's avatar
Sergey Edunov committed
14
15
16


class SequenceGenerator(object):
Myle Ott's avatar
Nits  
Myle Ott committed
17
    def __init__(
Myle Ott's avatar
Myle Ott committed
18
        self, models, tgt_dict, beam_size=1, minlen=1, maxlen=None, stop_early=True,
Myle Ott's avatar
Nits  
Myle Ott committed
19
20
21
        normalize_scores=True, len_penalty=1, unk_penalty=0, retain_dropout=False,
        sampling=False, sampling_topk=-1, sampling_temperature=1,
    ):
Sergey Edunov's avatar
Sergey Edunov committed
22
23
24
25
26
27
28
29
30
31
        """Generates translations of a given source sentence.
        Args:
            min/maxlen: The length of the generated output will be bounded by
                minlen and maxlen (not including the end-of-sentence marker).
            stop_early: Stop generation immediately after we finalize beam_size
                hypotheses, even though longer hypotheses might have better
                normalized scores.
            normalize_scores: Normalize scores by the length of the output.
        """
        self.models = models
Myle Ott's avatar
Myle Ott committed
32
33
34
35
        self.pad = tgt_dict.pad()
        self.unk = tgt_dict.unk()
        self.eos = tgt_dict.eos()
        self.vocab_size = len(tgt_dict)
Sergey Edunov's avatar
Sergey Edunov committed
36
37
        self.beam_size = beam_size
        self.minlen = minlen
Myle Ott's avatar
Myle Ott committed
38
        max_decoder_len = min(m.max_decoder_positions() for m in self.models)
39
        max_decoder_len -= 1  # we define maxlen not including the EOS marker
40
        self.maxlen = max_decoder_len if maxlen is None else min(maxlen, max_decoder_len)
Sergey Edunov's avatar
Sergey Edunov committed
41
42
43
        self.stop_early = stop_early
        self.normalize_scores = normalize_scores
        self.len_penalty = len_penalty
44
        self.unk_penalty = unk_penalty
45
        self.retain_dropout = retain_dropout
Myle Ott's avatar
Myle Ott committed
46
        self.sampling = sampling
47
48
        self.sampling_topk = sampling_topk
        self.sampling_temperature = sampling_temperature
Sergey Edunov's avatar
Sergey Edunov committed
49
50
51
52
53
54

    def cuda(self):
        for model in self.models:
            model.cuda()
        return self

Myle Ott's avatar
Nits  
Myle Ott committed
55
56
57
58
    def generate_batched_itr(
        self, data_itr, beam_size=None, maxlen_a=0.0, maxlen_b=None,
        cuda=False, timer=None, prefix_size=0,
    ):
Sergey Edunov's avatar
Sergey Edunov committed
59
60
61
62
        """Iterate over a batched dataset and yield individual translations.
        Args:
            maxlen_a/b: generate sequences of maximum length ax + b,
                where x is the source sentence length.
Myle Ott's avatar
Myle Ott committed
63
            cuda: use GPU for generation
Sergey Edunov's avatar
Sergey Edunov committed
64
65
            timer: StopwatchMeter for timing generations.
        """
Myle Ott's avatar
Myle Ott committed
66
67
68
        if maxlen_b is None:
            maxlen_b = self.maxlen

Sergey Edunov's avatar
Sergey Edunov committed
69
        for sample in data_itr:
Myle Ott's avatar
Myle Ott committed
70
            s = utils.move_to_cuda(sample) if cuda else sample
Myle Ott's avatar
Myle Ott committed
71
72
            if 'net_input' not in s:
                continue
Sergey Edunov's avatar
Sergey Edunov committed
73
74
75
76
            input = s['net_input']
            srclen = input['src_tokens'].size(1)
            if timer is not None:
                timer.start()
Myle Ott's avatar
Myle Ott committed
77
            with torch.no_grad():
Myle Ott's avatar
Myle Ott committed
78
79
80
81
82
                hypos = self.generate(
                    input['src_tokens'],
                    input['src_lengths'],
                    beam_size=beam_size,
                    maxlen=int(maxlen_a*srclen + maxlen_b),
Dario Pavllo's avatar
Dario Pavllo committed
83
                    prefix_tokens=s['target'][:, :prefix_size] if prefix_size > 0 else None,
Myle Ott's avatar
Myle Ott committed
84
                )
Sergey Edunov's avatar
Sergey Edunov committed
85
            if timer is not None:
Myle Ott's avatar
Myle Ott committed
86
                timer.stop(sum(len(h[0]['tokens']) for h in hypos))
Myle Ott's avatar
Myle Ott committed
87
            for i, id in enumerate(s['id'].data):
88
89
                # remove padding
                src = utils.strip_pad(input['src_tokens'].data[i, :], self.pad)
90
                ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None
Sergey Edunov's avatar
Sergey Edunov committed
91
92
                yield id, src, ref, hypos[i]

Dario Pavllo's avatar
Dario Pavllo committed
93
    def generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None):
Sergey Edunov's avatar
Sergey Edunov committed
94
        """Generate a batch of translations."""
Myle Ott's avatar
Myle Ott committed
95
        with torch.no_grad():
96
            return self._generate(src_tokens, src_lengths, beam_size, maxlen, prefix_tokens)
Sergey Edunov's avatar
Sergey Edunov committed
97

Dario Pavllo's avatar
Dario Pavllo committed
98
    def _generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None):
Myle Ott's avatar
Myle Ott committed
99
        bsz, srclen = src_tokens.size()
Sergey Edunov's avatar
Sergey Edunov committed
100
101
        maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen

102
103
        # the max beam size is the dictionary size - 1, since we never select pad
        beam_size = beam_size if beam_size is not None else self.beam_size
104
        beam_size = min(beam_size, self.vocab_size - 1)
105

Sergey Edunov's avatar
Sergey Edunov committed
106
        encoder_outs = []
107
        incremental_states = {}
Sergey Edunov's avatar
Sergey Edunov committed
108
        for model in self.models:
109
110
            if not self.retain_dropout:
                model.eval()
Myle Ott's avatar
Myle Ott committed
111
            if isinstance(model.decoder, FairseqIncrementalDecoder):
112
113
114
                incremental_states[model] = {}
            else:
                incremental_states[model] = None
Sergey Edunov's avatar
Sergey Edunov committed
115

Myle Ott's avatar
Myle Ott committed
116
            # compute the encoder output for each beam
Myle Ott's avatar
Myle Ott committed
117
118
            encoder_out = model.encoder(
                src_tokens.repeat(1, beam_size).view(-1, srclen),
119
                src_lengths.expand(beam_size, src_lengths.numel()).t().contiguous().view(-1),
Myle Ott's avatar
Myle Ott committed
120
            )
Sergey Edunov's avatar
Sergey Edunov committed
121
122
123
            encoder_outs.append(encoder_out)

        # initialize buffers
Myle Ott's avatar
Myle Ott committed
124
125
        scores = src_tokens.data.new(bsz * beam_size, maxlen + 1).float().fill_(0)
        scores_buf = scores.clone()
Sergey Edunov's avatar
Sergey Edunov committed
126
127
128
        tokens = src_tokens.data.new(bsz * beam_size, maxlen + 2).fill_(self.pad)
        tokens_buf = tokens.clone()
        tokens[:, 0] = self.eos
129
130
        attn = scores.new(bsz * beam_size, src_tokens.size(1), maxlen + 2)
        attn_buf = attn.clone()
Sergey Edunov's avatar
Sergey Edunov committed
131
132
133
134

        # list of completed sentences
        finalized = [[] for i in range(bsz)]
        finished = [False for i in range(bsz)]
Myle Ott's avatar
Myle Ott committed
135
        worst_finalized = [{'idx': None, 'score': -math.inf} for i in range(bsz)]
Sergey Edunov's avatar
Sergey Edunov committed
136
137
138
139
140
141
        num_remaining_sent = bsz

        # number of candidate hypos per step
        cand_size = 2 * beam_size  # 2 x beam size in case half are EOS

        # offset arrays for converting between different indexing schemes
142
        bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens)
Sergey Edunov's avatar
Sergey Edunov committed
143
144
145
146
        cand_offsets = torch.arange(0, cand_size).type_as(tokens)

        # helper function for allocating buffers on the fly
        buffers = {}
147

Louis Martin's avatar
Louis Martin committed
148
        def buffer(name, type_of=tokens):  # noqa
Sergey Edunov's avatar
Sergey Edunov committed
149
150
151
152
            if name not in buffers:
                buffers[name] = type_of.new()
            return buffers[name]

Myle Ott's avatar
Myle Ott committed
153
        def is_finished(sent, step, unfinalized_scores=None):
Sergey Edunov's avatar
Sergey Edunov committed
154
155
156
157
158
159
160
            """
            Check whether we've finished generation for a given sentence, by
            comparing the worst score among finalized hypotheses to the best
            possible score among unfinalized hypotheses.
            """
            assert len(finalized[sent]) <= beam_size
            if len(finalized[sent]) == beam_size:
Myle Ott's avatar
Myle Ott committed
161
                if self.stop_early or step == maxlen or unfinalized_scores is None:
Sergey Edunov's avatar
Sergey Edunov committed
162
163
164
                    return True
                # stop if the best unfinalized score is worse than the worst
                # finalized one
Myle Ott's avatar
Myle Ott committed
165
                best_unfinalized_score = unfinalized_scores[sent].max()
Sergey Edunov's avatar
Sergey Edunov committed
166
                if self.normalize_scores:
167
                    best_unfinalized_score /= maxlen ** self.len_penalty
Sergey Edunov's avatar
Sergey Edunov committed
168
169
170
171
                if worst_finalized[sent]['score'] >= best_unfinalized_score:
                    return True
            return False

Myle Ott's avatar
Myle Ott committed
172
        def finalize_hypos(step, bbsz_idx, eos_scores, unfinalized_scores=None):
Sergey Edunov's avatar
Sergey Edunov committed
173
174
175
176
177
178
179
180
181
182
            """
            Finalize the given hypotheses at this step, while keeping the total
            number of finalized hypotheses per sentence <= beam_size.
            Note: the input must be in the desired finalization order, so that
            hypotheses that appear earlier in the input are preferred to those
            that appear later.
            Args:
                step: current time step
                bbsz_idx: A vector of indices in the range [0, bsz*beam_size),
                    indicating which hypotheses to finalize
Myle Ott's avatar
Myle Ott committed
183
184
185
186
                eos_scores: A vector of the same size as bbsz_idx containing
                    scores for each hypothesis
                unfinalized_scores: A vector containing scores for all
                    unfinalized hypotheses
Sergey Edunov's avatar
Sergey Edunov committed
187
            """
Myle Ott's avatar
Myle Ott committed
188
189
190
191
            assert bbsz_idx.numel() == eos_scores.numel()

            # clone relevant token and attention tensors
            tokens_clone = tokens.index_select(0, bbsz_idx)
192
            tokens_clone = tokens_clone[:, 1:step + 2]  # skip the first index, which is EOS
Myle Ott's avatar
Myle Ott committed
193
194
195
196
197
198
199
200
201
202
203
            tokens_clone[:, step] = self.eos
            attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1:step+2]

            # compute scores per token position
            pos_scores = scores.index_select(0, bbsz_idx)[:, :step+1]
            pos_scores[:, step] = eos_scores
            # convert from cumulative to per-position scores
            pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]

            # normalize sentence-level scores
            if self.normalize_scores:
204
                eos_scores /= (step + 1) ** self.len_penalty
Myle Ott's avatar
Myle Ott committed
205

206
207
208
209
210
211
212
213
            cum_unfin = []
            prev = 0
            for f in finished:
                if f:
                    prev += 1
                else:
                    cum_unfin.append(prev)

Sergey Edunov's avatar
Sergey Edunov committed
214
            sents_seen = set()
Myle Ott's avatar
Myle Ott committed
215
            for i, (idx, score) in enumerate(zip(bbsz_idx.tolist(), eos_scores.tolist())):
216
217
218
219
                unfin_idx = idx // beam_size
                sent = unfin_idx + cum_unfin[unfin_idx]

                sents_seen.add((sent, unfin_idx))
Sergey Edunov's avatar
Sergey Edunov committed
220
221

                def get_hypo():
222
223
224
225
226
227

                    # remove padding tokens from attn scores
                    nonpad_idxs = src_tokens[sent].ne(self.pad)
                    hypo_attn = attn_clone[i][nonpad_idxs]
                    _, alignment = hypo_attn.max(dim=0)

Sergey Edunov's avatar
Sergey Edunov committed
228
                    return {
Myle Ott's avatar
Myle Ott committed
229
                        'tokens': tokens_clone[i],
Sergey Edunov's avatar
Sergey Edunov committed
230
                        'score': score,
231
                        'attention': hypo_attn,  # src_len x tgt_len
Sergey Edunov's avatar
Sergey Edunov committed
232
                        'alignment': alignment,
Myle Ott's avatar
Myle Ott committed
233
                        'positional_scores': pos_scores[i],
Sergey Edunov's avatar
Sergey Edunov committed
234
235
236
237
                    }

                if len(finalized[sent]) < beam_size:
                    finalized[sent].append(get_hypo())
Myle Ott's avatar
Myle Ott committed
238
                elif not self.stop_early and score > worst_finalized[sent]['score']:
Sergey Edunov's avatar
Sergey Edunov committed
239
240
                    # replace worst hypo for this sentence with new/better one
                    worst_idx = worst_finalized[sent]['idx']
Myle Ott's avatar
Myle Ott committed
241
242
                    if worst_idx is not None:
                        finalized[sent][worst_idx] = get_hypo()
Sergey Edunov's avatar
Sergey Edunov committed
243
244
245
246
247
248
249
250

                    # find new worst finalized hypo for this sentence
                    idx, s = min(enumerate(finalized[sent]), key=lambda r: r[1]['score'])
                    worst_finalized[sent] = {
                        'score': s['score'],
                        'idx': idx,
                    }

251
252
            newly_finished = []
            for sent, unfin_idx in sents_seen:
Sergey Edunov's avatar
Sergey Edunov committed
253
                # check termination conditions for this sentence
Myle Ott's avatar
Myle Ott committed
254
                if not finished[sent] and is_finished(sent, step, unfinalized_scores):
Sergey Edunov's avatar
Sergey Edunov committed
255
                    finished[sent] = True
256
257
                    newly_finished.append(unfin_idx)
            return newly_finished
Sergey Edunov's avatar
Sergey Edunov committed
258
259

        reorder_state = None
260
        batch_idxs = None
Sergey Edunov's avatar
Sergey Edunov committed
261
262
263
        for step in range(maxlen + 1):  # one extra step for EOS marker
            # reorder decoder internal states based on the prev choice of beams
            if reorder_state is not None:
264
265
266
267
268
                if batch_idxs is not None:
                    # update beam indices to take into account removed sentences
                    corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(batch_idxs)
                    reorder_state.view(-1, beam_size).add_(corr.unsqueeze(-1) * beam_size)
                for i, model in enumerate(self.models):
Myle Ott's avatar
Myle Ott committed
269
                    if isinstance(model.decoder, FairseqIncrementalDecoder):
270
                        model.decoder.reorder_incremental_state(incremental_states[model], reorder_state)
271
                    encoder_outs[i] = model.encoder.reorder_encoder_out(encoder_outs[i], reorder_state)
Sergey Edunov's avatar
Sergey Edunov committed
272

273
            probs, avg_attn_scores = self._decode(
274
                tokens[:, :step + 1], encoder_outs, incremental_states)
Sergey Edunov's avatar
Sergey Edunov committed
275
276
277
278
            if step == 0:
                # at the first step all hypotheses are equally likely, so use
                # only the first beam
                probs = probs.unfold(0, 1, beam_size).squeeze(2).contiguous()
Myle Ott's avatar
Myle Ott committed
279
280
                scores = scores.type_as(probs)
                scores_buf = scores_buf.type_as(probs)
Myle Ott's avatar
Myle Ott committed
281
            elif not self.sampling:
Sergey Edunov's avatar
Sergey Edunov committed
282
                # make probs contain cumulative scores for each hypothesis
283
                probs.add_(scores[:, step - 1].view(-1, 1))
Myle Ott's avatar
Myle Ott committed
284

285
            probs[:, self.pad] = -math.inf  # never select pad
286
            probs[:, self.unk] -= self.unk_penalty  # apply unk penalty
Sergey Edunov's avatar
Sergey Edunov committed
287

288
            # Record attention scores
289
            attn[:, :, step + 1].copy_(avg_attn_scores)
Sergey Edunov's avatar
Sergey Edunov committed
290
291
292
293

            cand_scores = buffer('cand_scores', type_of=scores)
            cand_indices = buffer('cand_indices')
            cand_beams = buffer('cand_beams')
Myle Ott's avatar
Myle Ott committed
294
295
296
            eos_bbsz_idx = buffer('eos_bbsz_idx')
            eos_scores = buffer('eos_scores', type_of=scores)
            if step < maxlen:
Dario Pavllo's avatar
Dario Pavllo committed
297
298
                if prefix_tokens is not None and step < prefix_tokens.size(1):
                    probs_slice = probs.view(bsz, -1, probs.size(-1))[:, 0, :]
299
300
                    cand_scores = torch.gather(
                        probs_slice, dim=1,
Dario Pavllo's avatar
Dario Pavllo committed
301
302
303
304
                        index=prefix_tokens[:, step].view(-1, 1).data
                    ).expand(-1, cand_size)
                    cand_indices = prefix_tokens[:, step].view(-1, 1).expand(bsz, cand_size).data
                    cand_beams.resize_as_(cand_indices).fill_(0)
Myle Ott's avatar
Myle Ott committed
305
306
                elif self.sampling:
                    assert self.pad == 1, 'sampling assumes the first two symbols can be ignored'
307
308
309
310
311
312
313
314
315
316

                    if self.sampling_topk > 0:
                        values, indices = probs[:, 2:].topk(self.sampling_topk)
                        exp_probs = values.div_(self.sampling_temperature).exp()
                        if step == 0:
                            torch.multinomial(exp_probs, beam_size, replacement=True, out=cand_indices)
                        else:
                            torch.multinomial(exp_probs, 1, replacement=True, out=cand_indices)
                        torch.gather(exp_probs, dim=1, index=cand_indices, out=cand_scores)
                        torch.gather(indices, dim=1, index=cand_indices, out=cand_indices)
Myle Ott's avatar
Myle Ott committed
317
318
                        cand_indices.add_(2)
                    else:
319
320
321
322
323
324
325
326
                        exp_probs = probs.div_(self.sampling_temperature).exp_().view(-1, self.vocab_size)

                        if step == 0:
                            # we exclude the first two vocab items, one of which is pad
                            torch.multinomial(exp_probs[:, 2:], beam_size, replacement=True, out=cand_indices)
                        else:
                            torch.multinomial(exp_probs[:, 2:], 1, replacement=True, out=cand_indices)

Myle Ott's avatar
Myle Ott committed
327
                        cand_indices.add_(2)
328
329
                        torch.gather(exp_probs, dim=1, index=cand_indices, out=cand_scores)

Myle Ott's avatar
Myle Ott committed
330
331
332
333
334
335
336
337
338
339
                    cand_scores.log_()
                    cand_indices = cand_indices.view(bsz, -1).repeat(1, 2)
                    cand_scores = cand_scores.view(bsz, -1).repeat(1, 2)
                    if step == 0:
                        cand_beams = torch.zeros(bsz, cand_size).type_as(cand_indices)
                    else:
                        cand_beams = torch.arange(0, beam_size).repeat(bsz, 2).type_as(cand_indices)
                        # make scores cumulative
                        cand_scores.add_(
                            torch.gather(
340
                                scores[:, step - 1].view(bsz, beam_size), dim=1,
Myle Ott's avatar
Myle Ott committed
341
342
343
                                index=cand_beams,
                            )
                        )
Dario Pavllo's avatar
Dario Pavllo committed
344
345
346
347
348
349
350
351
352
353
                else:
                    # take the best 2 x beam_size predictions. We'll choose the first
                    # beam_size of these which don't predict eos to continue with.
                    torch.topk(
                        probs.view(bsz, -1),
                        k=min(cand_size, probs.view(bsz, -1).size(1) - 1),  # -1 so we never select pad
                        out=(cand_scores, cand_indices),
                    )
                    torch.div(cand_indices, self.vocab_size, out=cand_beams)
                    cand_indices.fmod_(self.vocab_size)
Myle Ott's avatar
Myle Ott committed
354
355
356
357
358
359
360
361
            else:
                # finalize all active hypotheses once we hit maxlen
                # pick the hypothesis with the highest prob of EOS right now
                torch.sort(
                    probs[:, self.eos],
                    descending=True,
                    out=(eos_scores, eos_bbsz_idx),
                )
362
363
                num_remaining_sent -= len(finalize_hypos(
                    step, eos_bbsz_idx, eos_scores))
Myle Ott's avatar
Myle Ott committed
364
365
                assert num_remaining_sent == 0
                break
Sergey Edunov's avatar
Sergey Edunov committed
366
367
368
369

            # cand_bbsz_idx contains beam indices for the top candidate
            # hypotheses, with a range of values: [0, bsz*beam_size),
            # and dimensions: [bsz, cand_size]
370
            cand_bbsz_idx = cand_beams.add(bbsz_offsets)
Sergey Edunov's avatar
Sergey Edunov committed
371
372
373

            # finalize hypotheses that end in eos
            eos_mask = cand_indices.eq(self.eos)
374
375

            finalized_sents = set()
Sergey Edunov's avatar
Sergey Edunov committed
376
            if step >= self.minlen:
377
                # only consider eos when it's among the top beam_size indices
Myle Ott's avatar
Myle Ott committed
378
379
380
381
382
                torch.masked_select(
                    cand_bbsz_idx[:, :beam_size],
                    mask=eos_mask[:, :beam_size],
                    out=eos_bbsz_idx,
                )
Sergey Edunov's avatar
Sergey Edunov committed
383
                if eos_bbsz_idx.numel() > 0:
Myle Ott's avatar
Myle Ott committed
384
385
386
387
388
                    torch.masked_select(
                        cand_scores[:, :beam_size],
                        mask=eos_mask[:, :beam_size],
                        out=eos_scores,
                    )
389
                    finalized_sents = finalize_hypos(
Myle Ott's avatar
Myle Ott committed
390
                        step, eos_bbsz_idx, eos_scores, cand_scores)
391
                    num_remaining_sent -= len(finalized_sents)
Sergey Edunov's avatar
Sergey Edunov committed
392
393
394
395

            assert num_remaining_sent >= 0
            if num_remaining_sent == 0:
                break
Myle Ott's avatar
Myle Ott committed
396
            assert step < maxlen
Sergey Edunov's avatar
Sergey Edunov committed
397

398
399
400
            if len(finalized_sents) > 0:
                new_bsz = bsz - len(finalized_sents)

Myle Ott's avatar
Myle Ott committed
401
                # construct batch_idxs which holds indices of batches to keep for the next pass
402
                batch_mask = torch.ones(bsz).type_as(cand_indices)
403
                batch_mask[cand_indices.new(finalized_sents)] = 0
404
405
406
407
408
409
410
411
412
                batch_idxs = batch_mask.nonzero().squeeze(-1)

                eos_mask = eos_mask[batch_idxs]
                cand_beams = cand_beams[batch_idxs]
                bbsz_offsets.resize_(new_bsz, 1)
                cand_bbsz_idx = cand_beams.add(bbsz_offsets)

                cand_scores = cand_scores[batch_idxs]
                cand_indices = cand_indices[batch_idxs]
Myle Ott's avatar
Myle Ott committed
413
414
                if prefix_tokens is not None:
                    prefix_tokens = prefix_tokens[batch_idxs]
415
416
417
418
419
420
421
422
423
424
425

                scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
                scores_buf.resize_as_(scores)
                tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
                tokens_buf.resize_as_(tokens)
                attn = attn.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, attn.size(1), -1)
                attn_buf.resize_as_(attn)
                bsz = new_bsz
            else:
                batch_idxs = None

Sergey Edunov's avatar
Sergey Edunov committed
426
427
428
429
            # set active_mask so that values > cand_size indicate eos hypos
            # and values < cand_size indicate candidate active hypos.
            # After, the min values per row are the top candidate active hypos
            active_mask = buffer('active_mask')
Myle Ott's avatar
Myle Ott committed
430
            torch.add(
431
                eos_mask.type_as(cand_offsets) * cand_size,
Myle Ott's avatar
Myle Ott committed
432
433
434
                cand_offsets[:eos_mask.size(1)],
                out=active_mask,
            )
Sergey Edunov's avatar
Sergey Edunov committed
435
436
437
438

            # get the top beam_size active hypotheses, which are just the hypos
            # with the smallest values in active_mask
            active_hypos, _ignore = buffer('active_hypos'), buffer('_ignore')
439
440
441
            torch.topk(
                active_mask, k=beam_size, dim=1, largest=False,
                out=(_ignore, active_hypos)
Myle Ott's avatar
Myle Ott committed
442
            )
Sergey Edunov's avatar
Sergey Edunov committed
443
            active_bbsz_idx = buffer('active_bbsz_idx')
444
445
            torch.gather(
                cand_bbsz_idx, dim=1, index=active_hypos,
Myle Ott's avatar
Myle Ott committed
446
447
                out=active_bbsz_idx,
            )
448
449
            active_scores = torch.gather(
                cand_scores, dim=1, index=active_hypos,
Myle Ott's avatar
Myle Ott committed
450
451
                out=scores[:, step].view(bsz, beam_size),
            )
452

Sergey Edunov's avatar
Sergey Edunov committed
453
454
455
            active_bbsz_idx = active_bbsz_idx.view(-1)
            active_scores = active_scores.view(-1)

Myle Ott's avatar
Myle Ott committed
456
457
            # copy tokens and scores for active hypotheses
            torch.index_select(
458
459
                tokens[:, :step + 1], dim=0, index=active_bbsz_idx,
                out=tokens_buf[:, :step + 1],
Myle Ott's avatar
Myle Ott committed
460
461
462
            )
            torch.gather(
                cand_indices, dim=1, index=active_hypos,
463
                out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1],
Myle Ott's avatar
Myle Ott committed
464
465
466
467
468
469
470
471
472
473
            )
            if step > 0:
                torch.index_select(
                    scores[:, :step], dim=0, index=active_bbsz_idx,
                    out=scores_buf[:, :step],
                )
            torch.gather(
                cand_scores, dim=1, index=active_hypos,
                out=scores_buf.view(bsz, beam_size, -1)[:, :, step],
            )
Sergey Edunov's avatar
Sergey Edunov committed
474

475
            # copy attention for active hypotheses
Myle Ott's avatar
Myle Ott committed
476
            torch.index_select(
477
478
                attn[:, :, :step + 2], dim=0, index=active_bbsz_idx,
                out=attn_buf[:, :, :step + 2],
Myle Ott's avatar
Myle Ott committed
479
            )
Sergey Edunov's avatar
Sergey Edunov committed
480
481

            # swap buffers
482
483
484
            tokens, tokens_buf = tokens_buf, tokens
            scores, scores_buf = scores_buf, scores
            attn, attn_buf = attn_buf, attn
Sergey Edunov's avatar
Sergey Edunov committed
485
486
487
488
489

            # reorder incremental state in decoder
            reorder_state = active_bbsz_idx

        # sort by score descending
490
        for sent in range(len(finalized)):
Sergey Edunov's avatar
Sergey Edunov committed
491
492
493
494
            finalized[sent] = sorted(finalized[sent], key=lambda r: r['score'], reverse=True)

        return finalized

495
    def _decode(self, tokens, encoder_outs, incremental_states):
496
497
498
        if len(self.models) == 1:
            return self._decode_one(tokens, self.models[0], encoder_outs[0], incremental_states, log_probs=True)

Sergey Edunov's avatar
Sergey Edunov committed
499
500
501
        avg_probs = None
        avg_attn = None
        for model, encoder_out in zip(self.models, encoder_outs):
502
            probs, attn = self._decode_one(tokens, model, encoder_out, incremental_states, log_probs=False)
503
            if avg_probs is None:
Sergey Edunov's avatar
Sergey Edunov committed
504
505
506
                avg_probs = probs
            else:
                avg_probs.add_(probs)
507
508
509
510
511
            if attn is not None:
                if avg_attn is None:
                    avg_attn = attn
                else:
                    avg_attn.add_(attn)
Sergey Edunov's avatar
Sergey Edunov committed
512
513
        avg_probs.div_(len(self.models))
        avg_probs.log_()
514
515
        if avg_attn is not None:
            avg_attn.div_(len(self.models))
Sergey Edunov's avatar
Sergey Edunov committed
516
        return avg_probs, avg_attn
517
518
519
520
521
522
523
524
525
526
527
528
529

    def _decode_one(self, tokens, model, encoder_out, incremental_states, log_probs):
        with torch.no_grad():
            if incremental_states[model] is not None:
                decoder_out = list(model.decoder(tokens, encoder_out, incremental_states[model]))
            else:
                decoder_out = list(model.decoder(tokens, encoder_out))
            decoder_out[0] = decoder_out[0][:, -1, :]
            attn = decoder_out[1]
            if attn is not None:
                attn = attn[:, -1, :]
        probs = model.get_normalized_probs(decoder_out, log_probs=log_probs)
        return probs, attn