Commit b8c88ed3 authored by bianch's avatar bianch
Browse files

feat:optimize act_and_mul_kernel

parent dd0c8f49
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/all.h> #include <torch/all.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <cmath> #include <cmath>
...@@ -23,6 +24,64 @@ __global__ void act_and_mul_kernel( ...@@ -23,6 +24,64 @@ __global__ void act_and_mul_kernel(
} }
} }
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&), int VEC>
__global__ void act_and_mul_kernel_vectorize1(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const int d) {
using VecType = at::native::memory::aligned_vector<scalar_t, VEC>;
const int token_idx = blockIdx.x;
int idx = threadIdx.x * VEC;
if (idx < d) {
const int x_index = token_idx * 2 * d + idx;
const int y_index = token_idx * d + idx;
VecType* x1 = (VecType*)(input + x_index);
VecType* x2 = (VecType*)(input + x_index + d);
VecType* y = (VecType*)(out + y_index);
scalar_t r_x1[VEC];
scalar_t r_x2[VEC];
scalar_t r_y[VEC];
*(VecType*)r_x1 = *x1;
*(VecType*)r_x2 = *x2;
#pragma unroll
for (int i = 0; i < VEC; i++) {
const scalar_t t_x1 = VLLM_LDG(&r_x1[i]);
const scalar_t t_x2 = VLLM_LDG(&r_x2[i]);
r_y[i] = ACT_FN(t_x1) * t_x2;
}
*y = *(VecType*)r_y;
}
}
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&), int VEC>
__global__ void act_and_mul_kernel_vectorize2(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const int d) {
using VecType = at::native::memory::aligned_vector<scalar_t, VEC>;
const int token_idx = blockIdx.x;
int idx = threadIdx.x * VEC;
for (; idx < d; idx += blockDim.x * VEC) {
const int x_index = token_idx * 2 * d + idx;
const int y_index = token_idx * d + idx;
VecType* x1 = (VecType*)(input + x_index);
VecType* x2 = (VecType*)(input + x_index + d);
VecType* y = (VecType*)(out + y_index);
scalar_t r_x1[VEC];
scalar_t r_x2[VEC];
scalar_t r_y[VEC];
*(VecType*)r_x1 = *x1;
*(VecType*)r_x2 = *x2;
#pragma unroll
for (int i = 0; i < VEC; i++) {
const scalar_t t_x1 = VLLM_LDG(&r_x1[i]);
const scalar_t t_x2 = VLLM_LDG(&r_x2[i]);
r_y[i] = ACT_FN(t_x1) * t_x2;
}
*y = *(VecType*)r_y;
}
}
template <typename T> template <typename T>
__device__ __forceinline__ T silu_kernel(const T& x) { __device__ __forceinline__ T silu_kernel(const T& x) {
// x * sigmoid(x) // x * sigmoid(x)
...@@ -54,19 +113,36 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { ...@@ -54,19 +113,36 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
} // namespace vllm } // namespace vllm
// Launch activation and gating kernel. #define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ int d = input.size(-1) / 2; \
int d = input.size(-1) / 2; \ int64_t num_tokens = input.numel() / input.size(-1); \
int64_t num_tokens = input.numel() / input.size(-1); \ dim3 grid(num_tokens); \
dim3 grid(num_tokens); \ dim3 block(std::min(d, 1024)); \
dim3 block(std::min(d, 1024)); \ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ VLLM_DISPATCH_FLOATING_TYPES( \
VLLM_DISPATCH_FLOATING_TYPES( \ input.scalar_type(), "act_and_mul_kernel", [&] { \
input.scalar_type(), "act_and_mul_kernel", [&] { \ if (d <= 512) { \
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \ vllm::act_and_mul_kernel_vectorize1<scalar_t, KERNEL<scalar_t>, 2> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \ <<<grid, 256, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \ input.data_ptr<scalar_t>(), d); \
} else if (d <= 1024) { \
vllm::act_and_mul_kernel_vectorize1<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 128, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else if (d <= 2048) { \
vllm::act_and_mul_kernel_vectorize1<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 256, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else if (d <= 4096) { \
vllm::act_and_mul_kernel_vectorize1<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 512, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else { \
vllm::act_and_mul_kernel_vectorize2<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 1024, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} \
}); });
void silu_and_mul(torch::Tensor& out, // [..., d] void silu_and_mul(torch::Tensor& out, // [..., d]
...@@ -147,4 +223,4 @@ void gelu_fast(torch::Tensor& out, // [..., d] ...@@ -147,4 +223,4 @@ void gelu_fast(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., d] torch::Tensor& input) // [..., d]
{ {
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel); LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
} }
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment