Commit 1dc97206 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add some shape test cases to have better coverage

parent f7154cce
...@@ -47,14 +47,14 @@ struct broadcast ...@@ -47,14 +47,14 @@ struct broadcast
broadcast_lens.cbegin(), broadcast_lens.cend(), [&](auto x) { return x == 1; })) broadcast_lens.cbegin(), broadcast_lens.cend(), [&](auto x) { return x == 1; }))
{ {
if(axis != 0) if(axis != 0)
MIGRAPHX_THROW("when broadcasting tensor of size 1, axis should be 0"); MIGRAPHX_THROW("BROADCAST: when broadcasting tensor of size 1, axis should be 0");
return {t, broadcast_lens, std::move(bcast_strides)}; return {t, broadcast_lens, std::move(bcast_strides)};
} }
else else
{ {
assert(broadcast_lens.size() - axis >= input.lens().size()); assert(broadcast_lens.size() - axis >= input.lens().size());
if(!std::equal(input.lens().begin(), input.lens().end(), broadcast_lens.begin() + axis)) if(!std::equal(input.lens().begin(), input.lens().end(), broadcast_lens.begin() + axis))
MIGRAPHX_THROW("when broadcasting success sizes must match"); MIGRAPHX_THROW("BROADCAST: when broadcasting success sizes must match");
std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis); std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis);
return {t, broadcast_lens, std::move(bcast_strides)}; return {t, broadcast_lens, std::move(bcast_strides)};
} }
......
...@@ -229,6 +229,38 @@ TEST_CASE(multibroadcast) ...@@ -229,6 +229,38 @@ TEST_CASE(multibroadcast)
} }
} }
TEST_CASE(broadcast)
{
{
std::vector<std::size_t> lens{1, 1};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1}, {0, 0}},
migraphx::op::broadcast{0, lens},
input);
}
{
std::vector<std::size_t> lens{1, 1};
migraphx::shape input{migraphx::shape::float_type, {4, 1, 3}};
throws_shape(migraphx::op::broadcast{1, lens},
input);
}
{
std::vector<std::size_t> lens{3, 2, 4, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 3}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 2, 4, 3}, {0, 0, 3, 1}},
migraphx::op::broadcast{2, lens},
input);
}
{
std::vector<std::size_t> lens{3, 2, 4, 3};
migraphx::shape input{migraphx::shape::float_type, {4, 4}};
throws_shape(migraphx::op::broadcast{2, lens},
input);
}
}
TEST_CASE(gather) TEST_CASE(gather)
{ {
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment