#include #include #include #include #include #include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { static literal get_scalar(instruction_ref ins) { const auto& s = ins->get_shape(); if(not(s.elements() == 1 or s.scalar())) return {}; if(not ins->can_eval()) return {}; auto e = ins->eval(); literal r{}; e.visit_at([&](auto x) { r = literal{x}; }); return r; } static void create_pointwise_modules(module_pass_manager& mpm) { std::size_t n = 0; for(auto ins : iterator_for(mpm.get_module())) { if(not ins->get_operator().attributes().get("pointwise", false)) continue; auto* pm = mpm.create_module("pointwise" + std::to_string(n++)); pm->set_bypass(); std::unordered_map param_map; std::vector pointwise_inputs; for(auto input : ins->inputs()) { if(contains(param_map, input)) continue; auto scalar = get_scalar(input); if(scalar.empty()) { pointwise_inputs.push_back(input); param_map[input] = pm->add_parameter("x" + std::to_string(param_map.size()), shape{input->get_shape().type()}); } else { param_map[input] = pm->add_literal(scalar); } } std::vector inputs; std::transform(ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto input) { return param_map[input]; }); auto r = pm->add_instruction(ins->get_operator(), inputs); pm->add_return({r}); mpm.get_module().replace_instruction(ins, make_op("pointwise"), pointwise_inputs, {pm}); } } static std::vector append_pointwise_module(instruction_ref ins, instruction_ref output) { module_ref pm = ins->module_inputs().at(0); module_ref xm = output->module_inputs().at(0); auto last = std::prev(pm->end()); assert(last->name() == "@return"); assert(last->inputs().size() == 1); std::vector inputs = ins->inputs(); std::unordered_map map_ins; std::unordered_map input_map; // Copy inputs to input_map for(auto i : range(inputs.size())) { auto input = inputs[i]; auto param = pm->get_parameter("x" + std::to_string(i)); input_map[input] = param; } // Add the new parameter and additional inputs for(auto i : range(output->inputs().size())) { auto input = output->inputs()[i]; auto param = xm->get_parameter("x" + std::to_string(i)); if(input == ins) { map_ins[param] = last->inputs().front(); input_map[input] = map_ins[param]; } // Avoid duplicate paramter inputs else if(contains(input_map, input)) { map_ins[param] = input_map[input]; } else { map_ins[param] = pm->add_parameter("x" + std::to_string(inputs.size()), {input->get_shape().type()}); inputs.push_back(input); input_map[input] = map_ins[param]; } } pm->replace_return(pm->insert_module_instructions(last, xm, map_ins)); return inputs; } static bool find_pointwise_modules(module& m) { bool changed = false; for(auto ins : iterator_for(m)) { if(ins->name() != "pointwise") continue; if(ins->outputs().empty()) continue; auto it = std::find_if(ins->inputs().begin(), ins->inputs().end(), [&](auto i) { return i->name() == "pointwise" and i->outputs().size() == 1; }); if(it == ins->inputs().end()) continue; auto new_inputs = append_pointwise_module(*it, ins); m.replace_instruction(*it, (*it)->get_operator(), new_inputs, (*it)->module_inputs()); m.replace_instruction(ins, *it); m.move_instruction(*it, ins); changed = true; } return changed; } void fuse_pointwise::apply(module_pass_manager& mpm) const { create_pointwise_modules(mpm); mpm.run_pass(dead_code_elimination{}); for(int i = 0; i < 8; i++) { if(not find_pointwise_modules(mpm.get_module())) break; mpm.run_pass(dead_code_elimination{}); } } } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx