Commit 612a35d6 authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Use maximum 8 splits

parent 45721793
......@@ -234,8 +234,7 @@ int override_num_splits_if_necessary(int batch,
if(num_splits < 1 && p_drop == 0.0f)
{
return num_splits_heuristic(
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, 16);
return num_splits_heuristic(batch * nhead * num_m_blocks, props.multiProcessorCount * 2, 8);
}
return num_splits;
......@@ -625,11 +624,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
ck_tile::HostTensor<OaccDataType> o_acc_host(
1 < num_splits ? std::array<ck_tile::index_t, 5>{shape_batch,
nhead,
num_splits,
shape_seqlen_q,
hdim_v}
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
nhead,
num_splits,
shape_seqlen_q,
hdim_v}
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
// batch mode of lse data layout is [batch, nhead, seqlen_q]
// group mode of lse data layout is [nhead, total_seqlen_q]
......@@ -1042,7 +1041,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
args.num_splits = num_splits;
if (1 < num_splits) {
if(1 < num_splits)
{
args.lse_acc_ptr = lse_acc_buf.GetDeviceBuffer();
args.o_acc_ptr = o_acc_buf.GetDeviceBuffer();
......@@ -1053,7 +1053,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
args.batch_stride_o_acc = batch_stride_o_acc;
args.split_stride_lse_acc = split_stride_lse_acc;
args.split_stride_o_acc = split_stride_o_acc;
} else {
}
else
{
// following attribues are ignored by fmha_fwd_splitkv()
args.lse_acc_ptr = nullptr;
args.o_acc_ptr = nullptr;
......@@ -1088,7 +1090,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const float fwd_ave_time = [&] {
#if CK_TILE_FMHA_FWD_SPLITKV_API
if(1 < num_splits || use_kvcache)
if(1 <= num_splits || use_kvcache)
{
fmha_fwd_splitkv_traits fmha_splitkv_traits;
init_traits(fmha_splitkv_traits);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment