Commit 62420b27 authored by zhuwenwen's avatar zhuwenwen
Browse files

sync flashmla get_mla_metadata

parent 93f78899
......@@ -70,7 +70,7 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
blocked_v = blocked_k[..., :dv]
tile_scheduler_metadata, num_splits = get_mla_metadata(
cache_seqlens, s_q * h_q // h_kv, h_kv)
cache_seqlens, s_q * h_q // h_kv, h_kv, h_q, False, None)
def flash_mla():
return flash_mla_with_kvcache(
......
......@@ -44,6 +44,9 @@ def get_mla_metadata(
cache_seqlens: torch.Tensor,
num_heads_per_head_k: int,
num_heads_k: int,
num_heads_q: Optional[int] = None,
is_fp8_kvcache: bool = False,
topk: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
......@@ -59,7 +62,7 @@ def get_mla_metadata(
if current_platform.is_rocm():
return flash_mla_cuda.get_mla_metadata(cache_seqlens,
num_heads_per_head_k,
num_heads_k)
num_heads_k, num_heads_q, is_fp8_kvcache, topk)
else:
return torch.ops._flashmla_C.get_mla_metadata(cache_seqlens,
num_heads_per_head_k,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment