#include #include #include #include #include #include #include #include #include #include #include #include struct allocate_no_out : migraphx::auto_register_op { migraphx::shape s{}; template static auto reflect(Self& self, F f) { return migraphx::pack(f(self.s, "shape")); } std::string name() const { return "allocate_no_out"; } migraphx::shape compute_shape(const std::vector& inputs) const { migraphx::check_shapes{inputs, *this}.has(0); return s; } migraphx::argument compute(migraphx::context&, const migraphx::shape& output_shape, const std::vector&) const { return {output_shape}; } }; struct allocate_with_out : migraphx::auto_register_op { migraphx::shape s{}; template static auto reflect(Self& self, F f) { return migraphx::pack(f(self.s, "shape")); } std::string name() const { return "allocate_with_out"; } migraphx::shape compute_shape(const std::vector& inputs) const { migraphx::check_shapes{inputs, *this}.has(0); return s; } migraphx::argument compute(migraphx::context&, const migraphx::shape& output_shape, const std::vector&) const { return {output_shape}; } }; // allocation model that has no out params struct allocation_no_out_model { std::string name() const { return "allocate_no_out"; } migraphx::operation allocate(const migraphx::shape& s) const { return migraphx::make_op(name(), {{"shape", to_value(s)}}); } migraphx::operation preallocate(const migraphx::shape&, const std::string&) const { return {}; } std::string copy() const { return {}; } bool needs_out_params() const { return false; } }; // allocation model with out params struct allocation_with_out_model { std::string name() const { return "allocate_with_out"; } migraphx::operation allocate(const migraphx::shape& s) const { return migraphx::make_op(name(), {{"shape", to_value(s)}}); } migraphx::operation preallocate(const migraphx::shape&, const std::string&) const { return {}; } std::string copy() const { return {}; } bool needs_out_params() const { return true; } }; void run_pass(migraphx::module& m, migraphx::allocation_model model, bool offload_copy = false) { migraphx::run_passes(m, {migraphx::replace_allocate{std::move(model), offload_copy}, migraphx::dead_code_elimination{}}); } void run_pass(migraphx::program& p, migraphx::allocation_model model, bool offload_copy = false) { migraphx::run_passes(p, {migraphx::replace_allocate{std::move(model), offload_copy}, migraphx::dead_code_elimination{}}); } migraphx::module create_simple_program() { migraphx::module m; migraphx::shape s{migraphx::shape::float_type, {5}}; auto x = m.add_parameter("x", s); auto y = m.add_parameter("y", s); auto alloc = m.add_instruction(migraphx::make_op("allocate", {{"shape", migraphx::to_value(s)}})); m.add_instruction(pass_op{}, alloc, x, y); return m; } TEST_CASE(allocate_no_out) { migraphx::module m = create_simple_program(); run_pass(m, allocation_no_out_model{}); EXPECT(std::any_of(m.begin(), m.end(), [](const migraphx::instruction& ins) { return migraphx::contains(ins.name(), "allocate_no_out"); })); } TEST_CASE(allocate_with_out_param) { migraphx::module m = create_simple_program(); run_pass(m, allocation_with_out_model{}); EXPECT(std::none_of(m.begin(), m.end(), [](const migraphx::instruction& ins) { return migraphx::contains(ins.name(), "allocate"); })); } TEST_CASE(allocate_with_out_return) { migraphx::module m = create_simple_program(); m.add_return({std::prev(m.end())}); run_pass(m, allocation_with_out_model{}); EXPECT(std::none_of(m.begin(), m.end(), [](const migraphx::instruction& ins) { return migraphx::contains(ins.name(), "allocate"); })); } TEST_CASE(allocate_with_out_no_params) { migraphx::module m; migraphx::shape s{migraphx::shape::float_type, {5}}; auto x = m.add_parameter("x", s); auto y = m.add_parameter("y", s); auto z = m.add_parameter("z", s); auto alloc = m.add_instruction(migraphx::make_op("allocate", {{"shape", migraphx::to_value(s)}})); auto pass1 = m.add_instruction(pass_op{}, alloc, x, y); auto alloc2 = m.add_instruction(migraphx::make_op("allocate", {{"shape", migraphx::to_value(s)}})); m.add_instruction(pass_op{}, alloc2, z, pass1); run_pass(m, allocation_with_out_model{}); EXPECT(std::any_of(m.begin(), m.end(), [](const migraphx::instruction& ins) { return migraphx::contains(ins.name(), "allocate_with_out"); })); } TEST_CASE(if_allocate) { migraphx::program p; auto* mm = p.get_main_module(); migraphx::shape cond_s{migraphx::shape::bool_type}; auto cond = mm->add_parameter("cond", cond_s); migraphx::shape s{migraphx::shape::float_type, {5}}; auto x = mm->add_parameter("x", s); auto y = mm->add_parameter("y", s); auto* then_mod = p.create_module("If_0_if"); auto alloc = then_mod->add_instruction( migraphx::make_op("allocate", {{"shape", migraphx::to_value(s)}})); auto a1 = then_mod->add_instruction(pass_op{}, alloc, x); then_mod->add_return({a1}); auto* else_mod = p.create_module("If_0_else"); auto alloc1 = else_mod->add_instruction( migraphx::make_op("allocate", {{"shape", migraphx::to_value(s)}})); auto a2 = else_mod->add_instruction(pass_op{}, alloc1, y); else_mod->add_return({a2}); mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); run_pass(p, allocation_with_out_model{}); EXPECT(std::any_of(mm->begin(), mm->end(), [](const migraphx::instruction& ins) { return migraphx::contains(ins.name(), "allocate_with_out"); })); } int main(int argc, const char* argv[]) { test::run(argc, argv); }