activation_kernels.cu 26.7 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
#include <ATen/cuda/CUDAContext.h>
2
#include <torch/all.h>
3
#include <c10/cuda/CUDAGuard.h>
Woosuk Kwon's avatar
Woosuk Kwon committed
4

5
6
#include <cmath>

7
#include "cuda_compat.h"
8
#include "cuda_vec_utils.cuh"
9
10
#include "dispatch_utils.h"

Woosuk Kwon's avatar
Woosuk Kwon committed
11
namespace vllm {
Woosuk Kwon's avatar
Woosuk Kwon committed
12

13
14
15
16
17
18
19
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
          bool act_first>
__device__ __forceinline__ scalar_t compute(const scalar_t& x,
                                            const scalar_t& y) {
  return act_first ? ACT_FN(x) * y : x * ACT_FN(y);
}

20
21
22
23
24
25
26
27
template <typename packed_t, packed_t (*PACKED_ACT_FN)(const packed_t&),
          bool act_first>
__device__ __forceinline__ packed_t packed_compute(const packed_t& x,
                                                   const packed_t& y) {
  return act_first ? packed_mul(PACKED_ACT_FN(x), y)
                   : packed_mul(x, PACKED_ACT_FN(y));
}

28
// Activation and gating kernel template.
29
30
31
32
template <typename scalar_t, typename packed_t,
          scalar_t (*ACT_FN)(const scalar_t&),
          packed_t (*PACKED_ACT_FN)(const packed_t&), bool act_first,
          bool use_vec, bool use_256b = false>
33
__global__ void act_and_mul_kernel(
34
35
36
    scalar_t* __restrict__ out,          // [..., d]
    const scalar_t* __restrict__ input,  // [..., 2, d]
    const int d) {
37
  const scalar_t* x_ptr = input + blockIdx.x * 2 * d;
38
  const scalar_t* y_ptr = x_ptr + d;
39
  scalar_t* out_ptr = out + blockIdx.x * d;
40

41
  if constexpr (use_vec) {
42
43
    using cuda_t = typename CUDATypeConverter<scalar_t>::Type;
    using pvec_t = PackedVec<cuda_t, use_256b>;
44

45
46
47
48
    const pvec_t* x_vec = reinterpret_cast<const pvec_t*>(x_ptr);
    const pvec_t* y_vec = reinterpret_cast<const pvec_t*>(y_ptr);
    pvec_t* out_vec = reinterpret_cast<pvec_t*>(out_ptr);
    const int num_vecs = d / 2 / pvec_t::NUM_ELTS;
49
50

    for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
51
      pvec_t x, y;
52
53
54
55
      if constexpr (use_256b) {
        ld256(x, &x_vec[i]);
        ld256(y, &y_vec[i]);
      } else {
56
57
        ld128(x, &x_vec[i]);
        ld128(y, &y_vec[i]);
58
      }
59
#pragma unroll
60
61
62
      for (int j = 0; j < pvec_t::NUM_ELTS; j++) {
        x.elts[j] = packed_compute<packed_t, PACKED_ACT_FN, act_first>(
            x.elts[j], y.elts[j]);
63
64
65
66
      }
      if constexpr (use_256b) {
        st256(x, &out_vec[i]);
      } else {
67
        st128(x, &out_vec[i]);
68
69
70
71
72
73
74
75
76
      }
    }
  } else {
    // Scalar fallback for unaligned data or small d
    for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
      const scalar_t x = VLLM_LDG(&x_ptr[idx]);
      const scalar_t y = VLLM_LDG(&y_ptr[idx]);
      out_ptr[idx] = compute<scalar_t, ACT_FN, act_first>(x, y);
    }
Woosuk Kwon's avatar
Woosuk Kwon committed
77
78
79
  }
}

80
template <typename T>
81
82
__device__ __forceinline__ T silu_kernel(const T& x) {
  // x * sigmoid(x)
83
  return (T)(((float)x) / (1.0f + expf((float)-x)));
84
85
}

86
87
88
89
90
91
92
93
94
template <typename packed_t>
__device__ __forceinline__ packed_t packed_silu_kernel(const packed_t& val) {
  // x * sigmoid(x)
  float2 fval = cast_to_float2(val);
  fval.x = fval.x / (1.0f + expf(-fval.x));
  fval.y = fval.y / (1.0f + expf(-fval.y));
  return cast_to_packed<packed_t>(fval);
}

95
template <typename T>
96
97
98
__device__ __forceinline__ T gelu_kernel(const T& x) {
  // Equivalent to PyTorch GELU with 'none' approximation.
  // Refer to:
99
  // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
100
  const float f = (float)x;
101
  constexpr float ALPHA = M_SQRT1_2;
102
  return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA)));
103
104
}

105
106
107
108
109
110
111
112
113
114
115
116
template <typename packed_t>
__device__ __forceinline__ packed_t packed_gelu_kernel(const packed_t& val) {
  // Equivalent to PyTorch GELU with 'none' approximation.
  // Refer to:
  // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
  constexpr float ALPHA = M_SQRT1_2;
  float2 fval = cast_to_float2(val);
  fval.x = fval.x * 0.5f * (1.0f + ::erf(fval.x * ALPHA));
  fval.y = fval.y * 0.5f * (1.0f + ::erf(fval.y * ALPHA));
  return cast_to_packed<packed_t>(fval);
}

117
template <typename T>
118
119
120
121
__device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
  // Equivalent to PyTorch GELU with 'tanh' approximation.
  // Refer to:
  // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30
122
  const float f = (float)x;
123
124
125
126
  constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f;
  constexpr float KAPPA = 0.044715;
  float x_cube = f * f * f;
  float inner = BETA * (f + KAPPA * x_cube);
127
  return (T)(0.5f * f * (1.0f + ::tanhf(inner)));
128
129
}

130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
template <typename packed_t>
__device__ __forceinline__ packed_t
packed_gelu_tanh_kernel(const packed_t& val) {
  // Equivalent to PyTorch GELU with 'tanh' approximation.
  // Refer to:
  // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30
  float2 fval = cast_to_float2(val);
  constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f;
  constexpr float KAPPA = 0.044715;

  float x_cube = fval.x * fval.x * fval.x;
  float inner = BETA * (fval.x + KAPPA * x_cube);
  fval.x = 0.5f * fval.x * (1.0f + ::tanhf(inner));

  x_cube = fval.y * fval.y * fval.y;
  inner = BETA * (fval.y + KAPPA * x_cube);
  fval.y = 0.5f * fval.y * (1.0f + ::tanhf(inner));
  return cast_to_packed<packed_t>(fval);
}

150
}  // namespace vllm
Woosuk Kwon's avatar
Woosuk Kwon committed
151

152
// Launch activation and gating kernel.
153
154
// Use ACT_FIRST (bool) indicating whether to apply the activation function
// first.
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, PACKED_KERNEL, ACT_FIRST)        \
  auto dtype = input.scalar_type();                                            \
  int d = input.size(-1) / 2;                                                  \
  int64_t num_tokens = input.numel() / input.size(-1);                         \
  if (num_tokens == 0) {                                                       \
    return;                                                                    \
  }                                                                            \
  dim3 grid(num_tokens);                                                       \
  int cc_major = at::cuda::getCurrentDeviceProperties()->major;                \
  int support_vec =                                                            \
      (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128)            \
          ? vllm::VecTraits<true>::ARCH_MAX_VEC_SIZE                           \
          : vllm::VecTraits<false>::ARCH_MAX_VEC_SIZE;                         \
  int vec_size = support_vec / at::elementSize(dtype);                         \
  const bool use_vec = (d % vec_size == 0);                                    \
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));            \
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();                \
  if (use_vec) {                                                               \
    dim3 block(std::min(d / vec_size, 1024));                                  \
    if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) {         \
      VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] {          \
        vllm::act_and_mul_kernel<                                              \
            scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type,      \
            KERNEL<scalar_t>,                                                  \
            PACKED_KERNEL<typename vllm::PackedTypeConverter<scalar_t>::Type>, \
            ACT_FIRST, true, true><<<grid, block, 0, stream>>>(                \
            out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d);          \
      });                                                                      \
    } else {                                                                   \
      VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] {          \
        vllm::act_and_mul_kernel<                                              \
            scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type,      \
            KERNEL<scalar_t>,                                                  \
            PACKED_KERNEL<typename vllm::PackedTypeConverter<scalar_t>::Type>, \
            ACT_FIRST, true, false><<<grid, block, 0, stream>>>(               \
            out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d);          \
      });                                                                      \
    }                                                                          \
  } else {                                                                     \
    dim3 block(std::min(d, 1024));                                             \
    VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] {            \
      vllm::act_and_mul_kernel<                                                \
          scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type,        \
          KERNEL<scalar_t>,                                                    \
          PACKED_KERNEL<typename vllm::PackedTypeConverter<scalar_t>::Type>,   \
          ACT_FIRST, false><<<grid, block, 0, stream>>>(                       \
          out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d);            \
    });                                                                        \
203
  }
204
205
206

void silu_and_mul(torch::Tensor& out,    // [..., d]
                  torch::Tensor& input)  // [..., 2 * d]
Woosuk Kwon's avatar
Woosuk Kwon committed
207
{
208
209
  LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, vllm::packed_silu_kernel,
                                true);
210
211
212
213
214
215
216
}

void mul_and_silu(torch::Tensor& out,    // [..., d]
                  torch::Tensor& input)  // [..., 2 * d]
{
  // The difference between mul_and_silu and silu_and_mul is that mul_and_silu
  // applies the silu to the latter half of the input.
217
218
  LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, vllm::packed_silu_kernel,
                                false);
219
}
Woosuk Kwon's avatar
Woosuk Kwon committed
220

221
222
void gelu_and_mul(torch::Tensor& out,    // [..., d]
                  torch::Tensor& input)  // [..., 2 * d]
223
{
224
225
  LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel, vllm::packed_gelu_kernel,
                                true);
Woosuk Kwon's avatar
Woosuk Kwon committed
226
}
227

228
229
void gelu_tanh_and_mul(torch::Tensor& out,    // [..., d]
                       torch::Tensor& input)  // [..., 2 * d]
230
{
231
232
  LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel,
                                vllm::packed_gelu_tanh_kernel, true);
233
234
}

235
236
namespace vllm {

237
238
239
240
241
242
template <typename T>
__device__ __forceinline__ T fatrelu_kernel(const T& x, const float threshold) {
  const float f = (float)x;
  return (T)(f > threshold ? f : 0.0f);
}

243
244
245
246
247
248
249
250
251
252
253
254
255
template <typename packed_t>
__device__ __forceinline__ packed_t
packed_fatrelu_kernel(const packed_t& val, const float threshold) {
  float2 fval = cast_to_float2(val);
  fval.x = fval.x > threshold ? fval.x : 0.0f;
  fval.y = fval.y > threshold ? fval.y : 0.0f;
  return cast_to_packed<packed_t>(fval);
}

template <typename scalar_t, typename packed_t,
          scalar_t (*ACT_FN)(const scalar_t&, const float),
          packed_t (*PACKED_ACT_FN)(const packed_t&, const float), bool use_vec,
          bool use_256b = false>
256
257
258
__global__ void act_and_mul_kernel_with_param(
    scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const int d,
    const float param) {
259
  const scalar_t* x_ptr = input + blockIdx.x * 2 * d;
260
  const scalar_t* y_ptr = x_ptr + d;
261
  scalar_t* out_ptr = out + blockIdx.x * d;
262

263
  if constexpr (use_vec) {
264
265
    using cuda_t = typename CUDATypeConverter<scalar_t>::Type;
    using pvec_t = PackedVec<cuda_t, use_256b>;
266

267
268
269
270
    const pvec_t* x_vec = reinterpret_cast<const pvec_t*>(x_ptr);
    const pvec_t* y_vec = reinterpret_cast<const pvec_t*>(y_ptr);
    pvec_t* out_vec = reinterpret_cast<pvec_t*>(out_ptr);
    const int num_vecs = d / 2 / pvec_t::NUM_ELTS;
271
272

    for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
273
      pvec_t x, y;
274
275
276
277
      if constexpr (use_256b) {
        ld256(x, &x_vec[i]);
        ld256(y, &y_vec[i]);
      } else {
278
279
        ld128(x, &x_vec[i]);
        ld128(y, &y_vec[i]);
280
      }
281
#pragma unroll
282
283
      for (int j = 0; j < pvec_t::NUM_ELTS; j++) {
        x.elts[j] = packed_mul(PACKED_ACT_FN(x.elts[j], param), y.elts[j]);
284
285
286
287
      }
      if constexpr (use_256b) {
        st256(x, &out_vec[i]);
      } else {
288
        st128(x, &out_vec[i]);
289
290
291
292
293
294
295
296
297
      }
    }
  } else {
    // Scalar fallback for unaligned data or small d
    for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
      const scalar_t x = VLLM_LDG(&x_ptr[idx]);
      const scalar_t y = VLLM_LDG(&y_ptr[idx]);
      out_ptr[idx] = ACT_FN(x, param) * y;
    }
298
299
300
  }
}

301
302
303
template <typename T>
__device__ __forceinline__ T swigluoai_and_mul(const T& gate, const T& up,
                                               float alpha, float limit) {
304
305
306
307
308
  // Clamp gate to (-inf, limit] and up to [-limit, limit]
  const float g = fminf((float)gate, limit);
  const float u = fmaxf(fminf((float)up, limit), -limit);
  // glu = gate * sigmoid(gate * alpha), then return (up + 1) * glu
  return (T)((u + 1.0f) * g / (1.0f + expf(-g * alpha)));
309
310
}

311
// Interleaved gate/up: input has [gate0, up0, gate1, up1, ...].
312
313
314
315
316
template <typename scalar_t,
          scalar_t (*ACT_FN)(const scalar_t&, const scalar_t&, const float,
                             const float)>
__global__ void swigluoai_and_mul_kernel(
    scalar_t* __restrict__ out,          // [..., d]
317
    const scalar_t* __restrict__ input,  // [..., 2 * d] (interleaved)
318
    const int d, const float alpha, const float limit) {
319
320
321
322
  // For interleaved data: input has 2*d elements per token (gate/up pairs)
  // output has d elements per token
  constexpr int VEC_SIZE = 16 / sizeof(scalar_t);
  constexpr int PAIRS = VEC_SIZE / 2;  // Number of gate/up pairs per int4 load
323
  const int64_t token_idx = blockIdx.x;
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
  const scalar_t* in_ptr = input + token_idx * 2 * d;
  scalar_t* out_ptr = out + token_idx * d;

  // Check alignment for 128-bit vectorized access on input.
  // For output we use int2 (64-bit) which has 8-byte alignment requirement.
  const bool in_aligned = is_16byte_aligned(in_ptr);
  const bool out_aligned =
      (reinterpret_cast<uintptr_t>(out_ptr) & 7) == 0;  // 8-byte for int2

  if (in_aligned && out_aligned && d >= PAIRS) {
    // Fast path: vectorized loop
    // Each int4 load gives VEC_SIZE elements = PAIRS gate/up pairs
    // Each int2 store writes PAIRS output elements
    const int4* in_vec = reinterpret_cast<const int4*>(in_ptr);
    int2* out_vec = reinterpret_cast<int2*>(out_ptr);
    const int num_vecs = d / PAIRS;
    const int vec_end = num_vecs * PAIRS;

    for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
      int4 v = VLLM_LDG(&in_vec[i]);
      int2 r;
      auto* vp = reinterpret_cast<scalar_t*>(&v);
      auto* rp = reinterpret_cast<scalar_t*>(&r);
#pragma unroll
      for (int j = 0; j < PAIRS; j++) {
        rp[j] = ACT_FN(vp[2 * j], vp[2 * j + 1], alpha, limit);
      }
      out_vec[i] = r;
    }
    // Scalar cleanup for remaining elements
    for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) {
      out_ptr[i] = ACT_FN(VLLM_LDG(&in_ptr[2 * i]),
                          VLLM_LDG(&in_ptr[2 * i + 1]), alpha, limit);
    }
  } else {
    // Scalar fallback for unaligned data or small d
    for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
      // gate = x[..., ::2]  (even indices)
      const scalar_t gate = VLLM_LDG(&in_ptr[2 * idx]);
      // up = x[..., 1::2]   (odd indices)
      const scalar_t up = VLLM_LDG(&in_ptr[2 * idx + 1]);
      out_ptr[idx] = ACT_FN(gate, up, alpha, limit);
    }
367
368
369
  }
}

370
371
}  // namespace vllm

372
373
374
375
376
377
378
379
380
#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PACKED_KERNEL, PARAM) \
  auto dtype = input.scalar_type();                                            \
  int d = input.size(-1) / 2;                                                  \
  int64_t num_tokens = input.numel() / input.size(-1);                         \
  if (num_tokens == 0) {                                                       \
    return;                                                                    \
  }                                                                            \
  dim3 grid(num_tokens);                                                       \
  int cc_major = at::cuda::getCurrentDeviceProperties()->major;                \
381
382
383
384
  int support_vec =                                                            \
      (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128)            \
          ? vllm::VecTraits<true>::ARCH_MAX_VEC_SIZE                           \
          : vllm::VecTraits<false>::ARCH_MAX_VEC_SIZE;                         \
385
386
387
388
389
390
  int vec_size = support_vec / at::elementSize(dtype);                         \
  const bool use_vec = (d % vec_size == 0);                                    \
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));            \
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();                \
  if (use_vec) {                                                               \
    dim3 block(std::min(d / vec_size, 1024));                                  \
391
    if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) {         \
392
393
394
      VLLM_DISPATCH_FLOATING_TYPES(                                            \
          dtype, "act_and_mul_kernel_with_param", [&] {                        \
            vllm::act_and_mul_kernel_with_param<                               \
395
                scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type,  \
396
397
                KERNEL<scalar_t>,                                              \
                PACKED_KERNEL<                                                 \
398
                    typename vllm::PackedTypeConverter<scalar_t>::Type>,       \
399
400
401
402
403
404
405
406
                true, true><<<grid, block, 0, stream>>>(                       \
                out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d,       \
                PARAM);                                                        \
          });                                                                  \
    } else {                                                                   \
      VLLM_DISPATCH_FLOATING_TYPES(                                            \
          dtype, "act_and_mul_kernel_with_param", [&] {                        \
            vllm::act_and_mul_kernel_with_param<                               \
407
                scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type,  \
408
409
                KERNEL<scalar_t>,                                              \
                PACKED_KERNEL<                                                 \
410
                    typename vllm::PackedTypeConverter<scalar_t>::Type>,       \
411
412
413
414
415
416
417
418
419
                true, false><<<grid, block, 0, stream>>>(                      \
                out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d,       \
                PARAM);                                                        \
          });                                                                  \
    }                                                                          \
  } else {                                                                     \
    dim3 block(std::min(d, 1024));                                             \
    VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel_with_param", [&] { \
      vllm::act_and_mul_kernel_with_param<                                     \
420
          scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type,        \
421
          KERNEL<scalar_t>,                                                    \
422
          PACKED_KERNEL<typename vllm::PackedTypeConverter<scalar_t>::Type>,   \
423
424
425
426
          false><<<grid, block, 0, stream>>>(                                  \
          out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d, PARAM);     \
    });                                                                        \
  }
427

428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
#define LAUNCH_SIGLUOAI_AND_MUL(KERNEL, ALPHA, LIMIT)                          \
  int d = input.size(-1) / 2;                                                  \
  int64_t num_tokens = input.numel() / input.size(-1);                         \
  dim3 grid(num_tokens);                                                       \
  dim3 block(std::min(d, 1024));                                               \
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));            \
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();                \
  VLLM_DISPATCH_FLOATING_TYPES(                                                \
      input.scalar_type(), "clamp_swiglu_kernel_with_params", [&] {            \
        vllm::swigluoai_and_mul_kernel<scalar_t, KERNEL<scalar_t>>             \
            <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(),             \
                                         input.data_ptr<scalar_t>(), d, ALPHA, \
                                         LIMIT);                               \
      });

443
444
445
void fatrelu_and_mul(torch::Tensor& out,    // [..., d],
                     torch::Tensor& input,  // [..., 2 * d]
                     double threshold) {
446
447
  LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(
      vllm::fatrelu_kernel, vllm::packed_fatrelu_kernel, threshold);
448
}
449
450
451
452
453
void swigluoai_and_mul(torch::Tensor& out,    // [..., d]
                       torch::Tensor& input,  // [..., 2 * d]
                       double alpha, double limit) {
  LAUNCH_SIGLUOAI_AND_MUL(vllm::swigluoai_and_mul, alpha, limit);
}
454
455
namespace vllm {

456
// Element-wise activation kernel template.
457
458
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&), bool use_vec,
          bool use_256b = false>
459
__global__ void activation_kernel(
460
461
462
    scalar_t* __restrict__ out,          // [..., d]
    const scalar_t* __restrict__ input,  // [..., d]
    const int d) {
463
464
465
466
467
468
469
470
471
472
  const scalar_t* in_ptr = input + blockIdx.x * d;
  scalar_t* out_ptr = out + blockIdx.x * d;

  if constexpr (use_vec) {
    // Fast path: 128-bit/256-bit vectorized loop
    using vec_t = typename VecTraits<use_256b>::vec_t;
    constexpr int ARCH_MAX_VEC_SIZE = VecTraits<use_256b>::ARCH_MAX_VEC_SIZE;
    constexpr int VEC_SIZE = ARCH_MAX_VEC_SIZE / sizeof(scalar_t);
    const vec_t* in_vec = reinterpret_cast<const vec_t*>(in_ptr);
    vec_t* out_vec = reinterpret_cast<vec_t*>(out_ptr);
473
474
475
    const int num_vecs = d / VEC_SIZE;

    for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
476
477
478
479
480
481
      vec_t v;
      if constexpr (use_256b) {
        ld256(v, &in_vec[i]);
      } else {
        v = VLLM_LDG(&in_vec[i]);
      }
482
483
484
      auto* vp = reinterpret_cast<scalar_t*>(&v);
#pragma unroll
      for (int j = 0; j < VEC_SIZE; j++) {
485
486
487
488
489
490
        vp[j] = ACT_FN(vp[j]);
      }
      if constexpr (use_256b) {
        st256(v, &out_vec[i]);
      } else {
        out_vec[i] = v;
491
492
493
494
495
496
497
498
      }
    }
  } else {
    // Scalar fallback for unaligned data or small d
    for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
      const scalar_t x = VLLM_LDG(&in_ptr[idx]);
      out_ptr[idx] = ACT_FN(x);
    }
499
500
501
  }
}

502
}  // namespace vllm
503
504

// Launch element-wise activation kernel.
505
506
507
508
509
510
511
512
513
#define LAUNCH_ACTIVATION_KERNEL(KERNEL)                                 \
  auto dtype = input.scalar_type();                                      \
  int d = input.size(-1);                                                \
  int64_t num_tokens = input.numel() / input.size(-1);                   \
  if (num_tokens == 0) {                                                 \
    return;                                                              \
  }                                                                      \
  dim3 grid(num_tokens);                                                 \
  int cc_major = at::cuda::getCurrentDeviceProperties()->major;          \
514
515
516
517
  int support_vec =                                                      \
      (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128)      \
          ? vllm::VecTraits<true>::ARCH_MAX_VEC_SIZE                     \
          : vllm::VecTraits<false>::ARCH_MAX_VEC_SIZE;                   \
518
519
520
521
522
523
  int vec_size = support_vec / at::elementSize(dtype);                   \
  const bool use_vec = (d % vec_size == 0);                              \
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));      \
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();          \
  if (use_vec) {                                                         \
    dim3 block(std::min(d / vec_size, 1024));                            \
524
    if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) {   \
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
      VLLM_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] {     \
        vllm::activation_kernel<scalar_t, KERNEL<scalar_t>, true, true>  \
            <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(),       \
                                         input.data_ptr<scalar_t>(), d); \
      });                                                                \
    } else {                                                             \
      VLLM_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] {     \
        vllm::activation_kernel<scalar_t, KERNEL<scalar_t>, true, false> \
            <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(),       \
                                         input.data_ptr<scalar_t>(), d); \
      });                                                                \
    }                                                                    \
  } else {                                                               \
    dim3 block(std::min(d, 1024));                                       \
    VLLM_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] {       \
      vllm::activation_kernel<scalar_t, KERNEL<scalar_t>, false>         \
          <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(),         \
                                       input.data_ptr<scalar_t>(), d);   \
    });                                                                  \
  }
545
546
547

namespace vllm {

548
template <typename T>
549
__device__ __forceinline__ T gelu_new_kernel(const T& x) {
550
551
552
  const float x3 = (float)(x * x * x);
  const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3))));
  return ((T)0.5) * x * (((T)1.0) + t);
553
554
}

555
template <typename T>
556
__device__ __forceinline__ T gelu_fast_kernel(const T& x) {
557
558
559
560
  const float f = (float)x;
  const T t =
      (T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x));
  return ((T)0.5) * x * (((T)1.0) + t);
561
562
}

563
564
565
566
567
568
template <typename T>
__device__ __forceinline__ T gelu_quick_kernel(const T& x) {
  // x * sigmoid(1.702 * x)
  return (T)(((float)x) / (1.0f + expf(-1.702f * (float)x)));
}

569
}  // namespace vllm
570

571
572
void gelu_new(torch::Tensor& out,    // [..., d]
              torch::Tensor& input)  // [..., d]
573
574
575
576
{
  LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
}

577
578
void gelu_fast(torch::Tensor& out,    // [..., d]
               torch::Tensor& input)  // [..., d]
579
580
581
{
  LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
}
582
583
584
585
586
587

void gelu_quick(torch::Tensor& out,    // [..., d]
                torch::Tensor& input)  // [..., d]
{
  LAUNCH_ACTIVATION_KERNEL(vllm::gelu_quick_kernel);
}