Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
6336ed52
Commit
6336ed52
authored
Jul 08, 2019
by
Khalique
Browse files
change transpose to false, adjusted tests
parent
f8ec4fa7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
7 deletions
+9
-7
src/tf/tf.cpp
src/tf/tf.cpp
+5
-5
test/tf/tf_test.cpp
test/tf/tf_test.cpp
+4
-2
No files found.
src/tf/tf.cpp
View file @
6336ed52
...
...
@@ -80,7 +80,7 @@ struct tf_parser
}
std
::
vector
<
size_t
>
parse_axes
(
const
attribute_map
&
attributes
,
const
std
::
string
&
s
,
const
size_t
&
num_dims
)
const
parse_axes
(
const
attribute_map
&
attributes
,
const
std
::
string
&
s
,
const
size_t
num_dims
)
const
{
auto
attrs
=
attributes
.
at
(
s
).
list
().
i
();
std
::
vector
<
size_t
>
axes
;
...
...
@@ -95,7 +95,7 @@ struct tf_parser
}
template
<
class
T
>
std
::
vector
<
T
>
parse_axes
(
std
::
vector
<
T
>
axes
,
const
size_t
&
num_dims
)
const
std
::
vector
<
T
>
parse_axes
(
std
::
vector
<
T
>
axes
,
const
size_t
num_dims
)
const
{
if
(
is_nhwc
)
{
...
...
@@ -125,7 +125,7 @@ struct tf_parser
}
template
<
class
T
>
T
parse_axis
(
const
T
&
dim
,
const
size_t
&
num_dims
)
const
T
parse_axis
(
const
T
&
dim
,
const
size_t
num_dims
)
const
{
T
new_dim
=
dim
;
if
(
is_nhwc
and
num_dims
>=
4
)
...
...
@@ -166,7 +166,7 @@ struct tf_parser
add_mem_op
(
"Const"
,
&
tf_parser
::
parse_constant
);
add_mem_op
(
"Conv2D"
,
&
tf_parser
::
parse_conv
);
add_mem_op
(
"DepthwiseConv2dNative"
,
&
tf_parser
::
parse_depthwiseconv
);
add_mem_op
(
"ExpandDims"
,
&
tf_parser
::
parse_expanddims
);
add_mem_op
(
"ExpandDims"
,
&
tf_parser
::
parse_expanddims
,
false
);
add_mem_op
(
"FusedBatchNorm"
,
&
tf_parser
::
parse_batchnorm
);
add_mem_op
(
"MatMul"
,
&
tf_parser
::
parse_matmul
,
false
);
add_mem_op
(
"MaxPool"
,
&
tf_parser
::
parse_pooling
);
...
...
@@ -498,7 +498,7 @@ struct tf_parser
std
::
vector
<
size_t
>
input_dims
=
args
[
0
]
->
get_shape
().
lens
();
std
::
vector
<
int64_t
>
new_dims
(
input_dims
.
begin
(),
input_dims
.
end
());
size_t
num_dims
=
input_dims
.
size
();
int32_t
dim
=
parse_axis
(
args
[
1
]
->
eval
().
at
<
int32_t
>
()
,
num_dims
)
;
int32_t
dim
=
args
[
1
]
->
eval
().
at
<
int32_t
>
();
if
(
dim
<
0
)
{
...
...
test/tf/tf_test.cpp
View file @
6336ed52
...
...
@@ -164,8 +164,9 @@ TEST_CASE(expanddims_test)
migraphx
::
program
p
;
auto
l0
=
p
.
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
}});
p
.
add_literal
(
0
);
p
.
add_instruction
(
migraphx
::
op
::
reshape
{{
1
,
2
,
3
,
4
}},
l0
);
auto
prog
=
optimize_tf
(
"expanddims_test.pb"
,
tru
e
);
auto
prog
=
optimize_tf
(
"expanddims_test.pb"
,
fals
e
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -176,8 +177,9 @@ TEST_CASE(expanddims_test_neg_dims)
migraphx
::
program
p
;
auto
l0
=
p
.
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
}});
p
.
add_literal
(
-
1
);
p
.
add_instruction
(
migraphx
::
op
::
reshape
{{
2
,
3
,
4
,
1
}},
l0
);
auto
prog
=
optimize_tf
(
"expanddims_neg_test.pb"
,
tru
e
);
auto
prog
=
optimize_tf
(
"expanddims_neg_test.pb"
,
fals
e
);
EXPECT
(
p
==
prog
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment