layernorm_kernels_impl.cuh 14.5 KB
Newer Older
fengzch-das's avatar
fengzch-das committed
1
2
#include "hip/hip_runtime.h"
#include <hip/hip_bf16.h>
Zhekai Zhang's avatar
Zhekai Zhang committed
3
4
5
6
7
8
9
10
11

#define ENABLE_BF16 1

#include "utils.cuh"
#include "reduction_utils.cuh"

namespace vllm {

// from TRTLLM
Muyang Li's avatar
Muyang Li committed
12
13
14
template<typename Tf, typename T>
__inline__ __device__ Tf
compute_layernorm(Tf val, float s_mean, float s_variance, const T *gamma, const T *beta, int i) {
Zhekai Zhang's avatar
Zhekai Zhang committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
    Tf ret = (val - s_mean) * s_variance;
    if (gamma != nullptr) {
        ret = ret * cuda_cast<Tf>(gamma[i]);
    }
    if (beta != nullptr) {
        ret = ret + cuda_cast<Tf>(beta[i]);
    }
    return ret;
}

// from TRTLLM
/* Computes the layernorm https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html
 * normed_output <- ( (input - E[input]) / Sqrt(Var[input] + eps) ) * gamma + beta
 * input is [tokens, hidden_dim]. Mean and Variance are per-row (i.e. per-token)
 *
 * One CTA handles one row.
 *
 * with USE_DIFF_OF_SQUARES set to false:
 * First pass (loop) computes the mean.
 * Second computes the variance via Var[x] = E[(x - E[x])²].
 * Third pass computes and writes normed_output
 *
 * with USE_DIFF_OF_SQUARES set to true (may be faster but less accurate):
 * First pass (loop) computes the mean and variance via Var[x] = E[x²] - E[x]²
 * Second pass computes and writes normed_output
 *
 * use_shmem controls if we cache input values into shared memory
 *
 * Optional: with dynamic scaling, the last pass doesn't write immediately but finds the
 *           amax per row. A final pass scales to int8 accordingly, and writes output to
 *           normed_output_quant.
 */
Muyang Li's avatar
Muyang Li committed
47
48
49
50
51
52
53
54
55
56
57
58
template<typename T, typename scale_type, bool USE_DIFF_OF_SQUARES = false>
__global__ void generalLayerNorm(const T *input,
                                 const T *gamma,
                                 const T *beta,
                                 T *normed_output,
                                 const float eps,
                                 int tokens,
                                 int hidden_dim,
                                 const scale_type *scale_orig_quant_per_tensor,
                                 scale_type *scale_orig_quant_per_token,
                                 int8_t *normed_output_quant,
                                 bool use_shmem) {
Zhekai Zhang's avatar
Zhekai Zhang committed
59
    constexpr auto num_elems_T = num_elems<T>::value;
Muyang Li's avatar
Muyang Li committed
60
61
62
    using int8_packed_t        = typename packed_as<int8_t, num_elems_T>::type;
    using float_packed_t       = typename packed_as<float, num_elems_T>::type;
    using T_scalar             = typename packed_as<T, 1>::type;
Zhekai Zhang's avatar
Zhekai Zhang committed
63
64

    extern __shared__ __align__(sizeof(float)) char _shmem[];
Muyang Li's avatar
Muyang Li committed
65
    T *shmem = reinterpret_cast<T *>(_shmem);
Zhekai Zhang's avatar
Zhekai Zhang committed
66
67
68
69
70
71
    __shared__ float s_mean;
    __shared__ float s_variance;

    const int tidx = threadIdx.x;
    const int bidx = blockIdx.x;

Muyang Li's avatar
Muyang Li committed
72
73
74
    float mean          = 0.0f;
    float variance      = 0.0f;
    float local_sum     = 0.0f;
Zhekai Zhang's avatar
Zhekai Zhang committed
75
76
77
    float local_var_sum = 0.0f;

    const int n_elems = hidden_dim / num_elems_T;
Muyang Li's avatar
Muyang Li committed
78
    for (int i = tidx; i < n_elems; i += blockDim.x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
79
        const T val = input[bidx * n_elems + i];
Muyang Li's avatar
Muyang Li committed
80
        if (use_shmem) {
Zhekai Zhang's avatar
Zhekai Zhang committed
81
82
83
84
85
            shmem[i] = val;
        }

        const float_packed_t val_f = cuda_cast<float_packed_t>(val);
        local_sum += cuda_sum<float>(val_f);
Muyang Li's avatar
Muyang Li committed
86
        if (USE_DIFF_OF_SQUARES) {
Zhekai Zhang's avatar
Zhekai Zhang committed
87
88
89
90
            local_var_sum += cuda_sum<float>(val_f * val_f);
        }
    }

Muyang Li's avatar
Muyang Li committed
91
    if (USE_DIFF_OF_SQUARES) {
Zhekai Zhang's avatar
Zhekai Zhang committed
92
93
        float packed[2] = {local_sum, local_var_sum};
        blockReduceSumV2<float, 2>(packed);
Muyang Li's avatar
Muyang Li committed
94
        mean     = packed[0];
Zhekai Zhang's avatar
Zhekai Zhang committed
95
        variance = packed[1];
Muyang Li's avatar
Muyang Li committed
96
    } else {
Zhekai Zhang's avatar
Zhekai Zhang committed
97
98
99
        mean = blockReduceSum(local_sum);
    }

Muyang Li's avatar
Muyang Li committed
100
101
    if (threadIdx.x == 0) {
        mean   = mean / hidden_dim;
Zhekai Zhang's avatar
Zhekai Zhang committed
102
        s_mean = mean;
Muyang Li's avatar
Muyang Li committed
103
104
        if (USE_DIFF_OF_SQUARES) {
            variance   = (variance / hidden_dim) - (mean * mean); // Var[x] = E[x²] - E[x]²
Zhekai Zhang's avatar
Zhekai Zhang committed
105
106
107
108
109
            s_variance = rsqrtf(variance + eps);
        }
    }
    __syncthreads();

Muyang Li's avatar
Muyang Li committed
110
111
112
    if (!USE_DIFF_OF_SQUARES) {
        for (int i = tidx; i < n_elems; i += blockDim.x) {
            const T val         = use_shmem ? shmem[i] : input[bidx * n_elems + i];
Zhekai Zhang's avatar
Zhekai Zhang committed
113
114
115
116
117
            float_packed_t diff = cuda_cast<float_packed_t>(val) - s_mean;
            local_var_sum += cuda_sum<float>(diff * diff);
        }
        variance = blockReduceSum(local_var_sum);

Muyang Li's avatar
Muyang Li committed
118
        if (threadIdx.x == 0) {
Zhekai Zhang's avatar
Zhekai Zhang committed
119
120
121
122
123
            s_variance = rsqrtf(variance / hidden_dim + eps);
        }
        __syncthreads();
    }

Muyang Li's avatar
Muyang Li committed
124
    const bool with_per_token_scaling  = scale_orig_quant_per_token != nullptr;
Zhekai Zhang's avatar
Zhekai Zhang committed
125
    const bool with_per_tensor_scaling = scale_orig_quant_per_tensor != nullptr;
Muyang Li's avatar
Muyang Li committed
126
127
    const float_packed_t scale_orig_quant =
        cuda_cast<float_packed_t>(with_per_tensor_scaling ? __half2float(*scale_orig_quant_per_tensor) : 0.0f);
Zhekai Zhang's avatar
Zhekai Zhang committed
128
129
    T_scalar amax = 1e-6f;

Muyang Li's avatar
Muyang Li committed
130
131
    for (int i = tidx; i < n_elems; i += blockDim.x) {
        const int index            = bidx * n_elems + i;
Zhekai Zhang's avatar
Zhekai Zhang committed
132
        const float_packed_t val_f = cuda_cast<float_packed_t>(use_shmem ? shmem[i] : input[index]);
Muyang Li's avatar
Muyang Li committed
133
        const T val                = cuda_cast<T>(compute_layernorm(val_f, s_mean, s_variance, gamma, beta, i));
Zhekai Zhang's avatar
Zhekai Zhang committed
134

Muyang Li's avatar
Muyang Li committed
135
        if (with_per_token_scaling) {
Zhekai Zhang's avatar
Zhekai Zhang committed
136
            amax = cuda_max(cuda_max<T_scalar, T>(cuda_abs(val)), amax);
Muyang Li's avatar
Muyang Li committed
137
            if (use_shmem) {
Zhekai Zhang's avatar
Zhekai Zhang committed
138
139
                shmem[i] = val;
            }
Muyang Li's avatar
Muyang Li committed
140
141
142
143
        } else if (with_per_tensor_scaling) {
            reinterpret_cast<int8_packed_t *>(normed_output_quant)[index] =
                cuda_cast<int8_packed_t>(cuda_cast<float_packed_t>(val) * scale_orig_quant);
        } else {
Zhekai Zhang's avatar
Zhekai Zhang committed
144
145
146
147
            normed_output[index] = val;
        }
    }

Muyang Li's avatar
Muyang Li committed
148
149
    if (with_per_token_scaling) {
        float abs_max_f                     = blockAllReduceMax(cuda_cast<float>(amax));
Zhekai Zhang's avatar
Zhekai Zhang committed
150
        const float dynamic_per_token_scale = 127.f / abs_max_f;
Muyang Li's avatar
Muyang Li committed
151
152
        for (int i = tidx; i < n_elems; i += blockDim.x) {
            const int index      = bidx * n_elems + i;
Zhekai Zhang's avatar
Zhekai Zhang committed
153
            float_packed_t val_f = cuda_cast<float_packed_t>(use_shmem ? shmem[i] : input[index]);
Muyang Li's avatar
Muyang Li committed
154
            if (!use_shmem) {
Zhekai Zhang's avatar
Zhekai Zhang committed
155
156
157
                val_f = compute_layernorm(val_f, s_mean, s_variance, gamma, beta, i);
            }

Muyang Li's avatar
Muyang Li committed
158
159
            reinterpret_cast<int8_packed_t *>(normed_output_quant)[index] =
                cuda_cast<int8_packed_t>(val_f * cuda_cast<float_packed_t>(dynamic_per_token_scale));
Zhekai Zhang's avatar
Zhekai Zhang committed
160
        }
Muyang Li's avatar
Muyang Li committed
161
        if (tidx == 0) {
Zhekai Zhang's avatar
Zhekai Zhang committed
162
163
164
165
166
            scale_orig_quant_per_token[bidx] = abs_max_f / 127.f;
        }
    }
}

Muyang Li's avatar
Muyang Li committed
167
168
169
170
171
172
173
174
175
176
177
178
179
template<typename T, typename scale_type, bool USE_DIFF_OF_SQUARES = false>
__global__ void generalLayerNorm_fuse_sum(const T *input,
                                          const T *gamma,
                                          const T *beta,
                                          T *normed_output,
                                          const float eps,
                                          int tokens,
                                          int hidden_dim,
                                          scale_type *input_sum,
                                          const scale_type *scale_orig_quant_per_tensor,
                                          scale_type *scale_orig_quant_per_token,
                                          int8_t *normed_output_quant,
                                          bool use_shmem) {
Zhekai Zhang's avatar
Zhekai Zhang committed
180
    constexpr auto num_elems_T = num_elems<T>::value;
Muyang Li's avatar
Muyang Li committed
181
182
183
    using int8_packed_t        = typename packed_as<int8_t, num_elems_T>::type;
    using float_packed_t       = typename packed_as<float, num_elems_T>::type;
    using T_scalar             = typename packed_as<T, 1>::type;
Zhekai Zhang's avatar
Zhekai Zhang committed
184
185

    extern __shared__ __align__(sizeof(float)) char _shmem[];
Muyang Li's avatar
Muyang Li committed
186
    T *shmem = reinterpret_cast<T *>(_shmem);
Zhekai Zhang's avatar
Zhekai Zhang committed
187
188
189
190
191
192
    __shared__ float s_mean;
    __shared__ float s_variance;

    const int tidx = threadIdx.x;
    const int bidx = blockIdx.x;

Muyang Li's avatar
Muyang Li committed
193
194
195
    float mean          = 0.0f;
    float variance      = 0.0f;
    float local_sum     = 0.0f;
Zhekai Zhang's avatar
Zhekai Zhang committed
196
197
198
    float local_var_sum = 0.0f;

    const int n_elems = hidden_dim / num_elems_T;
Muyang Li's avatar
Muyang Li committed
199
    for (int i = tidx; i < n_elems; i += blockDim.x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
200
        const T val = input[bidx * n_elems + i];
Muyang Li's avatar
Muyang Li committed
201
        if (use_shmem) {
Zhekai Zhang's avatar
Zhekai Zhang committed
202
203
204
205
206
            shmem[i] = val;
        }

        const float_packed_t val_f = cuda_cast<float_packed_t>(val);
        local_sum += cuda_sum<float>(val_f);
Muyang Li's avatar
Muyang Li committed
207
        if (USE_DIFF_OF_SQUARES) {
Zhekai Zhang's avatar
Zhekai Zhang committed
208
209
210
211
            local_var_sum += cuda_sum<float>(val_f * val_f);
        }
    }

Muyang Li's avatar
Muyang Li committed
212
    if (USE_DIFF_OF_SQUARES) {
Zhekai Zhang's avatar
Zhekai Zhang committed
213
214
        float packed[2] = {local_sum, local_var_sum};
        blockReduceSumV2<float, 2>(packed);
Muyang Li's avatar
Muyang Li committed
215
        mean     = packed[0];
Zhekai Zhang's avatar
Zhekai Zhang committed
216
        variance = packed[1];
Muyang Li's avatar
Muyang Li committed
217
    } else {
Zhekai Zhang's avatar
Zhekai Zhang committed
218
219
220
        mean = blockReduceSum(local_sum);
    }

Muyang Li's avatar
Muyang Li committed
221
222
    if (threadIdx.x == 0) {
        mean   = mean / hidden_dim;
Zhekai Zhang's avatar
Zhekai Zhang committed
223
        s_mean = mean;
Muyang Li's avatar
Muyang Li committed
224
225
        if (USE_DIFF_OF_SQUARES) {
            variance   = (variance / hidden_dim) - (mean * mean); // Var[x] = E[x²] - E[x]²
Zhekai Zhang's avatar
Zhekai Zhang committed
226
227
228
229
230
            s_variance = rsqrtf(variance + eps);
        }
    }
    __syncthreads();

Muyang Li's avatar
Muyang Li committed
231
232
233
    if (!USE_DIFF_OF_SQUARES) {
        for (int i = tidx; i < n_elems; i += blockDim.x) {
            const T val         = use_shmem ? shmem[i] : input[bidx * n_elems + i];
Zhekai Zhang's avatar
Zhekai Zhang committed
234
235
236
237
238
            float_packed_t diff = cuda_cast<float_packed_t>(val) - s_mean;
            local_var_sum += cuda_sum<float>(diff * diff);
        }
        variance = blockReduceSum(local_var_sum);

Muyang Li's avatar
Muyang Li committed
239
        if (threadIdx.x == 0) {
Zhekai Zhang's avatar
Zhekai Zhang committed
240
241
242
243
244
            s_variance = rsqrtf(variance / hidden_dim + eps);
        }
        __syncthreads();
    }

Muyang Li's avatar
Muyang Li committed
245
    const bool with_per_token_scaling  = scale_orig_quant_per_token != nullptr;
Zhekai Zhang's avatar
Zhekai Zhang committed
246
    const bool with_per_tensor_scaling = scale_orig_quant_per_tensor != nullptr;
Muyang Li's avatar
Muyang Li committed
247
248
    const float_packed_t scale_orig_quant =
        cuda_cast<float_packed_t>(with_per_tensor_scaling ? __half2float(*scale_orig_quant_per_tensor) : 0.0f);
Zhekai Zhang's avatar
Zhekai Zhang committed
249
    T_scalar amax = 1e-6f;
Muyang Li's avatar
Muyang Li committed
250
    T_scalar sum  = 0.0f;
Zhekai Zhang's avatar
Zhekai Zhang committed
251

Muyang Li's avatar
Muyang Li committed
252
253
    for (int i = tidx; i < n_elems; i += blockDim.x) {
        const int index            = bidx * n_elems + i;
Zhekai Zhang's avatar
Zhekai Zhang committed
254
        const float_packed_t val_f = cuda_cast<float_packed_t>(use_shmem ? shmem[i] : input[index]);
Muyang Li's avatar
Muyang Li committed
255
        const T val                = cuda_cast<T>(compute_layernorm(val_f, s_mean, s_variance, gamma, beta, i));
Zhekai Zhang's avatar
Zhekai Zhang committed
256

Muyang Li's avatar
Muyang Li committed
257
        if (with_per_token_scaling) {
Zhekai Zhang's avatar
Zhekai Zhang committed
258
259
            amax = cuda_max(cuda_max<T_scalar, T>(cuda_abs(val)), amax);
            sum += cuda_sum<float>(val);
Muyang Li's avatar
Muyang Li committed
260
            if (use_shmem) {
Zhekai Zhang's avatar
Zhekai Zhang committed
261
262
                shmem[i] = val;
            }
Muyang Li's avatar
Muyang Li committed
263
264
265
266
        } else if (with_per_tensor_scaling) {
            reinterpret_cast<int8_packed_t *>(normed_output_quant)[index] =
                cuda_cast<int8_packed_t>(cuda_cast<float_packed_t>(val) * scale_orig_quant);
        } else {
Zhekai Zhang's avatar
Zhekai Zhang committed
267
268
269
270
            normed_output[index] = val;
        }
    }

Muyang Li's avatar
Muyang Li committed
271
272
273
    if (with_per_token_scaling) {
        float abs_max_f                     = blockAllReduceMax(cuda_cast<float>(amax));
        float sum_f                         = blockAllReduceSum(cuda_cast<float>(sum));
Zhekai Zhang's avatar
Zhekai Zhang committed
274
        const float dynamic_per_token_scale = 127.f / abs_max_f;
Muyang Li's avatar
Muyang Li committed
275
276
        for (int i = tidx; i < n_elems; i += blockDim.x) {
            const int index      = bidx * n_elems + i;
Zhekai Zhang's avatar
Zhekai Zhang committed
277
            float_packed_t val_f = cuda_cast<float_packed_t>(use_shmem ? shmem[i] : input[index]);
Muyang Li's avatar
Muyang Li committed
278
            if (!use_shmem) {
Zhekai Zhang's avatar
Zhekai Zhang committed
279
280
281
                val_f = compute_layernorm(val_f, s_mean, s_variance, gamma, beta, i);
            }

Muyang Li's avatar
Muyang Li committed
282
283
            reinterpret_cast<int8_packed_t *>(normed_output_quant)[index] =
                cuda_cast<int8_packed_t>(val_f * cuda_cast<float_packed_t>(dynamic_per_token_scale));
Zhekai Zhang's avatar
Zhekai Zhang committed
284
        }
Muyang Li's avatar
Muyang Li committed
285
        if (tidx == 0) {
Zhekai Zhang's avatar
Zhekai Zhang committed
286
            scale_orig_quant_per_token[bidx] = abs_max_f / 127.f;
Muyang Li's avatar
Muyang Li committed
287
            input_sum[bidx]                  = sum_f;
Zhekai Zhang's avatar
Zhekai Zhang committed
288
289
290
291
292
        }
    }
}

// TODO(woosuk): Further optimize this kernel.
Muyang Li's avatar
Muyang Li committed
293
294
295
296
297
298
299
300
301
302
303
304
305
template<typename scalar_t, typename out_type, bool use_quant>
__global__ void rms_norm_kernel(out_type *__restrict__ out,          // [..., hidden_size]
                                const scalar_t *__restrict__ input,  // [..., hidden_size]
                                const scalar_t *__restrict__ weight, // [hidden_size]
                                const float epsilon,
                                const int num_tokens,
                                const int hidden_size) {
    __shared__ float s_variance;
    float variance = 0.0f;

    for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
        const float x = (float)input[blockIdx.x * hidden_size + idx];
        variance += x * x;
Zhekai Zhang's avatar
Zhekai Zhang committed
306
    }
Muyang Li's avatar
Muyang Li committed
307
308
309
310
311
    variance = blockReduceSum<float>(variance);
    if (threadIdx.x == 0) {
        s_variance = rsqrtf(variance / hidden_size + epsilon);
    }
    __syncthreads();
Zhekai Zhang's avatar
Zhekai Zhang committed
312

Muyang Li's avatar
Muyang Li committed
313
314
315
316
317
318
319
320
321
    for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
        float x = (float)input[blockIdx.x * hidden_size + idx];
        if constexpr (use_quant) {
            out[blockIdx.x * hidden_size + idx] = float_to_int8_rn(((float)(x * s_variance)) * (float)(weight[idx]));
        } else {
            out[blockIdx.x * hidden_size + idx] = ((scalar_t)(x * s_variance)) * weight[idx];
        }
    }
}
Zhekai Zhang's avatar
Zhekai Zhang committed
322

Muyang Li's avatar
Muyang Li committed
323
324
325
326
327
328
329
330
331
332
333
template<typename T, typename scale_type, bool use_per_token_dequant>
__global__ void dequant_add_residual_rms_norm_quant_kernel(const int32_t *__restrict__ input,
                                                           T *__restrict__ residual,
                                                           int8_t *__restrict__ output,
                                                           const T *__restrict__ gamma,
                                                           const float layernorm_eps,
                                                           const scale_type scale,
                                                           int num_tokens,
                                                           int hidden_size) {
    // layernorm module in the T5 style No bias and no subtraction of mean.
    const int tid = threadIdx.x;
Zhekai Zhang's avatar
Zhekai Zhang committed
334

Muyang Li's avatar
Muyang Li committed
335
336
    __shared__ float s_variance;
    float variance = 0.0f;
Zhekai Zhang's avatar
Zhekai Zhang committed
337

Muyang Li's avatar
Muyang Li committed
338
339
340
341
342
343
344
345
346
347
348
349
350
351
    float local_var_sum = 0.0f;
    for (int i = tid; i < hidden_size; i += blockDim.x) {
        float diff = 0.0f;
        if constexpr (use_per_token_dequant) {
            diff = ((((float)input[blockIdx.x * hidden_size + i]) * __half2float(scale[blockIdx.x])) +
                    (float)residual[blockIdx.x * hidden_size + i]);
        } else {
            diff = ((((float)input[blockIdx.x * hidden_size + i]) * __half2float(scale)) +
                    (float)residual[blockIdx.x * hidden_size + i]);
        }
        residual[blockIdx.x * hidden_size + i] = (T)diff;
        local_var_sum += diff * diff;
    }
    variance = blockReduceSum(local_var_sum);
Zhekai Zhang's avatar
Zhekai Zhang committed
352

Muyang Li's avatar
Muyang Li committed
353
354
355
356
    if (threadIdx.x == 0) {
        s_variance = rsqrtf(variance / (float)hidden_size + layernorm_eps);
    }
    __syncthreads();
Zhekai Zhang's avatar
Zhekai Zhang committed
357

Muyang Li's avatar
Muyang Li committed
358
359
360
    for (int i = tid; i < hidden_size; i += blockDim.x) {
        output[blockIdx.x * hidden_size + i] =
            float_to_int8_rn((((float)(residual[blockIdx.x * hidden_size + i])) * s_variance) * (float)(gamma[i]));
Zhekai Zhang's avatar
Zhekai Zhang committed
361
362
363
    }
}
} // namespace vllm