Commit 3a58dbf9 authored by Paul's avatar Paul
Browse files

Update compile_ops to handle tuning config

parent 247ce1d2
...@@ -76,6 +76,53 @@ struct compiled_result ...@@ -76,6 +76,53 @@ struct compiled_result
instruction_ref ins; instruction_ref ins;
}; };
struct compile_plan
{
context* ctx;
operation preop;
instruction_ref ins;
optional<tuning_config> config = nullopt;
std::vector<compiled_result> results = {};
void update_config()
{
config = get_tuning_config(*ctx, ins, preop);
}
template<class Vector>
void add_compiles(Vector& compiles)
{
if (config.has_value())
{
const auto& solutions = config.value().solutions;
results.resize(solutions.size());
for(auto i:range(solutions.size()))
{
auto solution = solutions[i];
compiles.emplace_back([=] {
results[i] = compiled_result{compile(*ctx, ins, preop, solution), ins};
});
}
}
else
{
results.resize(1);
compiles.emplace_back([=] {
results[0] = compiled_result{compile(*ctx, ins, preop, value{}), ins};
});
}
}
void replace(module& m) const
{
if(results.size() == 1)
{
results.front().replace.replace(m, results.front().ins);
}
else
{
// TODO: Benchmark
}
}
};
template <class F> template <class F>
void par_compile(std::size_t n, F f) void par_compile(std::size_t n, F f)
{ {
...@@ -86,22 +133,27 @@ void par_compile(std::size_t n, F f) ...@@ -86,22 +133,27 @@ void par_compile(std::size_t n, F f)
void compile_ops::apply(module& m) const void compile_ops::apply(module& m) const
{ {
std::vector<std::function<compiled_result()>> compiles; std::vector<compile_plan> cps;
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
{ {
if(ins->name() != "gpu::precompile_op") if(ins->name() != "gpu::precompile_op")
continue; continue;
operation preop = any_cast<precompile_op>(ins->get_operator()).op; operation preop = any_cast<precompile_op>(ins->get_operator()).op;
compiles.emplace_back([=]() -> compiled_result { cps.push_back({ctx, preop, ins});
return {compile(*ctx, ins, preop, value{}), ins};
});
} }
std::vector<compiled_result> results(compiles.size()); par_compile(cps.size(), [&](auto i) {
par_compile(compiles.size(), [&](auto i) { results[i] = compiles[i](); }); cps[i].update_config();
for(const auto& cr : results) });
std::vector<std::function<void()>> compiles;
for(auto& cp:cps)
{
cp.add_compiles(compiles);
}
par_compile(compiles.size(), [&](auto i) { compiles[i](); });
for(const auto& cp:cps)
{ {
cr.replace.replace(m, cr.ins); cp.replace(m);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment