Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
18efb5e8
Unverified
Commit
18efb5e8
authored
Jun 09, 2025
by
JieXin Liang
Committed by
GitHub
Jun 08, 2025
Browse files
[perf][sgl-kernel] extend cutlass_mla_decode to support num_head < 128 (#6929)
parent
de1350ea
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
2959 additions
and
37 deletions
+2959
-37
sgl-kernel/benchmark/bench_cutlass_mla.py
sgl-kernel/benchmark/bench_cutlass_mla.py
+133
-0
sgl-kernel/csrc/attention/cutlass_mla_kernel.cu
sgl-kernel/csrc/attention/cutlass_mla_kernel.cu
+52
-22
sgl-kernel/csrc/attention/cutlass_sm100_mla/device/sm100_mla.hpp
...nel/csrc/attention/cutlass_sm100_mla/device/sm100_mla.hpp
+358
-0
sgl-kernel/csrc/attention/cutlass_sm100_mla/kernel/sm100_fmha_mla_reduction.hpp
...ion/cutlass_sm100_mla/kernel/sm100_fmha_mla_reduction.hpp
+198
-0
sgl-kernel/csrc/attention/cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp
...s_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp
+2018
-0
sgl-kernel/csrc/attention/cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp
...ion/cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp
+160
-0
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+1
-1
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+4
-2
sgl-kernel/python/sgl_kernel/attention.py
sgl-kernel/python/sgl_kernel/attention.py
+18
-8
sgl-kernel/tests/test_cutlass_mla.py
sgl-kernel/tests/test_cutlass_mla.py
+17
-4
No files found.
sgl-kernel/benchmark/bench_cutlass_mla.py
0 → 100644
View file @
18efb5e8
import
argparse
import
copy
import
itertools
import
torch
import
triton
from
sgl_kernel
import
cutlass_mla_decode
,
cutlass_mla_get_workspace_size
bs_range
=
[
1
,
8
,
32
,
64
,
128
,
256
]
qlen_range
=
[
1
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
]
configs
=
list
(
itertools
.
product
(
bs_range
,
qlen_range
))
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
,
"seq_len"
],
x_vals
=
configs
,
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
[
"128 heads"
,
"64 heads"
,
"32 heads"
,
"16 heads"
,
],
line_names
=
[
"128 heads"
,
"64 heads"
,
"32 heads"
,
"16 heads"
,
],
styles
=
[(
"green"
,
"-"
),
(
"green"
,
"--"
),
(
"blue"
,
"-"
),
(
"blue"
,
"--"
)],
ylabel
=
"GB/s"
,
plot_name
=
"cutlass mla"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
seq_len
,
provider
,
block_size
,
num_kv_splits
):
d
=
576
dv
=
512
h_q_map
=
{
"128"
:
128
,
"64"
:
64
,
"32"
:
32
,
"16"
:
16
,
}
parsed_h_q
=
next
(
(
value
for
key
,
value
in
h_q_map
.
items
()
if
key
in
provider
),
None
)
if
parsed_h_q
is
None
:
raise
ValueError
(
f
"Unknown head configuration in provider:
{
provider
}
"
)
h_q
=
parsed_h_q
seq_lens
=
torch
.
full
((
batch_size
,),
seq_len
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
max_seq_len
=
seq_lens
.
max
().
item
()
block_num
=
(
max_seq_len
+
block_size
-
1
)
//
block_size
# Pad block_num so that small blocks can be packed into full 128-sized CUTLASS tiles.
# One 128-wide tile can hold (128 // block_size) small blocks.
pack_factor
=
128
//
block_size
block_num
=
((
block_num
+
pack_factor
-
1
)
//
pack_factor
)
*
pack_factor
q
=
torch
.
randn
(
batch_size
,
h_q
,
d
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
*
100.0
block_table
=
torch
.
randint
(
0
,
batch_size
*
block_num
,
(
batch_size
,
block_num
),
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
kv_cache
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
d
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
workspace_size
=
cutlass_mla_get_workspace_size
(
block_num
*
block_size
,
batch_size
,
num_kv_splits
=
num_kv_splits
)
workspace
=
torch
.
empty
(
workspace_size
,
device
=
"cuda"
,
dtype
=
torch
.
uint8
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
cutlass_mla_decode
(
q
,
kv_cache
,
seq_lens
,
block_table
,
workspace
,
num_kv_splits
),
quantiles
=
quantiles
,
)
gbps
=
(
lambda
ms
:
(
q
.
numel
()
*
q
.
element_size
()
+
q
.
numel
()
*
q
.
element_size
()
*
dv
/
d
+
kv_cache
.
numel
()
*
kv_cache
.
element_size
()
)
*
1e-9
/
(
ms
*
1e-3
)
)
return
gbps
(
ms
),
gbps
(
max_ms
),
gbps
(
min_ms
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--block-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
[
1
,
32
,
64
,
128
],
help
=
"List of batch sizes"
,
)
parser
.
add_argument
(
"--num-kv-splits"
,
nargs
=
"+"
,
type
=
int
,
default
=
[
-
1
],
help
=
"List of batch sizes"
,
)
args
=
parser
.
parse_args
()
for
block_size
in
args
.
block_sizes
:
for
kv_split
in
args
.
num_kv_splits
:
print
(
f
"block_size=
{
block_size
}
, num_kv_splits=
{
kv_split
}
: "
)
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
,
save_path
=
"bench_blackwell_mla_res"
,
block_size
=
block_size
,
num_kv_splits
=
kv_split
,
)
print
(
"Benchmark finished!"
)
sgl-kernel/csrc/attention/cutlass_mla_kernel.cu
View file @
18efb5e8
...
...
@@ -22,8 +22,9 @@ limitations under the License.
#include <torch/all.h>
#include <cute/tensor.hpp>
#include <device/sm100_mla.hpp>
#include <kernel/sm100_mla_tile_scheduler.hpp>
#include "cutlass_sm100_mla/device/sm100_mla.hpp"
#include "cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp"
// clang-format off
#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040
...
...
@@ -55,7 +56,7 @@ struct IsPersistent {
static
const
bool
value
=
v
;
};
template
<
typename
T
,
typename
PersistenceOption
=
IsPersistent
<
true
>
>
template
<
typename
T
,
bool
IsPaged128
,
typename
PersistenceOption
=
IsPersistent
<
true
>
>
struct
MlaSm100
{
using
Element
=
T
;
using
ElementAcc
=
float
;
...
...
@@ -83,7 +84,7 @@ struct MlaSm100 {
ElementOut
,
ElementAcc
,
TileScheduler
,
/*kIsCpAsync=*/
true
>
;
/*kIsCpAsync=*/
!
IsPaged128
>
;
using
Fmha
=
cutlass
::
fmha
::
device
::
MLA
<
FmhaKernel
>
;
};
...
...
@@ -93,7 +94,8 @@ typename T::Fmha::Arguments args_from_options(
at
::
Tensor
const
&
q_nope_and_q_pe
,
at
::
Tensor
const
&
kv_c_and_k_pe_cache
,
at
::
Tensor
const
&
seq_lens
,
at
::
Tensor
const
&
page_table
)
{
at
::
Tensor
const
&
page_table
,
int64_t
num_kv_splits
)
{
cutlass
::
KernelHardwareInfo
hw_info
;
hw_info
.
device_id
=
q_nope_and_q_pe
.
device
().
index
();
hw_info
.
sm_count
=
cutlass
::
KernelHardwareInfo
::
query_device_multiprocessor_count
(
hw_info
.
device_id
);
...
...
@@ -154,8 +156,8 @@ typename T::Fmha::Arguments args_from_options(
// TODO(trevor-m): Change split_kv back to -1 when
// https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will
// perform worse with larger context length and smaller batch sizes.
1
,
// split_kv
nullptr
,
// is_var_split_kv
num_kv_splits
,
// split_kv
nullptr
,
// is_var_split_kv
};
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
// split_kv automatically based on batch size and sequence length to balance
...
...
@@ -165,7 +167,7 @@ typename T::Fmha::Arguments args_from_options(
return
arguments
;
}
template
<
typename
Element
>
template
<
typename
Element
,
bool
IsPaged128
,
typename
PersistenceOption
>
void
runMla
(
at
::
Tensor
const
&
out
,
at
::
Tensor
const
&
q_nope_and_q_pe
,
...
...
@@ -173,10 +175,11 @@ void runMla(
at
::
Tensor
const
&
seq_lens
,
at
::
Tensor
const
&
page_table
,
at
::
Tensor
const
&
workspace
,
int64_t
num_kv_splits
,
cudaStream_t
stream
)
{
using
MlaSm100Type
=
MlaSm100
<
Element
>
;
using
MlaSm100Type
=
MlaSm100
<
Element
,
IsPaged128
,
PersistenceOption
>
;
typename
MlaSm100Type
::
Fmha
fmha
;
auto
arguments
=
args_from_options
<
MlaSm100Type
>
(
out
,
q_nope_and_q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
);
auto
arguments
=
args_from_options
<
MlaSm100Type
>
(
out
,
q_nope_and_q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
num_kv_splits
);
CUTLASS_CHECK
(
fmha
.
can_implement
(
arguments
));
...
...
@@ -185,31 +188,57 @@ void runMla(
CUTLASS_CHECK
(
fmha
.
run
(
arguments
,
workspace
.
data_ptr
(),
stream
));
}
#define DISPATCH_BOOL(expr, const_expr, ...) \
[&]() -> bool { \
if (expr) { \
constexpr bool const_expr = true; \
return __VA_ARGS__(); \
} else { \
constexpr bool const_expr = false; \
return __VA_ARGS__(); \
} \
}()
void
cutlass_mla_decode
(
torch
::
Tensor
const
&
out
,
torch
::
Tensor
const
&
q_nope_and_q_pe
,
torch
::
Tensor
const
&
kv_c_and_k_pe_cache
,
torch
::
Tensor
const
&
seq_lens
,
torch
::
Tensor
const
&
page_table
,
torch
::
Tensor
const
&
workspace
)
{
torch
::
Tensor
const
&
workspace
,
int64_t
num_kv_splits
)
{
auto
in_dtype
=
q_nope_and_q_pe
.
dtype
();
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
q_nope_and_q_pe
.
get_device
()};
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
q_nope_and_q_pe
.
get_device
());
if
(
in_dtype
==
at
::
ScalarType
::
Half
)
{
runMla
<
cutlass
::
half_t
>
(
out
,
q_nope_and_q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
stream
);
}
else
if
(
in_dtype
==
at
::
ScalarType
::
BFloat16
)
{
runMla
<
cutlass
::
bfloat16_t
>
(
out
,
q_nope_and_q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
stream
);
}
else
if
(
in_dtype
==
at
::
ScalarType
::
Float8_e4m3fn
)
{
runMla
<
cutlass
::
float_e4m3_t
>
(
out
,
q_nope_and_q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
stream
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported input data type of MLA"
);
}
const
int
page_size
=
kv_c_and_k_pe_cache
.
sizes
()[
1
];
// NOTE(alcanderian): IsPersistent has bug with manual split_kv.
// Kernel will hang if batch is too large with large num_kv_splits. (for example bs=8, num_kv_splits=8)
// Maybe per batch split kv will fix this.
DISPATCH_BOOL
(
page_size
==
128
,
IsPaged128
,
[
&
]
{
DISPATCH_BOOL
(
num_kv_splits
<=
1
,
NotManualSplitKV
,
[
&
]
{
if
(
in_dtype
==
at
::
ScalarType
::
Half
)
{
runMla
<
cutlass
::
half_t
,
IsPaged128
,
IsPersistent
<
NotManualSplitKV
>>
(
out
,
q_nope_and_q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
num_kv_splits
,
stream
);
}
else
if
(
in_dtype
==
at
::
ScalarType
::
BFloat16
)
{
runMla
<
cutlass
::
bfloat16_t
,
IsPaged128
,
IsPersistent
<
NotManualSplitKV
>>
(
out
,
q_nope_and_q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
num_kv_splits
,
stream
);
}
else
if
(
in_dtype
==
at
::
ScalarType
::
Float8_e4m3fn
)
{
runMla
<
cutlass
::
float_e4m3_t
,
IsPaged128
,
IsPersistent
<
NotManualSplitKV
>>
(
out
,
q_nope_and_q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
num_kv_splits
,
stream
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported input data type of MLA"
);
}
return
true
;
});
return
true
;
});
}
int64_t
cutlass_mla_get_workspace_size
(
int64_t
max_seq_len
,
int64_t
num_batches
,
int64_t
sm_count
)
{
int64_t
cutlass_mla_get_workspace_size
(
int64_t
max_seq_len
,
int64_t
num_batches
,
int64_t
sm_count
,
int64_t
num_kv_splits
)
{
// Workspace size depends on ElementAcc and ElementLSE (same as ElementAcc)
// which are float, so Element type here doesn't matter.
using
MlaSm100Type
=
MlaSm100
<
cutlass
::
half_t
>
;
using
MlaSm100Type
=
MlaSm100
<
cutlass
::
half_t
,
true
>
;
// Get split kv. Requires problem shape and sm_count only.
typename
MlaSm100Type
::
Fmha
::
Arguments
arguments
;
...
...
@@ -220,6 +249,7 @@ int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches,
// Assumes device 0 when getting sm_count.
arguments
.
hw_info
.
sm_count
=
sm_count
<=
0
?
cutlass
::
KernelHardwareInfo
::
query_device_multiprocessor_count
(
/*device_id=*/
0
)
:
sm_count
;
arguments
.
split_kv
=
num_kv_splits
;
MlaSm100Type
::
Fmha
::
set_split_kv
(
arguments
);
return
MlaSm100Type
::
Fmha
::
get_workspace_size
(
arguments
);
...
...
sgl-kernel/csrc/attention/cutlass_sm100_mla/device/sm100_mla.hpp
0 → 100644
View file @
18efb5e8
/***************************************************************************************************
* Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief An universal device layer for cutlass 3.x-style kernels.
*/
// clang-format off
#pragma once
// common
#include "cutlass/cutlass.h"
#include "cutlass/device_kernel.h"
#if !defined(__CUDACC_RTC__)
#include "cutlass/cluster_launch.hpp"
#include "cutlass/trace.h"
#endif // !defined(__CUDACC_RTC__)
#include "../kernel/sm100_fmha_mla_tma_warpspecialized.hpp"
#include "../kernel/sm100_fmha_mla_reduction.hpp"
////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
::
fmha
::
device
{
using
namespace
cute
;
using
namespace
cutlass
::
fmha
::
kernel
;
////////////////////////////////////////////////////////////////////////////////
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
template
<
class
Kernel_
>
class
MLA
{
public:
using
Kernel
=
Kernel_
;
using
ReductionKernel
=
cutlass
::
fmha
::
kernel
::
Sm100FmhaMlaReductionKernel
<
typename
Kernel
::
ElementOut
,
typename
Kernel
::
ElementAcc
,
typename
Kernel
::
ElementAcc
,
Kernel
::
TileShapeH
::
value
,
Kernel
::
TileShapeL
::
value
,
256
/*Max split*/
>
;
/// Argument structure: User API
using
KernelArguments
=
typename
Kernel
::
Arguments
;
using
ReductionArguments
=
typename
ReductionKernel
::
Arguments
;
using
Arguments
=
KernelArguments
;
/// Argument structure: Kernel API
using
KernelParams
=
typename
Kernel
::
Params
;
using
ReductionParams
=
typename
ReductionKernel
::
Params
;
struct
Params
{
KernelParams
fmha_params
;
ReductionParams
reduction_params
;
};
private:
/// Kernel API parameters object
Params
params_
;
bool
is_initialized
(
bool
set
=
false
)
{
static
bool
initialized
=
false
;
if
(
set
)
initialized
=
true
;
return
initialized
;
}
static
ReductionArguments
to_reduction_args
(
Arguments
const
&
args
)
{
auto
[
H
,
K
,
D
,
B
]
=
args
.
problem_shape
;
return
ReductionArguments
{
nullptr
,
args
.
epilogue
.
ptr_o
,
nullptr
,
args
.
epilogue
.
ptr_lse
,
args
.
mainloop
.
softmax_scale
,
B
,
args
.
split_kv
,
K
,
args
.
mainloop
.
ptr_seq
,
args
.
ptr_split_kv
,
Kernel
::
TileShapeS
::
value
};
}
public:
/// Access the Params structure
Params
const
&
params
()
const
{
return
params_
;
}
static
void
set_split_kv
(
KernelArguments
&
args
)
{
if
(
args
.
split_kv
>=
1
)
return
;
auto
[
H
,
K
,
D
,
B
]
=
args
.
problem_shape
;
int
sm_count
=
args
.
hw_info
.
sm_count
;
int
max_splits
=
ceil_div
(
K
,
128
);
int
sms_per_batch
=
max
(
1
,
sm_count
/
B
);
int
split_heur
=
min
(
max_splits
,
sms_per_batch
);
int
waves
=
ceil_div
(
B
*
split_heur
,
sm_count
);
int
k_waves
=
ceil_div
(
max_splits
,
split_heur
);
int
split_wave_aware
=
ceil_div
(
max_splits
,
k_waves
);
args
.
split_kv
=
split_wave_aware
;
}
/// Determines whether the GEMM can execute the given problem.
static
Status
can_implement
(
Arguments
const
&
args
)
{
if
(
!
Kernel
::
can_implement
(
args
))
{
return
Status
::
kInvalid
;
}
if
(
!
ReductionKernel
::
can_implement
(
to_reduction_args
(
args
)))
{
return
Status
::
kInvalid
;
}
return
Status
::
kSuccess
;
}
/// Gets the workspace size
static
size_t
get_workspace_size
(
Arguments
const
&
args
)
{
size_t
workspace_bytes
=
0
;
workspace_bytes
+=
Kernel
::
get_workspace_size
(
args
);
workspace_bytes
+=
ReductionKernel
::
get_workspace_size
(
to_reduction_args
(
args
));
return
workspace_bytes
;
}
/// Computes the maximum number of active blocks per multiprocessor
static
int
maximum_active_blocks
(
int
/* smem_capacity */
=
-
1
)
{
CUTLASS_TRACE_HOST
(
"MLA::maximum_active_blocks()"
);
int
max_active_blocks
=
-
1
;
int
smem_size
=
Kernel
::
SharedStorageSize
;
// first, account for dynamic smem capacity if needed
cudaError_t
result
;
if
(
smem_size
>=
(
48
<<
10
))
{
CUTLASS_TRACE_HOST
(
" Setting smem size to "
<<
smem_size
);
result
=
cudaFuncSetAttribute
(
device_kernel
<
Kernel
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
);
if
(
cudaSuccess
!=
result
)
{
result
=
cudaGetLastError
();
// to clear the error bit
CUTLASS_TRACE_HOST
(
" cudaFuncSetAttribute() returned error: "
<<
cudaGetErrorString
(
result
));
return
-
1
;
}
}
// query occupancy after setting smem size
result
=
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
max_active_blocks
,
device_kernel
<
Kernel
>
,
Kernel
::
MaxThreadsPerBlock
,
smem_size
);
if
(
cudaSuccess
!=
result
)
{
result
=
cudaGetLastError
();
// to clear the error bit
CUTLASS_TRACE_HOST
(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: "
<<
cudaGetErrorString
(
result
));
return
-
1
;
}
CUTLASS_TRACE_HOST
(
" max_active_blocks: "
<<
max_active_blocks
);
return
max_active_blocks
;
}
/// Initializes GEMM state from arguments.
Status
initialize
(
Arguments
const
&
args
,
void
*
workspace
=
nullptr
,
cudaStream_t
stream
=
nullptr
)
{
CUTLASS_TRACE_HOST
(
"MLA::initialize() - workspace "
<<
workspace
<<
", stream: "
<<
(
stream
?
"non-null"
:
"null"
));
// Initialize the workspace
Status
status
=
Kernel
::
initialize_workspace
(
args
,
workspace
,
stream
);
if
(
status
!=
Status
::
kSuccess
)
{
return
status
;
}
status
=
ReductionKernel
::
initialize_workspace
(
to_reduction_args
(
args
),
workspace
,
stream
);
if
(
status
!=
Status
::
kSuccess
)
{
return
status
;
}
KernelParams
kernel_params
=
Kernel
::
to_underlying_arguments
(
args
,
workspace
);
ReductionArguments
reduction_args
=
to_reduction_args
(
args
);
if
(
reduction_args
.
split_kv
>
1
)
{
reduction_args
.
ptr_oaccum
=
kernel_params
.
epilogue
.
ptr_o_acc
;
reduction_args
.
ptr_lseaccum
=
kernel_params
.
epilogue
.
ptr_lse_acc
;
}
ReductionParams
reduction_params
=
ReductionKernel
::
to_underlying_arguments
(
reduction_args
,
workspace
);
// Initialize the Params structure
params_
=
Params
{
kernel_params
,
reduction_params
};
if
(
is_initialized
())
return
Status
::
kSuccess
;
// account for dynamic smem capacity if needed
// no dynamic smem is needed for reduction kernel
int
smem_size
=
Kernel
::
SharedStorageSize
;
if
(
smem_size
>=
(
48
<<
10
))
{
CUTLASS_TRACE_HOST
(
" Setting smem size to "
<<
smem_size
);
cudaError_t
result
=
cudaFuncSetAttribute
(
device_kernel
<
Kernel
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
);
if
(
cudaSuccess
!=
result
)
{
result
=
cudaGetLastError
();
// to clear the error bit
CUTLASS_TRACE_HOST
(
" cudaFuncSetAttribute() returned error: "
<<
cudaGetErrorString
(
result
));
return
Status
::
kErrorInternal
;
}
}
is_initialized
(
true
);
return
Status
::
kSuccess
;
}
/// Update API is preserved in 3.0, but does not guarantee a lightweight update of params.
Status
update
(
Arguments
const
&
args
,
void
*
workspace
=
nullptr
)
{
CUTLASS_TRACE_HOST
(
"MLA()::update() - workspace: "
<<
workspace
);
size_t
workspace_bytes
=
get_workspace_size
(
args
);
if
(
workspace_bytes
>
0
&&
nullptr
==
workspace
)
{
return
Status
::
kErrorWorkspaceNull
;
}
auto
fmha_params
=
Kernel
::
to_underlying_arguments
(
args
,
workspace
);
ReductionArguments
reduction_args
=
to_reduction_args
(
args
);
if
(
reduction_args
.
split_kv
>
1
)
{
reduction_args
.
ptr_oaccum
=
fmha_params
.
epilogue
.
ptr_o_acc
;
reduction_args
.
ptr_lseaccum
=
fmha_params
.
epilogue
.
ptr_lse_acc
;
}
ReductionParams
reduction_params
=
ReductionKernel
::
to_underlying_arguments
(
reduction_args
,
workspace
);
// Initialize the Params structure
params_
=
Params
{
fmha_params
,
reduction_params
};
return
Status
::
kSuccess
;
}
/// Primary run() entry point API that is static allowing users to create and manage their own params.
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
static
Status
run
(
Params
&
params
,
cudaStream_t
stream
=
nullptr
)
{
CUTLASS_TRACE_HOST
(
"MLA::run()"
);
dim3
const
block
=
Kernel
::
get_block_shape
();
dim3
const
grid
=
Kernel
::
get_grid_shape
(
params
.
fmha_params
);
// configure smem size and carveout
int
smem_size
=
Kernel
::
SharedStorageSize
;
Status
launch_result
;
// Use extended launch API only for mainloops that use it
if
constexpr
(
Kernel
::
ArchTag
::
kMinComputeCapability
>=
90
)
{
dim3
cluster
(
cute
::
size
<
0
>
(
typename
Kernel
::
ClusterShape
{}),
cute
::
size
<
1
>
(
typename
Kernel
::
ClusterShape
{}),
cute
::
size
<
2
>
(
typename
Kernel
::
ClusterShape
{}));
void
const
*
kernel
=
(
void
const
*
)
device_kernel
<
Kernel
>
;
void
*
kernel_params
[]
=
{
&
params
.
fmha_params
};
launch_result
=
ClusterLauncher
::
launch
(
grid
,
cluster
,
block
,
smem_size
,
stream
,
kernel
,
kernel_params
);
}
else
{
launch_result
=
Status
::
kSuccess
;
device_kernel
<
Kernel
><<<
grid
,
block
,
smem_size
,
stream
>>>
(
params
.
fmha_params
);
}
cudaError_t
result
=
cudaGetLastError
();
if
(
cudaSuccess
!=
result
or
Status
::
kSuccess
!=
launch_result
)
{
//return Status::kSuccess;
CUTLASS_TRACE_HOST
(
" Kernel launch failed. Reason: "
<<
result
);
return
Status
::
kErrorInternal
;
}
if
(
params
.
reduction_params
.
split_kv
>
1
)
{
// launch reduction kernel
dim3
const
block
=
ReductionKernel
::
get_block_shape
();
dim3
const
grid
=
ReductionKernel
::
get_grid_shape
(
params
.
reduction_params
);
device_kernel
<
ReductionKernel
><<<
grid
,
block
,
0
,
stream
>>>
(
params
.
reduction_params
);
cudaError_t
result
=
cudaGetLastError
();
if
(
cudaSuccess
==
result
)
{
return
Status
::
kSuccess
;
}
else
{
CUTLASS_TRACE_HOST
(
" Kernel launch failed. Reason: "
<<
result
);
return
Status
::
kErrorInternal
;
}
}
else
{
return
Status
::
kSuccess
;
}
}
//
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
//
/// Launches the kernel after first constructing Params internal state from supplied arguments.
Status
run
(
Arguments
const
&
args
,
void
*
workspace
=
nullptr
,
cudaStream_t
stream
=
nullptr
)
{
Status
status
=
initialize
(
args
,
workspace
,
stream
);
if
(
Status
::
kSuccess
==
status
)
{
status
=
run
(
params_
,
stream
);
}
return
status
;
}
/// Launches the kernel after first constructing Params internal state from supplied arguments.
Status
operator
()(
Arguments
const
&
args
,
void
*
workspace
=
nullptr
,
cudaStream_t
stream
=
nullptr
)
{
return
run
(
args
,
workspace
,
stream
);
}
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
Status
run
(
cudaStream_t
stream
=
nullptr
)
{
return
run
(
params_
,
stream
);
}
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
Status
operator
()(
cudaStream_t
stream
=
nullptr
)
{
return
run
(
params_
,
stream
);
}
};
////////////////////////////////////////////////////////////////////////////////
}
// namespace cutlass::fmha::device
////////////////////////////////////////////////////////////////////////////////
sgl-kernel/csrc/attention/cutlass_sm100_mla/kernel/sm100_fmha_mla_reduction.hpp
0 → 100644
View file @
18efb5e8
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
// clang-format off
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/arch/arch.h"
#include "cute/tensor.hpp"
namespace
cutlass
::
fmha
::
kernel
{
using
namespace
cute
;
template
<
class
ElementOut
,
class
ElementAcc
,
class
ElementScale
,
size_t
kNumHeads
,
size_t
kHeadDimLatent
,
int
kMaxSplits
>
struct
Sm100FmhaMlaReductionKernel
{
static
const
int
SharedStorageSize
=
0
;
static
const
int
MaxThreadsPerBlock
=
128
;
static
const
int
MinBlocksPerMultiprocessor
=
1
;
using
ArchTag
=
cutlass
::
arch
::
Sm100
;
static_assert
(
kHeadDimLatent
%
MaxThreadsPerBlock
==
0
);
struct
Arguments
{
ElementAcc
*
ptr_oaccum
=
nullptr
;
ElementOut
*
ptr_o
=
nullptr
;
ElementAcc
*
ptr_lseaccum
=
nullptr
;
ElementAcc
*
ptr_lse
=
nullptr
;
ElementScale
scale
=
1.
f
;
int
num_batches
=
0
;
int
split_kv
=
-
1
;
int
dim_k
=
-
1
;
int
*
ptr_seq
=
nullptr
;
int
*
ptr_split_kv
=
nullptr
;
int
tile_shape_s
=
128
;
};
using
Params
=
Arguments
;
static
Params
to_underlying_arguments
(
Arguments
const
&
args
,
void
*
workspace
)
{
return
{
args
.
ptr_oaccum
,
args
.
ptr_o
,
args
.
ptr_lseaccum
,
args
.
ptr_lse
,
args
.
scale
,
args
.
num_batches
,
args
.
split_kv
,
args
.
dim_k
,
args
.
ptr_seq
,
args
.
ptr_split_kv
,
args
.
tile_shape_s
};
}
static
size_t
get_workspace_size
(
Arguments
const
&
/*args*/
)
{
return
0
;
}
static
Status
initialize_workspace
(
Arguments
const
&
/*args*/
,
void
*
/*ws*/
,
cudaStream_t
/*stream*/
)
{
return
Status
::
kSuccess
;
}
static
dim3
get_grid_shape
(
Params
const
&
params
)
{
return
dim3
(
kNumHeads
,
1
,
params
.
num_batches
);
}
static
dim3
get_block_shape
()
{
return
dim3
(
MaxThreadsPerBlock
,
1
,
1
);
}
static
bool
can_implement
(
Arguments
const
&
args
)
{
if
(
args
.
num_batches
<=
0
)
return
false
;
if
(
args
.
split_kv
<=
0
)
return
false
;
return
true
;
}
CUTLASS_DEVICE
void
operator
()
(
Params
const
&
params
,
char
*
smem_raw
)
{
if
(
params
.
split_kv
<=
1
)
return
;
auto
blk_coord
=
make_coord
(
blockIdx
.
x
,
_0
{},
blockIdx
.
z
);
__shared__
ElementAcc
sLseScale
[
kMaxSplits
];
const
size_t
offset_lseaccum
=
get
<
0
>
(
blk_coord
)
+
kNumHeads
*
params
.
split_kv
*
get
<
2
>
(
blk_coord
);
const
size_t
offset_lse
=
get
<
0
>
(
blk_coord
)
+
kNumHeads
*
get
<
2
>
(
blk_coord
);
Tensor
gLSEaccum
=
make_tensor
(
make_gmem_ptr
(
params
.
ptr_lseaccum
+
offset_lseaccum
),
make_shape
(
params
.
split_kv
),
Stride
<
Int
<
kNumHeads
>>
{});
Tensor
gLSE
=
make_tensor
(
make_gmem_ptr
(
params
.
ptr_lse
+
offset_lse
),
Shape
<
_1
>
{},
Stride
<
_1
>
{});
auto
dim_k
=
params
.
ptr_seq
==
nullptr
?
params
.
dim_k
:
params
.
ptr_seq
[
get
<
2
>
(
blk_coord
)];
auto
local_split_kv
=
params
.
ptr_split_kv
==
nullptr
?
params
.
split_kv
:
params
.
ptr_split_kv
[
get
<
2
>
(
blk_coord
)];
auto
k_tile_total
=
ceil_div
(
dim_k
,
params
.
tile_shape_s
);
auto
k_tile_per_cta
=
ceil_div
(
k_tile_total
,
local_split_kv
);
local_split_kv
=
ceil_div
(
k_tile_total
,
k_tile_per_cta
);
int
warp_idx
=
cutlass
::
canonical_warp_idx_sync
();
if
(
warp_idx
==
0
)
{
constexpr
int
kNLsePerThread
=
cute
::
ceil_div
(
kMaxSplits
,
32
);
ElementAcc
local_lse
[
kNLsePerThread
];
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
kNLsePerThread
;
++
i
)
{
const
int
split
=
i
*
32
+
threadIdx
.
x
;
local_lse
[
i
]
=
split
<
local_split_kv
?
gLSEaccum
(
split
)
:
-
std
::
numeric_limits
<
ElementAcc
>::
infinity
();
}
ElementAcc
lse_max
=
-
std
::
numeric_limits
<
ElementAcc
>::
infinity
();
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
kNLsePerThread
;
++
i
)
{
lse_max
=
max
(
lse_max
,
local_lse
[
i
]);
}
CUTLASS_PRAGMA_UNROLL
for
(
int
offset
=
16
;
offset
>=
1
;
offset
/=
2
)
{
lse_max
=
max
(
lse_max
,
__shfl_xor_sync
(
0xffffffff
,
lse_max
,
offset
));
}
lse_max
=
lse_max
==
-
std
::
numeric_limits
<
ElementAcc
>::
infinity
()
?
0.0
f
:
lse_max
;
// In case all local LSEs are -inf
lse_max
=
__shfl_sync
(
0xffffffff
,
lse_max
,
0
);
ElementAcc
sum_lse
=
0
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
kNLsePerThread
;
++
i
)
{
sum_lse
=
sum_lse
+
expf
(
local_lse
[
i
]
-
lse_max
);
}
CUTLASS_PRAGMA_UNROLL
for
(
int
offset
=
16
;
offset
>=
1
;
offset
/=
2
)
{
sum_lse
=
sum_lse
+
__shfl_xor_sync
(
0xffffffff
,
sum_lse
,
offset
);
}
sum_lse
=
__shfl_sync
(
0xffffffff
,
sum_lse
,
0
);
ElementAcc
global_lse
=
(
sum_lse
==
0.
f
||
sum_lse
!=
sum_lse
)
?
std
::
numeric_limits
<
ElementAcc
>::
infinity
()
:
logf
(
sum_lse
)
+
lse_max
;
if
(
threadIdx
.
x
==
0
and
params
.
ptr_lse
!=
nullptr
)
{
gLSE
(
0
)
=
global_lse
;
}
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
kNLsePerThread
;
++
i
)
{
const
int
split
=
i
*
32
+
threadIdx
.
x
;
if
(
split
<
local_split_kv
)
{
sLseScale
[
split
]
=
expf
(
local_lse
[
i
]
-
global_lse
);
}
}
}
__syncthreads
();
constexpr
int
Elements
=
kHeadDimLatent
/
MaxThreadsPerBlock
;
const
size_t
offset_oaccum
=
kHeadDimLatent
*
params
.
split_kv
*
(
get
<
0
>
(
blk_coord
)
+
kNumHeads
*
get
<
2
>
(
blk_coord
));
Tensor
gOaccum
=
make_tensor
(
make_gmem_ptr
(
params
.
ptr_oaccum
+
offset_oaccum
),
Shape
<
Int
<
kHeadDimLatent
>>
{},
Stride
<
_1
>
{});
ElementAcc
local_val
[
Elements
]
=
{
0
};
for
(
int
split
=
0
;
split
<
local_split_kv
;
++
split
)
{
ElementAcc
lse_scale
=
sLseScale
[
split
];
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
Elements
;
++
i
)
{
local_val
[
i
]
+=
lse_scale
*
gOaccum
(
threadIdx
.
x
+
MaxThreadsPerBlock
*
i
);
}
gOaccum
.
data
()
=
gOaccum
.
data
()
+
kHeadDimLatent
;
}
auto
ptr_o_local
=
params
.
ptr_o
+
(
get
<
0
>
(
blk_coord
)
+
get
<
2
>
(
blk_coord
)
*
kNumHeads
)
*
kHeadDimLatent
;
Tensor
gO
=
make_tensor
(
make_gmem_ptr
(
ptr_o_local
),
Shape
<
Int
<
kHeadDimLatent
>>
{},
Stride
<
_1
>
{});
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
Elements
;
++
i
)
{
gO
(
threadIdx
.
x
+
MaxThreadsPerBlock
*
i
)
=
static_cast
<
ElementOut
>
(
local_val
[
i
]);
}
}
};
}
// namespace cutlass::fmha::kernel
sgl-kernel/csrc/attention/cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp
0 → 100644
View file @
18efb5e8
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
// clang-format off
#pragma once
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/arch/simd_sm100.hpp"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/memory_sm80.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "gather_tensor.hpp" // from examples/common
#include "common/pow_2.hpp"
namespace
cutlass
::
fmha
::
kernel
{
using
namespace
cute
;
template
<
class
TileShape
,
class
Element_
,
class
ElementAcc_
,
class
ElementOut_
,
class
ElementLSE_
,
class
TileScheduler
,
#ifdef CPASYNC
bool
kIsCpAsync
=
true
#else
bool
kIsCpAsync
=
false
#endif
>
struct
Sm100FmhaMlaKernelTmaWarpspecialized
{
using
Element
=
Element_
;
using
ElementAcc
=
ElementAcc_
;
using
ElementOut
=
ElementOut_
;
using
ElementLSE
=
ElementLSE_
;
// only 2Sm mode is supported
static
const
bool
kIs2Sm
=
true
;
static
const
int
MaxThreadsPerBlock
=
256
;
static
const
int
MinBlocksPerMultiprocessor
=
1
;
static
const
int
TotalSNum
=
2
;
static
const
int
TotalPNum
=
2
;
using
ArchTag
=
cutlass
::
arch
::
Sm100
;
using
ClusterShape
=
cute
::
conditional_t
<
kIs2Sm
,
Shape
<
_2
,
_1
,
_1
>
,
Shape
<
_1
,
_1
,
_1
>>
;
using
TileShapeH
=
tuple_element_t
<
0
,
TileShape
>
;
using
TileShapeS
=
tuple_element_t
<
1
,
TileShape
>
;
using
TileShapeD
=
tuple_element_t
<
2
,
TileShape
>
;
using
TileShapeL
=
tuple_element_t
<
0
,
TileShapeD
>
;
using
TileShapeR
=
tuple_element_t
<
1
,
TileShapeD
>
;
static_assert
(
TileShapeL
{}
%
TileShapeR
{}
==
0
,
"Rope head dim must divide latent head dim"
);
using
ProblemShape
=
Shape
<
TileShapeH
,
int
,
TileShapeD
,
int
>
;
using
TensorStride
=
Stride
<
int64_t
,
_1
,
int64_t
>
;
using
TmemAllocator
=
cute
::
conditional_t
<
kIs2Sm
,
cute
::
TMEM
::
Allocator2Sm
,
cute
::
TMEM
::
Allocator1Sm
>
;
static_assert
(
TileShapeH
{}
==
128
);
static
const
int
kWarpsInN
=
kIs2Sm
?
2
:
1
;
static
const
int
kNumComputeWarps
=
4
;
static
const
int
kNumLoadWarps
=
kIsCpAsync
?
2
:
1
;
enum
class
WarpRole
{
kMma
=
0x1
,
kLoad
=
0x2
,
kCompute
=
0x3
,
kLoadPageTable
=
0x4
,
kEmpty
=
0x0
};
static
const
long
long
unsigned
int
kWarpAssignment
=
kIsCpAsync
?
0x4221'3333ull
:
0x0021'3333ull
;
static
CUTLASS_DEVICE
WarpRole
warp_idx_to_role
(
int
warp_idx
)
{
return
static_cast
<
WarpRole
>
((
kWarpAssignment
>>
(
4
*
warp_idx
))
&
0xF
);
}
static
const
int
Alignment
=
128
/
sizeof_bits_v
<
Element
>
;
static
const
int
AlignmentOut
=
128
/
sizeof_bits_v
<
ElementOut
>
;
using
TileShapeQK
=
Shape
<
TileShapeH
,
TileShapeS
,
decltype
(
TileShapeR
{}
/
_1
{})
>
;
static
const
int
StagesQK
=
24
/
sizeof
(
Element
);
// free parameter
static
const
int
IterationsQKLatent
=
decltype
(
TileShapeL
{}
/
get
<
2
>
(
TileShapeQK
{}))
::
value
;
static
const
int
IterationsQKRope
=
decltype
(
TileShapeR
{}
/
get
<
2
>
(
TileShapeQK
{}))
::
value
;
static
const
int
IterationsQK
=
IterationsQKLatent
+
IterationsQKRope
;
using
Schedule
=
cute
::
conditional_t
<
kIs2Sm
,
cutlass
::
gemm
::
KernelTmaWarpSpecialized2SmSm100
,
cutlass
::
gemm
::
KernelTmaWarpSpecialized1SmSm100
>
;
using
CollectiveMmaQK
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
cutlass
::
arch
::
Sm100
,
cutlass
::
arch
::
OpClassTensorOp
,
Element
,
TensorStride
,
Alignment
,
Element
,
TensorStride
,
Alignment
,
ElementAcc
,
TileShapeQK
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCount
<
StagesQK
>
,
Schedule
>::
CollectiveOp
;
using
TiledMmaQK
=
typename
CollectiveMmaQK
::
TiledMma
;
using
CtaShapeQK
=
typename
CollectiveMmaQK
::
CtaShape_MNK
;
// chosen for unified smem staging between K and V
using
TileShapePV
=
Shape
<
TileShapeH
,
_256
,
_32
>
;
using
TransposeTensorStride
=
decltype
(
select
<
1
,
0
,
2
>
(
TensorStride
{}));
static
const
int
StagesPV
=
StagesQK
;
// not sure why, but must be at least two. check pipes
static
const
int
IterationsPV_K
=
decltype
(
TileShapeS
{}
/
get
<
2
>
(
TileShapePV
{}))
::
value
;
static
const
int
IterationsPV_N
=
decltype
(
TileShapeL
{}
/
get
<
1
>
(
TileShapePV
{}))
::
value
;
using
CollectiveMmaPV
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
cutlass
::
arch
::
Sm100
,
cutlass
::
arch
::
OpClassTensorOp
,
Element
,
TensorStride
,
Alignment
,
Element
,
TransposeTensorStride
,
Alignment
,
ElementAcc
,
TileShapePV
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCount
<
StagesPV
>
,
Schedule
>::
CollectiveOp
;
using
CtaShapePV
=
typename
CollectiveMmaPV
::
CtaShape_MNK
;
static_assert
(
std
::
is_same_v
<
TransposeTensorStride
,
typename
CollectiveMmaPV
::
StrideB
>
);
using
TiledMmaPV
=
typename
CollectiveMmaPV
::
TiledMma
;
using
AtomThrShapeMNK
=
typename
CollectiveMmaQK
::
AtomThrShapeMNK
;
static_assert
(
typename
CollectiveMmaQK
::
AtomThrShapeMNK
{}
==
typename
CollectiveMmaPV
::
AtomThrShapeMNK
{},
"schedule must match"
);
static
const
int
StagesPageTable
=
kIsCpAsync
?
StagesPV
:
1
;
// pipelines from load to mma, PipelineTmaUmmaAsync, stages tbd
// use expect_tx for Q load
using
PipelineLoadQK
=
cute
::
conditional_t
<
kIsCpAsync
,
PipelineUmmaConsumerAsync
<
StagesQK
,
AtomThrShapeMNK
>
,
PipelineTmaUmmaAsync
<
StagesQK
,
ClusterShape
,
AtomThrShapeMNK
>>
;
using
PipelineLoadPV
=
PipelineLoadQK
;
// pipeline from mma (Q@K) to softmax, PipelineUmmaAsync, 2 stages
using
PipelineS
=
PipelineUmmaAsync
<
TotalSNum
,
AtomThrShapeMNK
>
;
// pipeline from softmax (P) to mma (bmm2), PipelineUmmaAsync, 2 stages
using
PipelineP
=
PipelineUmmaConsumerAsync
<
TotalPNum
,
AtomThrShapeMNK
>
;
// pipeline from mma to softmax (for rescale), PipelineUmmaAsync, 1 stage
using
PipelineO
=
PipelineUmmaAsync
<
1
,
AtomThrShapeMNK
>
;
using
PipelinePT
=
PipelineAsync
<
StagesPageTable
>
;
struct
PipelineStorage
{
alignas
(
16
)
typename
PipelineLoadQK
::
SharedStorage
load_qk
;
alignas
(
16
)
typename
PipelineS
::
SharedStorage
mma_s
;
alignas
(
16
)
typename
PipelineP
::
SharedStorage
p_mma
;
alignas
(
16
)
typename
PipelineO
::
SharedStorage
mma_o
;
alignas
(
16
)
typename
PipelinePT
::
SharedStorage
load_page_table
;
};
template
<
class
Layout
,
class
Stages
=
_1
>
static
CUTE_DEVICE
constexpr
auto
unstageSmemLayout
(
Layout
const
&
layout
,
Stages
stages
=
{})
{
return
composition
(
layout
,
make_tuple
(
_
,
_
,
_
,
make_layout
(
stages
)));
}
using
SmemLayoutQ
=
decltype
(
unstageSmemLayout
(
typename
CollectiveMmaQK
::
SmemLayoutA
{},
Int
<
IterationsQK
>
{}));
using
SmemLayoutKC
=
typename
CollectiveMmaQK
::
SmemLayoutB
;
using
SmemLayoutVC
=
typename
CollectiveMmaPV
::
SmemLayoutB
;
using
SmemLayoutP
=
decltype
(
unstageSmemLayout
(
typename
CollectiveMmaPV
::
SmemLayoutA
{},
make_shape
(
Int
<
IterationsPV_K
>
{},
_2
{})));
static
const
int
kBytesLoadQ
=
size
(
AtomThrShapeMNK
{})
*
cutlass
::
bits_to_bytes
(
cosize
(
take
<
0
,
3
>
(
SmemLayoutQ
{}))
*
cute
::
sizeof_bits_v
<
Element
>
);
static
const
int
kBytesLoadKC
=
size
(
AtomThrShapeMNK
{})
*
cutlass
::
bits_to_bytes
(
cosize
(
take
<
0
,
3
>
(
SmemLayoutKC
{}))
*
cute
::
sizeof_bits_v
<
Element
>
);
static
const
int
kBytesLoadVC
=
size
(
AtomThrShapeMNK
{})
*
cutlass
::
bits_to_bytes
(
cosize
(
take
<
0
,
3
>
(
SmemLayoutVC
{}))
*
cute
::
sizeof_bits_v
<
Element
>
);
// pre-condition for overlapped smem staging
static_assert
(
kBytesLoadKC
==
kBytesLoadVC
);
static_assert
(
StagesQK
==
StagesPV
);
static
const
int
kTransactionsBytesLoadQK
=
kBytesLoadKC
;
static
const
int
kTransactionsBytesLoadExtraQ
=
kBytesLoadQ
;
static
const
int
kTransactionsBytesLoadPV
=
kBytesLoadVC
;
static
const
int
kNamedBarrierExchange
=
(
int
)
cutlass
::
arch
::
ReservedNamedBarriers
::
TransformBarrier
;
// This Named Barrier is introduced to solve Q tile loading overwritten issue when enable persistent
// tile scheduler for FP8 MLA.
static
const
int
kNamedBarrierEpilogue
=
(
int
)
cutlass
::
arch
::
ReservedNamedBarriers
::
EpilogueBarrier
;
//
static
const
int
kNamedBarrierTmemDealloc
=
(
int
)
cutlass
::
arch
::
ReservedNamedBarriers
::
TmemAllocBarrier
;
enum
class
TmemAllocation
:
uint32_t
{
kSizeS
=
TileShapeS
::
value
/
kWarpsInN
,
// Overall
kSizeO
=
TileShapeL
::
value
/
kWarpsInN
,
// Between accumulators we loop over
kSizeAccO
=
decltype
(
get
<
1
>
(
TileShapePV
{}))
::
value
/
kWarpsInN
,
kNumS
=
TotalSNum
,
kNumP
=
TotalPNum
,
kNumO
=
1
,
kS0
=
0
,
kS1
=
kS0
+
kSizeS
,
kO0
=
kS1
+
kSizeS
,
kTotal
=
kO0
+
kSizeO
};
static_assert
(
static_cast
<
int
>
(
TmemAllocation
::
kTotal
)
<=
TmemAllocator
::
Sm100TmemCapacityColumns
,
"using too much tmem"
);
struct
TensorStorage
{
// to communicate max and row_sum
cute
::
array
<
ElementAcc
,
kNumComputeWarps
*
cutlass
::
NumThreadsPerWarp
>
smem_exchange
;
cute
::
array
<
int
,
StagesPageTable
*
TileShapeS
::
value
>
smem_page_table
;
alignas
(
2048
)
cute
::
array
<
Element
,
cute
::
cosize_v
<
SmemLayoutQ
>>
smem_q
;
union
{
alignas
(
2048
)
cute
::
array
<
Element
,
cute
::
cosize_v
<
SmemLayoutKC
>>
smem_kc
;
alignas
(
2048
)
cute
::
array
<
Element
,
cute
::
cosize_v
<
SmemLayoutVC
>>
smem_vc
;
};
alignas
(
2048
)
cute
::
array
<
Element
,
cute
::
cosize_v
<
SmemLayoutP
>>
smem_p
;
};
struct
SharedStorage
{
PipelineStorage
pipelines
;
TensorStorage
tensors
;
uint32_t
tmem_base_ptr
;
};
static
const
int
SharedStorageSize
=
sizeof
(
SharedStorage
);
static_assert
(
SharedStorageSize
<=
cutlass
::
arch
::
sm100_smem_capacity_bytes
,
"using too much smem"
);
struct
MainloopArguments
{
ElementAcc
softmax_scale
;
// all tensors strides are (num_heads or seqlen, head_dim, batch)
// head_dim stride is always 1
Element
*
ptr_q_latent
;
TensorStride
stride_q_latent
;
Element
*
ptr_q_rope
;
TensorStride
stride_q_rope
;
Element
*
ptr_c_latent
;
TensorStride
stride_c_latent
;
Element
*
ptr_k_rope
;
TensorStride
stride_k_rope
;
// for paged attention, we interpret what was previously [batch, seqlen]
// as [page_count, page_size], and index according to page_table
int
*
ptr_seq
=
nullptr
;
int
*
ptr_page_table
=
nullptr
;
// page table is [batch, seqlen or similar]
Stride
<
_1
,
int
>
stride_page_table
=
{};
int
page_count
=
0
;
int
page_size
=
TileShapeS
{};
// powers of two if kIsCpAsync, otherwise TileShapeS
};
struct
EpilogueArguments
{
ElementOut
*
ptr_o
=
nullptr
;
TensorStride
stride_o
;
ElementLSE
*
ptr_lse
=
nullptr
;
Stride
<
_1
,
int
>
stride_lse
;
ElementAcc
output_scale
=
1.0
f
;
};
struct
Arguments
{
// (num_heads=128, seqlen, (d_latent=512, d_rope=64), batch_count)
// for paged attention, seqlen is max seqlen
ProblemShape
problem_shape
;
MainloopArguments
mainloop
;
EpilogueArguments
epilogue
;
KernelHardwareInfo
hw_info
;
int
split_kv
=
-
1
;
int
*
ptr_split_kv
=
nullptr
;
};
using
TmaLoadQLatent
=
typename
CollectiveMmaQK
::
Params
::
TMA_A
;
using
TmaLoadQRope
=
typename
CollectiveMmaQK
::
Params
::
TMA_A
;
using
TmaLoadCLatent
=
typename
CollectiveMmaQK
::
Params
::
TMA_B
;
using
TmaLoadKRope
=
typename
CollectiveMmaQK
::
Params
::
TMA_B
;
using
TmaLoadCLatentTranspose
=
typename
CollectiveMmaPV
::
Params
::
TMA_B
;
struct
MainloopParams
{
TmaLoadQLatent
tma_load_q_latent
;
TmaLoadQRope
tma_load_q_rope
;
TmaLoadCLatent
tma_load_c_latent
;
TmaLoadKRope
tma_load_k_rope
;
TmaLoadCLatentTranspose
tma_load_c_latent_transpose
;
};
struct
EpilogueParams
{
ElementOut
*
ptr_o
=
nullptr
;
ElementAcc
*
ptr_o_acc
=
nullptr
;
TensorStride
stride_o
;
TensorStride
stride_o_acc
;
ElementLSE
*
ptr_lse
=
nullptr
;
ElementLSE
*
ptr_lse_acc
=
nullptr
;
Stride
<
_1
,
int
>
stride_lse
;
Stride
<
_1
,
int
>
stride_lse_acc
;
ElementAcc
output_scale
=
1.0
f
;
};
struct
Params
{
ProblemShape
problem_shape
;
MainloopArguments
mainloop
;
EpilogueParams
epilogue
;
MainloopParams
mainloop_params
;
typename
TileScheduler
::
Params
tile_scheduler
;
int
split_kv
=
-
1
;
int
*
ptr_split_kv
=
nullptr
;
};
static
Params
to_underlying_arguments
(
Arguments
const
&
args
,
void
*
workspace
)
{
//workspace = nullptr; // let's get an error if one of these needs workspace
auto
[
H
,
K
,
D
,
B
]
=
args
.
problem_shape
;
auto
[
L
,
R
]
=
D
;
int
paged_B
=
B
;
int
paged_K
=
K
;
if
(
args
.
mainloop
.
ptr_page_table
!=
nullptr
)
{
paged_B
=
args
.
mainloop
.
page_count
;
paged_K
=
args
.
mainloop
.
page_size
;
}
auto
params_qk_latent
=
CollectiveMmaQK
::
to_underlying_arguments
(
make_shape
(
H
,
K
,
L
,
B
),
typename
CollectiveMmaQK
::
Arguments
{
args
.
mainloop
.
ptr_q_latent
,
args
.
mainloop
.
stride_q_latent
,
args
.
mainloop
.
ptr_c_latent
,
args
.
mainloop
.
stride_c_latent
,
},
nullptr
);
auto
params_qk_latent_paged
=
CollectiveMmaQK
::
to_underlying_arguments
(
make_shape
(
H
,
paged_K
,
L
,
paged_B
),
typename
CollectiveMmaQK
::
Arguments
{
args
.
mainloop
.
ptr_q_latent
,
args
.
mainloop
.
stride_q_latent
,
args
.
mainloop
.
ptr_c_latent
,
args
.
mainloop
.
stride_c_latent
,
},
nullptr
);
auto
params_qk_rope
=
CollectiveMmaQK
::
to_underlying_arguments
(
make_shape
(
H
,
K
,
R
,
B
),
typename
CollectiveMmaQK
::
Arguments
{
args
.
mainloop
.
ptr_q_rope
,
args
.
mainloop
.
stride_q_rope
,
args
.
mainloop
.
ptr_k_rope
,
args
.
mainloop
.
stride_k_rope
,
},
nullptr
);
auto
params_qk_rope_paged
=
CollectiveMmaQK
::
to_underlying_arguments
(
make_shape
(
H
,
paged_K
,
R
,
paged_B
),
typename
CollectiveMmaQK
::
Arguments
{
args
.
mainloop
.
ptr_q_rope
,
args
.
mainloop
.
stride_q_rope
,
args
.
mainloop
.
ptr_k_rope
,
args
.
mainloop
.
stride_k_rope
,
},
nullptr
);
auto
stride_c_latent_transpose
=
select
<
1
,
0
,
2
>
(
args
.
mainloop
.
stride_c_latent
);
auto
params_pv_latent
=
CollectiveMmaPV
::
to_underlying_arguments
(
make_shape
(
H
,
L
,
paged_K
,
paged_B
),
typename
CollectiveMmaPV
::
Arguments
{
args
.
mainloop
.
ptr_q_latent
,
args
.
mainloop
.
stride_q_latent
,
// dummy, never used
args
.
mainloop
.
ptr_c_latent
,
stride_c_latent_transpose
,
},
nullptr
);
MainloopParams
mainloop_params
{
params_qk_latent
.
tma_load_a
,
params_qk_rope
.
tma_load_a
,
params_qk_latent_paged
.
tma_load_b
,
params_qk_rope_paged
.
tma_load_b
,
params_pv_latent
.
tma_load_b
};
EpilogueParams
epilogue_params
;
epilogue_params
.
ptr_o
=
args
.
epilogue
.
ptr_o
;
epilogue_params
.
stride_o
=
args
.
epilogue
.
stride_o
;
epilogue_params
.
ptr_lse
=
args
.
epilogue
.
ptr_lse
;
epilogue_params
.
stride_lse
=
args
.
epilogue
.
stride_lse
;
epilogue_params
.
output_scale
=
args
.
epilogue
.
output_scale
;
if
(
args
.
split_kv
>
1
)
{
ElementAcc
*
ptr_o_acc
=
reinterpret_cast
<
ElementAcc
*>
(
workspace
);
ElementLSE
*
ptr_lse_acc
=
reinterpret_cast
<
ElementLSE
*>
(
ptr_o_acc
+
H
*
L
*
args
.
split_kv
*
B
);
epilogue_params
.
ptr_o_acc
=
ptr_o_acc
;
epilogue_params
.
ptr_lse_acc
=
ptr_lse_acc
;
epilogue_params
.
stride_o_acc
=
make_tuple
(
static_cast
<
int64_t
>
(
0
+
L
)
*
args
.
split_kv
,
_1
{},
static_cast
<
int64_t
>
(
0
+
H
*
L
)
*
args
.
split_kv
);
epilogue_params
.
stride_lse_acc
=
make_tuple
(
_1
{},
(
0
+
H
)
*
args
.
split_kv
);
}
return
{
args
.
problem_shape
,
args
.
mainloop
,
epilogue_params
,
mainloop_params
,
TileScheduler
::
to_underlying_arguments
(
args
.
problem_shape
,
args
.
hw_info
,
ClusterShape
{},
args
.
split_kv
),
args
.
split_kv
,
args
.
ptr_split_kv
};
}
static
size_t
get_workspace_size
(
Arguments
const
&
args
)
{
ProblemShape
problem_shape
=
args
.
problem_shape
;
auto
[
H
,
K
,
D
,
B
]
=
problem_shape
;
auto
[
D_latent
,
D_rope
]
=
D
;
auto
split_kv
=
args
.
split_kv
;
return
(
sizeof
(
ElementAcc
)
*
D_latent
+
sizeof
(
ElementLSE
))
*
H
*
split_kv
*
B
;
}
static
Status
initialize_workspace
(
Arguments
const
&
/*args*/
,
void
*
/*ws*/
,
cudaStream_t
/*stream*/
)
{
return
Status
::
kSuccess
;
}
static
dim3
get_grid_shape
(
Params
const
&
params
)
{
return
TileScheduler
::
get_grid_shape
(
params
.
tile_scheduler
);
}
static
dim3
get_block_shape
()
{
dim3
block
(
MaxThreadsPerBlock
,
1
,
1
);
return
block
;
}
static
bool
can_implement
(
Arguments
const
&
args
)
{
if
(
kIsCpAsync
)
{
if
((
args
.
mainloop
.
page_size
&
(
args
.
mainloop
.
page_size
-
1
))
!=
0
)
{
return
false
;
}
if
(
args
.
mainloop
.
page_size
>
TileShapeS
{})
{
return
false
;
}
}
else
{
if
(
args
.
mainloop
.
ptr_page_table
!=
nullptr
&&
args
.
mainloop
.
page_size
!=
TileShapeS
{})
{
return
false
;
}
}
if
(
get
<
0
>
(
args
.
problem_shape
)
!=
128
)
{
return
false
;
}
if
(
get
<
1
>
(
args
.
problem_shape
)
<=
0
)
{
return
false
;
}
if
(
args
.
split_kv
<=
0
)
{
return
false
;
}
return
true
;
}
CUTLASS_DEVICE
void
operator
()(
Params
const
&
params
,
char
*
smem_raw
)
{
TileScheduler
tile_scheduler
(
params
.
tile_scheduler
);
int
warp_idx
=
cutlass
::
canonical_warp_idx_sync
();
auto
role
=
warp_idx_to_role
(
warp_idx
);
uint32_t
lane_predicate
=
cute
::
elect_one_sync
();
uint32_t
cta_rank_in_cluster
=
cute
::
block_rank_in_cluster
();
int
cta_coord_v
=
cta_rank_in_cluster
%
size
<
0
>
(
AtomThrShapeMNK
{});
bool
is_mma_leader_cta
=
cta_coord_v
==
0
;
if
(
role
==
WarpRole
::
kLoad
&&
lane_predicate
&&
!
kIsCpAsync
)
{
prefetch_tma_descriptor
(
params
.
mainloop_params
.
tma_load_q_latent
.
get_tma_descriptor
());
prefetch_tma_descriptor
(
params
.
mainloop_params
.
tma_load_c_latent
.
get_tma_descriptor
());
prefetch_tma_descriptor
(
params
.
mainloop_params
.
tma_load_q_rope
.
get_tma_descriptor
());
prefetch_tma_descriptor
(
params
.
mainloop_params
.
tma_load_k_rope
.
get_tma_descriptor
());
prefetch_tma_descriptor
(
params
.
mainloop_params
.
tma_load_c_latent_transpose
.
get_tma_descriptor
());
}
SharedStorage
&
shared_storage
=
*
reinterpret_cast
<
SharedStorage
*>
(
smem_raw
);
typename
PipelineLoadQK
::
Params
pipeline_load_qk_params
;
if
(
role
==
WarpRole
::
kLoad
)
{
pipeline_load_qk_params
.
role
=
PipelineLoadQK
::
ThreadCategory
::
Producer
;
}
if
(
role
==
WarpRole
::
kMma
)
{
pipeline_load_qk_params
.
role
=
PipelineLoadQK
::
ThreadCategory
::
Consumer
;
}
if
constexpr
(
kIsCpAsync
)
{
// we can make our life easier by unconditionally loading blocks
// since we know it'll always be legal
pipeline_load_qk_params
.
producer_arv_count
=
kNumLoadWarps
*
cutlass
::
NumThreadsPerWarp
*
size
(
AtomThrShapeMNK
{});
}
else
{
pipeline_load_qk_params
.
is_leader
=
lane_predicate
&&
(
role
==
WarpRole
::
kLoad
)
&&
is_mma_leader_cta
;
pipeline_load_qk_params
.
transaction_bytes
=
kTransactionsBytesLoadQK
;
}
pipeline_load_qk_params
.
initializing_warp
=
0
;
PipelineLoadQK
pipeline_load_qk
(
shared_storage
.
pipelines
.
load_qk
,
pipeline_load_qk_params
,
ClusterShape
{},
/*barrier init*/
cute
::
true_type
{},
/*mask calc*/
cute
::
false_type
{});
typename
PipelineS
::
Params
pipeline_mma_s_params
;
if
(
role
==
WarpRole
::
kMma
)
{
pipeline_mma_s_params
.
role
=
PipelineS
::
ThreadCategory
::
Producer
;
}
if
(
role
==
WarpRole
::
kCompute
)
{
pipeline_mma_s_params
.
role
=
PipelineS
::
ThreadCategory
::
Consumer
;
}
pipeline_mma_s_params
.
consumer_arv_count
=
kNumComputeWarps
*
cutlass
::
NumThreadsPerWarp
*
size
(
AtomThrShapeMNK
{});
pipeline_mma_s_params
.
initializing_warp
=
1
;
PipelineS
pipeline_mma_s
(
shared_storage
.
pipelines
.
mma_s
,
pipeline_mma_s_params
,
ClusterShape
{},
/*barrier init*/
cute
::
true_type
{},
/*mask calc*/
cute
::
false_type
{});
typename
PipelineP
::
Params
pipeline_p_mma_params
;
if
(
role
==
WarpRole
::
kMma
)
{
pipeline_p_mma_params
.
role
=
PipelineP
::
ThreadCategory
::
Consumer
;
}
if
(
role
==
WarpRole
::
kCompute
)
{
pipeline_p_mma_params
.
role
=
PipelineP
::
ThreadCategory
::
Producer
;
}
pipeline_p_mma_params
.
producer_arv_count
=
kNumComputeWarps
*
cutlass
::
NumThreadsPerWarp
*
size
(
AtomThrShapeMNK
{});
pipeline_p_mma_params
.
consumer_arv_count
=
1
;
pipeline_p_mma_params
.
initializing_warp
=
2
;
PipelineP
pipeline_p_mma
(
shared_storage
.
pipelines
.
p_mma
,
pipeline_p_mma_params
,
ClusterShape
{},
/*barrier init*/
cute
::
true_type
{},
/*mask calc*/
cute
::
false_type
{});
typename
PipelineO
::
Params
pipeline_mma_o_params
;
if
(
role
==
WarpRole
::
kMma
)
{
pipeline_mma_o_params
.
role
=
PipelineO
::
ThreadCategory
::
Producer
;
}
if
(
role
==
WarpRole
::
kCompute
)
{
pipeline_mma_o_params
.
role
=
PipelineO
::
ThreadCategory
::
Consumer
;
}
pipeline_mma_o_params
.
consumer_arv_count
=
kNumComputeWarps
*
cutlass
::
NumThreadsPerWarp
*
size
(
AtomThrShapeMNK
{});
pipeline_mma_o_params
.
initializing_warp
=
3
;
PipelineO
pipeline_mma_o
(
shared_storage
.
pipelines
.
mma_o
,
pipeline_mma_o_params
,
ClusterShape
{},
/*barrier init*/
cute
::
true_type
{},
/*mask calc*/
cute
::
false_type
{});
typename
PipelinePT
::
Params
pipeline_pt_params
;
if
(
role
==
WarpRole
::
kLoad
)
{
pipeline_pt_params
.
role
=
PipelinePT
::
ThreadCategory
::
Consumer
;
}
if
(
role
==
WarpRole
::
kLoadPageTable
)
{
pipeline_pt_params
.
role
=
PipelinePT
::
ThreadCategory
::
Producer
;
}
pipeline_pt_params
.
consumer_arv_count
=
kNumLoadWarps
*
cutlass
::
NumThreadsPerWarp
;
pipeline_pt_params
.
producer_arv_count
=
cutlass
::
NumThreadsPerWarp
;
pipeline_pt_params
.
initializing_warp
=
4
;
PipelinePT
pipeline_page_table
(
shared_storage
.
pipelines
.
load_page_table
,
pipeline_pt_params
);
TmemAllocator
tmem_allocator
;
pipeline_init_arrive_relaxed
(
size
(
ClusterShape
{}));
pipeline_load_qk
.
init_masks
(
ClusterShape
{});
// do we need an update here for 2Sm?
pipeline_mma_s
.
init_masks
(
ClusterShape
{});
pipeline_p_mma
.
init_masks
(
ClusterShape
{});
pipeline_mma_o
.
init_masks
(
ClusterShape
{});
typename
PipelineLoadQK
::
PipelineState
pipeline_load_qk_consumer_state
;
typename
PipelineLoadQK
::
PipelineState
pipeline_load_qk_producer_state
=
cutlass
::
make_producer_start_state
<
PipelineLoadQK
>
();
typename
PipelineS
::
PipelineState
pipeline_mma_s_consumer_state
;
typename
PipelineS
::
PipelineState
pipeline_mma_s_producer_state
=
cutlass
::
make_producer_start_state
<
PipelineS
>
();
typename
PipelineP
::
PipelineState
pipeline_p_mma_consumer_state
;
typename
PipelineP
::
PipelineState
pipeline_p_mma_producer_state
=
cutlass
::
make_producer_start_state
<
PipelineP
>
();
typename
PipelineO
::
PipelineState
pipeline_mma_o_consumer_state
;
typename
PipelineO
::
PipelineState
pipeline_mma_o_producer_state
=
cutlass
::
make_producer_start_state
<
PipelineO
>
();
typename
PipelinePT
::
PipelineState
pipeline_pt_consumer_state
;
typename
PipelinePT
::
PipelineState
pipeline_pt_producer_state
=
cutlass
::
make_producer_start_state
<
PipelinePT
>
();
pipeline_init_wait
(
size
(
ClusterShape
{}));
if
(
role
==
WarpRole
::
kLoadPageTable
)
{
CUTLASS_PRAGMA_NO_UNROLL
for
(;
tile_scheduler
.
is_valid
();
++
tile_scheduler
)
{
auto
blk_coord
=
tile_scheduler
.
get_block_coord
();
auto
problem_shape
=
params
.
problem_shape
;
auto
local_split_kv
=
params
.
split_kv
;
if
(
params
.
mainloop
.
ptr_seq
!=
nullptr
)
{
get
<
1
>
(
problem_shape
)
=
params
.
mainloop
.
ptr_seq
[
get
<
2
>
(
blk_coord
)];
if
(
params
.
ptr_split_kv
!=
nullptr
)
{
local_split_kv
=
params
.
ptr_split_kv
[
get
<
2
>
(
blk_coord
)];
}
}
if
(
local_split_kv
<=
get
<
3
>
(
blk_coord
))
continue
;
load_page_table
(
blk_coord
,
problem_shape
,
params
.
mainloop
,
shared_storage
.
tensors
,
pipeline_page_table
,
pipeline_pt_producer_state
,
local_split_kv
);
}
}
else
if
(
role
==
WarpRole
::
kLoad
)
{
if
constexpr
(
kIsCpAsync
)
{
CUTLASS_PRAGMA_NO_UNROLL
for
(;
tile_scheduler
.
is_valid
();
++
tile_scheduler
)
{
auto
blk_coord
=
tile_scheduler
.
get_block_coord
();
auto
problem_shape
=
params
.
problem_shape
;
auto
local_split_kv
=
params
.
split_kv
;
if
(
params
.
mainloop
.
ptr_seq
!=
nullptr
)
{
get
<
1
>
(
problem_shape
)
=
params
.
mainloop
.
ptr_seq
[
get
<
2
>
(
blk_coord
)];
if
(
params
.
ptr_split_kv
!=
nullptr
)
{
local_split_kv
=
params
.
ptr_split_kv
[
get
<
2
>
(
blk_coord
)];
}
}
if
(
local_split_kv
<=
get
<
3
>
(
blk_coord
))
continue
;
load_cpasync
(
blk_coord
,
problem_shape
,
params
.
mainloop
,
params
.
mainloop_params
,
shared_storage
.
tensors
,
pipeline_load_qk
,
pipeline_load_qk_producer_state
,
local_split_kv
,
/* must be shared pipe */
pipeline_page_table
,
pipeline_pt_consumer_state
);
cutlass
::
arch
::
NamedBarrier
((
kNumComputeWarps
+
kNumLoadWarps
)
*
NumThreadsPerWarp
,
kNamedBarrierEpilogue
).
arrive_and_wait
();
}
}
else
{
if
(
params
.
mainloop
.
ptr_page_table
!=
nullptr
)
{
CUTLASS_PRAGMA_NO_UNROLL
for
(;
tile_scheduler
.
is_valid
();
++
tile_scheduler
)
{
auto
blk_coord
=
tile_scheduler
.
get_block_coord
();
auto
problem_shape
=
params
.
problem_shape
;
auto
local_split_kv
=
params
.
split_kv
;
if
(
params
.
mainloop
.
ptr_seq
!=
nullptr
)
{
get
<
1
>
(
problem_shape
)
=
params
.
mainloop
.
ptr_seq
[
get
<
2
>
(
blk_coord
)];
if
(
params
.
ptr_split_kv
!=
nullptr
)
{
local_split_kv
=
params
.
ptr_split_kv
[
get
<
2
>
(
blk_coord
)];
}
}
if
(
local_split_kv
<=
get
<
3
>
(
blk_coord
))
continue
;
load_tma
<
/* paged= */
true
>
(
blk_coord
,
problem_shape
,
params
.
mainloop
,
params
.
mainloop_params
,
shared_storage
.
tensors
,
pipeline_load_qk
,
pipeline_load_qk_producer_state
,
pipeline_load_qk
,
pipeline_load_qk_producer_state
,
local_split_kv
);
cutlass
::
arch
::
NamedBarrier
((
kNumComputeWarps
+
kNumLoadWarps
)
*
NumThreadsPerWarp
,
kNamedBarrierEpilogue
).
arrive_and_wait
();
}
}
else
{
CUTLASS_PRAGMA_NO_UNROLL
for
(;
tile_scheduler
.
is_valid
();
++
tile_scheduler
)
{
auto
blk_coord
=
tile_scheduler
.
get_block_coord
();
auto
problem_shape
=
params
.
problem_shape
;
auto
local_split_kv
=
params
.
split_kv
;
if
(
params
.
mainloop
.
ptr_seq
!=
nullptr
)
{
get
<
1
>
(
problem_shape
)
=
params
.
mainloop
.
ptr_seq
[
get
<
2
>
(
blk_coord
)];
if
(
params
.
ptr_split_kv
!=
nullptr
)
{
local_split_kv
=
params
.
ptr_split_kv
[
get
<
2
>
(
blk_coord
)];
}
}
if
(
local_split_kv
<=
get
<
3
>
(
blk_coord
))
continue
;
load_tma
<
false
>
(
blk_coord
,
problem_shape
,
params
.
mainloop
,
params
.
mainloop_params
,
shared_storage
.
tensors
,
pipeline_load_qk
,
pipeline_load_qk_producer_state
,
pipeline_load_qk
,
pipeline_load_qk_producer_state
,
local_split_kv
);
cutlass
::
arch
::
NamedBarrier
((
kNumComputeWarps
+
kNumLoadWarps
)
*
NumThreadsPerWarp
,
kNamedBarrierEpilogue
).
arrive_and_wait
();
}
}
}
}
else
if
(
role
==
WarpRole
::
kMma
)
{
tmem_allocator
.
allocate
(
TmemAllocator
::
Sm100TmemCapacityColumns
,
&
shared_storage
.
tmem_base_ptr
);
__syncwarp
();
if
(
is_mma_leader_cta
)
{
CUTLASS_PRAGMA_NO_UNROLL
for
(;
tile_scheduler
.
is_valid
();
++
tile_scheduler
)
{
auto
blk_coord
=
tile_scheduler
.
get_block_coord
();
auto
problem_shape
=
params
.
problem_shape
;
auto
local_split_kv
=
params
.
split_kv
;
if
(
params
.
mainloop
.
ptr_seq
!=
nullptr
)
{
get
<
1
>
(
problem_shape
)
=
params
.
mainloop
.
ptr_seq
[
get
<
2
>
(
blk_coord
)];
if
(
params
.
ptr_split_kv
!=
nullptr
)
{
local_split_kv
=
params
.
ptr_split_kv
[
get
<
2
>
(
blk_coord
)];
}
}
if
(
local_split_kv
<=
get
<
3
>
(
blk_coord
))
continue
;
mma
(
blk_coord
,
problem_shape
,
shared_storage
.
tensors
,
pipeline_load_qk
,
pipeline_load_qk_consumer_state
,
pipeline_load_qk
,
pipeline_load_qk_consumer_state
,
pipeline_mma_s
,
pipeline_mma_s_producer_state
,
pipeline_p_mma
,
pipeline_p_mma_consumer_state
,
pipeline_mma_o
,
pipeline_mma_o_producer_state
,
local_split_kv
);
}
}
//cutlass::arch::NamedBarrier((kNumComputeWarps + 1) * NumThreadsPerWarp, kNamedBarrierTmemDealloc).arrive_and_wait();
//uint32_t free_stage_ptr = shared_storage.tmem_base_ptr;
//tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns);
}
else
if
(
role
==
WarpRole
::
kCompute
)
{
CUTLASS_PRAGMA_NO_UNROLL
for
(;
tile_scheduler
.
is_valid
();
++
tile_scheduler
)
{
auto
blk_coord
=
tile_scheduler
.
get_block_coord
();
auto
problem_shape
=
params
.
problem_shape
;
auto
split_kv
=
params
.
split_kv
;
auto
local_split_kv
=
split_kv
;
if
(
params
.
mainloop
.
ptr_seq
!=
nullptr
)
{
get
<
1
>
(
problem_shape
)
=
params
.
mainloop
.
ptr_seq
[
get
<
2
>
(
blk_coord
)];
if
(
params
.
ptr_split_kv
!=
nullptr
)
{
local_split_kv
=
params
.
ptr_split_kv
[
get
<
2
>
(
blk_coord
)];
}
}
if
(
local_split_kv
<=
get
<
3
>
(
blk_coord
))
continue
;
compute
(
blk_coord
,
problem_shape
,
params
.
mainloop
,
// for softmax_scale
params
.
epilogue
,
shared_storage
.
tensors
,
// for smem_comm
pipeline_mma_s
,
pipeline_mma_s_consumer_state
,
pipeline_p_mma
,
pipeline_p_mma_producer_state
,
pipeline_mma_o
,
pipeline_mma_o_consumer_state
,
local_split_kv
);
}
//cutlass::arch::NamedBarrier((kNumComputeWarps + 1) * NumThreadsPerWarp, kNamedBarrierTmemDealloc).arrive();
}
cute
::
cluster_sync
();
cutlass
::
arch
::
NamedBarrier
((
kNumComputeWarps
+
1
)
*
NumThreadsPerWarp
,
kNamedBarrierTmemDealloc
).
arrive
();
if
(
role
==
WarpRole
::
kMma
)
{
uint32_t
free_stage_ptr
=
shared_storage
.
tmem_base_ptr
;
tmem_allocator
.
free
(
free_stage_ptr
,
TmemAllocator
::
Sm100TmemCapacityColumns
);
}
}
template
<
class
BlkCoord
>
CUTLASS_DEVICE
void
load_page_table
(
BlkCoord
const
&
blk_coord
,
ProblemShape
const
&
problem_shape
,
MainloopArguments
const
&
mainloop_args
,
TensorStorage
&
shared_tensors
,
PipelinePT
&
pipeline_page_table
,
typename
PipelinePT
::
PipelineState
&
pipeline_pt_producer_state
,
int
const
&
split_kv
)
{
auto
[
H
,
K
,
D
,
B
]
=
problem_shape
;
int
batch_coord
=
get
<
2
>
(
blk_coord
);
auto
mPT_l
=
make_tensor
(
make_gmem_ptr
(
mainloop_args
.
ptr_page_table
),
make_shape
(
mainloop_args
.
page_count
,
B
),
mainloop_args
.
stride_page_table
);
auto
mPT
=
mPT_l
(
_
,
batch_coord
);
int
k_tile_total
=
ceil_div
(
K
,
TileShapeS
{});
int
k_tile_per_cta
=
ceil_div
(
k_tile_total
,
split_kv
);
int
k_index
=
get
<
3
>
(
blk_coord
)
*
k_tile_per_cta
;
// lower limit
int
k_tile_count
=
max
(
0
,
min
(
k_tile_total
,
k_index
+
k_tile_per_cta
)
-
k_index
);
if
(
k_tile_count
==
0
)
{
return
;
}
auto
page_size
=
Pow2
{
mainloop_args
.
page_size
};
auto
pages_per_tile
=
Pow2
{
TileShapeS
{}
/
page_size
};
int
thread_idx
=
threadIdx
.
x
%
cutlass
::
NumThreadsPerWarp
;
#if 1
for
(;
k_tile_count
>
0
;
++
k_index
,
--
k_tile_count
)
{
pipeline_page_table
.
producer_acquire
(
pipeline_pt_producer_state
);
// assume a single warp
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
TileShapeS
{};
i
+=
cutlass
::
NumThreadsPerWarp
)
{
int
idx
=
i
+
thread_idx
;
bool
guard
=
idx
<
pages_per_tile
;
int
smem_idx
=
pipeline_pt_producer_state
.
index
()
*
TileShapeS
::
value
+
idx
;
int
pt_idx
=
pages_per_tile
*
k_index
+
idx
;
cutlass
::
arch
::
cp_async_zfill
<
sizeof
(
int
),
cutlass
::
arch
::
CacheOperation
::
Always
>
(
&
shared_tensors
.
smem_page_table
[
smem_idx
],
&
mPT
(
pt_idx
),
guard
);
}
pipeline_page_table
.
producer_commit
(
pipeline_pt_producer_state
,
cutlass
::
arch
::
cpasync_barrier_arrive
);
++
pipeline_pt_producer_state
;
}
#endif
}
struct
Gather
{
int
&
page_table_stage
;
Pow2
pages_per_tile
;
const
int
*
__restrict__
smem_page_table
;
CUTLASS_DEVICE
int
operator
()(
int
idx
)
const
{
return
smem_page_table
[
page_table_stage
*
TileShapeS
::
value
+
idx
%
pages_per_tile
];
}
CUTLASS_DEVICE
friend
void
print
(
Gather
const
&
)
{
printf
(
"<gather>"
);
}
};
template
<
class
BlkCoord
>
CUTLASS_DEVICE
void
load_cpasync
(
BlkCoord
const
&
blk_coord
,
ProblemShape
const
&
problem_shape
,
MainloopArguments
const
&
mainloop_args
,
MainloopParams
const
&
mainloop_params
,
TensorStorage
&
shared_tensors
,
PipelineLoadQK
&
pipeline_load
,
typename
PipelineLoadQK
::
PipelineState
&
pipeline_load_producer_state
,
int
const
&
split_kv
,
PipelinePT
&
pipeline_page_table
,
typename
PipelinePT
::
PipelineState
&
pipeline_pt_consumer_state
)
{
auto
[
H
,
K
,
D
,
B
]
=
problem_shape
;
auto
[
D_latent
,
D_rope
]
=
D
;
using
X
=
Underscore
;
int
k_tile_total
=
ceil_div
(
K
,
TileShapeS
{});
int
k_tile_per_cta
=
ceil_div
(
k_tile_total
,
split_kv
);
int
k_index
=
get
<
3
>
(
blk_coord
)
*
k_tile_per_cta
;
// lower limit
int
k_tile_count
=
max
(
0
,
min
(
k_tile_total
,
k_index
+
k_tile_per_cta
)
-
k_index
);
if
(
k_tile_count
==
0
)
{
return
;
}
// partition all tensors
auto
mQL
=
make_tensor
(
make_gmem_ptr
(
mainloop_args
.
ptr_q_latent
),
make_shape
(
H
,
D_latent
,
B
),
mainloop_args
.
stride_q_latent
);
auto
mQR
=
make_tensor
(
make_gmem_ptr
(
mainloop_args
.
ptr_q_rope
),
make_shape
(
H
,
D_rope
,
B
),
mainloop_args
.
stride_q_rope
);
int
paged_B
=
mainloop_args
.
page_count
;
auto
paged_K
=
Pow2
{
mainloop_args
.
page_size
};
auto
mPT_l
=
make_tensor
(
make_gmem_ptr
(
mainloop_args
.
ptr_page_table
),
make_shape
(
paged_B
,
B
),
mainloop_args
.
stride_page_table
);
int
batch_coord
=
get
<
2
>
(
blk_coord
);
auto
mPT
=
mPT_l
(
_
,
batch_coord
);
auto
gQL
=
local_tile
(
mQL
,
TileShapeQK
{},
make_coord
(
_
,
_
,
_
),
Step
<
_1
,
X
,
_1
>
{});
auto
gQR
=
local_tile
(
mQR
,
TileShapeQK
{},
make_coord
(
_
,
_
,
_
),
Step
<
_1
,
X
,
_1
>
{});
ThrMMA
cta_mma_qk
=
TiledMmaQK
{}.
get_slice
(
get
<
0
>
(
blk_coord
)
%
size
(
AtomThrShapeMNK
{}));
ThrMMA
cta_mma_pv
=
TiledMmaPV
{}.
get_slice
(
get
<
0
>
(
blk_coord
)
%
size
(
AtomThrShapeMNK
{}));
auto
tSgQL
=
cta_mma_qk
.
partition_A
(
gQL
);
auto
tSgQR
=
cta_mma_qk
.
partition_A
(
gQR
);
Tensor
sQ
=
make_tensor
(
make_smem_ptr
(
shared_tensors
.
smem_q
.
begin
()),
SmemLayoutQ
{});
Tensor
sKC
=
make_tensor
(
make_smem_ptr
(
shared_tensors
.
smem_kc
.
begin
()),
SmemLayoutKC
{});
Tensor
sVC
=
make_tensor
(
make_smem_ptr
(
shared_tensors
.
smem_vc
.
begin
()),
SmemLayoutVC
{});
auto
make_copy_for
=
[](
auto
sT
)
{
auto
rT_a
=
sT
.
layout
()(
_
,
_
,
_
,
_0
{});
auto
rT
=
make_ordered_layout
(
shape
(
rT_a
),
stride
(
rT_a
));
auto
threads
=
Int
<
kNumLoadWarps
*
cutlass
::
NumThreadsPerWarp
>
{};
auto
values
=
Int
<
sizeof
(
uint128_t
)
/
sizeof
(
Element
)
>
{};
return
make_cotiled_copy
(
Copy_Atom
<
SM80_CP_ASYNC_CACHEALWAYS
<
uint128_t
>
,
Element
>
{},
make_ordered_layout
(
make_shape
(
threads
,
values
),
make_stride
(
_1
{},
_0
{})),
rT
);
};
// like cute::copy, but makes sure we do all page table lookups first
auto
copy_split
=
[](
auto
atom
,
auto
src
,
auto
dst
)
{
auto
src_v
=
group_modes
<
1
,
rank_v
<
decltype
(
src
)
>>
(
src
);
auto
dst_v
=
group_modes
<
1
,
rank_v
<
decltype
(
dst
)
>>
(
dst
);
auto
src_v_ptrs
=
make_tensor
<
Element
*>
(
size
<
1
>
(
src_v
));
for
(
int
i
=
0
;
i
<
size
<
1
>
(
src_v
);
i
++
)
{
src_v_ptrs
(
i
)
=
&
src_v
(
_0
{},
i
);
}
for
(
int
i
=
0
;
i
<
size
<
1
>
(
src_v
);
i
++
)
{
auto
src_v_i
=
make_tensor
(
make_gmem_ptr
(
src_v_ptrs
(
i
)),
make_shape
(
shape
<
0
>
(
src_v
)),
make_stride
(
make_stride
(
_1
{},
_0
{}))
);
atom
.
call
(
src_v_i
,
dst_v
(
_
,
i
));
}
};
auto
tiled_copy_q
=
make_copy_for
(
sQ
);
auto
tiled_copy_kc
=
make_copy_for
(
sKC
);
auto
tiled_copy_vc
=
make_copy_for
(
sVC
);
auto
thr_copy_q
=
tiled_copy_q
.
get_thread_slice
(
threadIdx
.
x
%
(
kNumLoadWarps
*
cutlass
::
NumThreadsPerWarp
));
auto
thr_copy_kc
=
tiled_copy_kc
.
get_thread_slice
(
threadIdx
.
x
%
(
kNumLoadWarps
*
cutlass
::
NumThreadsPerWarp
));
auto
thr_copy_vc
=
tiled_copy_vc
.
get_thread_slice
(
threadIdx
.
x
%
(
kNumLoadWarps
*
cutlass
::
NumThreadsPerWarp
));
auto
tQsQ
=
thr_copy_q
.
partition_D
(
sQ
);
auto
tQgQL
=
thr_copy_q
.
partition_S
(
tSgQL
);
auto
tQgQR
=
thr_copy_q
.
partition_S
(
tSgQR
);
auto
tKCsKC
=
thr_copy_kc
.
partition_D
(
sKC
);
auto
tVCsVC
=
thr_copy_vc
.
partition_D
(
sVC
);
auto
pipeline_pt_release_state
=
pipeline_pt_consumer_state
;
int
page_table_stage
=
-
1
;
Pow2
pages_per_tile
{
TileShapeS
{}
/
paged_K
};
const
int
*
__restrict__
smem_page_table
=
shared_tensors
.
smem_page_table
.
begin
();
Gather
gather
{
page_table_stage
,
pages_per_tile
,
smem_page_table
};
auto
mCL
=
make_tensor
(
make_gmem_ptr
(
mainloop_args
.
ptr_c_latent
),
ComposedLayout
{
make_layout
(
make_shape
(
make_shape
(
paged_K
,
paged_B
),
_1
{}),
make_stride
(
make_stride
(
get
<
0
>
(
mainloop_args
.
stride_c_latent
),
example
::
CustomStride
(
gather
,
get
<
2
>
(
mainloop_args
.
stride_c_latent
))),
get
<
1
>
(
mainloop_args
.
stride_c_latent
))),
make_coord
(
_0
{},
_0
{}),
make_identity_layout
(
make_shape
(
paged_K
*
paged_B
,
D_latent
))});
auto
mKR
=
make_tensor
(
make_gmem_ptr
(
mainloop_args
.
ptr_k_rope
),
ComposedLayout
{
make_layout
(
make_shape
(
make_shape
(
paged_K
,
paged_B
),
_1
{}),
make_stride
(
make_stride
(
get
<
0
>
(
mainloop_args
.
stride_k_rope
),
example
::
CustomStride
(
gather
,
get
<
2
>
(
mainloop_args
.
stride_k_rope
))),
get
<
1
>
(
mainloop_args
.
stride_k_rope
))),
make_coord
(
_0
{},
_0
{}),
make_identity_layout
(
make_shape
(
paged_K
*
paged_B
,
D_latent
))});
auto
mCLT
=
make_tensor
(
make_gmem_ptr
(
mainloop_args
.
ptr_c_latent
),
ComposedLayout
{
make_layout
(
make_shape
(
_1
{},
make_shape
(
paged_K
,
paged_B
)),
make_stride
(
get
<
1
>
(
mainloop_args
.
stride_c_latent
),
make_stride
(
get
<
0
>
(
mainloop_args
.
stride_c_latent
),
example
::
CustomStride
(
gather
,
get
<
2
>
(
mainloop_args
.
stride_c_latent
))))),
make_coord
(
_0
{},
_0
{}),
make_identity_layout
(
make_shape
(
D_latent
,
paged_K
*
paged_B
))});
auto
gCL
=
local_tile
(
mCL
,
TileShapeQK
{},
make_coord
(
_
,
_
,
_
),
Step
<
X
,
_1
,
_1
>
{});
auto
gKR
=
local_tile
(
mKR
,
TileShapeQK
{},
make_coord
(
_
,
_
,
_
),
Step
<
X
,
_1
,
_1
>
{});
auto
gCLT
=
local_tile
(
mCLT
,
TileShapePV
{},
make_coord
(
_
,
_
,
_
),
Step
<
X
,
_1
,
_1
>
{});
auto
tSgCL
=
cta_mma_qk
.
partition_B
(
gCL
);
auto
tSgKR
=
cta_mma_qk
.
partition_B
(
gKR
);
auto
tOgCLT
=
cta_mma_pv
.
partition_B
(
gCLT
);
auto
tKCgCL
=
thr_copy_kc
.
partition_S
(
tSgCL
);
auto
tKCgKR
=
thr_copy_kc
.
partition_S
(
tSgKR
);
auto
tVCgCLT
=
thr_copy_vc
.
partition_S
(
tOgCLT
);
// latent is first in memory, so let's load it first always
// startup: alternate Q and K, set tx count appropriately, for k_idx = 0
auto
&
pipeline_acquire_state
=
pipeline_load_producer_state
;
auto
pipeline_commit_state
=
pipeline_acquire_state
;
int
pipeline_offset
=
0
;
for
(
int
i
=
0
;
i
<
StagesPV
;
i
++
)
{
cutlass
::
arch
::
cp_async_fence
();
}
auto
load_stage
=
[
&
](
auto
fn
)
{
pipeline_load
.
producer_acquire
(
pipeline_acquire_state
);
fn
(
pipeline_acquire_state
.
index
());
cutlass
::
arch
::
cp_async_fence
();
++
pipeline_acquire_state
;
++
pipeline_offset
;
if
(
pipeline_offset
==
StagesPV
-
1
)
{
cutlass
::
arch
::
cp_async_wait
<
StagesPV
-
1
>
();
pipeline_load
.
producer_commit
(
pipeline_commit_state
);
++
pipeline_commit_state
;
--
pipeline_offset
;
}
};
pipeline_page_table
.
consumer_wait
(
pipeline_pt_consumer_state
);
page_table_stage
=
pipeline_pt_consumer_state
.
index
();
++
pipeline_pt_consumer_state
;
// each Q/K tile consists of rope and latent
for
(
int
i
=
0
;
i
<
IterationsQKLatent
;
i
++
)
{
load_stage
([
&
](
int
index
)
{
cute
::
copy
(
tiled_copy_q
,
tQgQL
(
_
,
_
,
_
,
_
,
_0
{},
i
,
batch_coord
),
tQsQ
(
_
,
_
,
_
,
_
,
i
));
copy_split
(
tiled_copy_kc
,
tKCgCL
(
_
,
_
,
_
,
_
,
k_index
,
i
),
tKCsKC
(
_
,
_
,
_
,
_
,
index
));
});
}
for
(
int
i
=
0
;
i
<
IterationsQKRope
;
i
++
)
{
load_stage
([
&
](
int
index
)
{
cute
::
copy
(
tiled_copy_q
,
tQgQR
(
_
,
_
,
_
,
_
,
_0
{},
i
,
batch_coord
),
tQsQ
(
_
,
_
,
_
,
_
,
IterationsQKLatent
+
i
));
copy_split
(
tiled_copy_kc
,
tKCgKR
(
_
,
_
,
_
,
_
,
k_index
,
i
),
tKCsKC
(
_
,
_
,
_
,
_
,
index
));
});
}
k_index
+=
1
;
k_tile_count
-=
1
;
// assume k_tile_count >= 1
// perform K+Q load here
CUTLASS_PRAGMA_NO_UNROLL
while
(
k_tile_count
>
0
)
{
pipeline_page_table
.
consumer_wait
(
pipeline_pt_consumer_state
);
page_table_stage
=
pipeline_pt_consumer_state
.
index
();
++
pipeline_pt_consumer_state
;
for
(
int
i
=
0
;
i
<
IterationsQKLatent
;
i
++
)
{
load_stage
([
&
](
int
index
)
{
copy_split
(
tiled_copy_kc
,
tKCgCL
(
_
,
_
,
_
,
_
,
k_index
,
i
),
tKCsKC
(
_
,
_
,
_
,
_
,
index
));
});
}
for
(
int
i
=
0
;
i
<
IterationsQKRope
;
i
++
)
{
load_stage
([
&
](
int
index
)
{
copy_split
(
tiled_copy_kc
,
tKCgKR
(
_
,
_
,
_
,
_
,
k_index
,
i
),
tKCsKC
(
_
,
_
,
_
,
_
,
index
));
});
}
page_table_stage
=
pipeline_pt_release_state
.
index
();
for
(
int
i
=
0
;
i
<
IterationsPV_K
;
i
++
)
{
for
(
int
j
=
0
;
j
<
IterationsPV_N
;
j
++
)
{
load_stage
([
&
](
int
index
)
{
copy_split
(
tiled_copy_vc
,
tVCgCLT
(
_
,
_
,
_
,
_
,
j
,
IterationsPV_K
*
(
k_index
-
1
)
+
i
),
tVCsVC
(
_
,
_
,
_
,
_
,
index
));
});
}
}
pipeline_page_table
.
consumer_release
(
pipeline_pt_release_state
);
++
pipeline_pt_release_state
;
k_index
+=
1
;
k_tile_count
-=
1
;
}
page_table_stage
=
pipeline_pt_release_state
.
index
();
for
(
int
i
=
0
;
i
<
IterationsPV_K
;
i
++
)
{
for
(
int
j
=
0
;
j
<
IterationsPV_N
;
j
++
)
{
load_stage
([
&
](
int
index
)
{
copy_split
(
tiled_copy_vc
,
tVCgCLT
(
_
,
_
,
_
,
_
,
j
,
IterationsPV_K
*
(
k_index
-
1
)
+
i
),
tVCsVC
(
_
,
_
,
_
,
_
,
index
));
});
}
}
pipeline_page_table
.
consumer_release
(
pipeline_pt_release_state
);
++
pipeline_pt_release_state
;
while
(
pipeline_offset
>
0
)
{
cutlass
::
arch
::
cp_async_fence
();
cutlass
::
arch
::
cp_async_wait
<
StagesPV
-
1
>
();
pipeline_load
.
producer_commit
(
pipeline_commit_state
);
++
pipeline_commit_state
;
--
pipeline_offset
;
}
cutlass
::
arch
::
cp_async_wait
<
0
>
();
}
template
<
bool
kIsPaged
=
false
,
class
BlkCoord
>
CUTLASS_DEVICE
void
load_tma
(
BlkCoord
const
&
blk_coord
,
ProblemShape
const
&
problem_shape
,
MainloopArguments
const
&
mainloop_args
,
MainloopParams
const
&
mainloop_params
,
TensorStorage
&
shared_tensors
,
PipelineLoadQK
&
pipeline_load_qk
,
typename
PipelineLoadQK
::
PipelineState
&
pipeline_load_qk_producer_state
,
PipelineLoadPV
&
pipeline_load_pv
,
typename
PipelineLoadPV
::
PipelineState
&
pipeline_load_pv_producer_state
,
int
const
&
split_kv
)
{
auto
[
H
,
K
,
D
,
B
]
=
problem_shape
;
auto
[
D_latent
,
D_rope
]
=
D
;
int
k_tile_total
=
ceil_div
(
K
,
TileShapeS
{});
int
k_tile_per_cta
=
ceil_div
(
k_tile_total
,
split_kv
);
int
k_index
=
get
<
3
>
(
blk_coord
)
*
k_tile_per_cta
;
// lower limit
int
k_tile_count
=
max
(
0
,
min
(
k_tile_total
,
k_index
+
k_tile_per_cta
)
-
k_index
);
if
(
k_tile_count
==
0
)
{
return
;
}
using
X
=
Underscore
;
// partition all tensors
auto
mQL
=
mainloop_params
.
tma_load_q_latent
.
get_tma_tensor
(
make_shape
(
H
,
D_latent
,
B
));
auto
mQR
=
mainloop_params
.
tma_load_q_rope
.
get_tma_tensor
(
make_shape
(
H
,
D_rope
,
B
));
int
paged_B
=
B
;
int
paged_K
=
K
;
if
constexpr
(
kIsPaged
)
{
paged_B
=
mainloop_args
.
page_count
;
paged_K
=
mainloop_args
.
page_size
;
}
auto
mPT_l
=
make_tensor
(
make_gmem_ptr
(
mainloop_args
.
ptr_page_table
),
make_shape
(
paged_B
,
B
),
mainloop_args
.
stride_page_table
);
auto
mCL
=
mainloop_params
.
tma_load_c_latent
.
get_tma_tensor
(
make_shape
(
paged_K
,
D_latent
,
paged_B
));
auto
mKR
=
mainloop_params
.
tma_load_k_rope
.
get_tma_tensor
(
make_shape
(
paged_K
,
D_rope
,
paged_B
));
auto
mCLT
=
mainloop_params
.
tma_load_c_latent_transpose
.
get_tma_tensor
(
make_shape
(
D_latent
,
paged_K
,
paged_B
));
auto
gQL
=
local_tile
(
mQL
,
TileShapeQK
{},
make_coord
(
_
,
_
,
_
),
Step
<
_1
,
X
,
_1
>
{});
auto
gQR
=
local_tile
(
mQR
,
TileShapeQK
{},
make_coord
(
_
,
_
,
_
),
Step
<
_1
,
X
,
_1
>
{});
auto
gCL
=
local_tile
(
mCL
,
TileShapeQK
{},
make_coord
(
_
,
_
,
_
),
Step
<
X
,
_1
,
_1
>
{});
auto
gKR
=
local_tile
(
mKR
,
TileShapeQK
{},
make_coord
(
_
,
_
,
_
),
Step
<
X
,
_1
,
_1
>
{});
auto
gCLT
=
local_tile
(
mCLT
,
TileShapePV
{},
make_coord
(
_
,
_
,
_
),
Step
<
X
,
_1
,
_1
>
{});
ThrMMA
cta_mma_qk
=
TiledMmaQK
{}.
get_slice
(
get
<
0
>
(
blk_coord
)
%
size
(
AtomThrShapeMNK
{}));
ThrMMA
cta_mma_pv
=
TiledMmaPV
{}.
get_slice
(
get
<
0
>
(
blk_coord
)
%
size
(
AtomThrShapeMNK
{}));
auto
tSgQL
=
cta_mma_qk
.
partition_A
(
gQL
);
auto
tSgQR
=
cta_mma_qk
.
partition_A
(
gQR
);
auto
tSgCL
=
cta_mma_qk
.
partition_B
(
gCL
);
auto
tSgKR
=
cta_mma_qk
.
partition_B
(
gKR
);
auto
tOgCLT
=
cta_mma_pv
.
partition_B
(
gCLT
);
Tensor
sQ
=
make_tensor
(
make_smem_ptr
(
shared_tensors
.
smem_q
.
begin
()),
SmemLayoutQ
{});
Tensor
sKC
=
make_tensor
(
make_smem_ptr
(
shared_tensors
.
smem_kc
.
begin
()),
SmemLayoutKC
{});
Tensor
sVC
=
make_tensor
(
make_smem_ptr
(
shared_tensors
.
smem_vc
.
begin
()),
SmemLayoutVC
{});
auto
[
tQLgQL_mkl
,
tQsQ
]
=
tma_partition
(
mainloop_params
.
tma_load_q_latent
,
_0
{},
make_layout
(
_1
{}),
group_modes
<
0
,
3
>
(
sQ
),
group_modes
<
0
,
3
>
(
tSgQL
));
auto
[
tQRgQR_mkl
,
tQsQ_ignore
]
=
tma_partition
(
mainloop_params
.
tma_load_q_rope
,
_0
{},
make_layout
(
_1
{}),
group_modes
<
0
,
3
>
(
sQ
),
group_modes
<
0
,
3
>
(
tSgQR
));
auto
[
tCLgCL_nkl
,
tKCsKC
]
=
tma_partition
(
mainloop_params
.
tma_load_c_latent
,
_0
{},
make_layout
(
_1
{}),
group_modes
<
0
,
3
>
(
sKC
),
group_modes
<
0
,
3
>
(
tSgCL
));
auto
[
tKRgKR_nkl
,
tKCsKC_ignore
]
=
tma_partition
(
mainloop_params
.
tma_load_k_rope
,
_0
{},
make_layout
(
_1
{}),
group_modes
<
0
,
3
>
(
sKC
),
group_modes
<
0
,
3
>
(
tSgKR
));
auto
[
tCLTgCLT_nkl
,
tVCsVC
]
=
tma_partition
(
mainloop_params
.
tma_load_c_latent_transpose
,
_0
{},
make_layout
(
_1
{}),
group_modes
<
0
,
3
>
(
sVC
),
group_modes
<
0
,
3
>
(
tOgCLT
));
uint16_t
mcast_mask
=
0
;
int
batch_coord
=
get
<
2
>
(
blk_coord
);
Tensor
tQLgQL
=
tQLgQL_mkl
(
_
,
_
,
_
,
batch_coord
);
Tensor
tQRgQR
=
tQRgQR_mkl
(
_
,
_
,
_
,
batch_coord
);
auto
mPT
=
mPT_l
(
_
,
batch_coord
);
Tensor
tCLgCL
=
tCLgCL_nkl
(
_
,
_
,
_
,
_
);
Tensor
tKRgKR
=
tKRgKR_nkl
(
_
,
_
,
_
,
_
);
// careful: stage and k are swapped here!
Tensor
tCLTgCLT
=
tCLTgCLT_nkl
(
_
,
_
,
_
,
_
);
// latent is first in memory, so let's load it first always
// startup: alternate Q and K, set tx count appropriately, for k_idx = 0
// each Q/K tile consists of rope and latent
for
(
int
i
=
0
;
i
<
IterationsQKLatent
;
i
++
)
{
pipeline_load_qk
.
producer_expect_transaction
(
pipeline_load_qk_producer_state
,
kTransactionsBytesLoadExtraQ
);
pipeline_load_qk
.
producer_acquire
(
pipeline_load_qk_producer_state
);
auto
tma_barrier
=
pipeline_load_qk
.
producer_get_barrier
(
pipeline_load_qk_producer_state
);
if
(
cute
::
elect_one_sync
())
{
// expect the extra bytes
// load_qk ql
cute
::
copy
(
mainloop_params
.
tma_load_q_latent
.
with
(
*
tma_barrier
,
mcast_mask
),
tQLgQL
(
_
,
_0
{},
i
),
tQsQ
(
_
,
i
));
// load_qk cl
if
constexpr
(
kIsPaged
)
{
cute
::
copy
(
mainloop_params
.
tma_load_c_latent
.
with
(
*
tma_barrier
,
mcast_mask
),
tCLgCL
(
_
,
_0
{},
i
,
mPT
(
k_index
)),
tKCsKC
(
_
,
pipeline_load_qk_producer_state
.
index
())
);
}
else
{
cute
::
copy
(
mainloop_params
.
tma_load_c_latent
.
with
(
*
tma_barrier
,
mcast_mask
),
tCLgCL
(
_
,
k_index
,
i
,
batch_coord
),
tKCsKC
(
_
,
pipeline_load_qk_producer_state
.
index
()));
}
}
++
pipeline_load_qk_producer_state
;
}
for
(
int
i
=
0
;
i
<
IterationsQKRope
;
i
++
)
{
pipeline_load_qk
.
producer_expect_transaction
(
pipeline_load_qk_producer_state
,
kTransactionsBytesLoadExtraQ
);
pipeline_load_qk
.
producer_acquire
(
pipeline_load_qk_producer_state
);
auto
tma_barrier
=
pipeline_load_qk
.
producer_get_barrier
(
pipeline_load_qk_producer_state
);
if
(
cute
::
elect_one_sync
())
{
// expect the extra bytes
// load_qk ql
cute
::
copy
(
mainloop_params
.
tma_load_q_rope
.
with
(
*
tma_barrier
,
mcast_mask
),
tQRgQR
(
_
,
_0
{},
i
),
tQsQ
(
_
,
i
+
IterationsQKLatent
));
// load_qk cl
if
constexpr
(
kIsPaged
)
{
cute
::
copy
(
mainloop_params
.
tma_load_k_rope
.
with
(
*
tma_barrier
,
mcast_mask
),
tKRgKR
(
_
,
_0
{},
i
,
mPT
(
k_index
)),
tKCsKC
(
_
,
pipeline_load_qk_producer_state
.
index
())
);
}
else
{
cute
::
copy
(
mainloop_params
.
tma_load_k_rope
.
with
(
*
tma_barrier
,
mcast_mask
),
tKRgKR
(
_
,
k_index
,
i
,
batch_coord
),
tKCsKC
(
_
,
pipeline_load_qk_producer_state
.
index
()));
}
}
++
pipeline_load_qk_producer_state
;
}
k_index
+=
1
;
k_tile_count
-=
1
;
// assume k_tile_count >= 1
// perform K+Q load here
CUTLASS_PRAGMA_NO_UNROLL
while
(
k_tile_count
>
0
)
{
// perform K load
for
(
int
i
=
0
;
i
<
IterationsQKLatent
;
i
++
)
{
pipeline_load_qk
.
producer_acquire
(
pipeline_load_qk_producer_state
);
auto
tma_barrier
=
pipeline_load_qk
.
producer_get_barrier
(
pipeline_load_qk_producer_state
);
if
(
cute
::
elect_one_sync
())
{
// load_qk cl
if
constexpr
(
kIsPaged
)
{
cute
::
copy
(
mainloop_params
.
tma_load_c_latent
.
with
(
*
tma_barrier
,
mcast_mask
),
tCLgCL
(
_
,
_0
{},
i
,
mPT
(
k_index
)),
tKCsKC
(
_
,
pipeline_load_qk_producer_state
.
index
())
);
}
else
{
cute
::
copy
(
mainloop_params
.
tma_load_c_latent
.
with
(
*
tma_barrier
,
mcast_mask
),
tCLgCL
(
_
,
k_index
,
i
,
batch_coord
),
tKCsKC
(
_
,
pipeline_load_qk_producer_state
.
index
()));
}
}
++
pipeline_load_qk_producer_state
;
}
for
(
int
i
=
0
;
i
<
IterationsQKRope
;
i
++
)
{
pipeline_load_qk
.
producer_acquire
(
pipeline_load_qk_producer_state
);
auto
tma_barrier
=
pipeline_load_qk
.
producer_get_barrier
(
pipeline_load_qk_producer_state
);
if
(
cute
::
elect_one_sync
())
{
// load_qk cl
if
constexpr
(
kIsPaged
)
{
cute
::
copy
(
mainloop_params
.
tma_load_k_rope
.
with
(
*
tma_barrier
,
mcast_mask
),
tKRgKR
(
_
,
_0
{},
i
,
mPT
(
k_index
)),
tKCsKC
(
_
,
pipeline_load_qk_producer_state
.
index
())
);
}
else
{
cute
::
copy
(
mainloop_params
.
tma_load_k_rope
.
with
(
*
tma_barrier
,
mcast_mask
),
tKRgKR
(
_
,
k_index
,
i
,
batch_coord
),
tKCsKC
(
_
,
pipeline_load_qk_producer_state
.
index
()));
}
}
++
pipeline_load_qk_producer_state
;
}
// prefetch next K load to keep busy while we transpose-load from cache
const
int
kPrefetchDistance
=
1
;
for
(
int
i
=
0
;
i
<
IterationsQKLatent
;
i
++
)
{
if
(
cute
::
elect_one_sync
())
{
if
constexpr
(
kIsPaged
)
{
if
(
k_tile_count
>
kPrefetchDistance
)
{
cute
::
prefetch
(
mainloop_params
.
tma_load_c_latent
,
tCLgCL
(
_
,
_0
{},
i
,
mPT
(
k_index
+
kPrefetchDistance
))
);
}
}
else
{
cute
::
prefetch
(
mainloop_params
.
tma_load_c_latent
,
tCLgCL
(
_
,
k_index
+
kPrefetchDistance
,
i
,
batch_coord
)
);
}
}
}
for
(
int
i
=
0
;
i
<
IterationsQKRope
;
i
++
)
{
if
(
cute
::
elect_one_sync
())
{
if
constexpr
(
kIsPaged
)
{
if
(
k_tile_count
>
kPrefetchDistance
)
{
cute
::
prefetch
(
mainloop_params
.
tma_load_k_rope
,
tKRgKR
(
_
,
_0
{},
i
,
mPT
(
k_index
+
kPrefetchDistance
))
);
}
}
else
{
cute
::
prefetch
(
mainloop_params
.
tma_load_k_rope
,
tKRgKR
(
_
,
k_index
+
kPrefetchDistance
,
i
,
batch_coord
)
);
}
}
}
// perform V load (k_idx - 1)
for
(
int
i
=
0
;
i
<
IterationsPV_K
;
i
++
)
{
for
(
int
j
=
0
;
j
<
IterationsPV_N
;
j
++
)
{
pipeline_load_pv
.
producer_acquire
(
pipeline_load_pv_producer_state
);
auto
tma_barrier
=
pipeline_load_pv
.
producer_get_barrier
(
pipeline_load_pv_producer_state
);
if
(
cute
::
elect_one_sync
())
{
// load_pv cl
// note the transpose in indices!
// note we are off-by-one on k_index
if
constexpr
(
kIsPaged
)
{
cute
::
copy
(
mainloop_params
.
tma_load_c_latent_transpose
.
with
(
*
tma_barrier
,
mcast_mask
,
cute
::
TMA
::
CacheHintSm100
::
EVICT_FIRST
),
tCLTgCLT
(
_
,
j
,
i
,
mPT
(
k_index
-
1
)),
tVCsVC
(
_
,
pipeline_load_pv_producer_state
.
index
())
);
}
else
{
cute
::
copy
(
mainloop_params
.
tma_load_c_latent_transpose
.
with
(
*
tma_barrier
,
mcast_mask
,
cute
::
TMA
::
CacheHintSm100
::
EVICT_FIRST
),
tCLTgCLT
(
_
,
j
,
IterationsPV_K
*
(
k_index
-
1
)
+
i
,
batch_coord
),
tVCsVC
(
_
,
pipeline_load_pv_producer_state
.
index
())
);
}
}
++
pipeline_load_pv_producer_state
;
}
}
k_index
+=
1
;
k_tile_count
-=
1
;
}
for
(
int
i
=
0
;
i
<
IterationsPV_K
;
i
++
)
{
for
(
int
j
=
0
;
j
<
IterationsPV_N
;
j
++
)
{
pipeline_load_pv
.
producer_acquire
(
pipeline_load_pv_producer_state
);
auto
tma_barrier
=
pipeline_load_pv
.
producer_get_barrier
(
pipeline_load_pv_producer_state
);
if
(
cute
::
elect_one_sync
())
{
// load_pv cl
// note the transpose in indices
// note we are off-by-one on k_index
if
constexpr
(
kIsPaged
)
{
cute
::
copy
(
mainloop_params
.
tma_load_c_latent_transpose
.
with
(
*
tma_barrier
,
mcast_mask
,
cute
::
TMA
::
CacheHintSm100
::
EVICT_FIRST
),
tCLTgCLT
(
_
,
j
,
i
,
mPT
(
k_index
-
1
)),
tVCsVC
(
_
,
pipeline_load_pv_producer_state
.
index
())
);
}
else
{
cute
::
copy
(
mainloop_params
.
tma_load_c_latent_transpose
.
with
(
*
tma_barrier
,
mcast_mask
,
cute
::
TMA
::
CacheHintSm100
::
EVICT_FIRST
),
tCLTgCLT
(
_
,
j
,
IterationsPV_K
*
(
k_index
-
1
)
+
i
,
batch_coord
),
tVCsVC
(
_
,
pipeline_load_pv_producer_state
.
index
())
);
}
}
++
pipeline_load_pv_producer_state
;
}
}
}
template
<
class
BlkCoord
>
CUTLASS_DEVICE
void
mma
(
BlkCoord
const
&
blk_coord
,
ProblemShape
const
&
problem_shape
,
TensorStorage
&
shared_tensors
,
PipelineLoadQK
&
pipeline_load_qk
,
typename
PipelineLoadQK
::
PipelineState
&
pipeline_load_qk_consumer_state
,
PipelineLoadPV
&
pipeline_load_pv
,
typename
PipelineLoadPV
::
PipelineState
&
pipeline_load_pv_consumer_state
,
PipelineS
&
pipeline_mma_s
,
typename
PipelineS
::
PipelineState
&
pipeline_mma_s_producer_state
,
PipelineP
&
pipeline_p_mma
,
typename
PipelineP
::
PipelineState
&
pipeline_p_mma_consumer_state
,
PipelineO
&
pipeline_mma_o
,
typename
PipelineO
::
PipelineState
&
pipeline_mma_o_producer_state
,
int
const
&
split_kv
)
{
auto
[
H
,
K
,
D
,
B
]
=
problem_shape
;
int
k_tile_total
=
ceil_div
(
K
,
TileShapeS
{});
int
k_tile_per_cta
=
ceil_div
(
k_tile_total
,
split_kv
);
int
k_index
=
get
<
3
>
(
blk_coord
)
*
k_tile_per_cta
;
// lower limit
int
k_tile_count
=
max
(
0
,
min
(
k_tile_total
,
k_index
+
k_tile_per_cta
)
-
k_index
);
if
(
k_tile_count
==
0
)
{
return
;
}
// mma init
Tensor
sQ
=
make_tensor
(
make_smem_ptr
(
shared_tensors
.
smem_q
.
begin
()),
SmemLayoutQ
{});
Tensor
sKC
=
make_tensor
(
make_smem_ptr
(
shared_tensors
.
smem_kc
.
begin
()),
SmemLayoutKC
{});
Tensor
sVC
=
make_tensor
(
make_smem_ptr
(
shared_tensors
.
smem_vc
.
begin
()),
SmemLayoutVC
{});
Tensor
sP
=
make_tensor
(
make_smem_ptr
((
Element
*
)
shared_tensors
.
smem_p
.
begin
()),
SmemLayoutP
{});
Tensor
tSrQ
=
TiledMmaQK
::
make_fragment_A
(
sQ
);
Tensor
tSrKC
=
TiledMmaQK
::
make_fragment_B
(
sKC
);
Tensor
tOrP
=
TiledMmaPV
::
make_fragment_A
(
sP
);
Tensor
tOrVC
=
TiledMmaPV
::
make_fragment_B
(
sVC
);
TiledMmaQK
tiled_mma_qk
;
TiledMmaPV
tiled_mma_pv
;
Tensor
tStS
=
partition_fragment_C
(
tiled_mma_qk
,
select
<
0
,
1
>
(
TileShapeQK
{}));
Tensor
tOtO
=
partition_fragment_C
(
tiled_mma_pv
,
select
<
0
,
1
>
(
TileShapePV
{}));
tiled_mma_pv
.
accumulate_
=
UMMA
::
ScaleOut
::
Zero
;
pipeline_mma_s
.
producer_acquire
(
pipeline_mma_s_producer_state
);
// Mma S0 S1 O0 S2 O1 ... Sn On-1 On
// S0 ownership -- ----- -- --
// S1 ownership -- ----- ----
// O ownership -- -- ---- --
tiled_mma_qk
.
accumulate_
=
UMMA
::
ScaleOut
::
Zero
;
for
(
int
i
=
0
;
i
<
IterationsQK
;
i
++
)
{
pipeline_load_qk
.
consumer_wait
(
pipeline_load_qk_consumer_state
);
int
read_stage
=
pipeline_load_qk_consumer_state
.
index
();
tStS
.
data
()
=
uint32_t
(
pipeline_mma_s_producer_state
.
index
()
==
0
?
TmemAllocation
::
kS0
:
TmemAllocation
::
kS1
);
CUTLASS_PRAGMA_UNROLL
for
(
int
k_block
=
0
;
k_block
<
size
<
2
>
(
tSrQ
);
++
k_block
)
{
cute
::
gemm
(
tiled_mma_qk
,
tSrQ
(
_
,
_
,
k_block
,
i
),
tSrKC
(
_
,
_
,
k_block
,
read_stage
),
tStS
);
tiled_mma_qk
.
accumulate_
=
UMMA
::
ScaleOut
::
One
;
}
pipeline_load_qk
.
consumer_release
(
pipeline_load_qk_consumer_state
);
++
pipeline_load_qk_consumer_state
;
}
pipeline_mma_s
.
producer_commit
(
pipeline_mma_s_producer_state
);
++
pipeline_mma_s_producer_state
;
k_tile_count
-=
1
;
CUTLASS_PRAGMA_NO_UNROLL
while
(
k_tile_count
>
0
)
{
pipeline_mma_s
.
producer_acquire
(
pipeline_mma_s_producer_state
);
tiled_mma_qk
.
accumulate_
=
UMMA
::
ScaleOut
::
Zero
;
for
(
int
i
=
0
;
i
<
IterationsQK
;
i
++
)
{
pipeline_load_qk
.
consumer_wait
(
pipeline_load_qk_consumer_state
);
int
read_stage
=
pipeline_load_qk_consumer_state
.
index
();
tStS
.
data
()
=
uint32_t
(
pipeline_mma_s_producer_state
.
index
()
==
0
?
TmemAllocation
::
kS0
:
TmemAllocation
::
kS1
);
CUTLASS_PRAGMA_UNROLL
for
(
int
k_block
=
0
;
k_block
<
size
<
2
>
(
tSrQ
);
++
k_block
)
{
cute
::
gemm
(
tiled_mma_qk
,
tSrQ
(
_
,
_
,
k_block
,
i
),
tSrKC
(
_
,
_
,
k_block
,
read_stage
),
tStS
);
tiled_mma_qk
.
accumulate_
=
UMMA
::
ScaleOut
::
One
;
}
pipeline_load_qk
.
consumer_release
(
pipeline_load_qk_consumer_state
);
++
pipeline_load_qk_consumer_state
;
}
pipeline_mma_s
.
producer_commit
(
pipeline_mma_s_producer_state
);
++
pipeline_mma_s_producer_state
;
pipeline_mma_o
.
producer_acquire
(
pipeline_mma_o_producer_state
);
pipeline_p_mma
.
consumer_wait
(
pipeline_p_mma_consumer_state
);
for
(
int
i
=
0
;
i
<
IterationsPV_K
;
i
++
)
{
auto
acc_flag
=
tiled_mma_pv
.
accumulate_
;
for
(
int
j
=
0
;
j
<
IterationsPV_N
;
j
++
)
{
pipeline_load_pv
.
consumer_wait
(
pipeline_load_pv_consumer_state
);
int
read_stage
=
pipeline_load_pv_consumer_state
.
index
();
tOtO
.
data
()
=
uint32_t
(
TmemAllocation
::
kO0
)
+
j
*
uint32_t
(
TmemAllocation
::
kSizeAccO
);
tiled_mma_pv
.
accumulate_
=
acc_flag
;
CUTLASS_PRAGMA_UNROLL
for
(
int
k_block
=
0
;
k_block
<
size
<
2
>
(
tOrP
);
++
k_block
)
{
cute
::
gemm
(
tiled_mma_pv
,
tOrP
(
_
,
_
,
k_block
,
make_coord
(
i
,
pipeline_p_mma_consumer_state
.
index
())),
tOrVC
(
_
,
_
,
k_block
,
read_stage
),
tOtO
);
tiled_mma_pv
.
accumulate_
=
UMMA
::
ScaleOut
::
One
;
}
pipeline_load_pv
.
consumer_release
(
pipeline_load_pv_consumer_state
);
++
pipeline_load_pv_consumer_state
;
}
}
pipeline_p_mma
.
consumer_release
(
pipeline_p_mma_consumer_state
);
++
pipeline_p_mma_consumer_state
;
pipeline_mma_o
.
producer_commit
(
pipeline_mma_o_producer_state
);
++
pipeline_mma_o_producer_state
;
--
k_tile_count
;
}
pipeline_mma_o
.
producer_acquire
(
pipeline_mma_o_producer_state
);
pipeline_p_mma
.
consumer_wait
(
pipeline_p_mma_consumer_state
);
for
(
int
i
=
0
;
i
<
IterationsPV_K
;
i
++
)
{
auto
acc_flag
=
tiled_mma_pv
.
accumulate_
;
for
(
int
j
=
0
;
j
<
IterationsPV_N
;
j
++
)
{
pipeline_load_pv
.
consumer_wait
(
pipeline_load_pv_consumer_state
);
int
read_stage
=
pipeline_load_pv_consumer_state
.
index
();
tOtO
.
data
()
=
uint32_t
(
TmemAllocation
::
kO0
)
+
j
*
uint32_t
(
TmemAllocation
::
kSizeAccO
);
tiled_mma_pv
.
accumulate_
=
acc_flag
;
CUTLASS_PRAGMA_UNROLL
for
(
int
k_block
=
0
;
k_block
<
size
<
2
>
(
tOrP
);
++
k_block
)
{
cute
::
gemm
(
tiled_mma_pv
,
tOrP
(
_
,
_
,
k_block
,
make_coord
(
i
,
pipeline_p_mma_consumer_state
.
index
())),
tOrVC
(
_
,
_
,
k_block
,
read_stage
),
tOtO
);
tiled_mma_pv
.
accumulate_
=
UMMA
::
ScaleOut
::
One
;
}
pipeline_load_pv
.
consumer_release
(
pipeline_load_pv_consumer_state
);
++
pipeline_load_pv_consumer_state
;
}
}
pipeline_p_mma
.
consumer_release
(
pipeline_p_mma_consumer_state
);
++
pipeline_p_mma_consumer_state
;
pipeline_mma_o
.
producer_commit
(
pipeline_mma_o_producer_state
);
++
pipeline_mma_o_producer_state
;
}
template
<
class
IsLastTile
>
CUTLASS_DEVICE
void
softmax
(
IsLastTile
const
&
is_last_tile
,
ElementAcc
&
row_max
,
ElementAcc
&
row_sum
,
ElementAcc
&
correction_factor
,
ProblemShape
const
&
problem_shape
,
MainloopArguments
const
&
mainloop_args
,
TensorStorage
&
shared_tensors
,
int
k_index
,
uint32_t
tmem_s
,
int
smem_p_index
)
{
auto
load_op
=
cute
::
SM100_TMEM_LOAD_32dp32b32x
{};
TiledMmaQK
tiled_mma_qk
;
Tensor
tStS
=
partition_fragment_C
(
tiled_mma_qk
,
select
<
0
,
1
>
(
TileShapeQK
{}));
tStS
.
data
()
=
tmem_s
;
CUTE_STATIC_ASSERT_V
(
shape
<
1
>
(
tStS
)
==
_1
{});
CUTE_STATIC_ASSERT_V
(
shape
<
2
>
(
tStS
)
==
_1
{});
Tensor
tAcc
=
tStS
(
make_coord
(
_
,
_
),
_0
{},
_0
{});
Tensor
cS
=
make_identity_tensor
(
take
<
0
,
2
>
(
CtaShapeQK
{}));
auto
tiled_t2r
=
make_tmem_copy
(
load_op
,
tAcc
);
auto
thread_idx
=
threadIdx
.
x
%
size
(
tiled_t2r
);
auto
thread_t2r
=
tiled_t2r
.
get_slice
(
thread_idx
);
Tensor
tTR_cS
=
thread_t2r
.
partition_D
(
cS
);
Tensor
tTR_rAcc
=
make_tensor
<
ElementAcc
>
(
shape
(
tTR_cS
));
Tensor
tTR_rS_frag
=
make_tensor
<
Element
>
(
shape
(
tTR_rAcc
));
const
int
AlignmentS
=
4
;
Tensor
tTR_tAcc
=
thread_t2r
.
partition_S
(
tAcc
);
Tensor
tTR_rAcc_vec
=
recast
<
Array
<
ElementAcc
,
AlignmentS
>>
(
tTR_rAcc
);
Tensor
tTR_rS_vec
=
recast
<
Array
<
Element
,
AlignmentS
>>
(
tTR_rS_frag
);
// load s
copy
(
tiled_t2r
,
tTR_tAcc
,
tTR_rAcc
);
if
(
is_last_tile
)
{
for
(
int
i
=
0
;
i
<
size
(
tTR_rAcc
);
i
++
)
{
if
(
get
<
1
>
(
tTR_cS
(
i
))
+
TileShapeS
{}
*
k_index
>=
get
<
1
>
(
problem_shape
))
{
tTR_rAcc
(
i
)
=
-
std
::
numeric_limits
<
ElementAcc
>::
infinity
();
}
}
}
// max
ElementAcc
row_max_new
=
row_max
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
(
tTR_rAcc
);
i
+=
1
)
{
row_max_new
=
::
fmax
(
row_max_new
,
tTR_rAcc
(
i
));
}
// for 2x2 dp, reduce here
if
constexpr
(
kWarpsInN
>
1
)
{
shared_tensors
.
smem_exchange
[
threadIdx
.
x
]
=
row_max_new
;
cutlass
::
arch
::
NamedBarrier
(
kNumComputeWarps
*
NumThreadsPerWarp
,
kNamedBarrierExchange
).
sync
();
// (64, 2) shape
int
peer_index
=
(
threadIdx
.
x
+
64
)
%
128
;
row_max_new
=
cutlass
::
max
(
row_max_new
,
shared_tensors
.
smem_exchange
[
peer_index
]);
}
#ifndef B2B
// find correction factor
ElementAcc
softmax_scale_log2
=
mainloop_args
.
softmax_scale
*
static_cast
<
ElementAcc
>
(
M_LOG2E
);
correction_factor
=
::
exp2f
(
softmax_scale_log2
*
(
row_max
-
row_max_new
));
row_max
=
row_max_new
;
// softmax
ElementAcc
row_max_scale_log2
=
row_max
*
softmax_scale_log2
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
(
tTR_rAcc
);
i
++
)
{
tTR_rAcc
(
i
)
=
::
exp2f
(
softmax_scale_log2
*
tTR_rAcc
(
i
)
-
row_max_scale_log2
);
}
#endif
// quantize
cutlass
::
NumericArrayConverter
<
Element
,
ElementAcc
,
AlignmentS
>
epilogue_op
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
(
tTR_rAcc_vec
);
i
++
)
{
tTR_rS_vec
(
i
)
=
epilogue_op
(
tTR_rAcc_vec
(
i
));
}
Tensor
sP
=
make_tensor
(
make_smem_ptr
((
Element
*
)
shared_tensors
.
smem_p
.
begin
()),
SmemLayoutP
{})(
_
,
_
,
_
,
make_coord
(
_
,
smem_p_index
));
Tensor
tOcP
=
TiledMmaPV
{}.
get_slice
(
_0
{}).
partition_A
(
cS
);
// have a mapping for each thread to coord
// find identical mapping to coords for the MMA
auto
l
=
make_ordered_layout
(
make_shape
(
make_shape
(
_64
{},
_2
{}),
make_shape
(
_16
{},
TileShapeS
{}
/
_32
{})),
make_stride
(
make_stride
(
_0
{},
_3
{}),
make_stride
(
_1
{},
_2
{})));
auto
sP_
=
as_position_independent_swizzle_tensor
(
sP
);
copy_aligned
(
tTR_rS_frag
,
sP_
.
compose
(
l
)(
threadIdx
.
x
,
_
));
// sum
row_sum
*=
correction_factor
;
static_assert
(
cute
::
is_same_v
<
ElementAcc
,
float
>
);
auto
tTR_rAcc_float2
=
recast
<
float2
>
(
tTR_rAcc
);
auto
sums
=
make_tensor
<
float2
>
(
_4
{});
static_assert
(
size
(
tTR_rAcc_float2
)
%
size
(
sums
)
==
0
);
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
(
sums
);
i
++
)
{
sums
(
i
)
=
tTR_rAcc_float2
(
i
);
}
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
size
(
sums
);
i
<
size
(
tTR_rAcc_float2
);
i
+=
size
(
sums
))
{
CUTLASS_PRAGMA_UNROLL
for
(
int
j
=
0
;
j
<
size
(
sums
);
j
++
)
{
cute
::
add
(
sums
(
j
),
sums
(
j
),
tTR_rAcc_float2
(
i
+
j
));
}
}
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
1
;
i
<
size
(
sums
);
i
*=
2
)
{
CUTLASS_PRAGMA_UNROLL
for
(
int
j
=
0
;
j
<
size
(
sums
);
j
+=
2
*
i
)
{
cute
::
add
(
sums
(
j
),
sums
(
j
),
sums
(
j
+
i
));
}
}
row_sum
+=
sums
(
0
).
x
+
sums
(
0
).
y
;
}
CUTLASS_DEVICE
void
rescale
(
ElementAcc
correction_factor
,
uint32_t
tmem_o
)
{
// for b2b gemm, do nothing
#ifndef B2B
auto
load_op
=
cute
::
SM100_TMEM_LOAD_32dp32b32x
{};
auto
store_op
=
TMEM
::
tmem_load_to_store
(
load_op
);
TiledMmaPV
tiled_mma_pv
;
Tensor
tOtO
=
partition_fragment_C
(
tiled_mma_pv
,
select
<
0
,
1
>
(
TileShapePV
{}));
tOtO
.
data
()
=
tmem_o
;
CUTE_STATIC_ASSERT_V
(
shape
<
1
>
(
tOtO
)
==
_1
{});
CUTE_STATIC_ASSERT_V
(
shape
<
2
>
(
tOtO
)
==
_1
{});
Tensor
tAcc
=
tOtO
(
make_coord
(
_
,
_
),
_0
{},
_0
{});
auto
cta_tiler_pv
=
take
<
0
,
2
>
(
typename
CollectiveMmaPV
::
CtaShape_MNK
{});
Tensor
gO
=
make_tensor
(
make_gmem_ptr
((
ElementAcc
*
)
nullptr
),
cta_tiler_pv
,
make_stride
(
0
,
0
));
auto
tiled_t2r
=
make_tmem_copy
(
load_op
,
tAcc
);
auto
tiled_r2t
=
make_tmem_copy
(
store_op
,
tAcc
);
auto
thread_idx
=
threadIdx
.
x
%
size
(
tiled_t2r
);
auto
thread_t2r
=
tiled_t2r
.
get_slice
(
thread_idx
);
auto
thread_r2t
=
tiled_r2t
.
get_slice
(
thread_idx
);
Tensor
tTR_gO
=
thread_t2r
.
partition_D
(
gO
);
Tensor
tTR_rAcc
=
make_tensor
<
ElementAcc
>
(
shape
(
tTR_gO
));
Tensor
tTR_tAcc
=
thread_t2r
.
partition_S
(
tAcc
);
// load o
copy
(
tiled_t2r
,
tTR_tAcc
,
tTR_rAcc
);
// multiply by correction factor
float2
correction_factor_vec
=
make_float2
(
correction_factor
,
correction_factor
);
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
(
tTR_rAcc
);
i
+=
2
)
{
float2
in
=
make_float2
(
tTR_rAcc
(
i
+
0
),
tTR_rAcc
(
i
+
1
));
float2
out
;
cute
::
mul
(
out
,
in
,
correction_factor_vec
);
tTR_rAcc
(
i
+
0
)
=
out
.
x
;
tTR_rAcc
(
i
+
1
)
=
out
.
y
;
}
// store o
copy
(
tiled_r2t
,
tTR_rAcc
,
tTR_tAcc
);
#endif
}
template
<
class
BlkCoord
>
CUTLASS_DEVICE
void
epilogue
(
ElementAcc
&
row_max
,
ElementAcc
&
row_sum
,
BlkCoord
const
&
cta_coord
,
ProblemShape
const
&
problem_shape
,
MainloopArguments
const
&
mainloop_args
,
EpilogueParams
const
&
epilogue_args
,
TensorStorage
&
shared_tensors
,
uint32_t
tmem_o
,
int
const
&
split_kv
)
{
auto
load_op
=
cute
::
SM100_TMEM_LOAD_32dp32b32x
{};
TiledMmaPV
tiled_mma_pv
;
Tensor
tOtO
=
TiledMmaPV
::
make_fragment_C
(
partition_shape_C
(
TiledMmaPV
{},
take
<
0
,
2
>
(
TileShapePV
{})));
tOtO
.
data
()
=
tmem_o
;
CUTE_STATIC_ASSERT_V
(
shape
<
1
>
(
tOtO
)
==
_1
{});
CUTE_STATIC_ASSERT_V
(
shape
<
2
>
(
tOtO
)
==
_1
{});
Tensor
tAcc
=
tOtO
(
make_coord
(
_
,
_
),
_0
{},
_0
{});
auto
[
H
,
K
,
D
,
B
]
=
problem_shape
;
auto
[
D_latent
,
D_rope
]
=
D
;
if
(
epilogue_args
.
ptr_o_acc
!=
nullptr
)
{
using
ElementOutAcc
=
ElementAcc
;
constexpr
auto
AlignmentOutAcc
=
128
/
cute
::
sizeof_bits_v
<
ElementOutAcc
>
;
Tensor
mO
=
make_tensor
(
make_gmem_ptr
(
epilogue_args
.
ptr_o_acc
+
get
<
3
>
(
cta_coord
)
*
D_latent
),
make_shape
(
H
,
D_latent
,
B
),
epilogue_args
.
stride_o_acc
);
auto
cta_tiler_pv
=
take
<
0
,
2
>
(
typename
CollectiveMmaPV
::
CtaShape_MNK
{});
Tensor
gO
=
local_tile
(
mO
,
cta_tiler_pv
,
take
<
0
,
3
>
(
cta_coord
));
auto
tiled_t2r
=
make_tmem_copy
(
load_op
,
tAcc
);
auto
thread_idx
=
threadIdx
.
x
%
size
(
tiled_t2r
);
auto
thread_t2r
=
tiled_t2r
.
get_slice
(
thread_idx
);
Tensor
tTR_gO
=
thread_t2r
.
partition_D
(
gO
);
Tensor
tTR_rAcc
=
make_tensor
<
ElementAcc
>
(
shape
(
tTR_gO
));
Tensor
tTR_rO_frag
=
make_tensor
<
ElementOutAcc
>
(
shape
(
tTR_rAcc
));
Tensor
tTR_rO_src
=
recast
<
Array
<
ElementOutAcc
,
AlignmentOutAcc
>>
(
coalesce
(
tTR_rO_frag
));
Tensor
tR2G_rO_dst
=
recast
<
Array
<
ElementOutAcc
,
AlignmentOutAcc
>>
(
coalesce
(
tTR_gO
));
Tensor
tTR_tAcc
=
thread_t2r
.
partition_S
(
tAcc
);
copy
(
tiled_t2r
,
tTR_tAcc
,
tTR_rAcc
);
cutlass
::
epilogue
::
thread
::
LinearCombination
<
ElementOutAcc
,
1
,
ElementAcc
,
ElementAcc
,
cutlass
::
epilogue
::
thread
::
ScaleType
::
OnlyAlphaScaling
>
epilogue_op
({
epilogue_args
.
output_scale
/
row_sum
});
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
(
tTR_rAcc
);
i
++
)
{
tTR_rO_frag
(
i
)
=
epilogue_op
(
tTR_rAcc
(
i
));
}
copy
(
tTR_rO_src
,
tR2G_rO_dst
);
#ifndef B2B
// compute LSE
ElementAcc
lse
=
cutlass
::
fast_log
(
row_sum
)
+
mainloop_args
.
softmax_scale
*
row_max
;
// store LSE
Tensor
mLSE
=
make_tensor
(
make_gmem_ptr
(
epilogue_args
.
ptr_lse_acc
+
H
*
get
<
3
>
(
cta_coord
)),
make_shape
(
H
,
B
),
epilogue_args
.
stride_lse_acc
);
Tensor
gLSE
=
local_tile
(
mLSE
,
append
<
3
>
(
cta_tiler_pv
,
_1
{}),
take
<
0
,
3
>
(
cta_coord
),
Step
<
_1
,
Underscore
,
_1
>
{});
// for 2x2 dp, this must be conditional and the index is wrong
if
(
!
kIs2Sm
||
(
threadIdx
.
x
<
64
))
{
gLSE
(
threadIdx
.
x
)
=
lse
;
}
#endif
}
else
{
Tensor
mO
=
make_tensor
(
make_gmem_ptr
(
epilogue_args
.
ptr_o
),
make_shape
(
H
,
D_latent
,
B
),
epilogue_args
.
stride_o
);
auto
cta_tiler_pv
=
take
<
0
,
2
>
(
typename
CollectiveMmaPV
::
CtaShape_MNK
{});
Tensor
gO
=
local_tile
(
mO
,
cta_tiler_pv
,
take
<
0
,
3
>
(
cta_coord
));
auto
tiled_t2r
=
make_tmem_copy
(
load_op
,
tAcc
);
auto
thread_idx
=
threadIdx
.
x
%
size
(
tiled_t2r
);
auto
thread_t2r
=
tiled_t2r
.
get_slice
(
thread_idx
);
Tensor
tTR_gO
=
thread_t2r
.
partition_D
(
gO
);
Tensor
tTR_rAcc
=
make_tensor
<
ElementAcc
>
(
shape
(
tTR_gO
));
Tensor
tTR_rO_frag
=
make_tensor
<
ElementOut
>
(
shape
(
tTR_rAcc
));
Tensor
tTR_rO_src
=
recast
<
Array
<
ElementOut
,
AlignmentOut
>>
(
coalesce
(
tTR_rO_frag
));
Tensor
tR2G_rO_dst
=
recast
<
Array
<
ElementOut
,
AlignmentOut
>>
(
coalesce
(
tTR_gO
));
Tensor
tTR_tAcc
=
thread_t2r
.
partition_S
(
tAcc
);
copy
(
tiled_t2r
,
tTR_tAcc
,
tTR_rAcc
);
cutlass
::
epilogue
::
thread
::
LinearCombination
<
ElementOut
,
1
,
ElementAcc
,
ElementAcc
,
cutlass
::
epilogue
::
thread
::
ScaleType
::
OnlyAlphaScaling
>
epilogue_op
({
epilogue_args
.
output_scale
/
row_sum
});
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
(
tTR_rAcc
);
i
++
)
{
tTR_rO_frag
(
i
)
=
epilogue_op
(
tTR_rAcc
(
i
));
}
copy
(
tTR_rO_src
,
tR2G_rO_dst
);
#ifndef B2B
if
(
epilogue_args
.
ptr_lse
!=
nullptr
)
{
// compute LSE
ElementAcc
lse
=
cutlass
::
fast_log
(
row_sum
)
+
mainloop_args
.
softmax_scale
*
row_max
;
// store LSE
Tensor
mLSE
=
make_tensor
(
make_gmem_ptr
(
epilogue_args
.
ptr_lse
),
make_shape
(
H
,
B
),
epilogue_args
.
stride_lse
);
Tensor
gLSE
=
local_tile
(
mLSE
,
append
<
3
>
(
cta_tiler_pv
,
_1
{}),
take
<
0
,
3
>
(
cta_coord
),
Step
<
_1
,
Underscore
,
_1
>
{});
// for 2x2 dp, this must be conditional and the index is wrong
if
(
!
kIs2Sm
||
(
threadIdx
.
x
<
64
))
{
gLSE
(
threadIdx
.
x
)
=
lse
;
}
}
#endif
}
}
template
<
class
CtaCoord
>
CUTLASS_DEVICE
void
compute
(
CtaCoord
const
&
cta_coord
,
ProblemShape
const
&
problem_shape
,
MainloopArguments
const
&
mainloop_args
,
EpilogueParams
const
&
epilogue_args
,
TensorStorage
&
shared_tensors
,
PipelineS
&
pipeline_mma_s
,
typename
PipelineS
::
PipelineState
&
pipeline_mma_s_consumer_state
,
PipelineP
&
pipeline_p_mma
,
typename
PipelineP
::
PipelineState
&
pipeline_p_mma_producer_state
,
PipelineO
&
pipeline_mma_o
,
typename
PipelineO
::
PipelineState
&
pipeline_mma_o_consumer_state
,
int
const
&
split_kv
)
{
auto
[
H
,
K
,
D
,
B
]
=
problem_shape
;
int
k_tile_total
=
ceil_div
(
K
,
TileShapeS
{});
int
k_tile_per_cta
=
ceil_div
(
k_tile_total
,
split_kv
);
int
k_index
=
get
<
3
>
(
cta_coord
)
*
k_tile_per_cta
;
// lower limit
int
k_tile_count
=
max
(
0
,
min
(
k_tile_total
,
k_index
+
k_tile_per_cta
)
-
k_index
);
if
(
k_tile_count
==
0
)
{
// if we return early, we have to make sure we release the load warp
cutlass
::
arch
::
NamedBarrier
(
(
kNumComputeWarps
+
kNumLoadWarps
)
*
NumThreadsPerWarp
,
kNamedBarrierEpilogue
).
arrive
();
return
;
}
int
k_index_final
=
k_tile_total
-
1
;
ElementAcc
row_max
=
-
std
::
numeric_limits
<
ElementAcc
>::
infinity
();
ElementAcc
row_sum
=
0
;
ElementAcc
correction_factor
=
1
;
pipeline_p_mma
.
producer_acquire
(
pipeline_p_mma_producer_state
);
pipeline_mma_s
.
consumer_wait
(
pipeline_mma_s_consumer_state
);
auto
dispatch_bool
=
[](
bool
b
,
auto
fn
)
{
if
(
b
)
{
fn
(
cute
::
true_type
{});
}
else
{
fn
(
cute
::
false_type
{});
}
};
// softmax s0 -> p0
dispatch_bool
(
k_index
==
k_index_final
,
[
&
](
auto
is_last_tile
)
{
softmax
(
is_last_tile
,
row_max
,
row_sum
,
correction_factor
,
problem_shape
,
mainloop_args
,
shared_tensors
,
k_index
,
uint32_t
(
pipeline_mma_s_consumer_state
.
index
()
==
0
?
TmemAllocation
::
kS0
:
TmemAllocation
::
kS1
),
pipeline_p_mma_producer_state
.
index
()
);
});
k_index
+=
1
;
cutlass
::
arch
::
fence_view_async_tmem_load
();
cutlass
::
arch
::
fence_view_async_shared
();
pipeline_mma_s
.
consumer_release
(
pipeline_mma_s_consumer_state
);
++
pipeline_mma_s_consumer_state
;
pipeline_p_mma
.
producer_commit
(
pipeline_p_mma_producer_state
);
++
pipeline_p_mma_producer_state
;
k_tile_count
-=
1
;
CUTLASS_PRAGMA_NO_UNROLL
while
(
k_tile_count
>
0
)
{
pipeline_p_mma
.
producer_acquire
(
pipeline_p_mma_producer_state
);
pipeline_mma_s
.
consumer_wait
(
pipeline_mma_s_consumer_state
);
// softmax s1 -> p1
dispatch_bool
(
k_index
==
k_index_final
,
[
&
](
auto
is_last_tile
)
{
softmax
(
is_last_tile
,
row_max
,
row_sum
,
correction_factor
,
problem_shape
,
mainloop_args
,
shared_tensors
,
k_index
,
uint32_t
(
pipeline_mma_s_consumer_state
.
index
()
==
0
?
TmemAllocation
::
kS0
:
TmemAllocation
::
kS1
),
pipeline_p_mma_producer_state
.
index
()
);
});
cutlass
::
arch
::
fence_view_async_tmem_load
();
cutlass
::
arch
::
fence_view_async_shared
();
pipeline_mma_s
.
consumer_release
(
pipeline_mma_s_consumer_state
);
++
pipeline_mma_s_consumer_state
;
pipeline_p_mma
.
producer_commit
(
pipeline_p_mma_producer_state
);
++
pipeline_p_mma_producer_state
;
pipeline_mma_o
.
consumer_wait
(
pipeline_mma_o_consumer_state
);
// rescale
CUTLASS_PRAGMA_UNROLL
for
(
int
j
=
0
;
j
<
IterationsPV_N
;
j
++
)
{
rescale
(
correction_factor
,
uint32_t
(
TmemAllocation
::
kO0
)
+
j
*
uint32_t
(
TmemAllocation
::
kSizeAccO
));
}
cutlass
::
arch
::
fence_view_async_tmem_store
();
pipeline_mma_o
.
consumer_release
(
pipeline_mma_o_consumer_state
);
++
pipeline_mma_o_consumer_state
;
--
k_tile_count
;
k_index
+=
1
;
}
pipeline_mma_o
.
consumer_wait
(
pipeline_mma_o_consumer_state
);
#ifdef B2B
row_sum
=
1
;
#else
if
constexpr
(
kWarpsInN
>
1
)
{
// reduce row_sum if needed (for 2x2 dp)
shared_tensors
.
smem_exchange
[
threadIdx
.
x
]
=
row_sum
;
cutlass
::
arch
::
NamedBarrier
(
kNumComputeWarps
*
NumThreadsPerWarp
,
kNamedBarrierExchange
).
sync
();
// (64, 2) shape
int
peer_index
=
(
threadIdx
.
x
+
64
)
%
128
;
row_sum
+=
shared_tensors
.
smem_exchange
[
peer_index
];
}
#endif
cutlass
::
arch
::
NamedBarrier
((
kNumComputeWarps
+
kNumLoadWarps
)
*
NumThreadsPerWarp
,
kNamedBarrierEpilogue
).
arrive
();
// epilogue
CUTLASS_PRAGMA_UNROLL
for
(
int
j
=
0
;
j
<
IterationsPV_N
;
j
++
)
{
epilogue
(
row_max
,
row_sum
,
replace
<
1
>
(
cta_coord
,
j
),
problem_shape
,
mainloop_args
,
epilogue_args
,
shared_tensors
,
uint32_t
(
TmemAllocation
::
kO0
)
+
j
*
uint32_t
(
TmemAllocation
::
kSizeAccO
),
split_kv
);
}
cutlass
::
arch
::
fence_view_async_tmem_load
();
pipeline_mma_o
.
consumer_release
(
pipeline_mma_o_consumer_state
);
++
pipeline_mma_o_consumer_state
;
}
};
///////////////////////////////////////////////////////////////////////////////
}
// namespace cutlass::fmha::kernel
sgl-kernel/csrc/attention/cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp
0 → 100644
View file @
18efb5e8
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
// clang-format off
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/kernel_hardware_info.h"
namespace
cutlass
::
fmha
::
kernel
{
////////////////////////////////////////////////////////////////////////////////
struct
Sm100MlaIndividualTileScheduler
{
struct
Params
{
dim3
grid
;
};
bool
valid_
=
true
;
CUTLASS_DEVICE
Sm100MlaIndividualTileScheduler
(
Params
const
&
)
{}
template
<
class
ProblemShape
,
class
ClusterShape
>
static
Params
to_underlying_arguments
(
ProblemShape
const
&
problem_shape
,
KernelHardwareInfo
hw_info
,
ClusterShape
const
&
cluster_shape
,
int
const
&
split_kv
)
{
using
namespace
cute
;
dim3
grid
(
get
<
0
>
(
cluster_shape
),
get
<
3
>
(
problem_shape
)
/* Batch */
,
split_kv
/*Maximum Split KV*/
);
return
Params
{
grid
};
}
static
dim3
get_grid_shape
(
Params
const
&
params
)
{
return
params
.
grid
;
}
CUTLASS_DEVICE
bool
is_valid
()
{
return
valid_
;
}
CUTLASS_DEVICE
auto
get_block_coord
()
{
using
namespace
cute
;
return
make_coord
(
blockIdx
.
x
,
_0
{},
blockIdx
.
y
,
blockIdx
.
z
);
}
CUTLASS_DEVICE
Sm100MlaIndividualTileScheduler
&
operator
++
()
{
valid_
=
false
;
return
*
this
;
}
};
////////////////////////////////////////////////////////////////////////////////
struct
Sm100MlaPersistentTileScheduler
{
struct
Params
{
int
num_blocks
;
FastDivmod
divmod_m_block
;
FastDivmod
divmod_b
;
FastDivmod
divmod_split_kv
;
KernelHardwareInfo
hw_info
;
};
int
block_idx
=
0
;
Params
params
;
CUTLASS_DEVICE
Sm100MlaPersistentTileScheduler
(
Params
const
&
params
)
:
block_idx
(
blockIdx
.
x
),
params
(
params
)
{}
template
<
class
ProblemShape
,
class
ClusterShape
>
static
Params
to_underlying_arguments
(
ProblemShape
const
&
problem_shape
,
KernelHardwareInfo
hw_info
,
ClusterShape
const
&
cluster_shape
,
int
const
&
split_kv
)
{
using
namespace
cute
;
// Get SM count if needed, otherwise use user supplied SM count
int
sm_count
=
hw_info
.
sm_count
;
if
(
sm_count
<=
1
||
sm_count
%
size
<
0
>
(
cluster_shape
)
!=
0
)
{
CUTLASS_TRACE_HOST
(
" WARNING: Arguments do not include a valid SM count.
\n
"
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."
);
sm_count
=
KernelHardwareInfo
::
query_device_multiprocessor_count
(
hw_info
.
device_id
);
}
CUTLASS_TRACE_HOST
(
"to_underlying_arguments(): Setting persistent grid SM count to "
<<
sm_count
);
hw_info
.
sm_count
=
sm_count
;
int
num_m_blocks
=
size
<
0
>
(
cluster_shape
);
int
num_blocks
=
num_m_blocks
*
get
<
3
>
(
problem_shape
)
/* Batch */
;
num_blocks
*=
split_kv
;
/* Maximum Split KV*/
return
Params
{
num_blocks
,
{
num_m_blocks
},
{
get
<
3
>
(
problem_shape
)
},
{
split_kv
},
hw_info
};
}
static
dim3
get_grid_shape
(
Params
const
&
params
)
{
dim3
grid
(
std
::
min
(
params
.
num_blocks
,
params
.
hw_info
.
sm_count
),
1
,
1
);
return
grid
;
}
CUTLASS_DEVICE
bool
is_valid
()
{
return
block_idx
<
params
.
num_blocks
;
}
CUTLASS_DEVICE
auto
get_block_coord
()
{
using
namespace
cute
;
int
block_decode
=
block_idx
;
int
m_block
,
bidb
,
n_split_kv
;
params
.
divmod_m_block
(
block_decode
,
m_block
,
block_decode
);
params
.
divmod_b
(
block_decode
,
bidb
,
block_decode
);
params
.
divmod_split_kv
(
block_decode
,
n_split_kv
,
block_decode
);
return
make_coord
(
m_block
,
_0
{},
bidb
,
n_split_kv
);
}
CUTLASS_DEVICE
Sm100MlaPersistentTileScheduler
&
operator
++
()
{
block_idx
+=
gridDim
.
x
;
return
*
this
;
}
};
////////////////////////////////////////////////////////////////////////////////
}
// namespace cutlass::fmha::kernel
sgl-kernel/csrc/common_extension.cc
View file @
18efb5e8
...
...
@@ -60,7 +60,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m
.
impl
(
"merge_state_v2"
,
torch
::
kCUDA
,
&
merge_state_v2
);
m
.
def
(
"cutlass_mla_decode(Tensor! out, Tensor q_nope_and_q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor "
"page_table, Tensor workspace) -> ()"
);
"page_table, Tensor
!
workspace
, int num_kv_splits
) -> ()"
);
m
.
impl
(
"cutlass_mla_decode"
,
torch
::
kCUDA
,
&
cutlass_mla_decode
);
m
.
def
(
"cutlass_mla_get_workspace_size"
,
&
cutlass_mla_get_workspace_size
);
...
...
sgl-kernel/include/sgl_kernel_ops.h
View file @
18efb5e8
...
...
@@ -109,8 +109,10 @@ void cutlass_mla_decode(
torch
::
Tensor
const
&
kv_c_and_k_pe_cache
,
torch
::
Tensor
const
&
seq_lens
,
torch
::
Tensor
const
&
page_table
,
torch
::
Tensor
const
&
workspace
);
int64_t
cutlass_mla_get_workspace_size
(
int64_t
max_seq_len
,
int64_t
num_batches
,
int64_t
sm_count
=
0
);
torch
::
Tensor
const
&
workspace
,
int64_t
num_kv_splits
=
-
1
);
int64_t
cutlass_mla_get_workspace_size
(
int64_t
max_seq_len
,
int64_t
num_batches
,
int64_t
sm_count
=
0
,
int64_t
num_kv_splits
=
-
1
);
/*
* From csrc/elementwise
*/
...
...
sgl-kernel/python/sgl_kernel/attention.py
View file @
18efb5e8
...
...
@@ -57,6 +57,7 @@ def cutlass_mla_decode(
seq_lens
:
torch
.
Tensor
,
page_table
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
num_kv_splits
:
int
=
-
1
,
)
->
torch
.
Tensor
:
assert
(
q_nope_and_q_pe
.
ndim
==
3
...
...
@@ -73,7 +74,12 @@ def cutlass_mla_decode(
f
"D_q must be equal to D_ckv and D_q must be equal to D_latent + D_rope, "
f
"but got D_q =
{
D_q
}
, D_ckv =
{
D_ckv
}
, D_latent =
{
D_latent
}
, D_rope =
{
D_rope
}
"
)
assert
H
==
128
,
f
"H must be 128, but got
{
H
}
"
MAX_HEADS
=
128
assert
H
<=
MAX_HEADS
,
f
"H must be <=
{
MAX_HEADS
}
, but got
{
H
}
"
if
H
<
MAX_HEADS
:
q_nope_and_q_pe_padded
=
q_nope_and_q_pe
.
new_empty
((
B_q
,
MAX_HEADS
,
D_q
))
q_nope_and_q_pe_padded
[:,
:
H
]
=
q_nope_and_q_pe
q_nope_and_q_pe
=
q_nope_and_q_pe_padded
assert
len
(
page_table
.
shape
)
==
2
B_block_table
,
block_num
=
page_table
.
shape
...
...
@@ -97,21 +103,25 @@ def cutlass_mla_decode(
page_table
.
dtype
==
torch
.
int32
),
f
"page_table.dtype needs to be int32 but got
{
page_table
.
dtype
}
."
out
=
torch
.
empty
(
(
B_q
,
H
,
D_latent
),
device
=
q_nope_and_q_pe
.
device
,
dtype
=
q_nope_and_q_pe
.
dtype
)
out
=
q_nope_and_q_pe
.
new_empty
((
B_q
,
MAX_HEADS
,
D_latent
))
torch
.
ops
.
sgl_kernel
.
cutlass_mla_decode
.
default
(
out
,
q_nope_and_q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
out
,
q_nope_and_q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
num_kv_splits
,
)
return
out
return
out
[:,
:
H
].
contiguous
()
def
cutlass_mla_get_workspace_size
(
max_seq_len
:
int
,
num_batches
:
int
,
sm_count
:
int
=
0
max_seq_len
:
int
,
num_batches
:
int
,
sm_count
:
int
=
0
,
num_kv_splits
:
int
=
-
1
)
->
int
:
assert
max_seq_len
>
0
,
f
"max_seq_len must be greater than 0, got
{
max_seq_len
}
"
assert
num_batches
>
0
,
f
"num_batches must be greater than 0, got
{
num_batches
}
"
return
torch
.
ops
.
sgl_kernel
.
cutlass_mla_get_workspace_size
.
default
(
max_seq_len
,
num_batches
,
sm_count
max_seq_len
,
num_batches
,
sm_count
,
num_kv_splits
)
sgl-kernel/tests/test_cutlass_mla.py
View file @
18efb5e8
...
...
@@ -40,15 +40,23 @@ def ref_mla(
@
pytest
.
mark
.
parametrize
(
"bs"
,
[
1
,
2
,
4
])
@
pytest
.
mark
.
parametrize
(
"varlen"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
1
,
16
,
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
[
16
,
32
,
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"num_kv_splits"
,
[
-
1
,
1
])
def
test_cutlass_mla_decode
(
dtype
:
torch
.
dtype
,
mean_seq_len
:
int
,
bs
:
int
,
varlen
:
bool
,
block_size
:
int
dtype
:
torch
.
dtype
,
mean_seq_len
:
int
,
bs
:
int
,
varlen
:
bool
,
block_size
:
int
,
num_heads
:
int
,
num_kv_splits
:
int
,
):
torch
.
set_default_dtype
(
dtype
)
torch
.
set_default_device
(
"cuda"
)
torch
.
manual_seed
(
42
)
d
=
576
h_q
=
128
h_q
=
num_heads
dv
=
512
q_nope_dim
=
128
...
...
@@ -67,17 +75,22 @@ def test_cutlass_mla_decode(
pack_factor
=
128
//
block_size
block_num
=
((
block_num
+
pack_factor
-
1
)
//
pack_factor
)
*
pack_factor
# Lager q values to detect split kv error
q
=
torch
.
randn
(
bs
,
h_q
,
d
)
*
100.0
block_table
=
torch
.
randint
(
0
,
bs
*
block_num
,
(
bs
,
block_num
),
dtype
=
torch
.
int32
)
kv_cache
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
d
)
workspace_size
=
cutlass_mla_get_workspace_size
(
block_num
*
block_size
,
bs
)
workspace_size
=
cutlass_mla_get_workspace_size
(
block_num
*
block_size
,
bs
,
num_kv_splits
=
num_kv_splits
)
workspace
=
torch
.
empty
(
workspace_size
,
device
=
"cuda"
,
dtype
=
torch
.
uint8
)
out_ref
=
q
.
new_zeros
(
bs
,
h_q
,
dv
)
ref_mla
(
out_ref
,
q
,
kv_cache
,
scale
,
block_table
,
seq_lens
)
out
=
cutlass_mla_decode
(
q
,
kv_cache
,
seq_lens
,
block_table
,
workspace
)
out
=
cutlass_mla_decode
(
q
,
kv_cache
,
seq_lens
,
block_table
,
workspace
,
num_kv_splits
)
torch
.
testing
.
assert_close
(
out
,
out_ref
,
atol
=
1e-2
,
rtol
=
1e-2
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment