Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
7949f854
Commit
7949f854
authored
Mar 04, 2026
by
zhanghj2
Browse files
get_mla_decoding_metadata_dense_fp8和社区保持一致
parent
b894e2da
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
6 additions
and
9 deletions
+6
-9
csrc/extension/flash_api.h
csrc/extension/flash_api.h
+1
-3
flash_mla/flash_mla_interface.py
flash_mla/flash_mla_interface.py
+3
-4
tests/test_flash_mla_fp8.py
tests/test_flash_mla_fp8.py
+1
-1
tests/test_flash_mla_qkvfp8.py
tests/test_flash_mla_qkvfp8.py
+1
-1
No files found.
csrc/extension/flash_api.h
View file @
7949f854
...
@@ -265,9 +265,7 @@ std::vector<at::Tensor>
...
@@ -265,9 +265,7 @@ std::vector<at::Tensor>
get_mla_decoding_metadata_dense_fp8
(
get_mla_decoding_metadata_dense_fp8
(
at
::
Tensor
&
seqlens_k
,
at
::
Tensor
&
seqlens_k
,
const
int
num_heads_per_head_k
,
const
int
num_heads_per_head_k
,
const
int
num_heads_k
,
const
int
num_heads_k
)
{
const
std
::
optional
<
int
>
h_q
)
{
// This should match the logic in the MLA kernel.
// This should match the logic in the MLA kernel.
int
block_size_m
=
16
;
int
block_size_m
=
16
;
static
constexpr
int
block_size_n
=
64
;
static
constexpr
int
block_size_n
=
64
;
...
...
flash_mla/flash_mla_interface.py
View file @
7949f854
...
@@ -213,9 +213,8 @@ def flash_mla_sparse_fwd(
...
@@ -213,9 +213,8 @@ def flash_mla_sparse_fwd(
def
get_mla_decoding_metadata_dense_fp8
(
def
get_mla_decoding_metadata_dense_fp8
(
cache_seqlens
:
torch
.
Tensor
,
cache_seqlens
:
torch
.
Tensor
,
num_heads_per_head_k
:
int
,
num_heads_per_head_k
:
int
,
num_heads_k
:
int
,
num_heads_k
:
int
num_heads_q
:
int
=
16
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
"""
Arguments:
Arguments:
cache_seqlens: (batch_size), dtype torch.int32.
cache_seqlens: (batch_size), dtype torch.int32.
...
@@ -226,7 +225,7 @@ def get_mla_decoding_metadata_dense_fp8(
...
@@ -226,7 +225,7 @@ def get_mla_decoding_metadata_dense_fp8(
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
num_splits: (batch_size + 1), dtype torch.int32.
num_splits: (batch_size + 1), dtype torch.int32.
"""
"""
return
flash_mla_cuda
.
get_mla_decoding_metadata_dense_fp8
(
cache_seqlens
,
num_heads_per_head_k
,
num_heads_k
,
num_heads_q
)
return
flash_mla_cuda
.
get_mla_decoding_metadata_dense_fp8
(
cache_seqlens
,
num_heads_per_head_k
,
num_heads_k
)
...
...
tests/test_flash_mla_fp8.py
View file @
7949f854
...
@@ -79,7 +79,7 @@ def test_flash_mla_fp8_e5m2(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, i
...
@@ -79,7 +79,7 @@ def test_flash_mla_fp8_e5m2(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, i
blocked_v
=
blocked_k
[...,
:
dv
]
blocked_v
=
blocked_k
[...,
:
dv
]
tile_scheduler_metadata
,
num_splits
=
get_mla_decoding_metadata_dense_fp8
(
tile_scheduler_metadata
,
num_splits
=
get_mla_decoding_metadata_dense_fp8
(
cache_seqlens
,
s_q
*
h_q
//
h_kv
,
h_kv
,
h_q
cache_seqlens
,
s_q
*
h_q
//
h_kv
,
h_kv
)
)
# print("q:", q.shape, q.dtype, q)
# print("q:", q.shape, q.dtype, q)
...
...
tests/test_flash_mla_qkvfp8.py
View file @
7949f854
...
@@ -88,7 +88,7 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, is_prof=Fa
...
@@ -88,7 +88,7 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, is_prof=Fa
blocked_v
=
blocked_k
[...,
:
dv
]
blocked_v
=
blocked_k
[...,
:
dv
]
tile_scheduler_metadata
,
num_splits
=
get_mla_decoding_metadata_dense_fp8
(
tile_scheduler_metadata
,
num_splits
=
get_mla_decoding_metadata_dense_fp8
(
cache_seqlens
,
s_q
*
h_q
//
h_kv
,
h_kv
,
h_q
cache_seqlens
,
s_q
*
h_q
//
h_kv
,
h_kv
)
)
init_dtype
=
q
.
dtype
init_dtype
=
q
.
dtype
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment