Commit 3bc77083 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

add condition with just one broadcast

parent fc38213a
......@@ -521,15 +521,17 @@ struct find_inner_broadcast
}) < (lens.size() - 1);
}))
return;
if(broadcasts.size() > 1)
{
auto bcast_strides = broadcasts.front()->get_shape().strides().size();
std::vector<size_t> common_axis(bcast_strides, 0);
// go through the strides of each broadcast,
// keep track of values that are equal to 0 in a dimension
for(auto i = 0; i < bcast_strides; i++)
{
for(auto j = 0; j < broadcasts.size(); j++)
for(const auto& broadcast : broadcasts)
{
if(broadcasts[j]->get_shape().strides()[i] == 0)
if(broadcast->get_shape().strides()[i] == 0)
common_axis[i]++;
}
}
......@@ -538,6 +540,7 @@ struct find_inner_broadcast
return num_common > 1;
}) == common_axis.end())
return;
}
std::vector<instruction_ref> inputs;
std::transform(broadcasts.begin(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment