Unverified Commit 9853a3c1 authored by Ibrahim Arshad's avatar Ibrahim Arshad Committed by GitHub
Browse files

fix(gdn): Align prefill warmup with real prefill path (#39169)


Signed-off-by: default avatarIbrahim Arshad <38925737+ibrahim1023@users.noreply.github.com>
parent bb6047db
......@@ -702,19 +702,33 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
num_v_heads = self.num_v_heads // self.tp_size
_, state_dtype = self.get_state_dtype()
# All kernels use BT = chunk_size (FLA_CHUNK_SIZE4), so a single pass with
# T = chunk_size is sufficient to populate every autotuner cache.
# All kernels use BT = chunk_size, so a single pass with T = chunk_size
# is sufficient to populate every autotuner cache. Mirror the real
# prefill path here: build q/k/v/g/beta via fused_post_conv_prep and
# then run chunk_gated_delta_rule with in-kernel L2 norm disabled.
T = FLA_CHUNK_SIZE
q = torch.randn(1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype)
k = torch.randn(1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype)
v = torch.randn(1, T, num_v_heads, self.head_v_dim, device=device, dtype=dtype)
# NOTE: g and beta must have the same dtypes as during
# inference, so we construct them with the same function
# (fused_gdn_gating). dummy_a and dummy_b are throwaway
# inputs required by that function.
dummy_mixed_qkv = torch.randn(
T, mixed_qkv.shape[-1], device=device, dtype=dtype
)
dummy_a = torch.randn(T, num_v_heads, device=device, dtype=dtype)
dummy_b = torch.randn(T, num_v_heads, device=device, dtype=dtype)
g, beta = fused_gdn_gating(self.A_log, dummy_a, dummy_b, self.dt_bias)
q, k, v, g, beta = fused_post_conv_prep(
conv_output=dummy_mixed_qkv,
a=dummy_a,
b=dummy_b,
A_log=self.A_log,
dt_bias=self.dt_bias,
num_k_heads=num_k_heads,
head_k_dim=self.head_k_dim,
head_v_dim=self.head_v_dim,
apply_l2norm=True,
output_g_exp=False,
)
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
g = g.unsqueeze(0)
beta = beta.unsqueeze(0)
state = torch.zeros(
1,
num_v_heads,
......@@ -735,7 +749,7 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
initial_state=state,
output_final_state=True,
cu_seqlens=cu_seqlens,
use_qk_l2norm_in_kernel=True,
use_qk_l2norm_in_kernel=False,
)
except Exception:
logger.warning(
......@@ -753,7 +767,7 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
self.prefix,
)
finally:
del q, k, v, dummy_a, dummy_b, g, beta, state, cu_seqlens
del dummy_mixed_qkv, q, k, v, dummy_a, dummy_b, g, beta, state, cu_seqlens
torch.accelerator.empty_cache()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment