Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ef283548
Unverified
Commit
ef283548
authored
Sep 30, 2025
by
Pavani Majety
Committed by
GitHub
Sep 30, 2025
Browse files
[Bugfix] Fix accuracy issue of TRTLLM FP8 MOE and improve logging (#25895)
Signed-off-by:
Pavani Majety
<
pmajety@nvidia.com
>
parent
f4db5e6d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
29 additions
and
17 deletions
+29
-17
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+23
-16
vllm/utils/deep_gemm.py
vllm/utils/deep_gemm.py
+6
-1
No files found.
vllm/model_executor/layers/quantization/fp8.py
View file @
ef283548
...
...
@@ -434,14 +434,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self
.
weight_block_size
=
self
.
quant_config
.
weight_block_size
self
.
block_quant
=
self
.
weight_block_size
is
not
None
self
.
flashinfer_moe_backend
:
Optional
[
FlashinferMoeBackend
]
=
None
self
.
fused_experts
:
Optional
[
mk
.
FusedMoEModularKernel
]
=
None
# type: ignore
if
envs
.
VLLM_USE_FLASHINFER_MOE_FP8
and
has_flashinfer_moe
():
self
.
flashinfer_moe_backend
=
get_flashinfer_moe_backend
()
logger
.
info_once
(
f
"Using FlashInfer
{
self
.
flashinfer_moe_backend
.
value
}
kernels"
)
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self
.
use_marlin
=
(
not
current_platform
.
has_device_capability
(
89
)
...
...
@@ -450,14 +445,27 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if
current_platform
.
is_rocm
():
self
.
use_marlin
=
False
# First check for Flashinfer MOE on Blackwell GPUs
self
.
flashinfer_moe_backend
:
Optional
[
FlashinferMoeBackend
]
=
None
if
(
current_platform
.
is_cuda
()
and
current_platform
.
is_device_capability
(
100
)
and
envs
.
VLLM_USE_FLASHINFER_MOE_FP8
and
has_flashinfer_moe
()):
self
.
flashinfer_moe_backend
=
get_flashinfer_moe_backend
()
logger
.
info_once
(
f
"Detected Blackwell GPUs, using FlashInfer "
f
"
{
self
.
flashinfer_moe_backend
.
value
}
kernels for FP8 MOE."
)
# Check for DeepGemm support.
self
.
allow_deep_gemm
=
False
if
envs
.
VLLM_USE_DEEP_GEMM
:
if
not
has_deep_gemm
():
logger
.
warning_once
(
"Failed to import DeepGemm kernels."
)
elif
not
self
.
block_quant
:
logger
.
warning_once
(
"Model is not block quantized. Not using "
"DeepGemm kernels"
)
logger
.
warning_once
(
"Model is not block quantized. Not using"
" DeepGemm kernels"
)
elif
self
.
flashinfer_moe_backend
:
logger
.
info_once
(
"DeepGemm disabled: FlashInfer MOE is"
" enabled."
)
elif
(
is_deep_gemm_supported
()):
logger
.
info_once
(
"Using DeepGemm kernels for Fp8MoEMethod."
)
self
.
allow_deep_gemm
=
True
...
...
@@ -471,15 +479,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
logger
.
debug_once
(
"Model is not block quantized. Not using "
"CutlassBlockScaledGroupedGemm kernels"
)
elif
(
current_platform
.
is_cuda
()
and
current_platform
.
is_device_capability
(
100
)):
and
current_platform
.
is_device_capability
(
100
)
and
not
self
.
flashinfer_moe_backend
):
logger
.
info_once
(
"Using CutlassBlockScaledGroupedGemm kernels for Fp8
MoEMethod.
"
)
"Using CutlassBlockScaledGroupedGemm kernels for Fp8
MOE
"
"on SM100."
)
self
.
allow_cutlass_block_scaled_grouped_gemm
=
True
else
:
logger
.
warning_once
(
"CutlassBlockScaledGroupedGemm not supported on the current "
"platform."
)
def
create_weights
(
self
,
layer
:
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
...
...
@@ -934,7 +939,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
import
vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe
# noqa: E501, F401
assert
(
renormalize
and
use_grouped_topk
and
custom_routing_function
is
None
)
result
=
torch
.
ops
.
vllm
.
flashinfer_fused_moe_blockscale_fp8
(
e_score_correction_bias
=
(
e_score_correction_bias
.
to
(
x
.
dtype
)
if
e_score_correction_bias
is
not
None
else
None
)
return
torch
.
ops
.
vllm
.
flashinfer_fused_moe_blockscale_fp8
(
routing_logits
=
router_logits
.
to
(
torch
.
float32
),
routing_bias
=
e_score_correction_bias
,
x
=
x
,
...
...
vllm/utils/deep_gemm.py
View file @
ef283548
...
...
@@ -27,7 +27,8 @@ def is_deep_gemm_supported() -> bool:
is_supported_arch
=
current_platform
.
is_cuda
()
and
(
current_platform
.
is_device_capability
(
90
)
or
current_platform
.
is_device_capability
(
100
))
return
envs
.
VLLM_USE_DEEP_GEMM
and
has_deep_gemm
()
and
is_supported_arch
return
(
envs
.
VLLM_USE_DEEP_GEMM
and
has_deep_gemm
()
and
is_supported_arch
and
not
envs
.
VLLM_USE_FLASHINFER_MOE_FP8
)
@
functools
.
cache
...
...
@@ -46,6 +47,10 @@ def is_deep_gemm_e8m0_used() -> bool:
logger
.
info_once
(
"DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found"
)
return
False
if
envs
.
VLLM_USE_FLASHINFER_MOE_FP8
:
logger
.
info_once
(
"DeepGEMM E8M0 disabled: FlashInfer MOE is enabled."
)
return
False
if
current_platform
.
is_device_capability
(
100
)
and
\
envs
.
VLLM_USE_DEEP_GEMM_E8M0
:
logger
.
info_once
(
"DeepGEMM E8M0 enabled on Blackwell GPU."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment