Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b2dd1743
Commit
b2dd1743
authored
Jul 23, 2024
by
zhuwenwen
Browse files
refactoring the transpose kernel and update supported model
parent
795ce518
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
51 additions
and
31 deletions
+51
-31
CMakeLists.txt
CMakeLists.txt
+1
-0
README.md
README.md
+12
-2
csrc/quantization/gptq/q_gemm.cu
csrc/quantization/gptq/q_gemm.cu
+0
-29
csrc/transpose_kernels.cu
csrc/transpose_kernels.cu
+38
-0
No files found.
CMakeLists.txt
View file @
b2dd1743
...
@@ -156,6 +156,7 @@ set(VLLM_EXT_SRC
...
@@ -156,6 +156,7 @@ set(VLLM_EXT_SRC
"csrc/pos_encoding_tgi_kernels.cu"
"csrc/pos_encoding_tgi_kernels.cu"
"csrc/activation_kernels.cu"
"csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/transpose_kernels.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
...
...
README.md
View file @
b2dd1743
...
@@ -15,12 +15,18 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention
...
@@ -15,12 +15,18 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention
| LlamaForCausalLM | LLaMA-3 | Yes | Yes |
| LlamaForCausalLM | LLaMA-3 | Yes | Yes |
| LlamaForCausalLM | Codellama | Yes | Yes |
| LlamaForCausalLM | Codellama | Yes | Yes |
| QWenLMHeadModel | QWen | Yes | Yes |
| QWenLMHeadModel | QWen | Yes | Yes |
| Qwen2ForCausalLM | QWen1.5 | Yes | Yes |
| Qwen2ForCausalLM | CodeQwen1.5 | Yes | Yes |
| Qwen2ForCausalLM | QWen2 | Yes | Yes |
| ChatGLMModel | chatglm2 | Yes | Yes |
| ChatGLMModel | chatglm3 | Yes | Yes |
| ChatGLMModel | glm-4 | Yes | Yes |
| BaiChuanForCausalLM | Baichuan-7B | Yes | Yes |
| BaiChuanForCausalLM | Baichuan-7B | Yes | Yes |
| BaiChuanForCausalLM | Baichuan2-7B | Yes | Yes |
| BaiChuanForCausalLM | Baichuan2-7B | Yes | Yes |
| ChatGLMModel | chatglm2-6b | Yes | Yes |
| ChatGLMModel | chatglm3-6b | Yes | Yes |
| InternLMForCausalLM | InternLM | Yes | Yes |
| InternLMForCausalLM | InternLM | Yes | Yes |
| InternLM2ForCausalLM | InternLM2 | Yes | Yes |
| InternLM2ForCausalLM | InternLM2 | Yes | Yes |
| LlamaForCausalLM | deepseek | Yes | Yes |
| DeepseekV2ForCausalLM | DeepSeek-V2 | Yes | Yes |
| LlamaForCausalLM | Yi | Yes | Yes |
| LlamaForCausalLM | Yi | Yes | Yes |
| MixtralForCausalLM | Mixtral-8x7B | Yes | Yes |
| MixtralForCausalLM | Mixtral-8x7B | Yes | Yes |
...
@@ -56,6 +62,10 @@ git clone http://developer.hpccube.com/codes/OpenDAS/vllm.git # 根据需要的
...
@@ -56,6 +62,10 @@ git clone http://developer.hpccube.com/codes/OpenDAS/vllm.git # 根据需要的
VLLM_INSTALL_PUNICA_KERNELS=1 python setup.py bdist_wheel
VLLM_INSTALL_PUNICA_KERNELS=1 python setup.py bdist_wheel
cd dist
cd dist
pip install vllm*
pip install vllm*
cd csrc/quantization/gptq
python setup.py bdist_wheel
cd dist
pip install gptq_kernel
2. 源码编译安装
2. 源码编译安装
VLLM_INSTALL_PUNICA_KERNELS=1 python3 setup.py install
VLLM_INSTALL_PUNICA_KERNELS=1 python3 setup.py install
...
...
csrc/quantization/gptq/q_gemm.cu
View file @
b2dd1743
...
@@ -1542,25 +1542,6 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a,
...
@@ -1542,25 +1542,6 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a,
}
}
}
}
template
<
typename
T
>
__global__
void
trans_w16_gemm_cudakernel
(
int64_t
num_kernels
,
T
*
dst
,
const
T
*
src
,
int64_t
row
,
int64_t
col
)
{
int64_t
id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
id
>=
num_kernels
)
return
;
int64_t
j
=
id
%
row
;
int64_t
i
=
id
/
row
;
dst
[
i
*
row
+
j
]
=
src
[
j
*
col
+
i
];
}
void
trans_w16_gemm_cuda
(
half
*
dst
,
const
half
*
src
,
int64_t
row
,
int64_t
col
){
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
int64_t
num_kernels
=
row
*
col
;
int
block_size
=
256
;
trans_w16_gemm_cudakernel
<<<
(
num_kernels
+
block_size
-
1
)
/
block_size
,
block_size
,
0
,
stream
>>>
(
num_kernels
,
dst
,
src
,
row
,
col
);
}
__global__
void
shuffle_4bit_kernel
(
uint32_t
*
__restrict__
b_q_weight
,
__global__
void
shuffle_4bit_kernel
(
uint32_t
*
__restrict__
b_q_weight
,
const
int
size_k
,
const
int
size_n
)
{
const
int
size_k
,
const
int
size_n
)
{
...
@@ -1867,16 +1848,6 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
...
@@ -1867,16 +1848,6 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
return
c
;
return
c
;
}
}
void
trans_w16_gemm
(
torch
::
Tensor
dst
,
torch
::
Tensor
src
,
int64_t
row
,
int64_t
col
){
//row是原矩阵的行,col是原矩阵的列
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
src
));
vllm
::
gptq
::
trans_w16_gemm_cuda
(
(
half
*
)
dst
.
data_ptr
(),
(
const
half
*
)
src
.
data_ptr
(),
row
,
col
);
}
void
gptq_shuffle
(
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_perm
,
int64_t
bit
)
{
void
gptq_shuffle
(
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_perm
,
int64_t
bit
)
{
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
q_weight
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
q_weight
));
...
...
csrc/transpose_kernels.cu
0 → 100644
View file @
b2dd1743
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
namespace
vllm
{
template
<
typename
T
>
__global__
void
trans_w16_gemm_cudakernel
(
int64_t
num_kernels
,
T
*
dst
,
const
T
*
src
,
int64_t
row
,
int64_t
col
)
{
int64_t
id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
id
>=
num_kernels
)
return
;
int64_t
j
=
id
%
row
;
int64_t
i
=
id
/
row
;
dst
[
i
*
row
+
j
]
=
src
[
j
*
col
+
i
];
}
void
trans_w16_gemm_cuda
(
half
*
dst
,
const
half
*
src
,
int64_t
row
,
int64_t
col
){
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
int64_t
num_kernels
=
row
*
col
;
int
block_size
=
256
;
trans_w16_gemm_cudakernel
<<<
(
num_kernels
+
block_size
-
1
)
/
block_size
,
block_size
,
0
,
stream
>>>
(
num_kernels
,
dst
,
src
,
row
,
col
);
}
}
// namespace vllm
void
trans_w16_gemm
(
torch
::
Tensor
dst
,
torch
::
Tensor
src
,
int64_t
row
,
int64_t
col
){
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
src
));
vllm
::
trans_w16_gemm_cuda
(
(
half
*
)
dst
.
data_ptr
(),
(
const
half
*
)
src
.
data_ptr
(),
row
,
col
);
}
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment