Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4c22ebe2
Unverified
Commit
4c22ebe2
authored
Sep 06, 2025
by
hlu1
Committed by
GitHub
Sep 06, 2025
Browse files
Disable kernel cutlass_mla_decode on SM103 (#10058)
Signed-off-by:
Hao Lu
<
14827759+hlu1@users.noreply.github.com
>
parent
a5a03209
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
2 deletions
+8
-2
sgl-kernel/csrc/attention/cutlass_mla_kernel.cu
sgl-kernel/csrc/attention/cutlass_mla_kernel.cu
+5
-0
sgl-kernel/tests/test_cutlass_mla.py
sgl-kernel/tests/test_cutlass_mla.py
+3
-2
No files found.
sgl-kernel/csrc/attention/cutlass_mla_kernel.cu
View file @
4c22ebe2
...
...
@@ -26,6 +26,7 @@ limitations under the License.
#include "cutlass_sm100_mla/device/sm100_mla.hpp"
#include "cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp"
#include "utils.h"
// clang-format off
#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040
...
...
@@ -217,6 +218,10 @@ void cutlass_mla_decode(
torch
::
Tensor
const
&
workspace
,
double
sm_scale
,
int64_t
num_kv_splits
)
{
auto
sm_version
=
getSMVersion
();
// On SM103a, half of the accuracy tests are failing.
TORCH_CHECK
(
sm_version
==
100
,
"cutlass_mla_decode is only supported on compute capability 10.0, but found sm version "
,
sm_version
);
auto
in_dtype
=
q_nope
.
dtype
();
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
q_nope
.
get_device
()};
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
q_nope
.
get_device
());
...
...
sgl-kernel/tests/test_cutlass_mla.py
View file @
4c22ebe2
...
...
@@ -4,9 +4,10 @@ import torch.nn.functional as F
from
sgl_kernel
import
cutlass_mla_decode
,
cutlass_mla_get_workspace_size
from
torch
import
Tensor
if
torch
.
cuda
.
get_device_capability
()
<
(
10
,
0
):
# Disable tests on SM103 until the accuracy issues are fixed.
if
torch
.
cuda
.
get_device_capability
()
!=
(
10
,
0
):
pytest
.
skip
(
reason
=
"Cutlass MLA Requires compute capability of 10
or above
."
,
reason
=
"Cutlass MLA Requires compute capability of 10."
,
allow_module_level
=
True
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment