Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2e6bc468
Unverified
Commit
2e6bc468
authored
Sep 11, 2025
by
Lucas Wilkinson
Committed by
GitHub
Sep 11, 2025
Browse files
[Startup] Make DeepGEMM warmup scale with max-num-batched-tokens (#24693)
Signed-off-by:
Lucas Wilkinson
<
lwilkins@redhat.com
>
parent
fcba05c4
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
9 deletions
+13
-9
vllm/model_executor/warmup/deep_gemm_warmup.py
vllm/model_executor/warmup/deep_gemm_warmup.py
+13
-9
No files found.
vllm/model_executor/warmup/deep_gemm_warmup.py
View file @
2e6bc468
...
@@ -10,6 +10,7 @@ import torch
...
@@ -10,6 +10,7 @@ import torch
from
tqdm
import
tqdm
from
tqdm
import
tqdm
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.distributed.parallel_state
import
get_dp_group
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
DeepGemmExperts
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
DeepGemmExperts
from
vllm.model_executor.layers.fused_moe.deep_gemm_utils
import
(
from
vllm.model_executor.layers.fused_moe.deep_gemm_utils
import
(
compute_aligned_M
,
deep_gemm_block_shape
)
compute_aligned_M
,
deep_gemm_block_shape
)
...
@@ -131,11 +132,9 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor,
...
@@ -131,11 +132,9 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor,
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
:
set
[
torch
.
Size
]
=
set
()
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
:
set
[
torch
.
Size
]
=
set
()
def
_deepgemm_grouped_fp8_gemm_nt_contiguous_warmup
(
w1
:
torch
.
Tensor
,
def
_deepgemm_grouped_fp8_gemm_nt_contiguous_warmup
(
w2
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
num_topk
:
int
,
max_tokens
:
int
):
w2_scale
:
torch
.
Tensor
,
num_topk
:
int
):
if
(
w1
.
size
()
in
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
if
(
w1
.
size
()
in
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
and
w2
.
size
()
in
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
):
and
w2
.
size
()
in
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
):
return
return
...
@@ -147,9 +146,13 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(w1: torch.Tensor,
...
@@ -147,9 +146,13 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(w1: torch.Tensor,
num_experts
=
w1
.
size
(
0
)
num_experts
=
w1
.
size
(
0
)
device
=
w1
.
device
device
=
w1
.
device
# Assumes all ranks have the same max_num_batched_tokens
max_tokens_across_dp
=
get_dp_group
().
world_size
*
max_tokens
max_tokens
=
min
(
max_tokens_across_dp
,
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
)
# This is the maximum GroupedGemm M size that we expect to run
# This is the maximum GroupedGemm M size that we expect to run
# the grouped_gemm with.
# the grouped_gemm with.
MAX_M
=
compute_aligned_M
(
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
,
MAX_M
=
compute_aligned_M
(
max_tokens
,
num_topk
,
num_topk
,
num_experts
,
num_experts
,
block_m
,
block_m
,
...
@@ -201,7 +204,8 @@ def deepgemm_fp8_gemm_nt_warmup(model: torch.nn.Module, max_tokens: int):
...
@@ -201,7 +204,8 @@ def deepgemm_fp8_gemm_nt_warmup(model: torch.nn.Module, max_tokens: int):
_deepgemm_fp8_gemm_nt_warmup
(
w
=
w
,
ws
=
ws
,
max_tokens
=
max_tokens
)
_deepgemm_fp8_gemm_nt_warmup
(
w
=
w
,
ws
=
ws
,
max_tokens
=
max_tokens
)
def
deepgemm_grouped_fp8_gemm_nt_contiguous_warmup
(
model
:
torch
.
nn
.
Module
):
def
deepgemm_grouped_fp8_gemm_nt_contiguous_warmup
(
model
:
torch
.
nn
.
Module
,
max_tokens
:
int
):
dg_modules
=
[
dg_modules
=
[
m
for
m
in
model
.
modules
()
m
for
m
in
model
.
modules
()
if
_fused_moe_grouped_gemm_may_use_deep_gemm
(
m
)
if
_fused_moe_grouped_gemm_may_use_deep_gemm
(
m
)
...
@@ -211,9 +215,9 @@ def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model: torch.nn.Module):
...
@@ -211,9 +215,9 @@ def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model: torch.nn.Module):
w13
,
w13_scale
,
w2
,
w2_scale
,
num_topk
=
(
w13
,
w13_scale
,
w2
,
w2_scale
,
num_topk
=
(
_extract_data_from_fused_moe_module
(
dgm
))
_extract_data_from_fused_moe_module
(
dgm
))
_deepgemm_grouped_fp8_gemm_nt_contiguous_warmup
(
_deepgemm_grouped_fp8_gemm_nt_contiguous_warmup
(
w13
,
w2
,
w13_scale
,
w2_scale
,
num_topk
)
w13
,
w2
,
w13_scale
,
w2_scale
,
num_topk
,
max_tokens
)
def
deep_gemm_warmup
(
model
:
torch
.
nn
.
Module
,
max_tokens
:
int
):
def
deep_gemm_warmup
(
model
:
torch
.
nn
.
Module
,
max_tokens
:
int
):
deepgemm_fp8_gemm_nt_warmup
(
model
,
max_tokens
)
deepgemm_fp8_gemm_nt_warmup
(
model
,
max_tokens
)
deepgemm_grouped_fp8_gemm_nt_contiguous_warmup
(
model
)
deepgemm_grouped_fp8_gemm_nt_contiguous_warmup
(
model
,
max_tokens
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment