Commit 1e77d04e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.5.0-dtk24.04.1' into v0.5.2-dtk24.04.1

# Conflicts:
#	csrc/attention/attention_kernels.cu
#	csrc/attention/attention_utils.cuh
#	csrc/layernorm_kernels.cu
#	vllm/model_executor/layers/linear.py
#	vllm/model_executor/models/baichuan.py
#	vllm/model_executor/models/llama.py
parents 6fa22430 c62f8e9a
...@@ -5,11 +5,13 @@ project(vllm_extensions LANGUAGES CXX) ...@@ -5,11 +5,13 @@ project(vllm_extensions LANGUAGES CXX)
# CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py) # CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py)
set(VLLM_TARGET_DEVICE "cuda" CACHE STRING "Target device backend for vLLM") set(VLLM_TARGET_DEVICE "cuda" CACHE STRING "Target device backend for vLLM")
set(CMAKE_BUILD_TYPE "Release")
message(STATUS "Build type: ${CMAKE_BUILD_TYPE}") message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
message(STATUS "Target device: ${VLLM_TARGET_DEVICE}") message(STATUS "Target device: ${VLLM_TARGET_DEVICE}")
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake) include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
add_compile_options(-w)
# #
# Supported python versions. These versions will be searched in order, the # Supported python versions. These versions will be searched in order, the
# first match will be selected. These should be kept in sync with setup.py. # first match will be selected. These should be kept in sync with setup.py.
...@@ -116,7 +118,7 @@ endif() ...@@ -116,7 +118,7 @@ endif()
override_gpu_arches(VLLM_GPU_ARCHES override_gpu_arches(VLLM_GPU_ARCHES
${VLLM_GPU_LANG} ${VLLM_GPU_LANG}
"${${VLLM_GPU_LANG}_SUPPORTED_ARCHS}") "${${VLLM_GPU_LANG}_SUPPORTED_ARCHS}")
# #
# Query torch for additional GPU compilation flags for the given # Query torch for additional GPU compilation flags for the given
# `VLLM_GPU_LANG`. # `VLLM_GPU_LANG`.
...@@ -143,8 +145,10 @@ set(VLLM_EXT_SRC ...@@ -143,8 +145,10 @@ set(VLLM_EXT_SRC
"csrc/cache_kernels.cu" "csrc/cache_kernels.cu"
"csrc/attention/attention_kernels.cu" "csrc/attention/attention_kernels.cu"
"csrc/pos_encoding_kernels.cu" "csrc/pos_encoding_kernels.cu"
"csrc/pos_encoding_tgi_kernels.cu"
"csrc/activation_kernels.cu" "csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu" "csrc/layernorm_kernels.cu"
"csrc/transpose_kernels.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu" "csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/gptq/q_gemm.cu" "csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu" "csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
......
...@@ -15,12 +15,17 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention ...@@ -15,12 +15,17 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention
| LlamaForCausalLM | LLaMA-3 | Yes | Yes | | LlamaForCausalLM | LLaMA-3 | Yes | Yes |
| LlamaForCausalLM | Codellama | Yes | Yes | | LlamaForCausalLM | Codellama | Yes | Yes |
| QWenLMHeadModel | QWen | Yes | Yes | | QWenLMHeadModel | QWen | Yes | Yes |
| Qwen2ForCausalLM | QWen1.5 | Yes | Yes |
| Qwen2ForCausalLM | CodeQwen1.5 | Yes | Yes |
| Qwen2ForCausalLM | QWen2 | Yes | Yes |
| ChatGLMModel | chatglm2 | Yes | Yes |
| ChatGLMModel | chatglm3 | Yes | Yes |
| BaiChuanForCausalLM | Baichuan-7B | Yes | Yes | | BaiChuanForCausalLM | Baichuan-7B | Yes | Yes |
| BaiChuanForCausalLM | Baichuan2-7B | Yes | Yes | | BaiChuanForCausalLM | Baichuan2-7B | Yes | Yes |
| ChatGLMModel | chatglm2-6b | Yes | Yes |
| ChatGLMModel | chatglm3-6b | Yes | Yes |
| InternLMForCausalLM | InternLM | Yes | Yes | | InternLMForCausalLM | InternLM | Yes | Yes |
| InternLM2ForCausalLM | InternLM2 | Yes | Yes | | InternLM2ForCausalLM | InternLM2 | Yes | Yes |
| LlamaForCausalLM | deepseek | Yes | Yes |
| DeepseekV2ForCausalLM | DeepSeek-V2 | Yes | Yes |
| LlamaForCausalLM | Yi | Yes | Yes | | LlamaForCausalLM | Yi | Yes | Yes |
| MixtralForCausalLM | Mixtral-8x7B | Yes | Yes | | MixtralForCausalLM | Mixtral-8x7B | Yes | Yes |
...@@ -56,9 +61,15 @@ git clone http://developer.hpccube.com/codes/OpenDAS/vllm.git # 根据需要的 ...@@ -56,9 +61,15 @@ git clone http://developer.hpccube.com/codes/OpenDAS/vllm.git # 根据需要的
VLLM_INSTALL_PUNICA_KERNELS=1 python setup.py bdist_wheel VLLM_INSTALL_PUNICA_KERNELS=1 python setup.py bdist_wheel
cd dist cd dist
pip install vllm* pip install vllm*
cd csrc/quantization/gptq
python setup.py bdist_wheel
cd dist
pip install gptq_kernel
2. 源码编译安装 2. 源码编译安装
VLLM_INSTALL_PUNICA_KERNELS=1 python3 setup.py install VLLM_INSTALL_PUNICA_KERNELS=1 python3 setup.py install
cd csrc/quantization/gptq
python setup.py install
``` ```
#### 运行基础环境准备 #### 运行基础环境准备
......
...@@ -117,6 +117,10 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) ...@@ -117,6 +117,10 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
"import torch.utils.cpp_extension as t; print(';'.join(t.COMMON_HIP_FLAGS + t.COMMON_HIPCC_FLAGS))" "import torch.utils.cpp_extension as t; print(';'.join(t.COMMON_HIP_FLAGS + t.COMMON_HIPCC_FLAGS))"
"Failed to determine torch nvcc compiler flags") "Failed to determine torch nvcc compiler flags")
list(REMOVE_ITEM GPU_FLAGS
"-DUSE_ROCM=1"
)
list(APPEND GPU_FLAGS list(APPEND GPU_FLAGS
"-DUSE_ROCM" "-DUSE_ROCM"
# "-DENABLE_FP8" # "-DENABLE_FP8"
...@@ -124,7 +128,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) ...@@ -124,7 +128,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
"-U__HIP_NO_HALF_OPERATORS__" "-U__HIP_NO_HALF_OPERATORS__"
"-fno-gpu-rdc" "-fno-gpu-rdc"
"--gpu-max-threads-per-block=1024") "--gpu-max-threads-per-block=1024")
message(STATUS "${GPU_FLAGS}")
endif() endif()
set(${OUT_GPU_FLAGS} ${GPU_FLAGS} PARENT_SCOPE) set(${OUT_GPU_FLAGS} ${GPU_FLAGS} PARENT_SCOPE)
endfunction() endfunction()
......
This diff is collapsed.
...@@ -26,17 +26,104 @@ ...@@ -26,17 +26,104 @@
namespace vllm { namespace vllm {
// Q*K^T operation. inline __device__ void v_dot2_f32_f16(float& a, const uint32_t & b,const uint32_t & c) {
asm volatile("v_dot2_f32_f16 %0, %1, %2, %0;": "=v"(a): "v"(b), "v"(c), "0"(a));
}
inline __device__ void v_pk_fma_f16(uint32_t& a, const uint32_t & b,const uint32_t & c){
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;": "=v"(a) : "v"(b), "v"(c), "v"(a));
}
inline __device__ void ds_read_b128(uint4& a, uint32_t offset){
asm volatile("ds_read_b128 %0 %1;": "=v" (a): "v" (offset));
}
inline __device__ void ds_read_b128_sync(uint4& a, uint32_t offset){
asm volatile("ds_read_b128 %0 %1\ns_waitcnt lgkmcnt(1);": "=v" (a): "v" (offset));
}
inline __device__ void lgkmcnt0(){
asm volatile("s_waitcnt lgkmcnt(0);");
}
__device__ inline size_t __nv_cvta_generic_to_shared_impl(const void *__ptr) {
return (size_t)(void __attribute__((address_space(3))) *)__ptr;
}
inline __device__ void v_dot2_f32_f16(float& a,const uint2 & b,const uint2 & c) {
v_dot2_f32_f16(a, b.x, c.x);
v_dot2_f32_f16(a, b.y, c.y);
}
inline __device__ void v_dot2_f32_f16(float& a,const uint4 & b,const uint4 & c) {
v_dot2_f32_f16(a, b.x, c.x);
v_dot2_f32_f16(a, b.y, c.y);
v_dot2_f32_f16(a, b.z, c.z);
v_dot2_f32_f16(a, b.w, c.w);
}
inline __device__ float add_half2(uint32_t a){
union {
uint32_t u32;
half u16[2];
} tmp;
tmp.u32=a;
return static_cast<float>(tmp.u16[0]+tmp.u16[1]);
}
inline __device__ void v_pk_fma_f16x8(float& a,const uint4 & b,const uint4 & c) {
uint32_t tmp = mul<uint32_t, uint32_t, uint32_t>(b.x,c.x);
v_pk_fma_f16(tmp,b.y,c.y);
v_pk_fma_f16(tmp,b.z,c.z);
v_pk_fma_f16(tmp,b.w,c.w);
a+=add_half2(tmp);
}
// Q*K^T operation. fp16
// template <int THREAD_GROUP_SIZE, typename Vec, int N, typename scalar_t, std::enable_if_t<std::is_same<scalar_t, uint16_t>::value, int> = 0>
template <int THREAD_GROUP_SIZE, typename Vec, int N> template <int THREAD_GROUP_SIZE, typename Vec, int N>
inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
float qk =0;
// uint32_t offset = __nv_cvta_generic_to_shared_impl(q);
// const uint4 *k_ptr= reinterpret_cast<const uint4 *>(k);
// // Compute the parallel products for Q*K^T (treat vector lanes separately).
// constexpr int loop=N*sizeof(Vec)/16/2;
// uint4 qt[2];
// #pragma unroll
// for (int ii = 0; ii < loop; ++ii) {
// ds_read_b128(qt[0],offset+16*ii*2);
// ds_read_b128_sync(qt[1],offset+16*(ii*2+1));
// v_dot2_f32_f16(qk,qt[0],k_ptr[ii*2]);
// // v_pk_fma_f16x8(qk,qt[0],k_ptr[ii*2]);
// lgkmcnt0();
// v_dot2_f32_f16(qk,qt[1],k_ptr[ii*2+1]);
// // v_pk_fma_f16x8(qk,qt[1],k_ptr[ii*2+1]);
// }
#pragma unroll
for (int ii = 0; ii < N; ++ii) {
v_dot2_f32_f16(qk,q[ii],k[ii]);
}
// Finalize the reduction across lanes.
#pragma unroll
for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
qk += VLLM_SHFL_XOR_SYNC(qk, mask);
}
return qk;
}
// Q*K^T operation. //bf16
// template <int THREAD_GROUP_SIZE, typename Vec, int N, typename scalar_t, std::enable_if_t<!std::is_same<scalar_t, uint16_t>::value, int> = 0>
template <int THREAD_GROUP_SIZE, typename Vec, int N>
inline __device__ float qk_dot_vpack_(const Vec (&q)[N], const Vec (&k)[N]) {
using A_vec = typename FloatVec<Vec>::Type; using A_vec = typename FloatVec<Vec>::Type;
// Compute the parallel products for Q*K^T (treat vector lanes separately).
A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]); A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]);
#pragma unroll #pragma unroll
for (int ii = 1; ii < N; ++ii) { for (int ii = 1; ii < N; ++ii) {
qk_vec = fma(q[ii], k[ii], qk_vec); qk_vec = fma(q[ii], k[ii], qk_vec);
} }
// Finalize the reduction across lanes. // Finalize the reduction across lanes.
float qk = sum(qk_vec); float qk = sum(qk_vec);
#pragma unroll #pragma unroll
...@@ -46,12 +133,17 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { ...@@ -46,12 +133,17 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
return qk; return qk;
} }
template <typename T, int THREAD_GROUP_SIZE> template <typename T, int THREAD_GROUP_SIZE>
struct Qk_dot { struct Qk_dot {
template <typename Vec, int N> template <typename Vec, int N>
static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) { static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) {
return qk_dot_<THREAD_GROUP_SIZE>(q, k); return qk_dot_<THREAD_GROUP_SIZE>(q, k);
} }
// template <typename Vec, int N>
// static inline __device__ float qk_dot_vpack(const Vec (&q)[N], const Vec (&k)[N]) {
// return qk_dot_vpack_<THREAD_GROUP_SIZE>(q, k);
// }
}; };
} // namespace vllm } // namespace vllm
\ No newline at end of file
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr static bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
#define OPT_SWITCH(COND, ...) \
[&] { \
if (COND) { \
constexpr static int opt = 1; \
return __VA_ARGS__(); \
} else { \
constexpr static int opt = 2; \
return __VA_ARGS__(); \
} \
}()
#define NUM_THREADS_SWITCH(NUM_THREAD, ...) \
[&] { \
if (NUM_THREAD == 256) { \
constexpr static int NUM_THREADS = 256; \
return __VA_ARGS__(); \
} else { \
constexpr static int NUM_THREADS = 128; \
return __VA_ARGS__(); \
} \
}()
// #define HEADSIZE_SWITCH(HEADDIM, ...) \
// [&] { \
// if (HEADDIM == 64) { \
// constexpr static int HEAD_SIZE = 64; \
// return __VA_ARGS__(); \
// } else if (HEADDIM == 80) { \
// constexpr static int HEAD_SIZE = 80; \
// return __VA_ARGS__(); \
// } else if (HEADDIM == 96) { \
// constexpr static int HEAD_SIZE = 96; \
// return __VA_ARGS__(); \
// } else if (HEADDIM == 112) { \
// constexpr static int HEAD_SIZE = 112; \
// return __VA_ARGS__(); \
// } else if (HEADDIM == 128) { \
// constexpr static int HEAD_SIZE = 128; \
// return __VA_ARGS__(); \
// } else if (HEADDIM == 256) { \
// constexpr static int HEAD_SIZE = 256; \
// return __VA_ARGS__(); \
// } \
// else { \
// TORCH_CHECK(false, "Unsupported head size: ", HEADDIM);\
// } \
// }()
#define HEADSIZE_SWITCH(HEADDIM, ...) \
[&] { \
if (HEADDIM == 128) { \
constexpr static int HEAD_SIZE = 128; \
return __VA_ARGS__(); \
} else { \
TORCH_CHECK(false, "Unsupported head size: ", HEADDIM);\
} \
}()
#define REUSEKV_SWITCH(num_blocks , ...) \
[&] { \
if (num_heads % 2 == 0 && num_heads / num_kv_heads >= 4 && num_blocks >= 1200){ \
constexpr static int REUSE_KV_TIMES = 4; \
return __VA_ARGS__(); \
} else if (num_heads / num_kv_heads >= 2 && num_blocks >= 1200){\
constexpr static int REUSE_KV_TIMES = 2; \
return __VA_ARGS__(); \
} else { \
constexpr static int REUSE_KV_TIMES = 1; \
return __VA_ARGS__(); \
} \
}()
#define REUSEKV_SWITCH_V1(num_blocks , ...) \
[&] { \
if (num_heads > num_kv_heads && num_blocks >= 1200){ \
constexpr static int REUSE_KV_TIMES = 2; \
return __VA_ARGS__(); \
} else { \
constexpr static int REUSE_KV_TIMES = 1; \
return __VA_ARGS__(); \
} \
}()
#include <torch/all.h> #include <torch/all.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <c10/cuda/CUDAMathCompat.h>
#include <ATen/AccumulateType.h>
#include <THC/THCDeviceUtils.cuh>
#include "dispatch_utils.h" #include "dispatch_utils.h"
#include "reduction_utils.cuh" #include "reduction_utils.cuh"
#ifndef USE_ROCM #ifndef USE_ROCM
...@@ -288,22 +291,150 @@ fused_add_rms_norm_kernel( ...@@ -288,22 +291,150 @@ fused_add_rms_norm_kernel(
} // namespace vllm } // namespace vllm
template <typename T,int reducesize=C10_WARP_SIZE>
__inline__ __device__ T WarpReduceSum_NEW(T val) {
#pragma unroll
for (int offset = reducesize/2; offset > 0; offset >>= 1) {
val += WARP_SHFL_DOWN(val, offset);
}
return val;
}
template <typename T,int block_size=512>
__inline__ __device__ T BlockReduceSum_NEW(T val, T* shared) {
constexpr int share_size=block_size/C10_WARP_SIZE;
val = WarpReduceSum_NEW<T>(val);
if constexpr(block_size==C10_WARP_SIZE)
{
return val;
}
else{
const int lid = threadIdx.x % C10_WARP_SIZE;
const int wid = threadIdx.x / C10_WARP_SIZE;
__syncthreads();
if (lid == 0&&wid<share_size) {
shared[wid] = val;
}
__syncthreads();
if (wid == 0&&lid<share_size) {
val = WarpReduceSum_NEW<T,share_size>(shared[lid]);
}
return val;
}
}
template <typename scalar_t,typename T_ACC,int Vec=4,int block_size=512>
__global__ void fused_add_rms_kernel_eval(scalar_t* input,scalar_t* residual,scalar_t* gamma,int cols,T_ACC eps)
{
constexpr int share_size=block_size/C10_WARP_SIZE;
__shared__ T_ACC val_shared[share_size];
__shared__ T_ACC s_rstd;
T_ACC val=0;
int i=blockIdx.x;
int j=threadIdx.x;
int tcol=cols/Vec;
if(j>=tcol)return;
using LoadT = at::native::memory::aligned_vector<scalar_t, Vec>;
scalar_t intput_vec[Vec];
scalar_t residual_vec[Vec];
T_ACC trstd;
int idx = i * tcol + j;
idx*=Vec;
*(LoadT*)intput_vec = *(LoadT*)(input+idx);
*(LoadT*)residual_vec = *(LoadT*)(residual+idx);
#pragma unroll
for (int ii = 0; ii < Vec; ii++) {
residual_vec[ii]+=intput_vec[ii];
val += static_cast<T_ACC>(residual_vec[ii])*static_cast<T_ACC>(residual_vec[ii]);
}
val = BlockReduceSum_NEW<T_ACC,block_size>(val,val_shared);
if (j == 0) s_rstd=c10::cuda::compat::rsqrt(val/cols + eps);
__syncthreads();
trstd=s_rstd;
#pragma unroll
for(int ii=0;ii<Vec;ii++){
int jj=j*Vec+ii;
intput_vec[ii] = static_cast<T_ACC>(residual_vec[ii]) * trstd * static_cast<T_ACC>(gamma[jj]);
}
*(LoadT*)(residual+idx)=*(LoadT*)residual_vec;
*(LoadT*)(input+idx)=*(LoadT*)intput_vec;
}
template <typename scalar_t,typename T_ACC,int Vec=4,int block_size=512>
__global__ void fused_rms_kernel_eval(scalar_t* input,scalar_t* output,scalar_t* gamma,int cols,T_ACC eps)
{
constexpr int share_size=block_size/C10_WARP_SIZE;
__shared__ T_ACC val_shared[share_size];
__shared__ T_ACC s_rstd;
T_ACC val=0;
int i=blockIdx.x;
int j=threadIdx.x;
int tcol=cols/Vec;
if(j>=tcol)return;
using LoadT = at::native::memory::aligned_vector<scalar_t, Vec>;
scalar_t intput_vec[Vec];
T_ACC trstd;
int idx = i * tcol + j;
idx*=Vec;
*(LoadT*)intput_vec = *(LoadT*)(input+idx);
#pragma unroll
for (int ii = 0; ii < Vec; ii++) {
val += static_cast<T_ACC>(intput_vec[ii])*static_cast<T_ACC>(intput_vec[ii]);
}
val = BlockReduceSum_NEW<T_ACC,block_size>(val,val_shared);
if (j == 0) s_rstd=c10::cuda::compat::rsqrt(val/cols + eps);
__syncthreads();
trstd=s_rstd;
#pragma unroll
for(int ii=0;ii<Vec;ii++){
int jj=j*Vec+ii;
intput_vec[ii] = static_cast<T_ACC>(intput_vec[ii]) * trstd * static_cast<T_ACC>(gamma[jj]);
}
*(LoadT*)(output+idx)=*(LoadT*)intput_vec;
}
void rms_norm(torch::Tensor& out, // [..., hidden_size] void rms_norm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size] torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size] torch::Tensor& weight, // [hidden_size]
double epsilon) { double epsilon) {
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size; int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>( if(hidden_size%16==0&&hidden_size>=2048&&hidden_size<=8192){
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), AT_DISPATCH_FLOATING_TYPES_AND2(
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size); at::ScalarType::Half,
}); at::ScalarType::BFloat16,
input.scalar_type(),
"fused_add_rms_norm_kernel",
[&] {
using T_ACC = at::acc_type<scalar_t, true>;
T_ACC eps = epsilon;
scalar_t* self_data = input.data_ptr<scalar_t>();
scalar_t* out_data =out.data_ptr<scalar_t>();
scalar_t* weight_data=weight.data_ptr<scalar_t>();
if(hidden_size==2048){
fused_rms_kernel_eval<scalar_t,T_ACC,2,1024><<<num_tokens, 1024, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
}
else if(hidden_size<=4096){
fused_rms_kernel_eval<scalar_t,T_ACC,4,1024><<<num_tokens, 1024, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
}
else{
fused_rms_kernel_eval<scalar_t,T_ACC,8,1024><<<num_tokens, 1024, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
}
});
}
else{
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
});
}
} }
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ #define LAUNCH_FUSED_ADD_RMS_NORM(width) \
...@@ -316,37 +447,64 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size] ...@@ -316,37 +447,64 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
num_tokens, hidden_size); \ num_tokens, hidden_size); \
}); });
void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size] torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size] torch::Tensor& weight, // [hidden_size]
double epsilon) { double epsilon) {
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size; int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
/* This kernel is memory-latency bound in many scenarios.
When num_tokens is large, a smaller block size allows
for increased block occupancy on CUs and better latency
hiding on global mem ops. */
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
dim3 block(std::min(hidden_size, max_block_size));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
/*If the tensor types are FP16/BF16, try to use the optimized kernel
with packed + vectorized ops. if(hidden_size%16==0&&hidden_size>=2048&&hidden_size<=8192){
Max optimization is achieved with a width-8 vector of FP16/BF16s AT_DISPATCH_FLOATING_TYPES_AND2(
since we can load at most 128 bits at once in a global memory op. at::ScalarType::Half,
However, this requires each tensor's data to be aligned to 16 at::ScalarType::BFloat16,
bytes. input.scalar_type(),
*/ "fused_add_rms_norm_kernel",
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr()); [&] {
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr()); using T_ACC = at::acc_type<scalar_t, true>;
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr()); T_ACC eps = epsilon;
bool ptrs_are_aligned = scalar_t* self_data = input.data_ptr<scalar_t>();
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0; scalar_t* other_data =residual.data_ptr<scalar_t>();
if (ptrs_are_aligned && hidden_size % 8 == 0) { scalar_t* weight_data=weight.data_ptr<scalar_t>();
LAUNCH_FUSED_ADD_RMS_NORM(8); if(hidden_size==2048){
} else { fused_add_rms_kernel_eval<scalar_t,T_ACC,2,1024><<<num_tokens, 1024, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
LAUNCH_FUSED_ADD_RMS_NORM(0); }
else if(hidden_size<=4096){
fused_add_rms_kernel_eval<scalar_t,T_ACC,4,1024><<<num_tokens, 1024, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
else{
fused_add_rms_kernel_eval<scalar_t,T_ACC,8,1024><<<num_tokens, 1024, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
});
} }
} else{
\ No newline at end of file dim3 grid(num_tokens);
/* This kernel is memory-latency bound in many scenarios.
When num_tokens is large, a smaller block size allows
for increased block occupancy on CUs and better latency
hiding on global mem ops. */
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
dim3 block(std::min(hidden_size, max_block_size));
/*If the tensor types are FP16/BF16, try to use the optimized kernel
with packed + vectorized ops.
Max optimization is achieved with a width-8 vector of FP16/BF16s
since we can load at most 128 bits at once in a global memory op.
However, this requires each tensor's data to be aligned to 16
bytes.
*/
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
bool ptrs_are_aligned =
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
if (ptrs_are_aligned && hidden_size % 8 == 0) {
LAUNCH_FUSED_ADD_RMS_NORM(8);
} else {
LAUNCH_FUSED_ADD_RMS_NORM(0);
}
}
}
...@@ -39,6 +39,13 @@ void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query, ...@@ -39,6 +39,13 @@ void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& cos_sin_cache, bool is_neox, torch::Tensor& cos_sin_cache, bool is_neox,
int64_t rot_dim, int64_t rot_dim,
torch::Tensor& cos_sin_cache_offsets); torch::Tensor& cos_sin_cache_offsets);
void rotary_embedding_tgi(
torch::Tensor& query,
torch::Tensor& key,
int64_t head_size,
torch::Tensor& cos_cache,
torch::Tensor& sin_cache,
bool is_neox);
void silu_and_mul(torch::Tensor& out, torch::Tensor& input); void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
...@@ -123,6 +130,8 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, ...@@ -123,6 +130,8 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit); void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
void trans_w16_gemm(torch::Tensor dst, torch::Tensor src, int64_t row, int64_t col);
// void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, // void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
// torch::Tensor& scale); // torch::Tensor& scale);
......
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_compat.h"
#include "dispatch_utils.h"
namespace vllm {
template<typename scalar_t, bool IS_NEOX>
inline __device__ void apply_token_rotary_embedding_tgi(
scalar_t* __restrict__ arr,
const float* __restrict__ cos_ptr,
const float* __restrict__ sin_ptr,
int rot_offset,
int embed_dim)
{
int x_index, y_index;
float cos, sin;
if (IS_NEOX) {
// GPT-NeoX style rotary embedding.
x_index = rot_offset;
y_index = embed_dim + rot_offset;
cos = VLLM_LDG(cos_ptr + x_index);
sin = VLLM_LDG(sin_ptr + x_index);
} else {
// GPT-J style rotary embedding.
x_index = 2 * rot_offset;
y_index = 2 * rot_offset + 1;
cos = VLLM_LDG(cos_ptr + x_index / 2);
sin = VLLM_LDG(sin_ptr + x_index / 2);
}
const scalar_t x = arr[x_index];
const scalar_t y = arr[y_index];
arr[x_index] = x * cos - y * sin;
arr[y_index] = y * cos + x * sin;
}
template<typename scalar_t, bool IS_NEOX>
inline __device__ void apply_rotary_embedding_tgi(
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
const float* __restrict__ cos_ptr, // [max_position, 1, rot_dim]
const float* __restrict__ sin_ptr, // [max_position, 1, rot_dim]
const int head_size,
const int num_heads,
const int num_kv_heads,
const int rot_dim,
const int token_idx,
const int64_t query_stride,
const int64_t key_stride)
{
const int nq = num_heads * rot_dim;
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
const int head_idx = i / rot_dim;
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
const int rot_offset = i % rot_dim;
apply_token_rotary_embedding_tgi<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
sin_ptr, rot_offset, rot_dim);
}
const int nk = num_kv_heads * rot_dim;
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
const int head_idx = i / rot_dim;
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
const int rot_offset = i % rot_dim;
apply_token_rotary_embedding_tgi<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
sin_ptr, rot_offset, rot_dim);
}
}
template<typename scalar_t, bool IS_NEOX>
__global__ void rotary_embedding_tgi_kernel(
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
const float* __restrict__ cos_cache, // [max_position, 1, rot_dim]
const float* __restrict__ sin_cache, // [max_position, 1, rot_dim]
const int rot_dim,
const int64_t query_stride,
const int64_t key_stride,
const int num_heads,
const int num_kv_heads,
const int head_size) {
// Each thread block is responsible for one token.
const int token_idx = blockIdx.x;
const float* cos_ptr = cos_cache + token_idx * rot_dim;
const float* sin_ptr = sin_cache + token_idx * rot_dim;
apply_rotary_embedding_tgi<scalar_t, IS_NEOX>(query, key, cos_ptr, sin_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride);
}
} // namespace vllm
void rotary_embedding_tgi(
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
int64_t head_size,
torch::Tensor& cos_cache,
torch::Tensor& sin_cache,
bool is_neox) {
int num_tokens = query.size(0);
int rot_dim = cos_cache.size(2);
int num_heads = query.size(1);
int num_kv_heads = key.size(1);
int query_stride = query.stride(0);
int key_stride = key.stride(0);
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim / 2, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
query.scalar_type(),
"rotary_embedding_tgi",
[&] {
if (is_neox) {
vllm::rotary_embedding_tgi_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_cache.data_ptr<float>(),
sin_cache.data_ptr<float>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
} else {
vllm::rotary_embedding_tgi_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_cache.data_ptr<float>(),
sin_cache.data_ptr<float>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
}
});
}
...@@ -1542,6 +1542,7 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a, ...@@ -1542,6 +1542,7 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a,
} }
} }
__global__ void shuffle_4bit_kernel(uint32_t* __restrict__ b_q_weight, __global__ void shuffle_4bit_kernel(uint32_t* __restrict__ b_q_weight,
const int size_k, const int size_n) { const int size_k, const int size_n) {
int n = blockIdx.x * THREADS_X + threadIdx.x; int n = blockIdx.x * THREADS_X + threadIdx.x;
...@@ -1847,6 +1848,7 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, ...@@ -1847,6 +1848,7 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
return c; return c;
} }
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit) { void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight)); const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
vllm::gptq::shuffle_exllama_weight( vllm::gptq::shuffle_exllama_weight(
......
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import torch
# Compiler flags.
CXX_FLAGS = ["-g", "-O3", "-std=c++17"]
NVCC_FLAGS = ["-O3", "-std=c++17","-DUSE_ROCM","-U__HIP_NO_HALF_CONVERSIONS__","-U__HIP_NO_HALF_OPERATORS__"]
#--gpu-max-threads-per-block=1024编译会导致GPTQ多batch性能下降。
# NVCC_FLAGS = ["-O3", "-std=c++17","-DUSE_ROCM","--gpu-max-threads-per-block=1024","-U__HIP_NO_HALF_CONVERSIONS__","-U__HIP_NO_HALF_OPERATORS__"]
ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0
CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
}
setup(
name="gptq_kernels",
ext_modules=[
CUDAExtension(
name="gptq_kernels",
sources=[
"./torch_bindings.cpp",
"./q_gemm.cu",
],
extra_compile_args=extra_compile_args,
)
],
cmdclass={"build_ext": BuildExtension},
)
#include <torch/extension.h>
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
bool use_exllama, int64_t bit);
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
// Bindings
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("gptq_gemm", &gptq_gemm, "make_q_matrix");
m.def("gptq_shuffle", &gptq_shuffle, "gemm_half_q_half");
}
...@@ -93,6 +93,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -93,6 +93,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor cos_sin_cache, bool is_neox) -> ()"); " Tensor cos_sin_cache, bool is_neox) -> ()");
ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding); ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);
// Rotary embedding TGI for TGI
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops.def(
"rotary_embedding_tgi(Tensor! query, Tensor! key,"
" int head_size, Tensor cos_cache,"
" Tensor sin_cache, bool is_neox) -> ()");
// ops.def("rotary_embedding_tgi",&rotary_embedding_tgi);
ops.impl("rotary_embedding_tgi", torch::kCUDA, &rotary_embedding_tgi);
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key // Apply GPT-NeoX or GPT-J style rotary embedding to query and key
// (supports multiple loras). // (supports multiple loras).
ops.def( ops.def(
...@@ -164,6 +173,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -164,6 +173,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()"); ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()");
ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle); ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
// trans w16
ops.def("trans_w16_gemm(Tensor! dst, Tensor src, int row, int col) -> ()");
ops.impl("trans_w16_gemm", torch::kCUDA, &trans_w16_gemm);
// Quantized GEMM for SqueezeLLM. // Quantized GEMM for SqueezeLLM.
ops.def( ops.def(
"squeezellm_gemm(Tensor vec, Tensor mat, Tensor! mul, Tensor " "squeezellm_gemm(Tensor vec, Tensor mat, Tensor! mul, Tensor "
......
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
namespace vllm {
template <typename T>
__global__ void trans_w16_gemm_cudakernel(int64_t num_kernels,T* dst,const T* src,int64_t row,int64_t col)
{
int64_t id = blockIdx.x * blockDim.x + threadIdx.x;
if(id >= num_kernels) return;
int64_t j=id%row;
int64_t i=id/row;
dst[i*row+j]=src[j*col+i];
}
void trans_w16_gemm_cuda(half* dst,const half* src,int64_t row,int64_t col){
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int64_t num_kernels=row*col;
int block_size=256;
trans_w16_gemm_cudakernel<<<(num_kernels+block_size-1)/block_size,block_size, 0, stream>>>(num_kernels,dst,src,row,col);
}
} // namespace vllm
void trans_w16_gemm(torch::Tensor dst,torch::Tensor src,int64_t row,int64_t col){
const at::cuda::OptionalCUDAGuard device_guard(device_of(src));
vllm::trans_w16_gemm_cuda(
(half*)dst.data_ptr(),
(const half*)src.data_ptr(),
row,
col
);
}
\ No newline at end of file
...@@ -3,6 +3,10 @@ import functools ...@@ -3,6 +3,10 @@ import functools
from typing import List, Optional, Tuple, Type from typing import List, Optional, Tuple, Type
import torch import torch
try:
import gptq_kernels
except ImportError as e:
raise RuntimeError("Failed to import gptq_kernel with, Please install gptq_kernels from csrc/quantization/gptq ")
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -182,14 +186,21 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, ...@@ -182,14 +186,21 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
b_g_idx: torch.Tensor, use_exllama: bool, b_g_idx: torch.Tensor, use_exllama: bool,
bit: int) -> torch.Tensor: bit: int) -> torch.Tensor:
return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, return gptq_kernels.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
b_g_idx, use_exllama, bit) b_g_idx, use_exllama, bit)
# return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
# b_g_idx, use_exllama, bit)
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
bit: int) -> None: bit: int) -> None:
torch.ops._C.gptq_shuffle(q_weight, q_perm, bit) gptq_kernels.gptq_shuffle(q_weight, q_perm, bit)
# torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
# trans_w16
def trans_w16_gemm(dst: torch.Tensor, src: torch.Tensor,
row:int, col:int) -> None :
torch.ops._C.trans_w16_gemm(dst,src,row,col)
# squeezellm # squeezellm
def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor, def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor,
......
...@@ -142,7 +142,7 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -142,7 +142,7 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# flag to control if vllm should use triton flash attention # flag to control if vllm should use triton flash attention
"VLLM_USE_TRITON_FLASH_ATTN": "VLLM_USE_TRITON_FLASH_ATTN":
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "False").lower() in
("true", "1")), ("true", "1")),
# local rank of the process in the distributed setting, used to determine # local rank of the process in the distributed setting, used to determine
......
...@@ -14,6 +14,7 @@ from vllm.logger import init_logger ...@@ -14,6 +14,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
import os import os
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -65,34 +66,6 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id): ...@@ -65,34 +66,6 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
return param[shard_id], loaded_weight return param[shard_id], loaded_weight
def pad_weight(weight: torch.Tensor, num_pad: int, pad_dim: int = 0):
if weight.dim() == 1:
padding = torch.zeros(num_pad, dtype=weight.dtype, device=weight.device)
padded_weight = torch.cat([weight, padding], dim=0)
elif weight.dim() == 2:
if pad_dim == 0:
padding = torch.zeros(num_pad, weight.shape[1], dtype=weight.dtype, device=weight.device)
padded_weight = torch.cat([weight, padding], dim=0)
elif pad_dim == 1:
padding = torch.zeros(weight.shape[0], num_pad, dtype=weight.dtype, device=weight.device)
padded_weight = torch.cat([weight, padding], dim=1)
else:
raise ValueError("pad_dim must be 0 or 1")
else:
raise ValueError("Weight tensor must be 1D or 2D")
return padded_weight
def gemm_bank_conf(weight):
is_mul_of_2048 = weight % 2048 == 0
is_power_of_two = (weight & (weight - 1)) == 0 and weight != 0
if is_mul_of_2048 and is_power_of_two:
return True
else:
return False
class LinearMethodBase(QuantizeMethodBase): class LinearMethodBase(QuantizeMethodBase):
"""Base class for different (maybe quantized) linear methods.""" """Base class for different (maybe quantized) linear methods."""
...@@ -133,7 +106,6 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -133,7 +106,6 @@ class UnquantizedLinearMethod(LinearMethodBase):
def __init__(self): def __init__(self):
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def create_weights(self, layer: torch.nn.Module, def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int, input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int, output_partition_sizes: List[int], input_size: int,
...@@ -151,20 +123,19 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -151,20 +123,19 @@ class UnquantizedLinearMethod(LinearMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.use_llama_nn: if self.use_llama_nn:
layer.weight = layer.weight.reshape(layer.weight.shape[1], -1)
if bias is not None: if bias is not None:
return torch.matmul(x, layer.weight) + bias if len(x.shape) == 2:
else: return torch.addmm(bias, x, layer.weight)
if gemm_bank_conf(layer.weight.shape[1] - 32) and os.environ['GEMM_PAD'] == '1':
return torch.matmul(x, layer.weight[:,:-32])
else: else:
return torch.matmul(x, layer.weight) return torch.matmul(x, layer.weight) + bias
else:
return torch.matmul(x, layer.weight)
else: else:
return F.linear(x, layer.weight, bias) return F.linear(x, layer.weight, bias)
class LinearBase(torch.nn.Module): class LinearBase(torch.nn.Module):
"""Base linear layer. """Base linear layer.
...@@ -321,7 +292,6 @@ class ColumnParallelLinear(LinearBase): ...@@ -321,7 +292,6 @@ class ColumnParallelLinear(LinearBase):
}) })
else: else:
self.register_parameter("bias", None) self.register_parameter("bias", None)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
...@@ -339,9 +309,6 @@ class ColumnParallelLinear(LinearBase): ...@@ -339,9 +309,6 @@ class ColumnParallelLinear(LinearBase):
loaded_weight = loaded_weight.reshape(1) loaded_weight = loaded_weight.reshape(1)
assert param_data.shape == loaded_weight.shape assert param_data.shape == loaded_weight.shape
if self.use_llama_nn:
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(param_data.shape[0],-1)
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
def forward(self, input_): def forward(self, input_):
...@@ -406,8 +373,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -406,8 +373,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip_bias_add=skip_bias_add, skip_bias_add=skip_bias_add,
params_dtype=params_dtype, params_dtype=params_dtype,
quant_config=quant_config) quant_config=quant_config)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def weight_loader(self, def weight_loader(self,
param: Parameter, param: Parameter,
...@@ -469,21 +434,15 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -469,21 +434,15 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# Special case for Marlin. # Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard( shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset) param, shard_size, shard_offset)
use_bitsandbytes = getattr(param, "use_bitsandbytes", False) use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
if use_bitsandbytes: if use_bitsandbytes:
shard_size = loaded_weight.shape[output_dim] shard_size = loaded_weight.shape[output_dim]
shard_offset = loaded_weight.shape[output_dim] * \ shard_offset = loaded_weight.shape[output_dim] * \
loaded_shard_id loaded_shard_id
if self.use_llama_nn: param_data = param_data.narrow(output_dim, shard_offset,
param_data_ = param_data.narrow(output_dim, shard_offset, shard_size)
shard_size)
else:
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
start_idx = tp_rank * shard_size start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx, loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size) shard_size)
...@@ -506,16 +465,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -506,16 +465,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
"Loading a weight without `output_dim` attribute in " "Loading a weight without `output_dim` attribute in "
"MergedColumnParallelLinear, assume the weight is " "MergedColumnParallelLinear, assume the weight is "
"the same for all partitions.") "the same for all partitions.")
if self.use_llama_nn: assert param_data.shape == loaded_weight.shape
assert param_data_.shape == loaded_weight.shape param_data.copy_(loaded_weight)
param_data_.copy_(loaded_weight)
if loaded_shard_id == 1 and len(param_data.shape) == 2:
param_data = param_data.transpose(0, 1)
param.data = param_data.reshape(param_data.shape[1], -1)
else:
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
class QKVParallelLinear(ColumnParallelLinear): class QKVParallelLinear(ColumnParallelLinear):
...@@ -575,6 +527,7 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -575,6 +527,7 @@ class QKVParallelLinear(ColumnParallelLinear):
self.num_kv_heads * self.head_size * tp_size, # k_proj self.num_kv_heads * self.head_size * tp_size, # k_proj
self.num_kv_heads * self.head_size * tp_size, # v_proj self.num_kv_heads * self.head_size * tp_size, # v_proj
] ]
super().__init__(input_size=input_size, super().__init__(input_size=input_size,
output_size=output_size, output_size=output_size,
bias=bias, bias=bias,
...@@ -582,8 +535,6 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -582,8 +535,6 @@ class QKVParallelLinear(ColumnParallelLinear):
skip_bias_add=skip_bias_add, skip_bias_add=skip_bias_add,
params_dtype=params_dtype, params_dtype=params_dtype,
quant_config=quant_config) quant_config=quant_config)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
def weight_loader(self, def weight_loader(self,
param: Parameter, param: Parameter,
...@@ -675,14 +626,9 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -675,14 +626,9 @@ class QKVParallelLinear(ColumnParallelLinear):
} }
shard_size, shard_offset = adjust_bitsandbytes_shard( shard_size, shard_offset = adjust_bitsandbytes_shard(
param, orig_qkv_offsets, loaded_shard_id) param, orig_qkv_offsets, loaded_shard_id)
if self.use_llama_nn: param_data = param_data.narrow(output_dim, shard_offset,
param_data_ = param_data.narrow(output_dim, shard_offset,
shard_size) shard_size)
else:
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
if loaded_shard_id == "q": if loaded_shard_id == "q":
shard_id = tp_rank shard_id = tp_rank
else: else:
...@@ -708,20 +654,9 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -708,20 +654,9 @@ class QKVParallelLinear(ColumnParallelLinear):
"Loading a weight without `output_dim` attribute in " "Loading a weight without `output_dim` attribute in "
"QKVParallelLinear, assume the weight is the same " "QKVParallelLinear, assume the weight is the same "
"for all partitions.") "for all partitions.")
if self.use_llama_nn: assert param_data.shape == loaded_weight.shape
assert param_data_.shape == loaded_weight.shape param_data.copy_(loaded_weight)
param_data_.copy_(loaded_weight)
if loaded_shard_id == "v" and len(param_data.shape) == 2:
if self.use_fa_pad and param_data.shape[0]== 12288:
param_data = pad_weight(param.data, 32)
param_data = param_data.transpose(0, 1)
param.data = param_data.reshape(param_data.shape[1], -1)
if self.use_fa_pad and param_data.shape[0]== 12288 and loaded_shard_id == "v" and len(param_data.shape) == 1:
param.data = pad_weight(param.data, 32)
else:
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
class RowParallelLinear(LinearBase): class RowParallelLinear(LinearBase):
...@@ -790,8 +725,6 @@ class RowParallelLinear(LinearBase): ...@@ -790,8 +725,6 @@ class RowParallelLinear(LinearBase):
}) })
else: else:
self.register_parameter("bias", None) self.register_parameter("bias", None)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
...@@ -809,18 +742,7 @@ class RowParallelLinear(LinearBase): ...@@ -809,18 +742,7 @@ class RowParallelLinear(LinearBase):
loaded_weight = loaded_weight.reshape(1) loaded_weight = loaded_weight.reshape(1)
assert param_data.shape == loaded_weight.shape assert param_data.shape == loaded_weight.shape
# if self.use_llama_nn:
# loaded_weight = loaded_weight.transpose(0, 1)
# loaded_weight=loaded_weight.reshape(param_data.shape[0],-1)
# param_data.copy_(loaded_weight)
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
if self.use_llama_nn:
if gemm_bank_conf(param.data.shape[0]) and self.use_gemm_pad:
param.data = pad_weight(param.data, 32)
param.data = param.data.transpose(0, 1)
param.data=param.data.reshape(param.data.shape[1],-1)
def forward(self, input_): def forward(self, input_):
if self.input_is_parallel: if self.input_is_parallel:
...@@ -853,4 +775,4 @@ class RowParallelLinear(LinearBase): ...@@ -853,4 +775,4 @@ class RowParallelLinear(LinearBase):
s += f", bias={self.bias is not None}" s += f", bias={self.bias is not None}"
s += f", tp_size={self.tp_size}" s += f", tp_size={self.tp_size}"
s += f", reduce_results={self.reduce_results}" s += f", reduce_results={self.reduce_results}"
return s return s
\ No newline at end of file
...@@ -22,13 +22,11 @@ def set_default_torch_dtype(dtype: torch.dtype): ...@@ -22,13 +22,11 @@ def set_default_torch_dtype(dtype: torch.dtype):
def get_model_architecture( def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", []) architectures = getattr(model_config.hf_config, "architectures", [])
if architectures == ['LlamaForCausalLM'] or architectures == ['Qwen2ForCausalLM'] or architectures == ['ChatGLMModel'] or architectures == ['BaichuanForCausalLM']: if architectures == ['LlamaForCausalLM'] or architectures == ['QWenLMHeadModel'] or architectures == ['Qwen2ForCausalLM'] or architectures == ['ChatGLMModel'] or architectures == ['BaichuanForCausalLM']:
if os.getenv('LLAMA_NN') != '0': if os.getenv('LLAMA_NN') != '0':
os.environ['LLAMA_NN'] = '1' os.environ['LLAMA_NN'] = '1'
if os.getenv('GEMM_PAD') != '0': else:
os.environ['GEMM_PAD'] = '1' os.environ['LLAMA_NN'] = '0'
if os.getenv('FA_PAD') != '1':
os.environ['FA_PAD'] = '0'
# Special handling for quantized Mixtral. # Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack. # FIXME(woosuk): This is a temporary hack.
if (model_config.quantization is not None if (model_config.quantization is not None
......
...@@ -25,6 +25,7 @@ import torch ...@@ -25,6 +25,7 @@ import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
import os import os
import re
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
...@@ -48,6 +49,9 @@ from vllm.sequence import IntermediateTensors, SamplerOutput ...@@ -48,6 +49,9 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
from vllm import _custom_ops as ops
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
...@@ -181,8 +185,6 @@ class BaiChuanAttention(nn.Module): ...@@ -181,8 +185,6 @@ class BaiChuanAttention(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.W_pack(hidden_states) qkv, _ = self.W_pack(hidden_states)
if os.environ.get('FA_PAD') == '1' and qkv.shape[-1] == 12320:
qkv = qkv[...,:-32]
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.postion_embedding != "ALIBI": if self.postion_embedding != "ALIBI":
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
...@@ -336,6 +338,7 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA): ...@@ -336,6 +338,7 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
quant_config=quant_config) quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def forward( def forward(
self, self,
...@@ -404,6 +407,26 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA): ...@@ -404,6 +407,26 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if self.use_llama_nn:
lay_key_words = [
"self_attn.W_pack.weight",
"self_attn.o_proj.weight",
"mlp.gate_up_proj.weight",
"mlp.down_proj.weight"
]
combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches:
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1], -1)
class BaichuanForCausalLM(BaiChuanBaseForCausalLM): class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
......
...@@ -8,6 +8,7 @@ import torch ...@@ -8,6 +8,7 @@ import torch
from torch import nn from torch import nn
from torch.nn import LayerNorm from torch.nn import LayerNorm
import os import os
import re
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
...@@ -28,6 +29,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -28,6 +29,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.transformers_utils.configs import ChatGLMConfig from vllm.transformers_utils.configs import ChatGLMConfig
from vllm import _custom_ops as ops
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
...@@ -104,8 +106,6 @@ class GLMAttention(nn.Module): ...@@ -104,8 +106,6 @@ class GLMAttention(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states) qkv, _ = self.query_key_value(hidden_states)
if os.environ.get('FA_PAD') == '1' and qkv.shape[-1] == 12320:
qkv = qkv[...,:-32]
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(position_ids, q, k) q, k = self.rotary_emb(position_ids, q, k)
context_layer = self.attn( context_layer = self.attn(
...@@ -362,6 +362,7 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA): ...@@ -362,6 +362,7 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
self.lm_head = self.transformer.output_layer self.lm_head = self.transformer.output_layer
self.logits_processor = LogitsProcessor(config.padded_vocab_size) self.logits_processor = LogitsProcessor(config.padded_vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def forward( def forward(
self, self,
...@@ -403,3 +404,23 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA): ...@@ -403,3 +404,23 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if self.use_llama_nn:
lay_key_words = [
"self_attention.query_key_value.weight",
"self_attention.dense.weight",
"mlp.dense_h_to_4h.weight",
"mlp.dense_4h_to_h.weight"
]
combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches:
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1], -1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment