Commit 8949b9b1 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

fixes to optimize_module pass

parent 7aee6388
...@@ -40,7 +40,7 @@ bool skip_propogate(instruction_ref ins) ...@@ -40,7 +40,7 @@ bool skip_propogate(instruction_ref ins)
if(ins->name() == "contiguous") if(ins->name() == "contiguous")
return skip_propogate(ins->inputs().front()); return skip_propogate(ins->inputs().front());
auto&& s = ins->get_shape(); auto&& s = ins->get_shape();
if(s.broadcasted() and not s.scalar()) if(s.broadcasted() and not s.scalar() and not s.packed())
return true; return true;
if(s.scalar() and s.elements() != 1) if(s.scalar() and s.elements() != 1)
return true; return true;
...@@ -101,9 +101,13 @@ void propagate_constant::apply(module& m) const ...@@ -101,9 +101,13 @@ void propagate_constant::apply(module& m) const
})(const_instrs_vec[i]); })(const_instrs_vec[i]);
m.debug_print(inss); m.debug_print(inss);
} }
assert(literals[i].get_shape() == const_instrs_vec[i]->get_shape()); auto in_shape = const_instrs_vec[i]->get_shape();
auto l = m.add_literal(literals[i].get_shape(), literals[i].data()); assert(literals[i].get_shape() == in_shape);
m.replace_instruction(const_instrs_vec[i], l); literal l{in_shape, literals[i].data()};
if(const_instrs_vec[i]->outputs().front()->name() == "dot")
l = {{in_shape.type(), in_shape.lens()}, literals[i].data()};
auto l0 = m.add_literal(l);
m.replace_instruction(const_instrs_vec[i], l0);
} }
} }
} }
......
...@@ -543,6 +543,17 @@ struct find_inner_broadcast ...@@ -543,6 +543,17 @@ struct find_inner_broadcast
return 3; return 3;
})); }));
auto op = insert_common_op(m, ins, ins->get_operator(), inputs); auto op = insert_common_op(m, ins, ins->get_operator(), inputs);
std::vector<shape> broadcast_shapes;
std::transform(broadcasts.begin(), broadcasts.end(), std::back_inserter(broadcast_shapes), [](auto broadcast){
return broadcast->get_shape();
});
std::vector<shape> common_shapes;
std::transform(op->inputs().begin(), op->inputs().end(), std::back_inserter(common_shapes), [](auto common){
return common->get_shape();
});
if(broadcast_shapes == common_shapes and std::all_of(op->inputs().begin(), op->inputs().end(), [](auto i){
return i->name() == "broadcast" or i->name() == "multibroadcast";}))
return;
m.replace_instruction(ins, broadcasts.front()->get_operator(), op); m.replace_instruction(ins, broadcasts.front()->get_operator(), op);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment