layernorm_kernels_impl.cuh 14.5 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
5
6
7
8
9
10
#include <cuda_bf16.h>

#define ENABLE_BF16 1

#include "utils.cuh"
#include "reduction_utils.cuh"

namespace vllm {

// from TRTLLM
Muyang Li's avatar
Muyang Li committed
11
12
13
template<typename Tf, typename T>
__inline__ __device__ Tf
compute_layernorm(Tf val, float s_mean, float s_variance, const T *gamma, const T *beta, int i) {
Zhekai Zhang's avatar
Zhekai Zhang committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
    Tf ret = (val - s_mean) * s_variance;
    if (gamma != nullptr) {
        ret = ret * cuda_cast<Tf>(gamma[i]);
    }
    if (beta != nullptr) {
        ret = ret + cuda_cast<Tf>(beta[i]);
    }
    return ret;
}

// from TRTLLM
/* Computes the layernorm https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html
 * normed_output <- ( (input - E[input]) / Sqrt(Var[input] + eps) ) * gamma + beta
 * input is [tokens, hidden_dim]. Mean and Variance are per-row (i.e. per-token)
 *
 * One CTA handles one row.
 *
 * with USE_DIFF_OF_SQUARES set to false:
 * First pass (loop) computes the mean.
 * Second computes the variance via Var[x] = E[(x - E[x])²].
 * Third pass computes and writes normed_output
 *
 * with USE_DIFF_OF_SQUARES set to true (may be faster but less accurate):
 * First pass (loop) computes the mean and variance via Var[x] = E[x²] - E[x]²
 * Second pass computes and writes normed_output
 *
 * use_shmem controls if we cache input values into shared memory
 *
 * Optional: with dynamic scaling, the last pass doesn't write immediately but finds the
 *           amax per row. A final pass scales to int8 accordingly, and writes output to
 *           normed_output_quant.
 */
Muyang Li's avatar
Muyang Li committed
46
47
48
49
50
51
52
53
54
55
56
57
template<typename T, typename scale_type, bool USE_DIFF_OF_SQUARES = false>
__global__ void generalLayerNorm(const T *input,
                                 const T *gamma,
                                 const T *beta,
                                 T *normed_output,
                                 const float eps,
                                 int tokens,
                                 int hidden_dim,
                                 const scale_type *scale_orig_quant_per_tensor,
                                 scale_type *scale_orig_quant_per_token,
                                 int8_t *normed_output_quant,
                                 bool use_shmem) {
Zhekai Zhang's avatar
Zhekai Zhang committed
58
    constexpr auto num_elems_T = num_elems<T>::value;
Muyang Li's avatar
Muyang Li committed
59
60
61
    using int8_packed_t        = typename packed_as<int8_t, num_elems_T>::type;
    using float_packed_t       = typename packed_as<float, num_elems_T>::type;
    using T_scalar             = typename packed_as<T, 1>::type;
Zhekai Zhang's avatar
Zhekai Zhang committed
62
63

    extern __shared__ __align__(sizeof(float)) char _shmem[];
Muyang Li's avatar
Muyang Li committed
64
    T *shmem = reinterpret_cast<T *>(_shmem);
Zhekai Zhang's avatar
Zhekai Zhang committed
65
66
67
68
69
70
    __shared__ float s_mean;
    __shared__ float s_variance;

    const int tidx = threadIdx.x;
    const int bidx = blockIdx.x;

Muyang Li's avatar
Muyang Li committed
71
72
73
    float mean          = 0.0f;
    float variance      = 0.0f;
    float local_sum     = 0.0f;
Zhekai Zhang's avatar
Zhekai Zhang committed
74
75
76
    float local_var_sum = 0.0f;

    const int n_elems = hidden_dim / num_elems_T;
Muyang Li's avatar
Muyang Li committed
77
    for (int i = tidx; i < n_elems; i += blockDim.x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
78
        const T val = input[bidx * n_elems + i];
Muyang Li's avatar
Muyang Li committed
79
        if (use_shmem) {
Zhekai Zhang's avatar
Zhekai Zhang committed
80
81
82
83
84
            shmem[i] = val;
        }

        const float_packed_t val_f = cuda_cast<float_packed_t>(val);
        local_sum += cuda_sum<float>(val_f);
Muyang Li's avatar
Muyang Li committed
85
        if (USE_DIFF_OF_SQUARES) {
Zhekai Zhang's avatar
Zhekai Zhang committed
86
87
88
89
            local_var_sum += cuda_sum<float>(val_f * val_f);
        }
    }

Muyang Li's avatar
Muyang Li committed
90
    if (USE_DIFF_OF_SQUARES) {
Zhekai Zhang's avatar
Zhekai Zhang committed
91
92
        float packed[2] = {local_sum, local_var_sum};
        blockReduceSumV2<float, 2>(packed);
Muyang Li's avatar
Muyang Li committed
93
        mean     = packed[0];
Zhekai Zhang's avatar
Zhekai Zhang committed
94
        variance = packed[1];
Muyang Li's avatar
Muyang Li committed
95
    } else {
Zhekai Zhang's avatar
Zhekai Zhang committed
96
97
98
        mean = blockReduceSum(local_sum);
    }

Muyang Li's avatar
Muyang Li committed
99
100
    if (threadIdx.x == 0) {
        mean   = mean / hidden_dim;
Zhekai Zhang's avatar
Zhekai Zhang committed
101
        s_mean = mean;
Muyang Li's avatar
Muyang Li committed
102
103
        if (USE_DIFF_OF_SQUARES) {
            variance   = (variance / hidden_dim) - (mean * mean); // Var[x] = E[x²] - E[x]²
Zhekai Zhang's avatar
Zhekai Zhang committed
104
105
106
107
108
            s_variance = rsqrtf(variance + eps);
        }
    }
    __syncthreads();

Muyang Li's avatar
Muyang Li committed
109
110
111
    if (!USE_DIFF_OF_SQUARES) {
        for (int i = tidx; i < n_elems; i += blockDim.x) {
            const T val         = use_shmem ? shmem[i] : input[bidx * n_elems + i];
Zhekai Zhang's avatar
Zhekai Zhang committed
112
113
114
115
116
            float_packed_t diff = cuda_cast<float_packed_t>(val) - s_mean;
            local_var_sum += cuda_sum<float>(diff * diff);
        }
        variance = blockReduceSum(local_var_sum);

Muyang Li's avatar
Muyang Li committed
117
        if (threadIdx.x == 0) {
Zhekai Zhang's avatar
Zhekai Zhang committed
118
119
120
121
122
            s_variance = rsqrtf(variance / hidden_dim + eps);
        }
        __syncthreads();
    }

Muyang Li's avatar
Muyang Li committed
123
    const bool with_per_token_scaling  = scale_orig_quant_per_token != nullptr;
Zhekai Zhang's avatar
Zhekai Zhang committed
124
    const bool with_per_tensor_scaling = scale_orig_quant_per_tensor != nullptr;
Muyang Li's avatar
Muyang Li committed
125
126
    const float_packed_t scale_orig_quant =
        cuda_cast<float_packed_t>(with_per_tensor_scaling ? __half2float(*scale_orig_quant_per_tensor) : 0.0f);
Zhekai Zhang's avatar
Zhekai Zhang committed
127
128
    T_scalar amax = 1e-6f;

Muyang Li's avatar
Muyang Li committed
129
130
    for (int i = tidx; i < n_elems; i += blockDim.x) {
        const int index            = bidx * n_elems + i;
Zhekai Zhang's avatar
Zhekai Zhang committed
131
        const float_packed_t val_f = cuda_cast<float_packed_t>(use_shmem ? shmem[i] : input[index]);
Muyang Li's avatar
Muyang Li committed
132
        const T val                = cuda_cast<T>(compute_layernorm(val_f, s_mean, s_variance, gamma, beta, i));
Zhekai Zhang's avatar
Zhekai Zhang committed
133

Muyang Li's avatar
Muyang Li committed
134
        if (with_per_token_scaling) {
Zhekai Zhang's avatar
Zhekai Zhang committed
135
            amax = cuda_max(cuda_max<T_scalar, T>(cuda_abs(val)), amax);
Muyang Li's avatar
Muyang Li committed
136
            if (use_shmem) {
Zhekai Zhang's avatar
Zhekai Zhang committed
137
138
                shmem[i] = val;
            }
Muyang Li's avatar
Muyang Li committed
139
140
141
142
        } else if (with_per_tensor_scaling) {
            reinterpret_cast<int8_packed_t *>(normed_output_quant)[index] =
                cuda_cast<int8_packed_t>(cuda_cast<float_packed_t>(val) * scale_orig_quant);
        } else {
Zhekai Zhang's avatar
Zhekai Zhang committed
143
144
145
146
            normed_output[index] = val;
        }
    }

Muyang Li's avatar
Muyang Li committed
147
148
    if (with_per_token_scaling) {
        float abs_max_f                     = blockAllReduceMax(cuda_cast<float>(amax));
Zhekai Zhang's avatar
Zhekai Zhang committed
149
        const float dynamic_per_token_scale = 127.f / abs_max_f;
Muyang Li's avatar
Muyang Li committed
150
151
        for (int i = tidx; i < n_elems; i += blockDim.x) {
            const int index      = bidx * n_elems + i;
Zhekai Zhang's avatar
Zhekai Zhang committed
152
            float_packed_t val_f = cuda_cast<float_packed_t>(use_shmem ? shmem[i] : input[index]);
Muyang Li's avatar
Muyang Li committed
153
            if (!use_shmem) {
Zhekai Zhang's avatar
Zhekai Zhang committed
154
155
156
                val_f = compute_layernorm(val_f, s_mean, s_variance, gamma, beta, i);
            }

Muyang Li's avatar
Muyang Li committed
157
158
            reinterpret_cast<int8_packed_t *>(normed_output_quant)[index] =
                cuda_cast<int8_packed_t>(val_f * cuda_cast<float_packed_t>(dynamic_per_token_scale));
Zhekai Zhang's avatar
Zhekai Zhang committed
159
        }
Muyang Li's avatar
Muyang Li committed
160
        if (tidx == 0) {
Zhekai Zhang's avatar
Zhekai Zhang committed
161
162
163
164
165
            scale_orig_quant_per_token[bidx] = abs_max_f / 127.f;
        }
    }
}

Muyang Li's avatar
Muyang Li committed
166
167
168
169
170
171
172
173
174
175
176
177
178
template<typename T, typename scale_type, bool USE_DIFF_OF_SQUARES = false>
__global__ void generalLayerNorm_fuse_sum(const T *input,
                                          const T *gamma,
                                          const T *beta,
                                          T *normed_output,
                                          const float eps,
                                          int tokens,
                                          int hidden_dim,
                                          scale_type *input_sum,
                                          const scale_type *scale_orig_quant_per_tensor,
                                          scale_type *scale_orig_quant_per_token,
                                          int8_t *normed_output_quant,
                                          bool use_shmem) {
Zhekai Zhang's avatar
Zhekai Zhang committed
179
    constexpr auto num_elems_T = num_elems<T>::value;
Muyang Li's avatar
Muyang Li committed
180
181
182
    using int8_packed_t        = typename packed_as<int8_t, num_elems_T>::type;
    using float_packed_t       = typename packed_as<float, num_elems_T>::type;
    using T_scalar             = typename packed_as<T, 1>::type;
Zhekai Zhang's avatar
Zhekai Zhang committed
183
184

    extern __shared__ __align__(sizeof(float)) char _shmem[];
Muyang Li's avatar
Muyang Li committed
185
    T *shmem = reinterpret_cast<T *>(_shmem);
Zhekai Zhang's avatar
Zhekai Zhang committed
186
187
188
189
190
191
    __shared__ float s_mean;
    __shared__ float s_variance;

    const int tidx = threadIdx.x;
    const int bidx = blockIdx.x;

Muyang Li's avatar
Muyang Li committed
192
193
194
    float mean          = 0.0f;
    float variance      = 0.0f;
    float local_sum     = 0.0f;
Zhekai Zhang's avatar
Zhekai Zhang committed
195
196
197
    float local_var_sum = 0.0f;

    const int n_elems = hidden_dim / num_elems_T;
Muyang Li's avatar
Muyang Li committed
198
    for (int i = tidx; i < n_elems; i += blockDim.x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
199
        const T val = input[bidx * n_elems + i];
Muyang Li's avatar
Muyang Li committed
200
        if (use_shmem) {
Zhekai Zhang's avatar
Zhekai Zhang committed
201
202
203
204
205
            shmem[i] = val;
        }

        const float_packed_t val_f = cuda_cast<float_packed_t>(val);
        local_sum += cuda_sum<float>(val_f);
Muyang Li's avatar
Muyang Li committed
206
        if (USE_DIFF_OF_SQUARES) {
Zhekai Zhang's avatar
Zhekai Zhang committed
207
208
209
210
            local_var_sum += cuda_sum<float>(val_f * val_f);
        }
    }

Muyang Li's avatar
Muyang Li committed
211
    if (USE_DIFF_OF_SQUARES) {
Zhekai Zhang's avatar
Zhekai Zhang committed
212
213
        float packed[2] = {local_sum, local_var_sum};
        blockReduceSumV2<float, 2>(packed);
Muyang Li's avatar
Muyang Li committed
214
        mean     = packed[0];
Zhekai Zhang's avatar
Zhekai Zhang committed
215
        variance = packed[1];
Muyang Li's avatar
Muyang Li committed
216
    } else {
Zhekai Zhang's avatar
Zhekai Zhang committed
217
218
219
        mean = blockReduceSum(local_sum);
    }

Muyang Li's avatar
Muyang Li committed
220
221
    if (threadIdx.x == 0) {
        mean   = mean / hidden_dim;
Zhekai Zhang's avatar
Zhekai Zhang committed
222
        s_mean = mean;
Muyang Li's avatar
Muyang Li committed
223
224
        if (USE_DIFF_OF_SQUARES) {
            variance   = (variance / hidden_dim) - (mean * mean); // Var[x] = E[x²] - E[x]²
Zhekai Zhang's avatar
Zhekai Zhang committed
225
226
227
228
229
            s_variance = rsqrtf(variance + eps);
        }
    }
    __syncthreads();

Muyang Li's avatar
Muyang Li committed
230
231
232
    if (!USE_DIFF_OF_SQUARES) {
        for (int i = tidx; i < n_elems; i += blockDim.x) {
            const T val         = use_shmem ? shmem[i] : input[bidx * n_elems + i];
Zhekai Zhang's avatar
Zhekai Zhang committed
233
234
235
236
237
            float_packed_t diff = cuda_cast<float_packed_t>(val) - s_mean;
            local_var_sum += cuda_sum<float>(diff * diff);
        }
        variance = blockReduceSum(local_var_sum);

Muyang Li's avatar
Muyang Li committed
238
        if (threadIdx.x == 0) {
Zhekai Zhang's avatar
Zhekai Zhang committed
239
240
241
242
243
            s_variance = rsqrtf(variance / hidden_dim + eps);
        }
        __syncthreads();
    }

Muyang Li's avatar
Muyang Li committed
244
    const bool with_per_token_scaling  = scale_orig_quant_per_token != nullptr;
Zhekai Zhang's avatar
Zhekai Zhang committed
245
    const bool with_per_tensor_scaling = scale_orig_quant_per_tensor != nullptr;
Muyang Li's avatar
Muyang Li committed
246
247
    const float_packed_t scale_orig_quant =
        cuda_cast<float_packed_t>(with_per_tensor_scaling ? __half2float(*scale_orig_quant_per_tensor) : 0.0f);
Zhekai Zhang's avatar
Zhekai Zhang committed
248
    T_scalar amax = 1e-6f;
Muyang Li's avatar
Muyang Li committed
249
    T_scalar sum  = 0.0f;
Zhekai Zhang's avatar
Zhekai Zhang committed
250

Muyang Li's avatar
Muyang Li committed
251
252
    for (int i = tidx; i < n_elems; i += blockDim.x) {
        const int index            = bidx * n_elems + i;
Zhekai Zhang's avatar
Zhekai Zhang committed
253
        const float_packed_t val_f = cuda_cast<float_packed_t>(use_shmem ? shmem[i] : input[index]);
Muyang Li's avatar
Muyang Li committed
254
        const T val                = cuda_cast<T>(compute_layernorm(val_f, s_mean, s_variance, gamma, beta, i));
Zhekai Zhang's avatar
Zhekai Zhang committed
255

Muyang Li's avatar
Muyang Li committed
256
        if (with_per_token_scaling) {
Zhekai Zhang's avatar
Zhekai Zhang committed
257
258
            amax = cuda_max(cuda_max<T_scalar, T>(cuda_abs(val)), amax);
            sum += cuda_sum<float>(val);
Muyang Li's avatar
Muyang Li committed
259
            if (use_shmem) {
Zhekai Zhang's avatar
Zhekai Zhang committed
260
261
                shmem[i] = val;
            }
Muyang Li's avatar
Muyang Li committed
262
263
264
265
        } else if (with_per_tensor_scaling) {
            reinterpret_cast<int8_packed_t *>(normed_output_quant)[index] =
                cuda_cast<int8_packed_t>(cuda_cast<float_packed_t>(val) * scale_orig_quant);
        } else {
Zhekai Zhang's avatar
Zhekai Zhang committed
266
267
268
269
            normed_output[index] = val;
        }
    }

Muyang Li's avatar
Muyang Li committed
270
271
272
    if (with_per_token_scaling) {
        float abs_max_f                     = blockAllReduceMax(cuda_cast<float>(amax));
        float sum_f                         = blockAllReduceSum(cuda_cast<float>(sum));
Zhekai Zhang's avatar
Zhekai Zhang committed
273
        const float dynamic_per_token_scale = 127.f / abs_max_f;
Muyang Li's avatar
Muyang Li committed
274
275
        for (int i = tidx; i < n_elems; i += blockDim.x) {
            const int index      = bidx * n_elems + i;
Zhekai Zhang's avatar
Zhekai Zhang committed
276
            float_packed_t val_f = cuda_cast<float_packed_t>(use_shmem ? shmem[i] : input[index]);
Muyang Li's avatar
Muyang Li committed
277
            if (!use_shmem) {
Zhekai Zhang's avatar
Zhekai Zhang committed
278
279
280
                val_f = compute_layernorm(val_f, s_mean, s_variance, gamma, beta, i);
            }

Muyang Li's avatar
Muyang Li committed
281
282
            reinterpret_cast<int8_packed_t *>(normed_output_quant)[index] =
                cuda_cast<int8_packed_t>(val_f * cuda_cast<float_packed_t>(dynamic_per_token_scale));
Zhekai Zhang's avatar
Zhekai Zhang committed
283
        }
Muyang Li's avatar
Muyang Li committed
284
        if (tidx == 0) {
Zhekai Zhang's avatar
Zhekai Zhang committed
285
            scale_orig_quant_per_token[bidx] = abs_max_f / 127.f;
Muyang Li's avatar
Muyang Li committed
286
            input_sum[bidx]                  = sum_f;
Zhekai Zhang's avatar
Zhekai Zhang committed
287
288
289
290
291
        }
    }
}

// TODO(woosuk): Further optimize this kernel.
Muyang Li's avatar
Muyang Li committed
292
293
294
295
296
297
298
299
300
301
302
303
304
template<typename scalar_t, typename out_type, bool use_quant>
__global__ void rms_norm_kernel(out_type *__restrict__ out,          // [..., hidden_size]
                                const scalar_t *__restrict__ input,  // [..., hidden_size]
                                const scalar_t *__restrict__ weight, // [hidden_size]
                                const float epsilon,
                                const int num_tokens,
                                const int hidden_size) {
    __shared__ float s_variance;
    float variance = 0.0f;

    for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
        const float x = (float)input[blockIdx.x * hidden_size + idx];
        variance += x * x;
Zhekai Zhang's avatar
Zhekai Zhang committed
305
    }
Muyang Li's avatar
Muyang Li committed
306
307
308
309
310
    variance = blockReduceSum<float>(variance);
    if (threadIdx.x == 0) {
        s_variance = rsqrtf(variance / hidden_size + epsilon);
    }
    __syncthreads();
Zhekai Zhang's avatar
Zhekai Zhang committed
311

Muyang Li's avatar
Muyang Li committed
312
313
314
315
316
317
318
319
320
    for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
        float x = (float)input[blockIdx.x * hidden_size + idx];
        if constexpr (use_quant) {
            out[blockIdx.x * hidden_size + idx] = float_to_int8_rn(((float)(x * s_variance)) * (float)(weight[idx]));
        } else {
            out[blockIdx.x * hidden_size + idx] = ((scalar_t)(x * s_variance)) * weight[idx];
        }
    }
}
Zhekai Zhang's avatar
Zhekai Zhang committed
321

Muyang Li's avatar
Muyang Li committed
322
323
324
325
326
327
328
329
330
331
332
template<typename T, typename scale_type, bool use_per_token_dequant>
__global__ void dequant_add_residual_rms_norm_quant_kernel(const int32_t *__restrict__ input,
                                                           T *__restrict__ residual,
                                                           int8_t *__restrict__ output,
                                                           const T *__restrict__ gamma,
                                                           const float layernorm_eps,
                                                           const scale_type scale,
                                                           int num_tokens,
                                                           int hidden_size) {
    // layernorm module in the T5 style No bias and no subtraction of mean.
    const int tid = threadIdx.x;
Zhekai Zhang's avatar
Zhekai Zhang committed
333

Muyang Li's avatar
Muyang Li committed
334
335
    __shared__ float s_variance;
    float variance = 0.0f;
Zhekai Zhang's avatar
Zhekai Zhang committed
336

Muyang Li's avatar
Muyang Li committed
337
338
339
340
341
342
343
344
345
346
347
348
349
350
    float local_var_sum = 0.0f;
    for (int i = tid; i < hidden_size; i += blockDim.x) {
        float diff = 0.0f;
        if constexpr (use_per_token_dequant) {
            diff = ((((float)input[blockIdx.x * hidden_size + i]) * __half2float(scale[blockIdx.x])) +
                    (float)residual[blockIdx.x * hidden_size + i]);
        } else {
            diff = ((((float)input[blockIdx.x * hidden_size + i]) * __half2float(scale)) +
                    (float)residual[blockIdx.x * hidden_size + i]);
        }
        residual[blockIdx.x * hidden_size + i] = (T)diff;
        local_var_sum += diff * diff;
    }
    variance = blockReduceSum(local_var_sum);
Zhekai Zhang's avatar
Zhekai Zhang committed
351

Muyang Li's avatar
Muyang Li committed
352
353
354
355
    if (threadIdx.x == 0) {
        s_variance = rsqrtf(variance / (float)hidden_size + layernorm_eps);
    }
    __syncthreads();
Zhekai Zhang's avatar
Zhekai Zhang committed
356

Muyang Li's avatar
Muyang Li committed
357
358
359
    for (int i = tid; i < hidden_size; i += blockDim.x) {
        output[blockIdx.x * hidden_size + i] =
            float_to_int8_rn((((float)(residual[blockIdx.x * hidden_size + i])) * s_variance) * (float)(gamma[i]));
Zhekai Zhang's avatar
Zhekai Zhang committed
360
361
362
    }
}
} // namespace vllm