Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a8d604ca
Unverified
Commit
a8d604ca
authored
Aug 02, 2024
by
Lucas Wilkinson
Committed by
GitHub
Aug 02, 2024
Browse files
[Misc] Disambiguate quantized types via a new ScalarType (#6396)
parent
b482b9a5
Changes
29
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
269 additions
and
212 deletions
+269
-212
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py
...compressed_tensors/schemes/compressed_tensors_w4a16_24.py
+14
-4
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
...on/compressed_tensors/schemes/compressed_tensors_wNa16.py
+20
-9
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+27
-16
vllm/model_executor/layers/quantization/gptq_marlin_24.py
vllm/model_executor/layers/quantization/gptq_marlin_24.py
+19
-10
vllm/model_executor/layers/quantization/utils/marlin_utils.py
.../model_executor/layers/quantization/utils/marlin_utils.py
+57
-63
vllm/model_executor/layers/quantization/utils/marlin_utils_test.py
...l_executor/layers/quantization/utils/marlin_utils_test.py
+19
-10
vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py
...xecutor/layers/quantization/utils/marlin_utils_test_24.py
+14
-16
vllm/model_executor/layers/quantization/utils/quant_utils.py
vllm/model_executor/layers/quantization/utils/quant_utils.py
+64
-84
vllm/scalar_type.py
vllm/scalar_type.py
+35
-0
No files found.
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py
View file @
a8d604ca
...
...
@@ -9,9 +9,13 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
(
GPTQ_MARLIN_24_MAX_PARALLEL
,
GPTQ_MARLIN_24_MIN_THREAD_N
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.scalar_type
import
scalar_types
__all__
=
[
"CompressedTensorsW4A16Sparse24"
]
W4A16SPARSE24_SUPPORTED_BITS
=
[
4
]
W4A16SPARSE24_SUPPORTED_TYPES_MAP
=
{
4
:
scalar_types
.
uint4b8
,
}
W4A16SPARSE24_SUPPORTED_BITS
=
list
(
W4A16SPARSE24_SUPPORTED_TYPES_MAP
.
keys
())
class
CompressedTensorsW4A16Sparse24
(
CompressedTensorsScheme
):
...
...
@@ -22,9 +26,15 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
group_size
:
Optional
[
int
]
=
None
):
self
.
strategy
=
strategy
self
.
group_size
=
group_size
self
.
num_bits
=
num_bits
self
.
tile_size
=
16
if
num_bits
not
in
W4A16SPARSE24_SUPPORTED_TYPES_MAP
:
raise
ValueError
(
f
"Unsupported num_bits =
{
num_bits
}
. "
f
"Supported num_bits =
{
W4A16SPARSE24_SUPPORTED_BITS
}
"
)
self
.
quant_type
=
W4A16SPARSE24_SUPPORTED_TYPES_MAP
[
num_bits
]
if
self
.
strategy
==
"group"
and
self
.
group_size
is
None
:
raise
ValueError
(
"group_size must be given when using strategy group"
)
...
...
@@ -43,7 +53,7 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
):
pack_factor
=
32
//
self
.
num
_bits
pack_factor
=
32
//
self
.
quant_type
.
size
_bits
output_size_per_partition
=
sum
(
output_partition_sizes
)
qweight
=
Parameter
(
...
...
@@ -138,7 +148,7 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
size_n
=
scales
.
shape
[
1
]
output_2d
=
ops
.
gptq_marlin_24_gemm
(
x_2d
,
qweight
,
meta
,
scales
,
workspace
,
self
.
num_bits
,
size_m
,
workspace
,
self
.
quant_type
,
size_m
,
size_n
,
size_k
)
output
=
output_2d
.
view
(
x
.
shape
[:
-
1
]
+
(
output_2d
.
shape
[
1
],
))
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
View file @
a8d604ca
...
...
@@ -8,12 +8,17 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
apply_gptq_marlin_linear
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_permute_scales
,
replace_tensor
,
verify_
gptq_
marlin_supported
,
marlin_permute_scales
,
replace_tensor
,
verify_marlin_supported
,
verify_marlin_supports_shape
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.scalar_type
import
scalar_types
__all__
=
[
"CompressedTensorsWNA16"
]
WNA16_SUPPORTED_BITS
=
[
4
,
8
]
WNA16_SUPPORTED_TYPES_MAP
=
{
4
:
scalar_types
.
uint4b8
,
8
:
scalar_types
.
uint8b128
,
}
WNA16_SUPPORTED_BITS
=
list
(
WNA16_SUPPORTED_TYPES_MAP
.
keys
())
class
CompressedTensorsWNA16
(
CompressedTensorsScheme
):
...
...
@@ -22,8 +27,8 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
strategy
:
str
,
num_bits
:
int
,
group_size
:
Optional
[
int
]
=
None
):
self
.
num_bits
=
num_bits
self
.
pack_factor
=
32
//
self
.
num_bits
self
.
pack_factor
=
32
//
num_bits
self
.
strategy
=
strategy
self
.
group_size
:
int
...
...
@@ -37,10 +42,16 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
else
:
self
.
group_size
=
group_size
if
num_bits
not
in
WNA16_SUPPORTED_TYPES_MAP
:
raise
ValueError
(
f
"Unsupported num_bits =
{
num_bits
}
. "
f
"Supported num_bits =
{
WNA16_SUPPORTED_TYPES_MAP
.
keys
()
}
"
)
self
.
quant_type
=
WNA16_SUPPORTED_TYPES_MAP
[
num_bits
]
# Verify supported on platform.
verify_gptq_marlin_supported
(
num_bits
=
self
.
num_bits
,
group_size
=
self
.
group_size
,
is_sym
=
True
)
verify_marlin_supported
(
quant_type
=
self
.
quant_type
,
group_size
=
self
.
group_size
)
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
...
...
@@ -150,7 +161,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
perm
=
layer
.
g_idx_sort_indices
,
size_k
=
layer
.
input_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
num_bits
=
self
.
num
_bits
)
num_bits
=
self
.
quant_type
.
size
_bits
)
replace_tensor
(
layer
,
"weight_packed"
,
marlin_qweight
)
# Permute scales from compressed-tensors format to marlin format.
...
...
@@ -172,7 +183,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
g_idx
=
layer
.
g_idx
,
g_idx_sort_indices
=
layer
.
g_idx_sort_indices
,
workspace
=
layer
.
workspace
,
num_bits
=
self
.
num_bits
,
wtype
=
self
.
quant_type
,
output_size_per_partition
=
layer
.
output_size_per_partition
,
input_size_per_partition
=
layer
.
input_size_per_partition
,
is_k_full
=
True
,
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
a8d604ca
...
...
@@ -10,11 +10,12 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
apply_gptq_marlin_linear
,
check_
gptq_
marlin_supported
,
marlin_is_k_full
,
apply_gptq_marlin_linear
,
check_marlin_supported
,
marlin_is_k_full
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_permute_scales
,
marlin_repeat_scales_on_all_ranks
,
marlin_sort_g_idx
,
replace_tensor
,
verify_
gptq_
marlin_supported
,
verify_marlin_supports_shape
)
verify_marlin_supported
,
verify_marlin_supports_shape
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.scalar_type
import
scalar_types
logger
=
init_logger
(
__name__
)
...
...
@@ -22,6 +23,12 @@ logger = init_logger(__name__)
class
GPTQMarlinConfig
(
QuantizationConfig
):
"""Config class for GPTQ Marlin"""
# (num_bits, is_sym) -> quant_type
TYPE_MAP
=
{
(
4
,
True
):
scalar_types
.
uint4b8
,
(
8
,
True
):
scalar_types
.
uint8b128
,
}
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
desc_act
:
bool
,
is_sym
:
bool
,
lm_head_quantized
:
bool
)
->
None
:
if
desc_act
and
group_size
==
-
1
:
...
...
@@ -29,20 +36,23 @@ class GPTQMarlinConfig(QuantizationConfig):
# (since we have only one group per output channel)
desc_act
=
False
self
.
weight_bits
=
weight_bits
self
.
pack_factor
=
32
//
self
.
weight_bits
# packed into int32
self
.
pack_factor
=
32
//
weight_bits
# packed into int32
self
.
group_size
=
group_size
self
.
desc_act
=
desc_act
self
.
is_sym
=
is_sym
self
.
lm_head_quantized
=
lm_head_quantized
if
(
weight_bits
,
is_sym
)
not
in
self
.
TYPE_MAP
:
raise
ValueError
(
"Unsupported quantization config: "
f
"bits=
{
weight_bits
}
, sym=
{
is_sym
}
"
)
self
.
quant_type
=
self
.
TYPE_MAP
[(
weight_bits
,
is_sym
)]
# Verify supported on platform.
verify_gptq_marlin_supported
(
num_bits
=
self
.
weight_bits
,
group_size
=
self
.
group_size
,
is_sym
=
self
.
is_sym
)
verify_marlin_supported
(
quant_type
=
self
.
quant_type
,
group_size
=
self
.
group_size
)
def
__repr__
(
self
)
->
str
:
return
(
f
"GPTQMarlinConfig(
weight_bits=
{
self
.
weight_bits
}
, "
return
(
f
"GPTQMarlinConfig(
quant_type=
{
self
.
quant_type
}
, "
f
"group_size=
{
self
.
group_size
}
, "
f
"desc_act=
{
self
.
desc_act
}
, "
f
"lm_head_quantized=
{
self
.
lm_head_quantized
}
)"
)
...
...
@@ -122,11 +132,12 @@ class GPTQMarlinConfig(QuantizationConfig):
or
desc_act
is
None
):
return
False
return
check_gptq_marlin_supported
(
num_bits
=
num_bits
,
group_size
=
group_size
,
is_sym
=
sym
,
min_capability
=
cls
.
get_min_capability
())
if
(
num_bits
,
sym
)
not
in
cls
.
TYPE_MAP
:
return
False
return
check_marlin_supported
(
quant_type
=
cls
.
TYPE_MAP
[(
num_bits
,
sym
)],
group_size
=
group_size
,
min_capability
=
cls
.
get_min_capability
())
class
GPTQMarlinLinearMethod
(
LinearMethodBase
):
...
...
@@ -293,7 +304,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
perm
=
layer
.
g_idx_sort_indices
,
size_k
=
layer
.
input_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
num_bits
=
self
.
quant_config
.
weight
_bits
)
num_bits
=
self
.
quant_config
.
quant_type
.
size
_bits
)
replace_tensor
(
layer
,
"qweight"
,
marlin_qweight
)
# Permute scales from autogptq format to marlin format.
...
...
@@ -319,7 +330,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
g_idx
=
layer
.
g_idx
,
g_idx_sort_indices
=
layer
.
g_idx_sort_indices
,
workspace
=
layer
.
workspace
,
num_bits
=
self
.
quant_config
.
weight_bits
,
wtype
=
self
.
quant_config
.
quant_type
,
output_size_per_partition
=
layer
.
output_size_per_partition
,
input_size_per_partition
=
layer
.
input_size_per_partition
,
is_k_full
=
layer
.
is_k_full
,
...
...
vllm/model_executor/layers/quantization/gptq_marlin_24.py
View file @
a8d604ca
...
...
@@ -9,6 +9,7 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.scalar_type
import
scalar_types
logger
=
init_logger
(
__name__
)
...
...
@@ -17,9 +18,10 @@ GPTQ_MARLIN_24_MIN_THREAD_N = 128
GPTQ_MARLIN_24_MIN_THREAD_K
=
128
GPTQ_MARLIN_24_MAX_PARALLEL
=
64
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
=
[
4
,
8
]
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
=
[
scalar_types
.
uint4b8
,
scalar_types
.
uint8b128
]
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
=
[
-
1
,
128
]
GPTQ_MARLIN_24_SUPPORTED_SYM
=
[
True
]
class
GPTQMarlin24Config
(
QuantizationConfig
):
...
...
@@ -31,14 +33,19 @@ class GPTQMarlin24Config(QuantizationConfig):
weight_bits
:
int
,
group_size
:
int
,
)
->
None
:
self
.
weight_bits
=
weight_bits
quant_type
=
{
4
:
scalar_types
.
uint4b8
,
8
:
scalar_types
.
uint8b128
,
}.
get
(
weight_bits
)
self
.
group_size
=
group_size
# Verify
if
self
.
weight_bits
not
in
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
:
if
quant_type
is
None
or
\
quant_type
not
in
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
:
raise
ValueError
(
f
"Marlin_24 does not support
weight_bits =
{
self
.
weight_bits
}
. "
f
"Only weight_bits =
{
GPTQ_MARLIN_24_SUPPORTED_
NUM_BIT
S
}
"
f
"Marlin_24 does not support
quant_type =
{
quant_type
}
. "
f
"Only weight_bits =
{
GPTQ_MARLIN_24_SUPPORTED_
QUANT_TYPE
S
}
"
"are supported."
)
if
self
.
group_size
not
in
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
:
raise
ValueError
(
...
...
@@ -46,8 +53,10 @@ class GPTQMarlin24Config(QuantizationConfig):
f
"Only group_sizes =
{
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
}
"
"are supported."
)
self
.
quant_type
=
quant_type
# 4 Bits packed into 32 bit datatype.
self
.
pack_factor
=
32
//
self
.
weight
_bits
self
.
pack_factor
=
32
//
self
.
quant_type
.
size
_bits
# Tile size used by marlin kernels.
self
.
tile_size
=
16
...
...
@@ -66,8 +75,8 @@ class GPTQMarlin24Config(QuantizationConfig):
self
.
perm_len
=
1024
def
__repr__
(
self
)
->
str
:
return
"Marlin24Config(
weight_bits
={}, group_size={})"
.
format
(
self
.
weight_bits
,
self
.
group_size
)
return
"Marlin24Config(
quant_type
={}, group_size={})"
.
format
(
self
.
quant_type
,
self
.
group_size
)
@
classmethod
def
get_name
(
cls
)
->
str
:
...
...
@@ -279,7 +288,7 @@ class GPTQMarlin24LinearMethod(LinearMethodBase):
output_2d
=
ops
.
gptq_marlin_24_gemm
(
x_2d
,
qweight
,
meta
,
scales
,
workspace
,
self
.
quant_config
.
weight_bits
,
self
.
quant_config
.
quant_type
,
size_m
,
size_n
,
size_k
)
output
=
output_2d
.
view
(
x
.
shape
[:
-
1
]
+
(
output_2d
.
shape
[
1
],
))
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils.py
View file @
a8d604ca
...
...
@@ -5,6 +5,7 @@ import torch
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
ScalarType
,
scalar_types
from
.quant_utils
import
pack_cols
,
unpack_cols
...
...
@@ -13,7 +14,6 @@ GPTQ_MARLIN_MIN_THREAD_N = 64
GPTQ_MARLIN_MIN_THREAD_K
=
128
GPTQ_MARLIN_MAX_PARALLEL
=
16
MARLIN_SUPPORTED_NUM_BITS
=
[
4
,
8
]
MARLIN_SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
# In case there is a performance issue with Marlin, the variable below can be
...
...
@@ -22,76 +22,70 @@ MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
USE_FP32_REDUCE_DEFAULT
=
True
def
_check_marlin_supported
(
num_bits
:
int
,
group_size
:
int
,
is_sym
:
bool
,
min_capability
:
Optional
[
int
],
has_zp
:
bool
)
->
Tuple
[
bool
,
Optional
[
str
]]:
if
min_capability
is
not
None
:
# For binary size and compile time, we don't support the same types for with and
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
# TODO: we may want to move this into the C++ so its closer to the actual impl
def
query_marlin_supported_quant_types
(
has_zp
:
bool
,
min_capability
:
Optional
[
int
]
=
None
):
if
min_capability
is
None
:
major
,
minor
=
current_platform
.
get_device_capability
()
device_capability
=
major
*
10
+
minor
if
device_capability
<
min_capability
:
return
(
False
,
"Marlin does not support device_capability = {}"
", the min_capability required is {}"
.
format
(
device_capability
,
min_capability
))
if
num_bits
not
in
MARLIN_SUPPORTED_NUM_BITS
:
return
(
False
,
"Marlin does not support weight_bits = {}. "
"Only weight_bits = {} are supported."
.
format
(
num_bits
,
MARLIN_SUPPORTED_NUM_BITS
))
if
group_size
not
in
MARLIN_SUPPORTED_GROUP_SIZES
:
return
(
False
,
"Marlin does not support group_size = {}. Only "
"group_sizes = {} are supported."
.
format
(
group_size
,
MARLIN_SUPPORTED_GROUP_SIZES
))
if
not
has_zp
and
not
is_sym
:
return
(
False
,
"Marlin without zero_points must have symmetric quantization"
)
min_capability
=
major
*
10
+
minor
return
True
,
None
if
min_capability
<
80
:
return
[]
if
has_zp
:
# AWQ style, unsigned + runtime zero-point
return
[
scalar_types
.
uint4
,
scalar_types
.
uint8
]
else
:
# GPTQ style, unsigned + symmetric bias
# TODO: once fp8_marlin is merged into "gptq_marlin" we should be able
# to add `scalar_types.float8_e4m3fn` here
return
[
scalar_types
.
uint4b8
,
scalar_types
.
uint8b128
]
def
check_gptq_marlin_supported
(
num_bits
:
int
,
group_size
:
int
,
is_sym
:
bool
,
min_capability
:
int
)
->
bool
:
cond
,
_
=
_check_marlin_supported
(
num_bits
,
group_size
,
is_sym
,
min_capability
,
has_zp
=
False
)
return
cond
def
_check_marlin_supported
(
quant_type
:
ScalarType
,
group_size
:
Optional
[
int
],
has_zp
:
bool
,
min_capability
:
Optional
[
int
]
=
None
)
->
Tuple
[
bool
,
Optional
[
str
]]:
def
check_awq_marlin_supported
(
num_bits
:
int
,
group_size
:
int
,
has_zp
:
bool
,
min_capability
:
int
)
->
bool
:
cond
,
_
=
_check_marlin_supported
(
num_bits
,
group_size
,
False
,
min_capability
,
has_zp
=
has_zp
)
return
cond
if
min_capability
is
None
:
major
,
minor
=
current_platform
.
get_device_capability
()
min_capability
=
major
*
10
+
minor
supported_types
=
query_marlin_supported_quant_types
(
has_zp
,
min_capability
)
def
verify_gptq_marlin_supported
(
num_bits
:
int
,
group_size
:
int
,
is_sym
:
bool
)
->
None
:
cond
,
err_msg
=
_check_marlin_supported
(
num_bits
,
group_size
,
is_sym
,
min_capability
=
None
,
has_zp
=
False
)
if
not
cond
:
assert
err_msg
is
not
None
raise
ValueError
(
"GPTQ"
+
err_msg
)
if
quant_type
not
in
supported_types
:
return
(
False
,
f
"Marlin does not support weight_bits =
{
quant_type
}
. "
f
"Only types =
{
supported_types
}
"
f
"are supported (for group_size =
{
group_size
}
, "
f
"min_capability =
{
min_capability
}
, zp =
{
has_zp
}
)."
)
if
(
group_size
is
None
or
group_size
not
in
MARLIN_SUPPORTED_GROUP_SIZES
):
return
(
False
,
f
"Marlin does not support group_size =
{
group_size
}
. "
f
"Only group_sizes =
{
MARLIN_SUPPORTED_GROUP_SIZES
}
"
"are supported."
)
return
True
,
None
def
check_marlin_supported
(
quant_type
:
ScalarType
,
group_size
:
int
,
has_zp
:
bool
=
False
,
min_capability
:
Optional
[
int
]
=
None
)
->
bool
:
cond
,
_
=
_check_marlin_supported
(
quant_type
,
group_size
,
has_zp
,
min_capability
)
return
cond
def
verify_awq_marlin_supported
(
num_bits
:
int
,
group_size
:
int
,
has_zp
:
bool
)
->
None
:
cond
,
err_msg
=
_check_marlin_supported
(
num_bits
,
group_size
,
False
,
min_capability
=
None
,
has_zp
=
has_zp
)
def
verify_marlin_supported
(
quant_type
:
ScalarType
,
group_size
:
int
,
has_zp
:
bool
=
False
)
->
None
:
cond
,
err_msg
=
_check_marlin_supported
(
quant_type
,
group_size
,
has_zp
)
if
not
cond
:
assert
err_msg
is
not
None
raise
ValueError
(
"AWQ"
+
err_msg
)
raise
ValueError
(
err_msg
)
def
verify_marlin_supports_shape
(
output_size_per_partition
:
int
,
...
...
@@ -245,7 +239,7 @@ def apply_gptq_marlin_linear(
g_idx
:
torch
.
Tensor
,
g_idx_sort_indices
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
num_bits
:
int
,
wtype
:
ScalarType
,
output_size_per_partition
:
int
,
input_size_per_partition
:
int
,
is_k_full
:
bool
,
...
...
@@ -261,7 +255,7 @@ def apply_gptq_marlin_linear(
g_idx
,
g_idx_sort_indices
,
workspace
,
num_bits
,
wtype
,
size_m
=
reshaped_x
.
shape
[
0
],
size_n
=
output_size_per_partition
,
size_k
=
input_size_per_partition
,
...
...
@@ -283,7 +277,7 @@ def apply_awq_marlin_linear(
g_idx
:
torch
.
Tensor
,
g_idx_sort_indices
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
num_bits
:
int
,
quant_type
:
ScalarType
,
output_size_per_partition
:
int
,
input_size_per_partition
:
int
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -298,7 +292,7 @@ def apply_awq_marlin_linear(
g_idx
,
g_idx_sort_indices
,
workspace
,
num_bits
,
quant_type
,
size_m
=
reshaped_x
.
shape
[
0
],
size_n
=
output_size_per_partition
,
size_k
=
input_size_per_partition
,
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils_test.py
View file @
a8d604ca
...
...
@@ -5,10 +5,12 @@ from typing import List
import
numpy
as
np
import
torch
from
vllm.scalar_type
import
ScalarType
from
.marlin_utils
import
(
GPTQ_MARLIN_TILE
,
marlin_permute_scales
,
marlin_zero_points
)
from
.quant_utils
import
(
get_pack_factor
,
quantize_weights
,
quantize_weights
_with_zp
,
sort_weights
)
from
.quant_utils
import
(
get_pack_factor
,
gptq_
quantize_weights
,
quantize_weights
,
sort_weights
)
class
MarlinWorkspace
:
...
...
@@ -90,9 +92,10 @@ def get_weight_perm(num_bits: int):
return
perm
def
marlin_quantize
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
,
def
marlin_quantize
(
w
:
torch
.
Tensor
,
quant_type
:
ScalarType
,
group_size
:
int
,
act_order
:
bool
):
size_k
,
size_n
=
w
.
shape
num_bits
=
quant_type
.
size_bits
# Normalize group_size
if
group_size
==
-
1
:
...
...
@@ -100,8 +103,8 @@ def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int,
assert
group_size
<=
size_k
# Quantize (and apply act_order if provided)
w_ref
,
q_w
,
s
,
g_idx
,
rand_perm
=
quantize_weights
(
w
,
num_bits
,
group_size
,
act_order
)
w_ref
,
q_w
,
s
,
g_idx
,
rand_perm
=
gptq_
quantize_weights
(
w
,
quant_type
,
group_size
,
act_order
)
# For act_order, sort the "weights" and "g_idx" so that group ids are
# increasing
...
...
@@ -122,7 +125,8 @@ def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int,
return
res_list
def
awq_marlin_quantize
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
):
def
awq_marlin_quantize
(
w
:
torch
.
Tensor
,
quant_type
:
ScalarType
,
group_size
:
int
):
size_k
,
size_n
=
w
.
shape
# Normalize group_size
...
...
@@ -135,13 +139,18 @@ def awq_marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int):
num_groups
=
size_k
//
group_size
# Quantize with zp
w_ref
,
q_w
,
s
,
zp
=
quantize_weights_with_zp
(
w
,
num_bits
,
group_size
)
w_ref
,
q_w
,
s
,
zp
=
quantize_weights
(
w
,
quant_type
,
group_size
,
zero_points
=
True
)
# Reformat to marlin
weight_perm
=
get_weight_perm
(
num_bits
)
marlin_q_w
=
marlin_weights
(
q_w
,
size_k
,
size_n
,
num_bits
,
weight_perm
)
weight_perm
=
get_weight_perm
(
quant_type
.
size_bits
)
marlin_q_w
=
marlin_weights
(
q_w
,
size_k
,
size_n
,
quant_type
.
size_bits
,
weight_perm
)
marlin_s
=
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
)
marlin_zp
=
marlin_zero_points
(
zp
,
num_groups
,
size_n
,
num_bits
)
marlin_zp
=
marlin_zero_points
(
zp
,
num_groups
,
size_n
,
quant_type
.
size_bits
)
# Create result
res_list
=
[
w_ref
,
marlin_q_w
,
marlin_s
,
marlin_zp
]
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py
View file @
a8d604ca
...
...
@@ -6,8 +6,10 @@ from typing import List
import
numpy
import
torch
from
vllm.scalar_type
import
ScalarType
from
.marlin_utils_test
import
marlin_weights
from
.quant_utils
import
quantize_weights
from
.quant_utils
import
gptq_
quantize_weights
# This is PyTorch implementation of main part of reorder_meta()
...
...
@@ -348,13 +350,11 @@ def check_24(w, num_rows_to_sample=50, _verbose=False):
print
(
f
"
{
non_24_segments
}
/
{
total_segments
}
do not have 2:4 structure."
)
def
compress_quantized_24_weight
(
q_24
,
size_k
,
size_n
,
num_bits
):
def
compress_quantized_24_weight
(
q_24
,
size_k
,
size_n
,
wtype
:
ScalarType
):
assert
q_24
.
shape
==
(
size_k
,
size_n
)
# Remove zp to normalize over 0
max_q_val
=
(
1
<<
num_bits
)
-
1
zp
=
(
max_q_val
+
1
)
//
2
q_24_no_zp
=
q_24
-
zp
# Remove bias to normalize over 0
q_24_no_zp
=
q_24
-
wtype
.
bias
# Compress
q_24_no_zp
=
q_24_no_zp
.
t
().
contiguous
()
...
...
@@ -362,8 +362,8 @@ def compress_quantized_24_weight(q_24, size_k, size_n, num_bits):
q_24_no_zp
)
q_24_no_zp_comp
=
q_24_no_zp_comp
.
t
().
contiguous
()
# Restore
zp
q_24_comp
=
q_24_no_zp_comp
+
zp
# Restore
bias
q_24_comp
=
q_24_no_zp_comp
+
wtype
.
bias
# Resize meta to its actual shape (without moving any data)
meta
=
meta
.
resize_
(
meta
.
shape
[
1
]
//
2
,
meta
.
shape
[
0
]
*
2
)
...
...
@@ -427,7 +427,7 @@ def marlin_permute_scales_24(s: torch.Tensor, size_k: int, size_n: int,
def
marlin_24_quantize
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
quant_type
:
ScalarType
,
group_size
:
int
,
):
size_k
,
size_n
=
w
.
shape
...
...
@@ -441,20 +441,18 @@ def marlin_24_quantize(
w_24
,
mask_24
=
inject_24
(
w
,
size_k
,
size_n
)
# Quantize
w_24_ref
,
q_w_24
,
s
,
g_idx
,
rand_perm
=
quantize_weights
(
w_24
,
num_bits
,
group_size
,
act_order
=
False
)
w_24_ref
,
q_w_24
,
s
,
g_idx
,
rand_perm
=
gptq_quantize_weights
(
w_24
,
quant_type
,
group_size
,
act_order
=
False
)
# Compress quantized weight
q_w_24_comp
,
meta
=
compress_quantized_24_weight
(
q_w_24
,
size_k
,
size_n
,
num_bits
)
quant_type
)
size_k_comp
=
size_k
//
2
# Reformat to marlin
weight_perm
=
get_weight_perm_24
(
num
_bits
)
weight_perm
=
get_weight_perm_24
(
quant_type
.
size
_bits
)
marlin_24_q_w_comp
=
marlin_weights
(
q_w_24_comp
,
size_k_comp
,
size_n
,
num
_bits
,
weight_perm
)
quant_type
.
size
_bits
,
weight_perm
)
marlin_24_s
=
marlin_permute_scales_24
(
s
,
size_k
,
size_n
,
group_size
)
# Create result
...
...
vllm/model_executor/layers/quantization/utils/quant_utils.py
View file @
a8d604ca
...
...
@@ -4,7 +4,11 @@ from typing import List
import
numpy
import
torch
SUPPORTED_NUM_BITS
=
[
4
,
8
]
from
vllm.model_executor.layers.quantization.qqq
import
(
MARLIN_QQQ_SUPPORTED_NUM_BITS
)
from
vllm.scalar_type
import
ScalarType
,
scalar_types
SUPPORTED_GPTQ_QUANT_TYPES
=
[
scalar_types
.
uint4b8
,
scalar_types
.
uint8b128
]
SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
# Note: this is a hack. We should update each model to register the
...
...
@@ -45,7 +49,7 @@ def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
def
get_pack_factor
(
num_bits
):
assert
num_bits
in
SUPPORTED_NUM_BITS
,
f
"Unsupported num_bits =
{
num_bits
}
"
assert
32
%
num_bits
==
0
,
f
"Unsupported num_bits =
{
num_bits
}
"
return
32
//
num_bits
...
...
@@ -74,24 +78,23 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
)
def
quantize_weights
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
,
act_order
:
bool
):
def
quantize_weights
(
w
:
torch
.
Tensor
,
quant_type
:
ScalarType
,
group_size
:
int
,
zero_points
:
bool
=
False
):
assert
quant_type
.
is_integer
(),
\
"Floating point quantization may work but has not been tested"
orig_device
=
w
.
device
orig_type
=
w
.
dtype
size_k
,
size_n
=
w
.
shape
assert
w
.
is_floating_point
(),
"w must be float"
assert
num_bits
in
SUPPORTED_NUM_BITS
,
f
"Unsupported num_bits =
{
num_bits
}
"
assert
group_size
in
SUPPORTED_GROUP_SIZES
+
[
size_k
],
f
"Unsupported groupsize =
{
group_size
}
"
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
max_q_val
=
2
**
num_bits
-
1
half_q_val
=
(
max_q_val
+
1
)
//
2
# Reshape to [groupsize, -1]
if
group_size
<
size_k
:
w
=
w
.
reshape
((
-
1
,
group_size
,
size_n
))
...
...
@@ -99,16 +102,34 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
w
=
w
.
reshape
((
group_size
,
-
1
))
# Compute scale for each group
s
=
torch
.
max
(
torch
.
abs
(
w
),
0
,
keepdim
=
True
)[
0
]
s
*=
2
/
max_q_val
# 2 => symmetric
max_val
=
torch
.
max
(
w
,
0
,
keepdim
=
True
).
values
min_val
=
torch
.
min
(
w
,
0
,
keepdim
=
True
).
values
max_q_val
=
quant_type
.
max
()
min_q_val
=
quant_type
.
min
()
if
zero_points
:
assert
not
quant_type
.
is_signed
()
and
quant_type
.
max
()
>
0
w_s
=
(
max_val
-
min_val
).
clamp
(
min
=
1e-5
)
/
quant_type
.
max
()
maybe_w_zp
=
torch
.
round
(
torch
.
abs
(
min_val
/
w_s
))
\
.
clamp
(
min_q_val
,
max_q_val
).
int
()
else
:
# If the bias is such that there are no possible negative/positive
# values, set the max value to inf to avoid divide by 0
w_s
=
torch
.
max
(
abs
(
max_val
/
(
max_q_val
if
max_q_val
!=
0
else
torch
.
inf
)),
abs
(
min_val
/
(
min_q_val
if
min_q_val
!=
0
else
torch
.
inf
)))
maybe_w_zp
=
None
# Quantize
q_w
=
torch
.
round
(
w
/
s
).
int
()
q_w
+=
half_q_val
q_w
=
torch
.
clamp
(
q_w
,
0
,
max_q_val
)
w_q
=
torch
.
round
(
w
/
w_s
).
int
()
+
(
maybe_w_zp
if
zero_points
else
0
)
w_q
=
torch
.
clamp
(
w_q
,
min_q_val
,
max_q_val
)
# Compute ref (dequantized)
w_ref
=
(
q_w
-
half_q_val
).
half
()
*
s
w_ref
=
(
w_q
-
(
maybe_w_zp
if
zero_points
else
0
)).
to
(
orig_type
)
*
w_s
if
quant_type
.
has_bias
():
w_q
+=
quant_type
.
bias
# Restore original shapes
if
group_size
<
size_k
:
...
...
@@ -119,90 +140,48 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
w
=
w
.
reshape
((
size_k
,
size_n
)).
contiguous
()
return
w
q_w
=
reshape_w
(
q_w
)
w_q
=
reshape_w
(
w_q
)
w_ref
=
reshape_w
(
w_ref
)
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
w_
s
=
w_
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
# Apply act_order
g_idx
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
w
.
device
)
rand_perm
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
w
.
device
)
if
act_order
:
assert
(
group_size
<
size_k
),
"For act_order, groupsize = {} must be less than size_k = {}"
.
format
(
group_size
,
size_k
)
w_ref
,
q_w
,
g_idx
,
rand_perm
=
permute_rows
(
q_w
,
w_ref
,
group_size
)
if
zero_points
:
maybe_w_zp
=
maybe_w_zp
.
reshape
((
-
1
,
size_n
)).
contiguous
()
maybe_w_zp
=
maybe_w_zp
.
to
(
device
=
orig_device
)
return
(
w_ref
.
to
(
device
=
orig_device
),
q_w
.
to
(
device
=
orig_device
),
s
.
to
(
device
=
orig_device
),
g_idx
.
to
(
device
=
orig_device
),
rand_perm
.
to
(
device
=
orig_device
),
w_q
.
to
(
device
=
orig_device
),
w_s
.
to
(
device
=
orig_device
),
maybe_w_zp
,
)
def
quantize_weights
_with_zp
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
):
orig_device
=
w
.
device
size_k
,
size_n
=
w
.
shape
def
gptq_
quantize_weights
(
w
:
torch
.
Tensor
,
quant_type
:
ScalarType
,
group_size
:
int
,
act_order
:
bool
):
size_k
,
_
=
w
.
shape
assert
w
.
is_floating_point
(),
"w must be float"
assert
num_bits
in
SUPPORTED_NUM_BITS
,
f
"Unsupported num_bits =
{
num_bits
}
"
assert
quant_type
in
SUPPORTED_GPTQ_QUANT_TYPES
,
\
f
"Unsupported gptq type =
{
quant_type
}
"
assert
group_size
in
SUPPORTED_GROUP_SIZES
+
[
size_k
],
f
"Unsupported groupsize =
{
group_size
}
"
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
max_q_val
=
2
**
num_bits
-
1
min_q_val
=
0
w_ref
,
w_q
,
w_s
,
_
=
quantize_weights
(
w
,
quant_type
,
group_size
)
# Reshape to [groupsize, -1]
if
group_size
<
size_k
:
w
=
w
.
reshape
((
-
1
,
group_size
,
size_n
))
w
=
w
.
permute
(
1
,
0
,
2
)
w
=
w
.
reshape
((
group_size
,
-
1
))
# Compute scale for each group
max
=
torch
.
max
(
w
,
0
,
keepdim
=
True
)[
0
]
min
=
torch
.
min
(
w
,
0
,
keepdim
=
True
)[
0
]
s
=
(
max
-
min
).
clamp
(
min
=
1e-5
)
/
max_q_val
# Compute zero-point for each group
zp
=
(
-
torch
.
round
(
min
/
s
)).
clamp
(
min_q_val
,
max_q_val
).
int
()
# Quantize
q_w
=
torch
.
round
(
w
/
s
).
int
()
+
zp
q_w
=
torch
.
clamp
(
q_w
,
min_q_val
,
max_q_val
)
# Compute ref (dequantized)
w_ref
=
(
q_w
-
zp
).
half
()
*
s
# Restore original shapes
if
group_size
<
size_k
:
def
reshape_w
(
w
):
w
=
w
.
reshape
((
group_size
,
-
1
,
size_n
))
w
=
w
.
permute
(
1
,
0
,
2
)
w
=
w
.
reshape
((
size_k
,
size_n
)).
contiguous
()
return
w
q_w
=
reshape_w
(
q_w
)
w_ref
=
reshape_w
(
w_ref
)
# Apply act_order
g_idx
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
w
.
device
)
rand_perm
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
w
.
device
)
if
act_order
:
assert
(
group_size
<
size_k
),
"For act_order, groupsize = {} must be less than size_k = {}"
.
format
(
group_size
,
size_k
)
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
zp
=
zp
.
reshape
((
-
1
,
size_n
)).
contiguous
()
w_ref
,
w_q
,
g_idx
,
rand_perm
=
permute_rows
(
w_q
,
w_ref
,
group_size
)
return
(
w_ref
.
to
(
device
=
orig_device
),
q_w
.
to
(
device
=
orig_device
),
s
.
to
(
device
=
orig_device
),
zp
.
to
(
device
=
orig_device
),
)
return
w_ref
,
w_q
,
w_s
,
g_idx
,
rand_perm
# QQQ employs different quant schemes for per-group and
...
...
@@ -212,7 +191,8 @@ def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int):
size_k
,
size_n
=
w
.
shape
assert
w
.
is_floating_point
(),
"w must be float"
assert
num_bits
in
SUPPORTED_NUM_BITS
,
f
"Unsupported num_bits =
{
num_bits
}
"
assert
num_bits
in
MARLIN_QQQ_SUPPORTED_NUM_BITS
,
\
f
"Unsupported num_bits =
{
num_bits
}
"
assert
group_size
in
SUPPORTED_GROUP_SIZES
+
[
size_k
],
f
"Unsupported groupsize =
{
group_size
}
"
...
...
vllm/scalar_type.py
0 → 100644
View file @
a8d604ca
from
._core_ext
import
NanRepr
,
ScalarType
# naming generally follows: https://github.com/jax-ml/ml_dtypes
# for floating point types (leading f) the scheme is:
# `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
# flags:
# - no-flags: means it follows IEEE 754 conventions
# - f: means finite values only (no infinities)
# - n: means nans are supported (non-standard encoding)
# for integer types the scheme is:
# `[u]int<size_bits>[b<bias>]`
# - if bias is not present it means its zero
class
scalar_types
:
int4
=
ScalarType
.
int_
(
4
,
None
)
uint4
=
ScalarType
.
uint
(
4
,
None
)
int8
=
ScalarType
.
int_
(
8
,
None
)
uint8
=
ScalarType
.
uint
(
8
,
None
)
float8_e4m3fn
=
ScalarType
.
float_
(
4
,
3
,
True
,
NanRepr
.
EXTD_RANGE_MAX_MIN
.
value
)
float8_e5m2
=
ScalarType
.
float_IEEE754
(
5
,
2
)
float16_e8m7
=
ScalarType
.
float_IEEE754
(
8
,
7
)
float16_e5m10
=
ScalarType
.
float_IEEE754
(
5
,
10
)
# fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
float6_e3m2f
=
ScalarType
.
float_
(
3
,
2
,
True
,
NanRepr
.
NONE
.
value
)
# "gptq" types
uint4b8
=
ScalarType
.
uint
(
4
,
8
)
uint8b128
=
ScalarType
.
uint
(
8
,
128
)
# colloquial names
bfloat16
=
float16_e8m7
float16
=
float16_e5m10
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment