Commit 499e7938 authored by Khalique's avatar Khalique
Browse files

add function for axis mask

parent 63410264
......@@ -148,6 +148,21 @@ struct tf_parser
return axes;
}
std::vector<int64_t> get_axes_from_mask(const size_t num_axes, const uint32_t mask)
{
uint32_t bitwise_compare = 1;
std::vector<int64_t> axes;
for(size_t i = 0; i < num_axes; i++)
{
// the LSB corresponds to axis 0 when determining which axes to begin
if(((mask >> i) & bitwise_compare) == 1)
axes.push_back(1);
else
axes.push_back(0);
}
return axes;
}
tf_parser()
{
add_generic_op("All", op::identity{});
......@@ -837,8 +852,6 @@ struct tf_parser
uint32_t end_mask = 0;
uint32_t shrink_axis_mask = 0;
uint32_t bitwise_compare = 1;
std::vector<int64_t> begin_axes;
std::vector<int64_t> end_axes;
std::vector<int64_t> squeeze_axes;
if(contains(attributes, "begin_mask"))
......@@ -850,23 +863,8 @@ struct tf_parser
if(contains(attributes, "shrink_axis_mask"))
shrink_axis_mask = static_cast<uint32_t>(attributes.at("shrink_axis_mask").i());
for(size_t i = 0; i < num_axes; i++)
{
// the LSB corresponds to axis 0 when determining which axes to begin
if(((begin_mask >> i) & bitwise_compare) == 1)
begin_axes.push_back(1);
else
begin_axes.push_back(0);
}
for(size_t i = 0; i < num_axes; i++)
{
// the LSB corresponds to axis 0 when determining which axes to end
if(((end_mask >> i) & bitwise_compare) == 1)
end_axes.push_back(1);
else
end_axes.push_back(0);
}
std::vector<int64_t> begin_axes = get_axes_from_mask(num_axes, begin_mask);
std::vector<int64_t> end_axes = get_axes_from_mask(num_axes, end_mask);
for(size_t i = 0; i < num_axes; i++)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment