Commit 612a35d6 authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Use maximum 8 splits

parent 45721793
...@@ -234,8 +234,7 @@ int override_num_splits_if_necessary(int batch, ...@@ -234,8 +234,7 @@ int override_num_splits_if_necessary(int batch,
if(num_splits < 1 && p_drop == 0.0f) if(num_splits < 1 && p_drop == 0.0f)
{ {
return num_splits_heuristic( return num_splits_heuristic(batch * nhead * num_m_blocks, props.multiProcessorCount * 2, 8);
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, 16);
} }
return num_splits; return num_splits;
...@@ -1042,7 +1041,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -1042,7 +1041,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
args.num_splits = num_splits; args.num_splits = num_splits;
if (1 < num_splits) { if(1 < num_splits)
{
args.lse_acc_ptr = lse_acc_buf.GetDeviceBuffer(); args.lse_acc_ptr = lse_acc_buf.GetDeviceBuffer();
args.o_acc_ptr = o_acc_buf.GetDeviceBuffer(); args.o_acc_ptr = o_acc_buf.GetDeviceBuffer();
...@@ -1053,7 +1053,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -1053,7 +1053,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
args.batch_stride_o_acc = batch_stride_o_acc; args.batch_stride_o_acc = batch_stride_o_acc;
args.split_stride_lse_acc = split_stride_lse_acc; args.split_stride_lse_acc = split_stride_lse_acc;
args.split_stride_o_acc = split_stride_o_acc; args.split_stride_o_acc = split_stride_o_acc;
} else { }
else
{
// following attribues are ignored by fmha_fwd_splitkv() // following attribues are ignored by fmha_fwd_splitkv()
args.lse_acc_ptr = nullptr; args.lse_acc_ptr = nullptr;
args.o_acc_ptr = nullptr; args.o_acc_ptr = nullptr;
...@@ -1088,7 +1090,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -1088,7 +1090,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const float fwd_ave_time = [&] { const float fwd_ave_time = [&] {
#if CK_TILE_FMHA_FWD_SPLITKV_API #if CK_TILE_FMHA_FWD_SPLITKV_API
if(1 < num_splits || use_kvcache) if(1 <= num_splits || use_kvcache)
{ {
fmha_fwd_splitkv_traits fmha_splitkv_traits; fmha_fwd_splitkv_traits fmha_splitkv_traits;
init_traits(fmha_splitkv_traits); init_traits(fmha_splitkv_traits);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment