Unverified Commit 79a5d9c1 authored by Dan Yao's avatar Dan Yao Committed by GitHub
Browse files

[CK_TILE] FA bwd kernels optimization (#1397)



* tmp save

* fix batch deterministic bugs

* fix group deterministic bugs

* codegen update

* reorder files

* bias support

* hd256 bias support

* bwd smoke test update

* simplify convert dq

* fix hd256 dropout scratch

* do{}while() -> while(){}

* comments

* remove FmhaBwdTilePartitioner

* save clear_tile

* refactor dropout

* code cleanup

* code cleanup

* comments

* fix epilogue problem

* fix fwd dropout

* group convert_dq opt

* fix dq alignment

* Do not store storerandval in bwd for flash attention integration

* fix hd32 error and boost performance

* revert

* Remove duplicated WarpGemm definitions in the policy file

* dropout patch for mrepeat 16*16

* code sync up

* dq_acc stride

* dq_acc stride stuff

* codegen update

* fwd dropout revert

* fix hd128 scratches and boost performance

* receipt 3 for simplified smoke test

* more strides for fa integration

* fix hd64 scratches and boost performance

* non-iglp pipeline for headdim padding cases

* dpad same as dvpad for flash attention integration

* unpadded lse&d for group mode

* Support unpad layout for group lse

* Support unpad lse layout for splitkv

* Fix stride for splitkv kernel

* fix unpadded lse issue in fwd splitkv

* comment

* solve lds read&write conflicts

* rename

* bias rename

* tile index revert

---------

Co-authored-by: danyao12 <danyao12>
Co-authored-by: default avatarrocking <ChunYu.Lai@amd.com>
Co-authored-by: default avatarQianfeng Zhang <Qianfeng.Zhang@amd.com>
parent 2581727d
...@@ -6,7 +6,7 @@ execute_process( ...@@ -6,7 +6,7 @@ execute_process(
execute_process( execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt --api bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt --receipt 3
) )
# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory # NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory
...@@ -23,7 +23,7 @@ add_custom_command( ...@@ -23,7 +23,7 @@ add_custom_command(
add_custom_command( add_custom_command(
OUTPUT ${FMHA_BWD_GEN_BLOBS} OUTPUT ${FMHA_BWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} --api bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} --receipt 3
) )
set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd") set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd")
...@@ -55,11 +55,10 @@ set(EXAMPLE_FMHA_BWD_COMPILE_OPTIONS) ...@@ -55,11 +55,10 @@ set(EXAMPLE_FMHA_BWD_COMPILE_OPTIONS)
# ... because they are auto-generated # ... because they are auto-generated
if(FMHA_FWD_FAST_EXP2) if(FMHA_FWD_FAST_EXP2)
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero) list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
else() else()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0) list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
endif() endif()
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero)
# Allow comparing floating points directly in order to check sentinel values # Allow comparing floating points directly in order to check sentinel values
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal) list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal)
......
...@@ -66,6 +66,22 @@ BIAS_CHECK_MAP = { ...@@ -66,6 +66,22 @@ BIAS_CHECK_MAP = {
"alibi" : "bias_enum::alibi" "alibi" : "bias_enum::alibi"
} }
DROPOUT_MAP = {
"no" : "ck_tile::BlockDropoutBwd<false, true, false>",
"dropout_wg32" : "ck_tile::BlockDropoutBwd<true, true, false>",
"dropout_wg32_storerandval" : "ck_tile::BlockDropoutBwd<true, true, true >",
"dropout_wg16" : "ck_tile::BlockDropoutBwd<true, false, false>",
"dropout_wg16_storerandval" : "ck_tile::BlockDropoutBwd<true, false, true >"
}
DROPOUT_CHECK_MAP = {
"no" : "t.has_dropout == false",
"dropout_wg32" : "t.has_dropout == true && t.is_store_randval == false",
"dropout_wg32_storerandval" : "t.has_dropout == true && t.is_store_randval == true",
"dropout_wg16" : "t.has_dropout == true && t.is_store_randval == false",
"dropout_wg16_storerandval" : "t.has_dropout == true && t.is_store_randval == true",
}
MODE_MAP = { MODE_MAP = {
"batch" : "false", "batch" : "false",
"group" : "true" "group" : "true"
......
...@@ -87,7 +87,11 @@ auto create_args(int argc, char* argv[]) ...@@ -87,7 +87,11 @@ auto create_args(int argc, char* argv[])
.insert("drop_offset", "0", "offset for random number generator") .insert("drop_offset", "0", "offset for random number generator")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel"); .insert("repeat", "20", "number of iterations to benchmark the kernel")
.insert("deterministic",
"0",
"if set to 1 will use multi-buffer reduction strategy for dq, atomic opeartion "
"will not be used");
bool result = arg_parser.parse(argc, argv); bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser); return std::make_tuple(result, arg_parser);
...@@ -128,11 +132,6 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -128,11 +132,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
if(hdim_v < 0) if(hdim_v < 0)
hdim_v = hdim_q; hdim_v = hdim_q;
if(hdim_q % 2 != 0 || hdim_v % 2 != 0)
{
std::cerr << "FMHA Bwd kernel currently only supports even headdim" << std::endl;
return false;
}
bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim
bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim
...@@ -177,9 +176,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -177,9 +176,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
seed.reset(); seed.reset();
} }
int stream_warmup = arg_parser.get_int("warmup"); int stream_warmup = arg_parser.get_int("warmup");
int stream_repeat = arg_parser.get_int("repeat"); int stream_repeat = arg_parser.get_int("repeat");
bool kname = arg_parser.get_bool("kname"); bool kname = arg_parser.get_bool("kname");
bool deterministic = arg_parser.get_bool("deterministic");
ck_tile::stream_config stream_config{nullptr, ck_tile::stream_config stream_config{nullptr,
true, true,
...@@ -265,6 +265,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -265,6 +265,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
(mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back()); (mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back());
const ck_tile::index_t shape_seqlen_k = const ck_tile::index_t shape_seqlen_k =
(mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back()); (mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back());
const ck_tile::index_t kN0 = (hdim_q <= 128) ? 128 : 64;
const ck_tile::index_t nsplits =
deterministic ? ck_tile::integer_divide_ceil(max_seqlen_k, kN0) : 1;
ck_tile::HostTensor<QDataType> q_host( ck_tile::HostTensor<QDataType> q_host(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
...@@ -284,9 +287,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -284,9 +287,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<ODataType> o_host( ck_tile::HostTensor<ODataType> o_host(
get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
ck_tile::HostTensor<LSEDataType> lse_host( ck_tile::HostTensor<LSEDataType> lse_host(
std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q}); std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q});
ck_tile::HostTensor<DDataType> d_host( ck_tile::HostTensor<DDataType> d_host(
std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q}); std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q});
ck_tile::HostTensor<RandValOutputDataType> randval_host( ck_tile::HostTensor<RandValOutputDataType> randval_host(
p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k) p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1}); : std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
...@@ -302,6 +305,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -302,6 +305,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
use_dbias use_dbias
? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k) ? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */); : std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<AccDataType> dq_acc_host(
i_perm
? std::array<ck_tile::index_t, 5>{nsplits, shape_batch, nhead, shape_seqlen_q, hdim_q}
: std::array<ck_tile::index_t, 5>{nsplits, shape_batch, shape_seqlen_q, nhead, hdim_q});
if(init_method == 0) if(init_method == 0)
{ {
...@@ -362,6 +369,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -362,6 +369,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dq_acc_buf(dq_acc_host.get_element_space_size_in_bytes());
q_buf.ToDevice(q_host.data()); q_buf.ToDevice(q_host.data());
k_buf.ToDevice(k_host.data()); k_buf.ToDevice(k_host.data());
...@@ -387,8 +395,17 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -387,8 +395,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch
<< ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k
<< ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << bias << ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << bias
<< ", dbias:" << use_dbias << ", p_drop:" << p_drop << ", mask:" << mask << ", dbias:" << use_dbias << ", p_drop:" << p_drop << ", s_randval:" << s_randval
<< std::flush; << ", deterministic:" << deterministic << ", mask:" << mask << std::flush;
std::size_t workspace_size =
dq_acc_host.get_element_space_size_in_bytes() * sizeof(AccDataType) / (1024 * 1024);
if(deterministic == 1)
{
std::cout << "\nDeterministic mode ON: " << workspace_size
<< " MByte memory workspace allocated" << std::endl;
}
auto fmha_traits = fmha_bwd_traits{hdim_q, auto fmha_traits = fmha_bwd_traits{hdim_q,
hdim_v, hdim_v,
...@@ -397,7 +414,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -397,7 +414,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
mask.type, mask.type,
bias.type, bias.type,
use_dbias, use_dbias,
p_drop > 0.0f}; p_drop > 0.0f,
s_randval,
deterministic};
auto fmha_args = [&]() { auto fmha_args = [&]() {
assert(nhead % nhead_k == 0); assert(nhead % nhead_k == 0);
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
...@@ -422,7 +441,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -422,7 +441,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t nhead_stride_do = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); const ck_tile::index_t nhead_stride_do = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
const ck_tile::index_t nhead_stride_lsed = max_seqlen_q; const ck_tile::index_t nhead_stride_lsed = shape_seqlen_q;
const ck_tile::index_t nhead_stride_dbias = const ck_tile::index_t nhead_stride_dbias =
(i_perm ? shape_seqlen_q * max_seqlen_k : max_seqlen_k); (i_perm ? shape_seqlen_q * max_seqlen_k : max_seqlen_k);
// setup batch_stride_* arguments // setup batch_stride_* arguments
...@@ -433,10 +452,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -433,10 +452,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_do = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_do = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_lsed = (nhead * max_seqlen_q); const ck_tile::index_t batch_stride_lsed = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q); const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q);
const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v); const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v);
const ck_tile::index_t batch_stride_dbias = (nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_dbias = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t split_stride_dq_acc =
(shape_batch * nhead * shape_seqlen_q * hdim_q);
return fmha_bwd_args{q_buf.GetDeviceBuffer(), return fmha_bwd_args{q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(), k_buf.GetDeviceBuffer(),
...@@ -452,6 +473,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -452,6 +473,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
dk_buf.GetDeviceBuffer(), dk_buf.GetDeviceBuffer(),
dv_buf.GetDeviceBuffer(), dv_buf.GetDeviceBuffer(),
dbias_buf.GetDeviceBuffer(), dbias_buf.GetDeviceBuffer(),
dq_acc_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(), seqstart_q.GetDeviceBuffer(),
seqstart_k.GetDeviceBuffer(), seqstart_k.GetDeviceBuffer(),
nullptr, nullptr,
...@@ -473,6 +495,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -473,6 +495,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
stride_o, stride_o,
stride_randval, stride_randval,
stride_do, stride_do,
stride_q, // stride_dq_acc
stride_q, // stride_dq
stride_dk, stride_dk,
stride_dv, stride_dv,
stride_dbias, stride_dbias,
...@@ -484,6 +508,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -484,6 +508,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
nhead_stride_randval, nhead_stride_randval,
nhead_stride_do, nhead_stride_do,
nhead_stride_lsed, nhead_stride_lsed,
nhead_stride_q, // nhead_stride_dq_acc
nhead_stride_q, // nhead_stride_dq
nhead_stride_k, // nhead_stride_dk
nhead_stride_v, // nhead_stride_dv
nhead_stride_dbias, nhead_stride_dbias,
batch_stride_q, batch_stride_q,
batch_stride_k, batch_stride_k,
...@@ -493,15 +521,17 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -493,15 +521,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
batch_stride_randval, batch_stride_randval,
batch_stride_do, batch_stride_do,
batch_stride_lsed, batch_stride_lsed,
batch_stride_q, // batch_stride_dq_acc
batch_stride_q, // batch_stride_dq
batch_stride_dk, batch_stride_dk,
batch_stride_dv, batch_stride_dv,
batch_stride_dbias, batch_stride_dbias,
split_stride_dq_acc,
mask.left, mask.left,
mask.right, mask.right,
static_cast<ck_tile::index_t>(mask.type), static_cast<ck_tile::index_t>(mask.type),
p_drop, p_drop,
p_undrop, p_undrop,
s_randval,
{drop_seed, drop_offset}}; {drop_seed, drop_offset}};
}(); }();
...@@ -719,7 +749,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -719,7 +749,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(o_perm) o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[0], idx[1] + query_offset, idx[2]) = self(idx); }); if(o_perm) o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[0], idx[1] + query_offset, idx[2]) = self(idx); });
else o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[1] + query_offset, idx[0], idx[2]) = self(idx); }); else o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[1] + query_offset, idx[0], idx[2]) = self(idx); });
lse_host_ref.ForEach([&](auto& self, auto idx) { lse_host(wb, idx[0], idx[1]) = self(idx); }); lse_host_ref.ForEach([&](auto& self, auto idx) { lse_host(b, idx[0], idx[1] + query_offset) = self(idx); });
// clang-format on // clang-format on
q_host_refs.push_back(q_host_ref); q_host_refs.push_back(q_host_ref);
...@@ -738,6 +768,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -738,6 +768,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
lse_buf.ToDevice(lse_host.data()); lse_buf.ToDevice(lse_host.data());
dq_buf.SetZero(); dq_buf.SetZero();
dbias_buf.SetZero(); dbias_buf.SetZero();
dq_acc_buf.SetZero();
ck_tile::stream_config stream_config_v{ ck_tile::stream_config stream_config_v{
nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")}; nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")};
......
...@@ -77,6 +77,7 @@ struct fmha_bwd_args ...@@ -77,6 +77,7 @@ struct fmha_bwd_args
void* dk_ptr; void* dk_ptr;
void* dv_ptr; void* dv_ptr;
void* dbias_ptr; void* dbias_ptr;
void* dq_acc_ptr;
const void* seqstart_q_ptr; const void* seqstart_q_ptr;
const void* seqstart_k_ptr; const void* seqstart_k_ptr;
const void* seqlen_k_ptr; const void* seqlen_k_ptr;
...@@ -97,6 +98,8 @@ struct fmha_bwd_args ...@@ -97,6 +98,8 @@ struct fmha_bwd_args
ck_tile::index_t stride_o; ck_tile::index_t stride_o;
ck_tile::index_t stride_randval; ck_tile::index_t stride_randval;
ck_tile::index_t stride_do; ck_tile::index_t stride_do;
ck_tile::index_t stride_dq_acc;
ck_tile::index_t stride_dq;
ck_tile::index_t stride_dk; ck_tile::index_t stride_dk;
ck_tile::index_t stride_dv; ck_tile::index_t stride_dv;
ck_tile::index_t stride_dbias; ck_tile::index_t stride_dbias;
...@@ -108,6 +111,10 @@ struct fmha_bwd_args ...@@ -108,6 +111,10 @@ struct fmha_bwd_args
ck_tile::index_t nhead_stride_randval; ck_tile::index_t nhead_stride_randval;
ck_tile::index_t nhead_stride_do; ck_tile::index_t nhead_stride_do;
ck_tile::index_t nhead_stride_lsed; ck_tile::index_t nhead_stride_lsed;
ck_tile::index_t nhead_stride_dq_acc;
ck_tile::index_t nhead_stride_dq;
ck_tile::index_t nhead_stride_dk;
ck_tile::index_t nhead_stride_dv;
ck_tile::index_t nhead_stride_dbias; ck_tile::index_t nhead_stride_dbias;
ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_k;
...@@ -117,15 +124,17 @@ struct fmha_bwd_args ...@@ -117,15 +124,17 @@ struct fmha_bwd_args
ck_tile::index_t batch_stride_randval; ck_tile::index_t batch_stride_randval;
ck_tile::index_t batch_stride_do; ck_tile::index_t batch_stride_do;
ck_tile::index_t batch_stride_lsed; ck_tile::index_t batch_stride_lsed;
ck_tile::index_t batch_stride_dq_acc;
ck_tile::index_t batch_stride_dq;
ck_tile::index_t batch_stride_dk; ck_tile::index_t batch_stride_dk;
ck_tile::index_t batch_stride_dv; ck_tile::index_t batch_stride_dv;
ck_tile::index_t batch_stride_dbias; ck_tile::index_t batch_stride_dbias;
ck_tile::index_t split_stride_dq_acc;
ck_tile::index_t window_size_left; ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right; ck_tile::index_t window_size_right;
ck_tile::index_t mask_type; ck_tile::index_t mask_type;
float p_drop; float p_drop;
float p_undrop; float p_undrop;
bool s_randval;
std::tuple<uint64_t, uint64_t> drop_seed_offset; std::tuple<uint64_t, uint64_t> drop_seed_offset;
}; };
...@@ -145,10 +154,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) ...@@ -145,10 +154,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.do_ptr, args.do_ptr,
args.d_ptr, args.d_ptr,
args.rand_val_ptr, args.rand_val_ptr,
args.dq_ptr,
args.dk_ptr, args.dk_ptr,
args.dv_ptr, args.dv_ptr,
args.dbias_ptr, args.dbias_ptr,
args.dq_acc_ptr,
args.seqstart_q_ptr, args.seqstart_q_ptr,
args.seqstart_k_ptr, args.seqstart_k_ptr,
args.seqlen_k_ptr, args.seqlen_k_ptr,
...@@ -163,6 +172,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) ...@@ -163,6 +172,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.stride_bias, args.stride_bias,
args.stride_randval, args.stride_randval,
args.stride_do, args.stride_do,
args.stride_dq_acc,
args.stride_dk, args.stride_dk,
args.stride_dv, args.stride_dv,
args.stride_dbias, args.stride_dbias,
...@@ -173,13 +183,15 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) ...@@ -173,13 +183,15 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.nhead_stride_randval, args.nhead_stride_randval,
args.nhead_stride_do, args.nhead_stride_do,
args.nhead_stride_lsed, args.nhead_stride_lsed,
args.nhead_stride_dq_acc,
args.nhead_stride_dk,
args.nhead_stride_dv,
args.nhead_stride_dbias, args.nhead_stride_dbias,
args.batch_stride_lsed, args.split_stride_dq_acc,
args.window_size_left, args.window_size_left,
args.window_size_right, args.window_size_right,
args.mask_type, args.mask_type,
args.p_drop, args.p_drop,
args.s_randval,
args.drop_seed_offset); args.drop_seed_offset);
} }
else else
...@@ -192,10 +204,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) ...@@ -192,10 +204,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.do_ptr, args.do_ptr,
args.d_ptr, args.d_ptr,
args.rand_val_ptr, args.rand_val_ptr,
args.dq_ptr,
args.dk_ptr, args.dk_ptr,
args.dv_ptr, args.dv_ptr,
args.dbias_ptr, args.dbias_ptr,
args.dq_acc_ptr,
args.seqlen_q, args.seqlen_q,
args.seqlen_k, args.seqlen_k,
args.hdim_q, args.hdim_q,
...@@ -209,6 +221,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) ...@@ -209,6 +221,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.stride_bias, args.stride_bias,
args.stride_randval, args.stride_randval,
args.stride_do, args.stride_do,
args.stride_dq_acc,
args.stride_dk, args.stride_dk,
args.stride_dv, args.stride_dv,
args.stride_dbias, args.stride_dbias,
...@@ -219,6 +232,9 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) ...@@ -219,6 +232,9 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.nhead_stride_randval, args.nhead_stride_randval,
args.nhead_stride_do, args.nhead_stride_do,
args.nhead_stride_lsed, args.nhead_stride_lsed,
args.nhead_stride_dq_acc,
args.nhead_stride_dk,
args.nhead_stride_dv,
args.nhead_stride_dbias, args.nhead_stride_dbias,
args.batch_stride_q, args.batch_stride_q,
args.batch_stride_k, args.batch_stride_k,
...@@ -227,14 +243,15 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) ...@@ -227,14 +243,15 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.batch_stride_randval, args.batch_stride_randval,
args.batch_stride_do, args.batch_stride_do,
args.batch_stride_lsed, args.batch_stride_lsed,
args.batch_stride_dq_acc,
args.batch_stride_dk, args.batch_stride_dk,
args.batch_stride_dv, args.batch_stride_dv,
args.batch_stride_dbias, args.batch_stride_dbias,
args.split_stride_dq_acc,
args.window_size_left, args.window_size_left,
args.window_size_right, args.window_size_right,
args.mask_type, args.mask_type,
args.p_drop, args.p_drop,
args.s_randval,
args.drop_seed_offset); args.drop_seed_offset);
} }
}(); }();
...@@ -260,8 +277,7 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args) ...@@ -260,8 +277,7 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
args.stride_o, args.stride_o,
args.nhead_stride_do, args.nhead_stride_do,
args.nhead_stride_o, args.nhead_stride_o,
args.nhead_stride_lsed, args.nhead_stride_lsed);
args.batch_stride_lsed);
} }
else else
{ // create batch mode kernel arguments { // create batch mode kernel arguments
...@@ -286,19 +302,59 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args) ...@@ -286,19 +302,59 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
return ck_tile::make_tuple(kargs, grids); return ck_tile::make_tuple(kargs, grids);
} }
template <typename FmhaBwdConvertQGradKernel>
auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
{
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(FmhaBwdConvertQGradKernel::kIsGroupMode)
{
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
args.dq_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.hdim_q,
args.stride_dq,
args.stride_dq_acc,
args.nhead_stride_dq,
args.nhead_stride_dq_acc,
args.split_stride_dq_acc);
}
else
{ // create batch mode kernel arguments
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
args.dq_ptr,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
args.stride_dq,
args.stride_dq_acc,
args.nhead_stride_dq,
args.nhead_stride_dq_acc,
args.batch_stride_dq,
args.batch_stride_dq_acc,
args.split_stride_dq_acc);
}
}();
dim3 grids = FmhaBwdConvertQGradKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q);
return ck_tile::make_tuple(kargs, grids);
}
// this is used to pattern-match internl kernel implementation, not to instantiate kernel // this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <ck_tile::index_t HDim_, template <ck_tile::index_t HDim_,
typename DataType_, typename DataType_,
bool kIsGroupMode_, bool kIsGroupMode_,
ck_tile::BlockFmhaBwdPipelineEnum FmhaBwdPipelineEnum_, ck_tile::BlockFmhaBwdPipelineEnum FmhaBwdPipelineEnum_,
typename FmhaMask_, typename FmhaMask_,
typename FmhaDropout_,
ck_tile::BlockAttentionBiasEnum BiasEnum_, ck_tile::BlockAttentionBiasEnum BiasEnum_,
bool kHasBiasGrad_, bool kHasBiasGrad_,
bool kHasDropout_,
bool kPadS_, bool kPadS_,
bool kPadSK_, bool kPadSK_,
bool kPadD_, bool kPadD_,
bool kPadDv_> bool kPadDv_,
bool kIsDeterministic_>
struct fmha_bwd_dq_dk_dv_traits_ struct fmha_bwd_dq_dk_dv_traits_
{ {
static constexpr ck_tile::index_t HDim = HDim_; static constexpr ck_tile::index_t HDim = HDim_;
...@@ -306,13 +362,14 @@ struct fmha_bwd_dq_dk_dv_traits_ ...@@ -306,13 +362,14 @@ struct fmha_bwd_dq_dk_dv_traits_
static constexpr bool kIsGroupMode = kIsGroupMode_; static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr auto FmhaBwdPipelineEnum = FmhaBwdPipelineEnum_; static constexpr auto FmhaBwdPipelineEnum = FmhaBwdPipelineEnum_;
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>; using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
using FmhaDropout = ck_tile::remove_cvref_t<FmhaDropout_>;
static constexpr auto BiasEnum = BiasEnum_; static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kHasBiasGrad = kHasBiasGrad_; static constexpr bool kHasBiasGrad = kHasBiasGrad_;
static constexpr bool kHasDropout = kHasDropout_;
static constexpr bool kPadS = kPadS_; static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSK = kPadSK_; static constexpr bool kPadSK = kPadSK_;
static constexpr bool kPadD = kPadD_; static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_; static constexpr bool kPadDv = kPadDv_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
}; };
template <typename Traits_> template <typename Traits_>
...@@ -343,6 +400,31 @@ void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args); ...@@ -343,6 +400,31 @@ void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_> template <typename Traits_>
std::string fmha_bwd_dot_do_o_get_name_(); std::string fmha_bwd_dot_do_o_get_name_();
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
bool kPadS_,
bool kPadD_,
bool kIsDeterministic_>
struct fmha_bwd_convert_dq_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
};
template <typename Traits_>
float fmha_bwd_convert_dq_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
std::string fmha_bwd_convert_dq_get_name_();
// This is the public API, will be generated by script // This is the public API, will be generated by script
struct fmha_bwd_traits struct fmha_bwd_traits
{ {
...@@ -354,6 +436,8 @@ struct fmha_bwd_traits ...@@ -354,6 +436,8 @@ struct fmha_bwd_traits
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool has_dbias; bool has_dbias;
bool has_dropout; bool has_dropout;
bool is_store_randval;
bool is_deterministic;
// TODO: padding check is inside this api // TODO: padding check is inside this api
}; };
float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&); float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&);
...@@ -479,16 +479,18 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -479,16 +479,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
: std::array<ck_tile::index_t, 2>{1, 1}); : std::array<ck_tile::index_t, 2>{1, 1});
ck_tile::HostTensor<LSEDataType> lse_acc_host( ck_tile::HostTensor<LSEDataType> lse_acc_host(
1 < num_splits ? std::array<ck_tile::index_t, 4>{num_splits, batch, nhead, max_seqlen_q} 1 < num_splits
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1}); ? std::array<ck_tile::index_t, 4>{num_splits, shape_batch, nhead, shape_seqlen_q}
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
ck_tile::HostTensor<OaccDataType> o_acc_host( ck_tile::HostTensor<OaccDataType> o_acc_host(
1 < num_splits 1 < num_splits
? std::array<ck_tile::index_t, 5>{num_splits, batch, nhead, max_seqlen_q, hdim_v} ? std::array<ck_tile::index_t, 5>{num_splits, batch, nhead, max_seqlen_q, hdim_v}
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1}); : std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
// self define lse data layout as [batch, nhead, max_seqlen_q] // batch mode of lse data layout is [batch, nhead, seqlen_q]
// group mode of lse data layout is [nhead, total_seqlen_q]
ck_tile::HostTensor<LSEDataType> lse_host( ck_tile::HostTensor<LSEDataType> lse_host(
lse ? std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q} lse ? std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q}
: std::array<ck_tile::index_t, 3>{1, 1, 1} /* dummy shape for simplifying code */); : std::array<ck_tile::index_t, 3>{1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<ODataType> o_host( ck_tile::HostTensor<ODataType> o_host(
...@@ -669,8 +671,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -669,8 +671,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t nhead_stride_bias = const ck_tile::index_t nhead_stride_bias =
(i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k); (i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k);
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t nhead_stride_lse = max_seqlen_q; const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
const ck_tile::index_t nhead_stride_lse_acc = max_seqlen_q; const ck_tile::index_t nhead_stride_lse_acc = shape_seqlen_q;
const ck_tile::index_t nhead_stride_o_acc = (max_seqlen_q * hdim_v); const ck_tile::index_t nhead_stride_o_acc = (max_seqlen_q * hdim_v);
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
// setup batch_stride_* arguments // setup batch_stride_* arguments
...@@ -679,12 +681,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -679,12 +681,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k); const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k);
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k); const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k);
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_lse = (nhead * max_seqlen_q); const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_lse_acc = (nhead * max_seqlen_q); const ck_tile::index_t batch_stride_lse_acc = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_o_acc = (nhead * max_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o_acc = (nhead * max_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
// setup split_stride_* arguments (only used in split-kv kernel) // setup split_stride_* arguments (only used in split-kv kernel)
const ck_tile::index_t split_stride_lse_acc = (batch * nhead * max_seqlen_q); const ck_tile::index_t split_stride_lse_acc = (shape_batch * nhead * shape_seqlen_q);
const ck_tile::index_t split_stride_o_acc = (batch * nhead * max_seqlen_q * hdim_v); const ck_tile::index_t split_stride_o_acc = (batch * nhead * max_seqlen_q * hdim_v);
return fmha_fwd_args{q_buf.GetDeviceBuffer(), return fmha_fwd_args{q_buf.GetDeviceBuffer(),
...@@ -996,8 +998,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -996,8 +998,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(lse) if(lse)
{ {
ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q}); ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q});
lse_host_result.ForEach( lse_host_result.ForEach([&](auto& self, auto idx) {
[&](auto& self, auto idx) { self(idx) = lse_host(wb, idx[0], idx[1]); }); self(idx) = lse_host(b, idx[0], idx[1] + query_offset);
});
cur_pass = ck_tile::check_err(lse_host_result, cur_pass = ck_tile::check_err(lse_host_result,
lse_host_ref, lse_host_ref,
......
...@@ -185,7 +185,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) ...@@ -185,7 +185,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.nhead_stride_randval, args.nhead_stride_randval,
args.nhead_stride_lse, args.nhead_stride_lse,
args.nhead_stride_o, args.nhead_stride_o,
args.batch_stride_lse,
args.window_size_left, args.window_size_left,
args.window_size_right, args.window_size_right,
args.mask_type, args.mask_type,
...@@ -284,7 +283,6 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args) ...@@ -284,7 +283,6 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
args.nhead_stride_randval, args.nhead_stride_randval,
args.nhead_stride_lse_acc, args.nhead_stride_lse_acc,
args.nhead_stride_o_acc, args.nhead_stride_o_acc,
args.batch_stride_lse_acc,
args.batch_stride_o_acc, args.batch_stride_o_acc,
args.split_stride_lse_acc, args.split_stride_lse_acc,
args.split_stride_o_acc, args.split_stride_o_acc,
...@@ -376,9 +374,7 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args) ...@@ -376,9 +374,7 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args)
args.nhead_stride_o_acc, args.nhead_stride_o_acc,
args.nhead_stride_lse, args.nhead_stride_lse,
args.nhead_stride_o, args.nhead_stride_o,
args.batch_stride_lse_acc,
args.batch_stride_o_acc, args.batch_stride_o_acc,
args.batch_stride_lse,
args.split_stride_lse_acc, args.split_stride_lse_acc,
args.split_stride_o_acc); args.split_stride_o_acc);
} }
......
...@@ -11,18 +11,19 @@ COMMON_ARGS='-v=1' ...@@ -11,18 +11,19 @@ COMMON_ARGS='-v=1'
set -x set -x
for prec in "fp16" "bf16" ; do for prec in "fp16" "bf16" ; do
for perm in 0 1 ; do for perm in 0 1 ; do
for hdim in 32 64 128 ; do for hdim in 32 64 128 256 ; do
for mode in 0 1 ; do for mode in 0 1 ; do
for bias in "n" "e" "a"; do for bias in "n" "a" ; do
for dbias in 0 1 ; do for dbias in 0 ; do
for p_drop in 0.0 0.2; do for p_drop in 0.0 0.2 ; do
for deterministic in 0 ; do
$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -v=1 -deterministic=$deterministic -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
done done
done done
...@@ -31,4 +32,5 @@ done ...@@ -31,4 +32,5 @@ done
done done
done done
done done
done
set +x set +x
...@@ -1341,7 +1341,7 @@ struct modulo : public base_transform<1, 1> ...@@ -1341,7 +1341,7 @@ struct modulo : public base_transform<1, 1>
}; };
// 2D XOR, NOTE: "xor" is a keyword // 2D XOR, NOTE: "xor" is a keyword
template <typename LowLengths, typename RightShift> template <typename LowLengths>
struct xor_t : public base_transform<2, 2> struct xor_t : public base_transform<2, 2>
{ {
static constexpr auto type_enum = coord_transform_enum::xor_t; static constexpr auto type_enum = coord_transform_enum::xor_t;
...@@ -1352,15 +1352,10 @@ struct xor_t : public base_transform<2, 2> ...@@ -1352,15 +1352,10 @@ struct xor_t : public base_transform<2, 2>
using UpLengths = LowLengths; using UpLengths = LowLengths;
UpLengths up_lengths_; UpLengths up_lengths_;
RightShift right_shift_;
CK_TILE_HOST_DEVICE constexpr xor_t() : up_lengths_{}, right_shift_{} {} CK_TILE_HOST_DEVICE constexpr xor_t() : up_lengths_{} {}
CK_TILE_HOST_DEVICE constexpr xor_t(const LowLengths& low_lengths, CK_TILE_HOST_DEVICE constexpr xor_t(const LowLengths& low_lengths) : up_lengths_{low_lengths} {}
const RightShift& right_shift)
: up_lengths_{low_lengths}, right_shift_{right_shift}
{
}
CK_TILE_HOST_DEVICE static constexpr auto get_type_enum() CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
{ {
...@@ -1378,13 +1373,8 @@ struct xor_t : public base_transform<2, 2> ...@@ -1378,13 +1373,8 @@ struct xor_t : public base_transform<2, 2>
idx_low(number<0>{}) = idx_up[number<0>{}]; idx_low(number<0>{}) = idx_up[number<0>{}];
const auto idx_low_1_tmp = idx_low(number<1>{}) =
(idx_up[number<1>{}] - idx_up[number<0>{}] * right_shift_) % up_lengths_[number<1>{}]; idx_up[number<1>{}] ^ (idx_up[number<0>{}] % up_lengths_[number<1>{}]);
const auto idx_low_1 =
(idx_low_1_tmp >= 0) ? idx_low_1_tmp : up_lengths_[number<1>{}] + idx_low_1_tmp;
idx_low(number<1>{}) = idx_low_1;
} }
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx> template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
...@@ -1419,8 +1409,7 @@ struct xor_t : public base_transform<2, 2> ...@@ -1419,8 +1409,7 @@ struct xor_t : public base_transform<2, 2>
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{ {
return ck_tile::is_known_at_compile_time<UpLengths>::value && return ck_tile::is_known_at_compile_time<UpLengths>::value;
ck_tile::is_known_at_compile_time<RightShift>::value;
} }
// MUST be static function // MUST be static function
...@@ -1432,14 +1421,6 @@ struct xor_t : public base_transform<2, 2> ...@@ -1432,14 +1421,6 @@ struct xor_t : public base_transform<2, 2>
array<index_t, 2> up_vector_lengths = low_vector_lengths; array<index_t, 2> up_vector_lengths = low_vector_lengths;
array<index_t, 2> up_vector_strides = low_vector_strides; array<index_t, 2> up_vector_strides = low_vector_strides;
if constexpr(ck_tile::is_known_at_compile_time<RightShift>::value)
{
if(low_vector_lengths[1] != -1)
{
up_vector_lengths(1) = gcd(low_vector_lengths[1], abs(right_shift_));
}
}
return make_tuple(up_vector_lengths, up_vector_strides); return make_tuple(up_vector_lengths, up_vector_strides);
} }
...@@ -1452,10 +1433,6 @@ struct xor_t : public base_transform<2, 2> ...@@ -1452,10 +1433,6 @@ struct xor_t : public base_transform<2, 2>
print(up_lengths_); print(up_lengths_);
printf(", "); printf(", ");
//
printf("right_shift_: ");
print(right_shift_);
printf("}"); printf("}");
} }
}; };
...@@ -1655,11 +1632,10 @@ CK_TILE_HOST_DEVICE constexpr auto make_modulo_transform(const Modulus& modulus, ...@@ -1655,11 +1632,10 @@ CK_TILE_HOST_DEVICE constexpr auto make_modulo_transform(const Modulus& modulus,
return modulo<Modulus, UpLength>{modulus, up_length}; return modulo<Modulus, UpLength>{modulus, up_length};
} }
template <typename LowLengths, typename RightShift> template <typename LowLengths>
CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths& low_lengths, CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths& low_lengths)
const RightShift& right_shift)
{ {
return xor_t<LowLengths, RightShift>{low_lengths, right_shift}; return xor_t<LowLengths>{low_lengths};
} }
template <typename LowLength, typename OffsetLength> template <typename LowLength, typename OffsetLength>
......
...@@ -117,6 +117,15 @@ using int32x16_t = int32_t __attribute__((ext_vector_type(16))); ...@@ -117,6 +117,15 @@ using int32x16_t = int32_t __attribute__((ext_vector_type(16)));
using int32x32_t = int32_t __attribute__((ext_vector_type(32))); using int32x32_t = int32_t __attribute__((ext_vector_type(32)));
using int32x64_t = int32_t __attribute__((ext_vector_type(64))); using int32x64_t = int32_t __attribute__((ext_vector_type(64)));
// u32
// using uint32_t = ...
using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
using uint32x4_t = uint32_t __attribute__((ext_vector_type(4)));
using uint32x8_t = uint32_t __attribute__((ext_vector_type(8)));
using uint32x16_t = uint32_t __attribute__((ext_vector_type(16)));
using uint32x32_t = uint32_t __attribute__((ext_vector_type(32)));
using uint32x64_t = uint32_t __attribute__((ext_vector_type(64)));
// i16 // i16
// using int16_t = ... // using int16_t = ...
using int16x2_t = int16_t __attribute__((ext_vector_type(2))); using int16x2_t = int16_t __attribute__((ext_vector_type(2)));
......
...@@ -746,8 +746,9 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x( ...@@ -746,8 +746,9 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
return make_tuple( return make_tuple(
make_static_tile_distribution( make_static_tile_distribution(
tile_distribution_encoding<typename Encoding::RsLengths, tile_distribution_encoding<typename Encoding::RsLengths,
decltype(sliced_h_lengths), // only need to change the remove_cvref_t<decltype(sliced_h_lengths)>, // only need to
// h_lengths type // change the
// h_lengths type
typename Encoding::Ps2RHssMajor, typename Encoding::Ps2RHssMajor,
typename Encoding::Ps2RHssMinor, typename Encoding::Ps2RHssMinor,
typename Encoding::Ys2RHsMajor, typename Encoding::Ys2RHsMajor,
......
...@@ -53,6 +53,39 @@ class philox ...@@ -53,6 +53,39 @@ class philox
out_tmp[3] = tmp_ph.w; out_tmp[3] = tmp_ph.w;
} }
CK_TILE_HOST_DEVICE void get_random_8x8(uint8_t* out,
const unsigned long long subsequence,
const index_t start_idx) const
{
uint4 tmp_ph;
tmp_ph = get_philox_4x32(subsequence);
uint32x4_t tmp;
tmp[0] = tmp_ph.x;
tmp[1] = tmp_ph.y;
tmp[2] = tmp_ph.z;
tmp[3] = tmp_ph.w;
uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
out_tmp[0] = tmp[start_idx];
out_tmp[1] = tmp[start_idx + 2];
}
CK_TILE_HOST_DEVICE void get_random_4x8(uint8_t* out,
const unsigned long long subsequence,
const index_t start_idx) const
{
uint4 tmp_ph;
tmp_ph = get_philox_4x32(subsequence);
uint32x4_t tmp;
tmp[0] = tmp_ph.x;
tmp[1] = tmp_ph.y;
tmp[2] = tmp_ph.z;
tmp[3] = tmp_ph.w;
uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
out_tmp[0] = tmp[start_idx];
}
private: private:
struct ull2 struct ull2
{ {
......
...@@ -8,21 +8,16 @@ ...@@ -8,21 +8,16 @@
#include "ck_tile/ops/fmha/block/block_masking.hpp" #include "ck_tile/ops/fmha/block/block_masking.hpp"
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp" #include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename BlockFmhaShape_>
struct FmhaBwdTilePartitioner
{
using BlockFmhaShape = ck_tile::remove_cvref_t<BlockFmhaShape_>;
static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0;
CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_k_, kN0), nhead_, batch_size_);
}
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_k*/)
{
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
}
};
template <ck_tile::index_t kBlockSize>
struct FmhaBwdOGradDotOTilePartitioner
{
CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kBlockSize), nhead_, batch_size_);
}
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/)
{
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
}
};
} // namespace ck_tile
...@@ -86,7 +86,7 @@ struct FmhaFwdKernel ...@@ -86,7 +86,7 @@ struct FmhaFwdKernel
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + "w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" ); (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" );
#undef _SS_ #undef _SS_
#undef _TS_ #undef _TS_
...@@ -387,7 +387,6 @@ struct FmhaFwdKernel ...@@ -387,7 +387,6 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t window_size_left, ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right, ck_tile::index_t window_size_right,
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
...@@ -448,7 +447,6 @@ struct FmhaFwdKernel ...@@ -448,7 +447,6 @@ struct FmhaFwdKernel
{ {
kargs.lse_ptr = lse_ptr; kargs.lse_ptr = lse_ptr;
kargs.nhead_stride_lse = nhead_stride_lse; kargs.nhead_stride_lse = nhead_stride_lse;
kargs.batch_stride_lse = batch_stride_lse;
} }
if constexpr(kDoFp8StaticQuant) if constexpr(kDoFp8StaticQuant)
{ {
...@@ -524,7 +522,7 @@ struct FmhaFwdKernel ...@@ -524,7 +522,7 @@ struct FmhaFwdKernel
} }
if constexpr(kStoreLSE) if constexpr(kStoreLSE)
{ {
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse; batch_offset_lse = query_start;
} }
if constexpr(kHasDropout) if constexpr(kHasDropout)
{ {
......
...@@ -55,7 +55,7 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -55,7 +55,7 @@ struct FmhaFwdSplitKVCombineKernel
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) +
_SS_(FmhaPipeline::name) + _SS_(FmhaPipeline::name) +
(pn.empty() ? "" : "_" + pn) + (pn.empty() ? "" : "_" + pn) +
(kStoreLSE ? "_lse" : "" ) + (kStoreLSE ? "_lse" : "" ) +
(kDoFp8StaticQuant ? "_squant" : "" ); (kDoFp8StaticQuant ? "_squant" : "" );
#undef _SS_ #undef _SS_
#undef _TS_ #undef _TS_
...@@ -91,7 +91,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -91,7 +91,6 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile::index_t nhead_stride_o_acc; ck_tile::index_t nhead_stride_o_acc;
ck_tile::index_t nhead_stride_o; ck_tile::index_t nhead_stride_o;
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc; ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t split_stride_lse_acc; ck_tile::index_t split_stride_lse_acc;
...@@ -116,6 +115,7 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -116,6 +115,7 @@ struct FmhaFwdSplitKVCombineKernel
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<1>> std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<1>>
{ {
ck_tile::index_t batch_stride_o; ck_tile::index_t batch_stride_o;
ck_tile::index_t batch_stride_lse_acc;
}; };
struct GroupModeKargs struct GroupModeKargs
...@@ -166,13 +166,13 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -166,13 +166,13 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc, nhead_stride_lse_acc,
nhead_stride_o_acc, nhead_stride_o_acc,
nhead_stride_o, nhead_stride_o,
batch_stride_lse_acc,
batch_stride_o_acc, batch_stride_o_acc,
split_stride_lse_acc, split_stride_lse_acc,
split_stride_o_acc}, // args for common karg split_stride_o_acc}, // args for common karg
{}, // placeholder for lse {}, // placeholder for lse
{}, // placeholder for fp8_static_quant args {}, // placeholder for fp8_static_quant args
batch_stride_o}; batch_stride_o,
batch_stride_lse_acc};
if constexpr(kStoreLSE) if constexpr(kStoreLSE)
{ {
...@@ -206,9 +206,7 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -206,9 +206,7 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_lse_acc,
ck_tile::index_t batch_stride_o_acc, ck_tile::index_t batch_stride_o_acc,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_lse_acc,
ck_tile::index_t split_stride_o_acc) ck_tile::index_t split_stride_o_acc)
{ {
...@@ -225,7 +223,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -225,7 +223,6 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc, nhead_stride_lse_acc,
nhead_stride_o_acc, nhead_stride_o_acc,
nhead_stride_o, nhead_stride_o,
batch_stride_lse_acc,
batch_stride_o_acc, batch_stride_o_acc,
split_stride_lse_acc, split_stride_lse_acc,
split_stride_o_acc}, // args for common karg split_stride_o_acc}, // args for common karg
...@@ -237,7 +234,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -237,7 +234,6 @@ struct FmhaFwdSplitKVCombineKernel
{ {
kargs.lse_ptr = lse_ptr; kargs.lse_ptr = lse_ptr;
kargs.nhead_stride_lse = nhead_stride_lse; kargs.nhead_stride_lse = nhead_stride_lse;
kargs.batch_stride_lse = batch_stride_lse;
} }
if constexpr(kDoFp8StaticQuant) if constexpr(kDoFp8StaticQuant)
{ {
...@@ -274,24 +270,25 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -274,24 +270,25 @@ struct FmhaFwdSplitKVCombineKernel
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
const long_index_t batch_offset_lse_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
const long_index_t batch_offset_o_acc = const long_index_t batch_offset_o_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc; static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
long_index_t batch_offset_lse = 0;
long_index_t batch_offset_o = 0;
if constexpr(kStoreLSE) long_index_t batch_offset_lse_acc = 0;
{ long_index_t batch_offset_lse = 0;
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse; long_index_t batch_offset_o = 0;
}
if constexpr(kIsGroupMode) if constexpr(kIsGroupMode)
{ {
// get starting offset for each batch // get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
batch_offset_o = query_start * kargs.row_stride_o; batch_offset_o = query_start * kargs.row_stride_o;
batch_offset_lse_acc = query_start;
if constexpr(kStoreLSE)
{
batch_offset_lse = query_start;
}
// get real # queries & # keys under group mode // get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
...@@ -306,7 +303,13 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -306,7 +303,13 @@ struct FmhaFwdSplitKVCombineKernel
} }
else else
{ {
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o; batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
if constexpr(kStoreLSE)
{
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
}
} }
// for simplicity, batch stride we just modify the pointer // for simplicity, batch stride we just modify the pointer
......
...@@ -85,7 +85,7 @@ struct FmhaFwdSplitKVKernel ...@@ -85,7 +85,7 @@ struct FmhaFwdSplitKVKernel
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + "w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" ); (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" );
#undef _SS_ #undef _SS_
#undef _TS_ #undef _TS_
...@@ -136,7 +136,6 @@ struct FmhaFwdSplitKVKernel ...@@ -136,7 +136,6 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t nhead_stride_lse_acc; ck_tile::index_t nhead_stride_lse_acc;
ck_tile::index_t nhead_stride_o_acc; ck_tile::index_t nhead_stride_o_acc;
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc; ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t split_stride_lse_acc; ck_tile::index_t split_stride_lse_acc;
...@@ -216,6 +215,7 @@ struct FmhaFwdSplitKVKernel ...@@ -216,6 +215,7 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_lse_acc;
}; };
struct GroupModeKargs struct GroupModeKargs
...@@ -313,7 +313,6 @@ struct FmhaFwdSplitKVKernel ...@@ -313,7 +313,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v, nhead_stride_v,
nhead_stride_lse_acc, nhead_stride_lse_acc,
nhead_stride_o_acc, nhead_stride_o_acc,
batch_stride_lse_acc,
batch_stride_o_acc, batch_stride_o_acc,
split_stride_lse_acc, split_stride_lse_acc,
split_stride_o_acc}, // args for common karg split_stride_o_acc}, // args for common karg
...@@ -323,7 +322,8 @@ struct FmhaFwdSplitKVKernel ...@@ -323,7 +322,8 @@ struct FmhaFwdSplitKVKernel
{}, // placeholder for dropout {}, // placeholder for dropout
batch_stride_q, batch_stride_q,
batch_stride_k, batch_stride_k,
batch_stride_v}; batch_stride_v,
batch_stride_lse_acc};
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
...@@ -394,7 +394,6 @@ struct FmhaFwdSplitKVKernel ...@@ -394,7 +394,6 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_lse_acc,
ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t batch_stride_lse_acc,
ck_tile::index_t batch_stride_o_acc, ck_tile::index_t batch_stride_o_acc,
ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_lse_acc,
ck_tile::index_t split_stride_o_acc, ck_tile::index_t split_stride_o_acc,
...@@ -433,7 +432,6 @@ struct FmhaFwdSplitKVKernel ...@@ -433,7 +432,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v, nhead_stride_v,
nhead_stride_lse_acc, nhead_stride_lse_acc,
nhead_stride_o_acc, nhead_stride_o_acc,
batch_stride_lse_acc,
batch_stride_o_acc, batch_stride_o_acc,
split_stride_lse_acc, split_stride_lse_acc,
split_stride_o_acc}, // args for common karg split_stride_o_acc}, // args for common karg
...@@ -511,8 +509,7 @@ struct FmhaFwdSplitKVKernel ...@@ -511,8 +509,7 @@ struct FmhaFwdSplitKVKernel
long_index_t batch_offset_v = 0; long_index_t batch_offset_v = 0;
long_index_t batch_offset_bias = 0; long_index_t batch_offset_bias = 0;
long_index_t batch_offset_randval = 0; long_index_t batch_offset_randval = 0;
const long_index_t batch_offset_lse_acc = long_index_t batch_offset_lse_acc = 0;
static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
const long_index_t batch_offset_o_acc = const long_index_t batch_offset_o_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc; static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
...@@ -522,8 +519,9 @@ struct FmhaFwdSplitKVKernel ...@@ -522,8 +519,9 @@ struct FmhaFwdSplitKVKernel
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
batch_offset_q = query_start * kargs.stride_q; batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k; batch_offset_k = key_start * kargs.stride_k;
batch_offset_lse_acc = query_start;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{ {
batch_offset_v = key_start * kargs.stride_v; batch_offset_v = key_start * kargs.stride_v;
...@@ -564,9 +562,10 @@ struct FmhaFwdSplitKVKernel ...@@ -564,9 +562,10 @@ struct FmhaFwdSplitKVKernel
} }
else else
{ {
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q; batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k; batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v; batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias; batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
template <typename Problem, typename Policy = BlockFmhaBwdPipelineDefaultPolicy>
struct BlockFmhaBwdConvertQGrad
{
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
static constexpr index_t kM0 = Problem::kM0;
static constexpr index_t kN0 = Problem::kN0;
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kQKHeaddim = Problem::kQKHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
static constexpr index_t kAlignmentQGradAcc =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentPostQGradAcc<Problem>();
static constexpr index_t kAlignmentQGrad =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentPostQGrad<Problem>();
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; }
// Convert only
template <typename QGradAccDramBlockWindowTmp, typename QGradDramBlockWindowTmp>
CK_TILE_HOST_DEVICE void
operator()(const QGradAccDramBlockWindowTmp& dq_acc_dram_block_window_tmp,
QGradDramBlockWindowTmp& dq_dram_block_window_tmp) const
{
static_assert(
std::is_same_v<AccDataType,
remove_cvref_t<typename QGradAccDramBlockWindowTmp::DataType>> &&
std::is_same_v<QGradDataType,
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}], "wrong!");
auto dq_acc_dram_window =
make_tile_window(dq_acc_dram_block_window_tmp.get_bottom_tensor_view(),
dq_acc_dram_block_window_tmp.get_window_lengths(),
dq_acc_dram_block_window_tmp.get_window_origin(),
Policy::template MakePostQGradDramTileDistribution<Problem>());
auto dq_acc = load_tile(dq_acc_dram_window);
const auto dq = cast_tile<QGradDataType>(dq_acc);
store_tile(dq_dram_block_window_tmp, dq);
}
// Reduce + Convert
template <typename QGradAccDramBlockWindowTmp, typename QGradDramBlockWindowTmp>
CK_TILE_HOST_DEVICE void
operator()(const QGradAccDramBlockWindowTmp& dq_acc_dram_block_window_tmp,
QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
index_t nsplits) const
{
static_assert(
std::is_same_v<AccDataType,
remove_cvref_t<typename QGradAccDramBlockWindowTmp::DataType>> &&
std::is_same_v<QGradDataType,
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}], "wrong!");
auto dq_acc_dram_window =
make_tile_window(dq_acc_dram_block_window_tmp.get_bottom_tensor_view(),
dq_acc_dram_block_window_tmp.get_window_lengths(),
dq_acc_dram_block_window_tmp.get_window_origin(),
Policy::template MakePostQGradAccDramTileDistribution<Problem>());
auto dq_acc = decltype(load_tile(dq_acc_dram_window)){};
clear_tile(dq_acc);
constexpr auto dq_acc_spans = decltype(dq_acc)::get_distributed_spans();
index_t i_total_loops = 0;
auto dq_acc_buf = load_tile(dq_acc_dram_window);
move_tile_window(dq_acc_dram_window, {1, 0, 0});
do
{
sweep_tile_span(dq_acc_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(dq_acc_spans[number<1>{}], [&](auto idx1) {
sweep_tile_span(dq_acc_spans[number<2>{}], [&](auto idx2) {
constexpr auto n_i_j_idx = make_tuple(idx0, idx1, idx2);
dq_acc(n_i_j_idx) += dq_acc_buf(n_i_j_idx);
});
});
});
dq_acc_buf = load_tile(dq_acc_dram_window);
move_tile_window(dq_acc_dram_window, {1, 0, 0});
i_total_loops += 1;
} while(i_total_loops < (nsplits - 1));
sweep_tile_span(dq_acc_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(dq_acc_spans[number<1>{}], [&](auto idx1) {
sweep_tile_span(dq_acc_spans[number<2>{}], [&](auto idx2) {
constexpr auto n_i_j_idx = make_tuple(idx0, idx1, idx2);
dq_acc(n_i_j_idx) += dq_acc_buf(n_i_j_idx);
});
});
});
// declare dq
constexpr auto dq_converted_dstr =
Policy::template MakePostQGradAccDramTileDistribution<Problem>();
auto dq_converted = make_static_distributed_tensor<QGradDataType>(dq_converted_dstr);
sweep_tile_span(dq_acc_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(dq_acc_spans[number<1>{}], [&](auto idx1) {
sweep_tile_span(dq_acc_spans[number<2>{}], [&](auto idx2) {
constexpr auto n_i_j_idx = make_tuple(idx0, idx1, idx2);
dq_converted(n_i_j_idx) = type_convert<QGradDataType>(dq_acc[n_i_j_idx]);
});
});
});
constexpr auto dq_dstr = Policy::template MakePostQGradDramTileDistribution<Problem>();
auto dq = make_static_distributed_tensor<QGradDataType>(dq_dstr);
dq.get_thread_buffer() = dq_converted.get_thread_buffer();
store_tile(dq_dram_block_window_tmp, dq);
}
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment