layernorm_kernels.cu 20.7 KB
Newer Older
1
2
#include "type_convert.cuh"
#include "dispatch_utils.h"
Aidyn-A's avatar
Aidyn-A committed
3
#include "cub_helpers.h"
zhuwenwen's avatar
zhuwenwen committed
4
#include "quantization/vectorization_utils.cuh"
5
6

#include <torch/cuda.h>
7
#include <c10/cuda/CUDAGuard.h>
8

Woosuk Kwon's avatar
Woosuk Kwon committed
9
namespace vllm {
10
11

// TODO(woosuk): Further optimize this kernel.
zhuwenwen's avatar
zhuwenwen committed
12
template <typename scalar_t, int VEC_SIZE, int NUM_DIMS>
13
__global__ void rms_norm_kernel(
zhuwenwen's avatar
zhuwenwen committed
14
15
16
17
18
19
20
    scalar_t* __restrict__ out,           // [..., hidden_size]
    const scalar_t* __restrict__ input,   // [..., hidden_size]
    const int64_t input_stride_d2,        // input.stride(-2)
    const int64_t input_stride_d3,        // input.stride(-3)
    const int64_t input_stride_d4,        // input.stride(-4)
    const int64_t input_shape_d2,         // input.size(-2)
    const int64_t input_shape_d3,         // input.size(-3)
21
22
    const scalar_t* __restrict__ weight,  // [hidden_size]
    const float epsilon, const int num_tokens, const int hidden_size) {
23
24
  __shared__ float s_variance;
  float variance = 0.0f;
zhuwenwen's avatar
zhuwenwen committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
  const scalar_t* input_row;
  if constexpr (NUM_DIMS == 2) {
    // 2D for layernorm normal case [batch_size, hidden]
    input_row = input + blockIdx.x * input_stride_d2;
  } else if constexpr (NUM_DIMS == 3) {
    // 3D for q/k norm [batch_size, num_heads, head_size]
    int batch_idx = blockIdx.x / input_shape_d2;
    int head_idx = blockIdx.x % input_shape_d2;
    input_row =
        input + batch_idx * input_stride_d3 + head_idx * input_stride_d2;
  } else if constexpr (NUM_DIMS == 4) {
    // 4D for transformers model_impl qk norm [batch, seq, head, head_dim]
    int batch_idx = blockIdx.x / (input_shape_d3 * input_shape_d2);
    int remaining = blockIdx.x % (input_shape_d3 * input_shape_d2);
    int seq_idx = remaining / input_shape_d2;
    int head_idx = remaining % input_shape_d2;
    input_row = input + batch_idx * input_stride_d4 +
                seq_idx * input_stride_d3 + head_idx * input_stride_d2;
  }
44

zhuwenwen's avatar
zhuwenwen committed
45
46
47
48
49
50
51
52
53
  auto vec_op = [&variance](const vec_n_t<scalar_t, VEC_SIZE>& vec) {
#pragma unroll
    for (int i = 0; i < VEC_SIZE; ++i) {
      float x = static_cast<float>(vec.val[i]);
      variance += x * x;
    }
  };
  auto scalar_op = [&variance](const scalar_t& val) {
    float x = static_cast<float>(val);
54
    variance += x * x;
zhuwenwen's avatar
zhuwenwen committed
55
56
57
  };
  vllm::vectorize_read_with_alignment<VEC_SIZE>(
      input_row, hidden_size, threadIdx.x, blockDim.x, vec_op, scalar_op);
58
59
60

  using BlockReduce = cub::BlockReduce<float, 1024>;
  __shared__ typename BlockReduce::TempStorage reduceStore;
Aidyn-A's avatar
Aidyn-A committed
61
  variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x);
62

63
64
65
66
67
  if (threadIdx.x == 0) {
    s_variance = rsqrtf(variance / hidden_size + epsilon);
  }
  __syncthreads();

zhuwenwen's avatar
zhuwenwen committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
  scalar_t* out_row = out + blockIdx.x * hidden_size;
  auto* v_in = reinterpret_cast<const vec_n_t<scalar_t, VEC_SIZE>*>(input_row);
  auto* v_w = reinterpret_cast<const vec_n_t<scalar_t, VEC_SIZE>*>(weight);
  auto* v_out = reinterpret_cast<vec_n_t<scalar_t, VEC_SIZE>*>(out_row);
  for (int i = threadIdx.x; i < hidden_size / VEC_SIZE; i += blockDim.x) {
    vec_n_t<scalar_t, VEC_SIZE> dst;
    vec_n_t<scalar_t, VEC_SIZE> src1 = v_in[i];
    vec_n_t<scalar_t, VEC_SIZE> src2 = v_w[i];
#pragma unroll
    for (int j = 0; j < VEC_SIZE; j++) {
      float x = static_cast<float>(src1.val[j]);
      dst.val[j] = ((scalar_t)(x * s_variance)) * src2.val[j];
    }
    v_out[i] = dst;
82
83
84
  }
}

85
86
87
88
/* Function specialization in the case of FP16/BF16 tensors.
   Additional optimizations we can make in this case are
   packed and vectorized operations, which help with the
   memory latency bottleneck. */
89
90
91
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
fused_add_rms_norm_kernel(
92
93
    scalar_t* __restrict__ input,  // [..., hidden_size]
    const int64_t input_stride,
94
95
96
    scalar_t* __restrict__ residual,      // [..., hidden_size]
    const scalar_t* __restrict__ weight,  // [hidden_size]
    const float epsilon, const int num_tokens, const int hidden_size) {
97
98
99
100
101
  // Sanity checks on our vector struct and type-punned pointer arithmetic
  static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
  static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);

  const int vec_hidden_size = hidden_size / width;
102
  const int64_t vec_input_stride = input_stride / width;
103
104
105
106
107
  __shared__ float s_variance;
  float variance = 0.0f;
  /* These and the argument pointers are all declared `restrict` as they are
     not aliased in practice. Argument pointers should not be dereferenced
     in this kernel as that would be undefined behavior */
108
109
110
111
112
113
  auto* __restrict__ input_v =
      reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
  auto* __restrict__ residual_v =
      reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
  auto* __restrict__ weight_v =
      reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
114
115
116

  for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
    int id = blockIdx.x * vec_hidden_size + idx;
117
118
    int64_t strided_id = blockIdx.x * vec_input_stride + idx;
    _f16Vec<scalar_t, width> temp = input_v[strided_id];
119
120
121
122
    temp += residual_v[id];
    variance += temp.sum_squares();
    residual_v[id] = temp;
  }
123
124
125

  using BlockReduce = cub::BlockReduce<float, 1024>;
  __shared__ typename BlockReduce::TempStorage reduceStore;
Aidyn-A's avatar
Aidyn-A committed
126
  variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x);
127

128
129
130
131
132
133
134
  if (threadIdx.x == 0) {
    s_variance = rsqrtf(variance / hidden_size + epsilon);
  }
  __syncthreads();

  for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
    int id = blockIdx.x * vec_hidden_size + idx;
135
    int64_t strided_id = blockIdx.x * vec_input_stride + idx;
136
137
138
    _f16Vec<scalar_t, width> temp = residual_v[id];
    temp *= s_variance;
    temp *= weight_v[idx];
139
    input_v[strided_id] = temp;
140
141
142
143
144
145
  }
}

/* Generic fused_add_rms_norm_kernel
   The width field is not used here but necessary for other specializations.
 */
146
147
148
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
fused_add_rms_norm_kernel(
149
150
    scalar_t* __restrict__ input,  // [..., hidden_size]
    const int64_t input_stride,
151
152
153
    scalar_t* __restrict__ residual,      // [..., hidden_size]
    const scalar_t* __restrict__ weight,  // [hidden_size]
    const float epsilon, const int num_tokens, const int hidden_size) {
154
155
156
157
  __shared__ float s_variance;
  float variance = 0.0f;

  for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
158
    scalar_t z = input[blockIdx.x * input_stride + idx];
159
    z += residual[blockIdx.x * hidden_size + idx];
160
    float x = (float)z;
161
    variance += x * x;
162
    residual[blockIdx.x * hidden_size + idx] = z;
163
  }
164
165
166

  using BlockReduce = cub::BlockReduce<float, 1024>;
  __shared__ typename BlockReduce::TempStorage reduceStore;
Aidyn-A's avatar
Aidyn-A committed
167
  variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x);
168

169
170
171
172
173
174
  if (threadIdx.x == 0) {
    s_variance = rsqrtf(variance / hidden_size + epsilon);
  }
  __syncthreads();

  for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
175
    float x = (float)residual[blockIdx.x * hidden_size + idx];
176
    input[blockIdx.x * input_stride + idx] =
177
        ((scalar_t)(x * s_variance)) * weight[idx];
178
179
180
  }
}

zhuwenwen's avatar
zhuwenwen committed
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
/* Function specialization in the case of FP16/BF16 tensors.
   Additional optimizations we can make in this case are
   packed and vectorized operations, which help with the
   memory latency bottleneck.

   _f16VecPN struct extends _f16Vec to add operations specifically required for
   polynomial normalization (poly norm).
   The original _f16Vec does not include the sum-of-powers computation or
   in-place polynomial normalization logic. */
template <typename scalar_t, int width>
struct alignas(16) _f16VecPN : _f16Vec<scalar_t, width> {
  using Base = _f16Vec<scalar_t, width>;
  using Converter = typename Base::Converter;
  using T1 = typename Base::T1;
  using T2 = typename Base::T2;
  using Base::data;

  __device__ auto sum_pows() const {
    float s2 = 0.0f, s4 = 0.0f, s6 = 0.0f;

#pragma unroll
    for (int i = 0; i < width; i += 2) {
      float2 z = Converter::convert(T2{data[i], data[i + 1]});
      float x2 = z.x * z.x;
      float x4 = x2 * x2;
      float x6 = x4 * x2;

      float y2 = z.y * z.y;
      float y4 = y2 * y2;
      float y6 = y4 * y2;

      s2 += x2 + y2;
      s4 += x4 + y4;
      s6 += x6 + y6;
    }
    return std::make_tuple(s2, s4, s6);
  }

  __device__ void poly_norm_inplace(const float w2_inv_std,
                                    const float w1_inv_std2,
                                    const float w0_inv_std3, const float bias) {
#pragma unroll
    for (int i = 0; i < width; i += 2) {
      float2 z = Converter::convert(T2{data[i], data[i + 1]});

      float x2 = z.x * z.x;
      float x3 = x2 * z.x;
      z.x = w2_inv_std * z.x + w1_inv_std2 * x2 + w0_inv_std3 * x3 + bias;

      float y2 = z.y * z.y;
      float y3 = y2 * z.y;
      z.y = w2_inv_std * z.y + w1_inv_std2 * y2 + w0_inv_std3 * y3 + bias;

      auto out = Converter::convert(z);
      data[i] = out.x;
      data[i + 1] = out.y;
    }
  }
};

template <typename scalar_t, int width>
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
poly_norm_kernel(scalar_t* __restrict__ out,           // [..., hidden_size]
                 const scalar_t* __restrict__ input,   // [..., hidden_size]
                 const scalar_t* __restrict__ weight,  // [3]
                 const scalar_t* __restrict__ bias,    // [1]
                 const float epsilon, const int hidden_size) {
  // Sanity checks on our vector struct and type-punned pointer arithmetic
  static_assert(std::is_pod_v<_f16VecPN<scalar_t, width>>);
  static_assert(sizeof(_f16VecPN<scalar_t, width>) == sizeof(scalar_t) * width);

  /* These and the argument pointers are all declared `restrict` as they are
     not aliased in practice. Argument pointers should not be dereferenced
     in this kernel as that would be undefined behavior */
  auto* __restrict__ input_v =
      reinterpret_cast<const _f16VecPN<scalar_t, width>*>(input);
  const int vec_hidden_size = hidden_size / width;
  float variance = 0.0f;
  float variance2 = 0.0f;
  float variance3 = 0.0f;

  for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
    int id = blockIdx.x * vec_hidden_size + idx;
    _f16VecPN<scalar_t, width> temp = input_v[id];
    auto [x2, x4, x6] = temp.sum_pows();

    variance += x2;
    variance2 += x4;
    variance3 += x6;
  }

  float3 thread_variances = make_float3(variance, variance2, variance3);

  struct SumOp {
    __device__ float3 operator()(const float3& a, const float3& b) const {
      return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
    }
  };

  using BlockReduce = cub::BlockReduce<float3, 1024>;
  __shared__ typename BlockReduce::TempStorage reduceStore;
  float3 block_variances =
      BlockReduce(reduceStore).Reduce(thread_variances, SumOp{}, blockDim.x);

  variance = block_variances.x;
  variance2 = block_variances.y;
  variance3 = block_variances.z;

  __shared__ float s_w2_inv_std;
  __shared__ float s_w1_inv_std2;
  __shared__ float s_w0_inv_std3;
  __shared__ float s_bias;

  if (threadIdx.x == 0) {
    float w0 = (float)weight[0];
    float w1 = (float)weight[1];
    float w2 = (float)weight[2];
    s_bias = (float)bias[0];

    s_w2_inv_std = w2 * rsqrtf(variance / hidden_size + epsilon);
    s_w1_inv_std2 = w1 * rsqrtf(variance2 / hidden_size + epsilon);
    s_w0_inv_std3 = w0 * rsqrtf(variance3 / hidden_size + epsilon);
  }
  __syncthreads();

  auto* __restrict__ out_v = reinterpret_cast<_f16VecPN<scalar_t, width>*>(out);

  for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
    int id = blockIdx.x * vec_hidden_size + idx;
    _f16VecPN<scalar_t, width> temp = input_v[id];
    temp.poly_norm_inplace(s_w2_inv_std, s_w1_inv_std2, s_w0_inv_std3, s_bias);
    out_v[id] = temp;
  }
}

/* Generic poly_norm_kernel
   The width field is not used here but necessary for other specializations.
 */
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
poly_norm_kernel(scalar_t* __restrict__ out,           // [..., hidden_size]
                 const scalar_t* __restrict__ input,   // [..., hidden_size]
                 const scalar_t* __restrict__ weight,  // [3]
                 const scalar_t* __restrict__ bias,    // [1]
                 const float epsilon, const int hidden_size) {
  float variance = 0.0f;
  float variance2 = 0.0f;
  float variance3 = 0.0f;

  for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
    float x = (float)input[blockIdx.x * hidden_size + idx];
    float x2 = x * x;
    float x4 = x2 * x2;
    float x6 = x4 * x2;

    variance += x2;
    variance2 += x4;
    variance3 += x6;
  }

  float3 thread_variances = make_float3(variance, variance2, variance3);

  struct SumOp {
    __device__ float3 operator()(const float3& a, const float3& b) const {
      return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
    }
  };

  using BlockReduce = cub::BlockReduce<float3, 1024>;
  __shared__ typename BlockReduce::TempStorage reduceStore;
  float3 block_variances =
      BlockReduce(reduceStore).Reduce(thread_variances, SumOp{}, blockDim.x);

  variance = block_variances.x;
  variance2 = block_variances.y;
  variance3 = block_variances.z;

  __shared__ float s_w2_inv_std;
  __shared__ float s_w1_inv_std2;
  __shared__ float s_w0_inv_std3;
  __shared__ float s_bias;

  if (threadIdx.x == 0) {
    float w0 = (float)weight[0];
    float w1 = (float)weight[1];
    float w2 = (float)weight[2];
    s_bias = (float)bias[0];

    s_w2_inv_std = w2 * rsqrtf(variance / hidden_size + epsilon);
    s_w1_inv_std2 = w1 * rsqrtf(variance2 / hidden_size + epsilon);
    s_w0_inv_std3 = w0 * rsqrtf(variance3 / hidden_size + epsilon);
  }
  __syncthreads();

  for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
    float x = (float)input[blockIdx.x * hidden_size + idx];
    float x2 = x * x;
    float x3 = x2 * x;

    out[blockIdx.x * hidden_size + idx] =
        (scalar_t)(x * s_w2_inv_std + x2 * s_w1_inv_std2 + x3 * s_w0_inv_std3 +
                   s_bias);
  }
}

386
}  // namespace vllm
387

388
389
390
void rms_norm(torch::Tensor& out,     // [..., hidden_size]
              torch::Tensor& input,   // [..., hidden_size]
              torch::Tensor& weight,  // [hidden_size]
391
              double epsilon) {
392
  TORCH_CHECK(out.is_contiguous());
zhuwenwen's avatar
zhuwenwen committed
393
394
395
  if (input.stride(-1) != 1) {
    input = input.contiguous();
  }
396
  TORCH_CHECK(input.stride(-1) == 1);
397
398
  TORCH_CHECK(weight.is_contiguous());

399
  int hidden_size = input.size(-1);
400

zhuwenwen's avatar
zhuwenwen committed
401
402
403
404
405
406
407
408
409
410
  int num_tokens = input.numel() / hidden_size;
  int num_dims = input.dim();
  int64_t input_stride_d2 = input.stride(-2);
  int64_t input_stride_d3 = (num_dims >= 3) ? input.stride(-3) : 0;
  int64_t input_stride_d4 = (num_dims >= 4) ? input.stride(-4) : 0;
  int64_t input_shape_d2 = (num_dims >= 3) ? input.size(-2) : 0;
  int64_t input_shape_d3 = (num_dims >= 4) ? input.size(-3) : 0;

  // For large num_tokens, use smaller blocks to increase SM concurrency.
  const int max_block_size = (num_tokens < 256) ? 1024 : 256;
411
  dim3 grid(num_tokens);
412
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
413
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
zhuwenwen's avatar
zhuwenwen committed
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
  VLLM_DISPATCH_RANK234(num_dims, [&] {
    VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
      const int calculated_vec_size =
          std::gcd(16 / sizeof(scalar_t), hidden_size);
      const int block_size =
          std::min(hidden_size / calculated_vec_size, max_block_size);
      dim3 block(block_size);
      VLLM_DISPATCH_VEC_SIZE(calculated_vec_size, [&] {
        vllm::rms_norm_kernel<scalar_t, vec_size, tensor_rank>
            <<<grid, block, 0, stream>>>(
                out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
                input_stride_d2, input_stride_d3, input_stride_d4,
                input_shape_d2, input_shape_d3, weight.data_ptr<scalar_t>(),
                epsilon, num_tokens, hidden_size);
      });
    });
430
  });
431
}
432

433
434
435
436
437
438
439
440
#define LAUNCH_FUSED_ADD_RMS_NORM(width)                                    \
  VLLM_DISPATCH_FLOATING_TYPES(                                             \
      input.scalar_type(), "fused_add_rms_norm_kernel", [&] {               \
        vllm::fused_add_rms_norm_kernel<scalar_t, width>                    \
            <<<grid, block, 0, stream>>>(                                   \
                input.data_ptr<scalar_t>(), input_stride,                   \
                residual.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(), \
                epsilon, num_tokens, hidden_size);                          \
441
442
443
444
445
      });

void fused_add_rms_norm(torch::Tensor& input,     // [..., hidden_size]
                        torch::Tensor& residual,  // [..., hidden_size]
                        torch::Tensor& weight,    // [hidden_size]
446
                        double epsilon) {
zhuwenwen's avatar
zhuwenwen committed
447
448
  TORCH_CHECK(weight.scalar_type() == input.scalar_type());
  TORCH_CHECK(input.scalar_type() == residual.scalar_type());
449
450
  TORCH_CHECK(residual.is_contiguous());
  TORCH_CHECK(weight.is_contiguous());
451
  int hidden_size = input.size(-1);
452
  int64_t input_stride = input.stride(-2);
453
454
455
  int num_tokens = input.numel() / hidden_size;

  dim3 grid(num_tokens);
456
457
458
459
460
461
  /* This kernel is memory-latency bound in many scenarios.
     When num_tokens is large, a smaller block size allows
     for increased block occupancy on CUs and better latency
     hiding on global mem ops. */
  const int max_block_size = (num_tokens < 256) ? 1024 : 256;
  dim3 block(std::min(hidden_size, max_block_size));
462
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
463
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
464
465
466
467
468
469
470
471
472
473
  /*If the tensor types are FP16/BF16, try to use the optimized kernel
    with packed + vectorized ops.
    Max optimization is achieved with a width-8 vector of FP16/BF16s
    since we can load at most 128 bits at once in a global memory op.
    However, this requires each tensor's data to be aligned to 16
    bytes.
   */
  auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
  auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
  auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
474
475
476
477
478
479
480
481
482
  constexpr int vector_width = 8;
  constexpr int req_alignment_bytes =
      vector_width * 2;  // vector_width * sizeof(bfloat16 or float16) (float32
                         // falls back to non-vectorized version anyway)
  bool ptrs_are_aligned = inp_ptr % req_alignment_bytes == 0 &&
                          res_ptr % req_alignment_bytes == 0 &&
                          wt_ptr % req_alignment_bytes == 0;
  bool offsets_are_multiple_of_vector_width =
      hidden_size % vector_width == 0 && input_stride % vector_width == 0;
zhuwenwen's avatar
zhuwenwen committed
483
  if (ptrs_are_aligned && offsets_are_multiple_of_vector_width) {
484
485
486
487
    LAUNCH_FUSED_ADD_RMS_NORM(8);
  } else {
    LAUNCH_FUSED_ADD_RMS_NORM(0);
  }
zhuwenwen's avatar
zhuwenwen committed
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
}

#define LAUNCH_FUSED_POLY_NORM(width)                                         \
  VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "poly_norm_kernel", [&] { \
    vllm::poly_norm_kernel<scalar_t, width><<<grid, block, 0, stream>>>(      \
        out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),                 \
        weight.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(), epsilon,      \
        hidden_size);                                                         \
  });

void poly_norm(torch::Tensor& out,     // [..., hidden_size]
               torch::Tensor& input,   // [..., hidden_size]
               torch::Tensor& weight,  // [3]
               torch::Tensor& bias,    // [1]
               double epsilon) {
  TORCH_CHECK(out.is_contiguous());
  TORCH_CHECK(input.is_contiguous());
  TORCH_CHECK(out.data_ptr() != input.data_ptr());

  int hidden_size = input.size(-1);
  int num_tokens = input.numel() / hidden_size;

  dim3 grid(num_tokens);
  /* This kernel is memory-latency bound in many scenarios.
     When num_tokens is large, a smaller block size allows
     for increased block occupancy on CUs and better latency
     hiding on global mem ops. */
  const int max_block_size = (num_tokens < 256) ? 1024 : 256;
  dim3 block(std::min(hidden_size, max_block_size));
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  /*If the tensor types are FP16/BF16, try to use the optimized kernel
    with packed + vectorized ops.
    Max optimization is achieved with a width-8 vector of FP16/BF16s
    since we can load at most 128 bits at once in a global memory op.
    However, this requires each tensor's data to be aligned to 16
    bytes.
   */
  auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
  auto out_ptr = reinterpret_cast<std::uintptr_t>(out.data_ptr());
  bool ptrs_are_aligned = inp_ptr % 16 == 0 && out_ptr % 16 == 0;
  if (ptrs_are_aligned && hidden_size % 8 == 0) {
    LAUNCH_FUSED_POLY_NORM(8);
  } else {
    LAUNCH_FUSED_POLY_NORM(0);
  }
zhuwenwen's avatar
zhuwenwen committed
534
}