Unverified Commit 9853a3c1 authored by Ibrahim Arshad's avatar Ibrahim Arshad Committed by GitHub
Browse files

fix(gdn): Align prefill warmup with real prefill path (#39169)


Signed-off-by: default avatarIbrahim Arshad <38925737+ibrahim1023@users.noreply.github.com>
parent bb6047db
...@@ -702,19 +702,33 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase): ...@@ -702,19 +702,33 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
num_v_heads = self.num_v_heads // self.tp_size num_v_heads = self.num_v_heads // self.tp_size
_, state_dtype = self.get_state_dtype() _, state_dtype = self.get_state_dtype()
# All kernels use BT = chunk_size (FLA_CHUNK_SIZE4), so a single pass with # All kernels use BT = chunk_size, so a single pass with T = chunk_size
# T = chunk_size is sufficient to populate every autotuner cache. # is sufficient to populate every autotuner cache. Mirror the real
# prefill path here: build q/k/v/g/beta via fused_post_conv_prep and
# then run chunk_gated_delta_rule with in-kernel L2 norm disabled.
T = FLA_CHUNK_SIZE T = FLA_CHUNK_SIZE
q = torch.randn(1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype) dummy_mixed_qkv = torch.randn(
k = torch.randn(1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype) T, mixed_qkv.shape[-1], device=device, dtype=dtype
v = torch.randn(1, T, num_v_heads, self.head_v_dim, device=device, dtype=dtype) )
# NOTE: g and beta must have the same dtypes as during
# inference, so we construct them with the same function
# (fused_gdn_gating). dummy_a and dummy_b are throwaway
# inputs required by that function.
dummy_a = torch.randn(T, num_v_heads, device=device, dtype=dtype) dummy_a = torch.randn(T, num_v_heads, device=device, dtype=dtype)
dummy_b = torch.randn(T, num_v_heads, device=device, dtype=dtype) dummy_b = torch.randn(T, num_v_heads, device=device, dtype=dtype)
g, beta = fused_gdn_gating(self.A_log, dummy_a, dummy_b, self.dt_bias) q, k, v, g, beta = fused_post_conv_prep(
conv_output=dummy_mixed_qkv,
a=dummy_a,
b=dummy_b,
A_log=self.A_log,
dt_bias=self.dt_bias,
num_k_heads=num_k_heads,
head_k_dim=self.head_k_dim,
head_v_dim=self.head_v_dim,
apply_l2norm=True,
output_g_exp=False,
)
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
g = g.unsqueeze(0)
beta = beta.unsqueeze(0)
state = torch.zeros( state = torch.zeros(
1, 1,
num_v_heads, num_v_heads,
...@@ -735,7 +749,7 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase): ...@@ -735,7 +749,7 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
initial_state=state, initial_state=state,
output_final_state=True, output_final_state=True,
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
use_qk_l2norm_in_kernel=True, use_qk_l2norm_in_kernel=False,
) )
except Exception: except Exception:
logger.warning( logger.warning(
...@@ -753,7 +767,7 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase): ...@@ -753,7 +767,7 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
self.prefix, self.prefix,
) )
finally: finally:
del q, k, v, dummy_a, dummy_b, g, beta, state, cu_seqlens del dummy_mixed_qkv, q, k, v, dummy_a, dummy_b, g, beta, state, cu_seqlens
torch.accelerator.empty_cache() torch.accelerator.empty_cache()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment