Commit f3b6bc41 authored by Paul's avatar Paul
Browse files

Use iterators for instruction ref

parent ab87d119
......@@ -4,6 +4,7 @@
#include <rtg/literal.hpp>
#include <rtg/shape.hpp>
#include <rtg/builtin.hpp>
#include <rtg/instruction_ref.hpp>
#include <string>
namespace rtg {
......@@ -12,7 +13,7 @@ struct instruction
{
instruction() {}
instruction(operation o, shape r, std::vector<instruction*> args)
instruction(operation o, shape r, std::vector<instruction_ref> args)
: op(std::move(o)), result(std::move(r)), arguments(std::move(args))
{
}
......@@ -21,7 +22,7 @@ struct instruction
operation op;
shape result;
std::vector<instruction*> arguments;
std::vector<instruction_ref> arguments;
literal lit;
};
......
#ifndef RTG_GUARD_INSTRUCTION_REF_HPP
#define RTG_GUARD_INSTRUCTION_REF_HPP
#include <list>
namespace rtg {
struct instruction;
using instruction_ref = std::list<instruction>::iterator;
} // namespace rtg
#endif
......@@ -6,11 +6,11 @@
#include <rtg/operation.hpp>
#include <rtg/literal.hpp>
#include <rtg/builtin.hpp>
#include <rtg/instruction_ref.hpp>
#include <algorithm>
namespace rtg {
struct instruction;
struct program_impl;
struct program
......@@ -21,27 +21,27 @@ struct program
~program() noexcept;
template <class... Ts>
instruction* add_instruction(operation op, Ts*... args)
instruction_ref add_instruction(operation op, Ts... args)
{
return add_instruction(op, {args...});
}
instruction* add_instruction(operation op, std::vector<instruction*> args);
instruction_ref add_instruction(operation op, std::vector<instruction_ref> args);
template <class... Ts>
instruction* add_literal(Ts&&... xs)
instruction_ref add_literal(Ts&&... xs)
{
return add_literal(literal{std::forward<Ts>(xs)...});
}
instruction* add_literal(literal l);
instruction_ref add_literal(literal l);
instruction* add_parameter(std::string name, shape s);
instruction_ref add_parameter(std::string name, shape s);
literal eval(std::unordered_map<std::string, argument> params) const;
// TODO: Change to stream operator
void print() const;
bool has_instruction(const instruction* ins) const;
bool has_instruction(instruction_ref ins) const;
private:
std::unique_ptr<program_impl> impl;
......
......@@ -41,16 +41,16 @@ struct onnx_parser
{
using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>;
using node_map = std::unordered_map<std::string, onnx::NodeProto>;
using op_func = std::function<rtg::instruction*(attribute_map, std::vector<rtg::instruction*>)>;
using op_func = std::function<rtg::instruction_ref(attribute_map, std::vector<rtg::instruction_ref>)>;
node_map nodes;
std::unordered_map<std::string, rtg::instruction*> instructions;
std::unordered_map<std::string, rtg::instruction_ref> instructions;
rtg::program prog = rtg::program();
std::unordered_map<std::string, op_func> ops;
onnx_parser()
{
add_op("Conv", [this](attribute_map attributes, std::vector<rtg::instruction*> args) {
add_op("Conv", [this](attribute_map attributes, std::vector<rtg::instruction_ref> args) {
rtg::convolution op;
if(contains(attributes, "pads"))
{
......@@ -66,7 +66,7 @@ struct onnx_parser
}
return prog.add_instruction(op, args);
});
add_op("MaxPool", [this](attribute_map attributes, std::vector<rtg::instruction*> args) {
add_op("MaxPool", [this](attribute_map attributes, std::vector<rtg::instruction_ref> args) {
rtg::pooling op{"max"};
// for(auto&& p:attributes) std::cout << p.first << std::endl;
if(contains(attributes, "pads"))
......@@ -83,16 +83,16 @@ struct onnx_parser
}
return prog.add_instruction(op, args);
});
add_op("Relu", [this](attribute_map, std::vector<rtg::instruction*> args) {
add_op("Relu", [this](attribute_map, std::vector<rtg::instruction_ref> args) {
return prog.add_instruction(rtg::activation{"relu"}, args);
});
add_op("Reshape", [this](attribute_map attributes, std::vector<rtg::instruction*> args) {
add_op("Reshape", [this](attribute_map attributes, std::vector<rtg::instruction_ref> args) {
rtg::reshape op;
rtg::literal s = parse_value(attributes.at("shape"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
return prog.add_instruction(op, args);
});
add_op("Constant", [this](attribute_map attributes, std::vector<rtg::instruction*>) {
add_op("Constant", [this](attribute_map attributes, std::vector<rtg::instruction_ref>) {
rtg::literal v = parse_value(attributes.at("value"));
return prog.add_literal(v);
});
......@@ -141,7 +141,7 @@ struct onnx_parser
if(instructions.count(name) == 0)
{
auto&& node = nodes.at(name);
std::vector<rtg::instruction*> args;
std::vector<rtg::instruction_ref> args;
for(auto&& input : node.input())
{
if(nodes.count(input) > 0)
......
......@@ -18,37 +18,37 @@ program::program(program&&) noexcept = default;
program& program::operator=(program&&) noexcept = default;
program::~program() noexcept = default;
instruction* program::add_instruction(operation op, std::vector<instruction*> args)
instruction_ref program::add_instruction(operation op, std::vector<instruction_ref> args)
{
assert(
std::all_of(args.begin(), args.end(), [&](instruction* x) { return has_instruction(x); }) &&
std::all_of(args.begin(), args.end(), [&](instruction_ref x) { return has_instruction(x); }) &&
"Argument is not an exisiting instruction");
std::vector<shape> shapes(args.size());
std::transform(
args.begin(), args.end(), shapes.begin(), [](instruction* ins) { return ins->result; });
args.begin(), args.end(), shapes.begin(), [](instruction_ref ins) { return ins->result; });
shape r = op.compute_shape(shapes);
impl->instructions.push_back({op, r, args});
assert(impl->instructions.back().arguments == args);
return std::addressof(impl->instructions.back());
return std::prev(impl->instructions.end());
}
instruction* program::add_literal(literal l)
instruction_ref program::add_literal(literal l)
{
impl->instructions.emplace_back(std::move(l));
return std::addressof(impl->instructions.back());
return std::prev(impl->instructions.end());
}
instruction* program::add_parameter(std::string name, shape s)
instruction_ref program::add_parameter(std::string name, shape s)
{
impl->instructions.push_back({builtin::param{std::move(name)}, s, {}});
return std::addressof(impl->instructions.back());
return std::prev(impl->instructions.end());
}
bool program::has_instruction(const instruction* ins) const
bool program::has_instruction(instruction_ref ins) const
{
return std::find_if(impl->instructions.begin(),
impl->instructions.end(),
[&](const instruction& x) { return ins == std::addressof(x); }) !=
[&](const instruction& x) { return std::addressof(*ins) == std::addressof(x); }) !=
impl->instructions.end();
}
......@@ -72,7 +72,7 @@ literal program::eval(std::unordered_map<std::string, argument> params) const
std::transform(ins.arguments.begin(),
ins.arguments.end(),
values.begin(),
[&](instruction* i) { return results.at(i); });
[&](instruction_ref i) { return results.at(std::addressof(*i)); });
result = ins.op.compute(values);
}
results.emplace(std::addressof(ins), result);
......@@ -111,7 +111,7 @@ void program::print() const
for(auto&& arg : ins.arguments)
{
assert(this->has_instruction(arg) && "Instruction not found");
std::cout << delim << names.at(arg);
std::cout << delim << names.at(std::addressof(*arg));
delim = ',';
}
std::cout << ")";
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment