Commit 7949f854 authored by zhanghj2's avatar zhanghj2
Browse files

get_mla_decoding_metadata_dense_fp8和社区保持一致

parent b894e2da
...@@ -265,9 +265,7 @@ std::vector<at::Tensor> ...@@ -265,9 +265,7 @@ std::vector<at::Tensor>
get_mla_decoding_metadata_dense_fp8( get_mla_decoding_metadata_dense_fp8(
at::Tensor &seqlens_k, at::Tensor &seqlens_k,
const int num_heads_per_head_k, const int num_heads_per_head_k,
const int num_heads_k, const int num_heads_k) {
const std::optional<int> h_q
) {
// This should match the logic in the MLA kernel. // This should match the logic in the MLA kernel.
int block_size_m = 16; int block_size_m = 16;
static constexpr int block_size_n = 64; static constexpr int block_size_n = 64;
......
...@@ -213,9 +213,8 @@ def flash_mla_sparse_fwd( ...@@ -213,9 +213,8 @@ def flash_mla_sparse_fwd(
def get_mla_decoding_metadata_dense_fp8( def get_mla_decoding_metadata_dense_fp8(
cache_seqlens: torch.Tensor, cache_seqlens: torch.Tensor,
num_heads_per_head_k: int, num_heads_per_head_k: int,
num_heads_k: int, num_heads_k: int
num_heads_q : int = 16, ) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Arguments: Arguments:
cache_seqlens: (batch_size), dtype torch.int32. cache_seqlens: (batch_size), dtype torch.int32.
...@@ -226,7 +225,7 @@ def get_mla_decoding_metadata_dense_fp8( ...@@ -226,7 +225,7 @@ def get_mla_decoding_metadata_dense_fp8(
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
num_splits: (batch_size + 1), dtype torch.int32. num_splits: (batch_size + 1), dtype torch.int32.
""" """
return flash_mla_cuda.get_mla_decoding_metadata_dense_fp8(cache_seqlens, num_heads_per_head_k, num_heads_k, num_heads_q) return flash_mla_cuda.get_mla_decoding_metadata_dense_fp8(cache_seqlens, num_heads_per_head_k, num_heads_k)
......
...@@ -79,7 +79,7 @@ def test_flash_mla_fp8_e5m2(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, i ...@@ -79,7 +79,7 @@ def test_flash_mla_fp8_e5m2(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, i
blocked_v = blocked_k[..., :dv] blocked_v = blocked_k[..., :dv]
tile_scheduler_metadata, num_splits = get_mla_decoding_metadata_dense_fp8( tile_scheduler_metadata, num_splits = get_mla_decoding_metadata_dense_fp8(
cache_seqlens, s_q * h_q // h_kv, h_kv, h_q cache_seqlens, s_q * h_q // h_kv, h_kv
) )
# print("q:", q.shape, q.dtype, q) # print("q:", q.shape, q.dtype, q)
......
...@@ -88,7 +88,7 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, is_prof=Fa ...@@ -88,7 +88,7 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, is_prof=Fa
blocked_v = blocked_k[..., :dv] blocked_v = blocked_k[..., :dv]
tile_scheduler_metadata, num_splits = get_mla_decoding_metadata_dense_fp8( tile_scheduler_metadata, num_splits = get_mla_decoding_metadata_dense_fp8(
cache_seqlens, s_q * h_q // h_kv, h_kv, h_q cache_seqlens, s_q * h_q // h_kv, h_kv
) )
init_dtype = q.dtype init_dtype = q.dtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment