Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
2b1c116b
Unverified
Commit
2b1c116b
authored
Sep 18, 2023
by
Woosuk Kwon
Committed by
GitHub
Sep 18, 2023
Browse files
Add minimum capability requirement for AWQ (#1064)
parent
cc796b13
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
47 additions
and
2 deletions
+47
-2
csrc/quantization/awq/dequantize.cuh
csrc/quantization/awq/dequantize.cuh
+8
-0
csrc/quantization/awq/gemm_kernels.cu
csrc/quantization/awq/gemm_kernels.cu
+16
-2
vllm/model_executor/model_loader.py
vllm/model_executor/model_loader.py
+8
-0
vllm/model_executor/quantization_utils/awq.py
vllm/model_executor/quantization_utils/awq.py
+5
-0
vllm/model_executor/quantization_utils/base.py
vllm/model_executor/quantization_utils/base.py
+10
-0
No files found.
csrc/quantization/awq/dequantize.cuh
View file @
2b1c116b
...
...
@@ -11,9 +11,14 @@ Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransfor
#pragma once
namespace
vllm
{
namespace
awq
{
__device__
uint4
dequantize_s4_to_fp16x2
(
uint32_t
const
&
source
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
uint4
result
;
uint32_t
*
h
=
reinterpret_cast
<
uint32_t
*>
(
&
result
);
...
...
@@ -75,5 +80,8 @@ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
h
[
3
])
:
"r"
(
h
[
3
]),
"r"
(
ONE_SIXTEENTH
),
"r"
(
NEG_64
));
return
result
;
#endif
}
}
// namespace awq
}
// namespace vllm
csrc/quantization/awq/gemm_kernels.cu
View file @
2b1c116b
...
...
@@ -16,6 +16,9 @@ Adapted from https://github.com/mit-han-lab/llm-awq
#include <cuda_fp16.h>
namespace
vllm
{
namespace
awq
{
// Pack two half values.
static
inline
__device__
__host__
unsigned
__pack_half2
(
const
half
x
,
const
half
y
)
{
...
...
@@ -26,6 +29,9 @@ __pack_half2(const half x, const half y) {
__global__
void
__launch_bounds__
(
64
)
gemm_forward_4bit_cuda_m16n128k32
(
int
G
,
int
split_k_iters
,
half
*
__restrict__
A
,
int
*
__restrict__
B
,
half
*
__restrict__
scaling_factors
,
int
*
__restrict__
zeros
,
int
M
,
int
IC
,
int
OC
,
half
*
__restrict__
C
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
static
constexpr
uint32_t
ZERO
=
0x0
;
float
C_warp
[
32
];
__shared__
half
A_shared
[
16
*
(
32
+
8
)];
...
...
@@ -214,11 +220,15 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
}
}
}
#endif
}
__global__
void
__launch_bounds__
(
64
)
gemm_forward_4bit_cuda_m16n64k32
(
int
G
,
int
split_k_iters
,
half
*
__restrict__
A
,
int
*
__restrict__
B
,
half
*
__restrict__
scaling_factors
,
int
*
__restrict__
zeros
,
int
M
,
int
IC
,
int
OC
,
half
*
__restrict__
C
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
static
constexpr
uint32_t
ZERO
=
0x0
;
float
C_warp
[
32
];
__shared__
half
A_shared
[
16
*
(
32
+
8
)];
...
...
@@ -412,8 +422,12 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
}
}
}
#endif
}
}
// namespace awq
}
// namespace vllm
// in_feats: M, IC [float16]
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
// scaling_factors: IC // G, OC [float16]
...
...
@@ -459,7 +473,7 @@ torch::Tensor awq_gemm(
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
dim3
threads_per_block
(
32
,
2
);
gemm_forward_4bit_cuda_m16n128k32
<<<
num_blocks
,
threads_per_block
>>>
(
vllm
::
awq
::
gemm_forward_4bit_cuda_m16n128k32
<<<
num_blocks
,
threads_per_block
>>>
(
group_size
,
split_k_iters
,
in_feats
,
kernel
,
scaling_factors
,
zeros
,
num_in_feats
,
num_in_channels
,
num_out_channels
,
out_feats
);
}
else
if
(
num_out_channels
%
64
==
0
)
...
...
@@ -470,7 +484,7 @@ torch::Tensor awq_gemm(
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
dim3
threads_per_block
(
32
,
2
);
gemm_forward_4bit_cuda_m16n64k32
<<<
num_blocks
,
threads_per_block
>>>
(
vllm
::
awq
::
gemm_forward_4bit_cuda_m16n64k32
<<<
num_blocks
,
threads_per_block
>>>
(
group_size
,
split_k_iters
,
in_feats
,
kernel
,
scaling_factors
,
zeros
,
num_in_feats
,
num_in_channels
,
num_out_channels
,
out_feats
);
}
return
_out_feats
.
sum
(
0
);
...
...
vllm/model_executor/model_loader.py
View file @
2b1c116b
...
...
@@ -68,6 +68,14 @@ def get_model(model_config: ModelConfig) -> nn.Module:
quant_config
=
get_quant_config
(
model_config
.
quantization
,
model_config
.
model
,
model_config
.
download_dir
)
capability
=
torch
.
cuda
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
if
capability
<
quant_config
.
get_min_capability
():
raise
ValueError
(
f
"The quantization method
{
model_config
.
quantization
}
is not "
"supported for the current GPU. "
f
"Minimum capability:
{
quant_config
.
get_min_capability
()
}
. "
f
"Current capability:
{
capability
}
."
)
supported_dtypes
=
quant_config
.
get_supported_act_dtypes
()
if
model_config
.
dtype
not
in
supported_dtypes
:
raise
ValueError
(
...
...
vllm/model_executor/quantization_utils/awq.py
View file @
2b1c116b
...
...
@@ -40,6 +40,11 @@ class AWQConfig(QuantizationConfig):
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
half
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
# The AWQ kernel only supports Ampere or newer GPUs.
return
80
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[
...
...
vllm/model_executor/quantization_utils/base.py
View file @
2b1c116b
...
...
@@ -15,6 +15,16 @@ class QuantizationConfig:
"""List of supported activation dtypes."""
raise
NotImplementedError
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
"""Minimum GPU capability to support the quantization method.
E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
This requirement is due to the custom CUDA kernels used by the
quantization method.
"""
raise
NotImplementedError
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
"""List of filenames to search for in the model directory."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment