Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ab1a4fa5
Unverified
Commit
ab1a4fa5
authored
Jun 15, 2025
by
JieXin Liang
Committed by
GitHub
Jun 14, 2025
Browse files
[fix] fix cutlass_mla_backend with cuda_graph and add sm_scale for sgl-kernel cutlass_mla (#7184)
parent
ed54bf9d
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
29 additions
and
17 deletions
+29
-17
python/sglang/srt/layers/attention/cutlass_mla_backend.py
python/sglang/srt/layers/attention/cutlass_mla_backend.py
+3
-2
sgl-kernel/benchmark/bench_cutlass_mla.py
sgl-kernel/benchmark/bench_cutlass_mla.py
+1
-0
sgl-kernel/csrc/attention/cutlass_mla_kernel.cu
sgl-kernel/csrc/attention/cutlass_mla_kernel.cu
+10
-9
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+1
-1
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+6
-2
sgl-kernel/python/sgl_kernel/attention.py
sgl-kernel/python/sgl_kernel/attention.py
+7
-2
sgl-kernel/tests/test_cutlass_mla.py
sgl-kernel/tests/test_cutlass_mla.py
+1
-1
No files found.
python/sglang/srt/layers/attention/cutlass_mla_backend.py
View file @
ab1a4fa5
...
...
@@ -108,7 +108,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
PAGE_SIZE
,
)
workspace_size
=
cutlass_mla_get_workspace_size
(
max_seqlen_pad
*
PAGE_SIZE
,
bs
max_seqlen_pad
*
PAGE_SIZE
,
bs
,
num_kv_splits
=
1
)
workspace
=
torch
.
empty
(
workspace_size
,
device
=
"cuda"
,
dtype
=
torch
.
uint8
...
...
@@ -138,7 +138,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
cuda_graph_kv_indices
=
block_kv_indices
workspace_size
=
cutlass_mla_get_workspace_size
(
cuda_graph_kv_indices
.
shape
[
1
]
*
PAGE_SIZE
,
max_bs
cuda_graph_kv_indices
.
shape
[
1
]
*
PAGE_SIZE
,
max_bs
,
num_kv_splits
=
1
)
self
.
cuda_graph_mla_workspace
=
torch
.
empty
(
workspace_size
,
device
=
"cuda"
,
dtype
=
torch
.
uint8
...
...
@@ -280,6 +280,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
seq_lens
=
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
page_table
=
self
.
forward_metadata
.
block_kv_indices
,
workspace
=
self
.
forward_metadata
.
workspace
,
num_kv_splits
=
1
,
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
sgl-kernel/benchmark/bench_cutlass_mla.py
View file @
ab1a4fa5
...
...
@@ -95,6 +95,7 @@ def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits):
seq_lens
,
block_table
,
workspace
,
1.44
,
num_kv_splits
,
),
quantiles
=
quantiles
,
...
...
sgl-kernel/csrc/attention/cutlass_mla_kernel.cu
View file @
ab1a4fa5
...
...
@@ -36,7 +36,8 @@ void cutlass_mla_decode(
torch
::
Tensor
const
&
kv_c_and_k_pe_cache
,
torch
::
Tensor
const
&
seq_lens
,
torch
::
Tensor
const
&
page_table
,
torch
::
Tensor
const
&
workspace
)
{
torch
::
Tensor
const
&
workspace
,
int64_t
num_kv_splits
)
{
TORCH_CHECK
(
false
,
"CUDA version must be >= 12.4 for cutlass_mla_decode"
);
}
int64_t
cutlass_mla_get_workspace_size
(
int64_t
max_seq_len
,
int64_t
num_batches
,
int64_t
sm_count
,
int64_t
num_kv_splits
)
{
...
...
@@ -98,6 +99,7 @@ typename T::Fmha::Arguments args_from_options(
at
::
Tensor
const
&
kv_c_and_k_pe_cache
,
at
::
Tensor
const
&
seq_lens
,
at
::
Tensor
const
&
page_table
,
double
sm_scale
,
int64_t
num_kv_splits
)
{
cutlass
::
KernelHardwareInfo
hw_info
;
hw_info
.
device_id
=
q_nope
.
device
().
index
();
...
...
@@ -115,10 +117,7 @@ typename T::Fmha::Arguments args_from_options(
auto
[
H
,
K
,
D
,
B
]
=
problem_shape
;
auto
[
D_latent
,
D_rope
]
=
D
;
// the scale is based on the non-absorbed sizes, change as appropriate
// we can't determine this parameter from the info we have, it's an input
int
D_non_latent
=
128
;
float
scale
=
1.0
/
sqrt
(
1.0
*
(
D_non_latent
+
D_rope
));
float
scale
=
float
(
sm_scale
);
using
StrideQ
=
typename
T
::
StrideQ
;
using
StrideK
=
typename
T
::
StrideK
;
...
...
@@ -183,11 +182,12 @@ void runMla(
at
::
Tensor
const
&
seq_lens
,
at
::
Tensor
const
&
page_table
,
at
::
Tensor
const
&
workspace
,
double
sm_scale
,
int64_t
num_kv_splits
,
cudaStream_t
stream
)
{
using
MlaSm100Type
=
MlaSm100
<
Element
,
IsPaged128
,
PersistenceOption
>
;
typename
MlaSm100Type
::
Fmha
fmha
;
auto
arguments
=
args_from_options
<
MlaSm100Type
>
(
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
num_kv_splits
);
auto
arguments
=
args_from_options
<
MlaSm100Type
>
(
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
sm_scale
,
num_kv_splits
);
CUTLASS_CHECK
(
fmha
.
can_implement
(
arguments
));
...
...
@@ -215,6 +215,7 @@ void cutlass_mla_decode(
torch
::
Tensor
const
&
seq_lens
,
torch
::
Tensor
const
&
page_table
,
torch
::
Tensor
const
&
workspace
,
double
sm_scale
,
int64_t
num_kv_splits
)
{
auto
in_dtype
=
q_nope
.
dtype
();
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
q_nope
.
get_device
()};
...
...
@@ -228,13 +229,13 @@ void cutlass_mla_decode(
DISPATCH_BOOL
(
num_kv_splits
<=
1
,
NotManualSplitKV
,
[
&
]
{
if
(
in_dtype
==
at
::
ScalarType
::
Half
)
{
runMla
<
cutlass
::
half_t
,
IsPaged128
,
IsPersistent
<
NotManualSplitKV
>>
(
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
num_kv_splits
,
stream
);
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
sm_scale
,
num_kv_splits
,
stream
);
}
else
if
(
in_dtype
==
at
::
ScalarType
::
BFloat16
)
{
runMla
<
cutlass
::
bfloat16_t
,
IsPaged128
,
IsPersistent
<
NotManualSplitKV
>>
(
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
num_kv_splits
,
stream
);
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
sm_scale
,
num_kv_splits
,
stream
);
}
else
if
(
in_dtype
==
at
::
ScalarType
::
Float8_e4m3fn
)
{
runMla
<
cutlass
::
float_e4m3_t
,
IsPaged128
,
IsPersistent
<
NotManualSplitKV
>>
(
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
num_kv_splits
,
stream
);
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
sm_scale
,
num_kv_splits
,
stream
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported input data type of MLA"
);
}
...
...
sgl-kernel/csrc/common_extension.cc
View file @
ab1a4fa5
...
...
@@ -60,7 +60,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m
.
impl
(
"merge_state_v2"
,
torch
::
kCUDA
,
&
merge_state_v2
);
m
.
def
(
"cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor "
"page_table, Tensor! workspace, int num_kv_splits) -> ()"
);
"page_table, Tensor! workspace,
float sm_scale,
int num_kv_splits) -> ()"
);
m
.
impl
(
"cutlass_mla_decode"
,
torch
::
kCUDA
,
&
cutlass_mla_decode
);
m
.
def
(
"cutlass_mla_get_workspace_size"
,
&
cutlass_mla_get_workspace_size
);
...
...
sgl-kernel/include/sgl_kernel_ops.h
100755 → 100644
View file @
ab1a4fa5
...
...
@@ -111,9 +111,13 @@ void cutlass_mla_decode(
torch
::
Tensor
const
&
seq_lens
,
torch
::
Tensor
const
&
page_table
,
torch
::
Tensor
const
&
workspace
,
int64_t
num_kv_splits
=
-
1
);
double
sm_scale
,
int64_t
num_kv_splits
=
1
/* Set to 1 to avoid cuda_graph issue by default. */
);
int64_t
cutlass_mla_get_workspace_size
(
int64_t
max_seq_len
,
int64_t
num_batches
,
int64_t
sm_count
=
0
,
int64_t
num_kv_splits
=
-
1
);
int64_t
max_seq_len
,
int64_t
num_batches
,
int64_t
sm_count
=
0
,
int64_t
num_kv_splits
=
1
/* Set to 1 to avoid cuda_graph issue by default. */
);
/*
* From csrc/elementwise
*/
...
...
sgl-kernel/python/sgl_kernel/attention.py
View file @
ab1a4fa5
...
...
@@ -58,7 +58,8 @@ def cutlass_mla_decode(
seq_lens
:
torch
.
Tensor
,
page_table
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
num_kv_splits
:
int
=
-
1
,
sm_scale
:
float
,
num_kv_splits
:
int
=
1
,
# Set to 1 to avoid cuda_graph issue by default.
)
->
torch
.
Tensor
:
assert
q_nope
.
ndim
==
3
,
f
"q_nope must be a 3D tensor, but got
{
q_nope
.
ndim
}
"
assert
q_pe
.
ndim
==
3
,
f
"q_pe must be a 3D tensor, but got
{
q_pe
.
ndim
}
"
...
...
@@ -118,13 +119,17 @@ def cutlass_mla_decode(
seq_lens
,
page_table
,
workspace
,
sm_scale
,
num_kv_splits
,
)
return
out
[:,
:
H
].
contiguous
()
def
cutlass_mla_get_workspace_size
(
max_seq_len
:
int
,
num_batches
:
int
,
sm_count
:
int
=
0
,
num_kv_splits
:
int
=
-
1
max_seq_len
:
int
,
num_batches
:
int
,
sm_count
:
int
=
0
,
num_kv_splits
:
int
=
1
,
# Set to 1 to avoid cuda_graph issue by default.
)
->
int
:
assert
max_seq_len
>
0
,
f
"max_seq_len must be greater than 0, got
{
max_seq_len
}
"
assert
num_batches
>
0
,
f
"num_batches must be greater than 0, got
{
num_batches
}
"
...
...
sgl-kernel/tests/test_cutlass_mla.py
View file @
ab1a4fa5
...
...
@@ -93,7 +93,7 @@ def test_cutlass_mla_decode(
out_ref
=
q
.
new_zeros
(
bs
,
h_q
,
dv
)
ref_mla
(
out_ref
,
q
,
kv_cache
,
scale
,
block_table
,
seq_lens
)
out
=
cutlass_mla_decode
(
q_nope
,
q_pe
,
kv_cache
,
seq_lens
,
block_table
,
workspace
,
num_kv_splits
q_nope
,
q_pe
,
kv_cache
,
seq_lens
,
block_table
,
workspace
,
scale
,
num_kv_splits
)
torch
.
testing
.
assert_close
(
out
,
out_ref
,
atol
=
1e-2
,
rtol
=
1e-2
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment