Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4022a9d2
"docs/benchmarking/cli.md" did not exist on "7cd95dc8a327d1366fb5a7a9425ef995c3f71dbd"
Unverified
Commit
4022a9d2
authored
Nov 04, 2025
by
Varun Sundar Rabindranath
Committed by
GitHub
Nov 04, 2025
Browse files
[BugFix][Performance] Restore flashinfer autotuning for all scenarios (#27904)
parent
53f6e81d
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
14 additions
and
44 deletions
+14
-44
tests/quantization/test_blackwell_moe.py
tests/quantization/test_blackwell_moe.py
+2
-14
vllm/model_executor/layers/fused_moe/trtllm_moe.py
vllm/model_executor/layers/fused_moe/trtllm_moe.py
+9
-2
vllm/model_executor/layers/quantization/mxfp4.py
vllm/model_executor/layers/quantization/mxfp4.py
+2
-2
vllm/model_executor/warmup/kernel_warmup.py
vllm/model_executor/warmup/kernel_warmup.py
+1
-26
No files found.
tests/quantization/test_blackwell_moe.py
View file @
4022a9d2
...
...
@@ -172,21 +172,9 @@ def test_gptoss_mxfp4mxfp8_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch
can_initialize
(
"openai/gpt-oss-20b"
,
hf_overrides
=
HF_OVERRIDE_TEXT
)
def
test_gptoss_dp2_mxfp4mxfp8_moe_flashinfer_trtllm
(
monkeypatch
:
pytest
.
MonkeyPatch
):
monkeypatch
.
setenv
(
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8"
,
"1"
)
monkeypatch
.
setenv
(
"VLLM_ALL2ALL_BACKEND"
,
"deepep_high_throughput"
)
can_initialize
(
"openai/gpt-oss-20b"
,
extra_args
=
[
"--data-parallel-size"
,
"2"
,
"--enable-expert-parallel"
],
hf_overrides
=
HF_OVERRIDE_TEXT
,
)
def
test_gptoss_dp2_mxfp4bf16_moe_flashinfer_trtllm
(
monkeypatch
:
pytest
.
MonkeyPatch
):
monkeypatch
.
setenv
(
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"
,
"1"
)
monkeypatch
.
setenv
(
"VLLM_ALL2ALL_BACKEND"
,
"deepep_high_throughput"
)
def
test_gptoss_eager
(
monkeypatch
:
pytest
.
MonkeyPatch
):
can_initialize
(
"openai/gpt-oss-20b"
,
extra_args
=
[
"--data-parallel-size"
,
"2"
,
"--enable-expert-parallel"
],
hf_overrides
=
HF_OVERRIDE_TEXT
,
extra_args
=
[
"--enforce-eager"
],
)
vllm/model_executor/layers/fused_moe/trtllm_moe.py
View file @
4022a9d2
...
...
@@ -127,10 +127,17 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
"routing_method_type"
:
1
,
"do_finalize"
:
True
,
"output"
:
output
,
"tune_max_num_tokens"
:
self
.
max_capture_size
,
"tune_max_num_tokens"
:
max
(
self
.
max_capture_size
,
1
),
}
from
flashinfer
import
trtllm_fp4_block_scale_routed_moe
from
vllm.utils.flashinfer
import
autotune
with
autotune
(
False
):
# Enable autotune when,
# https://github.com/flashinfer-ai/flashinfer/issues/2023 is
# resolved.
trtllm_fp4_block_scale_routed_moe
(
**
kwargs
)
return
output
vllm/model_executor/layers/quantization/mxfp4.py
View file @
4022a9d2
...
...
@@ -1047,7 +1047,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
None
,
1
if
renormalize
else
0
,
# routing_method_type, renormalize
True
,
# do finalize
tune_max_num_tokens
=
self
.
max_capture_size
,
tune_max_num_tokens
=
max
(
self
.
max_capture_size
,
1
),
)[
0
]
return
trtllm_gen_output
elif
(
...
...
@@ -1122,7 +1122,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
tp_rank
=
self
.
moe
.
tp_rank
,
ep_size
=
self
.
moe
.
ep_size
,
ep_rank
=
self
.
moe
.
ep_rank
,
tune_max_num_tokens
=
self
.
max_capture_size
,
tune_max_num_tokens
=
max
(
self
.
max_capture_size
,
1
),
**
extra_kwargs
,
)
...
...
vllm/model_executor/warmup/kernel_warmup.py
View file @
4022a9d2
...
...
@@ -11,7 +11,6 @@ from typing import TYPE_CHECKING
import
torch
import
vllm.envs
as
envs
from
vllm.config
import
CUDAGraphMode
,
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.warmup.deep_gemm_warmup
import
deep_gemm_warmup
from
vllm.platforms
import
current_platform
...
...
@@ -25,26 +24,6 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
def
flashinfer_autotune_supported
(
vllm_config
:
VllmConfig
)
->
bool
:
"""
Record known issues with vllm + flashinfer autotune here. Return True if
and only if flashinfer autotune will run through without issues.
"""
is_tp_or_dp
=
(
vllm_config
.
parallel_config
.
data_parallel_size
>
1
)
or
(
vllm_config
.
parallel_config
.
tensor_parallel_size
>
1
)
is_fi_mxfp4_backend
=
(
envs
.
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
or
envs
.
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16
or
envs
.
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS
)
or
(
current_platform
.
is_cuda
()
and
current_platform
.
is_device_capability
(
100
)
)
# on >=sm100, default mxfp4 backend is flashinfer
is_eager
=
vllm_config
.
compilation_config
.
cudagraph_mode
==
CUDAGraphMode
.
NONE
return
not
(
is_tp_or_dp
and
is_fi_mxfp4_backend
and
is_eager
)
def
kernel_warmup
(
worker
:
"Worker"
):
# Deep GEMM warmup
do_deep_gemm_warmup
=
(
...
...
@@ -58,11 +37,7 @@ def kernel_warmup(worker: "Worker"):
deep_gemm_warmup
(
model
,
max_tokens
)
# FlashInfer autotune for Hopper (SM 9.0) and Blackwell (SM 10.0) GPUs
if
(
has_flashinfer
()
and
current_platform
.
has_device_capability
(
90
)
and
flashinfer_autotune_supported
(
worker
.
vllm_config
)
):
if
has_flashinfer
()
and
current_platform
.
has_device_capability
(
90
):
flashinfer_autotune
(
worker
.
model_runner
)
# FlashInfer attention warmup
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment