Commit 3bc77083 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

add condition with just one broadcast

parent fc38213a
...@@ -521,23 +521,26 @@ struct find_inner_broadcast ...@@ -521,23 +521,26 @@ struct find_inner_broadcast
}) < (lens.size() - 1); }) < (lens.size() - 1);
})) }))
return; return;
auto bcast_strides = broadcasts.front()->get_shape().strides().size(); if(broadcasts.size() > 1)
std::vector<size_t> common_axis(bcast_strides, 0);
// go through the strides of each broadcast,
// keep track of values that are equal to 0 in a dimension
for(auto i = 0; i < bcast_strides; i++)
{ {
for(auto j = 0; j < broadcasts.size(); j++) auto bcast_strides = broadcasts.front()->get_shape().strides().size();
std::vector<size_t> common_axis(bcast_strides, 0);
// go through the strides of each broadcast,
// keep track of values that are equal to 0 in a dimension
for(auto i = 0; i < bcast_strides; i++)
{ {
if(broadcasts[j]->get_shape().strides()[i] == 0) for(const auto& broadcast : broadcasts)
common_axis[i]++; {
if(broadcast->get_shape().strides()[i] == 0)
common_axis[i]++;
}
} }
// if no common broadcast axis, transformation is not useful
if(std::find_if(common_axis.begin(), common_axis.end(), [](auto num_common) {
return num_common > 1;
}) == common_axis.end())
return;
} }
// if no common broadcast axis, transformation is not useful
if(std::find_if(common_axis.begin(), common_axis.end(), [](auto num_common) {
return num_common > 1;
}) == common_axis.end())
return;
std::vector<instruction_ref> inputs; std::vector<instruction_ref> inputs;
std::transform(broadcasts.begin(), std::transform(broadcasts.begin(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment