Commit bb7c3a25 authored by charlie's avatar charlie
Browse files

Add op_shape tests

parent 2929d51c
......@@ -82,7 +82,7 @@ std::vector<std::size_t> compute_broadcasted_opt_lens(std::vector<std::size_t> s
{
return a;
}
else if(a == 1 or b == 1)
else if((a == 1 or b == 1) and a != 0 and b != 0)
{
return std::max(a, b);
}
......
......@@ -59,7 +59,7 @@ struct multibroadcast
auto t = inputs.at(0).type();
auto s0 = inputs.at(0);
if(s0.lens().empty())
if(s0.max_lens().empty())
{
MIGRAPHX_THROW("MULTIBROADCAST: inputs dimensions should be > 0");
}
......
......@@ -1127,47 +1127,95 @@ TEST_CASE(multibroadcast)
TEST_CASE(multibroadcast_2in)
{
{
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {2, 4, 2}, {2, 4, 0}};
migraphx::shape a_shape{migraphx::shape::float_type, {4, 4}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 4, 0}, {4, 4, 0}, {4, 4, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {4, 4, 0}}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
}
{
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {4, 4, 0}, {4, 4, 0}};
migraphx::shape a_shape{migraphx::shape::float_type, a};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 4, 0}, {2, 4, 2}, {2, 4, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, {4, 4}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {4, 4, 0}}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
}
{
migraphx::shape a_shape{migraphx::shape::float_type, {1, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{8, 8, 0}, {6, 6, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(migraphx::shape{migraphx::shape::float_type, a},
expect_shape(migraphx::shape{migraphx::shape::float_type, {{8, 8, 0}, {6, 6, 0}}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
}
// weirdness begins
{
// dynamic_dimensions must be the same or one is {1, 1, 0}
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {2, 4, 0}, {2, 4, 0}};
migraphx::shape a_shape{migraphx::shape::float_type, a};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 1, 0}, {2, 4, 0}, {1, 1, 0}};
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 3, 0}, {6, 6, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(migraphx::shape{migraphx::shape::float_type, a},
expect_shape(migraphx::shape{migraphx::shape::float_type, {{3, 3, 0}, {6, 6, 0}}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
}
{
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {2, 4, 0}, {2, 4, 0}};
migraphx::shape a_shape{migraphx::shape::float_type, a};
migraphx::shape b_shape{migraphx::shape::float_type, {1, 6, 2}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {6, 6, 0}, {2, 4, 0}}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 4, 0}, {6, 6, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
}
{
migraphx::shape a_shape{migraphx::shape::float_type, {10, 3, 8}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 4, 0}, {2, 4, 0}, {2, 4, 0}};
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 2, 0}, {6, 6, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{10, 10, 0}, {3, 4, 0}, {8, 8, 0}}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
}
// opt handling
{
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 3, 2}, {6, 6, 6}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{3, 3, 0}, {6, 6, 6}}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
}
{
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 3, 3}, {6, 6, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{3, 3, 3}, {6, 6, 0}}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
}
{
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 3, 1}, {6, 6, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{3, 3, 3}, {6, 6, 0}}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
}
// dyn-dyn not handled
{
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {2, 4, 2}, {2, 4, 0}};
migraphx::shape a_shape{migraphx::shape::float_type, a};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 4, 0}, {2, 4, 2}, {2, 4, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
}
// both inputs are fixed
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment