Unverified Commit 45ccd757 authored by Ted Themistokleous's avatar Ted Themistokleous Committed by GitHub
Browse files

Remove contiguous from passes for reshapes (#2319)

Remove contiguous from simplify algebra passes for reshape op
Remove contiguous from fuse_pointwise for reshape op
parent fa0dc35f
...@@ -219,9 +219,8 @@ struct find_pointwise_reshape_pointwise ...@@ -219,9 +219,8 @@ struct find_pointwise_reshape_pointwise
auto reshape_input = [&](const auto& ins_to_insert) { auto reshape_input = [&](const auto& ins_to_insert) {
return [&](auto input) { return [&](auto input) {
auto c = m.insert_instruction(ins_to_insert, make_op("contiguous"), input);
return m.insert_instruction( return m.insert_instruction(
ins_to_insert, make_op("reshape", {{"dims", cd.dims}}), c); ins_to_insert, make_op("reshape", {{"dims", cd.dims}}), input);
}; };
}; };
auto x_inputs = x_ins->inputs(); auto x_inputs = x_ins->inputs();
......
...@@ -941,15 +941,6 @@ struct find_splits ...@@ -941,15 +941,6 @@ struct find_splits
{ {
auto split = i->inputs()[split_idx]; auto split = i->inputs()[split_idx];
assert(split->name() == "slice"); assert(split->name() == "slice");
// Insert contiguous for reshapes
auto outputs = i->outputs();
for(auto output : outputs)
{
if(output->name() != "reshape")
continue;
auto x = m.insert_instruction(output, make_op("contiguous"), i);
m.replace_instruction(output, output->get_operator(), x);
}
m.replace_instruction(i, split->get_operator(), c); m.replace_instruction(i, split->get_operator(), c);
} }
...@@ -1181,13 +1172,6 @@ struct find_conv_dot_horiz_fusion ...@@ -1181,13 +1172,6 @@ struct find_conv_dot_horiz_fusion
for(auto arg : range(start, last)) for(auto arg : range(start, last))
{ {
auto outputs = arg->outputs(); auto outputs = arg->outputs();
for(auto output : outputs)
{
if(output->name() != "reshape")
continue;
auto x = m.insert_instruction(output, make_op("contiguous"), arg);
m.replace_instruction(output, output->get_operator(), x);
}
int64_t len = arg->get_shape().lens()[axis]; int64_t len = arg->get_shape().lens()[axis];
m.replace_instruction( m.replace_instruction(
...@@ -1487,11 +1471,6 @@ struct find_split_reshape ...@@ -1487,11 +1471,6 @@ struct find_split_reshape
slc_axis_len; slc_axis_len;
}); });
// insert the reshape instruction and add contiguous if needed
if(not input->get_shape().standard())
{
input = m.insert_instruction(std::next(input), make_op("contiguous"), input);
}
auto rsp_ins = m.insert_instruction( auto rsp_ins = m.insert_instruction(
std::next(input), make_op("reshape", {{"dims", rsp_out_lens}}), input); std::next(input), make_op("reshape", {{"dims", rsp_out_lens}}), input);
......
...@@ -103,8 +103,6 @@ struct find_reshaper ...@@ -103,8 +103,6 @@ struct find_reshaper
auto input = mr.instructions["x"]; auto input = mr.instructions["x"];
auto dims = ins->get_shape().lens(); auto dims = ins->get_shape().lens();
if(not input->get_shape().standard())
input = m.insert_instruction(ins, make_op("contiguous"), input);
m.replace_instruction(ins, make_op("reshape", {{"dims", dims}}), input); m.replace_instruction(ins, make_op("reshape", {{"dims", dims}}), input);
} }
}; };
...@@ -475,9 +473,8 @@ struct find_resize ...@@ -475,9 +473,8 @@ struct find_resize
ins_rsp, migraphx::make_op("reshape", {{"dims", in_dims}}), in_rsp); ins_rsp, migraphx::make_op("reshape", {{"dims", in_dims}}), in_rsp);
auto mb_rsp = m.insert_instruction( auto mb_rsp = m.insert_instruction(
ins_rsp, migraphx::make_op("multibroadcast", {{"out_lens", out_dims}}), rsp_data); ins_rsp, migraphx::make_op("multibroadcast", {{"out_lens", out_dims}}), rsp_data);
auto std_mb = m.insert_instruction(ins, migraphx::make_op("contiguous"), mb_rsp);
std::vector<int64_t> rsp_dims(out_lens.begin(), out_lens.end()); std::vector<int64_t> rsp_dims(out_lens.begin(), out_lens.end());
m.replace_instruction(ins, migraphx::make_op("reshape", {{"dims", rsp_dims}}), std_mb); m.replace_instruction(ins, migraphx::make_op("reshape", {{"dims", rsp_dims}}), mb_rsp);
} }
}; };
...@@ -626,9 +623,8 @@ struct find_transpose_contiguous_reshaper_unary ...@@ -626,9 +623,8 @@ struct find_transpose_contiguous_reshaper_unary
auto cont_ins = r.instructions["cont_ins"]; auto cont_ins = r.instructions["cont_ins"];
auto unary_op_name = ins->get_operator().name(); auto unary_op_name = ins->get_operator().name();
auto unary_ins = m.insert_instruction(cont_ins, make_op(unary_op_name), trans_ins); auto unary_ins = m.insert_instruction(cont_ins, make_op(unary_op_name), trans_ins);
auto new_cont_ins = m.insert_instruction(cont_ins, make_op("contiguous"), unary_ins);
// older cont and reshape are removed by deadcode elimination // older cont and reshape are removed by deadcode elimination
m.replace_instruction(ins, reshaper_ins->get_operator(), new_cont_ins); m.replace_instruction(ins, reshaper_ins->get_operator(), unary_ins);
} }
}; };
......
...@@ -414,8 +414,8 @@ TEST_CASE(add_reshape_add_nonstandard) ...@@ -414,8 +414,8 @@ TEST_CASE(add_reshape_add_nonstandard)
auto y = mm->add_parameter("y", s1); auto y = mm->add_parameter("y", s1);
auto z = mm->add_parameter("z", s2); auto z = mm->add_parameter("z", s2);
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y); auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
auto c = mm->add_instruction(migraphx::make_op("contiguous"), add1); auto reshape =
auto reshape = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), c); mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), add1);
auto add2 = mm->add_instruction(migraphx::make_op("add"), reshape, z); auto add2 = mm->add_instruction(migraphx::make_op("add"), reshape, z);
mm->add_return({add2}); mm->add_return({add2});
} }
...@@ -426,10 +426,8 @@ TEST_CASE(add_reshape_add_nonstandard) ...@@ -426,10 +426,8 @@ TEST_CASE(add_reshape_add_nonstandard)
auto x = mm->add_parameter("x", s1); auto x = mm->add_parameter("x", s1);
auto y = mm->add_parameter("y", s1); auto y = mm->add_parameter("y", s1);
auto z = mm->add_parameter("z", s2); auto z = mm->add_parameter("z", s2);
auto cx = mm->add_instruction(migraphx::make_op("contiguous"), x); auto x2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), x);
auto cy = mm->add_instruction(migraphx::make_op("contiguous"), y); auto y2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), y);
auto x2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), cx);
auto y2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), cy);
auto z2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), z); auto z2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), z);
auto fadd = auto fadd =
add_pointwise(p2, "main:pointwise0", {x2, y2, z2}, [=](auto* pm, const auto& inputs) { add_pointwise(p2, "main:pointwise0", {x2, y2, z2}, [=](auto* pm, const auto& inputs) {
...@@ -466,10 +464,8 @@ TEST_CASE(add_unsqueeze_add_nonstandard) ...@@ -466,10 +464,8 @@ TEST_CASE(add_unsqueeze_add_nonstandard)
auto x = mm->add_parameter("x", s1); auto x = mm->add_parameter("x", s1);
auto y = mm->add_parameter("y", s1); auto y = mm->add_parameter("y", s1);
auto z = mm->add_parameter("z", s2); auto z = mm->add_parameter("z", s2);
auto cx = mm->add_instruction(migraphx::make_op("contiguous"), x); auto x2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), x);
auto cy = mm->add_instruction(migraphx::make_op("contiguous"), y); auto y2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), y);
auto x2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), cx);
auto y2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), cy);
auto fadd = auto fadd =
add_pointwise(p2, "main:pointwise0", {x2, y2, z}, [=](auto* pm, const auto& inputs) { add_pointwise(p2, "main:pointwise0", {x2, y2, z}, [=](auto* pm, const auto& inputs) {
auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]);
......
...@@ -1897,12 +1897,17 @@ TEST_CASE(simplify_split_add_relu_reshape) ...@@ -1897,12 +1897,17 @@ TEST_CASE(simplify_split_add_relu_reshape)
auto concatb = m2.add_instruction(b, concat); auto concatb = m2.add_instruction(b, concat);
auto sum = m2.add_instruction(migraphx::make_op("add"), input, concatb); auto sum = m2.add_instruction(migraphx::make_op("add"), input, concatb);
auto relu = m2.add_instruction(migraphx::make_op("relu"), sum); auto relu = m2.add_instruction(migraphx::make_op("relu"), sum);
auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 8}}}), relu);
auto slc1 = m2.add_instruction( auto slc1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {4}}}), rsp); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), relu);
auto rsp1 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 4}}}), slc1);
auto slc2 = m2.add_instruction( auto slc2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {4}}, {"ends", {8}}}), rsp); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {2}}}), relu);
auto add = m2.add_instruction(migraphx::make_op("add"), slc1, slc2);
auto rsp2 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 4}}}), slc2);
auto add = m2.add_instruction(migraphx::make_op("add"), rsp1, rsp2);
m2.add_instruction(pass_op{}, add); m2.add_instruction(pass_op{}, add);
} }
EXPECT(m1.sort() == m2.sort()); EXPECT(m1.sort() == m2.sort());
...@@ -2323,9 +2328,7 @@ TEST_CASE(simplify_dot_horiz_reshape) ...@@ -2323,9 +2328,7 @@ TEST_CASE(simplify_dot_horiz_reshape)
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {4}}}), dot); migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {4}}}), dot);
auto y = m2.add_instruction( auto y = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {4}}, {"ends", {8}}}), dot); migraphx::make_op("slice", {{"axes", {2}}, {"starts", {4}}, {"ends", {8}}}), dot);
auto x_cont = m2.add_instruction(migraphx::make_op("contiguous"), x); auto x_rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 4, 2, 2}}}), x);
auto x_rsp =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 4, 2, 2}}}), x_cont);
auto y_rsp = auto y_rsp =
m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {2}}}), y); m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {2}}}), y);
auto sum = m2.add_instruction(migraphx::make_op("add"), {x_rsp, y_rsp}); auto sum = m2.add_instruction(migraphx::make_op("add"), {x_rsp, y_rsp});
...@@ -2688,12 +2691,8 @@ void reorder_reshape_slice() ...@@ -2688,12 +2691,8 @@ void reorder_reshape_slice()
{ {
s = migraphx::shape{migraphx::shape::float_type, {BS, 128, 1920}, {165120, 1, 128}}; s = migraphx::shape{migraphx::shape::float_type, {BS, 128, 1920}, {165120, 1, 128}};
} }
auto input = m2.add_parameter("input", s); auto input = m2.add_parameter("input", s);
auto rsp_input = input; auto rsp_input = input;
if(TransposeInput)
{
rsp_input = m2.add_instruction(migraphx::make_op("contiguous"), {input});
}
std::vector<int64_t> lens = {static_cast<int64_t>(BS), 128, 30, 64}; std::vector<int64_t> lens = {static_cast<int64_t>(BS), 128, 30, 64};
auto r = m2.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), rsp_input); auto r = m2.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), rsp_input);
...@@ -2976,9 +2975,8 @@ TEST_CASE(reorder_reshape_slice_multi_rsp) ...@@ -2976,9 +2975,8 @@ TEST_CASE(reorder_reshape_slice_multi_rsp)
auto input = m2.add_parameter("input", s); auto input = m2.add_parameter("input", s);
auto t1 = m2.add_instruction( auto t1 = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {2, 0, 3, 1, 4}}}), input); migraphx::make_op("transpose", {{"permutation", {2, 0, 3, 1, 4}}}), input);
auto c_t1 = m2.add_instruction(migraphx::make_op("contiguous"), t1);
auto rsp1 = auto rsp1 =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {384, 128, 80}}}), c_t1); m2.add_instruction(migraphx::make_op("reshape", {{"dims", {384, 128, 80}}}), t1);
auto slc0 = m2.add_instruction( auto slc0 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {256}}, {"ends", {384}}}), rsp1); migraphx::make_op("slice", {{"axes", {0}}, {"starts", {256}}, {"ends", {384}}}), rsp1);
auto slc1 = m2.add_instruction( auto slc1 = m2.add_instruction(
...@@ -2993,9 +2991,8 @@ TEST_CASE(reorder_reshape_slice_multi_rsp) ...@@ -2993,9 +2991,8 @@ TEST_CASE(reorder_reshape_slice_multi_rsp)
auto dot = m2.add_instruction(migraphx::make_op("dot"), slc2, c_t_slc1); auto dot = m2.add_instruction(migraphx::make_op("dot"), slc2, c_t_slc1);
auto c_t1_1 = m2.add_instruction(migraphx::make_op("contiguous"), t1);
auto rsp2 = auto rsp2 =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {12, 32, 128, 80}}}), c_t1_1); m2.add_instruction(migraphx::make_op("reshape", {{"dims", {12, 32, 128, 80}}}), t1);
auto slc2_1 = m2.add_instruction( auto slc2_1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {4}}, {"ends", {8}}}), rsp2); migraphx::make_op("slice", {{"axes", {0}}, {"starts", {4}}, {"ends", {8}}}), rsp2);
...@@ -3372,9 +3369,8 @@ TEST_CASE(dot_fusion_reshape) ...@@ -3372,9 +3369,8 @@ TEST_CASE(dot_fusion_reshape)
auto s1 = m2.add_instruction( auto s1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {320}}, {"ends", {640}}}), d); migraphx::make_op("slice", {{"axes", {2}}, {"starts", {320}}, {"ends", {640}}}), d);
auto cont0 = m2.add_instruction(migraphx::make_op("contiguous"), s0);
auto r0 = auto r0 =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 4096, 8, 40}}}), cont0); m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 4096, 8, 40}}}), s0);
m2.add_return({r0, s1}); m2.add_return({r0, s1});
}; };
......
...@@ -888,9 +888,8 @@ TEST_CASE(optimize_resize) ...@@ -888,9 +888,8 @@ TEST_CASE(optimize_resize)
std::vector<int64_t> mb_dims = {1, 2, 2, 2, 2, 3}; std::vector<int64_t> mb_dims = {1, 2, 2, 2, 2, 3};
auto mbx = auto mbx =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_dims}}), rspx); m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mb_dims}}), rspx);
auto std_mb = m.add_instruction(migraphx::make_op("contiguous"), mbx);
std::vector<int64_t> orig_dims = {1, 2, 4, 6}; std::vector<int64_t> orig_dims = {1, 2, 4, 6};
auto rmb = m.add_instruction(migraphx::make_op("reshape", {{"dims", orig_dims}}), std_mb); auto rmb = m.add_instruction(migraphx::make_op("reshape", {{"dims", orig_dims}}), mbx);
auto r = m.add_instruction(migraphx::make_op("softmax", {{"axis", 1}}), rmb); auto r = m.add_instruction(migraphx::make_op("softmax", {{"axis", 1}}), rmb);
m.add_return({r}); m.add_return({r});
...@@ -1300,10 +1299,9 @@ TEST_CASE(transpose_contiguous_reshape_unary) ...@@ -1300,10 +1299,9 @@ TEST_CASE(transpose_contiguous_reshape_unary)
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), x); m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), x);
auto transpose_ins = m2.add_instruction( auto transpose_ins = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), reshape_ins1); migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), reshape_ins1);
auto relu = m2.add_instruction(migraphx::make_op("relu"), transpose_ins); auto relu = m2.add_instruction(migraphx::make_op("relu"), transpose_ins);
auto cont_ins = m2.add_instruction(migraphx::make_op("contiguous"), relu);
auto reshape_ins2 = auto reshape_ins2 =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), cont_ins); m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), relu);
m2.add_instruction(pass_op{}, reshape_ins2); m2.add_instruction(pass_op{}, reshape_ins2);
} }
EXPECT(m1 == m2); EXPECT(m1 == m2);
...@@ -1328,8 +1326,7 @@ TEST_CASE(transpose_contiguous_squeeze_unary) ...@@ -1328,8 +1326,7 @@ TEST_CASE(transpose_contiguous_squeeze_unary)
auto transpose_ins = auto transpose_ins =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x); m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x);
auto rsqrt = m2.add_instruction(migraphx::make_op("rsqrt"), transpose_ins); auto rsqrt = m2.add_instruction(migraphx::make_op("rsqrt"), transpose_ins);
auto cont_ins = m2.add_instruction(migraphx::make_op("contiguous"), rsqrt); auto sq_ins = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), rsqrt);
auto sq_ins = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), cont_ins);
m2.add_instruction(pass_op{}, sq_ins); m2.add_instruction(pass_op{}, sq_ins);
} }
EXPECT(m1 == m2); EXPECT(m1 == m2);
...@@ -1355,9 +1352,7 @@ TEST_CASE(transpose_contiguous_unsqueeze_unary) ...@@ -1355,9 +1352,7 @@ TEST_CASE(transpose_contiguous_unsqueeze_unary)
auto transpose_ins = auto transpose_ins =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x); m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x);
auto round = m2.add_instruction(migraphx::make_op("nearbyint"), transpose_ins); auto round = m2.add_instruction(migraphx::make_op("nearbyint"), transpose_ins);
auto cont_ins = m2.add_instruction(migraphx::make_op("contiguous"), round); auto unsq_ins = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), round);
auto unsq_ins =
m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), cont_ins);
m2.add_instruction(pass_op{}, unsq_ins); m2.add_instruction(pass_op{}, unsq_ins);
} }
EXPECT(m1 == m2); EXPECT(m1 == m2);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment