Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c1ffcb55
Unverified
Commit
c1ffcb55
authored
Oct 03, 2025
by
Wentao Ye
Committed by
GitHub
Oct 03, 2025
Browse files
[Refactor] Optimize FP8 MOE Backend Choice and Log (#26044)
Signed-off-by:
yewentao256
<
zhyanwentao@126.com
>
parent
0879736a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
71 additions
and
46 deletions
+71
-46
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+71
-46
No files found.
vllm/model_executor/layers/quantization/fp8.py
View file @
c1ffcb55
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
enum
import
Enum
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Optional
,
Union
import
torch
import
torch
...
@@ -68,6 +69,65 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
...
@@ -68,6 +69,65 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
class
Fp8MoeBackend
(
Enum
):
NONE
=
0
FLASHINFER_TRTLLM
=
1
FLASHINFER_CUTLASS
=
2
DEEPGEMM
=
3
CUTLASS_BLOCK_SCALED_GROUPED_GEMM
=
4
MARLIN
=
5
TRITON
=
6
def
get_fp8_moe_backend
(
block_quant
:
bool
)
->
Fp8MoeBackend
:
"""
Select the primary FP8 MoE backend
Note: Shape-specific fallbacks may still occur at runtime.
"""
# prefer FlashInfer backends when available and enabled on supported GPUs
if
(
current_platform
.
is_cuda
()
and
current_platform
.
is_device_capability
(
100
)
and
envs
.
VLLM_USE_FLASHINFER_MOE_FP8
and
has_flashinfer_moe
()):
backend
=
get_flashinfer_moe_backend
()
if
backend
==
FlashinferMoeBackend
.
TENSORRT_LLM
:
logger
.
info_once
(
"Using FlashInfer FP8 MoE TRTLLM backend for SM100"
)
return
Fp8MoeBackend
.
FLASHINFER_TRTLLM
else
:
logger
.
info_once
(
"Using FlashInfer FP8 MoE CUTLASS backend for SM100"
)
return
Fp8MoeBackend
.
FLASHINFER_CUTLASS
# weight-only path for older GPUs without native FP8
use_marlin
=
(
not
current_platform
.
has_device_capability
(
89
)
or
envs
.
VLLM_TEST_FORCE_FP8_MARLIN
)
if
current_platform
.
is_rocm
():
use_marlin
=
False
if
use_marlin
:
logger
.
info_once
(
"Using Marlin backend for FP8 MoE"
)
return
Fp8MoeBackend
.
MARLIN
# deepGEMM on supported platforms with block-quantized weights
if
envs
.
VLLM_USE_DEEP_GEMM
and
block_quant
:
if
not
has_deep_gemm
():
logger
.
warning_once
(
"DeepGEMM backend requested but not available."
)
elif
is_deep_gemm_supported
():
logger
.
info_once
(
"Using DeepGEMM backend for FP8 MoE"
)
return
Fp8MoeBackend
.
DEEPGEMM
# CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights
if
(
current_platform
.
is_cuda
()
and
current_platform
.
is_device_capability
(
100
)
and
block_quant
):
logger
.
info_once
(
"Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE"
)
return
Fp8MoeBackend
.
CUTLASS_BLOCK_SCALED_GROUPED_GEMM
# default to Triton
logger
.
info_once
(
"Using Triton backend for FP8 MoE"
)
return
Fp8MoeBackend
.
TRITON
class
Fp8Config
(
QuantizationConfig
):
class
Fp8Config
(
QuantizationConfig
):
"""Config class for FP8."""
"""Config class for FP8."""
...
@@ -453,54 +513,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -453,54 +513,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self
.
fused_experts
:
Optional
[
self
.
fused_experts
:
Optional
[
mk
.
FusedMoEModularKernel
]
=
None
# type: ignore
mk
.
FusedMoEModularKernel
]
=
None
# type: ignore
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
self
.
fp8_backend
=
get_fp8_moe_backend
(
self
.
block_quant
)
# kernel for fast weight-only FP8 quantization
self
.
use_marlin
=
(
not
current_platform
.
has_device_capability
(
89
)
or
envs
.
VLLM_TEST_FORCE_FP8_MARLIN
)
# Disable marlin for rocm
if
current_platform
.
is_rocm
():
self
.
use_marlin
=
False
# First check for Flashinfer MOE on Blackwell GPUs
self
.
use_marlin
=
(
self
.
fp8_backend
==
Fp8MoeBackend
.
MARLIN
)
self
.
flashinfer_moe_backend
:
Optional
[
FlashinferMoeBackend
]
=
None
self
.
flashinfer_moe_backend
:
Optional
[
FlashinferMoeBackend
]
=
None
if
(
current_platform
.
is_cuda
()
if
self
.
fp8_backend
==
Fp8MoeBackend
.
FLASHINFER_TRTLLM
:
and
current_platform
.
is_device_capability
(
100
)
self
.
flashinfer_moe_backend
=
FlashinferMoeBackend
.
TENSORRT_LLM
and
envs
.
VLLM_USE_FLASHINFER_MOE_FP8
and
has_flashinfer_moe
()):
elif
self
.
fp8_backend
==
Fp8MoeBackend
.
FLASHINFER_CUTLASS
:
self
.
flashinfer_moe_backend
=
get_flashinfer_moe_backend
()
self
.
flashinfer_moe_backend
=
FlashinferMoeBackend
.
CUTLASS
logger
.
info_once
(
f
"Detected Blackwell GPUs, using FlashInfer "
self
.
allow_deep_gemm
=
(
self
.
fp8_backend
==
Fp8MoeBackend
.
DEEPGEMM
)
f
"
{
self
.
flashinfer_moe_backend
.
value
}
kernels for FP8 MOE."
)
self
.
allow_cutlass_block_scaled_grouped_gemm
=
(
self
.
fp8_backend
==
Fp8MoeBackend
.
CUTLASS_BLOCK_SCALED_GROUPED_GEMM
# Check for DeepGemm support.
)
self
.
allow_deep_gemm
=
False
if
envs
.
VLLM_USE_DEEP_GEMM
:
if
not
has_deep_gemm
():
logger
.
warning_once
(
"Failed to import DeepGemm kernels."
)
elif
not
self
.
block_quant
:
logger
.
warning_once
(
"Model is not block quantized. Not using"
" DeepGemm kernels"
)
elif
self
.
flashinfer_moe_backend
:
logger
.
info_once
(
"DeepGemm disabled: FlashInfer MOE is"
" enabled."
)
elif
(
is_deep_gemm_supported
()):
logger
.
debug_once
(
"DeepGemm kernels available for Fp8MoEMethod."
)
self
.
allow_deep_gemm
=
True
else
:
logger
.
warning_once
(
"DeepGemm not supported on the current platform."
)
# Check for CutlassBlockScaledGroupedGemm support.
self
.
allow_cutlass_block_scaled_grouped_gemm
=
False
if
not
self
.
block_quant
:
logger
.
debug_once
(
"Model is not block quantized. Not using "
"CutlassBlockScaledGroupedGemm kernels"
)
elif
(
current_platform
.
is_cuda
()
and
current_platform
.
is_device_capability
(
100
)
and
not
self
.
flashinfer_moe_backend
):
logger
.
debug_once
(
"CutlassBlockScaledGroupedGemm available for Fp8MoEMethod."
)
self
.
allow_cutlass_block_scaled_grouped_gemm
=
True
def
create_weights
(
self
,
layer
:
Module
,
num_experts
:
int
,
hidden_size
:
int
,
def
create_weights
(
self
,
layer
:
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
intermediate_size_per_partition
:
int
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment