"vllm/vscode:/vscode.git/clone" did not exist on "0d8ce320a2e5eaf9fc025c6744a8a89419f59c38"
Unverified Commit c683d11c authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Refactor] Deprecate `head_first` for `chunk_gated_delta_rule` (#34263)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent 3eff45d7
...@@ -10,7 +10,6 @@ ...@@ -10,7 +10,6 @@
import warnings import warnings
import torch import torch
from einops import rearrange
from .chunk_delta_h import chunk_gated_delta_rule_fwd_h from .chunk_delta_h import chunk_gated_delta_rule_fwd_h
from .chunk_o import chunk_fwd_o from .chunk_o import chunk_fwd_o
...@@ -119,21 +118,20 @@ def chunk_gated_delta_rule( ...@@ -119,21 +118,20 @@ def chunk_gated_delta_rule(
initial_state: torch.Tensor = None, initial_state: torch.Tensor = None,
output_final_state: bool = False, output_final_state: bool = False,
cu_seqlens: torch.LongTensor | None = None, cu_seqlens: torch.LongTensor | None = None,
head_first: bool = False,
use_qk_l2norm_in_kernel: bool = False, use_qk_l2norm_in_kernel: bool = False,
): ):
r""" r"""
Args: Args:
q (torch.Tensor): q (torch.Tensor):
queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. Queries of shape `[B, T, H, K]`.
k (torch.Tensor): k (torch.Tensor):
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. Keys of shape `[B, T, H, K]`.
v (torch.Tensor): v (torch.Tensor):
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. Values of shape `[B, T, H, V]`.
g (torch.Tensor): g (torch.Tensor):
(forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. (forget) Gating tensor (in log space!) of shape `[B, T, H]`.
beta (torch.Tensor): beta (torch.Tensor):
betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. Betas of shape `[B, T, H]`.
scale (Optional[int]): scale (Optional[int]):
Scale factor for the RetNet attention scores. Scale factor for the RetNet attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`. If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
...@@ -146,13 +144,9 @@ def chunk_gated_delta_rule( ...@@ -146,13 +144,9 @@ def chunk_gated_delta_rule(
cu_seqlens (torch.LongTensor): cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training, Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API. consistent with the FlashAttention API.
head_first (Optional[bool]):
Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
Default: `False`.
Returns: Returns:
o (torch.Tensor): o (torch.Tensor):
Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. Outputs of shape `[B, T, H, V]`.
final_state (torch.Tensor): final_state (torch.Tensor):
Final state of shape `[N, H, V, K]` if `output_final_state=True` else `None`. Final state of shape `[N, H, V, K]` if `output_final_state=True` else `None`.
...@@ -189,24 +183,11 @@ def chunk_gated_delta_rule( ...@@ -189,24 +183,11 @@ def chunk_gated_delta_rule(
assert q.dtype != torch.float32, ( assert q.dtype != torch.float32, (
"ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16." "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
) )
assert len(beta.shape) == 3, ( assert len(beta.shape) == 3, "beta must be of shape [B, T, H]."
"beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise." if q.shape[1] < q.shape[2]:
)
if head_first:
raise DeprecationWarning(
"head_first is deprecated and will be removed in a future version. "
"Please use head_first=False for now instead.",
stacklevel=2,
)
q, k, v, beta, g = map(
lambda x: rearrange(x, "b h t ... -> b t h ..."), (q, k, v, beta, g)
)
if not head_first and q.shape[1] < q.shape[2]:
warnings.warn( warnings.warn(
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
"This may indicate the inputs were passed in head-first format [B, H, T, ...] " "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
"when head_first=False was specified. "
"Please verify your input tensor format matches the expected shape [B, T, H, ...].", "Please verify your input tensor format matches the expected shape [B, T, H, ...].",
stacklevel=2, stacklevel=2,
) )
...@@ -235,6 +216,4 @@ def chunk_gated_delta_rule( ...@@ -235,6 +216,4 @@ def chunk_gated_delta_rule(
cu_seqlens, cu_seqlens,
use_qk_l2norm_in_kernel, use_qk_l2norm_in_kernel,
) )
if head_first:
o = rearrange(o, "b t h ... -> b h t ...")
return o, final_state return o, final_state
...@@ -867,7 +867,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp ...@@ -867,7 +867,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
if not mm_input_by_modality: if not mm_input_by_modality:
return [] return []
return None
# The result multimodal_embeddings is tuple of tensors, with each # The result multimodal_embeddings is tuple of tensors, with each
# tensor corresponding to a multimodal data item (image or video). # tensor corresponding to a multimodal data item (image or video).
......
...@@ -115,7 +115,6 @@ def fi_chunk_gated_delta_rule( ...@@ -115,7 +115,6 @@ def fi_chunk_gated_delta_rule(
initial_state: torch.Tensor, initial_state: torch.Tensor,
output_final_state: bool, output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None, cu_seqlens: torch.LongTensor | None = None,
head_first: bool = False,
use_qk_l2norm_in_kernel: bool = True, use_qk_l2norm_in_kernel: bool = True,
): ):
from flashinfer.gdn_prefill import ( from flashinfer.gdn_prefill import (
...@@ -172,7 +171,6 @@ class ChunkGatedDeltaRule(CustomOp): ...@@ -172,7 +171,6 @@ class ChunkGatedDeltaRule(CustomOp):
initial_state: torch.Tensor, initial_state: torch.Tensor,
output_final_state: bool, output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None, cu_seqlens: torch.LongTensor | None = None,
head_first: bool = False,
use_qk_l2norm_in_kernel: bool = True, use_qk_l2norm_in_kernel: bool = True,
): ):
return fi_chunk_gated_delta_rule( return fi_chunk_gated_delta_rule(
...@@ -184,7 +182,6 @@ class ChunkGatedDeltaRule(CustomOp): ...@@ -184,7 +182,6 @@ class ChunkGatedDeltaRule(CustomOp):
initial_state=initial_state, initial_state=initial_state,
output_final_state=output_final_state, output_final_state=output_final_state,
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
head_first=head_first,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
) )
...@@ -198,7 +195,6 @@ class ChunkGatedDeltaRule(CustomOp): ...@@ -198,7 +195,6 @@ class ChunkGatedDeltaRule(CustomOp):
initial_state: torch.Tensor, initial_state: torch.Tensor,
output_final_state: bool, output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None, cu_seqlens: torch.LongTensor | None = None,
head_first: bool = False,
use_qk_l2norm_in_kernel: bool = True, use_qk_l2norm_in_kernel: bool = True,
): ):
return fla_chunk_gated_delta_rule( return fla_chunk_gated_delta_rule(
...@@ -210,7 +206,6 @@ class ChunkGatedDeltaRule(CustomOp): ...@@ -210,7 +206,6 @@ class ChunkGatedDeltaRule(CustomOp):
initial_state=initial_state, initial_state=initial_state,
output_final_state=output_final_state, output_final_state=output_final_state,
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
head_first=head_first,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
) )
...@@ -790,7 +785,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): ...@@ -790,7 +785,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
initial_state=initial_state, initial_state=initial_state,
output_final_state=True, output_final_state=True,
cu_seqlens=non_spec_query_start_loc, cu_seqlens=non_spec_query_start_loc,
head_first=False,
use_qk_l2norm_in_kernel=True, use_qk_l2norm_in_kernel=True,
) )
# Init cache # Init cache
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment