"llama/git@developer.sourcefind.cn:orangecat/ollama.git" did not exist on "03e40efa51a75a4e8385b64996af6468f42f6c06"
Commit a536b16b authored by Paul's avatar Paul
Browse files

Make operator class extensible

parent b1e9363f
#ifndef RTG_GUARD_BUILTIN_HPP
#define RTG_GUARD_BUILTIN_HPP
namespace rtg {
namespace builtin {
static const char * literal = "@literal";
static const char * param = "@param";
}
} // namespace rtg
#endif
......@@ -3,6 +3,7 @@
#include <rtg/literal.hpp>
#include <rtg/shape.hpp>
#include <rtg/builtin.hpp>
#include <string>
namespace rtg {
......@@ -16,7 +17,7 @@ struct instruction
{}
instruction(literal l)
: name("literal"), result(l.get_shape()), lit(std::move(l))
: name(builtin::literal), result(l.get_shape()), lit(std::move(l))
{}
std::string name;
......
......@@ -3,16 +3,135 @@
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
#include <rtg/shape.hpp>
#include <rtg/argument.hpp>
namespace rtg {
struct operand
/*
* Type-erased interface for:
*
* struct operand
* {
* std::string name() const;
* shape compute_shape(std::vector<shape> input) const;
* argument compute(std::vector<argument> input) const;
* };
*
*/
struct operand
{
std::string name;
std::function<shape(std::vector<shape>)> compute_shape;
std::function<argument(std::vector<argument>)> compute;
// Constructors
operand() = default;
template <typename TypeErased_T_>
operand(TypeErased_T_ value)
: handle_mem_var_(
std::make_shared<handle_type_<typename std::remove_reference<TypeErased_T_>::type>>(
std::forward<TypeErased_T_>(value)))
{
}
// Assignment
template <typename TypeErased_T_>
operand& operator=(TypeErased_T_ value)
{
if(handle_mem_var_.unique())
*handle_mem_var_ = std::forward<TypeErased_T_>(value);
else if(!handle_mem_var_)
handle_mem_var_ = std::make_shared<TypeErased_T_>(std::forward<TypeErased_T_>(value));
return *this;
}
std::string name() const
{
assert(handle_mem_var_);
return get_handle_().name();
}
shape compute_shape(std::vector<shape> input) const
{
assert(handle_mem_var_);
return get_handle_().compute_shape(std::move(input));
}
argument compute(std::vector<argument> input) const
{
assert(handle_mem_var_);
return get_handle_().compute(std::move(input));
}
private:
struct handle_base_type_
{
virtual ~handle_base_type_() {}
virtual std::shared_ptr<handle_base_type_> clone() const = 0;
virtual std::string name() const = 0;
virtual shape compute_shape(std::vector<shape> input) const = 0;
virtual argument compute(std::vector<argument> input) const = 0;
};
template <typename TypeErased_T_>
struct handle_type_ : handle_base_type_
{
template <typename TypeErased_U_ = TypeErased_T_>
handle_type_(TypeErased_T_ value,
typename std::enable_if<std::is_reference<TypeErased_U_>::value>::type* = 0)
: value_(value)
{
}
template <typename TypeErased_U_ = TypeErased_T_>
handle_type_(TypeErased_T_ value,
typename std::enable_if<!std::is_reference<TypeErased_U_>::value, int>::type* =
0) noexcept : value_(std::move(value))
{
}
virtual std::shared_ptr<handle_base_type_> clone() const
{
return std::make_shared<handle_type_>(value_);
}
virtual std::string name() const { return value_.name(); }
virtual shape compute_shape(std::vector<shape> input) const
{
return value_.compute_shape(std::move(input));
}
virtual argument compute(std::vector<argument> input) const
{
return value_.compute(std::move(input));
}
TypeErased_T_ value_;
};
template <typename TypeErased_T_>
struct handle_type_<std::reference_wrapper<TypeErased_T_>> : handle_type_<TypeErased_T_&>
{
handle_type_(std::reference_wrapper<TypeErased_T_> ref)
: handle_type_<TypeErased_T_&>(ref.get())
{
}
};
const handle_base_type_& get_handle_() const { return *handle_mem_var_; }
handle_base_type_& get_handle_()
{
if(!handle_mem_var_.unique())
handle_mem_var_ = handle_mem_var_->clone();
return *handle_mem_var_;
}
std::shared_ptr<handle_base_type_> handle_mem_var_;
};
}
......
#ifndef RTG_GUARD_OPERATORS_HPP
#define RTG_GUARD_OPERATORS_HPP
namespace rtg {
} // namespace rtg
#endif
......@@ -5,6 +5,7 @@
#include <unordered_map>
#include <rtg/instruction.hpp>
#include <rtg/operand.hpp>
#include <rtg/builtin.hpp>
namespace rtg {
......@@ -27,18 +28,13 @@ struct program
instruction * add_parameter(std::string name, shape s)
{
instructions.push_back({"param:"+std::move(name), s, {}});
instructions.push_back({builtin::param+std::move(name), s, {}});
return std::addressof(instructions.back());
}
template<class Op, class Shape>
void add_operator(std::string name, Op op, Shape s)
void add_operator(operand op)
{
operand result;
result.name = name;
result.compute = op;
result.compute_shape = s;
ops.emplace(name, result);
ops.emplace(op.name(), op);
}
literal eval(std::unordered_map<std::string, argument> params) const;
......@@ -48,7 +44,6 @@ private:
std::list<instruction> instructions;
std::unordered_map<std::string, operand> ops;
};
}
......
......@@ -10,11 +10,11 @@ literal program::eval(std::unordered_map<std::string, argument> params) const
argument result;
for(auto& ins:instructions)
{
if(ins.name == "literal")
if(ins.name == builtin::literal)
{
result = ins.lit.get_argument();
}
else if(starts_with(ins.name, "param:"))
else if(starts_with(ins.name, builtin::param))
{
result = params.at(ins.name.substr(6));
}
......
......@@ -4,28 +4,39 @@
#include <rtg/shape.hpp>
#include "test.hpp"
void literal_test() {
rtg::program p;
p.add_operator("sum",
[](std::vector<rtg::argument> args) {
rtg::argument result;
if(args.size() != 2) throw "Wrong args";
if(args[0].get_shape() != args[1].get_shape()) throw "Wrong args";
if(args[0].get_shape().lens().size() != 1) throw "Wrong args";
if(args[0].get_shape().lens().front() != 1) throw "Wrong args";
args[0].visit_at([&](auto x) {
args[1].visit_at([&](auto y) {
result = rtg::literal{x + y}.get_argument();
});
struct sum_op
{
std::string name() const
{
return "sum";
}
rtg::argument compute(std::vector<rtg::argument> args) const
{
rtg::argument result;
if(args.size() != 2) throw "Wrong args";
if(args[0].get_shape() != args[1].get_shape()) throw "Wrong args";
if(args[0].get_shape().lens().size() != 1) throw "Wrong args";
if(args[0].get_shape().lens().front() != 1) throw "Wrong args";
args[0].visit_at([&](auto x) {
args[1].visit_at([&](auto y) {
result = rtg::literal{x + y}.get_argument();
});
return result;
},
[](std::vector<rtg::shape> inputs) {
if(inputs.size() != 2) throw "Wrong inputs";
return inputs.front();
}
);
});
return result;
}
rtg::shape compute_shape(std::vector<rtg::shape> inputs) const
{
if(inputs.size() != 2) throw "Wrong inputs";
return inputs.front();
}
};
void literal_test() {
rtg::program p;
p.add_operator(sum_op{});
auto one = p.add_literal(1);
auto two = p.add_literal(2);
......@@ -37,26 +48,7 @@ void literal_test() {
void param_test() {
rtg::program p;
p.add_operator("sum",
[](std::vector<rtg::argument> args) {
rtg::argument result;
if(args.size() != 2) throw "Wrong args";
if(args[0].get_shape() != args[1].get_shape()) throw "Wrong args";
if(args[0].get_shape().lens().size() != 1) throw "Wrong args";
if(args[0].get_shape().lens().front() != 1) throw "Wrong args";
args[0].visit_at([&](auto x) {
args[1].visit_at([&](auto y) {
result = rtg::literal{x + y}.get_argument();
});
});
return result;
},
[](std::vector<rtg::shape> inputs) {
if(inputs.size() != 2) throw "Wrong inputs";
return inputs.front();
}
);
p.add_operator(sum_op{});
auto x = p.add_parameter("x", {rtg::shape::int_type});
auto y = p.add_parameter("y", {rtg::shape::int_type});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment