Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
68d5b22b
Commit
68d5b22b
authored
Jun 25, 2019
by
Khalique
Browse files
add expanddims plus tests
parent
15eb1987
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
50 additions
and
19 deletions
+50
-19
src/tf/tf.cpp
src/tf/tf.cpp
+36
-19
test/tf/expanddims_test.pb
test/tf/expanddims_test.pb
+0
-0
test/tf/tf_test.cpp
test/tf/tf_test.cpp
+14
-0
No files found.
src/tf/tf.cpp
View file @
68d5b22b
...
@@ -37,7 +37,7 @@ struct tf_parser
...
@@ -37,7 +37,7 @@ struct tf_parser
std
::
unordered_map
<
std
::
string
,
op_func
>
ops
;
std
::
unordered_map
<
std
::
string
,
op_func
>
ops
;
std
::
vector
<
size_t
>
parse_axes
(
const
attribute_map
&
attributes
,
const
std
::
string
&
s
)
const
std
::
vector
<
size_t
>
parse_axes
(
const
attribute_map
&
attributes
,
const
std
::
string
&
s
,
const
size_t
&
num_dims
)
const
{
{
auto
attrs
=
attributes
.
at
(
s
).
list
().
i
();
auto
attrs
=
attributes
.
at
(
s
).
list
().
i
();
std
::
vector
<
size_t
>
axes
;
std
::
vector
<
size_t
>
axes
;
...
@@ -45,14 +45,14 @@ struct tf_parser
...
@@ -45,14 +45,14 @@ struct tf_parser
if
(
is_nhwc
)
if
(
is_nhwc
)
{
{
std
::
transform
(
axes
.
begin
(),
axes
.
end
(),
axes
.
begin
(),
[
&
](
size_t
axis
)
{
std
::
transform
(
axes
.
begin
(),
axes
.
end
(),
axes
.
begin
(),
[
&
](
size_t
axis
)
{
return
parse_axis
(
axis
);
return
parse_axis
(
axis
,
num_dims
);
});
});
}
}
return
axes
;
return
axes
;
}
}
template
<
class
T
>
template
<
class
T
>
std
::
vector
<
T
>
parse_axes
(
std
::
vector
<
T
>
axes
)
const
std
::
vector
<
T
>
parse_axes
(
std
::
vector
<
T
>
axes
,
const
size_t
&
num_dims
)
const
{
{
if
(
is_nhwc
)
if
(
is_nhwc
)
{
{
...
@@ -60,7 +60,7 @@ struct tf_parser
...
@@ -60,7 +60,7 @@ struct tf_parser
std
::
transform
(
axes
.
begin
(),
std
::
transform
(
axes
.
begin
(),
axes
.
end
(),
axes
.
end
(),
std
::
back_inserter
(
new_axes
),
std
::
back_inserter
(
new_axes
),
[
&
](
size_t
axis
)
{
return
parse_axis
(
axis
);
});
[
&
](
size_t
axis
)
{
return
parse_axis
(
axis
,
num_dims
);
});
return
new_axes
;
return
new_axes
;
}
}
return
axes
;
return
axes
;
...
@@ -75,17 +75,17 @@ struct tf_parser
...
@@ -75,17 +75,17 @@ struct tf_parser
std
::
vector
<
T
>
new_data
(
prev_data
.
size
());
std
::
vector
<
T
>
new_data
(
prev_data
.
size
());
for
(
size_t
i
=
0
;
i
<
new_data
.
size
();
i
++
)
for
(
size_t
i
=
0
;
i
<
new_data
.
size
();
i
++
)
{
{
auto
new_idx
=
parse_axis
(
i
);
auto
new_idx
=
parse_axis
(
i
,
new_data
.
size
()
);
new_data
.
at
(
new_idx
)
=
prev_data
.
at
(
i
);
new_data
.
at
(
new_idx
)
=
prev_data
.
at
(
i
);
}
}
prev_data
=
new_data
;
prev_data
=
new_data
;
}
}
template
<
class
T
>
template
<
class
T
>
T
parse_axis
(
const
T
&
dim
)
const
T
parse_axis
(
const
T
&
dim
,
const
size_t
&
num_dims
)
const
{
{
T
new_dim
=
dim
;
T
new_dim
=
dim
;
if
(
is_nhwc
)
if
(
is_nhwc
and
num_dims
>=
4
)
{
{
switch
(
dim
)
switch
(
dim
)
{
{
...
@@ -121,6 +121,7 @@ struct tf_parser
...
@@ -121,6 +121,7 @@ struct tf_parser
add_mem_op
(
"Const"
,
&
tf_parser
::
parse_constant
);
add_mem_op
(
"Const"
,
&
tf_parser
::
parse_constant
);
add_mem_op
(
"Conv2D"
,
&
tf_parser
::
parse_conv
);
add_mem_op
(
"Conv2D"
,
&
tf_parser
::
parse_conv
);
add_mem_op
(
"DepthwiseConv2dNative"
,
&
tf_parser
::
parse_depthwiseconv
);
add_mem_op
(
"DepthwiseConv2dNative"
,
&
tf_parser
::
parse_depthwiseconv
);
add_mem_op
(
"ExpandDims"
,
&
tf_parser
::
parse_expanddims
);
add_mem_op
(
"FusedBatchNorm"
,
&
tf_parser
::
parse_batchnorm
);
add_mem_op
(
"FusedBatchNorm"
,
&
tf_parser
::
parse_batchnorm
);
add_mem_op
(
"MatMul"
,
&
tf_parser
::
parse_matmul
);
add_mem_op
(
"MatMul"
,
&
tf_parser
::
parse_matmul
);
add_mem_op
(
"MaxPool"
,
&
tf_parser
::
parse_pooling
);
add_mem_op
(
"MaxPool"
,
&
tf_parser
::
parse_pooling
);
...
@@ -251,7 +252,7 @@ struct tf_parser
...
@@ -251,7 +252,7 @@ struct tf_parser
{
{
// get index for axis within args
// get index for axis within args
size_t
axis_idx
=
attributes
.
at
(
"N"
).
i
();
size_t
axis_idx
=
attributes
.
at
(
"N"
).
i
();
size_t
axis
=
parse_axis
(
args
[
axis_idx
]
->
eval
().
at
<
int64_t
>
());
size_t
axis
=
parse_axis
(
args
[
axis_idx
]
->
eval
().
at
<
int64_t
>
()
,
args
[
0
]
->
get_shape
().
lens
().
size
()
);
op
::
concat
op
{
axis
};
op
::
concat
op
{
axis
};
// return only first N arguments (assuming last index is the axis value)
// return only first N arguments (assuming last index is the axis value)
return
prog
.
add_instruction
(
return
prog
.
add_instruction
(
...
@@ -470,6 +471,24 @@ struct tf_parser
...
@@ -470,6 +471,24 @@ struct tf_parser
return
prog
.
add_instruction
(
op
,
{
l0
,
new_weights
});
return
prog
.
add_instruction
(
op
,
{
l0
,
new_weights
});
}
}
instruction_ref
parse_expanddims
(
const
std
::
string
&
,
const
attribute_map
&
,
std
::
vector
<
instruction_ref
>
args
)
{
std
::
vector
<
size_t
>
input_dims
=
args
[
0
]
->
get_shape
().
lens
();
std
::
vector
<
int64_t
>
new_dims
(
input_dims
.
begin
(),
input_dims
.
end
());
size_t
num_dims
=
input_dims
.
size
();
int32_t
dim
=
parse_axis
(
args
[
1
]
->
eval
().
at
<
int32_t
>
(),
num_dims
);
if
(
dim
<
0
)
{
new_dims
.
insert
(
new_dims
.
begin
()
+
(
num_dims
+
dim
+
1
),
1
);
}
else
{
new_dims
.
insert
(
new_dims
.
begin
()
+
dim
,
1
);
}
return
prog
.
add_instruction
(
op
::
reshape
{
new_dims
},
args
[
0
]);
}
instruction_ref
instruction_ref
parse_matmul
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
parse_matmul
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
{
...
@@ -499,11 +518,12 @@ struct tf_parser
...
@@ -499,11 +518,12 @@ struct tf_parser
instruction_ref
instruction_ref
parse_mean
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
parse_mean
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
{
auto
axes
=
parse_axes
(
args
[
1
]
->
eval
().
get
<
int32_t
>
().
to_vector
());
bool
keep_dims
=
attributes
.
at
(
"keep_dims"
).
b
();
bool
keep_dims
=
attributes
.
at
(
"keep_dims"
).
b
();
std
::
vector
<
int32_t
>
hw_axes
{
2
,
3
};
std
::
vector
<
int32_t
>
hw_axes
{
2
,
3
};
// check if conditions for GlobalAvgPool are met
// check if conditions for GlobalAvgPool are met
auto
lens
=
args
[
0
]
->
get_shape
().
lens
();
auto
lens
=
args
[
0
]
->
get_shape
().
lens
();
auto
axes
=
parse_axes
(
args
[
1
]
->
eval
().
get
<
int32_t
>
().
to_vector
(),
lens
.
size
());
if
(
axes
==
hw_axes
and
lens
.
size
()
==
4
)
if
(
axes
==
hw_axes
and
lens
.
size
()
==
4
)
{
{
op
::
pooling
op
{
"average"
};
op
::
pooling
op
{
"average"
};
...
@@ -534,8 +554,7 @@ struct tf_parser
...
@@ -534,8 +554,7 @@ struct tf_parser
" must be smaller than input size "
+
to_string
(
input_size
));
" must be smaller than input size "
+
to_string
(
input_size
));
}
}
// check if input arg needs axis to be converted to NCHW
// check if input arg needs axis to be converted to NCHW
if
(
input_size
>=
4
)
axis
=
parse_axis
(
axis
,
input_size
);
axis
=
parse_axis
(
axis
);
std
::
transform
(
std
::
transform
(
args
.
begin
(),
args
.
begin
(),
...
@@ -676,14 +695,15 @@ struct tf_parser
...
@@ -676,14 +695,15 @@ struct tf_parser
std
::
vector
<
instruction_ref
>
args
)
std
::
vector
<
instruction_ref
>
args
)
{
{
op
::
squeeze
op
;
op
::
squeeze
op
;
auto
axes
=
parse_axes
(
attributes
,
"squeeze_dims"
);
auto
input_dims
=
args
[
0
]
->
get_shape
().
lens
();
auto
axes
=
parse_axes
(
attributes
,
"squeeze_dims"
,
input_dims
.
size
());
copy
(
axes
,
std
::
back_inserter
(
op
.
axes
));
copy
(
axes
,
std
::
back_inserter
(
op
.
axes
));
auto
args0_dims
=
args
[
0
]
->
get_shape
().
lens
();
if
(
op
.
axes
.
empty
())
// no squeeze_dims provided, remove any dim that equals 1
if
(
op
.
axes
.
empty
())
// no squeeze_dims provided, remove any dim that equals 1
{
{
for
(
size_t
i
=
0
;
i
<
args0
_dims
.
size
();
i
++
)
for
(
size_t
i
=
0
;
i
<
input
_dims
.
size
();
i
++
)
{
{
if
(
args0
_dims
.
at
(
i
)
==
1
)
if
(
input
_dims
.
at
(
i
)
==
1
)
{
{
op
.
axes
.
push_back
(
i
);
op
.
axes
.
push_back
(
i
);
}
}
...
@@ -723,10 +743,7 @@ struct tf_parser
...
@@ -723,10 +743,7 @@ struct tf_parser
if
(((
shrink_axis_mask
>>
i
)
&
bitwise_compare
)
==
1
)
if
(((
shrink_axis_mask
>>
i
)
&
bitwise_compare
)
==
1
)
squeeze_axes
.
push_back
(
i
);
squeeze_axes
.
push_back
(
i
);
}
}
if
(
num_axes
>=
4
)
squeeze_axes
=
parse_axes
(
squeeze_axes
,
num_axes
);
{
squeeze_axes
=
parse_axes
(
squeeze_axes
);
}
auto
l0
=
prog
.
add_instruction
(
op
,
args
[
0
]);
auto
l0
=
prog
.
add_instruction
(
op
,
args
[
0
]);
return
prog
.
add_instruction
(
op
::
squeeze
{
squeeze_axes
},
l0
);
return
prog
.
add_instruction
(
op
::
squeeze
{
squeeze_axes
},
l0
);
...
...
test/tf/expanddims_test.pb
0 → 100644
View file @
68d5b22b
File added
test/tf/tf_test.cpp
View file @
68d5b22b
...
@@ -146,6 +146,20 @@ TEST_CASE(depthwiseconv_test)
...
@@ -146,6 +146,20 @@ TEST_CASE(depthwiseconv_test)
EXPECT
(
p
==
prog
);
EXPECT
(
p
==
prog
);
}
}
TEST_CASE
(
expanddims_test
)
{
migraphx
::
program
p
;
auto
l0
=
p
.
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
}});
p
.
add_literal
(
-
1
);
p
.
add_literal
(
0
);
p
.
add_instruction
(
migraphx
::
op
::
reshape
{{
2
,
3
,
4
,
1
}},
l0
);
p
.
add_instruction
(
migraphx
::
op
::
reshape
{{
1
,
2
,
3
,
4
}},
l0
);
auto
prog
=
migraphx
::
parse_tf
(
"expanddims_test.pb"
,
true
);
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
identity_test
)
TEST_CASE
(
identity_test
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment