"symphony/vscode:/vscode.git/clone" did not exist on "c6d1fa2dba5018955fab379cd142f6bd837faa81"
Commit 72c9f129 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'amd-develop' into amd-master

parents 241c261f ded0d83d
...@@ -87,7 +87,11 @@ auto create_args(int argc, char* argv[]) ...@@ -87,7 +87,11 @@ auto create_args(int argc, char* argv[])
.insert("drop_offset", "0", "offset for random number generator") .insert("drop_offset", "0", "offset for random number generator")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel"); .insert("repeat", "20", "number of iterations to benchmark the kernel")
.insert("deterministic",
"0",
"if set to 1 will use multi-buffer reduction strategy for dq, atomic opeartion "
"will not be used");
bool result = arg_parser.parse(argc, argv); bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser); return std::make_tuple(result, arg_parser);
...@@ -128,11 +132,6 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -128,11 +132,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
if(hdim_v < 0) if(hdim_v < 0)
hdim_v = hdim_q; hdim_v = hdim_q;
if(hdim_q % 2 != 0 || hdim_v % 2 != 0)
{
std::cerr << "FMHA Bwd kernel currently only supports even headdim" << std::endl;
return false;
}
bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim
bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim
...@@ -177,9 +176,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -177,9 +176,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
seed.reset(); seed.reset();
} }
int stream_warmup = arg_parser.get_int("warmup"); int stream_warmup = arg_parser.get_int("warmup");
int stream_repeat = arg_parser.get_int("repeat"); int stream_repeat = arg_parser.get_int("repeat");
bool kname = arg_parser.get_bool("kname"); bool kname = arg_parser.get_bool("kname");
bool deterministic = arg_parser.get_bool("deterministic");
ck_tile::stream_config stream_config{nullptr, ck_tile::stream_config stream_config{nullptr,
true, true,
...@@ -265,6 +265,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -265,6 +265,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
(mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back()); (mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back());
const ck_tile::index_t shape_seqlen_k = const ck_tile::index_t shape_seqlen_k =
(mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back()); (mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back());
const ck_tile::index_t kN0 = (hdim_q <= 128) ? 128 : 64;
const ck_tile::index_t nsplits =
deterministic ? ck_tile::integer_divide_ceil(max_seqlen_k, kN0) : 1;
ck_tile::HostTensor<QDataType> q_host( ck_tile::HostTensor<QDataType> q_host(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
...@@ -284,9 +287,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -284,9 +287,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<ODataType> o_host( ck_tile::HostTensor<ODataType> o_host(
get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
ck_tile::HostTensor<LSEDataType> lse_host( ck_tile::HostTensor<LSEDataType> lse_host(
std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q}); std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q});
ck_tile::HostTensor<DDataType> d_host( ck_tile::HostTensor<DDataType> d_host(
std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q}); std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q});
ck_tile::HostTensor<RandValOutputDataType> randval_host( ck_tile::HostTensor<RandValOutputDataType> randval_host(
p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k) p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1}); : std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
...@@ -302,6 +305,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -302,6 +305,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
use_dbias use_dbias
? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k) ? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */); : std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<AccDataType> dq_acc_host(
i_perm
? std::array<ck_tile::index_t, 5>{nsplits, shape_batch, nhead, shape_seqlen_q, hdim_q}
: std::array<ck_tile::index_t, 5>{nsplits, shape_batch, shape_seqlen_q, nhead, hdim_q});
if(init_method == 0) if(init_method == 0)
{ {
...@@ -362,6 +369,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -362,6 +369,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dq_acc_buf(dq_acc_host.get_element_space_size_in_bytes());
q_buf.ToDevice(q_host.data()); q_buf.ToDevice(q_host.data());
k_buf.ToDevice(k_host.data()); k_buf.ToDevice(k_host.data());
...@@ -387,8 +395,17 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -387,8 +395,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch
<< ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k
<< ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << bias << ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << bias
<< ", dbias:" << use_dbias << ", p_drop:" << p_drop << ", mask:" << mask << ", dbias:" << use_dbias << ", p_drop:" << p_drop << ", s_randval:" << s_randval
<< std::flush; << ", deterministic:" << deterministic << ", mask:" << mask << std::flush;
std::size_t workspace_size =
dq_acc_host.get_element_space_size_in_bytes() * sizeof(AccDataType) / (1024 * 1024);
if(deterministic == 1)
{
std::cout << "\nDeterministic mode ON: " << workspace_size
<< " MByte memory workspace allocated" << std::endl;
}
auto fmha_traits = fmha_bwd_traits{hdim_q, auto fmha_traits = fmha_bwd_traits{hdim_q,
hdim_v, hdim_v,
...@@ -397,7 +414,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -397,7 +414,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
mask.type, mask.type,
bias.type, bias.type,
use_dbias, use_dbias,
p_drop > 0.0f}; p_drop > 0.0f,
s_randval,
deterministic};
auto fmha_args = [&]() { auto fmha_args = [&]() {
assert(nhead % nhead_k == 0); assert(nhead % nhead_k == 0);
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
...@@ -422,7 +441,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -422,7 +441,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t nhead_stride_do = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); const ck_tile::index_t nhead_stride_do = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
const ck_tile::index_t nhead_stride_lsed = max_seqlen_q; const ck_tile::index_t nhead_stride_lsed = shape_seqlen_q;
const ck_tile::index_t nhead_stride_dbias = const ck_tile::index_t nhead_stride_dbias =
(i_perm ? shape_seqlen_q * max_seqlen_k : max_seqlen_k); (i_perm ? shape_seqlen_q * max_seqlen_k : max_seqlen_k);
// setup batch_stride_* arguments // setup batch_stride_* arguments
...@@ -433,10 +452,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -433,10 +452,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_do = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_do = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_lsed = (nhead * max_seqlen_q); const ck_tile::index_t batch_stride_lsed = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q); const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q);
const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v); const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v);
const ck_tile::index_t batch_stride_dbias = (nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_dbias = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t split_stride_dq_acc =
(shape_batch * nhead * shape_seqlen_q * hdim_q);
return fmha_bwd_args{q_buf.GetDeviceBuffer(), return fmha_bwd_args{q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(), k_buf.GetDeviceBuffer(),
...@@ -452,6 +473,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -452,6 +473,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
dk_buf.GetDeviceBuffer(), dk_buf.GetDeviceBuffer(),
dv_buf.GetDeviceBuffer(), dv_buf.GetDeviceBuffer(),
dbias_buf.GetDeviceBuffer(), dbias_buf.GetDeviceBuffer(),
dq_acc_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(), seqstart_q.GetDeviceBuffer(),
seqstart_k.GetDeviceBuffer(), seqstart_k.GetDeviceBuffer(),
nullptr, nullptr,
...@@ -473,6 +495,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -473,6 +495,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
stride_o, stride_o,
stride_randval, stride_randval,
stride_do, stride_do,
stride_q, // stride_dq_acc
stride_q, // stride_dq
stride_dk, stride_dk,
stride_dv, stride_dv,
stride_dbias, stride_dbias,
...@@ -484,6 +508,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -484,6 +508,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
nhead_stride_randval, nhead_stride_randval,
nhead_stride_do, nhead_stride_do,
nhead_stride_lsed, nhead_stride_lsed,
nhead_stride_q, // nhead_stride_dq_acc
nhead_stride_q, // nhead_stride_dq
nhead_stride_k, // nhead_stride_dk
nhead_stride_v, // nhead_stride_dv
nhead_stride_dbias, nhead_stride_dbias,
batch_stride_q, batch_stride_q,
batch_stride_k, batch_stride_k,
...@@ -493,15 +521,17 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -493,15 +521,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
batch_stride_randval, batch_stride_randval,
batch_stride_do, batch_stride_do,
batch_stride_lsed, batch_stride_lsed,
batch_stride_q, // batch_stride_dq_acc
batch_stride_q, // batch_stride_dq
batch_stride_dk, batch_stride_dk,
batch_stride_dv, batch_stride_dv,
batch_stride_dbias, batch_stride_dbias,
split_stride_dq_acc,
mask.left, mask.left,
mask.right, mask.right,
static_cast<ck_tile::index_t>(mask.type), static_cast<ck_tile::index_t>(mask.type),
p_drop, p_drop,
p_undrop, p_undrop,
s_randval,
{drop_seed, drop_offset}}; {drop_seed, drop_offset}};
}(); }();
...@@ -719,7 +749,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -719,7 +749,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(o_perm) o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[0], idx[1] + query_offset, idx[2]) = self(idx); }); if(o_perm) o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[0], idx[1] + query_offset, idx[2]) = self(idx); });
else o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[1] + query_offset, idx[0], idx[2]) = self(idx); }); else o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[1] + query_offset, idx[0], idx[2]) = self(idx); });
lse_host_ref.ForEach([&](auto& self, auto idx) { lse_host(wb, idx[0], idx[1]) = self(idx); }); lse_host_ref.ForEach([&](auto& self, auto idx) { lse_host(b, idx[0], idx[1] + query_offset) = self(idx); });
// clang-format on // clang-format on
q_host_refs.push_back(q_host_ref); q_host_refs.push_back(q_host_ref);
...@@ -738,6 +768,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -738,6 +768,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
lse_buf.ToDevice(lse_host.data()); lse_buf.ToDevice(lse_host.data());
dq_buf.SetZero(); dq_buf.SetZero();
dbias_buf.SetZero(); dbias_buf.SetZero();
dq_acc_buf.SetZero();
ck_tile::stream_config stream_config_v{ ck_tile::stream_config stream_config_v{
nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")}; nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")};
......
...@@ -77,6 +77,7 @@ struct fmha_bwd_args ...@@ -77,6 +77,7 @@ struct fmha_bwd_args
void* dk_ptr; void* dk_ptr;
void* dv_ptr; void* dv_ptr;
void* dbias_ptr; void* dbias_ptr;
void* dq_acc_ptr;
const void* seqstart_q_ptr; const void* seqstart_q_ptr;
const void* seqstart_k_ptr; const void* seqstart_k_ptr;
const void* seqlen_k_ptr; const void* seqlen_k_ptr;
...@@ -97,6 +98,8 @@ struct fmha_bwd_args ...@@ -97,6 +98,8 @@ struct fmha_bwd_args
ck_tile::index_t stride_o; ck_tile::index_t stride_o;
ck_tile::index_t stride_randval; ck_tile::index_t stride_randval;
ck_tile::index_t stride_do; ck_tile::index_t stride_do;
ck_tile::index_t stride_dq_acc;
ck_tile::index_t stride_dq;
ck_tile::index_t stride_dk; ck_tile::index_t stride_dk;
ck_tile::index_t stride_dv; ck_tile::index_t stride_dv;
ck_tile::index_t stride_dbias; ck_tile::index_t stride_dbias;
...@@ -108,6 +111,10 @@ struct fmha_bwd_args ...@@ -108,6 +111,10 @@ struct fmha_bwd_args
ck_tile::index_t nhead_stride_randval; ck_tile::index_t nhead_stride_randval;
ck_tile::index_t nhead_stride_do; ck_tile::index_t nhead_stride_do;
ck_tile::index_t nhead_stride_lsed; ck_tile::index_t nhead_stride_lsed;
ck_tile::index_t nhead_stride_dq_acc;
ck_tile::index_t nhead_stride_dq;
ck_tile::index_t nhead_stride_dk;
ck_tile::index_t nhead_stride_dv;
ck_tile::index_t nhead_stride_dbias; ck_tile::index_t nhead_stride_dbias;
ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_k;
...@@ -117,15 +124,17 @@ struct fmha_bwd_args ...@@ -117,15 +124,17 @@ struct fmha_bwd_args
ck_tile::index_t batch_stride_randval; ck_tile::index_t batch_stride_randval;
ck_tile::index_t batch_stride_do; ck_tile::index_t batch_stride_do;
ck_tile::index_t batch_stride_lsed; ck_tile::index_t batch_stride_lsed;
ck_tile::index_t batch_stride_dq_acc;
ck_tile::index_t batch_stride_dq;
ck_tile::index_t batch_stride_dk; ck_tile::index_t batch_stride_dk;
ck_tile::index_t batch_stride_dv; ck_tile::index_t batch_stride_dv;
ck_tile::index_t batch_stride_dbias; ck_tile::index_t batch_stride_dbias;
ck_tile::index_t split_stride_dq_acc;
ck_tile::index_t window_size_left; ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right; ck_tile::index_t window_size_right;
ck_tile::index_t mask_type; ck_tile::index_t mask_type;
float p_drop; float p_drop;
float p_undrop; float p_undrop;
bool s_randval;
std::tuple<uint64_t, uint64_t> drop_seed_offset; std::tuple<uint64_t, uint64_t> drop_seed_offset;
}; };
...@@ -145,10 +154,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) ...@@ -145,10 +154,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.do_ptr, args.do_ptr,
args.d_ptr, args.d_ptr,
args.rand_val_ptr, args.rand_val_ptr,
args.dq_ptr,
args.dk_ptr, args.dk_ptr,
args.dv_ptr, args.dv_ptr,
args.dbias_ptr, args.dbias_ptr,
args.dq_acc_ptr,
args.seqstart_q_ptr, args.seqstart_q_ptr,
args.seqstart_k_ptr, args.seqstart_k_ptr,
args.seqlen_k_ptr, args.seqlen_k_ptr,
...@@ -163,6 +172,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) ...@@ -163,6 +172,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.stride_bias, args.stride_bias,
args.stride_randval, args.stride_randval,
args.stride_do, args.stride_do,
args.stride_dq_acc,
args.stride_dk, args.stride_dk,
args.stride_dv, args.stride_dv,
args.stride_dbias, args.stride_dbias,
...@@ -173,13 +183,15 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) ...@@ -173,13 +183,15 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.nhead_stride_randval, args.nhead_stride_randval,
args.nhead_stride_do, args.nhead_stride_do,
args.nhead_stride_lsed, args.nhead_stride_lsed,
args.nhead_stride_dq_acc,
args.nhead_stride_dk,
args.nhead_stride_dv,
args.nhead_stride_dbias, args.nhead_stride_dbias,
args.batch_stride_lsed, args.split_stride_dq_acc,
args.window_size_left, args.window_size_left,
args.window_size_right, args.window_size_right,
args.mask_type, args.mask_type,
args.p_drop, args.p_drop,
args.s_randval,
args.drop_seed_offset); args.drop_seed_offset);
} }
else else
...@@ -192,10 +204,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) ...@@ -192,10 +204,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.do_ptr, args.do_ptr,
args.d_ptr, args.d_ptr,
args.rand_val_ptr, args.rand_val_ptr,
args.dq_ptr,
args.dk_ptr, args.dk_ptr,
args.dv_ptr, args.dv_ptr,
args.dbias_ptr, args.dbias_ptr,
args.dq_acc_ptr,
args.seqlen_q, args.seqlen_q,
args.seqlen_k, args.seqlen_k,
args.hdim_q, args.hdim_q,
...@@ -209,6 +221,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) ...@@ -209,6 +221,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.stride_bias, args.stride_bias,
args.stride_randval, args.stride_randval,
args.stride_do, args.stride_do,
args.stride_dq_acc,
args.stride_dk, args.stride_dk,
args.stride_dv, args.stride_dv,
args.stride_dbias, args.stride_dbias,
...@@ -219,6 +232,9 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) ...@@ -219,6 +232,9 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.nhead_stride_randval, args.nhead_stride_randval,
args.nhead_stride_do, args.nhead_stride_do,
args.nhead_stride_lsed, args.nhead_stride_lsed,
args.nhead_stride_dq_acc,
args.nhead_stride_dk,
args.nhead_stride_dv,
args.nhead_stride_dbias, args.nhead_stride_dbias,
args.batch_stride_q, args.batch_stride_q,
args.batch_stride_k, args.batch_stride_k,
...@@ -227,14 +243,15 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) ...@@ -227,14 +243,15 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.batch_stride_randval, args.batch_stride_randval,
args.batch_stride_do, args.batch_stride_do,
args.batch_stride_lsed, args.batch_stride_lsed,
args.batch_stride_dq_acc,
args.batch_stride_dk, args.batch_stride_dk,
args.batch_stride_dv, args.batch_stride_dv,
args.batch_stride_dbias, args.batch_stride_dbias,
args.split_stride_dq_acc,
args.window_size_left, args.window_size_left,
args.window_size_right, args.window_size_right,
args.mask_type, args.mask_type,
args.p_drop, args.p_drop,
args.s_randval,
args.drop_seed_offset); args.drop_seed_offset);
} }
}(); }();
...@@ -260,8 +277,7 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args) ...@@ -260,8 +277,7 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
args.stride_o, args.stride_o,
args.nhead_stride_do, args.nhead_stride_do,
args.nhead_stride_o, args.nhead_stride_o,
args.nhead_stride_lsed, args.nhead_stride_lsed);
args.batch_stride_lsed);
} }
else else
{ // create batch mode kernel arguments { // create batch mode kernel arguments
...@@ -286,19 +302,59 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args) ...@@ -286,19 +302,59 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
return ck_tile::make_tuple(kargs, grids); return ck_tile::make_tuple(kargs, grids);
} }
template <typename FmhaBwdConvertQGradKernel>
auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
{
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(FmhaBwdConvertQGradKernel::kIsGroupMode)
{
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
args.dq_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.hdim_q,
args.stride_dq,
args.stride_dq_acc,
args.nhead_stride_dq,
args.nhead_stride_dq_acc,
args.split_stride_dq_acc);
}
else
{ // create batch mode kernel arguments
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
args.dq_ptr,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
args.stride_dq,
args.stride_dq_acc,
args.nhead_stride_dq,
args.nhead_stride_dq_acc,
args.batch_stride_dq,
args.batch_stride_dq_acc,
args.split_stride_dq_acc);
}
}();
dim3 grids = FmhaBwdConvertQGradKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q);
return ck_tile::make_tuple(kargs, grids);
}
// this is used to pattern-match internl kernel implementation, not to instantiate kernel // this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <ck_tile::index_t HDim_, template <ck_tile::index_t HDim_,
typename DataType_, typename DataType_,
bool kIsGroupMode_, bool kIsGroupMode_,
ck_tile::BlockFmhaBwdPipelineEnum FmhaBwdPipelineEnum_, ck_tile::BlockFmhaBwdPipelineEnum FmhaBwdPipelineEnum_,
typename FmhaMask_, typename FmhaMask_,
typename FmhaDropout_,
ck_tile::BlockAttentionBiasEnum BiasEnum_, ck_tile::BlockAttentionBiasEnum BiasEnum_,
bool kHasBiasGrad_, bool kHasBiasGrad_,
bool kHasDropout_,
bool kPadS_, bool kPadS_,
bool kPadSK_, bool kPadSK_,
bool kPadD_, bool kPadD_,
bool kPadDv_> bool kPadDv_,
bool kIsDeterministic_>
struct fmha_bwd_dq_dk_dv_traits_ struct fmha_bwd_dq_dk_dv_traits_
{ {
static constexpr ck_tile::index_t HDim = HDim_; static constexpr ck_tile::index_t HDim = HDim_;
...@@ -306,13 +362,14 @@ struct fmha_bwd_dq_dk_dv_traits_ ...@@ -306,13 +362,14 @@ struct fmha_bwd_dq_dk_dv_traits_
static constexpr bool kIsGroupMode = kIsGroupMode_; static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr auto FmhaBwdPipelineEnum = FmhaBwdPipelineEnum_; static constexpr auto FmhaBwdPipelineEnum = FmhaBwdPipelineEnum_;
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>; using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
using FmhaDropout = ck_tile::remove_cvref_t<FmhaDropout_>;
static constexpr auto BiasEnum = BiasEnum_; static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kHasBiasGrad = kHasBiasGrad_; static constexpr bool kHasBiasGrad = kHasBiasGrad_;
static constexpr bool kHasDropout = kHasDropout_;
static constexpr bool kPadS = kPadS_; static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSK = kPadSK_; static constexpr bool kPadSK = kPadSK_;
static constexpr bool kPadD = kPadD_; static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_; static constexpr bool kPadDv = kPadDv_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
}; };
template <typename Traits_> template <typename Traits_>
...@@ -343,6 +400,31 @@ void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args); ...@@ -343,6 +400,31 @@ void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_> template <typename Traits_>
std::string fmha_bwd_dot_do_o_get_name_(); std::string fmha_bwd_dot_do_o_get_name_();
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
bool kPadS_,
bool kPadD_,
bool kIsDeterministic_>
struct fmha_bwd_convert_dq_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
};
template <typename Traits_>
float fmha_bwd_convert_dq_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
std::string fmha_bwd_convert_dq_get_name_();
// This is the public API, will be generated by script // This is the public API, will be generated by script
struct fmha_bwd_traits struct fmha_bwd_traits
{ {
...@@ -354,6 +436,8 @@ struct fmha_bwd_traits ...@@ -354,6 +436,8 @@ struct fmha_bwd_traits
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool has_dbias; bool has_dbias;
bool has_dropout; bool has_dropout;
bool is_store_randval;
bool is_deterministic;
// TODO: padding check is inside this api // TODO: padding check is inside this api
}; };
float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&); float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&);
This diff is collapsed.
This diff is collapsed.
...@@ -5,25 +5,30 @@ ...@@ -5,25 +5,30 @@
import argparse import argparse
from enum import IntEnum from enum import IntEnum
from pathlib import Path from pathlib import Path
import pkgutil
import sys
from typing import List, Optional from typing import List, Optional
import codegen.ops
from codegen.cmake_config import * from codegen.cmake_config import *
from codegen.ops import (
fmha_fwd,
fmha_fwd_splitkv,
fmha_bwd
)
class HandlerId(IntEnum): class HandlerId(IntEnum):
LIST_BLOBS = 0 LIST_BLOBS = 0
WRITE_BLOBS = 1 WRITE_BLOBS = 1
handlers = { # inspect all modules under 'codegen.ops' and register API handlers
'fwd' : (fmha_fwd.list_blobs, fmha_fwd.write_blobs), ops = []
'fwd_splitkv' : (fmha_fwd_splitkv.list_blobs, fmha_fwd_splitkv.write_blobs), for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__):
'bwd' : (fmha_bwd.list_blobs, fmha_bwd.write_blobs), full_module_name = '%s.%s' % (codegen.ops.__name__, module_name)
} if full_module_name not in sys.modules:
ops.append(importer.find_spec(module_name).loader.load_module(module_name))
unwanted_prefix = 'fmha_'
handlers = dict(
[(op.__name__[len(unwanted_prefix):] if op.__name__.startswith(unwanted_prefix) else op.__name__,
(op.list_blobs, op.write_blobs)) for op in ops]
)
assert 0 < len(handlers)
def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None: def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None:
if output_dir is None: if output_dir is None:
...@@ -103,4 +108,4 @@ if __name__ == "__main__": ...@@ -103,4 +108,4 @@ if __name__ == "__main__":
if args.list_blobs is not None: if args.list_blobs is not None:
list_blobs(args.list_blobs, api_list, args.filter, int(args.receipt), mask_impl=args.mask) list_blobs(args.list_blobs, api_list, args.filter, int(args.receipt), mask_impl=args.mask)
else: else:
write_blobs(args.output_dir, api_list, args.filter, int(args.receipt), mask_impl=args.mask) write_blobs(args.output_dir, api_list, args.filter, int(args.receipt), mask_impl=args.mask)
\ No newline at end of file
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <cassert>
#include <cmath>
#include <functional>
#include <iterator>
#include <optional>
#include <random>
#include <tuple>
// keep sync with RotaryEmbeddingEnum
enum class rope_enum
{
none = 0,
interleaved = 1,
half_rotated = 2,
};
template <typename DataType>
std::tuple<ck_tile::HostTensor<DataType>, ck_tile::HostTensor<DataType>>
generate_rotary_cos_sin(ck_tile::index_t seqlen,
ck_tile::index_t rotary_dim,
std::optional<unsigned> seed = std::nullopt)
{
// return dummy tensors if we won't apply RoPE at all
if(rotary_dim <= 0)
{
ck_tile::HostTensor<DataType> dummy({1, 1});
return std::make_tuple(dummy, dummy);
}
std::mt19937 random_engine(seed.has_value() ? *seed : std::random_device{}());
std::uniform_real_distribution<float> generator(0.0f, 1.0f);
const ck_tile::index_t num_rows = seqlen * 2;
const ck_tile::index_t num_cols = rotary_dim / 2;
using std::begin, std::end;
ck_tile::HostTensor<float> angle({num_rows, num_cols});
std::generate(begin(angle), end(angle), [&] { return generator(random_engine) * 2 * M_PI; });
ck_tile::HostTensor<DataType> cos({num_rows, num_cols});
std::transform(begin(angle), end(angle), begin(cos), [](float origin_value) {
return ck_tile::type_convert<DataType>(std::cos(origin_value));
});
ck_tile::HostTensor<DataType> sin({num_rows, num_cols});
std::transform(begin(angle), end(angle), begin(sin), [](float origin_value) {
return ck_tile::type_convert<DataType>(std::sin(origin_value));
});
return std::make_tuple(cos, sin);
}
template <typename DataType>
std::tuple<ck_tile::HostTensor<DataType>, ck_tile::HostTensor<DataType>>
slice_rotary_cos_sin(const ck_tile::HostTensor<DataType>& cos,
const ck_tile::HostTensor<DataType>& sin,
ck_tile::index_t seqlen_offset,
ck_tile::index_t seqlen)
{
assert(cos.get_num_of_dimension() == 2 && sin.get_num_of_dimension() == 2);
assert(cos.get_length(0) == sin.get_length(0) && cos.get_length(1) == sin.get_length(1));
assert(static_cast<std::size_t>(seqlen_offset + seqlen) <= cos.get_length(0));
const ck_tile::index_t num_rows = seqlen;
const ck_tile::index_t num_cols = cos.get_length(1);
ck_tile::HostTensor<DataType> cos_pt({num_rows, num_cols});
cos_pt.ForEach([&](auto& self, auto i) { self(i) = cos(i[0] + seqlen_offset, i[1]); });
ck_tile::HostTensor<DataType> sin_pt({num_rows, num_cols});
sin_pt.ForEach([&](auto& self, auto i) { self(i) = sin(i[0] + seqlen_offset, i[1]); });
return std::make_tuple(cos_pt, sin_pt);
}
#!/bin/sh #!/bin/sh
# TODO: run this script from CK root # TODO: run this script from CK root or build directory
BUILD=build EXE="$(find . -name tile_example_fmha_bwd -type f | head -n 1)"
EXE=$BUILD/bin/tile_example_fmha_bwd
VALID=0 VALID=0
for prec in "fp16" "bf16" ; do for prec in "fp16" "bf16" ; do
......
#!/bin/sh #!/bin/sh
# TODO: run this script from CK root # TODO: run this script from CK root or build directory
BUILD=build EXE="$(find . -name tile_example_fmha_fwd -type f | head -n 1)"
EXE=$BUILD/bin/tile_example_fmha_fwd
VALID=0 VALID=0
for prec in "fp16" "bf16" ; do for prec in "fp16" "bf16" ; do
......
#!/bin/sh #!/bin/sh
# TODO: run this script from CK root # TODO: run this script from CK root or build directory
BUILD=build EXE="$(find . -name tile_example_fmha_bwd -type f | head -n 1)"
EXE=$BUILD/bin/tile_example_fmha_bwd
KNAME=1 KNAME=1
export CK_WARMUP=0 export CK_WARMUP=0
...@@ -11,18 +10,19 @@ COMMON_ARGS='-v=1' ...@@ -11,18 +10,19 @@ COMMON_ARGS='-v=1'
set -x set -x
for prec in "fp16" "bf16" ; do for prec in "fp16" "bf16" ; do
for perm in 0 1 ; do for perm in 0 1 ; do
for hdim in 32 64 128 ; do for hdim in 32 64 128 256 ; do
for mode in 0 1 ; do for mode in 0 1 ; do
for bias in "n" "e" "a"; do for bias in "n" "a" ; do
for dbias in 0 1 ; do for dbias in 0 ; do
for p_drop in 0.0 0.2; do for p_drop in 0.0 0.2 ; do
for deterministic in 0 ; do
$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -v=1 -deterministic=$deterministic -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
done done
done done
...@@ -31,4 +31,5 @@ done ...@@ -31,4 +31,5 @@ done
done done
done done
done done
done
set +x set +x
#!/bin/sh #!/bin/bash
# TODO: run this script from CK root # TODO: run this script from CK root or build directory
BUILD=build EXE="$(find . -name tile_example_fmha_fwd -type f | head -n 1)"
EXE=$BUILD/bin/tile_example_fmha_fwd
KNAME=1 KNAME=1
export CK_WARMUP=0 export CK_WARMUP=0
...@@ -10,44 +9,98 @@ export CK_REPEAT=1 ...@@ -10,44 +9,98 @@ export CK_REPEAT=1
COMMON_ARGS='-v=1 -warmup=0 -repeat=1' COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
# mode=0 # mode=0
# export HIP_VISIBLE_DEVICES=4 # export HIP_VISIBLE_DEVICES=4
set -x
for prec in "fp16" "bf16" ; do
for mode in 1 0 ; do
for perm in 0 1 ; do
for vlayout in "r" "c" ; do
for hdim in 32 64 128 256 ; do
for lse in 0 1 ; do
for bias in "n" "e" "a" ; do
for p_drop in 0.0 0.2; do
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
done TEST_SPLITKV=0
done TEST_APPENDKV=0
done # options:
done # -s: run splitkv tests
done # -a: run appendkv tests
done while getopts ":sa" opt; do
done case "${opt}" in
s)
TEST_SPLITKV=1
;;
a)
TEST_APPENDKV=1
;;
*)
;;
esac
done done
run_fp16_bf16_tests() {
local NUM_SPLITS=(1)
local PAGE_BLOCK_SIZE=(0)
local CACHE_BATCH_IDX=(0)
for perm in 0 1 ; do if [ $TEST_SPLITKV -eq 1 ] ; then
for bias in "n" "e" "a" ; do NUM_SPLITS+=(2 3)
for b in 1 2 ; do PAGE_BLOCK_SIZE+=(128)
for hdim in 64 128 256 ; do CACHE_BATCH_IDX+=(1)
$EXE -prec=fp8 -init=3 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=c -squant=1 -kname=$KNAME $COMMON_ARGS fi
done
done for prec in "fp16" "bf16" ; do
done for mode in 1 0 ; do
done for perm in 0 1 ; do
set +x for vlayout in "r" "c" ; do
for hdim in 32 64 128 256 ; do
for lse in 0 1 ; do
for bias in "n" "e" "a" ; do
for p_drop in 0.0 0.2 ; do
for num_splits in "${NUM_SPLITS[@]}" ; do
for page_block_size in "${PAGE_BLOCK_SIZE[@]}" ; do
for cache_batch_idx in "${CACHE_BATCH_IDX[@]}" ; do
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
done ; done ; done ; done ; done
done ; done ; done ; done ; done
done ;
}
run_fp8_tests() {
for perm in 0 1 ; do
for bias in "n" "e" "a" ; do
for b in 1 2 ; do
for hdim in 64 128 256 ; do
$EXE -prec=fp8 -init=3 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=c -squant=1 -kname=$KNAME $COMMON_ARGS
done ; done ; done ; done
}
run_fp16_appendkv_tests() {
for s in $(seq 63 1 65) ; do
for s_k in 65 129 ; do
for s_knew in 0 64 $s_k ; do
for hdim in 32 64 128 256 ; do
for ri in 0 1 ; do
for rdim in 0 16 32 $hdim ; do
for page_block_size in 0 128 ; do
for cache_batch_idx in 0 1 ; do
$EXE -prec=fp16 -b=3 -h=3 -d=$hdim -s=$s -s_k=$s_k -s_knew=$s_knew -rotary_dim=$rdim -rotary_interleaved=$ri -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -iperm=1 -operm=1 -kname=1 $COMMON_ARGS
done ; done ; done ; done ; done
done ; done ; done
}
set -x
run_fp16_bf16_tests
run_fp8_tests
if [ $TEST_APPENDKV -eq 1 ] ; then
run_fp16_appendkv_tests
fi
set +x
\ No newline at end of file
This diff is collapsed.
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -153,8 +153,8 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) ...@@ -153,8 +153,8 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// LDS direct loads using inline assembly // LDS direct loads using inline assembly
#define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 0 #define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 0
// set stochastic rounding as default for f8 conversions // set rounding to nearest even as default for f8 conversions
#define CK_USE_SR_F8_CONVERSION 1 #define CK_USE_SR_F8_CONVERSION 0
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0) // block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1 #define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#ifndef GUARD_CK_FILESYSTEM_HPP_
#define GUARD_CK_FILESYSTEM_HPP_
#include <string>
#include <string_view>
// clang-format off
#if defined(CPPCHECK)
#define CK_HAS_FILESYSTEM 1
#define CK_HAS_FILESYSTEM_TS 1
#elif defined(_WIN32)
#if _MSC_VER >= 1920
#define CK_HAS_FILESYSTEM 1
#define CK_HAS_FILESYSTEM_TS 0
#elif _MSC_VER >= 1900
#define CK_HAS_FILESYSTEM 0
#define CK_HAS_FILESYSTEM_TS 1
#else
#define CK_HAS_FILESYSTEM 0
#define CK_HAS_FILESYSTEM_TS 0
#endif
#elif defined(__has_include)
#if __has_include(<filesystem>) && __cplusplus >= 201703L
#define CK_HAS_FILESYSTEM 1
#else
#define CK_HAS_FILESYSTEM 0
#endif
#if __has_include(<experimental/filesystem>) && __cplusplus >= 201103L
#define CK_HAS_FILESYSTEM_TS 1
#else
#define CK_HAS_FILESYSTEM_TS 0
#endif
#else
#define CK_HAS_FILESYSTEM 0
#define CK_HAS_FILESYSTEM_TS 0
#endif
// clang-format on
#if CK_HAS_FILESYSTEM
#include <filesystem>
#elif CK_HAS_FILESYSTEM_TS
#include <experimental/filesystem>
#else
#error "No filesystem include available"
#endif
namespace CK {
#if CK_HAS_FILESYSTEM
namespace fs = ::std::filesystem;
#elif CK_HAS_FILESYSTEM_TS
namespace fs = ::std::experimental::filesystem;
#endif
} // namespace CK
inline std::string operator+(const std::string_view s, const CK::fs::path& path)
{
return path.string().insert(0, s);
}
inline std::string operator+(const CK::fs::path& path, const std::string_view s)
{
return path.string().append(s);
}
#define FS_ENUM_PERMS_ALL fs::perms::all
#if CK_HAS_FILESYSTEM_TS
#ifdef __linux__
#include <linux/limits.h>
namespace CK {
inline fs::path weakly_canonical(const fs::path& path)
{
std::string result(PATH_MAX, '\0');
std::string p{path.is_relative() ? (fs::current_path() / path).string() : path.string()};
char* retval = realpath(p.c_str(), &result[0]);
return (retval == nullptr) ? path : fs::path{result};
}
} // namespace CK
#else
#error "Not implmeneted!"
#endif
#else
namespace CK {
inline fs::path weakly_canonical(const fs::path& path) { return fs::weakly_canonical(path); }
} // namespace CK
#endif
namespace CK {
#ifdef _WIN32
constexpr std::string_view executable_postfix{".exe"};
constexpr std::string_view library_prefix{""};
constexpr std::string_view dynamic_library_postfix{".dll"};
constexpr std::string_view static_library_postfix{".lib"};
constexpr std::string_view object_file_postfix{".obj"};
#else
constexpr std::string_view executable_postfix{""};
constexpr std::string_view library_prefix{"lib"};
constexpr std::string_view dynamic_library_postfix{".so"};
constexpr std::string_view static_library_postfix{".a"};
constexpr std::string_view object_file_postfix{".o"};
#endif
inline fs::path make_executable_name(const fs::path& path)
{
return path.parent_path() / (path.filename() + executable_postfix);
}
inline fs::path make_dynamic_library_name(const fs::path& path)
{
return path.parent_path() / (library_prefix + path.filename() + dynamic_library_postfix);
}
inline fs::path make_object_file_name(const fs::path& path)
{
return path.parent_path() / (path.filename() + object_file_postfix);
}
inline fs::path make_static_library_name(const fs::path& path)
{
return path.parent_path() / (library_prefix + path.filename() + static_library_postfix);
}
struct FsPathHash
{
std::size_t operator()(const fs::path& path) const { return fs::hash_value(path); }
};
} // namespace CK
#endif // GUARD_CK_FILESYSTEM_HPP_
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment