Commit 72c9f129 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'amd-develop' into amd-master

parents 241c261f ded0d83d
......@@ -87,7 +87,11 @@ auto create_args(int argc, char* argv[])
.insert("drop_offset", "0", "offset for random number generator")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel");
.insert("repeat", "20", "number of iterations to benchmark the kernel")
.insert("deterministic",
"0",
"if set to 1 will use multi-buffer reduction strategy for dq, atomic opeartion "
"will not be used");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
......@@ -128,11 +132,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
if(hdim_v < 0)
hdim_v = hdim_q;
if(hdim_q % 2 != 0 || hdim_v % 2 != 0)
{
std::cerr << "FMHA Bwd kernel currently only supports even headdim" << std::endl;
return false;
}
bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim
bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim
......@@ -180,6 +179,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
int stream_warmup = arg_parser.get_int("warmup");
int stream_repeat = arg_parser.get_int("repeat");
bool kname = arg_parser.get_bool("kname");
bool deterministic = arg_parser.get_bool("deterministic");
ck_tile::stream_config stream_config{nullptr,
true,
......@@ -265,6 +265,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
(mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back());
const ck_tile::index_t shape_seqlen_k =
(mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back());
const ck_tile::index_t kN0 = (hdim_q <= 128) ? 128 : 64;
const ck_tile::index_t nsplits =
deterministic ? ck_tile::integer_divide_ceil(max_seqlen_k, kN0) : 1;
ck_tile::HostTensor<QDataType> q_host(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
......@@ -284,9 +287,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<ODataType> o_host(
get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
ck_tile::HostTensor<LSEDataType> lse_host(
std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q});
std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q});
ck_tile::HostTensor<DDataType> d_host(
std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q});
std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q});
ck_tile::HostTensor<RandValOutputDataType> randval_host(
p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
......@@ -302,6 +305,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
use_dbias
? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<AccDataType> dq_acc_host(
i_perm
? std::array<ck_tile::index_t, 5>{nsplits, shape_batch, nhead, shape_seqlen_q, hdim_q}
: std::array<ck_tile::index_t, 5>{nsplits, shape_batch, shape_seqlen_q, nhead, hdim_q});
if(init_method == 0)
{
......@@ -362,6 +369,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dq_acc_buf(dq_acc_host.get_element_space_size_in_bytes());
q_buf.ToDevice(q_host.data());
k_buf.ToDevice(k_host.data());
......@@ -387,8 +395,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch
<< ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k
<< ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << bias
<< ", dbias:" << use_dbias << ", p_drop:" << p_drop << ", mask:" << mask
<< std::flush;
<< ", dbias:" << use_dbias << ", p_drop:" << p_drop << ", s_randval:" << s_randval
<< ", deterministic:" << deterministic << ", mask:" << mask << std::flush;
std::size_t workspace_size =
dq_acc_host.get_element_space_size_in_bytes() * sizeof(AccDataType) / (1024 * 1024);
if(deterministic == 1)
{
std::cout << "\nDeterministic mode ON: " << workspace_size
<< " MByte memory workspace allocated" << std::endl;
}
auto fmha_traits = fmha_bwd_traits{hdim_q,
hdim_v,
......@@ -397,7 +414,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
mask.type,
bias.type,
use_dbias,
p_drop > 0.0f};
p_drop > 0.0f,
s_randval,
deterministic};
auto fmha_args = [&]() {
assert(nhead % nhead_k == 0);
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
......@@ -422,7 +441,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t nhead_stride_do = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
const ck_tile::index_t nhead_stride_lsed = max_seqlen_q;
const ck_tile::index_t nhead_stride_lsed = shape_seqlen_q;
const ck_tile::index_t nhead_stride_dbias =
(i_perm ? shape_seqlen_q * max_seqlen_k : max_seqlen_k);
// setup batch_stride_* arguments
......@@ -433,10 +452,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_do = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_lsed = (nhead * max_seqlen_q);
const ck_tile::index_t batch_stride_lsed = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q);
const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v);
const ck_tile::index_t batch_stride_dbias = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t split_stride_dq_acc =
(shape_batch * nhead * shape_seqlen_q * hdim_q);
return fmha_bwd_args{q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(),
......@@ -452,6 +473,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
dk_buf.GetDeviceBuffer(),
dv_buf.GetDeviceBuffer(),
dbias_buf.GetDeviceBuffer(),
dq_acc_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(),
seqstart_k.GetDeviceBuffer(),
nullptr,
......@@ -473,6 +495,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
stride_o,
stride_randval,
stride_do,
stride_q, // stride_dq_acc
stride_q, // stride_dq
stride_dk,
stride_dv,
stride_dbias,
......@@ -484,6 +508,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
nhead_stride_randval,
nhead_stride_do,
nhead_stride_lsed,
nhead_stride_q, // nhead_stride_dq_acc
nhead_stride_q, // nhead_stride_dq
nhead_stride_k, // nhead_stride_dk
nhead_stride_v, // nhead_stride_dv
nhead_stride_dbias,
batch_stride_q,
batch_stride_k,
......@@ -493,15 +521,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
batch_stride_randval,
batch_stride_do,
batch_stride_lsed,
batch_stride_q, // batch_stride_dq_acc
batch_stride_q, // batch_stride_dq
batch_stride_dk,
batch_stride_dv,
batch_stride_dbias,
split_stride_dq_acc,
mask.left,
mask.right,
static_cast<ck_tile::index_t>(mask.type),
p_drop,
p_undrop,
s_randval,
{drop_seed, drop_offset}};
}();
......@@ -719,7 +749,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(o_perm) o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[0], idx[1] + query_offset, idx[2]) = self(idx); });
else o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[1] + query_offset, idx[0], idx[2]) = self(idx); });
lse_host_ref.ForEach([&](auto& self, auto idx) { lse_host(wb, idx[0], idx[1]) = self(idx); });
lse_host_ref.ForEach([&](auto& self, auto idx) { lse_host(b, idx[0], idx[1] + query_offset) = self(idx); });
// clang-format on
q_host_refs.push_back(q_host_ref);
......@@ -738,6 +768,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
lse_buf.ToDevice(lse_host.data());
dq_buf.SetZero();
dbias_buf.SetZero();
dq_acc_buf.SetZero();
ck_tile::stream_config stream_config_v{
nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")};
......
......@@ -77,6 +77,7 @@ struct fmha_bwd_args
void* dk_ptr;
void* dv_ptr;
void* dbias_ptr;
void* dq_acc_ptr;
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void* seqlen_k_ptr;
......@@ -97,6 +98,8 @@ struct fmha_bwd_args
ck_tile::index_t stride_o;
ck_tile::index_t stride_randval;
ck_tile::index_t stride_do;
ck_tile::index_t stride_dq_acc;
ck_tile::index_t stride_dq;
ck_tile::index_t stride_dk;
ck_tile::index_t stride_dv;
ck_tile::index_t stride_dbias;
......@@ -108,6 +111,10 @@ struct fmha_bwd_args
ck_tile::index_t nhead_stride_randval;
ck_tile::index_t nhead_stride_do;
ck_tile::index_t nhead_stride_lsed;
ck_tile::index_t nhead_stride_dq_acc;
ck_tile::index_t nhead_stride_dq;
ck_tile::index_t nhead_stride_dk;
ck_tile::index_t nhead_stride_dv;
ck_tile::index_t nhead_stride_dbias;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
......@@ -117,15 +124,17 @@ struct fmha_bwd_args
ck_tile::index_t batch_stride_randval;
ck_tile::index_t batch_stride_do;
ck_tile::index_t batch_stride_lsed;
ck_tile::index_t batch_stride_dq_acc;
ck_tile::index_t batch_stride_dq;
ck_tile::index_t batch_stride_dk;
ck_tile::index_t batch_stride_dv;
ck_tile::index_t batch_stride_dbias;
ck_tile::index_t split_stride_dq_acc;
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t mask_type;
float p_drop;
float p_undrop;
bool s_randval;
std::tuple<uint64_t, uint64_t> drop_seed_offset;
};
......@@ -145,10 +154,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.do_ptr,
args.d_ptr,
args.rand_val_ptr,
args.dq_ptr,
args.dk_ptr,
args.dv_ptr,
args.dbias_ptr,
args.dq_acc_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_k_ptr,
......@@ -163,6 +172,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.stride_bias,
args.stride_randval,
args.stride_do,
args.stride_dq_acc,
args.stride_dk,
args.stride_dv,
args.stride_dbias,
......@@ -173,13 +183,15 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.nhead_stride_randval,
args.nhead_stride_do,
args.nhead_stride_lsed,
args.nhead_stride_dq_acc,
args.nhead_stride_dk,
args.nhead_stride_dv,
args.nhead_stride_dbias,
args.batch_stride_lsed,
args.split_stride_dq_acc,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
}
else
......@@ -192,10 +204,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.do_ptr,
args.d_ptr,
args.rand_val_ptr,
args.dq_ptr,
args.dk_ptr,
args.dv_ptr,
args.dbias_ptr,
args.dq_acc_ptr,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
......@@ -209,6 +221,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.stride_bias,
args.stride_randval,
args.stride_do,
args.stride_dq_acc,
args.stride_dk,
args.stride_dv,
args.stride_dbias,
......@@ -219,6 +232,9 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.nhead_stride_randval,
args.nhead_stride_do,
args.nhead_stride_lsed,
args.nhead_stride_dq_acc,
args.nhead_stride_dk,
args.nhead_stride_dv,
args.nhead_stride_dbias,
args.batch_stride_q,
args.batch_stride_k,
......@@ -227,14 +243,15 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.batch_stride_randval,
args.batch_stride_do,
args.batch_stride_lsed,
args.batch_stride_dq_acc,
args.batch_stride_dk,
args.batch_stride_dv,
args.batch_stride_dbias,
args.split_stride_dq_acc,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
}
}();
......@@ -260,8 +277,7 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
args.stride_o,
args.nhead_stride_do,
args.nhead_stride_o,
args.nhead_stride_lsed,
args.batch_stride_lsed);
args.nhead_stride_lsed);
}
else
{ // create batch mode kernel arguments
......@@ -286,19 +302,59 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
return ck_tile::make_tuple(kargs, grids);
}
template <typename FmhaBwdConvertQGradKernel>
auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
{
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(FmhaBwdConvertQGradKernel::kIsGroupMode)
{
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
args.dq_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.hdim_q,
args.stride_dq,
args.stride_dq_acc,
args.nhead_stride_dq,
args.nhead_stride_dq_acc,
args.split_stride_dq_acc);
}
else
{ // create batch mode kernel arguments
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
args.dq_ptr,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
args.stride_dq,
args.stride_dq_acc,
args.nhead_stride_dq,
args.nhead_stride_dq_acc,
args.batch_stride_dq,
args.batch_stride_dq_acc,
args.split_stride_dq_acc);
}
}();
dim3 grids = FmhaBwdConvertQGradKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q);
return ck_tile::make_tuple(kargs, grids);
}
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
ck_tile::BlockFmhaBwdPipelineEnum FmhaBwdPipelineEnum_,
typename FmhaMask_,
typename FmhaDropout_,
ck_tile::BlockAttentionBiasEnum BiasEnum_,
bool kHasBiasGrad_,
bool kHasDropout_,
bool kPadS_,
bool kPadSK_,
bool kPadD_,
bool kPadDv_>
bool kPadDv_,
bool kIsDeterministic_>
struct fmha_bwd_dq_dk_dv_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
......@@ -306,13 +362,14 @@ struct fmha_bwd_dq_dk_dv_traits_
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr auto FmhaBwdPipelineEnum = FmhaBwdPipelineEnum_;
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
using FmhaDropout = ck_tile::remove_cvref_t<FmhaDropout_>;
static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
static constexpr bool kHasDropout = kHasDropout_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSK = kPadSK_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
};
template <typename Traits_>
......@@ -343,6 +400,31 @@ void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
std::string fmha_bwd_dot_do_o_get_name_();
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
bool kPadS_,
bool kPadD_,
bool kIsDeterministic_>
struct fmha_bwd_convert_dq_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
};
template <typename Traits_>
float fmha_bwd_convert_dq_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
std::string fmha_bwd_convert_dq_get_name_();
// This is the public API, will be generated by script
struct fmha_bwd_traits
{
......@@ -354,6 +436,8 @@ struct fmha_bwd_traits
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool has_dbias;
bool has_dropout;
bool is_store_randval;
bool is_deterministic;
// TODO: padding check is inside this api
};
float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&);
......@@ -4,6 +4,7 @@
#include "fmha_fwd.hpp"
#include "ck_tile/host.hpp"
#include "mask.hpp"
#include "rotary.hpp"
#include "utils.hpp"
#include <array>
......@@ -16,6 +17,10 @@
#include <utility>
#include <vector>
#if CK_TILE_FMHA_FWD_APPENDKV_API && !CK_TILE_FMHA_FWD_SPLITKV_API
#error "we should enable fmha_fwd_splitkv() api in order to cooperate with fmha_fwd_appendkv()"
#endif
template <typename T>
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
{
......@@ -50,7 +55,11 @@ auto create_args(int argc, char* argv[])
"seqlen_q. if group-mode, means the average value of seqlen_q\n"
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary\n"
"also with \"-s=s0,s1,s2...\" comma seperated int to set per batch seqlen(group-mode)")
.insert("s_k", "-1", "seqlen_k, -1 means equal to s")
.insert("s_k", "-1", "seqlen_k (including new key/value), -1 means equal to s")
.insert("s_knew",
"0",
"seqlen_k for new key/value, 0 means not to use this at all; "
"-1 to choose s_knew in [1, s] randomly.")
.insert("s_kpad",
"-1",
"seqlen_k stride between 2 tokens, currently used in group-mode only\n"
......@@ -114,9 +123,14 @@ auto create_args(int argc, char* argv[])
.insert("drop_seed", "1", "seed for random number generator")
.insert("drop_offset", "0", "offset for random number generator")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert(
"rotary_dim", "0", "RoPE rotary dimension. rotary_dim <= 0 means not apply RoPE at all")
.insert("rotary_interleaved", "1", "whether to apply interleaved RoPE")
.insert("num_splits",
"1",
"# of splits for key/value. 0 to determine actual number by heuristic")
.insert("page_block_size", "0", "paged-kvcache block size. 0 means not use paged-kvcahe")
.insert("cache_batch_idx", "0", "whether to use index map to the kvcache")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel");
......@@ -244,20 +258,6 @@ int override_num_splits_if_necessary(
return num_splits;
}
float fmha_fwd_dispatch(fmha_fwd_traits traits,
fmha_fwd_args args,
const ck_tile::stream_config& config)
{
if(1 < args.num_splits)
{
return fmha_fwd_splitkv(traits, args, config);
}
else
{
return fmha_fwd(traits, args, config);
}
}
template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
......@@ -276,11 +276,114 @@ bool run(const ck_tile::ArgParser& arg_parser)
return false;
}
auto [seqlen_qs, seqlen_ks, seqlen_kpads] = decode_seqlen(mode,
std::optional<uint32_t> seed = arg_parser.get_uint32("seed");
if(*seed == 0)
{
seed.reset();
}
ck_tile::index_t hdim_q = arg_parser.get_int("d");
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
if(hdim_v < 0)
hdim_v = hdim_q;
ck_tile::index_t seqlen_knew = arg_parser.get_int("s_knew");
#if !CK_TILE_FMHA_FWD_APPENDKV_API
if(seqlen_knew != 0)
{
std::cerr << "kvcache is not supported. ignoring the 's_knew' option" << std::endl;
seqlen_knew = 0;
}
#endif
if(seqlen_knew < 0)
{
seqlen_knew = randint<ck_tile::index_t>(1, arg_parser.get_int("s"), seed);
}
ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim");
if constexpr(!(std::is_same_v<DataType, ck_tile::fp16_t> ||
std::is_same_v<DataType, ck_tile::bf16_t>))
{
if(0 < rotary_dim)
{
std::cerr << "rotary embedding is only available for data type=fp16|bf16" << std::endl;
return false;
}
}
#if !CK_TILE_FMHA_FWD_APPENDKV_API
else if(0 < rotary_dim)
{
std::cerr << "rotary embedding is not supported. ignoring the 'rotary_dim' option"
<< std::endl;
rotary_dim = 0;
}
#endif
if(!(rotary_dim <= hdim_q))
{
std::cerr << "rotary_dim should be less than or equal to head dim for q" << std::endl;
return false;
}
else if(!(rotary_dim % 16 == 0))
{
std::cerr << "only rotary dimensions divisible by 16 are currently supported" << std::endl;
return false;
}
ck_tile::index_t page_block_size = arg_parser.get_int("page_block_size");
#if !CK_TILE_FMHA_FWD_APPENDKV_API && !CK_TILE_FMHA_FWD_SPLITKV_API
if(0 < page_block_size)
{
std::cerr << "paged-kvcache is not supported. ignoring the 'page_block_size' option"
<< std::endl;
page_block_size = 0;
}
#endif
if(!(page_block_size % 128 == 0))
{
std::cerr << "only paged-kvcache block size divisible by 128 are currently supported"
<< std::endl;
return false;
}
bool use_cache_batch_idx = arg_parser.get_bool("cache_batch_idx");
#if !CK_TILE_FMHA_FWD_APPENDKV_API && !CK_TILE_FMHA_FWD_SPLITKV_API
if(use_cache_batch_idx)
{
std::cerr << "split-kv is not supported. ignoring the 'cache_batch_idx' option"
<< std::endl;
use_cache_batch_idx = false;
}
#endif
if(0 < page_block_size && use_cache_batch_idx)
{
std::cerr << "paged-kvcache does not support cache_batch_idx. ignoring the "
"'cache_batch_idx' option"
<< std::endl;
use_cache_batch_idx = false;
}
// the input tensor layout for kvcache is same as batch mode
const bool need_append_kvcache = (0 < seqlen_knew || 0 < rotary_dim);
const bool use_kvcache = (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size);
if(use_kvcache && mode != mode_enum::batch)
{
std::cerr << "kvcache enabled. ignoring the 'mode' option" << std::endl;
mode = mode_enum::batch;
}
auto [seqlen_qs, seqlen_ks, seqlen_kpads] =
decode_seqlen(mode,
batch,
arg_parser.get_str("s"),
arg_parser.get_str("s_k"),
arg_parser.get_str("s_kpad"));
arg_parser.get_str("s_kpad"),
/*seqlen_k_min=*/0 < seqlen_knew ? seqlen_knew : 0,
use_kvcache);
// compute kvcache seqlen_k (before appending knew/vnew)
auto cache_seqlen_ks = seqlen_ks;
std::transform(cache_seqlen_ks.begin(),
cache_seqlen_ks.end(),
cache_seqlen_ks.begin(),
[&](auto seqlen_k) { return seqlen_k - seqlen_knew; });
#if 0
// clang-format off
......@@ -290,11 +393,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
// clang-format on
#endif
ck_tile::index_t hdim_q = arg_parser.get_int("d");
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
if(hdim_v < 0)
hdim_v = hdim_q;
bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim
bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim
......@@ -357,13 +455,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
std::string init_method = arg_parser.get_str("init");
std::optional<uint32_t> seed = arg_parser.get_uint32("seed");
if(*seed == 0)
const bool is_rotary_interleaved = arg_parser.get_bool("rotary_interleaved");
ck_tile::index_t num_splits = arg_parser.get_int("num_splits");
#if !CK_TILE_FMHA_FWD_SPLITKV_API
if(num_splits != 1)
{
seed.reset();
std::cerr << "split-kv is not supported. ignoring the 'num_splits' option" << std::endl;
num_splits = 1;
}
int num_splits = arg_parser.get_int("num_splits");
#endif
int stream_warmup = arg_parser.get_int("warmup");
int stream_repeat = arg_parser.get_int("repeat");
......@@ -425,6 +527,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
const ck_tile::index_t max_num_page_blocks =
(0 < page_block_size
? batch * std::max(1, ck_tile::integer_divide_ceil(max_seqlen_k, page_block_size))
: 0);
// legalize num_splits according to other options
if(num_splits < 1)
{
......@@ -436,6 +543,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::cerr << "num_splits greater than 128 is not supported" << std::endl;
return false;
}
#if CK_TILE_FMHA_FWD_SPLITKV_API
if(0 < p_drop && (1 < num_splits || use_kvcache))
{
std::cerr << "dropout is not supoprted by split-kv kernels. ignoring the 'p_drop' option"
<< std::endl;
p_drop = 0.0f;
}
#endif
auto get_lengths = [&](bool permute,
ck_tile::index_t b /*batch*/,
......@@ -462,11 +577,26 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<QDataType> q_host(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
ck_tile::HostTensor<KDataType> k_host(
get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
0 < page_block_size
? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q)
: get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
/// NOTICE: always use same shape for knew_host & vnew_host in batch/group mode
ck_tile::HostTensor<KDataType> knew_host(
0 < seqlen_knew
? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_q)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<VDataType> v_host(
is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v)
: get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k));
0 < page_block_size
? (is_v_rowmajor
? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_v)
: get_lengths(i_perm, max_num_page_blocks, nhead_k, hdim_v, page_block_size))
: (is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v)
: get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k)));
ck_tile::HostTensor<VDataType> vnew_host(
0 < seqlen_knew
? (is_v_rowmajor ? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_v)
: get_lengths(i_perm, batch, nhead_k, hdim_v, seqlen_knew))
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<BiasDataType> bias_host(
bias.type == bias_enum::elementwise_bias
? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k)
......@@ -478,17 +608,22 @@ bool run(const ck_tile::ArgParser& arg_parser)
: std::array<ck_tile::index_t, 2>{batch, nhead})
: std::array<ck_tile::index_t, 2>{1, 1});
auto [rotary_cos_host, rotary_sin_host] = generate_rotary_cos_sin<KDataType>(
std::max(shape_seqlen_q, shape_seqlen_k), rotary_dim, seed);
ck_tile::HostTensor<LSEDataType> lse_acc_host(
1 < num_splits ? std::array<ck_tile::index_t, 4>{num_splits, batch, nhead, max_seqlen_q}
1 < num_splits || use_kvcache
? std::array<ck_tile::index_t, 4>{num_splits, shape_batch, nhead, shape_seqlen_q}
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
ck_tile::HostTensor<OaccDataType> o_acc_host(
1 < num_splits
1 < num_splits || use_kvcache
? std::array<ck_tile::index_t, 5>{num_splits, batch, nhead, max_seqlen_q, hdim_v}
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
// self define lse data layout as [batch, nhead, max_seqlen_q]
// batch mode of lse data layout is [batch, nhead, seqlen_q]
// group mode of lse data layout is [nhead, total_seqlen_q]
ck_tile::HostTensor<LSEDataType> lse_host(
lse ? std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q}
lse ? std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q}
: std::array<ck_tile::index_t, 3>{1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<ODataType> o_host(
......@@ -498,39 +633,57 @@ bool run(const ck_tile::ArgParser& arg_parser)
p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
ck_tile::HostTensor<int32_t> block_table_host(
0 < page_block_size ? std::array<ck_tile::index_t, 2>{batch, max_num_page_blocks / batch}
: std::array<ck_tile::index_t, 2>{1, 1});
ck_tile::HostTensor<int32_t> cache_batch_idx_host(use_cache_batch_idx
? std::array<ck_tile::index_t, 1>{batch}
: std::array<ck_tile::index_t, 1>{1});
if(init_method == "ui" || init_method == "0")
{
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f, seed}(q_host);
ck_tile::FillUniformDistributionIntegerValue<KDataType>{-3.f, 3.f, seed}(k_host);
ck_tile::FillUniformDistributionIntegerValue<KDataType>{-3.f, 3.f, seed}(knew_host);
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-3.f, 3.f, seed}(v_host);
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-3.f, 3.f, seed}(vnew_host);
ck_tile::FillUniformDistributionIntegerValue<BiasDataType>{-3.f, 3.f, seed}(bias_host);
}
else if(init_method == "ni")
{
ck_tile::FillNormalDistributionIntegerValue<QDataType>{-3.f, 3.f, seed}(q_host);
ck_tile::FillNormalDistributionIntegerValue<KDataType>{-3.f, 3.f, seed}(k_host);
ck_tile::FillNormalDistributionIntegerValue<KDataType>{-3.f, 3.f, seed}(knew_host);
ck_tile::FillNormalDistributionIntegerValue<VDataType>{-3.f, 3.f, seed}(v_host);
ck_tile::FillNormalDistributionIntegerValue<VDataType>{-3.f, 3.f, seed}(vnew_host);
ck_tile::FillNormalDistributionIntegerValue<BiasDataType>{-3.f, 3.f, seed}(bias_host);
}
else if(init_method == "uf" || init_method == "1")
{
ck_tile::FillUniformDistribution<QDataType>{0.f, 1.f, seed}(q_host);
ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(k_host);
ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(knew_host);
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(v_host);
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(vnew_host);
ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host);
}
else if(init_method == "nf")
{
ck_tile::FillNormalDistribution<QDataType>{0.f, 3.f, seed}(q_host);
ck_tile::FillNormalDistribution<KDataType>{0.f, 3.f, seed}(k_host);
ck_tile::FillNormalDistribution<KDataType>{0.f, 3.f, seed}(knew_host);
ck_tile::FillNormalDistribution<VDataType>{0.f, 3.f, seed}(v_host);
ck_tile::FillNormalDistribution<VDataType>{0.f, 3.f, seed}(vnew_host);
ck_tile::FillNormalDistribution<BiasDataType>{0.f, 3.f, seed}(bias_host);
}
else if(init_method == "tf" || init_method == "2")
{
ck_tile::FillTrigValue<QDataType>{}(q_host);
ck_tile::FillTrigValue<KDataType>{}(k_host);
ck_tile::FillTrigValue<KDataType>{}(knew_host);
ck_tile::FillTrigValue<VDataType>{}(v_host);
ck_tile::FillTrigValue<VDataType>{}(vnew_host);
ck_tile::FillTrigValue<BiasDataType>{}(bias_host);
}
else if(init_method == "ufq" || init_method == "uf:q" ||
......@@ -538,7 +691,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile::FillUniformDistribution<QDataType>{-dtype_max, dtype_max, seed}(q_host);
ck_tile::FillUniformDistribution<KDataType>{-dtype_max, dtype_max, seed}(k_host);
ck_tile::FillUniformDistribution<KDataType>{-dtype_max, dtype_max, seed}(knew_host);
ck_tile::FillUniformDistribution<VDataType>{-dtype_max, dtype_max, seed}(v_host);
ck_tile::FillUniformDistribution<VDataType>{-dtype_max, dtype_max, seed}(vnew_host);
// bias_fp8 = qscale_bias * bias_fp32
float qscale_bias = (dtype_max / range_q) * (dtype_max / range_k);
......@@ -548,7 +703,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(bias.type == bias_enum::alibi)
{
auto slopes = ck_tile::get_alibi_slopes<SaccDataType>(nhead);
assert(slopes.size() == nhead);
assert(slopes.size() == static_cast<std::size_t>(nhead));
if(bias.rank_info == 0)
{
// alibi in 1*h
......@@ -563,10 +718,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
}
iota_shuffle(block_table_host.begin(), block_table_host.end(), 0);
iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0);
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem vnew_buf(vnew_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem o_acc_buf(o_acc_host.get_element_space_size_in_bytes());
......@@ -574,27 +733,41 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqlen_k_buf(seqlen_kpads[0] < 0 ? 0 : seqlen_ks.size() * sizeof(int32_t));
ck_tile::DeviceMem seqlen_k_buf(
use_kvcache || 0 <= seqlen_kpads[0] ? seqlen_ks.size() * sizeof(int32_t) : 0);
ck_tile::DeviceMem cache_seqlen_k_buf(
need_append_kvcache ? cache_seqlen_ks.size() * sizeof(int32_t) : 0);
ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem rotary_sin_buf(rotary_sin_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem block_table_buf(block_table_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem cache_batch_idx_buf(cache_batch_idx_host.get_element_space_size_in_bytes());
q_buf.ToDevice(q_host.data());
k_buf.ToDevice(k_host.data());
knew_buf.ToDevice(knew_host.data());
v_buf.ToDevice(v_host.data());
vnew_buf.ToDevice(vnew_host.data());
bias_buf.ToDevice(bias_host.data());
seqstart_q.ToDevice(seqstart_q_host.data());
seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data()
: seqstart_k_with_padding_host.data());
seqlen_k_buf.ToDevice(seqlen_kpads[0] < 0 ? nullptr : seqlen_ks.data());
seqlen_k_buf.ToDevice(use_kvcache || 0 <= seqlen_kpads[0] ? seqlen_ks.data() : nullptr);
cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr);
rotary_cos_buf.ToDevice(rotary_cos_host.data());
rotary_sin_buf.ToDevice(rotary_sin_host.data());
alibi_slope_buf.ToDevice(alibi_slope_host.data());
block_table_buf.ToDevice(block_table_host.data());
cache_batch_idx_buf.ToDevice(cache_batch_idx_host.data());
// clang-format off
auto layout_str = [&](bool permute){
if (permute) return std::string("bhsd");
if(permute) return std::string("bhsd");
else return std::string("bshd");
};
auto io_layout = [&](bool iperm_, bool operm_) {
if (iperm_ == operm_) return layout_str(iperm_);
if(iperm_ == operm_) return layout_str(iperm_);
else return layout_str(iperm_) + std::string("-") + layout_str(operm_);
};
// clang-format on
......@@ -607,39 +780,57 @@ bool run(const ck_tile::ArgParser& arg_parser)
<< ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias
<< ", p_drop:" << p_drop << ", lse:" << lse << ", squant:" << squant
<< ", mask:" << mask << ", v:" << vlayout;
#if CK_TILE_FMHA_FWD_APPENDKV_API
if(0 < rotary_dim)
{
std::cout << ", rotary_dim:" << rotary_dim << "("
<< (is_rotary_interleaved ? "inter" : "half") << ")";
}
#endif
#if CK_TILE_FMHA_FWD_SPLITKV_API
if(1 < num_splits)
{
std::cout << ", num_splits:" << num_splits;
}
if(0 < page_block_size)
{
std::cout << ", page_block_size:" << page_block_size;
}
if(use_cache_batch_idx)
{
std::cout << ", cache_batch_idx:" << use_cache_batch_idx;
}
#endif
std::cout << std::flush;
auto fmha_traits = fmha_fwd_traits{hdim_q,
hdim_v,
data_type,
mode == mode_enum::group,
is_v_rowmajor,
mask.type,
bias.type,
lse,
p_drop > 0.0f,
squant};
const auto init_traits = [&](auto& traits) {
traits.hdim_q = hdim_q;
traits.hdim_v = hdim_v;
traits.data_type = data_type;
traits.is_v_rowmajor = is_v_rowmajor;
auto p_compute_element_func = [&]() {
if constexpr(std::is_same_v<DataType, ck_tile::fp8_t>)
return ck_tile::scales{scale_p};
else
return ck_tile::identity{};
}();
if constexpr(std::is_same_v<fmha_fwd_appendkv_traits, std::decay_t<decltype(traits)>>)
{
traits.rope_type = (0 < rotary_dim ? (is_rotary_interleaved ? rope_enum::interleaved
: rope_enum::half_rotated)
: rope_enum::none);
}
else // fmha_fwd_traits or fmha_splitkv_traits
{
traits.is_group_mode = (mode == mode_enum::group);
traits.mask_type = mask.type;
traits.bias_type = bias.type;
traits.has_lse = lse;
traits.do_fp8_static_quant = squant;
auto oacc_element_func = [&]() {
if constexpr(std::is_same_v<DataType, ck_tile::fp8_t>)
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
ck_tile::scales{scale_o});
else
return ck_tile::identity{};
}();
if constexpr(std::is_same_v<fmha_fwd_traits, std::decay_t<decltype(traits)>>)
{
traits.has_dropout = (p_drop > 0.0f);
}
}
};
auto fmha_args = [&, k_paddings_ = seqlen_kpads]() {
const auto init_args = [&, k_paddings_ = seqlen_kpads](auto& args) {
assert(nhead % nhead_k == 0);
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
/// seqlen_k] in this example, hence both the 'batch_stride_bias' &
......@@ -647,11 +838,19 @@ bool run(const ck_tile::ArgParser& arg_parser)
// setup stride_* arguments
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
const ck_tile::index_t stride_knew = (i_perm ? hdim_q : nhead_k * hdim_q);
const ck_tile::index_t stride_v = [&]() {
if(is_v_rowmajor)
return i_perm ? hdim_v : nhead_k * hdim_v;
else
return i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k;
return 0 < page_block_size ? (i_perm ? page_block_size : nhead_k * page_block_size)
: (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k);
}();
const ck_tile::index_t stride_vnew = [&]() {
if(is_v_rowmajor)
return i_perm ? hdim_v : nhead_k * hdim_v;
else
return i_perm ? seqlen_knew : nhead_k * seqlen_knew;
}();
const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k);
const ck_tile::index_t stride_randval = (max_seqlen_k);
......@@ -659,103 +858,220 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
// setup nhead_stride_* arguments
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_k =
(0 < page_block_size ? (i_perm ? page_block_size * hdim_q : hdim_q)
: (i_perm ? shape_seqlen_k * hdim_q : hdim_q));
const ck_tile::index_t nhead_stride_knew = (i_perm ? seqlen_knew * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_v = [&]() {
if(is_v_rowmajor)
return i_perm ? shape_seqlen_k * hdim_v : hdim_v;
return 0 < page_block_size ? (i_perm ? page_block_size * hdim_v : hdim_v)
: (i_perm ? shape_seqlen_k * hdim_v : hdim_v);
else
return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k;
return 0 < page_block_size ? (i_perm ? hdim_v * page_block_size : page_block_size)
: (i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k);
}();
const ck_tile::index_t nhead_stride_vnew = [&]() {
if(is_v_rowmajor)
return i_perm ? seqlen_knew * hdim_v : hdim_v;
else
return i_perm ? hdim_v * seqlen_knew : seqlen_knew;
}();
const ck_tile::index_t nhead_stride_bias =
(i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k);
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t nhead_stride_lse = max_seqlen_q;
const ck_tile::index_t nhead_stride_lse_acc = max_seqlen_q;
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
const ck_tile::index_t nhead_stride_lse_acc = shape_seqlen_q;
const ck_tile::index_t nhead_stride_o_acc = (max_seqlen_q * hdim_v);
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
// setup batch_stride_* arguments
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q);
const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k);
const ck_tile::index_t batch_stride_k =
(0 < page_block_size ? (nhead_k * page_block_size * hdim_q)
: (nhead_k * shape_seqlen_k * hdim_q));
const ck_tile::index_t batch_stride_knew = (nhead_k * seqlen_knew * hdim_q);
const ck_tile::index_t batch_stride_v =
(0 < page_block_size ? (nhead_k * hdim_v * page_block_size)
: (nhead_k * hdim_v * shape_seqlen_k));
const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew);
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k);
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_lse = (nhead * max_seqlen_q);
const ck_tile::index_t batch_stride_lse_acc = (nhead * max_seqlen_q);
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_lse_acc = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_o_acc = (nhead * max_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch);
// setup split_stride_* arguments (only used in split-kv kernel)
const ck_tile::index_t split_stride_lse_acc = (batch * nhead * max_seqlen_q);
const ck_tile::index_t split_stride_lse_acc = (shape_batch * nhead * shape_seqlen_q);
const ck_tile::index_t split_stride_o_acc = (batch * nhead * max_seqlen_q * hdim_v);
return fmha_fwd_args{q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(),
v_buf.GetDeviceBuffer(),
bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer()
: bias_buf.GetDeviceBuffer(),
randval_buf.GetDeviceBuffer(),
lse_acc_buf.GetDeviceBuffer(),
o_acc_buf.GetDeviceBuffer(),
lse_buf.GetDeviceBuffer(),
o_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(),
seqstart_k.GetDeviceBuffer(),
k_paddings_[0] < 0 ? nullptr : seqlen_k_buf.GetDeviceBuffer(),
shape_seqlen_q,
shape_seqlen_k,
batch,
max_seqlen_q,
hdim_q,
hdim_v,
nhead,
nhead_k,
num_splits,
scale_s,
scale_p,
scale_o,
stride_q,
stride_k,
stride_v,
bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead)
: stride_bias,
stride_randval,
stride_o_acc,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_randval,
nhead_stride_lse,
nhead_stride_lse_acc,
nhead_stride_o_acc,
nhead_stride_o,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_bias,
batch_stride_randval,
batch_stride_lse,
batch_stride_lse_acc,
batch_stride_o_acc,
batch_stride_o,
split_stride_lse_acc,
split_stride_o_acc,
mask.left,
mask.right,
static_cast<ck_tile::index_t>(mask.type),
p_drop,
s_randval,
{drop_seed, drop_offset}};
args.q_ptr = q_buf.GetDeviceBuffer();
args.k_ptr = k_buf.GetDeviceBuffer();
args.v_ptr = v_buf.GetDeviceBuffer();
args.batch = batch;
args.seqlen_q = shape_seqlen_q; // unused in group mode
args.hdim_q = hdim_q;
args.hdim_v = hdim_v;
args.nhead_q = nhead;
args.nhead_k = nhead_k;
args.stride_q = stride_q;
args.stride_k = stride_k;
args.stride_v = stride_v;
args.nhead_stride_q = nhead_stride_q;
args.nhead_stride_k = nhead_stride_k;
args.nhead_stride_v = nhead_stride_v;
args.batch_stride_q = batch_stride_q;
args.batch_stride_k = batch_stride_k;
args.batch_stride_v = batch_stride_v;
if constexpr(std::is_same_v<fmha_fwd_appendkv_args, std::decay_t<decltype(args)>>)
{
args.knew_ptr = knew_buf.GetDeviceBuffer();
args.vnew_ptr = vnew_buf.GetDeviceBuffer();
args.seqlen_knew = seqlen_knew;
args.seqlen_k_ptr = cache_seqlen_k_buf.GetDeviceBuffer();
args.rotary_cos_ptr = (0 < rotary_dim ? rotary_cos_buf.GetDeviceBuffer() : nullptr);
args.rotary_sin_ptr = (0 < rotary_dim ? rotary_sin_buf.GetDeviceBuffer() : nullptr);
args.rotary_dim = rotary_dim;
args.has_mask = (mask.type != mask_enum::no_mask);
args.block_table_ptr =
(0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr);
args.batch_stride_block_table = batch_stride_block_table;
args.page_block_size = page_block_size;
args.cache_batch_idx =
(use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr);
args.stride_knew = stride_knew;
args.stride_vnew = stride_vnew;
args.nhead_stride_knew = nhead_stride_knew;
args.nhead_stride_vnew = nhead_stride_vnew;
args.batch_stride_knew = batch_stride_knew;
args.batch_stride_vnew = batch_stride_vnew;
}
else // fmha_fwd_args or fmha_fwd_splitkv_args
{
args.bias_ptr = bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer()
: bias_buf.GetDeviceBuffer();
args.lse_ptr = lse_buf.GetDeviceBuffer();
args.o_ptr = o_buf.GetDeviceBuffer();
args.seqstart_q_ptr =
(mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr);
args.seqstart_k_ptr =
(mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr);
args.seqlen_k_ptr =
(use_kvcache || 0 <= k_paddings_[0] ? seqlen_k_buf.GetDeviceBuffer() : nullptr);
args.seqlen_k = shape_seqlen_k; // unused in group mode (or kvcache enabled)
args.max_seqlen_q = max_seqlen_q;
args.scale_s = scale_s;
args.scale_p = scale_p;
args.scale_o = scale_o;
args.stride_bias =
(bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) : stride_bias);
args.stride_o = stride_o;
args.nhead_stride_bias = nhead_stride_bias;
args.nhead_stride_lse = nhead_stride_lse;
args.nhead_stride_o = nhead_stride_o;
args.batch_stride_bias = batch_stride_bias;
args.batch_stride_lse = batch_stride_lse;
args.batch_stride_o = batch_stride_o;
args.window_size_left = mask.left;
args.window_size_right = mask.right;
args.mask_type = static_cast<ck_tile::index_t>(mask.type);
if constexpr(std::is_same_v<fmha_fwd_args, std::decay_t<decltype(args)>>)
{
args.rand_val_ptr = randval_buf.GetDeviceBuffer();
args.stride_randval = stride_randval;
args.nhead_stride_randval = nhead_stride_randval;
args.batch_stride_randval = batch_stride_randval;
args.p_drop = p_drop;
args.s_randval = s_randval;
args.drop_seed_offset = std::tie(drop_seed, drop_offset);
}
else if constexpr(std::is_same_v<fmha_fwd_splitkv_args, std::decay_t<decltype(args)>>)
{
args.lse_acc_ptr = lse_acc_buf.GetDeviceBuffer();
args.o_acc_ptr = o_acc_buf.GetDeviceBuffer();
args.block_table_ptr =
(0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr);
args.batch_stride_block_table = batch_stride_block_table;
args.page_block_size = page_block_size;
args.cache_batch_idx =
(use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr);
args.num_splits = num_splits;
args.stride_o_acc = stride_o_acc;
args.nhead_stride_lse_acc = nhead_stride_lse_acc;
args.nhead_stride_o_acc = nhead_stride_o_acc;
args.batch_stride_lse_acc = batch_stride_lse_acc;
args.batch_stride_o_acc = batch_stride_o_acc;
args.split_stride_lse_acc = split_stride_lse_acc;
args.split_stride_o_acc = split_stride_o_acc;
}
}
};
const float appendkv_ave_time = [&] {
#if CK_TILE_FMHA_FWD_APPENDKV_API
if(need_append_kvcache)
{
fmha_fwd_appendkv_traits fwd_appendkv_traits;
init_traits(fwd_appendkv_traits);
fmha_fwd_appendkv_args fwd_appendkv_args;
init_args(fwd_appendkv_args);
return fmha_fwd_appendkv(fwd_appendkv_traits, fwd_appendkv_args, stream_config);
}
#endif
return 0.0f;
}();
float ave_time = fmha_fwd_dispatch(fmha_traits, fmha_args, stream_config);
const float fwd_ave_time = [&] {
#if CK_TILE_FMHA_FWD_SPLITKV_API
if(1 < num_splits || use_kvcache)
{
fmha_fwd_splitkv_traits fmha_splitkv_traits;
init_traits(fmha_splitkv_traits);
fmha_fwd_splitkv_args fmha_splitkv_args;
init_args(fmha_splitkv_args);
return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_splitkv_args, stream_config);
}
#endif
fmha_fwd_traits fmha_traits;
init_traits(fmha_traits);
fmha_fwd_args fmha_args;
init_args(fmha_args);
return fmha_fwd(fmha_traits, fmha_args, stream_config);
}();
if(ave_time < 0)
if(appendkv_ave_time < 0.0f || fwd_ave_time < 0.0f)
{
std::cout << ", not supported yet" << std::flush << std::endl;
return false;
}
const float ave_time = (appendkv_ave_time + fwd_ave_time);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
......@@ -773,36 +1089,46 @@ bool run(const ck_tile::ArgParser& arg_parser)
o_buf.FromDevice(o_host.data());
lse_buf.FromDevice(lse_host.data());
randval_buf.FromDevice(randval_host.data());
auto p_compute_element_func = [&]() {
if constexpr(std::is_same_v<DataType, ck_tile::fp8_t>)
return ck_tile::scales{scale_p};
else
return ck_tile::identity{};
}();
auto oacc_element_func = [&]() {
if constexpr(std::is_same_v<DataType, ck_tile::fp8_t>)
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
ck_tile::scales{scale_o});
else
return ck_tile::identity{};
}();
float p_undrop = 1.0 - p_drop;
uint8_t p_undrop_in_uint8_t =
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
float rp_undrop = 1.0 / p_undrop;
bool pass = true;
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
{
const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
// adjust matrix index according to the mode
const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0);
const ck_tile::index_t b_idx = (mode == mode_enum::batch ? wb : 0);
const ck_tile::index_t cache_b_idx =
(use_cache_batch_idx ? cache_batch_idx_host(b_idx) : b_idx);
const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
const ck_tile::index_t key_offset =
(mode == mode_enum::batch
? 0
: (seqlen_kpads[0] < 0 ? seqstart_k_host[wb] : seqstart_k_with_padding_host[wb]));
const auto v_host_ref_lengths =
std::array<ck_tile::index_t, 3>{nhead, hdim_v, real_seqlen_k};
const auto v_host_ref_strides =
is_v_rowmajor
? std::array<ck_tile::index_t, 3>{hdim_v * real_seqlen_k, 1, hdim_v}
: std::array<ck_tile::index_t, 3>{hdim_v * real_seqlen_k, real_seqlen_k, 1};
ck_tile::HostTensor<QDataType> q_host_ref({nhead, real_seqlen_q, hdim_q});
ck_tile::HostTensor<KDataType> k_host_ref({nhead, real_seqlen_k, hdim_q});
ck_tile::HostTensor<VDataType> v_host_ref(v_host_ref_lengths, v_host_ref_strides);
ck_tile::HostTensor<VDataType> v_host_ref({nhead, hdim_v, real_seqlen_k});
ck_tile::HostTensor<ODataType> o_host_ref({nhead, real_seqlen_q, hdim_v});
ck_tile::HostTensor<SMPLComputeDataType> s_host_ref({nhead, real_seqlen_q, real_seqlen_k});
......@@ -813,22 +1139,138 @@ bool run(const ck_tile::ArgParser& arg_parser)
// clang-format off
// permute
if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[0], i[1] + query_offset, i[2]); });
else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[1] + query_offset, i[0], i[2]); });
if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b_idx, i[0], i[1] + query_offset, i[2]); });
else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b_idx, i[1] + query_offset, i[0], i[2]); });
#if CK_TILE_FMHA_FWD_APPENDKV_API
// optionally apply RoPE to the q_host_ref
if(0 < rotary_dim)
{
decltype(q_host_ref) q_host_ref_ro(q_host_ref.get_lengths());
auto [rotary_cos_slice, rotary_sin_slice] =
slice_rotary_cos_sin(rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], real_seqlen_q);
ck_tile::reference_batched_rotary_position_embedding(
q_host_ref, rotary_cos_slice, rotary_sin_slice, is_rotary_interleaved, q_host_ref_ro,
/*use_1_row_sin_cos=*/mask.type == mask_enum::no_mask);
q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host_ref_ro(i); });
}
#endif
#if CK_TILE_FMHA_FWD_SPLITKV_API
if(0 < page_block_size) {
if(i_perm) {
k_host_ref.ForEach([&](auto& self, auto i) {
self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[0] / nr, i[1] % page_block_size, i[2]);
});
} else {
k_host_ref.ForEach([&](auto& self, auto i) {
self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[1] % page_block_size, i[0] / nr, i[2]);
});
}
} else
#endif
{
if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[0] / nr, i[1] + key_offset, i[2]); });
else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[1] + key_offset, i[0] / nr, i[2]); });
}
#if CK_TILE_FMHA_FWD_APPENDKV_API
// copy Knew to the end of K
if(0 < seqlen_knew)
{
ck_tile::HostTensor<KDataType> knew_host_ref({nhead, seqlen_knew, hdim_q});
if(i_perm) knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(wb, i[0] / nr, i[1], i[2]); });
else knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(wb, i[1], i[0] / nr, i[2]); });
// optionally apply RoPE to the knew_host_ref
auto* real_knew_host_ref = &knew_host_ref;
std::optional<decltype(knew_host_ref)> knew_host_ref_ro;
if(0 < rotary_dim)
{
knew_host_ref_ro.emplace(knew_host_ref.get_lengths());
if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[0] / nr, i[1] + key_offset, i[2]); });
else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[1] + key_offset, i[0] / nr, i[2]); });
auto [rotary_cos_slice, rotary_sin_slice] =
slice_rotary_cos_sin(rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], seqlen_knew);
if (is_v_rowmajor) {
ck_tile::reference_batched_rotary_position_embedding(
knew_host_ref,
rotary_cos_slice,
rotary_sin_slice,
is_rotary_interleaved,
knew_host_ref_ro.value());
real_knew_host_ref = &knew_host_ref_ro.value();
}
(*real_knew_host_ref).ForEach([&](auto& self, auto i) {
k_host_ref(i[0], i[1] + cache_seqlen_ks[wb], i[2]) = self(i);
});
}
#endif
#if CK_TILE_FMHA_FWD_SPLITKV_API
if(0 < page_block_size) {
if(is_v_rowmajor) {
if(i_perm) {
v_host_ref.ForEach([&](auto& self, auto i) {
self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[2] % page_block_size, i[1]);
});
} else {
v_host_ref.ForEach([&](auto& self, auto i) {
self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[2] % page_block_size, i[0] / nr, i[1]);
});
}
}
else
{
if(i_perm) {
v_host_ref.ForEach([&](auto& self, auto i) {
self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[1], i[2] % page_block_size);
});
} else {
v_host_ref.ForEach([&](auto& self, auto i) {
self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[1], i[0] / nr, i[2] % page_block_size);
});
}
}
} else
#endif
{
if(is_v_rowmajor) {
// v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d]
if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[2] + key_offset, i[1]); });
if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[0] / nr, i[2] + key_offset, i[1]); });
// v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d]
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[2] + key_offset, i[0] / nr, i[1]); });
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[2] + key_offset, i[0] / nr, i[1]); });
}
else
{
if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[0] / nr, i[1], i[2] + key_offset); });
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[1], i[0] / nr, i[2] + key_offset); });
}
else {
if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[1], i[2] + key_offset); });
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[1], i[0] / nr, i[2] + key_offset); });
}
#if CK_TILE_FMHA_FWD_APPENDKV_API
// copy Vnew to the end of V
if(0 < seqlen_knew)
{
ck_tile::HostTensor<VDataType> vnew_host_ref({nhead, hdim_v, seqlen_knew});
if(is_v_rowmajor)
{
if(i_perm) vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[0] / nr, i[2], i[1]); });
else vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[2], i[0] / nr, i[1]); });
}
else
{
if(i_perm) vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[0] / nr, i[1], i[2]); });
else vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[1], i[0] / nr, i[2]); });
}
vnew_host_ref.ForEach([&](auto& self, auto i) {
v_host_ref(i[0], i[1], i[2] + cache_seqlen_ks[wb]) = self(i);
});
}
#endif
// clang-format on
// reference
......@@ -957,7 +1399,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<RandValOutputDataType> randval_host_ref(
{nhead, real_seqlen_q, real_seqlen_k});
randval_host_ref.ForEach([&](auto& self, auto idx) {
self(idx) = randval_host(b, idx[0], idx[1] + query_offset, idx[2]);
self(idx) = randval_host(b_idx, idx[0], idx[1] + query_offset, idx[2]);
});
ck_tile::reference_batched_dropout(
p_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop);
......@@ -974,8 +1416,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<ODataType> o_host_result({nhead, real_seqlen_q, hdim_v});
// clang-format off
// permute
if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b, idx[0], idx[1] + query_offset, idx[2]); });
else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b, idx[1] + query_offset, idx[0], idx[2]); });
if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[0], idx[1] + query_offset, idx[2]); });
else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); });
// clang-format on
auto [rtol, atol] = get_elimit<DataType>(init_method);
......@@ -996,8 +1438,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(lse)
{
ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q});
lse_host_result.ForEach(
[&](auto& self, auto idx) { self(idx) = lse_host(wb, idx[0], idx[1]); });
lse_host_result.ForEach([&](auto& self, auto idx) {
self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset);
});
cur_pass = ck_tile::check_err(lse_host_result,
lse_host_ref,
......
......@@ -5,10 +5,13 @@
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/fmha.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "mask.hpp"
#include "ck_tile/ops/fmha.hpp"
#include "bias.hpp"
#include "mask.hpp"
#include "rotary.hpp"
#include <type_traits>
template <typename DataType>
......@@ -93,13 +96,86 @@ struct fmha_fwd_args
const void* v_ptr;
const void* bias_ptr; // bias or alibi_slope pointer
void* rand_val_ptr;
void* lse_ptr;
void* o_ptr;
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void*
seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t batch;
ck_tile::index_t max_seqlen_q;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k;
float scale_s;
float scale_p;
float scale_o;
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_v;
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
ck_tile::index_t stride_randval;
ck_tile::index_t stride_o;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_bias;
ck_tile::index_t nhead_stride_randval;
ck_tile::index_t nhead_stride_lse;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_bias;
ck_tile::index_t batch_stride_randval;
ck_tile::index_t batch_stride_lse;
ck_tile::index_t batch_stride_o;
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t mask_type;
float p_drop;
bool s_randval;
std::tuple<uint64_t, uint64_t> drop_seed_offset;
};
struct fmha_fwd_splitkv_args
{
const void* q_ptr;
const void* k_ptr;
const void* v_ptr;
const void* bias_ptr; // bias or alibi_slope pointer
void* lse_acc_ptr;
void* o_acc_ptr;
void* lse_ptr;
void* o_ptr;
void* block_table_ptr;
ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr
const void* cache_batch_idx;
// the real seqlen_q & seqlen_k are decided by following:
// batch mode: seqlen_q = kargs.seqlen_q
// seqlen_k = kargs.seqlen_k
// group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
// kvcache mode (use same kernel as batch mode):
// seqlen_q = kargs.seqlen_q
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void* seqlen_k_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t batch;
......@@ -109,21 +185,21 @@ struct fmha_fwd_args
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k;
ck_tile::index_t num_splits;
float scale_s;
float scale_p;
float scale_o;
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_v;
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
ck_tile::index_t stride_randval;
ck_tile::index_t stride_o_acc;
ck_tile::index_t stride_o;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_bias;
ck_tile::index_t nhead_stride_randval;
ck_tile::index_t nhead_stride_lse;
ck_tile::index_t nhead_stride_lse_acc;
ck_tile::index_t nhead_stride_o_acc;
......@@ -132,19 +208,62 @@ struct fmha_fwd_args
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_bias;
ck_tile::index_t batch_stride_randval;
ck_tile::index_t batch_stride_lse;
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t batch_stride_o;
ck_tile::index_t split_stride_lse_acc;
ck_tile::index_t split_stride_o_acc;
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t mask_type;
float p_drop;
bool s_randval;
std::tuple<uint64_t, uint64_t> drop_seed_offset;
};
struct fmha_fwd_appendkv_args
{
void* q_ptr;
void* k_ptr;
const void* knew_ptr;
void* v_ptr;
const void* vnew_ptr;
const void* seqlen_k_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_knew;
ck_tile::index_t batch;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k;
const void* rotary_cos_ptr; // only used if 'rotary_dim' > 0
const void* rotary_sin_ptr; // only used if 'rotary_dim' > 0
ck_tile::index_t rotary_dim;
bool has_mask;
void* block_table_ptr;
ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr
const void* cache_batch_idx;
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_knew;
ck_tile::index_t stride_v;
ck_tile::index_t stride_vnew;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_knew;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_vnew;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_knew;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_vnew;
};
template <typename FmhaKernel>
......@@ -185,7 +304,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.nhead_stride_randval,
args.nhead_stride_lse,
args.nhead_stride_o,
args.batch_stride_lse,
args.window_size_left,
args.window_size_right,
args.mask_type,
......@@ -245,7 +363,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
}
template <typename Kernel>
auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = [&] {
......@@ -256,11 +374,9 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.rand_val_ptr,
args.lse_acc_ptr,
args.o_acc_ptr,
args.batch,
args.max_seqlen_q,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_k_ptr,
......@@ -275,25 +391,22 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_randval,
args.stride_o_acc,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_randval,
args.nhead_stride_lse_acc,
args.nhead_stride_o_acc,
args.batch_stride_k,
args.batch_stride_v,
args.batch_stride_lse_acc,
args.batch_stride_o_acc,
args.split_stride_lse_acc,
args.split_stride_o_acc,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
args.mask_type);
}
else
{ // create batch mode kernel arguments
......@@ -301,48 +414,45 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.rand_val_ptr,
args.lse_acc_ptr,
args.o_acc_ptr,
args.batch,
args.max_seqlen_q,
args.seqlen_q,
args.seqlen_k,
args.seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.num_splits,
args.block_table_ptr,
args.batch_stride_block_table,
args.page_block_size,
args.cache_batch_idx,
args.scale_s,
args.scale_p,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_randval,
args.stride_o_acc,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_randval,
args.nhead_stride_lse_acc,
args.nhead_stride_o_acc,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_v,
args.batch_stride_bias,
args.batch_stride_randval,
args.batch_stride_lse_acc,
args.batch_stride_o_acc,
args.split_stride_lse_acc,
args.split_stride_o_acc,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
args.mask_type);
}
}();
......@@ -353,7 +463,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
}
template <typename Kernel>
auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args)
auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = [&] {
......@@ -376,9 +486,7 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args)
args.nhead_stride_o_acc,
args.nhead_stride_lse,
args.nhead_stride_o,
args.batch_stride_lse_acc,
args.batch_stride_o_acc,
args.batch_stride_lse,
args.split_stride_lse_acc,
args.split_stride_o_acc);
}
......@@ -414,6 +522,51 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args)
return ck_tile::make_tuple(kargs, grids);
}
template <typename Kernel>
auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = Kernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.knew_ptr,
args.v_ptr,
args.vnew_ptr,
args.seqlen_q,
args.seqlen_k_ptr,
args.seqlen_knew,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.rotary_cos_ptr,
args.rotary_sin_ptr,
args.rotary_dim,
args.has_mask,
args.block_table_ptr,
args.batch_stride_block_table,
args.page_block_size,
args.cache_batch_idx,
args.stride_q,
args.stride_k,
args.stride_knew,
args.stride_v,
args.stride_vnew,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_knew,
args.nhead_stride_v,
args.nhead_stride_vnew,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_knew,
args.batch_stride_v,
args.batch_stride_vnew);
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.seqlen_knew);
return ck_tile::make_tuple(kargs, grids);
}
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <ck_tile::index_t HDim_,
typename DataType_,
......@@ -462,8 +615,52 @@ struct fmha_fwd_traits_
template <typename Traits_>
float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args);
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
ck_tile::index_t kM0_,
ck_tile::index_t kN0_,
ck_tile::index_t kK0_,
ck_tile::index_t kN1_,
ck_tile::index_t kK1_,
ck_tile::index_t kK0BlockLength_,
bool kIsVLayoutRowMajor_,
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
typename FmhaMask_,
ck_tile::BlockAttentionBiasEnum BiasEnum_,
bool kStoreLse_,
bool kDoFp8StaticQuant_,
bool kIsPagedKV_,
bool kPadS_,
bool kPadSK_,
bool kPadD_,
bool kPadDv_>
struct fmha_fwd_splitkv_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr ck_tile::index_t kM0 = kM0_;
static constexpr ck_tile::index_t kN0 = kN0_;
static constexpr ck_tile::index_t kK0 = kK0_;
static constexpr ck_tile::index_t kN1 = kN1_;
static constexpr ck_tile::index_t kK1 = kK1_;
static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_;
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_;
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kStoreLse = kStoreLse_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSK = kPadSK_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
static constexpr bool kIsPagedKV = kIsPagedKV_;
};
template <typename Traits_>
void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_args);
void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args);
template <typename Traits_>
std::string fmha_fwd_splitkv_get_name_();
......@@ -491,11 +688,45 @@ struct fmha_fwd_splitkv_combine_traits_
};
template <typename Traits_>
void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_args);
void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args);
template <typename Traits_>
std::string fmha_fwd_splitkv_combine_get_name_();
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <ck_tile::index_t HDim_,
typename DataType_,
ck_tile::index_t kTileSizeS_,
ck_tile::index_t kTileSizeSk_,
ck_tile::index_t kTileSizeD_,
ck_tile::index_t kTileSizeDv_,
bool kIsVLayoutRowMajor_,
bool kPadS_,
bool kPadSk_,
bool kPadD_,
bool kPadDv_,
ck_tile::RotaryEmbeddingEnum RotaryEnum_,
bool kIsPagedKV_>
struct fmha_fwd_appendkv_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr ck_tile::index_t kTileSizeS = kTileSizeS_;
static constexpr ck_tile::index_t kTileSizeSk = kTileSizeSk_;
static constexpr ck_tile::index_t kTileSizeD = kTileSizeD_;
static constexpr ck_tile::index_t kTileSizeDv = kTileSizeDv_;
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSk = kPadSk_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
static constexpr auto RotaryEnum = RotaryEnum_;
static constexpr bool kIsPagedKV = kIsPagedKV_;
};
template <typename Traits_>
float fmha_fwd_appendkv_(const ck_tile::stream_config&, fmha_fwd_appendkv_args);
// This is the public API, will be generated by script
struct fmha_fwd_traits
{
......@@ -512,4 +743,32 @@ struct fmha_fwd_traits
// TODO: padding check is inside this api
};
float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);
float fmha_fwd_splitkv(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);
struct fmha_fwd_splitkv_traits
{
int hdim_q;
int hdim_v;
std::string data_type;
bool is_group_mode;
bool is_v_rowmajor;
mask_enum mask_type;
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool has_lse;
bool do_fp8_static_quant;
// TODO: padding check is inside this api
};
float fmha_fwd_splitkv(fmha_fwd_splitkv_traits,
fmha_fwd_splitkv_args,
const ck_tile::stream_config&);
struct fmha_fwd_appendkv_traits
{
int hdim_q;
int hdim_v;
std::string data_type;
bool is_v_rowmajor;
rope_enum rope_type;
};
float fmha_fwd_appendkv(fmha_fwd_appendkv_traits,
fmha_fwd_appendkv_args,
const ck_tile::stream_config&);
......@@ -5,25 +5,30 @@
import argparse
from enum import IntEnum
from pathlib import Path
import pkgutil
import sys
from typing import List, Optional
import codegen.ops
from codegen.cmake_config import *
from codegen.ops import (
fmha_fwd,
fmha_fwd_splitkv,
fmha_bwd
)
class HandlerId(IntEnum):
LIST_BLOBS = 0
WRITE_BLOBS = 1
handlers = {
'fwd' : (fmha_fwd.list_blobs, fmha_fwd.write_blobs),
'fwd_splitkv' : (fmha_fwd_splitkv.list_blobs, fmha_fwd_splitkv.write_blobs),
'bwd' : (fmha_bwd.list_blobs, fmha_bwd.write_blobs),
}
# inspect all modules under 'codegen.ops' and register API handlers
ops = []
for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__):
full_module_name = '%s.%s' % (codegen.ops.__name__, module_name)
if full_module_name not in sys.modules:
ops.append(importer.find_spec(module_name).loader.load_module(module_name))
unwanted_prefix = 'fmha_'
handlers = dict(
[(op.__name__[len(unwanted_prefix):] if op.__name__.startswith(unwanted_prefix) else op.__name__,
(op.list_blobs, op.write_blobs)) for op in ops]
)
assert 0 < len(handlers)
def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None:
if output_dir is None:
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <cassert>
#include <cmath>
#include <functional>
#include <iterator>
#include <optional>
#include <random>
#include <tuple>
// keep sync with RotaryEmbeddingEnum
enum class rope_enum
{
none = 0,
interleaved = 1,
half_rotated = 2,
};
template <typename DataType>
std::tuple<ck_tile::HostTensor<DataType>, ck_tile::HostTensor<DataType>>
generate_rotary_cos_sin(ck_tile::index_t seqlen,
ck_tile::index_t rotary_dim,
std::optional<unsigned> seed = std::nullopt)
{
// return dummy tensors if we won't apply RoPE at all
if(rotary_dim <= 0)
{
ck_tile::HostTensor<DataType> dummy({1, 1});
return std::make_tuple(dummy, dummy);
}
std::mt19937 random_engine(seed.has_value() ? *seed : std::random_device{}());
std::uniform_real_distribution<float> generator(0.0f, 1.0f);
const ck_tile::index_t num_rows = seqlen * 2;
const ck_tile::index_t num_cols = rotary_dim / 2;
using std::begin, std::end;
ck_tile::HostTensor<float> angle({num_rows, num_cols});
std::generate(begin(angle), end(angle), [&] { return generator(random_engine) * 2 * M_PI; });
ck_tile::HostTensor<DataType> cos({num_rows, num_cols});
std::transform(begin(angle), end(angle), begin(cos), [](float origin_value) {
return ck_tile::type_convert<DataType>(std::cos(origin_value));
});
ck_tile::HostTensor<DataType> sin({num_rows, num_cols});
std::transform(begin(angle), end(angle), begin(sin), [](float origin_value) {
return ck_tile::type_convert<DataType>(std::sin(origin_value));
});
return std::make_tuple(cos, sin);
}
template <typename DataType>
std::tuple<ck_tile::HostTensor<DataType>, ck_tile::HostTensor<DataType>>
slice_rotary_cos_sin(const ck_tile::HostTensor<DataType>& cos,
const ck_tile::HostTensor<DataType>& sin,
ck_tile::index_t seqlen_offset,
ck_tile::index_t seqlen)
{
assert(cos.get_num_of_dimension() == 2 && sin.get_num_of_dimension() == 2);
assert(cos.get_length(0) == sin.get_length(0) && cos.get_length(1) == sin.get_length(1));
assert(static_cast<std::size_t>(seqlen_offset + seqlen) <= cos.get_length(0));
const ck_tile::index_t num_rows = seqlen;
const ck_tile::index_t num_cols = cos.get_length(1);
ck_tile::HostTensor<DataType> cos_pt({num_rows, num_cols});
cos_pt.ForEach([&](auto& self, auto i) { self(i) = cos(i[0] + seqlen_offset, i[1]); });
ck_tile::HostTensor<DataType> sin_pt({num_rows, num_cols});
sin_pt.ForEach([&](auto& self, auto i) { self(i) = sin(i[0] + seqlen_offset, i[1]); });
return std::make_tuple(cos_pt, sin_pt);
}
#!/bin/sh
# TODO: run this script from CK root
BUILD=build
EXE=$BUILD/bin/tile_example_fmha_bwd
# TODO: run this script from CK root or build directory
EXE="$(find . -name tile_example_fmha_bwd -type f | head -n 1)"
VALID=0
for prec in "fp16" "bf16" ; do
......
#!/bin/sh
# TODO: run this script from CK root
BUILD=build
EXE=$BUILD/bin/tile_example_fmha_fwd
# TODO: run this script from CK root or build directory
EXE="$(find . -name tile_example_fmha_fwd -type f | head -n 1)"
VALID=0
for prec in "fp16" "bf16" ; do
......
#!/bin/sh
# TODO: run this script from CK root
BUILD=build
EXE=$BUILD/bin/tile_example_fmha_bwd
# TODO: run this script from CK root or build directory
EXE="$(find . -name tile_example_fmha_bwd -type f | head -n 1)"
KNAME=1
export CK_WARMUP=0
......@@ -11,18 +10,19 @@ COMMON_ARGS='-v=1'
set -x
for prec in "fp16" "bf16" ; do
for perm in 0 1 ; do
for hdim in 32 64 128 ; do
for hdim in 32 64 128 256 ; do
for mode in 0 1 ; do
for bias in "n" "e" "a"; do
for dbias in 0 1 ; do
for p_drop in 0.0 0.2; do
for bias in "n" "a" ; do
for dbias in 0 ; do
for p_drop in 0.0 0.2 ; do
for deterministic in 0 ; do
$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -v=1 -deterministic=$deterministic -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
done
done
......@@ -31,4 +31,5 @@ done
done
done
done
done
set +x
#!/bin/sh
# TODO: run this script from CK root
BUILD=build
EXE=$BUILD/bin/tile_example_fmha_fwd
#!/bin/bash
# TODO: run this script from CK root or build directory
EXE="$(find . -name tile_example_fmha_fwd -type f | head -n 1)"
KNAME=1
export CK_WARMUP=0
......@@ -10,44 +9,98 @@ export CK_REPEAT=1
COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
# mode=0
# export HIP_VISIBLE_DEVICES=4
set -x
for prec in "fp16" "bf16" ; do
for mode in 1 0 ; do
for perm in 0 1 ; do
for vlayout in "r" "c" ; do
for hdim in 32 64 128 256 ; do
for lse in 0 1 ; do
for bias in "n" "e" "a" ; do
for p_drop in 0.0 0.2; do
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
done
done
done
done
done
done
done
TEST_SPLITKV=0
TEST_APPENDKV=0
# options:
# -s: run splitkv tests
# -a: run appendkv tests
while getopts ":sa" opt; do
case "${opt}" in
s)
TEST_SPLITKV=1
;;
a)
TEST_APPENDKV=1
;;
*)
;;
esac
done
run_fp16_bf16_tests() {
local NUM_SPLITS=(1)
local PAGE_BLOCK_SIZE=(0)
local CACHE_BATCH_IDX=(0)
if [ $TEST_SPLITKV -eq 1 ] ; then
NUM_SPLITS+=(2 3)
PAGE_BLOCK_SIZE+=(128)
CACHE_BATCH_IDX+=(1)
fi
for prec in "fp16" "bf16" ; do
for mode in 1 0 ; do
for perm in 0 1 ; do
for vlayout in "r" "c" ; do
for hdim in 32 64 128 256 ; do
for lse in 0 1 ; do
for bias in "n" "e" "a" ; do
for p_drop in 0.0 0.2 ; do
for num_splits in "${NUM_SPLITS[@]}" ; do
for page_block_size in "${PAGE_BLOCK_SIZE[@]}" ; do
for cache_batch_idx in "${CACHE_BATCH_IDX[@]}" ; do
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
done ; done ; done ; done ; done
done ; done ; done ; done ; done
done ;
}
run_fp8_tests() {
for perm in 0 1 ; do
for bias in "n" "e" "a" ; do
for b in 1 2 ; do
for hdim in 64 128 256 ; do
$EXE -prec=fp8 -init=3 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=c -squant=1 -kname=$KNAME $COMMON_ARGS
done ; done ; done ; done
}
run_fp16_appendkv_tests() {
for s in $(seq 63 1 65) ; do
for s_k in 65 129 ; do
for s_knew in 0 64 $s_k ; do
for hdim in 32 64 128 256 ; do
for ri in 0 1 ; do
for rdim in 0 16 32 $hdim ; do
for page_block_size in 0 128 ; do
for cache_batch_idx in 0 1 ; do
$EXE -prec=fp16 -b=3 -h=3 -d=$hdim -s=$s -s_k=$s_k -s_knew=$s_knew -rotary_dim=$rdim -rotary_interleaved=$ri -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -iperm=1 -operm=1 -kname=1 $COMMON_ARGS
done ; done ; done ; done ; done
done ; done ; done
}
set -x
run_fp16_bf16_tests
run_fp8_tests
if [ $TEST_APPENDKV -eq 1 ] ; then
run_fp16_appendkv_tests
fi
for perm in 0 1 ; do
for bias in "n" "e" "a" ; do
for b in 1 2 ; do
for hdim in 64 128 256 ; do
$EXE -prec=fp8 -init=3 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=c -squant=1 -kname=$KNAME $COMMON_ARGS
done
done
done
done
set +x
\ No newline at end of file
......@@ -3,15 +3,17 @@
#pragma once
#include <algorithm>
#include <cstdint>
#include <cstdlib>
#include <functional>
#include <optional>
#include <ostream>
#include <sstream>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include <functional>
#include <string>
#include "ck_tile/core/container/span.hpp"
......@@ -40,13 +42,17 @@ std::vector<int32_t> to_seqstarts(ck_tile::span<const int32_t> seqlens)
std::vector<int32_t> generate_seqlens(mode_enum mode,
unsigned count,
int32_t seqlen_avg,
int32_t seqlen_min = -1, // if not negative, clamp min
int32_t seqlen_max = -1, // if not negative, clamp max
std::optional<unsigned> seed = std::nullopt)
{
assert(0 < count);
std::vector<int32_t> seqlens(
count, seqlen_max > 0 ? (seqlen_avg < seqlen_max ? seqlen_avg : seqlen_max) : seqlen_avg);
seqlen_min = (0 < seqlen_min ? seqlen_min : 1);
seqlen_max = (0 < seqlen_max ? seqlen_max : std::numeric_limits<int32_t>::max());
assert(seqlen_min <= seqlen_max);
std::vector<int32_t> seqlens(count, std::clamp(seqlen_avg, seqlen_min, seqlen_max));
if(mode == mode_enum::group && 1 < count)
{
......@@ -62,15 +68,15 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
for(unsigned repeat = seqlen_avg * (count / 2); 0 < repeat; --repeat)
{
const size_type to_decrease = next_idx();
// make sure each elements of seqlens is always greater than 0
if(seqlens[to_decrease] == 1)
// make sure each elements of seqlens is in range [seqlen_min, seqlen_max]
if(seqlens[to_decrease] == seqlen_min)
{
continue;
}
const size_type to_increase = (to_decrease + next_step()) % count;
if(seqlen_max > 0 && seqlens[to_increase] >= seqlen_max)
if(seqlens[to_increase] >= seqlen_max)
{
continue;
}
......@@ -86,10 +92,36 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
std::vector<int32_t> generate_seqstarts(mode_enum mode,
unsigned count,
int32_t seqlen_avg,
int32_t seqlen_min = -1,
int32_t seqlen_max = -1,
std::optional<unsigned> seed = std::nullopt)
{
return to_seqstarts(generate_seqlens(mode, count, seqlen_avg, seqlen_max, seed));
return to_seqstarts(generate_seqlens(mode, count, seqlen_avg, seqlen_min, seqlen_max, seed));
}
// return random integer generated uniformly in range [low, high]
template <typename Int = int>
auto randint(Int low, Int high, std::optional<unsigned> seed = std::nullopt)
-> std::enable_if_t<std::is_integral_v<Int>, Int>
{
std::mt19937 engine(seed.has_value() ? *seed : std::random_device{}());
std::uniform_int_distribution<Int> dist(low, high);
return dist(engine);
}
// return random integers generated uniformly in range [low, high]
template <typename Int, typename ForwardIterator>
auto randints(ForwardIterator first,
ForwardIterator last,
Int low,
Int high,
std::optional<unsigned> seed = std::nullopt)
-> std::enable_if_t<std::is_integral_v<Int>>
{
std::mt19937 engine(seed.has_value() ? *seed : std::random_device{}());
std::uniform_int_distribution<Int> dist(low, high);
std::generate(first, last, [&] { return dist(engine); });
}
/*
......@@ -112,6 +144,8 @@ decode_seqlen(mode_enum mode,
std::string q_val,
std::string k_val,
std::string k_pad_val,
ck_tile::index_t seqlen_k_min = 0,
bool use_kvcache = false,
std::optional<unsigned> seed = std::nullopt)
{
#define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str()))
......@@ -119,9 +153,36 @@ decode_seqlen(mode_enum mode,
{
ck_tile::index_t q = _S2I_(q_val);
ck_tile::index_t k = _S2I_(k_val);
auto s_q = std::vector<ck_tile::index_t>(batch, q);
auto s_k = std::vector<ck_tile::index_t>(batch, k < 0 ? q : k);
auto s_k = [&] {
const ck_tile::index_t seqlen_k_max = (k < 0 ? q : k);
std::vector<ck_tile::index_t> seqlen_ks(batch, seqlen_k_max);
if(1 < batch && use_kvcache)
{
// to keep the original s_k value, we always use seqlen_k_max in first batch
randints(std::next(seqlen_ks.begin()),
seqlen_ks.end(),
seqlen_k_min,
seqlen_k_max,
seed);
return seqlen_ks;
}
return seqlen_ks;
}();
auto s_kpad = std::vector<ck_tile::index_t>(batch, -1); // TODO: batch not support k_padding
// s_k should be greater than or equal to seqlen_k_min if provided
if(s_k.back() < seqlen_k_min)
{
std::ostringstream msg;
msg << __FILE__ << ":" << __LINE__ << ": seqlen_k (=" << s_k.back()
<< ") is less than minimum seqlen_k (=" << seqlen_k_min << ")";
throw std::runtime_error(msg.str());
}
return std::make_tuple(s_q, s_k, s_kpad);
}
else
......@@ -149,6 +210,16 @@ decode_seqlen(mode_enum mode,
s_q.push_back(q);
s_k.push_back(k < 0 ? q : k);
s_kpad.push_back(kp);
// s_k should be greater than or equal to seqlen_k_min
if(s_k.back() < seqlen_k_min)
{
std::ostringstream msg;
msg << __FILE__ << ":" << __LINE__ << ": seqlen_k (=" << s_k.back()
<< ") is less than minimum seqlen_k (=" << seqlen_k_min << ")";
throw std::runtime_error(msg.str());
}
idx++;
if(found_q == std::string::npos || idx >= batch)
{
......@@ -160,8 +231,9 @@ decode_seqlen(mode_enum mode,
}
if(idx < batch)
{
auto rem_q = generate_seqlens(mode, batch - idx, s_q.back(), s_kpad.back(), seed);
auto rem_k = generate_seqlens(mode, batch - idx, s_k.back(), s_kpad.back(), seed);
auto rem_q = generate_seqlens(mode, batch - idx, s_q.back(), 1, s_kpad.back(), seed);
auto rem_k =
generate_seqlens(mode, batch - idx, s_k.back(), seqlen_k_min, s_kpad.back(), seed);
s_q.insert(s_q.end(), rem_q.begin(), rem_q.end());
s_k.insert(s_k.end(), rem_k.begin(), rem_k.end());
......@@ -180,3 +252,15 @@ int env_get_int(const char* var_name, int default_int)
r = std::atoi(v);
return r;
}
template <typename RandomAccessIterator, typename Int>
std::enable_if_t<std::is_integral_v<Int>> iota_shuffle(RandomAccessIterator first,
RandomAccessIterator last,
Int value,
std::optional<unsigned> seed = std::nullopt)
{
std::iota(first, last, value);
std::mt19937 engine(seed.has_value() ? *seed : std::random_device{}());
std::shuffle(first, last, engine);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -153,8 +153,8 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// LDS direct loads using inline assembly
#define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 0
// set stochastic rounding as default for f8 conversions
#define CK_USE_SR_F8_CONVERSION 1
// set rounding to nearest even as default for f8 conversions
#define CK_USE_SR_F8_CONVERSION 0
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#ifndef GUARD_CK_FILESYSTEM_HPP_
#define GUARD_CK_FILESYSTEM_HPP_
#include <string>
#include <string_view>
// clang-format off
#if defined(CPPCHECK)
#define CK_HAS_FILESYSTEM 1
#define CK_HAS_FILESYSTEM_TS 1
#elif defined(_WIN32)
#if _MSC_VER >= 1920
#define CK_HAS_FILESYSTEM 1
#define CK_HAS_FILESYSTEM_TS 0
#elif _MSC_VER >= 1900
#define CK_HAS_FILESYSTEM 0
#define CK_HAS_FILESYSTEM_TS 1
#else
#define CK_HAS_FILESYSTEM 0
#define CK_HAS_FILESYSTEM_TS 0
#endif
#elif defined(__has_include)
#if __has_include(<filesystem>) && __cplusplus >= 201703L
#define CK_HAS_FILESYSTEM 1
#else
#define CK_HAS_FILESYSTEM 0
#endif
#if __has_include(<experimental/filesystem>) && __cplusplus >= 201103L
#define CK_HAS_FILESYSTEM_TS 1
#else
#define CK_HAS_FILESYSTEM_TS 0
#endif
#else
#define CK_HAS_FILESYSTEM 0
#define CK_HAS_FILESYSTEM_TS 0
#endif
// clang-format on
#if CK_HAS_FILESYSTEM
#include <filesystem>
#elif CK_HAS_FILESYSTEM_TS
#include <experimental/filesystem>
#else
#error "No filesystem include available"
#endif
namespace CK {
#if CK_HAS_FILESYSTEM
namespace fs = ::std::filesystem;
#elif CK_HAS_FILESYSTEM_TS
namespace fs = ::std::experimental::filesystem;
#endif
} // namespace CK
inline std::string operator+(const std::string_view s, const CK::fs::path& path)
{
return path.string().insert(0, s);
}
inline std::string operator+(const CK::fs::path& path, const std::string_view s)
{
return path.string().append(s);
}
#define FS_ENUM_PERMS_ALL fs::perms::all
#if CK_HAS_FILESYSTEM_TS
#ifdef __linux__
#include <linux/limits.h>
namespace CK {
inline fs::path weakly_canonical(const fs::path& path)
{
std::string result(PATH_MAX, '\0');
std::string p{path.is_relative() ? (fs::current_path() / path).string() : path.string()};
char* retval = realpath(p.c_str(), &result[0]);
return (retval == nullptr) ? path : fs::path{result};
}
} // namespace CK
#else
#error "Not implmeneted!"
#endif
#else
namespace CK {
inline fs::path weakly_canonical(const fs::path& path) { return fs::weakly_canonical(path); }
} // namespace CK
#endif
namespace CK {
#ifdef _WIN32
constexpr std::string_view executable_postfix{".exe"};
constexpr std::string_view library_prefix{""};
constexpr std::string_view dynamic_library_postfix{".dll"};
constexpr std::string_view static_library_postfix{".lib"};
constexpr std::string_view object_file_postfix{".obj"};
#else
constexpr std::string_view executable_postfix{""};
constexpr std::string_view library_prefix{"lib"};
constexpr std::string_view dynamic_library_postfix{".so"};
constexpr std::string_view static_library_postfix{".a"};
constexpr std::string_view object_file_postfix{".o"};
#endif
inline fs::path make_executable_name(const fs::path& path)
{
return path.parent_path() / (path.filename() + executable_postfix);
}
inline fs::path make_dynamic_library_name(const fs::path& path)
{
return path.parent_path() / (library_prefix + path.filename() + dynamic_library_postfix);
}
inline fs::path make_object_file_name(const fs::path& path)
{
return path.parent_path() / (path.filename() + object_file_postfix);
}
inline fs::path make_static_library_name(const fs::path& path)
{
return path.parent_path() / (library_prefix + path.filename() + static_library_postfix);
}
struct FsPathHash
{
std::size_t operator()(const fs::path& path) const { return fs::hash_value(path); }
};
} // namespace CK
#endif // GUARD_CK_FILESYSTEM_HPP_
......@@ -65,6 +65,12 @@ inline bool is_lds_direct_load_supported()
ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942";
}
inline bool is_bf16_atomic_supported()
{
return ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942";
}
inline bool is_gfx101_supported()
{
return ck::get_device_name() == "gfx1010" || ck::get_device_name() == "gfx1011" ||
......
......@@ -53,6 +53,49 @@ struct DeviceGemmMultipleD : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
// GEMM:
// input : A[M, K], B[K, N],
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceGemmMultipleDSplitK : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs,
ck::index_t StrideE,
ck::index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -69,7 +69,7 @@ template <typename ALayout,
typename ComputeTypeB = ComputeTypeA,
typename LDSTypeA = ComputeTypeA,
typename LDSTypeB = ComputeTypeB>
struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK<ALayout,
BLayout,
DsLayout,
CLayout,
......@@ -192,15 +192,11 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
// rotating mem
rotating_mem.Next();
// clear c mem
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
if(arg_.KBatch > 1)
hipGetErrorString(
hipMemsetAsync(arg_.p_c_grid,
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
0,
arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_));
}
};
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
......@@ -234,9 +230,19 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(arg.KBatch > 1)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
kernel_gemm_xdl_cshuffle_v3_multi_d<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
......@@ -246,11 +252,124 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
// Tail number could be One to Seven
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
kernel_gemm_xdl_cshuffle_v3_multi_d<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
......@@ -261,7 +380,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
TailNumber::Full)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
kernel_gemm_xdl_cshuffle_v3_multi_d<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
......@@ -273,8 +392,8 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
......@@ -288,8 +407,8 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
......@@ -303,8 +422,8 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
......@@ -318,8 +437,8 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
......@@ -332,8 +451,8 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
......@@ -347,8 +466,8 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
......@@ -361,12 +480,35 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
// Tail number could be Odd or Even
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
......@@ -375,8 +517,8 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
......@@ -387,11 +529,35 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
}
else
{
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
kernel_gemm_xdl_cshuffle_v3_multi_d<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
......@@ -401,7 +567,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
kernel_gemm_xdl_cshuffle_v3_multi_d<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
......@@ -416,9 +582,19 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(arg.KBatch > 1)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
kernel_gemm_xdl_cshuffle_v3_multi_d<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
......@@ -451,6 +627,11 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
return false;
}
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
{
return false;
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
......@@ -479,6 +660,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
index_t StrideB,
std::array<index_t, NumDTensor> StrideDs,
index_t StrideC,
index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
......@@ -494,7 +676,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
StrideB,
StrideDs,
StrideC,
1,
KBatch,
a_element_op,
b_element_op,
c_element_op};
......@@ -514,6 +696,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs,
index_t StrideC,
index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override
......@@ -529,7 +712,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
StrideB,
StrideDs,
StrideC,
1,
KBatch,
a_element_op,
b_element_op,
c_element_op);
......
......@@ -168,15 +168,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
// rotating mem
rotating_mem.Next();
// clear c mem
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
if(arg_.KBatch > 1)
hipGetErrorString(
hipMemsetAsync(arg_.p_c_grid,
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
0,
arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_));
}
};
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
......@@ -189,15 +185,12 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
arg_);
}
else
{
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
if(arg.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
0,
arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_));
}
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
......@@ -214,8 +207,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(arg.KBatch > 1)
{
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
......@@ -224,7 +215,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
minimum_occupancy>;
Run(kernel);
}
}
else
{
const auto kernel =
......@@ -239,13 +229,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{
if(arg.KBatch > 1)
{
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
......@@ -255,8 +243,8 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
......@@ -266,8 +254,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Two)
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
......@@ -326,8 +313,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Six)
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
......@@ -354,7 +340,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
}
}
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
......@@ -472,8 +457,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
if(arg.KBatch > 1)
{
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
......@@ -496,7 +479,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
Run(kernel);
}
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
......@@ -524,13 +506,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
else
{
if(arg.KBatch > 1)
{
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
......@@ -539,8 +519,8 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
......@@ -548,7 +528,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
Run(kernel);
}
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
......@@ -579,10 +558,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(arg.KBatch > 1)
{
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
......@@ -591,7 +567,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
minimum_occupancy>;
Run(kernel);
}
}
else
{
const auto kernel =
......@@ -628,6 +603,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
return false;
}
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
{
return false;
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
......
......@@ -1039,14 +1039,14 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
return false;
if constexpr(!((NDimSpatial == 1 &&
(is_NWGK_GKXC_NWGC<InLayout, WeiLayout, OutLayout>() ||
is_GNWK_GKXC_GNWC<InLayout, WeiLayout, OutLayout>())) ||
(is_NWGC_GKXC_NWGK<InLayout, WeiLayout, OutLayout>() ||
is_GNWC_GKXC_GNWK<InLayout, WeiLayout, OutLayout>())) ||
(NDimSpatial == 2 &&
(is_NHWGK_GKYXC_NHWGC<InLayout, WeiLayout, OutLayout>() ||
is_GNHWK_GKYXC_GNHWC<InLayout, WeiLayout, OutLayout>())) ||
(is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>() ||
is_GNHWC_GKYXC_GNHWK<InLayout, WeiLayout, OutLayout>())) ||
(NDimSpatial == 3 &&
(is_NDHWGK_GKZYXC_NDHWGC<InLayout, WeiLayout, OutLayout>() ||
is_GNDHWK_GKZYXC_GNDHWC<InLayout, WeiLayout, OutLayout>()))))
(is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>() ||
is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>()))))
{
return false;
}
......
......@@ -864,23 +864,23 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
}
if constexpr(NDimSpatial == 1)
{
if constexpr(!is_GNWK_GKXC_GNWC<InLayout, WeiLayout, OutLayout>())
if constexpr(!is_GNWC_GKXC_GNWK<InLayout, WeiLayout, OutLayout>())
{
return false;
}
}
else if constexpr(NDimSpatial == 2)
{
if constexpr(!(is_NHWGK_GKYXC_NHWGC<InLayout, WeiLayout, OutLayout>() ||
is_GNHWK_GKYXC_GNHWC<InLayout, WeiLayout, OutLayout>()))
if constexpr(!(is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>() ||
is_GNHWC_GKYXC_GNHWK<InLayout, WeiLayout, OutLayout>()))
{
return false;
}
}
else if constexpr(NDimSpatial == 3)
{
if constexpr(!(is_NDHWGK_GKZYXC_NDHWGC<InLayout, WeiLayout, OutLayout>() ||
is_GNDHWK_GKZYXC_GNDHWC<InLayout, WeiLayout, OutLayout>()))
if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>() ||
is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>()))
{
return false;
}
......
......@@ -22,6 +22,7 @@
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
......@@ -191,7 +192,9 @@ template <ck::index_t NDimSpatial,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
index_t NumGroupsToMerge = 1,
typename ComputeTypeA = InDataType,
typename ComputeTypeB = ComputeTypeA>
typename ComputeTypeB = ComputeTypeA,
index_t TransposeTransferSrcScalarPerVector = 1,
index_t TransposeTransferDstScalarPerVector = 1>
struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
: public DeviceGroupedConvBwdWeight<NDimSpatial,
InLayout,
......@@ -216,6 +219,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
using BDataType = InDataType;
using EDataType = WeiDataType;
// If NGCHW then ADataType must be equal to BDataType
static_assert(!(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>()) ||
is_same_v<ADataType, BDataType>);
using AElementwiseOperation = OutElementwiseOperation;
using BElementwiseOperation = InElementwiseOperation;
using CDEElementwiseOperation = WeiElementwiseOperation;
......@@ -351,6 +359,142 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
batch)[I2];
}
static constexpr index_t ClusterLengthMPerBlock =
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
static constexpr index_t ClusterLengthNPerBlock =
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto MakeInputTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[0];
const index_t& N = g_n_c_wis_lengths[1];
const index_t& C = g_n_c_wis_lengths[2];
const index_t& Hi = g_n_c_wis_lengths[3];
const index_t& Wi = g_n_c_wis_lengths[4];
const index_t& GStride = g_n_c_wis_strides[0];
const index_t& NStride = g_n_c_wis_strides[1];
const index_t& CStride = g_n_c_wis_strides[2];
const index_t& HiStride = g_n_c_wis_strides[3];
const index_t& WiStride = g_n_c_wis_strides[4];
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Hi, Wi), make_tuple(NStride, GStride, CStride, HiStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Hi, Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return PadTensorDescriptor(
merged_desc,
make_tuple(MPerBlock / ClusterLengthMPerBlock, NPerBlock / ClusterLengthNPerBlock),
Sequence<true, true>{});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto MakeOutputTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[0];
const index_t& N = g_n_c_wis_lengths[1];
const index_t& C = g_n_c_wis_lengths[2];
const index_t& Hi = g_n_c_wis_lengths[3];
const index_t& Wi = g_n_c_wis_lengths[4];
const index_t& NStride = g_n_c_wis_strides[1];
const index_t HiStride = Wi * G * C;
const index_t WiStride = G * C;
const index_t GStride = C;
const index_t CStride = 1;
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Hi, Wi), make_tuple(NStride, GStride, CStride, HiStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Hi, Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return PadTensorDescriptor(
merged_desc,
make_tuple(MPerBlock / ClusterLengthMPerBlock, NPerBlock / ClusterLengthNPerBlock),
Sequence<true, true>{});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto MakeInputTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[0];
const index_t& N = g_n_c_wis_lengths[1];
const index_t& C = g_n_c_wis_lengths[2];
const index_t& Di = g_n_c_wis_lengths[3];
const index_t& Hi = g_n_c_wis_lengths[4];
const index_t& Wi = g_n_c_wis_lengths[5];
const index_t& GStride = g_n_c_wis_strides[0];
const index_t& NStride = g_n_c_wis_strides[1];
const index_t& CStride = g_n_c_wis_strides[2];
const index_t& DiStride = g_n_c_wis_strides[3];
const index_t& HiStride = g_n_c_wis_strides[4];
const index_t& WiStride = g_n_c_wis_strides[5];
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Di, Hi, Wi),
make_tuple(NStride, GStride, CStride, DiStride, HiStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Di, Hi, Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return PadTensorDescriptor(
merged_desc,
make_tuple(MPerBlock / ClusterLengthMPerBlock, NPerBlock / ClusterLengthNPerBlock),
Sequence<true, true>{});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto MakeOutputTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[0];
const index_t& N = g_n_c_wis_lengths[1];
const index_t& C = g_n_c_wis_lengths[2];
const index_t& Di = g_n_c_wis_lengths[3];
const index_t& Hi = g_n_c_wis_lengths[4];
const index_t& Wi = g_n_c_wis_lengths[5];
const index_t& NStride = g_n_c_wis_strides[1];
const index_t DiStride = Hi * Wi * G * C;
const index_t HiStride = Wi * G * C;
const index_t WiStride = G * C;
const index_t GStride = C;
const index_t CStride = 1;
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Di, Hi, Wi),
make_tuple(NStride, GStride, CStride, DiStride, HiStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Di, Hi, Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return PadTensorDescriptor(
merged_desc,
make_tuple(MPerBlock / ClusterLengthMPerBlock, NPerBlock / ClusterLengthNPerBlock),
Sequence<true, true>{});
}
using InputTransposeDescType =
remove_cvref_t<decltype(MakeInputTransposeDesc<NDimSpatial>({}, {}))>;
using OutputTransposeDescType =
remove_cvref_t<decltype(MakeOutputTransposeDesc<NDimSpatial>({}, {}))>;
using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>());
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
......@@ -407,13 +551,9 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
ComputeTypeA,
ComputeTypeB>;
static constexpr index_t ClusterLengthMPerBlock =
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
static constexpr index_t ClusterLengthNPerBlock =
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
using GridwiseElementwise =
using GridwiseElementwiseCast =
GridwiseElementwise<Tuple<CElementwiseGridDesc_M_N>,
Tuple<CElementwiseGridDesc_M_N>,
Tuple<const AccDataType*>,
......@@ -431,6 +571,24 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
I1,
I1>;
using GridwiseElementwiseTranspose =
GridwiseElementwise<Tuple<InputTransposeDescType>,
Tuple<OutputTransposeDescType>,
Tuple<const ADataType*>,
Tuple<ADataType*>,
Block2TileMapElementwise,
element_wise::PassThrough,
BlockSize,
MPerBlock,
NPerBlock,
MPerBlock / ClusterLengthMPerBlock,
NPerBlock / ClusterLengthNPerBlock,
Sequence<1, 0>,
Sequence<TransposeTransferSrcScalarPerVector>,
Sequence<TransposeTransferDstScalarPerVector>,
I1,
I0>;
// Argument
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
decltype(GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
......@@ -493,6 +651,45 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
end(a_g_n_k_wos_lengths),
begin(output_spatial_lengths_));
std::array<index_t, NDimSpatial + 3> b_g_n_c_wis_strides_transposed =
b_g_n_c_wis_strides;
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
a_g_n_k_wos_strides;
// NGKHW - transpose needed
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
b_g_n_c_wis_strides_transposed[I0] = Conv_C_;
b_g_n_c_wis_strides_transposed[I2] = I1;
a_g_n_k_wos_strides_transposed[I0] = Conv_K_;
a_g_n_k_wos_strides_transposed[I2] = I1;
if constexpr(NDimSpatial == 2)
{
b_g_n_c_wis_strides_transposed[I3] =
input_spatial_lengths_[I1] * Conv_G_ * Conv_C_;
b_g_n_c_wis_strides_transposed[I4] = Conv_G_ * Conv_C_;
a_g_n_k_wos_strides_transposed[I3] =
output_spatial_lengths_[I1] * Conv_G_ * Conv_K_;
a_g_n_k_wos_strides_transposed[I4] = Conv_G_ * Conv_K_;
}
else if constexpr(NDimSpatial == 3)
{
b_g_n_c_wis_strides_transposed[I3] =
input_spatial_lengths_[I1] * input_spatial_lengths_[I2] * Conv_G_ * Conv_C_;
b_g_n_c_wis_strides_transposed[I4] =
input_spatial_lengths_[I2] * Conv_G_ * Conv_C_;
b_g_n_c_wis_strides_transposed[I5] = Conv_G_ * Conv_C_;
a_g_n_k_wos_strides_transposed[I3] = output_spatial_lengths_[I1] *
input_spatial_lengths_[I2] * Conv_G_ *
Conv_K_;
a_g_n_k_wos_strides_transposed[I4] =
input_spatial_lengths_[I2] * Conv_G_ * Conv_K_;
a_g_n_k_wos_strides_transposed[I5] = Conv_G_ * Conv_K_;
}
}
const auto descs =
conv_to_gemm_transformer_v2
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
......@@ -502,9 +699,9 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
input_spatial_lengths_,
filter_spatial_lengths_,
output_spatial_lengths_,
b_g_n_c_wis_strides,
b_g_n_c_wis_strides_transposed,
e_g_k_c_xs_strides,
a_g_n_k_wos_strides,
a_g_n_k_wos_strides_transposed,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
......@@ -540,8 +737,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1);
// A/B/C Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides[0];
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides_transposed[0];
compute_ptr_offset_of_batch_.BatchStrideC_ =
Conv_K_ * Conv_C_ *
std::accumulate(begin(filter_spatial_lengths_),
......@@ -553,13 +750,58 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
ce_grid_desc_m_n_,
GridwiseGemm::CalculateMBlock(GemmM),
GridwiseGemm::CalculateNBlock(GemmN));
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
a_in_transpose_desc_ =
MakeInputTransposeDesc<NDimSpatial>(a_g_n_k_wos_lengths, a_g_n_k_wos_strides);
a_out_transpose_desc_ =
MakeOutputTransposeDesc<NDimSpatial>(a_g_n_k_wos_lengths, a_g_n_k_wos_strides);
b_in_transpose_desc_ =
MakeInputTransposeDesc<NDimSpatial>(b_g_n_c_wis_lengths, b_g_n_c_wis_strides);
b_out_transpose_desc_ =
MakeOutputTransposeDesc<NDimSpatial>(b_g_n_c_wis_lengths, b_g_n_c_wis_strides);
elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{
a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)};
elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapElementwise{
b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)};
}
}
std::size_t GetWorkspaceSizeBytes() const
std::size_t GetWorkspaceATensorSizeBytes() const
{
return sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize();
}
std::size_t GetWorkspaceBTensorSizeBytes() const
{
return sizeof(BDataType) * b_in_transpose_desc_.GetElementSpaceSize();
}
std::size_t GetWorkspaceETensorSizeBytes() const
{
return sizeof(AccDataType) * ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_;
}
std::size_t GetWorkspaceSizeBytes() const
{
// Transpose require workspace for A and B
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes() +
GetWorkspaceETensorSizeBytes();
}
else
{
return GetWorkspaceETensorSizeBytes();
}
}
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
EDataType* p_e_grid_;
......@@ -571,6 +813,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_;
Block2TileMapElementwise elementwise_block_2_ctile_map_;
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_,
elementwise_block_2_ctile_map_transpose_b_;
InputTransposeDescType a_in_transpose_desc_, b_in_transpose_desc_;
OutputTransposeDescType a_out_transpose_desc_, b_out_transpose_desc_;
// for computing batch offset
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_batch_;
......@@ -624,17 +871,23 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
AccDataType* p_c_grid = type_convert<AccDataType*>(arg.p_workspace_);
const ADataType* p_a_grid = arg.p_a_grid_;
const BDataType* p_b_grid = arg.p_b_grid_;
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_) +
arg.GetWorkspaceETensorSizeBytes() / sizeof(BDataType);
p_b_grid =
type_convert<const BDataType*>(arg.p_workspace_) +
(arg.GetWorkspaceETensorSizeBytes() + arg.GetWorkspaceATensorSizeBytes()) /
sizeof(BDataType);
}
// nullptr for output, will be set after workspace set
typename GridwiseGemm::Argument gemm_arg{arg.p_a_grid_,
arg.p_b_grid_,
p_c_grid,
GemmM,
GemmN,
GemmK,
I0,
I0,
I0,
arg.k_batch_};
typename GridwiseGemm::Argument gemm_arg{
p_a_grid, p_b_grid, p_c_grid, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_};
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(
......@@ -651,8 +904,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
arg.a_grid_desc_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch;
const auto clear_workspace = [&]() {
hip_check_error(hipMemsetAsync(
gemm_arg.p_c_grid, 0, arg.GetWorkspaceSizeBytes(), stream_config.stream_id_));
hip_check_error(hipMemsetAsync(gemm_arg.p_c_grid,
0,
arg.GetWorkspaceETensorSizeBytes(),
stream_config.stream_id_));
};
const auto Run = [&](const auto& kernel) {
......@@ -1261,6 +1516,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
float avg_time = 0.f;
auto launch_elementwise_kernel = [&]() {
const AccDataType* p_c_grid = type_convert<const AccDataType*>(arg.p_workspace_);
const index_t grid_size = arg.elementwise_block_2_ctile_map_.CalculateGridSize(
......@@ -1270,7 +1526,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
std::array<index_t, I1> in_out_batch_strides = {
static_cast<index_t>(arg.compute_ptr_offset_of_batch_.BatchStrideC_)};
const auto kernel = kernel_batched_elementwise<GridwiseElementwise,
const auto kernel = kernel_batched_elementwise<GridwiseElementwiseCast,
ck::Tuple<CElementwiseGridDesc_M_N>,
ck::Tuple<CElementwiseGridDesc_M_N>,
ck::Tuple<const AccDataType*>,
......@@ -1296,7 +1552,54 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
in_out_batch_strides);
};
float avg_time = RunGemmV3(arg, stream_config);
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
const index_t grid_size_a =
arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
arg.a_in_transpose_desc_);
const index_t grid_size_b =
arg.elementwise_block_2_ctile_map_transpose_b_.CalculateGridSize(
arg.b_in_transpose_desc_);
ADataType* p_a_out_grid = type_convert<ADataType*>(arg.p_workspace_) +
arg.GetWorkspaceETensorSizeBytes() / sizeof(BDataType);
BDataType* p_b_out_grid =
type_convert<BDataType*>(arg.p_workspace_) +
(arg.GetWorkspaceETensorSizeBytes() + arg.GetWorkspaceATensorSizeBytes()) /
sizeof(BDataType);
auto kernel_transpose = kernel_elementwise_dual<GridwiseElementwiseTranspose,
ck::Tuple<InputTransposeDescType>,
ck::Tuple<InputTransposeDescType>,
ck::Tuple<OutputTransposeDescType>,
ck::Tuple<OutputTransposeDescType>,
ck::Tuple<const ADataType*>,
ck::Tuple<BDataType*>,
Block2TileMapElementwise,
Block2TileMapElementwise,
element_wise::PassThrough>;
avg_time += launch_and_time_kernel(stream_config,
kernel_transpose,
dim3(grid_size_a + grid_size_b),
dim3(BlockSize),
0,
make_tuple(arg.a_in_transpose_desc_),
make_tuple(arg.b_in_transpose_desc_),
make_tuple(arg.a_out_transpose_desc_),
make_tuple(arg.b_out_transpose_desc_),
make_tuple(arg.p_a_grid_),
make_tuple(arg.p_b_grid_),
make_tuple(p_a_out_grid),
make_tuple(p_b_out_grid),
arg.elementwise_block_2_ctile_map_transpose_a_,
arg.elementwise_block_2_ctile_map_transpose_b_,
element_wise::PassThrough{},
grid_size_a);
}
avg_time += RunGemmV3(arg, stream_config);
avg_time += launch_elementwise_kernel();
return avg_time;
}
......@@ -1347,25 +1650,18 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
{
return false;
}
if constexpr(NDimSpatial == 1)
if constexpr(NDimSpatial == 2)
{
if constexpr(!is_GNWK_GKXC_GNWC<InLayout, WeiLayout, OutLayout>())
{
return false;
}
}
else if constexpr(NDimSpatial == 2)
{
if constexpr(!(is_NHWGK_GKYXC_NHWGC<InLayout, WeiLayout, OutLayout>() ||
is_GNHWK_GKYXC_GNHWC<InLayout, WeiLayout, OutLayout>()))
if constexpr(!(is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>() ||
is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>()))
{
return false;
}
}
else if constexpr(NDimSpatial == 3)
{
if constexpr(!(is_NDHWGK_GKZYXC_NDHWGC<InLayout, WeiLayout, OutLayout>() ||
is_GNDHWK_GKZYXC_GNDHWC<InLayout, WeiLayout, OutLayout>()))
if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>()))
{
return false;
}
......@@ -1431,6 +1727,35 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
return false;
}
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
if((arg.Conv_G_ * arg.Conv_C_) % TransposeTransferDstScalarPerVector != 0)
{
return false;
}
if((arg.Conv_G_ * arg.Conv_K_) % TransposeTransferDstScalarPerVector != 0)
{
return false;
}
const index_t input_spatial_acum = ck::accumulate_n<index_t>(
arg.input_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
arg.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
if(input_spatial_acum % TransposeTransferSrcScalarPerVector != 0)
{
return false;
}
if(output_spatial_acum % TransposeTransferSrcScalarPerVector != 0)
{
return false;
}
}
return true;
}
......@@ -1563,8 +1888,17 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< NumGroupsToMerge
<< ">";
<< NumGroupsToMerge;
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>()) {
str << ", TransposeTransferSrcScalarPerVector: "
<< TransposeTransferSrcScalarPerVector <<", "
<< "TransposeTransferDstScalarPerVector: " << TransposeTransferDstScalarPerVector;
}
str << ">";
// clang-format on
return str.str();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment