activation_kernels.cu 8.48 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
#include <ATen/cuda/CUDAContext.h>
2
#include <torch/all.h>
3
#include <c10/cuda/CUDAGuard.h>
Woosuk Kwon's avatar
Woosuk Kwon committed
4

5
6
#include <cmath>

7
#include "cuda_compat.h"
8
9
#include "dispatch_utils.h"

Woosuk Kwon's avatar
Woosuk Kwon committed
10
namespace vllm {
Woosuk Kwon's avatar
Woosuk Kwon committed
11

12
13
14
15
16
17
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
          bool act_first>
__device__ __forceinline__ scalar_t compute(const scalar_t& x,
                                            const scalar_t& y) {
  return act_first ? ACT_FN(x) * y : x * ACT_FN(y);
}
18
// Activation and gating kernel template.
19
20
21

template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
          bool act_first>
22
__global__ void act_and_mul_kernel(
23
24
25
    scalar_t* __restrict__ out,          // [..., d]
    const scalar_t* __restrict__ input,  // [..., 2, d]
    const int d) {
Antoni Baum's avatar
Antoni Baum committed
26
27
  const int64_t token_idx = blockIdx.x;
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
28
29
    const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
    const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
30
    out[token_idx * d + idx] = compute<scalar_t, ACT_FN, act_first>(x, y);
Woosuk Kwon's avatar
Woosuk Kwon committed
31
32
33
  }
}

34
template <typename T>
35
36
__device__ __forceinline__ T silu_kernel(const T& x) {
  // x * sigmoid(x)
37
  return (T)(((float)x) / (1.0f + expf((float)-x)));
38
39
}

40
template <typename T>
41
42
43
__device__ __forceinline__ T gelu_kernel(const T& x) {
  // Equivalent to PyTorch GELU with 'none' approximation.
  // Refer to:
44
  // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
45
  const float f = (float)x;
46
  constexpr float ALPHA = M_SQRT1_2;
47
  return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA)));
48
49
}

50
template <typename T>
51
52
53
54
__device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
  // Equivalent to PyTorch GELU with 'tanh' approximation.
  // Refer to:
  // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30
55
  const float f = (float)x;
56
57
58
59
  constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f;
  constexpr float KAPPA = 0.044715;
  float x_cube = f * f * f;
  float inner = BETA * (f + KAPPA * x_cube);
60
  return (T)(0.5f * f * (1.0f + ::tanhf(inner)));
61
62
}

63
}  // namespace vllm
Woosuk Kwon's avatar
Woosuk Kwon committed
64

zhuwenwen's avatar
zhuwenwen committed
65
// Launch activation and gating kernel.
66
67
68
// Use ACT_FIRST (bool) indicating whether to apply the activation function
// first.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, ACT_FIRST)                 \
zhuwenwen's avatar
zhuwenwen committed
69
70
71
72
73
74
75
76
  int d = input.size(-1) / 2;                                            \
  int64_t num_tokens = input.numel() / input.size(-1);                   \
  dim3 grid(num_tokens);                                                 \
  dim3 block(std::min(d, 1024));                                         \
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));      \
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();          \
  VLLM_DISPATCH_FLOATING_TYPES(                                          \
      input.scalar_type(), "act_and_mul_kernel", [&] {                   \
77
        vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>, ACT_FIRST>  \
zhuwenwen's avatar
zhuwenwen committed
78
79
            <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(),       \
                                         input.data_ptr<scalar_t>(), d); \
80
81
82
83
      });

void silu_and_mul(torch::Tensor& out,    // [..., d]
                  torch::Tensor& input)  // [..., 2 * d]
Woosuk Kwon's avatar
Woosuk Kwon committed
84
{
85
86
87
88
89
90
91
92
93
  LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, true);
}

void mul_and_silu(torch::Tensor& out,    // [..., d]
                  torch::Tensor& input)  // [..., 2 * d]
{
  // The difference between mul_and_silu and silu_and_mul is that mul_and_silu
  // applies the silu to the latter half of the input.
  LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, false);
94
}
Woosuk Kwon's avatar
Woosuk Kwon committed
95

96
97
void gelu_and_mul(torch::Tensor& out,    // [..., d]
                  torch::Tensor& input)  // [..., 2 * d]
98
{
99
  LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel, true);
Woosuk Kwon's avatar
Woosuk Kwon committed
100
}
101

102
103
void gelu_tanh_and_mul(torch::Tensor& out,    // [..., d]
                       torch::Tensor& input)  // [..., 2 * d]
104
{
105
  LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel, true);
106
107
}

108
109
namespace vllm {

110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
template <typename T>
__device__ __forceinline__ T fatrelu_kernel(const T& x, const float threshold) {
  const float f = (float)x;
  return (T)(f > threshold ? f : 0.0f);
}

template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&, const float)>
__global__ void act_and_mul_kernel_with_param(
    scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const int d,
    const float param) {
  const int64_t token_idx = blockIdx.x;
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
    const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
    const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
    out[token_idx * d + idx] = ACT_FN(x, param) * y;
  }
}

}  // namespace vllm

#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM)         \
  int d = input.size(-1) / 2;                                           \
  int64_t num_tokens = input.numel() / input.size(-1);                  \
  dim3 grid(num_tokens);                                                \
  dim3 block(std::min(d, 1024));                                        \
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));     \
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();         \
  VLLM_DISPATCH_FLOATING_TYPES(                                         \
      input.scalar_type(), "act_and_mul_kernel_with_param", [&] {       \
        vllm::act_and_mul_kernel_with_param<scalar_t, KERNEL<scalar_t>> \
            <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(),      \
                                         input.data_ptr<scalar_t>(), d, \
                                         PARAM);                        \
      });

void fatrelu_and_mul(torch::Tensor& out,    // [..., d],
                     torch::Tensor& input,  // [..., 2 * d]
                     double threshold) {
  LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(vllm::fatrelu_kernel, threshold);
}
namespace vllm {

152
// Element-wise activation kernel template.
153
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
154
__global__ void activation_kernel(
155
156
157
    scalar_t* __restrict__ out,          // [..., d]
    const scalar_t* __restrict__ input,  // [..., d]
    const int d) {
Antoni Baum's avatar
Antoni Baum committed
158
159
  const int64_t token_idx = blockIdx.x;
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
160
    const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
161
162
163
164
    out[token_idx * d + idx] = ACT_FN(x);
  }
}

165
}  // namespace vllm
166
167

// Launch element-wise activation kernel.
168
169
170
171
172
173
174
175
176
177
178
179
#define LAUNCH_ACTIVATION_KERNEL(KERNEL)                                       \
  int d = input.size(-1);                                                      \
  int64_t num_tokens = input.numel() / d;                                      \
  dim3 grid(num_tokens);                                                       \
  dim3 block(std::min(d, 1024));                                               \
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));            \
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();                \
  VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \
    vllm::activation_kernel<scalar_t, KERNEL<scalar_t>>                        \
        <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(),                 \
                                     input.data_ptr<scalar_t>(), d);           \
  });
180
181
182

namespace vllm {

183
template <typename T>
184
__device__ __forceinline__ T gelu_new_kernel(const T& x) {
185
186
187
  const float x3 = (float)(x * x * x);
  const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3))));
  return ((T)0.5) * x * (((T)1.0) + t);
188
189
}

190
template <typename T>
191
__device__ __forceinline__ T gelu_fast_kernel(const T& x) {
192
193
194
195
  const float f = (float)x;
  const T t =
      (T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x));
  return ((T)0.5) * x * (((T)1.0) + t);
196
197
}

198
199
200
201
202
203
template <typename T>
__device__ __forceinline__ T gelu_quick_kernel(const T& x) {
  // x * sigmoid(1.702 * x)
  return (T)(((float)x) / (1.0f + expf(-1.702f * (float)x)));
}

204
}  // namespace vllm
205

206
207
void gelu_new(torch::Tensor& out,    // [..., d]
              torch::Tensor& input)  // [..., d]
208
209
210
211
{
  LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
}

212
213
void gelu_fast(torch::Tensor& out,    // [..., d]
               torch::Tensor& input)  // [..., d]
214
215
216
{
  LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
}
217
218
219
220
221

void gelu_quick(torch::Tensor& out,    // [..., d]
                torch::Tensor& input)  // [..., d]
{
  LAUNCH_ACTIVATION_KERNEL(vllm::gelu_quick_kernel);
zhuwenwen's avatar
zhuwenwen committed
222
}