Commit 8fce4170 authored by Scott Thornton's avatar Scott Thornton
Browse files

Added more tests + plus onnx parsing

parent 26c33a16
...@@ -776,14 +776,14 @@ struct multibroadcast ...@@ -776,14 +776,14 @@ struct multibroadcast
MIGRAPH_THROW("inputs dimensions should <= output size"); MIGRAPH_THROW("inputs dimensions should <= output size");
std::vector<size_t> bcast_strides(output_lens.size(), 0); std::vector<size_t> bcast_strides(output_lens.size(), 0);
auto extra = output_lens.size()-input.lens().size(); auto offset = output_lens.size()-input.lens().size();
if (input.lens().size() < output_lens.size()) if (input.lens().size() < output_lens.size())
{ {
for (std::size_t i = output_lens.size()-1; i > 0; i--) for (std::size_t i = output_lens.size()-1; i > 0; i--)
{ {
if (output_lens[i] == input.lens()[i-extra]) if (output_lens[i] == input.lens()[i-offset])
{ {
bcast_strides[i] = input.strides()[i-extra]; bcast_strides[i] = input.strides()[i-offset];
} }
} }
} }
......
...@@ -48,13 +48,14 @@ struct onnx_parser ...@@ -48,13 +48,14 @@ struct onnx_parser
onnx_parser() onnx_parser()
{ {
add_generic_op("Add", op::add{});
add_generic_op("Div", op::div{});
add_generic_op("MatMul", op::dot{}); add_generic_op("MatMul", op::dot{});
add_generic_op("Mul", op::mul{});
add_generic_op("Relu", op::relu{}); add_generic_op("Relu", op::relu{});
add_generic_op("Sub", op::sub{});
add_generic_op("Sum", op::add{}); add_broadcastable_binary_op("Add", op::add{});
add_broadcastable_binary_op("Div", op::div{});
add_broadcastable_binary_op("Mul", op::mul{});
add_broadcastable_binary_op("Sub", op::sub{});
add_broadcastable_binary_op("Sum", op::add{});
add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler); add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu); add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
...@@ -88,6 +89,70 @@ struct onnx_parser ...@@ -88,6 +89,70 @@ struct onnx_parser
return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...); return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
}); });
} }
template <class T>
void add_broadcastable_binary_op(std::string name, T x)
{
ops.emplace(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) {
if (args.size() != 2) MIGRAPH_THROW("binaGry operators should have 2 operands");
if(contains(attributes, "broadcast"))
{
uint64_t broadcasted = parse_value(attributes.at("broadcast")).at<uint64_t>();
if(broadcasted != 0)
{
uint64_t axis = (contains(attributes, "axis"))
? parse_value(attributes.at("axis")).at<uint64_t>()
: 0;
auto l =
prog.add_instruction(op::broadcast{axis, args[0]->get_shape()}, args[1]);
return prog.add_instruction(x, args[0], l);
}
}
else
{
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
//
// Get lengths for both arguments
const std::vector<std::size_t>& s0 = args[0]->get_shape().lens();
const std::vector<std::size_t>& s1 = args[1]->get_shape().lens();
// Copy the larger vector to output_lens
std::vector<std::size_t> output_lens =
(s0.size() >= s1.size()) ? s0 : s1;
if (s0.size() >= s1.size())
{
// s0 is bigger, so iterate over the range of s1
auto offset = s0.size() - s1.size();
for (std::size_t i = 0; i < s1.size(); i++)
{
output_lens[i+offset] = std::max(s0[i+offset], s1[i]);
}
}
else
{
// s1 is bigger, so iterate over the range of s0
auto offset = s1.size() - s0.size();
for (std::size_t i = 0; i < s0.size(); i++)
{
output_lens[i+offset] = std::max(s0[i], s1[i+offset]);
}
}
}
return prog.add_instruction(x, args);
});
}
template <class T> template <class T>
void add_generic_op(std::string name, T x) void add_generic_op(std::string name, T x)
...@@ -591,8 +656,10 @@ struct onnx_parser ...@@ -591,8 +656,10 @@ struct onnx_parser
} }
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
auto&& tensor_dims = t.tensor_type().shape().dim(); auto&& tensor_dims = t.tensor_type().shape().dim();
std::transform( std::transform(tensor_dims.begin(),
tensor_dims.begin(), tensor_dims.end(), std::back_inserter(dims), [](auto&& d) -> std::size_t { tensor_dims.end(),
std::back_inserter(dims),
[](auto&& d) -> std::size_t {
if(not d.has_dim_value()) if(not d.has_dim_value())
{ {
long default_batch_size = 1; // FIXME long default_batch_size = 1; // FIXME
......
...@@ -487,6 +487,7 @@ void broadcast_test() ...@@ -487,6 +487,7 @@ void broadcast_test()
} }
void add_broadcast_test() void add_broadcast_test()
{ {
{
migraph::program p; migraph::program p;
migraph::shape a_shape{migraph::shape::float_type, {2, 2, 3}}; migraph::shape a_shape{migraph::shape::float_type, {2, 2, 3}};
std::vector<float> a_data{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; std::vector<float> a_data{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
...@@ -504,6 +505,26 @@ void add_broadcast_test() ...@@ -504,6 +505,26 @@ void add_broadcast_test()
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 1, 2, 2, 3, 4, 4, 5, 6, 6, 7, 8}; std::vector<float> gold = {0, 1, 2, 2, 3, 4, 4, 5, 6, 6, 7, 8};
EXPECT(migraph::verify_range(results_vector, gold)); EXPECT(migraph::verify_range(results_vector, gold));
}
{
migraph::program p;
migraph::shape a_shape{migraph::shape::float_type, {2, 2, 3}};
std::vector<float> a_data{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
migraph::shape b_shape{migraph::shape::float_type, {2, 2, 1}};
std::vector<float> b_data{0, -1, -2, -3};
auto l1 = p.add_literal(migraph::literal{a_shape, a_data});
auto l2 = p.add_literal(migraph::literal{b_shape, b_data});
auto l3 = p.add_instruction(migraph::op::multibroadcast{{2, 2, 3}}, l1);
auto l4 = p.add_instruction(migraph::op::multibroadcast{{2, 2, 3}}, l2);
p.add_instruction(migraph::op::add{}, l3, l4);
p.compile(migraph::cpu::target{});
auto result = p.eval({});
EXPECT(result.get_shape().packed());
std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 1, 2, 2, 3, 4, 4, 5, 6, 6, 7, 8};
EXPECT(migraph::verify_range(results_vector, gold));
}
} }
void sub_test() void sub_test()
......
...@@ -149,30 +149,54 @@ void slice_shape() ...@@ -149,30 +149,54 @@ void slice_shape()
void multibroadcast_shape() void multibroadcast_shape()
{ {
{ {
std::vector<std::size_t> lens{4,2,5,3}; std::vector<std::size_t> lens{4, 2, 5, 3};
migraph::shape input{migraph::shape::float_type, {2,1,3}}; migraph::shape input{migraph::shape::float_type, {2, 1, 3}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {0,3,0,1}}, expect_shape(migraph::shape{migraph::shape::float_type, lens, {0, 3, 0, 1}},
migraph::op::multibroadcast{lens}, input); migraph::op::multibroadcast{lens},
input);
}
{
std::vector<std::size_t> lens{4, 2, 5, 3};
migraph::shape input{migraph::shape::float_type, {2, 1, 1}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {0, 1, 0, 0}},
migraph::op::multibroadcast{lens},
input);
}
{
std::vector<std::size_t> lens{4, 2, 5, 3};
migraph::shape input{migraph::shape::float_type, {5, 1}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {0, 0, 1, 0}},
migraph::op::multibroadcast{lens},
input);
}
{
std::vector<std::size_t> lens{4, 2, 5, 3};
migraph::shape input{migraph::shape::float_type, {4, 1, 1, 1}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {1, 0, 0, 0}},
migraph::op::multibroadcast{lens},
input);
} }
{ {
std::vector<std::size_t> lens{4,2,5,3}; std::vector<std::size_t> lens{4, 2, 5, 3};
migraph::shape input{migraph::shape::float_type, {2,1,1}}; migraph::shape input{migraph::shape::float_type, {3}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {0,1,0,0}}, expect_shape(migraph::shape{migraph::shape::float_type, lens, {0, 0, 0, 1}},
migraph::op::multibroadcast{lens}, input); migraph::op::multibroadcast{lens},
input);
} }
{ {
std::vector<std::size_t> lens{4,1,1,3}; std::vector<std::size_t> lens{4, 1, 1, 3};
migraph::shape input{migraph::shape::float_type, {4,1,1,1}}; migraph::shape input{migraph::shape::float_type, {4, 1, 1, 1}};
expect_shape(migraph::shape{migraph::shape::float_type, lens, {1,1,1,0}}, expect_shape(migraph::shape{migraph::shape::float_type, lens, {1, 1, 1, 0}},
migraph::op::multibroadcast{lens}, input); migraph::op::multibroadcast{lens},
input);
} }
{ {
std::vector<std::size_t> lens{4,1,3}; std::vector<std::size_t> lens{4, 1, 3};
migraph::shape input{migraph::shape::float_type, {4,1,1,1}}; migraph::shape input{migraph::shape::float_type, {4, 1, 1, 1}};
throws_shape(migraph::op::multibroadcast{lens}, input); throws_shape(migraph::op::multibroadcast{lens}, input);
} }
{ {
std::vector<std::size_t> lens{4,1,3}; std::vector<std::size_t> lens{4, 1, 3};
migraph::shape input{migraph::shape::float_type, {}}; migraph::shape input{migraph::shape::float_type, {}};
throws_shape(migraph::op::multibroadcast{lens}, input); throws_shape(migraph::op::multibroadcast{lens}, input);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment