Commit 6bc30259 authored by Paul's avatar Paul
Browse files

Format

parent 8bc67132
......@@ -60,7 +60,7 @@ struct fused_reduce
{
lens[axis] = 1;
}
if (sm->get_output_shapes().size() != 1)
if(sm->get_output_shapes().size() != 1)
MIGRAPHX_THROW("Only one output supported");
return inputs[0].with_lens(sm->get_output_shapes().front().type(), lens);
}
......@@ -75,10 +75,11 @@ static void create_reduce_modules(module_pass_manager& mpm)
{
if(not ins->get_operator().attributes().get("reduce", false))
continue;
if (ins->inputs().size() != 1)
if(ins->inputs().size() != 1)
continue;
auto* rm = mpm.create_module(mpm.get_module().name() + ":" + ins->name() + std::to_string(n++));
auto* rm =
mpm.create_module(mpm.get_module().name() + ":" + ins->name() + std::to_string(n++));
rm->set_bypass();
// TODO: Ensure standard shape
......@@ -91,13 +92,18 @@ static void create_reduce_modules(module_pass_manager& mpm)
}
}
static std::unordered_map<instruction_ref, instruction_ref> get_param_map(const std::vector<instruction_ref>& inputs, const_module_ref sm)
static std::unordered_map<instruction_ref, instruction_ref>
get_param_map(const std::vector<instruction_ref>& inputs, const_module_ref sm)
{
std::unordered_map<instruction_ref, instruction_ref> result;
auto names = sm->get_parameter_names();
std::sort(names.begin(), names.end());
assert(names.size() == inputs.size());
std::transform(names.begin(), names.end(), inputs.begin(), std::inserter(result, result.end()), [&](const auto& name, auto input) {
std::transform(names.begin(),
names.end(),
inputs.begin(),
std::inserter(result, result.end()),
[&](const auto& name, auto input) {
return std::make_pair(input, sm->get_parameter(name));
});
return result;
......@@ -106,7 +112,7 @@ static std::unordered_map<instruction_ref, instruction_ref> get_param_map(const
static std::vector<instruction_ref> get_returns(module& m)
{
auto last = std::prev(m.end());
if (last->name() == "@return")
if(last->name() == "@return")
return last->inputs();
return {last};
}
......@@ -115,7 +121,8 @@ struct find_reduce_pointwise
{
auto matcher() const
{
return match::name("pointwise")(match::any_of[match::inputs()](match::name("fused_reduce")(match::used_once()).bind("reduce")));
return match::name("pointwise")(match::any_of[match::inputs()](
match::name("fused_reduce")(match::used_once()).bind("reduce")));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
......@@ -129,15 +136,16 @@ struct find_reduce_pointwise
*rm = *old_rm;
auto map_ins = get_param_map(reduce->inputs(), rm);
auto new_inputs = reduce->inputs();
for(auto input:ins->inputs())
for(auto input : ins->inputs())
{
if(contains(map_ins, input))
continue;
if (input == reduce)
if(input == reduce)
{
map_ins[input] = rm->
}
map_ins[input] = rm->add_parameter("x" + std::to_string(new_inputs.size()), input->get_shape());
map_ins[input] =
rm->add_parameter("x" + std::to_string(new_inputs.size()), input->get_shape());
new_inputs.push_back(input);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment