bert_preprocess_kernels.cu 7.64 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
/*
 * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "bert_preprocess_kernels.h"
lvhan028's avatar
lvhan028 committed
18
19
20
#include "src/turbomind/utils/cuda_bf16_fallbacks.cuh"
#include "src/turbomind/utils/cuda_fp8_utils.h"
#include "src/turbomind/utils/cuda_type_utils.cuh"
Li Zhang's avatar
Li Zhang committed
21

lvhan028's avatar
lvhan028 committed
22
namespace turbomind {
Li Zhang's avatar
Li Zhang committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50

__global__ void getPaddingOffsetAndCuSeqLensKernel(size_t*    h_valid_word_num,
                                                   int*       tmp_mask_offset,
                                                   int*       cu_seqlens,
                                                   const int* sequence_length,
                                                   const int  batch_size,
                                                   const int  max_seq_len)
{
    // do cumulated sum
    int        total_seq_len        = 0;
    int        cum_offset           = 0;
    int        index                = 0;
    const bool calculate_cu_seqlens = cu_seqlens != nullptr;
    for (int i = 0; i < batch_size; i++) {
        const int seq_len = sequence_length[i];
        if (calculate_cu_seqlens) {
            cu_seqlens[i] = total_seq_len;
        }
        for (int j = 0; j < seq_len; j++) {
            tmp_mask_offset[index] = cum_offset;
            index++;
        }
        cum_offset += max_seq_len - seq_len;
        total_seq_len += seq_len;
    }
    if (calculate_cu_seqlens) {
        cu_seqlens[batch_size] = total_seq_len;
    }
51
52
53
    if (h_valid_word_num) {
        h_valid_word_num[0] = (size_t)total_seq_len;
    }
Li Zhang's avatar
Li Zhang committed
54
55
56
57
58
59
60
61
62
63
64
}

void invokeGetPaddingOffsetAndCuSeqLens(size_t*      h_pinned_token_num,
                                        size_t*      h_token_num,
                                        int*         tmp_mask_offset,
                                        int*         cu_seqlens,
                                        const int*   sequence_lengths,
                                        const int    batch_size,
                                        const int    max_seq_len,
                                        cudaStream_t stream)
{
65
66
67
    if (h_pinned_token_num) {
        h_pinned_token_num[0] = 0;
    }
Li Zhang's avatar
Li Zhang committed
68
69
    getPaddingOffsetAndCuSeqLensKernel<<<1, 1, 0, stream>>>(
        h_pinned_token_num, tmp_mask_offset, cu_seqlens, sequence_lengths, batch_size, max_seq_len);
70
    if (h_pinned_token_num) {
Chen Xin's avatar
Chen Xin committed
71
#ifdef _MSC_VER
72
        cudaStreamSynchronize(stream);
Chen Xin's avatar
Chen Xin committed
73
#else
74
        while (((volatile size_t*)h_pinned_token_num)[0] == 0) {};
Chen Xin's avatar
Chen Xin committed
75
#endif
76
77
        h_token_num[0] = h_pinned_token_num[0];
    }
Li Zhang's avatar
Li Zhang committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
    sync_check_cuda_error();
}

template<typename T>
__global__ void rebuild_sequence_length_padding(const T* src, T* dst, const int* padding_offset, const int n)
{
    const int tid        = threadIdx.x;
    const int bid        = blockIdx.x;
    const int dst_seq_id = bid + padding_offset[bid];
    const int src_seq_id = bid;

    for (int i = tid; i < n; i += blockDim.x) {
        dst[dst_seq_id * n + i] = src[src_seq_id * n + i];
    }
}

template<typename T>
void invokeRebuildPadding(
    T* dst, const T* src, const int* padding_offset, const int token_num, const int hidden_dim, cudaStream_t stream)
{
    // src: [token_num, hidden_dim]
    // dst: [batch_size*max_seq_len, hidden_dim]
    rebuild_sequence_length_padding<<<token_num, 256, 0, stream>>>(src, dst, padding_offset, hidden_dim);
}

template<typename T>
void invokeRebuildPadding(
    T* dst, const T* src, const int* padding_offset, const int token_num, const int hidden_dim, cudaStream_t stream);
template void invokeRebuildPadding(float*       dst,
                                   const float* src,
                                   const int*   padding_offset,
                                   const int    token_num,
                                   const int    hidden_dim,
                                   cudaStream_t stream);
template void invokeRebuildPadding(half*        dst,
                                   const half*  src,
                                   const int*   padding_offset,
                                   const int    token_num,
                                   const int    hidden_dim,
                                   cudaStream_t stream);
#ifdef ENABLE_BF16
template void invokeRebuildPadding(__nv_bfloat16*       dst,
                                   const __nv_bfloat16* src,
                                   const int*           padding_offset,
                                   const int            token_num,
                                   const int            hidden_dim,
                                   cudaStream_t         stream);
#endif  // ENABLE_BF16

#ifdef ENABLE_FP8
template void invokeRebuildPadding(__nv_fp8_e4m3*       dst,
                                   const __nv_fp8_e4m3* src,
                                   const int*           padding_offset,
                                   const int            token_num,
                                   const int            hidden_dim,
                                   cudaStream_t         stream);
#endif  // ENABLE_FP8

template<typename T>
__global__ void remove_padding(T* tgt, const T* src, const int* padding_offset, const int n)
{
    const int tid        = threadIdx.x;
    const int bid        = blockIdx.x;
    const int src_seq_id = bid + padding_offset[bid];
    const int tgt_seq_id = bid;

    for (int i = tid; i < n; i += blockDim.x) {
        tgt[tgt_seq_id * n + i] = src[src_seq_id * n + i];
    }
}

template<typename T>
void invokeRemovePadding(
    T* dst, const T* src, const int* padding_offset, const int token_num, const int hidden_dim, cudaStream_t stream)
{
    remove_padding<<<token_num, 256, 0, stream>>>(dst, src, padding_offset, hidden_dim);
}

template void invokeRemovePadding(float*       dst,
                                  const float* src,
                                  const int*   padding_offset,
                                  const int    token_num,
                                  const int    hidden_dim,
                                  cudaStream_t stream);

template void invokeRemovePadding(half*        dst,
                                  const half*  src,
                                  const int*   padding_offset,
                                  const int    token_num,
                                  const int    hidden_dim,
                                  cudaStream_t stream);
#ifdef ENABLE_FP8
template void invokeRemovePadding(__nv_fp8_e4m3*       dst,
                                  const __nv_fp8_e4m3* src,
                                  const int*           padding_offset,
                                  const int            token_num,
                                  const int            hidden_dim,
                                  cudaStream_t         stream);
#endif  // ENABLE_FP8
#ifdef ENABLE_BF16
template void invokeRemovePadding(__nv_bfloat16*       dst,
                                  const __nv_bfloat16* src,
                                  const int*           padding_offset,
                                  const int            token_num,
                                  const int            hidden_dim,
                                  cudaStream_t         stream);
#endif

lvhan028's avatar
lvhan028 committed
186
}  // namespace turbomind