Unverified Commit 78434b92 authored by rasmith's avatar rasmith Committed by GitHub
Browse files

[CI][AMD][BugFix][Kernel] Cast induction variable to int64 on MI350 for...


[CI][AMD][BugFix][Kernel] Cast induction variable to int64 on MI350 for chunk_gated_delta_rule_fwd_kernel_h_blockdim64 to avoid illegal memory access (#39087)
Signed-off-by: default avatarRandall Smith <Randall.Smith@amd.com>
parent 2488d1dc
...@@ -129,22 +129,42 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( ...@@ -129,22 +129,42 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
# main recurrence # main recurrence
for i_t in range(NT): for i_t in range(NT):
p_h1 = tl.make_block_ptr( p_h1 = tl.make_block_ptr(
h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0) h + i_t.to(tl.int64) * stride_h,
(V, K),
(K, 1),
(i_v * BV, 0),
(BV, 64),
(1, 0),
) )
tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1))
if K > 64: if K > 64:
p_h2 = tl.make_block_ptr( p_h2 = tl.make_block_ptr(
h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0) h + i_t.to(tl.int64) * stride_h,
(V, K),
(K, 1),
(i_v * BV, 64),
(BV, 64),
(1, 0),
) )
tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1))
if K > 128: if K > 128:
p_h3 = tl.make_block_ptr( p_h3 = tl.make_block_ptr(
h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0) h + i_t.to(tl.int64) * stride_h,
(V, K),
(K, 1),
(i_v * BV, 128),
(BV, 64),
(1, 0),
) )
tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1))
if K > 192: if K > 192:
p_h4 = tl.make_block_ptr( p_h4 = tl.make_block_ptr(
h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0) h + i_t.to(tl.int64) * stride_h,
(V, K),
(K, 1),
(i_v * BV, 192),
(BV, 64),
(1, 0),
) )
tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1))
...@@ -182,9 +202,9 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( ...@@ -182,9 +202,9 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
) )
tl.store(p_v, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_v, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1))
last_idx = min((i_t + 1) * BT, T) - 1 last_idx = min((i_t.to(tl.int64) + 1) * BT, T) - 1
if USE_G: if USE_G:
m_t = (i_t * BT + tl.arange(0, BT)) < T m_t = (i_t.to(tl.int64) * BT + tl.arange(0, BT)) < T
b_g_last = tl.load(g + bos * H + last_idx * H + i_h) b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
p_g = tl.make_block_ptr( p_g = tl.make_block_ptr(
g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,) g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment