layernorm_kernels.cu 12.9 KB
Newer Older
1
#include <torch/all.h>
2
#include <ATen/cuda/CUDAContext.h>
3
#include <c10/cuda/CUDAGuard.h>
zhuwenwen's avatar
zhuwenwen committed
4

5
#include "dispatch_utils.h"
6
7
8
#ifndef USE_ROCM
  #include <cuda_bf16.h>
  #include <cuda_fp16.h>
9
10
  #include <cub/util_type.cuh>
  #include <cub/cub.cuh>
11
12
13
#else
  #include <hip/hip_bf16.h>
  #include <hip/hip_fp16.h>
14
15
  #include <hipcub/util_type.hpp>
  #include <hipcub/hipcub.hpp>
16

17
18
using __nv_bfloat16 = __hip_bfloat16;
using __nv_bfloat162 = __hip_bfloat162;
19
#endif
20

Woosuk Kwon's avatar
Woosuk Kwon committed
21
namespace vllm {
22
23

// TODO(woosuk): Further optimize this kernel.
24
template <typename scalar_t>
25
__global__ void rms_norm_kernel(
26
27
28
29
    scalar_t* __restrict__ out,           // [..., hidden_size]
    const scalar_t* __restrict__ input,   // [..., hidden_size]
    const scalar_t* __restrict__ weight,  // [hidden_size]
    const float epsilon, const int num_tokens, const int hidden_size) {
30
31
32
33
  __shared__ float s_variance;
  float variance = 0.0f;

  for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
34
    const float x = (float)input[blockIdx.x * hidden_size + idx];
35
36
    variance += x * x;
  }
37
38
39
40
41

  using BlockReduce = cub::BlockReduce<float, 1024>;
  __shared__ typename BlockReduce::TempStorage reduceStore;
  variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);

42
43
44
45
46
47
  if (threadIdx.x == 0) {
    s_variance = rsqrtf(variance / hidden_size + epsilon);
  }
  __syncthreads();

  for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
48
49
50
    float x = (float)input[blockIdx.x * hidden_size + idx];
    out[blockIdx.x * hidden_size + idx] =
        ((scalar_t)(x * s_variance)) * weight[idx];
51
52
53
  }
}

54
55
56
57
58
59
60
61
/* Converter structs for the conversion from torch types to HIP/CUDA types,
   and the associated type conversions within HIP/CUDA. These helpers need
   to be implemented for now because the relevant type conversion
   operators/constructors are not consistently implemented by HIP/CUDA, so
   a generic conversion via type casts cannot be implemented.

   Each struct should have the member static constexpr bool `exists`:
   If false, the optimized kernel is not used for the corresponding torch type.
62
   If true, the struct should be fully defined as shown in the examples below.
63
 */
64
65
66
67
template <typename torch_type>
struct _typeConvert {
  static constexpr bool exists = false;
};
68

69
70
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
// CUDA < 12.0 runs into issues with packed type conversion
71
template <>
72
73
74
75
76
77
struct _typeConvert<c10::Half> {
  static constexpr bool exists = true;
  using hip_type = __half;
  using packed_hip_type = __half2;

  __device__ static inline float convert(hip_type x) { return __half2float(x); }
78
79
80
81
82
83
84
85
86
  __device__ static inline float2 convert(packed_hip_type x) {
    return __half22float2(x);
  }
  __device__ static inline hip_type convert(float x) {
    return __float2half_rn(x);
  }
  __device__ static inline packed_hip_type convert(float2 x) {
    return __float22half2_rn(x);
  }
87
88
};

89
  #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
90
91
// CUDA_ARCH < 800 does not have BF16 support
// TODO: Add in ROCm support once public headers handle bf16 maturely
92
template <>
93
94
95
96
97
struct _typeConvert<c10::BFloat16> {
  static constexpr bool exists = true;
  using hip_type = __nv_bfloat16;
  using packed_hip_type = __nv_bfloat162;

98
99
100
101
102
103
104
105
106
107
108
109
  __device__ static inline float convert(hip_type x) {
    return __bfloat162float(x);
  }
  __device__ static inline float2 convert(packed_hip_type x) {
    return __bfloat1622float2(x);
  }
  __device__ static inline hip_type convert(float x) {
    return __float2bfloat16(x);
  }
  __device__ static inline packed_hip_type convert(float2 x) {
    return __float22bfloat162_rn(x);
  }
110
};
111
112
113
  #endif  // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#endif    // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
          // 12000))
114
115
116
117
118
119

/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
   for appropriate specializations of fused_add_rms_norm_kernel.
   Only functions that are necessary in that kernel are implemented.
   Alignment to 16 bytes is required to use 128-bit global memory ops.
 */
120
template <typename scalar_t, int width>
121
struct alignas(16) _f16Vec {
122
123
  /* Not theoretically necessary that width is a power of 2 but should
     almost always be the case for optimization purposes */
124
125
126
127
128
129
130
131
132
  static_assert(width > 0 && (width & (width - 1)) == 0,
                "Width is not a positive power of 2!");
  using Converter = _typeConvert<scalar_t>;
  using T1 = typename Converter::hip_type;
  using T2 = typename Converter::packed_hip_type;
  T1 data[width];

  __device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) {
    if constexpr (width % 2 == 0) {
133
#pragma unroll
134
      for (int i = 0; i < width; i += 2) {
135
136
        T2 temp{data[i], data[i + 1]};
        temp += T2{other.data[i], other.data[i + 1]};
137
        data[i] = temp.x;
138
        data[i + 1] = temp.y;
139
140
      }
    } else {
141
142
#pragma unroll
      for (int i = 0; i < width; ++i) data[i] += other.data[i];
143
144
145
146
147
148
    }
    return *this;
  }

  __device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) {
    if constexpr (width % 2 == 0) {
149
#pragma unroll
150
      for (int i = 0; i < width; i += 2) {
151
152
        T2 temp{data[i], data[i + 1]};
        temp *= T2{other.data[i], other.data[i + 1]};
153
        data[i] = temp.x;
154
        data[i + 1] = temp.y;
155
156
      }
    } else {
157
158
#pragma unroll
      for (int i = 0; i < width; ++i) data[i] *= other.data[i];
159
160
161
162
163
164
    }
    return *this;
  }

  __device__ _f16Vec& operator*=(const float scale) {
    if constexpr (width % 2 == 0) {
165
#pragma unroll
166
      for (int i = 0; i < width; i += 2) {
167
        float2 temp_f = Converter::convert(T2{data[i], data[i + 1]});
168
169
170
171
        temp_f.x *= scale;
        temp_f.y *= scale;
        T2 temp = Converter::convert(temp_f);
        data[i] = temp.x;
172
        data[i + 1] = temp.y;
173
174
      }
    } else {
175
#pragma unroll
176
177
178
179
180
181
182
183
184
185
186
      for (int i = 0; i < width; ++i) {
        float temp = Converter::convert(data[i]) * scale;
        data[i] = Converter::convert(temp);
      }
    }
    return *this;
  }

  __device__ float sum_squares() const {
    float result = 0.0f;
    if constexpr (width % 2 == 0) {
187
#pragma unroll
188
      for (int i = 0; i < width; i += 2) {
189
        float2 z = Converter::convert(T2{data[i], data[i + 1]});
190
191
192
        result += z.x * z.x + z.y * z.y;
      }
    } else {
193
#pragma unroll
194
195
196
197
198
199
200
201
202
203
204
205
206
      for (int i = 0; i < width; ++i) {
        float x = Converter::convert(data[i]);
        result += x * x;
      }
    }
    return result;
  }
};

/* Function specialization in the case of FP16/BF16 tensors.
   Additional optimizations we can make in this case are
   packed and vectorized operations, which help with the
   memory latency bottleneck. */
207
208
209
210
211
212
213
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
fused_add_rms_norm_kernel(
    scalar_t* __restrict__ input,         // [..., hidden_size]
    scalar_t* __restrict__ residual,      // [..., hidden_size]
    const scalar_t* __restrict__ weight,  // [hidden_size]
    const float epsilon, const int num_tokens, const int hidden_size) {
214
215
216
217
218
219
220
221
222
223
  // Sanity checks on our vector struct and type-punned pointer arithmetic
  static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
  static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);

  const int vec_hidden_size = hidden_size / width;
  __shared__ float s_variance;
  float variance = 0.0f;
  /* These and the argument pointers are all declared `restrict` as they are
     not aliased in practice. Argument pointers should not be dereferenced
     in this kernel as that would be undefined behavior */
224
225
226
227
228
229
  auto* __restrict__ input_v =
      reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
  auto* __restrict__ residual_v =
      reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
  auto* __restrict__ weight_v =
      reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
230
231
232
233
234
235
236
237

  for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
    int id = blockIdx.x * vec_hidden_size + idx;
    _f16Vec<scalar_t, width> temp = input_v[id];
    temp += residual_v[id];
    variance += temp.sum_squares();
    residual_v[id] = temp;
  }
238
239
240
241
242

  using BlockReduce = cub::BlockReduce<float, 1024>;
  __shared__ typename BlockReduce::TempStorage reduceStore;
  variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);

243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
  if (threadIdx.x == 0) {
    s_variance = rsqrtf(variance / hidden_size + epsilon);
  }
  __syncthreads();

  for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
    int id = blockIdx.x * vec_hidden_size + idx;
    _f16Vec<scalar_t, width> temp = residual_v[id];
    temp *= s_variance;
    temp *= weight_v[idx];
    input_v[id] = temp;
  }
}

/* Generic fused_add_rms_norm_kernel
   The width field is not used here but necessary for other specializations.
 */
260
261
262
263
264
265
266
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
fused_add_rms_norm_kernel(
    scalar_t* __restrict__ input,         // [..., hidden_size]
    scalar_t* __restrict__ residual,      // [..., hidden_size]
    const scalar_t* __restrict__ weight,  // [hidden_size]
    const float epsilon, const int num_tokens, const int hidden_size) {
267
268
269
270
  __shared__ float s_variance;
  float variance = 0.0f;

  for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
271
272
    scalar_t z = input[blockIdx.x * hidden_size + idx];
    z += residual[blockIdx.x * hidden_size + idx];
273
    float x = (float)z;
274
    variance += x * x;
275
    residual[blockIdx.x * hidden_size + idx] = z;
276
  }
277
278
279
280
281

  using BlockReduce = cub::BlockReduce<float, 1024>;
  __shared__ typename BlockReduce::TempStorage reduceStore;
  variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);

282
283
284
285
286
287
  if (threadIdx.x == 0) {
    s_variance = rsqrtf(variance / hidden_size + epsilon);
  }
  __syncthreads();

  for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
288
289
290
    float x = (float)residual[blockIdx.x * hidden_size + idx];
    input[blockIdx.x * hidden_size + idx] =
        ((scalar_t)(x * s_variance)) * weight[idx];
291
292
293
  }
}

294
}  // namespace vllm
295

296
297
298
void rms_norm(torch::Tensor& out,     // [..., hidden_size]
              torch::Tensor& input,   // [..., hidden_size]
              torch::Tensor& weight,  // [hidden_size]
299
              double epsilon) {
300
301
  int hidden_size = input.size(-1);
  int num_tokens = input.numel() / hidden_size;
zhuwenwen's avatar
zhuwenwen committed
302
303
304

  dim3 grid(num_tokens);
  dim3 block(std::min(hidden_size, 1024));
305
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
306
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
zhuwenwen's avatar
zhuwenwen committed
307
308
309
310
311
  VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
    vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
        out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
        weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
  });
312
}
313

314
315
316
317
318
319
320
321
322
323
324
325
326
#define LAUNCH_FUSED_ADD_RMS_NORM(width)                                       \
  VLLM_DISPATCH_FLOATING_TYPES(                                                \
      input.scalar_type(), "fused_add_rms_norm_kernel", [&] {                  \
        vllm::fused_add_rms_norm_kernel<scalar_t, width>                       \
            <<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),           \
                                         residual.data_ptr<scalar_t>(),        \
                                         weight.data_ptr<scalar_t>(), epsilon, \
                                         num_tokens, hidden_size);             \
      });

void fused_add_rms_norm(torch::Tensor& input,     // [..., hidden_size]
                        torch::Tensor& residual,  // [..., hidden_size]
                        torch::Tensor& weight,    // [hidden_size]
327
                        double epsilon) {
328
329
  int hidden_size = input.size(-1);
  int num_tokens = input.numel() / hidden_size;
zhuwenwen's avatar
zhuwenwen committed
330
331
332
333
334
335
336
337

  dim3 grid(num_tokens);
  /* This kernel is memory-latency bound in many scenarios.
     When num_tokens is large, a smaller block size allows
     for increased block occupancy on CUs and better latency
     hiding on global mem ops. */
  const int max_block_size = (num_tokens < 256) ? 1024 : 256;
  dim3 block(std::min(hidden_size, max_block_size));
338
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
339
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
zhuwenwen's avatar
zhuwenwen committed
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
  /*If the tensor types are FP16/BF16, try to use the optimized kernel
    with packed + vectorized ops.
    Max optimization is achieved with a width-8 vector of FP16/BF16s
    since we can load at most 128 bits at once in a global memory op.
    However, this requires each tensor's data to be aligned to 16
    bytes.
   */
  auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
  auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
  auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
  bool ptrs_are_aligned =
      inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
  if (ptrs_are_aligned && hidden_size % 8 == 0) {
    LAUNCH_FUSED_ADD_RMS_NORM(8);
  } else {
    LAUNCH_FUSED_ADD_RMS_NORM(0);
zhangshao's avatar
zhangshao committed
356
  }
zhuwenwen's avatar
zhuwenwen committed
357
}