Commit 66858a44 authored by Paul's avatar Paul
Browse files

Merge branch 'fuse-horiz-contiguous' into jit-contiguous2

parents 2e5116bf dfc7bbac
...@@ -69,40 +69,58 @@ static bool try_compute_shape(instruction_ref ins, ...@@ -69,40 +69,58 @@ static bool try_compute_shape(instruction_ref ins,
return try_compute_shape(ins, inputs, mods); return try_compute_shape(ins, inputs, mods);
} }
void eliminate_contiguous::apply(module& m) const template <class F>
static void remove_contiguous(const std::string& op_name, module& m, F f)
{ {
auto last = std::prev(m.end());
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
{ {
// return instruction should have inputs with standard shape // return instruction should have inputs with standard shape
if(ins->name() == "@return") if(ins->name() == "@return")
continue; continue;
if(ins != last and ins->outputs().empty())
continue;
if(not f(ins))
continue;
// Make a copy so we can modify it while we iterate // Make a copy so we can modify it while we iterate
auto args = ins->inputs(); auto args = ins->inputs();
auto new_args = args; auto new_args = args;
auto mod_args = ins->module_inputs(); auto mod_args = ins->module_inputs();
for(auto arg : ins->inputs()) for(auto arg : ins->inputs())
{ {
if(arg->name() == op_name) if(arg->name() != op_name)
continue;
auto prev = arg->inputs().front();
replace(new_args, arg, prev);
if(try_compute_shape(ins, new_args, mod_args))
{
instruction::replace_argument(ins, arg, prev);
}
else if(prev->can_eval())
{ {
auto prev = arg->inputs().front(); auto c = op::contiguous{};
replace(new_args, arg, prev); auto r = c.compute(c.compute_shape({prev->get_shape()}), {prev->eval()});
if(try_compute_shape(ins, new_args, mod_args))
{
instruction::replace_argument(ins, arg, prev);
}
else if(prev->can_eval())
{
auto c = op::contiguous{};
auto r = c.compute(c.compute_shape({prev->get_shape()}), {prev->eval()});
auto l = m.add_literal(r.get_shape(), r.data()); auto l = m.add_literal(r.get_shape(), r.data());
m.replace_instruction(arg, l); m.replace_instruction(arg, l);
}
} }
} }
} }
} }
void eliminate_contiguous::apply(module& m) const
{
// Skip contiguous from splits first
remove_contiguous(op_name, m, [](auto ins) {
if(ins->name() != "slice")
return true;
return (ins->inputs().front()->outputs().size() == 1);
});
remove_contiguous(op_name, m, [](auto) { return true; });
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -1118,4 +1118,106 @@ TEST_CASE(transpose_contiguous_reshape_binary_broadcast) ...@@ -1118,4 +1118,106 @@ TEST_CASE(transpose_contiguous_reshape_binary_broadcast)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(transpose_slice)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {1, 384, 36, 64}});
auto slice1 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {12}}}), x);
auto transpose1 = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), slice1);
auto slice2 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {12}}, {"ends", {24}}}), x);
auto transpose2 = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), slice2);
auto slice3 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {24}}, {"ends", {36}}}), x);
auto transpose3 = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), slice3);
m1.add_return({transpose1, transpose2, transpose3});
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {1, 384, 36, 64}});
auto transpose =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), x);
auto slice1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {12}}}),
transpose);
auto slice2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {12}}, {"ends", {24}}}),
transpose);
auto slice3 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {24}}, {"ends", {36}}}),
transpose);
m2.add_return({slice1, slice2, slice3});
}
EXPECT(m1 == m2);
}
TEST_CASE(transpose_slice_diff_perm)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {1, 384, 36, 64}});
auto slice1 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {12}}}), x);
auto transpose1 = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), slice1);
auto slice2 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {12}}, {"ends", {24}}}), x);
auto transpose2 = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), slice2);
auto slice3 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {24}}, {"ends", {36}}}), x);
auto transpose3 = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), slice3);
m1.add_return({transpose1, transpose2, transpose3});
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {1, 384, 36, 64}});
auto transpose =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), x);
auto slice1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {12}}}),
transpose);
auto slice2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {12}}, {"ends", {24}}}),
transpose);
auto transpose2 = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), slice2);
auto slice3 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {24}}, {"ends", {36}}}),
transpose);
m2.add_return({slice1, transpose2, slice3});
}
EXPECT(m1 == m2);
}
TEST_CASE(transpose_slice_single_transpose)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {1, 384, 36, 64}});
auto slice1 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {12}}}), x);
auto sqrt1 = m1.add_instruction(migraphx::make_op("sqrt"), slice1);
auto slice2 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {12}}, {"ends", {24}}}), x);
auto transpose2 = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), slice2);
auto slice3 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {24}}, {"ends", {36}}}), x);
auto sqrt3 = m1.add_instruction(migraphx::make_op("sqrt"), slice3);
m1.add_return({sqrt1, transpose2, sqrt3});
}
migraphx::module m2 = m1;
run_pass(m1);
EXPECT(m1 == m2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment