Unverified Commit b779eb33 authored by Xu Jinyang's avatar Xu Jinyang Committed by GitHub
Browse files

[Model] Sync upstream BT=chunk_size fix for GDN chunk_fwd_kernel_o, simplify...


[Model] Sync upstream BT=chunk_size fix for GDN chunk_fwd_kernel_o, simplify warmup to single pass (#38343)
Signed-off-by: default avatarAuYang <459461160@qq.com>
Co-authored-by: default avatarJiangyun Zhu <riverclouds.zhu@qq.com>
parent 077a9a8e
...@@ -16,7 +16,7 @@ from vllm.triton_utils import tl, triton ...@@ -16,7 +16,7 @@ from vllm.triton_utils import tl, triton
from .index import prepare_chunk_indices from .index import prepare_chunk_indices
from .op import exp from .op import exp
from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper from .utils import FLA_CHUNK_SIZE, check_shared_mem, is_nvidia_hopper
BKV_LIST = [64, 128] if check_shared_mem() else [32, 64] BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8] NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
...@@ -146,11 +146,11 @@ def chunk_fwd_o( ...@@ -146,11 +146,11 @@ def chunk_fwd_o(
g: torch.Tensor | None = None, # cumsum of log decay g: torch.Tensor | None = None, # cumsum of log decay
scale: float | None = None, scale: float | None = None,
cu_seqlens: torch.Tensor | None = None, cu_seqlens: torch.Tensor | None = None,
chunk_size: int = 64, chunk_size: int = FLA_CHUNK_SIZE,
) -> torch.Tensor: ) -> torch.Tensor:
B, T, Hg, K, V = *q.shape, v.shape[-1] B, T, Hg, K, V = *q.shape, v.shape[-1]
H = v.shape[-2] H = v.shape[-2]
BT = 64 if FLA_GDN_FIX_BT else min(chunk_size, max(16, triton.next_power_of_2(T))) BT = chunk_size
chunk_indices = ( chunk_indices = (
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
) )
......
...@@ -24,10 +24,12 @@ logger = logging.getLogger(__name__) ...@@ -24,10 +24,12 @@ logger = logging.getLogger(__name__)
COMPILER_MODE = os.getenv("FLA_COMPILER_MODE") == "1" COMPILER_MODE = os.getenv("FLA_COMPILER_MODE") == "1"
FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1" FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1"
FLA_GDN_FIX_BT = os.getenv("FLA_GDN_FIX_BT", "0") == "1"
SUPPRESS_LEVEL = int(os.getenv("GDN_RECOMPUTE_SUPPRESS_LEVEL", "0")) SUPPRESS_LEVEL = int(os.getenv("GDN_RECOMPUTE_SUPPRESS_LEVEL", "0"))
# Default chunk size used across FLA triton kernels (kda, chunk, chunk_o, etc.)
FLA_CHUNK_SIZE = 64
def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
""" """
......
...@@ -28,6 +28,7 @@ from vllm.model_executor.layers.fla.ops import ( ...@@ -28,6 +28,7 @@ from vllm.model_executor.layers.fla.ops import (
fused_sigmoid_gating_delta_rule_update, fused_sigmoid_gating_delta_rule_update,
) )
from vllm.model_executor.layers.fla.ops.chunk import l2norm_fwd from vllm.model_executor.layers.fla.ops.chunk import l2norm_fwd
from vllm.model_executor.layers.fla.ops.utils import FLA_CHUNK_SIZE
from vllm.model_executor.layers.layernorm import RMSNormGated from vllm.model_executor.layers.layernorm import RMSNormGated
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
...@@ -581,11 +582,9 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase): ...@@ -581,11 +582,9 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
results are cached globally, so only the first layer incurs results are cached globally, so only the first layer incurs
actual benchmarking cost. actual benchmarking cost.
Most kernels use a fixed ``BT = chunk_size`` (64), but All kernels including ``chunk_fwd_kernel_o`` now use a fixed
``chunk_fwd_kernel_o`` recomputes ``BT`` from the sequence ``BT = chunk_size`` (64). A single warmup pass with T = 64
length: ``min(64, max(16, next_power_of_2(T)))``. Since ``BT`` is sufficient to populate the autotuner cache.
is part of its autotune key, we run warmup passes with T = 16,
32, and 64 to cover all possible ``BT`` values.
The decode path uses ``fused_sigmoid_gating_delta_rule_update`` The decode path uses ``fused_sigmoid_gating_delta_rule_update``
which has fixed kernel parameters (no autotuning), so only the which has fixed kernel parameters (no autotuning), so only the
...@@ -601,66 +600,58 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase): ...@@ -601,66 +600,58 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
num_v_heads = self.num_v_heads // self.tp_size num_v_heads = self.num_v_heads // self.tp_size
_, state_dtype = self.get_state_dtype() _, state_dtype = self.get_state_dtype()
# Run warmup for each possible BT value of chunk_fwd_kernel_o: # All kernels use BT = chunk_size (FLA_CHUNK_SIZE4), so a single pass with
# T=16 → BT=16, T=32 → BT=32, T=64 → BT=64. # T = chunk_size is sufficient to populate every autotuner cache.
# Other kernels always use BT=chunk_size(64), so their autotune T = FLA_CHUNK_SIZE
# cache is populated on the first pass and reused thereafter. q = torch.randn(1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype)
for T in (16, 32, 64): k = torch.randn(1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype)
q = torch.randn( v = torch.randn(1, T, num_v_heads, self.head_v_dim, device=device, dtype=dtype)
1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype # NOTE: g and beta must have the same dtypes as during
) # inference, so we construct them with the same function
k = torch.randn( # (fused_gdn_gating). dummy_a and dummy_b are throwaway
1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype # inputs required by that function.
dummy_a = torch.randn(T, num_v_heads, device=device, dtype=dtype)
dummy_b = torch.randn(T, num_v_heads, device=device, dtype=dtype)
g, beta = fused_gdn_gating(self.A_log, dummy_a, dummy_b, self.dt_bias)
state = torch.zeros(
1,
num_v_heads,
self.head_v_dim,
self.head_k_dim,
device=device,
dtype=state_dtype,
)
cu_seqlens = torch.tensor([0, T], device=device, dtype=torch.int32)
try:
self.chunk_gated_delta_rule(
q=q,
k=k,
v=v,
g=g,
beta=beta,
initial_state=state,
output_final_state=True,
cu_seqlens=cu_seqlens,
use_qk_l2norm_in_kernel=True,
) )
v = torch.randn( except Exception:
1, T, num_v_heads, self.head_v_dim, device=device, dtype=dtype logger.warning(
"GDN prefill kernel warmup (T=%d) failed for "
"layer %s. First inference may OOM due to "
"autotuner.",
T,
self.prefix,
exc_info=True,
) )
# NOTE: g and beta must have the same dtypes as during else:
# inference, so we construct them with the same function logger.debug(
# (fused_gdn_gating). dummy_a and dummy_b are throwaway "GDN prefill kernel warmup (T=%d) completed for layer %s",
# inputs required by that function. T,
dummy_a = torch.randn(T, num_v_heads, device=device, dtype=dtype) self.prefix,
dummy_b = torch.randn(T, num_v_heads, device=device, dtype=dtype)
g, beta = fused_gdn_gating(self.A_log, dummy_a, dummy_b, self.dt_bias)
state = torch.zeros(
1,
num_v_heads,
self.head_v_dim,
self.head_k_dim,
device=device,
dtype=state_dtype,
) )
cu_seqlens = torch.tensor([0, T], device=device, dtype=torch.int32) finally:
del q, k, v, dummy_a, dummy_b, g, beta, state, cu_seqlens
try:
self.chunk_gated_delta_rule(
q=q,
k=k,
v=v,
g=g,
beta=beta,
initial_state=state,
output_final_state=True,
cu_seqlens=cu_seqlens,
use_qk_l2norm_in_kernel=True,
)
except Exception:
logger.warning(
"GDN prefill kernel warmup (T=%d) failed for "
"layer %s. First inference may OOM due to "
"autotuner.",
T,
self.prefix,
exc_info=True,
)
else:
logger.debug(
"GDN prefill kernel warmup (T=%d) completed for layer %s",
T,
self.prefix,
)
finally:
del q, k, v, dummy_a, dummy_b, g, beta, state, cu_seqlens
torch.accelerator.empty_cache() torch.accelerator.empty_cache()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment