Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0cdbf5e6
Unverified
Commit
0cdbf5e6
authored
Aug 20, 2025
by
Michael Goin
Committed by
GitHub
Aug 20, 2025
Browse files
[Kernel/Quant] Remove the original marlin format and qqq (#23204)
Signed-off-by:
mgoin
<
mgoin64@gmail.com
>
parent
ebe56a00
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
0 additions
and
756 deletions
+0
-756
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+0
-1
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+0
-6
vllm/model_executor/layers/quantization/marlin.py
vllm/model_executor/layers/quantization/marlin.py
+0
-263
vllm/model_executor/layers/quantization/qqq.py
vllm/model_executor/layers/quantization/qqq.py
+0
-275
vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py
...ecutor/layers/quantization/utils/marlin_utils_test_qqq.py
+0
-126
vllm/model_executor/layers/quantization/utils/quant_utils.py
vllm/model_executor/layers/quantization/utils/quant_utils.py
+0
-85
No files found.
vllm/model_executor/layers/linear.py
View file @
0cdbf5e6
...
...
@@ -42,7 +42,6 @@ WEIGHT_LOADER_V2_SUPPORTED = [
"GPTQMarlinLinearMethod"
,
"Fp8LinearMethod"
,
"MarlinLinearMethod"
,
"QQQLinearMethod"
,
"GPTQMarlin24LinearMethod"
,
"TPUInt8LinearMethod"
,
"GPTQLinearMethod"
,
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
0cdbf5e6
...
...
@@ -15,7 +15,6 @@ QuantizationMethods = Literal[
"fbgemm_fp8"
,
"modelopt"
,
"modelopt_fp4"
,
"marlin"
,
"bitblas"
,
"gguf"
,
"gptq_marlin_24"
,
...
...
@@ -25,7 +24,6 @@ QuantizationMethods = Literal[
"gptq"
,
"compressed-tensors"
,
"bitsandbytes"
,
"qqq"
,
"hqq"
,
"experts_int8"
,
"neuron_quant"
,
...
...
@@ -106,13 +104,11 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from
.hqq_marlin
import
HQQMarlinConfig
from
.inc
import
INCConfig
from
.ipex_quant
import
IPEXConfig
from
.marlin
import
MarlinConfig
from
.modelopt
import
ModelOptFp8Config
,
ModelOptNvFp4Config
from
.moe_wna16
import
MoeWNA16Config
from
.mxfp4
import
Mxfp4Config
from
.neuron_quant
import
NeuronQuantConfig
from
.ptpc_fp8
import
PTPCFp8Config
from
.qqq
import
QQQConfig
from
.rtn
import
RTNConfig
from
.torchao
import
TorchAOConfig
from
.tpu_int8
import
Int8TpuConfig
...
...
@@ -125,7 +121,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"fbgemm_fp8"
:
FBGEMMFp8Config
,
"modelopt"
:
ModelOptFp8Config
,
"modelopt_fp4"
:
ModelOptNvFp4Config
,
"marlin"
:
MarlinConfig
,
"bitblas"
:
BitBLASConfig
,
"gguf"
:
GGUFConfig
,
"gptq_marlin_24"
:
GPTQMarlin24Config
,
...
...
@@ -136,7 +131,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"compressed-tensors"
:
CompressedTensorsConfig
,
"bitsandbytes"
:
BitsAndBytesConfig
,
"ptpc_fp8"
:
PTPCFp8Config
,
"qqq"
:
QQQConfig
,
"hqq"
:
HQQMarlinConfig
,
"experts_int8"
:
ExpertsInt8Config
,
"neuron_quant"
:
NeuronQuantConfig
,
...
...
vllm/model_executor/layers/quantization/marlin.py
deleted
100644 → 0
View file @
ebe56a00
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
,
Optional
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization
import
QuantizationMethods
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
ChannelQuantScaleParameter
,
GroupQuantScaleParameter
,
PackedvLLMParameter
)
logger
=
init_logger
(
__name__
)
class
MarlinConfig
(
QuantizationConfig
):
"""Config class for Marlin.
Reference: https://github.com/IST-DASLab/marlin/tree/master
"""
def
__init__
(
self
,
group_size
:
int
,
lm_head_quantized
:
bool
,
)
->
None
:
super
().
__init__
()
# Group size for the quantization.
self
.
group_size
=
group_size
self
.
lm_head_quantized
=
lm_head_quantized
if
self
.
group_size
!=
128
and
self
.
group_size
!=
-
1
:
raise
ValueError
(
"Currently, only group size 128 and -1 (channelwise) "
"is supported for Marlin, but got group_size of "
f
"
{
self
.
group_size
}
"
)
# 4 Bits packed into 32 bit datatype.
self
.
pack_factor
=
32
//
4
# Tile size used by marlin kernels.
self
.
tile_size
=
16
# Min out_features dim
self
.
min_n_threads
=
64
# Min in_features dim
self
.
min_k_threads
=
128
# Max parallel problems to solve at once (improves large
# batch performance)
self
.
max_parallel
=
16
# Permutation length used by the marlin kernels.
self
.
perm_len
=
1024
def
__repr__
(
self
)
->
str
:
return
(
f
"MarlinConfig(group_size=
{
self
.
group_size
}
, "
f
"lm_head_quantized=
{
self
.
lm_head_quantized
}
)"
)
@
classmethod
def
get_name
(
cls
)
->
QuantizationMethods
:
return
"marlin"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
list
[
torch
.
dtype
]:
return
[
torch
.
half
]
@
classmethod
# Need to figure it out
def
get_min_capability
(
cls
)
->
int
:
return
80
@
classmethod
def
get_config_filenames
(
cls
)
->
list
[
str
]:
return
[
"quantize_config.json"
]
@
classmethod
def
from_config
(
cls
,
config
:
dict
[
str
,
Any
])
->
"MarlinConfig"
:
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
lm_head_quantized
=
cls
.
get_from_keys_or
(
config
,
[
"lm_head"
],
default
=
False
)
return
cls
(
group_size
,
lm_head_quantized
)
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
)
->
Optional
[
QuantizationMethods
]:
# compat: autogptq >=0.8.0 use checkpoint_format: str
# compat: autogptq <=0.7.1 is_marlin_format: bool
is_marlin_format
=
(
hf_quant_cfg
.
get
(
"checkpoint_format"
)
==
"marlin"
or
hf_quant_cfg
.
get
(
"is_marlin_format"
,
False
))
is_valid_user_quant
=
(
user_quant
is
None
or
user_quant
==
"gptq"
or
user_quant
==
"marlin"
)
if
is_marlin_format
and
is_valid_user_quant
:
msg
=
(
"The model is serialized in {} format. Using {} kernel."
.
format
(
cls
.
get_name
(),
cls
.
get_name
()))
logger
.
info
(
msg
)
return
cls
.
get_name
()
return
None
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"MarlinLinearMethod"
]:
if
(
isinstance
(
layer
,
LinearBase
)
or
(
isinstance
(
layer
,
ParallelLMHead
)
and
self
.
lm_head_quantized
)):
return
MarlinLinearMethod
(
self
)
return
None
class
MarlinLinearMethod
(
LinearMethodBase
):
"""Linear method for Marlin.
Args:
quant_config: The Marlin quantization config.
"""
def
__init__
(
self
,
quant_config
:
MarlinConfig
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
list
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
del
output_size
# Unused.
weight_loader
=
extra_weight_attrs
[
"weight_loader"
]
if
params_dtype
!=
torch
.
float16
:
raise
ValueError
(
f
"The params dtype must be float16, but got
{
params_dtype
}
"
)
# Validate output_size_per_partition
output_size_per_partition
=
sum
(
output_partition_sizes
)
if
output_size_per_partition
%
self
.
quant_config
.
min_n_threads
!=
0
:
raise
ValueError
(
f
"Weight output_size_per_partition = "
f
"
{
output_size_per_partition
}
is not divisible by "
f
"min_n_threads =
{
self
.
quant_config
.
min_n_threads
}
."
)
if
output_size_per_partition
%
self
.
quant_config
.
pack_factor
!=
0
:
raise
ValueError
(
f
"Weight output_size_per_partition = "
f
"
{
output_size_per_partition
}
is not divisible by "
f
"pack_factor =
{
self
.
quant_config
.
pack_factor
}
."
)
# Validate input_size_per_partition
if
input_size_per_partition
%
self
.
quant_config
.
min_k_threads
!=
0
:
raise
ValueError
(
f
"Weight input_size_per_partition = "
f
"
{
input_size_per_partition
}
is not divisible by "
f
"min_k_threads =
{
self
.
quant_config
.
min_k_threads
}
."
)
if
(
self
.
quant_config
.
group_size
!=
-
1
and
input_size_per_partition
%
self
.
quant_config
.
group_size
!=
0
):
raise
ValueError
(
f
"Weight input_size_per_partition = "
f
"
{
input_size_per_partition
}
is not divisible by "
f
"group_size =
{
self
.
quant_config
.
group_size
}
."
)
# Check that we have at least 4 tiles horizontally in the shard
num_tiles_per_perm
=
self
.
quant_config
.
perm_len
//
(
self
.
quant_config
.
tile_size
**
2
)
if
output_size_per_partition
%
num_tiles_per_perm
!=
0
:
raise
ValueError
(
"Each permutation group must reside on the same gpu"
)
# Quantized 4Bit weights packed into Int32.
qweight
=
PackedvLLMParameter
(
data
=
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
tile_size
,
output_size_per_partition
*
self
.
quant_config
.
tile_size
//
self
.
quant_config
.
pack_factor
,
device
=
"cuda"
,
dtype
=
torch
.
int32
,
),
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
1
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
marlin_tile_size
=
self
.
quant_config
.
tile_size
,
weight_loader
=
weight_loader
)
# Determine if channelwise or not
input_groups
=
(
1
if
self
.
quant_config
.
group_size
==
-
1
else
input_size_per_partition
//
self
.
quant_config
.
group_size
)
weight_scale_args
=
{
"data"
:
torch
.
empty
(
input_groups
,
output_size_per_partition
,
device
=
"cuda"
,
dtype
=
params_dtype
,
),
"weight_loader"
:
weight_loader
}
if
input_groups
==
1
:
scales
=
ChannelQuantScaleParameter
(
output_dim
=
1
,
**
weight_scale_args
)
else
:
scales
=
GroupQuantScaleParameter
(
output_dim
=
1
,
input_dim
=
0
,
**
weight_scale_args
)
# Allocate workspace (Used for internal locking mechanism)
max_workspace_size
=
(
output_size_per_partition
//
self
.
quant_config
.
min_n_threads
)
*
self
.
quant_config
.
max_parallel
workspace
=
BasevLLMParameter
(
data
=
torch
.
zeros
(
max_workspace_size
,
device
=
"cuda"
,
dtype
=
torch
.
int
),
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"B"
,
qweight
)
layer
.
register_parameter
(
"s"
,
scales
)
layer
.
register_parameter
(
"workspace"
,
workspace
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# required by torch.compile
layer
.
B
=
Parameter
(
layer
.
B
.
data
,
requires_grad
=
False
)
layer
.
s
=
Parameter
(
layer
.
s
.
data
,
requires_grad
=
False
)
layer
.
workspace
=
Parameter
(
layer
.
workspace
.
data
,
requires_grad
=
False
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
qweight
=
layer
.
B
scales
=
layer
.
s
workspace
=
layer
.
workspace
x_2d
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
size_m
=
x_2d
.
shape
[
0
]
size_k
=
x_2d
.
shape
[
1
]
size_n
=
scales
.
shape
[
1
]
output_2d
=
ops
.
marlin_gemm
(
x_2d
,
qweight
,
scales
,
workspace
,
size_m
,
size_n
,
size_k
)
output
=
output_2d
.
view
(
x
.
shape
[:
-
1
]
+
(
output_2d
.
shape
[
1
],
))
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
vllm/model_executor/layers/quantization/qqq.py
deleted
100644 → 0
View file @
ebe56a00
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
,
Optional
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization
import
QuantizationMethods
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
ChannelQuantScaleParameter
,
GroupQuantScaleParameter
,
PackedvLLMParameter
)
logger
=
init_logger
(
__name__
)
MARLIN_QQQ_TILE
=
16
MARLIN_QQQ_MIN_THREAD_N
=
64
MARLIN_QQQ_MIN_THREAD_K
=
128
MARLIN_QQQ_MAX_PARALLEL
=
16
MARLIN_QQQ_SUPPORTED_NUM_BITS
=
[
4
]
MARLIN_QQQ_SUPPORTED_GROUP_SIZES
=
[
-
1
,
128
]
MARLIN_QQQ_SUPPORTED_SYM
=
[
True
]
class
QQQConfig
(
QuantizationConfig
):
"""Config class for QQQ
Reference: https://arxiv.org/pdf/2406.09904
"""
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
is_sym
:
bool
=
True
,
)
->
None
:
super
().
__init__
()
self
.
weight_bits
=
weight_bits
self
.
group_size
=
group_size
self
.
is_sym
=
is_sym
# Verify
if
self
.
weight_bits
not
in
MARLIN_QQQ_SUPPORTED_NUM_BITS
:
raise
ValueError
(
f
"QQQ does not support weight_bits =
{
self
.
weight_bits
}
. "
f
"Only weight_bits =
{
MARLIN_QQQ_SUPPORTED_NUM_BITS
}
"
"are supported."
)
if
self
.
group_size
not
in
MARLIN_QQQ_SUPPORTED_GROUP_SIZES
:
raise
ValueError
(
f
"QQQ does not support group_size =
{
self
.
group_size
}
. "
f
"Only group_sizes =
{
MARLIN_QQQ_SUPPORTED_GROUP_SIZES
}
"
"are supported."
)
if
self
.
is_sym
not
in
MARLIN_QQQ_SUPPORTED_SYM
:
raise
ValueError
(
f
"QQQ does not support is_sym =
{
self
.
is_sym
}
. "
f
"Only sym =
{
MARLIN_QQQ_SUPPORTED_SYM
}
are supported."
)
# 4 Bits packed into 32 bit datatype.
self
.
pack_factor
=
32
//
self
.
weight_bits
# Tile size used by QQQ kernels.
self
.
tile_size
=
MARLIN_QQQ_TILE
# Min out_features dim
self
.
min_n_threads
=
MARLIN_QQQ_MIN_THREAD_N
# Min in_features dim
self
.
min_k_threads
=
MARLIN_QQQ_MIN_THREAD_K
# Max parallel problems to solve at once (improves large
# batch performance)
self
.
max_parallel
=
MARLIN_QQQ_MAX_PARALLEL
# Permutation length used by the QQQ kernels.
self
.
perm_len
=
1024
def
__repr__
(
self
)
->
str
:
return
"QQQConfig(weight_bits={}, group_size={})"
.
format
(
self
.
weight_bits
,
self
.
group_size
)
@
classmethod
def
get_name
(
cls
)
->
QuantizationMethods
:
return
"qqq"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
list
[
torch
.
dtype
]:
return
[
torch
.
half
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
80
@
classmethod
def
get_config_filenames
(
cls
)
->
list
[
str
]:
"""List of filenames to search for in the model directory."""
return
[
"quant_config.json"
,
"quantize_config.json"
,
]
@
classmethod
def
from_config
(
cls
,
config
:
dict
[
str
,
Any
])
->
"QQQConfig"
:
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"wbits"
])
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
return
cls
(
weight_bits
,
group_size
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QQQLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
return
QQQLinearMethod
(
self
)
return
None
class
QQQLinearMethod
(
LinearMethodBase
):
"""Linear method for QQQ.
Args:
quant_config: The QQQ quantization config.
"""
def
__init__
(
self
,
quant_config
:
QQQConfig
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
list
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
weight_loader
=
extra_weight_attrs
[
"weight_loader"
]
if
params_dtype
!=
torch
.
float16
:
raise
ValueError
(
f
"The params dtype must be float16, but got
{
params_dtype
}
"
)
# Validate output_size_per_partition
output_size_per_partition
=
sum
(
output_partition_sizes
)
if
output_size_per_partition
%
self
.
quant_config
.
min_n_threads
!=
0
:
raise
ValueError
(
f
"Weight output_size_per_partition = "
f
"
{
output_size_per_partition
}
is not divisible by "
f
"min_n_threads =
{
self
.
quant_config
.
min_n_threads
}
."
)
if
output_size_per_partition
%
self
.
quant_config
.
pack_factor
!=
0
:
raise
ValueError
(
f
"Weight output_size_per_partition = "
f
"
{
output_size_per_partition
}
is not divisible by "
f
"pack_factor =
{
self
.
quant_config
.
pack_factor
}
."
)
# Validate input_size_per_partition
if
input_size_per_partition
%
self
.
quant_config
.
min_k_threads
!=
0
:
raise
ValueError
(
f
"Weight input_size_per_partition = "
f
"
{
input_size_per_partition
}
is not divisible by "
f
"min_k_threads =
{
self
.
quant_config
.
min_k_threads
}
."
)
if
(
self
.
quant_config
.
group_size
!=
-
1
and
input_size_per_partition
%
self
.
quant_config
.
group_size
!=
0
):
raise
ValueError
(
f
"Weight input_size_per_partition = "
f
"
{
input_size_per_partition
}
is not divisible by "
f
"group_size =
{
self
.
quant_config
.
group_size
}
."
)
# Check that we have at least 4 tiles horizontally in the shard
num_tiles_per_perm
=
self
.
quant_config
.
perm_len
//
(
self
.
quant_config
.
tile_size
**
2
)
if
output_size_per_partition
%
num_tiles_per_perm
!=
0
:
raise
ValueError
(
"Each permutation group must reside on the same gpu"
)
# Quantized 4Bit weights packed into Int32.
qweight
=
PackedvLLMParameter
(
data
=
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
tile_size
,
output_size_per_partition
*
self
.
quant_config
.
tile_size
//
self
.
quant_config
.
pack_factor
,
device
=
"cuda"
,
dtype
=
torch
.
int32
,
),
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
1
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
marlin_tile_size
=
self
.
quant_config
.
tile_size
,
weight_loader
=
weight_loader
)
s_channel
=
ChannelQuantScaleParameter
(
data
=
torch
.
empty
(
1
,
output_size_per_partition
,
device
=
"cuda"
,
dtype
=
torch
.
float
,
),
weight_loader
=
weight_loader
,
output_dim
=
1
)
if
self
.
quant_config
.
group_size
==
-
1
:
s_group_data
=
torch
.
tensor
(
[],
device
=
"cuda"
,
dtype
=
torch
.
half
,
)
else
:
s_group_data
=
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
group_size
,
output_size_per_partition
,
device
=
"cuda"
,
dtype
=
torch
.
half
,
)
s_group_attr
=
{
"data"
:
s_group_data
,
"weight_loader"
:
weight_loader
}
if
self
.
quant_config
.
group_size
==
-
1
:
s_group
=
BasevLLMParameter
(
**
s_group_attr
)
else
:
s_group
=
GroupQuantScaleParameter
(
output_dim
=
1
,
input_dim
=
0
,
**
s_group_attr
)
# Allocate workspace (Used for internal locking mechanism)
max_workspace_size
=
(
output_size_per_partition
//
self
.
quant_config
.
min_n_threads
)
*
self
.
quant_config
.
max_parallel
workspace
=
BasevLLMParameter
(
data
=
torch
.
zeros
(
max_workspace_size
,
device
=
"cuda"
,
dtype
=
torch
.
int
),
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"B"
,
qweight
)
layer
.
register_parameter
(
"s_channel"
,
s_channel
)
layer
.
register_parameter
(
"s_group"
,
s_group
)
layer
.
register_parameter
(
"workspace"
,
workspace
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# required by torch.compile
layer
.
B
=
Parameter
(
layer
.
B
.
data
,
requires_grad
=
False
)
layer
.
s_channel
=
Parameter
(
layer
.
s_channel
.
data
,
requires_grad
=
False
)
layer
.
s_group
=
Parameter
(
layer
.
s_group
.
data
,
requires_grad
=
False
)
layer
.
workspace
=
Parameter
(
layer
.
workspace
.
data
,
requires_grad
=
False
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
qweight
=
layer
.
B
s_ch
=
layer
.
s_channel
s_group
=
layer
.
s_group
workspace
=
layer
.
workspace
x_2d
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
size_m
=
x_2d
.
shape
[
0
]
size_k
=
x_2d
.
shape
[
1
]
size_n
=
s_ch
.
shape
[
1
]
x_int8
,
s_tok
,
_
=
ops
.
scaled_int8_quant
(
x_2d
)
output_2d
=
ops
.
marlin_qqq_gemm
(
x_int8
,
qweight
,
s_tok
,
s_ch
,
s_group
,
workspace
,
size_m
,
size_n
,
size_k
)
output
=
output_2d
.
view
(
x
.
shape
[:
-
1
]
+
(
output_2d
.
shape
[
1
],
))
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py
deleted
100644 → 0
View file @
ebe56a00
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
numpy
import
torch
from
.marlin_utils_test
import
marlin_permute_weights
from
.quant_utils
import
get_pack_factor
,
qqq_quantize_weights
def
marlin_qqq_weights
(
q_w
,
size_k
,
size_n
,
num_bits
,
perm
,
group_size
):
# Permute
q_w
=
marlin_permute_weights
(
q_w
,
size_k
,
size_n
,
perm
)
# Pack
pack_factor
=
get_pack_factor
(
num_bits
)
orig_device
=
q_w
.
device
q_w
=
q_w
.
cpu
().
numpy
().
astype
(
numpy
.
uint32
)
q_packed
=
numpy
.
zeros
((
q_w
.
shape
[
0
],
q_w
.
shape
[
1
]
//
pack_factor
),
dtype
=
numpy
.
uint32
)
if
group_size
==
size_k
:
for
i
in
range
(
pack_factor
):
q_packed
|=
(
q_w
[:,
i
::
pack_factor
]
&
0xF
)
<<
num_bits
*
i
else
:
for
i
in
range
(
pack_factor
):
q_packed
|=
q_w
[:,
i
::
pack_factor
]
<<
num_bits
*
i
q_packed
=
torch
.
from_numpy
(
q_packed
.
astype
(
numpy
.
int32
)).
to
(
orig_device
)
return
q_packed
def
get_qqq_scale_perms
():
scale_perm
:
list
[
int
]
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
+
8
*
j
for
j
in
range
(
8
)])
scale_perm_single
:
list
[
int
]
=
[]
for
i
in
range
(
4
):
scale_perm_single
.
extend
(
[
2
*
i
+
j
for
j
in
[
0
,
1
,
8
,
9
,
16
,
17
,
24
,
25
]])
return
scale_perm
,
scale_perm_single
# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501
def
get_qqq_weight_perm
(
num_bits
:
int
,
quant_type
:
str
):
perm_list
:
list
[
int
]
=
[]
for
i
in
range
(
32
):
perm1
:
list
[
int
]
=
[]
col
=
i
//
4
for
block
in
[
0
,
1
]:
for
row
in
[
4
*
(
i
%
4
),
4
*
(
i
%
4
)
+
1
,
4
*
(
i
%
4
)
+
2
,
4
*
(
i
%
4
)
+
3
,
]:
perm1
.
append
(
16
*
row
+
col
+
8
*
block
)
for
j
in
range
(
4
):
perm_list
.
extend
([
p
+
256
*
j
for
p
in
perm1
])
perm
=
numpy
.
array
(
perm_list
)
assert
quant_type
in
[
"per-channel"
,
"per-group"
],
"not supported quantization type"
if
num_bits
==
4
:
if
quant_type
==
"per-channel"
:
interleave
=
numpy
.
array
([
4
,
0
,
5
,
1
,
6
,
2
,
7
,
3
])
else
:
interleave
=
numpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
else
:
raise
Exception
(
"num_bits must be 4, got {}"
.
format
(
num_bits
))
perm
=
perm
.
reshape
((
-
1
,
len
(
interleave
)))[:,
interleave
].
ravel
()
perm
=
torch
.
from_numpy
(
perm
)
return
perm
def
marlin_qqq_permute_scales
(
s_group
,
s_channel
,
size_k
,
size_n
,
group_size
):
scale_perm
,
scale_perm_single
=
get_qqq_scale_perms
()
if
group_size
<
size_k
and
group_size
!=
-
1
:
s_group
=
s_group
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
s_channel
=
s_channel
.
reshape
(
(
-
1
,
len
(
scale_perm_single
)))[:,
scale_perm_single
]
s_group
=
s_group
.
reshape
((
-
1
,
size_n
)).
contiguous
()
else
:
s_channel
=
s_channel
.
reshape
(
(
-
1
,
len
(
scale_perm_single
)))[:,
scale_perm_single
]
s_channel
=
s_channel
.
reshape
((
-
1
,
size_n
)).
contiguous
()
return
s_group
,
s_channel
def
marlin_qqq_quantize
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
,
):
size_k
,
size_n
=
w
.
shape
# Normalize group_size
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
quant_type
=
"per-channel"
if
group_size
==
size_k
else
"per-group"
# Quantize
w_ref
,
q_w
,
s_group
,
s_channel
=
qqq_quantize_weights
(
w
,
num_bits
,
group_size
)
# Reformat to marlin_qqq
weight_perm
=
get_qqq_weight_perm
(
num_bits
,
quant_type
)
marlin_qqq_q_w
=
marlin_qqq_weights
(
q_w
,
size_k
,
size_n
,
num_bits
,
weight_perm
,
group_size
)
marlin_qqq_s_group
,
marlin_qqq_s_channel
=
marlin_qqq_permute_scales
(
s_group
,
s_channel
,
size_k
,
size_n
,
group_size
)
# Create result
res_list
=
[
w_ref
,
marlin_qqq_q_w
,
marlin_qqq_s_group
,
marlin_qqq_s_channel
]
for
i
in
range
(
len
(
res_list
)):
res_list
[
i
]
=
res_list
[
i
].
to
(
w
.
device
)
return
res_list
vllm/model_executor/layers/quantization/utils/quant_utils.py
View file @
0cdbf5e6
...
...
@@ -9,8 +9,6 @@ import numpy
import
torch
from
vllm._custom_ops
import
cutlass_scaled_mm_supports_fp4
from
vllm.model_executor.layers.quantization.qqq
import
(
MARLIN_QQQ_SUPPORTED_NUM_BITS
)
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
ScalarType
,
scalar_types
...
...
@@ -386,89 +384,6 @@ def gptq_quantize_weights(w: torch.Tensor,
return
w_ref
,
w_q
,
w_s
,
g_idx
,
rand_perm
# QQQ employs different quant schemes for per-group and
# per-channel quantization.
def
qqq_quantize_weights
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
):
orig_device
=
w
.
device
size_k
,
size_n
=
w
.
shape
assert
w
.
is_floating_point
(),
"w must be float"
assert
num_bits
in
MARLIN_QQQ_SUPPORTED_NUM_BITS
,
\
f
"Unsupported num_bits =
{
num_bits
}
"
assert
group_size
in
SUPPORTED_GROUP_SIZES
+
[
size_k
],
f
"Unsupported groupsize =
{
group_size
}
"
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
if
group_size
<
size_k
:
# Reshape to [groupsize, -1]
w
=
w
.
reshape
((
-
1
,
group_size
,
size_n
))
w
=
w
.
permute
(
1
,
0
,
2
)
w
=
w
.
reshape
((
group_size
,
-
1
))
max_q_val
=
2
**
num_bits
-
1
half_q_val
=
(
max_q_val
+
1
)
//
2
# Compute scale for each group
s_group
=
torch
.
max
(
torch
.
abs
(
w
),
0
,
keepdim
=
True
)[
0
]
s_group
*=
2
/
max_q_val
# 2 => symmetric
# Quantize
q_w
=
torch
.
round
(
w
/
s_group
).
int
()
q_w
+=
half_q_val
q_w
=
torch
.
clamp
(
q_w
,
0
,
max_q_val
)
# Compute ref (dequantized)
w_ref
=
(
q_w
-
half_q_val
).
half
()
*
s_group
# Restore original shapes
def
reshape_w
(
w
):
w
=
w
.
reshape
((
group_size
,
-
1
,
size_n
))
w
=
w
.
permute
(
1
,
0
,
2
)
w
=
w
.
reshape
((
size_k
,
size_n
)).
contiguous
()
return
w
q_w
=
reshape_w
(
q_w
)
w_ref
=
reshape_w
(
w_ref
)
# Compute int8 quantization scale for each channel
s_channel
=
torch
.
max
(
torch
.
abs
(
w_ref
),
0
,
keepdim
=
True
)[
0
]
s_channel
/=
127.0
t_int8
=
(
w_ref
/
s_channel
).
round
().
clamp
(
-
128
,
127
).
to
(
torch
.
int8
)
w_ref
=
t_int8
.
half
()
*
s_channel
s_channel
=
s_channel
.
reshape
(
1
,
-
1
).
to
(
dtype
=
torch
.
float
)
# Fuse scales
s_group
=
(
s_group
.
reshape
(
-
1
,
size_n
).
contiguous
()
/
s_channel
).
to
(
dtype
=
torch
.
half
)
else
:
max_q_val
=
2
**
(
num_bits
-
1
)
-
1
# Compute scale for each channel
s_channel
=
torch
.
max
(
torch
.
abs
(
w
),
0
,
keepdim
=
True
)[
0
]
s_channel
/=
max_q_val
# Quantize
q_w
=
torch
.
round
(
w
/
s_channel
).
int
()
q_w
=
torch
.
clamp
(
q_w
,
-
max_q_val
,
max_q_val
)
# Compute ref (dequantized)
w_ref
=
q_w
.
half
()
*
s_channel
s_group
=
torch
.
tensor
([],
dtype
=
torch
.
half
)
# div 2 ** (8 - self.bits)) to offset right shift in unpacking
s_channel
/=
(
2
**
(
8
-
num_bits
))
s_channel
=
s_channel
.
reshape
(
-
1
,
size_n
).
contiguous
().
to
(
torch
.
float
)
return
(
w_ref
.
to
(
device
=
orig_device
),
q_w
.
to
(
device
=
orig_device
),
s_group
.
to
(
device
=
orig_device
),
s_channel
.
to
(
device
=
orig_device
),
)
def
sort_weights
(
q_w
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
):
orig_device
=
q_w
.
device
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment