Unverified Commit 118ad5d0 authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Simplify dynamic reshape (#2429)

Converts dynamic reshapes with constant output dimensions to a static reshape instruction
Intended to be used after split_single_dyn_dim pass and after find_static_dimensions_of makes shape calculations doable at compile-time.
parent 225873aa
......@@ -33,6 +33,10 @@ inline namespace MIGRAPHX_INLINE_NS {
* Convert 2 input static shape broadcast/multibroadcast into 1 input version.
* Some compiler passes (ex. simplify_algebra) only support the 1 input versions
* of the broadcasting operators.
* From:
* broadcast_op(argument_with_static_shape, argument_with_static_shape)
* To:
* broadcast_op(argument_with_static_shape); broadcast_op.out_lens = constant_output_dims
*/
struct find_static_2in_broadcasts
{
......@@ -172,11 +176,43 @@ struct find_static_dimensions_of
}
};
/**
* Simplify allocate into 2 argument reshape that has constant output dimensions into a static 1
* argument reshape. Intended to simplify what ONNX parse_reshape creates for dynamic reshapes.
* From:
* x = allocate(constant_output_dims) -> reshape(data, x)
* To:
* reshape(data); reshape.dims = constant_output_dims
*/
struct find_const_alloc_reshapes
{
auto matcher() const
{
return match::name("reshape")(match::nargs(2),
match::arg(1)(match::name("allocate")(match::is_constant())));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto reshape_ins = mr.result;
auto reshape_inputs = reshape_ins->inputs();
auto alloc_ins = reshape_inputs.at(1);
argument output_dims_arg = alloc_ins->inputs().at(0)->eval(false);
std::vector<int64_t> output_dims_vec;
output_dims_arg.visit(
[&](auto output) { output_dims_vec.assign(output.begin(), output.end()); });
m.replace_instruction(
reshape_ins, make_op("reshape", {{"dims", output_dims_vec}}), reshape_inputs.at(0));
// have dead_code_elimination remove the previous allocate
}
};
void simplify_dyn_ops::apply(module& m) const
{
match::find_matches(m,
find_static_2in_broadcasts{},
find_static_dimensions_of{},
find_const_alloc_reshapes{},
find_static_2in_broadcasts{},
find_const_3in_slice{},
find_const_4in_slice{});
}
......
......@@ -319,4 +319,71 @@ TEST_CASE(static_dimensions_of_nonfixed)
EXPECT(m0 == m1);
}
TEST_CASE(constant_alloc_reshape)
{
migraphx::module m0;
{
migraphx::shape s{migraphx::shape::float_type, {3, 32}};
auto input = m0.add_parameter("data", s);
migraphx::shape lit_s{migraphx::shape::int64_type, {3}};
auto literal_ins = m0.add_literal(migraphx::literal{lit_s, {3, 4, 8}});
auto alloc_ins = m0.add_instruction(
migraphx::make_op("allocate", {{"buf_type", migraphx::shape::float_type}}),
literal_ins);
auto reshape_ins = m0.add_instruction(migraphx::make_op("reshape"), input, alloc_ins);
m0.add_return({reshape_ins});
}
run_pass(m0);
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {3, 32}};
auto input = m1.add_parameter("data", s);
auto reshape_ins =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 4, 8}}}), input);
m1.add_return({reshape_ins});
}
EXPECT(m0 == m1);
}
// A more contrived example to test static dimensions_of and constant reshape
TEST_CASE(static_dimensions_of_to_constant_alloc_reshape)
{
migraphx::module m0;
{
migraphx::shape input_shape{migraphx::shape::float_type, {3, 4, 8}};
auto x_param = m0.add_parameter("x", input_shape);
auto dimensions_of_ins =
m0.add_instruction(migraphx::make_op("dimensions_of", {{"end", 3}}), x_param);
migraphx::shape lit_shape{migraphx::shape::int64_type, {1}};
auto lit0 = m0.add_literal(migraphx::literal{lit_shape, {0}});
auto gather_ins =
m0.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), dimensions_of_ins, lit0);
auto slice_ins = m0.add_instruction(
migraphx::make_op("slice", {{"starts", {1}}, {"ends", {3}}, {"axes", {0}}}),
dimensions_of_ins);
auto reduce_ins =
m0.add_instruction(migraphx::make_op("reduce_prod", {{"axes", {0}}}), slice_ins);
auto concat_ins =
m0.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), gather_ins, reduce_ins);
auto alloc_ins = m0.add_instruction(
migraphx::make_op("allocate", {{"buf_type", migraphx::shape::float_type}}), concat_ins);
auto reshape_ins = m0.add_instruction(migraphx::make_op("reshape"), x_param, alloc_ins);
m0.add_return({reshape_ins});
}
run_pass(m0);
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {3, 4, 8}};
auto x_param = m1.add_parameter("x", s);
auto reshape_ins =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {3, 32}}}), x_param);
m1.add_return({reshape_ins});
}
EXPECT(m0 == m1);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment