Commit 31021d81 authored by zhuwenwen's avatar zhuwenwen
Browse files

update mla interface

parent f6aa3d19
...@@ -77,9 +77,9 @@ def get_mla_metadata( ...@@ -77,9 +77,9 @@ def get_mla_metadata(
- num_splits: (batch_size + 1), dtype torch.int32. - num_splits: (batch_size + 1), dtype torch.int32.
""" """
if current_platform.is_rocm(): if current_platform.is_rocm():
return flash_mla_cuda.get_mla_metadata(cache_seqlens, return flash_mla_cuda.get_mla_metadata(
num_q_tokens_per_head_k, cache_seqlens, num_q_tokens_per_head_k, num_heads_k, num_heads_q,
num_heads_k) is_fp8_kvcache, topk)
else: else:
return torch.ops._flashmla_C.get_mla_decoding_metadata( return torch.ops._flashmla_C.get_mla_decoding_metadata(
cache_seqlens, num_q_tokens_per_head_k, num_heads_k, num_heads_q, cache_seqlens, num_q_tokens_per_head_k, num_heads_k, num_heads_q,
...@@ -160,8 +160,9 @@ def flash_mla_with_kvcache( ...@@ -160,8 +160,9 @@ def flash_mla_with_kvcache(
else: else:
if current_platform.is_rocm(): if current_platform.is_rocm():
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla( out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
q, k_cache, None, head_dim_v, cache_seqlens, block_table, softmax_scale, q, k_cache, block_table, cache_seqlens, head_dim_v, tile_scheduler_metadata,
causal, tile_scheduler_metadata, num_splits) num_splits, softmax_scale, causal, is_fp8_kvcache,
indices)
else: else:
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla( out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale, q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
...@@ -196,6 +197,10 @@ def flash_mla_sparse_prefill( ...@@ -196,6 +197,10 @@ def flash_mla_sparse_prefill(
- max_logits: [s_q, h_q], float - max_logits: [s_q, h_q], float
- lse: [s_q, h_q], float, 2-based log-sum-exp - lse: [s_q, h_q], float, 2-based log-sum-exp
""" """
if current_platform.is_rocm():
return flash_mla_cuda.sparse_prefill_fwd(q, kv, indices,
sm_scale, d_v)
else:
results = torch.ops._flashmla_C.sparse_prefill_fwd(q, kv, indices, results = torch.ops._flashmla_C.sparse_prefill_fwd(q, kv, indices,
sm_scale, d_v) sm_scale, d_v)
return results return results
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment