Commit b4bba9a0 authored by Paul's avatar Paul
Browse files

Format

parent 94d93226
...@@ -91,10 +91,12 @@ get_ins_param_map(const std::vector<instruction_ref>& inputs, const_module_ref s ...@@ -91,10 +91,12 @@ get_ins_param_map(const std::vector<instruction_ref>& inputs, const_module_ref s
return result; return result;
} }
static void insert_params(module_ref sm, instruction_ref ins, std::unordered_map<instruction_ref, instruction_ref>& map_ins) static void insert_params(module_ref sm,
instruction_ref ins,
std::unordered_map<instruction_ref, instruction_ref>& map_ins)
{ {
auto n = sm->get_parameter_shapes().size(); auto n = sm->get_parameter_shapes().size();
for(auto input:ins->inputs()) for(auto input : ins->inputs())
{ {
if(contains(map_ins, input)) if(contains(map_ins, input))
continue; continue;
...@@ -103,7 +105,9 @@ static void insert_params(module_ref sm, instruction_ref ins, std::unordered_map ...@@ -103,7 +105,9 @@ static void insert_params(module_ref sm, instruction_ref ins, std::unordered_map
} }
} }
static auto insert_ins_in_submodule(module_ref sm, instruction_ref ins, std::unordered_map<instruction_ref, instruction_ref>& map_ins) static auto insert_ins_in_submodule(module_ref sm,
instruction_ref ins,
std::unordered_map<instruction_ref, instruction_ref>& map_ins)
{ {
insert_params(sm, ins, map_ins); insert_params(sm, ins, map_ins);
return sm->add_instructions({ins}, map_ins); return sm->add_instructions({ins}, map_ins);
...@@ -115,23 +119,27 @@ static auto insert_ins_in_submodule(module_ref sm, instruction_ref ins) ...@@ -115,23 +119,27 @@ static auto insert_ins_in_submodule(module_ref sm, instruction_ref ins)
return insert_ins_in_submodule(sm, ins, map_ins); return insert_ins_in_submodule(sm, ins, map_ins);
} }
static auto insert_module_in_submodule(module_ref sm, instruction_ref ins, std::unordered_map<instruction_ref, instruction_ref>& map_ins) static auto
insert_module_in_submodule(module_ref sm,
instruction_ref ins,
std::unordered_map<instruction_ref, instruction_ref>& map_ins)
{ {
insert_params(sm, ins, map_ins); insert_params(sm, ins, map_ins);
auto* m = ins->module_inputs().front(); auto* m = ins->module_inputs().front();
auto param_map = get_ins_param_map(ins->inputs(), m); auto param_map = get_ins_param_map(ins->inputs(), m);
for(auto&& [input, param]:param_map) for(auto&& [input, param] : param_map)
{ {
map_ins[param] = map_ins.at(input); map_ins[param] = map_ins.at(input);
} }
return sm->add_instructions(m, map_ins); return sm->add_instructions(m, map_ins);
} }
static std::vector<instruction_ref> find_inputs(module_ref sm, const std::unordered_map<instruction_ref, instruction_ref>& map_ins) static std::vector<instruction_ref>
find_inputs(module_ref sm, const std::unordered_map<instruction_ref, instruction_ref>& map_ins)
{ {
std::vector<instruction_ref> result; std::vector<instruction_ref> result;
std::map<std::string, instruction_ref> names; std::map<std::string, instruction_ref> names;
for(auto&& [input, param]:map_ins) for(auto&& [input, param] : map_ins)
{ {
if(not sm->has_instruction(param)) if(not sm->has_instruction(param))
continue; continue;
...@@ -182,7 +190,8 @@ struct find_pointwise_reduce ...@@ -182,7 +190,8 @@ struct find_pointwise_reduce
{ {
auto matcher() const auto matcher() const
{ {
return match::name("fused_reduce")(match::any_of[match::inputs()](match::name("pointwise")(match::used_once()).bind("pointwise"))); return match::name("fused_reduce")(match::any_of[match::inputs()](
match::name("pointwise")(match::used_once()).bind("pointwise")));
} }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
...@@ -235,7 +244,7 @@ struct find_reduce_pointwise ...@@ -235,7 +244,7 @@ struct find_reduce_pointwise
mpm.get_module().replace_instruction(pw, reduce->get_operator(), new_inputs, {rm}); mpm.get_module().replace_instruction(pw, reduce->get_operator(), new_inputs, {rm});
} }
}; };
} } // namespace
void fuse_reduce::apply(module_pass_manager& mpm) const void fuse_reduce::apply(module_pass_manager& mpm) const
{ {
......
...@@ -277,23 +277,31 @@ std::string generate_reduce(const module& rm, const std::string& name) ...@@ -277,23 +277,31 @@ std::string generate_reduce(const module& rm, const std::string& name)
i++; i++;
generate_pointwise(g, *ins->module_inputs().front(), pointwise_name); generate_pointwise(g, *ins->module_inputs().front(), pointwise_name);
std::vector<instruction_ref> tensors; std::vector<instruction_ref> tensors;
std::copy_if(ins->inputs().begin(), ins->inputs().end(), std::back_inserter(tensors), [&](auto input) { std::copy_if(ins->inputs().begin(),
return input->get_shape().lens() == ilens and not input->get_shape().broadcasted(); ins->inputs().end(),
std::back_inserter(tensors),
[&](auto input) {
return input->get_shape().lens() == ilens and
not input->get_shape().broadcasted();
}); });
auto inner_names = names; auto inner_names = names;
for(auto input:tensors) for(auto input : tensors)
inner_names[input] += "_lambda_param"; inner_names[input] += "_lambda_param";
auto call_function = pointwise_name + "(" + auto call_function =
pointwise_name + "(" +
join_strings(cpp_generator::to_args(ins->inputs(), inner_names), ", ") + ")"; join_strings(cpp_generator::to_args(ins->inputs(), inner_names), ", ") + ")";
if (tensors.empty()) if(tensors.empty())
return call_function; return call_function;
const std::string inner_template = "r.inner([=](${params}) { return ${call}; })(${args})"; const std::string inner_template =
"r.inner([=](${params}) { return ${call}; })(${args})";
auto args = cpp_generator::to_args(tensors, names); auto args = cpp_generator::to_args(tensors, names);
auto params = cpp_generator::to_args(tensors, inner_names); auto params = cpp_generator::to_args(tensors, inner_names);
std::transform(params.begin(), params.end(), params.begin(), [](auto s) { std::transform(
return "auto " + s; params.begin(), params.end(), params.begin(), [](auto s) { return "auto " + s; });
}); return interpolate_string(inner_template,
return interpolate_string(inner_template, {{"params", join_strings(params, ", ")}, {"args", join_strings(args, ", ")}, {"call", call_function}}); {{"params", join_strings(params, ", ")},
{"args", join_strings(args, ", ")},
{"call", call_function}});
} }
MIGRAPHX_THROW("Unknown operator: " + ins->name()); MIGRAPHX_THROW("Unknown operator: " + ins->name());
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment