Commit d4c0015a authored by huangwb's avatar huangwb
Browse files

Merge branch 'gptq_fix' into 'v0.5.0-dtk24.04.1'

fix gptq performance degradation when batch size>4 issue

See merge request dcutoolkit/deeplearing/vllm!5
parents 4caf1539 f423ad60
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import torch
# Compiler flags.
CXX_FLAGS = ["-g", "-O3", "-std=c++17"]
NVCC_FLAGS = ["-O3", "-std=c++17","-DUSE_ROCM","-U__HIP_NO_HALF_CONVERSIONS__","-U__HIP_NO_HALF_OPERATORS__"]
#--gpu-max-threads-per-block=1024编译会导致GPTQ多batch性能下降。
# NVCC_FLAGS = ["-O3", "-std=c++17","-DUSE_ROCM","--gpu-max-threads-per-block=1024","-U__HIP_NO_HALF_CONVERSIONS__","-U__HIP_NO_HALF_OPERATORS__"]
ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0
CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
}
setup(
name="gptq_kernels",
ext_modules=[
CUDAExtension(
name="gptq_kernels",
sources=[
"./torch_bindings.cpp",
"./q_gemm.cu",
],
extra_compile_args=extra_compile_args,
)
],
cmdclass={"build_ext": BuildExtension},
)
#include <torch/extension.h>
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
bool use_exllama, int64_t bit);
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
// Bindings
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("gptq_gemm", &gptq_gemm, "make_q_matrix");
m.def("gptq_shuffle", &gptq_shuffle, "gemm_half_q_half");
}
......@@ -2,6 +2,10 @@ import contextlib
from typing import List, Optional, Tuple, Type
import torch
try:
import gptq_kernels
except ImportError as e:
raise RuntimeError("Failed to import gptq_kernel with, Please install gptq_kernels from csrc/quantization/gptq ")
try:
import vllm._C
......@@ -156,13 +160,16 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
b_g_idx: torch.Tensor, use_exllama: bool,
bit: int) -> torch.Tensor:
return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
return gptq_kernels.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
b_g_idx, use_exllama, bit)
# return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
# b_g_idx, use_exllama, bit)
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
bit: int) -> None:
torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
gptq_kernels.gptq_shuffle(q_weight, q_perm, bit)
# torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
# trans_w16
def trans_w16_gemm(dst: torch.Tensor, src: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment