activation_kernels.cu 8.88 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
#include <ATen/cuda/CUDAContext.h>
2
#include <torch/all.h>
3
#include <c10/cuda/CUDAGuard.h>
bianch's avatar
bianch committed
4
#include <ATen/native/cuda/MemoryAccess.cuh>
Woosuk Kwon's avatar
Woosuk Kwon committed
5

6
7
#include <cmath>

8
#include "cuda_compat.h"
9
10
#include "dispatch_utils.h"

Woosuk Kwon's avatar
Woosuk Kwon committed
11
namespace vllm {
Woosuk Kwon's avatar
Woosuk Kwon committed
12

13
// Activation and gating kernel template.
14
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
15
__global__ void act_and_mul_kernel(
16
17
18
    scalar_t* __restrict__ out,          // [..., d]
    const scalar_t* __restrict__ input,  // [..., 2, d]
    const int d) {
Antoni Baum's avatar
Antoni Baum committed
19
20
  const int64_t token_idx = blockIdx.x;
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
21
22
    const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
    const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
23
    out[token_idx * d + idx] = ACT_FN(x) * y;
Woosuk Kwon's avatar
Woosuk Kwon committed
24
25
26
  }
}

bianch's avatar
bianch committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&), int VEC>
__global__ void act_and_mul_kernel_vectorize1(
    scalar_t* __restrict__ out,          // [..., d]
    const scalar_t* __restrict__ input,  // [..., 2, d]
    const int d) {
  using VecType = at::native::memory::aligned_vector<scalar_t, VEC>;
  const int token_idx = blockIdx.x;
  int idx = threadIdx.x * VEC;
  if (idx < d) {
    const int x_index = token_idx * 2 * d + idx;
    const int y_index = token_idx * d + idx;
    VecType* x1 = (VecType*)(input + x_index);
    VecType* x2 = (VecType*)(input + x_index + d);
    VecType* y = (VecType*)(out + y_index);
    scalar_t r_x1[VEC];
    scalar_t r_x2[VEC];
    scalar_t r_y[VEC];
    *(VecType*)r_x1 = *x1;
    *(VecType*)r_x2 = *x2;
#pragma unroll
    for (int i = 0; i < VEC; i++) {
      const scalar_t t_x1 = VLLM_LDG(&r_x1[i]);
      const scalar_t t_x2 = VLLM_LDG(&r_x2[i]);
      r_y[i] = ACT_FN(t_x1) * t_x2;
    }
    *y = *(VecType*)r_y;
  }
}

template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&), int VEC>
__global__ void act_and_mul_kernel_vectorize2(
    scalar_t* __restrict__ out,          // [..., d]
    const scalar_t* __restrict__ input,  // [..., 2, d]
    const int d) {
  using VecType = at::native::memory::aligned_vector<scalar_t, VEC>;
  const int token_idx = blockIdx.x;
  int idx = threadIdx.x * VEC;
  for (; idx < d; idx += blockDim.x * VEC) {
    const int x_index = token_idx * 2 * d + idx;
    const int y_index = token_idx * d + idx;
    VecType* x1 = (VecType*)(input + x_index);
    VecType* x2 = (VecType*)(input + x_index + d);
    VecType* y = (VecType*)(out + y_index);
    scalar_t r_x1[VEC];
    scalar_t r_x2[VEC];
    scalar_t r_y[VEC];
    *(VecType*)r_x1 = *x1;
    *(VecType*)r_x2 = *x2;
#pragma unroll
    for (int i = 0; i < VEC; i++) {
      const scalar_t t_x1 = VLLM_LDG(&r_x1[i]);
      const scalar_t t_x2 = VLLM_LDG(&r_x2[i]);
      r_y[i] = ACT_FN(t_x1) * t_x2;
    }
    *y = *(VecType*)r_y;
  }
}

85
template <typename T>
86
87
__device__ __forceinline__ T silu_kernel(const T& x) {
  // x * sigmoid(x)
88
  return (T)(((float)x) / (1.0f + expf((float)-x)));
89
90
}

91
template <typename T>
92
93
94
__device__ __forceinline__ T gelu_kernel(const T& x) {
  // Equivalent to PyTorch GELU with 'none' approximation.
  // Refer to:
95
  // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
96
  const float f = (float)x;
97
  constexpr float ALPHA = M_SQRT1_2;
98
  return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA)));
99
100
}

101
template <typename T>
102
103
104
105
__device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
  // Equivalent to PyTorch GELU with 'tanh' approximation.
  // Refer to:
  // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30
106
  const float f = (float)x;
107
108
109
110
  constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f;
  constexpr float KAPPA = 0.044715;
  float x_cube = f * f * f;
  float inner = BETA * (f + KAPPA * x_cube);
111
  return (T)(0.5f * f * (1.0f + ::tanhf(inner)));
112
113
}

114
}  // namespace vllm
Woosuk Kwon's avatar
Woosuk Kwon committed
115

bianch's avatar
bianch committed
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL)                               \
  int d = input.size(-1) / 2;                                               \
  int64_t num_tokens = input.numel() / input.size(-1);                      \
  dim3 grid(num_tokens);                                                    \
  dim3 block(std::min(d, 1024));                                            \
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));         \
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();             \
  VLLM_DISPATCH_FLOATING_TYPES(                                             \
      input.scalar_type(), "act_and_mul_kernel", [&] {                      \
        if (d <= 512) {                                                     \
          vllm::act_and_mul_kernel_vectorize1<scalar_t, KERNEL<scalar_t>, 2> \
              <<<grid, 256, 0, stream>>>(out.data_ptr<scalar_t>(),          \
                                         input.data_ptr<scalar_t>(), d);    \
        } else if (d <= 1024) {                                             \
          vllm::act_and_mul_kernel_vectorize1<scalar_t, KERNEL<scalar_t>, 8> \
              <<<grid, 128, 0, stream>>>(out.data_ptr<scalar_t>(),          \
                                         input.data_ptr<scalar_t>(), d);    \
        } else if (d <= 2048) {                                             \
          vllm::act_and_mul_kernel_vectorize1<scalar_t, KERNEL<scalar_t>, 8> \
              <<<grid, 256, 0, stream>>>(out.data_ptr<scalar_t>(),          \
                                         input.data_ptr<scalar_t>(), d);    \
        } else if (d <= 4096) {                                             \
          vllm::act_and_mul_kernel_vectorize1<scalar_t, KERNEL<scalar_t>, 8> \
              <<<grid, 512, 0, stream>>>(out.data_ptr<scalar_t>(),          \
                                         input.data_ptr<scalar_t>(), d);    \
        } else {                                                            \
          vllm::act_and_mul_kernel_vectorize2<scalar_t, KERNEL<scalar_t>, 8> \
              <<<grid, 1024, 0, stream>>>(out.data_ptr<scalar_t>(),         \
                                          input.data_ptr<scalar_t>(), d);   \
        }                                                                   \
146
147
148
149
      });

void silu_and_mul(torch::Tensor& out,    // [..., d]
                  torch::Tensor& input)  // [..., 2 * d]
Woosuk Kwon's avatar
Woosuk Kwon committed
150
{
151
152
  LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
}
Woosuk Kwon's avatar
Woosuk Kwon committed
153

154
155
void gelu_and_mul(torch::Tensor& out,    // [..., d]
                  torch::Tensor& input)  // [..., 2 * d]
156
157
{
  LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
Woosuk Kwon's avatar
Woosuk Kwon committed
158
}
159

160
161
void gelu_tanh_and_mul(torch::Tensor& out,    // [..., d]
                       torch::Tensor& input)  // [..., 2 * d]
162
163
164
165
{
  LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
}

166
167
168
namespace vllm {

// Element-wise activation kernel template.
169
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
170
__global__ void activation_kernel(
171
172
173
    scalar_t* __restrict__ out,          // [..., d]
    const scalar_t* __restrict__ input,  // [..., d]
    const int d) {
Antoni Baum's avatar
Antoni Baum committed
174
175
  const int64_t token_idx = blockIdx.x;
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
176
    const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
177
178
179
180
    out[token_idx * d + idx] = ACT_FN(x);
  }
}

181
}  // namespace vllm
182
183

// Launch element-wise activation kernel.
184
185
186
187
188
189
190
191
192
193
194
195
#define LAUNCH_ACTIVATION_KERNEL(KERNEL)                                       \
  int d = input.size(-1);                                                      \
  int64_t num_tokens = input.numel() / d;                                      \
  dim3 grid(num_tokens);                                                       \
  dim3 block(std::min(d, 1024));                                               \
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));            \
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();                \
  VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \
    vllm::activation_kernel<scalar_t, KERNEL<scalar_t>>                        \
        <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(),                 \
                                     input.data_ptr<scalar_t>(), d);           \
  });
196
197
198

namespace vllm {

199
template <typename T>
200
__device__ __forceinline__ T gelu_new_kernel(const T& x) {
201
202
203
  const float x3 = (float)(x * x * x);
  const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3))));
  return ((T)0.5) * x * (((T)1.0) + t);
204
205
}

206
template <typename T>
207
__device__ __forceinline__ T gelu_fast_kernel(const T& x) {
208
209
210
211
  const float f = (float)x;
  const T t =
      (T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x));
  return ((T)0.5) * x * (((T)1.0) + t);
212
213
}

214
}  // namespace vllm
215

216
217
void gelu_new(torch::Tensor& out,    // [..., d]
              torch::Tensor& input)  // [..., d]
218
219
220
221
{
  LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
}

222
223
void gelu_fast(torch::Tensor& out,    // [..., d]
               torch::Tensor& input)  // [..., d]
224
225
{
  LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
bianch's avatar
bianch committed
226
}