#ifndef MIGRAPHX_GUARD_OPERATORS_IF_OP_HPP #define MIGRAPHX_GUARD_OPERATORS_IF_OP_HPP #include #include #include #include #include #include #include #include #include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace op { struct if_op { std::string name() const { return "if"; } shape compute_shape(const std::vector& inputs, std::vector mods) const { check_shapes{inputs, *this}.standard(); if(mods.size() != 2) { MIGRAPHX_THROW("IF: operator should have two submodules."); } auto out_shapes0 = mods[0]->get_output_shapes(); auto out_shapes1 = mods[1]->get_output_shapes(); if(not std::equal( out_shapes1.begin(), out_shapes1.end(), out_shapes0.begin(), out_shapes0.end())) { MIGRAPHX_THROW("IF: output shapes of submodules must be the same."); } return out_shapes0.front(); } argument compute( const std::vector& args, const std::vector& mods, const std::function( module_ref& mdl, const std::unordered_map& inputs)>& run) const { auto cond = args.front().at(); module_ref mod = cond ? mods[0] : mods[1]; std::unordered_map params; std::set pnames; for(const auto& smod : mods) { auto names = smod->get_parameter_names(); pnames.insert(names.begin(), names.end()); } assert(pnames.size() < args.size()); std::transform(pnames.begin(), pnames.end(), args.begin() + 1, std::inserter(params, params.end()), [](auto&& name, auto&& arg) { return std::make_pair(name, arg); }); auto results = run(mod, params); return results[0]; } }; } // namespace op } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx #endif