Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
ee39cf0c
Commit
ee39cf0c
authored
Feb 28, 2019
by
Khalique
Browse files
added pad and mean op
parent
ff009f50
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
61 additions
and
0 deletions
+61
-0
src/tf/tf.cpp
src/tf/tf.cpp
+61
-0
No files found.
src/tf/tf.cpp
View file @
ee39cf0c
...
@@ -50,6 +50,19 @@ struct tf_parser
...
@@ -50,6 +50,19 @@ struct tf_parser
return
axes
;
return
axes
;
}
}
template
<
class
T
>
std
::
vector
<
T
>
parse_axes
(
std
::
vector
<
T
>
axes
)
const
{
std
::
vector
<
T
>
new_axes
;
if
(
is_nhwc
)
{
std
::
transform
(
axes
.
begin
(),
axes
.
end
(),
std
::
back_inserter
(
new_axes
),
[
&
](
size_t
axis
)
{
return
parse_axis
(
axis
);
});
}
return
new_axes
;
}
// tf stores certain attributes such as strides, dilations, as a 4D input.
// tf stores certain attributes such as strides, dilations, as a 4D input.
// The first and last dims are equal to 1, and the relevant data is in dims 2 and 3.
// The first and last dims are equal to 1, and the relevant data is in dims 2 and 3.
// This helper function reorders the data to store for the respective operator member variables.
// This helper function reorders the data to store for the respective operator member variables.
...
@@ -104,6 +117,8 @@ struct tf_parser
...
@@ -104,6 +117,8 @@ struct tf_parser
add_mem_op
(
"Conv2D"
,
&
tf_parser
::
parse_conv
);
add_mem_op
(
"Conv2D"
,
&
tf_parser
::
parse_conv
);
add_mem_op
(
"FusedBatchNorm"
,
&
tf_parser
::
parse_batchnorm
);
add_mem_op
(
"FusedBatchNorm"
,
&
tf_parser
::
parse_batchnorm
);
add_mem_op
(
"MaxPool"
,
&
tf_parser
::
parse_pooling
);
add_mem_op
(
"MaxPool"
,
&
tf_parser
::
parse_pooling
);
add_mem_op
(
"Mean"
,
&
tf_parser
::
parse_mean
);
add_mem_op
(
"Pad"
,
&
tf_parser
::
parse_pad
);
add_mem_op
(
"Reshape"
,
&
tf_parser
::
parse_reshape
);
add_mem_op
(
"Reshape"
,
&
tf_parser
::
parse_reshape
);
add_mem_op
(
"Softmax"
,
&
tf_parser
::
parse_softmax
);
add_mem_op
(
"Softmax"
,
&
tf_parser
::
parse_softmax
);
add_mem_op
(
"Squeeze"
,
&
tf_parser
::
parse_squeeze
);
add_mem_op
(
"Squeeze"
,
&
tf_parser
::
parse_squeeze
);
...
@@ -319,6 +334,52 @@ struct tf_parser
...
@@ -319,6 +334,52 @@ struct tf_parser
return
prog
.
add_instruction
(
op
,
{
args
[
0
],
weights
});
return
prog
.
add_instruction
(
op
,
{
args
[
0
],
weights
});
}
}
instruction_ref
parse_mean
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
auto
axes
=
parse_axes
(
args
[
1
]
->
eval
().
get
<
int32_t
>
().
to_vector
());
bool
keep_dims
=
attributes
.
at
(
"keep_dims"
).
b
();
std
::
vector
<
int32_t
>
hw_axes
{
2
,
3
};
if
(
axes
==
hw_axes
and
keep_dims
)
{
op
::
pooling
op
{
"average"
};
std
::
vector
<
size_t
>
input_dims
{
args
[
0
]
->
get_shape
().
lens
()};
op
.
lengths
[
0
]
=
input_dims
[
2
];
op
.
lengths
[
1
]
=
input_dims
[
3
];
return
prog
.
add_instruction
(
op
,
args
.
front
());
}
MIGRAPHX_THROW
(
"MIGraphX does not support mean outside of GlobalAvgPool transformation"
);
}
instruction_ref
parse_pad
(
const
std
::
string
&
,
const
attribute_map
&
,
std
::
vector
<
instruction_ref
>
args
)
{
size_t
ndims
=
args
.
front
()
->
get_shape
().
lens
().
size
();
// in tf, the paddings are arranged as a 2d shape (ndims, 2),
// the last dim contains the left padding and right padding respectively
std
::
vector
<
std
::
pair
<
int32_t
,
int32_t
>>
pad_per_dim
(
ndims
);
auto
tf_padding
=
args
[
1
]
->
eval
().
get
<
int32_t
>
().
to_vector
();
for
(
size_t
i
=
0
;
i
<
2
*
ndims
;
i
+=
2
)
{
pad_per_dim
[
i
/
2
].
first
=
tf_padding
[
i
];
pad_per_dim
[
i
/
2
].
second
=
tf_padding
[
i
+
1
];
}
reorder_data
(
pad_per_dim
);
op
::
pad
op
;
std
::
vector
<
int64_t
>
pads
(
ndims
*
2
);
for
(
size_t
i
=
0
;
i
<
ndims
;
i
++
)
{
pads
[
i
]
=
pad_per_dim
[
i
].
first
;
pads
[
i
+
ndims
]
=
pad_per_dim
[
i
].
second
;
}
op
.
pads
=
pads
;
return
prog
.
add_instruction
(
op
,
args
.
front
());
}
instruction_ref
parse_pooling
(
const
std
::
string
&
name
,
instruction_ref
parse_pooling
(
const
std
::
string
&
name
,
attribute_map
attributes
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
std
::
vector
<
instruction_ref
>
args
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment