Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
dd64aebf
Commit
dd64aebf
authored
Nov 01, 2025
by
jujl1
Browse files
feat: w8a8_marlin 接入,通过-q slimquant_marlin开启,优化w4a8_marlin代码
parent
4ba4b755
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
330 additions
and
9 deletions
+330
-9
vllm/config.py
vllm/config.py
+4
-3
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+6
-1
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_marlin.py
...ntization/compressed_tensors/compressed_tensors_marlin.py
+96
-0
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+0
-2
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
...ation/compressed_tensors/compressed_tensors_moe_marlin.py
+207
-0
vllm/model_executor/layers/quantization/slimquant_w4a8_marlin.py
...del_executor/layers/quantization/slimquant_w4a8_marlin.py
+1
-2
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+14
-0
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+2
-1
No files found.
vllm/config.py
View file @
dd64aebf
...
...
@@ -892,8 +892,8 @@ class ModelConfig:
optimized_quantization_methods
=
[
"fp8"
,
"marlin"
,
"modelopt"
,
"gptq_marlin_24"
,
"gptq_marlin"
,
"awq_marlin"
,
"fbgemm_fp8"
,
"compressed-tensors"
,
"experts_int8"
,
"quark"
,
"modelopt_fp4"
,
"bitblas"
,
"gptq_bitblas"
,
"slimquant_w4a8"
,
"slimquant_
w4a8
_marlin"
"quark"
,
"modelopt_fp4"
,
"bitblas"
,
"gptq_bitblas"
,
"slimquant_w4a8"
,
"slimquant_w4a8
_marlin
"
,
"slimquant_
compressed_tensors
_marlin"
]
if
self
.
quantization
is
not
None
:
self
.
quantization
=
cast
(
me_quant
.
QuantizationMethods
,
...
...
@@ -920,7 +920,8 @@ class ModelConfig:
"awq_marlin"
,
"ipex"
,
"moe_wna16"
,
"slimquant_w4a8_marlin"
"slimquant_w4a8_marlin"
,
"slimquant_compressed_tensors_marlin"
]
quantization_methods
=
[
q
for
q
in
supported_quantization
if
q
not
in
overrides
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
dd64aebf
...
...
@@ -38,7 +38,9 @@ QuantizationMethods = Literal[
"rtn"
,
"blockwise_int8"
,
"slimquant_w4a8"
,
"slimquant_w4a8_marlin"
"slimquant_w4a8_marlin"
,
"slimquant_compressed_tensors_marlin"
,
]
QUANTIZATION_METHODS
:
list
[
str
]
=
list
(
get_args
(
QuantizationMethods
))
...
...
@@ -97,6 +99,8 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from
.bitsandbytes
import
BitsAndBytesConfig
from
.compressed_tensors.compressed_tensors
import
(
# noqa: E501
CompressedTensorsConfig
)
from
.compressed_tensors.compressed_tensors_marlin
import
(
SlimQuantCompressedTensorsMarlinConfig
)
from
.deepspeedfp
import
DeepSpeedFPConfig
from
.experts_int8
import
ExpertsInt8Config
from
.fbgemm_fp8
import
FBGEMMFp8Config
...
...
@@ -154,6 +158,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"blockwise_int8"
:
BlockInt8Config
,
"slimquant_w4a8"
:
SlimQuantW4A8Int8Config
,
"slimquant_w4a8_marlin"
:
SlimQuantW4A8Int8MarlinConfig
,
"slimquant_compressed_tensors_marlin"
:
SlimQuantCompressedTensorsMarlinConfig
,
}
# Update the `method_to_config` with customized quantization methods.
method_to_config
.
update
(
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_marlin.py
0 → 100644
View file @
dd64aebf
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
,
Optional
,
cast
import
torch
from
compressed_tensors.config
import
SparsityCompressionConfig
from
compressed_tensors.quantization
import
QuantizationArgs
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.linear
import
LinearBase
from
vllm.model_executor.layers.vocab_parallel_embedding
import
UnquantizedEmbeddingMethod
from
vllm.model_executor.layers.quantization
import
QuantizationMethods
from
vllm.model_executor.layers.quantization.base_config
import
(
# noqa: E501
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors
import
(
CompressedTensorsConfig
,
CompressedTensorsLinearMethod
,
CompressedTensorsKVCacheMethod
)
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe_marlin
import
(
CompressedTensorsMarlinMoEMethod
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
should_ignore_layer
)
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
import
os
from
vllm
import
_custom_ops
as
ops
if
TYPE_CHECKING
:
from
vllm.model_executor.models.utils
import
WeightsMapper
logger
=
init_logger
(
__name__
)
__all__
=
[
"CompressedTensorsLinearMethod"
]
SPARSITY_CONFIG_NAME
:
Literal
[
"sparsity_config"
]
=
"sparsity_config"
QUANTIZATION_SCHEME_MAP_TYPE
=
dict
[
str
,
Optional
[
dict
[
str
,
QuantizationArgs
]]]
class
SlimQuantCompressedTensorsMarlinConfig
(
CompressedTensorsConfig
):
def
__init__
(
self
,
target_scheme_map
:
dict
[
str
,
Any
],
ignore
:
list
[
str
],
quant_format
:
str
,
sparsity_scheme_map
:
dict
[
str
,
SparsityCompressionConfig
],
sparsity_ignore_list
:
list
[
str
],
kv_cache_scheme
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
config
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
):
super
().
__init__
(
target_scheme_map
,
ignore
,
quant_format
,
sparsity_scheme_map
,
sparsity_ignore_list
,
kv_cache_scheme
,
config
)
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
)
->
Optional
[
QuantizationMethods
]:
if
hf_quant_cfg
.
get
(
"quant_method"
)
==
"compressed-tensors"
\
and
user_quant
==
"slimquant_marlin"
:
return
cls
.
get_name
()
return
None
@
classmethod
def
get_name
(
cls
)
->
QuantizationMethods
:
return
"slimquant_compressed_tensors_marlin"
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
,
)
->
Optional
[
"QuantizeMethodBase"
]:
from
vllm.attention.layer
import
Attention
# Avoid circular import
# Check if the layer is skipped for quantization.
if
should_ignore_layer
(
prefix
,
ignore
=
self
.
ignore
,
fused_mapping
=
self
.
packed_modules_mapping
):
return
UnquantizedEmbeddingMethod
()
#UnquantizedLinearMethod()
if
isinstance
(
layer
,
LinearBase
):
scheme
=
self
.
get_scheme
(
layer
=
layer
,
layer_name
=
prefix
)
if
scheme
is
None
:
return
UnquantizedEmbeddingMethod
()
#UnquantizedLinearMethod()
layer
.
scheme
=
scheme
return
CompressedTensorsLinearMethod
(
self
)
if
isinstance
(
layer
,
Attention
):
return
CompressedTensorsKVCacheMethod
(
self
)
if
isinstance
(
layer
,
FusedMoE
):
return
CompressedTensorsMarlinMoEMethod
.
get_moe_method
(
self
,
layer
)
return
None
\ No newline at end of file
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
dd64aebf
...
...
@@ -90,8 +90,6 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
return
CompressedTensorsW8A8Fp8MoEMethod
(
quant_config
)
elif
quant_config
.
_is_dynamic_token_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8Int8MoEMethod
(
quant_config
)
elif
quant_config
.
_is_dynamic_token_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8Int8MoEMethod
(
quant_config
)
else
:
raise
RuntimeError
(
f
"Unsupported FusedMoe scheme:
{
weight_quant
}
,
{
input_quant
}
"
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
0 → 100644
View file @
dd64aebf
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
enum
from
enum
import
Enum
from
typing
import
Callable
,
Optional
import
torch
from
compressed_tensors.quantization
import
(
QuantizationStrategy
)
from
vllm.logger
import
init_logger
from
torch.nn.parameter
import
Parameter
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEActivationFormat
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
get_w8a8_int8_marlin_weights
)
try
:
from
lmslim.layers.fused_moe.fuse_moe_int8_marlin
import
fused_experts_impl_int8_marlin
except
Exception
:
print
(
"INFO: Please install lmslim if you want to infer the quantitative model of moe.
\n
"
)
logger
=
init_logger
(
__name__
)
__all__
=
[
"CompressedTensorsW8A8Int8MarlinMoEMethod"
,
]
class
CompressedTensorsMarlinMoEMethod
(
FusedMoEMethodBase
):
@
staticmethod
def
get_moe_method
(
quant_config
:
"SlimQuantCompressedTensorsMarlinConfig"
,
# type: ignore # noqa E501
layer
:
torch
.
nn
.
Module
,
)
->
"CompressedTensorsMarlinMoEMethod"
:
# are supported + check if the layer is being ignored.
weight_quant
=
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"weights"
)
input_quant
=
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"input_activations"
)
if
quant_config
.
_is_dynamic_token_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8Int8MarlinMoEMethod
(
quant_config
)
else
:
raise
RuntimeError
(
f
"Slimquant_marlin does not support the FusedMoe scheme:
{
weight_quant
}
,
{
input_quant
}
"
)
class
CompressedTensorsW8A8Int8MarlinMoEMethod
(
CompressedTensorsMarlinMoEMethod
):
def
__init__
(
self
,
quant_config
:
"CompressedTensorsMarlinConfig"
# type: ignore # noqa E501
):
self
.
quant_config
=
quant_config
self
.
weight_quant
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"weights"
)
self
.
input_quant
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"input_activations"
)
per_channel
=
(
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
and
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
)
if
not
per_channel
:
raise
ValueError
(
"For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found "
f
"
{
self
.
weight_quant
}
,
{
self
.
input_quant
}
"
)
self
.
static_input_scales
=
not
self
.
input_quant
.
dynamic
if
self
.
static_input_scales
:
raise
ValueError
(
"For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found static input scales."
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
params_dtype
=
torch
.
int8
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size_per_partition
,
hidden_size
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
# WEIGHT_SCALES
assert
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
*
intermediate_size_per_partition
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
hidden_size
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
})
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
# INPUT_SCALES
assert
not
self
.
static_input_scales
layer
.
w13_input_scale
=
None
layer
.
w2_input_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
w1_marlin_list
=
[]
for
ii
in
range
(
layer
.
w13_weight
.
shape
[
0
]):
w1_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w13_weight
[
ii
])
w1_marlin_list
.
append
(
w1_marlin_in
)
w1_marlin
=
torch
.
stack
(
w1_marlin_list
,
dim
=
0
)
w2_marlin_list
=
[]
for
ii
in
range
(
layer
.
w2_weight
.
shape
[
0
]):
w2_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w2_weight
[
ii
])
w2_marlin_list
.
append
(
w2_marlin_in
)
w2_marlin
=
torch
.
stack
(
w2_marlin_list
,
dim
=
0
)
layer
.
w13_weight
=
Parameter
(
w1_marlin
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
w2_marlin
,
requires_grad
=
False
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
use_fused_gate
:
Optional
[
bool
]
=
False
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for "
"`CompressedTensorsW8A8Int8MoEMethod` yet."
)
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
routed_scaling_factor
=
routed_scaling_factor
,
use_fused_gate
=
use_fused_gate
,
e_score_correction_bias
=
e_score_correction_bias
)
return
fused_experts_impl_int8_marlin
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_int8_w8a8
=
True
,
per_channel_quant
=
True
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
use_nn_moe
=
False
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
)
vllm/model_executor/layers/quantization/slimquant_w4a8_marlin.py
View file @
dd64aebf
...
...
@@ -99,7 +99,7 @@ class SlimQuantW4A8Int8MarlinConfig(QuantizationConfig):
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
)
->
Optional
[
QuantizationMethods
]:
if
hf_quant_cfg
.
get
(
"quant_method"
)
==
"slimquant_w4a8"
\
and
user_quant
==
"slimquant_w4a8_marlin"
:
and
user_quant
in
(
"slimquant_w4a8_marlin"
,
"slimquant_marlin"
)
:
return
cls
.
get_name
()
return
None
def
get_quant_method
(
...
...
@@ -233,7 +233,6 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
**
_
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `SlimQuantW4A8Int8MarlinMoEMethod` yet."
)
...
...
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
dd64aebf
...
...
@@ -25,6 +25,20 @@ USE_ROWWISE_TORCH_SCALED_MM = (current_platform.is_rocm()
and
torch
.
__version__
[
0
:
3
]
>=
"2.7"
and
current_platform
.
has_device_capability
(
94
))
def
get_w8a8_int8_marlin_weights
(
weight
,
k_tile
=
64
):
# 7168, 512
weight
=
weight
.
T
size_k
,
size_n
=
weight
.
shape
assert
size_k
//
k_tile
weight
=
weight
.
reshape
(
size_k
//
k_tile
,
k_tile
,
size_n
)
weight
=
weight
.
transpose
(
1
,
2
)
weight
=
weight
.
reshape
(
size_k
//
k_tile
,
size_n
*
k_tile
)
return
weight
def
sparse_cutlass_supported
()
->
bool
:
if
not
current_platform
.
is_cuda
():
return
False
...
...
vllm/platforms/rocm.py
View file @
dd64aebf
...
...
@@ -191,7 +191,8 @@ class RocmPlatform(Platform):
supported_quantization
:
list
[
str
]
=
[
"awq"
,
"gptq"
,
"fp8"
,
"compressed-tensors"
,
"fbgemm_fp8"
,
"gguf"
,
"quark"
,
"ptpc_fp8"
,
"moe_wna16"
,
"blockwise_int8"
,
"slimquant_w4a8"
,
"awq_marlin"
,
"slimquant_w4a8_marlin"
"quark"
,
"ptpc_fp8"
,
"moe_wna16"
,
"blockwise_int8"
,
"slimquant_w4a8"
,
"awq_marlin"
,
"slimquant_w4a8_marlin"
,
"slimquant_compressed_tensors_marlin"
]
@
classmethod
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment