beam_search_penalty_kernels.cu 16.5 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
/*
 * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <assert.h>

lvhan028's avatar
lvhan028 committed
19
20
#include "src/turbomind/kernels/beam_search_penalty_kernels.h"
#include "src/turbomind/kernels/reduce_kernel_utils.cuh"
Li Zhang's avatar
Li Zhang committed
21

lvhan028's avatar
lvhan028 committed
22
namespace turbomind {
Li Zhang's avatar
Li Zhang committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312

template<typename T>
__global__ void add_bias_temperature(T*          logits,
                                     const T*    bias,
                                     const int   batch_size,
                                     const int   beam_width,
                                     const int   vocab_size,
                                     const int   vocab_size_padded,
                                     const float temperature)
{
    int tid  = threadIdx.x;
    int bid  = blockIdx.x;
    int bbid = blockIdx.y;

    logits += bbid * vocab_size_padded;

    const T MASK_VAL = (std::is_same<T, half>::value) ? -HALF_FLT_MAX : -FLT_MAX;
    const T inv_temp = static_cast<T>(1.0f / (temperature + 1e-6f));
    for (int i = tid + bid * blockDim.x; i < vocab_size_padded; i += blockDim.x * gridDim.x) {
        if (i < vocab_size) {
            T bias_val = bias == nullptr ? (T)(0.0f) : bias[i];
            logits[i]  = (logits[i] + bias_val) * inv_temp;
        }
        else {
            logits[i] = MASK_VAL;
        }
    }
}

template<>
__global__ void add_bias_temperature(half2*       logits,
                                     const half2* bias,
                                     const int    batch_size,
                                     const int    beam_width,
                                     const int    vocab_size,
                                     const int    vocab_size_padded,
                                     const float  temperature)
{
    assert(vocab_size % 2 == 0);
    assert(vocab_size_padded % 2 == 0);

    const int tid  = threadIdx.x;
    const int bid  = blockIdx.x;
    const int bbid = blockIdx.y;

    const half2 mask_val = __float2half2_rn(-HALF_FLT_MAX);
    const half2 inv_temp = __float2half2_rn(1.0f / (temperature + 1e-6f));

    const int half_vocab_size        = vocab_size / 2;
    const int half_vocab_size_padded = vocab_size_padded / 2;

    logits += bbid * half_vocab_size_padded;
    for (int index = tid + bid * blockDim.x; index < half_vocab_size_padded; index += blockDim.x * gridDim.x) {
        int   vocab_idx = index % half_vocab_size_padded;
        half2 logit     = vocab_idx < half_vocab_size ? __ldg(&logits[index]) : mask_val;
        if (vocab_idx < half_vocab_size) {
            if (bias != nullptr) {
                logit = __hadd2(logit, bias[vocab_idx]);
            }
            logit = __hmul2(logit, inv_temp);
        }
        logits[index] = logit;
    }
}

template<typename T, bool IS_ADDITIVE>
__global__ void apply_repetition_penalty(T*          logits,
                                         const int   batch_size,
                                         const int   beam_width,
                                         const int   vocab_size,
                                         const int   vocab_size_padded,
                                         const int   step,
                                         const int*  current_ids,
                                         const int*  previous_ids,
                                         const int*  parent_ids,
                                         const int*  input_lengths,
                                         const int   max_input_length,
                                         const float repetition_penalty)
{
    assert(step > 0);

    const int tid      = threadIdx.x;
    const int bbid     = blockIdx.x;
    const int batch_id = bbid / beam_width;
    const int bbsize   = batch_size * beam_width;

    logits += bbid * vocab_size_padded;
    extern __shared__ char sbuf[];
    T*                     penalty_logits = reinterpret_cast<T*>(sbuf);
    // prevent misaligment when sizeof(T) = 2
    int*      penalty_indices = reinterpret_cast<int*>(sbuf + (sizeof(T) * step + 31) / 32 * 32);
    const int input_length    = (input_lengths != nullptr) ? input_lengths[bbid] : max_input_length;
    if (tid == 0) {
        T   repet_penalty         = static_cast<T>(repetition_penalty);
        int prev_id               = current_ids[bbid];
        T   prev_logit            = logits[prev_id];
        penalty_indices[step - 1] = prev_id;

        if (IS_ADDITIVE) {
            penalty_logits[step - 1] = prev_logit - repet_penalty;
        }
        else {
            penalty_logits[step - 1] = prev_logit > T(0) ? prev_logit / repet_penalty : prev_logit * repet_penalty;
        }
        if (step > 1) {
            int parent_beam = bbid % beam_width;
            for (int i = step - 2; i >= 0; --i) {
                // Skip the padded tokens.
                if (i >= input_length && i < max_input_length) {
                    continue;
                }
                parent_beam        = parent_ids[i * bbsize + batch_id * beam_width + parent_beam];
                prev_id            = previous_ids[i * bbsize + batch_id * beam_width + parent_beam];
                prev_logit         = logits[prev_id];
                penalty_indices[i] = prev_id;
                if (IS_ADDITIVE) {
                    penalty_logits[i] = prev_logit - repet_penalty;
                }
                else {
                    penalty_logits[i] = prev_logit > T(0) ? prev_logit / repet_penalty : prev_logit * repet_penalty;
                }
            }
        }
    }
    __syncthreads();
    for (int i = tid; i < step; i += blockDim.x) {
        if (i >= input_length && i < max_input_length) {
            continue;
        }
        logits[penalty_indices[i]] = penalty_logits[i];
    }
}

template<typename T>
__global__ void apply_min_length_penalty(T*         logits,
                                         const int  min_length,
                                         const int* end_ids,
                                         const int* sequence_lengths,
                                         const int  max_input_length,
                                         const int  beam_width,
                                         const int  vocab_size_padded)
{
    int bbid = threadIdx.x + blockIdx.x * blockDim.x;  // batch-beam index
    int bid  = bbid / beam_width;                      // batch index
    // We need +1 because sequence_lengths = max_input_length + num_gen_tokens - 1,
    // which is equal to the length of k/v caches.
    if (sequence_lengths[bbid] + 1 - max_input_length < min_length) {
        T mask_val                                      = (std::is_same<T, half>::value) ? -HALF_FLT_MAX : -FLT_MAX;
        logits[bbid * vocab_size_padded + end_ids[bid]] = mask_val;
    }
}

template<typename T>
void invokeAddBiasApplyPenalties(int                         step,
                                 T*                          logits,
                                 const int*                  current_ids,
                                 const int*                  previous_ids,
                                 const int*                  parent_ids,
                                 const int*                  input_lengths,
                                 const int*                  sequence_lengths,
                                 const T*                    bias,
                                 const int                   ite,
                                 const int                   max_input_length,
                                 const int                   local_batch_size,
                                 const int                   batch_size,
                                 const int                   beam_width,
                                 const int                   vocab_size,
                                 const int                   vocab_size_padded,
                                 const int*                  end_ids,
                                 const float                 temperature,
                                 const float                 repetition_penalty,
                                 const RepetitionPenaltyType repetition_penalty_type,
                                 const int                   min_length,
                                 cudaStream_t                stream)
{
    if (bias != nullptr || temperature != 1.0f || vocab_size != vocab_size_padded) {
        dim3 block(512);
        if (std::is_same<T, half>::value && vocab_size % 2 == 0 && vocab_size_padded % 2 == 0) {
            dim3 grid((vocab_size_padded / 2 + block.x - 1) / block.x, beam_width * local_batch_size);
            add_bias_temperature<<<grid, block, 0, stream>>>(reinterpret_cast<half2*>(logits),
                                                             reinterpret_cast<const half2*>(bias),
                                                             batch_size,
                                                             beam_width,
                                                             vocab_size,
                                                             vocab_size_padded,
                                                             temperature);
        }
        else {
            dim3 grid((vocab_size_padded + block.x - 1) / block.x, beam_width * local_batch_size);
            add_bias_temperature<<<grid, block, 0, stream>>>(
                logits, bias, batch_size, beam_width, vocab_size, vocab_size_padded, temperature);
        }
    }

    if (repetition_penalty_type != RepetitionPenaltyType::None && step > 0) {
        if (repetition_penalty != getDefaultPenaltyValue(repetition_penalty_type)) {
            size_t smem_size = (sizeof(T) * step + 31) / 32 * 32 + sizeof(int) * step;
            dim3   block(256);
            dim3   grid(beam_width * local_batch_size);
            if (repetition_penalty_type == RepetitionPenaltyType::Multiplicative) {
                apply_repetition_penalty<T, false>
                    <<<grid, block, smem_size, stream>>>(logits,
                                                         batch_size,
                                                         beam_width,
                                                         vocab_size,
                                                         vocab_size_padded,
                                                         step,
                                                         current_ids,
                                                         previous_ids,
                                                         // TODO(jaedeokk):
                                                         //   Remove (+ite ...) by getting parent_ids with offset
                                                         //   and then remove 'ite' argument from the function.
                                                         parent_ids + ite * beam_width * local_batch_size,
                                                         input_lengths,
                                                         max_input_length,
                                                         repetition_penalty);
            }
            else if (repetition_penalty_type == RepetitionPenaltyType::Additive) {
                apply_repetition_penalty<T, true>
                    <<<grid, block, smem_size, stream>>>(logits,
                                                         batch_size,
                                                         beam_width,
                                                         vocab_size,
                                                         vocab_size_padded,
                                                         step,
                                                         current_ids,
                                                         previous_ids,
                                                         parent_ids + ite * beam_width * local_batch_size,
                                                         input_lengths,
                                                         max_input_length,
                                                         repetition_penalty);
            }
        }
    }

    if (step - max_input_length < min_length) {
        FT_CHECK_WITH_INFO(sequence_lengths != nullptr, "Need sequence_lengths to apply min length penlaty");
        FT_CHECK_WITH_INFO(end_ids != nullptr, "Need end_id to apply min length penlaty");

        const int block_size = min(local_batch_size * beam_width, 1024);
        const int grid_size  = (local_batch_size * beam_width + block_size - 1) / block_size;
        apply_min_length_penalty<<<grid_size, block_size, 0, stream>>>(
            logits, min_length, end_ids, sequence_lengths, max_input_length, beam_width, vocab_size_padded);
    }
}

template void invokeAddBiasApplyPenalties(int                         step,
                                          float*                      logits,
                                          const int*                  current_ids,
                                          const int*                  previous_ids,
                                          const int*                  parent_ids,
                                          const int*                  input_lengths,
                                          const int*                  sequence_lengths,
                                          const float*                bias,
                                          const int                   ite,
                                          const int                   max_input_length,
                                          const int                   local_batch_size,
                                          const int                   batch_size,
                                          const int                   beam_width,
                                          const int                   vocab_size,
                                          const int                   vocab_size_padded,
                                          const int*                  end_ids,
                                          const float                 temperature,
                                          const float                 repetition_penalty,
                                          const RepetitionPenaltyType repetition_penalty_type,
                                          const int                   min_length,
                                          cudaStream_t                stream);

template void invokeAddBiasApplyPenalties(int                         step,
                                          half*                       logits,
                                          const int*                  current_ids,
                                          const int*                  previous_ids,
                                          const int*                  parent_ids,
                                          const int*                  input_lengths,
                                          const int*                  sequence_lengths,
                                          const half*                 bias,
                                          const int                   ite,
                                          const int                   max_input_length,
                                          const int                   local_batch_size,
                                          const int                   batch_size,
                                          const int                   beam_width,
                                          const int                   vocab_size,
                                          const int                   vocab_size_padded,
                                          const int*                  end_ids,
                                          const float                 temperature,
                                          const float                 repetition_penalty,
                                          const RepetitionPenaltyType repetition_penalty_type,
                                          const int                   min_length,
                                          cudaStream_t                stream);

lvhan028's avatar
lvhan028 committed
313
}  // namespace turbomind