Commit 43ee9854 authored by Paul's avatar Paul
Browse files

Formatting

parent 0fc52912
...@@ -57,12 +57,11 @@ struct onnx_parser ...@@ -57,12 +57,11 @@ struct onnx_parser
add_generic_op("Mul", mul{}); add_generic_op("Mul", mul{});
add_generic_op("Relu", activation{"relu"}); add_generic_op("Relu", activation{"relu"});
add_generic_op("Sub", sub{}); add_generic_op("Sub", sub{});
add_mem_op("Constant",&onnx_parser::parse_constant);
add_mem_op("Conv",&onnx_parser::parse_conv);
add_mem_op("MaxPool",&onnx_parser::parse_pooling);
add_mem_op("Reshape",&onnx_parser::parse_reshape);
add_mem_op("Constant", &onnx_parser::parse_constant);
add_mem_op("Conv", &onnx_parser::parse_conv);
add_mem_op("MaxPool", &onnx_parser::parse_pooling);
add_mem_op("Reshape", &onnx_parser::parse_reshape);
} }
template <class F> template <class F>
...@@ -79,7 +78,7 @@ struct onnx_parser ...@@ -79,7 +78,7 @@ struct onnx_parser
}); });
} }
template<class T> template <class T>
void add_generic_op(std::string name, T x) void add_generic_op(std::string name, T x)
{ {
ops.emplace(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) { ops.emplace(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) {
...@@ -99,67 +98,75 @@ struct onnx_parser ...@@ -99,67 +98,75 @@ struct onnx_parser
}); });
} }
instruction_ref parse_conv(std::string, attribute_map attributes, std::vector<instruction_ref> args) { instruction_ref
convolution op; parse_conv(std::string, attribute_map attributes, std::vector<instruction_ref> args)
if(contains(attributes, "pads")) {
{ convolution op;
copy(attributes["pads"].ints(), op.padding.begin()); if(contains(attributes, "pads"))
} {
if(contains(attributes, "strides")) copy(attributes["pads"].ints(), op.padding.begin());
{
copy(attributes["strides"].ints(), op.stride.begin());
}
if(contains(attributes, "dilations"))
{
copy(attributes["dilations"].ints(), op.dilation.begin());
}
if(args.size() == 3)
{
uint64_t axis = 1;
auto l1 = prog.add_instruction(op, args[0], args[1]);
auto l2 = prog.add_instruction(broadcast{axis}, l1, args[2]);
return prog.add_instruction(add{}, l1, l2);
}
return prog.add_instruction(op, args);
} }
if(contains(attributes, "strides"))
instruction_ref parse_pooling(std::string, attribute_map attributes, std::vector<instruction_ref> args) { {
pooling op{"max"}; copy(attributes["strides"].ints(), op.stride.begin());
// for(auto&& p:attributes) std::cout << p.first << std::endl;
if(contains(attributes, "pads"))
{
copy(attributes["pads"].ints(), op.padding.begin());
}
if(contains(attributes, "strides"))
{
copy(attributes["strides"].ints(), op.stride.begin());
}
if(contains(attributes, "kernel_shape"))
{
copy(attributes["kernel_shape"].ints(), op.lengths.begin());
}
return prog.add_instruction(op, args);
} }
if(contains(attributes, "dilations"))
{
copy(attributes["dilations"].ints(), op.dilation.begin());
}
if(args.size() == 3)
{
uint64_t axis = 1;
auto l1 = prog.add_instruction(op, args[0], args[1]);
auto l2 = prog.add_instruction(broadcast{axis}, l1, args[2]);
return prog.add_instruction(add{}, l1, l2);
}
return prog.add_instruction(op, args);
}
instruction_ref parse_reshape(std::string, attribute_map attributes, std::vector<instruction_ref> args) { instruction_ref
reshape op; parse_pooling(std::string, attribute_map attributes, std::vector<instruction_ref> args)
if(args.size() == 1) {
{ pooling op{"max"};
literal s = parse_value(attributes.at("shape")); // for(auto&& p:attributes) std::cout << p.first << std::endl;
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); }); if(contains(attributes, "pads"))
} {
if(args.size() == 2) copy(attributes["pads"].ints(), op.padding.begin());
{
literal s = args[1]->lit;
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
}
return prog.add_instruction(op, args[0]);
} }
if(contains(attributes, "strides"))
{
copy(attributes["strides"].ints(), op.stride.begin());
}
if(contains(attributes, "kernel_shape"))
{
copy(attributes["kernel_shape"].ints(), op.lengths.begin());
}
return prog.add_instruction(op, args);
}
instruction_ref parse_constant(std::string, attribute_map attributes, std::vector<instruction_ref>) { instruction_ref
literal v = parse_value(attributes.at("value")); parse_reshape(std::string, attribute_map attributes, std::vector<instruction_ref> args)
return prog.add_literal(v); {
reshape op;
if(args.size() == 1)
{
literal s = parse_value(attributes.at("shape"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
}
if(args.size() == 2)
{
literal s = args[1]->lit;
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
} }
return prog.add_instruction(op, args[0]);
}
instruction_ref
parse_constant(std::string, attribute_map attributes, std::vector<instruction_ref>)
{
literal v = parse_value(attributes.at("value"));
return prog.add_literal(v);
}
void parse_from(std::istream& is) void parse_from(std::istream& is)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment