Commit c62f8e9a authored by zhuwenwen's avatar zhuwenwen
Browse files

optimize rmsnorm kernel

parent b2dd1743
...@@ -122,11 +122,10 @@ endif() ...@@ -122,11 +122,10 @@ endif()
# the supported versions for the current language. # the supported versions for the current language.
# The final set of arches is stored in `VLLM_GPU_ARCHES`. # The final set of arches is stored in `VLLM_GPU_ARCHES`.
# #
#override_gpu_arches(VLLM_GPU_ARCHES override_gpu_arches(VLLM_GPU_ARCHES
# ${VLLM_GPU_LANG} ${VLLM_GPU_LANG}
# "${${VLLM_GPU_LANG}_SUPPORTED_ARCHS}") "${${VLLM_GPU_LANG}_SUPPORTED_ARCHS}")
set(VLLM_GPU_ARCHES "gfx928")
message(STATUS "${VLLM_GPU_ARCHES}")
# #
# Query torch for additional GPU compilation flags for the given # Query torch for additional GPU compilation flags for the given
# `VLLM_GPU_LANG`. # `VLLM_GPU_LANG`.
......
...@@ -20,7 +20,6 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention ...@@ -20,7 +20,6 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention
| Qwen2ForCausalLM | QWen2 | Yes | Yes | | Qwen2ForCausalLM | QWen2 | Yes | Yes |
| ChatGLMModel | chatglm2 | Yes | Yes | | ChatGLMModel | chatglm2 | Yes | Yes |
| ChatGLMModel | chatglm3 | Yes | Yes | | ChatGLMModel | chatglm3 | Yes | Yes |
| ChatGLMModel | glm-4 | Yes | Yes |
| BaiChuanForCausalLM | Baichuan-7B | Yes | Yes | | BaiChuanForCausalLM | Baichuan-7B | Yes | Yes |
| BaiChuanForCausalLM | Baichuan2-7B | Yes | Yes | | BaiChuanForCausalLM | Baichuan2-7B | Yes | Yes |
| InternLMForCausalLM | InternLM | Yes | Yes | | InternLMForCausalLM | InternLM | Yes | Yes |
...@@ -69,6 +68,8 @@ pip install gptq_kernel ...@@ -69,6 +68,8 @@ pip install gptq_kernel
2. 源码编译安装 2. 源码编译安装
VLLM_INSTALL_PUNICA_KERNELS=1 python3 setup.py install VLLM_INSTALL_PUNICA_KERNELS=1 python3 setup.py install
cd csrc/quantization/gptq
python setup.py install
``` ```
#### 运行基础环境准备 #### 运行基础环境准备
......
...@@ -124,8 +124,8 @@ inline __device__ float qk_dot_vpack_(const Vec (&q)[N], const Vec (&k)[N]) { ...@@ -124,8 +124,8 @@ inline __device__ float qk_dot_vpack_(const Vec (&q)[N], const Vec (&k)[N]) {
for (int ii = 1; ii < N; ++ii) { for (int ii = 1; ii < N; ++ii) {
qk_vec = fma(q[ii], k[ii], qk_vec); qk_vec = fma(q[ii], k[ii], qk_vec);
} }
float qk = sum(qk_vec);
// Finalize the reduction across lanes. // Finalize the reduction across lanes.
float qk = sum(qk_vec);
#pragma unroll #pragma unroll
for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
qk += VLLM_SHFL_XOR_SYNC(qk, mask); qk += VLLM_SHFL_XOR_SYNC(qk, mask);
......
#include <torch/all.h> #include <torch/all.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <c10/cuda/CUDAMathCompat.h>
#include <ATen/AccumulateType.h>
#include <THC/THCDeviceUtils.cuh>
#include "dispatch_utils.h" #include "dispatch_utils.h"
#include "reduction_utils.cuh" #include "reduction_utils.cuh"
#ifndef USE_ROCM #ifndef USE_ROCM
...@@ -288,22 +291,150 @@ fused_add_rms_norm_kernel( ...@@ -288,22 +291,150 @@ fused_add_rms_norm_kernel(
} // namespace vllm } // namespace vllm
template <typename T,int reducesize=C10_WARP_SIZE>
__inline__ __device__ T WarpReduceSum_NEW(T val) {
#pragma unroll
for (int offset = reducesize/2; offset > 0; offset >>= 1) {
val += WARP_SHFL_DOWN(val, offset);
}
return val;
}
template <typename T,int block_size=512>
__inline__ __device__ T BlockReduceSum_NEW(T val, T* shared) {
constexpr int share_size=block_size/C10_WARP_SIZE;
val = WarpReduceSum_NEW<T>(val);
if constexpr(block_size==C10_WARP_SIZE)
{
return val;
}
else{
const int lid = threadIdx.x % C10_WARP_SIZE;
const int wid = threadIdx.x / C10_WARP_SIZE;
__syncthreads();
if (lid == 0&&wid<share_size) {
shared[wid] = val;
}
__syncthreads();
if (wid == 0&&lid<share_size) {
val = WarpReduceSum_NEW<T,share_size>(shared[lid]);
}
return val;
}
}
template <typename scalar_t,typename T_ACC,int Vec=4,int block_size=512>
__global__ void fused_add_rms_kernel_eval(scalar_t* input,scalar_t* residual,scalar_t* gamma,int cols,T_ACC eps)
{
constexpr int share_size=block_size/C10_WARP_SIZE;
__shared__ T_ACC val_shared[share_size];
__shared__ T_ACC s_rstd;
T_ACC val=0;
int i=blockIdx.x;
int j=threadIdx.x;
int tcol=cols/Vec;
if(j>=tcol)return;
using LoadT = at::native::memory::aligned_vector<scalar_t, Vec>;
scalar_t intput_vec[Vec];
scalar_t residual_vec[Vec];
T_ACC trstd;
int idx = i * tcol + j;
idx*=Vec;
*(LoadT*)intput_vec = *(LoadT*)(input+idx);
*(LoadT*)residual_vec = *(LoadT*)(residual+idx);
#pragma unroll
for (int ii = 0; ii < Vec; ii++) {
residual_vec[ii]+=intput_vec[ii];
val += static_cast<T_ACC>(residual_vec[ii])*static_cast<T_ACC>(residual_vec[ii]);
}
val = BlockReduceSum_NEW<T_ACC,block_size>(val,val_shared);
if (j == 0) s_rstd=c10::cuda::compat::rsqrt(val/cols + eps);
__syncthreads();
trstd=s_rstd;
#pragma unroll
for(int ii=0;ii<Vec;ii++){
int jj=j*Vec+ii;
intput_vec[ii] = static_cast<T_ACC>(residual_vec[ii]) * trstd * static_cast<T_ACC>(gamma[jj]);
}
*(LoadT*)(residual+idx)=*(LoadT*)residual_vec;
*(LoadT*)(input+idx)=*(LoadT*)intput_vec;
}
template <typename scalar_t,typename T_ACC,int Vec=4,int block_size=512>
__global__ void fused_rms_kernel_eval(scalar_t* input,scalar_t* output,scalar_t* gamma,int cols,T_ACC eps)
{
constexpr int share_size=block_size/C10_WARP_SIZE;
__shared__ T_ACC val_shared[share_size];
__shared__ T_ACC s_rstd;
T_ACC val=0;
int i=blockIdx.x;
int j=threadIdx.x;
int tcol=cols/Vec;
if(j>=tcol)return;
using LoadT = at::native::memory::aligned_vector<scalar_t, Vec>;
scalar_t intput_vec[Vec];
T_ACC trstd;
int idx = i * tcol + j;
idx*=Vec;
*(LoadT*)intput_vec = *(LoadT*)(input+idx);
#pragma unroll
for (int ii = 0; ii < Vec; ii++) {
val += static_cast<T_ACC>(intput_vec[ii])*static_cast<T_ACC>(intput_vec[ii]);
}
val = BlockReduceSum_NEW<T_ACC,block_size>(val,val_shared);
if (j == 0) s_rstd=c10::cuda::compat::rsqrt(val/cols + eps);
__syncthreads();
trstd=s_rstd;
#pragma unroll
for(int ii=0;ii<Vec;ii++){
int jj=j*Vec+ii;
intput_vec[ii] = static_cast<T_ACC>(intput_vec[ii]) * trstd * static_cast<T_ACC>(gamma[jj]);
}
*(LoadT*)(output+idx)=*(LoadT*)intput_vec;
}
void rms_norm(torch::Tensor& out, // [..., hidden_size] void rms_norm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size] torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size] torch::Tensor& weight, // [hidden_size]
double epsilon) { double epsilon) {
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size; int num_tokens = input.numel() / hidden_size;
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if(hidden_size%16==0&&hidden_size>=2048&&hidden_size<=8192){
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
input.scalar_type(),
"fused_add_rms_norm_kernel",
[&] {
using T_ACC = at::acc_type<scalar_t, true>;
T_ACC eps = epsilon;
scalar_t* self_data = input.data_ptr<scalar_t>();
scalar_t* out_data =out.data_ptr<scalar_t>();
scalar_t* weight_data=weight.data_ptr<scalar_t>();
if(hidden_size==2048){
fused_rms_kernel_eval<scalar_t,T_ACC,2,1024><<<num_tokens, 1024, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
}
else if(hidden_size<=4096){
fused_rms_kernel_eval<scalar_t,T_ACC,4,1024><<<num_tokens, 1024, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
}
else{
fused_rms_kernel_eval<scalar_t,T_ACC,8,1024><<<num_tokens, 1024, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
}
});
}
else{
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024)); dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>( vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size); weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
}); });
}
} }
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ #define LAUNCH_FUSED_ADD_RMS_NORM(width) \
...@@ -316,13 +447,41 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size] ...@@ -316,13 +447,41 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
num_tokens, hidden_size); \ num_tokens, hidden_size); \
}); });
void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size] torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size] torch::Tensor& weight, // [hidden_size]
double epsilon) { double epsilon) {
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size; int num_tokens = input.numel() / hidden_size;
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if(hidden_size%16==0&&hidden_size>=2048&&hidden_size<=8192){
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
input.scalar_type(),
"fused_add_rms_norm_kernel",
[&] {
using T_ACC = at::acc_type<scalar_t, true>;
T_ACC eps = epsilon;
scalar_t* self_data = input.data_ptr<scalar_t>();
scalar_t* other_data =residual.data_ptr<scalar_t>();
scalar_t* weight_data=weight.data_ptr<scalar_t>();
if(hidden_size==2048){
fused_add_rms_kernel_eval<scalar_t,T_ACC,2,1024><<<num_tokens, 1024, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
else if(hidden_size<=4096){
fused_add_rms_kernel_eval<scalar_t,T_ACC,4,1024><<<num_tokens, 1024, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
else{
fused_add_rms_kernel_eval<scalar_t,T_ACC,8,1024><<<num_tokens, 1024, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
});
}
else{
dim3 grid(num_tokens); dim3 grid(num_tokens);
/* This kernel is memory-latency bound in many scenarios. /* This kernel is memory-latency bound in many scenarios.
When num_tokens is large, a smaller block size allows When num_tokens is large, a smaller block size allows
...@@ -330,8 +489,6 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] ...@@ -330,8 +489,6 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
hiding on global mem ops. */ hiding on global mem ops. */
const int max_block_size = (num_tokens < 256) ? 1024 : 256; const int max_block_size = (num_tokens < 256) ? 1024 : 256;
dim3 block(std::min(hidden_size, max_block_size)); dim3 block(std::min(hidden_size, max_block_size));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
/*If the tensor types are FP16/BF16, try to use the optimized kernel /*If the tensor types are FP16/BF16, try to use the optimized kernel
with packed + vectorized ops. with packed + vectorized ops.
Max optimization is achieved with a width-8 vector of FP16/BF16s Max optimization is achieved with a width-8 vector of FP16/BF16s
...@@ -349,4 +506,5 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] ...@@ -349,4 +506,5 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
} else { } else {
LAUNCH_FUSED_ADD_RMS_NORM(0); LAUNCH_FUSED_ADD_RMS_NORM(0);
} }
}
} }
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment