Commit 0fc287d3 authored by Paul's avatar Paul
Browse files

Print program

parent 1aad1ec3
......@@ -18,6 +18,6 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR})
add_executable(read_onnx src/read_onnx.cpp ${PROTO_SRCS})
target_include_directories(read_onnx PUBLIC ${PROTOBUF_INLCUDE_DIR})
target_link_libraries(read_onnx ${PROTOBUF_LIBRARY})
target_link_libraries(read_onnx ${PROTOBUF_LIBRARY} rtg)
add_subdirectory(test)
#ifndef RTG_GUARD_OPERATORS_HPP
#define RTG_GUARD_OPERATORS_HPP
#include <rtg/operand.hpp>
#include <rtg/stringutils.hpp>
namespace rtg {
struct not_computable
{
argument compute(std::vector<argument>) const
{
throw "not computable";
}
};
struct convolution : not_computable
{
std::array<std::size_t, 2> padding = {0, 0};
std::array<std::size_t, 2> stride = {1, 1};
std::array<std::size_t, 2> dilation = {1, 1};
std::string name() const
{
return "convolution[padding={" + to_string(padding) +
"}, stride={" + to_string(stride) +
"}, dilation={" + to_string(dilation) +
"}]";
}
shape compute_shape(std::vector<shape> inputs) const
{
if(inputs.size() != 2) throw "Wrong number of arguments";
const shape& input = inputs.at(0);
const shape& weights = inputs.at(1);
if(input.type() != weights.type()) throw "Type doesn't match";
if(input.size() != weights.size()) throw "Dimensions don't match";
if(input.size() != 4) throw "Only 4d convolution supported";
auto t = input.type();
return {t, {
input[0],
weights[0],
std::max<std::ptrdiff_t>(
1, (input[2] - (1 + dilation[0] * (weights[2] - 1)) + 2 * padding[0]) / stride[0] + 1),
std::max<std::ptrdiff_t>(
1, (input[3] - (1 + dilation[1] * (weights[3] - 1)) + 2 * padding[1]) / stride[1] + 1),
}};
}
};
struct pooling : not_computable
{
std::string mode;
std::array<std::size_t, 2> padding = {0, 0};
std::array<std::size_t, 2> stride = {1, 1};
std::array<std::size_t, 2> lengths = {1, 1};
std::string name() const
{
return "pooling:" + mode + "[padding={" + to_string(padding) +
"}, stride={" + to_string(stride) +
"}, lengths={" + to_string(lengths) +
"}]";
}
shape compute_shape(std::vector<shape> inputs) const
{
if(!inputs.empty()) throw "Wrong number of arguments";
const shape& input = inputs.at(0);
if(input.size() != 4) throw "Only 4d pooling supported";
auto t = input.type();
return {t, {
input[0],
input[1],
std::max<std::ptrdiff_t>(
1, std::ceil((input[3] + 2 * padding[0] - lengths[0]) / static_cast<float>(stride[0])) + 1),
std::max<std::ptrdiff_t>(
1, std::ceil((input[4] + 2 * padding[1] - lengths[1]) / static_cast<float>(stride[1])) + 1),
}};
}
};
struct activation : not_computable
{
std::string mode;
std::string name() const
{
return "activation:" + mode;
}
shape compute_shape(std::vector<shape> inputs) const
{
if(!inputs.empty()) throw "Wrong number of arguments";
return inputs.front();
}
};
} // namespace rtg
#endif
......@@ -33,6 +33,9 @@ struct program
literal eval(std::unordered_map<std::string, argument> params) const;
// TODO: Change to stream operator
void print() const;
private:
// A list is used to keep references to an instruction stable
std::list<instruction> instructions;
......
......@@ -3,7 +3,7 @@
#include <vector>
#include <cassert>
#include <ostream>
namespace rtg {
......@@ -53,6 +53,7 @@ struct shape
friend bool operator==(const shape& x, const shape& y);
friend bool operator!=(const shape& x, const shape& y);
friend std::ostream& operator<<(std::ostream& os, const shape& x);
template<class T>
struct as
......@@ -122,6 +123,7 @@ private:
void calculate_strides();
std::size_t element_space() const;
std::string type_string() const;
};
}
......
......@@ -4,6 +4,7 @@
#include <algorithm>
#include <numeric>
#include <string>
#include <sstream>
namespace rtg {
......@@ -64,6 +65,21 @@ inline std::string remove_prefix(std::string s, std::string prefix)
return s;
}
template<class Range>
inline std::string to_string(const Range& r)
{
std::stringstream ss;
if(!r.empty())
{
ss << r.front();
std::for_each(++r.begin(), r.end(), [&](auto&& x)
{
ss << ", " << x;
});
}
return ss.str();
}
} // namespace rtg
#endif
\ No newline at end of file
#include <rtg/program.hpp>
#include <rtg/stringutils.hpp>
#include <iostream>
#include <algorithm>
namespace rtg {
......@@ -31,5 +32,47 @@ literal program::eval(std::unordered_map<std::string, argument> params) const
return literal{result.get_shape(), result.data()};
}
void program::print() const
{
std::unordered_map<const instruction*, std::string> names;
int count = 0;
for(auto& ins:instructions)
{
std::string var_name = "@" + std::to_string(count);
if(starts_with(ins.op.name(), "@param"))
{
var_name = ins.op.name().substr(7);
}
std::cout << var_name << " = ";
std::cout << ins.op.name();
if(ins.op.name() == "@literal")
{
std::cout << "{" << ins.lit << "}";
}
if(!ins.arguments.empty())
{
char delim = '(';
for(auto&& arg:ins.arguments)
{
std::cout << delim << names.at(arg);
delim = ',';
}
std::cout << ")";
}
std::cout << " -> " << ins.result;
std::cout << std::endl;
names.emplace(std::addressof(ins), var_name);
}
}
}
......@@ -6,6 +6,25 @@
#include <fstream>
#include <unordered_map>
#include <rtg/program.hpp>
struct unknown
{
rtg::shape s;
std::string op;
std::string name() const
{
return "unknown:"+op;
}
rtg::shape compute_shape(std::vector<rtg::shape> input) const
{
return s;
}
rtg::argument compute(std::vector<rtg::argument> input) const
{
throw "not computable";
}
};
std::unordered_map<std::string, onnx::AttributeProto> get_attributes(const onnx::NodeProto& node)
{
......@@ -17,6 +36,16 @@ std::unordered_map<std::string, onnx::AttributeProto> get_attributes(const onnx:
return result;
}
std::unordered_map<std::string, onnx::NodeProto> get_nodes(const onnx::GraphProto& graph)
{
std::unordered_map<std::string, onnx::NodeProto> result;
for(auto&& node:graph.node())
{
result[node.name()] = node;
}
return result;
}
void parse_graph(onnx::GraphProto graph)
{
std::cout << "Graph name: " << graph.name() << std::endl;
......@@ -27,6 +56,12 @@ void parse_graph(onnx::GraphProto graph)
std::cout << " Input: " << node.input(0) << std::endl;
if(node.output_size() > 0)
std::cout << " Output: " << node.output(0) << std::endl;
std::cout << " Attributes: " << std::endl;
for(auto&& attr:node.attribute())
{
std::cout << " " << attr.name() << std::endl;
}
}
}
......
#include <rtg/shape.hpp>
#include <rtg/stringutils.hpp>
#include <numeric>
#include <algorithm>
#include <functional>
......@@ -98,6 +99,19 @@ std::size_t shape::element_space() const
1;
}
std::string shape::type_string() const
{
switch(this->type_)
{
#define RTG_SHAPE_TYPE_STRING_CASE(x, t) \
case x: \
return #x;
RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_TYPE_STRING_CASE)
#undef RTG_SHAPE_TYPE_STRING_CASE
}
throw "Invalid type";
}
bool operator==(const shape& x, const shape& y)
{
return x.type() == y.type() && x.lens() == y.lens() && x.strides() == y.strides();
......@@ -107,4 +121,12 @@ bool operator!=(const shape& x, const shape& y)
return !(x == y);
}
std::ostream& operator<<(std::ostream& os, const shape& x)
{
os << x.type_string() << ", ";
os << "{" << to_string(x.lens()) << "}, ";
os << "{" << to_string(x.strides()) << "}";
return os;
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment