decoding_kernels.cu 20.7 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
/*
 * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

lvhan028's avatar
lvhan028 committed
17
18
19
20
#include "src/turbomind/kernels/decoding_kernels.h"
#include "src/turbomind/kernels/reduce_kernel_utils.cuh"
#include "src/turbomind/utils/cuda_type_utils.cuh"
#include "src/turbomind/utils/cuda_utils.h"
Li Zhang's avatar
Li Zhang committed
21

lvhan028's avatar
lvhan028 committed
22
namespace turbomind {
Li Zhang's avatar
Li Zhang committed
23
24
25

// PROMPT_SRC: 0 --> no prompts, 1 --> from loaded prompts, 2 --> from request prompts
template<typename T>
AllentDan's avatar
AllentDan committed
26
27
28
29
30
31
32
33
34
35
36
37
38
__global__ void embeddingLookupPosEncoding(T*            from_tensor,
                                           const T*      embedding_table,
                                           const T*      position_encoding,
                                           const int*    all_ids,
                                           const int*    padding_count,
                                           const int*    input_lengths,
                                           const int     local_token_num,
                                           const int64_t hidden_units,
                                           const int     step,
                                           const int     max_input_length,
                                           const int     token_num,
                                           const int     ite,
                                           const T       scale)
Li Zhang's avatar
Li Zhang committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
{
    // 1. lookup from embedding table
    // 2. multiply scale
    // 3. add the position encoding
    const int id_offset = step * token_num + ite * local_token_num;

    const bool use_padding_count = padding_count != nullptr;
    const bool use_input_len     = input_lengths != nullptr;

    for (int64_t index = blockIdx.x * blockDim.x + threadIdx.x; index < local_token_num * hidden_units;
         index += blockDim.x * gridDim.x) {
        const int row_index   = index / hidden_units;
        const int col_index   = index % hidden_units;
        int       step_offset = step;
        if (use_padding_count) {
            step_offset -= padding_count[row_index];
        }
        else if (use_input_len) {
            step_offset -= max_input_length - input_lengths[row_index];
        }
        step_offset *= hidden_units;

        T val = embedding_table[all_ids[id_offset + row_index] * hidden_units + col_index] * scale;
        val   = val + position_encoding[step_offset + col_index];

        from_tensor[index] = val;
    }
}

// No absolute position embedding
// PROMPT_SRC: 0 --> no prompts, 1 --> from loaded prompts, 2 --> from request prompts
template<typename T, int PROMPT_SRC>
__global__ void embeddingLookup(T*                    from_tensor,
                                const T*              embedding_table,
                                const int*            all_ids,
                                pPromptTuningParam<T> prompt_param,
                                const int             local_token_num,
                                const int64_t         hidden_units,
                                const int             step,
                                const int             token_num,
                                const int             ite,
                                const int             seq_len,
                                const T               scale)
{
    // 1. lookup from embedding table
    // 2. multiply scale
    const int id_offset = step * token_num + ite * local_token_num;

    for (int64_t index = blockIdx.x * blockDim.x + threadIdx.x; index < local_token_num * hidden_units;
         index += blockDim.x * gridDim.x) {

        const int word_index     = index / hidden_units;
        const int word_index_row = word_index / seq_len;  // batch_id
        const int col_index      = index % hidden_units;
        const int input_id       = all_ids == nullptr ? word_index : all_ids[id_offset + word_index];
        const int prompt_id      = input_id - prompt_param.p_prompt_tuning_id_start;
        T         embedding      = (T)0.0f;
        if (PROMPT_SRC > 0 && prompt_id >= 0) {
            if (PROMPT_SRC == 1) {
                // from loaded prompt embedding tables
                embedding =
                    prompt_param.p_prompt_tuning_batch_weights[word_index_row][prompt_id * hidden_units + col_index];
            }
            else {
                // from request prompt embedding
                embedding =
                    prompt_param
                        .request_prompt_embedding[word_index_row * prompt_param.request_prompt_max_length * hidden_units
                                                  + prompt_id * hidden_units + col_index];
            }
        }
        else {
            embedding = embedding_table[input_id * hidden_units + col_index];
        }
        from_tensor[index] = embedding * scale;
    }
}

#define EMBEDDING_LOOKUP(PROMPT_SRC)                                                                                   \
    embeddingLookup<T, PROMPT_SRC><<<grid, block, 0, stream>>>(from_tensor,                                            \
                                                               embedding_table,                                        \
                                                               all_ids,                                                \
                                                               prompt_param,                                           \
                                                               local_token_num,                                        \
                                                               hidden_units,                                           \
                                                               step,                                                   \
                                                               token_num,                                              \
                                                               ite,                                                    \
                                                               seq_len,                                                \
                                                               scale);

/* Adapter function for invokeEmbeddingLookupPosEncoding{PadCount,InputLen} */
template<typename T>
void invokeEmbeddingLookupPosEncoding(T*                    from_tensor,
                                      const T*              embedding_table,
                                      const T*              position_encoding,
                                      const int*            all_ids,
                                      const int*            padding_count,
                                      const int*            input_lengths,
                                      pPromptTuningParam<T> prompt_param,
                                      const int             local_token_num,
                                      const int             hidden_units,
                                      const T               scale,
                                      const int             step,
                                      const int             max_input_length,
                                      const int             token_num,
                                      const int             ite,
                                      const int             seq_len,
                                      cudaStream_t          stream)
{
    dim3 grid(min(local_token_num, 65536));
    dim3 block(min(hidden_units, 1024));
    if (position_encoding != nullptr) {
        FT_CHECK_WITH_INFO(prompt_param.use_request_p_prompt_embedding == false
                               && prompt_param.p_prompt_tuning_batch_weights == nullptr,
                           fmtstr("embeddingLookupPosEncoding still not support prompt tuning"));
        embeddingLookupPosEncoding<T><<<grid, block, 0, stream>>>(from_tensor,
                                                                  embedding_table,
                                                                  position_encoding,
                                                                  all_ids,
                                                                  padding_count,
                                                                  input_lengths,
                                                                  local_token_num,
                                                                  hidden_units,
                                                                  step,
                                                                  max_input_length,
                                                                  token_num,
                                                                  ite,
                                                                  scale);
    }
    else {
        if (prompt_param.use_request_p_prompt_embedding) {
            EMBEDDING_LOOKUP(2);
        }
        else if (prompt_param.p_prompt_tuning_batch_weights != nullptr) {
            EMBEDDING_LOOKUP(1);
        }
        else {
            EMBEDDING_LOOKUP(0);
        }
    }
}

#undef EMBEDDING_LOOKUP

template<typename T>
void invokeEmbeddingLookupPosEncodingPadCount(T*                    from_tensor,
                                              const T*              embedding_table,
                                              const T*              position_encoding,
                                              const int*            all_ids,
                                              const int*            pad_count,
                                              pPromptTuningParam<T> prompt_param,
                                              const int             local_token_num,
                                              const int             hidden_units,
                                              const T               scale,
                                              const int             step,
                                              const int             token_num,
                                              const int             ite,
                                              const int             seq_len,
                                              cudaStream_t          stream)
{
    invokeEmbeddingLookupPosEncoding<T>(from_tensor,
                                        embedding_table,
                                        position_encoding,
                                        all_ids,
                                        pad_count,
                                        nullptr,
                                        prompt_param,
                                        local_token_num,
                                        hidden_units,
                                        scale,
                                        step,
                                        0,
                                        token_num,
                                        ite,
                                        seq_len,
                                        stream);
}

#define INSTANTIATE_LOOKUP_POS_ENCODING_PAD_COUNT(T)                                                                   \
    template void invokeEmbeddingLookupPosEncodingPadCount(T*                    from_tensor,                          \
                                                           const T*              embedding_table,                      \
                                                           const T*              position_encoding,                    \
                                                           const int*            all_ids,                              \
                                                           const int*            pad_count,                            \
                                                           pPromptTuningParam<T> prompt_param,                         \
                                                           const int             local_token_num,                      \
                                                           const int             hidden_units,                         \
                                                           const T               scale,                                \
                                                           const int             step,                                 \
                                                           const int             token_num,                            \
                                                           const int             ite,                                  \
                                                           const int             seq_len,                              \
                                                           cudaStream_t          stream)
INSTANTIATE_LOOKUP_POS_ENCODING_PAD_COUNT(float);
INSTANTIATE_LOOKUP_POS_ENCODING_PAD_COUNT(half);
#ifdef ENABLE_BF16
INSTANTIATE_LOOKUP_POS_ENCODING_PAD_COUNT(__nv_bfloat16);
#endif
#undef INSTANTIATE_LOOKUP_POS_ENCODING_PAD_COUNT

template<typename T>
__global__ void paddingEmbedding(T*            padded_embedding_kernel,
                                 T*            padded_embedding_bias,
                                 const T*      embedding_kernel,
                                 const T*      embedding_bias,
                                 const int64_t hidden_unit,
                                 const int64_t vocab_size,
                                 const int64_t vocab_size_padded)
{
    for (int64_t id = threadIdx.x + blockIdx.x * blockDim.x; id < hidden_unit * vocab_size_padded;
         id += blockDim.x * gridDim.x) {
        int row_id = id / vocab_size_padded;
        int col_id = id % vocab_size_padded;
        if (col_id < vocab_size) {
            padded_embedding_kernel[id] = embedding_kernel[row_id * vocab_size + col_id];
        }
        else {
            padded_embedding_kernel[id] = (T)(0.0f);
        }
    }

    for (int id = threadIdx.x + blockIdx.x * blockDim.x; id < vocab_size_padded; id += blockDim.x * gridDim.x) {
        if (id < vocab_size) {
            padded_embedding_bias[id] = embedding_bias[id];
        }
        else {
            padded_embedding_bias[id] = (T)(0.0f);
        }
    }
}

template<typename T>
void invokePaddingEmbedding(T*           padded_embedding_kernel,
                            T*           padded_embedding_bias,
                            const T*     embedding_kernel,
                            const T*     embedding_bias,
                            const int    hidden_unit,
                            const int    vocab_size,
                            const int    vocab_size_padded,
                            cudaStream_t stream)
{
    dim3 block(512);
    dim3 grid((int)(ceil(hidden_unit * vocab_size_padded / 512.)));
    paddingEmbedding<<<grid, block, 0, stream>>>(padded_embedding_kernel,
                                                 padded_embedding_bias,
                                                 embedding_kernel,
                                                 embedding_bias,
                                                 hidden_unit,
                                                 vocab_size,
                                                 vocab_size_padded);
}

292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
// template void invokePaddingEmbedding(float*       padded_embedding_kernel,
//                                      float*       padded_embedding_bias,
//                                      const float* embedding_kernel,
//                                      const float* embedding_bias,
//                                      const int    hidden_unit,
//                                      const int    vocab_size,
//                                      const int    vocab_size_padded,
//                                      cudaStream_t stream);

// template void invokePaddingEmbedding(half*        padded_embedding_kernel,
//                                      half*        padded_embedding_bias,
//                                      const half*  embedding_kernel,
//                                      const half*  embedding_bias,
//                                      const int    hidden_unit,
//                                      const int    vocab_size,
//                                      const int    vocab_size_padded,
//                                      cudaStream_t stream);
// #ifdef ENABLE_BF16
// template void invokePaddingEmbedding(__nv_bfloat16*       padded_embedding_kernel,
//                                      __nv_bfloat16*       padded_embedding_bias,
//                                      const __nv_bfloat16* embedding_kernel,
//                                      const __nv_bfloat16* embedding_bias,
//                                      const int            hidden_unit,
//                                      const int            vocab_size,
//                                      const int            vocab_size_padded,
//                                      cudaStream_t         stream);
// #endif
Li Zhang's avatar
Li Zhang committed
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353

template<typename T>
__global__ void paddingEmbeddingKernel(T*        padded_embedding_kernel,
                                       const T*  embedding_kernel,
                                       const int hidden_unit,
                                       const int vocab_size,
                                       const int vocab_size_padded)
{
    for (int id = threadIdx.x + blockIdx.x * blockDim.x; id < hidden_unit * vocab_size_padded;
         id += blockDim.x * gridDim.x) {
        int row_id = id / hidden_unit;
        int col_id = id % hidden_unit;
        if (row_id < vocab_size) {
            padded_embedding_kernel[id] = embedding_kernel[row_id * hidden_unit + col_id];
        }
        else {
            padded_embedding_kernel[id] = (T)(0.0f);
        }
    }
}

template<typename T>
void invokePaddingEmbeddingKernel(T*           padded_embedding_kernel,
                                  const T*     embedding_kernel,
                                  const int    hidden_unit,
                                  const int    vocab_size,
                                  const int    vocab_size_padded,
                                  cudaStream_t stream)
{
    dim3 block(512);
    dim3 grid((int)(ceil(hidden_unit * vocab_size_padded / 512.)));
    paddingEmbeddingKernel<<<grid, block, 0, stream>>>(
        padded_embedding_kernel, embedding_kernel, hidden_unit, vocab_size, vocab_size_padded);
}

354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
// template void invokePaddingEmbeddingKernel(float*       padded_embedding_kernel,
//                                            const float* embedding_kernel,
//                                            const int    hidden_unit,
//                                            const int    vocab_size,
//                                            const int    vocab_size_padded,
//                                            cudaStream_t stream);

// template void invokePaddingEmbeddingKernel(half*        padded_embedding_kernel,
//                                            const half*  embedding_kernel,
//                                            const int    hidden_unit,
//                                            const int    vocab_size,
//                                            const int    vocab_size_padded,
//                                            cudaStream_t stream);

// #ifdef ENABLE_BF16
// template void invokePaddingEmbeddingKernel(__nv_bfloat16*       padded_embedding_kernel,
//                                            const __nv_bfloat16* embedding_kernel,
//                                            const int            hidden_unit,
//                                            const int            vocab_size,
//                                            const int            vocab_size_padded,
//                                            cudaStream_t         stream);
// #endif
Li Zhang's avatar
Li Zhang committed
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394

template<typename T>
__global__ void plusScalar(T* buf, const T val, const int size)
{
    for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < size; i += blockDim.x * gridDim.x) {
        buf[i] += val;
    }
}

template<typename T>
void invokePlusScalar(T* buf, const T val, const int size, cudaStream_t stream)
{
    dim3 block(min(256, size));
    dim3 grid(ceil(size / 256.));
    plusScalar<<<block, grid, 0, stream>>>(buf, val, size);
}

template void invokePlusScalar(int* buf, const int val, const int size, cudaStream_t stream);

lvhan028's avatar
lvhan028 committed
395
}  // namespace turbomind