Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f423ad60
Commit
f423ad60
authored
Jul 22, 2024
by
huangwb
Browse files
fix gptq performance degradation when batch size>4 issue
parent
4caf1539
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
58 additions
and
2 deletions
+58
-2
csrc/quantization/gptq/setup.py
csrc/quantization/gptq/setup.py
+34
-0
csrc/quantization/gptq/torch_bindings.cpp
csrc/quantization/gptq/torch_bindings.cpp
+15
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+9
-2
No files found.
csrc/quantization/gptq/setup.py
0 → 100644
View file @
f423ad60
from
setuptools
import
setup
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
import
torch
# Compiler flags.
CXX_FLAGS
=
[
"-g"
,
"-O3"
,
"-std=c++17"
]
NVCC_FLAGS
=
[
"-O3"
,
"-std=c++17"
,
"-DUSE_ROCM"
,
"-U__HIP_NO_HALF_CONVERSIONS__"
,
"-U__HIP_NO_HALF_OPERATORS__"
]
#--gpu-max-threads-per-block=1024编译会导致GPTQ多batch性能下降。
# NVCC_FLAGS = ["-O3", "-std=c++17","-DUSE_ROCM","--gpu-max-threads-per-block=1024","-U__HIP_NO_HALF_CONVERSIONS__","-U__HIP_NO_HALF_OPERATORS__"]
ABI
=
1
if
torch
.
_C
.
_GLIBCXX_USE_CXX11_ABI
else
0
CXX_FLAGS
+=
[
f
"-D_GLIBCXX_USE_CXX11_ABI=
{
ABI
}
"
]
NVCC_FLAGS
+=
[
f
"-D_GLIBCXX_USE_CXX11_ABI=
{
ABI
}
"
]
extra_compile_args
=
{
"cxx"
:
CXX_FLAGS
,
"nvcc"
:
NVCC_FLAGS
,
}
setup
(
name
=
"gptq_kernels"
,
ext_modules
=
[
CUDAExtension
(
name
=
"gptq_kernels"
,
sources
=
[
"./torch_bindings.cpp"
,
"./q_gemm.cu"
,
],
extra_compile_args
=
extra_compile_args
,
)
],
cmdclass
=
{
"build_ext"
:
BuildExtension
},
)
csrc/quantization/gptq/torch_bindings.cpp
0 → 100644
View file @
f423ad60
#include <torch/extension.h>
torch
::
Tensor
gptq_gemm
(
torch
::
Tensor
a
,
torch
::
Tensor
b_q_weight
,
torch
::
Tensor
b_gptq_qzeros
,
torch
::
Tensor
b_gptq_scales
,
torch
::
Tensor
b_g_idx
,
bool
use_exllama
,
int64_t
bit
);
void
gptq_shuffle
(
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_perm
,
int64_t
bit
);
// Bindings
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"gptq_gemm"
,
&
gptq_gemm
,
"make_q_matrix"
);
m
.
def
(
"gptq_shuffle"
,
&
gptq_shuffle
,
"gemm_half_q_half"
);
}
vllm/_custom_ops.py
View file @
f423ad60
...
...
@@ -2,6 +2,10 @@ import contextlib
from
typing
import
List
,
Optional
,
Tuple
,
Type
import
torch
try
:
import
gptq_kernels
except
ImportError
as
e
:
raise
RuntimeError
(
"Failed to import gptq_kernel with, Please install gptq_kernels from csrc/quantization/gptq "
)
try
:
import
vllm._C
...
...
@@ -156,13 +160,16 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros
:
torch
.
Tensor
,
b_gptq_scales
:
torch
.
Tensor
,
b_g_idx
:
torch
.
Tensor
,
use_exllama
:
bool
,
bit
:
int
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
gptq_gemm
(
a
,
b_q_weight
,
b_gptq_qzeros
,
b_gptq_scales
,
return
gptq_kernels
.
gptq_gemm
(
a
,
b_q_weight
,
b_gptq_qzeros
,
b_gptq_scales
,
b_g_idx
,
use_exllama
,
bit
)
# return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
# b_g_idx, use_exllama, bit)
def
gptq_shuffle
(
q_weight
:
torch
.
Tensor
,
q_perm
:
torch
.
Tensor
,
bit
:
int
)
->
None
:
torch
.
ops
.
_C
.
gptq_shuffle
(
q_weight
,
q_perm
,
bit
)
gptq_kernels
.
gptq_shuffle
(
q_weight
,
q_perm
,
bit
)
# torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
# trans_w16
def
trans_w16_gemm
(
dst
:
torch
.
Tensor
,
src
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment