Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6f1355a1
Unverified
Commit
6f1355a1
authored
Nov 24, 2025
by
Michael Goin
Committed by
GitHub
Nov 24, 2025
Browse files
[Perf] Disable DeepGEMM MoE by default when TP=8 is used (#29346)
Signed-off-by:
mgoin
<
mgoin64@gmail.com
>
parent
a4ad43ad
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
4 deletions
+20
-4
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+20
-4
No files found.
vllm/model_executor/layers/quantization/fp8.py
View file @
6f1355a1
...
...
@@ -28,6 +28,7 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoeWeightScaleSupported
,
)
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEParallelConfig
,
FusedMoEQuantConfig
,
RoutingMethodType
,
fp8_w8a8_moe_quant_config
,
...
...
@@ -118,7 +119,9 @@ class Fp8MoeBackend(Enum):
TRITON
=
6
def
get_fp8_moe_backend
(
block_quant
:
bool
)
->
Fp8MoeBackend
:
def
get_fp8_moe_backend
(
block_quant
:
bool
,
moe_parallel_config
:
FusedMoEParallelConfig
)
->
Fp8MoeBackend
:
"""
Select the primary FP8 MoE backend
Note: Shape-specific fallbacks may still occur at runtime.
...
...
@@ -159,8 +162,19 @@ def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
logger
.
info_once
(
"Using Marlin backend for FP8 MoE"
)
return
Fp8MoeBackend
.
MARLIN
# deepGEMM on supported platforms with block-quantized weights
if
envs
.
VLLM_USE_DEEP_GEMM
and
envs
.
VLLM_MOE_USE_DEEP_GEMM
and
block_quant
:
# Determine if we should use DeepGEMM with block-quantized weights:
# - If explicitly set by user, respect their choice
# - If not explicitly set (default), disable when TP size is >= 8
moe_use_deep_gemm
=
envs
.
VLLM_MOE_USE_DEEP_GEMM
if
not
envs
.
is_set
(
"VLLM_MOE_USE_DEEP_GEMM"
)
and
moe_parallel_config
.
tp_size
>=
8
:
moe_use_deep_gemm
=
False
logger
.
info_once
(
"DeepGEMM MoE is disabled by default when TP size is >= 8. "
"Set VLLM_MOE_USE_DEEP_GEMM=1 to enable it."
,
scope
=
"local"
,
)
if
envs
.
VLLM_USE_DEEP_GEMM
and
moe_use_deep_gemm
and
block_quant
:
if
not
has_deep_gemm
():
logger
.
warning_once
(
"DeepGEMM backend requested but not available."
,
scope
=
"local"
...
...
@@ -641,7 +655,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self
.
quant_config
=
quant_config
self
.
weight_block_size
=
self
.
quant_config
.
weight_block_size
self
.
block_quant
:
bool
=
self
.
weight_block_size
is
not
None
self
.
fp8_backend
=
get_fp8_moe_backend
(
self
.
block_quant
)
self
.
fp8_backend
=
get_fp8_moe_backend
(
self
.
block_quant
,
layer
.
moe_parallel_config
)
self
.
use_marlin
=
self
.
fp8_backend
==
Fp8MoeBackend
.
MARLIN
self
.
flashinfer_moe_backend
:
FlashinferMoeBackend
|
None
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment