"vscode:/vscode.git/clone" did not exist on "dec8b94466abeaad4c9c5be4929221c4e59b4049"
Commit 62420b27 authored by zhuwenwen's avatar zhuwenwen
Browse files

sync flashmla get_mla_metadata

parent 93f78899
...@@ -70,7 +70,7 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, ...@@ -70,7 +70,7 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
blocked_v = blocked_k[..., :dv] blocked_v = blocked_k[..., :dv]
tile_scheduler_metadata, num_splits = get_mla_metadata( tile_scheduler_metadata, num_splits = get_mla_metadata(
cache_seqlens, s_q * h_q // h_kv, h_kv) cache_seqlens, s_q * h_q // h_kv, h_kv, h_q, False, None)
def flash_mla(): def flash_mla():
return flash_mla_with_kvcache( return flash_mla_with_kvcache(
......
...@@ -44,6 +44,9 @@ def get_mla_metadata( ...@@ -44,6 +44,9 @@ def get_mla_metadata(
cache_seqlens: torch.Tensor, cache_seqlens: torch.Tensor,
num_heads_per_head_k: int, num_heads_per_head_k: int,
num_heads_k: int, num_heads_k: int,
num_heads_q: Optional[int] = None,
is_fp8_kvcache: bool = False,
topk: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Arguments: Arguments:
...@@ -59,7 +62,7 @@ def get_mla_metadata( ...@@ -59,7 +62,7 @@ def get_mla_metadata(
if current_platform.is_rocm(): if current_platform.is_rocm():
return flash_mla_cuda.get_mla_metadata(cache_seqlens, return flash_mla_cuda.get_mla_metadata(cache_seqlens,
num_heads_per_head_k, num_heads_per_head_k,
num_heads_k) num_heads_k, num_heads_q, is_fp8_kvcache, topk)
else: else:
return torch.ops._flashmla_C.get_mla_metadata(cache_seqlens, return torch.ops._flashmla_C.get_mla_metadata(cache_seqlens,
num_heads_per_head_k, num_heads_per_head_k,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment