Commit 6acbd4e4 authored by charlie's avatar charlie
Browse files

Merge branch 'dyn_unsqueeze' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_squeeze

parents d229d3e1 fd2921b5
...@@ -81,7 +81,7 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha ...@@ -81,7 +81,7 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
} }
auto offset = s1.ndim() - s0.ndim(); auto offset = s1.ndim() - s0.ndim();
std::vector<shape::dynamic_dimension> out_dims(s1.dyn_dims()); std::vector<shape::dynamic_dimension> out_dims(s1.dyn_dims());
std::vector<shape::dynamic_dimension> one_dyn_dims{{1, 1, 0}, {1, 1, 1}}; shape::dynamic_dimension one_dyn_dim{1, 1, 0};
std::transform( std::transform(
s0.dyn_dims().cbegin(), s0.dyn_dims().cbegin(),
s0.dyn_dims().cend(), s0.dyn_dims().cend(),
...@@ -92,7 +92,7 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha ...@@ -92,7 +92,7 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
{ {
return a; return a;
} }
else if(contains(one_dyn_dims, a) or contains(one_dyn_dims, b)) else if(a == one_dyn_dim or b == one_dyn_dim)
{ {
// setting opt to 0, may need to be changed // setting opt to 0, may need to be changed
return shape::dynamic_dimension{std::max(a.min, b.min), std::max(a.max, b.max), 0}; return shape::dynamic_dimension{std::max(a.min, b.min), std::max(a.max, b.max), 0};
......
...@@ -70,7 +70,8 @@ struct broadcast ...@@ -70,7 +70,8 @@ struct broadcast
// value of axis anymore // value of axis anymore
if(axis >= broadcast_lens.size()) if(axis >= broadcast_lens.size())
{ {
MIGRAPHX_THROW("BROADCAST : axis is out of range"); MIGRAPHX_THROW("BROADCAST : axis " + migraphx::to_string(axis) +
" is out of range");
} }
if(broadcast_lens.size() - axis < s0.lens().size()) if(broadcast_lens.size() - axis < s0.lens().size())
{ {
...@@ -107,21 +108,28 @@ struct broadcast ...@@ -107,21 +108,28 @@ struct broadcast
} }
if(axis >= s1.ndim()) if(axis >= s1.ndim())
{ {
MIGRAPHX_THROW("BROADCAST_2in: axis is out of range"); MIGRAPHX_THROW("BROADCAST_2in: axis " + migraphx::to_string(axis) +
" is out of range");
} }
if(s1.dynamic()) if(s1.dynamic())
{ {
s0 = s0.to_dynamic(); s0 = s0.to_dynamic();
if(s0.dyn_dims()[0] != s1.dyn_dims()[axis]) if(s0.dyn_dims()[0] != s1.dyn_dims()[axis])
{
MIGRAPHX_THROW("BROADCAST_2in: s0 length doesn't match with dynamic s1 axis " MIGRAPHX_THROW("BROADCAST_2in: s0 length doesn't match with dynamic s1 axis "
"dimension length"); "dimension length (" +
migraphx::to_string(s0.dyn_dims()[0]) +
" != " + migraphx::to_string(s1.dyn_dims()[axis]) + ")");
}
return s1; return s1;
} }
if(s0.lens()[0] != s1.lens()[axis]) if(s0.lens()[0] != s1.lens()[axis])
{ {
MIGRAPHX_THROW( MIGRAPHX_THROW("BROADCAST_2in: s0 length doesn't match with static s1 axis "
"BROADCAST_2in: s0 length doesn't match with static s1 axis dimension length"); "dimension length (" +
migraphx::to_string(s0.dyn_dims()[0]) +
" != " + migraphx::to_string(s1.dyn_dims()[axis]) + ")");
} }
std::vector<size_t> bcast_strides(s1.ndim(), 0); std::vector<size_t> bcast_strides(s1.ndim(), 0);
std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis); std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis);
......
...@@ -507,8 +507,8 @@ bool shape::dynamic_dimension::has_optimal() const { return opt != 0; } ...@@ -507,8 +507,8 @@ bool shape::dynamic_dimension::has_optimal() const { return opt != 0; }
bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y) bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
{ {
// don't check opt if both are fixed // don't check opt if both are fixed
bool check_opt = not(x.is_fixed() and y.is_fixed()); return (x.min == y.min and x.max == y.max and
return (x.min == y.min and x.max == y.max and (check_opt ? x.opt == y.opt : true)); ((x.is_fixed() and y.is_fixed()) or (x.opt == y.opt)));
} }
bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y) bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
......
...@@ -144,9 +144,8 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over) ...@@ -144,9 +144,8 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over)
std::size_t compute_block_size(std::size_t n, std::size_t max_block_size) std::size_t compute_block_size(std::size_t n, std::size_t max_block_size)
{ {
const std::size_t min_block_size = 64; const std::size_t min_block_size = 64;
const std::size_t base_block_size = 32; auto block_size = (((n - 1) / min_block_size + 1)) * min_block_size;
auto block_size = (((n - 1) / base_block_size + 1)) * base_block_size;
return std::min(std::max(min_block_size, block_size), max_block_size); return std::min(std::max(min_block_size, block_size), max_block_size);
} }
......
...@@ -1177,160 +1177,180 @@ TEST_CASE(multibroadcast) ...@@ -1177,160 +1177,180 @@ TEST_CASE(multibroadcast)
} }
} }
TEST_CASE(multibroadcast_2in) TEST_CASE(multibroadcast_2in_static_dyn0)
{ {
// static-dyn migraphx::shape a_shape{migraphx::shape::float_type, {4, 4}};
{ std::vector<migraphx::shape::dynamic_dimension> b{{1, 4, 0}, {4, 4, 4}, {4, 4, 0}};
migraphx::shape a_shape{migraphx::shape::float_type, {4, 4}}; migraphx::shape b_shape{migraphx::shape::float_type, b};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 4, 0}, {4, 4, 4}, {4, 4, 0}}; expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {4, 4, 0}}},
migraphx::shape b_shape{migraphx::shape::float_type, b}; migraphx::make_op("multibroadcast"),
expect_shape( a_shape,
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {4, 4, 0}}}, b_shape);
migraphx::make_op("multibroadcast"), expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {4, 4, 0}}},
a_shape, migraphx::make_op("multibroadcast"),
b_shape); b_shape,
expect_shape( a_shape);
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {4, 4, 0}}}, }
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
}
{
migraphx::shape a_shape{migraphx::shape::float_type, {1, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{8, 8, 0}, {6, 6, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{8, 8, 0}, {6, 6, 0}}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
expect_shape(migraphx::shape{migraphx::shape::float_type, {{8, 8, 0}, {6, 6, 0}}},
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
}
{
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 3, 0}, {6, 6, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
}
{
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 4, 0}, {6, 6, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
}
{
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 2, 0}, {6, 6, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
}
// dyn-dyn TEST_CASE(multibroadcast_2in_static_dyn1)
{ {
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {2, 4, 2}, {2, 4, 0}}; migraphx::shape a_shape{migraphx::shape::float_type, {1, 6}};
migraphx::shape a_shape{migraphx::shape::float_type, a}; std::vector<migraphx::shape::dynamic_dimension> b{{8, 8, 0}, {6, 6, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 4, 2}, {2, 4, 0}}; migraphx::shape b_shape{migraphx::shape::float_type, b};
migraphx::shape b_shape{migraphx::shape::float_type, b}; expect_shape(migraphx::shape{migraphx::shape::float_type, {{8, 8, 0}, {6, 6, 0}}},
expect_shape( migraphx::make_op("multibroadcast"),
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {2, 4, 2}, {2, 4, 0}}}, a_shape,
migraphx::make_op("multibroadcast"), b_shape);
a_shape, expect_shape(migraphx::shape{migraphx::shape::float_type, {{8, 8, 0}, {6, 6, 0}}},
b_shape); migraphx::make_op("multibroadcast"),
expect_shape( b_shape,
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {2, 4, 2}, {2, 4, 0}}}, a_shape);
migraphx::make_op("multibroadcast"), }
b_shape,
a_shape); TEST_CASE(multibroadcast_2in_static_dyn_error0)
} {
{ // doesn't match on first dimension
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {2, 4, 2}, {2, 4, 0}}; migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
migraphx::shape a_shape{migraphx::shape::float_type, a}; std::vector<migraphx::shape::dynamic_dimension> b{{1, 3, 0}, {6, 6, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 5, 2}, {2, 4, 0}}; migraphx::shape b_shape{migraphx::shape::float_type, b};
migraphx::shape b_shape{migraphx::shape::float_type, b}; throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape); throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape); }
}
{ TEST_CASE(multibroadcast_2in_static_dyn_error1)
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {2, 4, 2}, {2, 4, 0}}; {
migraphx::shape a_shape{migraphx::shape::float_type, a}; // doesn't match on first dimension
std::vector<migraphx::shape::dynamic_dimension> b{{2, 4, 3}, {2, 4, 0}}; migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
migraphx::shape b_shape{migraphx::shape::float_type, b}; std::vector<migraphx::shape::dynamic_dimension> b{{1, 4, 0}, {6, 6, 0}};
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape); migraphx::shape b_shape{migraphx::shape::float_type, b};
throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape); throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
} throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
}
// static-static
{ TEST_CASE(multibroadcast_2in_static_dyn_error2)
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}}; {
migraphx::shape b_shape{migraphx::shape::float_type, {3, 6}}; // doesn't match on first dimension
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 6}}, migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
migraphx::make_op("multibroadcast"), std::vector<migraphx::shape::dynamic_dimension> b{{1, 2, 0}, {6, 6, 0}};
a_shape, migraphx::shape b_shape{migraphx::shape::float_type, b};
b_shape); throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 6}}, throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
migraphx::make_op("multibroadcast"), }
b_shape,
a_shape); TEST_CASE(multibroadcast_2in_dyn_dyn0)
} {
{ std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {2, 4, 2}, {2, 4, 0}};
migraphx::shape a_shape{migraphx::shape::float_type, {1, 8}}; migraphx::shape a_shape{migraphx::shape::float_type, a};
migraphx::shape b_shape{migraphx::shape::float_type, {4, 8}}; std::vector<migraphx::shape::dynamic_dimension> b{{2, 4, 2}, {2, 4, 0}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}, {0, 1}}, migraphx::shape b_shape{migraphx::shape::float_type, b};
migraphx::make_op("multibroadcast"), expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {2, 4, 2}, {2, 4, 0}}},
a_shape, migraphx::make_op("multibroadcast"),
b_shape); a_shape,
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}, {8, 1}}, b_shape);
migraphx::make_op("multibroadcast"), expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {2, 4, 2}, {2, 4, 0}}},
b_shape, migraphx::make_op("multibroadcast"),
a_shape); b_shape,
} a_shape);
{ }
migraphx::shape a_shape{migraphx::shape::float_type, {8}};
migraphx::shape b_shape{migraphx::shape::float_type, {4, 4, 1}}; TEST_CASE(multibroadcast_2in_dyn_dyn_error0)
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 4, 8}, {0, 0, 1}}, {
migraphx::make_op("multibroadcast"), // max doesn't match on second dimension of a
a_shape, std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {2, 4, 2}, {2, 4, 0}};
b_shape); migraphx::shape a_shape{migraphx::shape::float_type, a};
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 4, 8}, {4, 1, 0}}, std::vector<migraphx::shape::dynamic_dimension> b{{2, 5, 2}, {2, 4, 0}};
migraphx::make_op("multibroadcast"), migraphx::shape b_shape{migraphx::shape::float_type, b};
b_shape, throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
a_shape); throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
} }
{
migraphx::shape a_shape{migraphx::shape::float_type, {3, 4, 4}}; TEST_CASE(multibroadcast_2in_dyn_dyn_error1)
migraphx::shape b_shape{migraphx::shape::float_type, {4, 1}}; {
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 4, 4}, {16, 4, 1}}, // opt doesn't match on second dimension of a
migraphx::make_op("multibroadcast"), std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {2, 4, 2}, {2, 4, 0}};
a_shape, migraphx::shape a_shape{migraphx::shape::float_type, a};
b_shape); std::vector<migraphx::shape::dynamic_dimension> b{{2, 4, 3}, {2, 4, 0}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 4, 4}, {0, 1, 0}}, migraphx::shape b_shape{migraphx::shape::float_type, b};
migraphx::make_op("multibroadcast"), throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
b_shape, throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
a_shape); }
}
{ TEST_CASE(multibroadcast_2in_static_static0)
migraphx::shape a_shape{migraphx::shape::float_type, {3, 1, 4}}; {
migraphx::shape b_shape{migraphx::shape::float_type, {4, 1}}; migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 4, 4}, {4, 0, 1}}, migraphx::shape b_shape{migraphx::shape::float_type, {3, 6}};
migraphx::make_op("multibroadcast"), expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 6}},
a_shape, migraphx::make_op("multibroadcast"),
b_shape); a_shape,
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 4, 4}, {0, 1, 0}}, b_shape);
migraphx::make_op("multibroadcast"), expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 6}},
b_shape, migraphx::make_op("multibroadcast"),
a_shape); b_shape,
} a_shape);
{ }
migraphx::shape a_shape{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape b_shape{migraphx::shape::float_type, {4, 3}}; TEST_CASE(multibroadcast_2in_static_static1)
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape); {
throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape); migraphx::shape a_shape{migraphx::shape::float_type, {1, 8}};
} migraphx::shape b_shape{migraphx::shape::float_type, {4, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}, {0, 1}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}, {8, 1}},
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
}
TEST_CASE(multibroadcast_2in_static_static2)
{
migraphx::shape a_shape{migraphx::shape::float_type, {8}};
migraphx::shape b_shape{migraphx::shape::float_type, {4, 4, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 4, 8}, {0, 0, 1}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 4, 8}, {4, 1, 0}},
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
}
TEST_CASE(multibroadcast_2in_static_static3)
{
migraphx::shape a_shape{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape b_shape{migraphx::shape::float_type, {4, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 4, 4}, {16, 4, 1}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 4, 4}, {0, 1, 0}},
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
}
TEST_CASE(multibroadcast_2in_static_static4)
{
migraphx::shape a_shape{migraphx::shape::float_type, {3, 1, 4}};
migraphx::shape b_shape{migraphx::shape::float_type, {4, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 4, 4}, {4, 0, 1}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 4, 4}, {0, 1, 0}},
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
}
TEST_CASE(multibroadcast_2in_static_static_error0)
{
migraphx::shape a_shape{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape b_shape{migraphx::shape::float_type, {4, 3}};
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
} }
TEST_CASE(multinomial) TEST_CASE(multinomial)
...@@ -2055,6 +2075,17 @@ TEST_CASE(test_unsqueeze_dyn) ...@@ -2055,6 +2075,17 @@ TEST_CASE(test_unsqueeze_dyn)
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {2, 4}}, {"steps", {2}}}), s1); throws_shape(migraphx::make_op("unsqueeze", {{"axes", {2, 4}}, {"steps", {2}}}), s1);
} }
TEST_CASE(test_unsqueeze_dyn_neg_axes)
{
migraphx::shape s1{migraphx::shape::float_type, {{1, 4, 3}, {2, 5, 0}, {3, 3, 0}}};
migraphx::shape s2{migraphx::shape::float_type, {{1, 4, 3}, {2, 5, 0}, {1, 1, 0}, {3, 3, 0}}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s1);
migraphx::shape s3{migraphx::shape::float_type,
{{1, 4, 3}, {2, 5, 0}, {1, 1, 0}, {3, 3, 0}, {1, 1, 0}}};
expect_shape(s3, migraphx::make_op("unsqueeze", {{"axes", {-1, -3}}}), s1);
}
TEST_CASE(test_unsqueeze_step) TEST_CASE(test_unsqueeze_step)
{ {
migraphx::shape s1{migraphx::shape::float_type, {4, 5, 12}}; migraphx::shape s1{migraphx::shape::float_type, {4, 5, 12}};
......
...@@ -1930,7 +1930,7 @@ TEST_CASE(div_test) ...@@ -1930,7 +1930,7 @@ TEST_CASE(div_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(div_dynamic_test) TEST_CASE(div_dyn_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
......
...@@ -185,6 +185,31 @@ TEST_CASE(test_shape_packed) ...@@ -185,6 +185,31 @@ TEST_CASE(test_shape_packed)
EXPECT(not s.broadcasted()); EXPECT(not s.broadcasted());
} }
TEST_CASE(test_shape_ndim_static)
{
migraphx::shape s0{migraphx::shape::float_type, {2, 2}};
EXPECT(s0.ndim() == 2);
migraphx::shape s1{migraphx::shape::float_type, {1, 2, 4, 4}};
EXPECT(s1.ndim() == 4);
migraphx::shape s2{migraphx::shape::float_type, {2, 4, 4, 1, 3}};
EXPECT(s1.ndim() == 5);
}
TEST_CASE(test_shape_ndim_dyn)
{
migraphx::shape s0{migraphx::shape::float_type, {{2, 2, 0}, {2, 2, 0}}};
EXPECT(s0.ndim() == 2);
migraphx::shape s1{migraphx::shape::float_type, {{1, 1, 0}, {2, 4, 0}, {2, 4, 0}, {2, 4, 0}}};
EXPECT(s1.ndim() == 4);
migraphx::shape s2{migraphx::shape::float_type,
{{1, 1, 0}, {2, 4, 0}, {2, 4, 0}, {1, 1, 1}, {3, 3, 0}}};
EXPECT(s1.ndim() == 5);
}
TEST_CASE(test_shape_non_packed_single_dim) TEST_CASE(test_shape_non_packed_single_dim)
{ {
migraphx::shape s{migraphx::shape::float_type, {1, 64, 35, 35}, {156800, 1225, 35, 1}}; migraphx::shape s{migraphx::shape::float_type, {1, 64, 35, 35}, {156800, 1225, 35, 1}};
...@@ -212,6 +237,21 @@ TEST_CASE(test_shape_transposed2) ...@@ -212,6 +237,21 @@ TEST_CASE(test_shape_transposed2)
EXPECT(not s.broadcasted()); EXPECT(not s.broadcasted());
} }
TEST_CASE(test_shape_static_to_dynamic)
{
migraphx::shape s0{migraphx::shape::float_type, {1, 2, 4, 4}};
migraphx::shape s1 = s0.to_dynamic();
migraphx::shape s2{migraphx::shape::float_type, {{1, 1, 0}, {2, 2, 0}, {4, 4, 0}, {4, 4, 0}}};
EXPECT(s1 == s2);
}
TEST_CASE(test_shape_dyn_to_dynamic)
{
migraphx::shape s0{migraphx::shape::float_type, {{1, 1, 0}, {2, 4, 0}, {2, 4, 0}, {2, 4, 0}}};
migraphx::shape s1 = s0.to_dynamic();
EXPECT(s0 == s1);
}
TEST_CASE(test_shape_overlap) TEST_CASE(test_shape_overlap)
{ {
migraphx::shape s{migraphx::shape::float_type, {2, 2, 3}, {6, 3, 2}}; migraphx::shape s{migraphx::shape::float_type, {2, 2, 3}, {6, 3, 2}};
......
...@@ -51,7 +51,7 @@ template struct test_reduce_op_large<migraphx::op::reduce_min, 1, migraphx::shap ...@@ -51,7 +51,7 @@ template struct test_reduce_op_large<migraphx::op::reduce_min, 1, migraphx::shap
template struct test_reduce_op_large<migraphx::op::reduce_prod, 2, migraphx::shape::float_type>; template struct test_reduce_op_large<migraphx::op::reduce_prod, 2, migraphx::shape::float_type>;
template struct test_reduce_op_large<migraphx::op::reduce_sum, 1, migraphx::shape::float_type>; template struct test_reduce_op_large<migraphx::op::reduce_sum, 1, migraphx::shape::float_type>;
struct test_reduce_mean : verify_program<test_reduce_mean> struct test_reduce_mean_1 : verify_program<test_reduce_mean_1>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -63,3 +63,16 @@ struct test_reduce_mean : verify_program<test_reduce_mean> ...@@ -63,3 +63,16 @@ struct test_reduce_mean : verify_program<test_reduce_mean>
return p; return p;
}; };
}; };
struct test_reduce_mean_2 : verify_program<test_reduce_mean_2>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {336, 400}};
auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::op::reduce_mean{{1}}, x);
return p;
};
};
...@@ -21,5 +21,5 @@ ...@@ -21,5 +21,5 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE. # THE SOFTWARE.
##################################################################################### #####################################################################################
numpy==1.18.5 numpy==1.21.6
onnxruntime==1.10.0 onnxruntime==1.10.0
...@@ -57,7 +57,7 @@ echo "Dependencies are installed at $PREFIX" ...@@ -57,7 +57,7 @@ echo "Dependencies are installed at $PREFIX"
rbuild prepare -d $PREFIX -s develop rbuild prepare -d $PREFIX -s develop
# install onnx package for unit tests # install onnx package for unit tests
pip3 install onnx==1.8.1 numpy==1.18.5 typing==3.7.4 pytest==6.0.1 packaging==16.8 pip3 install onnx==1.8.1 numpy==1.21.6 typing==3.7.4 pytest==6.0.1 packaging==16.8
# pin version of protobuf in Python for onnx runtime unit tests # pin version of protobuf in Python for onnx runtime unit tests
pip3 install protobuf==3.20.0 pip3 install protobuf==3.20.0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment