Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f65b8d5c
"git@developer.sourcefind.cn:change/sglang.git" did not exist on "ed6f7597b3395b7bfc53e74f8879eac597b834c2"
Unverified
Commit
f65b8d5c
authored
Apr 11, 2025
by
Trevor Morris
Committed by
GitHub
Apr 11, 2025
Browse files
Blackwell Cutlass MLA kernel (#5142)
parent
5ad05719
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
371 additions
and
3 deletions
+371
-3
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+4
-1
sgl-kernel/csrc/attention/cutlass_mla_kernel.cu
sgl-kernel/csrc/attention/cutlass_mla_kernel.cu
+207
-0
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+5
-0
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+8
-1
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+5
-1
sgl-kernel/python/sgl_kernel/attention.py
sgl-kernel/python/sgl_kernel/attention.py
+61
-0
sgl-kernel/tests/test_cutlass_mla.py
sgl-kernel/tests/test_cutlass_mla.py
+81
-0
No files found.
sgl-kernel/CMakeLists.txt
View file @
f65b8d5c
...
@@ -33,7 +33,7 @@ include(FetchContent)
...
@@ -33,7 +33,7 @@ include(FetchContent)
FetchContent_Declare
(
FetchContent_Declare
(
repo-cutlass
repo-cutlass
GIT_REPOSITORY https://github.com/NVIDIA/cutlass
GIT_REPOSITORY https://github.com/NVIDIA/cutlass
GIT_TAG
6f4921858b3bb0a82d7cbeb4e499690e9ae60d16
GIT_TAG
df8a550d3917b0e97f416b2ed8c2d786f7f686a3
GIT_SHALLOW OFF
GIT_SHALLOW OFF
)
)
FetchContent_Populate
(
repo-cutlass
)
FetchContent_Populate
(
repo-cutlass
)
...
@@ -76,6 +76,8 @@ include_directories(
...
@@ -76,6 +76,8 @@ include_directories(
${
PROJECT_SOURCE_DIR
}
/csrc
${
PROJECT_SOURCE_DIR
}
/csrc
${
repo-cutlass_SOURCE_DIR
}
/include
${
repo-cutlass_SOURCE_DIR
}
/include
${
repo-cutlass_SOURCE_DIR
}
/tools/util/include
${
repo-cutlass_SOURCE_DIR
}
/tools/util/include
${
repo-cutlass_SOURCE_DIR
}
/examples/77_blackwell_fmha
${
repo-cutlass_SOURCE_DIR
}
/examples/common
${
repo-flashinfer_SOURCE_DIR
}
/include
${
repo-flashinfer_SOURCE_DIR
}
/include
${
repo-flashinfer_SOURCE_DIR
}
/csrc
${
repo-flashinfer_SOURCE_DIR
}
/csrc
${
repo-flash-attention_SOURCE_DIR
}
/hopper
${
repo-flash-attention_SOURCE_DIR
}
/hopper
...
@@ -158,6 +160,7 @@ string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE
...
@@ -158,6 +160,7 @@ string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE
set
(
SOURCES
set
(
SOURCES
"csrc/allreduce/custom_all_reduce.cu"
"csrc/allreduce/custom_all_reduce.cu"
"csrc/attention/cutlass_mla_kernel.cu"
"csrc/attention/lightning_attention_decode_kernel.cu"
"csrc/attention/lightning_attention_decode_kernel.cu"
"csrc/elementwise/activation.cu"
"csrc/elementwise/activation.cu"
"csrc/elementwise/fused_add_rms_norm_kernel.cu"
"csrc/elementwise/fused_add_rms_norm_kernel.cu"
...
...
sgl-kernel/csrc/attention/cutlass_mla_kernel.cu
0 → 100644
View file @
f65b8d5c
/*
Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cutlass/cutlass.h>
#include <cutlass/kernel_hardware_info.h>
#include <torch/all.h>
#include <cute/tensor.hpp>
#include <device/sm100_mla.hpp>
#include <kernel/sm100_mla_tile_scheduler.hpp>
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
TORCH_CHECK(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \
}
using
namespace
cute
;
using
namespace
cutlass
::
fmha
::
kernel
;
template
<
bool
v
>
struct
IsPersistent
{
static
const
bool
value
=
v
;
};
template
<
typename
T
,
typename
PersistenceOption
=
IsPersistent
<
true
>
>
struct
MlaSm100
{
using
Element
=
T
;
using
ElementAcc
=
float
;
using
ElementOut
=
T
;
using
TileShape
=
Shape
<
_128
,
_128
,
Shape
<
_512
,
_64
>>
;
using
TileShapeH
=
cute
::
tuple_element_t
<
0
,
TileShape
>
;
using
TileShapeD
=
cute
::
tuple_element_t
<
2
,
TileShape
>
;
// H K (D_latent D_rope) B
using
ProblemShape
=
cute
::
tuple
<
TileShapeH
,
int
,
TileShapeD
,
int
>
;
using
StrideQ
=
cute
::
tuple
<
int64_t
,
_1
,
int64_t
>
;
// H D B
using
StrideK
=
cute
::
tuple
<
int64_t
,
_1
,
int64_t
>
;
// K D B
using
StrideO
=
StrideK
;
// H D B
using
StrideLSE
=
cute
::
tuple
<
_1
,
int
>
;
// H B
using
TileScheduler
=
std
::
conditional_t
<
PersistenceOption
::
value
,
Sm100MlaPersistentTileScheduler
,
Sm100MlaIndividualTileScheduler
>
;
using
FmhaKernel
=
cutlass
::
fmha
::
kernel
::
Sm100FmhaMlaKernelTmaWarpspecialized
<
TileShape
,
Element
,
ElementAcc
,
ElementOut
,
ElementAcc
,
TileScheduler
,
/*kIsCpAsync=*/
true
>
;
using
Fmha
=
cutlass
::
fmha
::
device
::
MLA
<
FmhaKernel
>
;
};
template
<
typename
T
>
typename
T
::
Fmha
::
Arguments
args_from_options
(
at
::
Tensor
const
&
out
,
at
::
Tensor
const
&
q_nope_and_q_pe
,
at
::
Tensor
const
&
kv_c_and_k_pe_cache
,
at
::
Tensor
const
&
seq_lens
,
at
::
Tensor
const
&
page_table
)
{
cutlass
::
KernelHardwareInfo
hw_info
;
hw_info
.
device_id
=
q_nope_and_q_pe
.
device
().
index
();
hw_info
.
sm_count
=
cutlass
::
KernelHardwareInfo
::
query_device_multiprocessor_count
(
hw_info
.
device_id
);
int
batches
=
q_nope_and_q_pe
.
sizes
()[
0
];
int
page_count_per_seq
=
page_table
.
sizes
()[
1
];
int
page_count_total
=
kv_c_and_k_pe_cache
.
sizes
()[
0
];
int
page_size
=
kv_c_and_k_pe_cache
.
sizes
()[
1
];
int
max_seq_len
=
page_size
*
page_count_per_seq
;
using
TileShapeH
=
typename
T
::
TileShapeH
;
using
TileShapeD
=
typename
T
::
TileShapeD
;
auto
problem_shape
=
cute
::
make_tuple
(
TileShapeH
{},
max_seq_len
,
TileShapeD
{},
batches
);
auto
[
H
,
K
,
D
,
B
]
=
problem_shape
;
auto
[
D_latent
,
D_rope
]
=
D
;
// the scale is based on the non-absorbed sizes, change as appropriate
// we can't determine this parameter from the info we have, it's an input
int
D_non_latent
=
128
;
float
scale
=
1.0
/
sqrt
(
1.0
*
(
D_non_latent
+
D_rope
));
using
StrideQ
=
typename
T
::
StrideQ
;
using
StrideK
=
typename
T
::
StrideK
;
using
StrideO
=
typename
T
::
StrideO
;
using
StrideLSE
=
typename
T
::
StrideLSE
;
StrideQ
stride_Q
=
cute
::
make_tuple
(
static_cast
<
int64_t
>
(
0
+
D_latent
+
D_rope
),
_1
{},
static_cast
<
int64_t
>
(
H
*
(
0
+
D_latent
+
D_rope
)));
StrideK
stride_C
=
cute
::
make_tuple
(
static_cast
<
int64_t
>
(
0
+
D_latent
+
D_rope
),
_1
{},
static_cast
<
int64_t
>
(
page_size
*
(
D_latent
+
D_rope
)));
StrideLSE
stride_PT
=
cute
::
make_stride
(
_1
{},
page_count_per_seq
);
StrideLSE
stride_LSE
=
cute
::
make_tuple
(
_1
{},
0
+
H
);
StrideO
stride_O
=
cute
::
make_tuple
(
static_cast
<
int64_t
>
(
0
+
D_latent
),
_1
{},
static_cast
<
int64_t
>
(
0
+
H
*
D_latent
));
using
Element
=
typename
T
::
Element
;
using
ElementOut
=
typename
T
::
ElementOut
;
using
ElementAcc
=
typename
T
::
ElementAcc
;
auto
Q_ptr
=
static_cast
<
Element
*>
(
q_nope_and_q_pe
.
data_ptr
());
auto
C_ptr
=
static_cast
<
Element
*>
(
kv_c_and_k_pe_cache
.
data_ptr
());
typename
T
::
Fmha
::
Arguments
arguments
{
problem_shape
,
{
scale
,
Q_ptr
,
stride_Q
,
Q_ptr
+
D_latent
,
stride_Q
,
C_ptr
,
stride_C
,
C_ptr
+
D_latent
,
stride_C
,
static_cast
<
int
*>
(
seq_lens
.
data_ptr
()),
static_cast
<
int
*>
(
page_table
.
data_ptr
()),
stride_PT
,
page_count_total
,
page_size
},
{
static_cast
<
ElementOut
*>
(
out
.
data_ptr
()),
stride_O
,
static_cast
<
ElementAcc
*>
(
nullptr
),
stride_LSE
},
hw_info
,
-
1
,
// split_kv
nullptr
,
// is_var_split_kv
};
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
// split_kv automatically based on batch size and sequence length to balance
// workload across available SMs. Consider using var_split_kv for manual
// control if needed.
T
::
Fmha
::
set_split_kv
(
arguments
);
return
arguments
;
}
template
<
typename
Element
>
void
runMla
(
at
::
Tensor
const
&
out
,
at
::
Tensor
const
&
q_nope_and_q_pe
,
at
::
Tensor
const
&
kv_c_and_k_pe_cache
,
at
::
Tensor
const
&
seq_lens
,
at
::
Tensor
const
&
page_table
,
at
::
Tensor
const
&
workspace
,
cudaStream_t
stream
)
{
using
MlaSm100Type
=
MlaSm100
<
Element
>
;
typename
MlaSm100Type
::
Fmha
fmha
;
auto
arguments
=
args_from_options
<
MlaSm100Type
>
(
out
,
q_nope_and_q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
);
CUTLASS_CHECK
(
fmha
.
can_implement
(
arguments
));
CUTLASS_CHECK
(
fmha
.
initialize
(
arguments
,
workspace
.
data_ptr
(),
stream
));
CUTLASS_CHECK
(
fmha
.
run
(
arguments
,
workspace
.
data_ptr
(),
stream
));
}
void
cutlass_mla_decode
(
torch
::
Tensor
const
&
out
,
torch
::
Tensor
const
&
q_nope_and_q_pe
,
torch
::
Tensor
const
&
kv_c_and_k_pe_cache
,
torch
::
Tensor
const
&
seq_lens
,
torch
::
Tensor
const
&
page_table
,
torch
::
Tensor
const
&
workspace
)
{
auto
in_dtype
=
q_nope_and_q_pe
.
dtype
();
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
q_nope_and_q_pe
.
get_device
()};
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
q_nope_and_q_pe
.
get_device
());
if
(
in_dtype
==
at
::
ScalarType
::
Half
)
{
runMla
<
cutlass
::
half_t
>
(
out
,
q_nope_and_q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
stream
);
}
else
if
(
in_dtype
==
at
::
ScalarType
::
BFloat16
)
{
runMla
<
cutlass
::
bfloat16_t
>
(
out
,
q_nope_and_q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
stream
);
}
else
if
(
in_dtype
==
at
::
ScalarType
::
Float8_e4m3fn
)
{
runMla
<
cutlass
::
float_e4m3_t
>
(
out
,
q_nope_and_q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
stream
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported input data type of MLA"
);
}
}
int64_t
cutlass_mla_get_workspace_size
(
int64_t
max_seq_len
,
int64_t
num_batches
,
int64_t
sm_count
)
{
// Workspace size depends on ElementAcc and ElementLSE (same as ElementAcc)
// which are float, so Element type here doesn't matter.
using
MlaSm100Type
=
MlaSm100
<
cutlass
::
half_t
>
;
// Get split kv. Requires problem shape and sm_count only.
typename
MlaSm100Type
::
Fmha
::
Arguments
arguments
;
using
TileShapeH
=
typename
MlaSm100Type
::
TileShapeH
;
using
TileShapeD
=
typename
MlaSm100Type
::
TileShapeD
;
arguments
.
problem_shape
=
cute
::
make_tuple
(
TileShapeH
{},
static_cast
<
int
>
(
max_seq_len
),
TileShapeD
{},
static_cast
<
int
>
(
num_batches
));
// Assumes device 0 when getting sm_count.
arguments
.
hw_info
.
sm_count
=
sm_count
<=
0
?
cutlass
::
KernelHardwareInfo
::
query_device_multiprocessor_count
(
/*device_id=*/
0
)
:
sm_count
;
MlaSm100Type
::
Fmha
::
set_split_kv
(
arguments
);
return
MlaSm100Type
::
Fmha
::
get_workspace_size
(
arguments
);
}
sgl-kernel/csrc/common_extension.cc
View file @
f65b8d5c
...
@@ -45,6 +45,11 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -45,6 +45,11 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! "
"lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! "
"new_kv) -> ()"
);
"new_kv) -> ()"
);
m
.
impl
(
"lightning_attention_decode"
,
torch
::
kCUDA
,
&
lightning_attention_decode
);
m
.
impl
(
"lightning_attention_decode"
,
torch
::
kCUDA
,
&
lightning_attention_decode
);
m
.
def
(
"cutlass_mla_decode(Tensor! out, Tensor q_nope_and_q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor "
"page_table, Tensor workspace) -> ()"
);
m
.
impl
(
"cutlass_mla_decode"
,
torch
::
kCUDA
,
&
cutlass_mla_decode
);
m
.
def
(
"cutlass_mla_get_workspace_size"
,
&
cutlass_mla_get_workspace_size
);
/*
/*
* From csrc/elementwise
* From csrc/elementwise
...
...
sgl-kernel/include/sgl_kernel_ops.h
View file @
f65b8d5c
...
@@ -87,7 +87,14 @@ void lightning_attention_decode(
...
@@ -87,7 +87,14 @@ void lightning_attention_decode(
const
torch
::
Tensor
&
slope
,
const
torch
::
Tensor
&
slope
,
torch
::
Tensor
output
,
torch
::
Tensor
output
,
torch
::
Tensor
new_kv
);
torch
::
Tensor
new_kv
);
void
cutlass_mla_decode
(
torch
::
Tensor
const
&
out
,
torch
::
Tensor
const
&
q_nope_and_q_pe
,
torch
::
Tensor
const
&
kv_c_and_k_pe_cache
,
torch
::
Tensor
const
&
seq_lens
,
torch
::
Tensor
const
&
page_table
,
torch
::
Tensor
const
&
workspace
);
int64_t
cutlass_mla_get_workspace_size
(
int64_t
max_seq_len
,
int64_t
num_batches
,
int64_t
sm_count
=
0
);
/*
/*
* From csrc/elementwise
* From csrc/elementwise
*/
*/
...
...
sgl-kernel/python/sgl_kernel/__init__.py
View file @
f65b8d5c
...
@@ -11,7 +11,11 @@ if os.path.exists("/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12"):
...
@@ -11,7 +11,11 @@ if os.path.exists("/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12"):
from
sgl_kernel
import
common_ops
from
sgl_kernel
import
common_ops
from
sgl_kernel.allreduce
import
*
from
sgl_kernel.allreduce
import
*
from
sgl_kernel.attention
import
lightning_attention_decode
from
sgl_kernel.attention
import
(
cutlass_mla_decode
,
cutlass_mla_get_workspace_size
,
lightning_attention_decode
,
)
from
sgl_kernel.elementwise
import
(
from
sgl_kernel.elementwise
import
(
apply_rope_with_cos_sin_cache_inplace
,
apply_rope_with_cos_sin_cache_inplace
,
fused_add_rmsnorm
,
fused_add_rmsnorm
,
...
...
sgl-kernel/python/sgl_kernel/attention.py
View file @
f65b8d5c
...
@@ -5,3 +5,64 @@ def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
...
@@ -5,3 +5,64 @@ def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
torch
.
ops
.
sgl_kernel
.
lightning_attention_decode
.
default
(
torch
.
ops
.
sgl_kernel
.
lightning_attention_decode
.
default
(
q
,
k
,
v
,
past_kv
,
slope
,
output
,
new_kv
q
,
k
,
v
,
past_kv
,
slope
,
output
,
new_kv
)
)
def
cutlass_mla_decode
(
q_nope_and_q_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
page_table
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
assert
(
q_nope_and_q_pe
.
ndim
==
3
),
f
"q_nope_and_q_pe must be a 3D tensor, but got
{
q_nope_and_q_pe
.
ndim
}
"
assert
(
kv_c_and_k_pe_cache
.
ndim
==
3
),
f
"kv_c_and_k_pe_cache must be a 3D tensor, but got
{
kv_c_and_k_pe_cache
.
ndim
}
"
B_q
,
H
,
D_q
=
q_nope_and_q_pe
.
shape
_
,
PAGE_SIZE
,
D_ckv
=
kv_c_and_k_pe_cache
.
shape
D_latent
=
512
D_rope
=
64
assert
D_q
==
D_ckv
and
D_q
==
D_latent
+
D_rope
,
(
f
"D_q must be equal to D_ckv and D_q must be equal to D_latent + D_rope, "
f
"but got D_q =
{
D_q
}
, D_ckv =
{
D_ckv
}
, D_latent =
{
D_latent
}
, D_rope =
{
D_rope
}
"
)
assert
H
==
128
,
f
"H must be 128, but got
{
H
}
"
# TODO: There is currently an illegal memory access issue with page size !=
# 128. Change this when it is fixed.
assert
PAGE_SIZE
==
128
,
f
"PAGE_SIZE must be 128, but got
{
PAGE_SIZE
}
"
# TODO(kaixih@nvidia): support fp8
assert
q_nope_and_q_pe
.
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
,
),
f
"q_nope_and_q_pe.dtype needs to be fp16 or bf16 but got
{
q_nope_and_q_pe
.
dtype
}
."
assert
kv_c_and_k_pe_cache
.
dtype
==
q_nope_and_q_pe
.
dtype
,
(
f
"kv_c_and_k_pe_cache.dtype needs to be the same as q_nope_and_q_pe.dtype, "
f
"but got
{
kv_c_and_k_pe_cache
.
dtype
}
."
)
assert
(
seq_lens
.
dtype
==
torch
.
int32
),
f
"seq_lens.dtype needs to be int32 but got
{
seq_lens
.
dtype
}
."
assert
(
page_table
.
dtype
==
torch
.
int32
),
f
"page_table.dtype needs to be int32 but got
{
page_table
.
dtype
}
."
out
=
torch
.
empty
(
(
B_q
,
H
,
D_latent
),
device
=
q_nope_and_q_pe
.
device
,
dtype
=
q_nope_and_q_pe
.
dtype
)
torch
.
ops
.
sgl_kernel
.
cutlass_mla_decode
(
out
,
q_nope_and_q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
)
return
out
def
cutlass_mla_get_workspace_size
(
max_seq_len
:
int
,
num_batches
:
int
,
sm_count
:
int
=
0
)
->
int
:
return
torch
.
ops
.
sgl_kernel
.
cutlass_mla_get_workspace_size
(
max_seq_len
,
num_batches
,
sm_count
)
sgl-kernel/tests/test_cutlass_mla.py
0 → 100644
View file @
f65b8d5c
import
pytest
import
torch
import
torch.nn.functional
as
F
from
sgl_kernel
import
cutlass_mla_decode
,
cutlass_mla_get_workspace_size
from
torch
import
Tensor
if
torch
.
cuda
.
get_device_capability
()
<
(
10
,
0
):
pytest
.
skip
(
reason
=
"Cutlass MLA Requires compute capability of 10 or above."
,
allow_module_level
=
True
,
)
def
ref_mla
(
out
:
Tensor
,
# (bs, num_heads, v_head_dim)
query
:
Tensor
,
# (bs, num_heads, head_dim)
kv_cache
:
Tensor
,
# (num_blocks, block_size, head_dim)
scale
:
float
,
block_tables
:
Tensor
,
# (bs, max_num_blocks)
seq_lens
:
Tensor
,
# (bs,)
):
bs
,
num_heads
,
v_head_dim
=
out
.
shape
head_dim
=
query
.
shape
[
2
]
for
i
in
range
(
bs
):
# gather and flatten KV-cache
kv
=
kv_cache
[
block_tables
[
i
]]
# (max_num_blocks, block_size, head_dim)
kv
=
kv
.
view
(
1
,
-
1
,
head_dim
)[:,
:
seq_lens
[
i
]]
# (1, seq_len, head_dim)
v
=
kv
[:,
:,
:
v_head_dim
]
q
=
query
[
i
].
view
(
num_heads
,
1
,
head_dim
)
o
=
F
.
scaled_dot_product_attention
(
q
,
kv
,
v
,
scale
=
scale
,
enable_gqa
=
True
)
out
[
i
]
=
o
.
view
(
num_heads
,
v_head_dim
)
return
out
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"mean_seq_len"
,
[
128
,
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
"bs"
,
[
1
,
2
,
4
])
@
pytest
.
mark
.
parametrize
(
"varlen"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
128
])
def
test_cutlass_mla_decode
(
dtype
:
torch
.
dtype
,
mean_seq_len
:
int
,
bs
:
int
,
varlen
:
bool
,
block_size
:
int
):
torch
.
set_default_dtype
(
dtype
)
torch
.
set_default_device
(
"cuda"
)
torch
.
manual_seed
(
42
)
d
=
576
h_q
=
128
dv
=
512
q_nope_dim
=
128
q_pe_dim
=
64
scale
=
(
q_nope_dim
+
q_pe_dim
)
**
(
-
0.5
)
if
varlen
:
seq_lens
=
torch
.
empty
(
bs
).
normal_
(
mean_seq_len
,
mean_seq_len
/
2
)
seq_lens
=
seq_lens
.
clip
(
2
).
to
(
torch
.
int32
)
else
:
seq_lens
=
torch
.
full
((
bs
,),
mean_seq_len
,
dtype
=
torch
.
int32
)
max_seq_len
=
seq_lens
.
max
().
item
()
block_num
=
(
max_seq_len
+
block_size
-
1
)
//
block_size
q
=
torch
.
randn
(
bs
,
h_q
,
d
)
block_table
=
torch
.
randint
(
0
,
bs
*
block_num
,
(
bs
,
block_num
),
dtype
=
torch
.
int32
)
kv_cache
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
d
)
workspace_size
=
cutlass_mla_get_workspace_size
(
block_num
*
block_size
,
bs
)
workspace
=
torch
.
empty
(
workspace_size
,
device
=
"cuda"
,
dtype
=
torch
.
uint8
)
out_ref
=
q
.
new_zeros
(
bs
,
h_q
,
dv
)
ref_mla
(
out_ref
,
q
,
kv_cache
,
scale
,
block_table
,
seq_lens
)
out
=
cutlass_mla_decode
(
q
,
kv_cache
,
seq_lens
,
block_table
,
workspace
)
torch
.
testing
.
assert_close
(
out
,
out_ref
,
atol
=
1e-2
,
rtol
=
1e-2
)
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment