"vscode:/vscode.git/clone" did not exist on "c71e140d32a4fcadee95415b7cb8fde05e68e02a"
Commit 23602a0d authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Use better num_splits heuristic

parent 289d5eb0
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "mask.hpp" #include "mask.hpp"
#include "rotary.hpp" #include "rotary.hpp"
#include <array>
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <variant> #include <variant>
...@@ -827,26 +828,27 @@ Int num_splits_heuristic(Int batch_nhead_mblocks, Int num_SMs, Int max_splits) ...@@ -827,26 +828,27 @@ Int num_splits_heuristic(Int batch_nhead_mblocks, Int num_SMs, Int max_splits)
max_splits = std::min({max_splits, num_SMs}); max_splits = std::min({max_splits, num_SMs});
constexpr std::array<Int, 5> num_splits_array = {1, 2, 4, 8, 16};
float max_efficiency = 0.f; float max_efficiency = 0.f;
std::vector<float> efficiency; std::array<float, num_splits_array.size()> efficiency;
efficiency.reserve(max_splits);
for(Int num_splits = 1; num_splits <= max_splits; num_splits *= 2) for(size_t idx = 0; idx < num_splits_array.size() && num_splits_array[idx] <= max_splits; ++idx)
{ {
float n_blocks = float(batch_nhead_mblocks * num_splits) / num_SMs; float n_blocks = float(batch_nhead_mblocks * num_splits_array[idx]) / num_SMs;
float eff = n_blocks / std::ceil(n_blocks); float eff = n_blocks / std::ceil(n_blocks);
if(eff > max_efficiency) if(eff > max_efficiency)
{ {
max_efficiency = eff; max_efficiency = eff;
} }
efficiency.push_back(eff); efficiency[idx] = eff;
} }
for(Int num_splits = 1; num_splits <= max_splits; num_splits++) for(size_t idx = 0; idx < num_splits_array.size() && num_splits_array[idx] <= max_splits; ++idx)
{ {
if(efficiency[num_splits - 1] >= 0.85 * max_efficiency) if(efficiency[idx] >= 0.85 * max_efficiency)
{ {
return num_splits; return num_splits_array[idx];
} }
} }
return 1; return 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment