Commit 0fc52912 authored by Paul's avatar Paul
Browse files

Refactor operators in onnx parser

parent 02e0dd2a
......@@ -51,7 +51,55 @@ struct onnx_parser
onnx_parser()
{
add_op("Conv", [this](attribute_map attributes, std::vector<instruction_ref> args) {
add_generic_op("Add", add{});
add_generic_op("Div", div{});
add_generic_op("MatMul", gemm{});
add_generic_op("Mul", mul{});
add_generic_op("Relu", activation{"relu"});
add_generic_op("Sub", sub{});
add_mem_op("Constant",&onnx_parser::parse_constant);
add_mem_op("Conv",&onnx_parser::parse_conv);
add_mem_op("MaxPool",&onnx_parser::parse_pooling);
add_mem_op("Reshape",&onnx_parser::parse_reshape);
}
template <class F>
void add_op(std::string name, F f)
{
ops.emplace(name, f);
}
template <class F>
void add_mem_op(std::string name, F f)
{
ops.emplace(name, [=](auto&&... xs) {
return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
});
}
template<class T>
void add_generic_op(std::string name, T x)
{
ops.emplace(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) {
if(args.size() == 2 and contains(attributes, "broadcast"))
{
uint64_t broadcasted = parse_value(attributes.at("broadcast")).at<uint64_t>();
if(broadcasted != 0)
{
uint64_t axis = (contains(attributes, "axis"))
? parse_value(attributes.at("axis")).at<uint64_t>()
: 0;
auto l = prog.add_instruction(broadcast{axis}, args);
return prog.add_instruction(x, args[0], l);
}
}
return prog.add_instruction(x, args);
});
}
instruction_ref parse_conv(std::string, attribute_map attributes, std::vector<instruction_ref> args) {
convolution op;
if(contains(attributes, "pads"))
{
......@@ -73,11 +121,9 @@ struct onnx_parser
return prog.add_instruction(add{}, l1, l2);
}
return prog.add_instruction(op, args);
});
add_op("MatMul", [this](attribute_map, std::vector<instruction_ref> args) {
return prog.add_instruction(gemm{}, args);
});
add_op("MaxPool", [this](attribute_map attributes, std::vector<instruction_ref> args) {
}
instruction_ref parse_pooling(std::string, attribute_map attributes, std::vector<instruction_ref> args) {
pooling op{"max"};
// for(auto&& p:attributes) std::cout << p.first << std::endl;
if(contains(attributes, "pads"))
......@@ -93,11 +139,9 @@ struct onnx_parser
copy(attributes["kernel_shape"].ints(), op.lengths.begin());
}
return prog.add_instruction(op, args);
});
add_op("Relu", [this](attribute_map, std::vector<instruction_ref> args) {
return prog.add_instruction(activation{"relu"}, args);
});
add_op("Reshape", [this](attribute_map attributes, std::vector<instruction_ref> args) {
}
instruction_ref parse_reshape(std::string, attribute_map attributes, std::vector<instruction_ref> args) {
reshape op;
if(args.size() == 1)
{
......@@ -110,78 +154,12 @@ struct onnx_parser
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
}
return prog.add_instruction(op, args[0]);
});
add_op("Constant", [this](attribute_map attributes, std::vector<instruction_ref>) {
}
instruction_ref parse_constant(std::string, attribute_map attributes, std::vector<instruction_ref>) {
literal v = parse_value(attributes.at("value"));
return prog.add_literal(v);
});
add_op("Add", [this](attribute_map attributes, std::vector<instruction_ref> args) {
if(contains(attributes, "broadcast"))
{
uint64_t broadcasted = parse_value(attributes.at("broadcast")).at<uint64_t>();
if(broadcasted != 0)
{
uint64_t axis = (contains(attributes, "axis"))
? parse_value(attributes.at("axis")).at<uint64_t>()
: 0;
auto l = prog.add_instruction(broadcast{axis}, args);
return prog.add_instruction(add{}, args[0], l);
}
}
return prog.add_instruction(add{}, args);
});
add_op("Sub", [this](attribute_map attributes, std::vector<instruction_ref> args) {
if(contains(attributes, "broadcast"))
{
uint64_t broadcasted = parse_value(attributes.at("broadcast")).at<uint64_t>();
if(broadcasted != 0)
{
uint64_t axis = (contains(attributes, "axis"))
? parse_value(attributes.at("axis")).at<uint64_t>()
: 0;
auto l = prog.add_instruction(broadcast{axis}, args);
return prog.add_instruction(sub{}, args[0], l);
}
}
return prog.add_instruction(sub{}, args);
});
add_op("Mul", [this](attribute_map attributes, std::vector<instruction_ref> args) {
if(contains(attributes, "broadcast"))
{
uint64_t broadcasted = parse_value(attributes.at("broadcast")).at<uint64_t>();
if(broadcasted != 0)
{
uint64_t axis = (contains(attributes, "axis"))
? parse_value(attributes.at("axis")).at<uint64_t>()
: 0;
auto l = prog.add_instruction(broadcast{axis}, args);
return prog.add_instruction(mul{}, args[0], l);
}
}
return prog.add_instruction(mul{}, args);
});
add_op("Div", [this](attribute_map attributes, std::vector<instruction_ref> args) {
if(contains(attributes, "broadcast"))
{
uint64_t broadcasted = parse_value(attributes.at("broadcast")).at<uint64_t>();
if(broadcasted != 0)
{
uint64_t axis = (contains(attributes, "axis"))
? parse_value(attributes.at("axis")).at<uint64_t>()
: 0;
auto l = prog.add_instruction(broadcast{axis}, args);
return prog.add_instruction(div{}, args[0], l);
}
}
return prog.add_instruction(div{}, args);
});
}
template <class F>
void add_op(std::string name, F f)
{
ops.emplace(name, f);
}
}
void parse_from(std::istream& is)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment