Commit 6ff7fa94 authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Update num_splits heuristic

parent 337f073d
......@@ -235,7 +235,7 @@ int override_num_splits_if_necessary(int batch,
if(num_splits < 1 && p_drop == 0.0f)
{
return num_splits_heuristic(
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, 32);
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, 16);
}
return num_splits;
......
......@@ -829,7 +829,7 @@ Int num_splits_heuristic(Int batch_nhead_mblocks, Int num_SMs, Int max_splits)
std::vector<float> efficiency;
efficiency.reserve(max_splits);
for(Int num_splits = 1; num_splits <= max_splits; num_splits++)
for(Int num_splits = 1; num_splits <= max_splits; num_splits *= 2)
{
float n_blocks = float(batch_nhead_mblocks * num_splits) / num_SMs;
float eff = n_blocks / std::ceil(n_blocks);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment