"tests/vscode:/vscode.git/clone" did not exist on "951fdd66d36ffed75a55e9bfa9f2221dce4fbcdc"
layernorm_kernels.cu 17.9 KB
Newer Older
1
2
#include "type_convert.cuh"
#include "dispatch_utils.h"
Aidyn-A's avatar
Aidyn-A committed
3
#include "cub_helpers.h"
4
#include "core/batch_invariant.hpp"
5
6

#include <torch/cuda.h>
7
#include <c10/cuda/CUDAGuard.h>
8

Woosuk Kwon's avatar
Woosuk Kwon committed
9
namespace vllm {
10
11

// TODO(woosuk): Further optimize this kernel.
12
template <typename scalar_t>
13
__global__ void rms_norm_kernel(
14
15
16
    scalar_t* __restrict__ out,          // [..., hidden_size]
    const scalar_t* __restrict__ input,  // [..., hidden_size]
    const int64_t input_stride,
17
18
    const scalar_t* __restrict__ weight,  // [hidden_size]
    const float epsilon, const int num_tokens, const int hidden_size) {
19
20
21
22
  __shared__ float s_variance;
  float variance = 0.0f;

  for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
23
    const float x = (float)input[blockIdx.x * input_stride + idx];
24
25
    variance += x * x;
  }
26
27
28

  using BlockReduce = cub::BlockReduce<float, 1024>;
  __shared__ typename BlockReduce::TempStorage reduceStore;
Aidyn-A's avatar
Aidyn-A committed
29
  variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x);
30

31
32
33
34
35
36
  if (threadIdx.x == 0) {
    s_variance = rsqrtf(variance / hidden_size + epsilon);
  }
  __syncthreads();

  for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
37
    float x = (float)input[blockIdx.x * input_stride + idx];
38
39
    out[blockIdx.x * hidden_size + idx] =
        ((scalar_t)(x * s_variance)) * weight[idx];
40
41
42
  }
}

43
44
45
46
/* Function specialization in the case of FP16/BF16 tensors.
   Additional optimizations we can make in this case are
   packed and vectorized operations, which help with the
   memory latency bottleneck. */
47
48
49
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
fused_add_rms_norm_kernel(
50
51
    scalar_t* __restrict__ input,  // [..., hidden_size]
    const int64_t input_stride,
52
53
54
    scalar_t* __restrict__ residual,      // [..., hidden_size]
    const scalar_t* __restrict__ weight,  // [hidden_size]
    const float epsilon, const int num_tokens, const int hidden_size) {
55
56
57
58
59
  // Sanity checks on our vector struct and type-punned pointer arithmetic
  static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
  static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);

  const int vec_hidden_size = hidden_size / width;
60
  const int64_t vec_input_stride = input_stride / width;
61
62
63
64
65
  __shared__ float s_variance;
  float variance = 0.0f;
  /* These and the argument pointers are all declared `restrict` as they are
     not aliased in practice. Argument pointers should not be dereferenced
     in this kernel as that would be undefined behavior */
66
67
68
69
70
71
  auto* __restrict__ input_v =
      reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
  auto* __restrict__ residual_v =
      reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
  auto* __restrict__ weight_v =
      reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
72
73
74

  for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
    int id = blockIdx.x * vec_hidden_size + idx;
75
76
    int64_t strided_id = blockIdx.x * vec_input_stride + idx;
    _f16Vec<scalar_t, width> temp = input_v[strided_id];
77
78
79
80
    temp += residual_v[id];
    variance += temp.sum_squares();
    residual_v[id] = temp;
  }
81
82
83

  using BlockReduce = cub::BlockReduce<float, 1024>;
  __shared__ typename BlockReduce::TempStorage reduceStore;
Aidyn-A's avatar
Aidyn-A committed
84
  variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x);
85

86
87
88
89
90
91
92
  if (threadIdx.x == 0) {
    s_variance = rsqrtf(variance / hidden_size + epsilon);
  }
  __syncthreads();

  for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
    int id = blockIdx.x * vec_hidden_size + idx;
93
    int64_t strided_id = blockIdx.x * vec_input_stride + idx;
94
95
96
    _f16Vec<scalar_t, width> temp = residual_v[id];
    temp *= s_variance;
    temp *= weight_v[idx];
97
    input_v[strided_id] = temp;
98
99
100
101
102
103
  }
}

/* Generic fused_add_rms_norm_kernel
   The width field is not used here but necessary for other specializations.
 */
104
105
106
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
fused_add_rms_norm_kernel(
107
108
    scalar_t* __restrict__ input,  // [..., hidden_size]
    const int64_t input_stride,
109
110
111
    scalar_t* __restrict__ residual,      // [..., hidden_size]
    const scalar_t* __restrict__ weight,  // [hidden_size]
    const float epsilon, const int num_tokens, const int hidden_size) {
112
113
114
115
  __shared__ float s_variance;
  float variance = 0.0f;

  for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
116
    scalar_t z = input[blockIdx.x * input_stride + idx];
117
    z += residual[blockIdx.x * hidden_size + idx];
118
    float x = (float)z;
119
    variance += x * x;
120
    residual[blockIdx.x * hidden_size + idx] = z;
121
  }
122
123
124

  using BlockReduce = cub::BlockReduce<float, 1024>;
  __shared__ typename BlockReduce::TempStorage reduceStore;
Aidyn-A's avatar
Aidyn-A committed
125
  variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x);
126

127
128
129
130
131
132
  if (threadIdx.x == 0) {
    s_variance = rsqrtf(variance / hidden_size + epsilon);
  }
  __syncthreads();

  for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
133
    float x = (float)residual[blockIdx.x * hidden_size + idx];
134
    input[blockIdx.x * input_stride + idx] =
135
        ((scalar_t)(x * s_variance)) * weight[idx];
136
137
138
  }
}

139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
/* Function specialization in the case of FP16/BF16 tensors.
   Additional optimizations we can make in this case are
   packed and vectorized operations, which help with the
   memory latency bottleneck.

   _f16VecPN struct extends _f16Vec to add operations specifically required for
   polynomial normalization (poly norm).
   The original _f16Vec does not include the sum-of-powers computation or
   in-place polynomial normalization logic. */
template <typename scalar_t, int width>
struct alignas(16) _f16VecPN : _f16Vec<scalar_t, width> {
  using Base = _f16Vec<scalar_t, width>;
  using Converter = typename Base::Converter;
  using T1 = typename Base::T1;
  using T2 = typename Base::T2;
  using Base::data;

  __device__ auto sum_pows() const {
    float s2 = 0.0f, s4 = 0.0f, s6 = 0.0f;

#pragma unroll
    for (int i = 0; i < width; i += 2) {
      float2 z = Converter::convert(T2{data[i], data[i + 1]});
      float x2 = z.x * z.x;
      float x4 = x2 * x2;
      float x6 = x4 * x2;

      float y2 = z.y * z.y;
      float y4 = y2 * y2;
      float y6 = y4 * y2;

      s2 += x2 + y2;
      s4 += x4 + y4;
      s6 += x6 + y6;
    }
    return std::make_tuple(s2, s4, s6);
  }

  __device__ void poly_norm_inplace(const float w2_inv_std,
                                    const float w1_inv_std2,
                                    const float w0_inv_std3, const float bias) {
#pragma unroll
    for (int i = 0; i < width; i += 2) {
      float2 z = Converter::convert(T2{data[i], data[i + 1]});

      float x2 = z.x * z.x;
      float x3 = x2 * z.x;
      z.x = w2_inv_std * z.x + w1_inv_std2 * x2 + w0_inv_std3 * x3 + bias;

      float y2 = z.y * z.y;
      float y3 = y2 * z.y;
      z.y = w2_inv_std * z.y + w1_inv_std2 * y2 + w0_inv_std3 * y3 + bias;

      auto out = Converter::convert(z);
      data[i] = out.x;
      data[i + 1] = out.y;
    }
  }
};

template <typename scalar_t, int width>
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
poly_norm_kernel(scalar_t* __restrict__ out,           // [..., hidden_size]
                 const scalar_t* __restrict__ input,   // [..., hidden_size]
                 const scalar_t* __restrict__ weight,  // [3]
                 const scalar_t* __restrict__ bias,    // [1]
                 const float epsilon, const int hidden_size) {
  // Sanity checks on our vector struct and type-punned pointer arithmetic
  static_assert(std::is_pod_v<_f16VecPN<scalar_t, width>>);
  static_assert(sizeof(_f16VecPN<scalar_t, width>) == sizeof(scalar_t) * width);

  /* These and the argument pointers are all declared `restrict` as they are
     not aliased in practice. Argument pointers should not be dereferenced
     in this kernel as that would be undefined behavior */
  auto* __restrict__ input_v =
      reinterpret_cast<const _f16VecPN<scalar_t, width>*>(input);
  const int vec_hidden_size = hidden_size / width;
  float variance = 0.0f;
  float variance2 = 0.0f;
  float variance3 = 0.0f;

  for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
    int id = blockIdx.x * vec_hidden_size + idx;
    _f16VecPN<scalar_t, width> temp = input_v[id];
    auto [x2, x4, x6] = temp.sum_pows();

    variance += x2;
    variance2 += x4;
    variance3 += x6;
  }

  float3 thread_variances = make_float3(variance, variance2, variance3);

  struct SumOp {
    __device__ float3 operator()(const float3& a, const float3& b) const {
      return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
    }
  };

  using BlockReduce = cub::BlockReduce<float3, 1024>;
  __shared__ typename BlockReduce::TempStorage reduceStore;
  float3 block_variances =
      BlockReduce(reduceStore).Reduce(thread_variances, SumOp{}, blockDim.x);

  variance = block_variances.x;
  variance2 = block_variances.y;
  variance3 = block_variances.z;

  __shared__ float s_w2_inv_std;
  __shared__ float s_w1_inv_std2;
  __shared__ float s_w0_inv_std3;
  __shared__ float s_bias;

  if (threadIdx.x == 0) {
    float w0 = (float)weight[0];
    float w1 = (float)weight[1];
    float w2 = (float)weight[2];
    s_bias = (float)bias[0];

    s_w2_inv_std = w2 * rsqrtf(variance / hidden_size + epsilon);
    s_w1_inv_std2 = w1 * rsqrtf(variance2 / hidden_size + epsilon);
    s_w0_inv_std3 = w0 * rsqrtf(variance3 / hidden_size + epsilon);
  }
  __syncthreads();

  auto* __restrict__ out_v = reinterpret_cast<_f16VecPN<scalar_t, width>*>(out);

  for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
    int id = blockIdx.x * vec_hidden_size + idx;
    _f16VecPN<scalar_t, width> temp = input_v[id];
    temp.poly_norm_inplace(s_w2_inv_std, s_w1_inv_std2, s_w0_inv_std3, s_bias);
    out_v[id] = temp;
  }
}

/* Generic poly_norm_kernel
   The width field is not used here but necessary for other specializations.
 */
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
poly_norm_kernel(scalar_t* __restrict__ out,           // [..., hidden_size]
                 const scalar_t* __restrict__ input,   // [..., hidden_size]
                 const scalar_t* __restrict__ weight,  // [3]
                 const scalar_t* __restrict__ bias,    // [1]
                 const float epsilon, const int hidden_size) {
  float variance = 0.0f;
  float variance2 = 0.0f;
  float variance3 = 0.0f;

  for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
    float x = (float)input[blockIdx.x * hidden_size + idx];
    float x2 = x * x;
    float x4 = x2 * x2;
    float x6 = x4 * x2;

    variance += x2;
    variance2 += x4;
    variance3 += x6;
  }

  float3 thread_variances = make_float3(variance, variance2, variance3);

  struct SumOp {
    __device__ float3 operator()(const float3& a, const float3& b) const {
      return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
    }
  };

  using BlockReduce = cub::BlockReduce<float3, 1024>;
  __shared__ typename BlockReduce::TempStorage reduceStore;
  float3 block_variances =
      BlockReduce(reduceStore).Reduce(thread_variances, SumOp{}, blockDim.x);

  variance = block_variances.x;
  variance2 = block_variances.y;
  variance3 = block_variances.z;

  __shared__ float s_w2_inv_std;
  __shared__ float s_w1_inv_std2;
  __shared__ float s_w0_inv_std3;
  __shared__ float s_bias;

  if (threadIdx.x == 0) {
    float w0 = (float)weight[0];
    float w1 = (float)weight[1];
    float w2 = (float)weight[2];
    s_bias = (float)bias[0];

    s_w2_inv_std = w2 * rsqrtf(variance / hidden_size + epsilon);
    s_w1_inv_std2 = w1 * rsqrtf(variance2 / hidden_size + epsilon);
    s_w0_inv_std3 = w0 * rsqrtf(variance3 / hidden_size + epsilon);
  }
  __syncthreads();

  for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
    float x = (float)input[blockIdx.x * hidden_size + idx];
    float x2 = x * x;
    float x3 = x2 * x;

    out[blockIdx.x * hidden_size + idx] =
        (scalar_t)(x * s_w2_inv_std + x2 * s_w1_inv_std2 + x3 * s_w0_inv_std3 +
                   s_bias);
  }
}

344
}  // namespace vllm
345

346
347
348
void rms_norm(torch::Tensor& out,     // [..., hidden_size]
              torch::Tensor& input,   // [..., hidden_size]
              torch::Tensor& weight,  // [hidden_size]
349
              double epsilon) {
350
  TORCH_CHECK(out.is_contiguous());
351
  TORCH_CHECK(input.stride(-1) == 1);
352
353
  TORCH_CHECK(weight.is_contiguous());

354
355
  int hidden_size = input.size(-1);
  int num_tokens = input.numel() / hidden_size;
356
  int64_t input_stride = input.stride(-2);
357
358
359

  dim3 grid(num_tokens);
  dim3 block(std::min(hidden_size, 1024));
360
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
361
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
362
363
  VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
    vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
364
        out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), input_stride,
365
366
        weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
  });
367
}
368

369
370
371
372
373
374
375
376
#define LAUNCH_FUSED_ADD_RMS_NORM(width)                                    \
  VLLM_DISPATCH_FLOATING_TYPES(                                             \
      input.scalar_type(), "fused_add_rms_norm_kernel", [&] {               \
        vllm::fused_add_rms_norm_kernel<scalar_t, width>                    \
            <<<grid, block, 0, stream>>>(                                   \
                input.data_ptr<scalar_t>(), input_stride,                   \
                residual.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(), \
                epsilon, num_tokens, hidden_size);                          \
377
378
379
380
381
      });

void fused_add_rms_norm(torch::Tensor& input,     // [..., hidden_size]
                        torch::Tensor& residual,  // [..., hidden_size]
                        torch::Tensor& weight,    // [hidden_size]
382
                        double epsilon) {
383
384
  TORCH_CHECK(residual.is_contiguous());
  TORCH_CHECK(weight.is_contiguous());
385
  int hidden_size = input.size(-1);
386
  int64_t input_stride = input.stride(-2);
387
388
389
  int num_tokens = input.numel() / hidden_size;

  dim3 grid(num_tokens);
390
391
392
393
394
395
  /* This kernel is memory-latency bound in many scenarios.
     When num_tokens is large, a smaller block size allows
     for increased block occupancy on CUs and better latency
     hiding on global mem ops. */
  const int max_block_size = (num_tokens < 256) ? 1024 : 256;
  dim3 block(std::min(hidden_size, max_block_size));
396
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
397
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
398
399
400
401
402
403
404
405
406
407
  /*If the tensor types are FP16/BF16, try to use the optimized kernel
    with packed + vectorized ops.
    Max optimization is achieved with a width-8 vector of FP16/BF16s
    since we can load at most 128 bits at once in a global memory op.
    However, this requires each tensor's data to be aligned to 16
    bytes.
   */
  auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
  auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
  auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
408
409
410
411
412
413
414
415
416
  constexpr int vector_width = 8;
  constexpr int req_alignment_bytes =
      vector_width * 2;  // vector_width * sizeof(bfloat16 or float16) (float32
                         // falls back to non-vectorized version anyway)
  bool ptrs_are_aligned = inp_ptr % req_alignment_bytes == 0 &&
                          res_ptr % req_alignment_bytes == 0 &&
                          wt_ptr % req_alignment_bytes == 0;
  bool offsets_are_multiple_of_vector_width =
      hidden_size % vector_width == 0 && input_stride % vector_width == 0;
417
418
419
  bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
  if (ptrs_are_aligned && offsets_are_multiple_of_vector_width &&
      !batch_invariant_launch) {
420
421
422
423
    LAUNCH_FUSED_ADD_RMS_NORM(8);
  } else {
    LAUNCH_FUSED_ADD_RMS_NORM(0);
  }
424
}
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464

#define LAUNCH_FUSED_POLY_NORM(width)                                         \
  VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "poly_norm_kernel", [&] { \
    vllm::poly_norm_kernel<scalar_t, width><<<grid, block, 0, stream>>>(      \
        out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),                 \
        weight.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(), epsilon,      \
        hidden_size);                                                         \
  });

void poly_norm(torch::Tensor& out,     // [..., hidden_size]
               torch::Tensor& input,   // [..., hidden_size]
               torch::Tensor& weight,  // [3]
               torch::Tensor& bias,    // [1]
               double epsilon) {
  TORCH_CHECK(out.is_contiguous());
  TORCH_CHECK(input.is_contiguous());
  TORCH_CHECK(out.data_ptr() != input.data_ptr());

  int hidden_size = input.size(-1);
  int num_tokens = input.numel() / hidden_size;

  dim3 grid(num_tokens);
  /* This kernel is memory-latency bound in many scenarios.
     When num_tokens is large, a smaller block size allows
     for increased block occupancy on CUs and better latency
     hiding on global mem ops. */
  const int max_block_size = (num_tokens < 256) ? 1024 : 256;
  dim3 block(std::min(hidden_size, max_block_size));
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  /*If the tensor types are FP16/BF16, try to use the optimized kernel
    with packed + vectorized ops.
    Max optimization is achieved with a width-8 vector of FP16/BF16s
    since we can load at most 128 bits at once in a global memory op.
    However, this requires each tensor's data to be aligned to 16
    bytes.
   */
  auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
  auto out_ptr = reinterpret_cast<std::uintptr_t>(out.data_ptr());
  bool ptrs_are_aligned = inp_ptr % 16 == 0 && out_ptr % 16 == 0;
465
466
  bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
  if (ptrs_are_aligned && hidden_size % 8 == 0 && !batch_invariant_launch) {
467
468
469
470
471
    LAUNCH_FUSED_POLY_NORM(8);
  } else {
    LAUNCH_FUSED_POLY_NORM(0);
  }
}