Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
7f55c715
Commit
7f55c715
authored
Sep 30, 2025
by
Jiashi Li
Browse files
Fix error message
parent
e9b67321
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
6 additions
and
5 deletions
+6
-5
csrc/pybind.cpp
csrc/pybind.cpp
+3
-3
csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp
...ense/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp
+1
-1
csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp
.../kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp
+2
-1
No files found.
csrc/pybind.cpp
View file @
7f55c715
...
@@ -41,7 +41,7 @@ struct Arch {
...
@@ -41,7 +41,7 @@ struct Arch {
}
}
};
};
// DecodingAttnImplMeta - A struct to hold metadata for Decoding Attention Implementation (i.e.
Hopper
Dense BF16,
Hopper
Sparse FP8, etc.)
// DecodingAttnImplMeta - A struct to hold metadata for Decoding Attention Implementation (i.e.
SM90
Dense BF16,
SM90
Sparse FP8, etc.)
struct
DecodingAttnImplMeta
{
struct
DecodingAttnImplMeta
{
int
num_sm_parts
;
int
num_sm_parts
;
int
fixed_overhead_num_blocks
;
int
fixed_overhead_num_blocks
;
...
@@ -334,7 +334,7 @@ fwd_kvcache_mla(
...
@@ -334,7 +334,7 @@ fwd_kvcache_mla(
TORCH_CHECK
(
q_dtype
==
torch
::
kBFloat16
,
"Sparse FP8 MLA only supports BFloat16 on SM90"
);
TORCH_CHECK
(
q_dtype
==
torch
::
kBFloat16
,
"Sparse FP8 MLA only supports BFloat16 on SM90"
);
sm90
::
run_flash_splitkv_mla_fp8_sparse_kernel
(
params
,
stream
);
sm90
::
run_flash_splitkv_mla_fp8_sparse_kernel
(
params
,
stream
);
}
else
{
}
else
{
TORCH_CHECK
(
false
,
"
Dense FP8 MLA is not supported
on SM90"
);
TORCH_CHECK
(
false
,
"
Only FP8 kvcahe is supported for sparse MLA
on SM90"
);
}
}
}
else
{
}
else
{
if
(
is_fp8
)
{
if
(
is_fp8
)
{
...
@@ -347,7 +347,7 @@ fwd_kvcache_mla(
...
@@ -347,7 +347,7 @@ fwd_kvcache_mla(
sm90
::
run_flash_splitkv_mla_kernel
<
cutlass
::
half_t
>
(
params
,
stream
);
sm90
::
run_flash_splitkv_mla_kernel
<
cutlass
::
half_t
>
(
params
,
stream
);
#endif
#endif
}
else
{
}
else
{
TORCH_CHECK
(
false
,
"Unsupported
tensor
dtype for
query
"
);
TORCH_CHECK
(
false
,
"Unsupported dtype for
dense MLA on SM90
"
);
}
}
}
}
}
}
...
...
csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp
View file @
7f55c715
...
@@ -949,7 +949,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
...
@@ -949,7 +949,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
TensorC
const
&
coord
,
TensorC
const
&
coord
,
TensorShape
const
&
tensor_shape
)
{
TensorShape
const
&
tensor_shape
)
{
//TODO Performance of FlashMLA on
hopper
is dropped with latest cutlass, so here revert the to the old version.
//
TODO
:
Performance of FlashMLA on
sm90
is dropped with latest cutlass, so here revert the to the old version.
// Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); });
// Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); });
auto
copy_op
=
make_cotiled_copy
(
auto
copy_op
=
make_cotiled_copy
(
...
...
csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp
View file @
7f55c715
...
@@ -953,7 +953,8 @@ struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized {
...
@@ -953,7 +953,8 @@ struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized {
TensorR
const
&
regs
,
TensorR
const
&
regs
,
TensorC
const
&
coord
,
TensorC
const
&
coord
,
TensorShape
const
&
tensor_shape
)
{
TensorShape
const
&
tensor_shape
)
{
//TODO Performance of FlashMLA on hopper is dropped with latest cutlass, so here revert the to the old version.
// TODO: Performance of FlashMLA on sm90 is dropped with latest cutlass, so here revert the to the old version.
// Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); });
// Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); });
auto
copy_op
=
make_cotiled_copy
(
auto
copy_op
=
make_cotiled_copy
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment