activation_kernels.cu 8.7 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
#include <ATen/cuda/CUDAContext.h>
2
#include <torch/all.h>
3
#include <c10/cuda/CUDAGuard.h>
Woosuk Kwon's avatar
Woosuk Kwon committed
4

5
6
#include <cmath>

7
#include "cuda_compat.h"
8
9
#include "dispatch_utils.h"

Woosuk Kwon's avatar
Woosuk Kwon committed
10
namespace vllm {
Woosuk Kwon's avatar
Woosuk Kwon committed
11

12
13
14
15
16
17
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
          bool act_first>
__device__ __forceinline__ scalar_t compute(const scalar_t& x,
                                            const scalar_t& y) {
  return act_first ? ACT_FN(x) * y : x * ACT_FN(y);
}
18
// Activation and gating kernel template.
19
20
21

template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
          bool act_first>
22
__global__ void act_and_mul_kernel(
23
24
25
    scalar_t* __restrict__ out,          // [..., d]
    const scalar_t* __restrict__ input,  // [..., 2, d]
    const int d) {
Antoni Baum's avatar
Antoni Baum committed
26
27
  const int64_t token_idx = blockIdx.x;
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
28
29
    const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
    const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
30
    out[token_idx * d + idx] = compute<scalar_t, ACT_FN, act_first>(x, y);
Woosuk Kwon's avatar
Woosuk Kwon committed
31
32
33
  }
}

34
template <typename T>
35
36
__device__ __forceinline__ T silu_kernel(const T& x) {
  // x * sigmoid(x)
37
  return (T)(((float)x) / (1.0f + expf((float)-x)));
38
39
}

40
template <typename T>
41
42
43
__device__ __forceinline__ T gelu_kernel(const T& x) {
  // Equivalent to PyTorch GELU with 'none' approximation.
  // Refer to:
44
  // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
45
  const float f = (float)x;
46
  constexpr float ALPHA = M_SQRT1_2;
47
  return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA)));
48
49
}

50
template <typename T>
51
52
53
54
__device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
  // Equivalent to PyTorch GELU with 'tanh' approximation.
  // Refer to:
  // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30
55
  const float f = (float)x;
56
57
58
59
  constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f;
  constexpr float KAPPA = 0.044715;
  float x_cube = f * f * f;
  float inner = BETA * (f + KAPPA * x_cube);
60
  return (T)(0.5f * f * (1.0f + ::tanhf(inner)));
61
62
}

63
}  // namespace vllm
Woosuk Kwon's avatar
Woosuk Kwon committed
64

65
// Launch activation and gating kernel.
66
67
68
// Use ACT_FIRST (bool) indicating whether to apply the activation function
// first.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, ACT_FIRST)                 \
69
70
71
72
  int d = input.size(-1) / 2;                                            \
  int64_t num_tokens = input.numel() / input.size(-1);                   \
  dim3 grid(num_tokens);                                                 \
  dim3 block(std::min(d, 1024));                                         \
73
74
75
  if (num_tokens == 0) {                                                 \
    return;                                                              \
  }                                                                      \
76
77
78
79
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));      \
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();          \
  VLLM_DISPATCH_FLOATING_TYPES(                                          \
      input.scalar_type(), "act_and_mul_kernel", [&] {                   \
80
        vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>, ACT_FIRST>  \
81
82
83
84
85
86
            <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(),       \
                                         input.data_ptr<scalar_t>(), d); \
      });

void silu_and_mul(torch::Tensor& out,    // [..., d]
                  torch::Tensor& input)  // [..., 2 * d]
Woosuk Kwon's avatar
Woosuk Kwon committed
87
{
88
89
90
91
92
93
94
95
96
  LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, true);
}

void mul_and_silu(torch::Tensor& out,    // [..., d]
                  torch::Tensor& input)  // [..., 2 * d]
{
  // The difference between mul_and_silu and silu_and_mul is that mul_and_silu
  // applies the silu to the latter half of the input.
  LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, false);
97
}
Woosuk Kwon's avatar
Woosuk Kwon committed
98

99
100
void gelu_and_mul(torch::Tensor& out,    // [..., d]
                  torch::Tensor& input)  // [..., 2 * d]
101
{
102
  LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel, true);
Woosuk Kwon's avatar
Woosuk Kwon committed
103
}
104

105
106
void gelu_tanh_and_mul(torch::Tensor& out,    // [..., d]
                       torch::Tensor& input)  // [..., 2 * d]
107
{
108
  LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel, true);
109
110
}

111
112
namespace vllm {

113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
template <typename T>
__device__ __forceinline__ T fatrelu_kernel(const T& x, const float threshold) {
  const float f = (float)x;
  return (T)(f > threshold ? f : 0.0f);
}

template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&, const float)>
__global__ void act_and_mul_kernel_with_param(
    scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const int d,
    const float param) {
  const int64_t token_idx = blockIdx.x;
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
    const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
    const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
    out[token_idx * d + idx] = ACT_FN(x, param) * y;
  }
}

}  // namespace vllm

#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM)         \
  int d = input.size(-1) / 2;                                           \
  int64_t num_tokens = input.numel() / input.size(-1);                  \
  dim3 grid(num_tokens);                                                \
  dim3 block(std::min(d, 1024));                                        \
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));     \
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();         \
  VLLM_DISPATCH_FLOATING_TYPES(                                         \
      input.scalar_type(), "act_and_mul_kernel_with_param", [&] {       \
        vllm::act_and_mul_kernel_with_param<scalar_t, KERNEL<scalar_t>> \
            <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(),      \
                                         input.data_ptr<scalar_t>(), d, \
                                         PARAM);                        \
      });

void fatrelu_and_mul(torch::Tensor& out,    // [..., d],
                     torch::Tensor& input,  // [..., 2 * d]
                     double threshold) {
  LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(vllm::fatrelu_kernel, threshold);
}
namespace vllm {

155
// Element-wise activation kernel template.
156
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
157
__global__ void activation_kernel(
158
159
160
    scalar_t* __restrict__ out,          // [..., d]
    const scalar_t* __restrict__ input,  // [..., d]
    const int d) {
Antoni Baum's avatar
Antoni Baum committed
161
162
  const int64_t token_idx = blockIdx.x;
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
163
    const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
164
165
166
167
    out[token_idx * d + idx] = ACT_FN(x);
  }
}

168
}  // namespace vllm
169
170

// Launch element-wise activation kernel.
171
172
173
174
175
176
177
178
179
180
181
182
#define LAUNCH_ACTIVATION_KERNEL(KERNEL)                                       \
  int d = input.size(-1);                                                      \
  int64_t num_tokens = input.numel() / d;                                      \
  dim3 grid(num_tokens);                                                       \
  dim3 block(std::min(d, 1024));                                               \
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));            \
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();                \
  VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \
    vllm::activation_kernel<scalar_t, KERNEL<scalar_t>>                        \
        <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(),                 \
                                     input.data_ptr<scalar_t>(), d);           \
  });
183
184
185

namespace vllm {

186
template <typename T>
187
__device__ __forceinline__ T gelu_new_kernel(const T& x) {
188
189
190
  const float x3 = (float)(x * x * x);
  const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3))));
  return ((T)0.5) * x * (((T)1.0) + t);
191
192
}

193
template <typename T>
194
__device__ __forceinline__ T gelu_fast_kernel(const T& x) {
195
196
197
198
  const float f = (float)x;
  const T t =
      (T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x));
  return ((T)0.5) * x * (((T)1.0) + t);
199
200
}

201
202
203
204
205
206
template <typename T>
__device__ __forceinline__ T gelu_quick_kernel(const T& x) {
  // x * sigmoid(1.702 * x)
  return (T)(((float)x) / (1.0f + expf(-1.702f * (float)x)));
}

207
}  // namespace vllm
208

209
210
void gelu_new(torch::Tensor& out,    // [..., d]
              torch::Tensor& input)  // [..., d]
211
212
213
214
{
  LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
}

215
216
void gelu_fast(torch::Tensor& out,    // [..., d]
               torch::Tensor& input)  // [..., d]
217
218
219
{
  LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
}
220
221
222
223
224
225

void gelu_quick(torch::Tensor& out,    // [..., d]
                torch::Tensor& input)  // [..., d]
{
  LAUNCH_ACTIVATION_KERNEL(vllm::gelu_quick_kernel);
}