Unverified Commit 70dc2fbe authored by kk's avatar kk Committed by GitHub
Browse files

Change extend attention kernel launch parameter for ROCm platform to … (#2610)


Co-authored-by: default avatarwunhuang <wunhuang@amd.com>
Co-authored-by: default avatarHAI <hixiao@gmail.com>
parent b438a2e5
......@@ -292,27 +292,33 @@ def extend_attention_fwd(
BLOCK_DPE = 0
BLOCK_DV = triton.next_power_of_2(Lv)
if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
if Lq <= 256:
BLOCK_M, BLOCK_N = (128, 64)
else:
BLOCK_M, BLOCK_N = (32, 64)
elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
if Lq <= 128:
BLOCK_M, BLOCK_N = (128, 128)
elif Lq <= 256:
BLOCK_M, BLOCK_N = (64, 64)
else:
BLOCK_M, BLOCK_N = (32, 64)
if is_hip_:
BLOCK_M, BLOCK_N = (64, 64)
num_warps = 4
else:
BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
if Lq <= 256:
BLOCK_M, BLOCK_N = (128, 64)
else:
BLOCK_M, BLOCK_N = (32, 64)
elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
if Lq <= 128:
BLOCK_M, BLOCK_N = (128, 128)
elif Lq <= 256:
BLOCK_M, BLOCK_N = (64, 64)
else:
BLOCK_M, BLOCK_N = (32, 64)
else:
BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
num_warps = 4 if Lk <= 64 else 8
sm_scale = sm_scale or 1.0 / (Lq**0.5)
batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1]
kv_group_num = q_extend.shape[1] // k_extend.shape[1]
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
num_warps = 4 if Lk <= 64 else 8
num_stages = 1
extra_kargs = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment