Commit 17ee5428 authored by Paul's avatar Paul
Browse files

Add interface to enable tuning

parent 2956bb3f
...@@ -94,7 +94,7 @@ void compile_ops::apply(module& m) const ...@@ -94,7 +94,7 @@ void compile_ops::apply(module& m) const
continue; continue;
operation preop = any_cast<precompile_op>(ins->get_operator()).op; operation preop = any_cast<precompile_op>(ins->get_operator()).op;
compiles.emplace_back([=]() -> compiled_result { compiles.emplace_back([=]() -> compiled_result {
return {compile(*ctx, ins, preop), ins}; return {compile(*ctx, ins, preop, value{}), ins};
}); });
} }
std::vector<compiled_result> results(compiles.size()); std::vector<compiled_result> results(compiles.size());
......
...@@ -28,33 +28,40 @@ namespace migraphx { ...@@ -28,33 +28,40 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
auto& compiler_map() namespace {
{ struct compiler_handle
static std::unordered_map<std::string, compiler_compile> m; // NOLINT {
return m; compiler_compile compile;
compiler_compile_op compile_op;
compiler_tuning_config get_tuning_config;
};
} }
auto& compiler_op_map() auto& compiler_map()
{ {
static std::unordered_map<std::string, compiler_compile_op> m; // NOLINT static std::unordered_map<std::string, compiler_handle> m; // NOLINT
return m; return m;
} }
void register_compiler(const std::string& name, compiler_compile c, compiler_compile_op cop) void register_compiler(const std::string& name, compiler_compile c, compiler_compile_op cop, compiler_tuning_config ctg)
{ {
compiler_map()[name] = std::move(c); compiler_map()[name] = {std::move(c), std::move(cop), std::move(ctg)};
compiler_op_map()[name] = std::move(cop);
} }
bool has_compiler_for(const std::string& name) { return compiler_map().count(name) > 0; } bool has_compiler_for(const std::string& name) { return compiler_map().count(name) > 0; }
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) compiler_replace compile(context& ctx, instruction_ref ins, const operation& op, const value& solution)
{ {
return compiler_map().at(op.name())(ctx, ins, op); return compiler_map().at(op.name()).compile(ctx, ins, op, solution);
} }
operation operation
compile_op(const std::string& name, context& ctx, const std::vector<shape>& inputs, const value& v) compile_op(const std::string& name, context& ctx, const std::vector<shape>& inputs, const value& v)
{ {
return compiler_op_map().at(name)(ctx, inputs, v); return compiler_map().at(name).compile_op(ctx, inputs, v);
}
optional<tuning_config> get_tuning_config(context& ctx, instruction_ref ins, const operation& op)
{
return compiler_map().at(op.name()).get_tuning_config(ctx, ins, op);
} }
} // namespace gpu } // namespace gpu
......
...@@ -30,6 +30,8 @@ ...@@ -30,6 +30,8 @@
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/module.hpp> #include <migraphx/module.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/optional.hpp>
#include <migraphx/rank.hpp>
#include <functional> #include <functional>
namespace migraphx { namespace migraphx {
...@@ -66,16 +68,25 @@ struct compiler_replace ...@@ -66,16 +68,25 @@ struct compiler_replace
} }
}; };
using compiler_compile = std::function<compiler_replace(context&, instruction_ref, operation)>; struct tuning_config
{
value problem;
std::vector<value> solutions;
};
using compiler_compile = std::function<compiler_replace(context&, instruction_ref, operation, const value&)>;
using compiler_compile_op = using compiler_compile_op =
std::function<operation(context&, const std::vector<shape>& inputs, const value&)>; std::function<operation(context&, const std::vector<shape>& inputs, const value&)>;
using compiler_tuning_config = std::function<optional<tuning_config>(context&, instruction_ref, const operation&)>;
void register_compiler(const std::string& name, compiler_compile c, compiler_compile_op cop); void register_compiler(const std::string& name, compiler_compile c, compiler_compile_op cop, compiler_tuning_config ctg);
bool has_compiler_for(const std::string& name); bool has_compiler_for(const std::string& name);
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op); compiler_replace compile(context& ctx, instruction_ref ins, const operation& op, const value& solution);
operation operation
compile_op(const std::string& name, context& ctx, const std::vector<shape>& inputs, const value& v); compile_op(const std::string& name, context& ctx, const std::vector<shape>& inputs, const value& v);
optional<tuning_config> get_tuning_config(context& ctx, instruction_ref ins, const operation& op);
template <class T> template <class T>
void register_compiler() void register_compiler()
...@@ -85,8 +96,9 @@ void register_compiler() ...@@ -85,8 +96,9 @@ void register_compiler()
{ {
register_compiler( register_compiler(
name, name,
[=](auto&&... xs) { return c.compile(std::forward<decltype(xs)>(xs)...); }, [=](auto&&... xs) { return c.invoke_compile(rank<1>{}, std::forward<decltype(xs)>(xs)...); },
[=](auto&&... xs) { return c.compile_op(std::forward<decltype(xs)>(xs)...); }); [=](auto&&... xs) { return c.compile_op(std::forward<decltype(xs)>(xs)...); },
[=](auto&&... xs) { return c.get_tuning_config(std::forward<decltype(xs)>(xs)...); });
} }
} }
...@@ -105,7 +117,29 @@ using auto_register_compiler = auto_register<register_compiler_action, T>; ...@@ -105,7 +117,29 @@ using auto_register_compiler = auto_register<register_compiler_action, T>;
template <class Derived> template <class Derived>
struct compiler : auto_register_compiler<Derived> struct compiler : auto_register_compiler<Derived>
{ {
const Derived& derived() const
{
return static_cast<const Derived&>(*this);
}
optional<tuning_config> get_tuning_config(context&, instruction_ref, const operation&) const
{
return nullopt;
}
operation compile_op(context&, const std::vector<shape>&, const value&) const { return {}; } operation compile_op(context&, const std::vector<shape>&, const value&) const { return {}; }
template<class D=Derived>
auto invoke_compile(rank<1>, context& ctx, instruction_ref ins, operation op, const value& solution) const -> decltype(std::declval<D>().compile(ctx, ins, std::move(op), solution))
{
return derived().compile(ctx, ins, std::move(op), solution);
}
template<class D=Derived>
auto invoke_compile(rank<0>, context& ctx, instruction_ref ins, operation op, const value& solution) const -> decltype(std::declval<D>().compile(ctx, ins, std::move(op)))
{
assert(solution.empty());
(void)solution;
return derived().compile(ctx, ins, std::move(op));
}
}; };
} // namespace gpu } // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment