Commit bb90f0eb authored by Shiv's avatar Shiv
Browse files

update conv dot fusion

parent e7ec374f
...@@ -896,21 +896,26 @@ struct find_conv_dot_horiz_fusion ...@@ -896,21 +896,26 @@ struct find_conv_dot_horiz_fusion
int64_t offset = 0; int64_t offset = 0;
for(auto arg : range(start, last)) for(auto arg : range(start, last))
{ {
auto outputs = arg->outputs(); auto outputs = arg->outputs();
for(auto output : outputs) auto requires_contiguous = std::any_of(outputs.begin(), outputs.end(), [](auto o) {
{ return o->get_shape().standard();
if(output->name() != "reshape") });
continue;
auto x = m.insert_instruction(output, make_op("contiguous"), arg);
m.replace_instruction(output, output->get_operator(), x);
}
int64_t len = arg->get_shape().lens()[axis]; int64_t len = arg->get_shape().lens()[axis];
m.replace_instruction(
arg, auto slice = m.insert_instruction(
std::prev(arg),
make_op("slice", make_op("slice",
{{"axes", {axis}}, {"starts", {offset}}, {"ends", {offset + len}}}), {{"axes", {axis}}, {"starts", {offset}}, {"ends", {offset + len}}}),
fused); fused);
if(requires_contiguous)
{
m.replace_instruction(arg, make_op("contiguous"), slice);
}
else
{
m.replace_instruction(arg, slice);
}
offset += len; offset += len;
} }
}; };
......
...@@ -2133,12 +2133,53 @@ TEST_CASE(simplify_dot_horiz) ...@@ -2133,12 +2133,53 @@ TEST_CASE(simplify_dot_horiz)
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {2}}}), dot); migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {2}}}), dot);
auto y = m2.add_instruction( auto y = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {4}}}), dot); migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {4}}}), dot);
x = m2.add_instruction(migraphx::make_op("contiguous"), x);
y = m2.add_instruction(migraphx::make_op("contiguous"), y);
auto sum = m2.add_instruction(migraphx::make_op("add"), x, y); auto sum = m2.add_instruction(migraphx::make_op("add"), x, y);
m2.add_instruction(pass_op{}, sum); m2.add_instruction(pass_op{}, sum);
} }
EXPECT(m1.sort() == m2.sort()); EXPECT(m1.sort() == m2.sort());
} }
TEST_CASE(simplify_dot_horiz_nonstandard)
{
auto s1 = migraphx::shape{migraphx::shape::int32_type, {4, 24, 24}};
auto s2 = migraphx::shape{migraphx::shape::int32_type, {4, 24, 24}, {0, 1, 24}};
migraphx::module m1;
{
auto input = m1.add_parameter("input", s1);
auto a = m1.add_literal(migraphx::generate_literal(s2, 0));
auto b = m1.add_literal(migraphx::generate_literal(s2, 1));
auto c = m1.add_literal(migraphx::generate_literal(s2, 2));
auto zeros = m1.add_literal(0);
zeros = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {4, 24, 24}}}),
zeros);
m1.add_instruction(migraphx::make_op("dot"), input, a);
auto y = m1.add_instruction(migraphx::make_op("dot"), input, b);
m1.add_instruction(migraphx::make_op("dot"), input, c);
auto sum = m1.add_instruction(migraphx::make_op("add"), y, zeros);
auto rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {4, 24, 3, 8}}}), sum);
m1.add_instruction(pass_op{}, rsp);
}
run_pass(m1);
migraphx::module m2;
{
auto input = m2.add_parameter("input", s1);
auto a = m2.add_literal(migraphx::generate_literal(s2, 0));
auto b = m2.add_literal(migraphx::generate_literal(s2, 1));
auto c = m2.add_literal(migraphx::generate_literal(s2, 2));
auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), a, b, c);
auto dot = m2.add_instruction(migraphx::make_op("dot"), input, concat);
auto x = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {24}}, {"ends", {48}}}), dot);
x = m2.add_instruction(migraphx::make_op("contiguous"), x);
auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {4, 24, 3, 8}}}), x);
m2.add_instruction(pass_op{}, rsp);
}
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(simplify_dot_horiz_same_constant) TEST_CASE(simplify_dot_horiz_same_constant)
{ {
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 2}}; auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 2}};
...@@ -2163,6 +2204,8 @@ TEST_CASE(simplify_dot_horiz_same_constant) ...@@ -2163,6 +2204,8 @@ TEST_CASE(simplify_dot_horiz_same_constant)
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {2}}}), dot); migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {2}}}), dot);
auto y = m2.add_instruction( auto y = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {4}}}), dot); migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {4}}}), dot);
x = m2.add_instruction(migraphx::make_op("contiguous"), x);
y = m2.add_instruction(migraphx::make_op("contiguous"), y);
auto sum = m2.add_instruction(migraphx::make_op("add"), x, y); auto sum = m2.add_instruction(migraphx::make_op("add"), x, y);
m2.add_instruction(pass_op{}, sum); m2.add_instruction(pass_op{}, sum);
} }
...@@ -2219,10 +2262,11 @@ TEST_CASE(simplify_dot_horiz_reshape) ...@@ -2219,10 +2262,11 @@ TEST_CASE(simplify_dot_horiz_reshape)
auto y = m2.add_instruction( auto y = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {4}}, {"ends", {8}}}), dot); migraphx::make_op("slice", {{"axes", {2}}, {"starts", {4}}, {"ends", {8}}}), dot);
auto x_cont = m2.add_instruction(migraphx::make_op("contiguous"), x); auto x_cont = m2.add_instruction(migraphx::make_op("contiguous"), x);
auto y_cont = m2.add_instruction(migraphx::make_op("contiguous"), y);
auto x_rsp = auto x_rsp =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 4, 2, 2}}}), x_cont); m2.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 4, 2, 2}}}), x_cont);
auto y_rsp = auto y_rsp = m2.add_instruction(
m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {2}}}), y); migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {2}}}), y_cont);
auto sum = m2.add_instruction(migraphx::make_op("add"), {x_rsp, y_rsp}); auto sum = m2.add_instruction(migraphx::make_op("add"), {x_rsp, y_rsp});
m2.add_instruction(pass_op{}, sum); m2.add_instruction(pass_op{}, sum);
} }
...@@ -2257,6 +2301,8 @@ TEST_CASE(simplify_conv_horiz) ...@@ -2257,6 +2301,8 @@ TEST_CASE(simplify_conv_horiz)
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {12}}}), conv); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {12}}}), conv);
auto y = m2.add_instruction( auto y = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {12}}, {"ends", {24}}}), conv); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {12}}, {"ends", {24}}}), conv);
x = m2.add_instruction(migraphx::make_op("contiguous"), x);
y = m2.add_instruction(migraphx::make_op("contiguous"), y);
auto sum = m2.add_instruction(migraphx::make_op("add"), x, y); auto sum = m2.add_instruction(migraphx::make_op("add"), x, y);
m2.add_instruction(pass_op{}, sum); m2.add_instruction(pass_op{}, sum);
} }
...@@ -2333,12 +2379,16 @@ TEST_CASE(simplify_conv_horiz_grouped) ...@@ -2333,12 +2379,16 @@ TEST_CASE(simplify_conv_horiz_grouped)
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {6}}}), conv); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {6}}}), conv);
auto convy = m2.add_instruction( auto convy = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {6}}, {"ends", {12}}}), conv); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {6}}, {"ends", {12}}}), conv);
convx = m2.add_instruction(migraphx::make_op("contiguous"), convx);
convy = m2.add_instruction(migraphx::make_op("contiguous"), convy);
auto sum1 = m2.add_instruction(migraphx::make_op("add"), convx, convy); auto sum1 = m2.add_instruction(migraphx::make_op("add"), convx, convy);
auto dot = m2.add_instruction(migraphx::make_op("dot"), input, concat2); auto dot = m2.add_instruction(migraphx::make_op("dot"), input, concat2);
auto dotx = m2.add_instruction( auto dotx = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {64}}}), dot); migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {64}}}), dot);
auto doty = m2.add_instruction( auto doty = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {128}}}), dot); migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {128}}}), dot);
dotx = m2.add_instruction(migraphx::make_op("contiguous"), dotx);
doty = m2.add_instruction(migraphx::make_op("contiguous"), doty);
auto sum2 = m2.add_instruction(migraphx::make_op("add"), dotx, doty); auto sum2 = m2.add_instruction(migraphx::make_op("add"), dotx, doty);
auto sum3 = m2.add_instruction(migraphx::make_op("add"), sum1, sum2); auto sum3 = m2.add_instruction(migraphx::make_op("add"), sum1, sum2);
m2.add_instruction(pass_op{}, sum3); m2.add_instruction(pass_op{}, sum3);
...@@ -2391,12 +2441,16 @@ TEST_CASE(simplify_conv_horiz_grouped_extra1) ...@@ -2391,12 +2441,16 @@ TEST_CASE(simplify_conv_horiz_grouped_extra1)
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {6}}}), conv); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {6}}}), conv);
auto convy = m2.add_instruction( auto convy = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {6}}, {"ends", {12}}}), conv); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {6}}, {"ends", {12}}}), conv);
convx = m2.add_instruction(migraphx::make_op("contiguous"), convx);
convy = m2.add_instruction(migraphx::make_op("contiguous"), convy);
auto sum1 = m2.add_instruction(migraphx::make_op("add"), convx, convy); auto sum1 = m2.add_instruction(migraphx::make_op("add"), convx, convy);
auto dot = m2.add_instruction(migraphx::make_op("dot"), input, concat2); auto dot = m2.add_instruction(migraphx::make_op("dot"), input, concat2);
auto dotx = m2.add_instruction( auto dotx = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {64}}}), dot); migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {64}}}), dot);
auto doty = m2.add_instruction( auto doty = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {128}}}), dot); migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {128}}}), dot);
dotx = m2.add_instruction(migraphx::make_op("contiguous"), dotx);
doty = m2.add_instruction(migraphx::make_op("contiguous"), doty);
auto sum2 = m2.add_instruction(migraphx::make_op("add"), dotx, doty); auto sum2 = m2.add_instruction(migraphx::make_op("add"), dotx, doty);
auto sqdiffx = m2.add_instruction(migraphx::make_op("sqdiff"), input, e); auto sqdiffx = m2.add_instruction(migraphx::make_op("sqdiff"), input, e);
auto sum3 = sqdiffx; auto sum3 = sqdiffx;
...@@ -2455,12 +2509,16 @@ TEST_CASE(simplify_conv_horiz_grouped_extra2) ...@@ -2455,12 +2509,16 @@ TEST_CASE(simplify_conv_horiz_grouped_extra2)
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {6}}}), conv); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {6}}}), conv);
auto convy = m2.add_instruction( auto convy = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {6}}, {"ends", {12}}}), conv); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {6}}, {"ends", {12}}}), conv);
convx = m2.add_instruction(migraphx::make_op("contiguous"), convx);
convy = m2.add_instruction(migraphx::make_op("contiguous"), convy);
auto sum1 = m2.add_instruction(migraphx::make_op("add"), convx, convy); auto sum1 = m2.add_instruction(migraphx::make_op("add"), convx, convy);
auto dot = m2.add_instruction(migraphx::make_op("dot"), input, concat2); auto dot = m2.add_instruction(migraphx::make_op("dot"), input, concat2);
auto dotx = m2.add_instruction( auto dotx = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {64}}}), dot); migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {64}}}), dot);
auto doty = m2.add_instruction( auto doty = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {128}}}), dot); migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {128}}}), dot);
dotx = m2.add_instruction(migraphx::make_op("contiguous"), dotx);
doty = m2.add_instruction(migraphx::make_op("contiguous"), doty);
auto sum2 = m2.add_instruction(migraphx::make_op("add"), dotx, doty); auto sum2 = m2.add_instruction(migraphx::make_op("add"), dotx, doty);
auto sqdiffx = m2.add_instruction(migraphx::make_op("sqdiff"), input, e); auto sqdiffx = m2.add_instruction(migraphx::make_op("sqdiff"), input, e);
auto sqdiffy = m2.add_instruction(migraphx::make_op("sqdiff"), input, f); auto sqdiffy = m2.add_instruction(migraphx::make_op("sqdiff"), input, f);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment