sequence_generator.py 14.9 KB
Newer Older
Sergey Edunov's avatar
Sergey Edunov committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#

from contextlib import ExitStack
import math
import torch
import torch.nn.functional as F
from torch.autograd import Variable

from fairseq import utils
Myle Ott's avatar
Myle Ott committed
16
from fairseq.models import FairseqIncrementalDecoder
Sergey Edunov's avatar
Sergey Edunov committed
17
18
19


class SequenceGenerator(object):
20
    def __init__(self, models, beam_size=1, minlen=1, maxlen=200,
21
22
                 stop_early=True, normalize_scores=True, len_penalty=1,
                 unk_penalty=0):
Sergey Edunov's avatar
Sergey Edunov committed
23
24
25
26
27
28
29
30
31
32
33
        """Generates translations of a given source sentence.

        Args:
            min/maxlen: The length of the generated output will be bounded by
                minlen and maxlen (not including the end-of-sentence marker).
            stop_early: Stop generation immediately after we finalize beam_size
                hypotheses, even though longer hypotheses might have better
                normalized scores.
            normalize_scores: Normalize scores by the length of the output.
        """
        self.models = models
34
        self.pad = models[0].dst_dict.pad()
35
        self.unk = models[0].dst_dict.unk()
36
37
        self.eos = models[0].dst_dict.eos()
        assert all(m.dst_dict.pad() == self.pad for m in self.models[1:])
38
        assert all(m.dst_dict.unk() == self.unk for m in self.models[1:])
39
40
        assert all(m.dst_dict.eos() == self.eos for m in self.models[1:])
        self.vocab_size = len(models[0].dst_dict)
Sergey Edunov's avatar
Sergey Edunov committed
41
42
        self.beam_size = beam_size
        self.minlen = minlen
43
        self.maxlen = min(maxlen, *[m.max_decoder_positions() for m in self.models])
Sergey Edunov's avatar
Sergey Edunov committed
44
45
46
        self.stop_early = stop_early
        self.normalize_scores = normalize_scores
        self.len_penalty = len_penalty
47
        self.unk_penalty = unk_penalty
Sergey Edunov's avatar
Sergey Edunov committed
48
49
50
51
52
53

    def cuda(self):
        for model in self.models:
            model.cuda()
        return self

Myle Ott's avatar
Myle Ott committed
54
    def generate_batched_itr(self, data_itr, beam_size=None, maxlen_a=0.0, maxlen_b=None,
Sergey Edunov's avatar
Sergey Edunov committed
55
56
57
58
59
60
61
62
63
                             cuda_device=None, timer=None):
        """Iterate over a batched dataset and yield individual translations.

        Args:
            maxlen_a/b: generate sequences of maximum length ax + b,
                where x is the source sentence length.
            cuda_device: GPU on which to do generation.
            timer: StopwatchMeter for timing generations.
        """
Myle Ott's avatar
Myle Ott committed
64
65
66
        if maxlen_b is None:
            maxlen_b = self.maxlen

Sergey Edunov's avatar
Sergey Edunov committed
67
        for sample in data_itr:
Myle Ott's avatar
Myle Ott committed
68
            s = utils.make_variable(sample, volatile=True, cuda_device=cuda_device)
Sergey Edunov's avatar
Sergey Edunov committed
69
70
71
72
            input = s['net_input']
            srclen = input['src_tokens'].size(1)
            if timer is not None:
                timer.start()
Myle Ott's avatar
Myle Ott committed
73
            hypos = self.generate(input['src_tokens'], beam_size=beam_size,
Myle Ott's avatar
Myle Ott committed
74
                                  maxlen=int(maxlen_a*srclen + maxlen_b))
Sergey Edunov's avatar
Sergey Edunov committed
75
76
            if timer is not None:
                timer.stop(s['ntokens'])
Myle Ott's avatar
Myle Ott committed
77
            for i, id in enumerate(s['id'].data):
Sergey Edunov's avatar
Sergey Edunov committed
78
                src = input['src_tokens'].data[i, :]
79
                # remove padding from ref
Myle Ott's avatar
Myle Ott committed
80
                ref = utils.strip_pad(s['target'].data[i, :], self.pad)
Sergey Edunov's avatar
Sergey Edunov committed
81
82
                yield id, src, ref, hypos[i]

Myle Ott's avatar
Myle Ott committed
83
    def generate(self, src_tokens, beam_size=None, maxlen=None):
Sergey Edunov's avatar
Sergey Edunov committed
84
85
86
        """Generate a batch of translations."""
        with ExitStack() as stack:
            for model in self.models:
Myle Ott's avatar
Myle Ott committed
87
88
89
                if isinstance(model.decoder, FairseqIncrementalDecoder):
                    stack.enter_context(model.decoder.incremental_inference())
            return self._generate(src_tokens, beam_size, maxlen)
Sergey Edunov's avatar
Sergey Edunov committed
90

Myle Ott's avatar
Myle Ott committed
91
    def _generate(self, src_tokens, beam_size=None, maxlen=None):
Myle Ott's avatar
Myle Ott committed
92
        bsz, srclen = src_tokens.size()
Sergey Edunov's avatar
Sergey Edunov committed
93
94
        maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen

95
96
        # the max beam size is the dictionary size - 1, since we never select pad
        beam_size = beam_size if beam_size is not None else self.beam_size
97
        beam_size = min(beam_size, self.vocab_size - 1)
98

Sergey Edunov's avatar
Sergey Edunov committed
99
100
101
        encoder_outs = []
        for model in self.models:
            model.eval()
Myle Ott's avatar
Myle Ott committed
102
103
            if isinstance(model.decoder, FairseqIncrementalDecoder):
                model.decoder.set_beam_size(beam_size)
Sergey Edunov's avatar
Sergey Edunov committed
104

Myle Ott's avatar
Myle Ott committed
105
106
            # compute the encoder output for each beam
            encoder_out = model.encoder(src_tokens.repeat(1, beam_size).view(-1, srclen))
Sergey Edunov's avatar
Sergey Edunov committed
107
108
109
110
111
112
113
            encoder_outs.append(encoder_out)

        # initialize buffers
        scores = encoder_outs[0][0].data.new(bsz * beam_size).fill_(0)
        tokens = src_tokens.data.new(bsz * beam_size, maxlen + 2).fill_(self.pad)
        tokens_buf = tokens.clone()
        tokens[:, 0] = self.eos
114
115
        attn = scores.new(bsz * beam_size, src_tokens.size(1), maxlen + 2)
        attn_buf = attn.clone()
Sergey Edunov's avatar
Sergey Edunov committed
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131

        # list of completed sentences
        finalized = [[] for i in range(bsz)]
        finished = [False for i in range(bsz)]
        worst_finalized = [{'idx': None, 'score': float('Inf')} for i in range(bsz)]
        num_remaining_sent = bsz

        # number of candidate hypos per step
        cand_size = 2 * beam_size  # 2 x beam size in case half are EOS

        # offset arrays for converting between different indexing schemes
        bbsz_offsets = (torch.arange(0, bsz)*beam_size).unsqueeze(1).type_as(tokens)
        cand_offsets = torch.arange(0, cand_size).type_as(tokens)

        # helper function for allocating buffers on the fly
        buffers = {}
Louis Martin's avatar
Louis Martin committed
132
        def buffer(name, type_of=tokens):  # noqa
Sergey Edunov's avatar
Sergey Edunov committed
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
            if name not in buffers:
                buffers[name] = type_of.new()
            return buffers[name]

        def is_finished(sent):
            """
            Check whether we've finished generation for a given sentence, by
            comparing the worst score among finalized hypotheses to the best
            possible score among unfinalized hypotheses.
            """
            assert len(finalized[sent]) <= beam_size
            if len(finalized[sent]) == beam_size:
                if self.stop_early:
                    return True
                # stop if the best unfinalized score is worse than the worst
                # finalized one
                bbsz = sent*beam_size
                best_unfinalized_score = scores[bbsz:bbsz+beam_size].max()
                if self.normalize_scores:
                    best_unfinalized_score /= maxlen
                if worst_finalized[sent]['score'] >= best_unfinalized_score:
                    return True
            return False

        def finalize_hypos(step, bbsz_idx, scores):
            """
            Finalize the given hypotheses at this step, while keeping the total
            number of finalized hypotheses per sentence <= beam_size.

            Note: the input must be in the desired finalization order, so that
            hypotheses that appear earlier in the input are preferred to those
            that appear later.

            Args:
                step: current time step
                bbsz_idx: A vector of indices in the range [0, bsz*beam_size),
                    indicating which hypotheses to finalize
                scores: A vector of the same size as bbsz_idx containing scores
                    for each hypothesis
            """
            assert bbsz_idx.numel() == scores.numel()
            norm_scores = scores/math.pow(step+1, self.len_penalty) if self.normalize_scores else scores
            sents_seen = set()
            for idx, score in zip(bbsz_idx.cpu(), norm_scores.cpu()):
                sent = idx // beam_size
                sents_seen.add(sent)

                def get_hypo():
181
                    hypo = tokens[idx, 1:step+2].clone()  # skip the first index, which is EOS
Sergey Edunov's avatar
Sergey Edunov committed
182
                    hypo[step] = self.eos
183
184
                    attention = attn[idx, :, 1:step+2].clone()
                    _, alignment = attention.max(dim=0)
Sergey Edunov's avatar
Sergey Edunov committed
185
186
187
                    return {
                        'tokens': hypo,
                        'score': score,
188
                        'attention': attention,
Sergey Edunov's avatar
Sergey Edunov committed
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
                        'alignment': alignment,
                    }

                if len(finalized[sent]) < beam_size:
                    finalized[sent].append(get_hypo())
                elif score > worst_finalized[sent]['score']:
                    # replace worst hypo for this sentence with new/better one
                    worst_idx = worst_finalized[sent]['idx']
                    finalized[sent][worst_idx] = get_hypo()

                    # find new worst finalized hypo for this sentence
                    idx, s = min(enumerate(finalized[sent]), key=lambda r: r[1]['score'])
                    worst_finalized[sent] = {
                        'score': s['score'],
                        'idx': idx,
                    }

            # return number of hypotheses finished this step
            num_finished = 0
            for sent in sents_seen:
                # check termination conditions for this sentence
                if not finished[sent] and is_finished(sent):
                    finished[sent] = True
                    num_finished += 1
            return num_finished

        reorder_state = None
        for step in range(maxlen + 1):  # one extra step for EOS marker
            # reorder decoder internal states based on the prev choice of beams
            if reorder_state is not None:
                for model in self.models:
Myle Ott's avatar
Myle Ott committed
220
221
                    if isinstance(model.decoder, FairseqIncrementalDecoder):
                        model.decoder.reorder_incremental_state(reorder_state)
Sergey Edunov's avatar
Sergey Edunov committed
222
223
224
225
226
227
228
229
230

            probs, avg_attn_scores = self._decode(tokens[:, :step+1], encoder_outs)
            if step == 0:
                # at the first step all hypotheses are equally likely, so use
                # only the first beam
                probs = probs.unfold(0, 1, beam_size).squeeze(2).contiguous()
            else:
                # make probs contain cumulative scores for each hypothesis
                probs.add_(scores.view(-1, 1))
231
            probs[:, self.pad] = -math.inf  # never select pad
232
            probs[:, self.unk] -= self.unk_penalty  # apply unk penalty
Sergey Edunov's avatar
Sergey Edunov committed
233

234
235
            # Record attention scores
            attn[:, :, step+1].copy_(avg_attn_scores)
Sergey Edunov's avatar
Sergey Edunov committed
236
237
238
239
240
241

            # take the best 2 x beam_size predictions. We'll choose the first
            # beam_size of these which don't predict eos to continue with.
            cand_scores = buffer('cand_scores', type_of=scores)
            cand_indices = buffer('cand_indices')
            cand_beams = buffer('cand_beams')
242
243
244
            probs.view(bsz, -1).topk(
                min(cand_size, probs.view(bsz, -1).size(1) - 1),  # -1 so we never select pad
                out=(cand_scores, cand_indices))
Sergey Edunov's avatar
Sergey Edunov committed
245
246
247
248
249
250
251
252
253
254
255
256
            torch.div(cand_indices, self.vocab_size, out=cand_beams)
            cand_indices.fmod_(self.vocab_size)

            # cand_bbsz_idx contains beam indices for the top candidate
            # hypotheses, with a range of values: [0, bsz*beam_size),
            # and dimensions: [bsz, cand_size]
            cand_bbsz_idx = cand_beams.add_(bbsz_offsets)

            # finalize hypotheses that end in eos
            eos_mask = cand_indices.eq(self.eos)
            if step >= self.minlen:
                eos_bbsz_idx = buffer('eos_bbsz_idx')
257
258
                # only consider eos when it's among the top beam_size indices
                cand_bbsz_idx[:, :beam_size].masked_select(eos_mask[:, :beam_size], out=eos_bbsz_idx)
Sergey Edunov's avatar
Sergey Edunov committed
259
260
                if eos_bbsz_idx.numel() > 0:
                    eos_scores = buffer('eos_scores', type_of=scores)
261
                    cand_scores[:, :beam_size].masked_select(eos_mask[:, :beam_size], out=eos_scores)
Sergey Edunov's avatar
Sergey Edunov committed
262
263
264
265
266
267
268
269
270
271
                    num_remaining_sent -= finalize_hypos(step, eos_bbsz_idx, eos_scores)

            assert num_remaining_sent >= 0
            if num_remaining_sent == 0:
                break

            # set active_mask so that values > cand_size indicate eos hypos
            # and values < cand_size indicate candidate active hypos.
            # After, the min values per row are the top candidate active hypos
            active_mask = buffer('active_mask')
272
            torch.add(eos_mask.type_as(cand_offsets)*cand_size, cand_offsets[:eos_mask.size(1)],
Sergey Edunov's avatar
Sergey Edunov committed
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
                      out=active_mask)

            # get the top beam_size active hypotheses, which are just the hypos
            # with the smallest values in active_mask
            active_hypos, _ignore = buffer('active_hypos'), buffer('_ignore')
            active_mask.topk(beam_size, 1, largest=False, out=(_ignore, active_hypos))
            active_bbsz_idx = buffer('active_bbsz_idx')
            cand_bbsz_idx.gather(1, active_hypos, out=active_bbsz_idx)
            active_scores = cand_scores.gather(1, active_hypos,
                                               out=scores.view(bsz, beam_size))

            active_bbsz_idx = active_bbsz_idx.view(-1)
            active_scores = active_scores.view(-1)

            # finalize all active hypotheses once we hit maxlen
            # finalize_hypos will take care of adding the EOS markers
            if step == maxlen:
                num_remaining_sent -= finalize_hypos(step, active_bbsz_idx, active_scores)
                assert num_remaining_sent == 0
                break

            # copy tokens for active hypotheses
            torch.index_select(tokens[:, :step+1], dim=0, index=active_bbsz_idx,
                               out=tokens_buf[:, :step+1])
            cand_indices.gather(1, active_hypos,
                                out=tokens_buf.view(bsz, beam_size, -1)[:, :, step+1])

300
301
302
            # copy attention for active hypotheses
            torch.index_select(attn[:, :, :step+2], dim=0, index=active_bbsz_idx,
                               out=attn_buf[:, :, :step+2])
Sergey Edunov's avatar
Sergey Edunov committed
303
304
305
306
307

            # swap buffers
            old_tokens = tokens
            tokens = tokens_buf
            tokens_buf = old_tokens
308
309
310
            old_attn = attn
            attn = attn_buf
            attn_buf = old_attn
Sergey Edunov's avatar
Sergey Edunov committed
311
312
313
314
315
316
317
318
319
320
321

            # reorder incremental state in decoder
            reorder_state = active_bbsz_idx

        # sort by score descending
        for sent in range(bsz):
            finalized[sent] = sorted(finalized[sent], key=lambda r: r['score'], reverse=True)

        return finalized

    def _decode(self, tokens, encoder_outs):
Myle Ott's avatar
Myle Ott committed
322
        # wrap in Variable
Sergey Edunov's avatar
Sergey Edunov committed
323
324
325
326
327
        tokens = Variable(tokens, volatile=True)

        avg_probs = None
        avg_attn = None
        for model, encoder_out in zip(self.models, encoder_outs):
328
            decoder_out, attn = model.decoder(tokens, encoder_out)
329
            probs = F.softmax(decoder_out[:, -1, :], dim=1).data
Sergey Edunov's avatar
Sergey Edunov committed
330
331
332
333
334
335
336
337
338
339
340
341
            attn = attn[:, -1, :].data
            if avg_probs is None or avg_attn is None:
                avg_probs = probs
                avg_attn = attn
            else:
                avg_probs.add_(probs)
                avg_attn.add_(attn)
        avg_probs.div_(len(self.models))
        avg_probs.log_()
        avg_attn.div_(len(self.models))

        return avg_probs, avg_attn