Commit 17ee5428 authored by Paul's avatar Paul
Browse files

Add interface to enable tuning

parent 2956bb3f
......@@ -94,7 +94,7 @@ void compile_ops::apply(module& m) const
continue;
operation preop = any_cast<precompile_op>(ins->get_operator()).op;
compiles.emplace_back([=]() -> compiled_result {
return {compile(*ctx, ins, preop), ins};
return {compile(*ctx, ins, preop, value{}), ins};
});
}
std::vector<compiled_result> results(compiles.size());
......
......@@ -28,33 +28,40 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
auto& compiler_map()
{
static std::unordered_map<std::string, compiler_compile> m; // NOLINT
return m;
namespace {
struct compiler_handle
{
compiler_compile compile;
compiler_compile_op compile_op;
compiler_tuning_config get_tuning_config;
};
}
auto& compiler_op_map()
auto& compiler_map()
{
static std::unordered_map<std::string, compiler_compile_op> m; // NOLINT
static std::unordered_map<std::string, compiler_handle> m; // NOLINT
return m;
}
void register_compiler(const std::string& name, compiler_compile c, compiler_compile_op cop)
void register_compiler(const std::string& name, compiler_compile c, compiler_compile_op cop, compiler_tuning_config ctg)
{
compiler_map()[name] = std::move(c);
compiler_op_map()[name] = std::move(cop);
compiler_map()[name] = {std::move(c), std::move(cop), std::move(ctg)};
}
bool has_compiler_for(const std::string& name) { return compiler_map().count(name) > 0; }
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op)
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op, const value& solution)
{
return compiler_map().at(op.name())(ctx, ins, op);
return compiler_map().at(op.name()).compile(ctx, ins, op, solution);
}
operation
compile_op(const std::string& name, context& ctx, const std::vector<shape>& inputs, const value& v)
{
return compiler_op_map().at(name)(ctx, inputs, v);
return compiler_map().at(name).compile_op(ctx, inputs, v);
}
optional<tuning_config> get_tuning_config(context& ctx, instruction_ref ins, const operation& op)
{
return compiler_map().at(op.name()).get_tuning_config(ctx, ins, op);
}
} // namespace gpu
......
......@@ -30,6 +30,8 @@
#include <migraphx/value.hpp>
#include <migraphx/module.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/optional.hpp>
#include <migraphx/rank.hpp>
#include <functional>
namespace migraphx {
......@@ -66,16 +68,25 @@ struct compiler_replace
}
};
using compiler_compile = std::function<compiler_replace(context&, instruction_ref, operation)>;
struct tuning_config
{
value problem;
std::vector<value> solutions;
};
using compiler_compile = std::function<compiler_replace(context&, instruction_ref, operation, const value&)>;
using compiler_compile_op =
std::function<operation(context&, const std::vector<shape>& inputs, const value&)>;
using compiler_tuning_config = std::function<optional<tuning_config>(context&, instruction_ref, const operation&)>;
void register_compiler(const std::string& name, compiler_compile c, compiler_compile_op cop);
void register_compiler(const std::string& name, compiler_compile c, compiler_compile_op cop, compiler_tuning_config ctg);
bool has_compiler_for(const std::string& name);
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op);
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op, const value& solution);
operation
compile_op(const std::string& name, context& ctx, const std::vector<shape>& inputs, const value& v);
optional<tuning_config> get_tuning_config(context& ctx, instruction_ref ins, const operation& op);
template <class T>
void register_compiler()
......@@ -85,8 +96,9 @@ void register_compiler()
{
register_compiler(
name,
[=](auto&&... xs) { return c.compile(std::forward<decltype(xs)>(xs)...); },
[=](auto&&... xs) { return c.compile_op(std::forward<decltype(xs)>(xs)...); });
[=](auto&&... xs) { return c.invoke_compile(rank<1>{}, std::forward<decltype(xs)>(xs)...); },
[=](auto&&... xs) { return c.compile_op(std::forward<decltype(xs)>(xs)...); },
[=](auto&&... xs) { return c.get_tuning_config(std::forward<decltype(xs)>(xs)...); });
}
}
......@@ -105,7 +117,29 @@ using auto_register_compiler = auto_register<register_compiler_action, T>;
template <class Derived>
struct compiler : auto_register_compiler<Derived>
{
const Derived& derived() const
{
return static_cast<const Derived&>(*this);
}
optional<tuning_config> get_tuning_config(context&, instruction_ref, const operation&) const
{
return nullopt;
}
operation compile_op(context&, const std::vector<shape>&, const value&) const { return {}; }
template<class D=Derived>
auto invoke_compile(rank<1>, context& ctx, instruction_ref ins, operation op, const value& solution) const -> decltype(std::declval<D>().compile(ctx, ins, std::move(op), solution))
{
return derived().compile(ctx, ins, std::move(op), solution);
}
template<class D=Derived>
auto invoke_compile(rank<0>, context& ctx, instruction_ref ins, operation op, const value& solution) const -> decltype(std::declval<D>().compile(ctx, ins, std::move(op)))
{
assert(solution.empty());
(void)solution;
return derived().compile(ctx, ins, std::move(op));
}
};
} // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment