Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6a6fc41c
Unverified
Commit
6a6fc41c
authored
Dec 12, 2025
by
Bhanu Prakash Voutharoja
Committed by
GitHub
Dec 12, 2025
Browse files
gptq marlin quantization support for fused moe with lora (#30254)
Signed-off-by:
Bhanu068
<
voutharoja.bhanu06@gmail.com
>
parent
f355ad54
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
146 additions
and
2 deletions
+146
-2
csrc/moe/marlin_moe_wna16/ops.cu
csrc/moe/marlin_moe_wna16/ops.cu
+1
-1
vllm/model_executor/layers/fused_moe/config.py
vllm/model_executor/layers/fused_moe/config.py
+36
-0
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+109
-1
No files found.
csrc/moe/marlin_moe_wna16/ops.cu
View file @
6a6fc41c
vllm/model_executor/layers/fused_moe/config.py
View file @
6a6fc41c
...
...
@@ -543,6 +543,42 @@ def int8_w8a8_moe_quant_config(
)
def
gptq_marlin_moe_quant_config
(
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
weight_bits
:
int
,
group_size
:
int
,
w1_zp
:
torch
.
Tensor
|
None
=
None
,
w2_zp
:
torch
.
Tensor
|
None
=
None
,
w1_bias
:
torch
.
Tensor
|
None
=
None
,
w2_bias
:
torch
.
Tensor
|
None
=
None
,
):
"""
Construct a quant config for gptq marlin quantization.
"""
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
GroupShape
w_shape
=
None
if
group_size
==
-
1
else
GroupShape
(
row
=
1
,
col
=
group_size
)
# Activations are NOT quantized for GPTQ (fp16/bf16)
a_shape
=
w_shape
# Same as weight shape for alignment
# Determine weight dtype
if
weight_bits
==
4
:
weight_dtype
=
"int4"
elif
weight_bits
==
8
:
weight_dtype
=
torch
.
int8
else
:
raise
ValueError
(
f
"Unsupported weight_bits:
{
weight_bits
}
"
)
return
FusedMoEQuantConfig
(
_a1
=
FusedMoEQuantDesc
(
dtype
=
None
,
shape
=
a_shape
),
_a2
=
FusedMoEQuantDesc
(
dtype
=
None
,
shape
=
a_shape
),
_w1
=
FusedMoEQuantDesc
(
weight_dtype
,
w_shape
,
w1_scale
,
None
,
w1_zp
,
w1_bias
),
_w2
=
FusedMoEQuantDesc
(
weight_dtype
,
w_shape
,
w2_scale
,
None
,
w2_zp
,
w2_bias
),
)
def
mxfp4_w4a16_moe_quant_config
(
w1_scale
:
Union
[
torch
.
Tensor
,
"PrecisionConfig"
],
w2_scale
:
Union
[
torch
.
Tensor
,
"PrecisionConfig"
],
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
6a6fc41c
...
...
@@ -732,6 +732,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
is_a_8bit
=
is_a_8bit
,
)
replace_parameter
(
layer
,
"w2_qweight"
,
marlin_w2_qweight
)
# The modular kernel expects w13_weight and w2_weight,
# but GPTQ uses w13_qweight and w2_qweight
# Alias for modular kernel
layer
.
w13_weight
=
layer
.
w13_qweight
# Alias for modular kernel
layer
.
w2_weight
=
layer
.
w2_qweight
# Repack scales
marlin_w13_scales
=
marlin_moe_permute_scales
(
s
=
layer
.
w13_scales
,
...
...
@@ -782,7 +790,107 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
)
->
FusedMoEQuantConfig
|
None
:
return
None
from
vllm.model_executor.layers.fused_moe.config
import
(
gptq_marlin_moe_quant_config
,
)
return
gptq_marlin_moe_quant_config
(
w1_scale
=
layer
.
w13_scales
,
w2_scale
=
layer
.
w2_scales
,
weight_bits
=
self
.
quant_config
.
weight_bits
,
group_size
=
self
.
quant_config
.
group_size
,
w1_zp
=
getattr
(
layer
,
"w13_qzeros"
,
None
)
if
not
self
.
quant_config
.
is_sym
else
None
,
w2_zp
=
getattr
(
layer
,
"w2_qzeros"
,
None
)
if
not
self
.
quant_config
.
is_sym
else
None
,
w1_bias
=
getattr
(
layer
,
"w13_bias"
,
None
),
w2_bias
=
getattr
(
layer
,
"w2_bias"
,
None
),
)
def
select_gemm_impl
(
self
,
prepare_finalize
,
layer
:
torch
.
nn
.
Module
,
):
"""
Select the GEMM implementation for GPTQ-Marlin MoE.
Returns MarlinExperts configured for GPTQ quantization.
This is ONLY used when LoRA is enabled.
Without LoRA, GPTQ uses its own apply() method.
"""
# Only use modular kernels when LoRA is enabled
# Without LoRA, GPTQ's own apply() method works fine and is more efficient
if
not
self
.
moe
.
is_lora_enabled
:
raise
NotImplementedError
(
"GPTQ-Marlin uses its own apply() method when LoRA is not enabled. "
"Modular kernels are only used for LoRA support."
)
# The modular marlin kernels do not support 8-bit weights.
if
self
.
quant_config
.
weight_bits
==
8
:
raise
NotImplementedError
(
"GPTQ-Marlin kernel does not support 8-bit weights."
)
from
vllm.model_executor.layers.fused_moe
import
modular_kernel
as
mk
from
vllm.model_executor.layers.fused_moe.fused_marlin_moe
import
(
BatchedMarlinExperts
,
MarlinExperts
,
)
# Ensure quant config is initialized
assert
self
.
moe_quant_config
is
not
None
,
(
"moe_quant_config must be initialized before select_gemm_impl"
)
w13_g_idx
=
(
getattr
(
layer
,
"w13_g_idx"
,
None
)
if
self
.
quant_config
.
desc_act
else
None
)
w2_g_idx
=
(
getattr
(
layer
,
"w2_g_idx"
,
None
)
if
self
.
quant_config
.
desc_act
else
None
)
w13_g_idx_sort_indices
=
(
getattr
(
layer
,
"w13_g_idx_sort_indices"
,
None
)
if
self
.
quant_config
.
desc_act
else
None
)
w2_g_idx_sort_indices
=
(
getattr
(
layer
,
"w2_g_idx_sort_indices"
,
None
)
if
self
.
quant_config
.
desc_act
else
None
)
# Check if using batched expert format (for Expert Parallelism)
if
(
prepare_finalize
.
activation_format
==
mk
.
FusedMoEActivationFormat
.
BatchedExperts
):
# For batched format, use BatchedMarlinExperts
max_num_tokens_per_rank
=
prepare_finalize
.
max_num_tokens_per_rank
()
assert
max_num_tokens_per_rank
is
not
None
return
BatchedMarlinExperts
(
max_num_tokens
=
max_num_tokens_per_rank
,
num_dispatchers
=
prepare_finalize
.
num_dispatchers
(),
quant_config
=
self
.
moe_quant_config
,
w13_g_idx
=
w13_g_idx
,
w2_g_idx
=
w2_g_idx
,
w13_g_idx_sort_indices
=
w13_g_idx_sort_indices
,
w2_g_idx_sort_indices
=
w2_g_idx_sort_indices
,
is_k_full
=
self
.
is_k_full
,
)
else
:
# Standard Marlin experts for GPTQ
return
MarlinExperts
(
quant_config
=
self
.
moe_quant_config
,
w13_g_idx
=
w13_g_idx
,
w2_g_idx
=
w2_g_idx
,
w13_g_idx_sort_indices
=
w13_g_idx_sort_indices
,
w2_g_idx_sort_indices
=
w2_g_idx_sort_indices
,
is_k_full
=
self
.
is_k_full
,
)
def
apply
(
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment