Commit 7f55c715 authored by Jiashi Li's avatar Jiashi Li
Browse files

Fix error message

parent e9b67321
...@@ -41,7 +41,7 @@ struct Arch { ...@@ -41,7 +41,7 @@ struct Arch {
} }
}; };
// DecodingAttnImplMeta - A struct to hold metadata for Decoding Attention Implementation (i.e. Hopper Dense BF16, Hopper Sparse FP8, etc.) // DecodingAttnImplMeta - A struct to hold metadata for Decoding Attention Implementation (i.e. SM90 Dense BF16, SM90 Sparse FP8, etc.)
struct DecodingAttnImplMeta { struct DecodingAttnImplMeta {
int num_sm_parts; int num_sm_parts;
int fixed_overhead_num_blocks; int fixed_overhead_num_blocks;
...@@ -334,7 +334,7 @@ fwd_kvcache_mla( ...@@ -334,7 +334,7 @@ fwd_kvcache_mla(
TORCH_CHECK(q_dtype == torch::kBFloat16, "Sparse FP8 MLA only supports BFloat16 on SM90"); TORCH_CHECK(q_dtype == torch::kBFloat16, "Sparse FP8 MLA only supports BFloat16 on SM90");
sm90::run_flash_splitkv_mla_fp8_sparse_kernel(params, stream); sm90::run_flash_splitkv_mla_fp8_sparse_kernel(params, stream);
} else { } else {
TORCH_CHECK(false, "Dense FP8 MLA is not supported on SM90"); TORCH_CHECK(false, "Only FP8 kvcahe is supported for sparse MLA on SM90");
} }
} else { } else {
if (is_fp8) { if (is_fp8) {
...@@ -347,7 +347,7 @@ fwd_kvcache_mla( ...@@ -347,7 +347,7 @@ fwd_kvcache_mla(
sm90::run_flash_splitkv_mla_kernel<cutlass::half_t>(params, stream); sm90::run_flash_splitkv_mla_kernel<cutlass::half_t>(params, stream);
#endif #endif
} else { } else {
TORCH_CHECK(false, "Unsupported tensor dtype for query"); TORCH_CHECK(false, "Unsupported dtype for dense MLA on SM90");
} }
} }
} }
......
...@@ -949,7 +949,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { ...@@ -949,7 +949,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized {
TensorC const& coord, TensorC const& coord,
TensorShape const& tensor_shape) { TensorShape const& tensor_shape) {
//TODO Performance of FlashMLA on hopper is dropped with latest cutlass, so here revert the to the old version. // TODO: Performance of FlashMLA on sm90 is dropped with latest cutlass, so here revert the to the old version.
// Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); }); // Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); });
auto copy_op = make_cotiled_copy( auto copy_op = make_cotiled_copy(
......
...@@ -953,7 +953,8 @@ struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized { ...@@ -953,7 +953,8 @@ struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized {
TensorR const& regs, TensorR const& regs,
TensorC const& coord, TensorC const& coord,
TensorShape const& tensor_shape) { TensorShape const& tensor_shape) {
//TODO Performance of FlashMLA on hopper is dropped with latest cutlass, so here revert the to the old version.
// TODO: Performance of FlashMLA on sm90 is dropped with latest cutlass, so here revert the to the old version.
// Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); }); // Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); });
auto copy_op = make_cotiled_copy( auto copy_op = make_cotiled_copy(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment