Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
981cf16f
Unverified
Commit
981cf16f
authored
May 20, 2019
by
Paul Fultz II
Committed by
GitHub
May 20, 2019
Browse files
Merge branch 'develop' into reshape
parents
7b6d6aee
3b284e9b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
47 additions
and
8 deletions
+47
-8
src/tf/tf.cpp
src/tf/tf.cpp
+13
-8
test/tf/mean_test.pb
test/tf/mean_test.pb
+0
-0
test/tf/mean_test_nhwc.pb
test/tf/mean_test_nhwc.pb
+0
-0
test/tf/tf_test.cpp
test/tf/tf_test.cpp
+34
-0
No files found.
src/tf/tf.cpp
View file @
981cf16f
...
@@ -53,15 +53,16 @@ struct tf_parser
...
@@ -53,15 +53,16 @@ struct tf_parser
template
<
class
T
>
template
<
class
T
>
std
::
vector
<
T
>
parse_axes
(
std
::
vector
<
T
>
axes
)
const
std
::
vector
<
T
>
parse_axes
(
std
::
vector
<
T
>
axes
)
const
{
{
std
::
vector
<
T
>
new_axes
;
if
(
is_nhwc
)
if
(
is_nhwc
)
{
{
std
::
vector
<
T
>
new_axes
;
std
::
transform
(
axes
.
begin
(),
std
::
transform
(
axes
.
begin
(),
axes
.
end
(),
axes
.
end
(),
std
::
back_inserter
(
new_axes
),
std
::
back_inserter
(
new_axes
),
[
&
](
size_t
axis
)
{
return
parse_axis
(
axis
);
});
[
&
](
size_t
axis
)
{
return
parse_axis
(
axis
);
});
return
new_axes
;
}
}
return
new_
axes
;
return
axes
;
}
}
// tf stores certain attributes such as strides, dilations, as a 4D input.
// tf stores certain attributes such as strides, dilations, as a 4D input.
...
@@ -428,17 +429,21 @@ struct tf_parser
...
@@ -428,17 +429,21 @@ struct tf_parser
instruction_ref
instruction_ref
parse_mean
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
parse_mean
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
{
auto
axes
=
parse_axes
(
args
[
1
]
->
eval
().
get
<
int32_t
>
().
to_vector
());
auto
axes
=
parse_axes
(
args
[
1
]
->
eval
().
get
<
int32_t
>
().
to_vector
());
bool
keep_dims
=
attributes
.
at
(
"keep_dims"
).
b
();
bool
keep_dims
=
attributes
.
at
(
"keep_dims"
).
b
();
std
::
vector
<
int32_t
>
hw_axes
{
2
,
3
};
std
::
vector
<
int32_t
>
hw_axes
{
2
,
3
};
if
(
axes
==
hw_axes
and
keep_dims
)
// check if conditions for GlobalAvgPool are met
auto
lens
=
args
[
0
]
->
get_shape
().
lens
();
if
(
axes
==
hw_axes
and
lens
.
size
()
==
4
)
{
{
op
::
pooling
op
{
"average"
};
op
::
pooling
op
{
"average"
};
std
::
vector
<
size_t
>
input_dims
{
args
[
0
]
->
get_shape
().
lens
()};
op
.
lengths
[
0
]
=
lens
[
2
];
op
.
lengths
[
0
]
=
input_dims
[
2
];
op
.
lengths
[
1
]
=
lens
[
3
];
op
.
lengths
[
1
]
=
input_dims
[
3
];
auto
l0
=
prog
.
add_instruction
(
op
,
args
.
front
());
return
prog
.
add_instruction
(
op
,
args
.
front
());
if
(
keep_dims
)
return
l0
;
return
prog
.
add_instruction
(
op
::
squeeze
{
std
::
vector
<
int64_t
>
(
hw_axes
.
begin
(),
hw_axes
.
end
())},
l0
);
}
}
MIGRAPHX_THROW
(
"MIGraphX does not support mean outside of GlobalAvgPool transformation"
);
MIGRAPHX_THROW
(
"MIGraphX does not support mean outside of GlobalAvgPool transformation"
);
}
}
...
...
test/tf/mean_test.pb
0 → 100644
View file @
981cf16f
File added
test/tf/mean_test_nhwc.pb
0 → 100644
View file @
981cf16f
File added
test/tf/tf_test.cpp
View file @
981cf16f
...
@@ -169,6 +169,40 @@ TEST_CASE(matmul_test)
...
@@ -169,6 +169,40 @@ TEST_CASE(matmul_test)
EXPECT
(
p
==
prog
);
EXPECT
(
p
==
prog
);
}
}
TEST_CASE
(
mean_test
)
{
migraphx
::
program
p
;
migraphx
::
literal
l
{
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
2
}},
{
2
,
3
}};
auto
l0
=
p
.
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
16
,
16
}});
p
.
add_literal
(
l
);
p
.
add_literal
(
l
);
migraphx
::
op
::
pooling
op
;
op
.
lengths
=
{
16
,
16
};
auto
l3
=
p
.
add_instruction
(
op
,
l0
);
p
.
add_instruction
(
migraphx
::
op
::
squeeze
{{
2
,
3
}},
l3
);
p
.
add_instruction
(
op
,
l0
);
auto
prog
=
migraphx
::
parse_tf
(
"mean_test.pb"
,
false
);
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
mean_test_nhwc
)
{
migraphx
::
program
p
;
migraphx
::
literal
l
{
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
2
}},
{
1
,
2
}};
auto
l0
=
p
.
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
16
,
16
}});
p
.
add_literal
(
l
);
p
.
add_literal
(
l
);
migraphx
::
op
::
pooling
op
;
op
.
lengths
=
{
16
,
16
};
auto
l3
=
p
.
add_instruction
(
op
,
l0
);
p
.
add_instruction
(
migraphx
::
op
::
squeeze
{{
2
,
3
}},
l3
);
p
.
add_instruction
(
op
,
l0
);
auto
prog
=
migraphx
::
parse_tf
(
"mean_test_nhwc.pb"
,
true
);
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
mul_test
)
TEST_CASE
(
mul_test
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment