Unverified Commit f1ff50c8 authored by Jingu Kang's avatar Jingu Kang Committed by GitHub
Browse files

[Bugfix] clamp dA_cumsum differences to prevent Inf in Mamba2 SSD kernels (#37501)


Signed-off-by: default avatarJingu Kang <jg.k@navercorp.com>
parent 757068dc
......@@ -356,7 +356,7 @@ def _chunk_scan_fwd_kernel(
)
# If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j].
# So we don't need masking wrt seq_idx here.
cb *= fast_exp(dA_cs_m[:, None] - dA_cs_k[None, :])
cb *= fast_exp(tl.minimum(dA_cs_m[:, None] - dA_cs_k[None, :], 0.0))
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)
cb *= dt_k
if IS_CAUSAL:
......
......@@ -280,7 +280,7 @@ def _chunk_state_fwd_kernel(
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
tl.float32
)
scale = fast_exp(dA_cs_last - dA_cs_k) * dt_k
scale = fast_exp(tl.minimum(dA_cs_last - dA_cs_k, 0.0)) * dt_k
b *= scale[:, None]
b = b.to(x_ptr.dtype.element_ty)
acc += tl.dot(x, b)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment