Commit 31021d81 authored by zhuwenwen's avatar zhuwenwen
Browse files

update mla interface

parent f6aa3d19
......@@ -77,9 +77,9 @@ def get_mla_metadata(
- num_splits: (batch_size + 1), dtype torch.int32.
"""
if current_platform.is_rocm():
return flash_mla_cuda.get_mla_metadata(cache_seqlens,
num_q_tokens_per_head_k,
num_heads_k)
return flash_mla_cuda.get_mla_metadata(
cache_seqlens, num_q_tokens_per_head_k, num_heads_k, num_heads_q,
is_fp8_kvcache, topk)
else:
return torch.ops._flashmla_C.get_mla_decoding_metadata(
cache_seqlens, num_q_tokens_per_head_k, num_heads_k, num_heads_q,
......@@ -160,8 +160,9 @@ def flash_mla_with_kvcache(
else:
if current_platform.is_rocm():
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
q, k_cache, None, head_dim_v, cache_seqlens, block_table, softmax_scale,
causal, tile_scheduler_metadata, num_splits)
q, k_cache, block_table, cache_seqlens, head_dim_v, tile_scheduler_metadata,
num_splits, softmax_scale, causal, is_fp8_kvcache,
indices)
else:
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
......@@ -196,8 +197,12 @@ def flash_mla_sparse_prefill(
- max_logits: [s_q, h_q], float
- lse: [s_q, h_q], float, 2-based log-sum-exp
"""
results = torch.ops._flashmla_C.sparse_prefill_fwd(q, kv, indices,
sm_scale, d_v)
if current_platform.is_rocm():
return flash_mla_cuda.sparse_prefill_fwd(q, kv, indices,
sm_scale, d_v)
else:
results = torch.ops._flashmla_C.sparse_prefill_fwd(q, kv, indices,
sm_scale, d_v)
return results
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment