Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c683d11c
Unverified
Commit
c683d11c
authored
Feb 19, 2026
by
Wentao Ye
Committed by
GitHub
Feb 19, 2026
Browse files
[Refactor] Deprecate `head_first` for `chunk_gated_delta_rule` (#34263)
Signed-off-by:
yewentao256
<
zhyanwentao@126.com
>
parent
3eff45d7
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
8 additions
and
36 deletions
+8
-36
vllm/model_executor/layers/fla/ops/chunk.py
vllm/model_executor/layers/fla/ops/chunk.py
+8
-29
vllm/model_executor/models/llava_onevision.py
vllm/model_executor/models/llava_onevision.py
+0
-1
vllm/model_executor/models/qwen3_next.py
vllm/model_executor/models/qwen3_next.py
+0
-6
No files found.
vllm/model_executor/layers/fla/ops/chunk.py
View file @
c683d11c
...
...
@@ -10,7 +10,6 @@
import
warnings
import
torch
from
einops
import
rearrange
from
.chunk_delta_h
import
chunk_gated_delta_rule_fwd_h
from
.chunk_o
import
chunk_fwd_o
...
...
@@ -119,21 +118,20 @@ def chunk_gated_delta_rule(
initial_state
:
torch
.
Tensor
=
None
,
output_final_state
:
bool
=
False
,
cu_seqlens
:
torch
.
LongTensor
|
None
=
None
,
head_first
:
bool
=
False
,
use_qk_l2norm_in_kernel
:
bool
=
False
,
):
r
"""
Args:
q (torch.Tensor):
q
ueries of shape `[B, T, H,
K]` if `head_first=False` else `[B, H, T,
K]`.
Q
ueries of shape `[B, T, H, K]`.
k (torch.Tensor):
k
eys of shape `[B, T, H,
K]` if `head_first=False` else `[B, H, T,
K]`.
K
eys of shape `[B, T, H, K]`.
v (torch.Tensor):
v
alues of shape `[B, T, H,
V]` if `head_first=False` else `[B, H, T,
V]`.
V
alues of shape `[B, T, H, V]`.
g (torch.Tensor):
(forget)
g
ating tensor (in log space!) of shape `[B, T, H]`
if `head_first=False` else `[B, H, T]`
.
(forget)
G
ating tensor (in log space!) of shape `[B, T, H]`.
beta (torch.Tensor):
b
etas of shape `[B, T, H]`
if `head_first=False` else `[B, H, T]`
.
B
etas of shape `[B, T, H]`.
scale (Optional[int]):
Scale factor for the RetNet attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
...
...
@@ -146,13 +144,9 @@ def chunk_gated_delta_rule(
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
head_first (Optional[bool]):
Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
Default: `False`.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, H,
V]` if `head_first=False` else `[B, H, T,
V]`.
Outputs of shape `[B, T, H, V]`.
final_state (torch.Tensor):
Final state of shape `[N, H, V, K]` if `output_final_state=True` else `None`.
...
...
@@ -189,24 +183,11 @@ def chunk_gated_delta_rule(
assert
q
.
dtype
!=
torch
.
float32
,
(
"ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
)
assert
len
(
beta
.
shape
)
==
3
,
(
"beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
)
if
head_first
:
raise
DeprecationWarning
(
"head_first is deprecated and will be removed in a future version. "
"Please use head_first=False for now instead."
,
stacklevel
=
2
,
)
q
,
k
,
v
,
beta
,
g
=
map
(
lambda
x
:
rearrange
(
x
,
"b h t ... -> b t h ..."
),
(
q
,
k
,
v
,
beta
,
g
)
)
if
not
head_first
and
q
.
shape
[
1
]
<
q
.
shape
[
2
]:
assert
len
(
beta
.
shape
)
==
3
,
"beta must be of shape [B, T, H]."
if
q
.
shape
[
1
]
<
q
.
shape
[
2
]:
warnings
.
warn
(
f
"Input tensor shape suggests potential format mismatch: seq_len (
{
q
.
shape
[
1
]
}
) < num_heads (
{
q
.
shape
[
2
]
}
). "
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
"when head_first=False was specified. "
"Please verify your input tensor format matches the expected shape [B, T, H, ...]."
,
stacklevel
=
2
,
)
...
...
@@ -235,6 +216,4 @@ def chunk_gated_delta_rule(
cu_seqlens
,
use_qk_l2norm_in_kernel
,
)
if
head_first
:
o
=
rearrange
(
o
,
"b t h ... -> b h t ..."
)
return
o
,
final_state
vllm/model_executor/models/llava_onevision.py
View file @
c683d11c
...
...
@@ -867,7 +867,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
mm_input_by_modality
=
self
.
_parse_and_validate_multimodal_inputs
(
**
kwargs
)
if
not
mm_input_by_modality
:
return
[]
return
None
# The result multimodal_embeddings is tuple of tensors, with each
# tensor corresponding to a multimodal data item (image or video).
...
...
vllm/model_executor/models/qwen3_next.py
View file @
c683d11c
...
...
@@ -115,7 +115,6 @@ def fi_chunk_gated_delta_rule(
initial_state
:
torch
.
Tensor
,
output_final_state
:
bool
,
cu_seqlens
:
torch
.
LongTensor
|
None
=
None
,
head_first
:
bool
=
False
,
use_qk_l2norm_in_kernel
:
bool
=
True
,
):
from
flashinfer.gdn_prefill
import
(
...
...
@@ -172,7 +171,6 @@ class ChunkGatedDeltaRule(CustomOp):
initial_state
:
torch
.
Tensor
,
output_final_state
:
bool
,
cu_seqlens
:
torch
.
LongTensor
|
None
=
None
,
head_first
:
bool
=
False
,
use_qk_l2norm_in_kernel
:
bool
=
True
,
):
return
fi_chunk_gated_delta_rule
(
...
...
@@ -184,7 +182,6 @@ class ChunkGatedDeltaRule(CustomOp):
initial_state
=
initial_state
,
output_final_state
=
output_final_state
,
cu_seqlens
=
cu_seqlens
,
head_first
=
head_first
,
use_qk_l2norm_in_kernel
=
use_qk_l2norm_in_kernel
,
)
...
...
@@ -198,7 +195,6 @@ class ChunkGatedDeltaRule(CustomOp):
initial_state
:
torch
.
Tensor
,
output_final_state
:
bool
,
cu_seqlens
:
torch
.
LongTensor
|
None
=
None
,
head_first
:
bool
=
False
,
use_qk_l2norm_in_kernel
:
bool
=
True
,
):
return
fla_chunk_gated_delta_rule
(
...
...
@@ -210,7 +206,6 @@ class ChunkGatedDeltaRule(CustomOp):
initial_state
=
initial_state
,
output_final_state
=
output_final_state
,
cu_seqlens
=
cu_seqlens
,
head_first
=
head_first
,
use_qk_l2norm_in_kernel
=
use_qk_l2norm_in_kernel
,
)
...
...
@@ -790,7 +785,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
initial_state
=
initial_state
,
output_final_state
=
True
,
cu_seqlens
=
non_spec_query_start_loc
,
head_first
=
False
,
use_qk_l2norm_in_kernel
=
True
,
)
# Init cache
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment