Commit f5400730 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

formatting

parent e801e2f7
...@@ -565,15 +565,19 @@ struct find_inner_broadcast ...@@ -565,15 +565,19 @@ struct find_inner_broadcast
})); }));
auto op = insert_common_op(m, ins, ins->get_operator(), inputs); auto op = insert_common_op(m, ins, ins->get_operator(), inputs);
std::vector<shape> broadcast_shapes; std::vector<shape> broadcast_shapes;
std::transform(broadcasts.begin(), broadcasts.end(), std::back_inserter(broadcast_shapes), [](auto broadcast){ std::transform(broadcasts.begin(),
return broadcast->get_shape(); broadcasts.end(),
}); std::back_inserter(broadcast_shapes),
[](auto broadcast) { return broadcast->get_shape(); });
std::vector<shape> common_shapes; std::vector<shape> common_shapes;
std::transform(op->inputs().begin(), op->inputs().end(), std::back_inserter(common_shapes), [](auto common){ std::transform(op->inputs().begin(),
return common->get_shape(); op->inputs().end(),
}); std::back_inserter(common_shapes),
if(broadcast_shapes == common_shapes and std::all_of(op->inputs().begin(), op->inputs().end(), [](auto i){ [](auto common) { return common->get_shape(); });
return i->name() == "broadcast" or i->name() == "multibroadcast";})) if(broadcast_shapes == common_shapes and
std::all_of(op->inputs().begin(), op->inputs().end(), [](auto i) {
return i->name() == "broadcast" or i->name() == "multibroadcast";
}))
return; return;
m.replace_instruction(ins, broadcasts.front()->get_operator(), op); m.replace_instruction(ins, broadcasts.front()->get_operator(), op);
} }
......
...@@ -43,7 +43,7 @@ struct index ...@@ -43,7 +43,7 @@ struct index
__device__ index_int nglobal() const { return blockDim.x * gridDim.x; } // NOLINT __device__ index_int nglobal() const { return blockDim.x * gridDim.x; } // NOLINT
__device__ index_int nlocal() const { return blockDim.x; } // NOLINT __device__ index_int nlocal() const { return blockDim.x; } // NOLINT
template <class F> template <class F>
__device__ void global_stride(index_int n, F f) const __device__ void global_stride(index_int n, F f) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment