Unverified Commit dd1161e3 authored by Jiashi Li's avatar Jiashi Li Committed by GitHub
Browse files

Merge pull request #14 from lancerts/minor-fix

minor fix test
parents accc1695 4fbaa952
......@@ -7,7 +7,7 @@ import triton
from flash_mla import get_mla_metadata, flash_mla_with_kvcache
def scaled_dot_product_attention(query, key, value, is_causal=False):
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
query = query.float()
key = key.float()
value = value.float()
......@@ -76,6 +76,8 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen):
q[i].transpose(0, 1),
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
h_q=h_q,
h_kv=h_kv,
is_causal=causal,
)
out[i] = O.transpose(0, 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment