Unverified Commit 457e4719 authored by rasmith's avatar rasmith Committed by GitHub
Browse files

[AMD][Kernel][Bugfix] Cast offsets tensor bn to tl.int64 to avoid GPU segfault (#23692)


Signed-off-by: default avatarRandall Smith <Randall.Smith@amd.com>
parent d328f789
...@@ -146,7 +146,7 @@ def _fwd_kernel(Q, ...@@ -146,7 +146,7 @@ def _fwd_kernel(Q,
start_n = tl.multiple_of(start_n, BLOCK_SIZE) start_n = tl.multiple_of(start_n, BLOCK_SIZE)
# -- compute qk ---- # -- compute qk ----
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
(start_n // BLOCK_SIZE) * stride_b_loc_s) (start_n // BLOCK_SIZE) * stride_b_loc_s).to(tl.int64)
# [D,BLOCK_SIZE] # [D,BLOCK_SIZE]
off_k = ( off_k = (
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
...@@ -367,7 +367,7 @@ def _fwd_kernel_flash_attn_v2( ...@@ -367,7 +367,7 @@ def _fwd_kernel_flash_attn_v2(
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
((start_n + offs_n) // block_size) * stride_b_loc_s, ((start_n + offs_n) // block_size) * stride_b_loc_s,
mask=(start_n + offs_n) < cur_batch_ctx_len, mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0) other=0).to(tl.int64)
off_k = ( off_k = (
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
(offs_d[:, None] // x) * stride_k_cache_d + (offs_d[:, None] // x) * stride_k_cache_d +
...@@ -575,7 +575,7 @@ def _fwd_kernel_alibi( ...@@ -575,7 +575,7 @@ def _fwd_kernel_alibi(
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
((start_n + offs_n) // block_size) * stride_b_loc_s, ((start_n + offs_n) // block_size) * stride_b_loc_s,
mask=(start_n + offs_n) < cur_batch_ctx_len, mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0) other=0).to(tl.int64)
off_k = ( off_k = (
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
(offs_d[:, None] // x) * stride_k_cache_d + (offs_d[:, None] // x) * stride_k_cache_d +
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment