Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
aa46ed34
Unverified
Commit
aa46ed34
authored
Jun 13, 2025
by
fzyzcjy
Committed by
GitHub
Jun 13, 2025
Browse files
Remove 200us slow concat kernel (part 1: kernel) (#7145)
parent
2f4ec752
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
79 additions
and
48 deletions
+79
-48
sgl-kernel/benchmark/bench_cutlass_mla.py
sgl-kernel/benchmark/bench_cutlass_mla.py
+16
-5
sgl-kernel/csrc/attention/cutlass_mla_kernel.cu
sgl-kernel/csrc/attention/cutlass_mla_kernel.cu
+29
-20
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+1
-1
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+2
-1
sgl-kernel/python/sgl_kernel/attention.py
sgl-kernel/python/sgl_kernel/attention.py
+26
-20
sgl-kernel/tests/test_cutlass_mla.py
sgl-kernel/tests/test_cutlass_mla.py
+5
-1
No files found.
sgl-kernel/benchmark/bench_cutlass_mla.py
View file @
aa46ed34
...
...
@@ -38,6 +38,7 @@ configs = list(itertools.product(bs_range, qlen_range))
)
def
benchmark
(
batch_size
,
seq_len
,
provider
,
block_size
,
num_kv_splits
):
d
=
576
dn
=
64
dv
=
512
h_q_map
=
{
...
...
@@ -63,7 +64,11 @@ def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits):
pack_factor
=
128
//
block_size
block_num
=
((
block_num
+
pack_factor
-
1
)
//
pack_factor
)
*
pack_factor
q
=
torch
.
randn
(
batch_size
,
h_q
,
d
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
*
100.0
qn
=
(
torch
.
randn
(
h_q
,
batch_size
,
d
-
dn
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
*
100.0
)
qr
=
torch
.
randn
(
batch_size
,
h_q
,
dn
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
*
100.0
block_table
=
torch
.
randint
(
0
,
batch_size
*
block_num
,
...
...
@@ -84,16 +89,22 @@ def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits):
quantiles
=
[
0.5
,
0.2
,
0.8
]
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
cutlass_mla_decode
(
q
,
kv_cache
,
seq_lens
,
block_table
,
workspace
,
num_kv_splits
qn
.
transpose
(
0
,
1
),
qr
,
kv_cache
,
seq_lens
,
block_table
,
workspace
,
num_kv_splits
,
),
quantiles
=
quantiles
,
)
q_size
=
qn
.
numel
()
*
qn
.
element_size
()
+
qr
.
numel
()
*
qr
.
element_size
()
gbps
=
(
lambda
ms
:
(
q
.
numel
()
*
q
.
element_size
()
+
q
.
numel
()
*
q
.
element_size
()
*
dv
/
d
+
kv_cache
.
numel
()
*
kv_cache
.
element_size
()
q_size
+
q_size
*
dv
/
d
+
kv_cache
.
numel
()
*
kv_cache
.
element_size
()
)
*
1e-9
/
(
ms
*
1e-3
)
...
...
sgl-kernel/csrc/attention/cutlass_mla_kernel.cu
View file @
aa46ed34
...
...
@@ -22,6 +22,7 @@ limitations under the License.
#include <torch/all.h>
#include <cute/tensor.hpp>
#include <iostream>
#include "cutlass_sm100_mla/device/sm100_mla.hpp"
#include "cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp"
...
...
@@ -30,7 +31,8 @@ limitations under the License.
#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040
void
cutlass_mla_decode
(
torch
::
Tensor
const
&
out
,
torch
::
Tensor
const
&
q_nope_and_q_pe
,
torch
::
Tensor
const
&
q_nope
,
torch
::
Tensor
const
&
q_pe
,
torch
::
Tensor
const
&
kv_c_and_k_pe_cache
,
torch
::
Tensor
const
&
seq_lens
,
torch
::
Tensor
const
&
page_table
,
...
...
@@ -91,16 +93,17 @@ struct MlaSm100 {
template
<
typename
T
>
typename
T
::
Fmha
::
Arguments
args_from_options
(
at
::
Tensor
const
&
out
,
at
::
Tensor
const
&
q_nope_and_q_pe
,
at
::
Tensor
const
&
q_nope
,
at
::
Tensor
const
&
q_pe
,
at
::
Tensor
const
&
kv_c_and_k_pe_cache
,
at
::
Tensor
const
&
seq_lens
,
at
::
Tensor
const
&
page_table
,
int64_t
num_kv_splits
)
{
cutlass
::
KernelHardwareInfo
hw_info
;
hw_info
.
device_id
=
q_nope
_and_q_pe
.
device
().
index
();
hw_info
.
device_id
=
q_nope
.
device
().
index
();
hw_info
.
sm_count
=
cutlass
::
KernelHardwareInfo
::
query_device_multiprocessor_count
(
hw_info
.
device_id
);
int
batches
=
q_nope
_and_q_pe
.
sizes
()[
0
];
int
batches
=
q_nope
.
sizes
()[
0
];
int
page_count_per_seq
=
page_table
.
sizes
()[
1
];
int
page_count_total
=
kv_c_and_k_pe_cache
.
sizes
()[
0
];
int
page_size
=
kv_c_and_k_pe_cache
.
sizes
()[
1
];
...
...
@@ -122,8 +125,11 @@ typename T::Fmha::Arguments args_from_options(
using
StrideO
=
typename
T
::
StrideO
;
using
StrideLSE
=
typename
T
::
StrideLSE
;
StrideQ
stride_Q
=
cute
::
make_tuple
(
static_cast
<
int64_t
>
(
0
+
D_latent
+
D_rope
),
_1
{},
static_cast
<
int64_t
>
(
H
*
(
0
+
D_latent
+
D_rope
)));
StrideQ
stride_Q_nope
=
cute
::
make_tuple
(
static_cast
<
int64_t
>
(
q_nope
.
stride
(
1
)),
_1
{},
static_cast
<
int64_t
>
(
q_nope
.
stride
(
0
)));
StrideQ
stride_Q_pe
=
cute
::
make_tuple
(
static_cast
<
int64_t
>
(
q_pe
.
stride
(
1
)),
_1
{},
static_cast
<
int64_t
>
(
q_pe
.
stride
(
0
)));
StrideK
stride_C
=
cute
::
make_tuple
(
static_cast
<
int64_t
>
(
0
+
D_latent
+
D_rope
),
_1
{},
static_cast
<
int64_t
>
(
page_size
*
(
D_latent
+
D_rope
)));
StrideLSE
stride_PT
=
cute
::
make_stride
(
_1
{},
page_count_per_seq
);
...
...
@@ -133,15 +139,16 @@ typename T::Fmha::Arguments args_from_options(
using
Element
=
typename
T
::
Element
;
using
ElementOut
=
typename
T
::
ElementOut
;
using
ElementAcc
=
typename
T
::
ElementAcc
;
auto
Q_ptr
=
static_cast
<
Element
*>
(
q_nope_and_q_pe
.
data_ptr
());
auto
Q_nope_ptr
=
static_cast
<
Element
*>
(
q_nope
.
data_ptr
());
auto
Q_pe_ptr
=
static_cast
<
Element
*>
(
q_pe
.
data_ptr
());
auto
C_ptr
=
static_cast
<
Element
*>
(
kv_c_and_k_pe_cache
.
data_ptr
());
typename
T
::
Fmha
::
Arguments
arguments
{
problem_shape
,
{
scale
,
Q_ptr
,
stride_Q
,
Q_p
tr
+
D_latent
,
stride_Q
,
Q_
nope_
ptr
,
stride_Q
_nope
,
Q_p
e_ptr
,
stride_Q
_pe
,
C_ptr
,
stride_C
,
C_ptr
+
D_latent
,
...
...
@@ -170,7 +177,8 @@ typename T::Fmha::Arguments args_from_options(
template
<
typename
Element
,
bool
IsPaged128
,
typename
PersistenceOption
>
void
runMla
(
at
::
Tensor
const
&
out
,
at
::
Tensor
const
&
q_nope_and_q_pe
,
at
::
Tensor
const
&
q_nope
,
at
::
Tensor
const
&
q_pe
,
at
::
Tensor
const
&
kv_c_and_k_pe_cache
,
at
::
Tensor
const
&
seq_lens
,
at
::
Tensor
const
&
page_table
,
...
...
@@ -179,7 +187,7 @@ void runMla(
cudaStream_t
stream
)
{
using
MlaSm100Type
=
MlaSm100
<
Element
,
IsPaged128
,
PersistenceOption
>
;
typename
MlaSm100Type
::
Fmha
fmha
;
auto
arguments
=
args_from_options
<
MlaSm100Type
>
(
out
,
q_nope
_and_
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
num_kv_splits
);
auto
arguments
=
args_from_options
<
MlaSm100Type
>
(
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
num_kv_splits
);
CUTLASS_CHECK
(
fmha
.
can_implement
(
arguments
));
...
...
@@ -201,15 +209,16 @@ void runMla(
void
cutlass_mla_decode
(
torch
::
Tensor
const
&
out
,
torch
::
Tensor
const
&
q_nope_and_q_pe
,
torch
::
Tensor
const
&
q_nope
,
torch
::
Tensor
const
&
q_pe
,
torch
::
Tensor
const
&
kv_c_and_k_pe_cache
,
torch
::
Tensor
const
&
seq_lens
,
torch
::
Tensor
const
&
page_table
,
torch
::
Tensor
const
&
workspace
,
int64_t
num_kv_splits
)
{
auto
in_dtype
=
q_nope
_and_q_pe
.
dtype
();
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
q_nope
_and_q_pe
.
get_device
()};
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
q_nope
_and_q_pe
.
get_device
());
auto
in_dtype
=
q_nope
.
dtype
();
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
q_nope
.
get_device
()};
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
q_nope
.
get_device
());
const
int
page_size
=
kv_c_and_k_pe_cache
.
sizes
()[
1
];
// NOTE(alcanderian): IsPersistent has bug with manual split_kv.
...
...
@@ -219,13 +228,13 @@ void cutlass_mla_decode(
DISPATCH_BOOL
(
num_kv_splits
<=
1
,
NotManualSplitKV
,
[
&
]
{
if
(
in_dtype
==
at
::
ScalarType
::
Half
)
{
runMla
<
cutlass
::
half_t
,
IsPaged128
,
IsPersistent
<
NotManualSplitKV
>>
(
out
,
q_nope
_and_
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
num_kv_splits
,
stream
);
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
num_kv_splits
,
stream
);
}
else
if
(
in_dtype
==
at
::
ScalarType
::
BFloat16
)
{
runMla
<
cutlass
::
bfloat16_t
,
IsPaged128
,
IsPersistent
<
NotManualSplitKV
>>
(
out
,
q_nope
_and_
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
num_kv_splits
,
stream
);
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
num_kv_splits
,
stream
);
}
else
if
(
in_dtype
==
at
::
ScalarType
::
Float8_e4m3fn
)
{
runMla
<
cutlass
::
float_e4m3_t
,
IsPaged128
,
IsPersistent
<
NotManualSplitKV
>>
(
out
,
q_nope
_and_
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
num_kv_splits
,
stream
);
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
num_kv_splits
,
stream
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported input data type of MLA"
);
}
...
...
sgl-kernel/csrc/common_extension.cc
View file @
aa46ed34
...
...
@@ -59,7 +59,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m
.
def
(
"merge_state_v2(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()"
);
m
.
impl
(
"merge_state_v2"
,
torch
::
kCUDA
,
&
merge_state_v2
);
m
.
def
(
"cutlass_mla_decode(Tensor! out, Tensor q_nope
_and_
q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor "
"cutlass_mla_decode(Tensor! out, Tensor q_nope
, Tensor
q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor "
"page_table, Tensor! workspace, int num_kv_splits) -> ()"
);
m
.
impl
(
"cutlass_mla_decode"
,
torch
::
kCUDA
,
&
cutlass_mla_decode
);
m
.
def
(
"cutlass_mla_get_workspace_size"
,
&
cutlass_mla_get_workspace_size
);
...
...
sgl-kernel/include/sgl_kernel_ops.h
View file @
aa46ed34
...
...
@@ -105,7 +105,8 @@ void merge_state_v2(
at
::
Tensor
v_a
,
at
::
Tensor
s_a
,
at
::
Tensor
v_b
,
at
::
Tensor
s_b
,
at
::
Tensor
v_merged
,
at
::
Tensor
s_merged
);
void
cutlass_mla_decode
(
torch
::
Tensor
const
&
out
,
torch
::
Tensor
const
&
q_nope_and_q_pe
,
torch
::
Tensor
const
&
q_nope
,
torch
::
Tensor
const
&
q_pe
,
torch
::
Tensor
const
&
kv_c_and_k_pe_cache
,
torch
::
Tensor
const
&
seq_lens
,
torch
::
Tensor
const
&
page_table
,
...
...
sgl-kernel/python/sgl_kernel/attention.py
View file @
aa46ed34
...
...
@@ -52,34 +52,42 @@ def merge_state_v2(
def
cutlass_mla_decode
(
q_nope_and_q_pe
:
torch
.
Tensor
,
q_nope
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
page_table
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
num_kv_splits
:
int
=
-
1
,
)
->
torch
.
Tensor
:
assert
(
q_nope_and_q_pe
.
ndim
==
3
),
f
"q_nope_and_q_pe must be a 3D tensor, but got
{
q_nope_and_q_pe
.
ndim
}
"
assert
q_nope
.
ndim
==
3
,
f
"q_nope must be a 3D tensor, but got
{
q_nope
.
ndim
}
"
assert
q_pe
.
ndim
==
3
,
f
"q_pe must be a 3D tensor, but got
{
q_pe
.
ndim
}
"
assert
(
kv_c_and_k_pe_cache
.
ndim
==
3
),
f
"kv_c_and_k_pe_cache must be a 3D tensor, but got
{
kv_c_and_k_pe_cache
.
ndim
}
"
B_q
,
H
,
D_q
=
q_nope_and_q_pe
.
shape
B_q
,
H
,
D_q_nope
=
q_nope
.
shape
B_q_2
,
H_2
,
D_q_pe
=
q_pe
.
shape
assert
(
B_q
==
B_q_2
)
and
(
H
==
H_2
)
_
,
PAGE_SIZE
,
D_ckv
=
kv_c_and_k_pe_cache
.
shape
D_latent
=
512
D_rope
=
64
assert
D_q
==
D_ckv
and
D_q
==
D_latent
+
D_rope
,
(
f
"D_q must be equal to D_ckv and D_q must be equal to D_latent +
D_rope
, "
f
"but got D_q =
{
D_q
}
, D_ckv =
{
D_ckv
}
, D_latent
=
{
D_latent
}
, D_rope =
{
D_rope
}
"
)
assert
D_q
_nope
==
D_latent
assert
D_q_pe
==
D_rope
assert
D_ckv
=
=
D_latent
+
D_rope
MAX_HEADS
=
128
assert
H
<=
MAX_HEADS
,
f
"H must be <=
{
MAX_HEADS
}
, but got
{
H
}
"
if
H
<
MAX_HEADS
:
q_nope_and_q_pe_padded
=
q_nope_and_q_pe
.
new_empty
((
B_q
,
MAX_HEADS
,
D_q
))
q_nope_and_q_pe_padded
[:,
:
H
]
=
q_nope_and_q_pe
q_nope_and_q_pe
=
q_nope_and_q_pe_padded
q_nope_padded
=
q_nope
.
new_empty
((
B_q
,
MAX_HEADS
,
D_q_nope
))
q_nope_padded
[:,
:
H
]
=
q_nope
q_nope
=
q_nope_padded
q_pe_padded
=
q_pe
.
new_empty
((
B_q
,
MAX_HEADS
,
D_q_pe
))
q_pe_padded
[:,
:
H
]
=
q_pe
q_pe
=
q_pe_padded
assert
len
(
page_table
.
shape
)
==
2
B_block_table
,
block_num
=
page_table
.
shape
...
...
@@ -88,14 +96,11 @@ def cutlass_mla_decode(
assert
block_num
%
(
128
/
PAGE_SIZE
)
==
0
# TODO(kaixih@nvidia): support fp8
assert
q_nope
_and_q_pe
.
dtype
in
(
assert
q_nope
.
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
,
),
f
"q_nope_and_q_pe.dtype needs to be fp16 or bf16 but got
{
q_nope_and_q_pe
.
dtype
}
."
assert
kv_c_and_k_pe_cache
.
dtype
==
q_nope_and_q_pe
.
dtype
,
(
f
"kv_c_and_k_pe_cache.dtype needs to be the same as q_nope_and_q_pe.dtype, "
f
"but got
{
kv_c_and_k_pe_cache
.
dtype
}
."
)
),
f
"q_nope.dtype needs to be fp16 or bf16 but got
{
q_nope
.
dtype
}
."
assert
q_nope
.
dtype
==
q_pe
.
dtype
==
kv_c_and_k_pe_cache
.
dtype
assert
(
seq_lens
.
dtype
==
torch
.
int32
),
f
"seq_lens.dtype needs to be int32 but got
{
seq_lens
.
dtype
}
."
...
...
@@ -103,11 +108,12 @@ def cutlass_mla_decode(
page_table
.
dtype
==
torch
.
int32
),
f
"page_table.dtype needs to be int32 but got
{
page_table
.
dtype
}
."
out
=
q_nope
_and_q_pe
.
new_empty
((
B_q
,
MAX_HEADS
,
D_latent
))
out
=
q_nope
.
new_empty
((
B_q
,
MAX_HEADS
,
D_latent
))
torch
.
ops
.
sgl_kernel
.
cutlass_mla_decode
.
default
(
out
,
q_nope_and_q_pe
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
...
...
sgl-kernel/tests/test_cutlass_mla.py
View file @
aa46ed34
...
...
@@ -86,10 +86,14 @@ def test_cutlass_mla_decode(
)
workspace
=
torch
.
empty
(
workspace_size
,
device
=
"cuda"
,
dtype
=
torch
.
uint8
)
q_nope
=
torch
.
empty
((
h_q
,
bs
,
dv
)).
transpose
(
0
,
1
)
q_nope
.
copy_
(
q
[:,
:,
:
dv
])
q_pe
=
q
[:,
:,
dv
:].
clone
()
out_ref
=
q
.
new_zeros
(
bs
,
h_q
,
dv
)
ref_mla
(
out_ref
,
q
,
kv_cache
,
scale
,
block_table
,
seq_lens
)
out
=
cutlass_mla_decode
(
q
,
kv_cache
,
seq_lens
,
block_table
,
workspace
,
num_kv_splits
q
_nope
,
q_pe
,
kv_cache
,
seq_lens
,
block_table
,
workspace
,
num_kv_splits
)
torch
.
testing
.
assert_close
(
out
,
out_ref
,
atol
=
1e-2
,
rtol
=
1e-2
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment