Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
aa7b76b5
"test/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "ec1ac8c0440202c501df405b1c8e4c5f16dfffbc"
Commit
aa7b76b5
authored
Aug 12, 2019
by
Paul
Browse files
Parse mean as reduce mean instead of pooling
parent
1fe84f2a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
15 deletions
+14
-15
src/include/migraphx/tensor_view.hpp
src/include/migraphx/tensor_view.hpp
+5
-1
src/tf/tf.cpp
src/tf/tf.cpp
+9
-14
No files found.
src/include/migraphx/tensor_view.hpp
View file @
aa7b76b5
...
@@ -132,7 +132,11 @@ struct tensor_view
...
@@ -132,7 +132,11 @@ struct tensor_view
return
m_data
+
this
->
size
();
return
m_data
+
this
->
size
();
}
}
std
::
vector
<
T
>
to_vector
()
const
{
return
std
::
vector
<
T
>
(
this
->
begin
(),
this
->
end
());
}
template
<
class
U
=
T
>
std
::
vector
<
U
>
to_vector
()
const
{
return
std
::
vector
<
U
>
(
this
->
begin
(),
this
->
end
());
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
tensor_view
<
T
>&
x
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
tensor_view
<
T
>&
x
)
{
{
...
...
src/tf/tf.cpp
View file @
aa7b76b5
...
@@ -574,23 +574,18 @@ struct tf_parser
...
@@ -574,23 +574,18 @@ struct tf_parser
parse_mean
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
parse_mean
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
{
bool
keep_dims
=
attributes
.
at
(
"keep_dims"
).
b
();
bool
keep_dims
=
attributes
.
at
(
"keep_dims"
).
b
();
std
::
vector
<
int32_t
>
hw_axes
{
2
,
3
};
// check if conditions for GlobalAvgPool are met
auto
lens
=
args
[
0
]
->
get_shape
().
lens
();
auto
lens
=
args
[
0
]
->
get_shape
().
lens
();
auto
axes
=
parse_axes
(
args
[
1
]
->
eval
().
get
<
int32_t
>
().
to_vector
(),
lens
.
size
());
auto
axes
=
parse_axes
(
args
[
1
]
->
eval
().
get
<
int32_t
>
().
to_vector
<
int64_t
>
(),
lens
.
size
());
if
(
axes
==
hw_axes
and
lens
.
size
()
==
4
)
{
op
::
pooling
op
{
"average"
};
op
.
lengths
[
0
]
=
lens
[
2
];
op
.
lengths
[
1
]
=
lens
[
3
];
auto
l0
=
prog
.
add_instruction
(
op
,
args
.
front
());
if
(
keep_dims
)
if
(
keep_dims
)
return
l0
;
{
return
prog
.
add_instruction
(
return
prog
.
add_instruction
(
op
::
reduce_mean
{
axes
},
args
[
0
]);
op
::
squeeze
{
std
::
vector
<
int64_t
>
(
hw_axes
.
begin
(),
hw_axes
.
end
())},
l0
);
}
else
{
auto
ins
=
prog
.
add_instruction
(
op
::
reduce_mean
{
axes
},
args
[
0
]);
return
prog
.
add_instruction
(
op
::
squeeze
{
axes
},
ins
);
}
}
MIGRAPHX_THROW
(
"MIGraphX does not support mean outside of GlobalAvgPool transformation"
);
}
}
instruction_ref
parse_pack
(
const
std
::
string
&
,
instruction_ref
parse_pack
(
const
std
::
string
&
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment