Unverified Commit 0cb2e06d authored by Po Yen Chen's avatar Po Yen Chen Committed by GitHub
Browse files

[CK_TILE] fmha forward split-kv + combine kernels (#1338)



* FA fwd dropout

* FA bwd

* epilogue reuse

* CMakeLists update

* [CK_TILE] support alibi (#1269)

* add alibi support

* fix code

* update code based on comment

* Support more hdim

* fix fp8 bias

* support seqlen_k=0 case

* remove unused printf

* fix format

---------
Co-authored-by: default avatarrocking <ChunYu.Lai@amd.com>

* now fwd/bwd can build

* bwd alibi

* add bwd validation stream_config

* update generated filenames

* update bwd kernel launch

* CK_TILE_HOST_DEVICE in philox

* Transpose -> transpose

* format

* format

* format

* Generate the instance for FA required

* format

* fix error in WarpGemm

* Add num_splits option and dummy split-kv api method

* Generate fmha_fwd_splitkv()

* Add SplitKV kernel codegen logics

* Add SplitKV combine kernel codegen logics

* Fix mismatched return type

* Clean-up code

* Replace sentinel value before storing

* Fix wrong layout of LSE/LSEacc/Oacc

* Format codes

* Fix o_acc memory error

* Fix wrong kBlockSize used in policy

* Reduce # of combine kernels

* Fix split-kv combine kernel name

* Fix wrong LDS indexing logics

* Fix wrong loop counter step logic

* Undo vector size changes

* Remove no-longer used field

* Remove in-consistent comment

* Remove debug statements in example

* Remove more debug statements

* Add constness to local variables

* Clearn up generate.py

* Fix unstable clang-format comment

* Remove unused include directive

* Use shorter template parameter name

* Enable non-split-kv blobs

* Update license date

* Print num_splits conditionally

* Undo disabling data types

* Remove unnessary tile size for fp8

* Fix wrong pipeline args for fp8

* Fix example output format

* Remove more debug code in combine pipeline

* Add stride kernel arguments for LSE/O acc workspace

* Re-order split-kv pipeline call operator arguments

* Pass LSE/O strides in kernel argument

* Re-order pipeline call operator arguments

* Use tensor_descriptor to locate LSEacc elements

* Support providing invalid element for tensor view

* Set invalid element value for LSEacc tensor view

* Remove hand-written store_tile() code

* Remove necessary value-overwrite logic

* Add transposed lds descriptor

* Support load_tile() for tile_window_with_static_lengths<>

* Undo removing necessary value-overwrite logic

* Use read descriptor to locate lds elements

* Simplify pipeline source code

* Add constraint to kMaxSplits

* Default use kMaxSplits=64 in generate.py

* Revert "Add constraint to kMaxSplits"

This reverts commit 0a2132d758042e6fb0292f4e354909b8a4d1c118.

* Revert "Default use kMaxSplits=64 in generate.py"

This reverts commit c7d9c80b77320aec6559222bed7d47adcaefe4e3.

* Decide alignment by the padding parameter

* Remove no-longer used utility functions

* Remove not-working code

* Add comment & remove no-longer used code

* Fix computation errors

* Add heuristic to override num_splits option

* Add constraint to kMaxSplits

* Fix compilation error

* Clean up pipeline code

* Wrap pointer access as lambda function

* Rename confusing methods

* Use kLogMasSplits as template parameter

* Finish splitkv combine kernel codegen

* Update kMaxSplits limit

* Use smaller kM0 for splitkv combine kernel

* Ignore droupout flag in splitkv pipeline

* Unify flag usage

* Add back flag kStoreLSE

* Merge lambda calls in pipeline

* Fix compilation errors

* Avoid all empty splits

* Always check for empty loop in splitkv pipelines

* Re-order parameters

* Remove redundant p_drop option check

* Add traits/problem for fwd splitkv kernel

* Conditionally enable uneven split boundary checks

* Add comment for the splitkv traits field

* Change even split criteria

* Re-order statements

* Refine occupancy value for hdim=128&256

* Refine occupancy value for hdim=32&64

* Remove redundant kernel argument

* Separate fmha bwd codegen logics

* Separate fmha fwd codegen logics

* Remove redundant direction parameter in fwd&bwd codegen logics

* Support generate multiple APIs for an example

* Let 'api' an alias of 'direction' option

* Remove choices for the 'direction' option

* Use dictionary to config all the functions

* Move fmha splitkv codegen logics to other file

* Add fwd_splitkv api for tile_example_fmha_fwd

---------

Co-authored-by: danyao12 <danyao12>
Co-authored-by: default avatarcarlushuang <carlus.huang@amd.com>
Co-authored-by: default avatarrocking <ChunYu.Lai@amd.com>
Co-authored-by: default avatarJing Zhang <jizhan@amd.com>
parent 3e9711f0
# generate a list of kernels, but not actually emit files at config stage
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--direction fwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
--api fwd,fwd_splitkv --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
)
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--direction bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt
--api bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt
)
# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory
......@@ -17,13 +17,13 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
add_custom_command(
OUTPUT ${FMHA_FWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--direction fwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}
--api fwd,fwd_splitkv --output_dir ${CMAKE_CURRENT_BINARY_DIR}
)
add_custom_command(
OUTPUT ${FMHA_BWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--direction bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}
--api bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}
)
set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd")
......
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
GEN_DIR = "" # in Cmake, have to generate files in same folder
\ No newline at end of file
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
DTYPE_MAP = {
"fp16": "ck_tile::fp16_t",
"bf16": "ck_tile::bf16_t",
"fp8" : "ck_tile::fp8_t"
}
MASK_IMPL = {
"generic" : "ck_tile::GenericAttentionMask",
"simplified" : "ck_tile::SimplifiedGenericAttentionMask"
}
_MASK_SIMPLIFIED_MAP = {
"s_no" : "ck_tile::SimplifiedGenericAttentionMask<false>",
"s_mask" : "ck_tile::SimplifiedGenericAttentionMask<true>",
}
_MASK_MAP = {
"no" : "FmhaMasks::NoMask",
"causal" : "FmhaMasks::CausalMask",
"generic" : "FmhaMasks::GenericMask"
}
def get_mask_map(mask : str):
if mask == "generic":
return _MASK_MAP
elif mask == "simplified":
return _MASK_SIMPLIFIED_MAP
else:
assert False
return None
_MASK_CHECK_MAP = {
"no" : "t.mask_type == mask_enum::no_mask",
"causal" : "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right",
"generic" : "t.mask_type == mask_enum::window_generic",
}
_MASK_SIMPLIFIED_CHECK_MAP = {
"s_no" : "t.mask_type == mask_enum::no_mask",
"s_mask" : "t.mask_type != mask_enum::no_mask",
}
def get_mask_check_map(mask : str):
if mask == "generic":
return _MASK_CHECK_MAP
elif mask == "simplified":
return _MASK_SIMPLIFIED_CHECK_MAP
else:
assert False
return None
BIAS_MAP = {
"no" : "ck_tile::BlockAttentionBiasEnum::NO_BIAS",
"bias" : "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS",
"alibi" : "ck_tile::BlockAttentionBiasEnum::ALIBI"
}
# TODO: this is ugly
BIAS_CHECK_MAP = {
"no" : "bias_enum::no_bias",
"bias" : "bias_enum::elementwise_bias",
"alibi" : "bias_enum::alibi"
}
MODE_MAP = {
"batch" : "false",
"group" : "true"
}
LAYOUT_MAP = {
"row" : "true",
"col" : "false"
}
PIPELINE_MAP = {
"qr" : "ck_tile::BlockFmhaPipelineQRKSVS",
"qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync",
}
PIPELINE_ENUM_MAP = {
"qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
}
BOOL_MAP = {
"t" : "true",
"f" : "false"
}
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -114,6 +114,9 @@ auto create_args(int argc, char* argv[])
.insert("drop_seed", "1", "seed for random number generator")
.insert("drop_offset", "0", "offset for random number generator")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("num_splits",
"1",
"# of splits for key/value. 0 to determine actual number by heuristic")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel");
......@@ -155,6 +158,106 @@ auto get_elimit<ck_tile::fp8_t>(std::string init_method)
}
}
int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int num_n_blocks, int max_splits)
{
// If we have enough to almost fill the SMs, then just use 1 split
if(batch_nhead_mblocks >= 0.8f * num_SMs)
{
return 1;
}
max_splits = std::min({max_splits, num_SMs, num_n_blocks});
float max_efficiency = 0.f;
std::vector<float> efficiency;
efficiency.reserve(max_splits);
auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
// Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
// we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
// (i.e. it's 11 splits anyway).
// So we check if the number of blocks per split is the same as the previous num_splits.
auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
return num_splits == 1 ||
ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
};
for(int num_splits = 1; num_splits <= max_splits; num_splits++)
{
if(!is_split_eligible(num_splits))
{
efficiency.push_back(0.f);
}
else
{
float n_waves = float(batch_nhead_mblocks * num_splits) / num_SMs;
float eff = n_waves / ceil(n_waves);
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
if(eff > max_efficiency)
{
max_efficiency = eff;
}
efficiency.push_back(eff);
}
}
for(int num_splits = 1; num_splits <= max_splits; num_splits++)
{
if(!is_split_eligible(num_splits))
{
continue;
}
if(efficiency[num_splits - 1] >= 0.85 * max_efficiency)
{
// printf("num_splits chosen = %d\n", num_splits);
return num_splits;
}
}
return 1;
}
int override_num_splits_if_necessary(
int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits)
{
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess)
{
return num_splits;
}
hipDeviceProp_t props{};
status = hipGetDeviceProperties(&props, device);
if(status != hipSuccess)
{
return num_splits;
}
// tile size should match the generate.py
const int kM0 = 64;
const int kN1 = hdim_v;
const int num_m_blocks = ck_tile::integer_divide_ceil(max_seqlen_q, kM0);
const int num_n_blocks = ck_tile::integer_divide_ceil(hdim_v, kN1);
if(num_splits < 1 && p_drop == 0.0f)
{
return num_splits_heuristic(
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128);
}
return num_splits;
}
float fmha_fwd_dispatch(fmha_fwd_traits traits,
fmha_fwd_args args,
const ck_tile::stream_config& config)
{
if(1 < args.num_splits)
{
return fmha_fwd_splitkv(traits, args, config);
}
else
{
return fmha_fwd(traits, args, config);
}
}
template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
......@@ -260,6 +363,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
seed.reset();
}
int num_splits = arg_parser.get_int("num_splits");
int stream_warmup = arg_parser.get_int("warmup");
int stream_repeat = arg_parser.get_int("repeat");
bool kname = arg_parser.get_bool("kname");
......@@ -320,6 +425,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
// legalize num_splits according to other options
if(num_splits < 1)
{
num_splits = override_num_splits_if_necessary(
batch, nhead, max_seqlen_q, hdim_v, p_drop, num_splits);
}
if(128 < num_splits)
{
std::cerr << "num_splits greater than 128 is not supported" << std::endl;
return false;
}
auto get_lengths = [&](bool permute,
ck_tile::index_t b /*batch*/,
ck_tile::index_t h /*nhead*/,
......@@ -361,7 +478,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
: std::array<ck_tile::index_t, 2>{batch, nhead})
: std::array<ck_tile::index_t, 2>{1, 1});
// self define lse data layout as [shape_batch, nhead, shape_seqlen_q]
ck_tile::HostTensor<LSEDataType> lse_acc_host(
1 < num_splits ? std::array<ck_tile::index_t, 4>{num_splits, batch, nhead, max_seqlen_q}
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
ck_tile::HostTensor<OaccDataType> o_acc_host(
1 < num_splits
? std::array<ck_tile::index_t, 5>{num_splits, batch, nhead, max_seqlen_q, hdim_v}
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
// self define lse data layout as [batch, nhead, max_seqlen_q]
ck_tile::HostTensor<LSEDataType> lse_host(
lse ? std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q}
: std::array<ck_tile::index_t, 3>{1, 1, 1} /* dummy shape for simplifying code */);
......@@ -443,6 +568,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem o_acc_buf(o_acc_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
......@@ -479,7 +606,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
: (std::string("(") + std::to_string(seqlen_kpads[0]) + ")"))
<< ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias
<< ", p_drop:" << p_drop << ", lse:" << lse << ", squant:" << squant
<< ", mask:" << mask << ", v:" << vlayout << std::flush;
<< ", mask:" << mask << ", v:" << vlayout;
if(1 < num_splits)
{
std::cout << ", num_splits:" << num_splits;
}
std::cout << std::flush;
auto fmha_traits = fmha_fwd_traits{hdim_q,
hdim_v,
......@@ -523,6 +655,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}();
const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k);
const ck_tile::index_t stride_randval = (max_seqlen_k);
const ck_tile::index_t stride_o_acc = hdim_v;
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
// setup nhead_stride_* arguments
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
......@@ -537,6 +670,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
(i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k);
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t nhead_stride_lse = max_seqlen_q;
const ck_tile::index_t nhead_stride_lse_acc = max_seqlen_q;
const ck_tile::index_t nhead_stride_o_acc = (max_seqlen_q * hdim_v);
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
// setup batch_stride_* arguments
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
......@@ -545,7 +680,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k);
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_lse = (nhead * max_seqlen_q);
const ck_tile::index_t batch_stride_lse_acc = (nhead * max_seqlen_q);
const ck_tile::index_t batch_stride_o_acc = (nhead * max_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
// setup split_stride_* arguments (only used in split-kv kernel)
const ck_tile::index_t split_stride_lse_acc = (batch * nhead * max_seqlen_q);
const ck_tile::index_t split_stride_o_acc = (batch * nhead * max_seqlen_q * hdim_v);
return fmha_fwd_args{q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(),
......@@ -553,6 +693,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer()
: bias_buf.GetDeviceBuffer(),
randval_buf.GetDeviceBuffer(),
lse_acc_buf.GetDeviceBuffer(),
o_acc_buf.GetDeviceBuffer(),
lse_buf.GetDeviceBuffer(),
o_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(),
......@@ -566,6 +708,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
hdim_v,
nhead,
nhead_k,
num_splits,
scale_s,
scale_p,
scale_o,
......@@ -575,6 +718,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead)
: stride_bias,
stride_randval,
stride_o_acc,
stride_o,
nhead_stride_q,
nhead_stride_k,
......@@ -582,6 +726,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
nhead_stride_bias,
nhead_stride_randval,
nhead_stride_lse,
nhead_stride_lse_acc,
nhead_stride_o_acc,
nhead_stride_o,
batch_stride_q,
batch_stride_k,
......@@ -589,7 +735,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
batch_stride_bias,
batch_stride_randval,
batch_stride_lse,
batch_stride_lse_acc,
batch_stride_o_acc,
batch_stride_o,
split_stride_lse_acc,
split_stride_o_acc,
mask.left,
mask.right,
static_cast<ck_tile::index_t>(mask.type),
......@@ -598,7 +748,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
{drop_seed, drop_offset}};
}();
float ave_time = fmha_fwd(fmha_traits, fmha_args, stream_config);
float ave_time = fmha_fwd_dispatch(fmha_traits, fmha_args, stream_config);
if(ave_time < 0)
{
......@@ -849,14 +999,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
lse_host_result.ForEach(
[&](auto& self, auto idx) { self(idx) = lse_host(wb, idx[0], idx[1]); });
bool lse_pass = ck_tile::check_err(lse_host_result,
cur_pass = ck_tile::check_err(lse_host_result,
lse_host_ref,
"LSE Error: Incorrect results!",
rtol,
atol,
/* allow_infinity_ref = */ true);
pass &= lse_pass;
pass &= cur_pass;
if(!cur_pass)
{
std::cerr << "LSE mismatch found at batch: " << wb << std::endl
......
......@@ -93,6 +93,8 @@ struct fmha_fwd_args
const void* v_ptr;
const void* bias_ptr; // bias or alibi_slope pointer
void* rand_val_ptr;
void* lse_acc_ptr;
void* o_acc_ptr;
void* lse_ptr;
void* o_ptr;
const void* seqstart_q_ptr;
......@@ -106,6 +108,7 @@ struct fmha_fwd_args
ck_tile::index_t hdim_v;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k;
ck_tile::index_t num_splits;
float scale_s;
float scale_p;
float scale_o;
......@@ -114,6 +117,7 @@ struct fmha_fwd_args
ck_tile::index_t stride_v;
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
ck_tile::index_t stride_randval;
ck_tile::index_t stride_o_acc;
ck_tile::index_t stride_o;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
......@@ -121,6 +125,8 @@ struct fmha_fwd_args
ck_tile::index_t nhead_stride_bias;
ck_tile::index_t nhead_stride_randval;
ck_tile::index_t nhead_stride_lse;
ck_tile::index_t nhead_stride_lse_acc;
ck_tile::index_t nhead_stride_o_acc;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
......@@ -128,7 +134,11 @@ struct fmha_fwd_args
ck_tile::index_t batch_stride_bias;
ck_tile::index_t batch_stride_randval;
ck_tile::index_t batch_stride_lse;
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t batch_stride_o;
ck_tile::index_t split_stride_lse_acc;
ck_tile::index_t split_stride_o_acc;
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t mask_type;
......@@ -234,6 +244,176 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
return ck_tile::make_tuple(kargs, grids);
}
template <typename Kernel>
auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(Kernel::kIsGroupMode)
{
return Kernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.rand_val_ptr,
args.lse_acc_ptr,
args.o_acc_ptr,
args.batch,
args.max_seqlen_q,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.num_splits,
args.scale_s,
args.scale_p,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_randval,
args.stride_o_acc,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_randval,
args.nhead_stride_lse_acc,
args.nhead_stride_o_acc,
args.batch_stride_lse_acc,
args.batch_stride_o_acc,
args.split_stride_lse_acc,
args.split_stride_o_acc,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
}
else
{ // create batch mode kernel arguments
return Kernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.rand_val_ptr,
args.lse_acc_ptr,
args.o_acc_ptr,
args.batch,
args.max_seqlen_q,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.num_splits,
args.scale_s,
args.scale_p,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_randval,
args.stride_o_acc,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_randval,
args.nhead_stride_lse_acc,
args.nhead_stride_o_acc,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_v,
args.batch_stride_bias,
args.batch_stride_randval,
args.batch_stride_lse_acc,
args.batch_stride_o_acc,
args.split_stride_lse_acc,
args.split_stride_o_acc,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
}
}();
dim3 grids =
Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.num_splits);
return ck_tile::make_tuple(kargs, grids);
}
template <typename Kernel>
auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = [&] {
// create group mode kernel argumentszs
if constexpr(Kernel::kIsGroupMode)
{
return Kernel::MakeKargs(args.lse_acc_ptr,
args.o_acc_ptr,
args.lse_ptr,
args.o_ptr,
args.batch,
args.max_seqlen_q,
args.seqstart_q_ptr,
args.hdim_v,
args.num_splits,
args.scale_o,
args.stride_o_acc,
args.stride_o,
args.nhead_stride_lse_acc,
args.nhead_stride_o_acc,
args.nhead_stride_lse,
args.nhead_stride_o,
args.batch_stride_lse_acc,
args.batch_stride_o_acc,
args.batch_stride_lse,
args.split_stride_lse_acc,
args.split_stride_o_acc);
}
else
{ // create batch mode kernel arguments
return Kernel::MakeKargs(args.lse_acc_ptr,
args.o_acc_ptr,
args.lse_ptr,
args.o_ptr,
args.batch,
args.max_seqlen_q,
args.seqlen_q,
args.hdim_v,
args.num_splits,
args.scale_o,
args.stride_o_acc,
args.stride_o,
args.nhead_stride_lse_acc,
args.nhead_stride_o_acc,
args.nhead_stride_lse,
args.nhead_stride_o,
args.batch_stride_lse_acc,
args.batch_stride_o_acc,
args.batch_stride_lse,
args.batch_stride_o,
args.split_stride_lse_acc,
args.split_stride_o_acc);
}
}();
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v);
return ck_tile::make_tuple(kargs, grids);
}
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <ck_tile::index_t HDim_,
typename DataType_,
......@@ -282,6 +462,40 @@ struct fmha_fwd_traits_
template <typename Traits_>
float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args);
template <typename Traits_>
void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_args);
template <typename Traits_>
std::string fmha_fwd_splitkv_get_name_();
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
ck_tile::index_t kM0_,
ck_tile::index_t kN1_,
bool kStoreLse_,
bool kDoFp8StaticQuant_,
bool kPadS_,
bool kPadDv_>
struct fmha_fwd_splitkv_combine_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr ck_tile::index_t kM0 = kM0_;
static constexpr ck_tile::index_t kN1 = kN1_;
static constexpr bool kStoreLse = kStoreLse_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadDv = kPadDv_;
};
template <typename Traits_>
void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_args);
template <typename Traits_>
std::string fmha_fwd_splitkv_combine_get_name_();
// This is the public API, will be generated by script
struct fmha_fwd_traits
{
......@@ -298,3 +512,4 @@ struct fmha_fwd_traits
// TODO: padding check is inside this api
};
float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);
float fmha_fwd_splitkv(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);
This diff is collapsed.
......@@ -10,6 +10,10 @@
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp"
......@@ -22,6 +26,12 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp"
......
......@@ -299,6 +299,23 @@ struct SimplifiedGenericAttentionMask
}
}
template <index_t TileHeight, index_t TileWidth>
CK_TILE_HOST_DEVICE constexpr auto GetTileRangeAlongX(index_t i_y,
number<TileHeight> height,
number<TileWidth> width,
index_t num_splits,
index_t i_split) const
{
auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width);
const index_t x_per_split = ck_tile::max(1, x_total / num_splits);
const index_t split_start = x_per_split * i_split;
const index_t split_end = (i_split == num_splits - 1 ? x_total : split_start + x_per_split);
return ck_tile::make_tuple(ck_tile::max(origin_start, split_start),
ck_tile::min(origin_end, split_end));
}
// to get the loop length along Y axis, return index:[start, end), end-start=length
// use this if need loop over Y axis tile by tile (like q-seqlen loopover)
// TODO: y_end still could be negative, so end-start could be negative(need check)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck_tile {
template <typename TilePartitioner_, typename FmhaPipeline_, typename EpiloguePipeline_>
struct FmhaFwdSplitKVCombineKernel
{
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using FmhaPipeline = remove_cvref_t<FmhaPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
static constexpr index_t kBlockSize = FmhaPipeline::kBlockSize;
static constexpr index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
static_assert(kBlockPerCu > 0);
static constexpr index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
using LSEDataType = remove_cvref_t<typename FmhaPipeline::LSEDataType>;
using OaccDataType = remove_cvref_t<typename FmhaPipeline::OaccDataType>;
using ODataType = remove_cvref_t<typename FmhaPipeline::ODataType>;
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
// clang-format off
template <typename T> struct t2s;
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
// clang-format on
__host__ static std::string GetName()
{
// sync with generate.py
// clang-format off
#define _SS_ std::string
#define _TS_ std::to_string
auto pn = [&] () {
std::string n;
if (kPadSeqLenQ) n += "s";
if (kPadHeadDimV) n += "dv";
return n.empty() ? n : std::string("p") + n; }();
return
_SS_("fmha_fwd_splitkv_combine_d") + _TS_(FmhaPipeline::kHeadDimV) + "_" + _SS_(t2s<ODataType>::name) +
"_" + (kIsGroupMode ? "group" : "batch") + "_"
"b" + _TS_(FmhaPipeline::kM0) + "x" +
_TS_(FmhaPipeline::kN1) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) +
_SS_(FmhaPipeline::name) +
(pn.empty() ? "" : "_" + pn) +
(kStoreLSE ? "_lse" : "" ) +
(kDoFp8StaticQuant ? "_squant" : "" );
#undef _SS_
#undef _TS_
// clang-format on
}
template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
// arg
struct EmptyKargs
{
};
// kargs use aggregate initializer, so no constructor will provided
// use inheritance to minimize karg size
// user need to use MakeKargs() function to create kargs.
struct CommonKargs
{
const void* lse_acc_ptr;
const void* o_acc_ptr;
void* o_ptr;
ck_tile::index_t batch;
ck_tile::index_t max_seqlen_q;
ck_tile::index_t seqlen_q;
ck_tile::index_t hdim_v;
ck_tile::index_t num_splits;
ck_tile::index_t row_stride_o_acc;
ck_tile::index_t row_stride_o;
ck_tile::index_t nhead_stride_lse_acc;
ck_tile::index_t nhead_stride_o_acc;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t split_stride_lse_acc;
ck_tile::index_t split_stride_o_acc;
};
struct CommonLSEKargs
{
void* lse_ptr = nullptr;
ck_tile::index_t nhead_stride_lse = 0;
ck_tile::index_t batch_stride_lse = 0;
};
struct Fp8StaticQuantKargs
{
float scale_o;
};
struct BatchModeKargs
: CommonKargs,
std::conditional_t<kStoreLSE, CommonLSEKargs, EmptyKargs<0>>,
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<1>>
{
ck_tile::index_t batch_stride_o;
};
struct GroupModeKargs
: CommonKargs,
std::conditional_t<kStoreLSE, CommonLSEKargs, EmptyKargs<0>>,
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<3>>
{
const int32_t* seqstart_q_ptr;
};
using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
template <bool Cond = !kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* lse_acc_ptr,
const void* o_acc_ptr,
void* lse_ptr,
void* o_ptr,
ck_tile::index_t batch,
ck_tile::index_t max_seqlen_q,
ck_tile::index_t seqlen_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_splits,
float scale_o,
ck_tile::index_t row_stride_o_acc,
ck_tile::index_t row_stride_o,
ck_tile::index_t nhead_stride_lse_acc,
ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_lse_acc,
ck_tile::index_t batch_stride_o_acc,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t batch_stride_o,
ck_tile::index_t split_stride_lse_acc,
ck_tile::index_t split_stride_o_acc)
{
Kargs kargs{{lse_acc_ptr,
o_acc_ptr,
o_ptr,
batch,
max_seqlen_q,
seqlen_q,
hdim_v,
num_splits,
row_stride_o_acc,
row_stride_o,
nhead_stride_lse_acc,
nhead_stride_o_acc,
nhead_stride_o,
batch_stride_lse_acc,
batch_stride_o_acc,
split_stride_lse_acc,
split_stride_o_acc}, // args for common karg
{}, // placeholder for lse
{}, // placeholder for fp8_static_quant args
batch_stride_o};
if constexpr(kStoreLSE)
{
kargs.lse_ptr = lse_ptr;
kargs.nhead_stride_lse = nhead_stride_lse;
kargs.batch_stride_lse = batch_stride_lse;
}
if constexpr(kDoFp8StaticQuant)
{
kargs.scale_o = scale_o;
}
return kargs;
}
template <bool Cond = kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* lse_acc_ptr,
const void* o_acc_ptr,
void* lse_ptr,
void* o_ptr,
ck_tile::index_t batch,
ck_tile::index_t max_seqlen_q,
const void* seqstart_q_ptr,
ck_tile::index_t hdim_v,
ck_tile::index_t num_splits,
float scale_o,
ck_tile::index_t row_stride_o_acc,
ck_tile::index_t row_stride_o,
ck_tile::index_t nhead_stride_lse_acc,
ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_lse_acc,
ck_tile::index_t batch_stride_o_acc,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t split_stride_lse_acc,
ck_tile::index_t split_stride_o_acc)
{
Kargs kargs{{lse_acc_ptr,
o_acc_ptr,
o_ptr,
batch,
max_seqlen_q,
-1, // seqlen will be updated by another pointer
hdim_v,
num_splits,
row_stride_o_acc,
row_stride_o,
nhead_stride_lse_acc,
nhead_stride_o_acc,
nhead_stride_o,
batch_stride_lse_acc,
batch_stride_o_acc,
split_stride_lse_acc,
split_stride_o_acc}, // args for common karg
{}, // placeholder for lse
{}, // placeholder for fp8_static_quant args
reinterpret_cast<const int32_t*>(seqstart_q_ptr)};
if constexpr(kStoreLSE)
{
kargs.lse_ptr = lse_ptr;
kargs.nhead_stride_lse = nhead_stride_lse;
kargs.batch_stride_lse = batch_stride_lse;
}
if constexpr(kDoFp8StaticQuant)
{
kargs.scale_o = scale_o;
}
return kargs;
}
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_)
{
return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_, hdim_v_);
}
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
// divide problem
const auto [i_tile_m, i_tile_n, i_nhead, i_batch] =
TilePartitioner{}(kargs.seqlen_q, kargs.hdim_v);
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
const long_index_t batch_offset_lse_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
const long_index_t batch_offset_o_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
long_index_t batch_offset_lse = 0;
long_index_t batch_offset_o = 0;
if constexpr(kStoreLSE)
{
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
}
if constexpr(kIsGroupMode)
{
// get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
batch_offset_o = query_start * kargs.row_stride_o;
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
if(kargs.seqlen_q <= i_m0)
{
return;
}
}
else
{
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
}
// for simplicity, batch stride we just modify the pointer
const LSEDataType* lse_acc_ptr =
reinterpret_cast<const LSEDataType*>(kargs.lse_acc_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_lse_acc + batch_offset_lse_acc;
const OaccDataType* o_acc_ptr =
reinterpret_cast<const OaccDataType*>(kargs.o_acc_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o_acc + batch_offset_o_acc;
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
batch_offset_o;
// LSEacc/Oacc DRAM and DRAM windows
const auto lse_acc_dram = [&]() {
const auto lse_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
lse_acc_ptr,
make_tuple(kargs.num_splits, kargs.seqlen_q),
make_tuple(kargs.split_stride_lse_acc, 1),
number<FmhaPipeline::kAlignmentLSEacc>{},
number<1>{});
return pad_tensor_view(
lse_acc_dram_naive,
make_tuple(number<FmhaPipeline::kMaxSplits>{}, number<FmhaPipeline::kM0>{}),
sequence<true, kPadSeqLenQ>{});
}();
auto o_acc_dram = [&]() {
const auto o_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
o_acc_ptr,
make_tuple(kargs.num_splits, kargs.max_seqlen_q, kargs.hdim_v),
make_tuple(kargs.split_stride_o_acc, kargs.row_stride_o_acc, 1),
number<FmhaPipeline::kAlignmentOacc>{},
number<1>{});
auto o_acc_dram_view = pad_tensor_view(
o_acc_dram_naive,
make_tuple(number<1>{}, number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
sequence<false, kPadSeqLenQ, kPadHeadDimV>{});
const index_t padded_max_seqlen_q =
o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<1>{}];
const index_t padded_hdim_v =
o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<2>{}];
return transform_tensor_view(
o_acc_dram_view,
make_tuple(make_merge_transform(make_tuple(kargs.num_splits, padded_max_seqlen_q)),
make_pass_through_transform(padded_hdim_v)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}();
auto lse_acc_dram_window = make_tile_window(
lse_acc_dram,
[&]() {
return make_tuple(number<FmhaPipeline::kMaxSplits>{}, number<FmhaPipeline::kM0>{});
}(),
{0, i_m0});
auto o_acc_dram_window = make_tile_window(
o_acc_dram,
[&]() {
return make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{});
}(),
{i_m0, i_n1});
// LSE DRAM window
auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
if constexpr(kStoreLSE)
{
LSEDataType* lse_ptr =
reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse;
const auto lse_dram = [&]() {
const auto lse_dram_naive = make_naive_tensor_view<address_space_enum::global>(
lse_ptr,
make_tuple(kargs.seqlen_q),
make_tuple(1),
number<FmhaPipeline::kAlignmentLSE>{},
number<1>{});
return pad_tensor_view(
lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
}();
return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
}
else
{
return make_null_tile_window(lse_dram_window_lengths);
}
}();
auto o_acc_tile = [&]() {
if constexpr(kDoFp8StaticQuant)
{
return FmhaPipeline{}(
lse_acc_dram_window,
o_acc_dram_window,
lse_dram_window,
identity{}, // lse_element_func
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
kargs.num_splits,
kargs.max_seqlen_q,
smem_ptr);
}
else
{
return FmhaPipeline{}(lse_acc_dram_window,
o_acc_dram_window,
lse_dram_window,
kargs.num_splits,
kargs.max_seqlen_q,
smem_ptr);
}
}();
// O DRAM and DRAM window
auto o_dram = [&]() {
const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
o_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_v),
make_tuple(kargs.row_stride_o, 1),
number<FmhaPipeline::kAlignmentO>{},
number<1>{});
return pad_tensor_view(
o_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
sequence<kPadSeqLenQ, kPadHeadDimV>{});
}();
auto o_dram_window =
make_tile_window(o_dram,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
{i_m0, i_n1});
EpiloguePipeline{}(o_dram_window, o_acc_tile);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <index_t kM0_, index_t kN1_>
struct FmhaFwdSplitKVCombineTilePartitioner
{
static constexpr ck_tile::index_t kM0 = kM0_;
static constexpr ck_tile::index_t kN1 = kN1_;
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0) *
ck_tile::integer_divide_ceil(hdim_v_, kN1),
nhead_,
batch_size_);
}
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v)
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1);
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
const auto f = [](index_t dividend, index_t divisor) {
index_t quotient = dividend / divisor;
index_t modulus = dividend - quotient * divisor;
return ck_tile::make_tuple(quotient, modulus);
};
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
}
};
} // namespace ck_tile
This diff is collapsed.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename BlockFmhaShape_>
struct FmhaFwdSplitKVTilePartitioner
{
using BlockFmhaShape = ck_tile::remove_cvref_t<BlockFmhaShape_>;
static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0;
static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0;
static constexpr ck_tile::index_t kK0 = BlockFmhaShape::kK0;
static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1;
static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1;
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead,
ck_tile::index_t seqlen_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_splits)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_q, kM0) *
ck_tile::integer_divide_ceil(hdim_v, kN1),
nhead * num_splits,
batch_size);
}
CK_TILE_DEVICE auto
operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v, ck_tile::index_t num_splits)
{
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1);
const auto f = [](index_t dividend, index_t divisor) {
index_t quotient = dividend / divisor;
index_t modulus = dividend - quotient * divisor;
return ck_tile::make_tuple(quotient, modulus);
};
const auto [i_tile_m, i_tile_n] = f(blockIdx.x, num_tile_n1);
const auto [i_nhead, i_split] = f(blockIdx.y, num_splits);
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
namespace detail {
template <index_t N>
struct log2;
template <>
struct log2<16> : std::integral_constant<index_t, 4>
{
};
template <>
struct log2<32> : std::integral_constant<index_t, 5>
{
};
template <>
struct log2<64> : std::integral_constant<index_t, 6>
{
};
template <>
struct log2<128> : std::integral_constant<index_t, 7>
{
};
} // namespace detail
template <typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
struct BlockFmhaFwdSplitKVCombinePipeline
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kHeadDimV = Problem::kHeadDimV;
static constexpr index_t kM0 = Problem::kM0;
static constexpr index_t kN1 = Problem::kN1;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr index_t kMaxSplits = Problem::kMaxSplits;
static constexpr index_t kAlignmentLSE =
kPadSeqLenQ ? 1 : Policy::template GetAlignmentLSE<Problem>();
static constexpr index_t kAlignmentLSEacc = kAlignmentLSE;
static constexpr index_t kAlignmentOacc =
kPadHeadDimV ? 1 : Policy::template GetAlignmentOacc<Problem>();
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
else
{
if constexpr(kHeadDimV <= 32)
{
constexpr std::array<int, 4> occupancy{3, 3, 3, 1};
return occupancy[detail::log2<kMaxSplits>::value - 4];
}
else if constexpr(kHeadDimV <= 128)
{
constexpr std::array<int, 4> occupancy{3, 3, 2, 1};
return occupancy[detail::log2<kMaxSplits>::value - 4];
}
else if constexpr(kHeadDimV <= 256)
{
constexpr std::array<int, 4> occupancy{2, 2, 2, 1};
return occupancy[detail::log2<kMaxSplits>::value - 4];
}
}
}();
static constexpr const char* name = "unused";
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename LSEaccDramBlockWindowTmp,
typename OaccDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename LSEElementFunction,
typename OaccElementFunction>
CK_TILE_HOST_DEVICE auto
operator()(const LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp,
const OaccDramBlockWindowTmp& o_acc_dram_block_window_tmp,
LSEDramBlockWindowTmp& lse_dram_window_tmp,
const LSEElementFunction& lse_element_func,
const OaccElementFunction& o_acc_element_func,
index_t num_splits,
index_t max_seqlen_q,
void* smem_ptr) const
{
// lse_acc tile in LDS
LSEDataType* lse_acc_lds_ptr =
static_cast<LSEDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
auto lse_acc_lds = [=, lds_desc = Policy::template MakeLSEaccLdsBlockDescriptor<Problem>()](
index_t row, index_t col) -> LSEDataType& {
return lse_acc_lds_ptr[lds_desc.calculate_offset(make_tuple(row, col))];
};
auto lse_acc_lds_write_window = [&]() {
auto view = make_tensor_view<address_space_enum::lds>(
lse_acc_lds_ptr, Policy::template MakeLSEaccLdsStoreBlockDescriptor<Problem>());
return make_tile_window(view, make_tuple(number<kMaxSplits>{}, number<kM0>{}), {0, 0});
}();
auto lse_acc_dram_window =
make_tile_window(lse_acc_dram_block_window_tmp.get_bottom_tensor_view(),
lse_acc_dram_block_window_tmp.get_window_lengths(),
lse_acc_dram_block_window_tmp.get_window_origin(),
Policy::template MakeLSEaccDramTileDistribution<Problem>());
// copy lse_acc tile (shape=[kMaxSplits, kM0]) to LDS (shape=[kMaxSplits, kM0]).
auto lse_acc_tile = load_tile(lse_acc_dram_window);
store_tile(lse_acc_lds_write_window, lse_acc_tile);
block_sync_lds();
auto lse_accum = make_static_distributed_tensor<LSEDataType>(
Policy::template MakeLSEaccRegTileDistribution<Problem>());
// copy LDS (shape=[kM0, kMaxSplits]) to lse_accum (shape=[kM0, max(kMaxSplits, warp_size)])
// this will extend the distributed tensor width so that each thread in wave have data to
// reduce.
{
constexpr auto spans = decltype(lse_accum)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto x_indices = get_x_indices_from_distributed_indices(
lse_accum.get_tile_distribution(), i_j_idx);
const auto col = x_indices.at(number<1>{});
if(col < num_splits)
{
const auto row = x_indices.at(number<0>{});
lse_accum(i_j_idx) = lse_acc_lds(row, col);
}
else
{
lse_accum(i_j_idx) = -numeric<LSEDataType>::infinity();
}
});
});
}
// compute the logsumexp of the LSE along the split dimension.
const auto f_max = [](auto e0, auto e1) { return ck_tile::max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
auto lse_max = block_tile_reduce<LSEDataType>(
lse_accum, sequence<1>{}, f_max, -numeric<LSEDataType>::infinity());
block_tile_reduce_sync(lse_max, f_max, bool_constant<false>{});
static const auto get_validated_m = [](LSEDataType raw_m) {
return raw_m == -numeric<LSEDataType>::infinity() ? type_convert<LSEDataType>(0.f)
: raw_m;
};
decltype(lse_accum) lse_exp;
{
constexpr auto spans = decltype(lse_exp)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
lse_exp(i_j_idx) =
ck_tile::exp(lse_accum(i_j_idx) - get_validated_m(lse_max(i_idx)));
});
});
}
auto lse_sum = block_tile_reduce<LSEDataType>(
lse_exp, sequence<1>{}, f_sum, type_convert<LSEDataType>(0));
block_tile_reduce_sync(lse_sum, f_sum, bool_constant<false>{});
decltype(lse_max) lse_logsum;
{
constexpr auto spans = decltype(lse_logsum)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
if(lse_sum(i_idx) == 0.f || lse_sum(i_idx) != lse_sum(i_idx))
{
lse_logsum(i_idx) = numeric<LSEDataType>::infinity();
}
else
{
lse_logsum(i_idx) =
ck_tile::log(lse_sum(i_idx)) + get_validated_m(lse_max(i_idx));
}
});
}
// store the lse scales in shared memory.
{
constexpr auto spans = decltype(lse_accum)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto x_indices = get_x_indices_from_distributed_indices(
lse_accum.get_tile_distribution(), i_j_idx);
const auto col = x_indices.at(number<1>{});
if(col < num_splits)
{
const auto row = x_indices.at(number<0>{});
lse_acc_lds(row, col) =
ck_tile::exp(lse_accum(i_j_idx) - lse_logsum(i_idx));
}
});
});
}
block_sync_lds();
if constexpr(kStoreLSE)
{
constexpr auto spans = decltype(lse_logsum)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
if(lse_logsum(i_idx) == numeric<LSEDataType>::infinity())
{
lse_logsum(i_idx) = -numeric<LSEDataType>::infinity();
}
});
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse_logsum));
}
auto o_acc_dist = Policy::template MakeOaccDramTileDistribution<Problem>();
auto o_acc_dram_window =
make_tile_window(o_acc_dram_block_window_tmp.get_bottom_tensor_view(),
o_acc_dram_block_window_tmp.get_window_lengths(),
o_acc_dram_block_window_tmp.get_window_origin(),
o_acc_dist);
auto o_acc = make_static_distributed_tensor<OaccDataType>(o_acc_dist);
clear_tile(o_acc);
const index_t padded_max_seqlen_q = integer_divide_ceil(max_seqlen_q, kM0) * kM0;
for(index_t i_split = 0; i_split < num_splits; ++i_split)
{
auto o_tile = load_tile(o_acc_dram_window);
{
constexpr auto spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto x_indices = get_x_indices_from_distributed_indices(
o_acc.get_tile_distribution(), i_j_idx);
const auto row = x_indices.at(number<0>{});
const LSEDataType lse_scale = lse_acc_lds(row, i_split);
o_acc(i_j_idx) += lse_scale * o_tile(i_j_idx);
});
});
}
move_tile_window(o_acc_dram_window, {padded_max_seqlen_q, 0});
}
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc;
}
template <typename LSEaccDramBlockWindow,
typename OaccDramBlockWindow,
typename LSEDramBlockWindow>
CK_TILE_HOST_DEVICE auto operator()(const LSEaccDramBlockWindow& lse_acc_dram_block_window,
const OaccDramBlockWindow& o_acc_dram_block_window,
LSEDramBlockWindow& lse_dram_block_window,
index_t num_splits,
index_t max_seqlen_q,
void* smem_ptr) const
{
return operator()(lse_acc_dram_block_window,
o_acc_dram_block_window,
lse_dram_block_window,
identity{},
identity{},
num_splits,
max_seqlen_q,
smem_ptr);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
namespace ck_tile {
struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentLSE()
{
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
return 16 / sizeof(LSEDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOacc()
{
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
return 16 / sizeof(OaccDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO()
{
using ODataType = remove_cvref_t<typename Problem::ODataType>;
return 16 / sizeof(ODataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return sizeof(typename Problem::LSEDataType) *
MakeLSEaccLdsBlockDescriptor<Problem>().get_element_space_size();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccDramTileDistribution()
{
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::kM0;
constexpr index_t kMPerBlock = Problem::kMaxSplits;
constexpr index_t NPerThread = 16 / sizeof(LSEDataType);
constexpr index_t NThreads = kNPerBlock / NPerThread;
constexpr index_t MThreadsPerWarp = get_warp_size() / NThreads;
constexpr index_t TotalWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (TotalWarps * MThreadsPerWarp);
static_assert(NThreads * NPerThread == kNPerBlock);
static_assert(MPerThread * TotalWarps * MThreadsPerWarp == kMPerBlock);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<MPerThread, TotalWarps, MThreadsPerWarp>,
sequence<NThreads, NPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
// 3d + padding, [kMaxSplits, kM0]
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccLdsStoreBlockDescriptor()
{
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
constexpr index_t kMPerBlock = Problem::kMaxSplits;
constexpr index_t kNPerBlock = Problem::kM0;
constexpr index_t NPack = 16 / sizeof(LSEDataType);
constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock / NPack>{}, number<kMPerBlock>{}, number<NPack>{}),
make_tuple(number<(kMPerBlock + 1) * NPack>{}, number<NPack>{}, number<1>{}),
number<8>{},
number<1>{});
constexpr auto lse_acc_lds_block_desc = transform_tensor_descriptor(
lse_acc_lds_block_desc_0,
make_tuple(make_pass_through_transform(kMPerBlock),
make_merge_transform(make_tuple(kNPerBlock / NPack, NPack))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return lse_acc_lds_block_desc;
}
// 3d + padding, [kM0, kMaxSplits]
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccLdsBlockDescriptor()
{
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
constexpr index_t kMPerBlock = Problem::kMaxSplits;
constexpr index_t kNPerBlock = Problem::kM0;
constexpr index_t NPack = 16 / sizeof(LSEDataType);
constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock / NPack>{}, number<kMPerBlock>{}, number<NPack>{}),
make_tuple(number<(kMPerBlock + 1) * NPack>{}, number<NPack>{}, number<1>{}),
number<8>{},
number<1>{});
constexpr auto lse_acc_t_lds_block_desc = transform_tensor_descriptor(
lse_acc_lds_block_desc_0,
make_tuple(make_pass_through_transform(kMPerBlock),
make_merge_transform(make_tuple(kNPerBlock / NPack, NPack))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return lse_acc_t_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccRegTileDistribution()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = max(Problem::kMaxSplits, get_warp_size());
constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t NThreads = get_warp_size();
constexpr index_t NPerThread = kNPerBlock / NThreads;
constexpr index_t MThreads = kBlockSize / NThreads;
constexpr index_t MPerThread = kMPerBlock / MThreads;
static_assert(NThreads * NPerThread == kNPerBlock);
static_assert(MThreads * MPerThread == kMPerBlock);
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<MThreads, MPerThread>, sequence<NThreads, NPerThread>>,
tuple<sequence<1>, sequence<2>>,
tuple<sequence<0>, sequence<0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOaccDramTileDistribution()
{
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kNPerBlock = Problem::kN1;
constexpr index_t N1 = 16 / sizeof(OaccDataType);
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t M2 = get_warp_size() / N0;
constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M0 = kMPerBlock / (M2 * M1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<N0, N1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment