beam_search.py 10.7 KB
Newer Older
huchen's avatar
huchen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
import torch

from seq2seq.data.config import BOS
from seq2seq.data.config import EOS


class SequenceGenerator:
    """
    Generator for the autoregressive inference with beam search decoding.
    """
    def __init__(self, model, beam_size=5, max_seq_len=100, cuda=False,
                 len_norm_factor=0.6, len_norm_const=5,
                 cov_penalty_factor=0.1):
        """
        Constructor for the SequenceGenerator.

        Beam search decoding supports coverage penalty and length
        normalization. For details, refer to Section 7 of the GNMT paper
        (https://arxiv.org/pdf/1609.08144.pdf).

        :param model: model which implements generate method
        :param beam_size: decoder beam size
        :param max_seq_len: maximum decoder sequence length
        :param cuda: whether to use cuda
        :param len_norm_factor: length normalization factor
        :param len_norm_const: length normalization constant
        :param cov_penalty_factor: coverage penalty factor
        """

        self.model = model
        self.cuda = cuda
        self.beam_size = beam_size
        self.max_seq_len = max_seq_len
        self.len_norm_factor = len_norm_factor
        self.len_norm_const = len_norm_const
        self.cov_penalty_factor = cov_penalty_factor

        self.batch_first = self.model.batch_first

    def greedy_search(self, batch_size, initial_input, initial_context=None):
        """
        Greedy decoder.

        :param batch_size: decoder batch size
        :param initial_input: initial input, usually tensor of BOS tokens
        :param initial_context: initial context, usually [encoder_context,
            src_seq_lengths, None]

        returns: (translation, lengths, counter)
            translation: (batch_size, max_seq_len) - indices of target tokens
            lengths: (batch_size) - lengths of generated translations
            counter: number of iterations of the decoding loop
        """
        max_seq_len = self.max_seq_len

        translation = torch.zeros(batch_size, max_seq_len, dtype=torch.int64)
        lengths = torch.ones(batch_size, dtype=torch.int64)
        active = torch.arange(0, batch_size, dtype=torch.int64)
        base_mask = torch.arange(0, batch_size, dtype=torch.int64)

        if self.cuda:
            translation = translation.cuda()
            lengths = lengths.cuda()
            active = active.cuda()
            base_mask = base_mask.cuda()

        translation[:, 0] = BOS
        words, context = initial_input, initial_context

        if self.batch_first:
            word_view = (-1, 1)
            ctx_batch_dim = 0
        else:
            word_view = (1, -1)
            ctx_batch_dim = 1

        counter = 0
        for idx in range(1, max_seq_len):
            if not len(active):
                break
            counter += 1

            words = words.view(word_view)
            output = self.model.generate(words, context, 1)
            words, logprobs, attn, context = output
            words = words.view(-1)

            translation[active, idx] = words
            lengths[active] += 1

            terminating = (words == EOS)

            if terminating.any():
                not_terminating = ~terminating

                mask = base_mask[:len(active)]
                mask = mask.masked_select(not_terminating)
                active = active.masked_select(not_terminating)

                words = words[mask]
                context[0] = context[0].index_select(ctx_batch_dim, mask)
                context[1] = context[1].index_select(0, mask)
                context[2] = context[2].index_select(1, mask)

        return translation, lengths, counter

    def beam_search(self, batch_size, initial_input, initial_context=None):
        """
        Beam search decoder.

        :param batch_size: decoder batch size
        :param initial_input: initial input, usually tensor of BOS tokens
        :param initial_context: initial context, usually [encoder_context,
            src_seq_lengths, None]

        returns: (translation, lengths, counter)
            translation: (batch_size, max_seq_len) - indices of target tokens
            lengths: (batch_size) - lengths of generated translations
            counter: number of iterations of the decoding loop
        """
        beam_size = self.beam_size
        norm_const = self.len_norm_const
        norm_factor = self.len_norm_factor
        max_seq_len = self.max_seq_len
        cov_penalty_factor = self.cov_penalty_factor

        translation = torch.zeros(batch_size * beam_size, max_seq_len,
                                  dtype=torch.int64)
        lengths = torch.ones(batch_size * beam_size, dtype=torch.int64)
        scores = torch.zeros(batch_size * beam_size, dtype=torch.float32)

        active = torch.arange(0, batch_size * beam_size, dtype=torch.int64)
        base_mask = torch.arange(0, batch_size * beam_size, dtype=torch.int64)
        global_offset = torch.arange(0, batch_size * beam_size, beam_size,
                                     dtype=torch.int64)

        eos_beam_fill = torch.tensor([0] + (beam_size - 1) * [float('-inf')])

        if self.cuda:
            translation = translation.cuda()
            lengths = lengths.cuda()
            active = active.cuda()
            base_mask = base_mask.cuda()
            scores = scores.cuda()
            global_offset = global_offset.cuda()
            eos_beam_fill = eos_beam_fill.cuda()

        translation[:, 0] = BOS

        words, context = initial_input, initial_context

        if self.batch_first:
            word_view = (-1, 1)
            ctx_batch_dim = 0
            attn_query_dim = 1
        else:
            word_view = (1, -1)
            ctx_batch_dim = 1
            attn_query_dim = 0

        # replicate context
        if self.batch_first:
            # context[0] (encoder state): (batch, seq, feature)
            _, seq, feature = context[0].shape
            context[0].unsqueeze_(1)
            context[0] = context[0].expand(-1, beam_size, -1, -1)
            context[0] = context[0].contiguous().view(batch_size * beam_size,
                                                      seq, feature)
            # context[0]: (batch * beam, seq, feature)
        else:
            # context[0] (encoder state): (seq, batch, feature)
            seq, _, feature = context[0].shape
            context[0].unsqueeze_(2)
            context[0] = context[0].expand(-1, -1, beam_size, -1)
            context[0] = context[0].contiguous().view(seq, batch_size *
                                                      beam_size, feature)
            # context[0]: (seq, batch * beam,  feature)

        # context[1] (encoder seq length): (batch)
        context[1].unsqueeze_(1)
        context[1] = context[1].expand(-1, beam_size)
        context[1] = context[1].contiguous().view(batch_size * beam_size)
        # context[1]: (batch * beam)

        accu_attn_scores = torch.zeros(batch_size * beam_size, seq)
        if self.cuda:
            accu_attn_scores = accu_attn_scores.cuda()

        counter = 0
        for idx in range(1, self.max_seq_len):
            if not len(active):
                break
            counter += 1

            eos_mask = (words == EOS)
            eos_mask = eos_mask.view(-1, beam_size)

            terminating, _ = eos_mask.min(dim=1)

            lengths[active[~eos_mask.view(-1)]] += 1

            output = self.model.generate(words, context, beam_size)
            words, logprobs, attn, context = output

            attn = attn.float().squeeze(attn_query_dim)
            attn = attn.masked_fill(eos_mask.view(-1).unsqueeze(1), 0)
            accu_attn_scores[active] += attn

            # words: (batch, beam, k)
            words = words.view(-1, beam_size, beam_size)
            words = words.masked_fill(eos_mask.unsqueeze(2), EOS)

            # logprobs: (batch, beam, k)
            logprobs = logprobs.float().view(-1, beam_size, beam_size)

            if eos_mask.any():
                logprobs[eos_mask] = eos_beam_fill

            active_scores = scores[active].view(-1, beam_size)
            # new_scores: (batch, beam, k)
            new_scores = active_scores.unsqueeze(2) + logprobs

            if idx == 1:
                new_scores[:, 1:, :].fill_(float('-inf'))

            new_scores = new_scores.view(-1, beam_size * beam_size)
            # index: (batch, beam)
            _, index = new_scores.topk(beam_size, dim=1)
            source_beam = index // beam_size

            new_scores = new_scores.view(-1, beam_size * beam_size)
            best_scores = torch.gather(new_scores, 1, index)
            scores[active] = best_scores.view(-1)

            words = words.view(-1, beam_size * beam_size)
            words = torch.gather(words, 1, index)

            # words: (1, batch * beam)
            words = words.view(word_view)

            offset = global_offset[:source_beam.shape[0]]
            source_beam += offset.unsqueeze(1)

            translation[active, :] = translation[active[source_beam.view(-1)], :]
            translation[active, idx] = words.view(-1)

            lengths[active] = lengths[active[source_beam.view(-1)]]

            context[2] = context[2].index_select(1, source_beam.view(-1))

            if terminating.any():
                not_terminating = ~terminating
                not_terminating = not_terminating.unsqueeze(1)
                not_terminating = not_terminating.expand(-1, beam_size).contiguous()

                normalization_mask = active.view(-1, beam_size)[terminating]

                # length normalization
                norm = lengths[normalization_mask].float()
                norm = (norm_const + norm) / (norm_const + 1.0)
                norm = norm ** norm_factor

                scores[normalization_mask] /= norm

                # coverage penalty
                penalty = accu_attn_scores[normalization_mask]
                penalty = penalty.clamp(0, 1)
                penalty = penalty.log()
                penalty[penalty == float('-inf')] = 0
                penalty = penalty.sum(dim=-1)

                scores[normalization_mask] += cov_penalty_factor * penalty

                mask = base_mask[:len(active)]
                mask = mask.masked_select(not_terminating.view(-1))

                words = words.index_select(ctx_batch_dim, mask)
                context[0] = context[0].index_select(ctx_batch_dim, mask)
                context[1] = context[1].index_select(0, mask)
                context[2] = context[2].index_select(1, mask)

                active = active.masked_select(not_terminating.view(-1))

        scores = scores.view(batch_size, beam_size)
        _, idx = scores.max(dim=1)

        translation = translation[idx + global_offset, :]
        lengths = lengths[idx + global_offset]

        return translation, lengths, counter