Unverified Commit 6d80ae83 authored by Burkhard Ringlein's avatar Burkhard Ringlein Committed by GitHub
Browse files

[Bugfix] Fixing division by zero in triton_attn if query_heads/kv_heads > 16 (#23424)


Signed-off-by: default avatarBurkhard Ringlein <ngl@zurich.ibm.com>
parent 4ba0c587
......@@ -674,7 +674,8 @@ def unified_attention(
num_queries_per_kv = num_query_heads // num_kv_heads
head_size = q.shape[2]
BLOCK_M = 16
BLOCK_M = 16 if num_queries_per_kv <= 16 else triton.next_power_of_2(
num_queries_per_kv)
BLOCK_Q = BLOCK_M // num_queries_per_kv
# Ideally we would launch with kernel with:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment