"vscode:/vscode.git/clone" did not exist on "9661bd57466c445545a4f432133d1581330fd8a1"
Commit 9f046d67 authored by Paul's avatar Paul
Browse files

Parse onnx and convert to internal ir

parent 2f8e4e83
...@@ -28,7 +28,9 @@ struct literal : raw_data<literal> ...@@ -28,7 +28,9 @@ struct literal : raw_data<literal>
{ {
assert(s.packed()); assert(s.packed());
static_assert(std::is_trivial<T>{}, "Literals can only be trivial types"); static_assert(std::is_trivial<T>{}, "Literals can only be trivial types");
std::copy(x.begin(), x.end(), reinterpret_cast<T*>(buffer.data())); s.visit_type([&](auto as) {
std::copy(x.begin(), x.end(), as.from(buffer.data()));
});
} }
template<class T> template<class T>
...@@ -37,7 +39,19 @@ struct literal : raw_data<literal> ...@@ -37,7 +39,19 @@ struct literal : raw_data<literal>
{ {
assert(s.packed()); assert(s.packed());
static_assert(std::is_trivial<T>{}, "Literals can only be trivial types"); static_assert(std::is_trivial<T>{}, "Literals can only be trivial types");
std::copy(x.begin(), x.end(), reinterpret_cast<T*>(buffer.data())); s.visit_type([&](auto as) {
std::copy(x.begin(), x.end(), as.from(buffer.data()));
});
}
template<class Iterator>
literal(shape s, Iterator start, Iterator end)
: buffer(s.bytes(), 0), shape_(s)
{
assert(s.packed());
s.visit_type([&](auto as) {
std::copy(start, end, as.from(buffer.data()));
});
} }
literal(shape s, const char* x) literal(shape s, const char* x)
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <rtg/operand.hpp> #include <rtg/operand.hpp>
#include <rtg/stringutils.hpp> #include <rtg/stringutils.hpp>
#include <cmath>
namespace rtg { namespace rtg {
...@@ -10,11 +11,11 @@ struct not_computable ...@@ -10,11 +11,11 @@ struct not_computable
{ {
argument compute(std::vector<argument>) const argument compute(std::vector<argument>) const
{ {
throw "not computable"; throw std::runtime_error("not computable");
} }
}; };
struct convolution : not_computable struct convolution
{ {
std::array<std::size_t, 2> padding = {0, 0}; std::array<std::size_t, 2> padding = {0, 0};
std::array<std::size_t, 2> stride = {1, 1}; std::array<std::size_t, 2> stride = {1, 1};
...@@ -28,26 +29,31 @@ struct convolution : not_computable ...@@ -28,26 +29,31 @@ struct convolution : not_computable
} }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
if(inputs.size() != 2) throw "Wrong number of arguments"; if(inputs.size() != 2) throw std::runtime_error("Wrong number of arguments");
const shape& input = inputs.at(0); const shape& input = inputs.at(0);
const shape& weights = inputs.at(1); const shape& weights = inputs.at(1);
if(input.type() != weights.type()) throw "Type doesn't match"; if(input.type() != weights.type()) throw std::runtime_error("Type doesn't match");
if(input.size() != weights.size()) throw "Dimensions don't match"; if(input.lens().size() != weights.lens().size()) throw std::runtime_error("Dimensions don't match");
if(input.size() != 4) throw "Only 4d convolution supported"; if(input.lens().size() != 4) throw std::runtime_error("Only 4d convolution supported");
auto t = input.type(); auto t = input.type();
return {t, { return {t, {
input[0], input.lens()[0],
weights[0], weights.lens()[0],
std::max<std::ptrdiff_t>( std::size_t(std::max<std::ptrdiff_t>(
1, (input[2] - (1 + dilation[0] * (weights[2] - 1)) + 2 * padding[0]) / stride[0] + 1), 1, (input.lens()[2] - (1 + dilation[0] * (weights.lens()[2] - 1)) + 2 * padding[0]) / stride[0] + 1)),
std::max<std::ptrdiff_t>( std::size_t(std::max<std::ptrdiff_t>(
1, (input[3] - (1 + dilation[1] * (weights[3] - 1)) + 2 * padding[1]) / stride[1] + 1), 1, (input.lens()[3] - (1 + dilation[1] * (weights.lens()[3] - 1)) + 2 * padding[1]) / stride[1] + 1)),
}}; }};
} }
argument compute(std::vector<argument>) const
{
throw std::runtime_error("not computable");
}
}; };
struct pooling : not_computable struct pooling
{ {
std::string mode; std::string mode;
std::array<std::size_t, 2> padding = {0, 0}; std::array<std::size_t, 2> padding = {0, 0};
...@@ -62,24 +68,29 @@ struct pooling : not_computable ...@@ -62,24 +68,29 @@ struct pooling : not_computable
} }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
if(!inputs.empty()) throw "Wrong number of arguments"; if(inputs.empty()) throw std::runtime_error("Wrong number of arguments");
const shape& input = inputs.at(0); const shape& input = inputs.at(0);
if(input.size() != 4) throw "Only 4d pooling supported"; if(input.lens().size() != 4) throw std::runtime_error("Only 4d pooling supported");
auto t = input.type(); auto t = input.type();
return {t, { return {t, {
input[0], input.lens()[0],
input[1], input.lens()[1],
std::max<std::ptrdiff_t>( std::size_t(std::max<std::ptrdiff_t>(
1, std::ceil((input[3] + 2 * padding[0] - lengths[0]) / static_cast<float>(stride[0])) + 1), 1, std::ceil((input.lens()[3] + 2 * padding[0] - lengths[0]) / static_cast<float>(stride[0])) + 1)),
std::max<std::ptrdiff_t>( std::size_t(std::max<std::ptrdiff_t>(
1, std::ceil((input[4] + 2 * padding[1] - lengths[1]) / static_cast<float>(stride[1])) + 1), 1, std::ceil((input.lens()[4] + 2 * padding[1] - lengths[1]) / static_cast<float>(stride[1])) + 1)),
}}; }};
} }
argument compute(std::vector<argument>) const
{
throw std::runtime_error("not computable");
}
}; };
struct activation : not_computable struct activation
{ {
std::string mode; std::string mode;
std::string name() const std::string name() const
...@@ -88,9 +99,14 @@ struct activation : not_computable ...@@ -88,9 +99,14 @@ struct activation : not_computable
} }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
if(!inputs.empty()) throw "Wrong number of arguments"; if(inputs.empty()) throw std::runtime_error("Wrong number of arguments");
return inputs.front(); return inputs.front();
} }
argument compute(std::vector<argument>) const
{
throw std::runtime_error("not computable");
}
}; };
......
...@@ -13,7 +13,15 @@ struct shape ...@@ -13,7 +13,15 @@ struct shape
// Add new types here // Add new types here
#define RTG_SHAPE_VISIT_TYPES(m) \ #define RTG_SHAPE_VISIT_TYPES(m) \
m(float_type, float) \ m(float_type, float) \
m(int_type, int) \ m(double_type, double) \
m(uint8_type, uint8_t) \
m(int8_type, int8_t) \
m(uint16_type, uint16_t) \
m(int16_type, int16_t) \
m(int32_type, int32_t) \
m(int64_type, int64_t) \
m(uint32_type, uint32_t) \
m(uint64_type, uint64_t) \
#define RTG_SHAPE_ENUM_TYPES(x, t) x, #define RTG_SHAPE_ENUM_TYPES(x, t) x,
enum type_t enum type_t
......
...@@ -72,7 +72,7 @@ inline std::string to_string(const Range& r) ...@@ -72,7 +72,7 @@ inline std::string to_string(const Range& r)
if(!r.empty()) if(!r.empty())
{ {
ss << r.front(); ss << r.front();
std::for_each(++r.begin(), r.end(), [&](auto&& x) std::for_each(std::next(r.begin()), r.end(), [&](auto&& x)
{ {
ss << ", " << x; ss << ", " << x;
}); });
......
...@@ -51,7 +51,10 @@ void program::print() const ...@@ -51,7 +51,10 @@ void program::print() const
if(ins.op.name() == "@literal") if(ins.op.name() == "@literal")
{ {
std::cout << "{" << ins.lit << "}"; if (ins.lit.get_shape().elements() > 10)
std::cout << "{ ... }";
else
std::cout << "{" << ins.lit << "}";
} }
if(!ins.arguments.empty()) if(!ins.arguments.empty())
......
...@@ -5,8 +5,10 @@ ...@@ -5,8 +5,10 @@
#include <iostream> #include <iostream>
#include <fstream> #include <fstream>
#include <unordered_map> #include <unordered_map>
#include <functional>
#include <rtg/program.hpp> #include <rtg/program.hpp>
#include <rtg/operators.hpp>
struct unknown struct unknown
{ {
...@@ -26,12 +28,95 @@ struct unknown ...@@ -26,12 +28,95 @@ struct unknown
} }
}; };
template<class C, class T>
bool contains(C&& c, T&& x)
{
return c.find(x) != c.end();
}
template<class Range, class Iterator>
void copy(Range&& r, Iterator it)
{
std::copy(r.begin(), r.end(), it);
}
struct onnx_parser struct onnx_parser
{ {
std::unordered_map<std::string, onnx::NodeProto> nodes; using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>;
using node_map = std::unordered_map<std::string, onnx::NodeProto>;
node_map nodes;
std::unordered_map<std::string, rtg::instruction*> instructions; std::unordered_map<std::string, rtg::instruction*> instructions;
std::shared_ptr<rtg::program> prog = std::make_shared<rtg::program>(); std::shared_ptr<rtg::program> prog = std::make_shared<rtg::program>();
std::unordered_map<std::string, std::function<rtg::instruction*(attribute_map, std::vector<rtg::instruction*>)>> ops;
onnx_parser()
{
add_op("Conv", [this](attribute_map attributes, std::vector<rtg::instruction*> args) {
rtg::convolution op;
if(contains(attributes, "pads"))
{
copy(attributes["pads"].ints(), op.padding.begin());
}
if(contains(attributes, "strides"))
{
copy(attributes["strides"].ints(), op.stride.begin());
}
if(contains(attributes, "dilations"))
{
copy(attributes["dilations"].ints(), op.dilation.begin());
}
return prog->add_instruction(op, args);
});
add_op("MaxPool", [this](attribute_map attributes, std::vector<rtg::instruction*> args) {
rtg::pooling op{"max"};
// for(auto&& p:attributes) std::cout << p.first << std::endl;
if(contains(attributes, "pads"))
{
copy(attributes["pads"].ints(), op.padding.begin());
}
if(contains(attributes, "strides"))
{
copy(attributes["strides"].ints(), op.stride.begin());
}
if(contains(attributes, "kernel_shape"))
{
copy(attributes["kernel_shape"].ints(), op.lengths.begin());
}
return prog->add_instruction(op, args);
});
add_op("Relu", [this](attribute_map attributes, std::vector<rtg::instruction*> args) {
return prog->add_instruction(rtg::activation{"relu"}, args);
});
add_op("Constant", [this](attribute_map attributes, std::vector<rtg::instruction*>) {
rtg::literal v = parse_value(attributes.at("value"));
return prog->add_literal(v);
});
}
template<class F>
void add_op(std::string name, F f)
{
ops.emplace(name, f);
}
void parse_from(std::istream& is)
{
onnx::ModelProto model;
if(model.ParseFromIstream(&is))
{
if(model.has_graph())
{
this->parse_graph(model.graph());
}
}
else
{
throw std::runtime_error("Failed reading");
}
}
void parse_graph(const onnx::GraphProto& graph) void parse_graph(const onnx::GraphProto& graph)
{ {
nodes = get_nodes(graph); nodes = get_nodes(graph);
...@@ -39,7 +124,8 @@ struct onnx_parser ...@@ -39,7 +124,8 @@ struct onnx_parser
{ {
std::string name = input.name(); std::string name = input.name();
// TODO: Get shape of input parameter // TODO: Get shape of input parameter
instructions[name] = prog->add_parameter(name, rtg::shape{}); rtg::shape s = parse_type(input.type());
instructions[name] = prog->add_parameter(name, s);
} }
for(auto&& p:nodes) for(auto&& p:nodes)
{ {
...@@ -66,11 +152,18 @@ struct onnx_parser ...@@ -66,11 +152,18 @@ struct onnx_parser
args.push_back(instructions.at(input)); args.push_back(instructions.at(input));
} }
} }
instructions[name] = prog->add_instruction(unknown{node.op_type()}, args); if (ops.count(node.op_type()) == 0)
{
instructions[name] = prog->add_instruction(unknown{node.op_type()}, args);
}
else
{
instructions[name] = ops[node.op_type()](get_attributes(node), args);
}
} }
} }
static std::unordered_map<std::string, onnx::AttributeProto> get_attributes(const onnx::NodeProto& node) static attribute_map get_attributes(const onnx::NodeProto& node)
{ {
std::unordered_map<std::string, onnx::AttributeProto> result; std::unordered_map<std::string, onnx::AttributeProto> result;
for(auto&& attr:node.attribute()) for(auto&& attr:node.attribute())
...@@ -80,7 +173,7 @@ struct onnx_parser ...@@ -80,7 +173,7 @@ struct onnx_parser
return result; return result;
} }
static std::unordered_map<std::string, onnx::NodeProto> get_nodes(const onnx::GraphProto& graph) static node_map get_nodes(const onnx::GraphProto& graph)
{ {
std::unordered_map<std::string, onnx::NodeProto> result; std::unordered_map<std::string, onnx::NodeProto> result;
for(auto&& node:graph.node()) for(auto&& node:graph.node())
...@@ -94,21 +187,80 @@ struct onnx_parser ...@@ -94,21 +187,80 @@ struct onnx_parser
} }
return result; return result;
} }
};
std::shared_ptr<rtg::program> parse_onnx(std::istream& is) static rtg::literal parse_value(const onnx::AttributeProto& attr)
{ {
onnx_parser parser; switch(attr.type())
onnx::ModelProto model; {
if(model.ParseFromIstream(&is)) { case onnx::AttributeProto::UNDEFINED: return {};
if(model.has_graph()) { case onnx::AttributeProto::FLOAT: return rtg::literal{attr.f()};
parser.parse_graph(model.graph()); case onnx::AttributeProto::INT: return rtg::literal{attr.i()};
case onnx::AttributeProto::STRING: return {};
case onnx::AttributeProto::TENSOR: return parse_tensor(attr.t());
case onnx::AttributeProto::GRAPH: return {};
case onnx::AttributeProto::FLOATS: return rtg::literal{rtg::shape::float_type, attr.floats().begin(), attr.floats().end()};
case onnx::AttributeProto::INTS: return rtg::literal{rtg::shape::int32_type, attr.ints().begin(), attr.ints().end()};;
case onnx::AttributeProto::STRINGS: return {};
case onnx::AttributeProto::TENSORS: return {};
case onnx::AttributeProto::GRAPHS: return {};
}
}
static rtg::literal parse_tensor(const onnx::TensorProto& t)
{
std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
switch(t.data_type())
{
case onnx::TensorProto::UNDEFINED: throw std::runtime_error("");
case onnx::TensorProto::FLOAT: return rtg::literal{{rtg::shape::float_type, dims}, t.float_data().begin(), t.float_data().end()};
case onnx::TensorProto::UINT8: throw std::runtime_error("");
case onnx::TensorProto::INT8: return rtg::literal{{rtg::shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::UINT16: return rtg::literal{{rtg::shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::INT16: return rtg::literal{{rtg::shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::INT32: return rtg::literal{{rtg::shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::INT64: return rtg::literal{{rtg::shape::int64_type, dims}, t.int64_data().begin(), t.int64_data().end()};
case onnx::TensorProto::STRING: throw std::runtime_error("");
case onnx::TensorProto::BOOL: return rtg::literal{{rtg::shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::FLOAT16: throw std::runtime_error("");
case onnx::TensorProto::DOUBLE: return rtg::literal{{rtg::shape::double_type, dims}, t.double_data().begin(), t.double_data().end()};
case onnx::TensorProto::UINT32: throw std::runtime_error("");
case onnx::TensorProto::UINT64: throw std::runtime_error("");
case onnx::TensorProto::COMPLEX64: throw std::runtime_error("");
case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
} }
} else {
throw "Failed reading";
} }
return parser.prog;
} static rtg::shape parse_type(const onnx::TypeProto& t)
{
rtg::shape::type_t shape_type;
switch(t.tensor_type().elem_type())
{
case onnx::TensorProto::UNDEFINED: break; //throw std::runtime_error("Unsupported type UNDEFINED");
case onnx::TensorProto::FLOAT: shape_type = rtg::shape::float_type;
case onnx::TensorProto::UINT8: break; //throw std::runtime_error("Unsupported type UINT8");
case onnx::TensorProto::INT8: shape_type = rtg::shape::int8_type;
case onnx::TensorProto::UINT16: shape_type = rtg::shape::uint16_type;
case onnx::TensorProto::INT16: shape_type = rtg::shape::int16_type;
case onnx::TensorProto::INT32: shape_type = rtg::shape::int32_type;
case onnx::TensorProto::INT64: shape_type = rtg::shape::int64_type;
case onnx::TensorProto::STRING: break; //throw std::runtime_error("Unsupported type STRING");
case onnx::TensorProto::BOOL: break; //throw std::runtime_error("Unsupported type BOOL");
case onnx::TensorProto::FLOAT16: break; //throw std::runtime_error("Unsupported type FLOAT16");
case onnx::TensorProto::DOUBLE: shape_type = rtg::shape::double_type;
case onnx::TensorProto::UINT32: shape_type = rtg::shape::uint32_type;
case onnx::TensorProto::UINT64: shape_type = rtg::shape::uint64_type;
case onnx::TensorProto::COMPLEX64: break; //throw std::runtime_error("Unsupported type COMPLEX64");
case onnx::TensorProto::COMPLEX128: break; //throw std::runtime_error("Unsupported type COMPLEX128");
}
std::vector<std::size_t> dims;
// TODO: USe std::transform
for(auto&& d:t.tensor_type().shape().dim())
{
dims.push_back(d.dim_value());
}
return {shape_type, dims};
}
};
int main(int argc, char const *argv[]) int main(int argc, char const *argv[])
{ {
...@@ -116,7 +268,16 @@ int main(int argc, char const *argv[]) ...@@ -116,7 +268,16 @@ int main(int argc, char const *argv[])
{ {
std::string file = argv[1]; std::string file = argv[1];
std::fstream input(file.c_str(), std::ios::in | std::ios::binary); std::fstream input(file.c_str(), std::ios::in | std::ios::binary);
auto prog = parse_onnx(input); onnx_parser parser;
prog->print(); try
{
parser.parse_from(input);
}
catch(...)
{
if(parser.prog) parser.prog->print();
throw;
}
parser.prog->print();
} }
} }
...@@ -48,8 +48,8 @@ void literal_test() { ...@@ -48,8 +48,8 @@ void literal_test() {
void param_test() { void param_test() {
rtg::program p; rtg::program p;
auto x = p.add_parameter("x", {rtg::shape::int_type}); auto x = p.add_parameter("x", {rtg::shape::int64_type});
auto y = p.add_parameter("y", {rtg::shape::int_type}); auto y = p.add_parameter("y", {rtg::shape::int64_type});
p.add_instruction(sum_op{}, x, y); p.add_instruction(sum_op{}, x, y);
auto result = p.eval({ auto result = p.eval({
......
...@@ -44,7 +44,7 @@ void literal_os2() ...@@ -44,7 +44,7 @@ void literal_os2()
void literal_os3() void literal_os3()
{ {
rtg::shape s{rtg::shape::int_type, {3}}; rtg::shape s{rtg::shape::int64_type, {3}};
rtg::literal l{s, {1, 2, 3}}; rtg::literal l{s, {1, 2, 3}};
std::stringstream ss; std::stringstream ss;
ss << l; ss << l;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment