Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
aa46ed34
"torchvision/vscode:/vscode.git/clone" did not exist on "5cb77a20c3c65ca6199fdf1c1bc642af7447d311"
Unverified
Commit
aa46ed34
authored
Jun 13, 2025
by
fzyzcjy
Committed by
GitHub
Jun 13, 2025
Browse files
Remove 200us slow concat kernel (part 1: kernel) (#7145)
parent
2f4ec752
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
79 additions
and
48 deletions
+79
-48
sgl-kernel/benchmark/bench_cutlass_mla.py
sgl-kernel/benchmark/bench_cutlass_mla.py
+16
-5
sgl-kernel/csrc/attention/cutlass_mla_kernel.cu
sgl-kernel/csrc/attention/cutlass_mla_kernel.cu
+29
-20
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+1
-1
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+2
-1
sgl-kernel/python/sgl_kernel/attention.py
sgl-kernel/python/sgl_kernel/attention.py
+26
-20
sgl-kernel/tests/test_cutlass_mla.py
sgl-kernel/tests/test_cutlass_mla.py
+5
-1
No files found.
sgl-kernel/benchmark/bench_cutlass_mla.py
View file @
aa46ed34
...
@@ -38,6 +38,7 @@ configs = list(itertools.product(bs_range, qlen_range))
...
@@ -38,6 +38,7 @@ configs = list(itertools.product(bs_range, qlen_range))
)
)
def
benchmark
(
batch_size
,
seq_len
,
provider
,
block_size
,
num_kv_splits
):
def
benchmark
(
batch_size
,
seq_len
,
provider
,
block_size
,
num_kv_splits
):
d
=
576
d
=
576
dn
=
64
dv
=
512
dv
=
512
h_q_map
=
{
h_q_map
=
{
...
@@ -63,7 +64,11 @@ def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits):
...
@@ -63,7 +64,11 @@ def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits):
pack_factor
=
128
//
block_size
pack_factor
=
128
//
block_size
block_num
=
((
block_num
+
pack_factor
-
1
)
//
pack_factor
)
*
pack_factor
block_num
=
((
block_num
+
pack_factor
-
1
)
//
pack_factor
)
*
pack_factor
q
=
torch
.
randn
(
batch_size
,
h_q
,
d
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
*
100.0
qn
=
(
torch
.
randn
(
h_q
,
batch_size
,
d
-
dn
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
*
100.0
)
qr
=
torch
.
randn
(
batch_size
,
h_q
,
dn
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
*
100.0
block_table
=
torch
.
randint
(
block_table
=
torch
.
randint
(
0
,
0
,
batch_size
*
block_num
,
batch_size
*
block_num
,
...
@@ -84,16 +89,22 @@ def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits):
...
@@ -84,16 +89,22 @@ def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits):
quantiles
=
[
0.5
,
0.2
,
0.8
]
quantiles
=
[
0.5
,
0.2
,
0.8
]
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
cutlass_mla_decode
(
lambda
:
cutlass_mla_decode
(
q
,
kv_cache
,
seq_lens
,
block_table
,
workspace
,
num_kv_splits
qn
.
transpose
(
0
,
1
),
qr
,
kv_cache
,
seq_lens
,
block_table
,
workspace
,
num_kv_splits
,
),
),
quantiles
=
quantiles
,
quantiles
=
quantiles
,
)
)
q_size
=
qn
.
numel
()
*
qn
.
element_size
()
+
qr
.
numel
()
*
qr
.
element_size
()
gbps
=
(
gbps
=
(
lambda
ms
:
(
lambda
ms
:
(
q
.
numel
()
*
q
.
element_size
()
q_size
+
q_size
*
dv
/
d
+
kv_cache
.
numel
()
*
kv_cache
.
element_size
()
+
q
.
numel
()
*
q
.
element_size
()
*
dv
/
d
+
kv_cache
.
numel
()
*
kv_cache
.
element_size
()
)
)
*
1e-9
*
1e-9
/
(
ms
*
1e-3
)
/
(
ms
*
1e-3
)
...
...
sgl-kernel/csrc/attention/cutlass_mla_kernel.cu
View file @
aa46ed34
...
@@ -22,6 +22,7 @@ limitations under the License.
...
@@ -22,6 +22,7 @@ limitations under the License.
#include <torch/all.h>
#include <torch/all.h>
#include <cute/tensor.hpp>
#include <cute/tensor.hpp>
#include <iostream>
#include "cutlass_sm100_mla/device/sm100_mla.hpp"
#include "cutlass_sm100_mla/device/sm100_mla.hpp"
#include "cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp"
#include "cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp"
...
@@ -30,7 +31,8 @@ limitations under the License.
...
@@ -30,7 +31,8 @@ limitations under the License.
#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040
#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040
void
cutlass_mla_decode
(
void
cutlass_mla_decode
(
torch
::
Tensor
const
&
out
,
torch
::
Tensor
const
&
out
,
torch
::
Tensor
const
&
q_nope_and_q_pe
,
torch
::
Tensor
const
&
q_nope
,
torch
::
Tensor
const
&
q_pe
,
torch
::
Tensor
const
&
kv_c_and_k_pe_cache
,
torch
::
Tensor
const
&
kv_c_and_k_pe_cache
,
torch
::
Tensor
const
&
seq_lens
,
torch
::
Tensor
const
&
seq_lens
,
torch
::
Tensor
const
&
page_table
,
torch
::
Tensor
const
&
page_table
,
...
@@ -91,16 +93,17 @@ struct MlaSm100 {
...
@@ -91,16 +93,17 @@ struct MlaSm100 {
template
<
typename
T
>
template
<
typename
T
>
typename
T
::
Fmha
::
Arguments
args_from_options
(
typename
T
::
Fmha
::
Arguments
args_from_options
(
at
::
Tensor
const
&
out
,
at
::
Tensor
const
&
out
,
at
::
Tensor
const
&
q_nope_and_q_pe
,
at
::
Tensor
const
&
q_nope
,
at
::
Tensor
const
&
q_pe
,
at
::
Tensor
const
&
kv_c_and_k_pe_cache
,
at
::
Tensor
const
&
kv_c_and_k_pe_cache
,
at
::
Tensor
const
&
seq_lens
,
at
::
Tensor
const
&
seq_lens
,
at
::
Tensor
const
&
page_table
,
at
::
Tensor
const
&
page_table
,
int64_t
num_kv_splits
)
{
int64_t
num_kv_splits
)
{
cutlass
::
KernelHardwareInfo
hw_info
;
cutlass
::
KernelHardwareInfo
hw_info
;
hw_info
.
device_id
=
q_nope
_and_q_pe
.
device
().
index
();
hw_info
.
device_id
=
q_nope
.
device
().
index
();
hw_info
.
sm_count
=
cutlass
::
KernelHardwareInfo
::
query_device_multiprocessor_count
(
hw_info
.
device_id
);
hw_info
.
sm_count
=
cutlass
::
KernelHardwareInfo
::
query_device_multiprocessor_count
(
hw_info
.
device_id
);
int
batches
=
q_nope
_and_q_pe
.
sizes
()[
0
];
int
batches
=
q_nope
.
sizes
()[
0
];
int
page_count_per_seq
=
page_table
.
sizes
()[
1
];
int
page_count_per_seq
=
page_table
.
sizes
()[
1
];
int
page_count_total
=
kv_c_and_k_pe_cache
.
sizes
()[
0
];
int
page_count_total
=
kv_c_and_k_pe_cache
.
sizes
()[
0
];
int
page_size
=
kv_c_and_k_pe_cache
.
sizes
()[
1
];
int
page_size
=
kv_c_and_k_pe_cache
.
sizes
()[
1
];
...
@@ -122,8 +125,11 @@ typename T::Fmha::Arguments args_from_options(
...
@@ -122,8 +125,11 @@ typename T::Fmha::Arguments args_from_options(
using
StrideO
=
typename
T
::
StrideO
;
using
StrideO
=
typename
T
::
StrideO
;
using
StrideLSE
=
typename
T
::
StrideLSE
;
using
StrideLSE
=
typename
T
::
StrideLSE
;
StrideQ
stride_Q
=
cute
::
make_tuple
(
StrideQ
stride_Q_nope
=
cute
::
make_tuple
(
static_cast
<
int64_t
>
(
0
+
D_latent
+
D_rope
),
_1
{},
static_cast
<
int64_t
>
(
H
*
(
0
+
D_latent
+
D_rope
)));
static_cast
<
int64_t
>
(
q_nope
.
stride
(
1
)),
_1
{},
static_cast
<
int64_t
>
(
q_nope
.
stride
(
0
)));
StrideQ
stride_Q_pe
=
cute
::
make_tuple
(
static_cast
<
int64_t
>
(
q_pe
.
stride
(
1
)),
_1
{},
static_cast
<
int64_t
>
(
q_pe
.
stride
(
0
)));
StrideK
stride_C
=
cute
::
make_tuple
(
StrideK
stride_C
=
cute
::
make_tuple
(
static_cast
<
int64_t
>
(
0
+
D_latent
+
D_rope
),
_1
{},
static_cast
<
int64_t
>
(
page_size
*
(
D_latent
+
D_rope
)));
static_cast
<
int64_t
>
(
0
+
D_latent
+
D_rope
),
_1
{},
static_cast
<
int64_t
>
(
page_size
*
(
D_latent
+
D_rope
)));
StrideLSE
stride_PT
=
cute
::
make_stride
(
_1
{},
page_count_per_seq
);
StrideLSE
stride_PT
=
cute
::
make_stride
(
_1
{},
page_count_per_seq
);
...
@@ -133,15 +139,16 @@ typename T::Fmha::Arguments args_from_options(
...
@@ -133,15 +139,16 @@ typename T::Fmha::Arguments args_from_options(
using
Element
=
typename
T
::
Element
;
using
Element
=
typename
T
::
Element
;
using
ElementOut
=
typename
T
::
ElementOut
;
using
ElementOut
=
typename
T
::
ElementOut
;
using
ElementAcc
=
typename
T
::
ElementAcc
;
using
ElementAcc
=
typename
T
::
ElementAcc
;
auto
Q_ptr
=
static_cast
<
Element
*>
(
q_nope_and_q_pe
.
data_ptr
());
auto
Q_nope_ptr
=
static_cast
<
Element
*>
(
q_nope
.
data_ptr
());
auto
Q_pe_ptr
=
static_cast
<
Element
*>
(
q_pe
.
data_ptr
());
auto
C_ptr
=
static_cast
<
Element
*>
(
kv_c_and_k_pe_cache
.
data_ptr
());
auto
C_ptr
=
static_cast
<
Element
*>
(
kv_c_and_k_pe_cache
.
data_ptr
());
typename
T
::
Fmha
::
Arguments
arguments
{
typename
T
::
Fmha
::
Arguments
arguments
{
problem_shape
,
problem_shape
,
{
scale
,
{
scale
,
Q_ptr
,
Q_
nope_
ptr
,
stride_Q
,
stride_Q
_nope
,
Q_p
tr
+
D_latent
,
Q_p
e_ptr
,
stride_Q
,
stride_Q
_pe
,
C_ptr
,
C_ptr
,
stride_C
,
stride_C
,
C_ptr
+
D_latent
,
C_ptr
+
D_latent
,
...
@@ -170,7 +177,8 @@ typename T::Fmha::Arguments args_from_options(
...
@@ -170,7 +177,8 @@ typename T::Fmha::Arguments args_from_options(
template
<
typename
Element
,
bool
IsPaged128
,
typename
PersistenceOption
>
template
<
typename
Element
,
bool
IsPaged128
,
typename
PersistenceOption
>
void
runMla
(
void
runMla
(
at
::
Tensor
const
&
out
,
at
::
Tensor
const
&
out
,
at
::
Tensor
const
&
q_nope_and_q_pe
,
at
::
Tensor
const
&
q_nope
,
at
::
Tensor
const
&
q_pe
,
at
::
Tensor
const
&
kv_c_and_k_pe_cache
,
at
::
Tensor
const
&
kv_c_and_k_pe_cache
,
at
::
Tensor
const
&
seq_lens
,
at
::
Tensor
const
&
seq_lens
,
at
::
Tensor
const
&
page_table
,
at
::
Tensor
const
&
page_table
,
...
@@ -179,7 +187,7 @@ void runMla(
...
@@ -179,7 +187,7 @@ void runMla(
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
using
MlaSm100Type
=
MlaSm100
<
Element
,
IsPaged128
,
PersistenceOption
>
;
using
MlaSm100Type
=
MlaSm100
<
Element
,
IsPaged128
,
PersistenceOption
>
;
typename
MlaSm100Type
::
Fmha
fmha
;
typename
MlaSm100Type
::
Fmha
fmha
;
auto
arguments
=
args_from_options
<
MlaSm100Type
>
(
out
,
q_nope
_and_
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
num_kv_splits
);
auto
arguments
=
args_from_options
<
MlaSm100Type
>
(
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
num_kv_splits
);
CUTLASS_CHECK
(
fmha
.
can_implement
(
arguments
));
CUTLASS_CHECK
(
fmha
.
can_implement
(
arguments
));
...
@@ -201,15 +209,16 @@ void runMla(
...
@@ -201,15 +209,16 @@ void runMla(
void
cutlass_mla_decode
(
void
cutlass_mla_decode
(
torch
::
Tensor
const
&
out
,
torch
::
Tensor
const
&
out
,
torch
::
Tensor
const
&
q_nope_and_q_pe
,
torch
::
Tensor
const
&
q_nope
,
torch
::
Tensor
const
&
q_pe
,
torch
::
Tensor
const
&
kv_c_and_k_pe_cache
,
torch
::
Tensor
const
&
kv_c_and_k_pe_cache
,
torch
::
Tensor
const
&
seq_lens
,
torch
::
Tensor
const
&
seq_lens
,
torch
::
Tensor
const
&
page_table
,
torch
::
Tensor
const
&
page_table
,
torch
::
Tensor
const
&
workspace
,
torch
::
Tensor
const
&
workspace
,
int64_t
num_kv_splits
)
{
int64_t
num_kv_splits
)
{
auto
in_dtype
=
q_nope
_and_q_pe
.
dtype
();
auto
in_dtype
=
q_nope
.
dtype
();
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
q_nope
_and_q_pe
.
get_device
()};
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
q_nope
.
get_device
()};
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
q_nope
_and_q_pe
.
get_device
());
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
q_nope
.
get_device
());
const
int
page_size
=
kv_c_and_k_pe_cache
.
sizes
()[
1
];
const
int
page_size
=
kv_c_and_k_pe_cache
.
sizes
()[
1
];
// NOTE(alcanderian): IsPersistent has bug with manual split_kv.
// NOTE(alcanderian): IsPersistent has bug with manual split_kv.
...
@@ -219,13 +228,13 @@ void cutlass_mla_decode(
...
@@ -219,13 +228,13 @@ void cutlass_mla_decode(
DISPATCH_BOOL
(
num_kv_splits
<=
1
,
NotManualSplitKV
,
[
&
]
{
DISPATCH_BOOL
(
num_kv_splits
<=
1
,
NotManualSplitKV
,
[
&
]
{
if
(
in_dtype
==
at
::
ScalarType
::
Half
)
{
if
(
in_dtype
==
at
::
ScalarType
::
Half
)
{
runMla
<
cutlass
::
half_t
,
IsPaged128
,
IsPersistent
<
NotManualSplitKV
>>
(
runMla
<
cutlass
::
half_t
,
IsPaged128
,
IsPersistent
<
NotManualSplitKV
>>
(
out
,
q_nope
_and_
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
num_kv_splits
,
stream
);
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
num_kv_splits
,
stream
);
}
else
if
(
in_dtype
==
at
::
ScalarType
::
BFloat16
)
{
}
else
if
(
in_dtype
==
at
::
ScalarType
::
BFloat16
)
{
runMla
<
cutlass
::
bfloat16_t
,
IsPaged128
,
IsPersistent
<
NotManualSplitKV
>>
(
runMla
<
cutlass
::
bfloat16_t
,
IsPaged128
,
IsPersistent
<
NotManualSplitKV
>>
(
out
,
q_nope
_and_
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
num_kv_splits
,
stream
);
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
num_kv_splits
,
stream
);
}
else
if
(
in_dtype
==
at
::
ScalarType
::
Float8_e4m3fn
)
{
}
else
if
(
in_dtype
==
at
::
ScalarType
::
Float8_e4m3fn
)
{
runMla
<
cutlass
::
float_e4m3_t
,
IsPaged128
,
IsPersistent
<
NotManualSplitKV
>>
(
runMla
<
cutlass
::
float_e4m3_t
,
IsPaged128
,
IsPersistent
<
NotManualSplitKV
>>
(
out
,
q_nope
_and_
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
num_kv_splits
,
stream
);
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
num_kv_splits
,
stream
);
}
else
{
}
else
{
TORCH_CHECK
(
false
,
"Unsupported input data type of MLA"
);
TORCH_CHECK
(
false
,
"Unsupported input data type of MLA"
);
}
}
...
...
sgl-kernel/csrc/common_extension.cc
View file @
aa46ed34
...
@@ -59,7 +59,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -59,7 +59,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m
.
def
(
"merge_state_v2(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()"
);
m
.
def
(
"merge_state_v2(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()"
);
m
.
impl
(
"merge_state_v2"
,
torch
::
kCUDA
,
&
merge_state_v2
);
m
.
impl
(
"merge_state_v2"
,
torch
::
kCUDA
,
&
merge_state_v2
);
m
.
def
(
m
.
def
(
"cutlass_mla_decode(Tensor! out, Tensor q_nope
_and_
q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor "
"cutlass_mla_decode(Tensor! out, Tensor q_nope
, Tensor
q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor "
"page_table, Tensor! workspace, int num_kv_splits) -> ()"
);
"page_table, Tensor! workspace, int num_kv_splits) -> ()"
);
m
.
impl
(
"cutlass_mla_decode"
,
torch
::
kCUDA
,
&
cutlass_mla_decode
);
m
.
impl
(
"cutlass_mla_decode"
,
torch
::
kCUDA
,
&
cutlass_mla_decode
);
m
.
def
(
"cutlass_mla_get_workspace_size"
,
&
cutlass_mla_get_workspace_size
);
m
.
def
(
"cutlass_mla_get_workspace_size"
,
&
cutlass_mla_get_workspace_size
);
...
...
sgl-kernel/include/sgl_kernel_ops.h
View file @
aa46ed34
...
@@ -105,7 +105,8 @@ void merge_state_v2(
...
@@ -105,7 +105,8 @@ void merge_state_v2(
at
::
Tensor
v_a
,
at
::
Tensor
s_a
,
at
::
Tensor
v_b
,
at
::
Tensor
s_b
,
at
::
Tensor
v_merged
,
at
::
Tensor
s_merged
);
at
::
Tensor
v_a
,
at
::
Tensor
s_a
,
at
::
Tensor
v_b
,
at
::
Tensor
s_b
,
at
::
Tensor
v_merged
,
at
::
Tensor
s_merged
);
void
cutlass_mla_decode
(
void
cutlass_mla_decode
(
torch
::
Tensor
const
&
out
,
torch
::
Tensor
const
&
out
,
torch
::
Tensor
const
&
q_nope_and_q_pe
,
torch
::
Tensor
const
&
q_nope
,
torch
::
Tensor
const
&
q_pe
,
torch
::
Tensor
const
&
kv_c_and_k_pe_cache
,
torch
::
Tensor
const
&
kv_c_and_k_pe_cache
,
torch
::
Tensor
const
&
seq_lens
,
torch
::
Tensor
const
&
seq_lens
,
torch
::
Tensor
const
&
page_table
,
torch
::
Tensor
const
&
page_table
,
...
...
sgl-kernel/python/sgl_kernel/attention.py
View file @
aa46ed34
...
@@ -52,34 +52,42 @@ def merge_state_v2(
...
@@ -52,34 +52,42 @@ def merge_state_v2(
def
cutlass_mla_decode
(
def
cutlass_mla_decode
(
q_nope_and_q_pe
:
torch
.
Tensor
,
q_nope
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
page_table
:
torch
.
Tensor
,
page_table
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
num_kv_splits
:
int
=
-
1
,
num_kv_splits
:
int
=
-
1
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
(
assert
q_nope
.
ndim
==
3
,
f
"q_nope must be a 3D tensor, but got
{
q_nope
.
ndim
}
"
q_nope_and_q_pe
.
ndim
==
3
assert
q_pe
.
ndim
==
3
,
f
"q_pe must be a 3D tensor, but got
{
q_pe
.
ndim
}
"
),
f
"q_nope_and_q_pe must be a 3D tensor, but got
{
q_nope_and_q_pe
.
ndim
}
"
assert
(
assert
(
kv_c_and_k_pe_cache
.
ndim
==
3
kv_c_and_k_pe_cache
.
ndim
==
3
),
f
"kv_c_and_k_pe_cache must be a 3D tensor, but got
{
kv_c_and_k_pe_cache
.
ndim
}
"
),
f
"kv_c_and_k_pe_cache must be a 3D tensor, but got
{
kv_c_and_k_pe_cache
.
ndim
}
"
B_q
,
H
,
D_q
=
q_nope_and_q_pe
.
shape
B_q
,
H
,
D_q_nope
=
q_nope
.
shape
B_q_2
,
H_2
,
D_q_pe
=
q_pe
.
shape
assert
(
B_q
==
B_q_2
)
and
(
H
==
H_2
)
_
,
PAGE_SIZE
,
D_ckv
=
kv_c_and_k_pe_cache
.
shape
_
,
PAGE_SIZE
,
D_ckv
=
kv_c_and_k_pe_cache
.
shape
D_latent
=
512
D_latent
=
512
D_rope
=
64
D_rope
=
64
assert
D_q
==
D_ckv
and
D_q
==
D_latent
+
D_rope
,
(
assert
D_q
_nope
==
D_latent
f
"D_q must be equal to D_ckv and D_q must be equal to D_latent +
D_rope
, "
assert
D_q_pe
==
D_rope
f
"but got D_q =
{
D_q
}
, D_ckv =
{
D_ckv
}
, D_latent
=
{
D_latent
}
, D_rope =
{
D_rope
}
"
assert
D_ckv
=
=
D_latent
+
D_rope
)
MAX_HEADS
=
128
MAX_HEADS
=
128
assert
H
<=
MAX_HEADS
,
f
"H must be <=
{
MAX_HEADS
}
, but got
{
H
}
"
assert
H
<=
MAX_HEADS
,
f
"H must be <=
{
MAX_HEADS
}
, but got
{
H
}
"
if
H
<
MAX_HEADS
:
if
H
<
MAX_HEADS
:
q_nope_and_q_pe_padded
=
q_nope_and_q_pe
.
new_empty
((
B_q
,
MAX_HEADS
,
D_q
))
q_nope_padded
=
q_nope
.
new_empty
((
B_q
,
MAX_HEADS
,
D_q_nope
))
q_nope_and_q_pe_padded
[:,
:
H
]
=
q_nope_and_q_pe
q_nope_padded
[:,
:
H
]
=
q_nope
q_nope_and_q_pe
=
q_nope_and_q_pe_padded
q_nope
=
q_nope_padded
q_pe_padded
=
q_pe
.
new_empty
((
B_q
,
MAX_HEADS
,
D_q_pe
))
q_pe_padded
[:,
:
H
]
=
q_pe
q_pe
=
q_pe_padded
assert
len
(
page_table
.
shape
)
==
2
assert
len
(
page_table
.
shape
)
==
2
B_block_table
,
block_num
=
page_table
.
shape
B_block_table
,
block_num
=
page_table
.
shape
...
@@ -88,14 +96,11 @@ def cutlass_mla_decode(
...
@@ -88,14 +96,11 @@ def cutlass_mla_decode(
assert
block_num
%
(
128
/
PAGE_SIZE
)
==
0
assert
block_num
%
(
128
/
PAGE_SIZE
)
==
0
# TODO(kaixih@nvidia): support fp8
# TODO(kaixih@nvidia): support fp8
assert
q_nope
_and_q_pe
.
dtype
in
(
assert
q_nope
.
dtype
in
(
torch
.
float16
,
torch
.
float16
,
torch
.
bfloat16
,
torch
.
bfloat16
,
),
f
"q_nope_and_q_pe.dtype needs to be fp16 or bf16 but got
{
q_nope_and_q_pe
.
dtype
}
."
),
f
"q_nope.dtype needs to be fp16 or bf16 but got
{
q_nope
.
dtype
}
."
assert
kv_c_and_k_pe_cache
.
dtype
==
q_nope_and_q_pe
.
dtype
,
(
assert
q_nope
.
dtype
==
q_pe
.
dtype
==
kv_c_and_k_pe_cache
.
dtype
f
"kv_c_and_k_pe_cache.dtype needs to be the same as q_nope_and_q_pe.dtype, "
f
"but got
{
kv_c_and_k_pe_cache
.
dtype
}
."
)
assert
(
assert
(
seq_lens
.
dtype
==
torch
.
int32
seq_lens
.
dtype
==
torch
.
int32
),
f
"seq_lens.dtype needs to be int32 but got
{
seq_lens
.
dtype
}
."
),
f
"seq_lens.dtype needs to be int32 but got
{
seq_lens
.
dtype
}
."
...
@@ -103,11 +108,12 @@ def cutlass_mla_decode(
...
@@ -103,11 +108,12 @@ def cutlass_mla_decode(
page_table
.
dtype
==
torch
.
int32
page_table
.
dtype
==
torch
.
int32
),
f
"page_table.dtype needs to be int32 but got
{
page_table
.
dtype
}
."
),
f
"page_table.dtype needs to be int32 but got
{
page_table
.
dtype
}
."
out
=
q_nope
_and_q_pe
.
new_empty
((
B_q
,
MAX_HEADS
,
D_latent
))
out
=
q_nope
.
new_empty
((
B_q
,
MAX_HEADS
,
D_latent
))
torch
.
ops
.
sgl_kernel
.
cutlass_mla_decode
.
default
(
torch
.
ops
.
sgl_kernel
.
cutlass_mla_decode
.
default
(
out
,
out
,
q_nope_and_q_pe
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
kv_c_and_k_pe_cache
,
seq_lens
,
seq_lens
,
page_table
,
page_table
,
...
...
sgl-kernel/tests/test_cutlass_mla.py
View file @
aa46ed34
...
@@ -86,10 +86,14 @@ def test_cutlass_mla_decode(
...
@@ -86,10 +86,14 @@ def test_cutlass_mla_decode(
)
)
workspace
=
torch
.
empty
(
workspace_size
,
device
=
"cuda"
,
dtype
=
torch
.
uint8
)
workspace
=
torch
.
empty
(
workspace_size
,
device
=
"cuda"
,
dtype
=
torch
.
uint8
)
q_nope
=
torch
.
empty
((
h_q
,
bs
,
dv
)).
transpose
(
0
,
1
)
q_nope
.
copy_
(
q
[:,
:,
:
dv
])
q_pe
=
q
[:,
:,
dv
:].
clone
()
out_ref
=
q
.
new_zeros
(
bs
,
h_q
,
dv
)
out_ref
=
q
.
new_zeros
(
bs
,
h_q
,
dv
)
ref_mla
(
out_ref
,
q
,
kv_cache
,
scale
,
block_table
,
seq_lens
)
ref_mla
(
out_ref
,
q
,
kv_cache
,
scale
,
block_table
,
seq_lens
)
out
=
cutlass_mla_decode
(
out
=
cutlass_mla_decode
(
q
,
kv_cache
,
seq_lens
,
block_table
,
workspace
,
num_kv_splits
q
_nope
,
q_pe
,
kv_cache
,
seq_lens
,
block_table
,
workspace
,
num_kv_splits
)
)
torch
.
testing
.
assert_close
(
out
,
out_ref
,
atol
=
1e-2
,
rtol
=
1e-2
)
torch
.
testing
.
assert_close
(
out
,
out_ref
,
atol
=
1e-2
,
rtol
=
1e-2
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment