Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
0651671f
"vscode:/vscode.git/clone" did not exist on "34489f466e8f6ddf3b7318cd1556a5df8759c005"
Commit
0651671f
authored
Jan 29, 2026
by
zhanghj2
Browse files
sparse decode支持head16
parent
b94fdd0f
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
37 additions
and
52 deletions
+37
-52
csrc/api/common.h
csrc/api/common.h
+3
-0
csrc/api/sparse_decode.h
csrc/api/sparse_decode.h
+12
-1
csrc/sm90/decode/sparse_fp8/config.h
csrc/sm90/decode/sparse_fp8/config.h
+3
-49
csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h16.cu
...decode/sparse_fp8/instantiations/model1_persistent_h16.cu
+8
-0
csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h16.cu
...90/decode/sparse_fp8/instantiations/v32_persistent_h16.cu
+7
-0
setup.py
setup.py
+2
-0
tests/test_flash_mla_sparse_decoding.py
tests/test_flash_mla_sparse_decoding.py
+2
-2
No files found.
csrc/api/common.h
View file @
0651671f
...
...
@@ -57,6 +57,9 @@ inline int int64_stride_to_int(int64_t orig_stride) {
} else if (NUM_HEADS == 64) { \
static constexpr int CONSTEXPR_NAME = 64; \
return __VA_ARGS__(); \
} else if (NUM_HEADS <= 16) { \
static constexpr int CONSTEXPR_NAME = 16; \
return __VA_ARGS__(); \
} else { \
TORCH_CHECK(false, "Unsupported num_heads_q: ", NUM_HEADS); \
} \
...
...
csrc/api/sparse_decode.h
View file @
0651671f
...
...
@@ -10,6 +10,7 @@
// Feature set of sparse decoding kernels
enum
class
DecodeFeatures
:
int
{
HEAD_16
,
HEAD_64
,
HEAD_128
,
...
...
@@ -41,6 +42,7 @@ public:
class
Decode_Sm90_Impl
:
public
DecodeImplBase
{
DECLARE_SUPPORTED_FEATURES
(
DecodeFeatures
::
HEAD_16
,
DecodeFeatures
::
HEAD_64
,
DecodeFeatures
::
HEAD_128
,
DecodeFeatures
::
HEAD_DIM_512
,
...
...
@@ -56,6 +58,13 @@ class Decode_Sm90_Impl : public DecodeImplBase {
public:
DecodeImplMeta
get_meta
(
int
h_q
,
int
s_q
)
override
{
Arch
arch
=
Arch
();
if
(
h_q
<=
16
)
{
return
{
std
::
max
(
arch
.
num_sms
*
2
/
s_q
/
(
h_q
/
16
),
1
),
5
,
64
};
}
return
{
std
::
max
(
arch
.
num_sms
/
s_q
/
(
h_q
/
64
),
1
),
5
,
...
...
@@ -218,7 +227,9 @@ sparse_attn_decode_interface(
}
std
::
vector
<
DecodeFeatures
>
features
;
if
(
h_q
==
64
)
{
if
(
h_q
<=
16
)
{
features
.
push_back
(
DecodeFeatures
::
HEAD_16
);
}
else
if
(
h_q
==
64
)
{
features
.
push_back
(
DecodeFeatures
::
HEAD_64
);
}
else
if
(
h_q
==
128
)
{
features
.
push_back
(
DecodeFeatures
::
HEAD_128
);
...
...
csrc/sm90/decode/sparse_fp8/config.h
View file @
0651671f
...
...
@@ -100,58 +100,9 @@ struct SharedMemoryPlan {
cute
::
array_aligned
<
ElementAccum
,
cute
::
cosize_v
<
SmemLayoutRow
>>
smem_row_max
;
};
// struct {
// cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutO>> smem_o;
// // cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutP>> smem_p;
// // cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_row_sum;
// // cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_row_max;
// };
// struct {
// cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutQ>> smem_q;
// };
};
// array_aligned<bf16, cosize_v<SmemLayoutQ>> q;
// union {
// array_aligned<bf16, cosize_v<SmemLayoutK>> k[NUM_K_BUFS];
// array_aligned<bf16, cosize_v<SmemLayoutOBuf>> oBuf;
// array_aligned<float, cosize_v<SmemLayoutOAccumBuf>> oAccumBuf;
// } u;
// CUTE_ALIGNAS(1024) array_aligned<bf16, cosize_v<SmemLayoutS>> s;
// bool is_kv_valid[NUM_K_BUFS][TOPK_BLOCK_SIZE];
// float sM[BLOCK_M], sL[BLOCK_M], sScale[BLOCK_M], sOScale[BLOCK_M];
// transac_bar_t bar_q, bar_k_local_ready[NUM_K_BUFS], bar_k_remote_ready[NUM_K_BUFS], bar_k_avail[NUM_K_BUFS];
};
// template<
// typename Shape_Q, typename TMA_Q
// >
// using TiledMMA_QK = decltype(make_tiled_mma(
// GMMA::MMA_64x64x16_F32BF16BF16_SS<GMMA::Major::K, GMMA::Major::K>{},
// Layout<Shape<_1, _1, _1>>{}
// ));
// using TiledMMA_QK_rQ = decltype(make_tiled_mma(
// GMMA::MMA_64x64x16_F32BF16BF16_RS<GMMA::Major::K, GMMA::Major::K>{},
// Layout<Shape<_1, _1, _1>>{}
// ));
// using TiledMMA_PV_LocalP = decltype(make_tiled_mma(
// GMMA::MMA_64x256x16_F32BF16BF16_RS<GMMA::Major::K, GMMA::Major::MN>{},
// Layout<Shape<_1, _1, _1>>{}
// ));
// using TiledMMA_PV_RemoteP = decltype(make_tiled_mma(
// GMMA::MMA_64x256x16_F32BF16BF16_SS<GMMA::Major::K, GMMA::Major::MN>{},
// Layout<Shape<_1, _1, _1>>{}
// ));
static
__device__
__forceinline__
void
compute_attn_1rowblock_splitkv_sparse_mla_fp8
(
const
SparseAttnDecodeParams
&
params
,
const
DecodingSchedMeta
&
sched_meta
,
int
batch_idx
);
...
...
@@ -163,4 +114,7 @@ static void run(const SparseAttnDecodeParams ¶ms);
};
}
csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h16.cu
0 → 100644
View file @
0651671f
#include "../splitkv_mla.cuh"
namespace
sm90
::
decode
::
sparse_fp8
{
template
void
run_flash_splitkv_mla_fp8_sparse_kernel
<
ModelType
::
MODEL1
,
16
>(
const
SparseAttnDecodeParams
&
params
);
}
csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h16.cu
0 → 100644
View file @
0651671f
#include "../splitkv_mla.cuh"
namespace
sm90
::
decode
::
sparse_fp8
{
template
void
run_flash_splitkv_mla_fp8_sparse_kernel
<
ModelType
::
V32
,
16
>(
const
SparseAttnDecodeParams
&
params
);
}
setup.py
View file @
0651671f
...
...
@@ -58,8 +58,10 @@ ext_modules.append(
"csrc/sm90/decode/dense/instantiations/bf16.cu"
,
# # sm90 sparse decode
"csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h16.cu"
,
"csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h64.cu"
,
"csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h128.cu"
,
"csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h16.cu"
,
"csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h64.cu"
,
"csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h128.cu"
,
...
...
tests/test_flash_mla_sparse_decoding.py
View file @
0651671f
...
...
@@ -27,7 +27,7 @@ def gen_testcase() -> List[RawTestParam]:
for
have_extra_k
in
([
False
,
True
]
if
d_qk
==
512
else
[
False
]):
for
have_extra_topk_len
in
([
False
,
True
]
if
have_extra_k
else
[
False
]):
for
have_topk_len
in
([
False
,
True
]
if
d_qk
==
512
else
[
False
]):
for
h_q
in
[
64
,
128
]:
for
h_q
in
[
16
,
64
,
128
]:
cur_correctness_cases
=
[
RawTestParam
(
b
,
h_q
,
s_q
,
1
,
s_k
,
is_varlen
,
topk
,
have_topk_length
=
have_topk_len
,
...
...
@@ -119,7 +119,7 @@ def gen_testcase() -> List[RawTestParam]:
]
+
[
# Peak perf cases
RawTestParam
(
74
*
2
,
h_q
,
2
,
1
,
32768
,
True
,
topk
=
16384
,
d_qk
=
d_qk
)
for
h_q
in
[
64
,
128
]
for
h_q
in
[
16
,
64
,
128
]
for
d_qk
in
[
512
,
576
]
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment