Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
54a66e5f
Unverified
Commit
54a66e5f
authored
Apr 15, 2025
by
Dipika Sikka
Committed by
GitHub
Apr 15, 2025
Browse files
[Misc] Update `compressed-tensors` WNA16 to support zero-points (#14211)
parent
280d62b8
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
85 additions
and
45 deletions
+85
-45
tests/quantization/test_compressed_tensors.py
tests/quantization/test_compressed_tensors.py
+15
-6
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+3
-3
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
...on/compressed_tensors/schemes/compressed_tensors_wNa16.py
+39
-3
vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py
...or/layers/quantization/kernels/mixed_precision/machete.py
+2
-5
vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py
...tor/layers/quantization/kernels/mixed_precision/marlin.py
+24
-27
vllm/model_executor/layers/quantization/utils/marlin_utils.py
.../model_executor/layers/quantization/utils/marlin_utils.py
+2
-1
No files found.
tests/quantization/test_compressed_tensors.py
View file @
54a66e5f
...
...
@@ -261,16 +261,23 @@ def test_compressed_tensors_w8a8_dynamic_per_token(
@
pytest
.
mark
.
parametrize
(
"wNa16_args"
,
[
(
"nm-testing/tinyllama-oneshot-w4a16-channel-v2"
,
"channel"
,
None
,
8
),
(
"nm-testing/tinyllama-oneshot-w4a16-group128-v2"
,
"group"
,
128
,
8
),
(
"nm-testing/tinyllama-oneshot-w8a16-per-channel"
,
"channel"
,
None
,
4
),
],
[(
"nm-testing/tinyllama-oneshot-w4a16-channel-v2"
,
"channel"
,
None
,
8
,
True
,
False
),
(
"nm-testing/tinyllama-oneshot-w4a16-group128-v2"
,
"group"
,
128
,
8
,
True
,
False
),
(
"nm-testing/tinyllama-oneshot-w8a16-per-channel"
,
"channel"
,
None
,
4
,
True
,
False
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-awq-group128-asym256"
,
"group"
,
128
,
8
,
False
,
False
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-Asym-Updated-Channel"
,
"channel"
,
None
,
8
,
False
,
False
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-Asym-Updated-ActOrder"
,
"group"
,
128
,
8
,
False
,
True
)],
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"The tests are skipped on non-CUDA platform."
)
def
test_compressed_tensors_wNa16
(
vllm_runner
,
wNa16_args
):
model
,
strategy
,
group
,
pack_factor
=
wNa16_args
model
,
strategy
,
group
,
pack_factor
,
symmetric
,
has_g_idx
=
wNa16_args
with
vllm_runner
(
model
)
as
llm
:
def
check_model
(
model
):
...
...
@@ -286,6 +293,8 @@ def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
if
group
is
None
else
group
)
assert
qkv_proj
.
scheme
.
pack_factor
==
pack_factor
assert
qkv_proj
.
scheme
.
symmetric
==
symmetric
assert
qkv_proj
.
scheme
.
has_g_idx
==
has_g_idx
llm
.
apply_model
(
check_model
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
54a66e5f
...
...
@@ -302,14 +302,12 @@ class CompressedTensorsConfig(QuantizationConfig):
def
_is_wNa16_group_channel
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
input_quant_none
=
input_quant
is
None
is_symmetric
=
weight_quant
.
symmetric
is_channel_group
=
(
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
.
value
or
weight_quant
.
strategy
==
QuantizationStrategy
.
GROUP
.
value
)
is_static
=
not
weight_quant
.
dynamic
return
(
is_channel_group
and
input_quant_none
and
is_symmetric
and
is_static
)
return
(
is_channel_group
and
input_quant_none
and
is_static
)
def
_get_scheme_from_parts
(
self
,
weight_quant
:
BaseModel
,
...
...
@@ -319,6 +317,7 @@ class CompressedTensorsConfig(QuantizationConfig):
if
self
.
_is_wNa16_group_channel
(
weight_quant
,
input_quant
):
if
(
self
.
quant_format
==
CompressionFormat
.
marlin_24
.
value
and
weight_quant
.
num_bits
in
W4A16SPARSE24_SUPPORTED_BITS
):
assert
weight_quant
.
symmetric
return
CompressedTensorsW4A16Sparse24
(
strategy
=
weight_quant
.
strategy
,
num_bits
=
weight_quant
.
num_bits
,
...
...
@@ -328,6 +327,7 @@ class CompressedTensorsConfig(QuantizationConfig):
return
CompressedTensorsWNA16
(
num_bits
=
weight_quant
.
num_bits
,
strategy
=
weight_quant
.
strategy
,
symmetric
=
weight_quant
.
symmetric
,
group_size
=
weight_quant
.
group_size
,
actorder
=
weight_quant
.
actorder
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
View file @
54a66e5f
...
...
@@ -12,11 +12,15 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
MPLinearLayerConfig
,
choose_mp_linear_kernel
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
marlin_repeat_scales_on_all_ranks
)
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
ChannelQuantScaleParameter
,
GroupQuantScaleParameter
,
PackedColumnParameter
,
PackedvLLMParameter
,
RowvLLMParameter
)
# yapf: enable
from
vllm.scalar_type
import
scalar_types
logger
=
init_logger
(
__name__
)
...
...
@@ -26,6 +30,7 @@ WNA16_SUPPORTED_TYPES_MAP = {
4
:
scalar_types
.
uint4b8
,
8
:
scalar_types
.
uint8b128
}
WNA16_ZP_SUPPORTED_TYPES_MAP
=
{
4
:
scalar_types
.
uint4
,
8
:
scalar_types
.
uint8
}
WNA16_SUPPORTED_BITS
=
list
(
WNA16_SUPPORTED_TYPES_MAP
.
keys
())
...
...
@@ -36,10 +41,12 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
strategy
:
str
,
num_bits
:
int
,
group_size
:
Optional
[
int
]
=
None
,
symmetric
:
Optional
[
bool
]
=
True
,
actorder
:
Optional
[
ActivationOrdering
]
=
None
):
self
.
pack_factor
=
32
//
num_bits
self
.
strategy
=
strategy
self
.
symmetric
=
symmetric
self
.
group_size
=
-
1
if
group_size
is
None
else
group_size
self
.
has_g_idx
=
actorder
==
ActivationOrdering
.
GROUP
...
...
@@ -53,7 +60,9 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
f
"Unsupported num_bits =
{
num_bits
}
. "
f
"Supported num_bits =
{
WNA16_SUPPORTED_TYPES_MAP
.
keys
()
}
"
)
self
.
quant_type
=
WNA16_SUPPORTED_TYPES_MAP
[
num_bits
]
self
.
quant_type
=
(
WNA16_ZP_SUPPORTED_TYPES_MAP
[
num_bits
]
if
not
self
.
symmetric
else
WNA16_SUPPORTED_TYPES_MAP
[
num_bits
])
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
...
...
@@ -75,7 +84,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
weight_type
=
self
.
quant_type
,
act_type
=
params_dtype
,
group_size
=
self
.
group_size
,
zero_points
=
False
,
zero_points
=
not
self
.
symmetric
,
has_g_idx
=
self
.
has_g_idx
)
...
...
@@ -120,13 +129,37 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
dtype
=
params_dtype
,
)
}
zeros_args
=
{
"weight_loader"
:
weight_loader
,
"data"
:
torch
.
zeros
(
output_size_per_partition
//
self
.
pack_factor
,
scales_and_zp_size
,
dtype
=
torch
.
int32
,
)
}
if
not
partition_scales
:
weight_scale
=
ChannelQuantScaleParameter
(
output_dim
=
0
,
**
weight_scale_args
)
if
not
self
.
symmetric
:
qzeros
=
PackedColumnParameter
(
output_dim
=
0
,
packed_dim
=
0
,
packed_factor
=
self
.
pack_factor
,
**
zeros_args
)
else
:
weight_scale
=
GroupQuantScaleParameter
(
output_dim
=
0
,
input_dim
=
1
,
**
weight_scale_args
)
if
not
self
.
symmetric
:
qzeros
=
PackedvLLMParameter
(
input_dim
=
1
,
output_dim
=
0
,
packed_dim
=
0
,
packed_factor
=
self
.
pack_factor
,
**
zeros_args
)
# A 2D array defining the original shape of the weights
# before packing
...
...
@@ -138,6 +171,9 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
layer
.
register_parameter
(
"weight_shape"
,
weight_shape
)
if
not
self
.
symmetric
:
layer
.
register_parameter
(
"weight_zero_point"
,
qzeros
)
# group index (for activation reordering)
if
self
.
has_g_idx
:
weight_g_idx
=
RowvLLMParameter
(
data
=
torch
.
empty
(
...
...
@@ -151,7 +187,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
self
.
kernel
=
kernel_type
(
mp_linear_kernel_config
,
w_q_param_name
=
"weight_packed"
,
w_s_param_name
=
"weight_scale"
,
w_zp_param_name
=
None
,
w_zp_param_name
=
"weight_zero_point"
,
w_gidx_param_name
=
"weight_g_idx"
)
# Checkpoints are serialized in compressed-tensors format, which is
...
...
vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py
View file @
54a66e5f
...
...
@@ -26,17 +26,14 @@ class MacheteLinearKernel(MPLinearKernel):
@
classmethod
def
can_implement
(
cls
,
c
:
MPLinearLayerConfig
)
->
Tuple
[
bool
,
Optional
[
str
]]:
if
c
.
has_g_idx
and
\
c
.
partition_weight_shape
[
0
]
!=
c
.
full_weight_shape
[
0
]:
return
False
,
"Act reordering currently not supported by Machete, "
\
"when the input features are partitioned across "
\
"devices"
if
c
.
zero_points
:
return
False
,
"Zero points currently not supported by "
\
" Compressed Tensors + Machete. (Kernel supports it"
\
" but CompressedTensorsWNA16 does not so support has"
\
" not been added to MacheteWNA16Kernel yet"
return
False
,
"Zero points currently not supported by Machete"
if
c
.
weight_type
not
in
query_machete_supported_quant_types
(
c
.
zero_points
):
...
...
vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py
View file @
54a66e5f
...
...
@@ -9,7 +9,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MARLIN_SUPPORTED_GROUP_SIZES
,
apply_gptq_marlin_linear
,
check_marlin_supports_shape
,
marlin_is_k_full
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_permute_scales
,
marlin_sort_g_idx
,
query_marlin_supported_quant_types
)
marlin_zero_points
,
query_marlin_supported_quant_types
,
unpack_cols
)
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
permute_param_layout_
)
...
...
@@ -25,10 +25,6 @@ class MarlinLinearKernel(MPLinearKernel):
@
classmethod
def
can_implement
(
cls
,
c
:
MPLinearLayerConfig
)
->
Tuple
[
bool
,
Optional
[
str
]]:
if
c
.
zero_points
:
return
False
,
"Zero points currently not supported by "
\
" MarlinLinearKernel. Will be added when AWQMarlin "
\
"is migrated over to using MPLinearKernel backend"
quant_types
=
query_marlin_supported_quant_types
(
c
.
zero_points
)
if
c
.
weight_type
not
in
quant_types
:
...
...
@@ -67,28 +63,6 @@ class MarlinLinearKernel(MPLinearKernel):
if
self
.
w_zp_name
is
None
:
self
.
w_zp_name
=
"w_zp"
if
c
.
has_g_idx
:
g_idx
,
g_idx_sort_indices
=
marlin_sort_g_idx
(
getattr
(
layer
,
self
.
w_gidx_name
))
self
.
_transform_param
(
layer
,
self
.
w_gidx_name
,
lambda
_
:
g_idx
)
layer
.
g_idx_sort_indices
=
g_idx_sort_indices
else
:
setattr
(
layer
,
self
.
w_gidx_name
,
marlin_make_empty_g_idx
(
device
))
layer
.
g_idx_sort_indices
=
marlin_make_empty_g_idx
(
device
)
if
c
.
zero_points
:
pass
# TODO (lucas): add the following when AWQMarlin is migrated over to
# using MPLinearKernel backend
# self._transform_param(layer, self.w_zp_name, lambda x: \
# marlin_zero_points(
# x,
# size_k=c.partition_weight_shape[0],
# size_n=c.partition_weight_shape[1],
# num_bits=c.weight_type.size_bits))
else
:
setattr
(
layer
,
self
.
w_zp_name
,
marlin_make_empty_g_idx
(
device
))
def
transform_w_q
(
x
):
assert
isinstance
(
x
,
BasevLLMParameter
)
permute_param_layout_
(
x
,
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
0
)
...
...
@@ -108,6 +82,28 @@ class MarlinLinearKernel(MPLinearKernel):
group_size
=
c
.
group_size
)
return
x
if
c
.
has_g_idx
:
g_idx
,
g_idx_sort_indices
=
marlin_sort_g_idx
(
getattr
(
layer
,
self
.
w_gidx_name
))
self
.
_transform_param
(
layer
,
self
.
w_gidx_name
,
lambda
_
:
g_idx
)
layer
.
g_idx_sort_indices
=
g_idx_sort_indices
else
:
setattr
(
layer
,
self
.
w_gidx_name
,
marlin_make_empty_g_idx
(
device
))
layer
.
g_idx_sort_indices
=
marlin_make_empty_g_idx
(
device
)
if
c
.
zero_points
:
grouped_k
=
(
c
.
partition_weight_shape
[
0
]
//
c
.
group_size
if
c
.
group_size
!=
-
1
else
1
)
self
.
_transform_param
(
layer
,
self
.
w_zp_name
,
lambda
x
:
\
marlin_zero_points
(
unpack_cols
(
x
.
t
(),
c
.
weight_type
.
size_bits
,
grouped_k
,
c
.
partition_weight_shape
[
1
]),
size_k
=
grouped_k
,
size_n
=
c
.
partition_weight_shape
[
1
],
num_bits
=
c
.
weight_type
.
size_bits
))
else
:
setattr
(
layer
,
self
.
w_zp_name
,
marlin_make_empty_g_idx
(
device
))
self
.
_transform_param
(
layer
,
self
.
w_q_name
,
transform_w_q
)
self
.
_transform_param
(
layer
,
self
.
w_s_name
,
transform_w_s
)
...
...
@@ -131,5 +127,6 @@ class MarlinLinearKernel(MPLinearKernel):
wtype
=
c
.
weight_type
,
input_size_per_partition
=
c
.
partition_weight_shape
[
0
],
output_size_per_partition
=
c
.
partition_weight_shape
[
1
],
has_zp
=
self
.
config
.
zero_points
,
is_k_full
=
self
.
is_k_full
,
bias
=
bias
)
vllm/model_executor/layers/quantization/utils/marlin_utils.py
View file @
54a66e5f
...
...
@@ -332,6 +332,7 @@ def apply_gptq_marlin_linear(
wtype
:
ScalarType
,
output_size_per_partition
:
int
,
input_size_per_partition
:
int
,
has_zp
:
bool
,
is_k_full
:
bool
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
use_fp32_reduce
:
bool
=
USE_FP32_REDUCE_DEFAULT
)
->
torch
.
Tensor
:
...
...
@@ -356,8 +357,8 @@ def apply_gptq_marlin_linear(
size_n
=
output_size_per_partition
,
size_k
=
input_size_per_partition
,
is_k_full
=
is_k_full
,
has_zp
=
False
,
use_atomic_add
=
use_atomic_add
,
has_zp
=
has_zp
,
use_fp32_reduce
=
use_fp32_reduce
,
is_zp_float
=
False
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment