"git@developer.sourcefind.cn:modelzoo/solov2-pytorch.git" did not exist on "c10ae5d2ca158ae506b93198afa4926ea48bb40b"
Unverified Commit 5055b3bd authored by carlushuang's avatar carlushuang Committed by GitHub
Browse files

[CK_TILE] support group from cmdline (#1295)

* support cmdline seqlen decode

* silent print

* update readme

* update kernel launch 3d

* update tile partitioner

* fix spill for bf16

* modify based on comment

* modify payload_t

* fix bug for alibi mode

* fix alibi test err

* refactor kernel launch, support select timer

* add missing file

* remove useless code

* add some comments
parent 02fa2c29
...@@ -34,6 +34,7 @@ args: ...@@ -34,6 +34,7 @@ args:
if not equal to h, then this is GQA/MQA case if not equal to h, then this is GQA/MQA case
-s seqlen_q. if group-mode, means the average value of seqlen_q (default:3328) -s seqlen_q. if group-mode, means the average value of seqlen_q (default:3328)
total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary
also with "-s=s0,s1,s2..." comma seperated int to set per batch seqlen(group-mode)
-s_k seqlen_k, -1 means equal to s (default:-1) -s_k seqlen_k, -1 means equal to s (default:-1)
-d head dim for q, k (default:128) -d head dim for q, k (default:128)
-d_v head dim for v, -1 means equal to d (default:-1) -d_v head dim for v, -1 means equal to d (default:-1)
......
...@@ -44,11 +44,18 @@ auto create_args(int argc, char* argv[]) ...@@ -44,11 +44,18 @@ auto create_args(int argc, char* argv[])
"-1", "-1",
"num of head, for k/v, -1 means equal to h\n" "num of head, for k/v, -1 means equal to h\n"
"if not equal to h, then this is GQA/MQA case") "if not equal to h, then this is GQA/MQA case")
.insert("s", .insert(
"3328", "s",
"seqlen_q. if group-mode, means the average value of seqlen_q\n" "3328",
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary") "seqlen_q. if group-mode, means the average value of seqlen_q\n"
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary\n"
"also with \"-s=s0,s1,s2...\" comma seperated int to set per batch seqlen(group-mode)")
.insert("s_k", "-1", "seqlen_k, -1 means equal to s") .insert("s_k", "-1", "seqlen_k, -1 means equal to s")
.insert("s_kpad",
"-1",
"seqlen_k stride between 2 tokens, currently used in group-mode only\n"
"for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride\n"
"along seqlen, instead of packed. same as xformer kv_padding")
.insert("d", "128", "head dim for q, k") .insert("d", "128", "head dim for q, k")
.insert("d_v", "-1", "head dim for v, -1 means equal to d") .insert("d_v", "-1", "head dim for v, -1 means equal to d")
.insert("scale_s", .insert("scale_s",
...@@ -103,6 +110,7 @@ auto create_args(int argc, char* argv[]) ...@@ -103,6 +110,7 @@ auto create_args(int argc, char* argv[])
"11939", "11939",
"random seed used for initializing input tensors. 0 for " "random seed used for initializing input tensors. 0 for "
"non-deterministic seed") "non-deterministic seed")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel"); .insert("repeat", "20", "number of iterations to benchmark the kernel");
...@@ -177,10 +185,20 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -177,10 +185,20 @@ bool run(const ck_tile::ArgParser& arg_parser)
return false; return false;
} }
ck_tile::index_t seqlen_q = arg_parser.get_int("s"); auto [seqlen_qs, seqlen_ks, seqlen_kpads] = decode_seqlen(mode,
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); batch,
if(seqlen_k < 0) arg_parser.get_str("s"),
seqlen_k = seqlen_q; arg_parser.get_str("s_k"),
arg_parser.get_str("s_kpad"));
#if 0
// clang-format off
std::cout << "seqlen_qs:"; for(auto xx : seqlen_qs) { std::cout << xx << ","; } std::cout << std::endl;
std::cout << "seqlen_ks:"; for(auto xx : seqlen_ks) { std::cout << xx << ","; } std::cout << std::endl;
std::cout << "seqlen_kpads:"; for(auto xx : seqlen_kpads) { std::cout << xx << ","; } std::cout << std::endl;
// clang-format on
#endif
ck_tile::index_t hdim_q = arg_parser.get_int("d"); ck_tile::index_t hdim_q = arg_parser.get_int("d");
ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
if(hdim_v < 0) if(hdim_v < 0)
...@@ -229,7 +247,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -229,7 +247,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
bool lse = arg_parser.get_bool("lse"); bool lse = arg_parser.get_bool("lse");
bias_info bias = bias_info::decode(arg_parser.get_str("bias")); bias_info bias = bias_info::decode(arg_parser.get_str("bias"));
mask_info mask = mask_info::decode(arg_parser.get_str("mask"), seqlen_q, seqlen_k); mask_info mask = mask_info::decode(
arg_parser.get_str("mask"), seqlen_qs[0], seqlen_ks[0]); // TODO: we don't need x/y anymore
std::string init_method = arg_parser.get_str("init"); std::string init_method = arg_parser.get_str("init");
std::optional<uint32_t> seed = arg_parser.get_uint32("seed"); std::optional<uint32_t> seed = arg_parser.get_uint32("seed");
...@@ -242,11 +261,16 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -242,11 +261,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
int stream_repeat = arg_parser.get_int("repeat"); int stream_repeat = arg_parser.get_int("repeat");
bool kname = arg_parser.get_bool("kname"); bool kname = arg_parser.get_bool("kname");
ck_tile::stream_config stream_config{ ck_tile::stream_config stream_config{nullptr,
nullptr, true, /* log_level = */ (kname ? 1 : 0), stream_warmup, stream_repeat}; true,
/* log_level = */ (kname ? 1 : 0),
stream_warmup,
stream_repeat,
arg_parser.get_str("timer") == std::string("gpu")};
const auto seqstart_q_host = generate_seqstarts(mode, batch, seqlen_q); const auto seqstart_q_host = to_seqstarts(seqlen_qs);
const auto seqstart_k_host = generate_seqstarts(mode, batch, seqlen_k); const auto seqstart_k_host = to_seqstarts(seqlen_ks);
const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads);
using TypeConfig = FmhaFwdTypeConfig<DataType>; using TypeConfig = FmhaFwdTypeConfig<DataType>;
...@@ -302,9 +326,11 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -302,9 +326,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
// host memory for storing all the tensor elements // host memory for storing all the tensor elements
const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1); const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1);
const ck_tile::index_t shape_seqlen_q = const ck_tile::index_t shape_seqlen_q =
(mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back()); (mode == mode_enum::batch ? seqlen_qs[0] : seqstart_q_host.back());
const ck_tile::index_t shape_seqlen_k = const ck_tile::index_t shape_seqlen_k =
(mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back()); (mode == mode_enum::batch ? seqlen_ks[0]
: (seqlen_kpads[0] < 0 ? seqstart_k_host.back()
: seqstart_k_with_padding_host.back()));
ck_tile::HostTensor<QDataType> q_host( ck_tile::HostTensor<QDataType> q_host(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
...@@ -407,6 +433,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -407,6 +433,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqlen_k_buf(seqlen_kpads[0] < 0 ? 0 : seqlen_ks.size() * sizeof(int32_t));
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
q_buf.ToDevice(q_host.data()); q_buf.ToDevice(q_host.data());
...@@ -414,7 +441,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -414,7 +441,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
v_buf.ToDevice(v_host.data()); v_buf.ToDevice(v_host.data());
bias_buf.ToDevice(bias_host.data()); bias_buf.ToDevice(bias_host.data());
seqstart_q.ToDevice(seqstart_q_host.data()); seqstart_q.ToDevice(seqstart_q_host.data());
seqstart_k.ToDevice(seqstart_k_host.data()); seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data()
: seqstart_k_with_padding_host.data());
seqlen_k_buf.ToDevice(seqlen_kpads[0] < 0 ? nullptr : seqlen_ks.data());
alibi_slope_buf.ToDevice(alibi_slope_host.data()); alibi_slope_buf.ToDevice(alibi_slope_host.data());
// clang-format off // clang-format off
...@@ -430,7 +459,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -430,7 +459,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
const std::string prec = arg_parser.get_str("prec"); const std::string prec = arg_parser.get_str("prec");
std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch
<< ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_qs[0] << "/" << seqlen_ks[0]
<< (seqlen_kpads[0] < 0 ? ""
: (std::string("(") + std::to_string(seqlen_kpads[0]) + ")"))
<< ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias << ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias
<< ", lse:" << lse << ", squant:" << squant << ", mask:" << mask << ", v:" << vlayout << ", lse:" << lse << ", squant:" << squant << ", mask:" << mask << ", v:" << vlayout
<< std::flush; << std::flush;
...@@ -460,7 +491,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -460,7 +491,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
return ck_tile::identity{}; return ck_tile::identity{};
}(); }();
auto fmha_args = [&]() { auto fmha_args = [&, k_paddings_ = seqlen_kpads]() {
assert(nhead % nhead_k == 0); assert(nhead % nhead_k == 0);
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
/// seqlen_k] in this example, hence both the 'batch_stride_bias' & /// seqlen_k] in this example, hence both the 'batch_stride_bias' &
...@@ -506,7 +537,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -506,7 +537,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
o_buf.GetDeviceBuffer(), o_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(), seqstart_q.GetDeviceBuffer(),
seqstart_k.GetDeviceBuffer(), seqstart_k.GetDeviceBuffer(),
nullptr, k_paddings_[0] < 0 ? nullptr : seqlen_k_buf.GetDeviceBuffer(),
shape_seqlen_q, shape_seqlen_q,
shape_seqlen_k, shape_seqlen_k,
batch, batch,
...@@ -576,7 +607,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -576,7 +607,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
// adjust matrix index according to the mode // adjust matrix index according to the mode
const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0); const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0);
const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]); const ck_tile::index_t key_offset =
(mode == mode_enum::batch
? 0
: (seqlen_kpads[0] < 0 ? seqstart_k_host[wb] : seqstart_k_with_padding_host[wb]));
const auto v_host_ref_lengths = const auto v_host_ref_lengths =
std::array<ck_tile::index_t, 3>{nhead, hdim_v, real_seqlen_k}; std::array<ck_tile::index_t, 3>{nhead, hdim_v, real_seqlen_k};
...@@ -661,7 +695,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -661,7 +695,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
else else
{ {
return ck_tile::Alibi<SaccDataType, true>{ return ck_tile::Alibi<SaccDataType, true>{
0, real_seqlen_q, real_seqlen_k, ck_tile::AlibiMode::VERTICAL}; 0, real_seqlen_q, real_seqlen_k, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT};
} }
}(); }();
...@@ -671,7 +705,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -671,7 +705,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
for(auto i_h = 0; i_h < nhead; i_h++) for(auto i_h = 0; i_h < nhead; i_h++)
{ {
SaccDataType current_slope = alibi_slope_host(i_b_slope, i_h); SaccDataType current_slope = alibi_slope_host(i_b_slope, i_h);
alibi_host.slope = current_slope; alibi_host.slope = alibi_host.mode == ck_tile::AlibiMode::VERTICAL ? current_slope
: -current_slope;
for(auto i_r = 0; i_r < real_seqlen_q; i_r++) for(auto i_r = 0; i_r < real_seqlen_q; i_r++)
{ {
for(auto i_c = 0; i_c < real_seqlen_k; i_c++) for(auto i_c = 0; i_c < real_seqlen_k; i_c++)
......
...@@ -78,6 +78,11 @@ BOOL_MAP = { ...@@ -78,6 +78,11 @@ BOOL_MAP = {
"f" : "false" "f" : "false"
} }
TILE_PARTITIONER_MAP = {
"shb" : "ck_tile::FmhaFwdTilePartitioner_SHB",
"hbs" : "ck_tile::FmhaFwdTilePartitioner_HBS",
}
DIRECTIONS = ["fwd"] DIRECTIONS = ["fwd"]
GEN_DIR = "" # in Cmake, have to generate files in same folder GEN_DIR = "" # in Cmake, have to generate files in same folder
...@@ -107,7 +112,7 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, ...@@ -107,7 +112,7 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_dvpad}, {F_dvpad},
{F_bias}, {F_bias},
{F_lse}, {F_lse},
{F_squant}, {F_squant},
{F_occupancy}>; {F_occupancy}>;
using fmha_mask_{F_idx} = {F_mask}; using fmha_mask_{F_idx} = {F_mask};
...@@ -136,7 +141,7 @@ using fmha_epilogue_{F_idx} = ...@@ -136,7 +141,7 @@ using fmha_epilogue_{F_idx} =
{F_spad}, {F_dvpad}>>; {F_spad}, {F_dvpad}>>;
using fmha_kernel_{F_idx} = using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner<fmha_shape_{F_idx}>, ck_tile::FmhaFwdKernel<{F_tile_partitioner}<fmha_shape_{F_idx}>,
fmha_pipeline_{F_idx}, fmha_pipeline_{F_idx},
fmha_epilogue_{F_idx}>; fmha_epilogue_{F_idx}>;
...@@ -154,7 +159,7 @@ float fmha_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a) ...@@ -154,7 +159,7 @@ float fmha_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a); auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize(); constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel<blocks.x, kBlockPerCu>(s, k_{{}}, grids, blocks, 0, kargs); return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}} }}
""" """
...@@ -389,6 +394,12 @@ class FmhaFwdKernel: ...@@ -389,6 +394,12 @@ class FmhaFwdKernel:
F_pipeline : FmhaFwdPipeline F_pipeline : FmhaFwdPipeline
mask_impl : str mask_impl : str
def get_tp(self) -> str:
if self.F_mode == 'group':
return 'hbs'
else:
return 'shb'
@property @property
def template(self) -> str: def template(self) -> str:
kernel_body = str() kernel_body = str()
...@@ -413,7 +424,7 @@ class FmhaFwdKernel: ...@@ -413,7 +424,7 @@ class FmhaFwdKernel:
F_spad = BOOL_MAP[self.F_pipeline.F_spad], F_spad = BOOL_MAP[self.F_pipeline.F_spad],
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_bias = BIAS_MAP[self.F_pipeline.F_bias], F_bias = BIAS_MAP[self.F_pipeline.F_bias],
F_lse = BOOL_MAP[self.F_pipeline.F_lse], F_lse = BOOL_MAP[self.F_pipeline.F_lse],
F_squant = BOOL_MAP[self.F_pipeline.F_squant], F_squant = BOOL_MAP[self.F_pipeline.F_squant],
...@@ -421,12 +432,13 @@ class FmhaFwdKernel: ...@@ -421,12 +432,13 @@ class FmhaFwdKernel:
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode = MODE_MAP[self.F_mode], F_mode = MODE_MAP[self.F_mode],
F_pipeline = PIPELINE_MAP[self.F_pipeline.tag]) F_pipeline = PIPELINE_MAP[self.F_pipeline.tag],
F_tile_partitioner = TILE_PARTITIONER_MAP[self.get_tp()])
@property @property
def name(self) -> str: def name(self) -> str:
# TODO: we don't encode idx here # TODO: we don't encode idx here
return f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" +\ return f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_{self.get_tp()}_" + \
self.F_tile.name + '_' + self.F_pipeline.name self.F_tile.name + '_' + self.F_pipeline.name
@property @property
......
...@@ -28,6 +28,7 @@ $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias ...@@ -28,6 +28,7 @@ $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias
$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
done done
done done
......
...@@ -4,12 +4,14 @@ ...@@ -4,12 +4,14 @@
#pragma once #pragma once
#include <cstdint> #include <cstdint>
#include <cstdlib>
#include <optional> #include <optional>
#include <ostream> #include <ostream>
#include <tuple> #include <tuple>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <functional> #include <functional>
#include <string>
#include "ck_tile/core/container/span.hpp" #include "ck_tile/core/container/span.hpp"
...@@ -37,12 +39,14 @@ std::vector<int32_t> to_seqstarts(ck_tile::span<const int32_t> seqlens) ...@@ -37,12 +39,14 @@ std::vector<int32_t> to_seqstarts(ck_tile::span<const int32_t> seqlens)
std::vector<int32_t> generate_seqlens(mode_enum mode, std::vector<int32_t> generate_seqlens(mode_enum mode,
unsigned count, unsigned count,
int32_t seqlens_sum, int32_t seqlen_avg,
int32_t seqlen_max = -1, // if not negative, clamp max
std::optional<unsigned> seed = std::nullopt) std::optional<unsigned> seed = std::nullopt)
{ {
assert(0 < count); assert(0 < count);
std::vector<int32_t> seqlens(count, seqlens_sum); std::vector<int32_t> seqlens(
count, seqlen_max > 0 ? (seqlen_avg < seqlen_max ? seqlen_avg : seqlen_max) : seqlen_avg);
if(mode == mode_enum::group && 1 < count) if(mode == mode_enum::group && 1 < count)
{ {
...@@ -55,7 +59,7 @@ std::vector<int32_t> generate_seqlens(mode_enum mode, ...@@ -55,7 +59,7 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
std::uniform_int_distribution<size_type> step_dist(1, count - 1); std::uniform_int_distribution<size_type> step_dist(1, count - 1);
auto next_step = std::bind(step_dist, std::ref(random_engine)); auto next_step = std::bind(step_dist, std::ref(random_engine));
for(unsigned repeat = seqlens_sum * (count / 2); 0 < repeat; --repeat) for(unsigned repeat = seqlen_avg * (count / 2); 0 < repeat; --repeat)
{ {
const size_type to_decrease = next_idx(); const size_type to_decrease = next_idx();
// make sure each elements of seqlens is always greater than 0 // make sure each elements of seqlens is always greater than 0
...@@ -66,6 +70,11 @@ std::vector<int32_t> generate_seqlens(mode_enum mode, ...@@ -66,6 +70,11 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
const size_type to_increase = (to_decrease + next_step()) % count; const size_type to_increase = (to_decrease + next_step()) % count;
if(seqlen_max > 0 && seqlens[to_increase] >= seqlen_max)
{
continue;
}
--seqlens[to_decrease]; --seqlens[to_decrease];
++seqlens[to_increase]; ++seqlens[to_increase];
} }
...@@ -76,10 +85,91 @@ std::vector<int32_t> generate_seqlens(mode_enum mode, ...@@ -76,10 +85,91 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
std::vector<int32_t> generate_seqstarts(mode_enum mode, std::vector<int32_t> generate_seqstarts(mode_enum mode,
unsigned count, unsigned count,
int32_t seqlens_sum, int32_t seqlen_avg,
int32_t seqlen_max = -1,
std::optional<unsigned> seed = std::nullopt) std::optional<unsigned> seed = std::nullopt)
{ {
return to_seqstarts(generate_seqlens(mode, count, seqlens_sum, seed)); return to_seqstarts(generate_seqlens(mode, count, seqlen_avg, seqlen_max, seed));
}
/*
* decode the seqlen string from cmdline
* example (assume batch=3)
* q_val=1,2,3 k_val=4,5,6 -> OK
* q_val=1,2,3 -> OK, k same as q
* q_val=1,2 -> OK, q will rand remaining 1 element, k same as q
* q_val=1,2 k_val=4,5 -> OK, q/k will rand remaining 1 element
* q_val=1,2,3,4 -> OK, but ignore exceed one
*
* q_val=1,2 k_val=4,5,6 -> not OK, k must have same splits with q
* q_val=1,2 k_val=4 -> not OK, k must have same splits with q
*/
std::tuple<std::vector<ck_tile::index_t>,
std::vector<ck_tile::index_t>,
std::vector<ck_tile::index_t>>
decode_seqlen(mode_enum mode,
ck_tile::index_t batch,
std::string q_val,
std::string k_val,
std::string k_pad_val,
std::optional<unsigned> seed = std::nullopt)
{
#define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str()))
if(mode == mode_enum::batch)
{
ck_tile::index_t q = _S2I_(q_val);
ck_tile::index_t k = _S2I_(k_val);
auto s_q = std::vector<ck_tile::index_t>(batch, q);
auto s_k = std::vector<ck_tile::index_t>(batch, k < 0 ? q : k);
auto s_kpad = std::vector<ck_tile::index_t>(batch, -1); // TODO: batch not support k_padding
return std::make_tuple(s_q, s_k, s_kpad);
}
else
{
ck_tile::index_t idx = 0;
std::string::size_type pos_q = 0;
std::string::size_type pos_k = 0;
std::string::size_type pos_kp = 0;
std::vector<ck_tile::index_t> s_q;
std::vector<ck_tile::index_t> s_k;
std::vector<ck_tile::index_t> s_kpad;
while(true)
{
auto found_q = q_val.find(',', pos_q);
auto found_k = k_val.find(',', pos_k);
auto found_kp = k_pad_val.find(',', pos_kp);
ck_tile::index_t q = _S2I_(
q_val.substr(pos_q, found_q == std::string::npos ? found_q : found_q - pos_q));
ck_tile::index_t k = _S2I_(
k_val.substr(pos_k, found_k == std::string::npos ? found_k : found_k - pos_k));
ck_tile::index_t kp = _S2I_(k_pad_val.substr(
pos_kp, found_kp == std::string::npos ? found_kp : found_kp - pos_kp));
s_q.push_back(q);
s_k.push_back(k < 0 ? q : k);
s_kpad.push_back(kp);
idx++;
if(found_q == std::string::npos || idx >= batch)
{
break;
}
pos_q = found_q + 1;
pos_k = found_k == std::string::npos ? pos_k : found_k + 1;
pos_kp = found_kp == std::string::npos ? pos_kp : found_kp + 1;
}
if(idx < batch)
{
auto rem_q = generate_seqlens(mode, batch - idx, s_q.back(), s_kpad.back(), seed);
auto rem_k = generate_seqlens(mode, batch - idx, s_k.back(), s_kpad.back(), seed);
s_q.insert(s_q.end(), rem_q.begin(), rem_q.end());
s_k.insert(s_k.end(), rem_k.begin(), rem_k.end());
s_kpad.insert(s_kpad.end(), batch - idx, s_kpad.back());
}
return std::make_tuple(s_q, s_k, s_kpad);
}
#undef _S2I_
} }
int env_get_int(const char* var_name, int default_int) int env_get_int(const char* var_name, int default_int)
...@@ -87,6 +177,6 @@ int env_get_int(const char* var_name, int default_int) ...@@ -87,6 +177,6 @@ int env_get_int(const char* var_name, int default_int)
char* v = getenv(var_name); char* v = getenv(var_name);
int r = default_int; int r = default_int;
if(v) if(v)
r = atoi(v); r = std::atoi(v);
return r; return r;
} }
...@@ -29,6 +29,25 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz ...@@ -29,6 +29,25 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz
return __builtin_bit_cast(int32x4_t, res); return __builtin_bit_cast(int32x4_t, res);
} }
namespace impl {
// below type indicate the data type used for buffer load inline asm
// clang-format off
template<index_t N, typename T> struct buffer_load_trait;
template<typename T> struct buffer_load_trait<16, T> { using payload_t = fp32x4_t; };
template<typename T> struct buffer_load_trait<8 , T> { using payload_t = fp32x2_t; };
template<typename T> struct buffer_load_trait<4 , T> { using payload_t = float; };
template<typename T> struct buffer_load_trait<2 , T> { using payload_t = float; };
template<typename T> struct buffer_load_trait<1 , T> { using payload_t = float; };
#if CK_TILE_BUFFER_LOAD_RAW_BF16_WA
template<> struct buffer_load_trait<16, thread_buffer<bf16_t, 8>> { using payload_t = bf16x8_t; };
template<> struct buffer_load_trait<8 , thread_buffer<bf16_t, 4>> { using payload_t = bf16x4_t; };
template<> struct buffer_load_trait<4 , thread_buffer<bf16_t, 2>> { using payload_t = bf16x2_t; };
#endif
// clang-format on
} // namespace impl
// TODO: glc/slc/... // TODO: glc/slc/...
template <index_t bytes> template <index_t bytes>
struct buffer_load; struct buffer_load;
...@@ -48,7 +67,7 @@ struct buffer_load<16> ...@@ -48,7 +67,7 @@ struct buffer_load<16>
index_t /*flag*/ = 0) index_t /*flag*/ = 0)
{ {
static_assert(sizeof(T) == 16); static_assert(sizeof(T) == 16);
using mbuf_t = fp32x4_t; using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t;
asm volatile("buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4" asm volatile("buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4"
: "+v"(reinterpret_cast<mbuf_t&>(value)) : "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
...@@ -68,7 +87,7 @@ struct buffer_load<8> ...@@ -68,7 +87,7 @@ struct buffer_load<8>
index_t /*flag*/ = 0) index_t /*flag*/ = 0)
{ {
static_assert(sizeof(T) == 8); static_assert(sizeof(T) == 8);
using mbuf_t = fp32x2_t; using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t;
asm volatile("buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4" asm volatile("buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4"
: "+v"(reinterpret_cast<mbuf_t&>(value)) : "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
...@@ -88,7 +107,7 @@ struct buffer_load<4> ...@@ -88,7 +107,7 @@ struct buffer_load<4>
index_t /*flag*/ = 0) index_t /*flag*/ = 0)
{ {
static_assert(sizeof(T) == 4); static_assert(sizeof(T) == 4);
using mbuf_t = float; using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t;
asm volatile("buffer_load_dword %0, %1, %2, %3 offen offset:%4" asm volatile("buffer_load_dword %0, %1, %2, %3 offen offset:%4"
: "+v"(reinterpret_cast<mbuf_t&>(value)) : "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
...@@ -108,7 +127,7 @@ struct buffer_load<2> ...@@ -108,7 +127,7 @@ struct buffer_load<2>
index_t /*flag*/ = 0) index_t /*flag*/ = 0)
{ {
static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually
using mbuf_t = float; using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t;
asm volatile("buffer_load_ushort %0, %1, %2, %3 offen offset:%4" asm volatile("buffer_load_ushort %0, %1, %2, %3 offen offset:%4"
: "+v"(reinterpret_cast<mbuf_t&>(value)) : "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
...@@ -128,7 +147,7 @@ struct buffer_load<1> ...@@ -128,7 +147,7 @@ struct buffer_load<1>
index_t /*flag*/ = 0) index_t /*flag*/ = 0)
{ {
static_assert(sizeof(T) == 4); static_assert(sizeof(T) == 4);
using mbuf_t = float; using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t;
asm volatile("buffer_load_ubyte %0, %1, %2, %3 offen offset:%4" asm volatile("buffer_load_ubyte %0, %1, %2, %3 offen offset:%4"
: "+v"(reinterpret_cast<mbuf_t&>(value)) : "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
...@@ -152,7 +171,7 @@ struct buffer_load_if<16> ...@@ -152,7 +171,7 @@ struct buffer_load_if<16>
{ {
static_assert(sizeof(T) == 16); static_assert(sizeof(T) == 16);
auto saved_exec = __builtin_amdgcn_read_exec(); auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = fp32x4_t; using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t;
static_assert(sizeof(mbuf_t) == sizeof(T)); static_assert(sizeof(mbuf_t) == sizeof(T));
asm volatile( asm volatile(
"v_cmpx_le_u32 exec, 1, %5\n" "v_cmpx_le_u32 exec, 1, %5\n"
...@@ -177,7 +196,7 @@ struct buffer_load_if<8> ...@@ -177,7 +196,7 @@ struct buffer_load_if<8>
{ {
static_assert(sizeof(T) == 8); static_assert(sizeof(T) == 8);
auto saved_exec = __builtin_amdgcn_read_exec(); auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = fp32x2_t; using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t;
asm volatile( asm volatile(
"v_cmpx_le_u32 exec, 1, %5\n" "v_cmpx_le_u32 exec, 1, %5\n"
"buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4\n" "buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4\n"
...@@ -201,7 +220,7 @@ struct buffer_load_if<4> ...@@ -201,7 +220,7 @@ struct buffer_load_if<4>
{ {
static_assert(sizeof(T) == 4); static_assert(sizeof(T) == 4);
auto saved_exec = __builtin_amdgcn_read_exec(); auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float; using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t;
asm volatile( asm volatile(
"v_cmpx_le_u32 exec, 1, %5\n" "v_cmpx_le_u32 exec, 1, %5\n"
"buffer_load_dword %0, %1, %2, %3 offen offset:%4\n" "buffer_load_dword %0, %1, %2, %3 offen offset:%4\n"
...@@ -225,7 +244,7 @@ struct buffer_load_if<2> ...@@ -225,7 +244,7 @@ struct buffer_load_if<2>
{ {
static_assert(sizeof(T) == 4); static_assert(sizeof(T) == 4);
auto saved_exec = __builtin_amdgcn_read_exec(); auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float; using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t;
asm volatile( asm volatile(
"v_cmpx_le_u32 exec, 1, %5\n" "v_cmpx_le_u32 exec, 1, %5\n"
"buffer_load_ushort %0, %1, %2, %3 offen offset:%4\n" "buffer_load_ushort %0, %1, %2, %3 offen offset:%4\n"
...@@ -249,7 +268,7 @@ struct buffer_load_if<1> ...@@ -249,7 +268,7 @@ struct buffer_load_if<1>
{ {
static_assert(sizeof(T) == 4); static_assert(sizeof(T) == 4);
auto saved_exec = __builtin_amdgcn_read_exec(); auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float; using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t;
asm volatile( asm volatile(
"v_cmpx_le_u32 exec, 1, %5\n" "v_cmpx_le_u32 exec, 1, %5\n"
"buffer_load_ubyte %0, %1, %2, %3 offen offset:%4\n" "buffer_load_ubyte %0, %1, %2, %3 offen offset:%4\n"
......
...@@ -171,3 +171,7 @@ ...@@ -171,3 +171,7 @@
#ifndef CK_TILE_FMHA_FWD_FAST_EXP2 #ifndef CK_TILE_FMHA_FWD_FAST_EXP2
#define CK_TILE_FMHA_FWD_FAST_EXP2 0 #define CK_TILE_FMHA_FWD_FAST_EXP2 0
#endif #endif
#ifndef CK_TILE_BUFFER_LOAD_RAW_BF16_WA
#define CK_TILE_BUFFER_LOAD_RAW_BF16_WA 1
#endif
...@@ -20,3 +20,4 @@ ...@@ -20,3 +20,4 @@
#include "ck_tile/host/reference/reference_reduce.hpp" #include "ck_tile/host/reference/reference_reduce.hpp"
#include "ck_tile/host/reference/reference_softmax.hpp" #include "ck_tile/host/reference/reference_softmax.hpp"
#include "ck_tile/host/stream_config.hpp" #include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/timer.hpp"
...@@ -27,7 +27,14 @@ struct DeviceMem ...@@ -27,7 +27,14 @@ struct DeviceMem
DeviceMem() : mpDeviceBuf(nullptr), mMemSize(0) {} DeviceMem() : mpDeviceBuf(nullptr), mMemSize(0) {}
DeviceMem(std::size_t mem_size) : mMemSize(mem_size) DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
{ {
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize)); if(mMemSize != 0)
{
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
}
else
{
mpDeviceBuf = nullptr;
}
} }
void Realloc(std::size_t mem_size) void Realloc(std::size_t mem_size)
{ {
...@@ -36,7 +43,14 @@ struct DeviceMem ...@@ -36,7 +43,14 @@ struct DeviceMem
HIP_CHECK_ERROR(hipFree(mpDeviceBuf)); HIP_CHECK_ERROR(hipFree(mpDeviceBuf));
} }
mMemSize = mem_size; mMemSize = mem_size;
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize)); if(mMemSize != 0)
{
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
}
else
{
mpDeviceBuf = nullptr;
}
} }
void* GetDeviceBuffer() const { return mpDeviceBuf; } void* GetDeviceBuffer() const { return mpDeviceBuf; }
std::size_t GetBufferSize() const { return mMemSize; } std::size_t GetBufferSize() const { return mMemSize; }
...@@ -47,15 +61,18 @@ struct DeviceMem ...@@ -47,15 +61,18 @@ struct DeviceMem
HIP_CHECK_ERROR( HIP_CHECK_ERROR(
hipMemcpy(mpDeviceBuf, const_cast<void*>(p), mMemSize, hipMemcpyHostToDevice)); hipMemcpy(mpDeviceBuf, const_cast<void*>(p), mMemSize, hipMemcpyHostToDevice));
} }
else // else
{ // {
throw std::runtime_error("ToDevice with an empty pointer"); // throw std::runtime_error("ToDevice with an empty pointer");
} // }
} }
void ToDevice(const void* p, const std::size_t cpySize) const void ToDevice(const void* p, const std::size_t cpySize) const
{ {
HIP_CHECK_ERROR( if(mpDeviceBuf)
hipMemcpy(mpDeviceBuf, const_cast<void*>(p), cpySize, hipMemcpyHostToDevice)); {
HIP_CHECK_ERROR(
hipMemcpy(mpDeviceBuf, const_cast<void*>(p), cpySize, hipMemcpyHostToDevice));
}
} }
void FromDevice(void* p) const void FromDevice(void* p) const
{ {
...@@ -63,14 +80,17 @@ struct DeviceMem ...@@ -63,14 +80,17 @@ struct DeviceMem
{ {
HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost)); HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost));
} }
else // else
{ // {
throw std::runtime_error("FromDevice with an empty pointer"); // throw std::runtime_error("FromDevice with an empty pointer");
} // }
} }
void FromDevice(void* p, const std::size_t cpySize) const void FromDevice(void* p, const std::size_t cpySize) const
{ {
HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost)); if(mpDeviceBuf)
{
HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost));
}
} }
void SetZero() const void SetZero() const
{ {
...@@ -82,13 +102,16 @@ struct DeviceMem ...@@ -82,13 +102,16 @@ struct DeviceMem
template <typename T> template <typename T>
void SetValue(T x) const void SetValue(T x) const
{ {
if(mMemSize % sizeof(T) != 0) if(mpDeviceBuf)
{ {
throw std::runtime_error("wrong! not entire DeviceMem will be set"); if(mMemSize % sizeof(T) != 0)
} {
throw std::runtime_error("wrong! not entire DeviceMem will be set");
}
// TODO: call a gpu kernel to set the value (?) // TODO: call a gpu kernel to set the value (?)
set_buffer_value<T><<<1, 1024>>>(static_cast<T*>(mpDeviceBuf), x, mMemSize / sizeof(T)); set_buffer_value<T><<<1, 1024>>>(static_cast<T*>(mpDeviceBuf), x, mMemSize / sizeof(T));
}
} }
~DeviceMem() ~DeviceMem()
{ {
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "ck_tile/core/config.hpp" #include "ck_tile/core/config.hpp"
#include "ck_tile/host/stream_config.hpp" #include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/hip_check_error.hpp" #include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/timer.hpp"
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <cstddef> #include <cstddef>
...@@ -14,153 +15,92 @@ template <int MaxThreadPerBlock, int MinBlockPerCu, typename Kernel, typename... ...@@ -14,153 +15,92 @@ template <int MaxThreadPerBlock, int MinBlockPerCu, typename Kernel, typename...
#if CK_TILE_USE_LAUNCH_BOUNDS #if CK_TILE_USE_LAUNCH_BOUNDS
__launch_bounds__(MaxThreadPerBlock, MinBlockPerCu) __launch_bounds__(MaxThreadPerBlock, MinBlockPerCu)
#endif #endif
__global__ void kentry(Kernel f, Args... args) __global__ void kentry(Args... args)
{ {
f(args...); Kernel{}(args...);
} }
template <typename... Args, typename F> //
CK_TILE_HOST float launch_and_time_kernel(const stream_config& s, // return a anonymous functor(lambda) to be called later
F kernel, // the KernelImpl should be a class without non-static data member, or let's say
dim3 grid_dim, // can be instantiate with "KernelImpl{}"
dim3 block_dim, //
std::size_t lds_byte, // the "static __device__ operator()(some_arg)" is the entry point of KernelImpl
Args... args) //
template <int MaxThreadPerBlock = CK_TILE_MAX_THREAD_PER_BLOCK,
int MinBlockPerCu = CK_TILE_MIN_BLOCK_PER_CU,
typename KernelImpl,
typename... Args>
CK_TILE_HOST auto
make_kernel(KernelImpl /*f*/, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
{ {
#if CK_TILE_TIME_KERNEL const auto kernel = kentry<MaxThreadPerBlock, MinBlockPerCu, KernelImpl, Args...>;
if(s.time_kernel_)
{
// warm up
for(int i = 0; i < s.cold_niters_; ++i)
{
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
hip_check_error(hipGetLastError());
}
const int nrepeat = s.nrepeat_;
hipEvent_t start, stop;
HIP_CHECK_ERROR(hipEventCreate(&start));
HIP_CHECK_ERROR(hipEventCreate(&stop));
HIP_CHECK_ERROR(hipDeviceSynchronize());
HIP_CHECK_ERROR(hipEventRecord(start, s.stream_id_));
for(int i = 0; i < nrepeat; ++i)
{
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
hip_check_error(hipGetLastError());
}
HIP_CHECK_ERROR(hipEventRecord(stop, s.stream_id_));
HIP_CHECK_ERROR(hipEventSynchronize(stop));
float total_time = 0;
HIP_CHECK_ERROR(hipEventElapsedTime(&total_time, start, stop));
return total_time / nrepeat; return [=](const stream_config& s) {
}
else
{
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...); kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
hip_check_error(hipGetLastError()); };
return 0;
}
#else
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
hip_check_error(hipGetLastError());
return 0;
#endif
} }
template <typename... Args, typename F, typename PreProcessFunc> // clang-format off
CK_TILE_HOST float launch_and_time_kernel_with_preprocess(const stream_config& s, /*
PreProcessFunc preprocess, * launch_kernel()
F kernel, *
dim3 grid_dim, * this is the function to launch arbitrary number of kernels with optional timer(selected by stream_config)
dim3 block_dim, * the callables should have signature as "operator()(const stream_config& s){ ... }" to call
std::size_t lds_byte, *
Args... args) * the simplest way is pass in a lambda function, with "[=](const stream_config& s){ call_your_kernel_here() }"
* as signature, for the callable (pay attention to the capture list)
*
* e.g.
* ck_tile::launch_kernel(s,
* [=](const stream_config& s){ hipMemset(ptr, 0, size) },
* [=](const stream_config& s){ some_kernel<<<grids, blocks>>>(arg); }
* );
*
* if you use ck_tile kernel, or similiar to this style (structure with "static __device__ operator()(...){}")
* you can pass your kernel to ck_tile::make_kernel(), which will create a anonymous functor for you,
* then pass it to ck_tile::launch_kernel()
*
* e.g.
* ck_tile::launch_kernel(s,
* ck_tile::make_kernel<T0, B0>(kernel_0{}, grids0, blocks0, 0, kargs0),
* ck_tile::make_kernel<T0, B1>(kernel_1{}, grids1, blocks1, 0, kargs1),
* ...);
**/
// clang-format on
template <typename... Callables>
CK_TILE_HOST float launch_kernel(const stream_config& s, Callables... callables)
{ {
#if CK_TILE_TIME_KERNEL // clang-format off
if(s.time_kernel_) if(!s.time_kernel_) {
{ (callables(s),...); hip_check_error(hipGetLastError());
#if CK_TILE_DEBUG_LOG return 0;
printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", }
__func__, if(s.is_gpu_timer_) {
grid_dim.x, gpu_timer timer {};
grid_dim.y,
grid_dim.z,
block_dim.x,
block_dim.y,
block_dim.z);
printf("Warm up 1 time\n");
#endif
// warm up
preprocess();
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
hip_check_error(hipGetLastError());
const int nrepeat = 10;
#if CK_TILE_DEBUG_LOG
printf("Start running %d times...\n", nrepeat);
#endif
hipEvent_t start, stop;
HIP_CHECK_ERROR(hipEventCreate(&start));
HIP_CHECK_ERROR(hipEventCreate(&stop));
HIP_CHECK_ERROR(hipDeviceSynchronize());
HIP_CHECK_ERROR(hipEventRecord(start, s.stream_id_));
for(int i = 0; i < nrepeat; ++i) // warmup
{ for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } hip_check_error(hipGetLastError());
preprocess();
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
hip_check_error(hipGetLastError());
}
HIP_CHECK_ERROR(hipEventRecord(stop, s.stream_id_)); timer.start(s.stream_id_);
HIP_CHECK_ERROR(hipEventSynchronize(stop)); for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } hip_check_error(hipGetLastError());
timer.stop(s.stream_id_);
float total_time = 0; return timer.duration() / s.nrepeat_;
}
else {
cpu_timer timer {};
HIP_CHECK_ERROR(hipEventElapsedTime(&total_time, start, stop)); // warmup
for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } hip_check_error(hipGetLastError());
return total_time / nrepeat; timer.start(s.stream_id_);
} for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } hip_check_error(hipGetLastError());
else timer.stop(s.stream_id_);
{
preprocess();
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
hip_check_error(hipGetLastError());
return 0; return timer.duration() / s.nrepeat_;
} }
#else // clang-format on
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
hip_check_error(hipGetLastError());
return 0;
#endif
} }
template <int MaxThreadPerBlock = CK_TILE_MAX_THREAD_PER_BLOCK,
int MinBlockPerCu = CK_TILE_MIN_BLOCK_PER_CU,
typename KernelImpl,
typename... Args>
CK_TILE_HOST float launch_kernel(const stream_config& s,
KernelImpl kernel_impl,
dim3 grid_dim,
dim3 block_dim,
std::size_t dynamic_smem_byte,
Args... args)
{
const auto kernel = kentry<MaxThreadPerBlock, MinBlockPerCu, KernelImpl, Args...>;
return launch_and_time_kernel(
s, kernel, grid_dim, block_dim, dynamic_smem_byte, kernel_impl, args...);
}
} // namespace ck_tile } // namespace ck_tile
...@@ -6,6 +6,22 @@ ...@@ -6,6 +6,22 @@
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
namespace ck_tile { namespace ck_tile {
/*
* construct this structure with behavior as:
*
* // create stream config with default stream(NULL), and not timing the kernel
* stream_config s = stream_config{};
*
* // create stream config with _some_stream_id_, and not timing the kernel
* stream_config s = stream_config{_some_stream_id_};
*
* // create stream config with _some_stream_id_, and benchmark with warmup/repeat as default
* stream_config s = stream_config{_some_stream_id_, true};
*
* // create stream config with _some_stream_id_, and benchmark using cpu timer
* stream_config s = stream_config{_some_stream_id_, true, 0, 3, 10, false};
**/
struct stream_config struct stream_config
{ {
hipStream_t stream_id_ = nullptr; hipStream_t stream_id_ = nullptr;
...@@ -13,5 +29,6 @@ struct stream_config ...@@ -13,5 +29,6 @@ struct stream_config
int log_level_ = 0; int log_level_ = 0;
int cold_niters_ = 3; int cold_niters_ = 3;
int nrepeat_ = 10; int nrepeat_ = 10;
bool is_gpu_timer_ = true; // keep compatible
}; };
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include <hip/hip_runtime.h>
#include <cstddef>
#include <chrono>
namespace ck_tile {
struct gpu_timer
{
CK_TILE_HOST gpu_timer()
{
HIP_CHECK_ERROR(hipEventCreate(&start_evt));
HIP_CHECK_ERROR(hipEventCreate(&stop_evt));
}
CK_TILE_HOST ~gpu_timer() noexcept(false)
{
HIP_CHECK_ERROR(hipEventDestroy(start_evt));
HIP_CHECK_ERROR(hipEventDestroy(stop_evt));
}
CK_TILE_HOST void start(const hipStream_t& s)
{
HIP_CHECK_ERROR(hipDeviceSynchronize());
HIP_CHECK_ERROR(hipEventRecord(start_evt, s));
}
CK_TILE_HOST void stop(const hipStream_t& s)
{
HIP_CHECK_ERROR(hipEventRecord(stop_evt, s));
HIP_CHECK_ERROR(hipEventSynchronize(stop_evt));
}
// return in ms
CK_TILE_HOST float duration() const
{
float ms = 0;
HIP_CHECK_ERROR(hipEventElapsedTime(&ms, start_evt, stop_evt));
return ms;
}
private:
hipEvent_t start_evt, stop_evt;
};
struct cpu_timer
{
// torch.utils.benchmark.Timer(), there is a sync inside each timer callback
CK_TILE_HOST void start(const hipStream_t&)
{
HIP_CHECK_ERROR(hipDeviceSynchronize());
start_tick = std::chrono::high_resolution_clock::now();
}
// torch.utils.benchmark.Timer(), there is a sync inside each timer callback
CK_TILE_HOST void stop(const hipStream_t&)
{
HIP_CHECK_ERROR(hipDeviceSynchronize());
stop_tick = std::chrono::high_resolution_clock::now();
}
// return in ms
CK_TILE_HOST float duration() const
{
double sec =
std::chrono::duration_cast<std::chrono::duration<double>>(stop_tick - start_tick)
.count();
return static_cast<float>(sec * 1e3);
}
private:
std::chrono::time_point<std::chrono::high_resolution_clock> start_tick;
std::chrono::time_point<std::chrono::high_resolution_clock> stop_tick;
};
} // namespace ck_tile
...@@ -23,13 +23,13 @@ VERTICAL: ...@@ -23,13 +23,13 @@ VERTICAL:
[0] 1 2 3 4 5 [0] 1 2 3 4 5
[0] 1 2 3 4 5 [0] 1 2 3 4 5
TOP_LEFT: TOP_LEFT(but negative):
[0] 1 2 3 4 5 [0] 1 2 3 4 5
1 [0] 1 2 3 4 1 [0] 1 2 3 4
2 1 [0] 1 2 3 2 1 [0] 1 2 3
3 2 1 [0] 1 2 3 2 1 [0] 1 2
FROM_BOTTOM_RIGHT: FROM_BOTTOM_RIGHT(but negative):
2 1 [0] 1 2 3 2 1 [0] 1 2 3
3 2 1 [0] 1 2 3 2 1 [0] 1 2
4 3 2 1 [0] 1 4 3 2 1 [0] 1
...@@ -54,7 +54,7 @@ struct Alibi ...@@ -54,7 +54,7 @@ struct Alibi
index_t x_total_, index_t x_total_,
AlibiMode mode_ = AlibiMode::VERTICAL) AlibiMode mode_ = AlibiMode::VERTICAL)
{ {
slope = mode_ == AlibiMode::VERTICAL ? slope_ : -slope; slope = mode_ == AlibiMode::VERTICAL ? slope_ : -slope_;
shift_left_up = [&]() { shift_left_up = [&]() {
if(RowMajor) if(RowMajor)
......
...@@ -76,7 +76,7 @@ struct FmhaFwdKernel ...@@ -76,7 +76,7 @@ struct FmhaFwdKernel
return n.empty() ? n : std::string("p") + n; }(); return n.empty() ? n : std::string("p") + n; }();
return return
_SS_("fmha_fwd_d") + _TS_(bfs::kK0BlockLength) + "_" + _SS_(t2s<QDataType>::name) + _SS_("fmha_fwd_d") + _TS_(bfs::kK0BlockLength) + "_" + _SS_(t2s<QDataType>::name) +
"_" + (kIsGroupMode ? "group" : "batch") + "_" + "_" + (kIsGroupMode ? "group" : "batch") + "_" + _SS_(TilePartitioner::name) + "_"
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" + _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" +
"r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" + "r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" +
...@@ -702,7 +702,7 @@ struct FmhaFwdKernel ...@@ -702,7 +702,7 @@ struct FmhaFwdKernel
else else
{ {
return Alibi<SaccDataType, true>{ return Alibi<SaccDataType, true>{
slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::VERTICAL}; slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
} }
} }
else else
......
...@@ -18,10 +18,12 @@ struct FmhaFwdTilePartitioner ...@@ -18,10 +18,12 @@ struct FmhaFwdTilePartitioner
static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1; static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1;
static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1; static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1;
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size_, static constexpr const char* name = "shb";
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_, CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t hdim_v_) ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_)
{ {
// TODO: this may need tuning // TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0) * return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0) *
...@@ -51,4 +53,53 @@ struct FmhaFwdTilePartitioner ...@@ -51,4 +53,53 @@ struct FmhaFwdTilePartitioner
} }
}; };
template <typename BlockFmhaShape_>
using FmhaFwdTilePartitioner_SHB = FmhaFwdTilePartitioner<BlockFmhaShape_>;
template <typename BlockFmhaShape_>
struct FmhaFwdTilePartitioner_HBS
{
using BlockFmhaShape = ck_tile::remove_cvref_t<BlockFmhaShape_>;
static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0;
static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0;
static constexpr ck_tile::index_t kK0 = BlockFmhaShape::kK0;
static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1;
static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1;
static constexpr const char* name = "hbs";
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_)
{
// TODO: this may need tuning
return dim3(nhead_,
batch_size_,
ck_tile::integer_divide_ceil(seqlen_q_, kM0) *
ck_tile::integer_divide_ceil(hdim_v_, kN1));
}
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v)
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1);
const index_t i_block = blockIdx.z;
const index_t i_nhead = blockIdx.x;
const index_t i_batch = blockIdx.y;
const auto f = [](index_t dividend, index_t divisor) {
index_t quotient = dividend / divisor;
index_t modulus = dividend - quotient * divisor;
return ck_tile::make_tuple(quotient, modulus);
};
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
}
};
} // namespace ck_tile } // namespace ck_tile
...@@ -131,74 +131,74 @@ int main() ...@@ -131,74 +131,74 @@ int main()
0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5,
0, 1, 2, 3, 4, 5}); 0, 1, 2, 3, 4, 5});
rtn &= test_alibi_traverse_with_slope<true, dtype>(4, 6, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, 3, 4, 5, rtn &= test_alibi_traverse_with_slope<true, dtype>(4, 6, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2, -3, -4, -5,
1, 0, 1, 2, 3, 4, -1, 0, -1, -2, -3, -4,
2, 1, 0, 1, 2, 3, -2, -1, 0, -1, -2, -3,
3, 2, 1, 0, 1, 2}); -3, -2, -1, 0, -1, -2});
rtn &= test_alibi_traverse_with_slope<true, dtype>(6, 4, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, 3, rtn &= test_alibi_traverse_with_slope<true, dtype>(6, 4, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2, -3,
1, 0, 1, 2, -1, 0, -1, -2,
2, 1, 0, 1, -2, -1, 0, -1,
3, 2, 1, 0, -3, -2, -1, 0,
4, 3, 2, 1, -4, -3, -2, -1,
5, 4, 3, 2}); -5, -4, -3, -2});
rtn &= test_alibi_traverse_with_slope<true, dtype>(3, 3, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, rtn &= test_alibi_traverse_with_slope<true, dtype>(3, 3, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2,
1, 0, 1, -1, 0, -1,
2, 1, 0}); -2, -1, 0});
rtn &= test_alibi_traverse_with_slope<true, dtype>(4, 6, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {2, 1, 0, 1, 2, 3, rtn &= test_alibi_traverse_with_slope<true, dtype>(4, 6, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {-2, -1, 0, -1, -2, -3,
3, 2, 1, 0, 1, 2, -3, -2, -1, 0, -1, -2,
4, 3, 2, 1, 0, 1, -4, -3, -2, -1, 0, -1,
5, 4, 3, 2, 1, 0}); -5, -4, -3, -2, -1, 0});
rtn &= test_alibi_traverse_with_slope<true, dtype>(6, 4, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {2, 3, 4, 5, rtn &= test_alibi_traverse_with_slope<true, dtype>(6, 4, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {-2, -3, -4, -5,
1, 2, 3, 4, -1, -2, -3, -4,
0, 1, 2, 3, 0, -1, -2, -3,
1, 0, 1, 2, -1, 0, -1, -2,
2, 1, 0, 1, -2, -1, 0, -1,
3, 2, 1, 0}); -3, -2, -1, 0});
rtn &= test_alibi_traverse_with_slope<true, dtype>(3, 3, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {0, 1, 2, rtn &= test_alibi_traverse_with_slope<true, dtype>(3, 3, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, { 0, -1, -2,
1, 0, 1, -1, 0, -1,
2, 1, 0}); -2, -1, 0});
rtn &= test_alibi_traverse_with_slope<false, dtype>(4, 6, slope, ck_tile::AlibiMode::VERTICAL, {0, 1, 2, 3, 4, 5, rtn &= test_alibi_traverse_with_slope<false, dtype>(4, 6, slope, ck_tile::AlibiMode::VERTICAL, {0, 1, 2, 3, 4, 5,
0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5,
0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5,
0, 1, 2, 3, 4, 5}); 0, 1, 2, 3, 4, 5});
rtn &= test_alibi_traverse_with_slope<false, dtype>(4, 6, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, 3, 4, 5, rtn &= test_alibi_traverse_with_slope<false, dtype>(4, 6, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2, -3, -4, -5,
1, 0, 1, 2, 3, 4, -1, 0, -1, -2, -3, -4,
2, 1, 0, 1, 2, 3, -2, -1, 0, -1, -2, -3,
3, 2, 1, 0, 1, 2}); -3, -2, -1, 0, -1, -2});
rtn &= test_alibi_traverse_with_slope<false, dtype>(6, 4, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, 3, rtn &= test_alibi_traverse_with_slope<false, dtype>(6, 4, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2, -3,
1, 0, 1, 2, -1, 0, -1, -2,
2, 1, 0, 1, -2, -1, 0, -1,
3, 2, 1, 0, -3, -2, -1, 0,
4, 3, 2, 1, -4, -3, -2, -1,
5, 4, 3, 2}); -5, -4, -3, -2});
rtn &= test_alibi_traverse_with_slope<false, dtype>(3, 3, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, rtn &= test_alibi_traverse_with_slope<false, dtype>(3, 3, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2,
1, 0, 1, -1, 0, -1,
2, 1, 0}); -2, -1, 0});
rtn &= test_alibi_traverse_with_slope<false, dtype>(4, 6, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {2, 1, 0, 1, 2, 3, rtn &= test_alibi_traverse_with_slope<false, dtype>(4, 6, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {-2, -1, 0, -1, -2, -3,
3, 2, 1, 0, 1, 2, -3, -2, -1, 0, -1, -2,
4, 3, 2, 1, 0, 1, -4, -3, -2, -1, 0, -1,
5, 4, 3, 2, 1, 0}); -5, -4, -3, -2, -1, 0});
rtn &= test_alibi_traverse_with_slope<false, dtype>(6, 4, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {2, 3, 4, 5, rtn &= test_alibi_traverse_with_slope<false, dtype>(6, 4, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {-2, -3, -4, -5,
1, 2, 3, 4, -1, -2, -3, -4,
0, 1, 2, 3, 0, -1, -2, -3,
1, 0, 1, 2, -1, 0, -1, -2,
2, 1, 0, 1, -2, -1, 0, -1,
3, 2, 1, 0}); -3, -2, -1, 0});
rtn &= test_alibi_traverse_with_slope<false, dtype>(3, 3, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {0, 1, 2, rtn &= test_alibi_traverse_with_slope<false, dtype>(3, 3, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, { 0, -1, -2,
1, 0, 1, -1, 0, -1,
2, 1, 0}); -2, -1, 0});
rtn &= test_alibi_slope_generation<float>(8, {0.5, 0.25, 0.125, 0.0625, 0.03125, 0.015625, 0.0078125, 0.00390625}); rtn &= test_alibi_slope_generation<float>(8, {0.5, 0.25, 0.125, 0.0625, 0.03125, 0.015625, 0.0078125, 0.00390625});
rtn &= test_alibi_slope_generation<float>(16, {0.7071067811865476, 0.5, 0.35355339059327384, 0.25000000000000006, 0.17677669529663692, rtn &= test_alibi_slope_generation<float>(16, {0.7071067811865476, 0.5, 0.35355339059327384, 0.25000000000000006, 0.17677669529663692,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment