layernorm_kernels.cu 18.2 KB
Newer Older
1
#include <torch/all.h>
2
#include <ATen/cuda/CUDAContext.h>
3
#include <c10/cuda/CUDAGuard.h>
zhangshao's avatar
zhangshao committed
4
5
6
7
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <c10/cuda/CUDAMathCompat.h>
#include <ATen/AccumulateType.h>
#include <THC/THCDeviceUtils.cuh>
8
#include "dispatch_utils.h"
Woosuk Kwon's avatar
Woosuk Kwon committed
9
#include "reduction_utils.cuh"
10
11
12
13
14
15
16
#ifndef USE_ROCM
  #include <cuda_bf16.h>
  #include <cuda_fp16.h>
#else
  #include <hip/hip_bf16.h>
  #include <hip/hip_fp16.h>

17
18
using __nv_bfloat16 = __hip_bfloat16;
using __nv_bfloat162 = __hip_bfloat162;
19
#endif
20

Woosuk Kwon's avatar
Woosuk Kwon committed
21
namespace vllm {
22
23

// TODO(woosuk): Further optimize this kernel.
24
template <typename scalar_t>
25
__global__ void rms_norm_kernel(
26
27
28
29
    scalar_t* __restrict__ out,           // [..., hidden_size]
    const scalar_t* __restrict__ input,   // [..., hidden_size]
    const scalar_t* __restrict__ weight,  // [hidden_size]
    const float epsilon, const int num_tokens, const int hidden_size) {
30
31
32
33
  __shared__ float s_variance;
  float variance = 0.0f;

  for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
34
    const float x = (float)input[blockIdx.x * hidden_size + idx];
35
36
37
38
39
40
41
42
43
    variance += x * x;
  }
  variance = blockReduceSum<float>(variance);
  if (threadIdx.x == 0) {
    s_variance = rsqrtf(variance / hidden_size + epsilon);
  }
  __syncthreads();

  for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
44
45
46
    float x = (float)input[blockIdx.x * hidden_size + idx];
    out[blockIdx.x * hidden_size + idx] =
        ((scalar_t)(x * s_variance)) * weight[idx];
47
48
49
  }
}

50
51
52
53
54
55
56
57
/* Converter structs for the conversion from torch types to HIP/CUDA types,
   and the associated type conversions within HIP/CUDA. These helpers need
   to be implemented for now because the relevant type conversion
   operators/constructors are not consistently implemented by HIP/CUDA, so
   a generic conversion via type casts cannot be implemented.

   Each struct should have the member static constexpr bool `exists`:
   If false, the optimized kernel is not used for the corresponding torch type.
58
   If true, the struct should be fully defined as shown in the examples below.
59
 */
60
61
62
63
template <typename torch_type>
struct _typeConvert {
  static constexpr bool exists = false;
};
64

65
66
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
// CUDA < 12.0 runs into issues with packed type conversion
67
template <>
68
69
70
71
72
73
struct _typeConvert<c10::Half> {
  static constexpr bool exists = true;
  using hip_type = __half;
  using packed_hip_type = __half2;

  __device__ static inline float convert(hip_type x) { return __half2float(x); }
74
75
76
77
78
79
80
81
82
  __device__ static inline float2 convert(packed_hip_type x) {
    return __half22float2(x);
  }
  __device__ static inline hip_type convert(float x) {
    return __float2half_rn(x);
  }
  __device__ static inline packed_hip_type convert(float2 x) {
    return __float22half2_rn(x);
  }
83
84
};

85
  #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
86
87
// CUDA_ARCH < 800 does not have BF16 support
// TODO: Add in ROCm support once public headers handle bf16 maturely
88
template <>
89
90
91
92
93
struct _typeConvert<c10::BFloat16> {
  static constexpr bool exists = true;
  using hip_type = __nv_bfloat16;
  using packed_hip_type = __nv_bfloat162;

94
95
96
97
98
99
100
101
102
103
104
105
  __device__ static inline float convert(hip_type x) {
    return __bfloat162float(x);
  }
  __device__ static inline float2 convert(packed_hip_type x) {
    return __bfloat1622float2(x);
  }
  __device__ static inline hip_type convert(float x) {
    return __float2bfloat16(x);
  }
  __device__ static inline packed_hip_type convert(float2 x) {
    return __float22bfloat162_rn(x);
  }
106
};
107
108
109
  #endif  // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#endif    // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
          // 12000))
110
111
112
113
114
115

/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
   for appropriate specializations of fused_add_rms_norm_kernel.
   Only functions that are necessary in that kernel are implemented.
   Alignment to 16 bytes is required to use 128-bit global memory ops.
 */
116
template <typename scalar_t, int width>
117
struct alignas(16) _f16Vec {
118
119
  /* Not theoretically necessary that width is a power of 2 but should
     almost always be the case for optimization purposes */
120
121
122
123
124
125
126
127
128
  static_assert(width > 0 && (width & (width - 1)) == 0,
                "Width is not a positive power of 2!");
  using Converter = _typeConvert<scalar_t>;
  using T1 = typename Converter::hip_type;
  using T2 = typename Converter::packed_hip_type;
  T1 data[width];

  __device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) {
    if constexpr (width % 2 == 0) {
129
#pragma unroll
130
      for (int i = 0; i < width; i += 2) {
131
132
        T2 temp{data[i], data[i + 1]};
        temp += T2{other.data[i], other.data[i + 1]};
133
        data[i] = temp.x;
134
        data[i + 1] = temp.y;
135
136
      }
    } else {
137
138
#pragma unroll
      for (int i = 0; i < width; ++i) data[i] += other.data[i];
139
140
141
142
143
144
    }
    return *this;
  }

  __device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) {
    if constexpr (width % 2 == 0) {
145
#pragma unroll
146
      for (int i = 0; i < width; i += 2) {
147
148
        T2 temp{data[i], data[i + 1]};
        temp *= T2{other.data[i], other.data[i + 1]};
149
        data[i] = temp.x;
150
        data[i + 1] = temp.y;
151
152
      }
    } else {
153
154
#pragma unroll
      for (int i = 0; i < width; ++i) data[i] *= other.data[i];
155
156
157
158
159
160
    }
    return *this;
  }

  __device__ _f16Vec& operator*=(const float scale) {
    if constexpr (width % 2 == 0) {
161
#pragma unroll
162
      for (int i = 0; i < width; i += 2) {
163
        float2 temp_f = Converter::convert(T2{data[i], data[i + 1]});
164
165
166
167
        temp_f.x *= scale;
        temp_f.y *= scale;
        T2 temp = Converter::convert(temp_f);
        data[i] = temp.x;
168
        data[i + 1] = temp.y;
169
170
      }
    } else {
171
#pragma unroll
172
173
174
175
176
177
178
179
180
181
182
      for (int i = 0; i < width; ++i) {
        float temp = Converter::convert(data[i]) * scale;
        data[i] = Converter::convert(temp);
      }
    }
    return *this;
  }

  __device__ float sum_squares() const {
    float result = 0.0f;
    if constexpr (width % 2 == 0) {
183
#pragma unroll
184
      for (int i = 0; i < width; i += 2) {
185
        float2 z = Converter::convert(T2{data[i], data[i + 1]});
186
187
188
        result += z.x * z.x + z.y * z.y;
      }
    } else {
189
#pragma unroll
190
191
192
193
194
195
196
197
198
199
200
201
202
      for (int i = 0; i < width; ++i) {
        float x = Converter::convert(data[i]);
        result += x * x;
      }
    }
    return result;
  }
};

/* Function specialization in the case of FP16/BF16 tensors.
   Additional optimizations we can make in this case are
   packed and vectorized operations, which help with the
   memory latency bottleneck. */
203
204
205
206
207
208
209
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
fused_add_rms_norm_kernel(
    scalar_t* __restrict__ input,         // [..., hidden_size]
    scalar_t* __restrict__ residual,      // [..., hidden_size]
    const scalar_t* __restrict__ weight,  // [hidden_size]
    const float epsilon, const int num_tokens, const int hidden_size) {
210
211
212
213
214
215
216
217
218
219
  // Sanity checks on our vector struct and type-punned pointer arithmetic
  static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
  static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);

  const int vec_hidden_size = hidden_size / width;
  __shared__ float s_variance;
  float variance = 0.0f;
  /* These and the argument pointers are all declared `restrict` as they are
     not aliased in practice. Argument pointers should not be dereferenced
     in this kernel as that would be undefined behavior */
220
221
222
223
224
225
  auto* __restrict__ input_v =
      reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
  auto* __restrict__ residual_v =
      reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
  auto* __restrict__ weight_v =
      reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
226
227
228
229
230
231
232
233
234

  for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
    int id = blockIdx.x * vec_hidden_size + idx;
    _f16Vec<scalar_t, width> temp = input_v[id];
    temp += residual_v[id];
    variance += temp.sum_squares();
    residual_v[id] = temp;
  }
  /* Keep the following if-else block in sync with the
235
     calculation of max_block_size in fused_add_rms_norm */
236
237
  if (num_tokens < 256) {
    variance = blockReduceSum<float, 1024>(variance);
238
239
  } else
    variance = blockReduceSum<float, 256>(variance);
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
  if (threadIdx.x == 0) {
    s_variance = rsqrtf(variance / hidden_size + epsilon);
  }
  __syncthreads();

  for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
    int id = blockIdx.x * vec_hidden_size + idx;
    _f16Vec<scalar_t, width> temp = residual_v[id];
    temp *= s_variance;
    temp *= weight_v[idx];
    input_v[id] = temp;
  }
}

/* Generic fused_add_rms_norm_kernel
   The width field is not used here but necessary for other specializations.
 */
257
258
259
260
261
262
263
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
fused_add_rms_norm_kernel(
    scalar_t* __restrict__ input,         // [..., hidden_size]
    scalar_t* __restrict__ residual,      // [..., hidden_size]
    const scalar_t* __restrict__ weight,  // [hidden_size]
    const float epsilon, const int num_tokens, const int hidden_size) {
264
265
266
267
  __shared__ float s_variance;
  float variance = 0.0f;

  for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
268
269
    scalar_t z = input[blockIdx.x * hidden_size + idx];
    z += residual[blockIdx.x * hidden_size + idx];
270
    float x = (float)z;
271
    variance += x * x;
272
    residual[blockIdx.x * hidden_size + idx] = z;
273
  }
274
  /* Keep the following if-else block in sync with the
275
     calculation of max_block_size in fused_add_rms_norm */
276
277
  if (num_tokens < 256) {
    variance = blockReduceSum<float, 1024>(variance);
278
279
  } else
    variance = blockReduceSum<float, 256>(variance);
280
281
282
283
284
285
  if (threadIdx.x == 0) {
    s_variance = rsqrtf(variance / hidden_size + epsilon);
  }
  __syncthreads();

  for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
286
287
288
    float x = (float)residual[blockIdx.x * hidden_size + idx];
    input[blockIdx.x * hidden_size + idx] =
        ((scalar_t)(x * s_variance)) * weight[idx];
289
290
291
  }
}

292
}  // namespace vllm
293

zhangshao's avatar
zhangshao committed
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
template <typename T,int reducesize=C10_WARP_SIZE>
__inline__ __device__ T WarpReduceSum_NEW(T val) {
#pragma unroll
  for (int offset = reducesize/2; offset > 0; offset >>= 1) {
    val += WARP_SHFL_DOWN(val, offset);
  }
  return val;
}

template <typename T,int block_size=512>
__inline__ __device__ T BlockReduceSum_NEW(T val, T* shared) {
  constexpr int share_size=block_size/C10_WARP_SIZE;
  val = WarpReduceSum_NEW<T>(val);
  if constexpr(block_size==C10_WARP_SIZE)
  {
    return val;
  }
  else{
    const int lid = threadIdx.x % C10_WARP_SIZE;
    const int wid = threadIdx.x / C10_WARP_SIZE;
    __syncthreads();
    if (lid == 0&&wid<share_size) {
      shared[wid] = val;
    }
    __syncthreads();
    if (wid == 0&&lid<share_size) {
      val = WarpReduceSum_NEW<T,share_size>(shared[lid]);
    }
    return val;
  }
}

template <typename scalar_t,typename T_ACC,int Vec=4,int block_size=512>
__global__ void fused_add_rms_kernel_eval(scalar_t* input,scalar_t* residual,scalar_t* gamma,int cols,T_ACC eps)
{
  constexpr int share_size=block_size/C10_WARP_SIZE;
  __shared__ T_ACC val_shared[share_size];
  __shared__ T_ACC s_rstd;
  T_ACC val=0;
  int i=blockIdx.x;
  int j=threadIdx.x;
  int tcol=cols/Vec;
  if(j>=tcol)return;
  using LoadT = at::native::memory::aligned_vector<scalar_t, Vec>;
  scalar_t intput_vec[Vec];
  scalar_t residual_vec[Vec];
  T_ACC trstd;
  int idx = i * tcol + j;
  idx*=Vec;
  *(LoadT*)intput_vec = *(LoadT*)(input+idx);
  *(LoadT*)residual_vec = *(LoadT*)(residual+idx);
  #pragma unroll
  for (int ii = 0; ii < Vec; ii++) {
    residual_vec[ii]+=intput_vec[ii];
    val += static_cast<T_ACC>(residual_vec[ii])*static_cast<T_ACC>(residual_vec[ii]);
  }
  val = BlockReduceSum_NEW<T_ACC,block_size>(val,val_shared);
  if (j == 0) s_rstd=c10::cuda::compat::rsqrt(val/cols + eps);
  __syncthreads();
  trstd=s_rstd;
  #pragma unroll
  for(int ii=0;ii<Vec;ii++){
    int jj=j*Vec+ii;
    intput_vec[ii] = static_cast<T_ACC>(residual_vec[ii]) * trstd * static_cast<T_ACC>(gamma[jj]);
  }
  *(LoadT*)(residual+idx)=*(LoadT*)residual_vec;
  *(LoadT*)(input+idx)=*(LoadT*)intput_vec;
}

template <typename scalar_t,typename T_ACC,int Vec=4,int block_size=512>
__global__ void fused_rms_kernel_eval(scalar_t* input,scalar_t* output,scalar_t* gamma,int cols,T_ACC eps)
{
  constexpr int share_size=block_size/C10_WARP_SIZE;
  __shared__ T_ACC val_shared[share_size];
  __shared__ T_ACC s_rstd;
  T_ACC val=0;
  int i=blockIdx.x;
  int j=threadIdx.x;
  int tcol=cols/Vec;
  if(j>=tcol)return;
  using LoadT = at::native::memory::aligned_vector<scalar_t, Vec>;
  scalar_t intput_vec[Vec];
  T_ACC trstd;
  int idx = i * tcol + j;
  idx*=Vec;
  *(LoadT*)intput_vec = *(LoadT*)(input+idx);
  #pragma unroll
  for (int ii = 0; ii < Vec; ii++) {
    val += static_cast<T_ACC>(intput_vec[ii])*static_cast<T_ACC>(intput_vec[ii]);
  }
  val = BlockReduceSum_NEW<T_ACC,block_size>(val,val_shared);
  if (j == 0) s_rstd=c10::cuda::compat::rsqrt(val/cols + eps);
  __syncthreads();
  trstd=s_rstd;
  #pragma unroll
  for(int ii=0;ii<Vec;ii++){
    int jj=j*Vec+ii;
    intput_vec[ii] = static_cast<T_ACC>(intput_vec[ii]) * trstd * static_cast<T_ACC>(gamma[jj]);
  }
  *(LoadT*)(output+idx)=*(LoadT*)intput_vec;
}

396
397
398
void rms_norm(torch::Tensor& out,     // [..., hidden_size]
              torch::Tensor& input,   // [..., hidden_size]
              torch::Tensor& weight,  // [hidden_size]
399
              double epsilon) {
400
401
  int hidden_size = input.size(-1);
  int num_tokens = input.numel() / hidden_size;
402
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
403
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
zhangshao's avatar
zhangshao committed
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
  if(hidden_size%16==0&&hidden_size>=2048&&hidden_size<=8192){
  AT_DISPATCH_FLOATING_TYPES_AND2(
    at::ScalarType::Half,
    at::ScalarType::BFloat16,
    input.scalar_type(),
    "fused_add_rms_norm_kernel",
    [&] {
      using T_ACC = at::acc_type<scalar_t, true>;
      T_ACC eps = epsilon;
      scalar_t* self_data = input.data_ptr<scalar_t>();
      scalar_t* out_data =out.data_ptr<scalar_t>();
      scalar_t* weight_data=weight.data_ptr<scalar_t>();
      if(hidden_size==2048){
          fused_rms_kernel_eval<scalar_t,T_ACC,2,1024><<<num_tokens,  1024, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
      }
      else if(hidden_size<=4096){
          fused_rms_kernel_eval<scalar_t,T_ACC,4,1024><<<num_tokens,  1024, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
      }
      else{
          fused_rms_kernel_eval<scalar_t,T_ACC,8,1024><<<num_tokens,  1024, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
      } 
    });
  }
  else{
    dim3 grid(num_tokens);
    dim3 block(std::min(hidden_size, 1024));
  
    VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
      vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
          out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
          weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
    });
  }
437
}
438

439
440
441
442
443
444
445
446
447
448
#define LAUNCH_FUSED_ADD_RMS_NORM(width)                                       \
  VLLM_DISPATCH_FLOATING_TYPES(                                                \
      input.scalar_type(), "fused_add_rms_norm_kernel", [&] {                  \
        vllm::fused_add_rms_norm_kernel<scalar_t, width>                       \
            <<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),           \
                                         residual.data_ptr<scalar_t>(),        \
                                         weight.data_ptr<scalar_t>(), epsilon, \
                                         num_tokens, hidden_size);             \
      });

zhangshao's avatar
zhangshao committed
449
450


451
452
453
void fused_add_rms_norm(torch::Tensor& input,     // [..., hidden_size]
                        torch::Tensor& residual,  // [..., hidden_size]
                        torch::Tensor& weight,    // [hidden_size]
454
                        double epsilon) {
455
456
  int hidden_size = input.size(-1);
  int num_tokens = input.numel() / hidden_size;
457
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
458
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
zhangshao's avatar
zhangshao committed
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
  if(hidden_size%16==0&&hidden_size>=2048&&hidden_size<=8192){
    AT_DISPATCH_FLOATING_TYPES_AND2(
      at::ScalarType::Half,
      at::ScalarType::BFloat16,
      input.scalar_type(),
      "fused_add_rms_norm_kernel",
      [&] {
        using T_ACC = at::acc_type<scalar_t, true>;
        T_ACC eps = epsilon;
        scalar_t* self_data = input.data_ptr<scalar_t>();
        scalar_t* other_data =residual.data_ptr<scalar_t>();
        scalar_t* weight_data=weight.data_ptr<scalar_t>();
        if(hidden_size==2048){
            fused_add_rms_kernel_eval<scalar_t,T_ACC,2,1024><<<num_tokens,  1024, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
        }
        else if(hidden_size<=4096){
            fused_add_rms_kernel_eval<scalar_t,T_ACC,4,1024><<<num_tokens,  1024, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
        }
        else{
            fused_add_rms_kernel_eval<scalar_t,T_ACC,8,1024><<<num_tokens,  1024, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
        } 
      });
  }
  else{
    dim3 grid(num_tokens);
    /* This kernel is memory-latency bound in many scenarios.
      When num_tokens is large, a smaller block size allows
      for increased block occupancy on CUs and better latency
      hiding on global mem ops. */
    const int max_block_size = (num_tokens < 256) ? 1024 : 256;
    dim3 block(std::min(hidden_size, max_block_size));
    /*If the tensor types are FP16/BF16, try to use the optimized kernel
      with packed + vectorized ops.
      Max optimization is achieved with a width-8 vector of FP16/BF16s
      since we can load at most 128 bits at once in a global memory op.
      However, this requires each tensor's data to be aligned to 16
      bytes.
    */
    auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
    auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
    auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
    bool ptrs_are_aligned =
        inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
    if (ptrs_are_aligned && hidden_size % 8 == 0) {
      LAUNCH_FUSED_ADD_RMS_NORM(8);
    } else {
      LAUNCH_FUSED_ADD_RMS_NORM(0);
    }
507
  }
508
}