"docs/archive_en_US/Tuner/BatchTuner.md" did not exist on "b7366b685afdde156e551f8ba5008857f789e368"
Unverified Commit e7ec374f authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Refactor dynamic_dimension to have multiple optimals (#1625)

Makes the optimals into a std::set<std::size_t>
Changes shape object functions to handle the opts change
Changes to convolution, flatten, pooling, and convolution in that they no longer calculate the output optimal dimensions. Instead returns empty opts. Will need to change this in the future if we want to support dynamic shapes fully.
Many changes to tests and shape calls with respect to the new optimals
parent 1329b9be
......@@ -121,28 +121,24 @@ TEST_CASE(argmax_axis_outofbounds)
TEST_CASE(argmax_dyn0)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {3, 3, 0}, {4, 4, 0}, {5, 5, 0}}};
expect_shape(
migraphx::shape{migraphx::shape::int64_type, {{1, 4, 0}, {1, 1, 0}, {4, 4, 0}, {5, 5, 0}}},
migraphx::make_op("argmax", {{"axis", 1}}),
input);
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {3, 3}, {4, 4}, {5, 5}}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {{1, 4}, {1, 1}, {4, 4}, {5, 5}}},
migraphx::make_op("argmax", {{"axis", 1}}),
input);
}
TEST_CASE(argmax_dyn1)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {3, 3, 0}, {4, 6, 0}, {4, 6, 0}}};
expect_shape(
migraphx::shape{migraphx::shape::int64_type, {{1, 4, 0}, {3, 3, 0}, {1, 1, 0}, {4, 6, 0}}},
migraphx::make_op("argmax", {{"axis", 2}}),
input);
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {3, 3}, {4, 6}, {4, 6}}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {{1, 4}, {3, 3}, {1, 1}, {4, 6}}},
migraphx::make_op("argmax", {{"axis", 2}}),
input);
}
TEST_CASE(binary_dyn_static_error)
{
migraphx::shape a_shape{migraphx::shape::float_type, {1, 4, 4}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 1, 0}, {4, 4, 4}, {4, 4, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 1}, {4, 4, {4}}, {4, 4}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
throws_shape(migraphx::make_op("add"), a_shape, b_shape);
}
......@@ -216,13 +212,13 @@ TEST_CASE(broadcast_2in_not_matching_error)
TEST_CASE(broadcast_2in_dynamic_s0_error1)
{
migraphx::shape a_input{migraphx::shape::float_type, {4, 2}, {2, 1}};
migraphx::shape b_input{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {2, 2, 0}}};
migraphx::shape b_input{migraphx::shape::float_type, {{1, 4}, {4, 4}, {2, 2}}};
throws_shape(migraphx::make_op("broadcast", {{"axis", 0}}), b_input, a_input);
}
TEST_CASE(broadcast_2in_dynamic_s0_error2)
{
std::vector<migraphx::shape::dynamic_dimension> dd{{4, 4, 0}};
std::vector<migraphx::shape::dynamic_dimension> dd{{4, 4}};
migraphx::shape a_input{migraphx::shape::float_type, dd};
migraphx::shape b_input{migraphx::shape::float_type, {4, 4}, {4, 1}};
throws_shape(migraphx::make_op("broadcast", {{"axis", 0}}), a_input, b_input);
......@@ -231,9 +227,9 @@ TEST_CASE(broadcast_2in_dynamic_s0_error2)
TEST_CASE(broadcast_2in_static_dyn)
{
migraphx::shape a_input{migraphx::shape::float_type, {4}, {1}};
migraphx::shape b_input{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {2, 2, 0}}};
migraphx::shape b_input{migraphx::shape::float_type, {{1, 4}, {4, 4}, {2, 2}}};
throws_shape(migraphx::make_op("broadcast", {{"axis", 0}}), a_input, b_input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {2, 2, 0}}},
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}, {2, 2}}},
migraphx::make_op("broadcast", {{"axis", 1}}),
a_input,
b_input);
......@@ -243,11 +239,11 @@ TEST_CASE(broadcast_2in_static_dyn)
TEST_CASE(broadcast_2in_dyn_s0_ndim_greater_than_1_error)
{
migraphx::shape a_input{migraphx::shape::float_type, {4, 2}};
migraphx::shape b_input{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {2, 2, 0}}};
migraphx::shape b_input{migraphx::shape::float_type, {{1, 4}, {4, 4}, {2, 2}}};
throws_shape(migraphx::make_op("broadcast", {{"axis", 0}}), a_input, b_input);
}
TEST_CASE(convolution_shape)
TEST_CASE(conv_2d_0)
{
migraphx::shape output{migraphx::shape::float_type, {4, 4, 1, 1}};
migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
......@@ -257,13 +253,19 @@ TEST_CASE(convolution_shape)
throws_shape(
migraphx::make_op("convolution", {{"padding", {0}}, {"stride", {1}}, {"dilation", {1}}}),
input);
}
TEST_CASE(conv_2d_1)
{
migraphx::shape weights{migraphx::shape::float_type, {4, 3, 3, 3}};
migraphx::shape input2{migraphx::shape::float_type, {3, 3}};
migraphx::shape weights2{migraphx::shape::float_type, {3, 3}};
throws_shape(migraphx::make_op("convolution"), input2, weights2);
throws_shape(migraphx::make_op("convolution"), input2, weights);
}
// 1D convolution
TEST_CASE(conv_1d)
{
migraphx::shape output_1d{migraphx::shape::float_type, {4, 4, 1}};
migraphx::shape input_1d{migraphx::shape::float_type, {4, 3, 3}};
migraphx::shape weights_1d{migraphx::shape::float_type, {4, 3, 3}};
......@@ -272,12 +274,17 @@ TEST_CASE(convolution_shape)
migraphx::make_op("convolution", {{"padding", {0}}, {"stride", {1}}, {"dilation", {1}}}),
input_1d,
weights_1d);
}
// channel numbers mismatch
weights_1d = {migraphx::shape::float_type, {4, 8, 3}};
TEST_CASE(conv_channel_mismatch)
{
migraphx::shape input_1d{migraphx::shape::float_type, {4, 3, 3}};
migraphx::shape weights_1d = {migraphx::shape::float_type, {4, 8, 3}};
throws_shape(migraphx::make_op("convolution"), input_1d, weights_1d);
}
// 3D convolution
TEST_CASE(conv_3D)
{
migraphx::shape output_3d{migraphx::shape::float_type, {4, 4, 1, 1, 1}};
migraphx::shape input_3d{migraphx::shape::float_type, {4, 3, 3, 3, 3}};
migraphx::shape weights_3d{migraphx::shape::float_type, {4, 3, 3, 3, 3}};
......@@ -289,93 +296,82 @@ TEST_CASE(convolution_shape)
weights_3d);
throws_shape(migraphx::make_op("convolution"), input_3d, weights_3d);
}
// dynamic batch
TEST_CASE(conv_dyn_batch)
{
migraphx::shape input_dyn_shape{migraphx::shape::float_type,
{{1, 100, 0}, {3, 3, 0}, {5, 5, 0}, {5, 5, 0}}};
{{1, 100}, {3, 3}, {5, 5}, {5, 5}}};
migraphx::shape weights_shape{migraphx::shape::float_type, {1, 3, 3, 3}};
migraphx::shape output_dyn_shape{migraphx::shape::float_type,
{{
1,
100,
0,
},
{1, 1, 0},
{3, 3, 0},
{3, 3, 0}}};
{{1, 100}, {1, 1}, {3, 3}, {3, 3}}};
expect_shape(output_dyn_shape,
migraphx::make_op("convolution",
{{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}),
input_dyn_shape,
weights_shape);
}
// dynamic image
input_dyn_shape = {migraphx::shape::float_type, {{1, 1, 0}, {3, 3, 0}, {5, 20, 0}, {5, 20, 0}}};
weights_shape = {migraphx::shape::float_type, {1, 3, 3, 3}};
output_dyn_shape = {migraphx::shape::float_type,
{{
1,
1,
0,
},
{1, 1, 0},
{3, 18, 0},
{3, 18, 0}}};
TEST_CASE(conv_dyn_img)
{
migraphx::shape input_dyn_shape = {migraphx::shape::float_type,
{{1, 1}, {3, 3}, {5, 20}, {5, 20}}};
migraphx::shape weights_shape = {migraphx::shape::float_type, {1, 3, 3, 3}};
migraphx::shape output_dyn_shape = {migraphx::shape::float_type,
{{1, 1}, {1, 1}, {3, 18}, {3, 18}}};
expect_shape(output_dyn_shape,
migraphx::make_op("convolution",
{{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}),
input_dyn_shape,
weights_shape);
}
// dynamic weights
input_dyn_shape = {migraphx::shape::float_type, {1, 3, 10, 10}};
weights_shape = {migraphx::shape::float_type, {{1, 1, 0}, {3, 3, 0}, {2, 4, 0}, {2, 4, 0}}};
output_dyn_shape = {migraphx::shape::float_type,
{{
1,
1,
0,
},
{1, 1, 0},
{7, 9, 0},
{7, 9, 0}}};
TEST_CASE(conv_dyn_weights)
{
migraphx::shape input_dyn_shape = {migraphx::shape::float_type, {1, 3, 10, 10}};
migraphx::shape weights_shape = {migraphx::shape::float_type, {{1, 1}, {3, 3}, {2, 4}, {2, 4}}};
migraphx::shape output_dyn_shape = {migraphx::shape::float_type,
{{1, 1}, {1, 1}, {7, 9}, {7, 9}}};
expect_shape(output_dyn_shape,
migraphx::make_op("convolution",
{{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}),
input_dyn_shape,
weights_shape);
}
// dynamic img and weights
input_dyn_shape = {migraphx::shape::float_type, {{1, 1, 0}, {3, 3, 0}, {5, 20, 0}, {5, 20, 0}}};
weights_shape = {migraphx::shape::float_type, {{1, 1, 0}, {3, 3, 0}, {2, 4, 0}, {2, 4, 0}}};
output_dyn_shape = {migraphx::shape::float_type,
{{
1,
1,
0,
},
{1, 1, 0},
{2, 19, 0},
{2, 19, 0}}};
TEST_CASE(conv_dyn_img_weights)
{
migraphx::shape input_dyn_shape = {migraphx::shape::float_type,
{{1, 1}, {3, 3}, {5, 20}, {5, 20}}};
migraphx::shape weights_shape = {migraphx::shape::float_type, {{1, 1}, {3, 3}, {2, 4}, {2, 4}}};
migraphx::shape output_dyn_shape = {migraphx::shape::float_type,
{{1, 1}, {1, 1}, {2, 19}, {2, 19}}};
expect_shape(output_dyn_shape,
migraphx::make_op("convolution",
{{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}),
input_dyn_shape,
weights_shape);
}
// input attr shape mismatch
input_dyn_shape = {migraphx::shape::float_type,
{{1, 100, 0}, {3, 3, 0}, {5, 5, 0}, {5, 5, 0}, {5, 5, 0}}};
weights_shape = {migraphx::shape::float_type, {1, 3, 3, 3, 3}};
TEST_CASE(conv_attr_shape_mismatch)
{
migraphx::shape input_dyn_shape = {migraphx::shape::float_type,
{{1, 100}, {3, 3}, {5, 5}, {5, 5}, {5, 5}}};
migraphx::shape weights_shape = {migraphx::shape::float_type, {1, 3, 3, 3, 3}};
throws_shape(migraphx::make_op("convolution",
{{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}),
input_dyn_shape,
weights_shape);
}
TEST_CASE(conv_autopad_dyn_batch)
{
// auto_pad dynamic batch
input_dyn_shape = {migraphx::shape::float_type, {{1, 10, 0}, {3, 3, 0}, {5, 5, 0}, {5, 5, 0}}};
weights_shape = {migraphx::shape::float_type, {1, 3, 3, 3}};
output_dyn_shape = {migraphx::shape::float_type, {{1, 10, 0}, {1, 1, 0}, {5, 5, 0}, {5, 5, 0}}};
migraphx::shape input_dyn_shape = {migraphx::shape::float_type,
{{1, 10}, {3, 3}, {5, 5}, {5, 5}}};
migraphx::shape weights_shape = {migraphx::shape::float_type, {1, 3, 3, 3}};
migraphx::shape output_dyn_shape = {migraphx::shape::float_type,
{{1, 10}, {1, 1}, {5, 5}, {5, 5}}};
expect_shape(output_dyn_shape,
migraphx::make_op("convolution",
{{"stride", {1, 1}},
......@@ -383,12 +379,16 @@ TEST_CASE(convolution_shape)
{"padding_mode", migraphx::op::padding_mode_t::same_upper}}),
input_dyn_shape,
weights_shape);
}
TEST_CASE(conv_autopad_dyn_img)
{
// auto_pad dynamic img
input_dyn_shape = {migraphx::shape::float_type, {{1, 1, 0}, {3, 3, 0}, {5, 10, 0}, {5, 10, 0}}};
weights_shape = {migraphx::shape::float_type, {1, 3, 3, 3}};
output_dyn_shape = {migraphx::shape::float_type,
{{1, 1, 0}, {1, 1, 0}, {5, 10, 0}, {5, 10, 0}}};
migraphx::shape input_dyn_shape = {migraphx::shape::float_type,
{{1, 1}, {3, 3}, {5, 10}, {5, 10}}};
migraphx::shape weights_shape = {migraphx::shape::float_type, {1, 3, 3, 3}};
migraphx::shape output_dyn_shape = {migraphx::shape::float_type,
{{1, 1}, {1, 1}, {5, 10}, {5, 10}}};
expect_shape(output_dyn_shape,
migraphx::make_op("convolution",
{{"stride", {1, 1}},
......@@ -396,13 +396,15 @@ TEST_CASE(convolution_shape)
{"padding_mode", migraphx::op::padding_mode_t::same_upper}}),
input_dyn_shape,
weights_shape);
}
// auto_pad dynamic kernel
input_dyn_shape = {migraphx::shape::float_type,
{{1, 1, 0}, {3, 3, 0}, {10, 10, 0}, {10, 10, 0}}};
weights_shape = {migraphx::shape::float_type, {{1, 1, 0}, {3, 3, 0}, {2, 4, 0}, {2, 4, 0}}};
output_dyn_shape = {migraphx::shape::float_type,
{{1, 1, 0}, {1, 1, 0}, {10, 10, 0}, {10, 10, 0}}};
TEST_CASE(conv_autopad_dyn_kernel)
{
migraphx::shape input_dyn_shape = {migraphx::shape::float_type,
{{1, 1}, {3, 3}, {10, 10}, {10, 10}}};
migraphx::shape weights_shape = {migraphx::shape::float_type, {{1, 1}, {3, 3}, {2, 4}, {2, 4}}};
migraphx::shape output_dyn_shape = {migraphx::shape::float_type,
{{1, 1}, {1, 1}, {10, 10}, {10, 10}}};
expect_shape(output_dyn_shape,
migraphx::make_op("convolution",
{{"stride", {1, 1}},
......@@ -425,7 +427,7 @@ TEST_CASE(contiguous_shape)
TEST_CASE(contiguous_dyn_shape)
{
migraphx::shape s0{migraphx::shape::float_type, {{1, 4, 0}, {2, 2, 2}}};
migraphx::shape s0{migraphx::shape::float_type, {{1, 4}, {2, 2, {2}}}};
expect_shape(s0, migraphx::make_op("contiguous"), s0);
}
......@@ -618,9 +620,9 @@ TEST_CASE(dot_4D_test)
TEST_CASE(dot_dyn_static_test0)
{
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4, 0}, {5, 5, 0}}};
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4}, {5, 5}}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {8, 8, 0}}},
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {8, 8}}},
migraphx::make_op("dot"),
s_m1,
s_m2);
......@@ -628,16 +630,16 @@ TEST_CASE(dot_dyn_static_test0)
TEST_CASE(dot_dyn_static_mismatch_error)
{
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4, 0}, {3, 3, 0}, {5, 5, 0}, {5, 5, 0}}};
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4}, {3, 3}, {5, 5}, {5, 5}}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_dyn_dyn_test0)
{
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4, 0}, {5, 5, 0}}};
migraphx::shape s_m2{migraphx::shape::float_type, {{5, 5, 0}, {6, 8, 8}}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {6, 8, 8}}},
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4}, {5, 5}}};
migraphx::shape s_m2{migraphx::shape::float_type, {{5, 5}, {6, 8, {8}}}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {6, 8, {8}}}},
migraphx::make_op("dot"),
s_m1,
s_m2);
......@@ -645,9 +647,9 @@ TEST_CASE(dot_dyn_dyn_test0)
TEST_CASE(dot_dyn_dyn_test1)
{
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4, 0}, {4, 5, 5}}};
migraphx::shape s_m2{migraphx::shape::float_type, {{4, 5, 5}, {6, 8, 8}}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {6, 8, 8}}},
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4}, {4, 5, {5}}}};
migraphx::shape s_m2{migraphx::shape::float_type, {{4, 5, {5}}, {6, 8, {8}}}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {6, 8, {8}}}},
migraphx::make_op("dot"),
s_m1,
s_m2);
......@@ -655,14 +657,14 @@ TEST_CASE(dot_dyn_dyn_test1)
TEST_CASE(dot_dyn_mismatch_test0)
{
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4, 0}, {5, 5, 0}, {5, 5, 0}}};
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4}, {5, 5}, {5, 5}}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_dyn_mismatch_test1)
{
migraphx::shape s_m1{migraphx::shape::float_type, {{4, 4, 0}, {5, 5, 0}, {2, 5, 0}}};
migraphx::shape s_m1{migraphx::shape::float_type, {{4, 4}, {5, 5}, {2, 5}}};
migraphx::shape s_m2{migraphx::shape::float_type, {4, 5, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
......@@ -697,12 +699,11 @@ TEST_CASE(flatten_shape)
TEST_CASE(flatten_dyn_axis0)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {4, 4, 0}, {6, 6, 0}, {8, 8, 0}}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 1, 0}, {192, 768, 0}}},
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {4, 4}, {6, 6}, {8, 8}}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 1}, {192, 768}}},
migraphx::make_op("flatten", {{"axis", 0}}),
input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 1, 0}, {192, 768, 0}}},
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 1}, {192, 768}}},
migraphx::make_op("flatten", {{"axis", -4}}),
input);
}
......@@ -710,13 +711,13 @@ TEST_CASE(flatten_dyn_axis0)
TEST_CASE(flatten_dyn_axis1)
{
migraphx::shape input{migraphx::shape::float_type,
{{2, 2, 2}, {4, 4, 0}, {4, 6, 5}, {4, 6, 5}}};
{{2, 2, {2}}, {4, 4}, {4, 6, {5}}, {4, 6, {5}}}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{2, 2, 2}, {4 * 4 * 4, 4 * 6 * 6, 0}}},
migraphx::shape{migraphx::shape::float_type, {{2, 2, {2}}, {4 * 4 * 4, 4 * 6 * 6}}},
migraphx::make_op("flatten", {{"axis", 1}}),
input);
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{2, 2, 2}, {4 * 4 * 4, 4 * 6 * 6, 0}}},
migraphx::shape{migraphx::shape::float_type, {{2, 2, {2}}, {4 * 4 * 4, 4 * 6 * 6}}},
migraphx::make_op("flatten", {{"axis", -3}}),
input);
}
......@@ -724,31 +725,27 @@ TEST_CASE(flatten_dyn_axis1)
TEST_CASE(flatten_dyn_axis2)
{
migraphx::shape input{migraphx::shape::float_type,
{{2, 2, 2}, {4, 4, 0}, {4, 6, 5}, {4, 6, 5}}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{2 * 4, 2 * 4, 0}, {4 * 4, 6 * 6, 5 * 5}}},
migraphx::make_op("flatten", {{"axis", 2}}),
input);
{{2, 2, {2}}, {4, 4}, {4, 6, {5}}, {4, 6, {5}}}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{2 * 4, 2 * 4}, {4 * 4, 6 * 6}}},
migraphx::make_op("flatten", {{"axis", 2}}),
input);
}
TEST_CASE(flatten_dyn_axis3)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {4, 4, 0}, {6, 6, 0}, {8, 8, 0}}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{1 * 4 * 6, 4 * 4 * 6, 0}, {8, 8, 0}}},
migraphx::make_op("flatten", {{"axis", 3}}),
input);
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {4, 4}, {6, 6}, {8, 8}}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1 * 4 * 6, 4 * 4 * 6}, {8, 8}}},
migraphx::make_op("flatten", {{"axis", 3}}),
input);
}
TEST_CASE(flatten_dyn_axis4)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {4, 4, 0}, {6, 6, 0}, {8, 8, 0}}};
expect_shape(migraphx::shape{migraphx::shape::float_type,
{{1 * 4 * 6 * 8, 4 * 4 * 6 * 8, 0}, {1, 1, 0}}},
migraphx::make_op("flatten", {{"axis", 4}}),
input);
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {4, 4}, {6, 6}, {8, 8}}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{1 * 4 * 6 * 8, 4 * 4 * 6 * 8}, {1, 1}}},
migraphx::make_op("flatten", {{"axis", 4}}),
input);
}
TEST_CASE(gather)
......@@ -842,11 +839,11 @@ TEST_CASE(gather_dyn0)
{
// Insert dynamic index into dynamic shape
migraphx::shape input{migraphx::shape::float_type,
{{2, 3, 2}, {3, 4, 3}, {6, 9, 7}, {12, 14, 13}}};
migraphx::shape indices{migraphx::shape::int32_type, {{2, 7, 3}, {3, 3, 0}}};
{{2, 3, {2}}, {3, 4, {3}}, {6, 9, {7}}, {12, 14, {13}}}};
migraphx::shape indices{migraphx::shape::int32_type, {{2, 7, {3}}, {3, 3}}};
int axis = 1;
expect_shape(migraphx::shape{migraphx::shape::float_type,
{{2, 3, 2}, {2, 7, 3}, {3, 3, 0}, {6, 9, 7}, {12, 14, 13}}},
{{2, 3, {2}}, {2, 7, {3}}, {3, 3}, {6, 9, {7}}, {12, 14, {13}}}},
migraphx::make_op("gather", {{"axis", axis}}),
input,
indices);
......@@ -856,11 +853,11 @@ TEST_CASE(gather_dyn1)
{
// Insert static index into dynamic shape
migraphx::shape input{migraphx::shape::float_type,
{{2, 3, 2}, {3, 4, 3}, {6, 9, 7}, {12, 14, 13}}};
{{2, 3, {2}}, {3, 4, {3}}, {6, 9, {7}}, {12, 14, {13}}}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
int axis = 1;
expect_shape(migraphx::shape{migraphx::shape::float_type,
{{2, 3, 2}, {2, 2, 0}, {3, 3, 0}, {6, 9, 7}, {12, 14, 13}}},
{{2, 3, {2}}, {2, 2}, {3, 3}, {6, 9, {7}}, {12, 14, {13}}}},
migraphx::make_op("gather", {{"axis", axis}}),
input,
indices);
......@@ -870,27 +867,28 @@ TEST_CASE(gather_dyn2)
{
// Insert scalar (static) index into dynamic shape
migraphx::shape input{migraphx::shape::float_type,
{{2, 3, 2}, {3, 4, 3}, {6, 9, 7}, {12, 14, 13}}};
{{2, 3, {2}}, {3, 4, {3}}, {6, 9, {7}}, {12, 14, {13}}}};
std::vector<std::size_t> mins;
std::vector<std::size_t> maxes;
std::vector<std::size_t> opts;
std::vector<std::set<std::size_t>> opts;
migraphx::shape indices{migraphx::shape::int32_type, mins, maxes, opts};
int axis = 1;
expect_shape(migraphx::shape{migraphx::shape::float_type, {{2, 3, 2}, {6, 9, 7}, {12, 14, 13}}},
migraphx::make_op("gather", {{"axis", axis}}),
input,
indices);
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{2, 3, {2}}, {6, 9, {7}}, {12, 14, {13}}}},
migraphx::make_op("gather", {{"axis", axis}}),
input,
indices);
}
TEST_CASE(gather_dyn3)
{
// Insert dynamic index into static shape, axis 1
migraphx::shape input{migraphx::shape::float_type, {2, 3, 6, 12}};
migraphx::shape indices{migraphx::shape::int32_type, {{2, 3, 2}, {3, 4, 3}}};
migraphx::shape indices{migraphx::shape::int32_type, {{2, 3, {2}}, {3, 4, {3}}}};
int axis = 1;
expect_shape(migraphx::shape{migraphx::shape::float_type,
{{2, 2, 0}, {2, 3, 2}, {3, 4, 3}, {6, 6, 0}, {12, 12, 0}}},
{{2, 2}, {2, 3, {2}}, {3, 4, {3}}, {6, 6}, {12, 12}}},
migraphx::make_op("gather", {{"axis", axis}}),
input,
indices);
......@@ -900,10 +898,10 @@ TEST_CASE(gather_dyn4)
{
// Insert dynamic index into static shape, axis 0
migraphx::shape input{migraphx::shape::float_type, {2, 3, 6, 12}};
migraphx::shape indices{migraphx::shape::int32_type, {{2, 3, 2}, {3, 4, 3}}};
migraphx::shape indices{migraphx::shape::int32_type, {{2, 3, {2}}, {3, 4, {3}}}};
int axis = 0;
expect_shape(migraphx::shape{migraphx::shape::float_type,
{{2, 3, 2}, {3, 4, 3}, {3, 3, 0}, {6, 6, 0}, {12, 12, 0}}},
{{2, 3, {2}}, {3, 4, {3}}, {3, 3}, {6, 6}, {12, 12}}},
migraphx::make_op("gather", {{"axis", axis}}),
input,
indices);
......@@ -1439,13 +1437,13 @@ TEST_CASE(multibroadcast)
TEST_CASE(multibroadcast_2in_static_dyn0)
{
migraphx::shape a_shape{migraphx::shape::float_type, {4, 4}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 4, 0}, {4, 4, 4}, {4, 4, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 4}, {4, 4, {4}}, {4, 4}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {4, 4, 0}}},
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}, {4, 4}}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {4, 4, 0}}},
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}, {4, 4}}},
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
......@@ -1454,13 +1452,13 @@ TEST_CASE(multibroadcast_2in_static_dyn0)
TEST_CASE(multibroadcast_2in_static_dyn1)
{
migraphx::shape a_shape{migraphx::shape::float_type, {1, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{8, 8, 0}, {6, 6, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{8, 8}, {6, 6}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{8, 8, 0}, {6, 6, 0}}},
expect_shape(migraphx::shape{migraphx::shape::float_type, {{8, 8}, {6, 6}}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
expect_shape(migraphx::shape{migraphx::shape::float_type, {{8, 8, 0}, {6, 6, 0}}},
expect_shape(migraphx::shape{migraphx::shape::float_type, {{8, 8}, {6, 6}}},
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
......@@ -1469,13 +1467,13 @@ TEST_CASE(multibroadcast_2in_static_dyn1)
TEST_CASE(multibroadcast_2in_static_dyn2)
{
migraphx::shape a_shape{migraphx::shape::float_type, {1, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{8, 8, 0}, {6, 6, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{8, 8}, {6, 6}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{8, 8, 0}, {6, 6, 0}}},
expect_shape(migraphx::shape{migraphx::shape::float_type, {{8, 8}, {6, 6}}},
migraphx::make_op("multibroadcast", {{"out_dyn_dims", migraphx::to_value(b)}}),
a_shape,
b_shape);
expect_shape(migraphx::shape{migraphx::shape::float_type, {{8, 8, 0}, {6, 6, 0}}},
expect_shape(migraphx::shape{migraphx::shape::float_type, {{8, 8}, {6, 6}}},
migraphx::make_op("multibroadcast", {{"out_dyn_dims", migraphx::to_value(b)}}),
b_shape,
a_shape);
......@@ -1485,7 +1483,7 @@ TEST_CASE(multibroadcast_2in_static_dyn_error0)
{
// doesn't match on first dimension
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 3, 0}, {6, 6, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 3}, {6, 6}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
......@@ -1495,7 +1493,7 @@ TEST_CASE(multibroadcast_2in_static_dyn_error1)
{
// doesn't match on first dimension
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 4, 0}, {6, 6, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 4}, {6, 6}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
......@@ -1505,7 +1503,7 @@ TEST_CASE(multibroadcast_2in_static_dyn_error2)
{
// doesn't match on first dimension
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 2, 0}, {6, 6, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{1, 2}, {6, 6}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
......@@ -1513,15 +1511,15 @@ TEST_CASE(multibroadcast_2in_static_dyn_error2)
TEST_CASE(multibroadcast_2in_dyn_dyn0)
{
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {2, 4, 2}, {2, 4, 0}};
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4}, {2, 4, {2}}, {2, 4}};
migraphx::shape a_shape{migraphx::shape::float_type, a};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 4, 2}, {2, 4, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 4, {2}}, {2, 4}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {2, 4, 2}, {2, 4, 0}}},
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {2, 4, {2}}, {2, 4}}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {2, 4, 2}, {2, 4, 0}}},
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {2, 4, {2}}, {2, 4}}},
migraphx::make_op("multibroadcast"),
b_shape,
a_shape);
......@@ -1529,15 +1527,15 @@ TEST_CASE(multibroadcast_2in_dyn_dyn0)
TEST_CASE(multibroadcast_2in_dyn_dyn1)
{
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {2, 4, 2}, {2, 4, 0}};
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4}, {2, 4, {2}}, {2, 4}};
migraphx::shape a_shape{migraphx::shape::float_type, a};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 4, 2}, {2, 4, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 4, {2}}, {2, 4}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {2, 4, 2}, {2, 4, 0}}},
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {2, 4, {2}}, {2, 4}}},
migraphx::make_op("multibroadcast", {{"out_dyn_dims", migraphx::to_value(a)}}),
a_shape,
b_shape);
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {2, 4, 2}, {2, 4, 0}}},
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {2, 4, {2}}, {2, 4}}},
migraphx::make_op("multibroadcast", {{"out_dyn_dims", migraphx::to_value(a)}}),
b_shape,
a_shape);
......@@ -1546,9 +1544,9 @@ TEST_CASE(multibroadcast_2in_dyn_dyn1)
TEST_CASE(multibroadcast_2in_dyn_dyn_error0)
{
// max doesn't match on second dimension of a
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {2, 4, 2}, {2, 4, 0}};
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4}, {2, 4, {2}}, {2, 4}};
migraphx::shape a_shape{migraphx::shape::float_type, a};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 5, 2}, {2, 4, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 5, {2}}, {2, 4}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
......@@ -1557,9 +1555,9 @@ TEST_CASE(multibroadcast_2in_dyn_dyn_error0)
TEST_CASE(multibroadcast_2in_dyn_dyn_error1)
{
// opt doesn't match on second dimension of a
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4, 0}, {2, 4, 2}, {2, 4, 0}};
std::vector<migraphx::shape::dynamic_dimension> a{{1, 4}, {2, 4, {2}}, {2, 4}};
migraphx::shape a_shape{migraphx::shape::float_type, a};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 4, 3}, {2, 4, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 4, {3}}, {2, 4}};
migraphx::shape b_shape{migraphx::shape::float_type, b};
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
throws_shape(migraphx::make_op("multibroadcast"), b_shape, a_shape);
......@@ -1670,7 +1668,7 @@ TEST_CASE(nms_shape)
score_thres_s);
// use_dyn_output == true
output_s = {migraphx::shape::int64_type, {{0, 6, 0}, {3, 3, 0}}};
output_s = {migraphx::shape::int64_type, {{0, 6}, {3, 3}}};
expect_shape(output_s,
migraphx::make_op("nonmaxsuppression",
{{"center_point_box", true}, {"use_dyn_output", true}}),
......@@ -1681,9 +1679,9 @@ TEST_CASE(nms_shape)
score_thres_s);
// dynamic batches
boxes_s = {migraphx::shape::float_type, {{1, 3, 0}, {6, 6, 0}, {4, 4, 0}}};
scores_s = {migraphx::shape::float_type, {{1, 3, 0}, {1, 1, 0}, {6, 6, 0}}};
output_s = {migraphx::shape::int64_type, {{0, 18, 0}, {3, 3, 0}}};
boxes_s = {migraphx::shape::float_type, {{1, 3}, {6, 6}, {4, 4}}};
scores_s = {migraphx::shape::float_type, {{1, 3}, {1, 1}, {6, 6}}};
output_s = {migraphx::shape::int64_type, {{0, 18}, {3, 3}}};
expect_shape(output_s,
migraphx::make_op("nonmaxsuppression",
{{"center_point_box", true}, {"use_dyn_output", true}}),
......@@ -1694,9 +1692,9 @@ TEST_CASE(nms_shape)
score_thres_s);
// dynamic num boxes
boxes_s = {migraphx::shape::float_type, {{1, 1, 0}, {6, 20, 0}, {4, 4, 0}}};
scores_s = {migraphx::shape::float_type, {{1, 1, 0}, {1, 1, 0}, {6, 20, 0}}};
output_s = {migraphx::shape::int64_type, {{0, 20, 0}, {3, 3, 0}}};
boxes_s = {migraphx::shape::float_type, {{1, 1}, {6, 20}, {4, 4}}};
scores_s = {migraphx::shape::float_type, {{1, 1}, {1, 1}, {6, 20}}};
output_s = {migraphx::shape::int64_type, {{0, 20}, {3, 3}}};
expect_shape(output_s,
migraphx::make_op("nonmaxsuppression",
{{"center_point_box", true}, {"use_dyn_output", true}}),
......@@ -1716,9 +1714,9 @@ TEST_CASE(nms_shape)
score_thres_s);
// dynamic classes
boxes_s = {migraphx::shape::float_type, {{1, 1, 0}, {6, 6, 0}, {4, 4, 0}}};
scores_s = {migraphx::shape::float_type, {{1, 1, 0}, {1, 3, 0}, {6, 6, 0}}};
output_s = {migraphx::shape::int64_type, {{0, 6, 0}, {3, 3, 0}}};
boxes_s = {migraphx::shape::float_type, {{1, 1}, {6, 6}, {4, 4}}};
scores_s = {migraphx::shape::float_type, {{1, 1}, {1, 3}, {6, 6}}};
output_s = {migraphx::shape::int64_type, {{0, 6}, {3, 3}}};
expect_shape(output_s,
migraphx::make_op("nonmaxsuppression",
{{"center_point_box", true}, {"use_dyn_output", true}}),
......@@ -1751,8 +1749,8 @@ TEST_CASE(nms_shape)
score_thres_s);
// dynamic mismatch batches
boxes_s = {migraphx::shape::float_type, {{1, 4, 0}, {6, 6, 0}, {4, 4, 0}}};
scores_s = {migraphx::shape::float_type, {{2, 8, 0}, {1, 1, 0}, {6, 6, 0}}};
boxes_s = {migraphx::shape::float_type, {{1, 4}, {6, 6}, {4, 4}}};
scores_s = {migraphx::shape::float_type, {{2, 8}, {1, 1}, {6, 6}}};
throws_shape(migraphx::make_op("nonmaxsuppression",
{{"center_point_box", true}, {"use_dyn_output", true}}),
boxes_s,
......@@ -1762,8 +1760,8 @@ TEST_CASE(nms_shape)
score_thres_s);
// dynamic mismatch num boxes
boxes_s = {migraphx::shape::float_type, {{1, 1, 0}, {6, 8, 0}, {4, 4, 0}}};
scores_s = {migraphx::shape::float_type, {{1, 1, 0}, {1, 1, 0}, {3, 9, 0}}};
boxes_s = {migraphx::shape::float_type, {{1, 1}, {6, 8}, {4, 4}}};
scores_s = {migraphx::shape::float_type, {{1, 1}, {1, 1}, {3, 9}}};
throws_shape(migraphx::make_op("nonmaxsuppression",
{{"center_point_box", true}, {"use_dyn_output", true}}),
boxes_s,
......@@ -1774,7 +1772,7 @@ TEST_CASE(nms_shape)
// dynamic number of classes, fixed boxes_s, mismatch batches
boxes_s = {migraphx::shape::float_type, {1, 6, 4}};
scores_s = {migraphx::shape::float_type, {{1, 3, 0}, {1, 3, 0}, {6, 6, 0}}};
scores_s = {migraphx::shape::float_type, {{1, 3}, {1, 3}, {6, 6}}};
throws_shape(migraphx::make_op("nonmaxsuppression",
{{"center_point_box", true}, {"use_dyn_output", true}}),
boxes_s,
......@@ -1784,7 +1782,7 @@ TEST_CASE(nms_shape)
score_thres_s);
// dynamic number of classes, fixed boxes_s, mismatch num boxes
boxes_s = {migraphx::shape::float_type, {1, 6, 4}};
scores_s = {migraphx::shape::float_type, {{1, 1, 0}, {1, 3, 0}, {4, 8, 0}}};
scores_s = {migraphx::shape::float_type, {{1, 1}, {1, 3}, {4, 8}}};
throws_shape(migraphx::make_op("nonmaxsuppression",
{{"center_point_box", true}, {"use_dyn_output", true}}),
boxes_s,
......@@ -1810,19 +1808,17 @@ TEST_CASE(pad_shape1)
TEST_CASE(pad_dyn_shape0)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 2}, {3, 3, 0}, {3, 5, 0}, {3, 5, 0}}};
migraphx::shape output{migraphx::shape::float_type,
{{1, 4, 2}, {3, 3, 0}, {5, 7, 0}, {5, 7, 0}}};
migraphx::shape input{migraphx::shape::float_type, {{1, 4, {2}}, {3, 3}, {3, 5}, {3, 5}}};
migraphx::shape output{migraphx::shape::float_type, {{1, 4, {2}}, {3, 3}, {5, 7}, {5, 7}}};
expect_shape(output, migraphx::make_op("pad", {{"pads", {0, 0, 1, 1, 0, 0, 1, 1}}}), input);
}
TEST_CASE(pad_dyn_shape1)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 2}, {3, 3, 0}, {3, 5, 5}, {3, 5, 5}}};
{{1, 4, {2}}, {3, 3}, {3, 5, {5}}, {3, 5, {5}}}};
migraphx::shape output{migraphx::shape::float_type,
{{1, 4, 2}, {3, 3, 0}, {5, 7, 7}, {5, 7, 7}}};
{{1, 4, {2}}, {3, 3}, {5, 7, {7}}, {5, 7, {7}}}};
expect_shape(output, migraphx::make_op("pad", {{"pads", {0, 0, 1, 1, 0, 0, 1, 1}}}), input);
}
......@@ -1880,8 +1876,7 @@ TEST_CASE(pooling_shape3)
TEST_CASE(pooling_dyn_shape0)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {3, 3, 3}, {3, 3, 3}, {3, 3, 0}}};
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {3, 3, {3}}, {3, 3, {3}}, {3, 3}}};
throws_shape(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max},
{"padding", {1}},
......@@ -1892,10 +1887,8 @@ TEST_CASE(pooling_dyn_shape0)
TEST_CASE(pooling_dyn_shape1)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {3, 3, 3}, {3, 3, 3}, {3, 3, 0}}};
migraphx::shape output{migraphx::shape::float_type,
{{1, 4, 0}, {3, 3, 3}, {1, 1, 1}, {1, 1, 0}}};
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {3, 3, {3}}, {3, 3, {3}}, {3, 3}}};
migraphx::shape output{migraphx::shape::float_type, {{1, 4}, {3, 3}, {1, 1}, {1, 1}}};
expect_shape(output,
migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max},
......@@ -1907,10 +1900,8 @@ TEST_CASE(pooling_dyn_shape1)
TEST_CASE(pooling_dyn_shape2)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {5, 5, 0}, {3, 3, 3}, {3, 3, 0}}};
migraphx::shape output{migraphx::shape::float_type,
{{1, 4, 0}, {5, 5, 0}, {2, 2, 2}, {2, 2, 0}}};
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {5, 5}, {3, 3, {3}}, {3, 3}}};
migraphx::shape output{migraphx::shape::float_type, {{1, 4}, {5, 5}, {2, 2}, {2, 2}}};
expect_shape(output,
migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max},
......@@ -1924,9 +1915,8 @@ TEST_CASE(pooling_dyn_shape2)
TEST_CASE(pooling_dyn_shape3)
{
migraphx::shape input{migraphx::shape::float_type,
{{4, 4, 0}, {3, 3, 0}, {4, 12, 8}, {4, 12, 8}}};
migraphx::shape output{migraphx::shape::float_type,
{{4, 4, 0}, {3, 3, 0}, {2, 4, 3}, {2, 4, 3}}};
{{4, 4}, {3, 3}, {4, 12, {8}}, {4, 12, {8}}}};
migraphx::shape output{migraphx::shape::float_type, {{4, 4}, {3, 3}, {2, 4}, {2, 4}}};
expect_shape(output,
migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max},
......@@ -1939,9 +1929,8 @@ TEST_CASE(pooling_dyn_shape3)
TEST_CASE(pooling_dyn_shape4)
{
migraphx::shape input{migraphx::shape::float_type,
{{4, 4, 0}, {3, 3, 0}, {4, 12, 8}, {4, 12, 8}}};
migraphx::shape output{migraphx::shape::float_type,
{{4, 4, 0}, {3, 3, 0}, {3, 6, 4}, {3, 6, 4}}};
{{4, 4}, {3, 3}, {4, 12, {8}}, {4, 12, {8}}}};
migraphx::shape output{migraphx::shape::float_type, {{4, 4}, {3, 3}, {3, 6}, {3, 6}}};
expect_shape(output,
migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max},
......@@ -2061,32 +2050,32 @@ template <class T>
void test_dyn_reduce_ops()
{
{
migraphx::shape input{migraphx::shape::float_type, {{2, 3, 3}, {2, 4, 4}}};
expect_shape(migraphx::shape{migraphx::shape::float_type,
std::vector<migraphx::shape::dynamic_dimension>(
{{2, 3, 3}, {1, 1, 0}})},
T{{-1}},
input);
migraphx::shape input{migraphx::shape::float_type, {{2, 3, {3}}, {2, 4, {4}}}};
expect_shape(
migraphx::shape{migraphx::shape::float_type,
std::vector<migraphx::shape::dynamic_dimension>({{2, 3, {3}}, {1, 1}})},
T{{-1}},
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {{2, 3, 3}, {2, 4, 4}}};
expect_shape(migraphx::shape{migraphx::shape::float_type,
std::vector<migraphx::shape::dynamic_dimension>(
{{1, 1, 0}, {2, 4, 4}})},
T{{0}},
input);
migraphx::shape input{migraphx::shape::float_type, {{2, 3, {3}}, {2, 4, {4}}}};
expect_shape(
migraphx::shape{migraphx::shape::float_type,
std::vector<migraphx::shape::dynamic_dimension>({{1, 1}, {2, 4, {4}}})},
T{{0}},
input);
}
{
// Empty axis argument reduces all axes
migraphx::shape input{migraphx::shape::float_type, {{2, 3, 3}, {2, 4, 4}}};
expect_shape(migraphx::shape{migraphx::shape::float_type,
std::vector<migraphx::shape::dynamic_dimension>(
{{1, 1, 0}, {1, 1, 0}})},
T{{}},
input);
migraphx::shape input{migraphx::shape::float_type, {{2, 3, {3}}, {2, 4, {4}}}};
expect_shape(
migraphx::shape{migraphx::shape::float_type,
std::vector<migraphx::shape::dynamic_dimension>({{1, 1}, {1, 1}})},
T{{}},
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {{2, 3, 3}, {2, 4, 4}}};
migraphx::shape input{migraphx::shape::float_type, {{2, 3, {3}}, {2, 4, {4}}}};
throws_shape(T{{4}}, input);
}
}
......@@ -2114,7 +2103,7 @@ TEST_CASE(reshape_shape)
}
for(auto&& new_shape :
std::vector<std::vector<int64_t>>{{8, 3, 2, 2}, {1, 3, -1, -1}, {3, 0, 0}, {3, 2, 0}})
std::vector<std::vector<int64_t>>{{8, 3, 2, 2}, {1, 3, -1, -1}, {3, 0}, {3, 2}})
{
throws_shape(migraphx::make_op("reshape", {{"dims", new_shape}}), input);
}
......@@ -2138,8 +2127,7 @@ TEST_CASE(reshape_shape)
TEST_CASE(reshape_dyn_shape)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {24, 24, 0}, {1, 1, 0}, {1, 1, 0}}};
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {24, 24}, {1, 1}, {1, 1}}};
for(auto&& new_shape : std::vector<std::vector<int64_t>>{
{-1, 1, 1, 24}, {0, 8, 3, 1}, {-1, 3, 4, 2}, {0, 2, 4, 3}})
{
......@@ -2153,7 +2141,7 @@ TEST_CASE(reshape_dyn_shape)
else
{
std::size_t d = new_shape[i];
out_dyn_dims.push_back({d, d, 0});
out_dyn_dims.push_back({d, d});
}
}
migraphx::shape output{migraphx::shape::float_type, out_dyn_dims};
......@@ -2163,24 +2151,21 @@ TEST_CASE(reshape_dyn_shape)
TEST_CASE(reshape_multiple_non_fixed_error)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {24, 24, 0}, {10, 20, 0}, {1, 1, 0}}};
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {24, 24}, {10, 20}, {1, 1}}};
std::vector<int64_t> new_shape = {0, 1, 0, 24};
throws_shape(migraphx::make_op("reshape", {{"dims", new_shape}}), input);
}
TEST_CASE(reshape_fixed_ele_not_matching_error)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {24, 24, 0}, {10, 10, 0}, {1, 1, 0}}};
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {24, 24}, {10, 10}, {1, 1}}};
std::vector<int64_t> new_shape = {0, 1, 5, 24};
throws_shape(migraphx::make_op("reshape", {{"dims", new_shape}}), input);
}
TEST_CASE(reshape_non_fixed_not_matching_error)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {24, 24, 0}, {1, 1, 0}, {1, 1, 0}}};
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {24, 24}, {1, 1}, {1, 1}}};
std::vector<int64_t> new_shape = {2, 1, 1, 24};
throws_shape(migraphx::make_op("reshape", {{"dims", new_shape}}), input);
}
......@@ -2400,28 +2385,28 @@ TEST_CASE(slice_shape)
TEST_CASE(slice_dyn_shape0)
{
migraphx::shape input{migraphx::shape::int32_type, {{2, 3, 0}, {7, 7, 0}, {2, 3, 0}}};
migraphx::shape input{migraphx::shape::int32_type, {{2, 3}, {7, 7}, {2, 3}}};
// Slice axis 1 to size 4-1=3
expect_shape(migraphx::shape{migraphx::shape::int32_type, {{2, 3, 0}, {3, 3, 0}, {2, 3, 0}}},
expect_shape(migraphx::shape{migraphx::shape::int32_type, {{2, 3}, {3, 3}, {2, 3}}},
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {4}}}),
input);
}
TEST_CASE(slice_dyn_shape1)
{
migraphx::shape input{migraphx::shape::int32_type, {{2, 3, 0}, {7, 7, 0}, {2, 3, 0}}};
migraphx::shape input{migraphx::shape::int32_type, {{2, 3}, {7, 7}, {2, 3}}};
// Slice axis 1 with negative index
expect_shape(migraphx::shape{migraphx::shape::int32_type, {{2, 3, 0}, {2, 2, 0}, {2, 3, 0}}},
expect_shape(migraphx::shape{migraphx::shape::int32_type, {{2, 3}, {2, 2}, {2, 3}}},
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {-4}}}),
input);
}
TEST_CASE(slice_dyn_shape2)
{
migraphx::shape input{migraphx::shape::int32_type, {{2, 3, 0}, {7, 7, 0}, {2, 3, 0}}};
migraphx::shape input{migraphx::shape::int32_type, {{2, 3}, {7, 7}, {2, 3}}};
// Sliced range max bigger than dimension; is clipped
expect_shape(migraphx::shape{migraphx::shape::int32_type, {{2, 3, 0}, {6, 6, 0}, {2, 3, 0}}},
expect_shape(migraphx::shape{migraphx::shape::int32_type, {{2, 3}, {6, 6}, {2, 3}}},
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {10}}}),
input);
}
......@@ -2430,11 +2415,11 @@ TEST_CASE(slice_dyn_shape3)
{
// TODO: When variable dimension slicing is allowed, Slice to a size smaller than min.
// Until then, this action is an error.
migraphx::shape input{migraphx::shape::int32_type, {{2, 3, 0}, {7, 8, 0}, {2, 3, 0}}};
migraphx::shape input{migraphx::shape::int32_type, {{2, 3}, {7, 8}, {2, 3}}};
throws_shape(migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}),
input);
// clang-format off
// expect_shape(migraphx::shape{migraphx::shape::int32_type, {{2, 3, 0}, {1, 1, 0}, {2, 3, 0}}},
// expect_shape(migraphx::shape{migraphx::shape::int32_type, {{2, 3}, {1, 1}, {2, 3}}},
// migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}),
// input);
// clang-format on
......@@ -2442,10 +2427,10 @@ TEST_CASE(slice_dyn_shape3)
TEST_CASE(slice_dyn_shape4)
{
migraphx::shape input{migraphx::shape::int32_type, {{2, 2, 0}, {7, 7, 0}, {2, 3, 0}}};
migraphx::shape input{migraphx::shape::int32_type, {{2, 2}, {7, 7}, {2, 3}}};
// Slice multiple axes: axis 0 to size 2-1=1 and axis 1 to size 4-1=3
expect_shape(
migraphx::shape{migraphx::shape::int32_type, {{1, 1, 0}, {3, 3, 0}, {2, 3, 0}}},
migraphx::shape{migraphx::shape::int32_type, {{1, 1}, {3, 3}, {2, 3}}},
migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {1, 1}}, {"ends", {2, 4}}}),
input);
}
......@@ -2453,7 +2438,7 @@ TEST_CASE(slice_dyn_shape4)
TEST_CASE(slice_dyn_shape5)
{
// Axis out of range.
migraphx::shape input{migraphx::shape::int32_type, {{2, 2, 0}, {7, 7, 0}, {2, 3, 0}}};
migraphx::shape input{migraphx::shape::int32_type, {{2, 2}, {7, 7}, {2, 3}}};
throws_shape(
migraphx::make_op("slice", {{"axes", {0, 20}}, {"starts", {1, 1}}, {"ends", {2, 4}}}),
input);
......@@ -2463,15 +2448,13 @@ TEST_CASE(softmax) { test_softmax_variations<migraphx::op::softmax>(); }
TEST_CASE(softmax_dyn0)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {3, 3, 0}, {4, 4, 0}, {5, 5, 0}}};
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {3, 3}, {4, 4}, {5, 5}}};
expect_shape(input, migraphx::make_op("softmax", {{"axis", 0}}), input);
}
TEST_CASE(softmax_dyn1)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 1, 0}, {3, 3, 0}, {4, 6, 0}, {5, 8, 6}}};
migraphx::shape input{migraphx::shape::float_type, {{1, 1}, {3, 3}, {4, 6}, {5, 8, {6}}}};
expect_shape(input, migraphx::make_op("softmax", {{"axis", 0}}), input);
}
......@@ -2636,7 +2619,7 @@ TEST_CASE(test_gathernd_dynamic0)
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape is{itype, {2, 4}};
std::vector<migraphx::shape::dynamic_dimension> b{{8, 8, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{8, 8}};
migraphx::shape ds{dtype, b};
int batch_dims(1);
......@@ -2649,7 +2632,7 @@ TEST_CASE(test_gathernd_dynamic1)
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape is{itype, {2, 4}};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 2, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 2}};
migraphx::shape ds{dtype, b};
int batch_dims(1);
......@@ -2662,7 +2645,7 @@ TEST_CASE(test_gathernd_dynamic2)
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape is{itype, {2, 1}};
migraphx::shape ds{dtype, {{2, 3, 3}, {5, 6, 5}, {6, 9, 7}, {7, 8, 8}}};
migraphx::shape ds{dtype, {{2, 3, {3}}, {5, 6, {5}}, {6, 9, {7}}, {7, 8, {8}}}};
int batch_dims(3);
throws_shape(migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), ds, is);
......@@ -2674,10 +2657,10 @@ TEST_CASE(test_gathernd_dynamic3)
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape is{itype, {1}};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 2, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 2}};
migraphx::shape ds{dtype, b};
migraphx::shape::dynamic_dimension ddout{1, 1, 0};
migraphx::shape::dynamic_dimension ddout{1, 1};
migraphx::shape s0{dtype, {ddout}};
expect_shape(s0, migraphx::make_op("gathernd"), ds, is);
}
......@@ -2688,10 +2671,10 @@ TEST_CASE(test_gathernd_dynamic4)
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape is{itype, {2, 2}};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 2, 0}, {2, 2, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 2}, {2, 2}};
migraphx::shape ds{dtype, b};
migraphx::shape::dynamic_dimension ddout{2, 2, 0};
migraphx::shape::dynamic_dimension ddout{2, 2};
migraphx::shape s0{dtype, {ddout}};
expect_shape(s0, migraphx::make_op("gathernd"), ds, is);
}
......@@ -2703,10 +2686,10 @@ TEST_CASE(test_gathernd_dynamic5)
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape is{itype, {2, 1}};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 2, 0}, {2, 2, 0}, {2, 2, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 2}, {2, 2}, {2, 2}};
migraphx::shape ds{dtype, b};
std::vector<migraphx::shape::dynamic_dimension> ddout{{2, 2, 0}, {2, 2, 0}};
std::vector<migraphx::shape::dynamic_dimension> ddout{{2, 2}, {2, 2}};
int batch_dims(1);
migraphx::shape s0{dtype, {ddout}};
expect_shape(s0, migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), ds, is);
......@@ -2718,11 +2701,11 @@ TEST_CASE(test_gathernd_dynamic6)
// index dynamic shape, data static
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
std::vector<migraphx::shape::dynamic_dimension> b{{2, 3, 0}, {1, 1, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 3}, {1, 1}};
migraphx::shape is{itype, b};
migraphx::shape ds{dtype, {2, 2, 2}};
std::vector<migraphx::shape::dynamic_dimension> ddout{{2, 3, 0}, {2, 2, 0}};
std::vector<migraphx::shape::dynamic_dimension> ddout{{2, 3}, {2, 2}};
int batch_dims(1);
migraphx::shape s0{dtype, {ddout}};
expect_shape(s0, migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), ds, is);
......@@ -2733,7 +2716,7 @@ TEST_CASE(test_gathernd_dynamic6a)
// indices with non-fixed dynamic dimension k
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
std::vector<migraphx::shape::dynamic_dimension> b{{2, 2, 0}, {1, 3, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{2, 2}, {1, 3}};
migraphx::shape is{itype, b};
migraphx::shape ds{dtype, {2, 2, 2}};
......@@ -2747,12 +2730,12 @@ TEST_CASE(test_gathernd_dynamic7)
// index and data both dynamic shapes
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
std::vector<migraphx::shape::dynamic_dimension> idyn{{2, 5, 0}, {1, 1, 0}};
std::vector<migraphx::shape::dynamic_dimension> idyn{{2, 5}, {1, 1}};
migraphx::shape is{itype, idyn};
std::vector<migraphx::shape::dynamic_dimension> bdyn{{1, 2, 0}, {1, 2, 0}, {1, 2, 0}};
std::vector<migraphx::shape::dynamic_dimension> bdyn{{1, 2}, {1, 2}, {1, 2}};
migraphx::shape ds{dtype, bdyn};
std::vector<migraphx::shape::dynamic_dimension> ddout{{2, 5, 0}, {1, 2, 0}};
std::vector<migraphx::shape::dynamic_dimension> ddout{{2, 5}, {1, 2}};
int batch_dims(1);
migraphx::shape s0{dtype, {ddout}};
expect_shape(s0, migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), ds, is);
......@@ -2765,10 +2748,10 @@ TEST_CASE(test_gathernd_dynamic8)
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape is{itype, {2, 5, 1}};
std::vector<migraphx::shape::dynamic_dimension> b{{6, 7, 7}, {3, 3, 0}, {1, 4, 0}};
std::vector<migraphx::shape::dynamic_dimension> b{{6, 7, {7}}, {3, 3}, {1, 4}};
migraphx::shape ds{dtype, b};
std::vector<migraphx::shape::dynamic_dimension> ddout{{2, 2, 0}, {5, 5, 0}, {1, 4, 0}};
std::vector<migraphx::shape::dynamic_dimension> ddout{{2, 2}, {5, 5}, {1, 4}};
int batch_dims(1);
migraphx::shape s0{dtype, {ddout}};
expect_shape(s0, migraphx::make_op("gathernd", {{"batch_dims", batch_dims}}), ds, is);
......@@ -2848,7 +2831,7 @@ TEST_CASE(test_scatternd_dyn0)
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {4}};
migraphx::shape is{itype, {4, 13}};
migraphx::shape::dynamic_dimension dd{4, 4, 0};
migraphx::shape::dynamic_dimension dd{4, 4};
migraphx::shape us{dtype, {dd}};
throws_shape(migraphx::make_op("scatternd_none"), ds, is, us);
}
......@@ -2860,7 +2843,7 @@ TEST_CASE(test_scatternd_dyn1)
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {8}};
migraphx::shape is{itype, {4, 1}};
migraphx::shape::dynamic_dimension dd{4, 4, 0};
migraphx::shape::dynamic_dimension dd{4, 4};
migraphx::shape us{dtype, {dd}};
expect_shape(ds, migraphx::make_op("scatternd_none"), ds, is, us);
}
......@@ -2873,7 +2856,7 @@ TEST_CASE(test_scatternd_dyn2)
migraphx::shape ds{dtype, {2, 3, 1, 4}, {0, 1, 1, 0}};
migraphx::shape ds_std{dtype, {2, 3, 1, 4}};
migraphx::shape is{itype, {4, 4}};
migraphx::shape::dynamic_dimension dd{4, 4, 0};
migraphx::shape::dynamic_dimension dd{4, 4};
migraphx::shape us{dtype, {dd}};
expect_shape(ds_std, migraphx::make_op("scatternd_none"), ds, is, us);
}
......@@ -2885,7 +2868,7 @@ TEST_CASE(test_scatternd_dyn3)
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {2, 3, 1, 4}};
migraphx::shape is{itype, {4, 4}};
migraphx::shape::dynamic_dimension dd{4, 4, 0};
migraphx::shape::dynamic_dimension dd{4, 4};
migraphx::shape us{dtype, {dd}};
expect_shape(ds, migraphx::make_op("scatternd_none"), ds, is, us);
}
......@@ -2896,7 +2879,7 @@ TEST_CASE(test_scatternd_dyn4)
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {2, 3, 1, 4}};
migraphx::shape::dynamic_dimension dd{4, 5, 0};
migraphx::shape::dynamic_dimension dd{4, 5};
migraphx::shape is{itype, {dd, dd}};
migraphx::shape us{dtype, {dd}};
throws_shape(migraphx::make_op("scatternd_none"), ds, is, us);
......@@ -2908,8 +2891,8 @@ TEST_CASE(test_scatternd_dyn5)
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {2, 3, 1, 4}};
migraphx::shape::dynamic_dimension dd{4, 4, 0};
migraphx::shape::dynamic_dimension dbad{2, 3, 0};
migraphx::shape::dynamic_dimension dd{4, 4};
migraphx::shape::dynamic_dimension dbad{2, 3};
migraphx::shape is{itype, {dd, dd}};
migraphx::shape us{dtype, {dbad}};
throws_shape(migraphx::make_op("scatternd_none"), ds, is, us);
......@@ -2931,12 +2914,11 @@ TEST_CASE(test_squeeze_all)
TEST_CASE(test_squeeze_dyn)
{
migraphx::shape s1{migraphx::shape::float_type,
{{1, 4, 0}, {1, 1, 0}, {3, 3, 0}, {1, 1, 0}, {3, 3, 0}}};
migraphx::shape s2{migraphx::shape::float_type, {{1, 4, 0}, {1, 1, 0}, {3, 3, 0}, {3, 3, 0}}};
migraphx::shape s1{migraphx::shape::float_type, {{1, 4}, {1, 1}, {3, 3}, {1, 1}, {3, 3}}};
migraphx::shape s2{migraphx::shape::float_type, {{1, 4}, {1, 1}, {3, 3}, {3, 3}}};
expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {3}}}), s1);
migraphx::shape s3{migraphx::shape::float_type, {{1, 4, 0}, {3, 3, 0}, {3, 3, 0}}};
migraphx::shape s3{migraphx::shape::float_type, {{1, 4}, {3, 3}, {3, 3}}};
expect_shape(s3, migraphx::make_op("squeeze"), s1);
throws_shape(migraphx::make_op("squeeze", {{"axes", {0}}}), s1);
......@@ -2944,12 +2926,11 @@ TEST_CASE(test_squeeze_dyn)
TEST_CASE(test_squeeze_dyn_neg_axes)
{
migraphx::shape s1{migraphx::shape::float_type,
{{1, 4, 0}, {1, 1, 0}, {3, 3, 0}, {1, 1, 0}, {3, 3, 0}}};
migraphx::shape s2{migraphx::shape::float_type, {{1, 4, 0}, {1, 1, 0}, {3, 3, 0}, {3, 3, 0}}};
migraphx::shape s1{migraphx::shape::float_type, {{1, 4}, {1, 1}, {3, 3}, {1, 1}, {3, 3}}};
migraphx::shape s2{migraphx::shape::float_type, {{1, 4}, {1, 1}, {3, 3}, {3, 3}}};
expect_shape(s2, migraphx::make_op("squeeze", {{"axes", {-2}}}), s1);
migraphx::shape s3{migraphx::shape::float_type, {{1, 4, 0}, {3, 3, 0}, {3, 3, 0}}};
migraphx::shape s3{migraphx::shape::float_type, {{1, 4}, {3, 3}, {3, 3}}};
expect_shape(s3, migraphx::make_op("squeeze", {{"axes", {-2, -4}}}), s1);
}
......@@ -2996,12 +2977,11 @@ TEST_CASE(test_unsqueeze)
TEST_CASE(test_unsqueeze_dyn)
{
migraphx::shape s1{migraphx::shape::float_type, {{1, 4, 3}, {2, 5, 0}, {3, 3, 0}}};
migraphx::shape s2{migraphx::shape::float_type, {{1, 4, 3}, {2, 5, 0}, {1, 1, 0}, {3, 3, 0}}};
migraphx::shape s1{migraphx::shape::float_type, {{1, 4, {3}}, {2, 5}, {3, 3}}};
migraphx::shape s2{migraphx::shape::float_type, {{1, 4, {3}}, {2, 5}, {1, 1}, {3, 3}}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2}}}), s1);
migraphx::shape s3{migraphx::shape::float_type,
{{1, 4, 3}, {2, 5, 0}, {1, 1, 0}, {3, 3, 0}, {1, 1, 0}}};
migraphx::shape s3{migraphx::shape::float_type, {{1, 4, {3}}, {2, 5}, {1, 1}, {3, 3}, {1, 1}}};
expect_shape(s3, migraphx::make_op("unsqueeze", {{"axes", {2, 4}}}), s1);
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {2, 4}}, {"steps", {2}}}), s1);
......@@ -3009,12 +2989,11 @@ TEST_CASE(test_unsqueeze_dyn)
TEST_CASE(test_unsqueeze_dyn_neg_axes)
{
migraphx::shape s1{migraphx::shape::float_type, {{1, 4, 3}, {2, 5, 0}, {3, 3, 0}}};
migraphx::shape s2{migraphx::shape::float_type, {{1, 4, 3}, {2, 5, 0}, {1, 1, 0}, {3, 3, 0}}};
migraphx::shape s1{migraphx::shape::float_type, {{1, 4, {3}}, {2, 5}, {3, 3}}};
migraphx::shape s2{migraphx::shape::float_type, {{1, 4, {3}}, {2, 5}, {1, 1}, {3, 3}}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s1);
migraphx::shape s3{migraphx::shape::float_type,
{{1, 4, 3}, {2, 5, 0}, {1, 1, 0}, {3, 3, 0}, {1, 1, 0}}};
migraphx::shape s3{migraphx::shape::float_type, {{1, 4, {3}}, {2, 5}, {1, 1}, {3, 3}, {1, 1}}};
expect_shape(s3, migraphx::make_op("unsqueeze", {{"axes", {-1, -3}}}), s1);
}
......@@ -3177,16 +3156,16 @@ TEST_CASE(transpose_shape)
TEST_CASE(transpose_dyn_shape0)
{
migraphx::shape input{migraphx::shape::float_type, {{1, 4, 0}, {2, 2, 0}}};
migraphx::shape output{migraphx::shape::float_type, {{2, 2, 0}, {1, 4, 0}}};
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {2, 2}}};
migraphx::shape output{migraphx::shape::float_type, {{2, 2}, {1, 4}}};
expect_shape(input, migraphx::make_op("transpose", {{"permutation", {0, 1}}}), input);
expect_shape(output, migraphx::make_op("transpose", {{"permutation", {1, 0}}}), input);
}
TEST_CASE(transpose_dyn_shape1)
{
migraphx::shape input{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {4, 4, 0}}};
migraphx::shape output{migraphx::shape::float_type, {{4, 4, 0}, {4, 4, 0}, {1, 4, 0}}};
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {4, 4}, {4, 4}}};
migraphx::shape output{migraphx::shape::float_type, {{4, 4}, {4, 4}, {1, 4}}};
expect_shape(input, migraphx::make_op("transpose", {{"permutation", {0, 1, 2}}}), input);
expect_shape(output, migraphx::make_op("transpose", {{"permutation", {2, 1, 0}}}), input);
}
......@@ -3243,8 +3222,8 @@ TEST_CASE(where_broadcast_input)
TEST_CASE(where_dyn_input0)
{
// dynamic shapes not the same
migraphx::shape s1{migraphx::shape::float_type, {{2, 3, 0}, {3, 3, 0}}};
migraphx::shape s2{migraphx::shape::float_type, {{2, 3, 0}, {2, 3, 0}}};
migraphx::shape s1{migraphx::shape::float_type, {{2, 3}, {3, 3}}};
migraphx::shape s2{migraphx::shape::float_type, {{2, 3}, {2, 3}}};
migraphx::shape s3{migraphx::shape::bool_type, {2, 2}};
throws_shape(migraphx::make_op("where"), s3, s1, s2);
}
......@@ -3253,7 +3232,7 @@ TEST_CASE(where_dyn_input1)
{
// mixed static/dynamic inputs (not allowed)
migraphx::shape s1{migraphx::shape::float_type, {2, 2}, {2, 1}};
migraphx::shape s2{migraphx::shape::float_type, {{2, 2, 0}, {2, 2, 0}}};
migraphx::shape s2{migraphx::shape::float_type, {{2, 2}, {2, 2}}};
migraphx::shape s3{migraphx::shape::bool_type, {2, 2}, {2, 1}};
throws_shape(migraphx::make_op("where"), s3, s1, s2);
}
......@@ -3261,18 +3240,18 @@ TEST_CASE(where_dyn_input1)
TEST_CASE(where_dyn_input2)
{
// dynamic shapes
migraphx::shape s1{migraphx::shape::float_type, {{2, 3, 0}, {3, 3, 0}}};
migraphx::shape s2{migraphx::shape::float_type, {{2, 3, 0}, {3, 3, 0}}};
migraphx::shape s3{migraphx::shape::bool_type, {{2, 3, 0}, {3, 3, 0}}};
migraphx::shape s1{migraphx::shape::float_type, {{2, 3}, {3, 3}}};
migraphx::shape s2{migraphx::shape::float_type, {{2, 3}, {3, 3}}};
migraphx::shape s3{migraphx::shape::bool_type, {{2, 3}, {3, 3}}};
expect_shape(s2, migraphx::make_op("where"), s3, s1, s2);
}
TEST_CASE(where_dyn_input3)
{
// dynamic shapes, predicate shape is different
migraphx::shape s1{migraphx::shape::float_type, {{2, 3, 0}, {3, 3, 0}}};
migraphx::shape s2{migraphx::shape::float_type, {{2, 3, 0}, {3, 3, 0}}};
migraphx::shape s3{migraphx::shape::bool_type, {{2, 3, 0}, {3, 4, 0}}};
migraphx::shape s1{migraphx::shape::float_type, {{2, 3}, {3, 3}}};
migraphx::shape s2{migraphx::shape::float_type, {{2, 3}, {3, 3}}};
migraphx::shape s3{migraphx::shape::bool_type, {{2, 3}, {3, 4}}};
throws_shape(migraphx::make_op("where"), s3, s1, s2);
}
......@@ -3325,9 +3304,9 @@ TEST_CASE(test_concat)
TEST_CASE(test_dyn_concat)
{
migraphx::shape sx{migraphx::shape::float_type, {{1, 3, 3}, {4, 4}, {1, 5, 5}, {6, 6}}};
migraphx::shape sy{migraphx::shape::float_type, {{1, 3, 3}, {4, 4}, {1, 4, 4}, {6, 6}}};
migraphx::shape sout{migraphx::shape::float_type, {{1, 3, 3}, {4, 4, 0}, {2, 9, 0}, {6, 6}}};
migraphx::shape sx{migraphx::shape::float_type, {{1, 3, {3}}, {4, 4}, {1, 5, {5}}, {6, 6}}};
migraphx::shape sy{migraphx::shape::float_type, {{1, 3, {3}}, {4, 4}, {1, 4, {4}}, {6, 6}}};
migraphx::shape sout{migraphx::shape::float_type, {{1, 3, {3}}, {4, 4}, {2, 9}, {6, 6}}};
expect_shape(sout, migraphx::make_op("concat", {{"axis", 2}}), sx, sy);
......@@ -3335,7 +3314,7 @@ TEST_CASE(test_dyn_concat)
throws_shape(migraphx::make_op("concat", {{"axis", 4}}), sx, sy);
// rank doesn't match
migraphx::shape srank{migraphx::shape::int64_type, {{1, 3, 3}, {4, 4}}};
migraphx::shape srank{migraphx::shape::int64_type, {{1, 3, {3}}, {4, 4}}};
throws_shape(migraphx::make_op("concat", {{"axis", 0}}), sx, srank);
// non-matching dimension 2
......
......@@ -1197,7 +1197,7 @@ TEST_CASE(dot_dyn_2D_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::float_type, {{1, 4, 0}, {5, 5, 0}}};
migraphx::shape a_shape{migraphx::shape::float_type, {{1, 4}, {5, 5}}};
auto ap = mm->add_parameter("a", a_shape);
migraphx::shape b_shape{migraphx::shape::float_type, {5, 3}};
auto bp = mm->add_parameter("b", b_shape);
......@@ -1250,8 +1250,7 @@ TEST_CASE(dot_dyn_4D_test)
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::float_type,
{{1, 1, 0}, {1, 1, 0}, {4, 6, 4}, {5, 5, 0}}};
migraphx::shape a_shape{migraphx::shape::float_type, {{1, 1}, {1, 1}, {4, 6, {4}}, {5, 5}}};
auto al = mm->add_parameter("a", a_shape);
migraphx::shape b_shape{migraphx::shape::float_type, {1, 1, 5, 3}};
auto bl = mm->add_parameter("b", b_shape);
......
......@@ -64,7 +64,7 @@ TEST_CASE(abs_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{2, 8, 0}, {2, 2, 0}}};
migraphx::shape s{migraphx::shape::float_type, {{2, 8}, {2, 2}}};
auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("abs"), input);
p.compile(migraphx::make_target("ref"));
......@@ -102,7 +102,7 @@ TEST_CASE(acos_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0};
migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("acos"), input);
......@@ -143,7 +143,7 @@ TEST_CASE(acosh_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0};
migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s);
std::vector<float> input_data{1.1f, 1.2f, 2.0f};
......@@ -230,7 +230,7 @@ TEST_CASE(add_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, 0}};
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6}};
migraphx::shape s{migraphx::shape::float_type, dd};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
......@@ -330,7 +330,7 @@ TEST_CASE(argmax_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{2, 2, 0}, {3, 6, 0}, {3, 6, 0}}};
migraphx::shape s{migraphx::shape::float_type, {{2, 2}, {3, 6}, {3, 6}}};
auto dl = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("argmax", {{"axis", 0}}), dl);
p.compile(migraphx::make_target("ref"));
......@@ -446,7 +446,7 @@ TEST_CASE(asin_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0};
migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("asin"), input);
......@@ -487,7 +487,7 @@ TEST_CASE(asinh_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0};
migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("asinh"), input);
......@@ -528,7 +528,7 @@ TEST_CASE(atan_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0};
migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("atan"), input);
......@@ -569,7 +569,7 @@ TEST_CASE(atanh_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0};
migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("atanh"), input);
......@@ -615,7 +615,7 @@ TEST_CASE(avgpool_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {3, 3, 0}, {4, 4, 0}}};
auto s = migraphx::shape{migraphx::shape::float_type, {{1, 4}, {3, 3}, {4, 4}}};
auto x = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::average},
......@@ -767,7 +767,7 @@ TEST_CASE(broadcast_2in_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int32_type, {{2, 2, 0}, {2, 4, 0}}};
migraphx::shape a_shape{migraphx::shape::int32_type, {{2, 2}, {2, 4}}};
migraphx::shape b_shape{migraphx::shape::int32_type, {2}};
std::vector<int32_t> b_data{-2, -3};
uint64_t axis = 0;
......@@ -810,7 +810,7 @@ TEST_CASE(ceil_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{4, 12, 0};
migraphx::shape::dynamic_dimension dd{4, 12};
migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("ceil"), input);
......@@ -958,9 +958,9 @@ TEST_CASE(concat_dyn_test)
migraphx::program p;
auto* mm = p.get_main_module();
int axis = 0;
migraphx::shape s0{migraphx::shape::int32_type, {{2, 4, 2}, {2, 3, 2}}};
migraphx::shape s1{migraphx::shape::int32_type, {{3, 4, 4}, {2, 3, 2}}};
migraphx::shape s2{migraphx::shape::int32_type, {{1, 5, 3}, {2, 3, 2}}};
migraphx::shape s0{migraphx::shape::int32_type, {{2, 4, {2}}, {2, 3, {2}}}};
migraphx::shape s1{migraphx::shape::int32_type, {{3, 4, {4}}, {2, 3, {2}}}};
migraphx::shape s2{migraphx::shape::int32_type, {{1, 5, {3}}, {2, 3, {2}}}};
auto input0 = mm->add_parameter("X", s0);
auto input1 = mm->add_parameter("Y", s1);
......@@ -1039,8 +1039,7 @@ TEST_CASE(contiguous_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape dyn_shape{migraphx::shape::float_type,
{{1, 1, 0}, {2, 6, 0}, {2, 2, 0}, {2, 2, 0}}};
migraphx::shape dyn_shape{migraphx::shape::float_type, {{1, 1}, {2, 6}, {2, 2}, {2, 2}}};
auto input = mm->add_parameter("X", dyn_shape);
mm->add_instruction(migraphx::make_op("contiguous"), input);
p.compile(migraphx::make_target("ref"));
......@@ -1068,7 +1067,7 @@ TEST_CASE(conv_dyn_batch_test)
auto* mm = p.get_main_module();
migraphx::shape input_dyn_shape{migraphx::shape::float_type,
{{1, 100, 0}, {3, 3, 0}, {4, 4, 0}, {4, 4, 0}}};
{{1, 100}, {3, 3}, {4, 4}, {4, 4}}};
migraphx::shape weights_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
auto input = mm->add_parameter("X", input_dyn_shape);
......@@ -1184,8 +1183,7 @@ TEST_CASE(conv_dyn_img_shape_test)
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape input_dyn_shape{migraphx::shape::float_type,
{{1, 1, 0}, {3, 3, 0}, {4, 6, 0}, {4, 6, 0}}};
migraphx::shape input_dyn_shape{migraphx::shape::float_type, {{1, 1}, {3, 3}, {4, 6}, {4, 6}}};
migraphx::shape weights_shape{migraphx::shape::float_type, {1, 3, 3, 3}};
auto input = mm->add_parameter("X", input_dyn_shape);
......@@ -1274,8 +1272,7 @@ TEST_CASE(conv_dyn_weights_shape_test)
auto* mm = p.get_main_module();
migraphx::shape input_shape{migraphx::shape::float_type, {1, 3, 4, 4}};
migraphx::shape weights_shape{migraphx::shape::float_type,
{{1, 1, 0}, {3, 3, 0}, {2, 3, 0}, {2, 3, 0}}};
migraphx::shape weights_shape{migraphx::shape::float_type, {{1, 1}, {3, 3}, {2, 3}, {2, 3}}};
auto input = mm->add_parameter("X", input_shape);
auto weights = mm->add_parameter("W", weights_shape);
......@@ -1350,8 +1347,7 @@ TEST_CASE(conv_dyn_img_same_upper_test)
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape input_dyn_shape{migraphx::shape::float_type,
{{1, 1, 0}, {3, 3, 0}, {4, 6, 0}, {4, 6, 0}}};
migraphx::shape input_dyn_shape{migraphx::shape::float_type, {{1, 1}, {3, 3}, {4, 6}, {4, 6}}};
migraphx::shape weights_shape{migraphx::shape::float_type, {1, 3, 3, 3}};
auto input = mm->add_parameter("X", input_dyn_shape);
......@@ -1422,8 +1418,7 @@ TEST_CASE(conv_dyn_kernel_same_upper_test)
auto* mm = p.get_main_module();
migraphx::shape input_shape{migraphx::shape::float_type, {1, 3, 4, 4}};
migraphx::shape weights_shape{migraphx::shape::float_type,
{{1, 1, 0}, {3, 3, 0}, {2, 3, 0}, {2, 3, 0}}};
migraphx::shape weights_shape{migraphx::shape::float_type, {{1, 1}, {3, 3}, {2, 3}, {2, 3}}};
auto input = mm->add_parameter("X", input_shape);
auto weights = mm->add_parameter("W", weights_shape);
......@@ -1496,8 +1491,7 @@ TEST_CASE(conv_dyn_kernel_same_lower_test)
auto* mm = p.get_main_module();
migraphx::shape input_shape{migraphx::shape::float_type, {1, 3, 4, 4}};
migraphx::shape weights_shape{migraphx::shape::float_type,
{{1, 1, 0}, {3, 3, 0}, {2, 3, 0}, {2, 3, 0}}};
migraphx::shape weights_shape{migraphx::shape::float_type, {{1, 1}, {3, 3}, {2, 3}, {2, 3}}};
auto input = mm->add_parameter("X", input_shape);
auto weights = mm->add_parameter("W", weights_shape);
......@@ -1839,7 +1833,7 @@ TEST_CASE(cos_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0};
migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("cos"), input);
......@@ -1880,7 +1874,7 @@ TEST_CASE(cosh_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0};
migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("cosh"), input);
......@@ -2071,7 +2065,7 @@ TEST_CASE(div_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, 3}};
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, {3}}};
migraphx::shape s{migraphx::shape::float_type, dd};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
......@@ -2113,7 +2107,7 @@ TEST_CASE(elu_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0};
migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s);
float alpha = 0.5;
......@@ -2184,7 +2178,7 @@ TEST_CASE(equal_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{6, 12, 9}};
std::vector<migraphx::shape::dynamic_dimension> dd{{6, 12, {9}}};
migraphx::shape s{migraphx::shape::float_type, dd};
auto p0 = mm->add_parameter("l", s);
auto p1 = mm->add_parameter("r", s);
......@@ -2231,7 +2225,7 @@ TEST_CASE(erf_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0};
migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("erf"), input);
......@@ -2272,7 +2266,7 @@ TEST_CASE(exp_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0};
migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("exp"), input);
......@@ -2313,7 +2307,7 @@ TEST_CASE(floor_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{5, 12, 0};
migraphx::shape::dynamic_dimension dd{5, 12};
migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("floor"), input);
......@@ -2564,7 +2558,7 @@ TEST_CASE(gather_dyn_test0)
// Dynamic data, static indices
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int32_type, {{2, 5, 0}, {3, 3, 0}}};
migraphx::shape s{migraphx::shape::int32_type, {{2, 5}, {3, 3}}};
auto x = mm->add_parameter("x", s);
std::vector<int> indices{1, 2};
......@@ -2573,7 +2567,7 @@ TEST_CASE(gather_dyn_test0)
auto ind = mm->add_parameter("indices", s_ind);
mm->add_instruction(migraphx::make_op("gather", {{"axis", 1}}), x, ind);
migraphx::shape sresult{migraphx::shape::int32_type, {{2, 5, 0}, {1, 1, 0}, {2, 2, 0}}};
migraphx::shape sresult{migraphx::shape::int32_type, {{2, 5}, {1, 1}, {2, 2}}};
EXPECT(p.get_output_shapes().back() == sresult);
p.compile(migraphx::make_target("ref"));
......@@ -2599,15 +2593,15 @@ TEST_CASE(gather_dyn_test1)
// Dynamic data, dynamic indices
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int32_type, {{2, 5, 0}, {4, 4, 0}}};
migraphx::shape s{migraphx::shape::int32_type, {{2, 5}, {4, 4}}};
auto x = mm->add_parameter("x", s);
migraphx::shape s_ind{migraphx::shape::int32_type, {{1, 8, 7}, {2, 3, 3}}};
migraphx::shape s_ind{migraphx::shape::int32_type, {{1, 8, {7}}, {2, 3, {3}}}};
auto ind = mm->add_parameter("indices", s_ind);
mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), x, ind);
migraphx::shape sresult{migraphx::shape::int32_type, {{1, 8, 7}, {2, 3, 3}, {4, 4, 0}}};
migraphx::shape sresult{migraphx::shape::int32_type, {{1, 8, {7}}, {2, 3, {3}}, {4, 4}}};
EXPECT(p.get_output_shapes().back() == sresult);
p.compile(migraphx::make_target("ref"));
......@@ -2787,7 +2781,7 @@ TEST_CASE(gathernd_dynamic0)
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ds{migraphx::shape::float_type, {{2, 2, 2}, {3, 3, 0}, {1, 1, 0}}};
migraphx::shape ds{migraphx::shape::float_type, {{2, 2, {2}}, {3, 3}, {1, 1}}};
migraphx::shape is{migraphx::shape::int64_type, {2, 2, 1}};
auto xdata = mm->add_parameter("X", ds);
......@@ -2824,7 +2818,7 @@ TEST_CASE(gathernd_dynamic1)
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ds{migraphx::shape::float_type, {{2, 5, 2}, {1, 5, 0}, {1, 5, 0}}};
migraphx::shape ds{migraphx::shape::float_type, {{2, 5, {2}}, {1, 5}, {1, 5}}};
migraphx::shape is{migraphx::shape::int64_type, {2, 2, 1}};
auto xdata = mm->add_parameter("X", ds);
......@@ -2860,8 +2854,8 @@ TEST_CASE(gathernd_dynamic2)
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ds{migraphx::shape::float_type, {{2, 5, 2}, {1, 5, 0}, {1, 5, 0}}};
migraphx::shape is{migraphx::shape::int64_type, {{2, 5, 3}, {2, 3, 3}, {1, 1}}};
migraphx::shape ds{migraphx::shape::float_type, {{2, 5, {2}}, {1, 5}, {1, 5}}};
migraphx::shape is{migraphx::shape::int64_type, {{2, 5, {3}}, {2, 3, {3}}, {1, 1}}};
auto xdata = mm->add_parameter("X", ds);
auto xindex = mm->add_parameter("I", is);
......@@ -2897,7 +2891,7 @@ TEST_CASE(gathernd_dynamic3)
auto* mm = p.get_main_module();
migraphx::shape ds{migraphx::shape::float_type, {2, 3, 1}};
migraphx::shape is{migraphx::shape::int64_type, {{2, 5, 3}, {2, 3, 3}, {1, 1}}};
migraphx::shape is{migraphx::shape::int64_type, {{2, 5, {3}}, {2, 3, {3}}, {1, 1}}};
auto xdata = mm->add_parameter("X", ds);
auto xindex = mm->add_parameter("I", is);
......@@ -2932,8 +2926,7 @@ TEST_CASE(gathernd_dynamic4)
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape ds{migraphx::shape::float_type,
{migraphx::shape::dynamic_dimension({2, 2, 0})}};
migraphx::shape ds{migraphx::shape::float_type, {migraphx::shape::dynamic_dimension({2, 2})}};
migraphx::shape is{migraphx::shape::int64_type, {1}};
auto xdata = mm->add_parameter("X", ds);
......@@ -3034,9 +3027,8 @@ TEST_CASE(globalavgpool_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto s =
migraphx::shape{migraphx::shape::float_type, {{1, 1, 0}, {3, 3, 0}, {2, 6, 0}, {2, 6, 2}}};
auto x = mm->add_parameter("X", s);
auto s = migraphx::shape{migraphx::shape::float_type, {{1, 1}, {3, 3}, {2, 6}, {2, 6, {2}}}};
auto x = mm->add_parameter("X", s);
mm->add_instruction(
migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::average}, {"dyn_global", true}}),
......@@ -3081,7 +3073,7 @@ TEST_CASE(globallppool_dyn_test)
migraphx::program p;
auto* mm = p.get_main_module();
auto s =
migraphx::shape{migraphx::shape::float_type, {{1, 1, 0}, {3, 3, 0}, {2, 6, 2}, {2, 6, 2}}};
migraphx::shape{migraphx::shape::float_type, {{1, 1}, {3, 3}, {2, 6, {2}}, {2, 6, {2}}}};
auto x = mm->add_parameter("X", s);
mm->add_instruction(
migraphx::make_op("pooling",
......@@ -3126,7 +3118,7 @@ TEST_CASE(globalmaxpool_dyn_test)
migraphx::program p;
auto* mm = p.get_main_module();
auto s =
migraphx::shape{migraphx::shape::float_type, {{1, 1, 0}, {3, 3, 0}, {2, 6, 2}, {2, 6, 2}}};
migraphx::shape{migraphx::shape::float_type, {{1, 1}, {3, 3}, {2, 6, {2}}, {2, 6, {2}}}};
auto x = mm->add_parameter("X", s);
mm->add_instruction(
migraphx::make_op("pooling",
......@@ -3198,7 +3190,7 @@ TEST_CASE(greater_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{8, 10, 9}};
std::vector<migraphx::shape::dynamic_dimension> dd{{8, 10, {9}}};
migraphx::shape s{migraphx::shape::float_type, dd};
auto left = mm->add_parameter("l", s);
auto right = mm->add_parameter("r", s);
......@@ -3242,7 +3234,7 @@ TEST_CASE(identity_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{2, 4, 0}, {2, 4, 0}}};
migraphx::shape s{migraphx::shape::float_type, {{2, 4}, {2, 4}}};
auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("identity"), input);
p.compile(migraphx::make_target("ref"));
......@@ -3488,7 +3480,7 @@ TEST_CASE(isnan_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{2, 2, 0}, {3, 8, 0}}};
migraphx::shape s{migraphx::shape::float_type, {{2, 2}, {3, 8}}};
auto input = mm->add_parameter("X", s);
auto nan_val = std::numeric_limits<float>::quiet_NaN();
mm->add_instruction(migraphx::make_op("isnan"), input);
......@@ -3807,7 +3799,7 @@ TEST_CASE(less_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{8, 10, 9}};
std::vector<migraphx::shape::dynamic_dimension> dd{{8, 10, {9}}};
migraphx::shape s{migraphx::shape::float_type, dd};
auto left = mm->add_parameter("l", s);
auto right = mm->add_parameter("r", s);
......@@ -3859,7 +3851,7 @@ TEST_CASE(log_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0};
migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("log"), input);
......@@ -3904,7 +3896,7 @@ TEST_CASE(logical_and_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, 4}};
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, {4}}};
migraphx::shape s{migraphx::shape::bool_type, dd};
auto left = mm->add_parameter("l", s);
auto right = mm->add_parameter("r", s);
......@@ -3955,7 +3947,7 @@ TEST_CASE(logical_or_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, 4}};
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, {4}}};
migraphx::shape s{migraphx::shape::bool_type, dd};
auto left = mm->add_parameter("l", s);
auto right = mm->add_parameter("r", s);
......@@ -4006,7 +3998,7 @@ TEST_CASE(logical_xor_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, 4}};
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, {4}}};
migraphx::shape s{migraphx::shape::bool_type, dd};
auto left = mm->add_parameter("l", s);
auto right = mm->add_parameter("r", s);
......@@ -4227,7 +4219,7 @@ TEST_CASE(lppool_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {3, 3, 0}, {4, 4, 0}}};
auto s = migraphx::shape{migraphx::shape::float_type, {{1, 4}, {3, 3}, {4, 4}}};
auto x = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::lpnorm},
......@@ -4294,7 +4286,7 @@ TEST_CASE(max_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, 0}};
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6}};
migraphx::shape s{migraphx::shape::float_type, dd};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
......@@ -4497,7 +4489,7 @@ TEST_CASE(maxpool_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {3, 3, 0}, {4, 4, 0}}};
auto s = migraphx::shape{migraphx::shape::float_type, {{1, 4}, {3, 3}, {4, 4}}};
auto x = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::max},
......@@ -4540,7 +4532,7 @@ TEST_CASE(min_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, 0}};
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6}};
migraphx::shape s{migraphx::shape::float_type, dd};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
......@@ -4586,7 +4578,7 @@ TEST_CASE(fmod_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, 0}};
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6}};
migraphx::shape s{migraphx::shape::float_type, dd};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
......@@ -4651,7 +4643,7 @@ TEST_CASE(mod_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, 0}};
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6}};
migraphx::shape s{migraphx::shape::float_type, dd};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
......@@ -4720,7 +4712,7 @@ TEST_CASE(mul_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, 0}};
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6}};
migraphx::shape s{migraphx::shape::float_type, dd};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
......@@ -4790,7 +4782,7 @@ TEST_CASE(multibroadcast_2in_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int32_type, {{2, 4, 0}, {2, 2, 0}}};
migraphx::shape a_shape{migraphx::shape::int32_type, {{2, 4}, {2, 2}}};
migraphx::shape b_shape{migraphx::shape::int32_type, {2}};
std::vector<int32_t> b_data{-2, -3};
auto l1 = mm->add_parameter("a", a_shape);
......@@ -4882,7 +4874,7 @@ TEST_CASE(neg_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{2, 4, 0}, {3, 3, 0}}};
migraphx::shape s{migraphx::shape::float_type, {{2, 4}, {3, 3}}};
auto input = mm->add_parameter("X", s);
auto ret = mm->add_instruction(migraphx::make_op("neg"), input);
mm->add_return({ret});
......@@ -4939,9 +4931,9 @@ TEST_CASE(nms_dyn_batch_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape boxes_s{migraphx::shape::float_type, {{1, 3, 0}, {6, 6, 0}, {4, 4, 0}}};
migraphx::shape boxes_s{migraphx::shape::float_type, {{1, 3}, {6, 6}, {4, 4}}};
migraphx::shape scores_s{migraphx::shape::float_type, {{1, 3, 0}, {1, 1, 0}, {6, 6, 0}}};
migraphx::shape scores_s{migraphx::shape::float_type, {{1, 3}, {1, 1}, {6, 6}}};
auto boxes_p = mm->add_parameter("boxes", boxes_s);
auto scores_p = mm->add_parameter("scores", scores_s);
......@@ -4985,9 +4977,9 @@ TEST_CASE(nms_dyn_boxes_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape boxes_s{migraphx::shape::float_type, {{1, 1, 0}, {4, 20, 0}, {4, 4, 0}}};
migraphx::shape boxes_s{migraphx::shape::float_type, {{1, 1}, {4, 20}, {4, 4}}};
migraphx::shape scores_s{migraphx::shape::float_type, {{1, 1, 0}, {1, 1, 0}, {4, 20, 0}}};
migraphx::shape scores_s{migraphx::shape::float_type, {{1, 1}, {1, 1}, {4, 20}}};
auto boxes_p = mm->add_parameter("boxes", boxes_s);
auto scores_p = mm->add_parameter("scores", scores_s);
......@@ -5028,9 +5020,9 @@ TEST_CASE(nms_dyn_classes_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape boxes_s{migraphx::shape::float_type, {{1, 1, 0}, {6, 6, 0}, {4, 4, 0}}};
migraphx::shape boxes_s{migraphx::shape::float_type, {{1, 1}, {6, 6}, {4, 4}}};
migraphx::shape scores_s{migraphx::shape::float_type, {{1, 1, 0}, {1, 3, 0}, {6, 6, 0}}};
migraphx::shape scores_s{migraphx::shape::float_type, {{1, 1}, {1, 3}, {6, 6}}};
auto boxes_p = mm->add_parameter("boxes", boxes_s);
auto scores_p = mm->add_parameter("scores", scores_s);
......@@ -5274,7 +5266,7 @@ TEST_CASE(not_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0};
migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("not"), input);
......@@ -5363,7 +5355,7 @@ TEST_CASE(pad_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{2, 4, 2}, {2, 4, 2}}};
migraphx::shape s{migraphx::shape::float_type, {{2, 4, {2}}, {2, 4, {2}}}};
auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("pad", {{"pads", {1, 1, 1, 1}}}), x);
p.compile(migraphx::make_target("ref"));
......@@ -5827,7 +5819,7 @@ TEST_CASE(prelu_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, 0}};
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6}};
migraphx::shape s{migraphx::shape::float_type, dd};
auto x = mm->add_parameter("x", s);
auto slope = mm->add_parameter("slope", s);
......@@ -6028,7 +6020,7 @@ TEST_CASE(recip_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0};
migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("recip"), input);
......@@ -6065,7 +6057,7 @@ TEST_CASE(reduce_max_dynamic_axis0)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{2, 4, 2}, {3, 5, 3}}};
migraphx::shape s{migraphx::shape::float_type, {{2, 4, {2}}, {3, 5, {3}}}};
auto input = mm->add_parameter("X", s);
auto reduce_max_op = migraphx::make_op("reduce_max", {{"axes", {0}}});
mm->add_instruction(reduce_max_op, input);
......@@ -6357,7 +6349,7 @@ TEST_CASE(relu_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0};
migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("relu"), input);
......@@ -6429,7 +6421,7 @@ TEST_CASE(reshape_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{1, 4, 0}, {24, 24, 0}, {1, 1, 0}, {1, 1, 0}}};
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {24, 24}, {1, 1}, {1, 1}}};
std::vector<int64_t> new_shape = {0, 8, 3, 1};
auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("reshape", {{"dims", new_shape}}), input);
......@@ -6691,7 +6683,7 @@ TEST_CASE(round_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{4, 10, 0};
migraphx::shape::dynamic_dimension dd{4, 10};
migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("round"), input);
......@@ -6727,7 +6719,7 @@ TEST_CASE(rsqrt_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0};
migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("rsqrt"), input);
......@@ -7466,10 +7458,10 @@ TEST_CASE(scatternd_reduction_dyn_test)
auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type;
migraphx::shape::dynamic_dimension dd{3, 6, 0};
migraphx::shape::dynamic_dimension dd{3, 6};
migraphx::shape ds{migraphx::shape::float_type, {dd, dd, dd}};
migraphx::shape is{itype, {2, 1}};
migraphx::shape us{dtype, {{2, 2, 0}, dd, dd}};
migraphx::shape us{dtype, {{2, 2}, dd, dd}};
auto xdata = mm->add_parameter("X", ds);
auto xindex = mm->add_parameter("I", is);
......@@ -7523,7 +7515,7 @@ TEST_CASE(sigmoid_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{2, 4, 0}, {2, 2, 0}}};
migraphx::shape s{migraphx::shape::float_type, {{2, 4}, {2, 2}}};
auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("sigmoid"), input);
p.compile(migraphx::make_target("ref"));
......@@ -7559,7 +7551,7 @@ TEST_CASE(sign_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0};
migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("sign"), input);
......@@ -7598,7 +7590,7 @@ TEST_CASE(sin_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0};
migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("sin"), input);
......@@ -7639,7 +7631,7 @@ TEST_CASE(sinh_dynamic_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{2, 4, 0}, {2, 4, 0}}};
migraphx::shape s{migraphx::shape::float_type, {{2, 4}, {2, 4}}};
auto input = mm->add_parameter("X", s);
std::vector<float> input_data{-1.0, 2.0, -3.0, 4.0};
mm->add_instruction(migraphx::make_op("sinh"), input);
......@@ -7709,11 +7701,11 @@ TEST_CASE(slice_dyn_test0)
// too large
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int32_type, {{2, 3, 0}, {2, 2, 0}, {3, 3, 0}}};
migraphx::shape s{migraphx::shape::int32_type, {{2, 3}, {2, 2}, {3, 3}}};
auto x = mm->add_parameter("x", s);
mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1, 2}}, {"starts", {0, 1}}, {"ends", {1, 6}}}), x);
migraphx::shape s2{migraphx::shape::int32_type, {{2, 3, 0}, {1, 1, 0}, {2, 2, 0}}};
migraphx::shape s2{migraphx::shape::int32_type, {{2, 3}, {1, 1}, {2, 2}}};
EXPECT(p.get_output_shapes().back() == s2);
p.compile(migraphx::make_target("ref"));
......@@ -7740,14 +7732,14 @@ TEST_CASE(slice_dyn_test1)
// Slice all three dynamic dimensions
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int32_type, {{2, 2, 0}, {2, 2, 0}, {3, 3, 0}}};
migraphx::shape s{migraphx::shape::int32_type, {{2, 2}, {2, 2}, {3, 3}}};
auto x = mm->add_parameter("x", s);
mm->add_instruction(
migraphx::make_op("slice",
{{"axes", {0, 1, 2}}, {"starts", {0, 0, 0}}, {"ends", {2, 2, 2}}}),
x);
migraphx::shape s2{migraphx::shape::int32_type, {{2, 2, 0}, {2, 2, 0}, {2, 2, 0}}};
migraphx::shape s2{migraphx::shape::int32_type, {{2, 2}, {2, 2}, {2, 2}}};
EXPECT(p.get_output_shapes().back() == s2);
p.compile(migraphx::make_target("ref"));
migraphx::shape sresult{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}};
......@@ -7847,7 +7839,7 @@ TEST_CASE(softmax_dyn_test)
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::float_type,
{{1, 10, 0}, {1, 3, 3}, {4, 4, 0}, {2, 2, 2}}};
{{1, 10}, {1, 3, {3}}, {4, 4}, {2, 2, {2}}}};
auto al = mm->add_parameter("a", a_shape);
mm->add_instruction(migraphx::make_op("softmax", {{"axis", 1}}), al);
p.compile(migraphx::make_target("ref"));
......@@ -7925,7 +7917,7 @@ TEST_CASE(sqdiff_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, 0}};
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6}};
migraphx::shape s{migraphx::shape::float_type, dd};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
......@@ -7967,7 +7959,7 @@ TEST_CASE(sqrt_dynamic_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0};
migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s);
std::vector<float> input_data{1.02481645, 0.85643062, 0.03404123, 0.92791926, 0.10569184};
......@@ -8031,8 +8023,7 @@ TEST_CASE(squeeze_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s1{migraphx::shape::float_type,
{{1, 4, 0}, {1, 1, 0}, {3, 3, 0}, {1, 1, 0}, {3, 3, 0}}};
migraphx::shape s1{migraphx::shape::float_type, {{1, 4}, {1, 1}, {3, 3}, {1, 1}, {3, 3}}};
auto p0 = mm->add_parameter("x", s1);
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), p0);
p.compile(migraphx::make_target("ref"));
......@@ -8103,7 +8094,7 @@ TEST_CASE(sub_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6, 0}};
std::vector<migraphx::shape::dynamic_dimension> dd{{2, 6}};
migraphx::shape s{migraphx::shape::float_type, dd};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
......@@ -8145,7 +8136,7 @@ TEST_CASE(tan_dynamic_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0};
migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s);
std::vector<float> input_data{-1, 0, 1};
......@@ -8186,7 +8177,7 @@ TEST_CASE(tanh_dynamic_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 8, 0};
migraphx::shape::dynamic_dimension dd{3, 8};
migraphx::shape s{migraphx::shape::float_type, {dd}};
auto input = mm->add_parameter("X", s);
std::vector<float> input_data{-1.0, 2.0, -3.0, 4.0};
......@@ -8294,7 +8285,7 @@ TEST_CASE(transpose_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{1, 4, 0}, {2, 2, 0}, {2, 2, 0}, {3, 3, 0}}};
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {2, 2}, {2, 2}, {3, 3}}};
auto l = mm->add_parameter("X", s);
std::vector<int64_t> perm = {0, 3, 1, 2};
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), l);
......@@ -8349,7 +8340,7 @@ TEST_CASE(unsqueeze_dyn_test)
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s1{migraphx::shape::float_type, {{1, 4, 0}, {3, 3, 0}, {3, 3, 0}}};
migraphx::shape s1{migraphx::shape::float_type, {{1, 4}, {3, 3}, {3, 3}}};
auto p0 = mm->add_parameter("x", s1);
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), p0);
p.compile(migraphx::make_target("ref"));
......@@ -8394,8 +8385,8 @@ TEST_CASE(where_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sb{migraphx::shape::bool_type, {{2, 3, 0}, {2, 3, 0}}};
migraphx::shape sx{migraphx::shape::float_type, {{2, 3, 0}, {2, 3, 0}}};
migraphx::shape sb{migraphx::shape::bool_type, {{2, 3}, {2, 3}}};
migraphx::shape sx{migraphx::shape::float_type, {{2, 3}, {2, 3}}};
auto lb = mm->add_parameter("predicate", sb);
auto lx = mm->add_parameter("X", sx);
......
......@@ -41,22 +41,13 @@ TEST_CASE(test_shape_default)
TEST_CASE(test_dyn_4arg_constructor)
{
migraphx::shape s{migraphx::shape::float_type,
{
1,
4,
4,
},
{
4,
4,
4,
},
{0, 0, 0}};
std::vector<migraphx::shape::dynamic_dimension> expected_dyn_dims = {
{1, 4, 0}, {4, 4, 0}, {4, 4, 0}};
EXPECT(s.dynamic());
EXPECT(s.dyn_dims() == expected_dyn_dims);
migraphx::shape s0{migraphx::shape::float_type, {1, 4, 4}, {4, 4, 4}, {{}, {}, {}}};
migraphx::shape s1{migraphx::shape::float_type, {1, 4, 4}, {4, 4, 4}, {}};
std::vector<migraphx::shape::dynamic_dimension> expected_dyn_dims = {{1, 4}, {4, 4}, {4, 4}};
EXPECT(s0.dynamic());
EXPECT(s0.dyn_dims() == expected_dyn_dims);
EXPECT(s1.dynamic());
EXPECT(s1.dyn_dims() == expected_dyn_dims);
}
TEST_CASE(test_shape_assign)
......@@ -99,12 +90,12 @@ TEST_CASE(test_shape_min_max_opt)
migraphx::shape s{migraphx::shape::float_type, {2, 2, 3}, {6, 3, 1}};
EXPECT(s.min_lens() == s.lens());
EXPECT(s.max_lens() == s.lens());
EXPECT(s.opt_lens() == s.lens());
EXPECT(s.opt_lens().empty());
}
TEST_CASE(test_shape_dynamic_fixed)
{
migraphx::shape s{migraphx::shape::float_type, {{2, 2, 0}, {2, 2, 0}, {3, 3, 0}}};
migraphx::shape s{migraphx::shape::float_type, {{2, 2}, {2, 2}, {3, 3}}};
EXPECT(not s.standard());
EXPECT(not s.packed());
EXPECT(not s.transposed());
......@@ -115,7 +106,8 @@ TEST_CASE(test_shape_dynamic_fixed)
EXPECT(not s.dyn_dims().at(0).has_optimal());
EXPECT(s.min_lens() == std::vector<std::size_t>{2, 2, 3});
EXPECT(s.max_lens() == std::vector<std::size_t>{2, 2, 3});
EXPECT(s.opt_lens() == std::vector<std::size_t>{0, 0, 0});
std::vector<std::set<std::size_t>> e_opt_lens = {{}, {}, {}};
EXPECT(s.opt_lens() == e_opt_lens);
EXPECT(s.bytes() == 2 * 2 * 3 * sizeof(float));
}
......@@ -123,8 +115,8 @@ TEST_CASE(test_shape_dynamic_not_fixed)
{
using migraphx::shape;
std::vector<shape::dynamic_dimension> dims = {};
dims.push_back(shape::dynamic_dimension{2, 5, 2});
dims.push_back(shape::dynamic_dimension{2, 8, 0});
dims.push_back(shape::dynamic_dimension{2, 5, {2}});
dims.push_back(shape::dynamic_dimension{2, 8});
migraphx::shape s{migraphx::shape::float_type, dims};
EXPECT(not s.standard());
EXPECT(not s.packed());
......@@ -136,18 +128,16 @@ TEST_CASE(test_shape_dynamic_not_fixed)
EXPECT(s.dyn_dims().at(0).has_optimal());
EXPECT(s.min_lens() == std::vector<std::size_t>{2, 2});
EXPECT(s.max_lens() == std::vector<std::size_t>{5, 8});
EXPECT(s.opt_lens() == std::vector<std::size_t>{2, 0});
EXPECT(s.opt_lens() == std::vector<std::set<std::size_t>>{{2}, {}});
EXPECT(s.bytes() == 5 * 8 * sizeof(float));
}
TEST_CASE(test_shape_dynamic_compares)
{
using migraphx::shape;
auto a = shape::dynamic_dimension{2, 5, 2};
auto b = a;
auto c = shape::dynamic_dimension{2, 5, 2};
auto d = shape::dynamic_dimension{3, 8, 4};
EXPECT(a == b);
auto a = shape::dynamic_dimension{2, 5, {2}};
auto c = shape::dynamic_dimension{2, 5, {2}};
auto d = shape::dynamic_dimension{3, 8};
EXPECT(a == c);
EXPECT(a != d);
......@@ -172,13 +162,13 @@ TEST_CASE(test_shape_dynamic_compares)
TEST_CASE(dynamic_dimension_size_t_compares)
{
using migraphx::shape;
auto a = shape::dynamic_dimension{2, 2, 2};
auto a = shape::dynamic_dimension{2, 2, {2}};
EXPECT(a == 2);
EXPECT(a != 3);
EXPECT(static_cast<std::size_t>(2) == a);
EXPECT(static_cast<std::size_t>(3) != a);
auto b = shape::dynamic_dimension{2, 4, 0};
auto b = shape::dynamic_dimension{2, 4};
EXPECT(b != 2);
EXPECT(static_cast<std::size_t>(2) != b);
}
......@@ -186,25 +176,25 @@ TEST_CASE(dynamic_dimension_size_t_compares)
TEST_CASE(dynamic_dimension_add_sub_fixed)
{
using migraphx::shape;
auto a = shape::dynamic_dimension{2, 5, 2};
auto a = shape::dynamic_dimension{2, 5, {2}};
a += 3;
EXPECT(a == shape::dynamic_dimension{5, 8, 5});
EXPECT(a == shape::dynamic_dimension{5, 8, {5}});
a -= 3;
EXPECT(a == shape::dynamic_dimension{2, 5, 2});
EXPECT(a == shape::dynamic_dimension{2, 5, {2}});
auto b = shape::dynamic_dimension{3, 6, 3};
auto b = shape::dynamic_dimension{3, 6, {3}};
EXPECT((a + 1) == b);
EXPECT((1 + a) == b);
EXPECT((b - 1) == a);
auto c = shape::dynamic_dimension{4, 7, 4};
auto c = shape::dynamic_dimension{4, 7, {4}};
EXPECT((a + 2) == c);
EXPECT((2 + a) == c);
EXPECT((c - 2) == a);
auto d = shape::dynamic_dimension{4, 8, 0};
auto e = shape::dynamic_dimension{2, 6, 0};
auto d = shape::dynamic_dimension{4, 8};
auto e = shape::dynamic_dimension{2, 6};
EXPECT((d - 2) == e);
EXPECT((e + 2) == d);
EXPECT((2 + e) == d);
......@@ -214,8 +204,8 @@ TEST_CASE(test_shape_dynamic_errors)
{
using migraphx::shape;
std::vector<shape::dynamic_dimension> dims = {};
dims.push_back(shape::dynamic_dimension{2, 5, 2});
dims.push_back(shape::dynamic_dimension{2, 8, 0});
dims.push_back(shape::dynamic_dimension{2, 5, {2}});
dims.push_back(shape::dynamic_dimension{2, 8});
migraphx::shape s{shape::float_type, dims};
EXPECT(test::throws([&] { s.elements(); }));
EXPECT(test::throws([&] { s.index({0, 1}); }));
......@@ -229,13 +219,13 @@ TEST_CASE(test_shape_dynamic_serialize)
{
using migraphx::shape;
std::vector<shape::dynamic_dimension> dims1 = {};
dims1.push_back(shape::dynamic_dimension{2, 5, 2});
dims1.push_back(shape::dynamic_dimension{2, 8, 0});
dims1.push_back(shape::dynamic_dimension{2, 5, {2}});
dims1.push_back(shape::dynamic_dimension{2, 8});
migraphx::shape s1{shape::float_type, dims1};
auto v1 = migraphx::to_value(s1);
std::vector<shape::dynamic_dimension> dims2 = {};
dims2.push_back(shape::dynamic_dimension{2, 5, 2});
dims2.push_back(shape::dynamic_dimension{2, 5, {2}});
migraphx::shape s2{shape::uint64_type, dims2};
auto v2 = migraphx::to_value(s2);
EXPECT(v1 != v2);
......@@ -294,14 +284,13 @@ TEST_CASE(test_shape_ndim_static)
TEST_CASE(test_shape_ndim_dyn)
{
migraphx::shape s0{migraphx::shape::float_type, {{2, 2, 0}, {2, 2, 0}}};
migraphx::shape s0{migraphx::shape::float_type, {{2, 2}, {2, 2}}};
EXPECT(s0.ndim() == 2);
migraphx::shape s1{migraphx::shape::float_type, {{1, 1, 0}, {2, 4, 0}, {2, 4, 0}, {2, 4, 0}}};
migraphx::shape s1{migraphx::shape::float_type, {{1, 1}, {2, 4}, {2, 4}, {2, 4}}};
EXPECT(s1.ndim() == 4);
migraphx::shape s2{migraphx::shape::float_type,
{{1, 1, 0}, {2, 4, 0}, {2, 4, 0}, {1, 1, 1}, {3, 3, 0}}};
migraphx::shape s2{migraphx::shape::float_type, {{1, 1}, {2, 4}, {2, 4}, {1, 1}, {3, 3}}};
EXPECT(s2.ndim() == 5);
}
......@@ -336,13 +325,13 @@ TEST_CASE(test_shape_static_to_dynamic)
{
migraphx::shape s0{migraphx::shape::float_type, {1, 2, 4, 4}};
migraphx::shape s1 = s0.to_dynamic();
migraphx::shape s2{migraphx::shape::float_type, {{1, 1, 0}, {2, 2, 0}, {4, 4, 0}, {4, 4, 0}}};
migraphx::shape s2{migraphx::shape::float_type, {{1, 1}, {2, 2}, {4, 4}, {4, 4}}};
EXPECT(s1 == s2);
}
TEST_CASE(test_shape_dyn_to_dynamic)
{
migraphx::shape s0{migraphx::shape::float_type, {{1, 1, 0}, {2, 4, 0}, {2, 4, 0}, {2, 4, 0}}};
migraphx::shape s0{migraphx::shape::float_type, {{1, 1}, {2, 4}, {2, 4}, {2, 4}}};
migraphx::shape s1 = s0.to_dynamic();
EXPECT(s0 == s1);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment