Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
b94fdd0f
Commit
b94fdd0f
authored
Jan 29, 2026
by
zhanghj2
Browse files
prefill支持head 16
parent
38421051
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
55 deletions
+9
-55
csrc/api/sparse_fwd.h
csrc/api/sparse_fwd.h
+5
-52
tests/test_flash_mla_sparse_prefill.py
tests/test_flash_mla_sparse_prefill.py
+4
-3
No files found.
csrc/api/sparse_fwd.h
View file @
b94fdd0f
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
enum
class
FwdFeatures
:
int
{
enum
class
FwdFeatures
:
int
{
HEAD_16
,
HEAD_64
,
HEAD_64
,
HEAD_128
,
HEAD_128
,
...
@@ -26,6 +27,7 @@ class FwdImplBase : public ImplBase<
...
@@ -26,6 +27,7 @@ class FwdImplBase : public ImplBase<
class
Fwd_Sm90_Impl
:
public
FwdImplBase
{
class
Fwd_Sm90_Impl
:
public
FwdImplBase
{
DECLARE_SUPPORTED_FEATURES
(
DECLARE_SUPPORTED_FEATURES
(
FwdFeatures
::
HEAD_16
,
FwdFeatures
::
HEAD_64
,
FwdFeatures
::
HEAD_64
,
FwdFeatures
::
HEAD_128
,
FwdFeatures
::
HEAD_128
,
FwdFeatures
::
HEAD_DIM_512
,
FwdFeatures
::
HEAD_DIM_512
,
...
@@ -45,57 +47,6 @@ protected:
...
@@ -45,57 +47,6 @@ protected:
}
}
};
};
class
Fwd_Sm100_Head64_Impl
:
public
FwdImplBase
{
DECLARE_SUPPORTED_FEATURES
(
FwdFeatures
::
HEAD_64
,
FwdFeatures
::
HEAD_DIM_512
,
FwdFeatures
::
HEAD_DIM_576
,
FwdFeatures
::
ATTN_SINK
,
FwdFeatures
::
SINK_LSE
,
FwdFeatures
::
TOPK_LENGTH
)
protected:
void
run_
(
const
SparseAttnFwdParams
&
params
,
const
std
::
vector
<
FeatureT
>
&
required_features
)
override
{
DISPATCH_HEAD_DIM
(
params
.
d_qk
,
HEAD_DIM_QK
,
[
&
]()
{
// sm100::fwd::head64::run_fwd_phase1_kernel<HEAD_DIM_QK>(params);
});
}
};
class
Fwd_Sm100_Head128_Impl
:
public
FwdImplBase
{
DECLARE_SUPPORTED_FEATURES
(
FwdFeatures
::
HEAD_128
,
FwdFeatures
::
HEAD_DIM_512
,
FwdFeatures
::
HEAD_DIM_576
,
FwdFeatures
::
ATTN_SINK
,
FwdFeatures
::
SINK_LSE
,
FwdFeatures
::
TOPK_LENGTH
)
protected:
void
run_
(
const
SparseAttnFwdParams
&
params
,
const
std
::
vector
<
FeatureT
>
&
required_features
)
override
{
DISPATCH_HEAD_DIM
(
params
.
d_qk
,
HEAD_DIM_QK
,
[
&
]()
{
// sm100::fwd::head128::run_fwd_phase1_kernel<HEAD_DIM_QK>(params);
});
}
};
class
Fwd_Sm100_Head128_Small_TopK_Impl
:
public
FwdImplBase
{
DECLARE_SUPPORTED_FEATURES
(
FwdFeatures
::
HEAD_128
,
FwdFeatures
::
HEAD_DIM_512
,
FwdFeatures
::
ATTN_SINK
,
FwdFeatures
::
SINK_LSE
,
FwdFeatures
::
TOPK_LENGTH
)
protected:
void
run_
(
const
SparseAttnFwdParams
&
params
,
const
std
::
vector
<
FeatureT
>
&
required_features
)
override
{
// sm100::fwd_for_small_topk::head128::run_fwd_for_small_topk_phase1_kernel<SparseAttnFwdMode::Prefill, 512>(params);
}
};
static
std
::
vector
<
at
::
Tensor
>
sparse_attn_prefill_interface
(
static
std
::
vector
<
at
::
Tensor
>
sparse_attn_prefill_interface
(
const
at
::
Tensor
&
q
,
const
at
::
Tensor
&
q
,
const
at
::
Tensor
&
kv
,
const
at
::
Tensor
&
kv
,
...
@@ -187,7 +138,9 @@ static std::vector<at::Tensor> sparse_attn_prefill_interface(
...
@@ -187,7 +138,9 @@ static std::vector<at::Tensor> sparse_attn_prefill_interface(
};
};
std
::
vector
<
FwdFeatures
>
required_features
;
std
::
vector
<
FwdFeatures
>
required_features
;
if
(
h_q
==
64
)
{
if
(
h_q
==
16
)
{
required_features
.
push_back
(
FwdFeatures
::
HEAD_16
);
}
else
if
(
h_q
==
64
)
{
required_features
.
push_back
(
FwdFeatures
::
HEAD_64
);
required_features
.
push_back
(
FwdFeatures
::
HEAD_64
);
}
else
if
(
h_q
==
128
)
{
}
else
if
(
h_q
==
128
)
{
required_features
.
push_back
(
FwdFeatures
::
HEAD_128
);
required_features
.
push_back
(
FwdFeatures
::
HEAD_128
);
...
...
tests/test_flash_mla_sparse_prefill.py
View file @
b94fdd0f
...
@@ -64,7 +64,7 @@ if __name__ == '__main__':
...
@@ -64,7 +64,7 @@ if __name__ == '__main__':
TestParam
(
s_q
,
s_kv
,
topk
,
h_q
=
h_q
,
num_runs
=
0
,
d_qk
=
d_qk
)
TestParam
(
s_q
,
s_kv
,
topk
,
h_q
=
h_q
,
num_runs
=
0
,
d_qk
=
d_qk
)
for
d_qk
in
[
512
,
576
]
for
d_qk
in
[
512
,
576
]
for
h_q
in
[
for
h_q
in
[
128
,
64
16
,
128
,
64
]
]
for
s_kv
,
topk
in
[
for
s_kv
,
topk
in
[
# Regular shapes
# Regular shapes
...
@@ -92,7 +92,7 @@ if __name__ == '__main__':
...
@@ -92,7 +92,7 @@ if __name__ == '__main__':
TestParam
(
s_q
,
s_kv
,
topk
,
h_q
=
h_q
,
num_runs
=
0
,
have_attn_sink
=
have_attn_sink
,
have_topk_length
=
have_topk_length
,
d_qk
=
d_qk
)
TestParam
(
s_q
,
s_kv
,
topk
,
h_q
=
h_q
,
num_runs
=
0
,
have_attn_sink
=
have_attn_sink
,
have_topk_length
=
have_topk_length
,
d_qk
=
d_qk
)
for
d_qk
in
[
512
,
576
]
for
d_qk
in
[
512
,
576
]
for
h_q
in
[
for
h_q
in
[
128
,
64
16
,
128
,
64
]
]
for
s_kv
,
topk
in
[
for
s_kv
,
topk
in
[
(
592
,
128
),
(
592
,
128
),
...
@@ -114,7 +114,7 @@ if __name__ == '__main__':
...
@@ -114,7 +114,7 @@ if __name__ == '__main__':
TestParam
(
s_q
,
s_kv
,
topk
,
h_q
=
h_q
,
is_all_indices_invalid
=
True
,
num_runs
=
0
,
have_attn_sink
=
True
,
have_topk_length
=
True
,
d_qk
=
d_qk
)
TestParam
(
s_q
,
s_kv
,
topk
,
h_q
=
h_q
,
is_all_indices_invalid
=
True
,
num_runs
=
0
,
have_attn_sink
=
True
,
have_topk_length
=
True
,
d_qk
=
d_qk
)
for
d_qk
in
[
512
,
576
]
for
d_qk
in
[
512
,
576
]
for
h_q
in
[
for
h_q
in
[
128
,
64
16
,
128
,
64
]
]
for
s_q
,
s_kv
,
topk
in
[
for
s_q
,
s_kv
,
topk
in
[
(
1
,
128
,
128
),
(
1
,
128
,
128
),
...
@@ -150,6 +150,7 @@ if __name__ == '__main__':
...
@@ -150,6 +150,7 @@ if __name__ == '__main__':
(
512
,
64
,
512
,
[
8192
,
32768
,
49152
,
65536
]),
(
512
,
64
,
512
,
[
8192
,
32768
,
49152
,
65536
]),
# MODEL1 CONFIG2
# MODEL1 CONFIG2
(
512
,
128
,
1024
,
[
8192
,
32768
,
49152
,
65536
]),
(
512
,
128
,
1024
,
[
8192
,
32768
,
49152
,
65536
]),
(
512
,
16
,
1024
,
[
8192
,
32768
,
49152
,
65536
]),
]
]
performance_cases
=
[
performance_cases
=
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment