Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
1b7c791d
Unverified
Commit
1b7c791d
authored
Dec 19, 2023
by
kliuae
Committed by
GitHub
Dec 18, 2023
Browse files
[ROCm] Fixes for GPTQ on ROCm (#2180)
parent
bbe4466f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
23 additions
and
16 deletions
+23
-16
csrc/quantization/gptq/q_gemm.cu
csrc/quantization/gptq/q_gemm.cu
+10
-0
docs/source/getting_started/amd-installation.rst
docs/source/getting_started/amd-installation.rst
+1
-0
setup.py
setup.py
+1
-1
vllm/config.py
vllm/config.py
+11
-15
No files found.
csrc/quantization/gptq/q_gemm.cu
View file @
1b7c791d
...
@@ -28,6 +28,7 @@ namespace gptq {
...
@@ -28,6 +28,7 @@ namespace gptq {
#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
#if defined(USE_ROCM)
#if defined(USE_ROCM)
#include <hipblas/hipblas.h>
__host__
__forceinline__
hipblasStatus_t
__compat_hipblasHgemm
(
hipblasHandle_t
handle
,
__host__
__forceinline__
hipblasStatus_t
__compat_hipblasHgemm
(
hipblasHandle_t
handle
,
hipblasOperation_t
transA
,
hipblasOperation_t
transA
,
hipblasOperation_t
transB
,
hipblasOperation_t
transB
,
...
@@ -520,12 +521,21 @@ __global__ void gemm_half_q_half_alt_kernel(
...
@@ -520,12 +521,21 @@ __global__ void gemm_half_q_half_alt_kernel(
zeros_tmp
[
tmp_k
]
=
zero
;
zeros_tmp
[
tmp_k
]
=
zero
;
}
}
for
(
int
m
=
0
;
m
<
b_end
;
m
++
)
{
for
(
int
m
=
0
;
m
<
b_end
;
m
++
)
{
#ifndef USE_ROCM
res2
=
{};
res2
=
{};
#else
res2
.
x
=
__half_as_ushort
(
__float2half
(
0
));
res2
.
y
=
__half_as_ushort
(
__float2half
(
0
));
#endif
res2
=
__hfma2
(
__hfma2
(
deq2
[(
tmp
>>
0
)
&
0xff
][
off
],
scales_tmp
[
0
],
zeros_tmp
[
0
]),
blockvec
[
m
][
k
+
0
],
res2
);
res2
=
__hfma2
(
__hfma2
(
deq2
[(
tmp
>>
0
)
&
0xff
][
off
],
scales_tmp
[
0
],
zeros_tmp
[
0
]),
blockvec
[
m
][
k
+
0
],
res2
);
res2
=
__hfma2
(
__hfma2
(
deq2
[(
tmp
>>
8
)
&
0xff
][
off
],
scales_tmp
[
1
],
zeros_tmp
[
1
]),
blockvec
[
m
][
k
+
1
],
res2
);
res2
=
__hfma2
(
__hfma2
(
deq2
[(
tmp
>>
8
)
&
0xff
][
off
],
scales_tmp
[
1
],
zeros_tmp
[
1
]),
blockvec
[
m
][
k
+
1
],
res2
);
res2
=
__hfma2
(
__hfma2
(
deq2
[(
tmp
>>
16
)
&
0xff
][
off
],
scales_tmp
[
2
],
zeros_tmp
[
2
]),
blockvec
[
m
][
k
+
2
],
res2
);
res2
=
__hfma2
(
__hfma2
(
deq2
[(
tmp
>>
16
)
&
0xff
][
off
],
scales_tmp
[
2
],
zeros_tmp
[
2
]),
blockvec
[
m
][
k
+
2
],
res2
);
res2
=
__hfma2
(
__hfma2
(
deq2
[(
tmp
>>
24
)
&
0xff
][
off
],
scales_tmp
[
3
],
zeros_tmp
[
3
]),
blockvec
[
m
][
k
+
3
],
res2
);
res2
=
__hfma2
(
__hfma2
(
deq2
[(
tmp
>>
24
)
&
0xff
][
off
],
scales_tmp
[
3
],
zeros_tmp
[
3
]),
blockvec
[
m
][
k
+
3
],
res2
);
#ifndef USE_ROCM
res
[
m
]
=
__hadd
(
res
[
m
],
__hadd
(
res2
.
x
,
res2
.
y
));
res
[
m
]
=
__hadd
(
res
[
m
],
__hadd
(
res2
.
x
,
res2
.
y
));
#else
res
[
m
]
=
__hadd
(
res
[
m
],
__hadd
(
__ushort_as_half
(
res2
.
x
),
__ushort_as_half
(
res2
.
y
)));
#endif
}
}
i
+=
width
;
i
+=
width
;
k
+=
4
;
k
+=
4
;
...
...
docs/source/getting_started/amd-installation.rst
View file @
1b7c791d
...
@@ -116,6 +116,7 @@ Alternatively, if you plan to install vLLM-ROCm on a local machine or start from
...
@@ -116,6 +116,7 @@ Alternatively, if you plan to install vLLM-ROCm on a local machine or start from
- `ROCm <https://rocm.docs.amd.com/en/latest/deploy/linux/index.html>`_
- `ROCm <https://rocm.docs.amd.com/en/latest/deploy/linux/index.html>`_
- `Pytorch <https://pytorch.org/>`_
- `Pytorch <https://pytorch.org/>`_
- `hipBLAS <https://rocm.docs.amd.com/projects/hipBLAS/en/latest/install.html>`_
1. Install `flash attention for ROCm <https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm>`_
1. Install `flash attention for ROCm <https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm>`_
...
...
setup.py
View file @
1b7c791d
...
@@ -219,13 +219,13 @@ vllm_extension_sources = [
...
@@ -219,13 +219,13 @@ vllm_extension_sources = [
"csrc/activation_kernels.cu"
,
"csrc/activation_kernels.cu"
,
"csrc/layernorm_kernels.cu"
,
"csrc/layernorm_kernels.cu"
,
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
,
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
,
"csrc/quantization/gptq/q_gemm.cu"
,
"csrc/cuda_utils_kernels.cu"
,
"csrc/cuda_utils_kernels.cu"
,
"csrc/pybind.cpp"
,
"csrc/pybind.cpp"
,
]
]
if
_is_cuda
():
if
_is_cuda
():
vllm_extension_sources
.
append
(
"csrc/quantization/awq/gemm_kernels.cu"
)
vllm_extension_sources
.
append
(
"csrc/quantization/awq/gemm_kernels.cu"
)
vllm_extension_sources
.
append
(
"csrc/quantization/gptq/q_gemm.cu"
)
vllm_extension
=
CUDAExtension
(
vllm_extension
=
CUDAExtension
(
name
=
"vllm._C"
,
name
=
"vllm._C"
,
...
...
vllm/config.py
View file @
1b7c791d
...
@@ -112,24 +112,20 @@ class ModelConfig:
...
@@ -112,24 +112,20 @@ class ModelConfig:
supported_load_format
=
[
supported_load_format
=
[
"auto"
,
"pt"
,
"safetensors"
,
"npcache"
,
"dummy"
"auto"
,
"pt"
,
"safetensors"
,
"npcache"
,
"dummy"
]
]
rocm_not_supported_load_format
=
[
"safetensors"
]
rocm_not_supported_load_format
=
[]
if
load_format
not
in
supported_load_format
:
if
load_format
not
in
supported_load_format
:
raise
ValueError
(
raise
ValueError
(
f
"Unknown load format:
{
self
.
load_format
}
. Must be one of "
f
"Unknown load format:
{
self
.
load_format
}
. Must be one of "
"'auto', 'pt', 'safetensors', 'npcache', or 'dummy'."
)
"'auto', 'pt', 'safetensors', 'npcache', or 'dummy'."
)
if
is_hip
():
if
is_hip
()
and
load_format
in
rocm_not_supported_load_format
:
if
load_format
in
[
"safetensors"
]:
rocm_supported_load_format
=
[
rocm_supported_load_format
=
[
f
for
f
in
supported_load_format
f
for
f
in
supported_load_format
if
(
f
not
in
rocm_not_supported_load_format
)
if
(
f
not
in
rocm_not_supported_load_format
)
]
]
raise
ValueError
(
raise
ValueError
(
f
"load format
\'
{
load_format
}
\'
is not supported in ROCm. "
f
"load format
\'
{
load_format
}
\'
is not supported in ROCm. "
f
"Supported load format are "
f
"Supported load format are "
f
"
{
rocm_supported_load_format
}
"
)
f
"
{
rocm_supported_load_format
}
"
)
# Force ROCm to load from pt weights if nothing specific is set
if
load_format
==
"auto"
:
load_format
=
"pt"
# TODO: Remove this check once HF updates the pt weights of Mixtral.
# TODO: Remove this check once HF updates the pt weights of Mixtral.
architectures
=
getattr
(
self
.
hf_config
,
"architectures"
,
[])
architectures
=
getattr
(
self
.
hf_config
,
"architectures"
,
[])
...
@@ -149,7 +145,7 @@ class ModelConfig:
...
@@ -149,7 +145,7 @@ class ModelConfig:
def
_verify_quantization
(
self
)
->
None
:
def
_verify_quantization
(
self
)
->
None
:
supported_quantization
=
[
"awq"
,
"gptq"
,
"squeezellm"
]
supported_quantization
=
[
"awq"
,
"gptq"
,
"squeezellm"
]
rocm_not_supported_quantization
=
[
"awq"
,
"gptq"
]
rocm_not_supported_quantization
=
[
"awq"
]
if
self
.
quantization
is
not
None
:
if
self
.
quantization
is
not
None
:
self
.
quantization
=
self
.
quantization
.
lower
()
self
.
quantization
=
self
.
quantization
.
lower
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment