Unverified Commit 0cb2e06d authored by Po Yen Chen's avatar Po Yen Chen Committed by GitHub
Browse files

[CK_TILE] fmha forward split-kv + combine kernels (#1338)



* FA fwd dropout

* FA bwd

* epilogue reuse

* CMakeLists update

* [CK_TILE] support alibi (#1269)

* add alibi support

* fix code

* update code based on comment

* Support more hdim

* fix fp8 bias

* support seqlen_k=0 case

* remove unused printf

* fix format

---------
Co-authored-by: default avatarrocking <ChunYu.Lai@amd.com>

* now fwd/bwd can build

* bwd alibi

* add bwd validation stream_config

* update generated filenames

* update bwd kernel launch

* CK_TILE_HOST_DEVICE in philox

* Transpose -> transpose

* format

* format

* format

* Generate the instance for FA required

* format

* fix error in WarpGemm

* Add num_splits option and dummy split-kv api method

* Generate fmha_fwd_splitkv()

* Add SplitKV kernel codegen logics

* Add SplitKV combine kernel codegen logics

* Fix mismatched return type

* Clean-up code

* Replace sentinel value before storing

* Fix wrong layout of LSE/LSEacc/Oacc

* Format codes

* Fix o_acc memory error

* Fix wrong kBlockSize used in policy

* Reduce # of combine kernels

* Fix split-kv combine kernel name

* Fix wrong LDS indexing logics

* Fix wrong loop counter step logic

* Undo vector size changes

* Remove no-longer used field

* Remove in-consistent comment

* Remove debug statements in example

* Remove more debug statements

* Add constness to local variables

* Clearn up generate.py

* Fix unstable clang-format comment

* Remove unused include directive

* Use shorter template parameter name

* Enable non-split-kv blobs

* Update license date

* Print num_splits conditionally

* Undo disabling data types

* Remove unnessary tile size for fp8

* Fix wrong pipeline args for fp8

* Fix example output format

* Remove more debug code in combine pipeline

* Add stride kernel arguments for LSE/O acc workspace

* Re-order split-kv pipeline call operator arguments

* Pass LSE/O strides in kernel argument

* Re-order pipeline call operator arguments

* Use tensor_descriptor to locate LSEacc elements

* Support providing invalid element for tensor view

* Set invalid element value for LSEacc tensor view

* Remove hand-written store_tile() code

* Remove necessary value-overwrite logic

* Add transposed lds descriptor

* Support load_tile() for tile_window_with_static_lengths<>

* Undo removing necessary value-overwrite logic

* Use read descriptor to locate lds elements

* Simplify pipeline source code

* Add constraint to kMaxSplits

* Default use kMaxSplits=64 in generate.py

* Revert "Add constraint to kMaxSplits"

This reverts commit 0a2132d758042e6fb0292f4e354909b8a4d1c118.

* Revert "Default use kMaxSplits=64 in generate.py"

This reverts commit c7d9c80b77320aec6559222bed7d47adcaefe4e3.

* Decide alignment by the padding parameter

* Remove no-longer used utility functions

* Remove not-working code

* Add comment & remove no-longer used code

* Fix computation errors

* Add heuristic to override num_splits option

* Add constraint to kMaxSplits

* Fix compilation error

* Clean up pipeline code

* Wrap pointer access as lambda function

* Rename confusing methods

* Use kLogMasSplits as template parameter

* Finish splitkv combine kernel codegen

* Update kMaxSplits limit

* Use smaller kM0 for splitkv combine kernel

* Ignore droupout flag in splitkv pipeline

* Unify flag usage

* Add back flag kStoreLSE

* Merge lambda calls in pipeline

* Fix compilation errors

* Avoid all empty splits

* Always check for empty loop in splitkv pipelines

* Re-order parameters

* Remove redundant p_drop option check

* Add traits/problem for fwd splitkv kernel

* Conditionally enable uneven split boundary checks

* Add comment for the splitkv traits field

* Change even split criteria

* Re-order statements

* Refine occupancy value for hdim=128&256

* Refine occupancy value for hdim=32&64

* Remove redundant kernel argument

* Separate fmha bwd codegen logics

* Separate fmha fwd codegen logics

* Remove redundant direction parameter in fwd&bwd codegen logics

* Support generate multiple APIs for an example

* Let 'api' an alias of 'direction' option

* Remove choices for the 'direction' option

* Use dictionary to config all the functions

* Move fmha splitkv codegen logics to other file

* Add fwd_splitkv api for tile_example_fmha_fwd

---------

Co-authored-by: danyao12 <danyao12>
Co-authored-by: default avatarcarlushuang <carlus.huang@amd.com>
Co-authored-by: default avatarrocking <ChunYu.Lai@amd.com>
Co-authored-by: default avatarJing Zhang <jizhan@amd.com>
parent 3e9711f0
# generate a list of kernels, but not actually emit files at config stage # generate a list of kernels, but not actually emit files at config stage
execute_process( execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--direction fwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt --api fwd,fwd_splitkv --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
) )
execute_process( execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--direction bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt --api bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt
) )
# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory # NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory
...@@ -17,13 +17,13 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS) ...@@ -17,13 +17,13 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
add_custom_command( add_custom_command(
OUTPUT ${FMHA_FWD_GEN_BLOBS} OUTPUT ${FMHA_FWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--direction fwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} --api fwd,fwd_splitkv --output_dir ${CMAKE_CURRENT_BINARY_DIR}
) )
add_custom_command( add_custom_command(
OUTPUT ${FMHA_BWD_GEN_BLOBS} OUTPUT ${FMHA_BWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--direction bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} --api bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}
) )
set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd") set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd")
......
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
GEN_DIR = "" # in Cmake, have to generate files in same folder
\ No newline at end of file
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
DTYPE_MAP = {
"fp16": "ck_tile::fp16_t",
"bf16": "ck_tile::bf16_t",
"fp8" : "ck_tile::fp8_t"
}
MASK_IMPL = {
"generic" : "ck_tile::GenericAttentionMask",
"simplified" : "ck_tile::SimplifiedGenericAttentionMask"
}
_MASK_SIMPLIFIED_MAP = {
"s_no" : "ck_tile::SimplifiedGenericAttentionMask<false>",
"s_mask" : "ck_tile::SimplifiedGenericAttentionMask<true>",
}
_MASK_MAP = {
"no" : "FmhaMasks::NoMask",
"causal" : "FmhaMasks::CausalMask",
"generic" : "FmhaMasks::GenericMask"
}
def get_mask_map(mask : str):
if mask == "generic":
return _MASK_MAP
elif mask == "simplified":
return _MASK_SIMPLIFIED_MAP
else:
assert False
return None
_MASK_CHECK_MAP = {
"no" : "t.mask_type == mask_enum::no_mask",
"causal" : "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right",
"generic" : "t.mask_type == mask_enum::window_generic",
}
_MASK_SIMPLIFIED_CHECK_MAP = {
"s_no" : "t.mask_type == mask_enum::no_mask",
"s_mask" : "t.mask_type != mask_enum::no_mask",
}
def get_mask_check_map(mask : str):
if mask == "generic":
return _MASK_CHECK_MAP
elif mask == "simplified":
return _MASK_SIMPLIFIED_CHECK_MAP
else:
assert False
return None
BIAS_MAP = {
"no" : "ck_tile::BlockAttentionBiasEnum::NO_BIAS",
"bias" : "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS",
"alibi" : "ck_tile::BlockAttentionBiasEnum::ALIBI"
}
# TODO: this is ugly
BIAS_CHECK_MAP = {
"no" : "bias_enum::no_bias",
"bias" : "bias_enum::elementwise_bias",
"alibi" : "bias_enum::alibi"
}
MODE_MAP = {
"batch" : "false",
"group" : "true"
}
LAYOUT_MAP = {
"row" : "true",
"col" : "false"
}
PIPELINE_MAP = {
"qr" : "ck_tile::BlockFmhaPipelineQRKSVS",
"qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync",
}
PIPELINE_ENUM_MAP = {
"qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
}
BOOL_MAP = {
"t" : "true",
"f" : "false"
}
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -114,6 +114,9 @@ auto create_args(int argc, char* argv[]) ...@@ -114,6 +114,9 @@ auto create_args(int argc, char* argv[])
.insert("drop_seed", "1", "seed for random number generator") .insert("drop_seed", "1", "seed for random number generator")
.insert("drop_offset", "0", "offset for random number generator") .insert("drop_offset", "0", "offset for random number generator")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("num_splits",
"1",
"# of splits for key/value. 0 to determine actual number by heuristic")
.insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel"); .insert("repeat", "20", "number of iterations to benchmark the kernel");
...@@ -155,6 +158,106 @@ auto get_elimit<ck_tile::fp8_t>(std::string init_method) ...@@ -155,6 +158,106 @@ auto get_elimit<ck_tile::fp8_t>(std::string init_method)
} }
} }
int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int num_n_blocks, int max_splits)
{
// If we have enough to almost fill the SMs, then just use 1 split
if(batch_nhead_mblocks >= 0.8f * num_SMs)
{
return 1;
}
max_splits = std::min({max_splits, num_SMs, num_n_blocks});
float max_efficiency = 0.f;
std::vector<float> efficiency;
efficiency.reserve(max_splits);
auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
// Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
// we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
// (i.e. it's 11 splits anyway).
// So we check if the number of blocks per split is the same as the previous num_splits.
auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
return num_splits == 1 ||
ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
};
for(int num_splits = 1; num_splits <= max_splits; num_splits++)
{
if(!is_split_eligible(num_splits))
{
efficiency.push_back(0.f);
}
else
{
float n_waves = float(batch_nhead_mblocks * num_splits) / num_SMs;
float eff = n_waves / ceil(n_waves);
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
if(eff > max_efficiency)
{
max_efficiency = eff;
}
efficiency.push_back(eff);
}
}
for(int num_splits = 1; num_splits <= max_splits; num_splits++)
{
if(!is_split_eligible(num_splits))
{
continue;
}
if(efficiency[num_splits - 1] >= 0.85 * max_efficiency)
{
// printf("num_splits chosen = %d\n", num_splits);
return num_splits;
}
}
return 1;
}
int override_num_splits_if_necessary(
int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits)
{
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess)
{
return num_splits;
}
hipDeviceProp_t props{};
status = hipGetDeviceProperties(&props, device);
if(status != hipSuccess)
{
return num_splits;
}
// tile size should match the generate.py
const int kM0 = 64;
const int kN1 = hdim_v;
const int num_m_blocks = ck_tile::integer_divide_ceil(max_seqlen_q, kM0);
const int num_n_blocks = ck_tile::integer_divide_ceil(hdim_v, kN1);
if(num_splits < 1 && p_drop == 0.0f)
{
return num_splits_heuristic(
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128);
}
return num_splits;
}
float fmha_fwd_dispatch(fmha_fwd_traits traits,
fmha_fwd_args args,
const ck_tile::stream_config& config)
{
if(1 < args.num_splits)
{
return fmha_fwd_splitkv(traits, args, config);
}
else
{
return fmha_fwd(traits, args, config);
}
}
template <typename DataType> template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser) bool run(const ck_tile::ArgParser& arg_parser)
{ {
...@@ -260,6 +363,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -260,6 +363,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
seed.reset(); seed.reset();
} }
int num_splits = arg_parser.get_int("num_splits");
int stream_warmup = arg_parser.get_int("warmup"); int stream_warmup = arg_parser.get_int("warmup");
int stream_repeat = arg_parser.get_int("repeat"); int stream_repeat = arg_parser.get_int("repeat");
bool kname = arg_parser.get_bool("kname"); bool kname = arg_parser.get_bool("kname");
...@@ -320,6 +425,18 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -320,6 +425,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
} }
// legalize num_splits according to other options
if(num_splits < 1)
{
num_splits = override_num_splits_if_necessary(
batch, nhead, max_seqlen_q, hdim_v, p_drop, num_splits);
}
if(128 < num_splits)
{
std::cerr << "num_splits greater than 128 is not supported" << std::endl;
return false;
}
auto get_lengths = [&](bool permute, auto get_lengths = [&](bool permute,
ck_tile::index_t b /*batch*/, ck_tile::index_t b /*batch*/,
ck_tile::index_t h /*nhead*/, ck_tile::index_t h /*nhead*/,
...@@ -361,7 +478,15 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -361,7 +478,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
: std::array<ck_tile::index_t, 2>{batch, nhead}) : std::array<ck_tile::index_t, 2>{batch, nhead})
: std::array<ck_tile::index_t, 2>{1, 1}); : std::array<ck_tile::index_t, 2>{1, 1});
// self define lse data layout as [shape_batch, nhead, shape_seqlen_q] ck_tile::HostTensor<LSEDataType> lse_acc_host(
1 < num_splits ? std::array<ck_tile::index_t, 4>{num_splits, batch, nhead, max_seqlen_q}
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
ck_tile::HostTensor<OaccDataType> o_acc_host(
1 < num_splits
? std::array<ck_tile::index_t, 5>{num_splits, batch, nhead, max_seqlen_q, hdim_v}
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
// self define lse data layout as [batch, nhead, max_seqlen_q]
ck_tile::HostTensor<LSEDataType> lse_host( ck_tile::HostTensor<LSEDataType> lse_host(
lse ? std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q} lse ? std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q}
: std::array<ck_tile::index_t, 3>{1, 1, 1} /* dummy shape for simplifying code */); : std::array<ck_tile::index_t, 3>{1, 1, 1} /* dummy shape for simplifying code */);
...@@ -443,6 +568,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -443,6 +568,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem o_acc_buf(o_acc_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
...@@ -479,7 +606,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -479,7 +606,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
: (std::string("(") + std::to_string(seqlen_kpads[0]) + ")")) : (std::string("(") + std::to_string(seqlen_kpads[0]) + ")"))
<< ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias << ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias
<< ", p_drop:" << p_drop << ", lse:" << lse << ", squant:" << squant << ", p_drop:" << p_drop << ", lse:" << lse << ", squant:" << squant
<< ", mask:" << mask << ", v:" << vlayout << std::flush; << ", mask:" << mask << ", v:" << vlayout;
if(1 < num_splits)
{
std::cout << ", num_splits:" << num_splits;
}
std::cout << std::flush;
auto fmha_traits = fmha_fwd_traits{hdim_q, auto fmha_traits = fmha_fwd_traits{hdim_q,
hdim_v, hdim_v,
...@@ -523,6 +655,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -523,6 +655,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}(); }();
const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k); const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k);
const ck_tile::index_t stride_randval = (max_seqlen_k); const ck_tile::index_t stride_randval = (max_seqlen_k);
const ck_tile::index_t stride_o_acc = hdim_v;
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
// setup nhead_stride_* arguments // setup nhead_stride_* arguments
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
...@@ -537,6 +670,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -537,6 +670,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
(i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k); (i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k);
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t nhead_stride_lse = max_seqlen_q; const ck_tile::index_t nhead_stride_lse = max_seqlen_q;
const ck_tile::index_t nhead_stride_lse_acc = max_seqlen_q;
const ck_tile::index_t nhead_stride_o_acc = (max_seqlen_q * hdim_v);
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
// setup batch_stride_* arguments // setup batch_stride_* arguments
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
...@@ -545,7 +680,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -545,7 +680,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k); const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k);
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_lse = (nhead * max_seqlen_q); const ck_tile::index_t batch_stride_lse = (nhead * max_seqlen_q);
const ck_tile::index_t batch_stride_lse_acc = (nhead * max_seqlen_q);
const ck_tile::index_t batch_stride_o_acc = (nhead * max_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
// setup split_stride_* arguments (only used in split-kv kernel)
const ck_tile::index_t split_stride_lse_acc = (batch * nhead * max_seqlen_q);
const ck_tile::index_t split_stride_o_acc = (batch * nhead * max_seqlen_q * hdim_v);
return fmha_fwd_args{q_buf.GetDeviceBuffer(), return fmha_fwd_args{q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(), k_buf.GetDeviceBuffer(),
...@@ -553,6 +693,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -553,6 +693,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer() bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer()
: bias_buf.GetDeviceBuffer(), : bias_buf.GetDeviceBuffer(),
randval_buf.GetDeviceBuffer(), randval_buf.GetDeviceBuffer(),
lse_acc_buf.GetDeviceBuffer(),
o_acc_buf.GetDeviceBuffer(),
lse_buf.GetDeviceBuffer(), lse_buf.GetDeviceBuffer(),
o_buf.GetDeviceBuffer(), o_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(), seqstart_q.GetDeviceBuffer(),
...@@ -566,6 +708,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -566,6 +708,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
hdim_v, hdim_v,
nhead, nhead,
nhead_k, nhead_k,
num_splits,
scale_s, scale_s,
scale_p, scale_p,
scale_o, scale_o,
...@@ -575,6 +718,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -575,6 +718,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead)
: stride_bias, : stride_bias,
stride_randval, stride_randval,
stride_o_acc,
stride_o, stride_o,
nhead_stride_q, nhead_stride_q,
nhead_stride_k, nhead_stride_k,
...@@ -582,6 +726,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -582,6 +726,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
nhead_stride_bias, nhead_stride_bias,
nhead_stride_randval, nhead_stride_randval,
nhead_stride_lse, nhead_stride_lse,
nhead_stride_lse_acc,
nhead_stride_o_acc,
nhead_stride_o, nhead_stride_o,
batch_stride_q, batch_stride_q,
batch_stride_k, batch_stride_k,
...@@ -589,7 +735,11 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -589,7 +735,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
batch_stride_bias, batch_stride_bias,
batch_stride_randval, batch_stride_randval,
batch_stride_lse, batch_stride_lse,
batch_stride_lse_acc,
batch_stride_o_acc,
batch_stride_o, batch_stride_o,
split_stride_lse_acc,
split_stride_o_acc,
mask.left, mask.left,
mask.right, mask.right,
static_cast<ck_tile::index_t>(mask.type), static_cast<ck_tile::index_t>(mask.type),
...@@ -598,7 +748,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -598,7 +748,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
{drop_seed, drop_offset}}; {drop_seed, drop_offset}};
}(); }();
float ave_time = fmha_fwd(fmha_traits, fmha_args, stream_config); float ave_time = fmha_fwd_dispatch(fmha_traits, fmha_args, stream_config);
if(ave_time < 0) if(ave_time < 0)
{ {
...@@ -849,14 +999,14 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -849,14 +999,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
lse_host_result.ForEach( lse_host_result.ForEach(
[&](auto& self, auto idx) { self(idx) = lse_host(wb, idx[0], idx[1]); }); [&](auto& self, auto idx) { self(idx) = lse_host(wb, idx[0], idx[1]); });
bool lse_pass = ck_tile::check_err(lse_host_result, cur_pass = ck_tile::check_err(lse_host_result,
lse_host_ref, lse_host_ref,
"LSE Error: Incorrect results!", "LSE Error: Incorrect results!",
rtol, rtol,
atol, atol,
/* allow_infinity_ref = */ true); /* allow_infinity_ref = */ true);
pass &= lse_pass; pass &= cur_pass;
if(!cur_pass) if(!cur_pass)
{ {
std::cerr << "LSE mismatch found at batch: " << wb << std::endl std::cerr << "LSE mismatch found at batch: " << wb << std::endl
......
...@@ -93,6 +93,8 @@ struct fmha_fwd_args ...@@ -93,6 +93,8 @@ struct fmha_fwd_args
const void* v_ptr; const void* v_ptr;
const void* bias_ptr; // bias or alibi_slope pointer const void* bias_ptr; // bias or alibi_slope pointer
void* rand_val_ptr; void* rand_val_ptr;
void* lse_acc_ptr;
void* o_acc_ptr;
void* lse_ptr; void* lse_ptr;
void* o_ptr; void* o_ptr;
const void* seqstart_q_ptr; const void* seqstart_q_ptr;
...@@ -106,6 +108,7 @@ struct fmha_fwd_args ...@@ -106,6 +108,7 @@ struct fmha_fwd_args
ck_tile::index_t hdim_v; ck_tile::index_t hdim_v;
ck_tile::index_t nhead_q; ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k; ck_tile::index_t nhead_k;
ck_tile::index_t num_splits;
float scale_s; float scale_s;
float scale_p; float scale_p;
float scale_o; float scale_o;
...@@ -114,6 +117,7 @@ struct fmha_fwd_args ...@@ -114,6 +117,7 @@ struct fmha_fwd_args
ck_tile::index_t stride_v; ck_tile::index_t stride_v;
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
ck_tile::index_t stride_randval; ck_tile::index_t stride_randval;
ck_tile::index_t stride_o_acc;
ck_tile::index_t stride_o; ck_tile::index_t stride_o;
ck_tile::index_t nhead_stride_q; ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k; ck_tile::index_t nhead_stride_k;
...@@ -121,6 +125,8 @@ struct fmha_fwd_args ...@@ -121,6 +125,8 @@ struct fmha_fwd_args
ck_tile::index_t nhead_stride_bias; ck_tile::index_t nhead_stride_bias;
ck_tile::index_t nhead_stride_randval; ck_tile::index_t nhead_stride_randval;
ck_tile::index_t nhead_stride_lse; ck_tile::index_t nhead_stride_lse;
ck_tile::index_t nhead_stride_lse_acc;
ck_tile::index_t nhead_stride_o_acc;
ck_tile::index_t nhead_stride_o; ck_tile::index_t nhead_stride_o;
ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_k;
...@@ -128,7 +134,11 @@ struct fmha_fwd_args ...@@ -128,7 +134,11 @@ struct fmha_fwd_args
ck_tile::index_t batch_stride_bias; ck_tile::index_t batch_stride_bias;
ck_tile::index_t batch_stride_randval; ck_tile::index_t batch_stride_randval;
ck_tile::index_t batch_stride_lse; ck_tile::index_t batch_stride_lse;
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t batch_stride_o; ck_tile::index_t batch_stride_o;
ck_tile::index_t split_stride_lse_acc;
ck_tile::index_t split_stride_o_acc;
ck_tile::index_t window_size_left; ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right; ck_tile::index_t window_size_right;
ck_tile::index_t mask_type; ck_tile::index_t mask_type;
...@@ -234,6 +244,176 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) ...@@ -234,6 +244,176 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
return ck_tile::make_tuple(kargs, grids); return ck_tile::make_tuple(kargs, grids);
} }
template <typename Kernel>
auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(Kernel::kIsGroupMode)
{
return Kernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.rand_val_ptr,
args.lse_acc_ptr,
args.o_acc_ptr,
args.batch,
args.max_seqlen_q,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.num_splits,
args.scale_s,
args.scale_p,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_randval,
args.stride_o_acc,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_randval,
args.nhead_stride_lse_acc,
args.nhead_stride_o_acc,
args.batch_stride_lse_acc,
args.batch_stride_o_acc,
args.split_stride_lse_acc,
args.split_stride_o_acc,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
}
else
{ // create batch mode kernel arguments
return Kernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.rand_val_ptr,
args.lse_acc_ptr,
args.o_acc_ptr,
args.batch,
args.max_seqlen_q,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.num_splits,
args.scale_s,
args.scale_p,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_randval,
args.stride_o_acc,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_randval,
args.nhead_stride_lse_acc,
args.nhead_stride_o_acc,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_v,
args.batch_stride_bias,
args.batch_stride_randval,
args.batch_stride_lse_acc,
args.batch_stride_o_acc,
args.split_stride_lse_acc,
args.split_stride_o_acc,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
}
}();
dim3 grids =
Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.num_splits);
return ck_tile::make_tuple(kargs, grids);
}
template <typename Kernel>
auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = [&] {
// create group mode kernel argumentszs
if constexpr(Kernel::kIsGroupMode)
{
return Kernel::MakeKargs(args.lse_acc_ptr,
args.o_acc_ptr,
args.lse_ptr,
args.o_ptr,
args.batch,
args.max_seqlen_q,
args.seqstart_q_ptr,
args.hdim_v,
args.num_splits,
args.scale_o,
args.stride_o_acc,
args.stride_o,
args.nhead_stride_lse_acc,
args.nhead_stride_o_acc,
args.nhead_stride_lse,
args.nhead_stride_o,
args.batch_stride_lse_acc,
args.batch_stride_o_acc,
args.batch_stride_lse,
args.split_stride_lse_acc,
args.split_stride_o_acc);
}
else
{ // create batch mode kernel arguments
return Kernel::MakeKargs(args.lse_acc_ptr,
args.o_acc_ptr,
args.lse_ptr,
args.o_ptr,
args.batch,
args.max_seqlen_q,
args.seqlen_q,
args.hdim_v,
args.num_splits,
args.scale_o,
args.stride_o_acc,
args.stride_o,
args.nhead_stride_lse_acc,
args.nhead_stride_o_acc,
args.nhead_stride_lse,
args.nhead_stride_o,
args.batch_stride_lse_acc,
args.batch_stride_o_acc,
args.batch_stride_lse,
args.batch_stride_o,
args.split_stride_lse_acc,
args.split_stride_o_acc);
}
}();
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v);
return ck_tile::make_tuple(kargs, grids);
}
// this is used to pattern-match internl kernel implementation, not to instantiate kernel // this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <ck_tile::index_t HDim_, template <ck_tile::index_t HDim_,
typename DataType_, typename DataType_,
...@@ -282,6 +462,40 @@ struct fmha_fwd_traits_ ...@@ -282,6 +462,40 @@ struct fmha_fwd_traits_
template <typename Traits_> template <typename Traits_>
float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args); float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args);
template <typename Traits_>
void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_args);
template <typename Traits_>
std::string fmha_fwd_splitkv_get_name_();
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
ck_tile::index_t kM0_,
ck_tile::index_t kN1_,
bool kStoreLse_,
bool kDoFp8StaticQuant_,
bool kPadS_,
bool kPadDv_>
struct fmha_fwd_splitkv_combine_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr ck_tile::index_t kM0 = kM0_;
static constexpr ck_tile::index_t kN1 = kN1_;
static constexpr bool kStoreLse = kStoreLse_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadDv = kPadDv_;
};
template <typename Traits_>
void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_args);
template <typename Traits_>
std::string fmha_fwd_splitkv_combine_get_name_();
// This is the public API, will be generated by script // This is the public API, will be generated by script
struct fmha_fwd_traits struct fmha_fwd_traits
{ {
...@@ -298,3 +512,4 @@ struct fmha_fwd_traits ...@@ -298,3 +512,4 @@ struct fmha_fwd_traits
// TODO: padding check is inside this api // TODO: padding check is inside this api
}; };
float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&); float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);
float fmha_fwd_splitkv(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);
This diff is collapsed.
...@@ -10,6 +10,10 @@ ...@@ -10,6 +10,10 @@
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp"
...@@ -22,6 +26,12 @@ ...@@ -22,6 +26,12 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp"
......
...@@ -299,6 +299,23 @@ struct SimplifiedGenericAttentionMask ...@@ -299,6 +299,23 @@ struct SimplifiedGenericAttentionMask
} }
} }
template <index_t TileHeight, index_t TileWidth>
CK_TILE_HOST_DEVICE constexpr auto GetTileRangeAlongX(index_t i_y,
number<TileHeight> height,
number<TileWidth> width,
index_t num_splits,
index_t i_split) const
{
auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width);
const index_t x_per_split = ck_tile::max(1, x_total / num_splits);
const index_t split_start = x_per_split * i_split;
const index_t split_end = (i_split == num_splits - 1 ? x_total : split_start + x_per_split);
return ck_tile::make_tuple(ck_tile::max(origin_start, split_start),
ck_tile::min(origin_end, split_end));
}
// to get the loop length along Y axis, return index:[start, end), end-start=length // to get the loop length along Y axis, return index:[start, end), end-start=length
// use this if need loop over Y axis tile by tile (like q-seqlen loopover) // use this if need loop over Y axis tile by tile (like q-seqlen loopover)
// TODO: y_end still could be negative, so end-start could be negative(need check) // TODO: y_end still could be negative, so end-start could be negative(need check)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck_tile {
template <typename TilePartitioner_, typename FmhaPipeline_, typename EpiloguePipeline_>
struct FmhaFwdSplitKVCombineKernel
{
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using FmhaPipeline = remove_cvref_t<FmhaPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
static constexpr index_t kBlockSize = FmhaPipeline::kBlockSize;
static constexpr index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
static_assert(kBlockPerCu > 0);
static constexpr index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
using LSEDataType = remove_cvref_t<typename FmhaPipeline::LSEDataType>;
using OaccDataType = remove_cvref_t<typename FmhaPipeline::OaccDataType>;
using ODataType = remove_cvref_t<typename FmhaPipeline::ODataType>;
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
// clang-format off
template <typename T> struct t2s;
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
// clang-format on
__host__ static std::string GetName()
{
// sync with generate.py
// clang-format off
#define _SS_ std::string
#define _TS_ std::to_string
auto pn = [&] () {
std::string n;
if (kPadSeqLenQ) n += "s";
if (kPadHeadDimV) n += "dv";
return n.empty() ? n : std::string("p") + n; }();
return
_SS_("fmha_fwd_splitkv_combine_d") + _TS_(FmhaPipeline::kHeadDimV) + "_" + _SS_(t2s<ODataType>::name) +
"_" + (kIsGroupMode ? "group" : "batch") + "_"
"b" + _TS_(FmhaPipeline::kM0) + "x" +
_TS_(FmhaPipeline::kN1) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) +
_SS_(FmhaPipeline::name) +
(pn.empty() ? "" : "_" + pn) +
(kStoreLSE ? "_lse" : "" ) +
(kDoFp8StaticQuant ? "_squant" : "" );
#undef _SS_
#undef _TS_
// clang-format on
}
template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
// arg
struct EmptyKargs
{
};
// kargs use aggregate initializer, so no constructor will provided
// use inheritance to minimize karg size
// user need to use MakeKargs() function to create kargs.
struct CommonKargs
{
const void* lse_acc_ptr;
const void* o_acc_ptr;
void* o_ptr;
ck_tile::index_t batch;
ck_tile::index_t max_seqlen_q;
ck_tile::index_t seqlen_q;
ck_tile::index_t hdim_v;
ck_tile::index_t num_splits;
ck_tile::index_t row_stride_o_acc;
ck_tile::index_t row_stride_o;
ck_tile::index_t nhead_stride_lse_acc;
ck_tile::index_t nhead_stride_o_acc;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t split_stride_lse_acc;
ck_tile::index_t split_stride_o_acc;
};
struct CommonLSEKargs
{
void* lse_ptr = nullptr;
ck_tile::index_t nhead_stride_lse = 0;
ck_tile::index_t batch_stride_lse = 0;
};
struct Fp8StaticQuantKargs
{
float scale_o;
};
struct BatchModeKargs
: CommonKargs,
std::conditional_t<kStoreLSE, CommonLSEKargs, EmptyKargs<0>>,
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<1>>
{
ck_tile::index_t batch_stride_o;
};
struct GroupModeKargs
: CommonKargs,
std::conditional_t<kStoreLSE, CommonLSEKargs, EmptyKargs<0>>,
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<3>>
{
const int32_t* seqstart_q_ptr;
};
using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
template <bool Cond = !kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* lse_acc_ptr,
const void* o_acc_ptr,
void* lse_ptr,
void* o_ptr,
ck_tile::index_t batch,
ck_tile::index_t max_seqlen_q,
ck_tile::index_t seqlen_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_splits,
float scale_o,
ck_tile::index_t row_stride_o_acc,
ck_tile::index_t row_stride_o,
ck_tile::index_t nhead_stride_lse_acc,
ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_lse_acc,
ck_tile::index_t batch_stride_o_acc,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t batch_stride_o,
ck_tile::index_t split_stride_lse_acc,
ck_tile::index_t split_stride_o_acc)
{
Kargs kargs{{lse_acc_ptr,
o_acc_ptr,
o_ptr,
batch,
max_seqlen_q,
seqlen_q,
hdim_v,
num_splits,
row_stride_o_acc,
row_stride_o,
nhead_stride_lse_acc,
nhead_stride_o_acc,
nhead_stride_o,
batch_stride_lse_acc,
batch_stride_o_acc,
split_stride_lse_acc,
split_stride_o_acc}, // args for common karg
{}, // placeholder for lse
{}, // placeholder for fp8_static_quant args
batch_stride_o};
if constexpr(kStoreLSE)
{
kargs.lse_ptr = lse_ptr;
kargs.nhead_stride_lse = nhead_stride_lse;
kargs.batch_stride_lse = batch_stride_lse;
}
if constexpr(kDoFp8StaticQuant)
{
kargs.scale_o = scale_o;
}
return kargs;
}
template <bool Cond = kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* lse_acc_ptr,
const void* o_acc_ptr,
void* lse_ptr,
void* o_ptr,
ck_tile::index_t batch,
ck_tile::index_t max_seqlen_q,
const void* seqstart_q_ptr,
ck_tile::index_t hdim_v,
ck_tile::index_t num_splits,
float scale_o,
ck_tile::index_t row_stride_o_acc,
ck_tile::index_t row_stride_o,
ck_tile::index_t nhead_stride_lse_acc,
ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_lse_acc,
ck_tile::index_t batch_stride_o_acc,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t split_stride_lse_acc,
ck_tile::index_t split_stride_o_acc)
{
Kargs kargs{{lse_acc_ptr,
o_acc_ptr,
o_ptr,
batch,
max_seqlen_q,
-1, // seqlen will be updated by another pointer
hdim_v,
num_splits,
row_stride_o_acc,
row_stride_o,
nhead_stride_lse_acc,
nhead_stride_o_acc,
nhead_stride_o,
batch_stride_lse_acc,
batch_stride_o_acc,
split_stride_lse_acc,
split_stride_o_acc}, // args for common karg
{}, // placeholder for lse
{}, // placeholder for fp8_static_quant args
reinterpret_cast<const int32_t*>(seqstart_q_ptr)};
if constexpr(kStoreLSE)
{
kargs.lse_ptr = lse_ptr;
kargs.nhead_stride_lse = nhead_stride_lse;
kargs.batch_stride_lse = batch_stride_lse;
}
if constexpr(kDoFp8StaticQuant)
{
kargs.scale_o = scale_o;
}
return kargs;
}
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_)
{
return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_, hdim_v_);
}
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
// divide problem
const auto [i_tile_m, i_tile_n, i_nhead, i_batch] =
TilePartitioner{}(kargs.seqlen_q, kargs.hdim_v);
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
const long_index_t batch_offset_lse_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
const long_index_t batch_offset_o_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
long_index_t batch_offset_lse = 0;
long_index_t batch_offset_o = 0;
if constexpr(kStoreLSE)
{
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
}
if constexpr(kIsGroupMode)
{
// get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
batch_offset_o = query_start * kargs.row_stride_o;
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
if(kargs.seqlen_q <= i_m0)
{
return;
}
}
else
{
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
}
// for simplicity, batch stride we just modify the pointer
const LSEDataType* lse_acc_ptr =
reinterpret_cast<const LSEDataType*>(kargs.lse_acc_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_lse_acc + batch_offset_lse_acc;
const OaccDataType* o_acc_ptr =
reinterpret_cast<const OaccDataType*>(kargs.o_acc_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o_acc + batch_offset_o_acc;
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
batch_offset_o;
// LSEacc/Oacc DRAM and DRAM windows
const auto lse_acc_dram = [&]() {
const auto lse_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
lse_acc_ptr,
make_tuple(kargs.num_splits, kargs.seqlen_q),
make_tuple(kargs.split_stride_lse_acc, 1),
number<FmhaPipeline::kAlignmentLSEacc>{},
number<1>{});
return pad_tensor_view(
lse_acc_dram_naive,
make_tuple(number<FmhaPipeline::kMaxSplits>{}, number<FmhaPipeline::kM0>{}),
sequence<true, kPadSeqLenQ>{});
}();
auto o_acc_dram = [&]() {
const auto o_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
o_acc_ptr,
make_tuple(kargs.num_splits, kargs.max_seqlen_q, kargs.hdim_v),
make_tuple(kargs.split_stride_o_acc, kargs.row_stride_o_acc, 1),
number<FmhaPipeline::kAlignmentOacc>{},
number<1>{});
auto o_acc_dram_view = pad_tensor_view(
o_acc_dram_naive,
make_tuple(number<1>{}, number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
sequence<false, kPadSeqLenQ, kPadHeadDimV>{});
const index_t padded_max_seqlen_q =
o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<1>{}];
const index_t padded_hdim_v =
o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<2>{}];
return transform_tensor_view(
o_acc_dram_view,
make_tuple(make_merge_transform(make_tuple(kargs.num_splits, padded_max_seqlen_q)),
make_pass_through_transform(padded_hdim_v)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}();
auto lse_acc_dram_window = make_tile_window(
lse_acc_dram,
[&]() {
return make_tuple(number<FmhaPipeline::kMaxSplits>{}, number<FmhaPipeline::kM0>{});
}(),
{0, i_m0});
auto o_acc_dram_window = make_tile_window(
o_acc_dram,
[&]() {
return make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{});
}(),
{i_m0, i_n1});
// LSE DRAM window
auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
if constexpr(kStoreLSE)
{
LSEDataType* lse_ptr =
reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse;
const auto lse_dram = [&]() {
const auto lse_dram_naive = make_naive_tensor_view<address_space_enum::global>(
lse_ptr,
make_tuple(kargs.seqlen_q),
make_tuple(1),
number<FmhaPipeline::kAlignmentLSE>{},
number<1>{});
return pad_tensor_view(
lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
}();
return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
}
else
{
return make_null_tile_window(lse_dram_window_lengths);
}
}();
auto o_acc_tile = [&]() {
if constexpr(kDoFp8StaticQuant)
{
return FmhaPipeline{}(
lse_acc_dram_window,
o_acc_dram_window,
lse_dram_window,
identity{}, // lse_element_func
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
kargs.num_splits,
kargs.max_seqlen_q,
smem_ptr);
}
else
{
return FmhaPipeline{}(lse_acc_dram_window,
o_acc_dram_window,
lse_dram_window,
kargs.num_splits,
kargs.max_seqlen_q,
smem_ptr);
}
}();
// O DRAM and DRAM window
auto o_dram = [&]() {
const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
o_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_v),
make_tuple(kargs.row_stride_o, 1),
number<FmhaPipeline::kAlignmentO>{},
number<1>{});
return pad_tensor_view(
o_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
sequence<kPadSeqLenQ, kPadHeadDimV>{});
}();
auto o_dram_window =
make_tile_window(o_dram,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
{i_m0, i_n1});
EpiloguePipeline{}(o_dram_window, o_acc_tile);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <index_t kM0_, index_t kN1_>
struct FmhaFwdSplitKVCombineTilePartitioner
{
static constexpr ck_tile::index_t kM0 = kM0_;
static constexpr ck_tile::index_t kN1 = kN1_;
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0) *
ck_tile::integer_divide_ceil(hdim_v_, kN1),
nhead_,
batch_size_);
}
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v)
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1);
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
const auto f = [](index_t dividend, index_t divisor) {
index_t quotient = dividend / divisor;
index_t modulus = dividend - quotient * divisor;
return ck_tile::make_tuple(quotient, modulus);
};
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
}
};
} // namespace ck_tile
This diff is collapsed.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename BlockFmhaShape_>
struct FmhaFwdSplitKVTilePartitioner
{
using BlockFmhaShape = ck_tile::remove_cvref_t<BlockFmhaShape_>;
static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0;
static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0;
static constexpr ck_tile::index_t kK0 = BlockFmhaShape::kK0;
static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1;
static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1;
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead,
ck_tile::index_t seqlen_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_splits)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_q, kM0) *
ck_tile::integer_divide_ceil(hdim_v, kN1),
nhead * num_splits,
batch_size);
}
CK_TILE_DEVICE auto
operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v, ck_tile::index_t num_splits)
{
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1);
const auto f = [](index_t dividend, index_t divisor) {
index_t quotient = dividend / divisor;
index_t modulus = dividend - quotient * divisor;
return ck_tile::make_tuple(quotient, modulus);
};
const auto [i_tile_m, i_tile_n] = f(blockIdx.x, num_tile_n1);
const auto [i_nhead, i_split] = f(blockIdx.y, num_splits);
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
namespace detail {
template <index_t N>
struct log2;
template <>
struct log2<16> : std::integral_constant<index_t, 4>
{
};
template <>
struct log2<32> : std::integral_constant<index_t, 5>
{
};
template <>
struct log2<64> : std::integral_constant<index_t, 6>
{
};
template <>
struct log2<128> : std::integral_constant<index_t, 7>
{
};
} // namespace detail
template <typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
struct BlockFmhaFwdSplitKVCombinePipeline
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kHeadDimV = Problem::kHeadDimV;
static constexpr index_t kM0 = Problem::kM0;
static constexpr index_t kN1 = Problem::kN1;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr index_t kMaxSplits = Problem::kMaxSplits;
static constexpr index_t kAlignmentLSE =
kPadSeqLenQ ? 1 : Policy::template GetAlignmentLSE<Problem>();
static constexpr index_t kAlignmentLSEacc = kAlignmentLSE;
static constexpr index_t kAlignmentOacc =
kPadHeadDimV ? 1 : Policy::template GetAlignmentOacc<Problem>();
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
else
{
if constexpr(kHeadDimV <= 32)
{
constexpr std::array<int, 4> occupancy{3, 3, 3, 1};
return occupancy[detail::log2<kMaxSplits>::value - 4];
}
else if constexpr(kHeadDimV <= 128)
{
constexpr std::array<int, 4> occupancy{3, 3, 2, 1};
return occupancy[detail::log2<kMaxSplits>::value - 4];
}
else if constexpr(kHeadDimV <= 256)
{
constexpr std::array<int, 4> occupancy{2, 2, 2, 1};
return occupancy[detail::log2<kMaxSplits>::value - 4];
}
}
}();
static constexpr const char* name = "unused";
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename LSEaccDramBlockWindowTmp,
typename OaccDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename LSEElementFunction,
typename OaccElementFunction>
CK_TILE_HOST_DEVICE auto
operator()(const LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp,
const OaccDramBlockWindowTmp& o_acc_dram_block_window_tmp,
LSEDramBlockWindowTmp& lse_dram_window_tmp,
const LSEElementFunction& lse_element_func,
const OaccElementFunction& o_acc_element_func,
index_t num_splits,
index_t max_seqlen_q,
void* smem_ptr) const
{
// lse_acc tile in LDS
LSEDataType* lse_acc_lds_ptr =
static_cast<LSEDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
auto lse_acc_lds = [=, lds_desc = Policy::template MakeLSEaccLdsBlockDescriptor<Problem>()](
index_t row, index_t col) -> LSEDataType& {
return lse_acc_lds_ptr[lds_desc.calculate_offset(make_tuple(row, col))];
};
auto lse_acc_lds_write_window = [&]() {
auto view = make_tensor_view<address_space_enum::lds>(
lse_acc_lds_ptr, Policy::template MakeLSEaccLdsStoreBlockDescriptor<Problem>());
return make_tile_window(view, make_tuple(number<kMaxSplits>{}, number<kM0>{}), {0, 0});
}();
auto lse_acc_dram_window =
make_tile_window(lse_acc_dram_block_window_tmp.get_bottom_tensor_view(),
lse_acc_dram_block_window_tmp.get_window_lengths(),
lse_acc_dram_block_window_tmp.get_window_origin(),
Policy::template MakeLSEaccDramTileDistribution<Problem>());
// copy lse_acc tile (shape=[kMaxSplits, kM0]) to LDS (shape=[kMaxSplits, kM0]).
auto lse_acc_tile = load_tile(lse_acc_dram_window);
store_tile(lse_acc_lds_write_window, lse_acc_tile);
block_sync_lds();
auto lse_accum = make_static_distributed_tensor<LSEDataType>(
Policy::template MakeLSEaccRegTileDistribution<Problem>());
// copy LDS (shape=[kM0, kMaxSplits]) to lse_accum (shape=[kM0, max(kMaxSplits, warp_size)])
// this will extend the distributed tensor width so that each thread in wave have data to
// reduce.
{
constexpr auto spans = decltype(lse_accum)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto x_indices = get_x_indices_from_distributed_indices(
lse_accum.get_tile_distribution(), i_j_idx);
const auto col = x_indices.at(number<1>{});
if(col < num_splits)
{
const auto row = x_indices.at(number<0>{});
lse_accum(i_j_idx) = lse_acc_lds(row, col);
}
else
{
lse_accum(i_j_idx) = -numeric<LSEDataType>::infinity();
}
});
});
}
// compute the logsumexp of the LSE along the split dimension.
const auto f_max = [](auto e0, auto e1) { return ck_tile::max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
auto lse_max = block_tile_reduce<LSEDataType>(
lse_accum, sequence<1>{}, f_max, -numeric<LSEDataType>::infinity());
block_tile_reduce_sync(lse_max, f_max, bool_constant<false>{});
static const auto get_validated_m = [](LSEDataType raw_m) {
return raw_m == -numeric<LSEDataType>::infinity() ? type_convert<LSEDataType>(0.f)
: raw_m;
};
decltype(lse_accum) lse_exp;
{
constexpr auto spans = decltype(lse_exp)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
lse_exp(i_j_idx) =
ck_tile::exp(lse_accum(i_j_idx) - get_validated_m(lse_max(i_idx)));
});
});
}
auto lse_sum = block_tile_reduce<LSEDataType>(
lse_exp, sequence<1>{}, f_sum, type_convert<LSEDataType>(0));
block_tile_reduce_sync(lse_sum, f_sum, bool_constant<false>{});
decltype(lse_max) lse_logsum;
{
constexpr auto spans = decltype(lse_logsum)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
if(lse_sum(i_idx) == 0.f || lse_sum(i_idx) != lse_sum(i_idx))
{
lse_logsum(i_idx) = numeric<LSEDataType>::infinity();
}
else
{
lse_logsum(i_idx) =
ck_tile::log(lse_sum(i_idx)) + get_validated_m(lse_max(i_idx));
}
});
}
// store the lse scales in shared memory.
{
constexpr auto spans = decltype(lse_accum)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto x_indices = get_x_indices_from_distributed_indices(
lse_accum.get_tile_distribution(), i_j_idx);
const auto col = x_indices.at(number<1>{});
if(col < num_splits)
{
const auto row = x_indices.at(number<0>{});
lse_acc_lds(row, col) =
ck_tile::exp(lse_accum(i_j_idx) - lse_logsum(i_idx));
}
});
});
}
block_sync_lds();
if constexpr(kStoreLSE)
{
constexpr auto spans = decltype(lse_logsum)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
if(lse_logsum(i_idx) == numeric<LSEDataType>::infinity())
{
lse_logsum(i_idx) = -numeric<LSEDataType>::infinity();
}
});
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse_logsum));
}
auto o_acc_dist = Policy::template MakeOaccDramTileDistribution<Problem>();
auto o_acc_dram_window =
make_tile_window(o_acc_dram_block_window_tmp.get_bottom_tensor_view(),
o_acc_dram_block_window_tmp.get_window_lengths(),
o_acc_dram_block_window_tmp.get_window_origin(),
o_acc_dist);
auto o_acc = make_static_distributed_tensor<OaccDataType>(o_acc_dist);
clear_tile(o_acc);
const index_t padded_max_seqlen_q = integer_divide_ceil(max_seqlen_q, kM0) * kM0;
for(index_t i_split = 0; i_split < num_splits; ++i_split)
{
auto o_tile = load_tile(o_acc_dram_window);
{
constexpr auto spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto x_indices = get_x_indices_from_distributed_indices(
o_acc.get_tile_distribution(), i_j_idx);
const auto row = x_indices.at(number<0>{});
const LSEDataType lse_scale = lse_acc_lds(row, i_split);
o_acc(i_j_idx) += lse_scale * o_tile(i_j_idx);
});
});
}
move_tile_window(o_acc_dram_window, {padded_max_seqlen_q, 0});
}
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc;
}
template <typename LSEaccDramBlockWindow,
typename OaccDramBlockWindow,
typename LSEDramBlockWindow>
CK_TILE_HOST_DEVICE auto operator()(const LSEaccDramBlockWindow& lse_acc_dram_block_window,
const OaccDramBlockWindow& o_acc_dram_block_window,
LSEDramBlockWindow& lse_dram_block_window,
index_t num_splits,
index_t max_seqlen_q,
void* smem_ptr) const
{
return operator()(lse_acc_dram_block_window,
o_acc_dram_block_window,
lse_dram_block_window,
identity{},
identity{},
num_splits,
max_seqlen_q,
smem_ptr);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
namespace ck_tile {
struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentLSE()
{
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
return 16 / sizeof(LSEDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOacc()
{
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
return 16 / sizeof(OaccDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO()
{
using ODataType = remove_cvref_t<typename Problem::ODataType>;
return 16 / sizeof(ODataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return sizeof(typename Problem::LSEDataType) *
MakeLSEaccLdsBlockDescriptor<Problem>().get_element_space_size();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccDramTileDistribution()
{
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::kM0;
constexpr index_t kMPerBlock = Problem::kMaxSplits;
constexpr index_t NPerThread = 16 / sizeof(LSEDataType);
constexpr index_t NThreads = kNPerBlock / NPerThread;
constexpr index_t MThreadsPerWarp = get_warp_size() / NThreads;
constexpr index_t TotalWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (TotalWarps * MThreadsPerWarp);
static_assert(NThreads * NPerThread == kNPerBlock);
static_assert(MPerThread * TotalWarps * MThreadsPerWarp == kMPerBlock);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<MPerThread, TotalWarps, MThreadsPerWarp>,
sequence<NThreads, NPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
// 3d + padding, [kMaxSplits, kM0]
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccLdsStoreBlockDescriptor()
{
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
constexpr index_t kMPerBlock = Problem::kMaxSplits;
constexpr index_t kNPerBlock = Problem::kM0;
constexpr index_t NPack = 16 / sizeof(LSEDataType);
constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock / NPack>{}, number<kMPerBlock>{}, number<NPack>{}),
make_tuple(number<(kMPerBlock + 1) * NPack>{}, number<NPack>{}, number<1>{}),
number<8>{},
number<1>{});
constexpr auto lse_acc_lds_block_desc = transform_tensor_descriptor(
lse_acc_lds_block_desc_0,
make_tuple(make_pass_through_transform(kMPerBlock),
make_merge_transform(make_tuple(kNPerBlock / NPack, NPack))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return lse_acc_lds_block_desc;
}
// 3d + padding, [kM0, kMaxSplits]
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccLdsBlockDescriptor()
{
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
constexpr index_t kMPerBlock = Problem::kMaxSplits;
constexpr index_t kNPerBlock = Problem::kM0;
constexpr index_t NPack = 16 / sizeof(LSEDataType);
constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock / NPack>{}, number<kMPerBlock>{}, number<NPack>{}),
make_tuple(number<(kMPerBlock + 1) * NPack>{}, number<NPack>{}, number<1>{}),
number<8>{},
number<1>{});
constexpr auto lse_acc_t_lds_block_desc = transform_tensor_descriptor(
lse_acc_lds_block_desc_0,
make_tuple(make_pass_through_transform(kMPerBlock),
make_merge_transform(make_tuple(kNPerBlock / NPack, NPack))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return lse_acc_t_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccRegTileDistribution()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = max(Problem::kMaxSplits, get_warp_size());
constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t NThreads = get_warp_size();
constexpr index_t NPerThread = kNPerBlock / NThreads;
constexpr index_t MThreads = kBlockSize / NThreads;
constexpr index_t MPerThread = kMPerBlock / MThreads;
static_assert(NThreads * NPerThread == kNPerBlock);
static_assert(MThreads * MPerThread == kMPerBlock);
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<MThreads, MPerThread>, sequence<NThreads, NPerThread>>,
tuple<sequence<1>, sequence<2>>,
tuple<sequence<0>, sequence<0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOaccDramTileDistribution()
{
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kNPerBlock = Problem::kN1;
constexpr index_t N1 = 16 / sizeof(OaccDataType);
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t M2 = get_warp_size() / N0;
constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M0 = kMPerBlock / (M2 * M1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<N0, N1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment