Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
aa46ed34
Unverified
Commit
aa46ed34
authored
Jun 13, 2025
by
fzyzcjy
Committed by
GitHub
Jun 13, 2025
Browse files
Remove 200us slow concat kernel (part 1: kernel) (#7145)
parent
2f4ec752
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
79 additions
and
48 deletions
+79
-48
sgl-kernel/benchmark/bench_cutlass_mla.py
sgl-kernel/benchmark/bench_cutlass_mla.py
+16
-5
sgl-kernel/csrc/attention/cutlass_mla_kernel.cu
sgl-kernel/csrc/attention/cutlass_mla_kernel.cu
+29
-20
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+1
-1
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+2
-1
sgl-kernel/python/sgl_kernel/attention.py
sgl-kernel/python/sgl_kernel/attention.py
+26
-20
sgl-kernel/tests/test_cutlass_mla.py
sgl-kernel/tests/test_cutlass_mla.py
+5
-1
No files found.
sgl-kernel/benchmark/bench_cutlass_mla.py
View file @
aa46ed34
...
@@ -38,6 +38,7 @@ configs = list(itertools.product(bs_range, qlen_range))
...
@@ -38,6 +38,7 @@ configs = list(itertools.product(bs_range, qlen_range))
)
)
def
benchmark
(
batch_size
,
seq_len
,
provider
,
block_size
,
num_kv_splits
):
def
benchmark
(
batch_size
,
seq_len
,
provider
,
block_size
,
num_kv_splits
):
d
=
576
d
=
576
dn
=
64
dv
=
512
dv
=
512
h_q_map
=
{
h_q_map
=
{
...
@@ -63,7 +64,11 @@ def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits):
...
@@ -63,7 +64,11 @@ def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits):
pack_factor
=
128
//
block_size
pack_factor
=
128
//
block_size
block_num
=
((
block_num
+
pack_factor
-
1
)
//
pack_factor
)
*
pack_factor
block_num
=
((
block_num
+
pack_factor
-
1
)
//
pack_factor
)
*
pack_factor
q
=
torch
.
randn
(
batch_size
,
h_q
,
d
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
*
100.0
qn
=
(
torch
.
randn
(
h_q
,
batch_size
,
d
-
dn
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
*
100.0
)
qr
=
torch
.
randn
(
batch_size
,
h_q
,
dn
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
*
100.0
block_table
=
torch
.
randint
(
block_table
=
torch
.
randint
(
0
,
0
,
batch_size
*
block_num
,
batch_size
*
block_num
,
...
@@ -84,16 +89,22 @@ def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits):
...
@@ -84,16 +89,22 @@ def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits):
quantiles
=
[
0.5
,
0.2
,
0.8
]
quantiles
=
[
0.5
,
0.2
,
0.8
]
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
cutlass_mla_decode
(
lambda
:
cutlass_mla_decode
(
q
,
kv_cache
,
seq_lens
,
block_table
,
workspace
,
num_kv_splits
qn
.
transpose
(
0
,
1
),
qr
,
kv_cache
,
seq_lens
,
block_table
,
workspace
,
num_kv_splits
,
),
),
quantiles
=
quantiles
,
quantiles
=
quantiles
,
)
)
q_size
=
qn
.
numel
()
*
qn
.
element_size
()
+
qr
.
numel
()
*
qr
.
element_size
()
gbps
=
(
gbps
=
(
lambda
ms
:
(
lambda
ms
:
(
q
.
numel
()
*
q
.
element_size
()
q_size
+
q_size
*
dv
/
d
+
kv_cache
.
numel
()
*
kv_cache
.
element_size
()
+
q
.
numel
()
*
q
.
element_size
()
*
dv
/
d
+
kv_cache
.
numel
()
*
kv_cache
.
element_size
()
)
)
*
1e-9
*
1e-9
/
(
ms
*
1e-3
)
/
(
ms
*
1e-3
)
...
...
sgl-kernel/csrc/attention/cutlass_mla_kernel.cu
View file @
aa46ed34
...
@@ -22,6 +22,7 @@ limitations under the License.
...
@@ -22,6 +22,7 @@ limitations under the License.
#include <torch/all.h>
#include <torch/all.h>
#include <cute/tensor.hpp>
#include <cute/tensor.hpp>
#include <iostream>
#include "cutlass_sm100_mla/device/sm100_mla.hpp"
#include "cutlass_sm100_mla/device/sm100_mla.hpp"
#include "cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp"
#include "cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp"
...
@@ -30,7 +31,8 @@ limitations under the License.
...
@@ -30,7 +31,8 @@ limitations under the License.
#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040
#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040
void
cutlass_mla_decode
(
void
cutlass_mla_decode
(
torch
::
Tensor
const
&
out
,
torch
::
Tensor
const
&
out
,
torch
::
Tensor
const
&
q_nope_and_q_pe
,
torch
::
Tensor
const
&
q_nope
,
torch
::
Tensor
const
&
q_pe
,
torch
::
Tensor
const
&
kv_c_and_k_pe_cache
,
torch
::
Tensor
const
&
kv_c_and_k_pe_cache
,
torch
::
Tensor
const
&
seq_lens
,
torch
::
Tensor
const
&
seq_lens
,
torch
::
Tensor
const
&
page_table
,
torch
::
Tensor
const
&
page_table
,
...
@@ -91,16 +93,17 @@ struct MlaSm100 {
...
@@ -91,16 +93,17 @@ struct MlaSm100 {
template
<
typename
T
>
template
<
typename
T
>
typename
T
::
Fmha
::
Arguments
args_from_options
(
typename
T
::
Fmha
::
Arguments
args_from_options
(
at
::
Tensor
const
&
out
,
at
::
Tensor
const
&
out
,
at
::
Tensor
const
&
q_nope_and_q_pe
,
at
::
Tensor
const
&
q_nope
,
at
::
Tensor
const
&
q_pe
,
at
::
Tensor
const
&
kv_c_and_k_pe_cache
,
at
::
Tensor
const
&
kv_c_and_k_pe_cache
,
at
::
Tensor
const
&
seq_lens
,
at
::
Tensor
const
&
seq_lens
,
at
::
Tensor
const
&
page_table
,
at
::
Tensor
const
&
page_table
,
int64_t
num_kv_splits
)
{
int64_t
num_kv_splits
)
{
cutlass
::
KernelHardwareInfo
hw_info
;
cutlass
::
KernelHardwareInfo
hw_info
;
hw_info
.
device_id
=
q_nope
_and_q_pe
.
device
().
index
();
hw_info
.
device_id
=
q_nope
.
device
().
index
();
hw_info
.
sm_count
=
cutlass
::
KernelHardwareInfo
::
query_device_multiprocessor_count
(
hw_info
.
device_id
);
hw_info
.
sm_count
=
cutlass
::
KernelHardwareInfo
::
query_device_multiprocessor_count
(
hw_info
.
device_id
);
int
batches
=
q_nope
_and_q_pe
.
sizes
()[
0
];
int
batches
=
q_nope
.
sizes
()[
0
];
int
page_count_per_seq
=
page_table
.
sizes
()[
1
];
int
page_count_per_seq
=
page_table
.
sizes
()[
1
];
int
page_count_total
=
kv_c_and_k_pe_cache
.
sizes
()[
0
];
int
page_count_total
=
kv_c_and_k_pe_cache
.
sizes
()[
0
];
int
page_size
=
kv_c_and_k_pe_cache
.
sizes
()[
1
];
int
page_size
=
kv_c_and_k_pe_cache
.
sizes
()[
1
];
...
@@ -122,8 +125,11 @@ typename T::Fmha::Arguments args_from_options(
...
@@ -122,8 +125,11 @@ typename T::Fmha::Arguments args_from_options(
using
StrideO
=
typename
T
::
StrideO
;
using
StrideO
=
typename
T
::
StrideO
;
using
StrideLSE
=
typename
T
::
StrideLSE
;
using
StrideLSE
=
typename
T
::
StrideLSE
;
StrideQ
stride_Q
=
cute
::
make_tuple
(
StrideQ
stride_Q_nope
=
cute
::
make_tuple
(
static_cast
<
int64_t
>
(
0
+
D_latent
+
D_rope
),
_1
{},
static_cast
<
int64_t
>
(
H
*
(
0
+
D_latent
+
D_rope
)));
static_cast
<
int64_t
>
(
q_nope
.
stride
(
1
)),
_1
{},
static_cast
<
int64_t
>
(
q_nope
.
stride
(
0
)));
StrideQ
stride_Q_pe
=
cute
::
make_tuple
(
static_cast
<
int64_t
>
(
q_pe
.
stride
(
1
)),
_1
{},
static_cast
<
int64_t
>
(
q_pe
.
stride
(
0
)));
StrideK
stride_C
=
cute
::
make_tuple
(
StrideK
stride_C
=
cute
::
make_tuple
(
static_cast
<
int64_t
>
(
0
+
D_latent
+
D_rope
),
_1
{},
static_cast
<
int64_t
>
(
page_size
*
(
D_latent
+
D_rope
)));
static_cast
<
int64_t
>
(
0
+
D_latent
+
D_rope
),
_1
{},
static_cast
<
int64_t
>
(
page_size
*
(
D_latent
+
D_rope
)));
StrideLSE
stride_PT
=
cute
::
make_stride
(
_1
{},
page_count_per_seq
);
StrideLSE
stride_PT
=
cute
::
make_stride
(
_1
{},
page_count_per_seq
);
...
@@ -133,15 +139,16 @@ typename T::Fmha::Arguments args_from_options(
...
@@ -133,15 +139,16 @@ typename T::Fmha::Arguments args_from_options(
using
Element
=
typename
T
::
Element
;
using
Element
=
typename
T
::
Element
;
using
ElementOut
=
typename
T
::
ElementOut
;
using
ElementOut
=
typename
T
::
ElementOut
;
using
ElementAcc
=
typename
T
::
ElementAcc
;
using
ElementAcc
=
typename
T
::
ElementAcc
;
auto
Q_ptr
=
static_cast
<
Element
*>
(
q_nope_and_q_pe
.
data_ptr
());
auto
Q_nope_ptr
=
static_cast
<
Element
*>
(
q_nope
.
data_ptr
());
auto
Q_pe_ptr
=
static_cast
<
Element
*>
(
q_pe
.
data_ptr
());
auto
C_ptr
=
static_cast
<
Element
*>
(
kv_c_and_k_pe_cache
.
data_ptr
());
auto
C_ptr
=
static_cast
<
Element
*>
(
kv_c_and_k_pe_cache
.
data_ptr
());
typename
T
::
Fmha
::
Arguments
arguments
{
typename
T
::
Fmha
::
Arguments
arguments
{
problem_shape
,
problem_shape
,
{
scale
,
{
scale
,
Q_ptr
,
Q_
nope_
ptr
,
stride_Q
,
stride_Q
_nope
,
Q_p
tr
+
D_latent
,
Q_p
e_ptr
,
stride_Q
,
stride_Q
_pe
,
C_ptr
,
C_ptr
,
stride_C
,
stride_C
,
C_ptr
+
D_latent
,
C_ptr
+
D_latent
,
...
@@ -170,7 +177,8 @@ typename T::Fmha::Arguments args_from_options(
...
@@ -170,7 +177,8 @@ typename T::Fmha::Arguments args_from_options(
template
<
typename
Element
,
bool
IsPaged128
,
typename
PersistenceOption
>
template
<
typename
Element
,
bool
IsPaged128
,
typename
PersistenceOption
>
void
runMla
(
void
runMla
(
at
::
Tensor
const
&
out
,
at
::
Tensor
const
&
out
,
at
::
Tensor
const
&
q_nope_and_q_pe
,
at
::
Tensor
const
&
q_nope
,
at
::
Tensor
const
&
q_pe
,
at
::
Tensor
const
&
kv_c_and_k_pe_cache
,
at
::
Tensor
const
&
kv_c_and_k_pe_cache
,
at
::
Tensor
const
&
seq_lens
,
at
::
Tensor
const
&
seq_lens
,
at
::
Tensor
const
&
page_table
,
at
::
Tensor
const
&
page_table
,
...
@@ -179,7 +187,7 @@ void runMla(
...
@@ -179,7 +187,7 @@ void runMla(
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
using
MlaSm100Type
=
MlaSm100
<
Element
,
IsPaged128
,
PersistenceOption
>
;
using
MlaSm100Type
=
MlaSm100
<
Element
,
IsPaged128
,
PersistenceOption
>
;
typename
MlaSm100Type
::
Fmha
fmha
;
typename
MlaSm100Type
::
Fmha
fmha
;
auto
arguments
=
args_from_options
<
MlaSm100Type
>
(
out
,
q_nope
_and_
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
num_kv_splits
);
auto
arguments
=
args_from_options
<
MlaSm100Type
>
(
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
num_kv_splits
);
CUTLASS_CHECK
(
fmha
.
can_implement
(
arguments
));
CUTLASS_CHECK
(
fmha
.
can_implement
(
arguments
));
...
@@ -201,15 +209,16 @@ void runMla(
...
@@ -201,15 +209,16 @@ void runMla(
void
cutlass_mla_decode
(
void
cutlass_mla_decode
(
torch
::
Tensor
const
&
out
,
torch
::
Tensor
const
&
out
,
torch
::
Tensor
const
&
q_nope_and_q_pe
,
torch
::
Tensor
const
&
q_nope
,
torch
::
Tensor
const
&
q_pe
,
torch
::
Tensor
const
&
kv_c_and_k_pe_cache
,
torch
::
Tensor
const
&
kv_c_and_k_pe_cache
,
torch
::
Tensor
const
&
seq_lens
,
torch
::
Tensor
const
&
seq_lens
,
torch
::
Tensor
const
&
page_table
,
torch
::
Tensor
const
&
page_table
,
torch
::
Tensor
const
&
workspace
,
torch
::
Tensor
const
&
workspace
,
int64_t
num_kv_splits
)
{
int64_t
num_kv_splits
)
{
auto
in_dtype
=
q_nope
_and_q_pe
.
dtype
();
auto
in_dtype
=
q_nope
.
dtype
();
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
q_nope
_and_q_pe
.
get_device
()};
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
q_nope
.
get_device
()};
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
q_nope
_and_q_pe
.
get_device
());
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
q_nope
.
get_device
());
const
int
page_size
=
kv_c_and_k_pe_cache
.
sizes
()[
1
];
const
int
page_size
=
kv_c_and_k_pe_cache
.
sizes
()[
1
];
// NOTE(alcanderian): IsPersistent has bug with manual split_kv.
// NOTE(alcanderian): IsPersistent has bug with manual split_kv.
...
@@ -219,13 +228,13 @@ void cutlass_mla_decode(
...
@@ -219,13 +228,13 @@ void cutlass_mla_decode(
DISPATCH_BOOL
(
num_kv_splits
<=
1
,
NotManualSplitKV
,
[
&
]
{
DISPATCH_BOOL
(
num_kv_splits
<=
1
,
NotManualSplitKV
,
[
&
]
{
if
(
in_dtype
==
at
::
ScalarType
::
Half
)
{
if
(
in_dtype
==
at
::
ScalarType
::
Half
)
{
runMla
<
cutlass
::
half_t
,
IsPaged128
,
IsPersistent
<
NotManualSplitKV
>>
(
runMla
<
cutlass
::
half_t
,
IsPaged128
,
IsPersistent
<
NotManualSplitKV
>>
(
out
,
q_nope
_and_
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
num_kv_splits
,
stream
);
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
num_kv_splits
,
stream
);
}
else
if
(
in_dtype
==
at
::
ScalarType
::
BFloat16
)
{
}
else
if
(
in_dtype
==
at
::
ScalarType
::
BFloat16
)
{
runMla
<
cutlass
::
bfloat16_t
,
IsPaged128
,
IsPersistent
<
NotManualSplitKV
>>
(
runMla
<
cutlass
::
bfloat16_t
,
IsPaged128
,
IsPersistent
<
NotManualSplitKV
>>
(
out
,
q_nope
_and_
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
num_kv_splits
,
stream
);
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
num_kv_splits
,
stream
);
}
else
if
(
in_dtype
==
at
::
ScalarType
::
Float8_e4m3fn
)
{
}
else
if
(
in_dtype
==
at
::
ScalarType
::
Float8_e4m3fn
)
{
runMla
<
cutlass
::
float_e4m3_t
,
IsPaged128
,
IsPersistent
<
NotManualSplitKV
>>
(
runMla
<
cutlass
::
float_e4m3_t
,
IsPaged128
,
IsPersistent
<
NotManualSplitKV
>>
(
out
,
q_nope
_and_
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
num_kv_splits
,
stream
);
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
num_kv_splits
,
stream
);
}
else
{
}
else
{
TORCH_CHECK
(
false
,
"Unsupported input data type of MLA"
);
TORCH_CHECK
(
false
,
"Unsupported input data type of MLA"
);
}
}
...
...
sgl-kernel/csrc/common_extension.cc
View file @
aa46ed34
...
@@ -59,7 +59,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -59,7 +59,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m
.
def
(
"merge_state_v2(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()"
);
m
.
def
(
"merge_state_v2(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()"
);
m
.
impl
(
"merge_state_v2"
,
torch
::
kCUDA
,
&
merge_state_v2
);
m
.
impl
(
"merge_state_v2"
,
torch
::
kCUDA
,
&
merge_state_v2
);
m
.
def
(
m
.
def
(
"cutlass_mla_decode(Tensor! out, Tensor q_nope
_and_
q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor "
"cutlass_mla_decode(Tensor! out, Tensor q_nope
, Tensor
q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor "
"page_table, Tensor! workspace, int num_kv_splits) -> ()"
);
"page_table, Tensor! workspace, int num_kv_splits) -> ()"
);
m
.
impl
(
"cutlass_mla_decode"
,
torch
::
kCUDA
,
&
cutlass_mla_decode
);
m
.
impl
(
"cutlass_mla_decode"
,
torch
::
kCUDA
,
&
cutlass_mla_decode
);
m
.
def
(
"cutlass_mla_get_workspace_size"
,
&
cutlass_mla_get_workspace_size
);
m
.
def
(
"cutlass_mla_get_workspace_size"
,
&
cutlass_mla_get_workspace_size
);
...
...
sgl-kernel/include/sgl_kernel_ops.h
View file @
aa46ed34
...
@@ -105,7 +105,8 @@ void merge_state_v2(
...
@@ -105,7 +105,8 @@ void merge_state_v2(
at
::
Tensor
v_a
,
at
::
Tensor
s_a
,
at
::
Tensor
v_b
,
at
::
Tensor
s_b
,
at
::
Tensor
v_merged
,
at
::
Tensor
s_merged
);
at
::
Tensor
v_a
,
at
::
Tensor
s_a
,
at
::
Tensor
v_b
,
at
::
Tensor
s_b
,
at
::
Tensor
v_merged
,
at
::
Tensor
s_merged
);
void
cutlass_mla_decode
(
void
cutlass_mla_decode
(
torch
::
Tensor
const
&
out
,
torch
::
Tensor
const
&
out
,
torch
::
Tensor
const
&
q_nope_and_q_pe
,
torch
::
Tensor
const
&
q_nope
,
torch
::
Tensor
const
&
q_pe
,
torch
::
Tensor
const
&
kv_c_and_k_pe_cache
,
torch
::
Tensor
const
&
kv_c_and_k_pe_cache
,
torch
::
Tensor
const
&
seq_lens
,
torch
::
Tensor
const
&
seq_lens
,
torch
::
Tensor
const
&
page_table
,
torch
::
Tensor
const
&
page_table
,
...
...
sgl-kernel/python/sgl_kernel/attention.py
View file @
aa46ed34
...
@@ -52,34 +52,42 @@ def merge_state_v2(
...
@@ -52,34 +52,42 @@ def merge_state_v2(
def
cutlass_mla_decode
(
def
cutlass_mla_decode
(
q_nope_and_q_pe
:
torch
.
Tensor
,
q_nope
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
page_table
:
torch
.
Tensor
,
page_table
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
num_kv_splits
:
int
=
-
1
,
num_kv_splits
:
int
=
-
1
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
(
assert
q_nope
.
ndim
==
3
,
f
"q_nope must be a 3D tensor, but got
{
q_nope
.
ndim
}
"
q_nope_and_q_pe
.
ndim
==
3
assert
q_pe
.
ndim
==
3
,
f
"q_pe must be a 3D tensor, but got
{
q_pe
.
ndim
}
"
),
f
"q_nope_and_q_pe must be a 3D tensor, but got
{
q_nope_and_q_pe
.
ndim
}
"
assert
(
assert
(
kv_c_and_k_pe_cache
.
ndim
==
3
kv_c_and_k_pe_cache
.
ndim
==
3
),
f
"kv_c_and_k_pe_cache must be a 3D tensor, but got
{
kv_c_and_k_pe_cache
.
ndim
}
"
),
f
"kv_c_and_k_pe_cache must be a 3D tensor, but got
{
kv_c_and_k_pe_cache
.
ndim
}
"
B_q
,
H
,
D_q
=
q_nope_and_q_pe
.
shape
B_q
,
H
,
D_q_nope
=
q_nope
.
shape
B_q_2
,
H_2
,
D_q_pe
=
q_pe
.
shape
assert
(
B_q
==
B_q_2
)
and
(
H
==
H_2
)
_
,
PAGE_SIZE
,
D_ckv
=
kv_c_and_k_pe_cache
.
shape
_
,
PAGE_SIZE
,
D_ckv
=
kv_c_and_k_pe_cache
.
shape
D_latent
=
512
D_latent
=
512
D_rope
=
64
D_rope
=
64
assert
D_q
==
D_ckv
and
D_q
==
D_latent
+
D_rope
,
(
assert
D_q
_nope
==
D_latent
f
"D_q must be equal to D_ckv and D_q must be equal to D_latent +
D_rope
, "
assert
D_q_pe
==
D_rope
f
"but got D_q =
{
D_q
}
, D_ckv =
{
D_ckv
}
, D_latent
=
{
D_latent
}
, D_rope =
{
D_rope
}
"
assert
D_ckv
=
=
D_latent
+
D_rope
)
MAX_HEADS
=
128
MAX_HEADS
=
128
assert
H
<=
MAX_HEADS
,
f
"H must be <=
{
MAX_HEADS
}
, but got
{
H
}
"
assert
H
<=
MAX_HEADS
,
f
"H must be <=
{
MAX_HEADS
}
, but got
{
H
}
"
if
H
<
MAX_HEADS
:
if
H
<
MAX_HEADS
:
q_nope_and_q_pe_padded
=
q_nope_and_q_pe
.
new_empty
((
B_q
,
MAX_HEADS
,
D_q
))
q_nope_padded
=
q_nope
.
new_empty
((
B_q
,
MAX_HEADS
,
D_q_nope
))
q_nope_and_q_pe_padded
[:,
:
H
]
=
q_nope_and_q_pe
q_nope_padded
[:,
:
H
]
=
q_nope
q_nope_and_q_pe
=
q_nope_and_q_pe_padded
q_nope
=
q_nope_padded
q_pe_padded
=
q_pe
.
new_empty
((
B_q
,
MAX_HEADS
,
D_q_pe
))
q_pe_padded
[:,
:
H
]
=
q_pe
q_pe
=
q_pe_padded
assert
len
(
page_table
.
shape
)
==
2
assert
len
(
page_table
.
shape
)
==
2
B_block_table
,
block_num
=
page_table
.
shape
B_block_table
,
block_num
=
page_table
.
shape
...
@@ -88,14 +96,11 @@ def cutlass_mla_decode(
...
@@ -88,14 +96,11 @@ def cutlass_mla_decode(
assert
block_num
%
(
128
/
PAGE_SIZE
)
==
0
assert
block_num
%
(
128
/
PAGE_SIZE
)
==
0
# TODO(kaixih@nvidia): support fp8
# TODO(kaixih@nvidia): support fp8
assert
q_nope
_and_q_pe
.
dtype
in
(
assert
q_nope
.
dtype
in
(
torch
.
float16
,
torch
.
float16
,
torch
.
bfloat16
,
torch
.
bfloat16
,
),
f
"q_nope_and_q_pe.dtype needs to be fp16 or bf16 but got
{
q_nope_and_q_pe
.
dtype
}
."
),
f
"q_nope.dtype needs to be fp16 or bf16 but got
{
q_nope
.
dtype
}
."
assert
kv_c_and_k_pe_cache
.
dtype
==
q_nope_and_q_pe
.
dtype
,
(
assert
q_nope
.
dtype
==
q_pe
.
dtype
==
kv_c_and_k_pe_cache
.
dtype
f
"kv_c_and_k_pe_cache.dtype needs to be the same as q_nope_and_q_pe.dtype, "
f
"but got
{
kv_c_and_k_pe_cache
.
dtype
}
."
)
assert
(
assert
(
seq_lens
.
dtype
==
torch
.
int32
seq_lens
.
dtype
==
torch
.
int32
),
f
"seq_lens.dtype needs to be int32 but got
{
seq_lens
.
dtype
}
."
),
f
"seq_lens.dtype needs to be int32 but got
{
seq_lens
.
dtype
}
."
...
@@ -103,11 +108,12 @@ def cutlass_mla_decode(
...
@@ -103,11 +108,12 @@ def cutlass_mla_decode(
page_table
.
dtype
==
torch
.
int32
page_table
.
dtype
==
torch
.
int32
),
f
"page_table.dtype needs to be int32 but got
{
page_table
.
dtype
}
."
),
f
"page_table.dtype needs to be int32 but got
{
page_table
.
dtype
}
."
out
=
q_nope
_and_q_pe
.
new_empty
((
B_q
,
MAX_HEADS
,
D_latent
))
out
=
q_nope
.
new_empty
((
B_q
,
MAX_HEADS
,
D_latent
))
torch
.
ops
.
sgl_kernel
.
cutlass_mla_decode
.
default
(
torch
.
ops
.
sgl_kernel
.
cutlass_mla_decode
.
default
(
out
,
out
,
q_nope_and_q_pe
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
kv_c_and_k_pe_cache
,
seq_lens
,
seq_lens
,
page_table
,
page_table
,
...
...
sgl-kernel/tests/test_cutlass_mla.py
View file @
aa46ed34
...
@@ -86,10 +86,14 @@ def test_cutlass_mla_decode(
...
@@ -86,10 +86,14 @@ def test_cutlass_mla_decode(
)
)
workspace
=
torch
.
empty
(
workspace_size
,
device
=
"cuda"
,
dtype
=
torch
.
uint8
)
workspace
=
torch
.
empty
(
workspace_size
,
device
=
"cuda"
,
dtype
=
torch
.
uint8
)
q_nope
=
torch
.
empty
((
h_q
,
bs
,
dv
)).
transpose
(
0
,
1
)
q_nope
.
copy_
(
q
[:,
:,
:
dv
])
q_pe
=
q
[:,
:,
dv
:].
clone
()
out_ref
=
q
.
new_zeros
(
bs
,
h_q
,
dv
)
out_ref
=
q
.
new_zeros
(
bs
,
h_q
,
dv
)
ref_mla
(
out_ref
,
q
,
kv_cache
,
scale
,
block_table
,
seq_lens
)
ref_mla
(
out_ref
,
q
,
kv_cache
,
scale
,
block_table
,
seq_lens
)
out
=
cutlass_mla_decode
(
out
=
cutlass_mla_decode
(
q
,
kv_cache
,
seq_lens
,
block_table
,
workspace
,
num_kv_splits
q
_nope
,
q_pe
,
kv_cache
,
seq_lens
,
block_table
,
workspace
,
num_kv_splits
)
)
torch
.
testing
.
assert_close
(
out
,
out_ref
,
atol
=
1e-2
,
rtol
=
1e-2
)
torch
.
testing
.
assert_close
(
out
,
out_ref
,
atol
=
1e-2
,
rtol
=
1e-2
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment