Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
dd2ccf8d
Unverified
Commit
dd2ccf8d
authored
Jun 24, 2025
by
Jun-Howie
Committed by
GitHub
Jun 24, 2025
Browse files
Feat Dynamic Quantization for MoE Layers in GPTQ Marlin Backend (#19395)
parent
a3bc76e4
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
29 additions
and
3 deletions
+29
-3
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+29
-3
No files found.
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
dd2ccf8d
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
copy
import
deepcopy
from
typing
import
Any
,
Callable
,
Optional
,
Union
import
torch
...
...
@@ -9,7 +10,8 @@ import vllm.model_executor.layers.fused_moe # noqa
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.layer
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
,
UnquantizedFusedMoEMethod
)
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
set_weight_attrs
)
from
vllm.model_executor.layers.quantization
import
QuantizationMethods
...
...
@@ -19,7 +21,7 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
MPLinearLayerConfig
,
choose_mp_linear_kernel
)
from
vllm.model_executor.layers.quantization.utils
import
replace_parameter
from
vllm.model_executor.layers.quantization.utils.gptq_utils
import
(
get_linear_quant_method
)
get_dynamic_override
,
get_linear_quant_method
,
override_config
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
check_marlin_supported
,
check_moe_marlin_supports_layer
,
marlin_make_workspace_new
,
marlin_moe_permute_scales
,
...
...
@@ -35,6 +37,29 @@ from vllm.scalar_type import scalar_types
logger
=
init_logger
(
__name__
)
def
get_moe_quant_method
(
config
:
QuantizationConfig
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
,
moe_method_cls
:
type
,
):
cloned_config
=
deepcopy
(
config
)
if
isinstance
(
layer
,
FusedMoE
):
# False = skip module, None = no override, else = Positive match
if
get_dynamic_override
(
# noqa: E712
cloned_config
,
# noqa: E712
layer_name
=
prefix
)
==
False
:
# noqa: E712
return
UnquantizedFusedMoEMethod
(
layer
.
moe_config
)
if
prefix
:
# Dynamic per module/layer rules may override base config
override_config
(
cloned_config
,
prefix
=
prefix
)
return
moe_method_cls
(
cloned_config
)
return
None
class
GPTQMarlinConfig
(
QuantizationConfig
):
"""Config class for GPTQ Marlin"""
...
...
@@ -163,7 +188,8 @@ class GPTQMarlinConfig(QuantizationConfig):
"Falling back to Moe WNA16 kernels."
)
return
MoeWNA16Config
.
from_config
(
self
.
full_config
).
get_quant_method
(
layer
,
prefix
)
return
GPTQMarlinMoEMethod
(
self
)
return
get_moe_quant_method
(
self
,
layer
,
prefix
,
GPTQMarlinMoEMethod
)
return
get_linear_quant_method
(
self
,
layer
,
prefix
,
GPTQMarlinLinearMethod
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment