Commit 395ccd2c authored by Khalique's avatar Khalique
Browse files

rewrite parse axis

parent b02a8685
...@@ -65,18 +65,19 @@ struct tf_parser ...@@ -65,18 +65,19 @@ struct tf_parser
template <class T> template <class T>
T parse_axis(const T& dim) const T parse_axis(const T& dim) const
{ {
T new_dim = dim;
if(is_nhwc) if(is_nhwc)
{ {
switch(dim) switch(dim)
{ {
case 0: return 0; case 0: new_dim = 0; break;
case 1: return 2; case 1: new_dim = 2; break;
case 2: return 3; case 2: new_dim = 3; break;
case 3: return 1; case 3: new_dim = 1; break;
default: return T{dim}; default: break;
} }
} }
return T{dim}; return new_dim;
} }
std::vector<int64_t> get_axes(size_t num_axes) const std::vector<int64_t> get_axes(size_t num_axes) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment