Commit 337f073d authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Move num_splits_heuristic() to fmha_fwd.hpp for reusability

parent 2da4b185
...@@ -177,41 +177,6 @@ auto get_elimit<FmhaFwdFp8>(std::string init_method) ...@@ -177,41 +177,6 @@ auto get_elimit<FmhaFwdFp8>(std::string init_method)
} }
} }
int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int max_splits)
{
// If we have enough to almost fill the SMs, then just use 1 split
if(batch_nhead_mblocks >= 0.8f * num_SMs)
{
return 1;
}
max_splits = std::min({max_splits, num_SMs});
float max_efficiency = 0.f;
std::vector<float> efficiency;
efficiency.reserve(max_splits);
for(int num_splits = 1; num_splits <= max_splits; num_splits++)
{
float n_blocks = float(batch_nhead_mblocks * num_splits) / num_SMs;
float eff = n_blocks / std::ceil(n_blocks);
if(eff > max_efficiency)
{
max_efficiency = eff;
}
efficiency.push_back(eff);
}
for(int num_splits = 1; num_splits <= max_splits; num_splits++)
{
if(efficiency[num_splits - 1] >= 0.85 * max_efficiency)
{
return num_splits;
}
}
return 1;
}
int override_num_splits_if_necessary(int batch, int override_num_splits_if_necessary(int batch,
int nhead, int nhead,
int max_seqlen_q, int max_seqlen_q,
......
...@@ -813,3 +813,39 @@ struct fmha_fwd_appendkv_traits ...@@ -813,3 +813,39 @@ struct fmha_fwd_appendkv_traits
float fmha_fwd_appendkv(fmha_fwd_appendkv_traits, float fmha_fwd_appendkv(fmha_fwd_appendkv_traits,
fmha_fwd_appendkv_args, fmha_fwd_appendkv_args,
const ck_tile::stream_config&); const ck_tile::stream_config&);
template <typename Int = int>
Int num_splits_heuristic(Int batch_nhead_mblocks, Int num_SMs, Int max_splits)
{
// If we have enough to almost fill the SMs, then just use 1 split
if(batch_nhead_mblocks >= 0.8f * num_SMs)
{
return 1;
}
max_splits = std::min({max_splits, num_SMs});
float max_efficiency = 0.f;
std::vector<float> efficiency;
efficiency.reserve(max_splits);
for(Int num_splits = 1; num_splits <= max_splits; num_splits++)
{
float n_blocks = float(batch_nhead_mblocks * num_splits) / num_SMs;
float eff = n_blocks / std::ceil(n_blocks);
if(eff > max_efficiency)
{
max_efficiency = eff;
}
efficiency.push_back(eff);
}
for(Int num_splits = 1; num_splits <= max_splits; num_splits++)
{
if(efficiency[num_splits - 1] >= 0.85 * max_efficiency)
{
return num_splits;
}
}
return 1;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment