Commit a536b16b authored by Paul's avatar Paul
Browse files

Make operator class extensible

parent b1e9363f
#ifndef RTG_GUARD_BUILTIN_HPP
#define RTG_GUARD_BUILTIN_HPP
namespace rtg {
namespace builtin {
static const char * literal = "@literal";
static const char * param = "@param";
}
} // namespace rtg
#endif
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <rtg/literal.hpp> #include <rtg/literal.hpp>
#include <rtg/shape.hpp> #include <rtg/shape.hpp>
#include <rtg/builtin.hpp>
#include <string> #include <string>
namespace rtg { namespace rtg {
...@@ -16,7 +17,7 @@ struct instruction ...@@ -16,7 +17,7 @@ struct instruction
{} {}
instruction(literal l) instruction(literal l)
: name("literal"), result(l.get_shape()), lit(std::move(l)) : name(builtin::literal), result(l.get_shape()), lit(std::move(l))
{} {}
std::string name; std::string name;
......
...@@ -3,16 +3,135 @@ ...@@ -3,16 +3,135 @@
#include <string> #include <string>
#include <functional> #include <functional>
#include <memory>
#include <type_traits>
#include <utility>
#include <rtg/shape.hpp> #include <rtg/shape.hpp>
#include <rtg/argument.hpp> #include <rtg/argument.hpp>
namespace rtg { namespace rtg {
/*
* Type-erased interface for:
*
* struct operand
* {
* std::string name() const;
* shape compute_shape(std::vector<shape> input) const;
* argument compute(std::vector<argument> input) const;
* };
*
*/
struct operand struct operand
{ {
std::string name; // Constructors
std::function<shape(std::vector<shape>)> compute_shape; operand() = default;
std::function<argument(std::vector<argument>)> compute;
template <typename TypeErased_T_>
operand(TypeErased_T_ value)
: handle_mem_var_(
std::make_shared<handle_type_<typename std::remove_reference<TypeErased_T_>::type>>(
std::forward<TypeErased_T_>(value)))
{
}
// Assignment
template <typename TypeErased_T_>
operand& operator=(TypeErased_T_ value)
{
if(handle_mem_var_.unique())
*handle_mem_var_ = std::forward<TypeErased_T_>(value);
else if(!handle_mem_var_)
handle_mem_var_ = std::make_shared<TypeErased_T_>(std::forward<TypeErased_T_>(value));
return *this;
}
std::string name() const
{
assert(handle_mem_var_);
return get_handle_().name();
}
shape compute_shape(std::vector<shape> input) const
{
assert(handle_mem_var_);
return get_handle_().compute_shape(std::move(input));
}
argument compute(std::vector<argument> input) const
{
assert(handle_mem_var_);
return get_handle_().compute(std::move(input));
}
private:
struct handle_base_type_
{
virtual ~handle_base_type_() {}
virtual std::shared_ptr<handle_base_type_> clone() const = 0;
virtual std::string name() const = 0;
virtual shape compute_shape(std::vector<shape> input) const = 0;
virtual argument compute(std::vector<argument> input) const = 0;
};
template <typename TypeErased_T_>
struct handle_type_ : handle_base_type_
{
template <typename TypeErased_U_ = TypeErased_T_>
handle_type_(TypeErased_T_ value,
typename std::enable_if<std::is_reference<TypeErased_U_>::value>::type* = 0)
: value_(value)
{
}
template <typename TypeErased_U_ = TypeErased_T_>
handle_type_(TypeErased_T_ value,
typename std::enable_if<!std::is_reference<TypeErased_U_>::value, int>::type* =
0) noexcept : value_(std::move(value))
{
}
virtual std::shared_ptr<handle_base_type_> clone() const
{
return std::make_shared<handle_type_>(value_);
}
virtual std::string name() const { return value_.name(); }
virtual shape compute_shape(std::vector<shape> input) const
{
return value_.compute_shape(std::move(input));
}
virtual argument compute(std::vector<argument> input) const
{
return value_.compute(std::move(input));
}
TypeErased_T_ value_;
};
template <typename TypeErased_T_>
struct handle_type_<std::reference_wrapper<TypeErased_T_>> : handle_type_<TypeErased_T_&>
{
handle_type_(std::reference_wrapper<TypeErased_T_> ref)
: handle_type_<TypeErased_T_&>(ref.get())
{
}
};
const handle_base_type_& get_handle_() const { return *handle_mem_var_; }
handle_base_type_& get_handle_()
{
if(!handle_mem_var_.unique())
handle_mem_var_ = handle_mem_var_->clone();
return *handle_mem_var_;
}
std::shared_ptr<handle_base_type_> handle_mem_var_;
}; };
} }
......
#ifndef RTG_GUARD_OPERATORS_HPP
#define RTG_GUARD_OPERATORS_HPP
namespace rtg {
} // namespace rtg
#endif
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <unordered_map> #include <unordered_map>
#include <rtg/instruction.hpp> #include <rtg/instruction.hpp>
#include <rtg/operand.hpp> #include <rtg/operand.hpp>
#include <rtg/builtin.hpp>
namespace rtg { namespace rtg {
...@@ -27,18 +28,13 @@ struct program ...@@ -27,18 +28,13 @@ struct program
instruction * add_parameter(std::string name, shape s) instruction * add_parameter(std::string name, shape s)
{ {
instructions.push_back({"param:"+std::move(name), s, {}}); instructions.push_back({builtin::param+std::move(name), s, {}});
return std::addressof(instructions.back()); return std::addressof(instructions.back());
} }
template<class Op, class Shape> void add_operator(operand op)
void add_operator(std::string name, Op op, Shape s)
{ {
operand result; ops.emplace(op.name(), op);
result.name = name;
result.compute = op;
result.compute_shape = s;
ops.emplace(name, result);
} }
literal eval(std::unordered_map<std::string, argument> params) const; literal eval(std::unordered_map<std::string, argument> params) const;
...@@ -48,7 +44,6 @@ private: ...@@ -48,7 +44,6 @@ private:
std::list<instruction> instructions; std::list<instruction> instructions;
std::unordered_map<std::string, operand> ops; std::unordered_map<std::string, operand> ops;
}; };
} }
......
...@@ -10,11 +10,11 @@ literal program::eval(std::unordered_map<std::string, argument> params) const ...@@ -10,11 +10,11 @@ literal program::eval(std::unordered_map<std::string, argument> params) const
argument result; argument result;
for(auto& ins:instructions) for(auto& ins:instructions)
{ {
if(ins.name == "literal") if(ins.name == builtin::literal)
{ {
result = ins.lit.get_argument(); result = ins.lit.get_argument();
} }
else if(starts_with(ins.name, "param:")) else if(starts_with(ins.name, builtin::param))
{ {
result = params.at(ins.name.substr(6)); result = params.at(ins.name.substr(6));
} }
......
...@@ -4,10 +4,15 @@ ...@@ -4,10 +4,15 @@
#include <rtg/shape.hpp> #include <rtg/shape.hpp>
#include "test.hpp" #include "test.hpp"
void literal_test() {
rtg::program p; struct sum_op
p.add_operator("sum", {
[](std::vector<rtg::argument> args) { std::string name() const
{
return "sum";
}
rtg::argument compute(std::vector<rtg::argument> args) const
{
rtg::argument result; rtg::argument result;
if(args.size() != 2) throw "Wrong args"; if(args.size() != 2) throw "Wrong args";
if(args[0].get_shape() != args[1].get_shape()) throw "Wrong args"; if(args[0].get_shape() != args[1].get_shape()) throw "Wrong args";
...@@ -20,12 +25,18 @@ void literal_test() { ...@@ -20,12 +25,18 @@ void literal_test() {
}); });
}); });
return result; return result;
}, }
[](std::vector<rtg::shape> inputs) {
rtg::shape compute_shape(std::vector<rtg::shape> inputs) const
{
if(inputs.size() != 2) throw "Wrong inputs"; if(inputs.size() != 2) throw "Wrong inputs";
return inputs.front(); return inputs.front();
} }
); };
void literal_test() {
rtg::program p;
p.add_operator(sum_op{});
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
...@@ -37,26 +48,7 @@ void literal_test() { ...@@ -37,26 +48,7 @@ void literal_test() {
void param_test() { void param_test() {
rtg::program p; rtg::program p;
p.add_operator("sum", p.add_operator(sum_op{});
[](std::vector<rtg::argument> args) {
rtg::argument result;
if(args.size() != 2) throw "Wrong args";
if(args[0].get_shape() != args[1].get_shape()) throw "Wrong args";
if(args[0].get_shape().lens().size() != 1) throw "Wrong args";
if(args[0].get_shape().lens().front() != 1) throw "Wrong args";
args[0].visit_at([&](auto x) {
args[1].visit_at([&](auto y) {
result = rtg::literal{x + y}.get_argument();
});
});
return result;
},
[](std::vector<rtg::shape> inputs) {
if(inputs.size() != 2) throw "Wrong inputs";
return inputs.front();
}
);
auto x = p.add_parameter("x", {rtg::shape::int_type}); auto x = p.add_parameter("x", {rtg::shape::int_type});
auto y = p.add_parameter("y", {rtg::shape::int_type}); auto y = p.add_parameter("y", {rtg::shape::int_type});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment