"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "2b5c5f5ecb794aba8279bbdf13f2d20c282e6776"
Commit 1e4e9c57 authored by Paul's avatar Paul
Browse files

Format

parent fa197568
......@@ -128,9 +128,10 @@ struct find_transpose
{
auto matcher() const
{
auto output_not_transpose = match::none_of(
match::skip_output(match::name("contiguous"))(match::name("transpose")));
auto input_has_transpose = match::skip(match::name("contiguous"))(match::args(match::name("transpose")));
auto output_not_transpose =
match::none_of(match::skip_output(match::name("contiguous"))(match::name("transpose")));
auto input_has_transpose =
match::skip(match::name("contiguous"))(match::args(match::name("transpose")));
return match::name("transpose")(output_not_transpose, input_has_transpose);
}
......@@ -654,61 +655,73 @@ struct find_transpose_slice
{
assert(op.starts.size() == op.ends.size());
std::vector<int64_t> result(op.starts.size());
std::transform(op.ends.begin(), op.ends.end(), op.starts.begin(), result.begin(), std::minus<>{});
std::transform(
op.ends.begin(), op.ends.end(), op.starts.begin(), result.begin(), std::minus<>{});
return result;
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto ins = r.result;
auto slices = ins->outputs();
if (slices.empty())
if(slices.empty())
return;
auto slice = any_cast<op::slice>(slices.front()->get_operator());
auto slice = any_cast<op::slice>(slices.front()->get_operator());
auto sdistance = slice_distance(slice);
// Check all distances and axes are the same
if (std::any_of(slices.begin(), slices.end(), [&](auto sins) {
auto s = any_cast<op::slice>(sins->get_operator());
return s.axes != slice.axes or slice_distance(s) != sdistance;
}))
if(std::any_of(slices.begin(), slices.end(), [&](auto sins) {
auto s = any_cast<op::slice>(sins->get_operator());
return s.axes != slice.axes or slice_distance(s) != sdistance;
}))
return;
// Check distances are divisible by axes
auto mod_by_distance = [&](const auto& v, auto f) {
return std::inner_product(v.begin(), v.end(), sdistance.begin(), 0, std::plus<>{}, [&](auto x, auto d) -> uint64_t {
if (d == 0)
return 1;
return f(x) % d;
});
return std::inner_product(v.begin(),
v.end(),
sdistance.begin(),
0,
std::plus<>{},
[&](auto x, auto d) -> uint64_t {
if(d == 0)
return 1;
return f(x) % d;
});
};
if (mod_by_distance(slice.axes, [&](auto x) { return ins->get_shape().lens()[x]; }) != 0 or mod_by_distance(slice.starts, id{}) != 0 or mod_by_distance(slice.ends, id{}) != 0)
if(mod_by_distance(slice.axes, [&](auto x) { return ins->get_shape().lens()[x]; }) != 0 or
mod_by_distance(slice.starts, id{}) != 0 or mod_by_distance(slice.ends, id{}) != 0)
return;
// TODO: Handle multiple axes
if (sdistance.size() != 1)
if(sdistance.size() != 1)
return;
auto axis = slice.axes.front();
// Skip if axis would be packed
if (std::all_of(ins->get_shape().lens().begin(), ins->get_shape().lens().begin()+axis, [](auto x) { return x == 1; }))
if(std::all_of(ins->get_shape().lens().begin(),
ins->get_shape().lens().begin() + axis,
[](auto x) { return x == 1; }))
return;
// Make unsqeeze
auto unsqueeze = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", slice.axes}, {"steps", sdistance}}), ins->inputs());
auto unsqueeze = m.insert_instruction(
ins, make_op("unsqueeze", {{"axes", slice.axes}, {"steps", sdistance}}), ins->inputs());
// Make transpose
auto perm = ins->get_operator().to_value()["permutation"].to_vector<int64_t>();
std::transform(perm.begin(), perm.end(), perm.begin(), [&](auto i) {
if (i >= axis)
if(i >= axis)
return i + 1;
return i;
});
perm.insert(perm.begin(), axis);
auto transpose = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), unsqueeze);
auto transpose =
m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), unsqueeze);
// Slice and sqeeze
for(auto s:slices)
for(auto s : slices)
{
auto op = any_cast<op::slice>(s->get_operator());
op.axes = {0};
op.starts = {op.starts.front()/sdistance.front()};
op.ends = {op.ends.front()/sdistance.front()};
auto op = any_cast<op::slice>(s->get_operator());
op.axes = {0};
op.starts = {op.starts.front() / sdistance.front()};
op.ends = {op.ends.front() / sdistance.front()};
auto slice_ins = m.insert_instruction(ins, op, transpose);
auto squeeze = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), slice_ins);
auto squeeze =
m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), slice_ins);
m.replace_instruction(s, squeeze);
}
}
......
......@@ -1237,14 +1237,14 @@ TEST_CASE(transpose_slice_non_packed_axis)
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 384, 36, 64}});
auto unsqueeze = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}, {"steps", {12}}}), x);
auto transpose =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0, 3, 2, 4}}}), unsqueeze);
auto unsqueeze =
m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}, {"steps", {12}}}), x);
auto transpose = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {1, 0, 3, 2, 4}}}), unsqueeze);
auto slice = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}),
transpose);
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), transpose);
auto squeeze = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice);
auto sqrt = m2.add_instruction(migraphx::make_op("sqrt"), squeeze);
auto sqrt = m2.add_instruction(migraphx::make_op("sqrt"), squeeze);
m2.add_return({sqrt});
}
EXPECT(m1 == m2);
......@@ -1274,22 +1274,20 @@ TEST_CASE(transpose_slice_non_packed_multi_axis)
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 384, 36, 64}});
auto unsqueeze = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}, {"steps", {12}}}), x);
auto transpose =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0, 3, 2, 4}}}), unsqueeze);
auto unsqueeze =
m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}, {"steps", {12}}}), x);
auto transpose = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {1, 0, 3, 2, 4}}}), unsqueeze);
auto slice1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}),
transpose);
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), transpose);
auto squeeze1 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice1);
auto slice2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}),
transpose);
auto squeeze2 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice2);
auto slice2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), transpose);
auto squeeze2 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice2);
auto transpose2 = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), squeeze2);
auto slice3 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {3}}}),
transpose);
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {3}}}), transpose);
auto squeeze3 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice3);
m2.add_return({squeeze1, transpose2, squeeze3});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment