Commit 626ab5b6 authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Update num_splits heuristic for prefill phase

parent 287a53bf
......@@ -211,8 +211,14 @@ int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int max_splits)
return 1;
}
int override_num_splits_if_necessary(
int batch, int nhead, int max_seqlen_q, int hdim_q, int hdim_v, float p_drop, int num_splits)
int override_num_splits_if_necessary(int batch,
int nhead,
int max_seqlen_q,
int hdim_q,
int hdim_v,
float p_drop,
bool is_prefill,
int num_splits)
{
int device;
auto status = hipGetDevice(&device);
......@@ -229,6 +235,13 @@ int override_num_splits_if_necessary(
}
const int kM0 = [&] {
// get kM0 for prefill phase
if(is_prefill)
{
return 128;
}
// get kM0 for decode phase
/// TODO: take dtype=fp8/bf8 into consideration
const std::map<int, int> hdim_to_m0 = {
{32, 32},
......@@ -553,8 +566,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
// legalize num_splits according to other options
if(num_splits < 1)
{
num_splits = override_num_splits_if_necessary(
batch, nhead, max_seqlen_q, hdim_q, hdim_v, p_drop, num_splits);
num_splits = override_num_splits_if_necessary(batch,
nhead,
max_seqlen_q,
hdim_q,
hdim_v,
p_drop,
/*is_prefill=*/mode == mode_enum::group &&
0 < page_block_size,
num_splits);
}
if(128 < num_splits)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment