layernorm_quant_kernels.cu 10.8 KB
Newer Older
1
2
3
4
5
6
7
8
/*
 * This file contains the CUDA kernels for the fused quantized layernorm.
 * The kernels correspond to the kernels in layernorm_kernels.cu, except they
 * also produce quantized output directly.
 * Currently, only static fp8 quantization is supported.
 */

#include "type_convert.cuh"
9
#include "quantization/w8a8/fp8/common.cuh"
10
#include "dispatch_utils.h"
Aidyn-A's avatar
Aidyn-A committed
11
#include "cub_helpers.h"
12
#include "core/batch_invariant.hpp"
13
#include "quantization/vectorization_utils.cuh"
14
15
16
17
18
19
20

#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>

namespace vllm {

// TODO(woosuk): Further optimize this kernel.
21
template <typename scalar_t, typename fp8_type>
22
__global__ void rms_norm_static_fp8_quant_kernel(
23
24
25
    fp8_type* __restrict__ out,          // [..., hidden_size]
    const scalar_t* __restrict__ input,  // [..., hidden_size]
    const int input_stride,
26
27
28
29
30
31
    const scalar_t* __restrict__ weight,  // [hidden_size]
    const float* __restrict__ scale,      // [1]
    const float epsilon, const int num_tokens, const int hidden_size) {
  __shared__ float s_variance;
  float variance = 0.0f;

32
33
34
35
36
37
38
39
40
41
42
43
  const scalar_t* input_row = input + blockIdx.x * input_stride;

  constexpr int VEC_SIZE = 8;
  auto vec_op = [&variance](const vec_n_t<scalar_t, VEC_SIZE>& vec) {
#pragma unroll
    for (int i = 0; i < VEC_SIZE; ++i) {
      float x = static_cast<float>(vec.val[i]);
      variance += x * x;
    }
  };
  auto scalar_op = [&variance](const scalar_t& val) {
    float x = static_cast<float>(val);
44
    variance += x * x;
45
46
47
  };
  vllm::vectorize_read_with_alignment<VEC_SIZE>(
      input_row, hidden_size, threadIdx.x, blockDim.x, vec_op, scalar_op);
48
49
50

  using BlockReduce = cub::BlockReduce<float, 1024>;
  __shared__ typename BlockReduce::TempStorage reduceStore;
Aidyn-A's avatar
Aidyn-A committed
51
  variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x);
52
53
54
55
56
57
58
59
60
61

  if (threadIdx.x == 0) {
    s_variance = rsqrtf(variance / hidden_size + epsilon);
  }
  __syncthreads();

  // invert scale to avoid division
  float const scale_inv = 1.0f / *scale;

  for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
62
    float x = (float)input[blockIdx.x * input_stride + idx];
63
64
    float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
    out[blockIdx.x * hidden_size + idx] =
65
        scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
66
67
68
69
70
71
72
  }
}

/* Function specialization in the case of FP16/BF16 tensors.
   Additional optimizations we can make in this case are
   packed and vectorized operations, which help with the
   memory latency bottleneck. */
73
template <typename scalar_t, int width, typename fp8_type>
74
75
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
fused_add_rms_norm_static_fp8_quant_kernel(
76
77
78
    fp8_type* __restrict__ out,    // [..., hidden_size]
    scalar_t* __restrict__ input,  // [..., hidden_size]
    const int input_stride,
79
80
81
82
83
84
85
86
87
    scalar_t* __restrict__ residual,      // [..., hidden_size]
    const scalar_t* __restrict__ weight,  // [hidden_size]
    const float* __restrict__ scale,      // [1]
    const float epsilon, const int num_tokens, const int hidden_size) {
  // Sanity checks on our vector struct and type-punned pointer arithmetic
  static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
  static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);

  const int vec_hidden_size = hidden_size / width;
88
  const int vec_input_stride = input_stride / width;
89
90
91
92
93
94
95
96
97
98
99
100
101
  __shared__ float s_variance;
  float variance = 0.0f;
  /* These and the argument pointers are all declared `restrict` as they are
     not aliased in practice. Argument pointers should not be dereferenced
     in this kernel as that would be undefined behavior */
  auto* __restrict__ input_v =
      reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
  auto* __restrict__ residual_v =
      reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
  auto* __restrict__ weight_v =
      reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);

  for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
102
    int stride_id = blockIdx.x * vec_input_stride + idx;
103
    int id = blockIdx.x * vec_hidden_size + idx;
104
    _f16Vec<scalar_t, width> temp = input_v[stride_id];
105
106
107
108
109
110
111
    temp += residual_v[id];
    variance += temp.sum_squares();
    residual_v[id] = temp;
  }

  using BlockReduce = cub::BlockReduce<float, 1024>;
  __shared__ typename BlockReduce::TempStorage reduceStore;
Aidyn-A's avatar
Aidyn-A committed
112
  variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x);
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129

  if (threadIdx.x == 0) {
    s_variance = rsqrtf(variance / hidden_size + epsilon);
  }
  __syncthreads();

  // invert scale to avoid division
  float const scale_inv = 1.0f / *scale;

  for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
    int id = blockIdx.x * vec_hidden_size + idx;
    _f16Vec<scalar_t, width> temp = residual_v[id];
    temp *= s_variance;
    temp *= weight_v[idx];
#pragma unroll
    for (int i = 0; i < width; ++i) {
      out[id * width + i] =
130
          scaled_fp8_conversion<true, fp8_type>(float(temp.data[i]), scale_inv);
131
132
133
134
135
136
137
    }
  }
}

/* Generic fused_add_rms_norm_kernel
   The width field is not used here but necessary for other specializations.
 */
138
template <typename scalar_t, int width, typename fp8_type>
139
140
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
fused_add_rms_norm_static_fp8_quant_kernel(
141
142
143
    fp8_type* __restrict__ out,    // [..., hidden_size]
    scalar_t* __restrict__ input,  // [..., hidden_size]
    const int input_stride,
144
145
146
147
148
149
150
151
    scalar_t* __restrict__ residual,      // [..., hidden_size]
    const scalar_t* __restrict__ weight,  // [hidden_size]
    const float* __restrict__ scale,      // [1]
    const float epsilon, const int num_tokens, const int hidden_size) {
  __shared__ float s_variance;
  float variance = 0.0f;

  for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
152
    scalar_t z = input[blockIdx.x * input_stride + idx];
153
154
155
156
157
158
159
160
    z += residual[blockIdx.x * hidden_size + idx];
    float x = (float)z;
    variance += x * x;
    residual[blockIdx.x * hidden_size + idx] = z;
  }

  using BlockReduce = cub::BlockReduce<float, 1024>;
  __shared__ typename BlockReduce::TempStorage reduceStore;
Aidyn-A's avatar
Aidyn-A committed
161
  variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x);
162
163
164
165
166
167
168
169
170
171
172
173
174

  if (threadIdx.x == 0) {
    s_variance = rsqrtf(variance / hidden_size + epsilon);
  }
  __syncthreads();

  // invert scale to avoid division
  float const scale_inv = 1.0f / *scale;

  for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
    float x = (float)residual[blockIdx.x * hidden_size + idx];
    float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
    out[blockIdx.x * hidden_size + idx] =
175
        scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
176
177
178
179
180
181
182
183
184
185
  }
}

}  // namespace vllm

void rms_norm_static_fp8_quant(torch::Tensor& out,     // [..., hidden_size]
                               torch::Tensor& input,   // [..., hidden_size]
                               torch::Tensor& weight,  // [hidden_size]
                               torch::Tensor& scale,   // [1]
                               double epsilon) {
186
  TORCH_CHECK(out.is_contiguous());
187
  int hidden_size = input.size(-1);
188
  int input_stride = input.stride(-2);
189
190
191
192
193
194
  int num_tokens = input.numel() / hidden_size;

  dim3 grid(num_tokens);
  dim3 block(std::min(hidden_size, 1024));
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
195
196
197
198
199
200
201
  VLLM_DISPATCH_FLOATING_TYPES(
      input.scalar_type(), "rms_norm_kernel_scalar_type", [&] {
        VLLM_DISPATCH_FP8_TYPES(
            out.scalar_type(), "rms_norm_kernel_fp8_type", [&] {
              vllm::rms_norm_static_fp8_quant_kernel<scalar_t, fp8_t>
                  <<<grid, block, 0, stream>>>(
                      out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
202
203
204
                      input_stride, weight.data_ptr<scalar_t>(),
                      scale.data_ptr<float>(), epsilon, num_tokens,
                      hidden_size);
205
206
            });
      });
207
208
}

209
210
211
212
213
214
215
216
217
#define LAUNCH_FUSED_ADD_RMS_NORM(width)                                     \
  VLLM_DISPATCH_FLOATING_TYPES(                                              \
      input.scalar_type(), "fused_add_rms_norm_kernel_scalar_type", [&] {    \
        VLLM_DISPATCH_FP8_TYPES(                                             \
            out.scalar_type(), "fused_add_rms_norm_kernel_fp8_type", [&] {   \
              vllm::fused_add_rms_norm_static_fp8_quant_kernel<scalar_t,     \
                                                               width, fp8_t> \
                  <<<grid, block, 0, stream>>>(                              \
                      out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),     \
218
                      input_stride, residual.data_ptr<scalar_t>(),           \
219
220
221
                      weight.data_ptr<scalar_t>(), scale.data_ptr<float>(),  \
                      epsilon, num_tokens, hidden_size);                     \
            });                                                              \
222
223
224
225
226
227
228
229
      });
void fused_add_rms_norm_static_fp8_quant(
    torch::Tensor& out,       // [..., hidden_size],
    torch::Tensor& input,     // [..., hidden_size]
    torch::Tensor& residual,  // [..., hidden_size]
    torch::Tensor& weight,    // [hidden_size]
    torch::Tensor& scale,     // [1]
    double epsilon) {
230
231
  TORCH_CHECK(out.is_contiguous());
  TORCH_CHECK(residual.is_contiguous());
232
  int hidden_size = input.size(-1);
233
  int input_stride = input.stride(-2);
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
  int num_tokens = input.numel() / hidden_size;

  dim3 grid(num_tokens);
  /* This kernel is memory-latency bound in many scenarios.
     When num_tokens is large, a smaller block size allows
     for increased block occupancy on CUs and better latency
     hiding on global mem ops. */
  const int max_block_size = (num_tokens < 256) ? 1024 : 256;
  dim3 block(std::min(hidden_size, max_block_size));
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  /*If the tensor types are FP16/BF16, try to use the optimized kernel
    with packed + vectorized ops.
    Max optimization is achieved with a width-8 vector of FP16/BF16s
    since we can load at most 128 bits at once in a global memory op.
    However, this requires each tensor's data to be aligned to 16
    bytes.
   */
  auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
  auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
  auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
  bool ptrs_are_aligned =
      inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
257
  bool batch_invariant_launch = vllm::vllm_is_batch_invariant();
258
259
  if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0 &&
      !batch_invariant_launch) {
260
261
262
263
264
    LAUNCH_FUSED_ADD_RMS_NORM(8);
  } else {
    LAUNCH_FUSED_ADD_RMS_NORM(0);
  }
}