Commit b4bba9a0 authored by Paul's avatar Paul
Browse files

Format

parent 94d93226
......@@ -91,10 +91,12 @@ get_ins_param_map(const std::vector<instruction_ref>& inputs, const_module_ref s
return result;
}
static void insert_params(module_ref sm, instruction_ref ins, std::unordered_map<instruction_ref, instruction_ref>& map_ins)
static void insert_params(module_ref sm,
instruction_ref ins,
std::unordered_map<instruction_ref, instruction_ref>& map_ins)
{
auto n = sm->get_parameter_shapes().size();
for(auto input:ins->inputs())
for(auto input : ins->inputs())
{
if(contains(map_ins, input))
continue;
......@@ -103,7 +105,9 @@ static void insert_params(module_ref sm, instruction_ref ins, std::unordered_map
}
}
static auto insert_ins_in_submodule(module_ref sm, instruction_ref ins, std::unordered_map<instruction_ref, instruction_ref>& map_ins)
static auto insert_ins_in_submodule(module_ref sm,
instruction_ref ins,
std::unordered_map<instruction_ref, instruction_ref>& map_ins)
{
insert_params(sm, ins, map_ins);
return sm->add_instructions({ins}, map_ins);
......@@ -115,30 +119,34 @@ static auto insert_ins_in_submodule(module_ref sm, instruction_ref ins)
return insert_ins_in_submodule(sm, ins, map_ins);
}
static auto insert_module_in_submodule(module_ref sm, instruction_ref ins, std::unordered_map<instruction_ref, instruction_ref>& map_ins)
static auto
insert_module_in_submodule(module_ref sm,
instruction_ref ins,
std::unordered_map<instruction_ref, instruction_ref>& map_ins)
{
insert_params(sm, ins, map_ins);
auto* m = ins->module_inputs().front();
auto* m = ins->module_inputs().front();
auto param_map = get_ins_param_map(ins->inputs(), m);
for(auto&& [input, param]:param_map)
for(auto&& [input, param] : param_map)
{
map_ins[param] = map_ins.at(input);
}
return sm->add_instructions(m, map_ins);
}
static std::vector<instruction_ref> find_inputs(module_ref sm, const std::unordered_map<instruction_ref, instruction_ref>& map_ins)
static std::vector<instruction_ref>
find_inputs(module_ref sm, const std::unordered_map<instruction_ref, instruction_ref>& map_ins)
{
std::vector<instruction_ref> result;
std::map<std::string, instruction_ref> names;
for(auto&& [input, param]:map_ins)
for(auto&& [input, param] : map_ins)
{
if(not sm->has_instruction(param))
continue;
if(param->name() != "@param")
continue;
auto v = param->get_operator().to_value();
auto name = v.at("parameter").to<std::string>();
auto v = param->get_operator().to_value();
auto name = v.at("parameter").to<std::string>();
names[name] = input;
}
std::transform(names.begin(), names.end(), std::back_inserter(result), [](const auto& p) {
......@@ -182,22 +190,23 @@ struct find_pointwise_reduce
{
auto matcher() const
{
return match::name("fused_reduce")(match::any_of[match::inputs()](match::name("pointwise")(match::used_once()).bind("pointwise")));
return match::name("fused_reduce")(match::any_of[match::inputs()](
match::name("pointwise")(match::used_once()).bind("pointwise")));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto reduce = r.result;
auto pw = r.instructions["pointwise"];
auto reduce = r.result;
auto pw = r.instructions["pointwise"];
const auto* pm = pw->module_inputs().front();
// const auto* old_rm = reduce->module_inputs().front();
auto* rm = mpm.create_module(pm->name() + ":reduce");
auto* rm = mpm.create_module(pm->name() + ":reduce");
rm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> map_ins;
// Insert pointwise
auto rins = insert_ins_in_submodule(rm, pw, map_ins).front();
auto rins = insert_ins_in_submodule(rm, pw, map_ins).front();
map_ins[pw] = rins;
// Insert fused_reduce
insert_module_in_submodule(rm, reduce, map_ins);
......@@ -217,7 +226,7 @@ struct find_reduce_pointwise
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto pw = r.result;
auto pw = r.result;
auto reduce = r.instructions["reduce"];
const auto* old_rm = reduce->module_inputs().front();
......@@ -235,7 +244,7 @@ struct find_reduce_pointwise
mpm.get_module().replace_instruction(pw, reduce->get_operator(), new_inputs, {rm});
}
};
}
} // namespace
void fuse_reduce::apply(module_pass_manager& mpm) const
{
......
......@@ -264,7 +264,7 @@ std::string generate_reduce(const module& rm, const std::string& name)
{
module m = rm;
cpp_generator g;
auto ilens = rm.get_parameter_shapes().begin()->second.lens();
auto ilens = rm.get_parameter_shapes().begin()->second.lens();
std::size_t i = 0;
auto f = g.generate_module(m, [&](instruction_ref ins, const auto& names) {
if(contains(ins->name(), "reduce"))
......@@ -277,23 +277,31 @@ std::string generate_reduce(const module& rm, const std::string& name)
i++;
generate_pointwise(g, *ins->module_inputs().front(), pointwise_name);
std::vector<instruction_ref> tensors;
std::copy_if(ins->inputs().begin(), ins->inputs().end(), std::back_inserter(tensors), [&](auto input) {
return input->get_shape().lens() == ilens and not input->get_shape().broadcasted();
});
std::copy_if(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(tensors),
[&](auto input) {
return input->get_shape().lens() == ilens and
not input->get_shape().broadcasted();
});
auto inner_names = names;
for(auto input:tensors)
for(auto input : tensors)
inner_names[input] += "_lambda_param";
auto call_function = pointwise_name + "(" +
join_strings(cpp_generator::to_args(ins->inputs(), inner_names), ", ") + ")";
if (tensors.empty())
auto call_function =
pointwise_name + "(" +
join_strings(cpp_generator::to_args(ins->inputs(), inner_names), ", ") + ")";
if(tensors.empty())
return call_function;
const std::string inner_template = "r.inner([=](${params}) { return ${call}; })(${args})";
auto args = cpp_generator::to_args(tensors, names);
const std::string inner_template =
"r.inner([=](${params}) { return ${call}; })(${args})";
auto args = cpp_generator::to_args(tensors, names);
auto params = cpp_generator::to_args(tensors, inner_names);
std::transform(params.begin(), params.end(), params.begin(), [](auto s) {
return "auto " + s;
});
return interpolate_string(inner_template, {{"params", join_strings(params, ", ")}, {"args", join_strings(args, ", ")}, {"call", call_function}});
std::transform(
params.begin(), params.end(), params.begin(), [](auto s) { return "auto " + s; });
return interpolate_string(inner_template,
{{"params", join_strings(params, ", ")},
{"args", join_strings(args, ", ")},
{"call", call_function}});
}
MIGRAPHX_THROW("Unknown operator: " + ins->name());
});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment