"src/vscode:/vscode.git/clone" did not exist on "c6d92ecc94c8f2fb8d12fd2b4a5219df275cfdf5"
Commit fa197568 authored by Paul's avatar Paul
Browse files

Adjust transpose slice

parent c96b88b7
......@@ -128,8 +128,10 @@ struct find_transpose
{
auto matcher() const
{
return match::name("transpose")(match::none_of(
match::skip_output(match::name("contiguous"))(match::name("transpose"))));
auto output_not_transpose = match::none_of(
match::skip_output(match::name("contiguous"))(match::name("transpose")));
auto input_has_transpose = match::skip(match::name("contiguous"))(match::args(match::name("transpose")));
return match::name("transpose")(output_not_transpose, input_has_transpose);
}
void apply(module& m, const match::matcher_result& mr) const
......@@ -641,9 +643,80 @@ struct find_slice_transpose
}
};
struct find_transpose_slice
{
auto matcher() const
{
return match::name("transpose")(match::all_of[match::outputs()](match::name("slice")));
}
static std::vector<int64_t> slice_distance(const op::slice& op)
{
assert(op.starts.size() == op.ends.size());
std::vector<int64_t> result(op.starts.size());
std::transform(op.ends.begin(), op.ends.end(), op.starts.begin(), result.begin(), std::minus<>{});
return result;
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto slices = ins->outputs();
if (slices.empty())
return;
auto slice = any_cast<op::slice>(slices.front()->get_operator());
auto sdistance = slice_distance(slice);
// Check all distances and axes are the same
if (std::any_of(slices.begin(), slices.end(), [&](auto sins) {
auto s = any_cast<op::slice>(sins->get_operator());
return s.axes != slice.axes or slice_distance(s) != sdistance;
}))
return;
// Check distances are divisible by axes
auto mod_by_distance = [&](const auto& v, auto f) {
return std::inner_product(v.begin(), v.end(), sdistance.begin(), 0, std::plus<>{}, [&](auto x, auto d) -> uint64_t {
if (d == 0)
return 1;
return f(x) % d;
});
};
if (mod_by_distance(slice.axes, [&](auto x) { return ins->get_shape().lens()[x]; }) != 0 or mod_by_distance(slice.starts, id{}) != 0 or mod_by_distance(slice.ends, id{}) != 0)
return;
// TODO: Handle multiple axes
if (sdistance.size() != 1)
return;
auto axis = slice.axes.front();
// Skip if axis would be packed
if (std::all_of(ins->get_shape().lens().begin(), ins->get_shape().lens().begin()+axis, [](auto x) { return x == 1; }))
return;
// Make unsqeeze
auto unsqueeze = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", slice.axes}, {"steps", sdistance}}), ins->inputs());
// Make transpose
auto perm = ins->get_operator().to_value()["permutation"].to_vector<int64_t>();
std::transform(perm.begin(), perm.end(), perm.begin(), [&](auto i) {
if (i >= axis)
return i + 1;
return i;
});
perm.insert(perm.begin(), axis);
auto transpose = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), unsqueeze);
// Slice and sqeeze
for(auto s:slices)
{
auto op = any_cast<op::slice>(s->get_operator());
op.axes = {0};
op.starts = {op.starts.front()/sdistance.front()};
op.ends = {op.ends.front()/sdistance.front()};
auto slice_ins = m.insert_instruction(ins, op, transpose);
auto squeeze = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), slice_ins);
m.replace_instruction(s, squeeze);
}
}
};
void simplify_reshapes::apply(module& m) const
{
for(int i = 0; i < 2; i++)
for(int i = 0; i < 4; i++)
{
match::find_matches(m,
find_where_op{},
......@@ -656,6 +729,7 @@ void simplify_reshapes::apply(module& m) const
find_nested_convert{},
find_nested_slice{},
find_nested_concat{},
find_transpose_slice{},
find_slice_transpose{},
find_transpose_contiguous_reshaper_unary{});
dead_code_elimination{}.apply(m);
......
......@@ -1220,4 +1220,80 @@ TEST_CASE(transpose_slice_single_transpose)
EXPECT(m1 == m2);
}
TEST_CASE(transpose_slice_non_packed_axis)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 384, 36, 64}});
auto transpose =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), x);
auto slice = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {12}}}),
transpose);
auto sqrt = m1.add_instruction(migraphx::make_op("sqrt"), slice);
m1.add_return({sqrt});
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 384, 36, 64}});
auto unsqueeze = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}, {"steps", {12}}}), x);
auto transpose =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0, 3, 2, 4}}}), unsqueeze);
auto slice = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}),
transpose);
auto squeeze = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice);
auto sqrt = m2.add_instruction(migraphx::make_op("sqrt"), squeeze);
m2.add_return({sqrt});
}
EXPECT(m1 == m2);
}
TEST_CASE(transpose_slice_non_packed_multi_axis)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 384, 36, 64}});
auto transpose =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), x);
auto slice1 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {12}}}),
transpose);
auto slice2 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {12}}, {"ends", {24}}}),
transpose);
auto transpose2 = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), slice2);
auto slice3 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {24}}, {"ends", {36}}}),
transpose);
m1.add_return({slice1, transpose2, slice3});
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 384, 36, 64}});
auto unsqueeze = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}, {"steps", {12}}}), x);
auto transpose =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0, 3, 2, 4}}}), unsqueeze);
auto slice1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}),
transpose);
auto squeeze1 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice1);
auto slice2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}),
transpose);
auto squeeze2 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice2);
auto transpose2 = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), squeeze2);
auto slice3 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {3}}}),
transpose);
auto squeeze3 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice3);
m2.add_return({squeeze1, transpose2, squeeze3});
}
EXPECT(m1.sort() == m2.sort());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment