Unverified Commit 269d9017 authored by Hongxia Yang's avatar Hongxia Yang Committed by GitHub
Browse files

[Bugfix][ROCm] fix the power of 2 exception from triton_unified_attention.py...


[Bugfix][ROCm] fix the power of 2 exception from triton_unified_attention.py when running llama4 models and unit test fix (#18100)
Signed-off-by: default avatarHongxia Yang <hongxia.yang@amd.com>
Signed-off-by: default avatartjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: default avatartjtanaa <tunjian.tan@embeddedllm.com>
parent 7951d787
...@@ -13,7 +13,9 @@ HEAD_SIZES = [128, 256] ...@@ -13,7 +13,9 @@ HEAD_SIZES = [128, 256]
BLOCK_SIZES = [16, 32] BLOCK_SIZES = [16, 32]
DTYPES = [torch.float16, torch.bfloat16] DTYPES = [torch.float16, torch.bfloat16]
QDTYPES = [None, torch.float8_e4m3fn] QDTYPES = [None, torch.float8_e4m3fn] if not current_platform.is_rocm() else [
None, torch.float8_e4m3fnuz
]
# one value large enough to test overflow in index calculation. # one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check # one value small enough to test the schema op check
NUM_BLOCKS = [32768, 2048] NUM_BLOCKS = [32768, 2048]
......
...@@ -64,6 +64,7 @@ def kernel_unified_attention_2d( ...@@ -64,6 +64,7 @@ def kernel_unified_attention_2d(
query_start_len_ptr, # [num_seqs+1] query_start_len_ptr, # [num_seqs+1]
BLOCK_Q: tl.constexpr, # int BLOCK_Q: tl.constexpr, # int
num_seqs: tl.int32, num_seqs: tl.int32,
BLOCK_M: tl.constexpr, # int
): ):
q_block_global_idx = tl.program_id(0) q_block_global_idx = tl.program_id(0)
...@@ -94,15 +95,13 @@ def kernel_unified_attention_2d( ...@@ -94,15 +95,13 @@ def kernel_unified_attention_2d(
if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: if q_block_local_idx * BLOCK_Q >= cur_batch_query_len:
return return
offs_m = tl.arange(0, BLOCK_Q * num_queries_per_kv) offs_m = tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, HEAD_SIZE_PADDED) offs_d = tl.arange(0, HEAD_SIZE_PADDED)
query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv
query_offset_0 = cur_batch_in_all_start_index + query_pos query_offset_0 = cur_batch_in_all_start_index + query_pos
query_offset_1 = kv_head_idx * num_queries_per_kv + \ query_offset_1 = kv_head_idx * num_queries_per_kv + \
offs_m % num_queries_per_kv offs_m % num_queries_per_kv
query_offset = (query_offset_0[:, None] * query_stride_0 + query_offset = (query_offset_0[:, None] * query_stride_0 +
query_offset_1[:, None] * query_stride_1 + offs_d[None, :]) query_offset_1[:, None] * query_stride_1 + offs_d[None, :])
...@@ -110,7 +109,7 @@ def kernel_unified_attention_2d( ...@@ -110,7 +109,7 @@ def kernel_unified_attention_2d(
query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1)
query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1) query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1)
# Q : (BLOCK_Q * num_queries_per_kv, HEAD_SIZE,) # Q : (BLOCK_M, HEAD_SIZE_PADDED)
Q = tl.load( Q = tl.load(
query_ptr + query_offset, query_ptr + query_offset,
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
...@@ -119,12 +118,9 @@ def kernel_unified_attention_2d( ...@@ -119,12 +118,9 @@ def kernel_unified_attention_2d(
block_table_offset = seq_idx * block_table_stride block_table_offset = seq_idx * block_table_stride
M = tl.full([BLOCK_Q * num_queries_per_kv], M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
float("-inf"), L = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
dtype=tl.float32) acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)
L = tl.full([BLOCK_Q * num_queries_per_kv], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_Q * num_queries_per_kv, HEAD_SIZE_PADDED],
dtype=tl.float32)
# sequence len for this particular sequence # sequence len for this particular sequence
seq_len = tl.load(seq_lens_ptr + seq_idx) seq_len = tl.load(seq_lens_ptr + seq_idx)
...@@ -183,13 +179,12 @@ def kernel_unified_attention_2d( ...@@ -183,13 +179,12 @@ def kernel_unified_attention_2d(
else: else:
V = V_load V = V_load
seq_offset = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) seq_offset = j * BLOCK_SIZE + offs_n
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
# S : (BLOCK_Q * num_queries_per_kv, BLOCK_SIZE,) # S : (BLOCK_M, BLOCK_SIZE)
S = tl.zeros(shape=(BLOCK_Q * num_queries_per_kv, BLOCK_SIZE), S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32)
dtype=tl.float32)
S += scale * tl.dot(Q, K) S += scale * tl.dot(Q, K)
...@@ -207,29 +202,29 @@ def kernel_unified_attention_2d( ...@@ -207,29 +202,29 @@ def kernel_unified_attention_2d(
S += alibi_slope[:, None] * (seq_offset - context_len) S += alibi_slope[:, None] * (seq_offset - context_len)
# compute running maximum # compute running maximum
# m_j : (BLOCK_Q * num_queries_per_kv,) # m_j : (BLOCK_M,)
m_j = tl.maximum(M, tl.max(S, axis=1)) m_j = tl.maximum(M, tl.max(S, axis=1))
# For sliding window there's a chance the max is -inf due to masking of # For sliding window there's a chance the max is -inf due to masking of
# the entire row. In this case we need to set m_j 0 to avoid NaN # the entire row. In this case we need to set m_j 0 to avoid NaN
m_j = tl.where(m_j > float("-inf"), m_j, 0.0) m_j = tl.where(m_j > float("-inf"), m_j, 0.0)
# P : (BLOCK_Q * num_queries_per_kv, BLOCK_SIZE,) # P : (BLOCK_M, BLOCK_SIZE)
P = tl.exp(S - m_j[:, None]) P = tl.exp(S - m_j[:, None])
# l_j : (BLOCK_Q * num_queries_per_kv,) # l_j : (BLOCK_M,)
l_j = tl.sum(P, axis=1) l_j = tl.sum(P, axis=1)
# alpha : (BLOCK_Q * num_queries_per_kv, ) # alpha : (BLOCK_M, )
alpha = tl.exp(M - m_j) alpha = tl.exp(M - m_j)
# acc : (BLOCK_Q * num_queries_per_kv, BLOCK_SIZE,) # acc : (BLOCK_M, HEAD_SIZE_PADDED)
acc = acc * alpha[:, None] acc = acc * alpha[:, None]
# update constants # update constants
L = L * alpha + l_j L = L * alpha + l_j
M = m_j M = m_j
# acc : (BLOCK_Q * num_queries_per_kv, BLOCK_SIZE,) # acc : (BLOCK_M, HEAD_SIZE_PADDED)
acc += tl.dot(P.to(V.dtype), V) acc += tl.dot(P.to(V.dtype), V)
# epilogue # epilogue
...@@ -334,4 +329,5 @@ def unified_attention( ...@@ -334,4 +329,5 @@ def unified_attention(
query_start_len_ptr=cu_seqlens_q, query_start_len_ptr=cu_seqlens_q,
BLOCK_Q=BLOCK_Q, BLOCK_Q=BLOCK_Q,
num_seqs=num_seqs, num_seqs=num_seqs,
BLOCK_M=BLOCK_M,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment