#include #include #include #include #include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { std::unordered_map create_output_names(const module& mod) { std::unordered_map mod_output_names{}; auto last = std::prev(mod.end()); if(last->name() == "@return") { const auto& prog_outputs = last->inputs(); std::vector outputs_alias(prog_outputs.size()); std::transform(prog_outputs.begin(), prog_outputs.end(), outputs_alias.begin(), [](const auto& i) { return instruction::get_output_alias(i); }); std::size_t index = 0; for(auto ins : outputs_alias) { mod_output_names[ins] = mod.name() + ":#output_" + std::to_string(index++); } } else { auto ins = instruction::get_output_alias(last); mod_output_names[ins] = "output"; } return mod_output_names; } void insert_submod_allocations(instruction_ref ins, module& mod, const allocation_model& model) { std::vector inputs = ins->inputs(); std::vector mod_args = ins->module_inputs(); std::map name_shapes; for(const auto& smod : mod_args) { auto ps = smod->get_parameter_shapes(); name_shapes.insert(ps.begin(), ps.end()); } for(auto& pn : name_shapes) { const auto& s = pn.second; instruction_ref output{}; output = mod.insert_instruction(ins, model.allocate(s)); inputs.push_back(output); } mod.replace_instruction(ins, ins->get_operator(), inputs, mod_args); } void replace_allocate::apply(module& m) const { auto mod_output_names = create_output_names(m); bool main_offload_copy = m.name() == "main" ? this->offload_copy : false; for(auto ins : iterator_for(m)) { auto op = ins->get_operator(); auto op_name = op.name(); // check if allocations from submodules need to be inserted // for now, only the "if" operator is affected if(op_name == "if") { insert_submod_allocations(ins, m, model); continue; } if(op_name != "allocate") continue; auto s = ins->get_shape(); if(not main_offload_copy and model.needs_out_params() and contains(mod_output_names, ins)) { auto out_param = m.add_parameter(mod_output_names[ins], s); m.replace_instruction(ins, out_param); continue; } m.replace_instruction( ins, m.insert_instruction(ins, make_op(model.name(), migraphx::value{{"shape", to_value(s)}}))); } } } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx