Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
0651671f
Commit
0651671f
authored
Jan 29, 2026
by
zhanghj2
Browse files
sparse decode支持head16
parent
b94fdd0f
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
37 additions
and
52 deletions
+37
-52
csrc/api/common.h
csrc/api/common.h
+3
-0
csrc/api/sparse_decode.h
csrc/api/sparse_decode.h
+12
-1
csrc/sm90/decode/sparse_fp8/config.h
csrc/sm90/decode/sparse_fp8/config.h
+3
-49
csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h16.cu
...decode/sparse_fp8/instantiations/model1_persistent_h16.cu
+8
-0
csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h16.cu
...90/decode/sparse_fp8/instantiations/v32_persistent_h16.cu
+7
-0
setup.py
setup.py
+2
-0
tests/test_flash_mla_sparse_decoding.py
tests/test_flash_mla_sparse_decoding.py
+2
-2
No files found.
csrc/api/common.h
View file @
0651671f
...
...
@@ -57,6 +57,9 @@ inline int int64_stride_to_int(int64_t orig_stride) {
} else if (NUM_HEADS == 64) { \
static constexpr int CONSTEXPR_NAME = 64; \
return __VA_ARGS__(); \
} else if (NUM_HEADS <= 16) { \
static constexpr int CONSTEXPR_NAME = 16; \
return __VA_ARGS__(); \
} else { \
TORCH_CHECK(false, "Unsupported num_heads_q: ", NUM_HEADS); \
} \
...
...
csrc/api/sparse_decode.h
View file @
0651671f
...
...
@@ -10,6 +10,7 @@
// Feature set of sparse decoding kernels
enum
class
DecodeFeatures
:
int
{
HEAD_16
,
HEAD_64
,
HEAD_128
,
...
...
@@ -41,6 +42,7 @@ public:
class
Decode_Sm90_Impl
:
public
DecodeImplBase
{
DECLARE_SUPPORTED_FEATURES
(
DecodeFeatures
::
HEAD_16
,
DecodeFeatures
::
HEAD_64
,
DecodeFeatures
::
HEAD_128
,
DecodeFeatures
::
HEAD_DIM_512
,
...
...
@@ -56,6 +58,13 @@ class Decode_Sm90_Impl : public DecodeImplBase {
public:
DecodeImplMeta
get_meta
(
int
h_q
,
int
s_q
)
override
{
Arch
arch
=
Arch
();
if
(
h_q
<=
16
)
{
return
{
std
::
max
(
arch
.
num_sms
*
2
/
s_q
/
(
h_q
/
16
),
1
),
5
,
64
};
}
return
{
std
::
max
(
arch
.
num_sms
/
s_q
/
(
h_q
/
64
),
1
),
5
,
...
...
@@ -218,7 +227,9 @@ sparse_attn_decode_interface(
}
std
::
vector
<
DecodeFeatures
>
features
;
if
(
h_q
==
64
)
{
if
(
h_q
<=
16
)
{
features
.
push_back
(
DecodeFeatures
::
HEAD_16
);
}
else
if
(
h_q
==
64
)
{
features
.
push_back
(
DecodeFeatures
::
HEAD_64
);
}
else
if
(
h_q
==
128
)
{
features
.
push_back
(
DecodeFeatures
::
HEAD_128
);
...
...
csrc/sm90/decode/sparse_fp8/config.h
View file @
0651671f
...
...
@@ -100,58 +100,9 @@ struct SharedMemoryPlan {
cute
::
array_aligned
<
ElementAccum
,
cute
::
cosize_v
<
SmemLayoutRow
>>
smem_row_max
;
};
// struct {
// cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutO>> smem_o;
// // cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutP>> smem_p;
// // cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_row_sum;
// // cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_row_max;
// };
// struct {
// cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutQ>> smem_q;
// };
};
// array_aligned<bf16, cosize_v<SmemLayoutQ>> q;
// union {
// array_aligned<bf16, cosize_v<SmemLayoutK>> k[NUM_K_BUFS];
// array_aligned<bf16, cosize_v<SmemLayoutOBuf>> oBuf;
// array_aligned<float, cosize_v<SmemLayoutOAccumBuf>> oAccumBuf;
// } u;
// CUTE_ALIGNAS(1024) array_aligned<bf16, cosize_v<SmemLayoutS>> s;
// bool is_kv_valid[NUM_K_BUFS][TOPK_BLOCK_SIZE];
// float sM[BLOCK_M], sL[BLOCK_M], sScale[BLOCK_M], sOScale[BLOCK_M];
// transac_bar_t bar_q, bar_k_local_ready[NUM_K_BUFS], bar_k_remote_ready[NUM_K_BUFS], bar_k_avail[NUM_K_BUFS];
};
// template<
// typename Shape_Q, typename TMA_Q
// >
// using TiledMMA_QK = decltype(make_tiled_mma(
// GMMA::MMA_64x64x16_F32BF16BF16_SS<GMMA::Major::K, GMMA::Major::K>{},
// Layout<Shape<_1, _1, _1>>{}
// ));
// using TiledMMA_QK_rQ = decltype(make_tiled_mma(
// GMMA::MMA_64x64x16_F32BF16BF16_RS<GMMA::Major::K, GMMA::Major::K>{},
// Layout<Shape<_1, _1, _1>>{}
// ));
// using TiledMMA_PV_LocalP = decltype(make_tiled_mma(
// GMMA::MMA_64x256x16_F32BF16BF16_RS<GMMA::Major::K, GMMA::Major::MN>{},
// Layout<Shape<_1, _1, _1>>{}
// ));
// using TiledMMA_PV_RemoteP = decltype(make_tiled_mma(
// GMMA::MMA_64x256x16_F32BF16BF16_SS<GMMA::Major::K, GMMA::Major::MN>{},
// Layout<Shape<_1, _1, _1>>{}
// ));
static
__device__
__forceinline__
void
compute_attn_1rowblock_splitkv_sparse_mla_fp8
(
const
SparseAttnDecodeParams
&
params
,
const
DecodingSchedMeta
&
sched_meta
,
int
batch_idx
);
...
...
@@ -163,4 +114,7 @@ static void run(const SparseAttnDecodeParams ¶ms);
};
}
csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h16.cu
0 → 100644
View file @
0651671f
#include "../splitkv_mla.cuh"
namespace
sm90
::
decode
::
sparse_fp8
{
template
void
run_flash_splitkv_mla_fp8_sparse_kernel
<
ModelType
::
MODEL1
,
16
>(
const
SparseAttnDecodeParams
&
params
);
}
csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h16.cu
0 → 100644
View file @
0651671f
#include "../splitkv_mla.cuh"
namespace
sm90
::
decode
::
sparse_fp8
{
template
void
run_flash_splitkv_mla_fp8_sparse_kernel
<
ModelType
::
V32
,
16
>(
const
SparseAttnDecodeParams
&
params
);
}
setup.py
View file @
0651671f
...
...
@@ -58,8 +58,10 @@ ext_modules.append(
"csrc/sm90/decode/dense/instantiations/bf16.cu"
,
# # sm90 sparse decode
"csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h16.cu"
,
"csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h64.cu"
,
"csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h128.cu"
,
"csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h16.cu"
,
"csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h64.cu"
,
"csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h128.cu"
,
...
...
tests/test_flash_mla_sparse_decoding.py
View file @
0651671f
...
...
@@ -27,7 +27,7 @@ def gen_testcase() -> List[RawTestParam]:
for
have_extra_k
in
([
False
,
True
]
if
d_qk
==
512
else
[
False
]):
for
have_extra_topk_len
in
([
False
,
True
]
if
have_extra_k
else
[
False
]):
for
have_topk_len
in
([
False
,
True
]
if
d_qk
==
512
else
[
False
]):
for
h_q
in
[
64
,
128
]:
for
h_q
in
[
16
,
64
,
128
]:
cur_correctness_cases
=
[
RawTestParam
(
b
,
h_q
,
s_q
,
1
,
s_k
,
is_varlen
,
topk
,
have_topk_length
=
have_topk_len
,
...
...
@@ -119,7 +119,7 @@ def gen_testcase() -> List[RawTestParam]:
]
+
[
# Peak perf cases
RawTestParam
(
74
*
2
,
h_q
,
2
,
1
,
32768
,
True
,
topk
=
16384
,
d_qk
=
d_qk
)
for
h_q
in
[
64
,
128
]
for
h_q
in
[
16
,
64
,
128
]
for
d_qk
in
[
512
,
576
]
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment