Commit 7b101e49 authored by Paul's avatar Paul
Browse files

Format

parent c5de9766
...@@ -52,15 +52,15 @@ struct fused_reduce ...@@ -52,15 +52,15 @@ struct fused_reduce
{ {
MIGRAPHX_THROW("should have one submodule."); MIGRAPHX_THROW("should have one submodule.");
} }
auto* sm = mods.front(); auto* sm = mods.front();
check_shapes{inputs, *this}.has(sm->get_parameter_shapes().size()).same_dims(); check_shapes{inputs, *this}.has(sm->get_parameter_shapes().size()).same_dims();
auto s = inputs.at(0); auto s = inputs.at(0);
auto lens = s.lens(); auto lens = s.lens();
for(const auto& axis : axes) for(const auto& axis : axes)
{ {
lens[axis] = 1; lens[axis] = 1;
} }
if (sm->get_output_shapes().size() != 1) if(sm->get_output_shapes().size() != 1)
MIGRAPHX_THROW("Only one output supported"); MIGRAPHX_THROW("Only one output supported");
return inputs[0].with_lens(sm->get_output_shapes().front().type(), lens); return inputs[0].with_lens(sm->get_output_shapes().front().type(), lens);
} }
...@@ -75,15 +75,16 @@ static void create_reduce_modules(module_pass_manager& mpm) ...@@ -75,15 +75,16 @@ static void create_reduce_modules(module_pass_manager& mpm)
{ {
if(not ins->get_operator().attributes().get("reduce", false)) if(not ins->get_operator().attributes().get("reduce", false))
continue; continue;
if (ins->inputs().size() != 1) if(ins->inputs().size() != 1)
continue; continue;
auto* rm = mpm.create_module(mpm.get_module().name() + ":" + ins->name() + std::to_string(n++)); auto* rm =
mpm.create_module(mpm.get_module().name() + ":" + ins->name() + std::to_string(n++));
rm->set_bypass(); rm->set_bypass();
// TODO: Ensure standard shape // TODO: Ensure standard shape
auto x0 = rm->add_parameter("x0", ins->inputs().front()->get_shape()); auto x0 = rm->add_parameter("x0", ins->inputs().front()->get_shape());
auto r = rm->add_instruction(ins->get_operator(), x0); auto r = rm->add_instruction(ins->get_operator(), x0);
rm->add_return({r}); rm->add_return({r});
// TODO: Set axes // TODO: Set axes
...@@ -91,22 +92,27 @@ static void create_reduce_modules(module_pass_manager& mpm) ...@@ -91,22 +92,27 @@ static void create_reduce_modules(module_pass_manager& mpm)
} }
} }
static std::unordered_map<instruction_ref, instruction_ref> get_ins_param_map(const std::vector<instruction_ref>& inputs, const_module_ref sm) static std::unordered_map<instruction_ref, instruction_ref>
get_ins_param_map(const std::vector<instruction_ref>& inputs, const_module_ref sm)
{ {
std::unordered_map<instruction_ref, instruction_ref> result; std::unordered_map<instruction_ref, instruction_ref> result;
auto names = sm->get_parameter_names(); auto names = sm->get_parameter_names();
std::sort(names.begin(), names.end()); std::sort(names.begin(), names.end());
assert(names.size() == inputs.size()); assert(names.size() == inputs.size());
std::transform(names.begin(), names.end(), inputs.begin(), std::inserter(result, result.end()), [&](const auto& name, auto input) { std::transform(names.begin(),
return std::make_pair(input, sm->get_parameter(name)); names.end(),
}); inputs.begin(),
std::inserter(result, result.end()),
[&](const auto& name, auto input) {
return std::make_pair(input, sm->get_parameter(name));
});
return result; return result;
} }
static std::vector<instruction_ref> get_returns(module& m) static std::vector<instruction_ref> get_returns(module& m)
{ {
auto last = std::prev(m.end()); auto last = std::prev(m.end());
if (last->name() == "@return") if(last->name() == "@return")
return last->inputs(); return last->inputs();
return {last}; return {last};
} }
...@@ -115,31 +121,33 @@ struct find_reduce_pointwise ...@@ -115,31 +121,33 @@ struct find_reduce_pointwise
{ {
auto matcher() const auto matcher() const
{ {
return match::name("pointwise")(match::any_of[match::inputs()](match::name("fused_reduce")(match::used_once()).bind("reduce"))); return match::name("pointwise")(match::any_of[match::inputs()](
match::name("fused_reduce")(match::used_once()).bind("reduce")));
} }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto reduce = r.instructions["reduce"]; auto reduce = r.instructions["reduce"];
auto* old_rm = reduce->module_inputs().front(); auto* old_rm = reduce->module_inputs().front();
auto* rm = mpm.create_module(old_rm->name() + ":pointwise"); auto* rm = mpm.create_module(old_rm->name() + ":pointwise");
// Copy module // Copy module
*rm = *old_rm; *rm = *old_rm;
auto map_ins = get_ins_param_map(reduce->inputs(), rm); auto map_ins = get_ins_param_map(reduce->inputs(), rm);
auto new_inputs = reduce->inputs(); auto new_inputs = reduce->inputs();
for(auto input:ins->inputs()) for(auto input : ins->inputs())
{ {
if(contains(map_ins, input)) if(contains(map_ins, input))
continue; continue;
if (input == reduce) if(input == reduce)
{ {
map_ins[input] = get_returns(*rm).front(); map_ins[input] = get_returns(*rm).front();
} }
else else
{ {
map_ins[input] = rm->add_parameter("x" + std::to_string(new_inputs.size()), input->get_shape()); map_ins[input] =
rm->add_parameter("x" + std::to_string(new_inputs.size()), input->get_shape());
new_inputs.push_back(input); new_inputs.push_back(input);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment