Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3e64fe4a
Unverified
Commit
3e64fe4a
authored
Mar 12, 2026
by
Xu Jinyang
Committed by
GitHub
Mar 12, 2026
Browse files
[Bugfix] Warm up Triton autotuner for GDN layers during V1 profiling (#36599)
Signed-off-by:
AuYang
<
459461160@qq.com
>
parent
8cb24d3a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
98 additions
and
1 deletion
+98
-1
vllm/model_executor/models/qwen3_next.py
vllm/model_executor/models/qwen3_next.py
+98
-1
No files found.
vllm/model_executor/models/qwen3_next.py
View file @
3e64fe4a
...
@@ -645,6 +645,101 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
...
@@ -645,6 +645,101 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
core_attn_out
=
rearrange
(
core_attn_out
,
"... h d -> ... (h d)"
)
core_attn_out
=
rearrange
(
core_attn_out
,
"... h d -> ... (h d)"
)
output
[:
num_tokens
],
_
=
self
.
out_proj
(
core_attn_out
)
output
[:
num_tokens
],
_
=
self
.
out_proj
(
core_attn_out
)
def
_warmup_prefill_kernels
(
self
,
mixed_qkv
:
torch
.
Tensor
)
->
None
:
"""Warm up GDN prefill kernels during V1 profiling.
During V1 profile runs, ``_forward_core`` returns early because
``attn_metadata`` is ``None``, so the autotuned kernels used by
``chunk_gated_delta_rule`` (e.g. ``solve_tril``,
``chunk_scaled_dot_kkt``) are never invoked. After profiling,
vLLM allocates KV cache using most of the remaining GPU memory.
When the first real inference triggers the autotuner it OOMs
because there is not enough memory left for benchmarking.
This method runs minimal forward passes through
``chunk_gated_delta_rule`` with small dummy tensors to force
autotuning while GPU memory is still plentiful. The autotuner
results are cached globally, so only the first layer incurs
actual benchmarking cost.
Most kernels use a fixed ``BT = chunk_size`` (64), but
``chunk_fwd_kernel_o`` recomputes ``BT`` from the sequence
length: ``min(64, max(16, next_power_of_2(T)))``. Since ``BT``
is part of its autotune key, we run warmup passes with T = 16,
32, and 64 to cover all possible ``BT`` values.
The decode path uses ``fused_sigmoid_gating_delta_rule_update``
which has fixed kernel parameters (no autotuning), so only the
prefill (chunked) path needs warming up.
"""
if
hasattr
(
self
,
"_prefill_kernels_warmed_up"
):
return
self
.
_prefill_kernels_warmed_up
=
True
device
=
mixed_qkv
.
device
dtype
=
mixed_qkv
.
dtype
num_k_heads
=
self
.
num_k_heads
//
self
.
tp_size
num_v_heads
=
self
.
num_v_heads
//
self
.
tp_size
_
,
state_dtype
=
self
.
get_state_dtype
()
# Run warmup for each possible BT value of chunk_fwd_kernel_o:
# T=16 → BT=16, T=32 → BT=32, T=64 → BT=64.
# Other kernels always use BT=chunk_size(64), so their autotune
# cache is populated on the first pass and reused thereafter.
for
T
in
(
16
,
32
,
64
):
q
=
torch
.
randn
(
1
,
T
,
num_k_heads
,
self
.
head_k_dim
,
device
=
device
,
dtype
=
dtype
)
k
=
torch
.
randn
(
1
,
T
,
num_k_heads
,
self
.
head_k_dim
,
device
=
device
,
dtype
=
dtype
)
v
=
torch
.
randn
(
1
,
T
,
num_v_heads
,
self
.
head_v_dim
,
device
=
device
,
dtype
=
dtype
)
g
=
torch
.
randn
(
1
,
T
,
num_v_heads
,
device
=
device
,
dtype
=
dtype
)
beta
=
torch
.
randn
(
1
,
T
,
num_v_heads
,
device
=
device
,
dtype
=
dtype
)
state
=
torch
.
zeros
(
1
,
num_v_heads
,
self
.
head_v_dim
,
self
.
head_k_dim
,
device
=
device
,
dtype
=
state_dtype
,
)
cu_seqlens
=
torch
.
tensor
([
0
,
T
],
device
=
device
,
dtype
=
torch
.
long
)
try
:
self
.
chunk_gated_delta_rule
(
q
=
q
,
k
=
k
,
v
=
v
,
g
=
g
,
beta
=
beta
,
initial_state
=
state
,
output_final_state
=
False
,
cu_seqlens
=
cu_seqlens
,
use_qk_l2norm_in_kernel
=
True
,
)
except
Exception
:
logger
.
warning
(
"GDN prefill kernel warmup (T=%d) failed for "
"layer %s. First inference may OOM due to "
"autotuner."
,
T
,
self
.
prefix
,
exc_info
=
True
,
)
else
:
logger
.
debug
(
"GDN prefill kernel warmup (T=%d) completed for layer %s"
,
T
,
self
.
prefix
,
)
finally
:
del
q
,
k
,
v
,
g
,
beta
,
state
,
cu_seqlens
torch
.
accelerator
.
empty_cache
()
def
_forward_core
(
def
_forward_core
(
self
,
self
,
mixed_qkv
:
torch
.
Tensor
,
mixed_qkv
:
torch
.
Tensor
,
...
@@ -659,7 +754,9 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
...
@@ -659,7 +754,9 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
attn_metadata
:
AttentionMetadata
=
forward_context
.
attn_metadata
attn_metadata
:
AttentionMetadata
=
forward_context
.
attn_metadata
if
attn_metadata
is
None
:
if
attn_metadata
is
None
:
# V1 profile run
# V1 profile run — warm up prefill kernels so that
# autotuning completes before KV cache allocation.
self
.
_warmup_prefill_kernels
(
mixed_qkv
)
return
return
assert
isinstance
(
attn_metadata
,
dict
)
assert
isinstance
(
attn_metadata
,
dict
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment