activation_kernels.cu 5.84 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
#include <ATen/cuda/CUDAContext.h>
2
#include <torch/all.h>
3
#include <c10/cuda/CUDAGuard.h>
Woosuk Kwon's avatar
Woosuk Kwon committed
4

5
6
#include <cmath>

7
#include "cuda_compat.h"
8
9
#include "dispatch_utils.h"

Woosuk Kwon's avatar
Woosuk Kwon committed
10
namespace vllm {
Woosuk Kwon's avatar
Woosuk Kwon committed
11

12
// Activation and gating kernel template.
13
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
14
__global__ void act_and_mul_kernel(
15
16
17
    scalar_t* __restrict__ out,          // [..., d]
    const scalar_t* __restrict__ input,  // [..., 2, d]
    const int d) {
Antoni Baum's avatar
Antoni Baum committed
18
19
  const int64_t token_idx = blockIdx.x;
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
20
21
    const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
    const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
22
    out[token_idx * d + idx] = ACT_FN(x) * y;
Woosuk Kwon's avatar
Woosuk Kwon committed
23
24
25
  }
}

26
template <typename T>
27
28
__device__ __forceinline__ T silu_kernel(const T& x) {
  // x * sigmoid(x)
29
  return (T)(((float)x) / (1.0f + expf((float)-x)));
30
31
}

32
template <typename T>
33
34
35
__device__ __forceinline__ T gelu_kernel(const T& x) {
  // Equivalent to PyTorch GELU with 'none' approximation.
  // Refer to:
36
  // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
37
  const float f = (float)x;
38
  constexpr float ALPHA = M_SQRT1_2;
39
  return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA)));
40
41
}

42
template <typename T>
43
44
45
46
__device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
  // Equivalent to PyTorch GELU with 'tanh' approximation.
  // Refer to:
  // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30
47
  const float f = (float)x;
48
49
50
51
  constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f;
  constexpr float KAPPA = 0.044715;
  float x_cube = f * f * f;
  float inner = BETA * (f + KAPPA * x_cube);
52
  return (T)(0.5f * f * (1.0f + ::tanhf(inner)));
53
54
}

55
}  // namespace vllm
Woosuk Kwon's avatar
Woosuk Kwon committed
56

zhuwenwen's avatar
zhuwenwen committed
57
58
59
60
61
62
63
64
65
66
67
68
69
// Launch activation and gating kernel.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL)                            \
  int d = input.size(-1) / 2;                                            \
  int64_t num_tokens = input.numel() / input.size(-1);                   \
  dim3 grid(num_tokens);                                                 \
  dim3 block(std::min(d, 1024));                                         \
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));      \
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();          \
  VLLM_DISPATCH_FLOATING_TYPES(                                          \
      input.scalar_type(), "act_and_mul_kernel", [&] {                   \
        vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>>             \
            <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(),       \
                                         input.data_ptr<scalar_t>(), d); \
70
71
72
73
      });

void silu_and_mul(torch::Tensor& out,    // [..., d]
                  torch::Tensor& input)  // [..., 2 * d]
Woosuk Kwon's avatar
Woosuk Kwon committed
74
{
75
76
  LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
}
Woosuk Kwon's avatar
Woosuk Kwon committed
77

78
79
void gelu_and_mul(torch::Tensor& out,    // [..., d]
                  torch::Tensor& input)  // [..., 2 * d]
80
81
{
  LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
Woosuk Kwon's avatar
Woosuk Kwon committed
82
}
83

84
85
void gelu_tanh_and_mul(torch::Tensor& out,    // [..., d]
                       torch::Tensor& input)  // [..., 2 * d]
86
87
88
89
{
  LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
}

90
91
92
namespace vllm {

// Element-wise activation kernel template.
93
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
94
__global__ void activation_kernel(
95
96
97
    scalar_t* __restrict__ out,          // [..., d]
    const scalar_t* __restrict__ input,  // [..., d]
    const int d) {
Antoni Baum's avatar
Antoni Baum committed
98
99
  const int64_t token_idx = blockIdx.x;
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
100
    const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
101
102
103
104
    out[token_idx * d + idx] = ACT_FN(x);
  }
}

105
}  // namespace vllm
106
107

// Launch element-wise activation kernel.
108
109
110
111
112
113
114
115
116
117
118
119
#define LAUNCH_ACTIVATION_KERNEL(KERNEL)                                       \
  int d = input.size(-1);                                                      \
  int64_t num_tokens = input.numel() / d;                                      \
  dim3 grid(num_tokens);                                                       \
  dim3 block(std::min(d, 1024));                                               \
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));            \
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();                \
  VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \
    vllm::activation_kernel<scalar_t, KERNEL<scalar_t>>                        \
        <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(),                 \
                                     input.data_ptr<scalar_t>(), d);           \
  });
120
121
122

namespace vllm {

123
template <typename T>
124
__device__ __forceinline__ T gelu_new_kernel(const T& x) {
125
126
127
  const float x3 = (float)(x * x * x);
  const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3))));
  return ((T)0.5) * x * (((T)1.0) + t);
128
129
}

130
template <typename T>
131
__device__ __forceinline__ T gelu_fast_kernel(const T& x) {
132
133
134
135
  const float f = (float)x;
  const T t =
      (T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x));
  return ((T)0.5) * x * (((T)1.0) + t);
136
137
}

138
139
140
141
142
143
template <typename T>
__device__ __forceinline__ T gelu_quick_kernel(const T& x) {
  // x * sigmoid(1.702 * x)
  return (T)(((float)x) / (1.0f + expf(-1.702f * (float)x)));
}

144
}  // namespace vllm
145

146
147
void gelu_new(torch::Tensor& out,    // [..., d]
              torch::Tensor& input)  // [..., d]
148
149
150
151
{
  LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
}

152
153
void gelu_fast(torch::Tensor& out,    // [..., d]
               torch::Tensor& input)  // [..., d]
154
155
156
{
  LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
}
157
158
159
160
161

void gelu_quick(torch::Tensor& out,    // [..., d]
                torch::Tensor& input)  // [..., d]
{
  LAUNCH_ACTIVATION_KERNEL(vllm::gelu_quick_kernel);
zhuwenwen's avatar
zhuwenwen committed
162
}