Unverified Commit dd1161e3 authored by Jiashi Li's avatar Jiashi Li Committed by GitHub
Browse files

Merge pull request #14 from lancerts/minor-fix

minor fix test
parents accc1695 4fbaa952
...@@ -7,7 +7,7 @@ import triton ...@@ -7,7 +7,7 @@ import triton
from flash_mla import get_mla_metadata, flash_mla_with_kvcache from flash_mla import get_mla_metadata, flash_mla_with_kvcache
def scaled_dot_product_attention(query, key, value, is_causal=False): def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
query = query.float() query = query.float()
key = key.float() key = key.float()
value = value.float() value = value.float()
...@@ -76,6 +76,8 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen): ...@@ -76,6 +76,8 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen):
q[i].transpose(0, 1), q[i].transpose(0, 1),
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
h_q=h_q,
h_kv=h_kv,
is_causal=causal, is_causal=causal,
) )
out[i] = O.transpose(0, 1) out[i] = O.transpose(0, 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment