Commit 1e4e9c57 authored by Paul's avatar Paul
Browse files

Format

parent fa197568
...@@ -128,9 +128,10 @@ struct find_transpose ...@@ -128,9 +128,10 @@ struct find_transpose
{ {
auto matcher() const auto matcher() const
{ {
auto output_not_transpose = match::none_of( auto output_not_transpose =
match::skip_output(match::name("contiguous"))(match::name("transpose"))); match::none_of(match::skip_output(match::name("contiguous"))(match::name("transpose")));
auto input_has_transpose = match::skip(match::name("contiguous"))(match::args(match::name("transpose"))); auto input_has_transpose =
match::skip(match::name("contiguous"))(match::args(match::name("transpose")));
return match::name("transpose")(output_not_transpose, input_has_transpose); return match::name("transpose")(output_not_transpose, input_has_transpose);
} }
...@@ -654,7 +655,8 @@ struct find_transpose_slice ...@@ -654,7 +655,8 @@ struct find_transpose_slice
{ {
assert(op.starts.size() == op.ends.size()); assert(op.starts.size() == op.ends.size());
std::vector<int64_t> result(op.starts.size()); std::vector<int64_t> result(op.starts.size());
std::transform(op.ends.begin(), op.ends.end(), op.starts.begin(), result.begin(), std::minus<>{}); std::transform(
op.ends.begin(), op.ends.end(), op.starts.begin(), result.begin(), std::minus<>{});
return result; return result;
} }
...@@ -662,53 +664,64 @@ struct find_transpose_slice ...@@ -662,53 +664,64 @@ struct find_transpose_slice
{ {
auto ins = r.result; auto ins = r.result;
auto slices = ins->outputs(); auto slices = ins->outputs();
if (slices.empty()) if(slices.empty())
return; return;
auto slice = any_cast<op::slice>(slices.front()->get_operator()); auto slice = any_cast<op::slice>(slices.front()->get_operator());
auto sdistance = slice_distance(slice); auto sdistance = slice_distance(slice);
// Check all distances and axes are the same // Check all distances and axes are the same
if (std::any_of(slices.begin(), slices.end(), [&](auto sins) { if(std::any_of(slices.begin(), slices.end(), [&](auto sins) {
auto s = any_cast<op::slice>(sins->get_operator()); auto s = any_cast<op::slice>(sins->get_operator());
return s.axes != slice.axes or slice_distance(s) != sdistance; return s.axes != slice.axes or slice_distance(s) != sdistance;
})) }))
return; return;
// Check distances are divisible by axes // Check distances are divisible by axes
auto mod_by_distance = [&](const auto& v, auto f) { auto mod_by_distance = [&](const auto& v, auto f) {
return std::inner_product(v.begin(), v.end(), sdistance.begin(), 0, std::plus<>{}, [&](auto x, auto d) -> uint64_t { return std::inner_product(v.begin(),
if (d == 0) v.end(),
sdistance.begin(),
0,
std::plus<>{},
[&](auto x, auto d) -> uint64_t {
if(d == 0)
return 1; return 1;
return f(x) % d; return f(x) % d;
}); });
}; };
if (mod_by_distance(slice.axes, [&](auto x) { return ins->get_shape().lens()[x]; }) != 0 or mod_by_distance(slice.starts, id{}) != 0 or mod_by_distance(slice.ends, id{}) != 0) if(mod_by_distance(slice.axes, [&](auto x) { return ins->get_shape().lens()[x]; }) != 0 or
mod_by_distance(slice.starts, id{}) != 0 or mod_by_distance(slice.ends, id{}) != 0)
return; return;
// TODO: Handle multiple axes // TODO: Handle multiple axes
if (sdistance.size() != 1) if(sdistance.size() != 1)
return; return;
auto axis = slice.axes.front(); auto axis = slice.axes.front();
// Skip if axis would be packed // Skip if axis would be packed
if (std::all_of(ins->get_shape().lens().begin(), ins->get_shape().lens().begin()+axis, [](auto x) { return x == 1; })) if(std::all_of(ins->get_shape().lens().begin(),
ins->get_shape().lens().begin() + axis,
[](auto x) { return x == 1; }))
return; return;
// Make unsqeeze // Make unsqeeze
auto unsqueeze = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", slice.axes}, {"steps", sdistance}}), ins->inputs()); auto unsqueeze = m.insert_instruction(
ins, make_op("unsqueeze", {{"axes", slice.axes}, {"steps", sdistance}}), ins->inputs());
// Make transpose // Make transpose
auto perm = ins->get_operator().to_value()["permutation"].to_vector<int64_t>(); auto perm = ins->get_operator().to_value()["permutation"].to_vector<int64_t>();
std::transform(perm.begin(), perm.end(), perm.begin(), [&](auto i) { std::transform(perm.begin(), perm.end(), perm.begin(), [&](auto i) {
if (i >= axis) if(i >= axis)
return i + 1; return i + 1;
return i; return i;
}); });
perm.insert(perm.begin(), axis); perm.insert(perm.begin(), axis);
auto transpose = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), unsqueeze); auto transpose =
m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), unsqueeze);
// Slice and sqeeze // Slice and sqeeze
for(auto s:slices) for(auto s : slices)
{ {
auto op = any_cast<op::slice>(s->get_operator()); auto op = any_cast<op::slice>(s->get_operator());
op.axes = {0}; op.axes = {0};
op.starts = {op.starts.front()/sdistance.front()}; op.starts = {op.starts.front() / sdistance.front()};
op.ends = {op.ends.front()/sdistance.front()}; op.ends = {op.ends.front() / sdistance.front()};
auto slice_ins = m.insert_instruction(ins, op, transpose); auto slice_ins = m.insert_instruction(ins, op, transpose);
auto squeeze = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), slice_ins); auto squeeze =
m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), slice_ins);
m.replace_instruction(s, squeeze); m.replace_instruction(s, squeeze);
} }
} }
......
...@@ -1237,12 +1237,12 @@ TEST_CASE(transpose_slice_non_packed_axis) ...@@ -1237,12 +1237,12 @@ TEST_CASE(transpose_slice_non_packed_axis)
migraphx::module m2; migraphx::module m2;
{ {
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 384, 36, 64}}); auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 384, 36, 64}});
auto unsqueeze = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}, {"steps", {12}}}), x); auto unsqueeze =
auto transpose = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}, {"steps", {12}}}), x);
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0, 3, 2, 4}}}), unsqueeze); auto transpose = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {1, 0, 3, 2, 4}}}), unsqueeze);
auto slice = m2.add_instruction( auto slice = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), transpose);
transpose);
auto squeeze = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice); auto squeeze = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice);
auto sqrt = m2.add_instruction(migraphx::make_op("sqrt"), squeeze); auto sqrt = m2.add_instruction(migraphx::make_op("sqrt"), squeeze);
m2.add_return({sqrt}); m2.add_return({sqrt});
...@@ -1274,22 +1274,20 @@ TEST_CASE(transpose_slice_non_packed_multi_axis) ...@@ -1274,22 +1274,20 @@ TEST_CASE(transpose_slice_non_packed_multi_axis)
migraphx::module m2; migraphx::module m2;
{ {
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 384, 36, 64}}); auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 384, 36, 64}});
auto unsqueeze = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}, {"steps", {12}}}), x); auto unsqueeze =
auto transpose = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}, {"steps", {12}}}), x);
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0, 3, 2, 4}}}), unsqueeze); auto transpose = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {1, 0, 3, 2, 4}}}), unsqueeze);
auto slice1 = m2.add_instruction( auto slice1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), transpose);
transpose);
auto squeeze1 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice1); auto squeeze1 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice1);
auto slice2 = m2.add_instruction( auto slice2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), transpose);
transpose);
auto squeeze2 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice2); auto squeeze2 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice2);
auto transpose2 = m2.add_instruction( auto transpose2 = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), squeeze2); migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), squeeze2);
auto slice3 = m2.add_instruction( auto slice3 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {3}}}), migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {3}}}), transpose);
transpose);
auto squeeze3 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice3); auto squeeze3 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice3);
m2.add_return({squeeze1, transpose2, squeeze3}); m2.add_return({squeeze1, transpose2, squeeze3});
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment