Commit 8cd2c875 authored by Paul's avatar Paul
Browse files

Make program use impl idion

parent 894a0fca
...@@ -3,68 +3,48 @@ ...@@ -3,68 +3,48 @@
#include <list> #include <list>
#include <unordered_map> #include <unordered_map>
#include <rtg/instruction.hpp>
#include <rtg/operation.hpp> #include <rtg/operation.hpp>
#include <rtg/literal.hpp>
#include <rtg/builtin.hpp> #include <rtg/builtin.hpp>
#include <algorithm> #include <algorithm>
namespace rtg { namespace rtg {
struct instruction;
struct program_impl;
struct program struct program
{ {
// TODO: A program should be copyable program();;
program() = default; program(program&&) = default;
program(const program&) = delete; program& operator=(program&&) = default;
program& operator=(const program&) = delete; ~program();
template <class... Ts> template <class... Ts>
instruction* add_instruction(operation op, Ts*... args) instruction* add_instruction(operation op, Ts*... args)
{ {
shape r = op.compute_shape({args->result...}); return add_instruction(op, {args...});
instructions.push_back({op, r, {args...}});
return std::addressof(instructions.back());
}
instruction* add_instruction(operation op, std::vector<instruction*> args)
{
assert(std::all_of(
args.begin(), args.end(), [&](instruction* x) { return has_instruction(x); }) &&
"Argument is not an exisiting instruction");
std::vector<shape> shapes(args.size());
std::transform(
args.begin(), args.end(), shapes.begin(), [](instruction* ins) { return ins->result; });
shape r = op.compute_shape(shapes);
instructions.push_back({op, r, args});
assert(instructions.back().arguments == args);
return std::addressof(instructions.back());
} }
instruction* add_instruction(operation op, std::vector<instruction*> args);
template <class... Ts> template <class... Ts>
instruction* add_literal(Ts&&... xs) instruction* add_literal(Ts&&... xs)
{ {
instructions.emplace_back(literal{std::forward<Ts>(xs)...}); return add_literal(literal{std::forward<Ts>(xs)...});
return std::addressof(instructions.back());
} }
instruction* add_parameter(std::string name, shape s) instruction* add_literal(literal l);
{
instructions.push_back({builtin::param{std::move(name)}, s, {}}); instruction* add_parameter(std::string name, shape s);
return std::addressof(instructions.back());
}
literal eval(std::unordered_map<std::string, argument> params) const; literal eval(std::unordered_map<std::string, argument> params) const;
// TODO: Change to stream operator // TODO: Change to stream operator
void print() const; void print() const;
bool has_instruction(const instruction* ins) const bool has_instruction(const instruction* ins) const;
{
return std::find_if(instructions.begin(), instructions.end(), [&](const instruction& x) {
return ins == std::addressof(x);
}) != instructions.end();
}
private: private:
// A list is used to keep references to an instruction stable std::unique_ptr<program_impl> impl;
std::list<instruction> instructions;
}; };
} // namespace rtg } // namespace rtg
......
...@@ -43,7 +43,7 @@ struct onnx_parser ...@@ -43,7 +43,7 @@ struct onnx_parser
using node_map = std::unordered_map<std::string, onnx::NodeProto>; using node_map = std::unordered_map<std::string, onnx::NodeProto>;
node_map nodes; node_map nodes;
std::unordered_map<std::string, rtg::instruction*> instructions; std::unordered_map<std::string, rtg::instruction*> instructions;
std::shared_ptr<rtg::program> prog = std::make_shared<rtg::program>(); rtg::program prog = std::make_shared<rtg::program>();
std::unordered_map< std::unordered_map<
std::string, std::string,
...@@ -66,7 +66,7 @@ struct onnx_parser ...@@ -66,7 +66,7 @@ struct onnx_parser
{ {
copy(attributes["dilations"].ints(), op.dilation.begin()); copy(attributes["dilations"].ints(), op.dilation.begin());
} }
return prog->add_instruction(op, args); return prog.add_instruction(op, args);
}); });
add_op("MaxPool", [this](attribute_map attributes, std::vector<rtg::instruction*> args) { add_op("MaxPool", [this](attribute_map attributes, std::vector<rtg::instruction*> args) {
rtg::pooling op{"max"}; rtg::pooling op{"max"};
...@@ -83,20 +83,20 @@ struct onnx_parser ...@@ -83,20 +83,20 @@ struct onnx_parser
{ {
copy(attributes["kernel_shape"].ints(), op.lengths.begin()); copy(attributes["kernel_shape"].ints(), op.lengths.begin());
} }
return prog->add_instruction(op, args); return prog.add_instruction(op, args);
}); });
add_op("Relu", [this](attribute_map, std::vector<rtg::instruction*> args) { add_op("Relu", [this](attribute_map, std::vector<rtg::instruction*> args) {
return prog->add_instruction(rtg::activation{"relu"}, args); return prog.add_instruction(rtg::activation{"relu"}, args);
}); });
add_op("Reshape", [this](attribute_map attributes, std::vector<rtg::instruction*> args) { add_op("Reshape", [this](attribute_map attributes, std::vector<rtg::instruction*> args) {
rtg::reshape op; rtg::reshape op;
rtg::literal s = parse_value(attributes.at("shape")); rtg::literal s = parse_value(attributes.at("shape"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); }); s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
return prog->add_instruction(op, args); return prog.add_instruction(op, args);
}); });
add_op("Constant", [this](attribute_map attributes, std::vector<rtg::instruction*>) { add_op("Constant", [this](attribute_map attributes, std::vector<rtg::instruction*>) {
rtg::literal v = parse_value(attributes.at("value")); rtg::literal v = parse_value(attributes.at("value"));
return prog->add_literal(v); return prog.add_literal(v);
}); });
} }
...@@ -130,7 +130,7 @@ struct onnx_parser ...@@ -130,7 +130,7 @@ struct onnx_parser
const std::string& name = input.name(); const std::string& name = input.name();
// TODO: Get shape of input parameter // TODO: Get shape of input parameter
rtg::shape s = parse_type(input.type()); rtg::shape s = parse_type(input.type());
instructions[name] = prog->add_parameter(name, s); instructions[name] = prog.add_parameter(name, s);
} }
for(auto&& p : nodes) for(auto&& p : nodes)
{ {
...@@ -159,7 +159,7 @@ struct onnx_parser ...@@ -159,7 +159,7 @@ struct onnx_parser
} }
if(ops.count(node.op_type()) == 0) if(ops.count(node.op_type()) == 0)
{ {
instructions[name] = prog->add_instruction(unknown{node.op_type()}, args); instructions[name] = prog.add_instruction(unknown{node.op_type()}, args);
} }
else else
{ {
...@@ -306,9 +306,9 @@ int main(int argc, char const* argv[]) ...@@ -306,9 +306,9 @@ int main(int argc, char const* argv[])
catch(...) catch(...)
{ {
if(parser.prog) if(parser.prog)
parser.prog->print(); parser.prog.print();
throw; throw;
} }
parser.prog->print(); parser.prog.print();
} }
} }
#include <rtg/program.hpp> #include <rtg/program.hpp>
#include <rtg/stringutils.hpp> #include <rtg/stringutils.hpp>
#include <rtg/instruction.hpp>
#include <iostream> #include <iostream>
#include <algorithm> #include <algorithm>
namespace rtg { namespace rtg {
struct program_impl
{
// A list is used to keep references to an instruction stable
std::list<instruction> instructions;
};
program::program()
: impl(std::make_unique<program_impl>())
{}
program::~program()
{}
instruction* program::add_instruction(operation op, std::vector<instruction*> args)
{
assert(std::all_of(
args.begin(), args.end(), [&](instruction* x) { return has_instruction(x); }) &&
"Argument is not an exisiting instruction");
std::vector<shape> shapes(args.size());
std::transform(
args.begin(), args.end(), shapes.begin(), [](instruction* ins) { return ins->result; });
shape r = op.compute_shape(shapes);
impl->instructions.push_back({op, r, args});
assert(impl->instructions.back().arguments == args);
return std::addressof(impl->instructions.back());
}
instruction* program::add_literal(literal l)
{
impl->instructions.emplace_back(std::move(l));
return std::addressof(impl->instructions.back());
}
instruction* program::add_parameter(std::string name, shape s)
{
impl->instructions.push_back({builtin::param{std::move(name)}, s, {}});
return std::addressof(impl->instructions.back());
}
bool program::has_instruction(const instruction* ins) const
{
return std::find_if(impl->instructions.begin(), impl->instructions.end(), [&](const instruction& x) {
return ins == std::addressof(x);
}) != impl->instructions.end();
}
literal program::eval(std::unordered_map<std::string, argument> params) const literal program::eval(std::unordered_map<std::string, argument> params) const
{ {
std::unordered_map<const instruction*, argument> results; std::unordered_map<const instruction*, argument> results;
argument result; argument result;
for(auto& ins : instructions) for(auto& ins : impl->instructions)
{ {
if(ins.op.name() == "@literal") if(ins.op.name() == "@literal")
{ {
...@@ -38,7 +85,7 @@ void program::print() const ...@@ -38,7 +85,7 @@ void program::print() const
std::unordered_map<const instruction*, std::string> names; std::unordered_map<const instruction*, std::string> names;
int count = 0; int count = 0;
for(auto& ins : instructions) for(auto& ins : impl->instructions)
{ {
std::string var_name = "@" + std::to_string(count); std::string var_name = "@" + std::to_string(count);
if(starts_with(ins.op.name(), "@param")) if(starts_with(ins.op.name(), "@param"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment