"vllm/engine/async_llm_engine.py" did not exist on "c3442c1f6fabe54adb82d2d676920c5f31b9834e"
Unverified Commit c683d11c authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Refactor] Deprecate `head_first` for `chunk_gated_delta_rule` (#34263)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent 3eff45d7
......@@ -10,7 +10,6 @@
import warnings
import torch
from einops import rearrange
from .chunk_delta_h import chunk_gated_delta_rule_fwd_h
from .chunk_o import chunk_fwd_o
......@@ -119,21 +118,20 @@ def chunk_gated_delta_rule(
initial_state: torch.Tensor = None,
output_final_state: bool = False,
cu_seqlens: torch.LongTensor | None = None,
head_first: bool = False,
use_qk_l2norm_in_kernel: bool = False,
):
r"""
Args:
q (torch.Tensor):
queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
Queries of shape `[B, T, H, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
Keys of shape `[B, T, H, K]`.
v (torch.Tensor):
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
Values of shape `[B, T, H, V]`.
g (torch.Tensor):
(forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
(forget) Gating tensor (in log space!) of shape `[B, T, H]`.
beta (torch.Tensor):
betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
Betas of shape `[B, T, H]`.
scale (Optional[int]):
Scale factor for the RetNet attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
......@@ -146,13 +144,9 @@ def chunk_gated_delta_rule(
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
head_first (Optional[bool]):
Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
Default: `False`.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
Outputs of shape `[B, T, H, V]`.
final_state (torch.Tensor):
Final state of shape `[N, H, V, K]` if `output_final_state=True` else `None`.
......@@ -189,24 +183,11 @@ def chunk_gated_delta_rule(
assert q.dtype != torch.float32, (
"ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
)
assert len(beta.shape) == 3, (
"beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
)
if head_first:
raise DeprecationWarning(
"head_first is deprecated and will be removed in a future version. "
"Please use head_first=False for now instead.",
stacklevel=2,
)
q, k, v, beta, g = map(
lambda x: rearrange(x, "b h t ... -> b t h ..."), (q, k, v, beta, g)
)
if not head_first and q.shape[1] < q.shape[2]:
assert len(beta.shape) == 3, "beta must be of shape [B, T, H]."
if q.shape[1] < q.shape[2]:
warnings.warn(
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
"when head_first=False was specified. "
"Please verify your input tensor format matches the expected shape [B, T, H, ...].",
stacklevel=2,
)
......@@ -235,6 +216,4 @@ def chunk_gated_delta_rule(
cu_seqlens,
use_qk_l2norm_in_kernel,
)
if head_first:
o = rearrange(o, "b t h ... -> b h t ...")
return o, final_state
......@@ -867,7 +867,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
if not mm_input_by_modality:
return []
return None
# The result multimodal_embeddings is tuple of tensors, with each
# tensor corresponding to a multimodal data item (image or video).
......
......@@ -115,7 +115,6 @@ def fi_chunk_gated_delta_rule(
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
head_first: bool = False,
use_qk_l2norm_in_kernel: bool = True,
):
from flashinfer.gdn_prefill import (
......@@ -172,7 +171,6 @@ class ChunkGatedDeltaRule(CustomOp):
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
head_first: bool = False,
use_qk_l2norm_in_kernel: bool = True,
):
return fi_chunk_gated_delta_rule(
......@@ -184,7 +182,6 @@ class ChunkGatedDeltaRule(CustomOp):
initial_state=initial_state,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
head_first=head_first,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
)
......@@ -198,7 +195,6 @@ class ChunkGatedDeltaRule(CustomOp):
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
head_first: bool = False,
use_qk_l2norm_in_kernel: bool = True,
):
return fla_chunk_gated_delta_rule(
......@@ -210,7 +206,6 @@ class ChunkGatedDeltaRule(CustomOp):
initial_state=initial_state,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
head_first=head_first,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
)
......@@ -790,7 +785,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
initial_state=initial_state,
output_final_state=True,
cu_seqlens=non_spec_query_start_loc,
head_first=False,
use_qk_l2norm_in_kernel=True,
)
# Init cache
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment