Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
a8393a04
Commit
a8393a04
authored
Feb 22, 2026
by
zhanghj2
Browse files
支持nhead<16
parent
945ced44
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
4 additions
and
4 deletions
+4
-4
csrc/api/sparse_fwd.h
csrc/api/sparse_fwd.h
+1
-1
csrc/sm90/decode/sparse_fp8/splitkv_mla.cuh
csrc/sm90/decode/sparse_fp8/splitkv_mla.cuh
+1
-1
csrc/sm90/prefill/sparse/phase1.cuh
csrc/sm90/prefill/sparse/phase1.cuh
+2
-2
No files found.
csrc/api/sparse_fwd.h
View file @
a8393a04
...
...
@@ -137,7 +137,7 @@ static std::vector<at::Tensor> sparse_attn_prefill_interface(
};
std
::
vector
<
FwdFeatures
>
required_features
;
if
(
h_q
=
=
16
)
{
if
(
h_q
<
=
16
)
{
required_features
.
push_back
(
FwdFeatures
::
HEAD_16
);
}
else
if
(
h_q
==
64
)
{
required_features
.
push_back
(
FwdFeatures
::
HEAD_64
);
...
...
csrc/sm90/decode/sparse_fp8/splitkv_mla.cuh
View file @
a8393a04
...
...
@@ -934,7 +934,7 @@ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::run(const SparseAttnDecodeParams &pa
KU_ASSERT
(
params
.
topk
%
TOPK_BLOCK_SIZE
==
0
);
KU_ASSERT
(
params
.
d_qk
==
HEAD_DIM_K
);
KU_ASSERT
(
params
.
d_v
==
HEAD_DIM_V
);
KU_ASSERT
(
params
.
h_q
%
BLOCK_M
==
0
);
//
KU_ASSERT(params.h_q % BLOCK_M == 0);
if
constexpr
(
MODEL_TYPE
==
ModelType
::
MODEL1
)
{
constexpr
int
BYTES_PER_TOKEN
=
HEAD_DIM_NOPE
+
2
*
HEAD_DIM_ROPE
+
8
;
KU_ASSERT
(
params
.
stride_kv_row
==
BYTES_PER_TOKEN
,
"Each page block in KV cache must be contiguous for head64 sparse fp8 decoding attention in MODEL1"
);
// Each block must be contiguous
...
...
csrc/sm90/prefill/sparse/phase1.cuh
View file @
a8393a04
...
...
@@ -483,10 +483,10 @@ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::run(const SparseAttnFwdParams ¶
KU_ASSERT
(
params
.
h_kv
==
1
);
KU_ASSERT
(
params
.
topk
%
(
2
*
B_TOPK
)
==
0
);
// To save some boundry checkings
KU_ASSERT
(
params
.
topk
>
0
);
KU_ASSERT
(
params
.
h_q
%
B_H
==
0
);
//
KU_ASSERT(params.h_q % B_H == 0);
auto
kernel
=
&
sparse_attn_fwd_kernel
<
KernelTemplate
<
D_QK
,
HAVE_TOPK_LENGTH
>>
;
constexpr
size_t
smem_size
=
16384
+
4096
;
// 做了lds复用
dim3
grid
(
params
.
s_q
,
params
.
h_q
/
B_H
,
1
);
dim3
grid
(
params
.
s_q
,
(
params
.
h_q
+
B_H
-
1
)
/
B_H
,
1
);
kernel
<<<
grid
,
NUM_THREADS
,
smem_size
,
params
.
stream
>>>
(
params
);
KU_CHECK_KERNEL_LAUNCH
();
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment