"profiler/vscode:/vscode.git/clone" did not exist on "85978e0201bb94bf6e59b325e1f5f19266845d08"
Unverified Commit 4bd3f4e3 authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Simplify dimensions of (#2207)

Simplifies dimensions_of instructions to a literal when possible.
Intended to be used after the split_single_dyn_dim pass.
parent b9475653
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include <migraphx/simplify_dyn_ops.hpp> #include <migraphx/simplify_dyn_ops.hpp>
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/literal.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -131,10 +132,53 @@ struct find_const_4in_slice ...@@ -131,10 +132,53 @@ struct find_const_4in_slice
} }
}; };
/**
* Simplify dimensions_of to a literal when the input arugment has a static shape
* or the dynamic dimensions from `start` to `end` are fixed.
*/
struct find_static_dimensions_of
{
auto matcher() const { return match::name("dimensions_of")(); }
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto input = ins->inputs().at(0);
auto dimensions_of_value = ins->get_operator().to_value();
auto start = dimensions_of_value.at("start").to<std::size_t>();
auto end = dimensions_of_value.at("end").to<std::size_t>();
if(input->get_shape().dynamic())
{
// check if dynamic dimensions from start to end are fixed
auto dds = input->get_shape().dyn_dims();
if(std::any_of(dds.begin() + start, dds.begin() + end, [](auto dd) {
return not dd.is_fixed();
}))
{
return;
}
}
std::size_t output_ndim = end - start;
std::vector<int64_t> vec_shape(output_ndim);
migraphx::shape s(migraphx::shape::int64_type, {output_ndim});
std::vector<std::size_t> input_lens = input->get_shape().to_static(1).lens();
std::transform(input_lens.begin() + start,
input_lens.begin() + end,
vec_shape.begin(),
[](auto i) { return int64_t(i); });
migraphx::shape output_shape{migraphx::shape::int64_type, {end - start}};
auto lit_ins = m.add_literal(migraphx::literal{output_shape, vec_shape});
m.replace_instruction(ins, lit_ins);
}
};
void simplify_dyn_ops::apply(module& m) const void simplify_dyn_ops::apply(module& m) const
{ {
match::find_matches( match::find_matches(m,
m, find_static_2in_broadcasts{}, find_const_3in_slice{}, find_const_4in_slice{}); find_static_2in_broadcasts{},
find_static_dimensions_of{},
find_const_3in_slice{},
find_const_4in_slice{});
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -647,8 +647,8 @@ struct find_broadcast_transpose ...@@ -647,8 +647,8 @@ struct find_broadcast_transpose
{ {
auto transpose = r.result; auto transpose = r.result;
auto transpose_lens = transpose->get_shape().lens(); auto transpose_lens = transpose->get_shape().lens();
auto bcast_ins = r.instructions["bcast_ins"]; auto bcast_ins = r.instructions["bcast_ins"];
auto input = bcast_ins->inputs().front(); auto input = bcast_ins->inputs().front();
// scalar transformation does not need extra transpose // scalar transformation does not need extra transpose
if(not input->get_shape().scalar()) if(not input->get_shape().scalar())
{ {
......
...@@ -237,4 +237,86 @@ TEST_CASE(const_slice_4input) ...@@ -237,4 +237,86 @@ TEST_CASE(const_slice_4input)
EXPECT(m0 == m1); EXPECT(m0 == m1);
} }
TEST_CASE(static_dimensions_of0)
{
// dead_code_elimination will get rid of atan
migraphx::module m0;
{
migraphx::shape s{migraphx::shape::float_type, {2, 4, 4}};
auto input = m0.add_parameter("data", s);
auto atan_ins = m0.add_instruction(migraphx::make_op("atan"), input);
auto dimensions_of_ins =
m0.add_instruction(migraphx::make_op("dimensions_of", {{"end", 3}}), atan_ins);
m0.add_return({dimensions_of_ins});
}
run_pass(m0);
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {2, 4, 4}};
m1.add_parameter("data", s);
migraphx::shape lit_shape{migraphx::shape::int64_type, {3}};
std::vector<int64_t> lit_data = {2, 4, 4};
auto lit_ins = m1.add_literal(migraphx::literal{lit_shape, lit_data});
m1.add_return({lit_ins});
}
EXPECT(m0 == m1);
}
TEST_CASE(static_dimensions_of1)
{
// dead_code_elimination will get rid of atan
migraphx::module m0;
{
migraphx::shape s{migraphx::shape::float_type, {{2, 4, {2, 4}}, {4, 4}, {4, 4}}};
auto input = m0.add_parameter("data", s);
auto atan_ins = m0.add_instruction(migraphx::make_op("atan"), input);
auto dimensions_of_ins = m0.add_instruction(
migraphx::make_op("dimensions_of", {{"start", 1}, {"end", 3}}), atan_ins);
m0.add_return({dimensions_of_ins});
}
run_pass(m0);
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {{2, 4, {2, 4}}, {4, 4}, {4, 4}}};
m1.add_parameter("data", s);
migraphx::shape lit_shape{migraphx::shape::int64_type, {2}};
std::vector<int64_t> lit_data = {4, 4};
auto lit_ins = m1.add_literal(migraphx::literal{lit_shape, lit_data});
m1.add_return({lit_ins});
}
EXPECT(m0 == m1);
}
// Does nothing because the dynamic_dimensions from start to end
// are not all fixed
TEST_CASE(static_dimensions_of_nonfixed)
{
// dead_code_elimination will get rid of atan
migraphx::module m0;
{
migraphx::shape s{migraphx::shape::float_type, {{2, 4, {2, 4}}, {4, 8}, {4, 8}}};
auto input = m0.add_parameter("data", s);
auto atan_ins = m0.add_instruction(migraphx::make_op("atan"), input);
auto dimensions_of_ins = m0.add_instruction(
migraphx::make_op("dimensions_of", {{"start", 1}, {"end", 3}}), atan_ins);
m0.add_return({dimensions_of_ins});
}
run_pass(m0);
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {{2, 4, {2, 4}}, {4, 8}, {4, 8}}};
auto input = m1.add_parameter("data", s);
auto atan_ins = m1.add_instruction(migraphx::make_op("atan"), input);
auto dimensions_of_ins = m1.add_instruction(
migraphx::make_op("dimensions_of", {{"start", 1}, {"end", 3}}), atan_ins);
m1.add_return({dimensions_of_ins});
}
EXPECT(m0 == m1);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment