Commit b2dd1743 authored by zhuwenwen's avatar zhuwenwen
Browse files

refactoring the transpose kernel and update supported model

parent 795ce518
......@@ -156,6 +156,7 @@ set(VLLM_EXT_SRC
"csrc/pos_encoding_tgi_kernels.cu"
"csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/transpose_kernels.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
......
......@@ -15,12 +15,18 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention
| LlamaForCausalLM | LLaMA-3 | Yes | Yes |
| LlamaForCausalLM | Codellama | Yes | Yes |
| QWenLMHeadModel | QWen | Yes | Yes |
| Qwen2ForCausalLM | QWen1.5 | Yes | Yes |
| Qwen2ForCausalLM | CodeQwen1.5 | Yes | Yes |
| Qwen2ForCausalLM | QWen2 | Yes | Yes |
| ChatGLMModel | chatglm2 | Yes | Yes |
| ChatGLMModel | chatglm3 | Yes | Yes |
| ChatGLMModel | glm-4 | Yes | Yes |
| BaiChuanForCausalLM | Baichuan-7B | Yes | Yes |
| BaiChuanForCausalLM | Baichuan2-7B | Yes | Yes |
| ChatGLMModel | chatglm2-6b | Yes | Yes |
| ChatGLMModel | chatglm3-6b | Yes | Yes |
| InternLMForCausalLM | InternLM | Yes | Yes |
| InternLM2ForCausalLM | InternLM2 | Yes | Yes |
| LlamaForCausalLM | deepseek | Yes | Yes |
| DeepseekV2ForCausalLM | DeepSeek-V2 | Yes | Yes |
| LlamaForCausalLM | Yi | Yes | Yes |
| MixtralForCausalLM | Mixtral-8x7B | Yes | Yes |
......@@ -56,6 +62,10 @@ git clone http://developer.hpccube.com/codes/OpenDAS/vllm.git # 根据需要的
VLLM_INSTALL_PUNICA_KERNELS=1 python setup.py bdist_wheel
cd dist
pip install vllm*
cd csrc/quantization/gptq
python setup.py bdist_wheel
cd dist
pip install gptq_kernel
2. 源码编译安装
VLLM_INSTALL_PUNICA_KERNELS=1 python3 setup.py install
......
......@@ -1542,25 +1542,6 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a,
}
}
template <typename T>
__global__ void trans_w16_gemm_cudakernel(int64_t num_kernels,T* dst,const T* src,int64_t row,int64_t col)
{
int64_t id = blockIdx.x * blockDim.x + threadIdx.x;
if(id >= num_kernels) return;
int64_t j=id%row;
int64_t i=id/row;
dst[i*row+j]=src[j*col+i];
}
void trans_w16_gemm_cuda(half* dst,const half* src,int64_t row,int64_t col){
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int64_t num_kernels=row*col;
int block_size=256;
trans_w16_gemm_cudakernel<<<(num_kernels+block_size-1)/block_size,block_size, 0, stream>>>(num_kernels,dst,src,row,col);
}
__global__ void shuffle_4bit_kernel(uint32_t* __restrict__ b_q_weight,
const int size_k, const int size_n) {
......@@ -1867,16 +1848,6 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
return c;
}
void trans_w16_gemm(torch::Tensor dst,torch::Tensor src,int64_t row,int64_t col){
//row是原矩阵的行,col是原矩阵的列
const at::cuda::OptionalCUDAGuard device_guard(device_of(src));
vllm::gptq::trans_w16_gemm_cuda(
(half*)dst.data_ptr(),
(const half*)src.data_ptr(),
row,
col
);
}
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
......
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
namespace vllm {
template <typename T>
__global__ void trans_w16_gemm_cudakernel(int64_t num_kernels,T* dst,const T* src,int64_t row,int64_t col)
{
int64_t id = blockIdx.x * blockDim.x + threadIdx.x;
if(id >= num_kernels) return;
int64_t j=id%row;
int64_t i=id/row;
dst[i*row+j]=src[j*col+i];
}
void trans_w16_gemm_cuda(half* dst,const half* src,int64_t row,int64_t col){
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int64_t num_kernels=row*col;
int block_size=256;
trans_w16_gemm_cudakernel<<<(num_kernels+block_size-1)/block_size,block_size, 0, stream>>>(num_kernels,dst,src,row,col);
}
} // namespace vllm
void trans_w16_gemm(torch::Tensor dst,torch::Tensor src,int64_t row,int64_t col){
const at::cuda::OptionalCUDAGuard device_guard(device_of(src));
vllm::trans_w16_gemm_cuda(
(half*)dst.data_ptr(),
(const half*)src.data_ptr(),
row,
col
);
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment