activation_kernels_opt.cu 8.17 KB
Newer Older
raojy's avatar
raojy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/native/cuda/MemoryAccess.cuh>

#include <cmath>

#include "cuda_compat.h"
#include "../dispatch_utils.h"

namespace vllm {

template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
          bool act_first>
__device__ __forceinline__ scalar_t compute(const scalar_t& x,
                                            const scalar_t& y) {
  return act_first ? ACT_FN(x) * y : x * ACT_FN(y);
}
// Activation and gating kernel template.
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
          bool act_first>
__global__ void act_and_mul_kernel(
    scalar_t* __restrict__ out,          // [..., d]
    const scalar_t* __restrict__ input,  // [..., 2, d]
    const int d) {
  const int64_t token_idx = blockIdx.x;
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
    const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
    const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
    out[token_idx * d + idx] = compute<scalar_t, ACT_FN, act_first>(x, y);
  }
}

template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&), int VEC,
          bool act_first>
__global__ void act_and_mul_kernel_opt1(
    scalar_t* __restrict__ out,          // [..., d]
    const scalar_t* __restrict__ input,  // [..., 2, d]
    const int d) {
  using VecType = at::native::memory::aligned_vector<scalar_t, VEC>;
  const int64_t token_idx= blockIdx.x;
  int idx = threadIdx.x * VEC;
  if (idx < d) {
    const int64_t x_index = token_idx * 2 * d + idx;
    const int64_t y_index = token_idx * d + idx;
    VecType* x1 = (VecType*)(input + x_index);
    VecType* x2 = (VecType*)(input + x_index + d);
    VecType* y = (VecType*)(out + y_index);
    scalar_t r_x1[VEC];
    scalar_t r_x2[VEC];
    scalar_t r_y[VEC];
    *(VecType*)r_x1 = *x1;
    *(VecType*)r_x2 = *x2;
#pragma unroll
    for (int i = 0; i < VEC; i++) {
      r_y[i] = ACT_FN(r_x1[i]) * r_x2[i];
    }
    *y = *(VecType*)r_y;
  }
}

template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&), int VEC,
          bool act_first>
__global__ void act_and_mul_kernel_opt2(
    scalar_t* __restrict__ out,          // [..., d]
    const scalar_t* __restrict__ input,  // [..., 2, d]
    const int d) {
  using VecType = at::native::memory::aligned_vector<scalar_t, VEC>;
  const int64_t token_idx = blockIdx.x;
  int idx = threadIdx.x * VEC;
  for (; idx < d; idx += blockDim.x * VEC) {
    const int64_t x_index = token_idx * 2 * d + idx;
    const int64_t y_index = token_idx * d + idx;
    VecType* x1 = (VecType*)(input + x_index);
    VecType* x2 = (VecType*)(input + x_index + d);
    VecType* y = (VecType*)(out + y_index);
    scalar_t r_x1[VEC];
    scalar_t r_x2[VEC];
    scalar_t r_y[VEC];
    *(VecType*)r_x1 = *x1;
    *(VecType*)r_x2 = *x2;
#pragma unroll
    for (int i = 0; i < VEC; i++) {
      r_y[i] = ACT_FN(r_x1[i]) * r_x2[i];
    }
    *y = *(VecType*)r_y;
  }
}

template <typename T>
__device__ __forceinline__ T silu_kernel(const T& x) {
  // x * sigmoid(x)
  return (T)(((float)x) / (1.0f + expf((float)-x)));
}

template <typename T>
__device__ __forceinline__ T gelu_kernel(const T& x) {
  // Equivalent to PyTorch GELU with 'none' approximation.
  // Refer to:
  // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
  const float f = (float)x;
  constexpr float ALPHA = M_SQRT1_2;
  return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA)));
}

template <typename T>
__device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
  // Equivalent to PyTorch GELU with 'tanh' approximation.
  // Refer to:
  // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30
  const float f = (float)x;
  constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f;
  constexpr float KAPPA = 0.044715;
  float x_cube = f * f * f;
  float inner = BETA * (f + KAPPA * x_cube);
  return (T)(0.5f * f * (1.0f + ::tanhf(inner)));
}
    
}  // namespace vllm


// Launch activation and gating kernel.
// Use ACT_FIRST (bool) indicating whether to apply the activation function
// first.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, ACT_FIRST)                                  \
  int d = input.size(-1) / 2;                                                  \
  int64_t num_tokens = input.numel() / input.size(-1);                         \
  dim3 grid(num_tokens);                                                       \
  dim3 block(std::min(d, 1024));                                               \
  if (num_tokens == 0) {                                                       \
    return;                                                                    \
  }                                                                            \
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));            \
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();                \
  VLLM_DISPATCH_FLOATING_TYPES(                                                \
      input.scalar_type(), "act_and_mul_kernel", [&] {                         \
        if (0 == d % 8 && d <= 16384) {                                        \
          if (d <= 512) {                                                      \
            vllm::act_and_mul_kernel_opt1<scalar_t, KERNEL<scalar_t>, 2, ACT_FIRST> \
                <<<grid, 256, 0, stream>>>(out.data_ptr<scalar_t>(),           \
                                           input.data_ptr<scalar_t>(), d);     \
          } else if (d <= 1024) {                                              \
            vllm::act_and_mul_kernel_opt1<scalar_t, KERNEL<scalar_t>, 8, ACT_FIRST> \
                <<<grid, 128, 0, stream>>>(out.data_ptr<scalar_t>(),           \
                                           input.data_ptr<scalar_t>(), d);     \
          } else if (d <= 2048) {                                              \
            vllm::act_and_mul_kernel_opt1<scalar_t, KERNEL<scalar_t>, 8, ACT_FIRST> \
                <<<grid, 256, 0, stream>>>(out.data_ptr<scalar_t>(),           \
                                           input.data_ptr<scalar_t>(), d);     \
          } else if (d <= 4096) {                                              \
            vllm::act_and_mul_kernel_opt1<scalar_t, KERNEL<scalar_t>, 8, ACT_FIRST> \
                <<<grid, 512, 0, stream>>>(out.data_ptr<scalar_t>(),           \
                                           input.data_ptr<scalar_t>(), d);     \
          } else {                                                             \
            vllm::act_and_mul_kernel_opt2<scalar_t, KERNEL<scalar_t>, 8, ACT_FIRST> \
                <<<grid, 1024, 0, stream>>>(out.data_ptr<scalar_t>(),          \
                                            input.data_ptr<scalar_t>(), d);    \
          }                                                                    \
        } else {                                                               \
              vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>, ACT_FIRST>             \
                  <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(),       \
                                              input.data_ptr<scalar_t>(), d);  \
        }                                                                      \
      });

void silu_and_mul_opt(torch::Tensor& out,    // [..., d]
                  torch::Tensor& input)  // [..., 2 * d]
{
  LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, true);
}

// void mul_and_silu_opt(torch::Tensor& out,    // [..., d]
//                   torch::Tensor& input)  // [..., 2 * d]
// {
//   // The difference between mul_and_silu and silu_and_mul is that mul_and_silu
//   // applies the silu to the latter half of the input.
//   LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, false);
// }

void gelu_and_mul_opt(torch::Tensor& out,    // [..., d]
                  torch::Tensor& input)  // [..., 2 * d]
{
  LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel, true);
}

void gelu_tanh_and_mul_opt(torch::Tensor& out,    // [..., d]
                       torch::Tensor& input)  // [..., 2 * d]
{
  LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel, true);
}