Commit 3dc5db72 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'amd-develop' into amd-master

parents b924e330 e547c141
...@@ -600,8 +600,8 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -600,8 +600,8 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
# TODO: use async pipeline when compiler is more stable # TODO: use async pipeline when compiler is more stable
if hdim == 256 or hdim in [32, 64, 128]: if hdim == 256 or hdim in [32, 64, 128]:
# if True: # if True:
pipelines.append(Pipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask))
......
...@@ -85,6 +85,9 @@ auto create_args(int argc, char* argv[]) ...@@ -85,6 +85,9 @@ auto create_args(int argc, char* argv[])
.insert("p_drop", "0", "0~1 probability of dropout") .insert("p_drop", "0", "0~1 probability of dropout")
.insert("drop_seed", "1", "seed for random number generator") .insert("drop_seed", "1", "seed for random number generator")
.insert("drop_offset", "0", "offset for random number generator") .insert("drop_offset", "0", "offset for random number generator")
.insert("drop_prefs",
"0",
"seed and offset values are present on GPU; 0 - host, 1 - device/GPU")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel") .insert("repeat", "20", "number of iterations to benchmark the kernel")
...@@ -99,10 +102,23 @@ auto create_args(int argc, char* argv[]) ...@@ -99,10 +102,23 @@ auto create_args(int argc, char* argv[])
// different threshold for different dtype // different threshold for different dtype
template <typename DataType> template <typename DataType>
auto get_elimit(int /*init_method*/) auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/)
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::bf16_t>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v)
{ {
double rtol = 1e-2; double rtol = 1e-2;
double atol = 1e-2; double atol = 1e-2;
if(hdim_q > 128 && hdim_v > 128) // 3.2 for RTZ/1.5 for RTN
{
rtol = 3.2e-2;
atol = 3.2e-2;
}
return ck_tile::make_tuple(rtol, atol); return ck_tile::make_tuple(rtol, atol);
} }
...@@ -145,6 +161,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -145,6 +161,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
float p_drop = arg_parser.get_float("p_drop"); float p_drop = arg_parser.get_float("p_drop");
uint64_t drop_seed = arg_parser.get_uint64("drop_seed"); uint64_t drop_seed = arg_parser.get_uint64("drop_seed");
uint64_t drop_offset = arg_parser.get_uint64("drop_offset"); uint64_t drop_offset = arg_parser.get_uint64("drop_offset");
bool drop_prefs = arg_parser.get_bool("drop_prefs");
if(use_dbias && bias.type != bias_enum::elementwise_bias) if(use_dbias && bias.type != bias_enum::elementwise_bias)
{ {
std::cerr << "dbias only exists when bias type is elementwise" << std::endl; std::cerr << "dbias only exists when bias type is elementwise" << std::endl;
...@@ -368,6 +386,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -368,6 +386,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem dbias_buf(dbias_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem dbias_buf(dbias_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0);
ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0);
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dq_acc_buf(dq_acc_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem dq_acc_buf(dq_acc_host.get_element_space_size_in_bytes());
...@@ -378,6 +398,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -378,6 +398,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
do_buf.ToDevice(do_host.data()); do_buf.ToDevice(do_host.data());
seqstart_q.ToDevice(seqstart_q_host.data()); seqstart_q.ToDevice(seqstart_q_host.data());
seqstart_k.ToDevice(seqstart_k_host.data()); seqstart_k.ToDevice(seqstart_k_host.data());
drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr);
drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr);
alibi_slope_buf.ToDevice(alibi_slope_host.data()); alibi_slope_buf.ToDevice(alibi_slope_host.data());
// clang-format off // clang-format off
...@@ -459,6 +481,18 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -459,6 +481,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t split_stride_dq_acc = const ck_tile::index_t split_stride_dq_acc =
(shape_batch * nhead * shape_seqlen_q * hdim_q); (shape_batch * nhead * shape_seqlen_q * hdim_q);
const auto drop_seed_offset = [&]() -> decltype(fmha_bwd_args::drop_seed_offset) {
if(drop_prefs)
{
return std::make_pair(drop_seed_buf.GetDeviceBuffer(),
drop_offset_buf.GetDeviceBuffer());
}
else
{
return std::make_pair(drop_seed, drop_offset);
}
}();
return fmha_bwd_args{q_buf.GetDeviceBuffer(), return fmha_bwd_args{q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(), k_buf.GetDeviceBuffer(),
v_buf.GetDeviceBuffer(), v_buf.GetDeviceBuffer(),
...@@ -532,7 +566,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -532,7 +566,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
static_cast<ck_tile::index_t>(mask.type), static_cast<ck_tile::index_t>(mask.type),
p_drop, p_drop,
p_undrop, p_undrop,
{drop_seed, drop_offset}}; drop_seed_offset};
}(); }();
float ave_time = fmha_bwd(fmha_traits, fmha_args, stream_config); float ave_time = fmha_bwd(fmha_traits, fmha_args, stream_config);
...@@ -899,7 +933,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -899,7 +933,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
// clang-format on // clang-format on
auto [rtol, atol] = get_elimit<DataType>(init_method); auto [rtol, atol] = get_elimit<DataType>(hdim_q, hdim_v);
bool dq_cur_pass = ck_tile::check_err(dq_host_result, bool dq_cur_pass = ck_tile::check_err(dq_host_result,
dq_host_ref, dq_host_ref,
std::string("Error: QGrad Incorrect results!"), std::string("Error: QGrad Incorrect results!"),
......
...@@ -9,7 +9,10 @@ ...@@ -9,7 +9,10 @@
#include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/epilogue.hpp"
#include "mask.hpp" #include "mask.hpp"
#include "bias.hpp" #include "bias.hpp"
#include <type_traits> #include <type_traits>
#include <utility>
#include <variant>
template <typename DataType> template <typename DataType>
struct FmhaBwdTypeConfig; struct FmhaBwdTypeConfig;
...@@ -135,7 +138,8 @@ struct fmha_bwd_args ...@@ -135,7 +138,8 @@ struct fmha_bwd_args
ck_tile::index_t mask_type; ck_tile::index_t mask_type;
float p_drop; float p_drop;
float p_undrop; float p_undrop;
std::tuple<uint64_t, uint64_t> drop_seed_offset; std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset;
}; };
template <typename FmhaBwdDQDKDVKernel> template <typename FmhaBwdDQDKDVKernel>
......
...@@ -122,6 +122,9 @@ auto create_args(int argc, char* argv[]) ...@@ -122,6 +122,9 @@ auto create_args(int argc, char* argv[])
.insert("p_drop", "0", "0~1 probability of dropout") .insert("p_drop", "0", "0~1 probability of dropout")
.insert("drop_seed", "1", "seed for random number generator") .insert("drop_seed", "1", "seed for random number generator")
.insert("drop_offset", "0", "offset for random number generator") .insert("drop_offset", "0", "offset for random number generator")
.insert("drop_prefs",
"0",
"seed and offset values are present on GPU; 0 - host, 1 - device/GPU")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert( .insert(
"rotary_dim", "0", "RoPE rotary dimension. rotary_dim <= 0 means not apply RoPE at all") "rotary_dim", "0", "RoPE rotary dimension. rotary_dim <= 0 means not apply RoPE at all")
...@@ -442,6 +445,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -442,6 +445,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
float p_drop = arg_parser.get_float("p_drop"); float p_drop = arg_parser.get_float("p_drop");
uint64_t drop_seed = arg_parser.get_uint64("drop_seed"); uint64_t drop_seed = arg_parser.get_uint64("drop_seed");
uint64_t drop_offset = arg_parser.get_uint64("drop_offset"); uint64_t drop_offset = arg_parser.get_uint64("drop_offset");
bool drop_prefs = arg_parser.get_bool("drop_prefs");
if(p_drop < 0.0f || p_drop > 1.0f) if(p_drop < 0.0f || p_drop > 1.0f)
{ {
std::cerr << "The value of p_drop should be 0~1" << std::endl; std::cerr << "The value of p_drop should be 0~1" << std::endl;
...@@ -552,16 +557,33 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -552,16 +557,33 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
#endif #endif
auto get_lengths = [&](bool permute, struct
ck_tile::index_t b /*batch*/, {
ck_tile::index_t h /*nhead*/, auto operator()(bool permute,
ck_tile::index_t s /*seqlen*/, ck_tile::index_t b /*batch*/,
ck_tile::index_t d /*hdim*/) { ck_tile::index_t h /*nhead*/,
if(permute) ck_tile::index_t s /*seqlen*/,
return std::array<ck_tile::index_t, 4>{b, h, s, d}; ck_tile::index_t d /*hdim*/)
else {
return std::array<ck_tile::index_t, 4>{b, s, h, d}; if(permute)
}; return std::array<ck_tile::index_t, 4>{b, h, s, d};
else
return std::array<ck_tile::index_t, 4>{b, s, h, d};
}
auto operator()(bool permute,
ck_tile::index_t ns /*num_splits*/,
ck_tile::index_t b /*batch*/,
ck_tile::index_t h /*nhead*/,
ck_tile::index_t s /*seqlen*/,
ck_tile::index_t d /*hdim*/)
{
if(permute)
return std::array<ck_tile::index_t, 5>{ns, b, h, s, d};
else
return std::array<ck_tile::index_t, 5>{ns, b, s, h, d};
}
} get_lengths;
bool is_v_rowmajor = vlayout == std::string("r"); bool is_v_rowmajor = vlayout == std::string("r");
...@@ -617,7 +639,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -617,7 +639,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1}); : std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
ck_tile::HostTensor<OaccDataType> o_acc_host( ck_tile::HostTensor<OaccDataType> o_acc_host(
1 < num_splits || use_kvcache 1 < num_splits || use_kvcache
? std::array<ck_tile::index_t, 5>{num_splits, batch, nhead, max_seqlen_q, hdim_v} ? get_lengths(o_perm, num_splits, shape_batch, nhead, shape_seqlen_q, hdim_v)
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1}); : std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
// batch mode of lse data layout is [batch, nhead, seqlen_q] // batch mode of lse data layout is [batch, nhead, seqlen_q]
...@@ -739,6 +761,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -739,6 +761,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
need_append_kvcache ? cache_seqlen_ks.size() * sizeof(int32_t) : 0); need_append_kvcache ? cache_seqlen_ks.size() * sizeof(int32_t) : 0);
ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem rotary_sin_buf(rotary_sin_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem rotary_sin_buf(rotary_sin_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0);
ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0);
ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem block_table_buf(block_table_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem block_table_buf(block_table_host.get_element_space_size_in_bytes());
...@@ -757,6 +781,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -757,6 +781,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr); cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr);
rotary_cos_buf.ToDevice(rotary_cos_host.data()); rotary_cos_buf.ToDevice(rotary_cos_host.data());
rotary_sin_buf.ToDevice(rotary_sin_host.data()); rotary_sin_buf.ToDevice(rotary_sin_host.data());
drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr);
drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr);
alibi_slope_buf.ToDevice(alibi_slope_host.data()); alibi_slope_buf.ToDevice(alibi_slope_host.data());
block_table_buf.ToDevice(block_table_host.data()); block_table_buf.ToDevice(block_table_host.data());
cache_batch_idx_buf.ToDevice(cache_batch_idx_host.data()); cache_batch_idx_buf.ToDevice(cache_batch_idx_host.data());
...@@ -854,7 +880,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -854,7 +880,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}(); }();
const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k); const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k);
const ck_tile::index_t stride_randval = (max_seqlen_k); const ck_tile::index_t stride_randval = (max_seqlen_k);
const ck_tile::index_t stride_o_acc = hdim_v; const ck_tile::index_t stride_o_acc = (o_perm ? hdim_v : nhead * hdim_v);
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
// setup nhead_stride_* arguments // setup nhead_stride_* arguments
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
...@@ -881,7 +907,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -881,7 +907,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
const ck_tile::index_t nhead_stride_lse_acc = shape_seqlen_q; const ck_tile::index_t nhead_stride_lse_acc = shape_seqlen_q;
const ck_tile::index_t nhead_stride_o_acc = (max_seqlen_q * hdim_v); const ck_tile::index_t nhead_stride_o_acc = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
// setup batch_stride_* arguments // setup batch_stride_* arguments
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
...@@ -897,12 +923,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -897,12 +923,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q); const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_lse_acc = (nhead * shape_seqlen_q); const ck_tile::index_t batch_stride_lse_acc = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_o_acc = (nhead * max_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o_acc = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch); const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch);
// setup split_stride_* arguments (only used in split-kv kernel) // setup split_stride_* arguments (only used in split-kv kernel)
const ck_tile::index_t split_stride_lse_acc = (shape_batch * nhead * shape_seqlen_q); const ck_tile::index_t split_stride_lse_acc = (shape_batch * nhead * shape_seqlen_q);
const ck_tile::index_t split_stride_o_acc = (batch * nhead * max_seqlen_q * hdim_v); const ck_tile::index_t split_stride_o_acc = (shape_batch * nhead * shape_seqlen_q * hdim_v);
args.q_ptr = q_buf.GetDeviceBuffer(); args.q_ptr = q_buf.GetDeviceBuffer();
args.k_ptr = k_buf.GetDeviceBuffer(); args.k_ptr = k_buf.GetDeviceBuffer();
...@@ -996,9 +1022,17 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -996,9 +1022,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
args.nhead_stride_randval = nhead_stride_randval; args.nhead_stride_randval = nhead_stride_randval;
args.batch_stride_randval = batch_stride_randval; args.batch_stride_randval = batch_stride_randval;
args.p_drop = p_drop; args.p_drop = p_drop;
args.s_randval = s_randval; args.s_randval = s_randval;
args.drop_seed_offset = std::tie(drop_seed, drop_offset); if(drop_prefs)
{
args.drop_seed_offset = std::make_pair(drop_seed_buf.GetDeviceBuffer(),
drop_offset_buf.GetDeviceBuffer());
}
else
{
args.drop_seed_offset = std::make_pair(drop_seed, drop_offset);
}
} }
else if constexpr(std::is_same_v<fmha_fwd_splitkv_args, std::decay_t<decltype(args)>>) else if constexpr(std::is_same_v<fmha_fwd_splitkv_args, std::decay_t<decltype(args)>>)
{ {
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
#include "rotary.hpp" #include "rotary.hpp"
#include <type_traits> #include <type_traits>
#include <utility>
#include <variant>
template <typename DataType> template <typename DataType>
struct FmhaFwdTypeConfig; struct FmhaFwdTypeConfig;
...@@ -144,7 +146,9 @@ struct fmha_fwd_args ...@@ -144,7 +146,9 @@ struct fmha_fwd_args
float p_drop; float p_drop;
bool s_randval; bool s_randval;
std::tuple<uint64_t, uint64_t> drop_seed_offset;
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset;
}; };
struct fmha_fwd_splitkv_args struct fmha_fwd_splitkv_args
...@@ -398,10 +402,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) ...@@ -398,10 +402,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args.nhead_stride_bias, args.nhead_stride_bias,
args.nhead_stride_lse_acc, args.nhead_stride_lse_acc,
args.nhead_stride_o_acc, args.nhead_stride_o_acc,
args.batch_stride_k, args.batch_stride_k, // only used for paged-kvcache
args.batch_stride_v, args.batch_stride_v, // only used for paged-kvcache
args.batch_stride_lse_acc,
args.batch_stride_o_acc,
args.split_stride_lse_acc, args.split_stride_lse_acc,
args.split_stride_o_acc, args.split_stride_o_acc,
args.window_size_left, args.window_size_left,
...@@ -475,7 +477,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args) ...@@ -475,7 +477,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args.lse_ptr, args.lse_ptr,
args.o_ptr, args.o_ptr,
args.batch, args.batch,
args.max_seqlen_q,
args.seqstart_q_ptr, args.seqstart_q_ptr,
args.hdim_v, args.hdim_v,
args.num_splits, args.num_splits,
...@@ -486,7 +487,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args) ...@@ -486,7 +487,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args.nhead_stride_o_acc, args.nhead_stride_o_acc,
args.nhead_stride_lse, args.nhead_stride_lse,
args.nhead_stride_o, args.nhead_stride_o,
args.batch_stride_o_acc,
args.split_stride_lse_acc, args.split_stride_lse_acc,
args.split_stride_o_acc); args.split_stride_o_acc);
} }
...@@ -497,7 +497,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args) ...@@ -497,7 +497,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args.lse_ptr, args.lse_ptr,
args.o_ptr, args.o_ptr,
args.batch, args.batch,
args.max_seqlen_q,
args.seqlen_q, args.seqlen_q,
args.hdim_v, args.hdim_v,
args.num_splits, args.num_splits,
......
...@@ -6,7 +6,8 @@ This folder contains example for Layernorm2D forward using ck_tile tile-programm ...@@ -6,7 +6,8 @@ This folder contains example for Layernorm2D forward using ck_tile tile-programm
``` ```
# in the root of ck_tile # in the root of ck_tile
mkdir build && cd build mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942... # you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
make tile_example_layernorm2d_fwd -j make tile_example_layernorm2d_fwd -j
``` ```
This will result in an executable `build/bin/tile_example_layernorm2d_fwd` This will result in an executable `build/bin/tile_example_layernorm2d_fwd`
......
...@@ -35,7 +35,9 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -35,7 +35,9 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
YDataType, YDataType,
MeanDataType, MeanDataType,
InvStdDataType, InvStdDataType,
Shape>; Shape,
true,
true>;
using Kernel = ck_tile::Layernorm2dFwd<PipelineProblem>; using Kernel = ck_tile::Layernorm2dFwd<PipelineProblem>;
......
...@@ -6,7 +6,8 @@ This folder contains example for GEMM using ck_tile tile-programming implementat ...@@ -6,7 +6,8 @@ This folder contains example for GEMM using ck_tile tile-programming implementat
``` ```
# in the root of ck_tile # in the root of ck_tile
mkdir build && cd build mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942... # you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
make tile_example_gemm_basic -j make tile_example_gemm_basic -j
``` ```
This will result in an executable `build/bin/tile_example_gemm_basic` This will result in an executable `build/bin/tile_example_gemm_basic`
...@@ -14,10 +15,17 @@ This will result in an executable `build/bin/tile_example_gemm_basic` ...@@ -14,10 +15,17 @@ This will result in an executable `build/bin/tile_example_gemm_basic`
## example ## example
``` ```
args: args:
-m m dimension (default:3328) -b batch size (default:1)
-n m dimension (default:4096) -m m dimension (default:1024)
-n n dimension (default:2048)
-k k dimension (default:64) -k k dimension (default:64)
-e epsilon (default:1e-5) -stride_a Tensor A stride (default:0)
-v cpu validation or not (default:1) -stride_b Tensor B stride (default:0)
-prec precision (default:fp16) -stride_c Tensor C stride (default:0)
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2)
-e Absolute error tolerance (default:1e-5)
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
-warmup number of iterations before benchmark the kernel (default:10)
-repeat number of iterations to benchmark the kernel (default:100)
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
``` ```
...@@ -41,18 +41,39 @@ template <typename LayoutA, ...@@ -41,18 +41,39 @@ template <typename LayoutA,
float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
{ {
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part. // The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadA = true; constexpr bool kPadA = true;
constexpr bool kPadB = true; constexpr bool kPadB = true;
constexpr bool kTilePermute = false;
constexpr int kBlockPerCu = 1; constexpr int kBlockPerCu = 1;
using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>; using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>;
using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadA, kPadB>>; // The rank and permutation will also be generate out by the CodeGen part.
constexpr ck_tile::index_t kOutputRank = 2;
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
constexpr bool CShuffleEpilogue =
std::is_same_v<LayoutC, ck_tile::tensor_layout::gemm::ColumnMajor>;
using GemmEpilogue = std::conditional_t<
CShuffleEpilogue,
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
kPadA,
kPadB,
kTilePermute,
kOutputRank,
1,
0,
TilePartitioner::kM,
TilePartitioner::kN>>,
ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadA, kPadB>>>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM. // ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue, LayoutA, LayoutB, LayoutC>;
auto kargs = Kernel::MakeKargs(args.p_a, auto kargs = Kernel::MakeKargs(args.p_a,
args.p_b, args.p_b,
...@@ -255,15 +276,13 @@ int main(int argc, char* argv[]) ...@@ -255,15 +276,13 @@ int main(int argc, char* argv[])
ck_tile::sequence<M_Warp, N_Warp, K_Warp>, ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>; ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using CodegenPipelineProblem = ck_tile::BlockGemmPipelineProblem<ADataType, using CodegenGemmTraits = ck_tile::
BDataType, TileGemmTraits<kPadA, kPadB, kPadC, matrix_a_layout, matrix_b_layout, matrix_c_layout>;
AccDataType,
CodegenGemmShape, using CodegenPipelineProblem = ck_tile::
kPadA, GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
kPadB,
kPadC>;
using CodegenGemmPipeline = ck_tile::BlockGemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>; using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
invoke_gemm<ck_tile::half_t, invoke_gemm<ck_tile::half_t,
matrix_a_layout, matrix_a_layout,
...@@ -341,7 +360,13 @@ int main(int argc, char* argv[]) ...@@ -341,7 +360,13 @@ int main(int argc, char* argv[])
ck_tile::HostTensor<CDataType> c_host_gpu_ref(c_dimensions); ck_tile::HostTensor<CDataType> c_host_gpu_ref(c_dimensions);
ck_tile::DeviceMem c_gpu_buf(c_host_gpu_ref.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_gpu_buf(c_host_gpu_ref.get_element_space_size_in_bytes());
ck_tile::reference_gemm_gpu<ADataType, BDataType, AccDataType, CDataType>( ck_tile::reference_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
matrix_a_layout,
matrix_b_layout,
matrix_c_layout>(
a_buf, b_buf, c_gpu_buf, M, N, K, stride_a, stride_b, stride_c); a_buf, b_buf, c_gpu_buf, M, N, K, stride_a, stride_b, stride_c);
c_buf.FromDevice(c_host_gpu_ref.data()); c_buf.FromDevice(c_host_gpu_ref.data());
......
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
add_executable(tile_example_img2col EXCLUDE_FROM_ALL image_to_column.cpp)
# Image to Column
This folder contains example for Image to Column using ck_tile tile-programming implementation.
## build
```
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
make tile_example_img2col -j
```
This will result in an executable `build/bin/tile_example_img2col`
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cstring>
#include "ck_tile/host.hpp"
#include "image_to_column.hpp"
// Host API implementation
template <>
float image_to_column(const image_to_column_traits& traits,
const image_to_column_args<2>& args,
const ck_tile::stream_config& stream_conf)
{
if(traits.data_type.compare("fp16") == 0)
{
constexpr ck_tile::index_t NDimSpatial = 2;
constexpr ck_tile::index_t VectorSize = 8;
using thread_tile = ck_tile::sequence<8, 8>;
using warp_tile = ck_tile::sequence<64, 64>;
using block_tile = ck_tile::sequence<128, 128>;
using Shape = ck_tile::TileImageToColumnShape<thread_tile, warp_tile, block_tile>;
using InDataType = ck_tile::half_t;
using OutDataType = ck_tile::half_t;
using PipelineProblem = ck_tile::BlockImageToColumnProblem<InDataType,
OutDataType,
Shape,
NDimSpatial,
VectorSize,
VectorSize>;
using Kernel = ck_tile::ImageToColumn<PipelineProblem>;
auto kargs = Kernel::MakeKargs(args.p_in,
args.p_out,
args.G,
args.N,
args.C,
args.input_spatial_lengths,
args.filter_spatial_lengths,
args.output_spatial_lengths,
args.image_g_n_c_wis_strides,
args.gemm_g_m_k_strides,
args.conv_filter_strides,
args.conv_filter_dilations,
args.input_left_pads,
args.input_right_pads);
const dim3 grids = Kernel::GridSize(
args.N * args.output_spatial_lengths[0] * args.output_spatial_lengths[1],
args.filter_spatial_lengths[0] * args.filter_spatial_lengths[1] * args.C,
args.G);
constexpr dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 2;
float ave_time = ck_tile::launch_kernel(
stream_conf,
ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
return 0;
}
int main(int argc, char* argv[])
{
constexpr ck_tile::index_t NDimSpatial = 2;
ExecutionConfig config;
ck_tile::conv::ConvParam conv_params = DefaultConvParams;
if(!parse_cmd_args(argc, argv, config, conv_params))
{
return EXIT_FAILURE;
}
if(conv_params.num_dim_spatial_ != NDimSpatial)
{
std::cerr << "unsupported # of spatial dimensions" << std::endl;
return EXIT_FAILURE;
}
using InDataType = ck_tile::half_t;
using OutDataType = ck_tile::half_t;
using ImLayout = ck_tile::tensor_layout::convolution::NHWGC;
const auto G = conv_params.G_;
const auto N = conv_params.N_;
const auto C = conv_params.C_;
const ck_tile::long_index_t NHoWo =
N * std::accumulate(conv_params.output_spatial_lengths_.begin(),
std::next(conv_params.output_spatial_lengths_.begin(), NDimSpatial),
1,
std::multiplies<>());
const ck_tile::long_index_t CYX =
C * std::accumulate(conv_params.filter_spatial_lengths_.begin(),
std::next(conv_params.filter_spatial_lengths_.begin(), NDimSpatial),
1,
std::multiplies<>());
const auto in_desc =
ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<ImLayout>(conv_params);
const auto out_desc = ck_tile::HostTensorDescriptor({G, NHoWo, CYX});
// host verify
ck_tile::HostTensor<InDataType> in(in_desc);
ck_tile::HostTensor<OutDataType> out_device(out_desc);
ck_tile::HostTensor<OutDataType> out_host(out_desc);
switch(config.init_method)
{
case 0: break;
case 1: ck_tile::FillUniformDistributionIntegerValue<InDataType>{-5.f, 5.f}(in); break;
default: ck_tile::FillUniformDistribution<InDataType>{-0.5, 0.5}(in); break;
}
ck_tile::DeviceMem in_device_buf(in.get_element_space_size_in_bytes());
ck_tile::DeviceMem out_device_buf(out_device.get_element_space_size_in_bytes());
in_device_buf.ToDevice(in.data());
image_to_column_traits traits{"fp16"};
image_to_column_args<NDimSpatial> args{
in_device_buf.GetDeviceBuffer(),
out_device_buf.GetDeviceBuffer(),
G,
N,
C,
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.input_spatial_lengths_),
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.filter_spatial_lengths_),
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.output_spatial_lengths_),
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial + 3>(in_desc.get_strides()),
ck_tile::to_array<ck_tile::long_index_t, 3>(out_desc.get_strides()),
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.conv_filter_strides_),
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.conv_filter_dilations_),
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.input_left_pads_),
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.input_right_pads_)};
float ave_time =
image_to_column(traits, args, ck_tile::stream_config{nullptr, config.time_kernel});
std::size_t num_btype = G * NHoWo * CYX * (sizeof(OutDataType) + sizeof(InDataType));
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
bool pass = true;
if(config.do_verification)
{
// reference
ck_tile::reference_im2col<InDataType, OutDataType, NDimSpatial>(in, out_host, conv_params);
out_device_buf.FromDevice(out_device.data());
pass = ck_tile::check_err(out_device, out_host);
std::cout << "valid:" << (pass ? "y" : "n") << std::endl;
}
return !pass;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/image_to_column.hpp"
#include <string>
#define DefaultConvParams \
ck_tile::conv::ConvParam \
{ \
2, 2, 32, 32, 32, {4, 4}, {64, 64}, {1, 1}, {1, 1}, {0, 0}, { 0, 0 } \
}
struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
};
inline void print_help_msg()
{
std::cerr << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< "arg3: time kernel (0=no, 1=yes)\n"
<< ck_tile::conv::get_conv_param_parser_helper_msg() << std::endl;
}
inline bool parse_cmd_args(int argc,
char* argv[],
ExecutionConfig& config,
ck_tile::conv::ConvParam& conv_params)
{
constexpr int num_execution_config_args =
3; // arguments for do_verification, init_method, time_kernel
constexpr int num_conv_param_leading_args = 5; // arguments for num_dim_spatial_, G_, N_, K_, C_
constexpr int threshold_to_catch_partial_args = 1 + num_execution_config_args;
constexpr int threshold_to_catch_all_args =
threshold_to_catch_partial_args + num_conv_param_leading_args;
if(argc == 1)
{
// use default
config = ExecutionConfig{};
}
// catch only ExecutionConfig arguments
else if(argc == threshold_to_catch_partial_args)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
}
// catch both ExecutionConfig & ConvParam arguments
else if(threshold_to_catch_all_args < argc && ((argc - threshold_to_catch_all_args) % 3 == 0))
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
const ck_tile::index_t num_dim_spatial = std::stoi(argv[4]);
conv_params =
ck_tile::conv::parse_conv_param(num_dim_spatial, threshold_to_catch_partial_args, argv);
}
else
{
print_help_msg();
return false;
}
return true;
}
struct image_to_column_traits
{
std::string data_type;
};
template <ck_tile::index_t NDimSpatial>
struct image_to_column_args
{
const void* p_in;
void* p_out;
const ck_tile::long_index_t G;
const ck_tile::long_index_t N;
const ck_tile::long_index_t C;
const ck_tile::array<ck_tile::long_index_t, NDimSpatial> input_spatial_lengths;
const ck_tile::array<ck_tile::long_index_t, NDimSpatial> filter_spatial_lengths;
const ck_tile::array<ck_tile::long_index_t, NDimSpatial> output_spatial_lengths;
const ck_tile::array<ck_tile::long_index_t, NDimSpatial + 3> image_g_n_c_wis_strides;
const ck_tile::array<ck_tile::long_index_t, 3> gemm_g_m_k_strides;
const ck_tile::array<ck_tile::long_index_t, NDimSpatial> conv_filter_strides;
const ck_tile::array<ck_tile::long_index_t, NDimSpatial> conv_filter_dilations;
const ck_tile::array<ck_tile::long_index_t, NDimSpatial> input_left_pads;
const ck_tile::array<ck_tile::long_index_t, NDimSpatial> input_right_pads;
};
// host API
template <ck_tile::index_t NDimSpatial>
float image_to_column(const image_to_column_traits&,
const image_to_column_args<NDimSpatial>&,
const ck_tile::stream_config&);
...@@ -5,3 +5,4 @@ include_directories(AFTER ...@@ -5,3 +5,4 @@ include_directories(AFTER
add_subdirectory(01_fmha) add_subdirectory(01_fmha)
add_subdirectory(02_layernorm2d) add_subdirectory(02_layernorm2d)
add_subdirectory(03_gemm) add_subdirectory(03_gemm)
add_subdirectory(04_img2col)
...@@ -97,13 +97,6 @@ ...@@ -97,13 +97,6 @@
#cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@ #cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@
#endif #endif
//
// Instances supports in the current CK build
//
#ifndef CK_ENABLE_INSTANCES_ONLY
#cmakedefine CK_ENABLE_INSTANCES_ONLY @CK_ENABLE_INSTANCES_ONLY@
#endif
// //
// CK kernels which support XDL (MI series) // CK kernels which support XDL (MI series)
// //
......
...@@ -66,6 +66,9 @@ float launch_and_time_kernel(const StreamConfig& stream_config, ...@@ -66,6 +66,9 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
hip_check_error(hipEventElapsedTime(&total_time, start, stop)); hip_check_error(hipEventElapsedTime(&total_time, start, stop));
hip_check_error(hipEventDestroy(start));
hip_check_error(hipEventDestroy(stop));
return total_time / nrepeat; return total_time / nrepeat;
} }
else else
...@@ -143,6 +146,9 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, ...@@ -143,6 +146,9 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
hip_check_error(hipEventElapsedTime(&total_time, start, stop)); hip_check_error(hipEventElapsedTime(&total_time, start, stop));
hip_check_error(hipEventDestroy(start));
hip_check_error(hipEventDestroy(stop));
return total_time / nrepeat; return total_time / nrepeat;
} }
else else
......
...@@ -406,7 +406,7 @@ struct BlockwiseGemmXdlops_pipeline_v4 ...@@ -406,7 +406,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
} }
template <> template <>
__device__ static constexpr auto TailScheduler<1>() __device__ constexpr auto TailScheduler<1>()
{ {
// schedule // schedule
constexpr auto num_ds_read_inst = constexpr auto num_ds_read_inst =
...@@ -433,7 +433,7 @@ struct BlockwiseGemmXdlops_pipeline_v4 ...@@ -433,7 +433,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
} }
template <> template <>
__device__ static constexpr auto TailScheduler<2>() __device__ constexpr auto TailScheduler<2>()
{ {
// schedule // schedule
constexpr auto num_ds_read_inst = constexpr auto num_ds_read_inst =
......
...@@ -308,7 +308,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -308,7 +308,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
typename vector_type<ComputeDataType, typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type; xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run( xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); c_thread_buf_per_scale.GetVectorTypeReference(I0));
...@@ -390,9 +390,10 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -390,9 +390,10 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
using mfma_input_type = using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type; typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(), xdlops_gemm.template Run<>(
b_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
}); });
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset = constexpr index_t c_offset =
......
...@@ -350,7 +350,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -350,7 +350,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
typename vector_type<ComputeDataType, typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type; xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run( xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); c_thread_buf_per_scale.GetVectorTypeReference(I0));
...@@ -443,7 +443,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -443,7 +443,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
typename vector_type<ComputeDataType, typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type; xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run( xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); c_thread_buf_per_scale.GetVectorTypeReference(I0));
...@@ -518,9 +518,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -518,9 +518,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
using mfma_input_type = using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type; typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(), xdlops_gemm.template Run<>(
b_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
}); });
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset = constexpr index_t c_offset =
...@@ -575,9 +576,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -575,9 +576,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
using mfma_input_type = using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type; typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(), xdlops_gemm.template Run<>(
b_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
}); });
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset = constexpr index_t c_offset =
......
...@@ -427,7 +427,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -427,7 +427,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
typename vector_type<ComputeDataType, typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type; xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run( xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); c_thread_buf_per_scale.GetVectorTypeReference(I0));
...@@ -504,9 +504,10 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -504,9 +504,10 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
using mfma_input_type = using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type; typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(), xdlops_gemm.template Run<>(
b_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
}); });
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset = constexpr index_t c_offset =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment