sampling_penalty_kernels.cu 27.3 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
/*
 * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <assert.h>
#include <float.h>

#include "src/fastertransformer/kernels/sampling_penalty_kernels.h"

namespace fastertransformer {

// TODO Add half2 implementation
template<typename T>
__global__ void applyTemperaturePenalty(T*          logits,
                                        const T*    bias,
                                        const float temperature_inverse,
                                        const int   m,
                                        const int   vocab_size,
                                        const int   vocab_size_padd)
{
    const bool IS_FP16   = std::is_same<T, half>::value;
    const T    MAX_T_VAL = (IS_FP16) ? 65504.F : FLT_MAX;
    for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < m * vocab_size_padd;
         index += blockDim.x * gridDim.x) {
        T bias_val = bias == nullptr ? (T)(0.0f) : bias[index % vocab_size_padd];
        if (index % vocab_size_padd < vocab_size) {
            logits[index] = (logits[index] + bias_val) * (T)temperature_inverse;
        }
        else {
            logits[index] = -MAX_T_VAL;
        }
    }
}

template<>
__global__ void applyTemperaturePenalty(half2*       logits,
                                        const half2* bias,
                                        const float  temperature_inverse,
                                        const int    batch_size,
                                        const int    vocab_size,
                                        const int    vocab_size_padded)
{
    assert(vocab_size % 2 == 0);
    assert(vocab_size_padded % 2 == 0);
    const half2 mask_val = __float2half2_rn(-65504.0f);
    const half2 temp_inv = __float2half2_rn(temperature_inverse);

    const int half_vocab_size        = vocab_size / 2;
    const int half_vocab_size_padded = vocab_size_padded / 2;
    for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size * half_vocab_size_padded;
         index += blockDim.x * gridDim.x) {
        int   vocab_idx = index % half_vocab_size_padded;
        half2 logit     = vocab_idx < half_vocab_size ? __ldg(&logits[index]) : mask_val;
        if (vocab_idx < half_vocab_size) {
            if (bias != nullptr) {
                logit = __hadd2(logit, bias[vocab_idx]);
            }
            logits[index] = __hmul2(logit, temp_inv);
        }
    }
}

template<typename T>
void invokeApplyTemperaturePenalty(T*           logits,
                                   const T*     bias,
                                   const float  temperature,
                                   const int    batch_size,
                                   const int    vocab_size,
                                   const int    vocab_size_padd,
                                   cudaStream_t stream)
{
    dim3    block(min(vocab_size_padd, 1024));
    dim3    grid(min(batch_size * vocab_size_padd / block.x, 65536));
    const T temperature_inverse = (T)(1.f / (temperature + 1e-6f));
    if (std::is_same<T, half>::value && vocab_size % 2 == 0 && vocab_size_padd % 2 == 0) {
        applyTemperaturePenalty<<<grid, block, 0, stream>>>(reinterpret_cast<half2*>(logits),
                                                            reinterpret_cast<const half2*>(bias),
                                                            temperature_inverse,
                                                            batch_size,
                                                            vocab_size,
                                                            vocab_size_padd);
    }
    else {
        applyTemperaturePenalty<T>
            <<<grid, block, 0, stream>>>(logits, bias, temperature_inverse, batch_size, vocab_size, vocab_size_padd);
    }
}

template void invokeApplyTemperaturePenalty(float*       logits,
                                            const float* bias,
                                            const float  temperature,
                                            const int    batch_size,
                                            const int    vocab_size,
                                            const int    vocab_size_padd,
                                            cudaStream_t stream);

template void invokeApplyTemperaturePenalty(half*        logits,
                                            const half*  bias,
                                            const float  temperature,
                                            const int    batch_size,
                                            const int    vocab_size,
                                            const int    vocab_size_padd,
                                            cudaStream_t stream);

template<typename T>
__global__ void batchApplyTemperaturePenalty(T*           logits,
                                             const T*     bias,
                                             const float* temperatures,
                                             const int    batch_size,
                                             const int    vocab_size,
                                             const int    vocab_size_padd)
{
    // TODO: Add macro or device function to get MAX_T_VAL.
    const bool              IS_FP16   = std::is_same<T, half>::value;
    const T                 MAX_T_VAL = (IS_FP16) ? 65504.F : FLT_MAX;
    extern __shared__ float inv_temperatures[];
    if (threadIdx.x < batch_size) {
        inv_temperatures[threadIdx.x] = 1.0f / (temperatures[threadIdx.x] + 1e-6f);
    }
    __syncthreads();

    for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size * vocab_size_padd;
         index += blockDim.x * gridDim.x) {
        int batch_idx = index / vocab_size_padd;
        int vocab_idx = index % vocab_size_padd;
        T   logit     = (vocab_idx < vocab_size) ? logits[index] : -MAX_T_VAL;
        if (vocab_idx < vocab_size) {
            if (bias != nullptr) {
                logit += bias[vocab_idx];
            }
            logit *= inv_temperatures[batch_idx];
        }
        logits[index] = logit;
    }
}

__global__ void batchApplyTemperaturePenalty_h2(half2*       logits,
                                                const half2* bias,
                                                const float* temperatures,
                                                const int    batch_size,
                                                const int    vocab_size,
                                                const int    vocab_size_padded)
{
    assert(vocab_size % 2 == 0);
    assert(vocab_size_padded % 2 == 0);
    extern __shared__ half2 h2_inv_temperatures[];
    if (threadIdx.x < batch_size) {
        h2_inv_temperatures[threadIdx.x] = __float2half2_rn(1.f / (temperatures[threadIdx.x] + 1e-6f));
    }
    __syncthreads();

    const half2 mask_val               = __float2half2_rn(-65504.0f);
    const int   half_vocab_size        = vocab_size / 2;
    const int   half_vocab_size_padded = vocab_size_padded / 2;
    for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size * half_vocab_size_padded;
         index += blockDim.x * gridDim.x) {
        int   batch_idx = index / half_vocab_size_padded;
        int   vocab_idx = index % half_vocab_size_padded;
        half2 logit     = vocab_idx < half_vocab_size ? __ldg(&logits[index]) : mask_val;
        if (vocab_idx < half_vocab_size) {
            if (bias != nullptr) {
                logit = __hadd2(logit, bias[vocab_idx]);
            }
            logits[index] = __hmul2(logit, h2_inv_temperatures[batch_idx]);
        }
    }
}

template<typename T>
void invokeBatchApplyTemperaturePenalty(T*           logits,
                                        const T*     bias,
                                        const float* temperatures,
                                        const int    batch_size,
                                        const int    vocab_size,
                                        const int    vocab_size_padd,
                                        cudaStream_t stream)
{
    dim3 block(min(vocab_size_padd, 1024));
    dim3 grid(min(batch_size * vocab_size_padd / block.x, 65536));
    if (std::is_same<T, half>::value && vocab_size % 2 == 0 && vocab_size_padd % 2 == 0) {
        size_t smem_size = sizeof(half2) * batch_size;
        batchApplyTemperaturePenalty_h2<<<grid, block, smem_size, stream>>>(reinterpret_cast<half2*>(logits),
                                                                            reinterpret_cast<const half2*>(bias),
                                                                            temperatures,
                                                                            batch_size,
                                                                            vocab_size,
                                                                            vocab_size_padd);
    }
    else {
        size_t smem_size = sizeof(float) * batch_size;
        batchApplyTemperaturePenalty<T>
            <<<grid, block, smem_size, stream>>>(logits, bias, temperatures, batch_size, vocab_size, vocab_size_padd);
    }
}

template void invokeBatchApplyTemperaturePenalty(float*       logits,
                                                 const float* bias,
                                                 const float* temperatures,
                                                 const int    batch_size,
                                                 const int    vocab_size,
                                                 const int    vocab_size_padd,
                                                 cudaStream_t stream);

template void invokeBatchApplyTemperaturePenalty(half*        logits,
                                                 const half*  bias,
                                                 const float* temperatures,
                                                 const int    batch_size,
                                                 const int    vocab_size,
                                                 const int    vocab_size_padd,
                                                 cudaStream_t stream);

template<typename T, RepetitionPenaltyType penalty_type>
__global__ void applyRepetitionPenalty(T*          logits,
                                       const float penalty,
                                       const int*  start_ids,
                                       int*        output_ids,
                                       const int   batch_size,
                                       const int   local_batch_size,
                                       const int   vocab_size,
                                       const int   vocab_size_padd,
                                       const int*  input_lengths,
                                       const int   max_input_len,
                                       const int   step)
{
    extern __shared__ float penalty_logits[];
    int*                    penalty_indices = (int*)(penalty_logits + step);

    logits                 = logits + blockIdx.x * vocab_size_padd;
    const int input_length = input_lengths != nullptr ? input_lengths[blockIdx.x] : max_input_len;
    for (int index = threadIdx.x; index < step; index += blockDim.x) {

        if (index >= input_length && index < max_input_len) {
            continue;
        }

        // output_ids shape: (input_len + output_len, batch_size)
        int penalty_index = output_ids[index * batch_size + blockIdx.x];
        if (penalty_index >= vocab_size) {
            continue;
        }
        penalty_indices[index] = penalty_index;
        float logit            = (float)logits[penalty_index];
        if (penalty_type == RepetitionPenaltyType::Additive) {
            penalty_logits[index] = logit - penalty;
        }
        else if (penalty_type == RepetitionPenaltyType::Multiplicative) {
            penalty_logits[index] = logit < 0.0f ? logit * penalty : logit / penalty;
        }
        else if (penalty_type == RepetitionPenaltyType::None) {
            penalty_logits[index] = logit;
        }
        else {
            // Unsupported type
            assert(false);
        }
    }

    if (blockDim.x > 32) {
        __syncthreads();
    }

    for (int index = threadIdx.x; index < step; index += blockDim.x) {

        if (index >= input_length && index < max_input_len) {
            continue;
        }

        // output_ids shape: (input_len + output_len, batch_size)
        if (penalty_indices[index] >= vocab_size) {
            continue;
        }
        logits[penalty_indices[index]] = penalty_logits[index];
    }
}

template<typename T>
void invokeApplyRepetitionPenalty(T*                          logits,
                                  const float                 penalty,
                                  const int*                  start_ids,
                                  int*                        output_ids,
                                  const int                   batch_size,
                                  const int                   local_batch_size,
                                  const int                   vocab_size,
                                  const int                   vocab_size_padd,
                                  const int*                  input_lengths,
                                  const int                   max_input_len,
                                  const int                   step,
                                  const RepetitionPenaltyType penalty_type,
                                  cudaStream_t                stream)
{
    dim3   block(min(step, 1024));
    dim3   grid(local_batch_size);
    size_t smem_size = step * (sizeof(float) + sizeof(int));

    if (penalty_type == RepetitionPenaltyType::Additive) {
        applyRepetitionPenalty<T, RepetitionPenaltyType::Additive><<<grid, block, smem_size, stream>>>(logits,
                                                                                                       penalty,
                                                                                                       start_ids,
                                                                                                       output_ids,
                                                                                                       batch_size,
                                                                                                       local_batch_size,
                                                                                                       vocab_size,
                                                                                                       vocab_size_padd,
                                                                                                       input_lengths,
                                                                                                       max_input_len,
                                                                                                       step);
    }
    else if (penalty_type == RepetitionPenaltyType::Multiplicative) {
        applyRepetitionPenalty<T, RepetitionPenaltyType::Multiplicative>
            <<<grid, block, smem_size, stream>>>(logits,
                                                 penalty,
                                                 start_ids,
                                                 output_ids,
                                                 batch_size,
                                                 local_batch_size,
                                                 vocab_size,
                                                 vocab_size_padd,
                                                 input_lengths,
                                                 max_input_len,
                                                 step);
    }
    else if (penalty_type == RepetitionPenaltyType::None) {
        // do nothing
    }
}

template void invokeApplyRepetitionPenalty(float*                      logits,
                                           const float                 penalty,
                                           const int*                  start_ids,
                                           int*                        output_ids,
                                           const int                   batch_size,
                                           const int                   local_batch_size,
                                           const int                   vocab_size,
                                           const int                   vocab_size_padd,
                                           const int*                  input_lengths,
                                           const int                   max_input_len,
                                           const int                   step,
                                           const RepetitionPenaltyType penalty_type,
                                           cudaStream_t                stream);

template void invokeApplyRepetitionPenalty(half*                       logits,
                                           const float                 penalty,
                                           const int*                  start_ids,
                                           int*                        output_ids,
                                           const int                   batch_size,
                                           const int                   local_batch_size,
                                           const int                   vocab_size,
                                           const int                   vocab_size_padd,
                                           const int*                  input_lengths,
                                           const int                   max_input_len,
                                           const int                   step,
                                           const RepetitionPenaltyType penalty_type,
                                           cudaStream_t                stream);

template<typename T, RepetitionPenaltyType penalty_type>
__global__ void batchApplyRepetitionPenalty(T*           logits,
                                            const float* penalties,
                                            const int*   output_ids,
                                            const int    batch_size,
                                            const int    vocab_size,
                                            const int*   input_lengths,
                                            const int    max_input_length,
                                            const int    step)
{
    extern __shared__ float penalty_logits[];
    int*                    penalty_indices = (int*)(penalty_logits + step);
    const int               batch_idx       = blockIdx.x;
    const float             penalty         = penalties[batch_idx];
    const int               input_length    = input_lengths != nullptr ? input_lengths[batch_idx] : max_input_length;

    logits += batch_idx * vocab_size;

    // Phase 1. Find indices to penalize and keep the penalized values.
    // A vocab id can appear multiple times but should be penalized once.
    for (int index = threadIdx.x; index < step; index += blockDim.x) {
        // Skip the padding tokens in input sequences.
        if (index >= input_length && index < max_input_length) {
            continue;
        }
        // output_ids shape: (input_len + output_len, batch_size)
        int penalty_index = output_ids[index * batch_size + batch_idx];
        assert(penalty_index < vocab_size);
        penalty_indices[index] = penalty_index;
        float logit            = (float)logits[penalty_index];
        if (penalty_type == RepetitionPenaltyType::Additive) {
            penalty_logits[index] = logit - penalty;
        }
        else if (penalty_type == RepetitionPenaltyType::Multiplicative) {
            penalty_logits[index] = logit < 0.0f ? logit * penalty : logit / penalty;
        }
        else if (penalty_type == RepetitionPenaltyType::None) {
            penalty_logits[index] = logit;
        }
        else {
            // Unsupported type
            assert(false);
        }
    }

    if (blockDim.x > 32) {
        __syncthreads();
    }

    // Phase 2. Replace a logit value by the penalized one.
    for (int index = threadIdx.x; index < step; index += blockDim.x) {
        // Skip the padding tokens in input sequences.
        if (index >= input_length && index < max_input_length) {
            continue;
        }
        logits[penalty_indices[index]] = penalty_logits[index];
    }
}

template<typename T>
void invokeBatchApplyRepetitionPenalty(T*                    logits,
                                       const float*          penalties,
                                       const int*            output_ids,
                                       const int             batch_size,
                                       const int             local_batch_size,
                                       const int             vocab_size,
                                       const int*            input_lengths,
                                       const int             max_input_length,
                                       const int             step,
                                       RepetitionPenaltyType penalty_type,
                                       cudaStream_t          stream)
{
    // Inputs
    //   logits [local_batch_size, vocab_size] : logit values.
    //   penalties [local_batch_size] : repetition penalty factors.
    //   output_ids [step, batch_size] : output token ids (with offset ite * local_batch_size).
    //   input_lengths [local_batch_size], input lengths (optional).
    //      Padding tokens at [input_length, max_input_length) of input will not be penalized.
    dim3   block(min(step, 1024));
    dim3   grid(local_batch_size);
    size_t smem_size = step * (sizeof(float) + sizeof(int));
    if (penalty_type == RepetitionPenaltyType::Additive) {
        batchApplyRepetitionPenalty<T, RepetitionPenaltyType::Additive><<<grid, block, smem_size, stream>>>(
            logits, penalties, output_ids, batch_size, vocab_size, input_lengths, max_input_length, step);
    }
    else if (penalty_type == RepetitionPenaltyType::Multiplicative) {
        batchApplyRepetitionPenalty<T, RepetitionPenaltyType::Multiplicative><<<grid, block, smem_size, stream>>>(
            logits, penalties, output_ids, batch_size, vocab_size, input_lengths, max_input_length, step);
    }
    else if (penalty_type == RepetitionPenaltyType::None) {
        // do nothing
    }
}

template void invokeBatchApplyRepetitionPenalty(float*                logits,
                                                const float*          penalties,
                                                const int*            output_ids,
                                                const int             batch_size,
                                                const int             local_batch_size,
                                                const int             vocab_size,
                                                const int*            input_lengths,
                                                const int             max_input_length,
                                                const int             step,
                                                RepetitionPenaltyType penalty_type,
                                                cudaStream_t          stream);

template void invokeBatchApplyRepetitionPenalty(half*                 logits,
                                                const float*          penalties,
                                                const int*            output_ids,
                                                const int             batch_size,
                                                const int             local_batch_size,
                                                const int             vocab_size,
                                                const int*            input_lengths,
                                                const int             max_input_length,
                                                const int             step,
                                                RepetitionPenaltyType penalty_type,
                                                cudaStream_t          stream);

template<typename T>
__global__ void batchApplyMinLengthPenalty(T*         logits,
                                           const int* min_lengths,
                                           const int* end_ids,
                                           const int* sequence_lengths,
                                           const int  max_input_length,
                                           const int  vocab_size_padded)
{
    int bid = threadIdx.x + blockIdx.x * blockDim.x;  // batch index
    // We need +1 because sequence_lengths = max_input_length + num_gen_tokens - 1,
    // which is equal to the length of k/v caches.
    if (sequence_lengths[bid] + 1 - max_input_length < min_lengths[bid]) {
        T mask_val                                     = (std::is_same<T, half>::value) ? -65504.0f : -FLT_MAX;
        logits[bid * vocab_size_padded + end_ids[bid]] = mask_val;
    }
}

template<typename T>
void invokeMinLengthPenalty(T*           logits,
                            const int*   min_lengths,
                            const int*   end_ids,
                            const int*   sequnece_lengths,
                            const int    max_input_length,
                            const int    batch_size,
                            const int    vocab_size_padded,
                            cudaStream_t stream)

{
    const int block_size = min(batch_size, 1024);
    const int grid_size  = (batch_size + block_size - 1) / block_size;
    batchApplyMinLengthPenalty<<<grid_size, block_size, 0, stream>>>(
        logits, min_lengths, end_ids, sequnece_lengths, max_input_length, vocab_size_padded);
}

template void invokeMinLengthPenalty(float*       logits,
                                     const int*   min_lengths,
                                     const int*   end_ids,
                                     const int*   sequnece_lengths,
                                     const int    max_input_length,
                                     const int    batch_size,
                                     const int    vocab_size_padded,
                                     cudaStream_t stream);

template void invokeMinLengthPenalty(half*        logits,
                                     const int*   min_lengths,
                                     const int*   end_ids,
                                     const int*   sequnece_lengths,
                                     const int    max_input_length,
                                     const int    batch_size,
                                     const int    vocab_size_padded,
                                     cudaStream_t stream);

}  // namespace fastertransformer