Commit f3b6bc41 authored by Paul's avatar Paul
Browse files

Use iterators for instruction ref

parent ab87d119
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <rtg/literal.hpp> #include <rtg/literal.hpp>
#include <rtg/shape.hpp> #include <rtg/shape.hpp>
#include <rtg/builtin.hpp> #include <rtg/builtin.hpp>
#include <rtg/instruction_ref.hpp>
#include <string> #include <string>
namespace rtg { namespace rtg {
...@@ -12,7 +13,7 @@ struct instruction ...@@ -12,7 +13,7 @@ struct instruction
{ {
instruction() {} instruction() {}
instruction(operation o, shape r, std::vector<instruction*> args) instruction(operation o, shape r, std::vector<instruction_ref> args)
: op(std::move(o)), result(std::move(r)), arguments(std::move(args)) : op(std::move(o)), result(std::move(r)), arguments(std::move(args))
{ {
} }
...@@ -21,7 +22,7 @@ struct instruction ...@@ -21,7 +22,7 @@ struct instruction
operation op; operation op;
shape result; shape result;
std::vector<instruction*> arguments; std::vector<instruction_ref> arguments;
literal lit; literal lit;
}; };
......
#ifndef RTG_GUARD_INSTRUCTION_REF_HPP
#define RTG_GUARD_INSTRUCTION_REF_HPP
#include <list>
namespace rtg {
struct instruction;
using instruction_ref = std::list<instruction>::iterator;
} // namespace rtg
#endif
...@@ -6,11 +6,11 @@ ...@@ -6,11 +6,11 @@
#include <rtg/operation.hpp> #include <rtg/operation.hpp>
#include <rtg/literal.hpp> #include <rtg/literal.hpp>
#include <rtg/builtin.hpp> #include <rtg/builtin.hpp>
#include <rtg/instruction_ref.hpp>
#include <algorithm> #include <algorithm>
namespace rtg { namespace rtg {
struct instruction;
struct program_impl; struct program_impl;
struct program struct program
...@@ -21,27 +21,27 @@ struct program ...@@ -21,27 +21,27 @@ struct program
~program() noexcept; ~program() noexcept;
template <class... Ts> template <class... Ts>
instruction* add_instruction(operation op, Ts*... args) instruction_ref add_instruction(operation op, Ts... args)
{ {
return add_instruction(op, {args...}); return add_instruction(op, {args...});
} }
instruction* add_instruction(operation op, std::vector<instruction*> args); instruction_ref add_instruction(operation op, std::vector<instruction_ref> args);
template <class... Ts> template <class... Ts>
instruction* add_literal(Ts&&... xs) instruction_ref add_literal(Ts&&... xs)
{ {
return add_literal(literal{std::forward<Ts>(xs)...}); return add_literal(literal{std::forward<Ts>(xs)...});
} }
instruction* add_literal(literal l); instruction_ref add_literal(literal l);
instruction* add_parameter(std::string name, shape s); instruction_ref add_parameter(std::string name, shape s);
literal eval(std::unordered_map<std::string, argument> params) const; literal eval(std::unordered_map<std::string, argument> params) const;
// TODO: Change to stream operator // TODO: Change to stream operator
void print() const; void print() const;
bool has_instruction(const instruction* ins) const; bool has_instruction(instruction_ref ins) const;
private: private:
std::unique_ptr<program_impl> impl; std::unique_ptr<program_impl> impl;
......
...@@ -41,16 +41,16 @@ struct onnx_parser ...@@ -41,16 +41,16 @@ struct onnx_parser
{ {
using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>; using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>;
using node_map = std::unordered_map<std::string, onnx::NodeProto>; using node_map = std::unordered_map<std::string, onnx::NodeProto>;
using op_func = std::function<rtg::instruction*(attribute_map, std::vector<rtg::instruction*>)>; using op_func = std::function<rtg::instruction_ref(attribute_map, std::vector<rtg::instruction_ref>)>;
node_map nodes; node_map nodes;
std::unordered_map<std::string, rtg::instruction*> instructions; std::unordered_map<std::string, rtg::instruction_ref> instructions;
rtg::program prog = rtg::program(); rtg::program prog = rtg::program();
std::unordered_map<std::string, op_func> ops; std::unordered_map<std::string, op_func> ops;
onnx_parser() onnx_parser()
{ {
add_op("Conv", [this](attribute_map attributes, std::vector<rtg::instruction*> args) { add_op("Conv", [this](attribute_map attributes, std::vector<rtg::instruction_ref> args) {
rtg::convolution op; rtg::convolution op;
if(contains(attributes, "pads")) if(contains(attributes, "pads"))
{ {
...@@ -66,7 +66,7 @@ struct onnx_parser ...@@ -66,7 +66,7 @@ struct onnx_parser
} }
return prog.add_instruction(op, args); return prog.add_instruction(op, args);
}); });
add_op("MaxPool", [this](attribute_map attributes, std::vector<rtg::instruction*> args) { add_op("MaxPool", [this](attribute_map attributes, std::vector<rtg::instruction_ref> args) {
rtg::pooling op{"max"}; rtg::pooling op{"max"};
// for(auto&& p:attributes) std::cout << p.first << std::endl; // for(auto&& p:attributes) std::cout << p.first << std::endl;
if(contains(attributes, "pads")) if(contains(attributes, "pads"))
...@@ -83,16 +83,16 @@ struct onnx_parser ...@@ -83,16 +83,16 @@ struct onnx_parser
} }
return prog.add_instruction(op, args); return prog.add_instruction(op, args);
}); });
add_op("Relu", [this](attribute_map, std::vector<rtg::instruction*> args) { add_op("Relu", [this](attribute_map, std::vector<rtg::instruction_ref> args) {
return prog.add_instruction(rtg::activation{"relu"}, args); return prog.add_instruction(rtg::activation{"relu"}, args);
}); });
add_op("Reshape", [this](attribute_map attributes, std::vector<rtg::instruction*> args) { add_op("Reshape", [this](attribute_map attributes, std::vector<rtg::instruction_ref> args) {
rtg::reshape op; rtg::reshape op;
rtg::literal s = parse_value(attributes.at("shape")); rtg::literal s = parse_value(attributes.at("shape"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); }); s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
return prog.add_instruction(op, args); return prog.add_instruction(op, args);
}); });
add_op("Constant", [this](attribute_map attributes, std::vector<rtg::instruction*>) { add_op("Constant", [this](attribute_map attributes, std::vector<rtg::instruction_ref>) {
rtg::literal v = parse_value(attributes.at("value")); rtg::literal v = parse_value(attributes.at("value"));
return prog.add_literal(v); return prog.add_literal(v);
}); });
...@@ -141,7 +141,7 @@ struct onnx_parser ...@@ -141,7 +141,7 @@ struct onnx_parser
if(instructions.count(name) == 0) if(instructions.count(name) == 0)
{ {
auto&& node = nodes.at(name); auto&& node = nodes.at(name);
std::vector<rtg::instruction*> args; std::vector<rtg::instruction_ref> args;
for(auto&& input : node.input()) for(auto&& input : node.input())
{ {
if(nodes.count(input) > 0) if(nodes.count(input) > 0)
......
...@@ -18,37 +18,37 @@ program::program(program&&) noexcept = default; ...@@ -18,37 +18,37 @@ program::program(program&&) noexcept = default;
program& program::operator=(program&&) noexcept = default; program& program::operator=(program&&) noexcept = default;
program::~program() noexcept = default; program::~program() noexcept = default;
instruction* program::add_instruction(operation op, std::vector<instruction*> args) instruction_ref program::add_instruction(operation op, std::vector<instruction_ref> args)
{ {
assert( assert(
std::all_of(args.begin(), args.end(), [&](instruction* x) { return has_instruction(x); }) && std::all_of(args.begin(), args.end(), [&](instruction_ref x) { return has_instruction(x); }) &&
"Argument is not an exisiting instruction"); "Argument is not an exisiting instruction");
std::vector<shape> shapes(args.size()); std::vector<shape> shapes(args.size());
std::transform( std::transform(
args.begin(), args.end(), shapes.begin(), [](instruction* ins) { return ins->result; }); args.begin(), args.end(), shapes.begin(), [](instruction_ref ins) { return ins->result; });
shape r = op.compute_shape(shapes); shape r = op.compute_shape(shapes);
impl->instructions.push_back({op, r, args}); impl->instructions.push_back({op, r, args});
assert(impl->instructions.back().arguments == args); assert(impl->instructions.back().arguments == args);
return std::addressof(impl->instructions.back()); return std::prev(impl->instructions.end());
} }
instruction* program::add_literal(literal l) instruction_ref program::add_literal(literal l)
{ {
impl->instructions.emplace_back(std::move(l)); impl->instructions.emplace_back(std::move(l));
return std::addressof(impl->instructions.back()); return std::prev(impl->instructions.end());
} }
instruction* program::add_parameter(std::string name, shape s) instruction_ref program::add_parameter(std::string name, shape s)
{ {
impl->instructions.push_back({builtin::param{std::move(name)}, s, {}}); impl->instructions.push_back({builtin::param{std::move(name)}, s, {}});
return std::addressof(impl->instructions.back()); return std::prev(impl->instructions.end());
} }
bool program::has_instruction(const instruction* ins) const bool program::has_instruction(instruction_ref ins) const
{ {
return std::find_if(impl->instructions.begin(), return std::find_if(impl->instructions.begin(),
impl->instructions.end(), impl->instructions.end(),
[&](const instruction& x) { return ins == std::addressof(x); }) != [&](const instruction& x) { return std::addressof(*ins) == std::addressof(x); }) !=
impl->instructions.end(); impl->instructions.end();
} }
...@@ -72,7 +72,7 @@ literal program::eval(std::unordered_map<std::string, argument> params) const ...@@ -72,7 +72,7 @@ literal program::eval(std::unordered_map<std::string, argument> params) const
std::transform(ins.arguments.begin(), std::transform(ins.arguments.begin(),
ins.arguments.end(), ins.arguments.end(),
values.begin(), values.begin(),
[&](instruction* i) { return results.at(i); }); [&](instruction_ref i) { return results.at(std::addressof(*i)); });
result = ins.op.compute(values); result = ins.op.compute(values);
} }
results.emplace(std::addressof(ins), result); results.emplace(std::addressof(ins), result);
...@@ -111,7 +111,7 @@ void program::print() const ...@@ -111,7 +111,7 @@ void program::print() const
for(auto&& arg : ins.arguments) for(auto&& arg : ins.arguments)
{ {
assert(this->has_instruction(arg) && "Instruction not found"); assert(this->has_instruction(arg) && "Instruction not found");
std::cout << delim << names.at(arg); std::cout << delim << names.at(std::addressof(*arg));
delim = ','; delim = ',';
} }
std::cout << ")"; std::cout << ")";
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment