Unverified Commit 8698cd2c authored by kahmed10's avatar kahmed10 Committed by GitHub
Browse files

fix shape tests for broadcast op (#698)



* change transpose func

* formatting

* fix tf file

* add tests, change broadcast

* formatting

* revert if statement

* add nonzero axis test

* formatting

* remove test and add test file

* fix test

* formatting

* add test for more coverage

* change error message

* change error message
Co-authored-by: default avatarPaul Fultz II <pfultz2@yahoo.com>
parent d13e052a
......@@ -46,25 +46,21 @@ struct broadcast
MIGRAPHX_THROW("BROADCAST : axis is out of range");
}
if(std::all_of(
broadcast_lens.cbegin(), broadcast_lens.cend(), [&](auto x) { return x == 1; }))
if(broadcast_lens.size() - axis < input.lens().size())
{
return {t, broadcast_lens, std::move(bcast_strides)};
MIGRAPHX_THROW("BROADCAST: (broadcast ndims - axis) is less than input ndims");
}
else
{
if(broadcast_lens.size() - axis < input.lens().size())
{
MIGRAPHX_THROW("BROADCAST: when broadcasting success sizes must match");
}
if(!std::equal(input.lens().begin(), input.lens().end(), broadcast_lens.begin() + axis))
{
MIGRAPHX_THROW("BROADCAST: when broadcasting success sizes must match");
}
std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis);
return {t, broadcast_lens, std::move(bcast_strides)};
if(!std::equal(input.lens().begin(), input.lens().end(), broadcast_lens.begin() + axis))
{
MIGRAPHX_THROW("BROADCAST: when broadcasting, succeeding sizes must match");
}
std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis);
shape output{t, broadcast_lens, std::move(bcast_strides)};
if(output.elements() < input.elements())
MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to input size");
return output;
}
argument compute(shape output_shape, std::vector<argument> args) const
{
......
......@@ -392,7 +392,7 @@ TEST_CASE(broadcast)
{
{
std::vector<std::size_t> lens{1, 1};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 3}};
migraphx::shape input{migraphx::shape::float_type, {1}, {0}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1}, {0, 0}},
migraphx::make_op("broadcast", {{"axis", 0}, {"dims", lens}}),
input);
......@@ -400,10 +400,14 @@ TEST_CASE(broadcast)
{
std::vector<std::size_t> lens{1, 1};
migraphx::shape input{migraphx::shape::float_type, {1}, {0}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1}, {0, 0}},
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", lens}}),
input);
migraphx::shape input{migraphx::shape::float_type, {2}};
throws_shape(migraphx::op::broadcast{1, lens}, input);
}
{
std::vector<std::size_t> lens{2, 2};
migraphx::shape input{migraphx::shape::float_type, {1, 2}};
throws_shape(migraphx::op::broadcast{1, lens}, input);
}
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment