Unverified Commit ed7973d1 authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Insert contiguous for reshape as necessary (#1351)

reshape op requires standard shape. During simplify_algebra, it inserts reshapes without checking for this requirement.
parent 349635ce
...@@ -615,10 +615,9 @@ struct find_splits ...@@ -615,10 +615,9 @@ struct find_splits
auto outputs = i->outputs(); auto outputs = i->outputs();
for(auto output : outputs) for(auto output : outputs)
{ {
if(not contains({"reshape", "squeeze", "unsqueeze"}, output->name())) if(output->name() != "reshape")
continue; continue;
auto x = auto x = m.insert_instruction(output, make_op("contiguous"), i);
m.insert_instruction(output, make_op("contiguous"), output->inputs());
m.replace_instruction(output, output->get_operator(), x); m.replace_instruction(output, output->get_operator(), x);
} }
...@@ -808,7 +807,7 @@ struct find_conv_dot_horiz_fusion ...@@ -808,7 +807,7 @@ struct find_conv_dot_horiz_fusion
auto y = j->inputs()[1]->get_shape().lens(); auto y = j->inputs()[1]->get_shape().lens();
if(x.size() != y.size()) if(x.size() != y.size())
return false; return false;
// Check that non-axises match // Check that non-axes match
int axis = 1; int axis = 1;
if(i->name() == "dot") if(i->name() == "dot")
{ {
...@@ -844,13 +843,22 @@ struct find_conv_dot_horiz_fusion ...@@ -844,13 +843,22 @@ struct find_conv_dot_horiz_fusion
for(auto arg : args) for(auto arg : args)
m.move_instructions(arg, input); m.move_instructions(arg, input);
// TODO: Check if axises match // TODO: Check if axes match
auto concat = auto concat =
m.insert_instruction(input, make_op("concat", {{"axis", concat_axis}}), args); m.insert_instruction(input, make_op("concat", {{"axis", concat_axis}}), args);
auto fused = m.insert_instruction(std::next(input), op, input, concat); auto fused = m.insert_instruction(std::next(input), op, input, concat);
int64_t offset = 0; int64_t offset = 0;
for(auto arg : range(start, last)) for(auto arg : range(start, last))
{ {
auto outputs = arg->outputs();
for(auto output : outputs)
{
if(output->name() != "reshape")
continue;
auto x = m.insert_instruction(output, make_op("contiguous"), arg);
m.replace_instruction(output, output->get_operator(), x);
}
int64_t len = arg->get_shape().lens()[axis]; int64_t len = arg->get_shape().lens()[axis];
m.replace_instruction( m.replace_instruction(
arg, arg,
...@@ -993,7 +1001,11 @@ struct find_split_reshape ...@@ -993,7 +1001,11 @@ struct find_split_reshape
std::vector<int64_t> rsp_out_lens(rsp_lens.begin(), rsp_lens.end()); std::vector<int64_t> rsp_out_lens(rsp_lens.begin(), rsp_lens.end());
rsp_out_lens[rsp_axis] = std::accumulate(vec_dims.begin(), vec_dims.end(), std::int64_t{0}); rsp_out_lens[rsp_axis] = std::accumulate(vec_dims.begin(), vec_dims.end(), std::int64_t{0});
// insert the reshape instruction // insert the reshape instruction and add contiguous if needed
if(not input->get_shape().standard())
{
input = m.insert_instruction(std::next(input), make_op("contiguous"), input);
}
auto rsp_ins = m.insert_instruction( auto rsp_ins = m.insert_instruction(
std::next(input), make_op("reshape", {{"dims", rsp_out_lens}}), input); std::next(input), make_op("reshape", {{"dims", rsp_out_lens}}), input);
......
...@@ -30,7 +30,6 @@ ...@@ -30,7 +30,6 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <test.hpp> #include <test.hpp>
void run_pass(migraphx::module& m) void run_pass(migraphx::module& m)
...@@ -1528,6 +1527,48 @@ TEST_CASE(simplify_dot_horiz_flipped) ...@@ -1528,6 +1527,48 @@ TEST_CASE(simplify_dot_horiz_flipped)
EXPECT(m1.sort() == m2.sort()); EXPECT(m1.sort() == m2.sort());
} }
// test if contiguous is added as necessary for reshapes
TEST_CASE(simplify_dot_horiz_reshape)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 4, 4}};
migraphx::module m1;
{
auto input = m1.add_parameter("input", s);
auto a = m1.add_literal(migraphx::generate_literal(s, 0));
auto b = m1.add_literal(migraphx::generate_literal(s, 1));
auto x = m1.add_instruction(migraphx::make_op("dot"), input, a);
auto y = m1.add_instruction(migraphx::make_op("dot"), input, b);
auto x_rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 4, 2, 2}}}), x);
auto y_rsp =
m1.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {2}}}), y);
auto sum = m1.add_instruction(migraphx::make_op("add"), {x_rsp, y_rsp});
m1.add_instruction(pass_op{}, sum);
}
run_pass(m1);
migraphx::module m2;
{
auto input = m2.add_parameter("input", s);
auto a = m2.add_literal(migraphx::generate_literal(s, 0));
auto b = m2.add_literal(migraphx::generate_literal(s, 1));
auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), a, b);
auto dot = m2.add_instruction(migraphx::make_op("dot"), input, concat);
auto x = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {4}}}), dot);
auto y = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {4}}, {"ends", {8}}}), dot);
auto x_cont = m2.add_instruction(migraphx::make_op("contiguous"), x);
auto x_rsp =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 4, 2, 2}}}), x_cont);
auto y_rsp =
m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {2}}}), y);
auto sum = m2.add_instruction(migraphx::make_op("add"), {x_rsp, y_rsp});
m2.add_instruction(pass_op{}, sum);
}
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(simplify_conv_horiz) TEST_CASE(simplify_conv_horiz)
{ {
auto s = migraphx::shape{migraphx::shape::int32_type, {8, 3, 64, 64}}; auto s = migraphx::shape{migraphx::shape::int32_type, {8, 3, 64, 64}};
...@@ -1833,13 +1874,19 @@ TEST_CASE(simplify_mul_slice_conv_horiz_fusion) ...@@ -1833,13 +1874,19 @@ TEST_CASE(simplify_mul_slice_conv_horiz_fusion)
} }
EXPECT(m1.sort() == m2.sort()); EXPECT(m1.sort() == m2.sort());
} }
TEST_CASE(reorder_reshape_slice)
template <std::size_t BS, bool TransposeInput>
void reorder_reshape_slice()
{ {
std::vector<int64_t> perm0 = {0, 2, 1, 3}; std::vector<int64_t> perm0 = {0, 2, 1, 3};
std::vector<int64_t> perm1 = {0, 2, 3, 1}; std::vector<int64_t> perm1 = {0, 2, 3, 1};
auto create_m1 = [&](std::size_t batch_size) {
migraphx::module m1; migraphx::module m1;
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}}; {
auto s = migraphx::shape{migraphx::shape::float_type, {BS, 128, 1920}};
if(TransposeInput)
{
s = migraphx::shape{migraphx::shape::float_type, {BS, 128, 1920}, {165120, 1, 128}};
}
auto input = m1.add_parameter("input", s); auto input = m1.add_parameter("input", s);
auto slc0 = m1.add_instruction( auto slc0 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {640}}}), input); migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {640}}}), input);
...@@ -1854,7 +1901,7 @@ TEST_CASE(reorder_reshape_slice) ...@@ -1854,7 +1901,7 @@ TEST_CASE(reorder_reshape_slice)
auto c1 = m1.add_instruction(migraphx::make_op("contiguous"), slc1); auto c1 = m1.add_instruction(migraphx::make_op("contiguous"), slc1);
auto c2 = m1.add_instruction(migraphx::make_op("contiguous"), slc2); auto c2 = m1.add_instruction(migraphx::make_op("contiguous"), slc2);
std::vector<int64_t> lens = {static_cast<int64_t>(batch_size), 128, 10, 64}; std::vector<int64_t> lens = {static_cast<int64_t>(BS), 128, 10, 64};
auto r0 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0); auto r0 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0);
auto r1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1); auto r1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1);
auto r2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2); auto r2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2);
...@@ -1866,16 +1913,23 @@ TEST_CASE(reorder_reshape_slice) ...@@ -1866,16 +1913,23 @@ TEST_CASE(reorder_reshape_slice)
auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1); auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1);
auto ret = m1.add_instruction(migraphx::make_op("dot"), sum, t2); auto ret = m1.add_instruction(migraphx::make_op("dot"), sum, t2);
m1.add_return({ret}); m1.add_return({ret});
return m1;
}; };
auto create_m2 = [&](std::size_t batch_size) {
migraphx::module m2; migraphx::module m2;
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}}; {
auto s = migraphx::shape{migraphx::shape::float_type, {BS, 128, 1920}};
if(TransposeInput)
{
s = migraphx::shape{migraphx::shape::float_type, {BS, 128, 1920}, {165120, 1, 128}};
}
auto input = m2.add_parameter("input", s); auto input = m2.add_parameter("input", s);
std::vector<int64_t> lens = {static_cast<int64_t>(batch_size), 128, 30, 64}; auto rsp_input = input;
auto r = m2.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), input); if(TransposeInput)
{
rsp_input = m2.add_instruction(migraphx::make_op("contiguous"), {input});
}
std::vector<int64_t> lens = {static_cast<int64_t>(BS), 128, 30, 64};
auto r = m2.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), rsp_input);
auto slc0 = m2.add_instruction( auto slc0 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {10}}}), r); migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {10}}}), r);
...@@ -1894,27 +1948,25 @@ TEST_CASE(reorder_reshape_slice) ...@@ -1894,27 +1948,25 @@ TEST_CASE(reorder_reshape_slice)
auto sum = m2.add_instruction(migraphx::make_op("add"), t0, t1); auto sum = m2.add_instruction(migraphx::make_op("add"), t0, t1);
auto ret = m2.add_instruction(migraphx::make_op("dot"), sum, t2); auto ret = m2.add_instruction(migraphx::make_op("dot"), sum, t2);
m2.add_return({ret}); m2.add_return({ret});
return m2;
}; };
auto test = [&](std::size_t batch_size) {
auto m1 = create_m1(batch_size);
run_pass(m1); run_pass(m1);
auto m2 = create_m2(batch_size);
EXPECT(m1.sort() == m2.sort()); EXPECT(m1.sort() == m2.sort());
};
test(1);
test(4);
test(8);
} }
TEST_CASE(reorder_reshape_slice_move_axis1) TEST_CASE_REGISTER(reorder_reshape_slice<1, true>); // test if contiguous is added as necessary if
// input is transposed
TEST_CASE_REGISTER(reorder_reshape_slice<4, true>);
TEST_CASE_REGISTER(reorder_reshape_slice<8, true>);
TEST_CASE_REGISTER(reorder_reshape_slice<1, false>);
TEST_CASE_REGISTER(reorder_reshape_slice<4, false>);
TEST_CASE_REGISTER(reorder_reshape_slice<8, false>);
template <std::size_t BS>
void reorder_reshape_slice_move_axis1()
{ {
auto create_m1 = [](std::size_t batch_size) {
migraphx::module m1; migraphx::module m1;
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 256, 96}}; {
auto s = migraphx::shape{migraphx::shape::float_type, {BS, 256, 96}};
std::vector<int64_t> perm0 = {0, 2, 1, 3}; std::vector<int64_t> perm0 = {0, 2, 1, 3};
std::vector<int64_t> perm1 = {0, 2, 3, 1}; std::vector<int64_t> perm1 = {0, 2, 3, 1};
auto input = m1.add_parameter("input", s); auto input = m1.add_parameter("input", s);
...@@ -1929,7 +1981,7 @@ TEST_CASE(reorder_reshape_slice_move_axis1) ...@@ -1929,7 +1981,7 @@ TEST_CASE(reorder_reshape_slice_move_axis1)
auto c1 = m1.add_instruction(migraphx::make_op("contiguous"), slc1); auto c1 = m1.add_instruction(migraphx::make_op("contiguous"), slc1);
auto c2 = m1.add_instruction(migraphx::make_op("contiguous"), slc2); auto c2 = m1.add_instruction(migraphx::make_op("contiguous"), slc2);
std::vector<int64_t> lens = {static_cast<int64_t>(batch_size), 64, 4, 32}; std::vector<int64_t> lens = {static_cast<int64_t>(BS), 64, 4, 32};
auto r0 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0); auto r0 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0);
auto r1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1); auto r1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1);
auto r2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2); auto r2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2);
...@@ -1941,50 +1993,45 @@ TEST_CASE(reorder_reshape_slice_move_axis1) ...@@ -1941,50 +1993,45 @@ TEST_CASE(reorder_reshape_slice_move_axis1)
auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1); auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1);
auto ret = m1.add_instruction(migraphx::make_op("dot"), sum, t2); auto ret = m1.add_instruction(migraphx::make_op("dot"), sum, t2);
m1.add_return({ret}); m1.add_return({ret});
return m1;
}; };
auto create_m2 = [](std::size_t batch_size) { migraphx::module m2;
migraphx::module m; {
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 256, 96}}; auto s = migraphx::shape{migraphx::shape::float_type, {BS, 256, 96}};
std::vector<int64_t> perm0 = {0, 2, 1, 3}; std::vector<int64_t> perm0 = {0, 2, 1, 3};
std::vector<int64_t> perm1 = {0, 2, 3, 1}; std::vector<int64_t> perm1 = {0, 2, 3, 1};
auto input = m.add_parameter("input", s); auto input = m2.add_parameter("input", s);
std::vector<int64_t> lens = {static_cast<int64_t>(batch_size), 64, 4, 96}; std::vector<int64_t> lens = {static_cast<int64_t>(BS), 64, 4, 96};
auto rsp = m.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), input); auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), input);
auto slc0 = m.add_instruction( auto slc0 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {32}}}), rsp); migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {32}}}), rsp);
auto t0 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), slc0); auto t0 =
auto slc1 = m.add_instruction( m2.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), slc0);
auto slc1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {32}}, {"ends", {64}}}), rsp); migraphx::make_op("slice", {{"axes", {3}}, {"starts", {32}}, {"ends", {64}}}), rsp);
auto t1 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), slc1); auto t1 =
auto slc2 = m.add_instruction( m2.add_instruction(migraphx::make_op("transpose", {{"permutation", perm0}}), slc1);
auto slc2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {96}}}), rsp); migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {96}}}), rsp);
auto t2 = m.add_instruction(migraphx::make_op("transpose", {{"permutation", perm1}}), slc2); auto t2 =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", perm1}}), slc2);
auto sum = m.add_instruction(migraphx::make_op("add"), t0, t1);
auto ret = m.add_instruction(migraphx::make_op("dot"), sum, t2);
m.add_return({ret});
return m; auto sum = m2.add_instruction(migraphx::make_op("add"), t0, t1);
auto ret = m2.add_instruction(migraphx::make_op("dot"), sum, t2);
m2.add_return({ret});
}; };
auto test = [&](std::size_t batch_size) {
auto m1 = create_m1(batch_size);
auto m2 = create_m2(batch_size);
run_pass(m1); run_pass(m1);
EXPECT(m1.sort() == m2.sort()); EXPECT(m1.sort() == m2.sort());
};
test(4);
test(8);
} }
TEST_CASE_REGISTER(reorder_reshape_slice_move_axis1<4>);
TEST_CASE_REGISTER(reorder_reshape_slice_move_axis1<8>);
TEST_CASE(reorder_reshape_slice_move_axis2) TEST_CASE(reorder_reshape_slice_move_axis2)
{ {
auto create_m1 = [] {
migraphx::module m1; migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {128, 96}}; migraphx::shape s{migraphx::shape::float_type, {128, 96}};
auto input = m1.add_parameter("input", s); auto input = m1.add_parameter("input", s);
auto slc0 = m1.add_instruction( auto slc0 = m1.add_instruction(
...@@ -2006,32 +2053,26 @@ TEST_CASE(reorder_reshape_slice_move_axis2) ...@@ -2006,32 +2053,26 @@ TEST_CASE(reorder_reshape_slice_move_axis2)
auto sum = m1.add_instruction(migraphx::make_op("add"), r0, r1); auto sum = m1.add_instruction(migraphx::make_op("add"), r0, r1);
auto ret = m1.add_instruction(migraphx::make_op("mul"), sum, r2); auto ret = m1.add_instruction(migraphx::make_op("mul"), sum, r2);
m1.add_return({ret}); m1.add_return({ret});
return m1;
}; };
auto create_m2 = [] { migraphx::module m2;
migraphx::module m; {
auto s = migraphx::shape{migraphx::shape::float_type, {128, 96}}; auto s = migraphx::shape{migraphx::shape::float_type, {128, 96}};
auto input = m.add_parameter("input", s); auto input = m2.add_parameter("input", s);
std::vector<int64_t> lens = {1, 16, 8, 96}; std::vector<int64_t> lens = {1, 16, 8, 96};
auto rsp = m.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), input); auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), input);
auto slc0 = m.add_instruction( auto slc0 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {32}}}), rsp); migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {32}}}), rsp);
auto slc1 = m.add_instruction( auto slc1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {32}}, {"ends", {64}}}), rsp); migraphx::make_op("slice", {{"axes", {3}}, {"starts", {32}}, {"ends", {64}}}), rsp);
auto slc2 = m.add_instruction( auto slc2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {96}}}), rsp); migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {96}}}), rsp);
auto sum = m.add_instruction(migraphx::make_op("add"), slc0, slc1); auto sum = m2.add_instruction(migraphx::make_op("add"), slc0, slc1);
auto ret = m.add_instruction(migraphx::make_op("mul"), sum, slc2); auto ret = m2.add_instruction(migraphx::make_op("mul"), sum, slc2);
m.add_return({ret}); m2.add_return({ret});
return m;
}; };
auto m1 = create_m1();
auto m2 = create_m2();
run_pass(m1); run_pass(m1);
EXPECT(m1.sort() == m2.sort()); EXPECT(m1.sort() == m2.sort());
} }
...@@ -2071,13 +2112,12 @@ TEST_CASE(reorder_reshape_slice_not_apply) ...@@ -2071,13 +2112,12 @@ TEST_CASE(reorder_reshape_slice_not_apply)
EXPECT(m1.sort() == m2.sort()); EXPECT(m1.sort() == m2.sort());
} }
TEST_CASE(reorder_reshape_slice_diff_dims) template <std::size_t BS>
void reorder_reshape_slice_diff_dims()
{ {
auto create_m1 = [](std::size_t batch_size) {
migraphx::module m1; migraphx::module m1;
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 96, 96}}; {
std::vector<int64_t> perm0 = {0, 2, 1, 3}; auto s = migraphx::shape{migraphx::shape::float_type, {BS, 96, 96}};
std::vector<int64_t> perm1 = {0, 2, 3, 1};
auto input = m1.add_parameter("input", s); auto input = m1.add_parameter("input", s);
auto slc0 = m1.add_instruction( auto slc0 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {32}}}), input); migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {32}}}), input);
...@@ -2090,34 +2130,31 @@ TEST_CASE(reorder_reshape_slice_diff_dims) ...@@ -2090,34 +2130,31 @@ TEST_CASE(reorder_reshape_slice_diff_dims)
auto c1 = m1.add_instruction(migraphx::make_op("contiguous"), slc1); auto c1 = m1.add_instruction(migraphx::make_op("contiguous"), slc1);
auto c2 = m1.add_instruction(migraphx::make_op("contiguous"), slc2); auto c2 = m1.add_instruction(migraphx::make_op("contiguous"), slc2);
std::vector<int64_t> lens = {static_cast<int64_t>(batch_size), 32, 3, 32}; std::vector<int64_t> lens = {static_cast<int64_t>(BS), 32, 3, 32};
std::vector<int64_t> lens1 = {static_cast<int64_t>(batch_size), 48, 2, 32}; std::vector<int64_t> lens1 = {static_cast<int64_t>(BS), 48, 2, 32};
auto r0 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0); auto r0 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0);
auto r1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1); auto r1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1);
auto r2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens1}}), c2); auto r2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens1}}), c2);
m1.add_return({r0, r1, r2}); m1.add_return({r0, r1, r2});
return m1;
}; };
auto test = [&](std::size_t batch_size) {
auto m1 = create_m1(batch_size);
auto m2 = m1; auto m2 = m1;
run_pass(m1); run_pass(m1);
EXPECT(m1.sort() == m2.sort()); EXPECT(m1.sort() == m2.sort());
};
test(4);
test(8);
} }
TEST_CASE(reorder_slice_trans) TEST_CASE_REGISTER(reorder_reshape_slice_diff_dims<4>);
TEST_CASE_REGISTER(reorder_reshape_slice_diff_dims<8>);
template <std::size_t BS>
void reorder_slice_trans()
{ {
std::vector<int64_t> perm = {0, 2, 1}; std::vector<int64_t> perm = {0, 2, 1};
auto create_m1 = [&](std::size_t batch_size) {
migraphx::module m1; migraphx::module m1;
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}}; {
auto s = migraphx::shape{migraphx::shape::float_type, {BS, 128, 1920}};
auto input = m1.add_parameter("input", s); auto input = m1.add_parameter("input", s);
auto slc0 = m1.add_instruction( auto slc0 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {640}}}), input); migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {640}}}), input);
...@@ -2135,13 +2172,11 @@ TEST_CASE(reorder_slice_trans) ...@@ -2135,13 +2172,11 @@ TEST_CASE(reorder_slice_trans)
auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1); auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1);
auto ret = m1.add_instruction(migraphx::make_op("mul"), sum, t2); auto ret = m1.add_instruction(migraphx::make_op("mul"), sum, t2);
m1.add_return({ret}); m1.add_return({ret});
return m1;
}; };
auto create_m2 = [&](std::size_t batch_size) {
migraphx::module m2; migraphx::module m2;
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}}; {
auto s = migraphx::shape{migraphx::shape::float_type, {BS, 128, 1920}};
auto input = m2.add_parameter("input", s); auto input = m2.add_parameter("input", s);
auto r = m2.add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), input); auto r = m2.add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), input);
...@@ -2155,26 +2190,21 @@ TEST_CASE(reorder_slice_trans) ...@@ -2155,26 +2190,21 @@ TEST_CASE(reorder_slice_trans)
auto sum = m2.add_instruction(migraphx::make_op("add"), slc0, slc1); auto sum = m2.add_instruction(migraphx::make_op("add"), slc0, slc1);
auto ret = m2.add_instruction(migraphx::make_op("mul"), sum, slc2); auto ret = m2.add_instruction(migraphx::make_op("mul"), sum, slc2);
m2.add_return({ret}); m2.add_return({ret});
return m2;
}; };
auto test = [&](std::size_t batch_size) {
auto m1 = create_m1(batch_size);
run_pass(m1); run_pass(m1);
auto m2 = create_m2(batch_size);
EXPECT(m1.sort() == m2.sort()); EXPECT(m1.sort() == m2.sort());
};
test(1);
test(8);
} }
TEST_CASE(reorder_slice_trans_diff_perm) TEST_CASE_REGISTER(reorder_slice_trans<1>);
TEST_CASE_REGISTER(reorder_slice_trans<8>);
template <std::size_t BS>
void reorder_slice_trans_diff_perm()
{ {
auto create_m1 = [](std::size_t batch_size) {
migraphx::module m1; migraphx::module m1;
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}}; {
auto s = migraphx::shape{migraphx::shape::float_type, {BS, 128, 1920}};
std::vector<int64_t> perm0 = {0, 2, 1}; std::vector<int64_t> perm0 = {0, 2, 1};
std::vector<int64_t> perm1 = {0, 1, 2}; std::vector<int64_t> perm1 = {0, 1, 2};
auto input = m1.add_parameter("input", s); auto input = m1.add_parameter("input", s);
...@@ -2197,21 +2227,16 @@ TEST_CASE(reorder_slice_trans_diff_perm) ...@@ -2197,21 +2227,16 @@ TEST_CASE(reorder_slice_trans_diff_perm)
auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1); auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1);
auto ret = m1.add_instruction(migraphx::make_op("dot"), sum, t2); auto ret = m1.add_instruction(migraphx::make_op("dot"), sum, t2);
m1.add_return({ret}); m1.add_return({ret});
return m1;
}; };
auto test = [&](std::size_t batch_size) {
auto m1 = create_m1(batch_size);
run_pass(m1); run_pass(m1);
auto m2 = m1; auto m2 = m1;
EXPECT(m1.sort() == m2.sort()); EXPECT(m1.sort() == m2.sort());
};
test(1);
test(4);
} }
TEST_CASE_REGISTER(reorder_slice_trans_diff_perm<1>);
TEST_CASE_REGISTER(reorder_slice_trans_diff_perm<4>);
TEST_CASE(reorder_slice_ins_deps) TEST_CASE(reorder_slice_ins_deps)
{ {
auto create_module = [] { auto create_module = [] {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment