Commit 7949f854 authored by zhanghj2's avatar zhanghj2
Browse files

get_mla_decoding_metadata_dense_fp8和社区保持一致

parent b894e2da
......@@ -265,9 +265,7 @@ std::vector<at::Tensor>
get_mla_decoding_metadata_dense_fp8(
at::Tensor &seqlens_k,
const int num_heads_per_head_k,
const int num_heads_k,
const std::optional<int> h_q
) {
const int num_heads_k) {
// This should match the logic in the MLA kernel.
int block_size_m = 16;
static constexpr int block_size_n = 64;
......
......@@ -213,9 +213,8 @@ def flash_mla_sparse_fwd(
def get_mla_decoding_metadata_dense_fp8(
cache_seqlens: torch.Tensor,
num_heads_per_head_k: int,
num_heads_k: int,
num_heads_q : int = 16,
) -> Tuple[torch.Tensor, torch.Tensor]:
num_heads_k: int
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
cache_seqlens: (batch_size), dtype torch.int32.
......@@ -226,7 +225,7 @@ def get_mla_decoding_metadata_dense_fp8(
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
num_splits: (batch_size + 1), dtype torch.int32.
"""
return flash_mla_cuda.get_mla_decoding_metadata_dense_fp8(cache_seqlens, num_heads_per_head_k, num_heads_k, num_heads_q)
return flash_mla_cuda.get_mla_decoding_metadata_dense_fp8(cache_seqlens, num_heads_per_head_k, num_heads_k)
......
......@@ -79,7 +79,7 @@ def test_flash_mla_fp8_e5m2(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, i
blocked_v = blocked_k[..., :dv]
tile_scheduler_metadata, num_splits = get_mla_decoding_metadata_dense_fp8(
cache_seqlens, s_q * h_q // h_kv, h_kv, h_q
cache_seqlens, s_q * h_q // h_kv, h_kv
)
# print("q:", q.shape, q.dtype, q)
......
......@@ -88,7 +88,7 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, is_prof=Fa
blocked_v = blocked_k[..., :dv]
tile_scheduler_metadata, num_splits = get_mla_decoding_metadata_dense_fp8(
cache_seqlens, s_q * h_q // h_kv, h_kv, h_q
cache_seqlens, s_q * h_q // h_kv, h_kv
)
init_dtype = q.dtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment