Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
86e8a8e6
Commit
86e8a8e6
authored
May 10, 2019
by
Khalique
Browse files
Merge branch 'develop' of
https://github.com/ROCmSoftwarePlatform/AMDMIGraphX
into parse_mean_fix
parents
6c27b962
b93f5320
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
100 additions
and
0 deletions
+100
-0
src/tf/tf.cpp
src/tf/tf.cpp
+58
-0
test/tf/depthwise_conv_test.pb
test/tf/depthwise_conv_test.pb
+0
-0
test/tf/relu6_test.pb
test/tf/relu6_test.pb
+8
-0
test/tf/tf_test.cpp
test/tf/tf_test.cpp
+34
-0
No files found.
src/tf/tf.cpp
View file @
86e8a8e6
...
...
@@ -108,6 +108,7 @@ struct tf_parser
{
add_generic_op
(
"Identity"
,
op
::
identity
{});
add_generic_op
(
"Relu"
,
op
::
relu
{});
add_generic_op
(
"Relu6"
,
op
::
clip
{
6.0
,
0.0
});
add_binary_op
(
"Add"
,
op
::
add
{});
add_binary_op
(
"Mul"
,
op
::
mul
{});
...
...
@@ -117,6 +118,7 @@ struct tf_parser
add_mem_op
(
"ConcatV2"
,
&
tf_parser
::
parse_concat
);
add_mem_op
(
"Const"
,
&
tf_parser
::
parse_constant
);
add_mem_op
(
"Conv2D"
,
&
tf_parser
::
parse_conv
);
add_mem_op
(
"DepthwiseConv2dNative"
,
&
tf_parser
::
parse_depthwiseconv
);
add_mem_op
(
"FusedBatchNorm"
,
&
tf_parser
::
parse_batchnorm
);
add_mem_op
(
"MatMul"
,
&
tf_parser
::
parse_matmul
);
add_mem_op
(
"MaxPool"
,
&
tf_parser
::
parse_pooling
);
...
...
@@ -339,6 +341,62 @@ struct tf_parser
return
prog
.
add_instruction
(
op
,
{
args
[
0
],
weights
});
}
instruction_ref
parse_depthwiseconv
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
op
::
convolution
op
;
size_t
num_channels
=
args
[
0
]
->
get_shape
().
lens
()[
1
];
op
.
group
=
num_channels
;
if
(
contains
(
attributes
,
"padding"
))
{
const
std
::
string
&
pad_mode
=
attributes
.
at
(
"padding"
).
s
();
if
(
pad_mode
.
find
(
"SAME"
)
!=
std
::
string
::
npos
)
{
op
.
padding_mode
=
op
::
padding_mode_t
::
same
;
}
}
if
(
contains
(
attributes
,
"strides"
))
{
std
::
vector
<
size_t
>
stride
;
copy
(
attributes
.
at
(
"strides"
).
list
().
i
(),
std
::
back_inserter
(
stride
));
reorder_data
(
stride
);
if
(
stride
.
size
()
!=
4
)
{
MIGRAPHX_THROW
(
"strides should have 4 values"
);
}
op
.
stride
[
0
]
=
stride
[
2
];
op
.
stride
[
1
]
=
stride
[
3
];
}
auto
weights
=
args
[
1
];
// check if weights are from a constant
if
(
weights
->
name
()
!=
"@param"
)
{
if
(
is_nhwc
)
{
weights
=
prog
.
add_instruction
(
op
::
transpose
{{
1
,
3
,
0
,
2
}},
args
[
1
]);
}
else
{
weights
=
prog
.
add_instruction
(
op
::
transpose
{{
3
,
2
,
0
,
1
}},
args
[
1
]);
}
}
std
::
vector
<
int64_t
>
new_weights_shape
;
copy
(
weights
->
get_shape
().
lens
(),
std
::
back_inserter
(
new_weights_shape
));
// weight format is (out_channels, in_channels, h, w), but in depthwise_conv,
// out_channels is equal to the multiplier. Adjust by inserting a reshape and
// setting in_channels to 1
int64_t
multiplier
=
new_weights_shape
[
0
];
int64_t
out_channels
=
num_channels
*
multiplier
;
new_weights_shape
[
0
]
=
out_channels
;
new_weights_shape
[
1
]
=
1
;
auto
new_weights
=
prog
.
add_instruction
(
op
::
reshape
{
new_weights_shape
},
weights
);
return
prog
.
add_instruction
(
op
,
{
args
[
0
],
new_weights
});
}
instruction_ref
parse_matmul
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
...
...
test/tf/depthwise_conv_test.pb
0 → 100644
View file @
86e8a8e6
File added
test/tf/relu6_test.pb
0 → 100644
View file @
86e8a8e6
:
0Placeholder*
dtype0*
shape:
relu6Relu60*
T0"
\ No newline at end of file
test/tf/tf_test.cpp
View file @
86e8a8e6
...
...
@@ -119,6 +119,30 @@ TEST_CASE(conv_test)
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
depthwiseconv_test
)
{
migraphx
::
program
p
;
auto
l0
=
p
.
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
16
,
16
}});
std
::
vector
<
float
>
weight_data
(
3
*
3
*
3
*
1
);
std
::
fill
(
weight_data
.
begin
(),
weight_data
.
end
(),
1.0
f
);
auto
l1
=
p
.
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
3
,
3
,
3
,
1
}},
weight_data
);
migraphx
::
op
::
convolution
op
;
op
.
padding_mode
=
migraphx
::
op
::
padding_mode_t
::
same
;
op
.
stride
=
{
1
,
1
};
op
.
dilation
=
{
1
,
1
};
op
.
group
=
3
;
auto
l2
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
3
,
1
,
2
}},
l1
);
auto
l3
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
3
,
0
,
2
}},
l2
);
auto
l4
=
p
.
add_instruction
(
migraphx
::
op
::
reshape
{{
3
,
1
,
3
,
3
}},
l3
);
p
.
add_instruction
(
op
,
l0
,
l4
);
auto
prog
=
migraphx
::
parse_tf
(
"depthwise_conv_test.pb"
,
true
);
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
identity_test
)
{
migraphx
::
program
p
;
...
...
@@ -229,6 +253,16 @@ TEST_CASE(relu_test)
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
relu6_test
)
{
migraphx
::
program
p
;
auto
l0
=
p
.
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
16
,
16
}});
p
.
add_instruction
(
migraphx
::
op
::
clip
{
6.0
,
0.0
},
l0
);
auto
prog
=
migraphx
::
parse_tf
(
"relu6_test.pb"
,
false
);
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
reshape_test
)
{
migraphx
::
program
p
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment