Unverified Commit f2667056 authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Fix TF literal parsing for relu6 (#1370)

Fixes TF literal parsing for relu6.  previously always made a float type literal, breaks for float16 as an example
parent 60aa0e48
...@@ -41,8 +41,9 @@ struct parse_relu6 : op_parser<parse_relu6> ...@@ -41,8 +41,9 @@ struct parse_relu6 : op_parser<parse_relu6>
const tf_parser::node_info& info, const tf_parser::node_info& info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
auto min_val = info.add_literal(0.0f); shape::type_t output_type = args[0]->get_shape().type();
auto max_val = info.add_literal(6.0f); auto min_val = info.add_literal(migraphx::literal{migraphx::shape{output_type}, {0.0f}});
auto max_val = info.add_literal(migraphx::literal{migraphx::shape{output_type}, {6.0f}});
return info.add_common_op("clip", args[0], min_val, max_val); return info.add_common_op("clip", args[0], min_val, max_val);
} }
......
...@@ -495,10 +495,10 @@ def relu6_test(g1): ...@@ -495,10 +495,10 @@ def relu6_test(g1):
@tf_test @tf_test
def relu6_mismatch_test(g1): def relu6_half_test(g1):
with g1.as_default(): with g1.as_default():
g1_input = tf.compat.v1.placeholder(tf.float16, g1_input = tf.compat.v1.placeholder(tf.float16,
shape=(1, 3, 13, 37), shape=(1, 3, 16, 16),
name='0') name='0')
tf.nn.relu6(g1_input, 'relu6') tf.nn.relu6(g1_input, 'relu6')
...@@ -708,7 +708,7 @@ if __name__ == '__main__': ...@@ -708,7 +708,7 @@ if __name__ == '__main__':
pow_test() pow_test()
relu_test() relu_test()
relu6_test() relu6_test()
relu6_mismatch_test() relu6_half_test()
reshape_test() reshape_test()
rsqrt_test() rsqrt_test()
shape_test() shape_test()
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
: :
0 Placeholder* 0 Placeholder*
dtype0* dtype0*
shape: % shape:
 
relu6Relu60* relu6Relu60*
T0" T0"
\ No newline at end of file
...@@ -729,27 +729,23 @@ TEST_CASE(relu6_test) ...@@ -729,27 +729,23 @@ TEST_CASE(relu6_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(relu6_mismatch_test) TEST_CASE(relu6_half_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
std::vector<size_t> input_lens{1, 3, 13, 37}; std::vector<size_t> input_lens{1, 3, 16, 16};
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::half_type, input_lens}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::half_type, input_lens});
auto min_val = mm->add_literal(0.0f); auto min_val =
auto max_val = mm->add_literal(6.0f); mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {0.0f}});
auto max_val =
auto l0_convert = mm->add_instruction( mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {6.0f}});
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), l0);
min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
min_val); min_val);
max_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), max_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
max_val); max_val);
mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val);
mm->add_instruction(migraphx::make_op("clip"), l0_convert, min_val, max_val); auto prog = optimize_tf("relu6_half_test.pb", false);
auto prog = optimize_tf("relu6_mismatch_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment