Unverified Commit 23591e63 authored by CarstyYou's avatar CarstyYou Committed by GitHub
Browse files

[Bugfix][Kernel] Fix negative memory offset in GDN Triton kernel (#33326)


Signed-off-by: default avatarCarstyYou <186021327+CarstyYou@users.noreply.github.com>
parent 0493d897
...@@ -106,13 +106,14 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( ...@@ -106,13 +106,14 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1 i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
else: else:
i_t = 0 i_t = 0
p_h0 = ( # Load state index and check for PAD_SLOT_ID (-1)
h0 state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(
+ tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to( tl.int64
tl.int64
)
* stride_init_state_token
) )
# Skip if state index is invalid (PAD_SLOT_ID = -1)
if state_idx < 0:
return
p_h0 = h0 + state_idx * stride_init_state_token
else: else:
p_h0 = h0 + bos * HV * K * V p_h0 = h0 + bos * HV * K * V
p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :] p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
...@@ -149,17 +150,19 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( ...@@ -149,17 +150,19 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
# keep the states for multi-query tokens # keep the states for multi-query tokens
if INPLACE_FINAL_STATE: if INPLACE_FINAL_STATE:
p_ht = ( # Load state index and check for PAD_SLOT_ID (-1)
ht final_state_idx = tl.load(
+ tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to( ssm_state_indices + i_n * stride_indices_seq + i_t
tl.int64 ).to(tl.int64)
) # Only store if state index is valid (not PAD_SLOT_ID)
* stride_final_state_token if final_state_idx >= 0:
) p_ht = ht + final_state_idx * stride_final_state_token
p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
else: else:
p_ht = ht + (bos + i_t) * stride_final_state_token p_ht = ht + (bos + i_t) * stride_final_state_token
p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
p_q += H * K p_q += H * K
p_k += H * K p_k += H * K
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment