layernorm_kernels.cu 19.8 KB
Newer Older
1
#include <torch/all.h>
2
#include <ATen/cuda/CUDAContext.h>
3
#include <c10/cuda/CUDAGuard.h>
zhangshao's avatar
zhangshao committed
4
5
6
7
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <c10/cuda/CUDAMathCompat.h>
#include <ATen/AccumulateType.h>
#include <THC/THCDeviceUtils.cuh>
8
#include "dispatch_utils.h"
Woosuk Kwon's avatar
Woosuk Kwon committed
9
#include "reduction_utils.cuh"
10
11
12
13
14
15
16
#ifndef USE_ROCM
  #include <cuda_bf16.h>
  #include <cuda_fp16.h>
#else
  #include <hip/hip_bf16.h>
  #include <hip/hip_fp16.h>

17
18
using __nv_bfloat16 = __hip_bfloat16;
using __nv_bfloat162 = __hip_bfloat162;
19
#endif
20
21
22
23
24
25
26
27
28
29
static inline bool get_env_(const char *env_var) {
  if (char *value = std::getenv(env_var)) {
    if (strcmp(value, "0") == 0) {
      return false;
    }
    return true;
  }
  return false;
}
static const bool use_old= get_env_("USE_VLLM_OLD_OP");
Woosuk Kwon's avatar
Woosuk Kwon committed
30
namespace vllm {
31
32

// TODO(woosuk): Further optimize this kernel.
33
template <typename scalar_t>
34
__global__ void rms_norm_kernel(
35
36
37
38
    scalar_t* __restrict__ out,           // [..., hidden_size]
    const scalar_t* __restrict__ input,   // [..., hidden_size]
    const scalar_t* __restrict__ weight,  // [hidden_size]
    const float epsilon, const int num_tokens, const int hidden_size) {
39
40
41
42
  __shared__ float s_variance;
  float variance = 0.0f;

  for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
43
    const float x = (float)input[blockIdx.x * hidden_size + idx];
44
45
46
47
48
49
50
51
52
    variance += x * x;
  }
  variance = blockReduceSum<float>(variance);
  if (threadIdx.x == 0) {
    s_variance = rsqrtf(variance / hidden_size + epsilon);
  }
  __syncthreads();

  for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
53
54
55
    float x = (float)input[blockIdx.x * hidden_size + idx];
    out[blockIdx.x * hidden_size + idx] =
        ((scalar_t)(x * s_variance)) * weight[idx];
56
57
58
  }
}

59
60
61
62
63
64
65
66
/* Converter structs for the conversion from torch types to HIP/CUDA types,
   and the associated type conversions within HIP/CUDA. These helpers need
   to be implemented for now because the relevant type conversion
   operators/constructors are not consistently implemented by HIP/CUDA, so
   a generic conversion via type casts cannot be implemented.

   Each struct should have the member static constexpr bool `exists`:
   If false, the optimized kernel is not used for the corresponding torch type.
67
   If true, the struct should be fully defined as shown in the examples below.
68
 */
69
70
71
72
template <typename torch_type>
struct _typeConvert {
  static constexpr bool exists = false;
};
73

74
75
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
// CUDA < 12.0 runs into issues with packed type conversion
76
template <>
77
78
79
80
81
82
struct _typeConvert<c10::Half> {
  static constexpr bool exists = true;
  using hip_type = __half;
  using packed_hip_type = __half2;

  __device__ static inline float convert(hip_type x) { return __half2float(x); }
83
84
85
86
87
88
89
90
91
  __device__ static inline float2 convert(packed_hip_type x) {
    return __half22float2(x);
  }
  __device__ static inline hip_type convert(float x) {
    return __float2half_rn(x);
  }
  __device__ static inline packed_hip_type convert(float2 x) {
    return __float22half2_rn(x);
  }
92
93
};

94
  #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
95
96
// CUDA_ARCH < 800 does not have BF16 support
// TODO: Add in ROCm support once public headers handle bf16 maturely
97
template <>
98
99
100
101
102
struct _typeConvert<c10::BFloat16> {
  static constexpr bool exists = true;
  using hip_type = __nv_bfloat16;
  using packed_hip_type = __nv_bfloat162;

103
104
105
106
107
108
109
110
111
112
113
114
  __device__ static inline float convert(hip_type x) {
    return __bfloat162float(x);
  }
  __device__ static inline float2 convert(packed_hip_type x) {
    return __bfloat1622float2(x);
  }
  __device__ static inline hip_type convert(float x) {
    return __float2bfloat16(x);
  }
  __device__ static inline packed_hip_type convert(float2 x) {
    return __float22bfloat162_rn(x);
  }
115
};
116
117
118
  #endif  // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#endif    // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
          // 12000))
119
120
121
122
123
124

/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
   for appropriate specializations of fused_add_rms_norm_kernel.
   Only functions that are necessary in that kernel are implemented.
   Alignment to 16 bytes is required to use 128-bit global memory ops.
 */
125
template <typename scalar_t, int width>
126
struct alignas(16) _f16Vec {
127
128
  /* Not theoretically necessary that width is a power of 2 but should
     almost always be the case for optimization purposes */
129
130
131
132
133
134
135
136
137
  static_assert(width > 0 && (width & (width - 1)) == 0,
                "Width is not a positive power of 2!");
  using Converter = _typeConvert<scalar_t>;
  using T1 = typename Converter::hip_type;
  using T2 = typename Converter::packed_hip_type;
  T1 data[width];

  __device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) {
    if constexpr (width % 2 == 0) {
138
#pragma unroll
139
      for (int i = 0; i < width; i += 2) {
140
141
        T2 temp{data[i], data[i + 1]};
        temp += T2{other.data[i], other.data[i + 1]};
142
        data[i] = temp.x;
143
        data[i + 1] = temp.y;
144
145
      }
    } else {
146
147
#pragma unroll
      for (int i = 0; i < width; ++i) data[i] += other.data[i];
148
149
150
151
152
153
    }
    return *this;
  }

  __device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) {
    if constexpr (width % 2 == 0) {
154
#pragma unroll
155
      for (int i = 0; i < width; i += 2) {
156
157
        T2 temp{data[i], data[i + 1]};
        temp *= T2{other.data[i], other.data[i + 1]};
158
        data[i] = temp.x;
159
        data[i + 1] = temp.y;
160
161
      }
    } else {
162
163
#pragma unroll
      for (int i = 0; i < width; ++i) data[i] *= other.data[i];
164
165
166
167
168
169
    }
    return *this;
  }

  __device__ _f16Vec& operator*=(const float scale) {
    if constexpr (width % 2 == 0) {
170
#pragma unroll
171
      for (int i = 0; i < width; i += 2) {
172
        float2 temp_f = Converter::convert(T2{data[i], data[i + 1]});
173
174
175
176
        temp_f.x *= scale;
        temp_f.y *= scale;
        T2 temp = Converter::convert(temp_f);
        data[i] = temp.x;
177
        data[i + 1] = temp.y;
178
179
      }
    } else {
180
#pragma unroll
181
182
183
184
185
186
187
188
189
190
191
      for (int i = 0; i < width; ++i) {
        float temp = Converter::convert(data[i]) * scale;
        data[i] = Converter::convert(temp);
      }
    }
    return *this;
  }

  __device__ float sum_squares() const {
    float result = 0.0f;
    if constexpr (width % 2 == 0) {
192
#pragma unroll
193
      for (int i = 0; i < width; i += 2) {
194
        float2 z = Converter::convert(T2{data[i], data[i + 1]});
195
196
197
        result += z.x * z.x + z.y * z.y;
      }
    } else {
198
#pragma unroll
199
200
201
202
203
204
205
206
207
208
209
210
211
      for (int i = 0; i < width; ++i) {
        float x = Converter::convert(data[i]);
        result += x * x;
      }
    }
    return result;
  }
};

/* Function specialization in the case of FP16/BF16 tensors.
   Additional optimizations we can make in this case are
   packed and vectorized operations, which help with the
   memory latency bottleneck. */
212
213
214
215
216
217
218
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
fused_add_rms_norm_kernel(
    scalar_t* __restrict__ input,         // [..., hidden_size]
    scalar_t* __restrict__ residual,      // [..., hidden_size]
    const scalar_t* __restrict__ weight,  // [hidden_size]
    const float epsilon, const int num_tokens, const int hidden_size) {
219
220
221
222
223
224
225
226
227
228
  // Sanity checks on our vector struct and type-punned pointer arithmetic
  static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
  static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);

  const int vec_hidden_size = hidden_size / width;
  __shared__ float s_variance;
  float variance = 0.0f;
  /* These and the argument pointers are all declared `restrict` as they are
     not aliased in practice. Argument pointers should not be dereferenced
     in this kernel as that would be undefined behavior */
229
230
231
232
233
234
  auto* __restrict__ input_v =
      reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
  auto* __restrict__ residual_v =
      reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
  auto* __restrict__ weight_v =
      reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
235
236
237
238
239
240
241
242
243

  for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
    int id = blockIdx.x * vec_hidden_size + idx;
    _f16Vec<scalar_t, width> temp = input_v[id];
    temp += residual_v[id];
    variance += temp.sum_squares();
    residual_v[id] = temp;
  }
  /* Keep the following if-else block in sync with the
244
     calculation of max_block_size in fused_add_rms_norm */
245
246
  if (num_tokens < 256) {
    variance = blockReduceSum<float, 1024>(variance);
247
248
  } else
    variance = blockReduceSum<float, 256>(variance);
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
  if (threadIdx.x == 0) {
    s_variance = rsqrtf(variance / hidden_size + epsilon);
  }
  __syncthreads();

  for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
    int id = blockIdx.x * vec_hidden_size + idx;
    _f16Vec<scalar_t, width> temp = residual_v[id];
    temp *= s_variance;
    temp *= weight_v[idx];
    input_v[id] = temp;
  }
}

/* Generic fused_add_rms_norm_kernel
   The width field is not used here but necessary for other specializations.
 */
266
267
268
269
270
271
272
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
fused_add_rms_norm_kernel(
    scalar_t* __restrict__ input,         // [..., hidden_size]
    scalar_t* __restrict__ residual,      // [..., hidden_size]
    const scalar_t* __restrict__ weight,  // [hidden_size]
    const float epsilon, const int num_tokens, const int hidden_size) {
273
274
275
276
  __shared__ float s_variance;
  float variance = 0.0f;

  for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
277
278
    scalar_t z = input[blockIdx.x * hidden_size + idx];
    z += residual[blockIdx.x * hidden_size + idx];
279
    float x = (float)z;
280
    variance += x * x;
281
    residual[blockIdx.x * hidden_size + idx] = z;
282
  }
283
  /* Keep the following if-else block in sync with the
284
     calculation of max_block_size in fused_add_rms_norm */
285
286
  if (num_tokens < 256) {
    variance = blockReduceSum<float, 1024>(variance);
287
288
  } else
    variance = blockReduceSum<float, 256>(variance);
289
290
291
292
293
294
  if (threadIdx.x == 0) {
    s_variance = rsqrtf(variance / hidden_size + epsilon);
  }
  __syncthreads();

  for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
295
296
297
    float x = (float)residual[blockIdx.x * hidden_size + idx];
    input[blockIdx.x * hidden_size + idx] =
        ((scalar_t)(x * s_variance)) * weight[idx];
298
299
300
  }
}

301
}  // namespace vllm
302

zhangshao's avatar
zhangshao committed
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
template <typename T,int reducesize=C10_WARP_SIZE>
__inline__ __device__ T WarpReduceSum_NEW(T val) {
#pragma unroll
  for (int offset = reducesize/2; offset > 0; offset >>= 1) {
    val += WARP_SHFL_DOWN(val, offset);
  }
  return val;
}

template <typename T,int block_size=512>
__inline__ __device__ T BlockReduceSum_NEW(T val, T* shared) {
  constexpr int share_size=block_size/C10_WARP_SIZE;
  val = WarpReduceSum_NEW<T>(val);
  if constexpr(block_size==C10_WARP_SIZE)
  {
    return val;
  }
  else{
    const int lid = threadIdx.x % C10_WARP_SIZE;
    const int wid = threadIdx.x / C10_WARP_SIZE;
    if (lid == 0&&wid<share_size) {
      shared[wid] = val;
    }
    __syncthreads();
    if (wid == 0&&lid<share_size) {
      val = WarpReduceSum_NEW<T,share_size>(shared[lid]);
    }
    return val;
  }
}

template <typename scalar_t,typename T_ACC,int Vec=4,int block_size=512>
__global__ void fused_add_rms_kernel_eval(scalar_t* input,scalar_t* residual,scalar_t* gamma,int cols,T_ACC eps)
{
  constexpr int share_size=block_size/C10_WARP_SIZE;
  __shared__ T_ACC val_shared[share_size];
  __shared__ T_ACC s_rstd;
  T_ACC val=0;
  int i=blockIdx.x;
  int j=threadIdx.x;
  int tcol=cols/Vec;
  using LoadT = at::native::memory::aligned_vector<scalar_t, Vec>;
  scalar_t intput_vec[Vec];
  scalar_t residual_vec[Vec];
  T_ACC trstd;
  int idx = i * tcol + j;
  idx*=Vec;
  *(LoadT*)intput_vec = *(LoadT*)(input+idx);
  *(LoadT*)residual_vec = *(LoadT*)(residual+idx);
352
353
354
355
356
357
  if (j < tcol) {
    #pragma unroll
    for (int ii = 0; ii < Vec; ii++) {
      residual_vec[ii]+=intput_vec[ii];
      val += static_cast<T_ACC>(residual_vec[ii])*static_cast<T_ACC>(residual_vec[ii]);
    }
zhangshao's avatar
zhangshao committed
358
359
360
361
362
  }
  val = BlockReduceSum_NEW<T_ACC,block_size>(val,val_shared);
  if (j == 0) s_rstd=c10::cuda::compat::rsqrt(val/cols + eps);
  __syncthreads();
  trstd=s_rstd;
363
364
365
366
367
368
369
370
  if (j < tcol) {
    #pragma unroll
    for(int ii=0;ii<Vec;ii++){
      int jj=j*Vec+ii;
      intput_vec[ii] = static_cast<T_ACC>(residual_vec[ii]) *trstd* static_cast<T_ACC>(gamma[jj]);
    }
    *(LoadT*)(residual+idx)=*(LoadT*)residual_vec;
    *(LoadT*)(input+idx)=*(LoadT*)intput_vec;
zhangshao's avatar
zhangshao committed
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
  }
}

template <typename scalar_t,typename T_ACC,int Vec=4,int block_size=512>
__global__ void fused_rms_kernel_eval(scalar_t* input,scalar_t* output,scalar_t* gamma,int cols,T_ACC eps)
{
  constexpr int share_size=block_size/C10_WARP_SIZE;
  __shared__ T_ACC val_shared[share_size];
  __shared__ T_ACC s_rstd;
  T_ACC val=0;
  int i=blockIdx.x;
  int j=threadIdx.x;
  int tcol=cols/Vec;
  using LoadT = at::native::memory::aligned_vector<scalar_t, Vec>;
  scalar_t intput_vec[Vec];
  T_ACC trstd;
  int idx = i * tcol + j;
  idx*=Vec;
  *(LoadT*)intput_vec = *(LoadT*)(input+idx);
390
391
392
393
394
  if (j < tcol) {
    #pragma unroll
    for (int ii = 0; ii < Vec; ii++) {
      val += static_cast<T_ACC>(intput_vec[ii])*static_cast<T_ACC>(intput_vec[ii]);
    }
zhangshao's avatar
zhangshao committed
395
396
397
398
399
  }
  val = BlockReduceSum_NEW<T_ACC,block_size>(val,val_shared);
  if (j == 0) s_rstd=c10::cuda::compat::rsqrt(val/cols + eps);
  __syncthreads();
  trstd=s_rstd;
400
401
402
403
404
405
406
  if (j < tcol) {
    #pragma unroll
    for(int ii=0;ii<Vec;ii++){
      int jj=j*Vec+ii;
      intput_vec[ii] = static_cast<T_ACC>(intput_vec[ii]) *trstd* static_cast<T_ACC>(gamma[jj]);
    }
    *(LoadT*)(output+idx)=*(LoadT*)intput_vec;
zhangshao's avatar
zhangshao committed
407
408
409
  }
}

410
411
412
void rms_norm(torch::Tensor& out,     // [..., hidden_size]
              torch::Tensor& input,   // [..., hidden_size]
              torch::Tensor& weight,  // [hidden_size]
413
              double epsilon) {
414
415
  int hidden_size = input.size(-1);
  int num_tokens = input.numel() / hidden_size;
416
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
417
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
zhangshao's avatar
zhangshao committed
418
419
420
  auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
  auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
  bool ptrs_are_aligned =inp_ptr % 16 == 0  && wt_ptr % 16 == 0;
421
  if(!use_old&&hidden_size%16==0&&hidden_size<=16384&&ptrs_are_aligned){
zhangshao's avatar
zhangshao committed
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
  AT_DISPATCH_FLOATING_TYPES_AND2(
    at::ScalarType::Half,
    at::ScalarType::BFloat16,
    input.scalar_type(),
    "fused_add_rms_norm_kernel",
    [&] {
      using T_ACC = at::acc_type<scalar_t, true>;
      T_ACC eps = epsilon;
      scalar_t* self_data = input.data_ptr<scalar_t>();
      scalar_t* out_data =out.data_ptr<scalar_t>();
      scalar_t* weight_data=weight.data_ptr<scalar_t>();
      if (hidden_size<=1024){
          fused_rms_kernel_eval<scalar_t,T_ACC,8,128><<<num_tokens,  128, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
      }
      else if(hidden_size<=2048){
          fused_rms_kernel_eval<scalar_t,T_ACC,8,256><<<num_tokens,  256, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
      }
      else if(hidden_size<=4096){
          if(num_tokens>1200){
            fused_rms_kernel_eval<scalar_t,T_ACC,8,512><<<num_tokens,  512, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
          }
          else{
            fused_rms_kernel_eval<scalar_t,T_ACC,4,1024><<<num_tokens,  1024, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
          }
      }
      else if(hidden_size<=8192){
           fused_rms_kernel_eval<scalar_t,T_ACC,8,1024><<<num_tokens,  1024, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
      }
      else{
          fused_rms_kernel_eval<scalar_t,T_ACC,16,1024><<<num_tokens,  1024, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
      } 
    });
  }
  else{
    dim3 grid(num_tokens);
    dim3 block(std::min(hidden_size, 1024));
  
    VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
      vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
          out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
          weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
    });
  }
465
}
466

467
468
469
470
471
472
473
474
475
476
#define LAUNCH_FUSED_ADD_RMS_NORM(width)                                       \
  VLLM_DISPATCH_FLOATING_TYPES(                                                \
      input.scalar_type(), "fused_add_rms_norm_kernel", [&] {                  \
        vllm::fused_add_rms_norm_kernel<scalar_t, width>                       \
            <<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),           \
                                         residual.data_ptr<scalar_t>(),        \
                                         weight.data_ptr<scalar_t>(), epsilon, \
                                         num_tokens, hidden_size);             \
      });

zhangshao's avatar
zhangshao committed
477
478


479
480
481
void fused_add_rms_norm(torch::Tensor& input,     // [..., hidden_size]
                        torch::Tensor& residual,  // [..., hidden_size]
                        torch::Tensor& weight,    // [hidden_size]
482
                        double epsilon) {
483
484
  int hidden_size = input.size(-1);
  int num_tokens = input.numel() / hidden_size;
485
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
486
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
zhuwenwen's avatar
zhuwenwen committed
487
488
489
  auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
  auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
  auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
zhangshao's avatar
zhangshao committed
490
  bool ptrs_are_aligned =inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
491
  if(!use_old&&hidden_size%16==0&&hidden_size>=2048&&hidden_size<=8192&&ptrs_are_aligned){
zhangshao's avatar
zhangshao committed
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
    AT_DISPATCH_FLOATING_TYPES_AND2(
      at::ScalarType::Half,
      at::ScalarType::BFloat16,
      input.scalar_type(),
      "fused_add_rms_norm_kernel",
      [&] {
        using T_ACC = at::acc_type<scalar_t, true>;
        T_ACC eps = epsilon;
        scalar_t* self_data = input.data_ptr<scalar_t>();
        scalar_t* other_data =residual.data_ptr<scalar_t>();
        scalar_t* weight_data=weight.data_ptr<scalar_t>();
        if (hidden_size<=1024){
            fused_add_rms_kernel_eval<scalar_t,T_ACC,8,128><<<num_tokens,  128, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
        }
        else if(hidden_size<=2048){
            fused_add_rms_kernel_eval<scalar_t,T_ACC,8,256><<<num_tokens,  256, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
        }
        else if(hidden_size<=4096){
            if(num_tokens>1200){
              fused_add_rms_kernel_eval<scalar_t,T_ACC,8,512><<<num_tokens,  512, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
            }
            else{
              fused_add_rms_kernel_eval<scalar_t,T_ACC,4,1024><<<num_tokens,  1024, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
            }
        }
        else if(hidden_size<=8192){
            fused_add_rms_kernel_eval<scalar_t,T_ACC,8,1024><<<num_tokens,  1024, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
        }
        else{
            fused_add_rms_kernel_eval<scalar_t,T_ACC,16,1024><<<num_tokens,  1024, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
        } 
      });
zhangshao's avatar
zhangshao committed
524
  }
zhangshao's avatar
zhangshao committed
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
  else{
    dim3 grid(num_tokens);
    /* This kernel is memory-latency bound in many scenarios.
      When num_tokens is large, a smaller block size allows
      for increased block occupancy on CUs and better latency
      hiding on global mem ops. */
    const int max_block_size = (num_tokens < 256) ? 1024 : 256;
    dim3 block(std::min(hidden_size, max_block_size));
    /*If the tensor types are FP16/BF16, try to use the optimized kernel
      with packed + vectorized ops.
      Max optimization is achieved with a width-8 vector of FP16/BF16s
      since we can load at most 128 bits at once in a global memory op.
      However, this requires each tensor's data to be aligned to 16
      bytes.
    */

    if (ptrs_are_aligned && hidden_size % 8 == 0) {
      LAUNCH_FUSED_ADD_RMS_NORM(8);
    } else {
      LAUNCH_FUSED_ADD_RMS_NORM(0);
    }
  }
}