Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
2b1c116b
Unverified
Commit
2b1c116b
authored
Sep 18, 2023
by
Woosuk Kwon
Committed by
GitHub
Sep 18, 2023
Browse files
Add minimum capability requirement for AWQ (#1064)
parent
cc796b13
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
47 additions
and
2 deletions
+47
-2
csrc/quantization/awq/dequantize.cuh
csrc/quantization/awq/dequantize.cuh
+8
-0
csrc/quantization/awq/gemm_kernels.cu
csrc/quantization/awq/gemm_kernels.cu
+16
-2
vllm/model_executor/model_loader.py
vllm/model_executor/model_loader.py
+8
-0
vllm/model_executor/quantization_utils/awq.py
vllm/model_executor/quantization_utils/awq.py
+5
-0
vllm/model_executor/quantization_utils/base.py
vllm/model_executor/quantization_utils/base.py
+10
-0
No files found.
csrc/quantization/awq/dequantize.cuh
View file @
2b1c116b
...
@@ -11,9 +11,14 @@ Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransfor
...
@@ -11,9 +11,14 @@ Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransfor
#pragma once
#pragma once
namespace
vllm
{
namespace
awq
{
__device__
uint4
dequantize_s4_to_fp16x2
(
uint32_t
const
&
source
)
__device__
uint4
dequantize_s4_to_fp16x2
(
uint32_t
const
&
source
)
{
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
uint4
result
;
uint4
result
;
uint32_t
*
h
=
reinterpret_cast
<
uint32_t
*>
(
&
result
);
uint32_t
*
h
=
reinterpret_cast
<
uint32_t
*>
(
&
result
);
...
@@ -75,5 +80,8 @@ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
...
@@ -75,5 +80,8 @@ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
h
[
3
])
:
"r"
(
h
[
3
]),
"r"
(
ONE_SIXTEENTH
),
"r"
(
NEG_64
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
h
[
3
])
:
"r"
(
h
[
3
]),
"r"
(
ONE_SIXTEENTH
),
"r"
(
NEG_64
));
return
result
;
return
result
;
#endif
}
}
}
// namespace awq
}
// namespace vllm
csrc/quantization/awq/gemm_kernels.cu
View file @
2b1c116b
...
@@ -16,6 +16,9 @@ Adapted from https://github.com/mit-han-lab/llm-awq
...
@@ -16,6 +16,9 @@ Adapted from https://github.com/mit-han-lab/llm-awq
#include <cuda_fp16.h>
#include <cuda_fp16.h>
namespace
vllm
{
namespace
awq
{
// Pack two half values.
// Pack two half values.
static
inline
__device__
__host__
unsigned
static
inline
__device__
__host__
unsigned
__pack_half2
(
const
half
x
,
const
half
y
)
{
__pack_half2
(
const
half
x
,
const
half
y
)
{
...
@@ -26,6 +29,9 @@ __pack_half2(const half x, const half y) {
...
@@ -26,6 +29,9 @@ __pack_half2(const half x, const half y) {
__global__
void
__launch_bounds__
(
64
)
gemm_forward_4bit_cuda_m16n128k32
(
int
G
,
int
split_k_iters
,
half
*
__restrict__
A
,
int
*
__restrict__
B
,
half
*
__restrict__
scaling_factors
,
int
*
__restrict__
zeros
,
int
M
,
int
IC
,
int
OC
,
half
*
__restrict__
C
)
__global__
void
__launch_bounds__
(
64
)
gemm_forward_4bit_cuda_m16n128k32
(
int
G
,
int
split_k_iters
,
half
*
__restrict__
A
,
int
*
__restrict__
B
,
half
*
__restrict__
scaling_factors
,
int
*
__restrict__
zeros
,
int
M
,
int
IC
,
int
OC
,
half
*
__restrict__
C
)
{
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
static
constexpr
uint32_t
ZERO
=
0x0
;
static
constexpr
uint32_t
ZERO
=
0x0
;
float
C_warp
[
32
];
float
C_warp
[
32
];
__shared__
half
A_shared
[
16
*
(
32
+
8
)];
__shared__
half
A_shared
[
16
*
(
32
+
8
)];
...
@@ -214,11 +220,15 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
...
@@ -214,11 +220,15 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
}
}
}
}
}
}
#endif
}
}
__global__
void
__launch_bounds__
(
64
)
gemm_forward_4bit_cuda_m16n64k32
(
int
G
,
int
split_k_iters
,
half
*
__restrict__
A
,
int
*
__restrict__
B
,
half
*
__restrict__
scaling_factors
,
int
*
__restrict__
zeros
,
int
M
,
int
IC
,
int
OC
,
half
*
__restrict__
C
)
__global__
void
__launch_bounds__
(
64
)
gemm_forward_4bit_cuda_m16n64k32
(
int
G
,
int
split_k_iters
,
half
*
__restrict__
A
,
int
*
__restrict__
B
,
half
*
__restrict__
scaling_factors
,
int
*
__restrict__
zeros
,
int
M
,
int
IC
,
int
OC
,
half
*
__restrict__
C
)
{
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
static
constexpr
uint32_t
ZERO
=
0x0
;
static
constexpr
uint32_t
ZERO
=
0x0
;
float
C_warp
[
32
];
float
C_warp
[
32
];
__shared__
half
A_shared
[
16
*
(
32
+
8
)];
__shared__
half
A_shared
[
16
*
(
32
+
8
)];
...
@@ -412,8 +422,12 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
...
@@ -412,8 +422,12 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
}
}
}
}
}
}
#endif
}
}
}
// namespace awq
}
// namespace vllm
// in_feats: M, IC [float16]
// in_feats: M, IC [float16]
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
// scaling_factors: IC // G, OC [float16]
// scaling_factors: IC // G, OC [float16]
...
@@ -459,7 +473,7 @@ torch::Tensor awq_gemm(
...
@@ -459,7 +473,7 @@ torch::Tensor awq_gemm(
// threadIdx.x: 32
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
// threadIdx.y: i_factors[2] * j_factors[2]
dim3
threads_per_block
(
32
,
2
);
dim3
threads_per_block
(
32
,
2
);
gemm_forward_4bit_cuda_m16n128k32
<<<
num_blocks
,
threads_per_block
>>>
(
vllm
::
awq
::
gemm_forward_4bit_cuda_m16n128k32
<<<
num_blocks
,
threads_per_block
>>>
(
group_size
,
split_k_iters
,
in_feats
,
kernel
,
scaling_factors
,
zeros
,
num_in_feats
,
num_in_channels
,
num_out_channels
,
out_feats
);
group_size
,
split_k_iters
,
in_feats
,
kernel
,
scaling_factors
,
zeros
,
num_in_feats
,
num_in_channels
,
num_out_channels
,
out_feats
);
}
}
else
if
(
num_out_channels
%
64
==
0
)
else
if
(
num_out_channels
%
64
==
0
)
...
@@ -470,7 +484,7 @@ torch::Tensor awq_gemm(
...
@@ -470,7 +484,7 @@ torch::Tensor awq_gemm(
// threadIdx.x: 32
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
// threadIdx.y: i_factors[2] * j_factors[2]
dim3
threads_per_block
(
32
,
2
);
dim3
threads_per_block
(
32
,
2
);
gemm_forward_4bit_cuda_m16n64k32
<<<
num_blocks
,
threads_per_block
>>>
(
vllm
::
awq
::
gemm_forward_4bit_cuda_m16n64k32
<<<
num_blocks
,
threads_per_block
>>>
(
group_size
,
split_k_iters
,
in_feats
,
kernel
,
scaling_factors
,
zeros
,
num_in_feats
,
num_in_channels
,
num_out_channels
,
out_feats
);
group_size
,
split_k_iters
,
in_feats
,
kernel
,
scaling_factors
,
zeros
,
num_in_feats
,
num_in_channels
,
num_out_channels
,
out_feats
);
}
}
return
_out_feats
.
sum
(
0
);
return
_out_feats
.
sum
(
0
);
...
...
vllm/model_executor/model_loader.py
View file @
2b1c116b
...
@@ -68,6 +68,14 @@ def get_model(model_config: ModelConfig) -> nn.Module:
...
@@ -68,6 +68,14 @@ def get_model(model_config: ModelConfig) -> nn.Module:
quant_config
=
get_quant_config
(
model_config
.
quantization
,
quant_config
=
get_quant_config
(
model_config
.
quantization
,
model_config
.
model
,
model_config
.
model
,
model_config
.
download_dir
)
model_config
.
download_dir
)
capability
=
torch
.
cuda
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
if
capability
<
quant_config
.
get_min_capability
():
raise
ValueError
(
f
"The quantization method
{
model_config
.
quantization
}
is not "
"supported for the current GPU. "
f
"Minimum capability:
{
quant_config
.
get_min_capability
()
}
. "
f
"Current capability:
{
capability
}
."
)
supported_dtypes
=
quant_config
.
get_supported_act_dtypes
()
supported_dtypes
=
quant_config
.
get_supported_act_dtypes
()
if
model_config
.
dtype
not
in
supported_dtypes
:
if
model_config
.
dtype
not
in
supported_dtypes
:
raise
ValueError
(
raise
ValueError
(
...
...
vllm/model_executor/quantization_utils/awq.py
View file @
2b1c116b
...
@@ -40,6 +40,11 @@ class AWQConfig(QuantizationConfig):
...
@@ -40,6 +40,11 @@ class AWQConfig(QuantizationConfig):
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
half
]
return
[
torch
.
half
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
# The AWQ kernel only supports Ampere or newer GPUs.
return
80
@
classmethod
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[
return
[
...
...
vllm/model_executor/quantization_utils/base.py
View file @
2b1c116b
...
@@ -15,6 +15,16 @@ class QuantizationConfig:
...
@@ -15,6 +15,16 @@ class QuantizationConfig:
"""List of supported activation dtypes."""
"""List of supported activation dtypes."""
raise
NotImplementedError
raise
NotImplementedError
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
"""Minimum GPU capability to support the quantization method.
E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
This requirement is due to the custom CUDA kernels used by the
quantization method.
"""
raise
NotImplementedError
@
classmethod
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
def
get_config_filenames
(
cls
)
->
List
[
str
]:
"""List of filenames to search for in the model directory."""
"""List of filenames to search for in the model directory."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment