Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
b894e2da
Commit
b894e2da
authored
Mar 03, 2026
by
zhanghj2
Browse files
优化mtp场景
parent
7efb944d
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
23 additions
and
22 deletions
+23
-22
csrc/extension/flash_api.h
csrc/extension/flash_api.h
+5
-6
csrc/extension/flash_fwd_mla_kernel.h
csrc/extension/flash_fwd_mla_kernel.h
+2
-2
csrc/extension/flash_fwd_mla_kernel_fp8.h
csrc/extension/flash_fwd_mla_kernel_fp8.h
+2
-2
tests/test_flash_mla_fp8.py
tests/test_flash_mla_fp8.py
+6
-0
tests/test_flash_mla_qkvfp8.py
tests/test_flash_mla_qkvfp8.py
+8
-12
No files found.
csrc/extension/flash_api.h
View file @
b894e2da
...
@@ -271,13 +271,12 @@ get_mla_decoding_metadata_dense_fp8(
...
@@ -271,13 +271,12 @@ get_mla_decoding_metadata_dense_fp8(
// This should match the logic in the MLA kernel.
// This should match the logic in the MLA kernel.
int
block_size_m
=
16
;
int
block_size_m
=
16
;
static
constexpr
int
block_size_n
=
64
;
static
constexpr
int
block_size_n
=
64
;
if
(
h_q
.
has_value
())
{
if
(
num_heads_per_head_k
>
32
)
{
if
(
h_q
.
value
()
>=
64
)
{
block_size_m
=
64
;
block_size_m
=
64
;
}
else
if
(
h_q
.
value
()
>
16
)
{
}
else
if
(
num_heads_per_head_k
>
16
)
{
block_size_m
=
32
;
block_size_m
=
32
;
}
}
}
static
constexpr
int
fixed_overhead_num_blocks
=
5
;
static
constexpr
int
fixed_overhead_num_blocks
=
5
;
CHECK_DEVICE
(
seqlens_k
);
CHECK_DEVICE
(
seqlens_k
);
...
...
csrc/extension/flash_fwd_mla_kernel.h
View file @
b894e2da
...
@@ -4314,7 +4314,7 @@ void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, const std::string& kv
...
@@ -4314,7 +4314,7 @@ void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, const std::string& kv
if
(
kv_cache_dtype
==
"auto"
)
{
if
(
kv_cache_dtype
==
"auto"
)
{
// printf(" seqlen_q %d \n", params.seqlen_q);
// printf(" seqlen_q %d \n", params.seqlen_q);
if
(
params
.
ngroups
>=
64
)
{
if
(
params
.
seqlen_q
>
32
)
{
using
Kernel_traits
=
Flash_fwd_kernel_traits_mla_tp1
<
576
,
64
,
64
,
8
,
T
,
512
>
;
using
Kernel_traits
=
Flash_fwd_kernel_traits_mla_tp1
<
576
,
64
,
64
,
8
,
T
,
512
>
;
run_flash_splitkv_fwd_mla_tp1
<
Kernel_traits
,
flash
::
SharedStorageMLATP1
<
Kernel_traits
>
,
Fp8KVCacheDataType
::
kAuto
>
(
params
,
stream
);
run_flash_splitkv_fwd_mla_tp1
<
Kernel_traits
,
flash
::
SharedStorageMLATP1
<
Kernel_traits
>
,
Fp8KVCacheDataType
::
kAuto
>
(
params
,
stream
);
}
else
{
}
else
{
...
@@ -4325,7 +4325,7 @@ void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, const std::string& kv
...
@@ -4325,7 +4325,7 @@ void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, const std::string& kv
using
Kernel_traits
=
Flash_fwd_kernel_traits_mla_kvfp8
<
576
,
16
,
64
,
4
,
T
,
512
>
;
using
Kernel_traits
=
Flash_fwd_kernel_traits_mla_kvfp8
<
576
,
16
,
64
,
4
,
T
,
512
>
;
run_flash_splitkv_fwd_mla
<
Kernel_traits
,
flash
::
SharedStorageMLAFp8
<
Kernel_traits
>
,
Fp8KVCacheDataType
::
kFp8E4M3
>
(
params
,
stream
);
run_flash_splitkv_fwd_mla
<
Kernel_traits
,
flash
::
SharedStorageMLAFp8
<
Kernel_traits
>
,
Fp8KVCacheDataType
::
kFp8E4M3
>
(
params
,
stream
);
}
else
if
(
kv_cache_dtype
==
"fp8_e5m2"
)
{
}
else
if
(
kv_cache_dtype
==
"fp8_e5m2"
)
{
if
(
params
.
ngroups
>=
64
)
{
if
(
params
.
seqlen_q
>
32
)
{
using
Kernel_traits
=
Flash_fwd_kernel_traits_mla_kvfp8_TP1
<
576
,
64
,
64
,
8
,
T
,
512
>
;
using
Kernel_traits
=
Flash_fwd_kernel_traits_mla_kvfp8_TP1
<
576
,
64
,
64
,
8
,
T
,
512
>
;
run_flash_splitkv_fwd_mla_tp1
<
Kernel_traits
,
flash
::
SharedStorageMLAFp8_TP1
<
Kernel_traits
>
,
Fp8KVCacheDataType
::
kFp8E5M2
>
(
params
,
stream
);
run_flash_splitkv_fwd_mla_tp1
<
Kernel_traits
,
flash
::
SharedStorageMLAFp8_TP1
<
Kernel_traits
>
,
Fp8KVCacheDataType
::
kFp8E5M2
>
(
params
,
stream
);
}
else
{
}
else
{
...
...
csrc/extension/flash_fwd_mla_kernel_fp8.h
View file @
b894e2da
...
@@ -2781,10 +2781,10 @@ void run_mha_fwd_splitkv_mla_fp8(Flash_fwd_mla_params ¶ms, cudaStream_t str
...
@@ -2781,10 +2781,10 @@ void run_mha_fwd_splitkv_mla_fp8(Flash_fwd_mla_params ¶ms, cudaStream_t str
return
;
return
;
}
}
if
constexpr
(
std
::
is_same_v
<
T
,
cutlass
::
float_e4m3_t
>
)
{
if
constexpr
(
std
::
is_same_v
<
T
,
cutlass
::
float_e4m3_t
>
)
{
if
(
params
.
ngroups
>=
64
)
{
if
(
params
.
seqlen_q
>
32
)
{
using
Kernel_traits
=
Flash_fwd_kernel_traits_mla_qkvfp8_TP1
<
576
,
64
,
64
,
8
,
T
,
To
,
512
>
;
using
Kernel_traits
=
Flash_fwd_kernel_traits_mla_qkvfp8_TP1
<
576
,
64
,
64
,
8
,
T
,
To
,
512
>
;
run_flash_splitkv_fwd_mla_fp8_tp1
<
Kernel_traits
,
flash
::
SharedStorageMLAFloat8_TP1
<
Kernel_traits
>>
(
params
,
stream
);
run_flash_splitkv_fwd_mla_fp8_tp1
<
Kernel_traits
,
flash
::
SharedStorageMLAFloat8_TP1
<
Kernel_traits
>>
(
params
,
stream
);
}
else
if
(
params
.
ngroups
>
16
)
{
}
else
if
(
params
.
seqlen_q
>
16
)
{
using
Kernel_traits
=
Flash_fwd_kernel_traits_mla_qkvfp8_TP4
<
576
,
32
,
64
,
4
,
T
,
To
,
512
>
;
using
Kernel_traits
=
Flash_fwd_kernel_traits_mla_qkvfp8_TP4
<
576
,
32
,
64
,
4
,
T
,
To
,
512
>
;
run_flash_splitkv_fwd_mla_fp8_tp4
<
Kernel_traits
,
flash
::
SharedStorageMLAFloat8_TP4
<
Kernel_traits
>>
(
params
,
stream
);
run_flash_splitkv_fwd_mla_fp8_tp4
<
Kernel_traits
,
flash
::
SharedStorageMLAFloat8_TP4
<
Kernel_traits
>>
(
params
,
stream
);
}
else
{
}
else
{
...
...
tests/test_flash_mla_fp8.py
View file @
b894e2da
...
@@ -206,6 +206,12 @@ def main(torch_dtype, is_prof=False):
...
@@ -206,6 +206,12 @@ def main(torch_dtype, is_prof=False):
for
s_q
in
[
1
]:
# MTP = 1, 2
for
s_q
in
[
1
]:
# MTP = 1, 2
for
varlen
in
[
False
]:
for
varlen
in
[
False
]:
test_flash_mla_fp8_e5m2
(
b
,
s_q
,
s
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
varlen
)
test_flash_mla_fp8_e5m2
(
b
,
s_q
,
s
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
varlen
)
for
b
in
[
3
,
6
,
9
,
12
,
15
,
18
,
21
,
24
,
32
,
64
,
128
,
256
]:
for
s
in
[
4000
]:
for
h_q
in
[
16
]:
for
s_q
in
[
1
,
2
,
3
,
4
]:
# MTP = 1, 2
for
varlen
in
[
False
]:
test_flash_mla_fp8_e5m2
(
b
,
s_q
,
s
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
varlen
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
...
...
tests/test_flash_mla_qkvfp8.py
View file @
b894e2da
...
@@ -227,18 +227,14 @@ def main(torch_dtype, is_prof=False):
...
@@ -227,18 +227,14 @@ def main(torch_dtype, is_prof=False):
for
s_q
in
[
1
]:
# MTP = 1, 2
for
s_q
in
[
1
]:
# MTP = 1, 2
for
varlen
in
[
False
]:
for
varlen
in
[
False
]:
test_flash_mla
(
b
,
s_q
,
s
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
varlen
,
False
,
torch_dtype
)
test_flash_mla
(
b
,
s_q
,
s
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
varlen
,
False
,
torch_dtype
)
# for b in [1]:
# for s in [128]:
for
b
in
[
3
,
6
,
9
,
12
,
15
,
18
,
21
,
24
,
32
,
64
,
128
,
256
]:
# for h_q in [128]:
for
s
in
[
4000
]:
# for s_q in [2]: # MTP = 1, 2
for
h_q
in
[
16
]:
# for varlen in [False]:
for
s_q
in
[
1
,
2
,
3
,
4
]:
# MTP = 1, 2
# test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen,False,torch_dtype)
for
varlen
in
[
False
]:
# for b in [1, 32]:
test_flash_mla
(
b
,
s_q
,
s
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
varlen
,
False
,
torch_dtype
)
# for s in [200, 1002, 2002, 1024, 2000, 4000, 32768, 65536]:
# for h_q in [4, 16, 32, 64]:
# for s_q in [1, 2]: # MTP = 1, 2
# for varlen in [False]:
# test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen,False,torch_dtype)
# '''
# '''
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment