activation_kernels.cu 6.43 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
#include <ATen/cuda/CUDAContext.h>
2
3
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
Woosuk Kwon's avatar
Woosuk Kwon committed
4

5
6
#include <cmath>

7
#include "cuda_compat.h"
8
9
#include "dispatch_utils.h"

Woosuk Kwon's avatar
Woosuk Kwon committed
10
namespace vllm {
Woosuk Kwon's avatar
Woosuk Kwon committed
11

12
13
14
// Activation and gating kernel template.
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void act_and_mul_kernel(
15
16
  scalar_t* __restrict__ out,               // [..., d]
  const scalar_t* __restrict__ input,       // [..., 2, d]
Woosuk Kwon's avatar
Woosuk Kwon committed
17
  const int d) {
Antoni Baum's avatar
Antoni Baum committed
18
19
  const int64_t token_idx = blockIdx.x;
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
20
21
    const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
    const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
22
    out[token_idx * d + idx] = ACT_FN(x) * y;
Woosuk Kwon's avatar
Woosuk Kwon committed
23
24
25
  }
}

26
27
28
29
30
31
32
33
34
35
template<typename T>
__device__ __forceinline__ T silu_kernel(const T& x) {
  // x * sigmoid(x)
  return (T) (((float) x) / (1.0f + expf((float) -x)));
}

template<typename T>
__device__ __forceinline__ T gelu_kernel(const T& x) {
  // Equivalent to PyTorch GELU with 'none' approximation.
  // Refer to:
36
  // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
37
38
39
40
41
  const float f = (float) x;
  constexpr float ALPHA = M_SQRT1_2;
  return (T) (f * 0.5f * (1.0f + ::erf(f * ALPHA)));
}

42
43
44
45
46
47
48
49
50
51
52
53
54
template<typename T>
__device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
  // Equivalent to PyTorch GELU with 'tanh' approximation.
  // Refer to:
  // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30
  const float f = (float) x;
  constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f;
  constexpr float KAPPA = 0.044715;
  float x_cube = f * f * f;
  float inner = BETA * (f + KAPPA * x_cube);
  return (T) (0.5f * f * (1.0f + ::tanhf(inner)));
}

Woosuk Kwon's avatar
Woosuk Kwon committed
55
} // namespace vllm
Woosuk Kwon's avatar
Woosuk Kwon committed
56

57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
// Launch activation and gating kernel.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL)                                             \
  int d = input.size(-1) / 2;                                                             \
  int64_t num_tokens = input.numel() / input.size(-1);                                    \
  dim3 grid(num_tokens);                                                                  \
  dim3 block(std::min(d, 1024));                                                          \
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));                       \
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();                           \
  VLLM_DISPATCH_FLOATING_TYPES(                                                           \
    input.scalar_type(),                                                                  \
    "act_and_mul_kernel",                                                                 \
    [&] {                                                                                 \
      vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>(   \
        out.data_ptr<scalar_t>(),                                                         \
        input.data_ptr<scalar_t>(),                                                       \
        d);                                                                               \
    });

Woosuk Kwon's avatar
Woosuk Kwon committed
75
void silu_and_mul(
76
77
  torch::Tensor& out,      // [..., d]
  torch::Tensor& input)    // [..., 2 * d]
Woosuk Kwon's avatar
Woosuk Kwon committed
78
{
79
80
  LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
}
Woosuk Kwon's avatar
Woosuk Kwon committed
81

82
83
84
85
86
void gelu_and_mul(
  torch::Tensor& out,      // [..., d]
  torch::Tensor& input)    // [..., 2 * d]
{
  LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
Woosuk Kwon's avatar
Woosuk Kwon committed
87
}
88

89
90
91
92
93
94
95
void gelu_tanh_and_mul(
  torch::Tensor& out,      // [..., d]
  torch::Tensor& input)    // [..., 2 * d]
{
  LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
}

96
97
98
99
100
namespace vllm {

// Element-wise activation kernel template.
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void activation_kernel(
101
102
  scalar_t* __restrict__ out,               // [..., d]
  const scalar_t* __restrict__ input,       // [..., d]
103
  const int d) {
Antoni Baum's avatar
Antoni Baum committed
104
105
  const int64_t token_idx = blockIdx.x;
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
106
    const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
107
108
109
110
111
112
113
114
    out[token_idx * d + idx] = ACT_FN(x);
  }
}

} // namespace vllm

// Launch element-wise activation kernel.
#define LAUNCH_ACTIVATION_KERNEL(KERNEL)                                                  \
115
  int d = input.size(-1);                                                                 \
Antoni Baum's avatar
Antoni Baum committed
116
  int64_t num_tokens = input.numel() / d;                                                 \
117
118
  dim3 grid(num_tokens);                                                                  \
  dim3 block(std::min(d, 1024));                                                          \
119
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));                       \
120
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();                           \
121
  VLLM_DISPATCH_FLOATING_TYPES(                                                           \
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
    input.scalar_type(),                                                                  \
    "activation_kernel",                                                                  \
    [&] {                                                                                 \
      vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>(    \
        out.data_ptr<scalar_t>(),                                                         \
        input.data_ptr<scalar_t>(),                                                       \
        d);                                                                               \
    });

namespace vllm {

template<typename T>
__device__ __forceinline__ T gelu_new_kernel(const T& x) {
  const float x3 = (float) (x * x * x);
  const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3))));
  return ((T) 0.5) * x * (((T) 1.0) + t);
}

template<typename T>
__device__ __forceinline__ T gelu_fast_kernel(const T& x) {
  const float f = (float) x;
  const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x));
  return ((T) 0.5) * x * (((T) 1.0) + t);
}

} // namespace vllm

void gelu_new(
150
151
  torch::Tensor& out,     // [..., d]
  torch::Tensor& input)   // [..., d]
152
153
154
155
156
{
  LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
}

void gelu_fast(
157
158
  torch::Tensor& out,     // [..., d]
  torch::Tensor& input)   // [..., d]
159
160
161
{
  LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
}