"vscode:/vscode.git/clone" did not exist on "abbfd1dc58bf532a8301c7c0d25caec0e8e82189"
Commit 499e7938 authored by Khalique's avatar Khalique
Browse files

add function for axis mask

parent 63410264
...@@ -148,6 +148,21 @@ struct tf_parser ...@@ -148,6 +148,21 @@ struct tf_parser
return axes; return axes;
} }
std::vector<int64_t> get_axes_from_mask(const size_t num_axes, const uint32_t mask)
{
uint32_t bitwise_compare = 1;
std::vector<int64_t> axes;
for(size_t i = 0; i < num_axes; i++)
{
// the LSB corresponds to axis 0 when determining which axes to begin
if(((mask >> i) & bitwise_compare) == 1)
axes.push_back(1);
else
axes.push_back(0);
}
return axes;
}
tf_parser() tf_parser()
{ {
add_generic_op("All", op::identity{}); add_generic_op("All", op::identity{});
...@@ -837,8 +852,6 @@ struct tf_parser ...@@ -837,8 +852,6 @@ struct tf_parser
uint32_t end_mask = 0; uint32_t end_mask = 0;
uint32_t shrink_axis_mask = 0; uint32_t shrink_axis_mask = 0;
uint32_t bitwise_compare = 1; uint32_t bitwise_compare = 1;
std::vector<int64_t> begin_axes;
std::vector<int64_t> end_axes;
std::vector<int64_t> squeeze_axes; std::vector<int64_t> squeeze_axes;
if(contains(attributes, "begin_mask")) if(contains(attributes, "begin_mask"))
...@@ -850,23 +863,8 @@ struct tf_parser ...@@ -850,23 +863,8 @@ struct tf_parser
if(contains(attributes, "shrink_axis_mask")) if(contains(attributes, "shrink_axis_mask"))
shrink_axis_mask = static_cast<uint32_t>(attributes.at("shrink_axis_mask").i()); shrink_axis_mask = static_cast<uint32_t>(attributes.at("shrink_axis_mask").i());
for(size_t i = 0; i < num_axes; i++) std::vector<int64_t> begin_axes = get_axes_from_mask(num_axes, begin_mask);
{ std::vector<int64_t> end_axes = get_axes_from_mask(num_axes, end_mask);
// the LSB corresponds to axis 0 when determining which axes to begin
if(((begin_mask >> i) & bitwise_compare) == 1)
begin_axes.push_back(1);
else
begin_axes.push_back(0);
}
for(size_t i = 0; i < num_axes; i++)
{
// the LSB corresponds to axis 0 when determining which axes to end
if(((end_mask >> i) & bitwise_compare) == 1)
end_axes.push_back(1);
else
end_axes.push_back(0);
}
for(size_t i = 0; i < num_axes; i++) for(size_t i = 0; i < num_axes; i++)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment