Commit f5400730 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

formatting

parent e801e2f7
......@@ -565,15 +565,19 @@ struct find_inner_broadcast
}));
auto op = insert_common_op(m, ins, ins->get_operator(), inputs);
std::vector<shape> broadcast_shapes;
std::transform(broadcasts.begin(), broadcasts.end(), std::back_inserter(broadcast_shapes), [](auto broadcast){
return broadcast->get_shape();
});
std::transform(broadcasts.begin(),
broadcasts.end(),
std::back_inserter(broadcast_shapes),
[](auto broadcast) { return broadcast->get_shape(); });
std::vector<shape> common_shapes;
std::transform(op->inputs().begin(), op->inputs().end(), std::back_inserter(common_shapes), [](auto common){
return common->get_shape();
});
if(broadcast_shapes == common_shapes and std::all_of(op->inputs().begin(), op->inputs().end(), [](auto i){
return i->name() == "broadcast" or i->name() == "multibroadcast";}))
std::transform(op->inputs().begin(),
op->inputs().end(),
std::back_inserter(common_shapes),
[](auto common) { return common->get_shape(); });
if(broadcast_shapes == common_shapes and
std::all_of(op->inputs().begin(), op->inputs().end(), [](auto i) {
return i->name() == "broadcast" or i->name() == "multibroadcast";
}))
return;
m.replace_instruction(ins, broadcasts.front()->get_operator(), op);
}
......
......@@ -43,7 +43,7 @@ struct index
__device__ index_int nglobal() const { return blockDim.x * gridDim.x; } // NOLINT
__device__ index_int nlocal() const { return blockDim.x; } // NOLINT
__device__ index_int nlocal() const { return blockDim.x; } // NOLINT
template <class F>
__device__ void global_stride(index_int n, F f) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment