Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
499e7938
Commit
499e7938
authored
Aug 16, 2019
by
Khalique
Browse files
add function for axis mask
parent
63410264
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
19 deletions
+17
-19
src/tf/tf.cpp
src/tf/tf.cpp
+17
-19
No files found.
src/tf/tf.cpp
View file @
499e7938
...
@@ -148,6 +148,21 @@ struct tf_parser
...
@@ -148,6 +148,21 @@ struct tf_parser
return
axes
;
return
axes
;
}
}
std
::
vector
<
int64_t
>
get_axes_from_mask
(
const
size_t
num_axes
,
const
uint32_t
mask
)
{
uint32_t
bitwise_compare
=
1
;
std
::
vector
<
int64_t
>
axes
;
for
(
size_t
i
=
0
;
i
<
num_axes
;
i
++
)
{
// the LSB corresponds to axis 0 when determining which axes to begin
if
(((
mask
>>
i
)
&
bitwise_compare
)
==
1
)
axes
.
push_back
(
1
);
else
axes
.
push_back
(
0
);
}
return
axes
;
}
tf_parser
()
tf_parser
()
{
{
add_generic_op
(
"All"
,
op
::
identity
{});
add_generic_op
(
"All"
,
op
::
identity
{});
...
@@ -837,8 +852,6 @@ struct tf_parser
...
@@ -837,8 +852,6 @@ struct tf_parser
uint32_t
end_mask
=
0
;
uint32_t
end_mask
=
0
;
uint32_t
shrink_axis_mask
=
0
;
uint32_t
shrink_axis_mask
=
0
;
uint32_t
bitwise_compare
=
1
;
uint32_t
bitwise_compare
=
1
;
std
::
vector
<
int64_t
>
begin_axes
;
std
::
vector
<
int64_t
>
end_axes
;
std
::
vector
<
int64_t
>
squeeze_axes
;
std
::
vector
<
int64_t
>
squeeze_axes
;
if
(
contains
(
attributes
,
"begin_mask"
))
if
(
contains
(
attributes
,
"begin_mask"
))
...
@@ -850,23 +863,8 @@ struct tf_parser
...
@@ -850,23 +863,8 @@ struct tf_parser
if
(
contains
(
attributes
,
"shrink_axis_mask"
))
if
(
contains
(
attributes
,
"shrink_axis_mask"
))
shrink_axis_mask
=
static_cast
<
uint32_t
>
(
attributes
.
at
(
"shrink_axis_mask"
).
i
());
shrink_axis_mask
=
static_cast
<
uint32_t
>
(
attributes
.
at
(
"shrink_axis_mask"
).
i
());
for
(
size_t
i
=
0
;
i
<
num_axes
;
i
++
)
std
::
vector
<
int64_t
>
begin_axes
=
get_axes_from_mask
(
num_axes
,
begin_mask
);
{
std
::
vector
<
int64_t
>
end_axes
=
get_axes_from_mask
(
num_axes
,
end_mask
);
// the LSB corresponds to axis 0 when determining which axes to begin
if
(((
begin_mask
>>
i
)
&
bitwise_compare
)
==
1
)
begin_axes
.
push_back
(
1
);
else
begin_axes
.
push_back
(
0
);
}
for
(
size_t
i
=
0
;
i
<
num_axes
;
i
++
)
{
// the LSB corresponds to axis 0 when determining which axes to end
if
(((
end_mask
>>
i
)
&
bitwise_compare
)
==
1
)
end_axes
.
push_back
(
1
);
else
end_axes
.
push_back
(
0
);
}
for
(
size_t
i
=
0
;
i
<
num_axes
;
i
++
)
for
(
size_t
i
=
0
;
i
<
num_axes
;
i
++
)
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment