Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4a8412f7
Unverified
Commit
4a8412f7
authored
Dec 17, 2025
by
Matthew Bonanni
Committed by
GitHub
Dec 17, 2025
Browse files
[UX] Reduce DeepGEMM warmup log output to single progress bar (#30903)
Signed-off-by:
Matthew Bonanni
<
mbonanni@redhat.com
>
parent
0c738b58
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
99 additions
and
42 deletions
+99
-42
vllm/model_executor/warmup/deep_gemm_warmup.py
vllm/model_executor/warmup/deep_gemm_warmup.py
+99
-42
No files found.
vllm/model_executor/warmup/deep_gemm_warmup.py
View file @
4a8412f7
...
@@ -10,7 +10,7 @@ import torch
...
@@ -10,7 +10,7 @@ import torch
from
tqdm
import
tqdm
from
tqdm
import
tqdm
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.distributed.parallel_state
import
get_dp_group
from
vllm.distributed.parallel_state
import
get_dp_group
,
is_global_first_rank
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
DeepGemmExperts
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
DeepGemmExperts
from
vllm.model_executor.layers.fused_moe.deep_gemm_utils
import
compute_aligned_M
from
vllm.model_executor.layers.fused_moe.deep_gemm_utils
import
compute_aligned_M
from
vllm.model_executor.layers.fused_moe.layer
import
FusedMoE
,
FusedMoEModularMethod
from
vllm.model_executor.layers.fused_moe.layer
import
FusedMoE
,
FusedMoEModularMethod
...
@@ -175,7 +175,30 @@ def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
...
@@ -175,7 +175,30 @@ def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
FP8_GEMM_NT_WARMUP_CACHE
:
set
[
torch
.
Size
]
=
set
()
FP8_GEMM_NT_WARMUP_CACHE
:
set
[
torch
.
Size
]
=
set
()
def
_deepgemm_fp8_gemm_nt_warmup
(
w
:
torch
.
Tensor
,
ws
:
torch
.
Tensor
,
max_tokens
:
int
):
def
_get_fp8_gemm_nt_m_values
(
w
:
torch
.
Tensor
,
max_tokens
:
int
)
->
list
[
int
]:
"""Get the M values to warmup for a given weight tensor."""
n
,
_
=
w
.
size
()
device
=
w
.
device
# Use optimal M values only if VLLM_DEEP_GEMM_WARMUP is set to "relax".
# Otherwise warmup all token sizes to avoid JIT compilation in hotpath
if
envs
.
VLLM_DEEP_GEMM_WARMUP
==
"relax"
:
return
_generate_optimal_warmup_m_values
(
max_tokens
,
n
,
device
)
else
:
assert
envs
.
VLLM_DEEP_GEMM_WARMUP
==
"full"
,
(
"Expected "
'VLLM_DEEP_GEMM_WARMUP env to be set to "full" but got '
f
"
{
envs
.
VLLM_DEEP_GEMM_WARMUP
}
"
)
return
list
(
range
(
1
,
max_tokens
+
1
))
def
_deepgemm_fp8_gemm_nt_warmup
(
w
:
torch
.
Tensor
,
ws
:
torch
.
Tensor
,
max_tokens
:
int
,
pbar
:
tqdm
|
None
=
None
,
):
if
w
.
size
()
in
FP8_GEMM_NT_WARMUP_CACHE
:
if
w
.
size
()
in
FP8_GEMM_NT_WARMUP_CACHE
:
return
return
...
@@ -189,26 +212,13 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens:
...
@@ -189,26 +212,13 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens:
)
)
out
=
torch
.
empty
((
max_tokens
,
n
),
device
=
device
,
dtype
=
torch
.
bfloat16
)
out
=
torch
.
empty
((
max_tokens
,
n
),
device
=
device
,
dtype
=
torch
.
bfloat16
)
# Use optimal M values only if VLLM_DEEP_GEMM_WARMUP is set to "relax".
m_values
=
_get_fp8_gemm_nt_m_values
(
w
,
max_tokens
)
# Otherwise warmup all token sizes to avoid JIT compilation in hotpath
if
envs
.
VLLM_DEEP_GEMM_WARMUP
==
"relax"
:
m_values
=
_generate_optimal_warmup_m_values
(
max_tokens
,
n
,
device
)
desc
=
f
"DeepGemm(fp8_gemm_nt) warmup (W=
{
w
.
size
()
}
) [relaxed]"
else
:
assert
envs
.
VLLM_DEEP_GEMM_WARMUP
==
"full"
,
(
"Expected "
'VLLM_DEEP_GEMM_WARMUP env to be set to "full" but got '
f
"
{
envs
.
VLLM_DEEP_GEMM_WARMUP
}
"
)
m_values
=
list
(
range
(
1
,
max_tokens
+
1
))
desc
=
f
"DeepGemm(fp8_gemm_nt) warmup (W=
{
w
.
size
()
}
) [all tokens]"
pbar
=
tqdm
(
total
=
len
(
m_values
),
desc
=
desc
)
for
num_tokens
in
m_values
:
for
num_tokens
in
m_values
:
fp8_gemm_nt
(
fp8_gemm_nt
(
(
a1q
[:
num_tokens
],
a1q_scales
[:
num_tokens
]),
(
w
,
ws
),
out
[:
num_tokens
]
(
a1q
[:
num_tokens
],
a1q_scales
[:
num_tokens
]),
(
w
,
ws
),
out
[:
num_tokens
]
)
)
if
pbar
is
not
None
:
pbar
.
update
(
1
)
pbar
.
update
(
1
)
FP8_GEMM_NT_WARMUP_CACHE
.
add
(
w
.
size
())
FP8_GEMM_NT_WARMUP_CACHE
.
add
(
w
.
size
())
...
@@ -217,20 +227,12 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens:
...
@@ -217,20 +227,12 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens:
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
:
set
[
torch
.
Size
]
=
set
()
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
:
set
[
torch
.
Size
]
=
set
()
def
_
deepgemm
_grouped_
fp8_
gemm_
nt_contiguous_warmup
(
def
_
get
_grouped_gemm_
params
(
w1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
num_topk
:
int
,
num_topk
:
int
,
max_tokens
:
int
,
max_tokens
:
int
,
):
)
->
tuple
[
int
,
int
,
torch
.
Tensor
]:
if
(
w1
.
size
()
in
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
and
w2
.
size
()
in
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
):
return
assert
w1
.
size
(
0
)
==
w2
.
size
(
0
),
"w1 and w2 must have the same number of experts"
assert
w1
.
size
(
0
)
==
w2
.
size
(
0
),
"w1 and w2 must have the same number of experts"
block_m
=
get_mk_alignment_for_contiguous_layout
()[
0
]
block_m
=
get_mk_alignment_for_contiguous_layout
()[
0
]
...
@@ -253,6 +255,27 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
...
@@ -253,6 +255,27 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
)
)
expert_ids
=
torch
.
repeat_interleave
(
expert_ids_block
,
block_m
,
dim
=
0
)
expert_ids
=
torch
.
repeat_interleave
(
expert_ids_block
,
block_m
,
dim
=
0
)
return
MAX_M
,
block_m
,
expert_ids
def
_deepgemm_grouped_fp8_gemm_nt_contiguous_warmup
(
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
num_topk
:
int
,
max_tokens
:
int
,
pbar
:
tqdm
|
None
=
None
,
):
if
(
w1
.
size
()
in
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
and
w2
.
size
()
in
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
):
return
MAX_M
,
block_m
,
expert_ids
=
_get_grouped_gemm_params
(
w1
,
w2
,
num_topk
,
max_tokens
)
device
=
w1
.
device
def
_warmup
(
w
:
torch
.
Tensor
,
w_scale
:
torch
.
Tensor
):
def
_warmup
(
w
:
torch
.
Tensor
,
w_scale
:
torch
.
Tensor
):
_
,
n
,
k
=
w
.
size
()
_
,
n
,
k
=
w
.
size
()
a1q
=
torch
.
empty
((
MAX_M
,
k
),
device
=
device
,
dtype
=
torch
.
float8_e4m3fn
)
a1q
=
torch
.
empty
((
MAX_M
,
k
),
device
=
device
,
dtype
=
torch
.
float8_e4m3fn
)
...
@@ -261,15 +284,8 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
...
@@ -261,15 +284,8 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
)
)
out
=
torch
.
empty
((
MAX_M
,
n
),
device
=
device
,
dtype
=
torch
.
bfloat16
)
out
=
torch
.
empty
((
MAX_M
,
n
),
device
=
device
,
dtype
=
torch
.
bfloat16
)
# Generate M values in block_m increments (already optimized for MoE)
m_values
=
list
(
range
(
block_m
,
MAX_M
+
1
,
block_m
))
m_values
=
list
(
range
(
block_m
,
MAX_M
+
1
,
block_m
))
pbar
=
tqdm
(
total
=
len
(
m_values
),
desc
=
f
"DeepGemm(m_grouped_fp8_gemm_nt_contiguous) warmup (W=
{
w
.
size
()
}
) "
f
"[
{
len
(
m_values
)
}
values, block_m=
{
block_m
}
]"
,
)
for
num_tokens
in
m_values
:
for
num_tokens
in
m_values
:
m_grouped_fp8_gemm_nt_contiguous
(
m_grouped_fp8_gemm_nt_contiguous
(
(
a1q
[:
num_tokens
],
a1q_scales
[:
num_tokens
]),
(
a1q
[:
num_tokens
],
a1q_scales
[:
num_tokens
]),
...
@@ -277,6 +293,7 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
...
@@ -277,6 +293,7 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
out
[:
num_tokens
],
out
[:
num_tokens
],
expert_ids
[:
num_tokens
],
expert_ids
[:
num_tokens
],
)
)
if
pbar
is
not
None
:
pbar
.
update
(
1
)
pbar
.
update
(
1
)
for
w
,
ws
in
[(
w1
,
w1_scale
),
(
w2
,
w2_scale
)]:
for
w
,
ws
in
[(
w1
,
w1_scale
),
(
w2
,
w2_scale
)]:
...
@@ -285,16 +302,18 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
...
@@ -285,16 +302,18 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
.
add
(
w
.
size
())
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
.
add
(
w
.
size
())
def
deepgemm_fp8_gemm_nt_warmup
(
model
:
torch
.
nn
.
Module
,
max_tokens
:
int
):
def
deepgemm_fp8_gemm_nt_warmup
(
model
:
torch
.
nn
.
Module
,
max_tokens
:
int
,
pbar
:
tqdm
|
None
=
None
):
dg_modules
=
[
m
for
m
in
model
.
modules
()
if
_fp8_linear_may_use_deep_gemm
(
m
)]
dg_modules
=
[
m
for
m
in
model
.
modules
()
if
_fp8_linear_may_use_deep_gemm
(
m
)]
for
dgm
in
dg_modules
:
for
dgm
in
dg_modules
:
w
,
ws
,
_
=
_extract_data_from_linear_base_module
(
dgm
)
w
,
ws
,
_
=
_extract_data_from_linear_base_module
(
dgm
)
_deepgemm_fp8_gemm_nt_warmup
(
w
=
w
,
ws
=
ws
,
max_tokens
=
max_tokens
)
_deepgemm_fp8_gemm_nt_warmup
(
w
=
w
,
ws
=
ws
,
max_tokens
=
max_tokens
,
pbar
=
pbar
)
def
deepgemm_grouped_fp8_gemm_nt_contiguous_warmup
(
def
deepgemm_grouped_fp8_gemm_nt_contiguous_warmup
(
model
:
torch
.
nn
.
Module
,
max_tokens
:
int
model
:
torch
.
nn
.
Module
,
max_tokens
:
int
,
pbar
:
tqdm
|
None
=
None
):
):
dg_modules
=
[
dg_modules
=
[
m
for
m
in
model
.
modules
()
if
_fused_moe_grouped_gemm_may_use_deep_gemm
(
m
)
m
for
m
in
model
.
modules
()
if
_fused_moe_grouped_gemm_may_use_deep_gemm
(
m
)
...
@@ -305,10 +324,48 @@ def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
...
@@ -305,10 +324,48 @@ def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
dgm
dgm
)
)
_deepgemm_grouped_fp8_gemm_nt_contiguous_warmup
(
_deepgemm_grouped_fp8_gemm_nt_contiguous_warmup
(
w13
,
w2
,
w13_scale
,
w2_scale
,
num_topk
,
max_tokens
w13
,
w2
,
w13_scale
,
w2_scale
,
num_topk
,
max_tokens
,
pbar
=
pbar
)
)
def
_count_warmup_iterations
(
model
:
torch
.
nn
.
Module
,
max_tokens
:
int
)
->
int
:
seen_fp8_sizes
:
set
[
torch
.
Size
]
=
set
(
FP8_GEMM_NT_WARMUP_CACHE
)
seen_grouped_sizes
:
set
[
torch
.
Size
]
=
set
(
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
)
total
=
0
for
m
in
model
.
modules
():
if
_fp8_linear_may_use_deep_gemm
(
m
):
w
,
_
,
_
=
_extract_data_from_linear_base_module
(
m
)
if
w
.
size
()
not
in
seen_fp8_sizes
:
total
+=
len
(
_get_fp8_gemm_nt_m_values
(
w
,
max_tokens
))
seen_fp8_sizes
.
add
(
w
.
size
())
elif
_fused_moe_grouped_gemm_may_use_deep_gemm
(
m
):
w13
,
_
,
w2
,
_
,
num_topk
=
_extract_data_from_fused_moe_module
(
m
)
if
w13
.
size
()
in
seen_grouped_sizes
and
w2
.
size
()
in
seen_grouped_sizes
:
continue
MAX_M
,
block_m
,
_
=
_get_grouped_gemm_params
(
w13
,
w2
,
num_topk
,
max_tokens
)
n_values
=
(
MAX_M
-
block_m
)
//
block_m
+
1
if
w13
.
size
()
not
in
seen_grouped_sizes
:
total
+=
n_values
seen_grouped_sizes
.
add
(
w13
.
size
())
if
w2
.
size
()
not
in
seen_grouped_sizes
:
total
+=
n_values
seen_grouped_sizes
.
add
(
w2
.
size
())
return
total
def
deep_gemm_warmup
(
model
:
torch
.
nn
.
Module
,
max_tokens
:
int
):
def
deep_gemm_warmup
(
model
:
torch
.
nn
.
Module
,
max_tokens
:
int
):
deepgemm_fp8_gemm_nt_warmup
(
model
,
max_tokens
)
total
=
_count_warmup_iterations
(
model
,
max_tokens
)
deepgemm_grouped_fp8_gemm_nt_contiguous_warmup
(
model
,
max_tokens
)
if
total
==
0
:
return
# Only show progress bar on rank 0 to avoid cluttered output
if
is_global_first_rank
():
with
tqdm
(
total
=
total
,
desc
=
"DeepGEMM warmup"
)
as
pbar
:
deepgemm_fp8_gemm_nt_warmup
(
model
,
max_tokens
,
pbar
)
deepgemm_grouped_fp8_gemm_nt_contiguous_warmup
(
model
,
max_tokens
,
pbar
)
else
:
deepgemm_fp8_gemm_nt_warmup
(
model
,
max_tokens
,
None
)
deepgemm_grouped_fp8_gemm_nt_contiguous_warmup
(
model
,
max_tokens
,
None
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment