Commit 6ff7fa94 authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Update num_splits heuristic

parent 337f073d
...@@ -235,7 +235,7 @@ int override_num_splits_if_necessary(int batch, ...@@ -235,7 +235,7 @@ int override_num_splits_if_necessary(int batch,
if(num_splits < 1 && p_drop == 0.0f) if(num_splits < 1 && p_drop == 0.0f)
{ {
return num_splits_heuristic( return num_splits_heuristic(
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, 32); batch * nhead * num_m_blocks, props.multiProcessorCount * 2, 16);
} }
return num_splits; return num_splits;
......
...@@ -829,7 +829,7 @@ Int num_splits_heuristic(Int batch_nhead_mblocks, Int num_SMs, Int max_splits) ...@@ -829,7 +829,7 @@ Int num_splits_heuristic(Int batch_nhead_mblocks, Int num_SMs, Int max_splits)
std::vector<float> efficiency; std::vector<float> efficiency;
efficiency.reserve(max_splits); efficiency.reserve(max_splits);
for(Int num_splits = 1; num_splits <= max_splits; num_splits++) for(Int num_splits = 1; num_splits <= max_splits; num_splits *= 2)
{ {
float n_blocks = float(batch_nhead_mblocks * num_splits) / num_SMs; float n_blocks = float(batch_nhead_mblocks * num_splits) / num_SMs;
float eff = n_blocks / std::ceil(n_blocks); float eff = n_blocks / std::ceil(n_blocks);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment