Unverified Commit d3e5a5c0 authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Simplify constant of shape (#2448)

Simplify what we parse the ONNX operator ConstantOfShape into a literal for a static shape
parent 6f032fc0
...@@ -261,6 +261,7 @@ struct find_static_dimensions_of ...@@ -261,6 +261,7 @@ struct find_static_dimensions_of
/** /**
* Simplify allocate into 2 argument reshape that has constant output dimensions into a static 1 * Simplify allocate into 2 argument reshape that has constant output dimensions into a static 1
* argument reshape. Intended to simplify what ONNX parse_reshape creates for dynamic reshapes. * argument reshape. Intended to simplify what ONNX parse_reshape creates for dynamic reshapes.
* This matcher can be generalized to matching reshape(data, static_shape_output_tensor).
* From: * From:
* x = allocate(constant_output_dims) -> reshape(data, x) * x = allocate(constant_output_dims) -> reshape(data, x)
* To: * To:
...@@ -289,6 +290,34 @@ struct find_const_alloc_reshapes ...@@ -289,6 +290,34 @@ struct find_const_alloc_reshapes
} }
}; };
/**
* Simplify allocate into fill operator that has constant output dimensions and constant value.
* The allocate into fill instructions is what is produced when parsing the ONNX
* ConstantOfShape operator. This replacement could be handled with propagate_constant, but
* would rather have the simplification happen earlier during compiling.
* This matcher can be generalized to matching fill(constant_value, static_shape_output_tensor).
* From:
* x = allocate(constant_ouptut_dims) -> fill(constant_value, x)
* To:
* literal
*/
struct find_const_alloc_fill
{
auto matcher() const
{
return match::name("fill")(match::arg(0)(match::is_constant()),
match::arg(1)(match::name("allocate")(match::is_constant())));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto fill_ins = mr.result;
auto fill_arg = fill_ins->eval(false);
auto l = m.add_literal(fill_arg.get_shape(), fill_arg.data());
m.replace_instruction(fill_ins, l);
}
};
void simplify_dyn_ops::apply(module& m) const void simplify_dyn_ops::apply(module& m) const
{ {
match::find_matches(m, match::find_matches(m,
...@@ -297,7 +326,8 @@ void simplify_dyn_ops::apply(module& m) const ...@@ -297,7 +326,8 @@ void simplify_dyn_ops::apply(module& m) const
find_static_2in_broadcasts{}, find_static_2in_broadcasts{},
find_const_2in_slice{}, find_const_2in_slice{},
find_const_3in_slice{}, find_const_3in_slice{},
find_const_4in_slice{}); find_const_4in_slice{},
find_const_alloc_fill{});
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -544,4 +544,31 @@ TEST_CASE(static_dimensions_of_to_constant_alloc_reshape) ...@@ -544,4 +544,31 @@ TEST_CASE(static_dimensions_of_to_constant_alloc_reshape)
EXPECT(m0 == m1); EXPECT(m0 == m1);
} }
TEST_CASE(const_alloc_fill)
{
migraphx::module m0;
{
migraphx::shape val_shape{migraphx::shape::int64_type, {1}, {0}};
std::vector<int64_t> lit_data = {3};
auto value_lit = m0.add_literal(migraphx::literal{val_shape, lit_data});
migraphx::shape lit_s{migraphx::shape::int64_type, {3}};
auto output_dim_lit = m0.add_literal(migraphx::literal{lit_s, {3, 4, 4}});
auto alloc_ins = m0.add_instruction(
migraphx::make_op("allocate", {{"buf_type", migraphx::shape::int64_type}}),
output_dim_lit);
auto ret = m0.add_instruction(migraphx::make_op("fill"), value_lit, alloc_ins);
m0.add_return({ret});
}
run_pass(m0);
migraphx::module m1;
{
migraphx::shape lit_shape{migraphx::shape::int64_type, {3, 4, 4}};
std::vector<int64_t> lit_data(3 * 4 * 4, 3);
auto ret = m1.add_literal(migraphx::literal{lit_shape, lit_data});
m1.add_return({ret});
}
EXPECT(m0 == m1);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment