Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
18efb5e8
Unverified
Commit
18efb5e8
authored
Jun 09, 2025
by
JieXin Liang
Committed by
GitHub
Jun 08, 2025
Browse files
[perf][sgl-kernel] extend cutlass_mla_decode to support num_head < 128 (#6929)
parent
de1350ea
Changes
10
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
2959 additions
and
37 deletions
+2959
-37
sgl-kernel/benchmark/bench_cutlass_mla.py
sgl-kernel/benchmark/bench_cutlass_mla.py
+133
-0
sgl-kernel/csrc/attention/cutlass_mla_kernel.cu
sgl-kernel/csrc/attention/cutlass_mla_kernel.cu
+52
-22
sgl-kernel/csrc/attention/cutlass_sm100_mla/device/sm100_mla.hpp
...nel/csrc/attention/cutlass_sm100_mla/device/sm100_mla.hpp
+358
-0
sgl-kernel/csrc/attention/cutlass_sm100_mla/kernel/sm100_fmha_mla_reduction.hpp
...ion/cutlass_sm100_mla/kernel/sm100_fmha_mla_reduction.hpp
+198
-0
sgl-kernel/csrc/attention/cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp
...s_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp
+2018
-0
sgl-kernel/csrc/attention/cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp
...ion/cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp
+160
-0
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+1
-1
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+4
-2
sgl-kernel/python/sgl_kernel/attention.py
sgl-kernel/python/sgl_kernel/attention.py
+18
-8
sgl-kernel/tests/test_cutlass_mla.py
sgl-kernel/tests/test_cutlass_mla.py
+17
-4
No files found.
sgl-kernel/benchmark/bench_cutlass_mla.py
0 → 100644
View file @
18efb5e8
import
argparse
import
copy
import
itertools
import
torch
import
triton
from
sgl_kernel
import
cutlass_mla_decode
,
cutlass_mla_get_workspace_size
bs_range
=
[
1
,
8
,
32
,
64
,
128
,
256
]
qlen_range
=
[
1
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
]
configs
=
list
(
itertools
.
product
(
bs_range
,
qlen_range
))
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
,
"seq_len"
],
x_vals
=
configs
,
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
[
"128 heads"
,
"64 heads"
,
"32 heads"
,
"16 heads"
,
],
line_names
=
[
"128 heads"
,
"64 heads"
,
"32 heads"
,
"16 heads"
,
],
styles
=
[(
"green"
,
"-"
),
(
"green"
,
"--"
),
(
"blue"
,
"-"
),
(
"blue"
,
"--"
)],
ylabel
=
"GB/s"
,
plot_name
=
"cutlass mla"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
seq_len
,
provider
,
block_size
,
num_kv_splits
):
d
=
576
dv
=
512
h_q_map
=
{
"128"
:
128
,
"64"
:
64
,
"32"
:
32
,
"16"
:
16
,
}
parsed_h_q
=
next
(
(
value
for
key
,
value
in
h_q_map
.
items
()
if
key
in
provider
),
None
)
if
parsed_h_q
is
None
:
raise
ValueError
(
f
"Unknown head configuration in provider:
{
provider
}
"
)
h_q
=
parsed_h_q
seq_lens
=
torch
.
full
((
batch_size
,),
seq_len
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
max_seq_len
=
seq_lens
.
max
().
item
()
block_num
=
(
max_seq_len
+
block_size
-
1
)
//
block_size
# Pad block_num so that small blocks can be packed into full 128-sized CUTLASS tiles.
# One 128-wide tile can hold (128 // block_size) small blocks.
pack_factor
=
128
//
block_size
block_num
=
((
block_num
+
pack_factor
-
1
)
//
pack_factor
)
*
pack_factor
q
=
torch
.
randn
(
batch_size
,
h_q
,
d
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
*
100.0
block_table
=
torch
.
randint
(
0
,
batch_size
*
block_num
,
(
batch_size
,
block_num
),
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
kv_cache
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
d
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
workspace_size
=
cutlass_mla_get_workspace_size
(
block_num
*
block_size
,
batch_size
,
num_kv_splits
=
num_kv_splits
)
workspace
=
torch
.
empty
(
workspace_size
,
device
=
"cuda"
,
dtype
=
torch
.
uint8
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
cutlass_mla_decode
(
q
,
kv_cache
,
seq_lens
,
block_table
,
workspace
,
num_kv_splits
),
quantiles
=
quantiles
,
)
gbps
=
(
lambda
ms
:
(
q
.
numel
()
*
q
.
element_size
()
+
q
.
numel
()
*
q
.
element_size
()
*
dv
/
d
+
kv_cache
.
numel
()
*
kv_cache
.
element_size
()
)
*
1e-9
/
(
ms
*
1e-3
)
)
return
gbps
(
ms
),
gbps
(
max_ms
),
gbps
(
min_ms
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--block-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
[
1
,
32
,
64
,
128
],
help
=
"List of batch sizes"
,
)
parser
.
add_argument
(
"--num-kv-splits"
,
nargs
=
"+"
,
type
=
int
,
default
=
[
-
1
],
help
=
"List of batch sizes"
,
)
args
=
parser
.
parse_args
()
for
block_size
in
args
.
block_sizes
:
for
kv_split
in
args
.
num_kv_splits
:
print
(
f
"block_size=
{
block_size
}
, num_kv_splits=
{
kv_split
}
: "
)
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
,
save_path
=
"bench_blackwell_mla_res"
,
block_size
=
block_size
,
num_kv_splits
=
kv_split
,
)
print
(
"Benchmark finished!"
)
sgl-kernel/csrc/attention/cutlass_mla_kernel.cu
View file @
18efb5e8
...
@@ -22,8 +22,9 @@ limitations under the License.
...
@@ -22,8 +22,9 @@ limitations under the License.
#include <torch/all.h>
#include <torch/all.h>
#include <cute/tensor.hpp>
#include <cute/tensor.hpp>
#include <device/sm100_mla.hpp>
#include <kernel/sm100_mla_tile_scheduler.hpp>
#include "cutlass_sm100_mla/device/sm100_mla.hpp"
#include "cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp"
// clang-format off
// clang-format off
#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040
#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040
...
@@ -55,7 +56,7 @@ struct IsPersistent {
...
@@ -55,7 +56,7 @@ struct IsPersistent {
static
const
bool
value
=
v
;
static
const
bool
value
=
v
;
};
};
template
<
typename
T
,
typename
PersistenceOption
=
IsPersistent
<
true
>
>
template
<
typename
T
,
bool
IsPaged128
,
typename
PersistenceOption
=
IsPersistent
<
true
>
>
struct
MlaSm100
{
struct
MlaSm100
{
using
Element
=
T
;
using
Element
=
T
;
using
ElementAcc
=
float
;
using
ElementAcc
=
float
;
...
@@ -83,7 +84,7 @@ struct MlaSm100 {
...
@@ -83,7 +84,7 @@ struct MlaSm100 {
ElementOut
,
ElementOut
,
ElementAcc
,
ElementAcc
,
TileScheduler
,
TileScheduler
,
/*kIsCpAsync=*/
true
>
;
/*kIsCpAsync=*/
!
IsPaged128
>
;
using
Fmha
=
cutlass
::
fmha
::
device
::
MLA
<
FmhaKernel
>
;
using
Fmha
=
cutlass
::
fmha
::
device
::
MLA
<
FmhaKernel
>
;
};
};
...
@@ -93,7 +94,8 @@ typename T::Fmha::Arguments args_from_options(
...
@@ -93,7 +94,8 @@ typename T::Fmha::Arguments args_from_options(
at
::
Tensor
const
&
q_nope_and_q_pe
,
at
::
Tensor
const
&
q_nope_and_q_pe
,
at
::
Tensor
const
&
kv_c_and_k_pe_cache
,
at
::
Tensor
const
&
kv_c_and_k_pe_cache
,
at
::
Tensor
const
&
seq_lens
,
at
::
Tensor
const
&
seq_lens
,
at
::
Tensor
const
&
page_table
)
{
at
::
Tensor
const
&
page_table
,
int64_t
num_kv_splits
)
{
cutlass
::
KernelHardwareInfo
hw_info
;
cutlass
::
KernelHardwareInfo
hw_info
;
hw_info
.
device_id
=
q_nope_and_q_pe
.
device
().
index
();
hw_info
.
device_id
=
q_nope_and_q_pe
.
device
().
index
();
hw_info
.
sm_count
=
cutlass
::
KernelHardwareInfo
::
query_device_multiprocessor_count
(
hw_info
.
device_id
);
hw_info
.
sm_count
=
cutlass
::
KernelHardwareInfo
::
query_device_multiprocessor_count
(
hw_info
.
device_id
);
...
@@ -154,8 +156,8 @@ typename T::Fmha::Arguments args_from_options(
...
@@ -154,8 +156,8 @@ typename T::Fmha::Arguments args_from_options(
// TODO(trevor-m): Change split_kv back to -1 when
// TODO(trevor-m): Change split_kv back to -1 when
// https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will
// https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will
// perform worse with larger context length and smaller batch sizes.
// perform worse with larger context length and smaller batch sizes.
1
,
// split_kv
num_kv_splits
,
// split_kv
nullptr
,
// is_var_split_kv
nullptr
,
// is_var_split_kv
};
};
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
// split_kv automatically based on batch size and sequence length to balance
// split_kv automatically based on batch size and sequence length to balance
...
@@ -165,7 +167,7 @@ typename T::Fmha::Arguments args_from_options(
...
@@ -165,7 +167,7 @@ typename T::Fmha::Arguments args_from_options(
return
arguments
;
return
arguments
;
}
}
template
<
typename
Element
>
template
<
typename
Element
,
bool
IsPaged128
,
typename
PersistenceOption
>
void
runMla
(
void
runMla
(
at
::
Tensor
const
&
out
,
at
::
Tensor
const
&
out
,
at
::
Tensor
const
&
q_nope_and_q_pe
,
at
::
Tensor
const
&
q_nope_and_q_pe
,
...
@@ -173,10 +175,11 @@ void runMla(
...
@@ -173,10 +175,11 @@ void runMla(
at
::
Tensor
const
&
seq_lens
,
at
::
Tensor
const
&
seq_lens
,
at
::
Tensor
const
&
page_table
,
at
::
Tensor
const
&
page_table
,
at
::
Tensor
const
&
workspace
,
at
::
Tensor
const
&
workspace
,
int64_t
num_kv_splits
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
using
MlaSm100Type
=
MlaSm100
<
Element
>
;
using
MlaSm100Type
=
MlaSm100
<
Element
,
IsPaged128
,
PersistenceOption
>
;
typename
MlaSm100Type
::
Fmha
fmha
;
typename
MlaSm100Type
::
Fmha
fmha
;
auto
arguments
=
args_from_options
<
MlaSm100Type
>
(
out
,
q_nope_and_q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
);
auto
arguments
=
args_from_options
<
MlaSm100Type
>
(
out
,
q_nope_and_q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
num_kv_splits
);
CUTLASS_CHECK
(
fmha
.
can_implement
(
arguments
));
CUTLASS_CHECK
(
fmha
.
can_implement
(
arguments
));
...
@@ -185,31 +188,57 @@ void runMla(
...
@@ -185,31 +188,57 @@ void runMla(
CUTLASS_CHECK
(
fmha
.
run
(
arguments
,
workspace
.
data_ptr
(),
stream
));
CUTLASS_CHECK
(
fmha
.
run
(
arguments
,
workspace
.
data_ptr
(),
stream
));
}
}
#define DISPATCH_BOOL(expr, const_expr, ...) \
[&]() -> bool { \
if (expr) { \
constexpr bool const_expr = true; \
return __VA_ARGS__(); \
} else { \
constexpr bool const_expr = false; \
return __VA_ARGS__(); \
} \
}()
void
cutlass_mla_decode
(
void
cutlass_mla_decode
(
torch
::
Tensor
const
&
out
,
torch
::
Tensor
const
&
out
,
torch
::
Tensor
const
&
q_nope_and_q_pe
,
torch
::
Tensor
const
&
q_nope_and_q_pe
,
torch
::
Tensor
const
&
kv_c_and_k_pe_cache
,
torch
::
Tensor
const
&
kv_c_and_k_pe_cache
,
torch
::
Tensor
const
&
seq_lens
,
torch
::
Tensor
const
&
seq_lens
,
torch
::
Tensor
const
&
page_table
,
torch
::
Tensor
const
&
page_table
,
torch
::
Tensor
const
&
workspace
)
{
torch
::
Tensor
const
&
workspace
,
int64_t
num_kv_splits
)
{
auto
in_dtype
=
q_nope_and_q_pe
.
dtype
();
auto
in_dtype
=
q_nope_and_q_pe
.
dtype
();
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
q_nope_and_q_pe
.
get_device
()};
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
q_nope_and_q_pe
.
get_device
()};
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
q_nope_and_q_pe
.
get_device
());
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
q_nope_and_q_pe
.
get_device
());
if
(
in_dtype
==
at
::
ScalarType
::
Half
)
{
const
int
page_size
=
kv_c_and_k_pe_cache
.
sizes
()[
1
];
runMla
<
cutlass
::
half_t
>
(
out
,
q_nope_and_q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
stream
);
}
else
if
(
in_dtype
==
at
::
ScalarType
::
BFloat16
)
{
// NOTE(alcanderian): IsPersistent has bug with manual split_kv.
runMla
<
cutlass
::
bfloat16_t
>
(
out
,
q_nope_and_q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
stream
);
// Kernel will hang if batch is too large with large num_kv_splits. (for example bs=8, num_kv_splits=8)
}
else
if
(
in_dtype
==
at
::
ScalarType
::
Float8_e4m3fn
)
{
// Maybe per batch split kv will fix this.
runMla
<
cutlass
::
float_e4m3_t
>
(
out
,
q_nope_and_q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
stream
);
DISPATCH_BOOL
(
page_size
==
128
,
IsPaged128
,
[
&
]
{
}
else
{
DISPATCH_BOOL
(
num_kv_splits
<=
1
,
NotManualSplitKV
,
[
&
]
{
TORCH_CHECK
(
false
,
"Unsupported input data type of MLA"
);
if
(
in_dtype
==
at
::
ScalarType
::
Half
)
{
}
runMla
<
cutlass
::
half_t
,
IsPaged128
,
IsPersistent
<
NotManualSplitKV
>>
(
out
,
q_nope_and_q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
num_kv_splits
,
stream
);
}
else
if
(
in_dtype
==
at
::
ScalarType
::
BFloat16
)
{
runMla
<
cutlass
::
bfloat16_t
,
IsPaged128
,
IsPersistent
<
NotManualSplitKV
>>
(
out
,
q_nope_and_q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
num_kv_splits
,
stream
);
}
else
if
(
in_dtype
==
at
::
ScalarType
::
Float8_e4m3fn
)
{
runMla
<
cutlass
::
float_e4m3_t
,
IsPaged128
,
IsPersistent
<
NotManualSplitKV
>>
(
out
,
q_nope_and_q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
num_kv_splits
,
stream
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported input data type of MLA"
);
}
return
true
;
});
return
true
;
});
}
}
int64_t
cutlass_mla_get_workspace_size
(
int64_t
max_seq_len
,
int64_t
num_batches
,
int64_t
sm_count
)
{
int64_t
cutlass_mla_get_workspace_size
(
int64_t
max_seq_len
,
int64_t
num_batches
,
int64_t
sm_count
,
int64_t
num_kv_splits
)
{
// Workspace size depends on ElementAcc and ElementLSE (same as ElementAcc)
// Workspace size depends on ElementAcc and ElementLSE (same as ElementAcc)
// which are float, so Element type here doesn't matter.
// which are float, so Element type here doesn't matter.
using
MlaSm100Type
=
MlaSm100
<
cutlass
::
half_t
>
;
using
MlaSm100Type
=
MlaSm100
<
cutlass
::
half_t
,
true
>
;
// Get split kv. Requires problem shape and sm_count only.
// Get split kv. Requires problem shape and sm_count only.
typename
MlaSm100Type
::
Fmha
::
Arguments
arguments
;
typename
MlaSm100Type
::
Fmha
::
Arguments
arguments
;
...
@@ -220,6 +249,7 @@ int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches,
...
@@ -220,6 +249,7 @@ int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches,
// Assumes device 0 when getting sm_count.
// Assumes device 0 when getting sm_count.
arguments
.
hw_info
.
sm_count
=
arguments
.
hw_info
.
sm_count
=
sm_count
<=
0
?
cutlass
::
KernelHardwareInfo
::
query_device_multiprocessor_count
(
/*device_id=*/
0
)
:
sm_count
;
sm_count
<=
0
?
cutlass
::
KernelHardwareInfo
::
query_device_multiprocessor_count
(
/*device_id=*/
0
)
:
sm_count
;
arguments
.
split_kv
=
num_kv_splits
;
MlaSm100Type
::
Fmha
::
set_split_kv
(
arguments
);
MlaSm100Type
::
Fmha
::
set_split_kv
(
arguments
);
return
MlaSm100Type
::
Fmha
::
get_workspace_size
(
arguments
);
return
MlaSm100Type
::
Fmha
::
get_workspace_size
(
arguments
);
...
...
sgl-kernel/csrc/attention/cutlass_sm100_mla/device/sm100_mla.hpp
0 → 100644
View file @
18efb5e8
/***************************************************************************************************
* Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief An universal device layer for cutlass 3.x-style kernels.
*/
// clang-format off
#pragma once
// common
#include "cutlass/cutlass.h"
#include "cutlass/device_kernel.h"
#if !defined(__CUDACC_RTC__)
#include "cutlass/cluster_launch.hpp"
#include "cutlass/trace.h"
#endif // !defined(__CUDACC_RTC__)
#include "../kernel/sm100_fmha_mla_tma_warpspecialized.hpp"
#include "../kernel/sm100_fmha_mla_reduction.hpp"
////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
::
fmha
::
device
{
using
namespace
cute
;
using
namespace
cutlass
::
fmha
::
kernel
;
////////////////////////////////////////////////////////////////////////////////
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
template
<
class
Kernel_
>
class
MLA
{
public:
using
Kernel
=
Kernel_
;
using
ReductionKernel
=
cutlass
::
fmha
::
kernel
::
Sm100FmhaMlaReductionKernel
<
typename
Kernel
::
ElementOut
,
typename
Kernel
::
ElementAcc
,
typename
Kernel
::
ElementAcc
,
Kernel
::
TileShapeH
::
value
,
Kernel
::
TileShapeL
::
value
,
256
/*Max split*/
>
;
/// Argument structure: User API
using
KernelArguments
=
typename
Kernel
::
Arguments
;
using
ReductionArguments
=
typename
ReductionKernel
::
Arguments
;
using
Arguments
=
KernelArguments
;
/// Argument structure: Kernel API
using
KernelParams
=
typename
Kernel
::
Params
;
using
ReductionParams
=
typename
ReductionKernel
::
Params
;
struct
Params
{
KernelParams
fmha_params
;
ReductionParams
reduction_params
;
};
private:
/// Kernel API parameters object
Params
params_
;
bool
is_initialized
(
bool
set
=
false
)
{
static
bool
initialized
=
false
;
if
(
set
)
initialized
=
true
;
return
initialized
;
}
static
ReductionArguments
to_reduction_args
(
Arguments
const
&
args
)
{
auto
[
H
,
K
,
D
,
B
]
=
args
.
problem_shape
;
return
ReductionArguments
{
nullptr
,
args
.
epilogue
.
ptr_o
,
nullptr
,
args
.
epilogue
.
ptr_lse
,
args
.
mainloop
.
softmax_scale
,
B
,
args
.
split_kv
,
K
,
args
.
mainloop
.
ptr_seq
,
args
.
ptr_split_kv
,
Kernel
::
TileShapeS
::
value
};
}
public:
/// Access the Params structure
Params
const
&
params
()
const
{
return
params_
;
}
static
void
set_split_kv
(
KernelArguments
&
args
)
{
if
(
args
.
split_kv
>=
1
)
return
;
auto
[
H
,
K
,
D
,
B
]
=
args
.
problem_shape
;
int
sm_count
=
args
.
hw_info
.
sm_count
;
int
max_splits
=
ceil_div
(
K
,
128
);
int
sms_per_batch
=
max
(
1
,
sm_count
/
B
);
int
split_heur
=
min
(
max_splits
,
sms_per_batch
);
int
waves
=
ceil_div
(
B
*
split_heur
,
sm_count
);
int
k_waves
=
ceil_div
(
max_splits
,
split_heur
);
int
split_wave_aware
=
ceil_div
(
max_splits
,
k_waves
);
args
.
split_kv
=
split_wave_aware
;
}
/// Determines whether the GEMM can execute the given problem.
static
Status
can_implement
(
Arguments
const
&
args
)
{
if
(
!
Kernel
::
can_implement
(
args
))
{
return
Status
::
kInvalid
;
}
if
(
!
ReductionKernel
::
can_implement
(
to_reduction_args
(
args
)))
{
return
Status
::
kInvalid
;
}
return
Status
::
kSuccess
;
}
/// Gets the workspace size
static
size_t
get_workspace_size
(
Arguments
const
&
args
)
{
size_t
workspace_bytes
=
0
;
workspace_bytes
+=
Kernel
::
get_workspace_size
(
args
);
workspace_bytes
+=
ReductionKernel
::
get_workspace_size
(
to_reduction_args
(
args
));
return
workspace_bytes
;
}
/// Computes the maximum number of active blocks per multiprocessor
static
int
maximum_active_blocks
(
int
/* smem_capacity */
=
-
1
)
{
CUTLASS_TRACE_HOST
(
"MLA::maximum_active_blocks()"
);
int
max_active_blocks
=
-
1
;
int
smem_size
=
Kernel
::
SharedStorageSize
;
// first, account for dynamic smem capacity if needed
cudaError_t
result
;
if
(
smem_size
>=
(
48
<<
10
))
{
CUTLASS_TRACE_HOST
(
" Setting smem size to "
<<
smem_size
);
result
=
cudaFuncSetAttribute
(
device_kernel
<
Kernel
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
);
if
(
cudaSuccess
!=
result
)
{
result
=
cudaGetLastError
();
// to clear the error bit
CUTLASS_TRACE_HOST
(
" cudaFuncSetAttribute() returned error: "
<<
cudaGetErrorString
(
result
));
return
-
1
;
}
}
// query occupancy after setting smem size
result
=
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
max_active_blocks
,
device_kernel
<
Kernel
>
,
Kernel
::
MaxThreadsPerBlock
,
smem_size
);
if
(
cudaSuccess
!=
result
)
{
result
=
cudaGetLastError
();
// to clear the error bit
CUTLASS_TRACE_HOST
(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: "
<<
cudaGetErrorString
(
result
));
return
-
1
;
}
CUTLASS_TRACE_HOST
(
" max_active_blocks: "
<<
max_active_blocks
);
return
max_active_blocks
;
}
/// Initializes GEMM state from arguments.
Status
initialize
(
Arguments
const
&
args
,
void
*
workspace
=
nullptr
,
cudaStream_t
stream
=
nullptr
)
{
CUTLASS_TRACE_HOST
(
"MLA::initialize() - workspace "
<<
workspace
<<
", stream: "
<<
(
stream
?
"non-null"
:
"null"
));
// Initialize the workspace
Status
status
=
Kernel
::
initialize_workspace
(
args
,
workspace
,
stream
);
if
(
status
!=
Status
::
kSuccess
)
{
return
status
;
}
status
=
ReductionKernel
::
initialize_workspace
(
to_reduction_args
(
args
),
workspace
,
stream
);
if
(
status
!=
Status
::
kSuccess
)
{
return
status
;
}
KernelParams
kernel_params
=
Kernel
::
to_underlying_arguments
(
args
,
workspace
);
ReductionArguments
reduction_args
=
to_reduction_args
(
args
);
if
(
reduction_args
.
split_kv
>
1
)
{
reduction_args
.
ptr_oaccum
=
kernel_params
.
epilogue
.
ptr_o_acc
;
reduction_args
.
ptr_lseaccum
=
kernel_params
.
epilogue
.
ptr_lse_acc
;
}
ReductionParams
reduction_params
=
ReductionKernel
::
to_underlying_arguments
(
reduction_args
,
workspace
);
// Initialize the Params structure
params_
=
Params
{
kernel_params
,
reduction_params
};
if
(
is_initialized
())
return
Status
::
kSuccess
;
// account for dynamic smem capacity if needed
// no dynamic smem is needed for reduction kernel
int
smem_size
=
Kernel
::
SharedStorageSize
;
if
(
smem_size
>=
(
48
<<
10
))
{
CUTLASS_TRACE_HOST
(
" Setting smem size to "
<<
smem_size
);
cudaError_t
result
=
cudaFuncSetAttribute
(
device_kernel
<
Kernel
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
);
if
(
cudaSuccess
!=
result
)
{
result
=
cudaGetLastError
();
// to clear the error bit
CUTLASS_TRACE_HOST
(
" cudaFuncSetAttribute() returned error: "
<<
cudaGetErrorString
(
result
));
return
Status
::
kErrorInternal
;
}
}
is_initialized
(
true
);
return
Status
::
kSuccess
;
}
/// Update API is preserved in 3.0, but does not guarantee a lightweight update of params.
Status
update
(
Arguments
const
&
args
,
void
*
workspace
=
nullptr
)
{
CUTLASS_TRACE_HOST
(
"MLA()::update() - workspace: "
<<
workspace
);
size_t
workspace_bytes
=
get_workspace_size
(
args
);
if
(
workspace_bytes
>
0
&&
nullptr
==
workspace
)
{
return
Status
::
kErrorWorkspaceNull
;
}
auto
fmha_params
=
Kernel
::
to_underlying_arguments
(
args
,
workspace
);
ReductionArguments
reduction_args
=
to_reduction_args
(
args
);
if
(
reduction_args
.
split_kv
>
1
)
{
reduction_args
.
ptr_oaccum
=
fmha_params
.
epilogue
.
ptr_o_acc
;
reduction_args
.
ptr_lseaccum
=
fmha_params
.
epilogue
.
ptr_lse_acc
;
}
ReductionParams
reduction_params
=
ReductionKernel
::
to_underlying_arguments
(
reduction_args
,
workspace
);
// Initialize the Params structure
params_
=
Params
{
fmha_params
,
reduction_params
};
return
Status
::
kSuccess
;
}
/// Primary run() entry point API that is static allowing users to create and manage their own params.
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
static
Status
run
(
Params
&
params
,
cudaStream_t
stream
=
nullptr
)
{
CUTLASS_TRACE_HOST
(
"MLA::run()"
);
dim3
const
block
=
Kernel
::
get_block_shape
();
dim3
const
grid
=
Kernel
::
get_grid_shape
(
params
.
fmha_params
);
// configure smem size and carveout
int
smem_size
=
Kernel
::
SharedStorageSize
;
Status
launch_result
;
// Use extended launch API only for mainloops that use it
if
constexpr
(
Kernel
::
ArchTag
::
kMinComputeCapability
>=
90
)
{
dim3
cluster
(
cute
::
size
<
0
>
(
typename
Kernel
::
ClusterShape
{}),
cute
::
size
<
1
>
(
typename
Kernel
::
ClusterShape
{}),
cute
::
size
<
2
>
(
typename
Kernel
::
ClusterShape
{}));
void
const
*
kernel
=
(
void
const
*
)
device_kernel
<
Kernel
>
;
void
*
kernel_params
[]
=
{
&
params
.
fmha_params
};
launch_result
=
ClusterLauncher
::
launch
(
grid
,
cluster
,
block
,
smem_size
,
stream
,
kernel
,
kernel_params
);
}
else
{
launch_result
=
Status
::
kSuccess
;
device_kernel
<
Kernel
><<<
grid
,
block
,
smem_size
,
stream
>>>
(
params
.
fmha_params
);
}
cudaError_t
result
=
cudaGetLastError
();
if
(
cudaSuccess
!=
result
or
Status
::
kSuccess
!=
launch_result
)
{
//return Status::kSuccess;
CUTLASS_TRACE_HOST
(
" Kernel launch failed. Reason: "
<<
result
);
return
Status
::
kErrorInternal
;
}
if
(
params
.
reduction_params
.
split_kv
>
1
)
{
// launch reduction kernel
dim3
const
block
=
ReductionKernel
::
get_block_shape
();
dim3
const
grid
=
ReductionKernel
::
get_grid_shape
(
params
.
reduction_params
);
device_kernel
<
ReductionKernel
><<<
grid
,
block
,
0
,
stream
>>>
(
params
.
reduction_params
);
cudaError_t
result
=
cudaGetLastError
();
if
(
cudaSuccess
==
result
)
{
return
Status
::
kSuccess
;
}
else
{
CUTLASS_TRACE_HOST
(
" Kernel launch failed. Reason: "
<<
result
);
return
Status
::
kErrorInternal
;
}
}
else
{
return
Status
::
kSuccess
;
}
}
//
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
//
/// Launches the kernel after first constructing Params internal state from supplied arguments.
Status
run
(
Arguments
const
&
args
,
void
*
workspace
=
nullptr
,
cudaStream_t
stream
=
nullptr
)
{
Status
status
=
initialize
(
args
,
workspace
,
stream
);
if
(
Status
::
kSuccess
==
status
)
{
status
=
run
(
params_
,
stream
);
}
return
status
;
}
/// Launches the kernel after first constructing Params internal state from supplied arguments.
Status
operator
()(
Arguments
const
&
args
,
void
*
workspace
=
nullptr
,
cudaStream_t
stream
=
nullptr
)
{
return
run
(
args
,
workspace
,
stream
);
}
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
Status
run
(
cudaStream_t
stream
=
nullptr
)
{
return
run
(
params_
,
stream
);
}
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
Status
operator
()(
cudaStream_t
stream
=
nullptr
)
{
return
run
(
params_
,
stream
);
}
};
////////////////////////////////////////////////////////////////////////////////
}
// namespace cutlass::fmha::device
////////////////////////////////////////////////////////////////////////////////
sgl-kernel/csrc/attention/cutlass_sm100_mla/kernel/sm100_fmha_mla_reduction.hpp
0 → 100644
View file @
18efb5e8
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
// clang-format off
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/arch/arch.h"
#include "cute/tensor.hpp"
namespace
cutlass
::
fmha
::
kernel
{
using
namespace
cute
;
template
<
class
ElementOut
,
class
ElementAcc
,
class
ElementScale
,
size_t
kNumHeads
,
size_t
kHeadDimLatent
,
int
kMaxSplits
>
struct
Sm100FmhaMlaReductionKernel
{
static
const
int
SharedStorageSize
=
0
;
static
const
int
MaxThreadsPerBlock
=
128
;
static
const
int
MinBlocksPerMultiprocessor
=
1
;
using
ArchTag
=
cutlass
::
arch
::
Sm100
;
static_assert
(
kHeadDimLatent
%
MaxThreadsPerBlock
==
0
);
struct
Arguments
{
ElementAcc
*
ptr_oaccum
=
nullptr
;
ElementOut
*
ptr_o
=
nullptr
;
ElementAcc
*
ptr_lseaccum
=
nullptr
;
ElementAcc
*
ptr_lse
=
nullptr
;
ElementScale
scale
=
1.
f
;
int
num_batches
=
0
;
int
split_kv
=
-
1
;
int
dim_k
=
-
1
;
int
*
ptr_seq
=
nullptr
;
int
*
ptr_split_kv
=
nullptr
;
int
tile_shape_s
=
128
;
};
using
Params
=
Arguments
;
static
Params
to_underlying_arguments
(
Arguments
const
&
args
,
void
*
workspace
)
{
return
{
args
.
ptr_oaccum
,
args
.
ptr_o
,
args
.
ptr_lseaccum
,
args
.
ptr_lse
,
args
.
scale
,
args
.
num_batches
,
args
.
split_kv
,
args
.
dim_k
,
args
.
ptr_seq
,
args
.
ptr_split_kv
,
args
.
tile_shape_s
};
}
static
size_t
get_workspace_size
(
Arguments
const
&
/*args*/
)
{
return
0
;
}
static
Status
initialize_workspace
(
Arguments
const
&
/*args*/
,
void
*
/*ws*/
,
cudaStream_t
/*stream*/
)
{
return
Status
::
kSuccess
;
}
static
dim3
get_grid_shape
(
Params
const
&
params
)
{
return
dim3
(
kNumHeads
,
1
,
params
.
num_batches
);
}
static
dim3
get_block_shape
()
{
return
dim3
(
MaxThreadsPerBlock
,
1
,
1
);
}
static
bool
can_implement
(
Arguments
const
&
args
)
{
if
(
args
.
num_batches
<=
0
)
return
false
;
if
(
args
.
split_kv
<=
0
)
return
false
;
return
true
;
}
CUTLASS_DEVICE
void
operator
()
(
Params
const
&
params
,
char
*
smem_raw
)
{
if
(
params
.
split_kv
<=
1
)
return
;
auto
blk_coord
=
make_coord
(
blockIdx
.
x
,
_0
{},
blockIdx
.
z
);
__shared__
ElementAcc
sLseScale
[
kMaxSplits
];
const
size_t
offset_lseaccum
=
get
<
0
>
(
blk_coord
)
+
kNumHeads
*
params
.
split_kv
*
get
<
2
>
(
blk_coord
);
const
size_t
offset_lse
=
get
<
0
>
(
blk_coord
)
+
kNumHeads
*
get
<
2
>
(
blk_coord
);
Tensor
gLSEaccum
=
make_tensor
(
make_gmem_ptr
(
params
.
ptr_lseaccum
+
offset_lseaccum
),
make_shape
(
params
.
split_kv
),
Stride
<
Int
<
kNumHeads
>>
{});
Tensor
gLSE
=
make_tensor
(
make_gmem_ptr
(
params
.
ptr_lse
+
offset_lse
),
Shape
<
_1
>
{},
Stride
<
_1
>
{});
auto
dim_k
=
params
.
ptr_seq
==
nullptr
?
params
.
dim_k
:
params
.
ptr_seq
[
get
<
2
>
(
blk_coord
)];
auto
local_split_kv
=
params
.
ptr_split_kv
==
nullptr
?
params
.
split_kv
:
params
.
ptr_split_kv
[
get
<
2
>
(
blk_coord
)];
auto
k_tile_total
=
ceil_div
(
dim_k
,
params
.
tile_shape_s
);
auto
k_tile_per_cta
=
ceil_div
(
k_tile_total
,
local_split_kv
);
local_split_kv
=
ceil_div
(
k_tile_total
,
k_tile_per_cta
);
int
warp_idx
=
cutlass
::
canonical_warp_idx_sync
();
if
(
warp_idx
==
0
)
{
constexpr
int
kNLsePerThread
=
cute
::
ceil_div
(
kMaxSplits
,
32
);
ElementAcc
local_lse
[
kNLsePerThread
];
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
kNLsePerThread
;
++
i
)
{
const
int
split
=
i
*
32
+
threadIdx
.
x
;
local_lse
[
i
]
=
split
<
local_split_kv
?
gLSEaccum
(
split
)
:
-
std
::
numeric_limits
<
ElementAcc
>::
infinity
();
}
ElementAcc
lse_max
=
-
std
::
numeric_limits
<
ElementAcc
>::
infinity
();
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
kNLsePerThread
;
++
i
)
{
lse_max
=
max
(
lse_max
,
local_lse
[
i
]);
}
CUTLASS_PRAGMA_UNROLL
for
(
int
offset
=
16
;
offset
>=
1
;
offset
/=
2
)
{
lse_max
=
max
(
lse_max
,
__shfl_xor_sync
(
0xffffffff
,
lse_max
,
offset
));
}
lse_max
=
lse_max
==
-
std
::
numeric_limits
<
ElementAcc
>::
infinity
()
?
0.0
f
:
lse_max
;
// In case all local LSEs are -inf
lse_max
=
__shfl_sync
(
0xffffffff
,
lse_max
,
0
);
ElementAcc
sum_lse
=
0
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
kNLsePerThread
;
++
i
)
{
sum_lse
=
sum_lse
+
expf
(
local_lse
[
i
]
-
lse_max
);
}
CUTLASS_PRAGMA_UNROLL
for
(
int
offset
=
16
;
offset
>=
1
;
offset
/=
2
)
{
sum_lse
=
sum_lse
+
__shfl_xor_sync
(
0xffffffff
,
sum_lse
,
offset
);
}
sum_lse
=
__shfl_sync
(
0xffffffff
,
sum_lse
,
0
);
ElementAcc
global_lse
=
(
sum_lse
==
0.
f
||
sum_lse
!=
sum_lse
)
?
std
::
numeric_limits
<
ElementAcc
>::
infinity
()
:
logf
(
sum_lse
)
+
lse_max
;
if
(
threadIdx
.
x
==
0
and
params
.
ptr_lse
!=
nullptr
)
{
gLSE
(
0
)
=
global_lse
;
}
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
kNLsePerThread
;
++
i
)
{
const
int
split
=
i
*
32
+
threadIdx
.
x
;
if
(
split
<
local_split_kv
)
{
sLseScale
[
split
]
=
expf
(
local_lse
[
i
]
-
global_lse
);
}
}
}
__syncthreads
();
constexpr
int
Elements
=
kHeadDimLatent
/
MaxThreadsPerBlock
;
const
size_t
offset_oaccum
=
kHeadDimLatent
*
params
.
split_kv
*
(
get
<
0
>
(
blk_coord
)
+
kNumHeads
*
get
<
2
>
(
blk_coord
));
Tensor
gOaccum
=
make_tensor
(
make_gmem_ptr
(
params
.
ptr_oaccum
+
offset_oaccum
),
Shape
<
Int
<
kHeadDimLatent
>>
{},
Stride
<
_1
>
{});
ElementAcc
local_val
[
Elements
]
=
{
0
};
for
(
int
split
=
0
;
split
<
local_split_kv
;
++
split
)
{
ElementAcc
lse_scale
=
sLseScale
[
split
];
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
Elements
;
++
i
)
{
local_val
[
i
]
+=
lse_scale
*
gOaccum
(
threadIdx
.
x
+
MaxThreadsPerBlock
*
i
);
}
gOaccum
.
data
()
=
gOaccum
.
data
()
+
kHeadDimLatent
;
}
auto
ptr_o_local
=
params
.
ptr_o
+
(
get
<
0
>
(
blk_coord
)
+
get
<
2
>
(
blk_coord
)
*
kNumHeads
)
*
kHeadDimLatent
;
Tensor
gO
=
make_tensor
(
make_gmem_ptr
(
ptr_o_local
),
Shape
<
Int
<
kHeadDimLatent
>>
{},
Stride
<
_1
>
{});
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
Elements
;
++
i
)
{
gO
(
threadIdx
.
x
+
MaxThreadsPerBlock
*
i
)
=
static_cast
<
ElementOut
>
(
local_val
[
i
]);
}
}
};
}
// namespace cutlass::fmha::kernel
sgl-kernel/csrc/attention/cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp
0 → 100644
View file @
18efb5e8
This diff is collapsed.
Click to expand it.
sgl-kernel/csrc/attention/cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp
0 → 100644
View file @
18efb5e8
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
// clang-format off
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/kernel_hardware_info.h"
namespace
cutlass
::
fmha
::
kernel
{
////////////////////////////////////////////////////////////////////////////////
struct
Sm100MlaIndividualTileScheduler
{
struct
Params
{
dim3
grid
;
};
bool
valid_
=
true
;
CUTLASS_DEVICE
Sm100MlaIndividualTileScheduler
(
Params
const
&
)
{}
template
<
class
ProblemShape
,
class
ClusterShape
>
static
Params
to_underlying_arguments
(
ProblemShape
const
&
problem_shape
,
KernelHardwareInfo
hw_info
,
ClusterShape
const
&
cluster_shape
,
int
const
&
split_kv
)
{
using
namespace
cute
;
dim3
grid
(
get
<
0
>
(
cluster_shape
),
get
<
3
>
(
problem_shape
)
/* Batch */
,
split_kv
/*Maximum Split KV*/
);
return
Params
{
grid
};
}
static
dim3
get_grid_shape
(
Params
const
&
params
)
{
return
params
.
grid
;
}
CUTLASS_DEVICE
bool
is_valid
()
{
return
valid_
;
}
CUTLASS_DEVICE
auto
get_block_coord
()
{
using
namespace
cute
;
return
make_coord
(
blockIdx
.
x
,
_0
{},
blockIdx
.
y
,
blockIdx
.
z
);
}
CUTLASS_DEVICE
Sm100MlaIndividualTileScheduler
&
operator
++
()
{
valid_
=
false
;
return
*
this
;
}
};
////////////////////////////////////////////////////////////////////////////////
struct
Sm100MlaPersistentTileScheduler
{
struct
Params
{
int
num_blocks
;
FastDivmod
divmod_m_block
;
FastDivmod
divmod_b
;
FastDivmod
divmod_split_kv
;
KernelHardwareInfo
hw_info
;
};
int
block_idx
=
0
;
Params
params
;
CUTLASS_DEVICE
Sm100MlaPersistentTileScheduler
(
Params
const
&
params
)
:
block_idx
(
blockIdx
.
x
),
params
(
params
)
{}
template
<
class
ProblemShape
,
class
ClusterShape
>
static
Params
to_underlying_arguments
(
ProblemShape
const
&
problem_shape
,
KernelHardwareInfo
hw_info
,
ClusterShape
const
&
cluster_shape
,
int
const
&
split_kv
)
{
using
namespace
cute
;
// Get SM count if needed, otherwise use user supplied SM count
int
sm_count
=
hw_info
.
sm_count
;
if
(
sm_count
<=
1
||
sm_count
%
size
<
0
>
(
cluster_shape
)
!=
0
)
{
CUTLASS_TRACE_HOST
(
" WARNING: Arguments do not include a valid SM count.
\n
"
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."
);
sm_count
=
KernelHardwareInfo
::
query_device_multiprocessor_count
(
hw_info
.
device_id
);
}
CUTLASS_TRACE_HOST
(
"to_underlying_arguments(): Setting persistent grid SM count to "
<<
sm_count
);
hw_info
.
sm_count
=
sm_count
;
int
num_m_blocks
=
size
<
0
>
(
cluster_shape
);
int
num_blocks
=
num_m_blocks
*
get
<
3
>
(
problem_shape
)
/* Batch */
;
num_blocks
*=
split_kv
;
/* Maximum Split KV*/
return
Params
{
num_blocks
,
{
num_m_blocks
},
{
get
<
3
>
(
problem_shape
)
},
{
split_kv
},
hw_info
};
}
static
dim3
get_grid_shape
(
Params
const
&
params
)
{
dim3
grid
(
std
::
min
(
params
.
num_blocks
,
params
.
hw_info
.
sm_count
),
1
,
1
);
return
grid
;
}
CUTLASS_DEVICE
bool
is_valid
()
{
return
block_idx
<
params
.
num_blocks
;
}
CUTLASS_DEVICE
auto
get_block_coord
()
{
using
namespace
cute
;
int
block_decode
=
block_idx
;
int
m_block
,
bidb
,
n_split_kv
;
params
.
divmod_m_block
(
block_decode
,
m_block
,
block_decode
);
params
.
divmod_b
(
block_decode
,
bidb
,
block_decode
);
params
.
divmod_split_kv
(
block_decode
,
n_split_kv
,
block_decode
);
return
make_coord
(
m_block
,
_0
{},
bidb
,
n_split_kv
);
}
CUTLASS_DEVICE
Sm100MlaPersistentTileScheduler
&
operator
++
()
{
block_idx
+=
gridDim
.
x
;
return
*
this
;
}
};
////////////////////////////////////////////////////////////////////////////////
}
// namespace cutlass::fmha::kernel
sgl-kernel/csrc/common_extension.cc
View file @
18efb5e8
...
@@ -60,7 +60,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -60,7 +60,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m
.
impl
(
"merge_state_v2"
,
torch
::
kCUDA
,
&
merge_state_v2
);
m
.
impl
(
"merge_state_v2"
,
torch
::
kCUDA
,
&
merge_state_v2
);
m
.
def
(
m
.
def
(
"cutlass_mla_decode(Tensor! out, Tensor q_nope_and_q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor "
"cutlass_mla_decode(Tensor! out, Tensor q_nope_and_q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor "
"page_table, Tensor workspace) -> ()"
);
"page_table, Tensor
!
workspace
, int num_kv_splits
) -> ()"
);
m
.
impl
(
"cutlass_mla_decode"
,
torch
::
kCUDA
,
&
cutlass_mla_decode
);
m
.
impl
(
"cutlass_mla_decode"
,
torch
::
kCUDA
,
&
cutlass_mla_decode
);
m
.
def
(
"cutlass_mla_get_workspace_size"
,
&
cutlass_mla_get_workspace_size
);
m
.
def
(
"cutlass_mla_get_workspace_size"
,
&
cutlass_mla_get_workspace_size
);
...
...
sgl-kernel/include/sgl_kernel_ops.h
View file @
18efb5e8
...
@@ -109,8 +109,10 @@ void cutlass_mla_decode(
...
@@ -109,8 +109,10 @@ void cutlass_mla_decode(
torch
::
Tensor
const
&
kv_c_and_k_pe_cache
,
torch
::
Tensor
const
&
kv_c_and_k_pe_cache
,
torch
::
Tensor
const
&
seq_lens
,
torch
::
Tensor
const
&
seq_lens
,
torch
::
Tensor
const
&
page_table
,
torch
::
Tensor
const
&
page_table
,
torch
::
Tensor
const
&
workspace
);
torch
::
Tensor
const
&
workspace
,
int64_t
cutlass_mla_get_workspace_size
(
int64_t
max_seq_len
,
int64_t
num_batches
,
int64_t
sm_count
=
0
);
int64_t
num_kv_splits
=
-
1
);
int64_t
cutlass_mla_get_workspace_size
(
int64_t
max_seq_len
,
int64_t
num_batches
,
int64_t
sm_count
=
0
,
int64_t
num_kv_splits
=
-
1
);
/*
/*
* From csrc/elementwise
* From csrc/elementwise
*/
*/
...
...
sgl-kernel/python/sgl_kernel/attention.py
View file @
18efb5e8
...
@@ -57,6 +57,7 @@ def cutlass_mla_decode(
...
@@ -57,6 +57,7 @@ def cutlass_mla_decode(
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
page_table
:
torch
.
Tensor
,
page_table
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
num_kv_splits
:
int
=
-
1
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
(
assert
(
q_nope_and_q_pe
.
ndim
==
3
q_nope_and_q_pe
.
ndim
==
3
...
@@ -73,7 +74,12 @@ def cutlass_mla_decode(
...
@@ -73,7 +74,12 @@ def cutlass_mla_decode(
f
"D_q must be equal to D_ckv and D_q must be equal to D_latent + D_rope, "
f
"D_q must be equal to D_ckv and D_q must be equal to D_latent + D_rope, "
f
"but got D_q =
{
D_q
}
, D_ckv =
{
D_ckv
}
, D_latent =
{
D_latent
}
, D_rope =
{
D_rope
}
"
f
"but got D_q =
{
D_q
}
, D_ckv =
{
D_ckv
}
, D_latent =
{
D_latent
}
, D_rope =
{
D_rope
}
"
)
)
assert
H
==
128
,
f
"H must be 128, but got
{
H
}
"
MAX_HEADS
=
128
assert
H
<=
MAX_HEADS
,
f
"H must be <=
{
MAX_HEADS
}
, but got
{
H
}
"
if
H
<
MAX_HEADS
:
q_nope_and_q_pe_padded
=
q_nope_and_q_pe
.
new_empty
((
B_q
,
MAX_HEADS
,
D_q
))
q_nope_and_q_pe_padded
[:,
:
H
]
=
q_nope_and_q_pe
q_nope_and_q_pe
=
q_nope_and_q_pe_padded
assert
len
(
page_table
.
shape
)
==
2
assert
len
(
page_table
.
shape
)
==
2
B_block_table
,
block_num
=
page_table
.
shape
B_block_table
,
block_num
=
page_table
.
shape
...
@@ -97,21 +103,25 @@ def cutlass_mla_decode(
...
@@ -97,21 +103,25 @@ def cutlass_mla_decode(
page_table
.
dtype
==
torch
.
int32
page_table
.
dtype
==
torch
.
int32
),
f
"page_table.dtype needs to be int32 but got
{
page_table
.
dtype
}
."
),
f
"page_table.dtype needs to be int32 but got
{
page_table
.
dtype
}
."
out
=
torch
.
empty
(
out
=
q_nope_and_q_pe
.
new_empty
((
B_q
,
MAX_HEADS
,
D_latent
))
(
B_q
,
H
,
D_latent
),
device
=
q_nope_and_q_pe
.
device
,
dtype
=
q_nope_and_q_pe
.
dtype
)
torch
.
ops
.
sgl_kernel
.
cutlass_mla_decode
.
default
(
torch
.
ops
.
sgl_kernel
.
cutlass_mla_decode
.
default
(
out
,
q_nope_and_q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
out
,
q_nope_and_q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
num_kv_splits
,
)
)
return
out
return
out
[:,
:
H
].
contiguous
()
def
cutlass_mla_get_workspace_size
(
def
cutlass_mla_get_workspace_size
(
max_seq_len
:
int
,
num_batches
:
int
,
sm_count
:
int
=
0
max_seq_len
:
int
,
num_batches
:
int
,
sm_count
:
int
=
0
,
num_kv_splits
:
int
=
-
1
)
->
int
:
)
->
int
:
assert
max_seq_len
>
0
,
f
"max_seq_len must be greater than 0, got
{
max_seq_len
}
"
assert
max_seq_len
>
0
,
f
"max_seq_len must be greater than 0, got
{
max_seq_len
}
"
assert
num_batches
>
0
,
f
"num_batches must be greater than 0, got
{
num_batches
}
"
assert
num_batches
>
0
,
f
"num_batches must be greater than 0, got
{
num_batches
}
"
return
torch
.
ops
.
sgl_kernel
.
cutlass_mla_get_workspace_size
.
default
(
return
torch
.
ops
.
sgl_kernel
.
cutlass_mla_get_workspace_size
.
default
(
max_seq_len
,
num_batches
,
sm_count
max_seq_len
,
num_batches
,
sm_count
,
num_kv_splits
)
)
sgl-kernel/tests/test_cutlass_mla.py
View file @
18efb5e8
...
@@ -40,15 +40,23 @@ def ref_mla(
...
@@ -40,15 +40,23 @@ def ref_mla(
@
pytest
.
mark
.
parametrize
(
"bs"
,
[
1
,
2
,
4
])
@
pytest
.
mark
.
parametrize
(
"bs"
,
[
1
,
2
,
4
])
@
pytest
.
mark
.
parametrize
(
"varlen"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"varlen"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
1
,
16
,
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
1
,
16
,
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
[
16
,
32
,
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"num_kv_splits"
,
[
-
1
,
1
])
def
test_cutlass_mla_decode
(
def
test_cutlass_mla_decode
(
dtype
:
torch
.
dtype
,
mean_seq_len
:
int
,
bs
:
int
,
varlen
:
bool
,
block_size
:
int
dtype
:
torch
.
dtype
,
mean_seq_len
:
int
,
bs
:
int
,
varlen
:
bool
,
block_size
:
int
,
num_heads
:
int
,
num_kv_splits
:
int
,
):
):
torch
.
set_default_dtype
(
dtype
)
torch
.
set_default_dtype
(
dtype
)
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
"cuda"
)
torch
.
manual_seed
(
42
)
torch
.
manual_seed
(
42
)
d
=
576
d
=
576
h_q
=
128
h_q
=
num_heads
dv
=
512
dv
=
512
q_nope_dim
=
128
q_nope_dim
=
128
...
@@ -67,17 +75,22 @@ def test_cutlass_mla_decode(
...
@@ -67,17 +75,22 @@ def test_cutlass_mla_decode(
pack_factor
=
128
//
block_size
pack_factor
=
128
//
block_size
block_num
=
((
block_num
+
pack_factor
-
1
)
//
pack_factor
)
*
pack_factor
block_num
=
((
block_num
+
pack_factor
-
1
)
//
pack_factor
)
*
pack_factor
# Lager q values to detect split kv error
q
=
torch
.
randn
(
bs
,
h_q
,
d
)
*
100.0
q
=
torch
.
randn
(
bs
,
h_q
,
d
)
*
100.0
block_table
=
torch
.
randint
(
0
,
bs
*
block_num
,
(
bs
,
block_num
),
dtype
=
torch
.
int32
)
block_table
=
torch
.
randint
(
0
,
bs
*
block_num
,
(
bs
,
block_num
),
dtype
=
torch
.
int32
)
kv_cache
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
d
)
kv_cache
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
d
)
workspace_size
=
cutlass_mla_get_workspace_size
(
block_num
*
block_size
,
bs
)
workspace_size
=
cutlass_mla_get_workspace_size
(
block_num
*
block_size
,
bs
,
num_kv_splits
=
num_kv_splits
)
workspace
=
torch
.
empty
(
workspace_size
,
device
=
"cuda"
,
dtype
=
torch
.
uint8
)
workspace
=
torch
.
empty
(
workspace_size
,
device
=
"cuda"
,
dtype
=
torch
.
uint8
)
out_ref
=
q
.
new_zeros
(
bs
,
h_q
,
dv
)
out_ref
=
q
.
new_zeros
(
bs
,
h_q
,
dv
)
ref_mla
(
out_ref
,
q
,
kv_cache
,
scale
,
block_table
,
seq_lens
)
ref_mla
(
out_ref
,
q
,
kv_cache
,
scale
,
block_table
,
seq_lens
)
out
=
cutlass_mla_decode
(
q
,
kv_cache
,
seq_lens
,
block_table
,
workspace
)
out
=
cutlass_mla_decode
(
q
,
kv_cache
,
seq_lens
,
block_table
,
workspace
,
num_kv_splits
)
torch
.
testing
.
assert_close
(
out
,
out_ref
,
atol
=
1e-2
,
rtol
=
1e-2
)
torch
.
testing
.
assert_close
(
out
,
out_ref
,
atol
=
1e-2
,
rtol
=
1e-2
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment