Commit 3bc77083 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

add condition with just one broadcast

parent fc38213a
...@@ -521,15 +521,17 @@ struct find_inner_broadcast ...@@ -521,15 +521,17 @@ struct find_inner_broadcast
}) < (lens.size() - 1); }) < (lens.size() - 1);
})) }))
return; return;
if(broadcasts.size() > 1)
{
auto bcast_strides = broadcasts.front()->get_shape().strides().size(); auto bcast_strides = broadcasts.front()->get_shape().strides().size();
std::vector<size_t> common_axis(bcast_strides, 0); std::vector<size_t> common_axis(bcast_strides, 0);
// go through the strides of each broadcast, // go through the strides of each broadcast,
// keep track of values that are equal to 0 in a dimension // keep track of values that are equal to 0 in a dimension
for(auto i = 0; i < bcast_strides; i++) for(auto i = 0; i < bcast_strides; i++)
{ {
for(auto j = 0; j < broadcasts.size(); j++) for(const auto& broadcast : broadcasts)
{ {
if(broadcasts[j]->get_shape().strides()[i] == 0) if(broadcast->get_shape().strides()[i] == 0)
common_axis[i]++; common_axis[i]++;
} }
} }
...@@ -538,6 +540,7 @@ struct find_inner_broadcast ...@@ -538,6 +540,7 @@ struct find_inner_broadcast
return num_common > 1; return num_common > 1;
}) == common_axis.end()) }) == common_axis.end())
return; return;
}
std::vector<instruction_ref> inputs; std::vector<instruction_ref> inputs;
std::transform(broadcasts.begin(), std::transform(broadcasts.begin(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment