#include #include #include #include #include #include "cuda_compat.h" #include "../dispatch_utils.h" namespace vllm { template __device__ __forceinline__ scalar_t compute(const scalar_t& x, const scalar_t& y) { return act_first ? ACT_FN(x) * y : x * ACT_FN(y); } // Activation and gating kernel template. template __global__ void act_and_mul_kernel( scalar_t* __restrict__ out, // [..., d] const scalar_t* __restrict__ input, // [..., 2, d] const int d) { const int64_t token_idx = blockIdx.x; for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]); out[token_idx * d + idx] = compute(x, y); } } template __global__ void act_and_mul_kernel_opt1( scalar_t* __restrict__ out, // [..., d] const scalar_t* __restrict__ input, // [..., 2, d] const int d) { using VecType = at::native::memory::aligned_vector; const int64_t token_idx= blockIdx.x; int idx = threadIdx.x * VEC; if (idx < d) { const int64_t x_index = token_idx * 2 * d + idx; const int64_t y_index = token_idx * d + idx; VecType* x1 = (VecType*)(input + x_index); VecType* x2 = (VecType*)(input + x_index + d); VecType* y = (VecType*)(out + y_index); scalar_t r_x1[VEC]; scalar_t r_x2[VEC]; scalar_t r_y[VEC]; *(VecType*)r_x1 = *x1; *(VecType*)r_x2 = *x2; #pragma unroll for (int i = 0; i < VEC; i++) { r_y[i] = ACT_FN(r_x1[i]) * r_x2[i]; } *y = *(VecType*)r_y; } } template __global__ void act_and_mul_kernel_opt2( scalar_t* __restrict__ out, // [..., d] const scalar_t* __restrict__ input, // [..., 2, d] const int d) { using VecType = at::native::memory::aligned_vector; const int64_t token_idx = blockIdx.x; int idx = threadIdx.x * VEC; for (; idx < d; idx += blockDim.x * VEC) { const int64_t x_index = token_idx * 2 * d + idx; const int64_t y_index = token_idx * d + idx; VecType* x1 = (VecType*)(input + x_index); VecType* x2 = (VecType*)(input + x_index + d); VecType* y = (VecType*)(out + y_index); scalar_t r_x1[VEC]; scalar_t r_x2[VEC]; scalar_t r_y[VEC]; *(VecType*)r_x1 = *x1; *(VecType*)r_x2 = *x2; #pragma unroll for (int i = 0; i < VEC; i++) { r_y[i] = ACT_FN(r_x1[i]) * r_x2[i]; } *y = *(VecType*)r_y; } } template __device__ __forceinline__ T silu_kernel(const T& x) { // x * sigmoid(x) return (T)(((float)x) / (1.0f + expf((float)-x))); } template __device__ __forceinline__ T gelu_kernel(const T& x) { // Equivalent to PyTorch GELU with 'none' approximation. // Refer to: // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38 const float f = (float)x; constexpr float ALPHA = M_SQRT1_2; return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA))); } template __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { // Equivalent to PyTorch GELU with 'tanh' approximation. // Refer to: // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30 const float f = (float)x; constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f; constexpr float KAPPA = 0.044715; float x_cube = f * f * f; float inner = BETA * (f + KAPPA * x_cube); return (T)(0.5f * f * (1.0f + ::tanhf(inner))); } } // namespace vllm // Launch activation and gating kernel. // Use ACT_FIRST (bool) indicating whether to apply the activation function // first. #define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, ACT_FIRST) \ int d = input.size(-1) / 2; \ int64_t num_tokens = input.numel() / input.size(-1); \ dim3 grid(num_tokens); \ dim3 block(std::min(d, 1024)); \ if (num_tokens == 0) { \ return; \ } \ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ VLLM_DISPATCH_FLOATING_TYPES( \ input.scalar_type(), "act_and_mul_kernel", [&] { \ if (0 == d % 8 && d <= 16384) { \ if (d <= 512) { \ vllm::act_and_mul_kernel_opt1, 2, ACT_FIRST> \ <<>>(out.data_ptr(), \ input.data_ptr(), d); \ } else if (d <= 1024) { \ vllm::act_and_mul_kernel_opt1, 8, ACT_FIRST> \ <<>>(out.data_ptr(), \ input.data_ptr(), d); \ } else if (d <= 2048) { \ vllm::act_and_mul_kernel_opt1, 8, ACT_FIRST> \ <<>>(out.data_ptr(), \ input.data_ptr(), d); \ } else if (d <= 4096) { \ vllm::act_and_mul_kernel_opt1, 8, ACT_FIRST> \ <<>>(out.data_ptr(), \ input.data_ptr(), d); \ } else { \ vllm::act_and_mul_kernel_opt2, 8, ACT_FIRST> \ <<>>(out.data_ptr(), \ input.data_ptr(), d); \ } \ } else { \ vllm::act_and_mul_kernel, ACT_FIRST> \ <<>>(out.data_ptr(), \ input.data_ptr(), d); \ } \ }); void silu_and_mul_opt(torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d] { LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, true); } // void mul_and_silu_opt(torch::Tensor& out, // [..., d] // torch::Tensor& input) // [..., 2 * d] // { // // The difference between mul_and_silu and silu_and_mul is that mul_and_silu // // applies the silu to the latter half of the input. // LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, false); // } void gelu_and_mul_opt(torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d] { LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel, true); } void gelu_tanh_and_mul_opt(torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d] { LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel, true); }