Commit 8949b9b1 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

fixes to optimize_module pass

parent 7aee6388
......@@ -40,7 +40,7 @@ bool skip_propogate(instruction_ref ins)
if(ins->name() == "contiguous")
return skip_propogate(ins->inputs().front());
auto&& s = ins->get_shape();
if(s.broadcasted() and not s.scalar())
if(s.broadcasted() and not s.scalar() and not s.packed())
return true;
if(s.scalar() and s.elements() != 1)
return true;
......@@ -101,9 +101,13 @@ void propagate_constant::apply(module& m) const
})(const_instrs_vec[i]);
m.debug_print(inss);
}
assert(literals[i].get_shape() == const_instrs_vec[i]->get_shape());
auto l = m.add_literal(literals[i].get_shape(), literals[i].data());
m.replace_instruction(const_instrs_vec[i], l);
auto in_shape = const_instrs_vec[i]->get_shape();
assert(literals[i].get_shape() == in_shape);
literal l{in_shape, literals[i].data()};
if(const_instrs_vec[i]->outputs().front()->name() == "dot")
l = {{in_shape.type(), in_shape.lens()}, literals[i].data()};
auto l0 = m.add_literal(l);
m.replace_instruction(const_instrs_vec[i], l0);
}
}
}
......
......@@ -543,6 +543,17 @@ struct find_inner_broadcast
return 3;
}));
auto op = insert_common_op(m, ins, ins->get_operator(), inputs);
std::vector<shape> broadcast_shapes;
std::transform(broadcasts.begin(), broadcasts.end(), std::back_inserter(broadcast_shapes), [](auto broadcast){
return broadcast->get_shape();
});
std::vector<shape> common_shapes;
std::transform(op->inputs().begin(), op->inputs().end(), std::back_inserter(common_shapes), [](auto common){
return common->get_shape();
});
if(broadcast_shapes == common_shapes and std::all_of(op->inputs().begin(), op->inputs().end(), [](auto i){
return i->name() == "broadcast" or i->name() == "multibroadcast";}))
return;
m.replace_instruction(ins, broadcasts.front()->get_operator(), op);
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment