Commit 21ec113b authored by Khalique's avatar Khalique
Browse files

initial testing

parent 2ed46170
...@@ -875,6 +875,16 @@ struct div : binary ...@@ -875,6 +875,16 @@ struct div : binary
std::string name() const { return "div"; } std::string name() const { return "div"; }
}; };
struct max : binary
{
std::string name() const { return "max"; }
};
struct min : binary
{
std::string name() const { return "min"; }
};
struct load struct load
{ {
shape s; shape s;
......
...@@ -58,11 +58,14 @@ struct onnx_parser ...@@ -58,11 +58,14 @@ struct onnx_parser
add_generic_op("Dropout", op::identity{}); add_generic_op("Dropout", op::identity{});
add_generic_op("Identity", op::identity{}); add_generic_op("Identity", op::identity{});
add_broadcastable_binary_op("Add", op::add{}); add_binary_op("Add", op::add{});
add_broadcastable_binary_op("Div", op::div{}); add_binary_op("Div", op::div{});
add_broadcastable_binary_op("Mul", op::mul{}); add_binary_op("Mul", op::mul{});
add_broadcastable_binary_op("Sub", op::sub{}); add_binary_op("Sub", op::sub{});
add_broadcastable_binary_op("Sum", op::add{});
add_mem_op("Sum", &onnx_parser::parse_sum);
add_mem_op("Max", &onnx_parser::parse_max);
add_mem_op("Min", &onnx_parser::parse_min);
add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler); add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu); add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
...@@ -98,8 +101,9 @@ struct onnx_parser ...@@ -98,8 +101,9 @@ struct onnx_parser
return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...); return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
}); });
} }
template <class T> template <class T>
void add_broadcastable_binary_op(std::string name, T x) void add_binary_op(std::string name, T x)
{ {
ops.emplace(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) { ops.emplace(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) {
if(args.size() != 2) if(args.size() != 2)
...@@ -118,7 +122,17 @@ struct onnx_parser ...@@ -118,7 +122,17 @@ struct onnx_parser
} }
return prog.add_instruction(x, args); return prog.add_instruction(x, args);
} }
else if(args[0]->get_shape() != args[1]->get_shape()) else
{
return add_broadcastable_binary_op(args[0], args[1], x);
}
});
}
template <class T>
instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x)
{
if(arg0->get_shape() != arg1->get_shape())
{ {
// Example: // Example:
// s0 = (3,2,4,5) and s1 = (2,1,1) // s0 = (3,2,4,5) and s1 = (2,1,1)
...@@ -134,8 +148,8 @@ struct onnx_parser ...@@ -134,8 +148,8 @@ struct onnx_parser
// output_lens = (3,2,7,5) // output_lens = (3,2,7,5)
// //
// Get lengths for both arguments // Get lengths for both arguments
const std::vector<std::size_t>* s0 = &args[0]->get_shape().lens(); const std::vector<std::size_t>* s0 = &arg0->get_shape().lens();
const std::vector<std::size_t>* s1 = &args[1]->get_shape().lens(); const std::vector<std::size_t>* s1 = &arg1->get_shape().lens();
// Make sure s0 is the smaller size // Make sure s0 is the smaller size
if(s0->size() > s1->size()) if(s0->size() > s1->size())
...@@ -149,15 +163,14 @@ struct onnx_parser ...@@ -149,15 +163,14 @@ struct onnx_parser
output_lens.begin() + offset, output_lens.begin() + offset,
[](auto a, auto b) { return std::max(a, b); }); [](auto a, auto b) { return std::max(a, b); });
auto l0 = prog.add_instruction(op::multibroadcast{output_lens}, args[0]); auto l0 = prog.add_instruction(op::multibroadcast{output_lens}, arg0);
auto l1 = prog.add_instruction(op::multibroadcast{output_lens}, args[1]); auto l1 = prog.add_instruction(op::multibroadcast{output_lens}, arg1);
return prog.add_instruction(x, l0, l1); return prog.add_instruction(x, l0, l1);
} }
else else
{ {
return prog.add_instruction(x, args); return prog.add_instruction(x, {arg0, arg1});
} }
});
} }
template <class T> template <class T>
...@@ -168,6 +181,48 @@ struct onnx_parser ...@@ -168,6 +181,48 @@ struct onnx_parser
}); });
} }
instruction_ref
parse_sum(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
auto curr_sum = args.front();
if (args.size() > 1)
{
for (auto it = std::next(args.begin()); it != args.end(); ++it)
{
curr_sum = add_broadcastable_binary_op(curr_sum, *it, op::add{});
}
}
return curr_sum;
}
instruction_ref
parse_max(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
auto curr_max = args.front();
if (args.size() > 1)
{
for (auto it = std::next(args.begin()); it != args.end(); ++it)
{
curr_max = add_broadcastable_binary_op(curr_max, *it, op::max{});
}
}
return curr_max;
}
instruction_ref
parse_min(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
auto curr_min = args.front();
if (args.size() > 1)
{
for (auto it = std::next(args.begin()); it != args.end(); ++it)
{
curr_min = add_broadcastable_binary_op(curr_min, *it, op::min{});
}
}
return curr_min;
}
instruction_ref instruction_ref
parse_softmax(const std::string&, const attribute_map&, std::vector<instruction_ref> args) parse_softmax(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{ {
......
...@@ -564,6 +564,25 @@ struct div_op ...@@ -564,6 +564,25 @@ struct div_op
} }
}; };
struct max_op
{
std::string name() const { return "max"; }
auto fcn() const
{
return [](auto x, auto y) { return std::max(x, y); };
}
};
struct min_op
{
std::string name() const { return "min"; }
auto fcn() const
{
return [](auto x, auto y) { return std::min(x, y); };
}
};
template <typename Op> template <typename Op>
struct cpu_binary struct cpu_binary
{ {
...@@ -633,6 +652,8 @@ struct cpu_apply ...@@ -633,6 +652,8 @@ struct cpu_apply
apply_map["sub"] = simple_op<cpu_binary<sub_op>>(); apply_map["sub"] = simple_op<cpu_binary<sub_op>>();
apply_map["mul"] = simple_op<cpu_binary<mul_op>>(); apply_map["mul"] = simple_op<cpu_binary<mul_op>>();
apply_map["div"] = simple_op<cpu_binary<div_op>>(); apply_map["div"] = simple_op<cpu_binary<div_op>>();
apply_map["max"] = simple_op<cpu_binary<max_op>>();
apply_map["min"] = simple_op<cpu_binary<min_op>>();
apply_map["softmax"] = simple_op<softmax2d>(); apply_map["softmax"] = simple_op<softmax2d>();
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment