"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "f16356d4ff7d7f4f0c28c48c89a55d2372b2a3f7"
Commit 23602a0d authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Use better num_splits heuristic

parent 289d5eb0
......@@ -12,6 +12,7 @@
#include "mask.hpp"
#include "rotary.hpp"
#include <array>
#include <type_traits>
#include <utility>
#include <variant>
......@@ -827,26 +828,27 @@ Int num_splits_heuristic(Int batch_nhead_mblocks, Int num_SMs, Int max_splits)
max_splits = std::min({max_splits, num_SMs});
constexpr std::array<Int, 5> num_splits_array = {1, 2, 4, 8, 16};
float max_efficiency = 0.f;
std::vector<float> efficiency;
efficiency.reserve(max_splits);
std::array<float, num_splits_array.size()> efficiency;
for(Int num_splits = 1; num_splits <= max_splits; num_splits *= 2)
for(size_t idx = 0; idx < num_splits_array.size() && num_splits_array[idx] <= max_splits; ++idx)
{
float n_blocks = float(batch_nhead_mblocks * num_splits) / num_SMs;
float n_blocks = float(batch_nhead_mblocks * num_splits_array[idx]) / num_SMs;
float eff = n_blocks / std::ceil(n_blocks);
if(eff > max_efficiency)
{
max_efficiency = eff;
}
efficiency.push_back(eff);
efficiency[idx] = eff;
}
for(Int num_splits = 1; num_splits <= max_splits; num_splits++)
for(size_t idx = 0; idx < num_splits_array.size() && num_splits_array[idx] <= max_splits; ++idx)
{
if(efficiency[num_splits - 1] >= 0.85 * max_efficiency)
if(efficiency[idx] >= 0.85 * max_efficiency)
{
return num_splits;
return num_splits_array[idx];
}
}
return 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment