Unverified Commit 8698cd2c authored by kahmed10's avatar kahmed10 Committed by GitHub
Browse files

fix shape tests for broadcast op (#698)



* change transpose func

* formatting

* fix tf file

* add tests, change broadcast

* formatting

* revert if statement

* add nonzero axis test

* formatting

* remove test and add test file

* fix test

* formatting

* add test for more coverage

* change error message

* change error message
Co-authored-by: default avatarPaul Fultz II <pfultz2@yahoo.com>
parent d13e052a
...@@ -46,25 +46,21 @@ struct broadcast ...@@ -46,25 +46,21 @@ struct broadcast
MIGRAPHX_THROW("BROADCAST : axis is out of range"); MIGRAPHX_THROW("BROADCAST : axis is out of range");
} }
if(std::all_of(
broadcast_lens.cbegin(), broadcast_lens.cend(), [&](auto x) { return x == 1; }))
{
return {t, broadcast_lens, std::move(bcast_strides)};
}
else
{
if(broadcast_lens.size() - axis < input.lens().size()) if(broadcast_lens.size() - axis < input.lens().size())
{ {
MIGRAPHX_THROW("BROADCAST: when broadcasting success sizes must match"); MIGRAPHX_THROW("BROADCAST: (broadcast ndims - axis) is less than input ndims");
} }
if(!std::equal(input.lens().begin(), input.lens().end(), broadcast_lens.begin() + axis)) if(!std::equal(input.lens().begin(), input.lens().end(), broadcast_lens.begin() + axis))
{ {
MIGRAPHX_THROW("BROADCAST: when broadcasting success sizes must match"); MIGRAPHX_THROW("BROADCAST: when broadcasting, succeeding sizes must match");
} }
std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis); std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis);
return {t, broadcast_lens, std::move(bcast_strides)};
} shape output{t, broadcast_lens, std::move(bcast_strides)};
if(output.elements() < input.elements())
MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to input size");
return output;
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
......
...@@ -392,7 +392,7 @@ TEST_CASE(broadcast) ...@@ -392,7 +392,7 @@ TEST_CASE(broadcast)
{ {
{ {
std::vector<std::size_t> lens{1, 1}; std::vector<std::size_t> lens{1, 1};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 3}}; migraphx::shape input{migraphx::shape::float_type, {1}, {0}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1}, {0, 0}}, expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1}, {0, 0}},
migraphx::make_op("broadcast", {{"axis", 0}, {"dims", lens}}), migraphx::make_op("broadcast", {{"axis", 0}, {"dims", lens}}),
input); input);
...@@ -400,10 +400,14 @@ TEST_CASE(broadcast) ...@@ -400,10 +400,14 @@ TEST_CASE(broadcast)
{ {
std::vector<std::size_t> lens{1, 1}; std::vector<std::size_t> lens{1, 1};
migraphx::shape input{migraphx::shape::float_type, {1}, {0}}; migraphx::shape input{migraphx::shape::float_type, {2}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1}, {0, 0}}, throws_shape(migraphx::op::broadcast{1, lens}, input);
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", lens}}), }
input);
{
std::vector<std::size_t> lens{2, 2};
migraphx::shape input{migraphx::shape::float_type, {1, 2}};
throws_shape(migraphx::op::broadcast{1, lens}, input);
} }
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment