Unverified Commit 8640d392 authored by kahmed10's avatar kahmed10 Committed by GitHub
Browse files

fix tf fusion regressions (#755)



* fix relu6

* add more transposes

* add tests

* formatting
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent c42452e5
...@@ -9,6 +9,7 @@ namespace tf { ...@@ -9,6 +9,7 @@ namespace tf {
struct parse_binary_op : op_parser<parse_binary_op> struct parse_binary_op : op_parser<parse_binary_op>
{ {
bool transpose() const { return true; }
std::vector<op_desc> operators() const std::vector<op_desc> operators() const
{ {
return {{"Add", "add"}, return {{"Add", "add"},
......
...@@ -9,6 +9,7 @@ namespace tf { ...@@ -9,6 +9,7 @@ namespace tf {
struct parse_generic_op : op_parser<parse_generic_op> struct parse_generic_op : op_parser<parse_generic_op>
{ {
bool transpose() const { return true; }
std::vector<op_desc> operators() const std::vector<op_desc> operators() const
{ {
return {{"All", "identity"}, return {{"All", "identity"},
......
...@@ -10,6 +10,7 @@ namespace tf { ...@@ -10,6 +10,7 @@ namespace tf {
struct parse_relu6 : op_parser<parse_relu6> struct parse_relu6 : op_parser<parse_relu6>
{ {
bool transpose() const { return true; }
std::vector<op_desc> operators() const { return {{"Relu6"}}; } std::vector<op_desc> operators() const { return {{"Relu6"}}; }
instruction_ref parse(const op_desc& /*opd*/, instruction_ref parse(const op_desc& /*opd*/,
......
...@@ -196,6 +196,23 @@ def conv_test(g1): ...@@ -196,6 +196,23 @@ def conv_test(g1):
tf.nn.conv2d(g1_input, g1_weights, [1, 1, 1, 1], "SAME", name='conv1') tf.nn.conv2d(g1_input, g1_weights, [1, 1, 1, 1], "SAME", name='conv1')
@tf_test
def conv_add_test(g1):
with g1.as_default():
g1_input = tf.compat.v1.placeholder(tf.float32,
shape=(1, 16, 16, 3),
name='0')
g1_weights = tf.constant(value=1.0,
dtype=tf.float32,
shape=(3, 3, 3, 32),
name='1')
conv = tf.nn.conv2d(g1_input,
g1_weights, [1, 1, 1, 1],
"SAME",
name='conv1')
tf.add(conv, conv, name='add1')
@tf_test @tf_test
def conv_nchw_test(g1): def conv_nchw_test(g1):
with g1.as_default(): with g1.as_default():
...@@ -213,6 +230,40 @@ def conv_nchw_test(g1): ...@@ -213,6 +230,40 @@ def conv_nchw_test(g1):
name='conv1') name='conv1')
@tf_test
def conv_relu_test(g1):
with g1.as_default():
g1_input = tf.compat.v1.placeholder(tf.float32,
shape=(1, 16, 16, 3),
name='0')
g1_weights = tf.constant(value=1.0,
dtype=tf.float32,
shape=(3, 3, 3, 32),
name='1')
conv = tf.nn.conv2d(g1_input,
g1_weights, [1, 1, 1, 1],
"SAME",
name='conv1')
tf.nn.relu(conv, name='relu1')
@tf_test
def conv_relu6_test(g1):
with g1.as_default():
g1_input = tf.compat.v1.placeholder(tf.float32,
shape=(1, 16, 16, 3),
name='0')
g1_weights = tf.constant(value=1.0,
dtype=tf.float32,
shape=(3, 3, 3, 32),
name='1')
conv = tf.nn.conv2d(g1_input,
g1_weights, [1, 1, 1, 1],
"SAME",
name='conv1')
tf.nn.relu6(conv, name='relu1')
@tf_test @tf_test
def depthwiseconv_test(g1): def depthwiseconv_test(g1):
with g1.as_default(): with g1.as_default():
...@@ -582,7 +633,10 @@ if __name__ == '__main__': ...@@ -582,7 +633,10 @@ if __name__ == '__main__':
concat_test() concat_test()
const_test() const_test()
conv_test() conv_test()
conv_add_test()
conv_nchw_test() conv_nchw_test()
conv_relu_test()
conv_relu6_test()
depthwiseconv_test() depthwiseconv_test()
expanddims_test() expanddims_test()
gather_test() gather_test()
......
...@@ -303,6 +303,17 @@ TEST_CASE(conv_test) ...@@ -303,6 +303,17 @@ TEST_CASE(conv_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(conv_add_test)
{
migraphx::program p = create_conv();
auto* mm = p.get_main_module();
auto l0 = std::prev(mm->end());
mm->add_instruction(migraphx::make_op("add"), l0, l0);
auto prog = optimize_tf("conv_add_test.pb", true);
EXPECT(p == prog);
}
TEST_CASE(conv_nchw_test) TEST_CASE(conv_nchw_test)
{ {
migraphx::program p = create_conv(); migraphx::program p = create_conv();
...@@ -311,6 +322,35 @@ TEST_CASE(conv_nchw_test) ...@@ -311,6 +322,35 @@ TEST_CASE(conv_nchw_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(conv_relu_test)
{
migraphx::program p = create_conv();
auto* mm = p.get_main_module();
auto l0 = std::prev(mm->end());
mm->add_instruction(migraphx::make_op("relu"), l0);
auto prog = optimize_tf("conv_relu_test.pb", true);
EXPECT(p == prog);
}
TEST_CASE(conv_relu6_test)
{
migraphx::program p = create_conv();
auto* mm = p.get_main_module();
std::vector<size_t> input_lens{1, 32, 16, 16};
auto l0 = std::prev(mm->end());
auto min_val = mm->add_literal(0.0f);
auto max_val = mm->add_literal(6.0f);
min_val = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), min_val);
max_val = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", input_lens}}), max_val);
mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val);
auto prog = optimize_tf("conv_relu6_test.pb", true);
EXPECT(p == prog);
}
TEST_CASE(depthwiseconv_test) TEST_CASE(depthwiseconv_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment