Commit bb7c3a25 authored by charlie's avatar charlie
Browse files

Add op_shape tests

parent 2929d51c
...@@ -82,7 +82,7 @@ std::vector<std::size_t> compute_broadcasted_opt_lens(std::vector<std::size_t> s ...@@ -82,7 +82,7 @@ std::vector<std::size_t> compute_broadcasted_opt_lens(std::vector<std::size_t> s
{ {
return a; return a;
} }
else if(a == 1 or b == 1) else if((a == 1 or b == 1) and a != 0 and b != 0)
{ {
return std::max(a, b); return std::max(a, b);
} }
......
...@@ -59,7 +59,7 @@ struct multibroadcast ...@@ -59,7 +59,7 @@ struct multibroadcast
auto t = inputs.at(0).type(); auto t = inputs.at(0).type();
auto s0 = inputs.at(0); auto s0 = inputs.at(0);
if(s0.lens().empty()) if(s0.max_lens().empty())
{ {
MIGRAPHX_THROW("MULTIBROADCAST: inputs dimensions should be > 0"); MIGRAPHX_THROW("MULTIBROADCAST: inputs dimensions should be > 0");
} }
......
...@@ -1127,47 +1127,95 @@ TEST_CASE(multibroadcast) ...@@ -1127,47 +1127,95 @@ TEST_CASE(multibroadcast)
TEST_CASE(multibroadcast_2in) TEST_CASE(multibroadcast_2in)
{ {
{ {
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {2, 4, 2}, {2, 4, 0}}; migraphx::shape a_shape{migraphx::shape::float_type, {4, 4}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 4, 0}, {4, 4, 0}, {4, 4, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {4, 4, 0}}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
}
{
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {4, 4, 0}, {4, 4, 0}};
migraphx::shape a_shape{migraphx::shape::float_type, a}; migraphx::shape a_shape{migraphx::shape::float_type, a};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 4, 0}, {2, 4, 2}, {2, 4, 0}}; migraphx::shape b_shape{migraphx::shape::float_type, {4, 4}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {4, 4, 0}}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
}
{
migraphx::shape a_shape{migraphx::shape::float_type, {1, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{8, 8, 0}, {6, 6, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b}; migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(migraphx::shape{migraphx::shape::float_type, a}, expect_shape(migraphx::shape{migraphx::shape::float_type, {{8, 8, 0}, {6, 6, 0}}},
migraphx::make_op("multibroadcast"), migraphx::make_op("multibroadcast"),
a_shape, a_shape,
b_shape); b_shape);
} }
// weirdness begins
{ {
// dynamic_dimensions must be the same or one is {1, 1, 0} migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {2, 4, 0}, {2, 4, 0}}; std::vector<migraphx::shape::dynamic_dimension> b{{1, 3, 0}, {6, 6, 0}};
migraphx::shape a_shape{migraphx::shape::float_type, a};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 1, 0}, {2, 4, 0}, {1, 1, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b}; migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(migraphx::shape{migraphx::shape::float_type, a}, expect_shape(migraphx::shape{migraphx::shape::float_type, {{3, 3, 0}, {6, 6, 0}}},
migraphx::make_op("multibroadcast"), migraphx::make_op("multibroadcast"),
a_shape, a_shape,
b_shape); b_shape);
} }
{ {
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {2, 4, 0}, {2, 4, 0}}; migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
migraphx::shape a_shape{migraphx::shape::float_type, a}; std::vector<migraphx::shape::dynamic_dimension> b{{1, 4, 0}, {6, 6, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, {1, 6, 2}}; migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape( throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {6, 6, 0}, {2, 4, 0}}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
} }
{ {
migraphx::shape a_shape{migraphx::shape::float_type, {10, 3, 8}}; migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 4, 0}, {2, 4, 0}, {2, 4, 0}}; std::vector<migraphx::shape::dynamic_dimension> b{{1, 2, 0}, {6, 6, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b}; migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape( throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
migraphx::shape{migraphx::shape::float_type, {{10, 10, 0}, {3, 4, 0}, {8, 8, 0}}}, }
migraphx::make_op("multibroadcast"),
a_shape, // opt handling
b_shape); {
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 3, 2}, {6, 6, 6}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{3, 3, 0}, {6, 6, 6}}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
}
{
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 3, 3}, {6, 6, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{3, 3, 3}, {6, 6, 0}}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
}
{
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 3, 1}, {6, 6, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{3, 3, 3}, {6, 6, 0}}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
} }
// dyn-dyn not handled
{
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {2, 4, 2}, {2, 4, 0}};
migraphx::shape a_shape{migraphx::shape::float_type, a};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 4, 0}, {2, 4, 2}, {2, 4, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
}
// both inputs are fixed // both inputs are fixed
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment