Commit 8cd2c875 authored by Paul's avatar Paul
Browse files

Make program use impl idion

parent 894a0fca
......@@ -3,68 +3,48 @@
#include <list>
#include <unordered_map>
#include <rtg/instruction.hpp>
#include <rtg/operation.hpp>
#include <rtg/literal.hpp>
#include <rtg/builtin.hpp>
#include <algorithm>
namespace rtg {
struct instruction;
struct program_impl;
struct program
{
// TODO: A program should be copyable
program() = default;
program(const program&) = delete;
program& operator=(const program&) = delete;
program();;
program(program&&) = default;
program& operator=(program&&) = default;
~program();
template <class... Ts>
instruction* add_instruction(operation op, Ts*... args)
{
shape r = op.compute_shape({args->result...});
instructions.push_back({op, r, {args...}});
return std::addressof(instructions.back());
}
instruction* add_instruction(operation op, std::vector<instruction*> args)
{
assert(std::all_of(
args.begin(), args.end(), [&](instruction* x) { return has_instruction(x); }) &&
"Argument is not an exisiting instruction");
std::vector<shape> shapes(args.size());
std::transform(
args.begin(), args.end(), shapes.begin(), [](instruction* ins) { return ins->result; });
shape r = op.compute_shape(shapes);
instructions.push_back({op, r, args});
assert(instructions.back().arguments == args);
return std::addressof(instructions.back());
return add_instruction(op, {args...});
}
instruction* add_instruction(operation op, std::vector<instruction*> args);
template <class... Ts>
instruction* add_literal(Ts&&... xs)
{
instructions.emplace_back(literal{std::forward<Ts>(xs)...});
return std::addressof(instructions.back());
return add_literal(literal{std::forward<Ts>(xs)...});
}
instruction* add_parameter(std::string name, shape s)
{
instructions.push_back({builtin::param{std::move(name)}, s, {}});
return std::addressof(instructions.back());
}
instruction* add_literal(literal l);
instruction* add_parameter(std::string name, shape s);
literal eval(std::unordered_map<std::string, argument> params) const;
// TODO: Change to stream operator
void print() const;
bool has_instruction(const instruction* ins) const
{
return std::find_if(instructions.begin(), instructions.end(), [&](const instruction& x) {
return ins == std::addressof(x);
}) != instructions.end();
}
bool has_instruction(const instruction* ins) const;
private:
// A list is used to keep references to an instruction stable
std::list<instruction> instructions;
std::unique_ptr<program_impl> impl;
};
} // namespace rtg
......
......@@ -43,7 +43,7 @@ struct onnx_parser
using node_map = std::unordered_map<std::string, onnx::NodeProto>;
node_map nodes;
std::unordered_map<std::string, rtg::instruction*> instructions;
std::shared_ptr<rtg::program> prog = std::make_shared<rtg::program>();
rtg::program prog = std::make_shared<rtg::program>();
std::unordered_map<
std::string,
......@@ -66,7 +66,7 @@ struct onnx_parser
{
copy(attributes["dilations"].ints(), op.dilation.begin());
}
return prog->add_instruction(op, args);
return prog.add_instruction(op, args);
});
add_op("MaxPool", [this](attribute_map attributes, std::vector<rtg::instruction*> args) {
rtg::pooling op{"max"};
......@@ -83,20 +83,20 @@ struct onnx_parser
{
copy(attributes["kernel_shape"].ints(), op.lengths.begin());
}
return prog->add_instruction(op, args);
return prog.add_instruction(op, args);
});
add_op("Relu", [this](attribute_map, std::vector<rtg::instruction*> args) {
return prog->add_instruction(rtg::activation{"relu"}, args);
return prog.add_instruction(rtg::activation{"relu"}, args);
});
add_op("Reshape", [this](attribute_map attributes, std::vector<rtg::instruction*> args) {
rtg::reshape op;
rtg::literal s = parse_value(attributes.at("shape"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
return prog->add_instruction(op, args);
return prog.add_instruction(op, args);
});
add_op("Constant", [this](attribute_map attributes, std::vector<rtg::instruction*>) {
rtg::literal v = parse_value(attributes.at("value"));
return prog->add_literal(v);
return prog.add_literal(v);
});
}
......@@ -130,7 +130,7 @@ struct onnx_parser
const std::string& name = input.name();
// TODO: Get shape of input parameter
rtg::shape s = parse_type(input.type());
instructions[name] = prog->add_parameter(name, s);
instructions[name] = prog.add_parameter(name, s);
}
for(auto&& p : nodes)
{
......@@ -159,7 +159,7 @@ struct onnx_parser
}
if(ops.count(node.op_type()) == 0)
{
instructions[name] = prog->add_instruction(unknown{node.op_type()}, args);
instructions[name] = prog.add_instruction(unknown{node.op_type()}, args);
}
else
{
......@@ -306,9 +306,9 @@ int main(int argc, char const* argv[])
catch(...)
{
if(parser.prog)
parser.prog->print();
parser.prog.print();
throw;
}
parser.prog->print();
parser.prog.print();
}
}
#include <rtg/program.hpp>
#include <rtg/stringutils.hpp>
#include <rtg/instruction.hpp>
#include <iostream>
#include <algorithm>
namespace rtg {
struct program_impl
{
// A list is used to keep references to an instruction stable
std::list<instruction> instructions;
};
program::program()
: impl(std::make_unique<program_impl>())
{}
program::~program()
{}
instruction* program::add_instruction(operation op, std::vector<instruction*> args)
{
assert(std::all_of(
args.begin(), args.end(), [&](instruction* x) { return has_instruction(x); }) &&
"Argument is not an exisiting instruction");
std::vector<shape> shapes(args.size());
std::transform(
args.begin(), args.end(), shapes.begin(), [](instruction* ins) { return ins->result; });
shape r = op.compute_shape(shapes);
impl->instructions.push_back({op, r, args});
assert(impl->instructions.back().arguments == args);
return std::addressof(impl->instructions.back());
}
instruction* program::add_literal(literal l)
{
impl->instructions.emplace_back(std::move(l));
return std::addressof(impl->instructions.back());
}
instruction* program::add_parameter(std::string name, shape s)
{
impl->instructions.push_back({builtin::param{std::move(name)}, s, {}});
return std::addressof(impl->instructions.back());
}
bool program::has_instruction(const instruction* ins) const
{
return std::find_if(impl->instructions.begin(), impl->instructions.end(), [&](const instruction& x) {
return ins == std::addressof(x);
}) != impl->instructions.end();
}
literal program::eval(std::unordered_map<std::string, argument> params) const
{
std::unordered_map<const instruction*, argument> results;
argument result;
for(auto& ins : instructions)
for(auto& ins : impl->instructions)
{
if(ins.op.name() == "@literal")
{
......@@ -38,7 +85,7 @@ void program::print() const
std::unordered_map<const instruction*, std::string> names;
int count = 0;
for(auto& ins : instructions)
for(auto& ins : impl->instructions)
{
std::string var_name = "@" + std::to_string(count);
if(starts_with(ins.op.name(), "@param"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment