Commit bb90f0eb authored by Shiv's avatar Shiv
Browse files

update conv dot fusion

parent e7ec374f
......@@ -896,21 +896,26 @@ struct find_conv_dot_horiz_fusion
int64_t offset = 0;
for(auto arg : range(start, last))
{
auto outputs = arg->outputs();
for(auto output : outputs)
{
if(output->name() != "reshape")
continue;
auto x = m.insert_instruction(output, make_op("contiguous"), arg);
m.replace_instruction(output, output->get_operator(), x);
}
auto outputs = arg->outputs();
auto requires_contiguous = std::any_of(outputs.begin(), outputs.end(), [](auto o) {
return o->get_shape().standard();
});
int64_t len = arg->get_shape().lens()[axis];
m.replace_instruction(
arg,
auto slice = m.insert_instruction(
std::prev(arg),
make_op("slice",
{{"axes", {axis}}, {"starts", {offset}}, {"ends", {offset + len}}}),
fused);
if(requires_contiguous)
{
m.replace_instruction(arg, make_op("contiguous"), slice);
}
else
{
m.replace_instruction(arg, slice);
}
offset += len;
}
};
......
......@@ -2133,12 +2133,53 @@ TEST_CASE(simplify_dot_horiz)
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {2}}}), dot);
auto y = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {4}}}), dot);
x = m2.add_instruction(migraphx::make_op("contiguous"), x);
y = m2.add_instruction(migraphx::make_op("contiguous"), y);
auto sum = m2.add_instruction(migraphx::make_op("add"), x, y);
m2.add_instruction(pass_op{}, sum);
}
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(simplify_dot_horiz_nonstandard)
{
auto s1 = migraphx::shape{migraphx::shape::int32_type, {4, 24, 24}};
auto s2 = migraphx::shape{migraphx::shape::int32_type, {4, 24, 24}, {0, 1, 24}};
migraphx::module m1;
{
auto input = m1.add_parameter("input", s1);
auto a = m1.add_literal(migraphx::generate_literal(s2, 0));
auto b = m1.add_literal(migraphx::generate_literal(s2, 1));
auto c = m1.add_literal(migraphx::generate_literal(s2, 2));
auto zeros = m1.add_literal(0);
zeros = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {4, 24, 24}}}),
zeros);
m1.add_instruction(migraphx::make_op("dot"), input, a);
auto y = m1.add_instruction(migraphx::make_op("dot"), input, b);
m1.add_instruction(migraphx::make_op("dot"), input, c);
auto sum = m1.add_instruction(migraphx::make_op("add"), y, zeros);
auto rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {4, 24, 3, 8}}}), sum);
m1.add_instruction(pass_op{}, rsp);
}
run_pass(m1);
migraphx::module m2;
{
auto input = m2.add_parameter("input", s1);
auto a = m2.add_literal(migraphx::generate_literal(s2, 0));
auto b = m2.add_literal(migraphx::generate_literal(s2, 1));
auto c = m2.add_literal(migraphx::generate_literal(s2, 2));
auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), a, b, c);
auto dot = m2.add_instruction(migraphx::make_op("dot"), input, concat);
auto x = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {24}}, {"ends", {48}}}), dot);
x = m2.add_instruction(migraphx::make_op("contiguous"), x);
auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {4, 24, 3, 8}}}), x);
m2.add_instruction(pass_op{}, rsp);
}
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(simplify_dot_horiz_same_constant)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 2}};
......@@ -2163,6 +2204,8 @@ TEST_CASE(simplify_dot_horiz_same_constant)
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {2}}}), dot);
auto y = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {4}}}), dot);
x = m2.add_instruction(migraphx::make_op("contiguous"), x);
y = m2.add_instruction(migraphx::make_op("contiguous"), y);
auto sum = m2.add_instruction(migraphx::make_op("add"), x, y);
m2.add_instruction(pass_op{}, sum);
}
......@@ -2219,10 +2262,11 @@ TEST_CASE(simplify_dot_horiz_reshape)
auto y = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {4}}, {"ends", {8}}}), dot);
auto x_cont = m2.add_instruction(migraphx::make_op("contiguous"), x);
auto y_cont = m2.add_instruction(migraphx::make_op("contiguous"), y);
auto x_rsp =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 4, 2, 2}}}), x_cont);
auto y_rsp =
m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {2}}}), y);
auto y_rsp = m2.add_instruction(
migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {2}}}), y_cont);
auto sum = m2.add_instruction(migraphx::make_op("add"), {x_rsp, y_rsp});
m2.add_instruction(pass_op{}, sum);
}
......@@ -2257,6 +2301,8 @@ TEST_CASE(simplify_conv_horiz)
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {12}}}), conv);
auto y = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {12}}, {"ends", {24}}}), conv);
x = m2.add_instruction(migraphx::make_op("contiguous"), x);
y = m2.add_instruction(migraphx::make_op("contiguous"), y);
auto sum = m2.add_instruction(migraphx::make_op("add"), x, y);
m2.add_instruction(pass_op{}, sum);
}
......@@ -2333,12 +2379,16 @@ TEST_CASE(simplify_conv_horiz_grouped)
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {6}}}), conv);
auto convy = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {6}}, {"ends", {12}}}), conv);
convx = m2.add_instruction(migraphx::make_op("contiguous"), convx);
convy = m2.add_instruction(migraphx::make_op("contiguous"), convy);
auto sum1 = m2.add_instruction(migraphx::make_op("add"), convx, convy);
auto dot = m2.add_instruction(migraphx::make_op("dot"), input, concat2);
auto dotx = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {64}}}), dot);
auto doty = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {128}}}), dot);
dotx = m2.add_instruction(migraphx::make_op("contiguous"), dotx);
doty = m2.add_instruction(migraphx::make_op("contiguous"), doty);
auto sum2 = m2.add_instruction(migraphx::make_op("add"), dotx, doty);
auto sum3 = m2.add_instruction(migraphx::make_op("add"), sum1, sum2);
m2.add_instruction(pass_op{}, sum3);
......@@ -2391,12 +2441,16 @@ TEST_CASE(simplify_conv_horiz_grouped_extra1)
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {6}}}), conv);
auto convy = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {6}}, {"ends", {12}}}), conv);
convx = m2.add_instruction(migraphx::make_op("contiguous"), convx);
convy = m2.add_instruction(migraphx::make_op("contiguous"), convy);
auto sum1 = m2.add_instruction(migraphx::make_op("add"), convx, convy);
auto dot = m2.add_instruction(migraphx::make_op("dot"), input, concat2);
auto dotx = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {64}}}), dot);
auto doty = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {128}}}), dot);
dotx = m2.add_instruction(migraphx::make_op("contiguous"), dotx);
doty = m2.add_instruction(migraphx::make_op("contiguous"), doty);
auto sum2 = m2.add_instruction(migraphx::make_op("add"), dotx, doty);
auto sqdiffx = m2.add_instruction(migraphx::make_op("sqdiff"), input, e);
auto sum3 = sqdiffx;
......@@ -2455,12 +2509,16 @@ TEST_CASE(simplify_conv_horiz_grouped_extra2)
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {6}}}), conv);
auto convy = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {6}}, {"ends", {12}}}), conv);
convx = m2.add_instruction(migraphx::make_op("contiguous"), convx);
convy = m2.add_instruction(migraphx::make_op("contiguous"), convy);
auto sum1 = m2.add_instruction(migraphx::make_op("add"), convx, convy);
auto dot = m2.add_instruction(migraphx::make_op("dot"), input, concat2);
auto dotx = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {64}}}), dot);
auto doty = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {64}}, {"ends", {128}}}), dot);
dotx = m2.add_instruction(migraphx::make_op("contiguous"), dotx);
doty = m2.add_instruction(migraphx::make_op("contiguous"), doty);
auto sum2 = m2.add_instruction(migraphx::make_op("add"), dotx, doty);
auto sqdiffx = m2.add_instruction(migraphx::make_op("sqdiff"), input, e);
auto sqdiffy = m2.add_instruction(migraphx::make_op("sqdiff"), input, f);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment