Unverified Commit d4cb783c authored by Vibhav Agarwal's avatar Vibhav Agarwal Committed by GitHub
Browse files

[Bugfix] Fix GDN FLA kernel crashes with NULL_BLOCK_ID=0 CUDA graph padding (#39064)


Signed-off-by: default avatarVibhav Agarwal <vibhavagarwal5@gmail.com>
Co-authored-by: default avatarvibhav-agarwal <vibhav.agarwal@glance.com>
Co-authored-by: default avatarClaude Opus 4.6 <noreply@anthropic.com>
parent eb92ba74
......@@ -106,12 +106,12 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
else:
i_t = 0
# Load state index and check for PAD_SLOT_ID (-1)
# Load state index and check for invalid entries
state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(
tl.int64
)
# Skip if state index is invalid (PAD_SLOT_ID = -1)
if state_idx < 0:
# Skip if state index is invalid (NULL_BLOCK_ID=0)
if state_idx <= 0:
return
p_h0 = h0 + state_idx * stride_init_state_token
else:
......@@ -150,12 +150,12 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
# keep the states for multi-query tokens
if INPLACE_FINAL_STATE:
# Load state index and check for PAD_SLOT_ID (-1)
# Load state index and check for invalid entries
final_state_idx = tl.load(
ssm_state_indices + i_n * stride_indices_seq + i_t
).to(tl.int64)
# Only store if state index is valid (not PAD_SLOT_ID)
if final_state_idx >= 0:
# Only store if state index is valid (not NULL_BLOCK_ID=0)
if final_state_idx > 0:
p_ht = ht + final_state_idx * stride_final_state_token
p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
......@@ -292,7 +292,8 @@ def fused_recurrent_gated_delta_rule_packed_decode_kernel(
state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq).to(tl.int64)
p_o = o + (i_n * HV + i_hv) * V + o_v
if state_idx < 0:
# Skip if state index is invalid (NULL_BLOCK_ID=0)
if state_idx <= 0:
zero = tl.zeros([BV], dtype=tl.float32).to(p_o.dtype.element_ty)
tl.store(p_o, zero, mask=mask_v)
return
......
......@@ -106,12 +106,12 @@ def fused_sigmoid_gating_delta_rule_update_kernel(
i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
else:
i_t = 0
# Load state index and check for PAD_SLOT_ID (-1)
# Load state index and check for invalid entries
state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(
tl.int64
)
# Skip if state index is invalid (PAD_SLOT_ID = -1)
if state_idx < 0:
# Skip if state index is invalid (NULL_BLOCK_ID=0)
if state_idx <= 0:
return
p_h0 = h0 + state_idx * stride_init_state_token
else:
......@@ -155,12 +155,12 @@ def fused_sigmoid_gating_delta_rule_update_kernel(
# keep the states for multi-query tokens
if INPLACE_FINAL_STATE:
# Load state index and check for PAD_SLOT_ID (-1)
# Load state index and check for invalid entries
final_state_idx = tl.load(
ssm_state_indices + i_n * stride_indices_seq + i_t
).to(tl.int64)
# Only store if state index is valid (not PAD_SLOT_ID)
if final_state_idx >= 0:
# Only store if state index is valid (not NULL_BLOCK_ID=0)
if final_state_idx > 0:
p_ht = ht + final_state_idx * stride_final_state_token
p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment