Commit 6acbd4e4 authored by charlie's avatar charlie
Browse files

Merge branch 'dyn_unsqueeze' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_squeeze

parents d229d3e1 fd2921b5
......@@ -81,7 +81,7 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
}
auto offset = s1.ndim() - s0.ndim();
std::vector<shape::dynamic_dimension> out_dims(s1.dyn_dims());
std::vector<shape::dynamic_dimension> one_dyn_dims{{1, 1, 0}, {1, 1, 1}};
shape::dynamic_dimension one_dyn_dim{1, 1, 0};
std::transform(
s0.dyn_dims().cbegin(),
s0.dyn_dims().cend(),
......@@ -92,7 +92,7 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
{
return a;
}
else if(contains(one_dyn_dims, a) or contains(one_dyn_dims, b))
else if(a == one_dyn_dim or b == one_dyn_dim)
{
// setting opt to 0, may need to be changed
return shape::dynamic_dimension{std::max(a.min, b.min), std::max(a.max, b.max), 0};
......
......@@ -70,7 +70,8 @@ struct broadcast
// value of axis anymore
if(axis >= broadcast_lens.size())
{
MIGRAPHX_THROW("BROADCAST : axis is out of range");
MIGRAPHX_THROW("BROADCAST : axis " + migraphx::to_string(axis) +
" is out of range");
}
if(broadcast_lens.size() - axis < s0.lens().size())
{
......@@ -107,21 +108,28 @@ struct broadcast
}
if(axis >= s1.ndim())
{
MIGRAPHX_THROW("BROADCAST_2in: axis is out of range");
MIGRAPHX_THROW("BROADCAST_2in: axis " + migraphx::to_string(axis) +
" is out of range");
}
if(s1.dynamic())
{
s0 = s0.to_dynamic();
if(s0.dyn_dims()[0] != s1.dyn_dims()[axis])
{
MIGRAPHX_THROW("BROADCAST_2in: s0 length doesn't match with dynamic s1 axis "
"dimension length");
"dimension length (" +
migraphx::to_string(s0.dyn_dims()[0]) +
" != " + migraphx::to_string(s1.dyn_dims()[axis]) + ")");
}
return s1;
}
if(s0.lens()[0] != s1.lens()[axis])
{
MIGRAPHX_THROW(
"BROADCAST_2in: s0 length doesn't match with static s1 axis dimension length");
MIGRAPHX_THROW("BROADCAST_2in: s0 length doesn't match with static s1 axis "
"dimension length (" +
migraphx::to_string(s0.dyn_dims()[0]) +
" != " + migraphx::to_string(s1.dyn_dims()[axis]) + ")");
}
std::vector<size_t> bcast_strides(s1.ndim(), 0);
std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis);
......
......@@ -507,8 +507,8 @@ bool shape::dynamic_dimension::has_optimal() const { return opt != 0; }
bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
{
// don't check opt if both are fixed
bool check_opt = not(x.is_fixed() and y.is_fixed());
return (x.min == y.min and x.max == y.max and (check_opt ? x.opt == y.opt : true));
return (x.min == y.min and x.max == y.max and
((x.is_fixed() and y.is_fixed()) or (x.opt == y.opt)));
}
bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
......
......@@ -145,8 +145,7 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over)
std::size_t compute_block_size(std::size_t n, std::size_t max_block_size)
{
const std::size_t min_block_size = 64;
const std::size_t base_block_size = 32;
auto block_size = (((n - 1) / base_block_size + 1)) * base_block_size;
auto block_size = (((n - 1) / min_block_size + 1)) * min_block_size;
return std::min(std::max(min_block_size, block_size), max_block_size);
}
......
......@@ -1177,25 +1177,23 @@ TEST_CASE(multibroadcast)
}
}
TEST_CASE(multibroadcast_2in)
TEST_CASE(multibroadcast_2in_static_dyn0)
{
// static-dyn
{
migraphx::shape a_shape{migraphx::shape::float_type, {4, 4}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 4, 0}, {4, 4, 4}, {4, 4, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {4, 4, 0}}},
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {4, 4, 0}}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {4, 4, 0}}},
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {4, 4, 0}}},
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
}
{
}
TEST_CASE(multibroadcast_2in_static_dyn1)
{
migraphx::shape a_shape{migraphx::shape::float_type, {1, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{8, 8, 0}, {6, 6, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
......@@ -1207,65 +1205,78 @@ TEST_CASE(multibroadcast_2in)
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
}
{
}
TEST_CASE(multibroadcast_2in_static_dyn_error0)
{
// doesn't match on first dimension
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 3, 0}, {6, 6, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
}
{
}
TEST_CASE(multibroadcast_2in_static_dyn_error1)
{
// doesn't match on first dimension
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 4, 0}, {6, 6, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
}
{
}
TEST_CASE(multibroadcast_2in_static_dyn_error2)
{
// doesn't match on first dimension
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 2, 0}, {6, 6, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
}
}
// dyn-dyn
{
TEST_CASE(multibroadcast_2in_dyn_dyn0)
{
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {2, 4, 2}, {2, 4, 0}};
migraphx::shape a_shape{migraphx::shape::float_type, a};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 4, 2}, {2, 4, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {2, 4, 2}, {2, 4, 0}}},
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {2, 4, 2}, {2, 4, 0}}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {2, 4, 2}, {2, 4, 0}}},
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {2, 4, 2}, {2, 4, 0}}},
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
}
{
}
TEST_CASE(multibroadcast_2in_dyn_dyn_error0)
{
// max doesn't match on second dimension of a
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {2, 4, 2}, {2, 4, 0}};
migraphx::shape a_shape{migraphx::shape::float_type, a};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 5, 2}, {2, 4, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
}
{
}
TEST_CASE(multibroadcast_2in_dyn_dyn_error1)
{
// opt doesn't match on second dimension of a
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {2, 4, 2}, {2, 4, 0}};
migraphx::shape a_shape{migraphx::shape::float_type, a};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 4, 3}, {2, 4, 0}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
}
}
// static-static
{
TEST_CASE(multibroadcast_2in_static_static0)
{
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
migraphx::shape b_shape{migraphx::shape::float_type, {3, 6}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 6}},
......@@ -1276,8 +1287,10 @@ TEST_CASE(multibroadcast_2in)
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
}
{
}
TEST_CASE(multibroadcast_2in_static_static1)
{
migraphx::shape a_shape{migraphx::shape::float_type, {1, 8}};
migraphx::shape b_shape{migraphx::shape::float_type, {4, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}, {0, 1}},
......@@ -1288,8 +1301,10 @@ TEST_CASE(multibroadcast_2in)
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
}
{
}
TEST_CASE(multibroadcast_2in_static_static2)
{
migraphx::shape a_shape{migraphx::shape::float_type, {8}};
migraphx::shape b_shape{migraphx::shape::float_type, {4, 4, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 4, 8}, {0, 0, 1}},
......@@ -1300,8 +1315,10 @@ TEST_CASE(multibroadcast_2in)
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
}
{
}
TEST_CASE(multibroadcast_2in_static_static3)
{
migraphx::shape a_shape{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape b_shape{migraphx::shape::float_type, {4, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 4, 4}, {16, 4, 1}},
......@@ -1312,8 +1329,10 @@ TEST_CASE(multibroadcast_2in)
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
}
{
}
TEST_CASE(multibroadcast_2in_static_static4)
{
migraphx::shape a_shape{migraphx::shape::float_type, {3, 1, 4}};
migraphx::shape b_shape{migraphx::shape::float_type, {4, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 4, 4}, {4, 0, 1}},
......@@ -1324,13 +1343,14 @@ TEST_CASE(multibroadcast_2in)
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
}
{
}
TEST_CASE(multibroadcast_2in_static_static_error0)
{
migraphx::shape a_shape{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape b_shape{migraphx::shape::float_type, {4, 3}};
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
}
}
TEST_CASE(multinomial)
......@@ -2055,6 +2075,17 @@ TEST_CASE(test_unsqueeze_dyn)
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {2, 4}}, {"steps", {2}}}), s1);
}
TEST_CASE(test_unsqueeze_dyn_neg_axes)
{
migraphx::shape s1{migraphx::shape::float_type, {{1, 4, 3}, {2, 5, 0}, {3, 3, 0}}};
migraphx::shape s2{migraphx::shape::float_type, {{1, 4, 3}, {2, 5, 0}, {1, 1, 0}, {3, 3, 0}}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s1);
migraphx::shape s3{migraphx::shape::float_type,
{{1, 4, 3}, {2, 5, 0}, {1, 1, 0}, {3, 3, 0}, {1, 1, 0}}};
expect_shape(s3, migraphx::make_op("unsqueeze", {{"axes", {-1, -3}}}), s1);
}
TEST_CASE(test_unsqueeze_step)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 5, 12}};
......
......@@ -1930,7 +1930,7 @@ TEST_CASE(div_test)
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(div_dynamic_test)
TEST_CASE(div_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
......
......@@ -185,6 +185,31 @@ TEST_CASE(test_shape_packed)
EXPECT(not s.broadcasted());
}
TEST_CASE(test_shape_ndim_static)
{
migraphx::shape s0{migraphx::shape::float_type, {2, 2}};
EXPECT(s0.ndim() == 2);
migraphx::shape s1{migraphx::shape::float_type, {1, 2, 4, 4}};
EXPECT(s1.ndim() == 4);
migraphx::shape s2{migraphx::shape::float_type, {2, 4, 4, 1, 3}};
EXPECT(s1.ndim() == 5);
}
TEST_CASE(test_shape_ndim_dyn)
{
migraphx::shape s0{migraphx::shape::float_type, {{2, 2, 0}, {2, 2, 0}}};
EXPECT(s0.ndim() == 2);
migraphx::shape s1{migraphx::shape::float_type, {{1, 1, 0}, {2, 4, 0}, {2, 4, 0}, {2, 4, 0}}};
EXPECT(s1.ndim() == 4);
migraphx::shape s2{migraphx::shape::float_type,
{{1, 1, 0}, {2, 4, 0}, {2, 4, 0}, {1, 1, 1}, {3, 3, 0}}};
EXPECT(s1.ndim() == 5);
}
TEST_CASE(test_shape_non_packed_single_dim)
{
migraphx::shape s{migraphx::shape::float_type, {1, 64, 35, 35}, {156800, 1225, 35, 1}};
......@@ -212,6 +237,21 @@ TEST_CASE(test_shape_transposed2)
EXPECT(not s.broadcasted());
}
TEST_CASE(test_shape_static_to_dynamic)
{
migraphx::shape s0{migraphx::shape::float_type, {1, 2, 4, 4}};
migraphx::shape s1 = s0.to_dynamic();
migraphx::shape s2{migraphx::shape::float_type, {{1, 1, 0}, {2, 2, 0}, {4, 4, 0}, {4, 4, 0}}};
EXPECT(s1 == s2);
}
TEST_CASE(test_shape_dyn_to_dynamic)
{
migraphx::shape s0{migraphx::shape::float_type, {{1, 1, 0}, {2, 4, 0}, {2, 4, 0}, {2, 4, 0}}};
migraphx::shape s1 = s0.to_dynamic();
EXPECT(s0 == s1);
}
TEST_CASE(test_shape_overlap)
{
migraphx::shape s{migraphx::shape::float_type, {2, 2, 3}, {6, 3, 2}};
......
......@@ -51,7 +51,7 @@ template struct test_reduce_op_large<migraphx::op::reduce_min, 1, migraphx::shap
template struct test_reduce_op_large<migraphx::op::reduce_prod, 2, migraphx::shape::float_type>;
template struct test_reduce_op_large<migraphx::op::reduce_sum, 1, migraphx::shape::float_type>;
struct test_reduce_mean : verify_program<test_reduce_mean>
struct test_reduce_mean_1 : verify_program<test_reduce_mean_1>
{
migraphx::program create_program() const
{
......@@ -63,3 +63,16 @@ struct test_reduce_mean : verify_program<test_reduce_mean>
return p;
};
};
struct test_reduce_mean_2 : verify_program<test_reduce_mean_2>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {336, 400}};
auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::op::reduce_mean{{1}}, x);
return p;
};
};
......@@ -21,5 +21,5 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
numpy==1.18.5
numpy==1.21.6
onnxruntime==1.10.0
......@@ -57,7 +57,7 @@ echo "Dependencies are installed at $PREFIX"
rbuild prepare -d $PREFIX -s develop
# install onnx package for unit tests
pip3 install onnx==1.8.1 numpy==1.18.5 typing==3.7.4 pytest==6.0.1 packaging==16.8
pip3 install onnx==1.8.1 numpy==1.21.6 typing==3.7.4 pytest==6.0.1 packaging==16.8
# pin version of protobuf in Python for onnx runtime unit tests
pip3 install protobuf==3.20.0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment