Unverified Commit 1b7c791d authored by kliuae's avatar kliuae Committed by GitHub
Browse files

[ROCm] Fixes for GPTQ on ROCm (#2180)

parent bbe4466f
...@@ -28,6 +28,7 @@ namespace gptq { ...@@ -28,6 +28,7 @@ namespace gptq {
#define DIVIDE(x, size) (((x) + (size) - 1) / (size)) #define DIVIDE(x, size) (((x) + (size) - 1) / (size))
#if defined(USE_ROCM) #if defined(USE_ROCM)
#include <hipblas/hipblas.h>
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle, __host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
hipblasOperation_t transA, hipblasOperation_t transA,
hipblasOperation_t transB, hipblasOperation_t transB,
...@@ -520,12 +521,21 @@ __global__ void gemm_half_q_half_alt_kernel( ...@@ -520,12 +521,21 @@ __global__ void gemm_half_q_half_alt_kernel(
zeros_tmp[tmp_k] = zero; zeros_tmp[tmp_k] = zero;
} }
for (int m = 0; m < b_end; m++) { for (int m = 0; m < b_end; m++) {
#ifndef USE_ROCM
res2 = {}; res2 = {};
#else
res2.x = __half_as_ushort(__float2half(0));
res2.y = __half_as_ushort(__float2half(0));
#endif
res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xff][off], scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2); res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xff][off], scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2);
res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xff][off], scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2); res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xff][off], scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2);
res2 = __hfma2(__hfma2(deq2[(tmp >> 16) & 0xff][off], scales_tmp[2], zeros_tmp[2]), blockvec[m][k + 2], res2); res2 = __hfma2(__hfma2(deq2[(tmp >> 16) & 0xff][off], scales_tmp[2], zeros_tmp[2]), blockvec[m][k + 2], res2);
res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scales_tmp[3], zeros_tmp[3]), blockvec[m][k + 3], res2); res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scales_tmp[3], zeros_tmp[3]), blockvec[m][k + 3], res2);
#ifndef USE_ROCM
res[m] = __hadd(res[m], __hadd(res2.x, res2.y)); res[m] = __hadd(res[m], __hadd(res2.x, res2.y));
#else
res[m] = __hadd(res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)));
#endif
} }
i += width; i += width;
k += 4; k += 4;
......
...@@ -116,6 +116,7 @@ Alternatively, if you plan to install vLLM-ROCm on a local machine or start from ...@@ -116,6 +116,7 @@ Alternatively, if you plan to install vLLM-ROCm on a local machine or start from
- `ROCm <https://rocm.docs.amd.com/en/latest/deploy/linux/index.html>`_ - `ROCm <https://rocm.docs.amd.com/en/latest/deploy/linux/index.html>`_
- `Pytorch <https://pytorch.org/>`_ - `Pytorch <https://pytorch.org/>`_
- `hipBLAS <https://rocm.docs.amd.com/projects/hipBLAS/en/latest/install.html>`_
1. Install `flash attention for ROCm <https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm>`_ 1. Install `flash attention for ROCm <https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm>`_
......
...@@ -219,13 +219,13 @@ vllm_extension_sources = [ ...@@ -219,13 +219,13 @@ vllm_extension_sources = [
"csrc/activation_kernels.cu", "csrc/activation_kernels.cu",
"csrc/layernorm_kernels.cu", "csrc/layernorm_kernels.cu",
"csrc/quantization/squeezellm/quant_cuda_kernel.cu", "csrc/quantization/squeezellm/quant_cuda_kernel.cu",
"csrc/quantization/gptq/q_gemm.cu",
"csrc/cuda_utils_kernels.cu", "csrc/cuda_utils_kernels.cu",
"csrc/pybind.cpp", "csrc/pybind.cpp",
] ]
if _is_cuda(): if _is_cuda():
vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu") vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu")
vllm_extension_sources.append("csrc/quantization/gptq/q_gemm.cu")
vllm_extension = CUDAExtension( vllm_extension = CUDAExtension(
name="vllm._C", name="vllm._C",
......
...@@ -112,13 +112,12 @@ class ModelConfig: ...@@ -112,13 +112,12 @@ class ModelConfig:
supported_load_format = [ supported_load_format = [
"auto", "pt", "safetensors", "npcache", "dummy" "auto", "pt", "safetensors", "npcache", "dummy"
] ]
rocm_not_supported_load_format = ["safetensors"] rocm_not_supported_load_format = []
if load_format not in supported_load_format: if load_format not in supported_load_format:
raise ValueError( raise ValueError(
f"Unknown load format: {self.load_format}. Must be one of " f"Unknown load format: {self.load_format}. Must be one of "
"'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.") "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
if is_hip(): if is_hip() and load_format in rocm_not_supported_load_format:
if load_format in ["safetensors"]:
rocm_supported_load_format = [ rocm_supported_load_format = [
f for f in supported_load_format f for f in supported_load_format
if (f not in rocm_not_supported_load_format) if (f not in rocm_not_supported_load_format)
...@@ -127,9 +126,6 @@ class ModelConfig: ...@@ -127,9 +126,6 @@ class ModelConfig:
f"load format \'{load_format}\' is not supported in ROCm. " f"load format \'{load_format}\' is not supported in ROCm. "
f"Supported load format are " f"Supported load format are "
f"{rocm_supported_load_format}") f"{rocm_supported_load_format}")
# Force ROCm to load from pt weights if nothing specific is set
if load_format == "auto":
load_format = "pt"
# TODO: Remove this check once HF updates the pt weights of Mixtral. # TODO: Remove this check once HF updates the pt weights of Mixtral.
architectures = getattr(self.hf_config, "architectures", []) architectures = getattr(self.hf_config, "architectures", [])
...@@ -149,7 +145,7 @@ class ModelConfig: ...@@ -149,7 +145,7 @@ class ModelConfig:
def _verify_quantization(self) -> None: def _verify_quantization(self) -> None:
supported_quantization = ["awq", "gptq", "squeezellm"] supported_quantization = ["awq", "gptq", "squeezellm"]
rocm_not_supported_quantization = ["awq", "gptq"] rocm_not_supported_quantization = ["awq"]
if self.quantization is not None: if self.quantization is not None:
self.quantization = self.quantization.lower() self.quantization = self.quantization.lower()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment