activation_kernels.cu 3.96 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
3
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>

4
#include "cuda_compat.h"
5
6
#include "dispatch_utils.h"

Woosuk Kwon's avatar
Woosuk Kwon committed
7
namespace vllm {
Woosuk Kwon's avatar
Woosuk Kwon committed
8
9
10
11
12
13
14
15
16

template<typename T>
__device__ __forceinline__ T silu(const T& x) {
  // x * sigmoid(x)
  return (T) (((float) x) / (1.0f + expf((float) -x)));
}

template<typename scalar_t>
__global__ void silu_and_mul_kernel(
17
18
  scalar_t* __restrict__ out,               // [..., d]
  const scalar_t* __restrict__ input,       // [..., 2, d]
Woosuk Kwon's avatar
Woosuk Kwon committed
19
  const int d) {
Antoni Baum's avatar
Antoni Baum committed
20
21
  const int64_t token_idx = blockIdx.x;
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
22
23
    const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
    const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
Woosuk Kwon's avatar
Woosuk Kwon committed
24
25
26
27
    out[token_idx * d + idx] = silu(x) * y;
  }
}

Woosuk Kwon's avatar
Woosuk Kwon committed
28
} // namespace vllm
Woosuk Kwon's avatar
Woosuk Kwon committed
29
30

void silu_and_mul(
31
32
  torch::Tensor& out,      // [..., d]
  torch::Tensor& input)    // [..., 2 * d]
Woosuk Kwon's avatar
Woosuk Kwon committed
33
{
Antoni Baum's avatar
Antoni Baum committed
34
  int64_t num_tokens = input.numel() / input.size(-1);
35
  int d = input.size(-1) / 2;
Woosuk Kwon's avatar
Woosuk Kwon committed
36
37
38
39

  dim3 grid(num_tokens);
  dim3 block(std::min(d, 1024));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
40
  VLLM_DISPATCH_FLOATING_TYPES(
Woosuk Kwon's avatar
Woosuk Kwon committed
41
42
43
    input.scalar_type(),
    "silu_and_mul_kernel",
    [&] {
Woosuk Kwon's avatar
Woosuk Kwon committed
44
      vllm::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
Woosuk Kwon's avatar
Woosuk Kwon committed
45
46
47
48
49
        out.data_ptr<scalar_t>(),
        input.data_ptr<scalar_t>(),
        d);
    });
}
50
51
52
53
54
55

namespace vllm {

// Element-wise activation kernel template.
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void activation_kernel(
56
57
  scalar_t* __restrict__ out,               // [..., d]
  const scalar_t* __restrict__ input,       // [..., d]
58
  const int d) {
Antoni Baum's avatar
Antoni Baum committed
59
60
  const int64_t token_idx = blockIdx.x;
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
61
    const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
62
63
64
65
66
67
68
69
    out[token_idx * d + idx] = ACT_FN(x);
  }
}

} // namespace vllm

// Launch element-wise activation kernel.
#define LAUNCH_ACTIVATION_KERNEL(KERNEL)                                                  \
70
  int d = input.size(-1);                                                                 \
Antoni Baum's avatar
Antoni Baum committed
71
  int64_t num_tokens = input.numel() / d;                                                 \
72
73
74
  dim3 grid(num_tokens);                                                                  \
  dim3 block(std::min(d, 1024));                                                          \
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();                           \
75
  VLLM_DISPATCH_FLOATING_TYPES(                                                           \
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    input.scalar_type(),                                                                  \
    "activation_kernel",                                                                  \
    [&] {                                                                                 \
      vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>(    \
        out.data_ptr<scalar_t>(),                                                         \
        input.data_ptr<scalar_t>(),                                                       \
        d);                                                                               \
    });

namespace vllm {

template<typename T>
__device__ __forceinline__ T gelu_new_kernel(const T& x) {
  const float x3 = (float) (x * x * x);
  const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3))));
  return ((T) 0.5) * x * (((T) 1.0) + t);
}

template<typename T>
__device__ __forceinline__ T gelu_fast_kernel(const T& x) {
  const float f = (float) x;
  const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x));
  return ((T) 0.5) * x * (((T) 1.0) + t);
}

} // namespace vllm

void gelu_new(
104
105
  torch::Tensor& out,     // [..., d]
  torch::Tensor& input)   // [..., d]
106
107
108
109
110
{
  LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
}

void gelu_fast(
111
112
  torch::Tensor& out,     // [..., d]
  torch::Tensor& input)   // [..., d]
113
114
115
{
  LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
}