Unverified Commit c54167c7 authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

create simplify_dyn_ops pass and add variable slice optimization (#2161)

New compiler pass that simplifies dynamic shapes related operators to their static versions if possible
Will normally be used after a split_single_dyn_dim pass
parent fbec1d58
...@@ -96,6 +96,7 @@ add_library(migraphx ...@@ -96,6 +96,7 @@ add_library(migraphx
serialize.cpp serialize.cpp
shape.cpp shape.cpp
simplify_algebra.cpp simplify_algebra.cpp
simplify_dyn_ops.cpp
simplify_reshapes.cpp simplify_reshapes.cpp
split_single_dyn_dim.cpp split_single_dyn_dim.cpp
target.cpp target.cpp
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_SIMPLIFY_DYN_OPS_HPP
#define MIGRAPHX_GUARD_RTGLIB_SIMPLIFY_DYN_OPS_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
/**
* Convert dynamic ops to their static version if possible.
* Should be run after the split_single_dyn_dims pass.
*/
struct MIGRAPHX_EXPORT simplify_dyn_ops
{
std::string name() const { return "simplify_dyn_ops"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/simplify_dyn_ops.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
/**
* Convert 2 input static shape broadcast/multibroadcast into 1 input version.
* Some compiler passes (ex. simplify_algebra) only support the 1 input versions
* of the broadcasting operators.
*/
struct find_static_2in_broadcasts
{
auto matcher() const
{
return match::broadcast(match::nargs(2),
match::arg(0)(match::static_shape()),
match::arg(1)(match::static_shape()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto out_lens = ins->get_shape().lens();
auto broadcast_op = ins->get_operator();
if(broadcast_op.name() == "broadcast")
{
broadcast_op.from_value({{"out_lens", out_lens}});
}
else
{
broadcast_op.from_value({{"out_lens", out_lens}, {"out_dyn_dims", {}}});
}
m.replace_instruction(ins, broadcast_op, ins->inputs().at(0));
}
};
/**
* Simplify slice with variable `starts` and `ends` to the constant version if
* the `input_starts` and `input_ends` inputs are constant.
*/
struct find_const_3in_slice
{
auto matcher() const
{
return match::name("slice")(match::nargs(3),
match::arg(1)(match::is_constant()),
match::arg(2)(match::is_constant()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto inputs = ins->inputs();
argument starts_arg = inputs.at(1)->eval();
argument ends_arg = inputs.at(2)->eval();
if(not starts_arg.empty() and not ends_arg.empty())
{
std::vector<int64_t> starts_vec;
std::vector<int64_t> ends_vec;
starts_arg.visit([&](auto output) { starts_vec.assign(output.begin(), output.end()); });
ends_arg.visit([&](auto output) { ends_vec.assign(output.begin(), output.end()); });
auto slice_val = ins->get_operator().to_value();
auto axes_vec = slice_val.at("axes").to_vector<int64_t>();
m.replace_instruction(
ins,
make_op("slice", {{"starts", starts_vec}, {"ends", ends_vec}, {"axes", axes_vec}}),
inputs.at(0));
}
}
};
/**
* Simplify slice with variable `starts`, `ends`, and `input_axes` to the constant version if
* the `input_starts`, `input_ends`, and `input_axes` inputs are constant.
*/
struct find_const_4in_slice
{
auto matcher() const
{
return match::name("slice")(match::nargs(4),
match::arg(1)(match::is_constant()),
match::arg(2)(match::is_constant()),
match::arg(3)(match::is_constant()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto inputs = ins->inputs();
argument starts_arg = inputs.at(1)->eval();
argument ends_arg = inputs.at(2)->eval();
argument axes_arg = inputs.at(3)->eval();
if(not starts_arg.empty() and not ends_arg.empty() and not axes_arg.empty())
{
std::vector<int64_t> starts_vec;
std::vector<int64_t> ends_vec;
std::vector<int64_t> axes_vec;
starts_arg.visit([&](auto output) { starts_vec.assign(output.begin(), output.end()); });
ends_arg.visit([&](auto output) { ends_vec.assign(output.begin(), output.end()); });
axes_arg.visit([&](auto output) { axes_vec.assign(output.begin(), output.end()); });
m.replace_instruction(
ins,
make_op("slice", {{"starts", starts_vec}, {"ends", ends_vec}, {"axes", axes_vec}}),
inputs.at(0));
}
}
};
void simplify_dyn_ops::apply(module& m) const
{
match::find_matches(
m, find_static_2in_broadcasts{}, find_const_3in_slice{}, find_const_4in_slice{});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -68,37 +68,6 @@ has_one_dyn_dim(const std::unordered_map<std::string, shape>& param_shapes) ...@@ -68,37 +68,6 @@ has_one_dyn_dim(const std::unordered_map<std::string, shape>& param_shapes)
dds_it->max}; dds_it->max};
} }
namespace {
struct find_static_2in_broadcasts
{
// Convert 2 input static shape broadcast/multibroadcast into 1 input version.
// Some compiler passes (ex. simplify_algebra) only support the 1 input versions
// of the broadcasting operators.
auto matcher() const
{
return match::broadcast(match::nargs(2),
match::arg(0)(match::static_shape()),
match::arg(1)(match::static_shape()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto out_lens = ins->get_shape().lens();
auto broadcast_op = ins->get_operator();
if(broadcast_op.name() == "broadcast")
{
broadcast_op.from_value({{"out_lens", out_lens}});
}
else
{
broadcast_op.from_value({{"out_lens", out_lens}, {"out_dyn_dims", {}}});
}
m.replace_instruction(ins, broadcast_op, ins->inputs().at(0));
}
};
} // namespace
/** /**
* Makes all the shapes in the dynamic_dimension range. Probably won't work for `if` * Makes all the shapes in the dynamic_dimension range. Probably won't work for `if`
* and `loop` instructions, depending on how the submodules for those * and `loop` instructions, depending on how the submodules for those
...@@ -135,7 +104,6 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const ...@@ -135,7 +104,6 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const
dd_check->dyn_param_str, migraphx::shape{dyn_param_shape.type(), static_lens}); dd_check->dyn_param_str, migraphx::shape{dyn_param_shape.type(), static_lens});
auto outputs = submod->add_instructions(mm, map_ins); auto outputs = submod->add_instructions(mm, map_ins);
submod->add_return({outputs}); submod->add_return({outputs});
match::find_matches(*submod, find_static_2in_broadcasts{});
submodules.push_back(submod); submodules.push_back(submod);
} }
// redirect to select_module operator and return // redirect to select_module operator and return
......
...@@ -48,6 +48,7 @@ ...@@ -48,6 +48,7 @@
#include <migraphx/rewrite_quantization.hpp> #include <migraphx/rewrite_quantization.hpp>
#include <migraphx/rewrite_rnn.hpp> #include <migraphx/rewrite_rnn.hpp>
#include <migraphx/schedule.hpp> #include <migraphx/schedule.hpp>
#include <migraphx/simplify_dyn_ops.hpp>
#include <migraphx/simplify_qdq.hpp> #include <migraphx/simplify_qdq.hpp>
#include <migraphx/simplify_reshapes.hpp> #include <migraphx/simplify_reshapes.hpp>
#include <migraphx/split_single_dyn_dim.hpp> #include <migraphx/split_single_dyn_dim.hpp>
...@@ -109,6 +110,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -109,6 +110,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
{ {
split_single_dyn_dim{}, split_single_dyn_dim{},
dead_code_elimination{}, dead_code_elimination{},
simplify_dyn_ops{},
dead_code_elimination{},
normalize_ops{}, normalize_ops{},
dead_code_elimination{}, dead_code_elimination{},
simplify_qdq{}, simplify_qdq{},
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/simplify_dyn_ops.hpp>
#include <migraphx/split_single_dyn_dim.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
void run_pass(migraphx::module& m)
{
migraphx::run_passes(m, {migraphx::simplify_dyn_ops{}, migraphx::dead_code_elimination{}});
}
TEST_CASE(static_broadcast)
{
migraphx::module m0;
{
migraphx::shape s{migraphx::shape::float_type, {2, 4}};
auto input = m0.add_parameter("data", s);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {4}}};
auto literal_ins = m0.add_literal(migraphx::literal{lit_s, {6, 5, 4, 3}});
auto broadcast_lit = m0.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", s.lens()}}), literal_ins);
auto add_ins = m0.add_instruction(migraphx::make_op("add"), input, broadcast_lit);
m0.add_return({add_ins});
}
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {2, 4}};
auto input = m1.add_parameter("data", s);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {4}}};
auto literal_ins = m1.add_literal(migraphx::literal{lit_s, {6, 5, 4, 3}});
auto broadcast_lit =
m1.add_instruction(migraphx::make_op("broadcast", {{"axis", 1}}), literal_ins, input);
auto add_ins = m1.add_instruction(migraphx::make_op("add"), input, broadcast_lit);
m1.add_return({add_ins});
}
run_pass(m1);
EXPECT(m0 == m1);
}
TEST_CASE(static_multibroadcast)
{
migraphx::module m0;
{
migraphx::shape s{migraphx::shape::float_type, {2, 4}};
auto input = m0.add_parameter("data", s);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}, {0}}};
auto literal_ins = m0.add_literal(migraphx::literal{lit_s, {6}});
auto broadcast_lit = m0.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), literal_ins);
auto add_ins = m0.add_instruction(migraphx::make_op("add"), input, broadcast_lit);
m0.add_return({add_ins});
}
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {2, 4}};
auto input = m1.add_parameter("data", s);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}, {0}}};
auto literal_ins = m1.add_literal(migraphx::literal{lit_s, {6}});
auto broadcast_lit =
m1.add_instruction(migraphx::make_op("multibroadcast"), literal_ins, input);
auto add_ins = m1.add_instruction(migraphx::make_op("add"), input, broadcast_lit);
m1.add_return({add_ins});
}
run_pass(m1);
EXPECT(m0 == m1);
}
TEST_CASE(after_split_dyn_broadcast_match)
{
migraphx::program p0;
{
auto* mm0 = p0.get_main_module();
// create batch submodules
auto create_submodule = [&](std::size_t batch_size, const std::string& module_name) {
auto* submod = p0.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 4}};
auto sm_input = submod->add_parameter("data", sm_shape);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {4}}};
auto literal_ins = submod->add_literal(migraphx::literal{lit_s, {6, 5, 4, 3}});
auto broadcast_lit = submod->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", sm_shape.lens()}}),
literal_ins);
auto add_ins =
submod->add_instruction(migraphx::make_op("add"), sm_input, broadcast_lit);
submod->add_return({add_ins});
return submod;
};
auto* dim1 = create_submodule(1, "dim_1");
auto* dim2 = create_submodule(2, "dim_2");
auto* dim3 = create_submodule(3, "dim_3");
auto* dim4 = create_submodule(4, "dim_4");
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input0 = mm0->add_parameter("data", s);
std::vector<migraphx::shape> sub_shapes = {};
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}});
migraphx::shape out_attr = migraphx::shape{sub_shapes};
auto sm_ins = mm0->add_instruction(
migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
{input0},
{dim1, dim2, dim3, dim4});
auto ret =
mm0->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), sm_ins);
mm0->add_return({ret});
}
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input1 = mm1->add_parameter("data", s);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {4}}};
auto literal_ins = mm1->add_literal(migraphx::literal{lit_s, {6, 5, 4, 3}});
auto broadcast_lit = mm1->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}}), literal_ins, input1);
auto add_ins = mm1->add_instruction(migraphx::make_op("add"), input1, broadcast_lit);
mm1->add_return({add_ins});
}
migraphx::run_passes(p1,
{migraphx::split_single_dyn_dim{},
migraphx::dead_code_elimination{},
migraphx::simplify_dyn_ops{}});
EXPECT(p0 == p1);
}
TEST_CASE(const_slice_3input)
{
migraphx::module m0;
{
migraphx::shape s{migraphx::shape::float_type, {6, 4, 4}};
auto input = m0.add_parameter("data", s);
auto slice_ins = m0.add_instruction(
migraphx::make_op("slice", {{"starts", {0}}, {"ends", {3}}, {"axes", {0}}}), input);
m0.add_return({slice_ins});
}
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {6, 4, 4}};
auto input = m1.add_parameter("data", s);
migraphx::shape s1{migraphx::shape::int32_type, {1}};
auto input_starts = m1.add_literal(migraphx::literal{s1, {0}});
auto input_ends = m1.add_literal(migraphx::literal{s1, {3}});
auto slice_ins = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}}), input, input_starts, input_ends);
m1.add_return({slice_ins});
}
run_pass(m1);
EXPECT(m0 == m1);
}
TEST_CASE(const_slice_3input_dyn)
{
migraphx::module m0;
{
migraphx::shape s{migraphx::shape::float_type, {{6, 6}, {2, 4, {2, 4}}, {2, 4, {2, 4}}}};
auto input = m0.add_parameter("data", s);
auto slice_ins = m0.add_instruction(
migraphx::make_op("slice", {{"starts", {0}}, {"ends", {3}}, {"axes", {0}}}), input);
m0.add_return({slice_ins});
}
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {{6, 6}, {2, 4, {2, 4}}, {2, 4, {2, 4}}}};
auto input = m1.add_parameter("data", s);
migraphx::shape s1{migraphx::shape::int32_type, {1}};
auto input_starts = m1.add_literal(migraphx::literal{s1, {0}});
auto input_ends = m1.add_literal(migraphx::literal{s1, {3}});
auto slice_ins = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}}), input, input_starts, input_ends);
m1.add_return({slice_ins});
}
run_pass(m1);
EXPECT(m0 == m1);
}
TEST_CASE(const_slice_4input)
{
migraphx::module m0;
{
migraphx::shape s{migraphx::shape::float_type, {6, 4, 4}};
auto input = m0.add_parameter("data", s);
auto slice_ins = m0.add_instruction(
migraphx::make_op("slice", {{"starts", {0}}, {"ends", {3}}, {"axes", {0}}}), input);
m0.add_return({slice_ins});
}
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {6, 4, 4}};
auto input = m1.add_parameter("data", s);
migraphx::shape s1{migraphx::shape::int32_type, {1}};
auto input_starts = m1.add_literal(migraphx::literal{s1, {0}});
auto input_ends = m1.add_literal(migraphx::literal{s1, {3}});
auto input_axes = m1.add_literal(migraphx::literal{s1, {0}});
auto slice_ins = m1.add_instruction(
migraphx::make_op("slice"), input, input_starts, input_ends, input_axes);
m1.add_return({slice_ins});
}
run_pass(m1);
EXPECT(m0 == m1);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -50,8 +50,8 @@ TEST_CASE(dynamic_batch) ...@@ -50,8 +50,8 @@ TEST_CASE(dynamic_batch)
auto sm_input = submod->add_parameter("data", sm_shape); auto sm_input = submod->add_parameter("data", sm_shape);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}}; migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}};
auto literal_ins = submod->add_literal(migraphx::literal{lit_s, {6}}); auto literal_ins = submod->add_literal(migraphx::literal{lit_s, {6}});
auto broadcast_lit = submod->add_instruction( auto broadcast_lit =
migraphx::make_op("multibroadcast", {{"out_lens", sm_shape.lens()}}), literal_ins); submod->add_instruction(migraphx::make_op("multibroadcast"), literal_ins, sm_input);
auto add_ins = auto add_ins =
submod->add_instruction(migraphx::make_op("add"), sm_input, broadcast_lit); submod->add_instruction(migraphx::make_op("add"), sm_input, broadcast_lit);
submod->add_return({add_ins}); submod->add_return({add_ins});
...@@ -107,8 +107,8 @@ TEST_CASE(multiple_outputs) ...@@ -107,8 +107,8 @@ TEST_CASE(multiple_outputs)
auto sm_input = submod->add_parameter("data", sm_shape); auto sm_input = submod->add_parameter("data", sm_shape);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}}; migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {1}}};
auto literal_ins = submod->add_literal(migraphx::literal{lit_s, {6}}); auto literal_ins = submod->add_literal(migraphx::literal{lit_s, {6}});
auto broadcast_lit = submod->add_instruction( auto broadcast_lit =
migraphx::make_op("multibroadcast", {{"out_lens", sm_shape.lens()}}), literal_ins); submod->add_instruction(migraphx::make_op("multibroadcast"), literal_ins, sm_input);
auto add0_ins = auto add0_ins =
submod->add_instruction(migraphx::make_op("add"), sm_input, broadcast_lit); submod->add_instruction(migraphx::make_op("add"), sm_input, broadcast_lit);
auto add1_ins = submod->add_instruction(migraphx::make_op("add"), sm_input, sm_input); auto add1_ins = submod->add_instruction(migraphx::make_op("add"), sm_input, sm_input);
...@@ -157,64 +157,4 @@ TEST_CASE(multiple_outputs) ...@@ -157,64 +157,4 @@ TEST_CASE(multiple_outputs)
EXPECT(p0 == p1); EXPECT(p0 == p1);
} }
TEST_CASE(broadcast_match)
{
// Slightly different from ref_ops_test in that the literal is copied over the submodules.
// A different compiler pass will pull the literals from the submodules to the main module.
migraphx::program p0;
{
auto* mm0 = p0.get_main_module();
// create batch submodules
auto create_submodule = [&](std::size_t batch_size, const std::string& module_name) {
auto* submod = p0.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 4}};
auto sm_input = submod->add_parameter("data", sm_shape);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {4}}};
auto literal_ins = submod->add_literal(migraphx::literal{lit_s, {6, 5, 4, 3}});
auto broadcast_lit = submod->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", sm_shape.lens()}}),
literal_ins);
auto add_ins =
submod->add_instruction(migraphx::make_op("add"), sm_input, broadcast_lit);
submod->add_return({add_ins});
return submod;
};
auto* dim1 = create_submodule(1, "dim_1");
auto* dim2 = create_submodule(2, "dim_2");
auto* dim3 = create_submodule(3, "dim_3");
auto* dim4 = create_submodule(4, "dim_4");
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input0 = mm0->add_parameter("data", s);
std::vector<migraphx::shape> sub_shapes = {};
sub_shapes.push_back(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 4}}});
migraphx::shape out_attr = migraphx::shape{sub_shapes};
auto sm_ins = mm0->add_instruction(
migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
{input0},
{dim1, dim2, dim3, dim4});
auto ret =
mm0->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), sm_ins);
mm0->add_return({ret});
}
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{1, 4}, {4, 4}}};
auto input1 = mm1->add_parameter("data", s);
migraphx::shape lit_s{migraphx::shape{migraphx::shape::float_type, {4}}};
auto literal_ins = mm1->add_literal(migraphx::literal{lit_s, {6, 5, 4, 3}});
auto broadcast_lit = mm1->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}}), literal_ins, input1);
auto add_ins = mm1->add_instruction(migraphx::make_op("add"), input1, broadcast_lit);
mm1->add_return({add_ins});
}
run_pass(p1);
EXPECT(p0 == p1);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment