Commit e9a3e6c1 authored by Paul's avatar Paul
Browse files

Merge branch 'simplify-more-reshapes' into sd-opt

parents f6e22d56 5967d68d
......@@ -96,10 +96,41 @@ struct reshape
return {s0.type(), output_dyn_dims};
}
template <class Iterator>
static auto compute_end_dim(Iterator start, Iterator last, std::size_t dim)
{
std::size_t x = 1;
auto it = std::find_if(start, last, [&](auto i) {
x *= i;
return x >= dim;
});
if(x != dim)
return start;
return it;
}
template <class DimIterator, class StrideIterator>
static auto can_strides_merge(DimIterator dim_start,
DimIterator dim_last,
StrideIterator stride_start,
StrideIterator stride_last)
{
auto cstride = *std::prev(stride_last);
return std::equal(std::make_reverse_iterator(dim_last),
std::make_reverse_iterator(dim_start + 1),
std::make_reverse_iterator(stride_last - 1),
std::make_reverse_iterator(stride_start),
[&](auto dim, auto stride) {
cstride *= dim;
return stride == cstride;
});
}
shape static_compute_shape(std::vector<shape> inputs, std::size_t n_neg_dims) const
{
check_shapes{inputs, *this}.standard();
auto&& idims = inputs.front().lens();
check_shapes{inputs, *this}.has(1);
auto&& idims = inputs.front().lens();
auto&& istrides = inputs.front().strides();
std::vector<std::size_t> rdims(dims.begin(), dims.end());
for(std::size_t i = 0; i < dims.size(); i++)
......@@ -125,7 +156,81 @@ struct reshape
}
}
shape s{inputs.front().type(), rdims};
shape s;
if(inputs.front().standard())
{
s = shape{inputs.front().type(), rdims};
}
else
{
std::vector<std::size_t> rstrides;
std::size_t i = 0;
std::size_t r = 0;
while(i < idims.size() and r < rdims.size())
{
auto idim = idims[i];
auto rdim = rdims[r];
if(rdim == idim)
{
rstrides.push_back(istrides[i]);
}
// squeeze
else if(rdim > idim)
{
auto start = idims.begin() + i;
auto it = compute_end_dim(start, idims.end(), rdim);
if(it == start)
break;
auto n = it - start;
if((i + n) > istrides.size())
break;
if(not can_strides_merge(
start, it + 1, istrides.begin() + i, istrides.begin() + i + n))
break;
i += n;
rstrides.push_back(istrides[i]);
}
// unsqueeze
else if(rdim < idim)
{
auto start = rdims.begin() + i;
auto it = compute_end_dim(start, rdims.end(), idim);
if(it == start)
break;
auto n = it - start;
if((r + n) > rdims.size())
break;
auto stride = istrides[i] * idim;
std::for_each(start, it + 1, [&](auto dim) {
stride /= dim;
rstrides.push_back(stride);
});
r += n;
}
i++;
r++;
}
// Handle trailing 1s
if(rstrides.size() < rdims.size() and not rstrides.empty())
{
auto stride = rstrides.back();
for(auto d : range(rdims.begin() + rstrides.size(), rdims.end()))
{
if(d != 1)
break;
rstrides.push_back(stride);
}
}
if(rdims.size() != rstrides.size())
MIGRAPHX_THROW("Reshape on axis that is not standard");
s = shape{inputs.front().type(), rdims, rstrides};
}
assert(s.bytes() == inputs.front().bytes());
if(s.elements() != inputs.front().elements())
MIGRAPHX_THROW("Reshape: Wrong number of elements for reshape: reshape has " +
std::to_string(s.elements()) + " elements whereas the input has " +
......
......@@ -49,7 +49,6 @@ const auto& reshaper_names()
static const std::unordered_set<std::string> names = {
"flatten",
"reshape",
"contiguous",
"squeeze",
"unsqueeze"
};
......@@ -89,38 +88,23 @@ struct find_reshaper
{
auto matcher() const
{
return match::name(reshaper_names())(
match::any_of[match::outputs()](match::name(reshaper_names())));
auto no_output_reshape = match::none_of[match::outputs()](match::name(reshaper_names()));
auto input_reshape =
match::arg(0)(match::skip(match::name("contiguous"))(match::name(reshaper_names())));
auto input = match::skip(match::name(reshaper_names()),
match::name("contiguous"))(match::arg(0).bind("x"));
return match::name(reshaper_names())(no_output_reshape, input_reshape, input);
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
std::vector<instruction_ref> reshapes{ins};
while(is_reshaper(reshapes.back()))
{
assert(not reshapes.back()->inputs().empty());
assert(m.has_instruction(reshapes.back()->inputs().front()));
auto input = reshapes.back()->inputs().front();
reshapes.push_back(input);
}
auto ins = mr.result;
auto input = mr.instructions["x"];
auto dims = ins->get_shape().lens();
std::pair<instruction_ref, instruction_ref> r{m.end(), m.end()};
for(auto start : iterator_for(reshapes))
{
auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) {
return i->get_shape() == (*start)->get_shape() and i != (*start);
});
if(last != reshapes.rend())
{
r = std::make_pair(*start, *last);
break;
}
}
if(r.first != r.second)
{
m.replace_instruction(r.first, r.second);
}
if(not input->get_shape().standard())
input = m.insert_instruction(ins, make_op("contiguous"), input);
m.replace_instruction(ins, make_op("reshape", {{"dims", dims}}), input);
}
};
......@@ -603,14 +587,15 @@ struct find_reshape_cont
};
// match sequence of transpose --> contiguous --> reshaper_op
auto match_transpose_contiguous_reshaper()
template <class... Ms>
auto match_transpose_contiguous_reshaper(Ms... ms)
{
return match::name({"reshape", "squeeze", "unsqueeze"})(
match::used_once(),
match::args(
match::name("contiguous")(
match::used_once(), match::args(match::transpose_shape().bind("trans_ins")))
.bind("cont_ins")))
match::args(match::name("contiguous")(
match::used_once(),
match::args(match::transpose_shape(ms...).bind("trans_ins")))
.bind("cont_ins")))
.bind("reshaper_ins");
};
......@@ -642,6 +627,45 @@ struct find_transpose_contiguous_reshaper_unary
}
};
struct find_mul_add_transpose_contiguous_reshaper_gemm
{
auto matcher() const
{
auto pw = match::name("mul", "add")(
match::used_once(),
match::either_arg(0, 1)(match::is_constant().bind("c"), match::any().bind("x")));
return match::name("dot")(match::either_arg(0, 1)(
match_transpose_contiguous_reshaper(match::args(pw.bind("pointwise"))),
match::is_constant()));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto reshaper_ins = r.instructions["reshaper_ins"];
auto trans_ins = r.instructions["trans_ins"];
auto x_ins = r.instructions["x"];
auto c_ins = r.instructions["c"];
auto pw_ins = r.instructions["pointwise"];
auto insert_reshapes = [&](auto x) {
auto t = m.insert_instruction(ins, trans_ins->get_operator(), x);
auto c = m.insert_instruction(ins, make_op("contiguous"), t);
return m.insert_instruction(ins, reshaper_ins->get_operator(), c);
};
if(x_ins->name() == "mul")
{
x_ins = m.insert_instruction(
ins,
make_op("mul"),
{insert_reshapes(x_ins->inputs()[0]), insert_reshapes(x_ins->inputs()[1])});
}
auto y_ins =
m.insert_instruction(ins, pw_ins->get_operator(), {x_ins, insert_reshapes(c_ins)});
m.replace_instruction(reshaper_ins, y_ins);
}
};
struct find_slice_transpose
{
auto matcher() const
......@@ -797,6 +821,98 @@ struct find_transpose_slice
}
};
struct find_reshape_gemm
{
auto matcher() const { return match::name("reshape")(match::arg(0)(match::name("dot"))); }
static bool is_batched_unsqueeze(instruction_ref ins)
{
auto input = ins->inputs().front()->get_shape().lens();
auto output = ins->get_shape().lens();
if(output.size() <= input.size())
return false;
if(not std::equal(input.end() - 2, input.end(), output.end() - 2, output.end()))
return false;
return true;
}
static operation make_reshape(std::vector<std::size_t> batches, instruction_ref ins)
{
batches.insert(
batches.end(), ins->get_shape().lens().end() - 2, ins->get_shape().lens().end());
return make_op("reshape", {{"dims", batches}});
}
void apply(module& m, const match::matcher_result& r) const
{
auto reshape_ins = r.result;
auto dot_ins = reshape_ins->inputs().front();
// TODO: Put this in the matcher
if(not is_batched_unsqueeze(reshape_ins))
return;
std::vector<std::size_t> batches;
std::copy(reshape_ins->get_shape().lens().begin(),
reshape_ins->get_shape().lens().end() - 2,
std::back_inserter(batches));
auto input0 = m.insert_instruction(
dot_ins, make_reshape(batches, dot_ins->inputs()[0]), dot_ins->inputs()[0]);
auto input1 = m.insert_instruction(
dot_ins, make_reshape(batches, dot_ins->inputs()[1]), dot_ins->inputs()[1]);
m.replace_instruction(dot_ins, make_op("dot"), input0, input1);
}
};
struct find_broadcast_reshaper
{
auto matcher() const
{
auto broadcast =
match::broadcast_shape(match::skip(match::broadcast_shape())(match::any().bind("x")));
return match::name(reshaper_names())(
match::args(match::skip(match::name("contiguous"))(broadcast.bind("broadcast"))));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto broadcast_ins = r.instructions["broadcast"];
auto x_ins = r.instructions["x"];
auto broadcast_shape = broadcast_ins->get_shape();
auto result_shape = ins->get_shape();
if(std::accumulate(broadcast_shape.strides().begin(), broadcast_shape.strides().end(), 0) !=
1)
return;
auto baxis =
std::find(broadcast_shape.strides().begin(), broadcast_shape.strides().end(), 1) -
broadcast_shape.strides().begin();
auto relements = result_shape.lens();
std::partial_sum(
relements.begin(), relements.end(), relements.begin(), std::multiplies<>{});
auto prefix_elements = std::accumulate(broadcast_shape.lens().begin(),
broadcast_shape.lens().begin() + baxis + 1,
1,
std::multiplies<>{});
auto axis =
std::find(relements.begin(), relements.end(), prefix_elements) - relements.begin();
if(axis >= relements.size())
return;
if(x_ins->get_shape().lens().size() > 1)
x_ins = m.insert_instruction(ins, make_op("squeeze"), x_ins);
m.replace_instruction(
ins,
make_op("broadcast", {{"axis", axis}, {"out_lens", ins->get_shape().lens()}}),
x_ins);
}
};
void simplify_reshapes::apply(module& m) const
{
for(int i = 0; i < 4; i++)
......@@ -804,9 +920,10 @@ void simplify_reshapes::apply(module& m) const
match::find_matches(m,
find_where_op{},
find_resize{},
find_reshape_cont{},
find_nop_reshapes{},
find_reshaper{},
find_broadcast_reshaper{},
// find_reshape_cont{},
find_transpose{},
find_concat_transpose{},
find_concat_multibroadcasts{},
......@@ -815,7 +932,9 @@ void simplify_reshapes::apply(module& m) const
find_nested_concat{},
find_transpose_slice{},
find_slice_transpose{},
find_transpose_contiguous_reshaper_unary{});
find_transpose_contiguous_reshaper_unary{},
find_mul_add_transpose_contiguous_reshaper_gemm{},
find_reshape_gemm{});
dead_code_elimination{}.apply(m);
}
}
......
......@@ -2181,6 +2181,16 @@ TEST_CASE(reshape_shape)
}
}
TEST_CASE(reshape_nonstandard_unsqeeze)
{
migraphx::shape input{migraphx::shape::float_type, {4, 24, 1, 1, 1}, {1, 4, 1, 1, 1}};
std::vector<std::size_t> lens = {4, 1, 3, 4, 2};
std::vector<int64_t> perm = {4, 0, 1, 2, 3};
migraphx::shape output = migraphx::shape::from_permutation(
migraphx::shape::float_type, lens, migraphx::invert_permutation(perm));
expect_shape(output, migraphx::make_op("reshape", {{"dims", lens}}), input);
}
TEST_CASE(reshape_dyn_shape)
{
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {24, 24}, {1, 1}, {1, 1}}};
......
......@@ -1503,4 +1503,30 @@ TEST_CASE(transpose_slice_non_packed_multi_axis)
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(broadcast_transpose_reshape)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {320, 1, 1}});
auto broadcast = m1.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 320, 64, 64}}}), x);
auto transpose = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), broadcast);
auto contiguous = m1.add_instruction(migraphx::make_op("contiguous"), transpose);
auto reshape = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 4096, 320}}}),
contiguous);
m1.add_return({reshape});
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {320, 1, 1}});
auto squeeze = m2.add_instruction(migraphx::make_op("squeeze"), x);
auto broadcast = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 2}, {"out_lens", {2, 4096, 320}}}), squeeze);
m2.add_return({broadcast});
}
EXPECT(m1.sort() == m2.sort());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment