sequence_generator.py 22.3 KB
Newer Older
Sergey Edunov's avatar
Sergey Edunov committed
1
2
3
4
5
6
7
8
9
10
11
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import math
import torch

from fairseq import utils
Myle Ott's avatar
Myle Ott committed
12
from fairseq.models import FairseqIncrementalDecoder
Sergey Edunov's avatar
Sergey Edunov committed
13
14
15


class SequenceGenerator(object):
16
    def __init__(self, models, beam_size=1, minlen=1, maxlen=None,
17
                 stop_early=True, normalize_scores=True, len_penalty=1,
Myle Ott's avatar
Myle Ott committed
18
                 unk_penalty=0, retain_dropout=False, sampling=False):
Sergey Edunov's avatar
Sergey Edunov committed
19
20
21
22
23
24
25
26
27
28
29
        """Generates translations of a given source sentence.

        Args:
            min/maxlen: The length of the generated output will be bounded by
                minlen and maxlen (not including the end-of-sentence marker).
            stop_early: Stop generation immediately after we finalize beam_size
                hypotheses, even though longer hypotheses might have better
                normalized scores.
            normalize_scores: Normalize scores by the length of the output.
        """
        self.models = models
30
        self.pad = models[0].dst_dict.pad()
31
        self.unk = models[0].dst_dict.unk()
32
33
        self.eos = models[0].dst_dict.eos()
        assert all(m.dst_dict.pad() == self.pad for m in self.models[1:])
34
        assert all(m.dst_dict.unk() == self.unk for m in self.models[1:])
35
36
        assert all(m.dst_dict.eos() == self.eos for m in self.models[1:])
        self.vocab_size = len(models[0].dst_dict)
Sergey Edunov's avatar
Sergey Edunov committed
37
38
        self.beam_size = beam_size
        self.minlen = minlen
Myle Ott's avatar
Myle Ott committed
39
        max_decoder_len = min(m.max_decoder_positions() for m in self.models)
40
        max_decoder_len -= 1  # we define maxlen not including the EOS marker
41
        self.maxlen = max_decoder_len if maxlen is None else min(maxlen, max_decoder_len)
Sergey Edunov's avatar
Sergey Edunov committed
42
43
44
        self.stop_early = stop_early
        self.normalize_scores = normalize_scores
        self.len_penalty = len_penalty
45
        self.unk_penalty = unk_penalty
46
        self.retain_dropout = retain_dropout
Myle Ott's avatar
Myle Ott committed
47
        self.sampling = sampling
Sergey Edunov's avatar
Sergey Edunov committed
48
49
50
51
52
53

    def cuda(self):
        for model in self.models:
            model.cuda()
        return self

Myle Ott's avatar
Myle Ott committed
54
    def generate_batched_itr(self, data_itr, beam_size=None, maxlen_a=0.0, maxlen_b=None,
Dario Pavllo's avatar
Dario Pavllo committed
55
                             cuda=False, timer=None, prefix_size=0):
Sergey Edunov's avatar
Sergey Edunov committed
56
57
58
59
60
        """Iterate over a batched dataset and yield individual translations.

        Args:
            maxlen_a/b: generate sequences of maximum length ax + b,
                where x is the source sentence length.
Myle Ott's avatar
Myle Ott committed
61
            cuda: use GPU for generation
Sergey Edunov's avatar
Sergey Edunov committed
62
63
            timer: StopwatchMeter for timing generations.
        """
Myle Ott's avatar
Myle Ott committed
64
65
66
        if maxlen_b is None:
            maxlen_b = self.maxlen

Sergey Edunov's avatar
Sergey Edunov committed
67
        for sample in data_itr:
Myle Ott's avatar
Myle Ott committed
68
            s = utils.make_variable(sample, volatile=True, cuda=cuda)
Sergey Edunov's avatar
Sergey Edunov committed
69
70
71
72
            input = s['net_input']
            srclen = input['src_tokens'].size(1)
            if timer is not None:
                timer.start()
73
            with utils.maybe_no_grad():
Myle Ott's avatar
Myle Ott committed
74
75
76
77
78
                hypos = self.generate(
                    input['src_tokens'],
                    input['src_lengths'],
                    beam_size=beam_size,
                    maxlen=int(maxlen_a*srclen + maxlen_b),
Dario Pavllo's avatar
Dario Pavllo committed
79
                    prefix_tokens=s['target'][:, :prefix_size] if prefix_size > 0 else None,
Myle Ott's avatar
Myle Ott committed
80
                )
Sergey Edunov's avatar
Sergey Edunov committed
81
            if timer is not None:
Myle Ott's avatar
Myle Ott committed
82
                timer.stop(sum(len(h[0]['tokens']) for h in hypos))
Myle Ott's avatar
Myle Ott committed
83
            for i, id in enumerate(s['id'].data):
84
85
                # remove padding
                src = utils.strip_pad(input['src_tokens'].data[i, :], self.pad)
86
                ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None
Sergey Edunov's avatar
Sergey Edunov committed
87
88
                yield id, src, ref, hypos[i]

Dario Pavllo's avatar
Dario Pavllo committed
89
    def generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None):
Sergey Edunov's avatar
Sergey Edunov committed
90
        """Generate a batch of translations."""
91
92
        with utils.maybe_no_grad():
            return self._generate(src_tokens, src_lengths, beam_size, maxlen, prefix_tokens)
Sergey Edunov's avatar
Sergey Edunov committed
93

Dario Pavllo's avatar
Dario Pavllo committed
94
    def _generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None):
Myle Ott's avatar
Myle Ott committed
95
        bsz, srclen = src_tokens.size()
Sergey Edunov's avatar
Sergey Edunov committed
96
97
        maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen

98
99
        # the max beam size is the dictionary size - 1, since we never select pad
        beam_size = beam_size if beam_size is not None else self.beam_size
100
        beam_size = min(beam_size, self.vocab_size - 1)
101

Sergey Edunov's avatar
Sergey Edunov committed
102
        encoder_outs = []
103
        incremental_states = {}
Sergey Edunov's avatar
Sergey Edunov committed
104
        for model in self.models:
105
106
            if not self.retain_dropout:
                model.eval()
Myle Ott's avatar
Myle Ott committed
107
            if isinstance(model.decoder, FairseqIncrementalDecoder):
108
109
110
                incremental_states[model] = {}
            else:
                incremental_states[model] = None
Sergey Edunov's avatar
Sergey Edunov committed
111

Myle Ott's avatar
Myle Ott committed
112
            # compute the encoder output for each beam
Myle Ott's avatar
Myle Ott committed
113
114
            encoder_out = model.encoder(
                src_tokens.repeat(1, beam_size).view(-1, srclen),
115
                src_lengths.expand(beam_size, src_lengths.numel()).t().contiguous().view(-1),
Myle Ott's avatar
Myle Ott committed
116
            )
Sergey Edunov's avatar
Sergey Edunov committed
117
118
119
            encoder_outs.append(encoder_out)

        # initialize buffers
Myle Ott's avatar
Myle Ott committed
120
121
        scores = src_tokens.data.new(bsz * beam_size, maxlen + 1).float().fill_(0)
        scores_buf = scores.clone()
Sergey Edunov's avatar
Sergey Edunov committed
122
123
124
        tokens = src_tokens.data.new(bsz * beam_size, maxlen + 2).fill_(self.pad)
        tokens_buf = tokens.clone()
        tokens[:, 0] = self.eos
125
126
        attn = scores.new(bsz * beam_size, src_tokens.size(1), maxlen + 2)
        attn_buf = attn.clone()
Sergey Edunov's avatar
Sergey Edunov committed
127
128
129
130

        # list of completed sentences
        finalized = [[] for i in range(bsz)]
        finished = [False for i in range(bsz)]
Myle Ott's avatar
Myle Ott committed
131
        worst_finalized = [{'idx': None, 'score': -math.inf} for i in range(bsz)]
Sergey Edunov's avatar
Sergey Edunov committed
132
133
134
135
136
137
138
139
140
141
142
        num_remaining_sent = bsz

        # number of candidate hypos per step
        cand_size = 2 * beam_size  # 2 x beam size in case half are EOS

        # offset arrays for converting between different indexing schemes
        bbsz_offsets = (torch.arange(0, bsz)*beam_size).unsqueeze(1).type_as(tokens)
        cand_offsets = torch.arange(0, cand_size).type_as(tokens)

        # helper function for allocating buffers on the fly
        buffers = {}
Louis Martin's avatar
Louis Martin committed
143
        def buffer(name, type_of=tokens):  # noqa
Sergey Edunov's avatar
Sergey Edunov committed
144
145
146
147
            if name not in buffers:
                buffers[name] = type_of.new()
            return buffers[name]

Myle Ott's avatar
Myle Ott committed
148
        def is_finished(sent, step, unfinalized_scores=None):
Sergey Edunov's avatar
Sergey Edunov committed
149
150
151
152
153
154
155
            """
            Check whether we've finished generation for a given sentence, by
            comparing the worst score among finalized hypotheses to the best
            possible score among unfinalized hypotheses.
            """
            assert len(finalized[sent]) <= beam_size
            if len(finalized[sent]) == beam_size:
Myle Ott's avatar
Myle Ott committed
156
                if self.stop_early or step == maxlen or unfinalized_scores is None:
Sergey Edunov's avatar
Sergey Edunov committed
157
158
159
                    return True
                # stop if the best unfinalized score is worse than the worst
                # finalized one
Myle Ott's avatar
Myle Ott committed
160
                best_unfinalized_score = unfinalized_scores[sent].max()
Sergey Edunov's avatar
Sergey Edunov committed
161
162
163
164
165
166
                if self.normalize_scores:
                    best_unfinalized_score /= maxlen
                if worst_finalized[sent]['score'] >= best_unfinalized_score:
                    return True
            return False

Myle Ott's avatar
Myle Ott committed
167
        def finalize_hypos(step, bbsz_idx, eos_scores, unfinalized_scores=None):
Sergey Edunov's avatar
Sergey Edunov committed
168
169
170
171
172
173
174
175
176
177
178
179
            """
            Finalize the given hypotheses at this step, while keeping the total
            number of finalized hypotheses per sentence <= beam_size.

            Note: the input must be in the desired finalization order, so that
            hypotheses that appear earlier in the input are preferred to those
            that appear later.

            Args:
                step: current time step
                bbsz_idx: A vector of indices in the range [0, bsz*beam_size),
                    indicating which hypotheses to finalize
Myle Ott's avatar
Myle Ott committed
180
181
182
183
                eos_scores: A vector of the same size as bbsz_idx containing
                    scores for each hypothesis
                unfinalized_scores: A vector containing scores for all
                    unfinalized hypotheses
Sergey Edunov's avatar
Sergey Edunov committed
184
            """
Myle Ott's avatar
Myle Ott committed
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
            assert bbsz_idx.numel() == eos_scores.numel()

            # clone relevant token and attention tensors
            tokens_clone = tokens.index_select(0, bbsz_idx)
            tokens_clone = tokens_clone[:, 1:step+2]  # skip the first index, which is EOS
            tokens_clone[:, step] = self.eos
            attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1:step+2]

            # compute scores per token position
            pos_scores = scores.index_select(0, bbsz_idx)[:, :step+1]
            pos_scores[:, step] = eos_scores
            # convert from cumulative to per-position scores
            pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]

            # normalize sentence-level scores
            if self.normalize_scores:
                eos_scores /= (step+1)**self.len_penalty

203
204
205
206
207
208
209
210
            cum_unfin = []
            prev = 0
            for f in finished:
                if f:
                    prev += 1
                else:
                    cum_unfin.append(prev)

Sergey Edunov's avatar
Sergey Edunov committed
211
            sents_seen = set()
Myle Ott's avatar
Myle Ott committed
212
            for i, (idx, score) in enumerate(zip(bbsz_idx.tolist(), eos_scores.tolist())):
213
214
215
216
                unfin_idx = idx // beam_size
                sent = unfin_idx + cum_unfin[unfin_idx]

                sents_seen.add((sent, unfin_idx))
Sergey Edunov's avatar
Sergey Edunov committed
217
218

                def get_hypo():
Myle Ott's avatar
Myle Ott committed
219
                    _, alignment = attn_clone[i].max(dim=0)
Sergey Edunov's avatar
Sergey Edunov committed
220
                    return {
Myle Ott's avatar
Myle Ott committed
221
                        'tokens': tokens_clone[i],
Sergey Edunov's avatar
Sergey Edunov committed
222
                        'score': score,
Myle Ott's avatar
Myle Ott committed
223
                        'attention': attn_clone[i],  # src_len x tgt_len
Sergey Edunov's avatar
Sergey Edunov committed
224
                        'alignment': alignment,
Myle Ott's avatar
Myle Ott committed
225
                        'positional_scores': pos_scores[i],
Sergey Edunov's avatar
Sergey Edunov committed
226
227
228
229
                    }

                if len(finalized[sent]) < beam_size:
                    finalized[sent].append(get_hypo())
Myle Ott's avatar
Myle Ott committed
230
                elif not self.stop_early and score > worst_finalized[sent]['score']:
Sergey Edunov's avatar
Sergey Edunov committed
231
232
                    # replace worst hypo for this sentence with new/better one
                    worst_idx = worst_finalized[sent]['idx']
Myle Ott's avatar
Myle Ott committed
233
234
                    if worst_idx is not None:
                        finalized[sent][worst_idx] = get_hypo()
Sergey Edunov's avatar
Sergey Edunov committed
235
236
237
238
239
240
241
242

                    # find new worst finalized hypo for this sentence
                    idx, s = min(enumerate(finalized[sent]), key=lambda r: r[1]['score'])
                    worst_finalized[sent] = {
                        'score': s['score'],
                        'idx': idx,
                    }

243
244
            newly_finished = []
            for sent, unfin_idx in sents_seen:
Sergey Edunov's avatar
Sergey Edunov committed
245
                # check termination conditions for this sentence
Myle Ott's avatar
Myle Ott committed
246
                if not finished[sent] and is_finished(sent, step, unfinalized_scores):
Sergey Edunov's avatar
Sergey Edunov committed
247
                    finished[sent] = True
248
249
                    newly_finished.append(unfin_idx)
            return newly_finished
Sergey Edunov's avatar
Sergey Edunov committed
250
251

        reorder_state = None
252
        batch_idxs = None
Sergey Edunov's avatar
Sergey Edunov committed
253
254
255
        for step in range(maxlen + 1):  # one extra step for EOS marker
            # reorder decoder internal states based on the prev choice of beams
            if reorder_state is not None:
256
257
258
259
260
                if batch_idxs is not None:
                    # update beam indices to take into account removed sentences
                    corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(batch_idxs)
                    reorder_state.view(-1, beam_size).add_(corr.unsqueeze(-1) * beam_size)
                for i, model in enumerate(self.models):
Myle Ott's avatar
Myle Ott committed
261
                    if isinstance(model.decoder, FairseqIncrementalDecoder):
262
263
                        model.decoder.reorder_incremental_state(incremental_states[model], reorder_state)
                        encoder_outs[i] = model.decoder.reorder_encoder_out(encoder_outs[i], reorder_state)
Sergey Edunov's avatar
Sergey Edunov committed
264

265
266
            probs, avg_attn_scores = self._decode(
                tokens[:, :step+1], encoder_outs, incremental_states)
Sergey Edunov's avatar
Sergey Edunov committed
267
268
269
270
            if step == 0:
                # at the first step all hypotheses are equally likely, so use
                # only the first beam
                probs = probs.unfold(0, 1, beam_size).squeeze(2).contiguous()
Myle Ott's avatar
Myle Ott committed
271
272
                scores = scores.type_as(probs)
                scores_buf = scores_buf.type_as(probs)
Myle Ott's avatar
Myle Ott committed
273
            elif not self.sampling:
Sergey Edunov's avatar
Sergey Edunov committed
274
                # make probs contain cumulative scores for each hypothesis
Myle Ott's avatar
Myle Ott committed
275
                probs.add_(scores[:, step-1].view(-1, 1))
Myle Ott's avatar
Myle Ott committed
276

277
            probs[:, self.pad] = -math.inf  # never select pad
278
            probs[:, self.unk] -= self.unk_penalty  # apply unk penalty
Sergey Edunov's avatar
Sergey Edunov committed
279

280
281
            # Record attention scores
            attn[:, :, step+1].copy_(avg_attn_scores)
Sergey Edunov's avatar
Sergey Edunov committed
282
283
284
285

            cand_scores = buffer('cand_scores', type_of=scores)
            cand_indices = buffer('cand_indices')
            cand_beams = buffer('cand_beams')
Myle Ott's avatar
Myle Ott committed
286
287
288
            eos_bbsz_idx = buffer('eos_bbsz_idx')
            eos_scores = buffer('eos_scores', type_of=scores)
            if step < maxlen:
Dario Pavllo's avatar
Dario Pavllo committed
289
290
                if prefix_tokens is not None and step < prefix_tokens.size(1):
                    probs_slice = probs.view(bsz, -1, probs.size(-1))[:, 0, :]
291
292
                    cand_scores = torch.gather(
                        probs_slice, dim=1,
Dario Pavllo's avatar
Dario Pavllo committed
293
294
295
296
                        index=prefix_tokens[:, step].view(-1, 1).data
                    ).expand(-1, cand_size)
                    cand_indices = prefix_tokens[:, step].view(-1, 1).expand(bsz, cand_size).data
                    cand_beams.resize_as_(cand_indices).fill_(0)
Myle Ott's avatar
Myle Ott committed
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
                elif self.sampling:
                    assert self.pad == 1, 'sampling assumes the first two symbols can be ignored'
                    exp_probs = probs.exp_().view(-1, self.vocab_size)
                    if step == 0:
                        # we exclude the first two vocab items, one of which is pad
                        torch.multinomial(exp_probs[:, 2:], beam_size, replacement=True, out=cand_indices)
                        cand_indices.add_(2)
                    else:
                        torch.multinomial(exp_probs[:, 2:], 1, replacement=True, out=cand_indices)
                        cand_indices.add_(2)
                    torch.gather(exp_probs, dim=1, index=cand_indices, out=cand_scores)
                    cand_scores.log_()
                    cand_indices = cand_indices.view(bsz, -1).repeat(1, 2)
                    cand_scores = cand_scores.view(bsz, -1).repeat(1, 2)
                    if step == 0:
                        cand_beams = torch.zeros(bsz, cand_size).type_as(cand_indices)
                    else:
                        cand_beams = torch.arange(0, beam_size).repeat(bsz, 2).type_as(cand_indices)
                        # make scores cumulative
                        cand_scores.add_(
                            torch.gather(
                                scores[:, step-1].view(bsz, beam_size), dim=1,
                                index=cand_beams,
                            )
                        )
Dario Pavllo's avatar
Dario Pavllo committed
322
323
324
                else:
                    # take the best 2 x beam_size predictions. We'll choose the first
                    # beam_size of these which don't predict eos to continue with.
325

Dario Pavllo's avatar
Dario Pavllo committed
326
327
328
329
330
331
332
                    torch.topk(
                        probs.view(bsz, -1),
                        k=min(cand_size, probs.view(bsz, -1).size(1) - 1),  # -1 so we never select pad
                        out=(cand_scores, cand_indices),
                    )
                    torch.div(cand_indices, self.vocab_size, out=cand_beams)
                    cand_indices.fmod_(self.vocab_size)
Myle Ott's avatar
Myle Ott committed
333
334
335
336
337
338
339
340
            else:
                # finalize all active hypotheses once we hit maxlen
                # pick the hypothesis with the highest prob of EOS right now
                torch.sort(
                    probs[:, self.eos],
                    descending=True,
                    out=(eos_scores, eos_bbsz_idx),
                )
341
342
                num_remaining_sent -= len(finalize_hypos(
                    step, eos_bbsz_idx, eos_scores))
Myle Ott's avatar
Myle Ott committed
343
344
                assert num_remaining_sent == 0
                break
Sergey Edunov's avatar
Sergey Edunov committed
345
346
347
348

            # cand_bbsz_idx contains beam indices for the top candidate
            # hypotheses, with a range of values: [0, bsz*beam_size),
            # and dimensions: [bsz, cand_size]
349
            cand_bbsz_idx = cand_beams.add(bbsz_offsets)
Sergey Edunov's avatar
Sergey Edunov committed
350
351
352

            # finalize hypotheses that end in eos
            eos_mask = cand_indices.eq(self.eos)
353
354

            finalized_sents = set()
Sergey Edunov's avatar
Sergey Edunov committed
355
            if step >= self.minlen:
356
                # only consider eos when it's among the top beam_size indices
Myle Ott's avatar
Myle Ott committed
357
358
359
360
361
                torch.masked_select(
                    cand_bbsz_idx[:, :beam_size],
                    mask=eos_mask[:, :beam_size],
                    out=eos_bbsz_idx,
                )
Sergey Edunov's avatar
Sergey Edunov committed
362
                if eos_bbsz_idx.numel() > 0:
Myle Ott's avatar
Myle Ott committed
363
364
365
366
367
                    torch.masked_select(
                        cand_scores[:, :beam_size],
                        mask=eos_mask[:, :beam_size],
                        out=eos_scores,
                    )
368
                    finalized_sents = finalize_hypos(
Myle Ott's avatar
Myle Ott committed
369
                        step, eos_bbsz_idx, eos_scores, cand_scores)
370
                    num_remaining_sent -= len(finalized_sents)
Sergey Edunov's avatar
Sergey Edunov committed
371
372
373
374

            assert num_remaining_sent >= 0
            if num_remaining_sent == 0:
                break
Myle Ott's avatar
Myle Ott committed
375
            assert step < maxlen
Sergey Edunov's avatar
Sergey Edunov committed
376

377
378
379
380
381
382
            if len(finalized_sents) > 0:
                # construct batch_idxs which holds indices of batches to keep for the next pass

                new_bsz = bsz - len(finalized_sents)

                batch_mask = torch.ones(bsz).type_as(cand_indices)
383
                batch_mask[cand_indices.new(finalized_sents)] = 0
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
                batch_idxs = batch_mask.nonzero().squeeze(-1)

                eos_mask = eos_mask[batch_idxs]
                cand_beams = cand_beams[batch_idxs]
                bbsz_offsets.resize_(new_bsz, 1)
                cand_bbsz_idx = cand_beams.add(bbsz_offsets)

                cand_scores = cand_scores[batch_idxs]
                cand_indices = cand_indices[batch_idxs]

                scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
                scores_buf.resize_as_(scores)
                tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
                tokens_buf.resize_as_(tokens)
                attn = attn.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, attn.size(1), -1)
                attn_buf.resize_as_(attn)
                bsz = new_bsz
            else:
                batch_idxs = None

Sergey Edunov's avatar
Sergey Edunov committed
404
405
406
407
            # set active_mask so that values > cand_size indicate eos hypos
            # and values < cand_size indicate candidate active hypos.
            # After, the min values per row are the top candidate active hypos
            active_mask = buffer('active_mask')
Myle Ott's avatar
Myle Ott committed
408
409
410
411
412
            torch.add(
                eos_mask.type_as(cand_offsets)*cand_size,
                cand_offsets[:eos_mask.size(1)],
                out=active_mask,
            )
Sergey Edunov's avatar
Sergey Edunov committed
413
414
415
416

            # get the top beam_size active hypotheses, which are just the hypos
            # with the smallest values in active_mask
            active_hypos, _ignore = buffer('active_hypos'), buffer('_ignore')
417
418
419
            torch.topk(
                active_mask, k=beam_size, dim=1, largest=False,
                out=(_ignore, active_hypos)
Myle Ott's avatar
Myle Ott committed
420
            )
Sergey Edunov's avatar
Sergey Edunov committed
421
            active_bbsz_idx = buffer('active_bbsz_idx')
422
423
            torch.gather(
                cand_bbsz_idx, dim=1, index=active_hypos,
Myle Ott's avatar
Myle Ott committed
424
425
                out=active_bbsz_idx,
            )
426
427
            active_scores = torch.gather(
                cand_scores, dim=1, index=active_hypos,
Myle Ott's avatar
Myle Ott committed
428
429
                out=scores[:, step].view(bsz, beam_size),
            )
430

Sergey Edunov's avatar
Sergey Edunov committed
431
432
433
            active_bbsz_idx = active_bbsz_idx.view(-1)
            active_scores = active_scores.view(-1)

Myle Ott's avatar
Myle Ott committed
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
            # copy tokens and scores for active hypotheses
            torch.index_select(
                tokens[:, :step+1], dim=0, index=active_bbsz_idx,
                out=tokens_buf[:, :step+1],
            )
            torch.gather(
                cand_indices, dim=1, index=active_hypos,
                out=tokens_buf.view(bsz, beam_size, -1)[:, :, step+1],
            )
            if step > 0:
                torch.index_select(
                    scores[:, :step], dim=0, index=active_bbsz_idx,
                    out=scores_buf[:, :step],
                )
            torch.gather(
                cand_scores, dim=1, index=active_hypos,
                out=scores_buf.view(bsz, beam_size, -1)[:, :, step],
            )
Sergey Edunov's avatar
Sergey Edunov committed
452

453
            # copy attention for active hypotheses
Myle Ott's avatar
Myle Ott committed
454
455
456
457
            torch.index_select(
                attn[:, :, :step+2], dim=0, index=active_bbsz_idx,
                out=attn_buf[:, :, :step+2],
            )
Sergey Edunov's avatar
Sergey Edunov committed
458
459
460
461
462

            # swap buffers
            old_tokens = tokens
            tokens = tokens_buf
            tokens_buf = old_tokens
Myle Ott's avatar
Myle Ott committed
463
464
465
            old_scores = scores
            scores = scores_buf
            scores_buf = old_scores
466
467
468
            old_attn = attn
            attn = attn_buf
            attn_buf = old_attn
Sergey Edunov's avatar
Sergey Edunov committed
469
470
471
472
473

            # reorder incremental state in decoder
            reorder_state = active_bbsz_idx

        # sort by score descending
474
        for sent in range(len(finalized)):
Sergey Edunov's avatar
Sergey Edunov committed
475
476
477
478
            finalized[sent] = sorted(finalized[sent], key=lambda r: r['score'], reverse=True)

        return finalized

479
    def _decode(self, tokens, encoder_outs, incremental_states):
Myle Ott's avatar
Myle Ott committed
480
        # wrap in Variable
481
        tokens = utils.volatile_variable(tokens)
Sergey Edunov's avatar
Sergey Edunov committed
482
483
484
485

        avg_probs = None
        avg_attn = None
        for model, encoder_out in zip(self.models, encoder_outs):
486
            with utils.maybe_no_grad():
Myle Ott's avatar
Myle Ott committed
487
488
489
490
491
492
493
                if incremental_states[model] is not None:
                    decoder_out = list(model.decoder(tokens, encoder_out, incremental_states[model]))
                else:
                    decoder_out = list(model.decoder(tokens, encoder_out))
                decoder_out[0] = decoder_out[0][:, -1, :]
                attn = decoder_out[1]
            probs = model.get_normalized_probs(decoder_out, log_probs=False).data
494
            if avg_probs is None:
Sergey Edunov's avatar
Sergey Edunov committed
495
496
497
                avg_probs = probs
            else:
                avg_probs.add_(probs)
498
499
500
501
502
503
            if attn is not None:
                attn = attn[:, -1, :].data
                if avg_attn is None:
                    avg_attn = attn
                else:
                    avg_attn.add_(attn)
Sergey Edunov's avatar
Sergey Edunov committed
504
505
        avg_probs.div_(len(self.models))
        avg_probs.log_()
506
507
        if avg_attn is not None:
            avg_attn.div_(len(self.models))
Sergey Edunov's avatar
Sergey Edunov committed
508
509

        return avg_probs, avg_attn