Unverified Commit 421a5621 authored by Ted Themistokleous's avatar Ted Themistokleous Committed by GitHub
Browse files

Update tf_parser to have add_common_op() for parse_relu6 (#1241)



* [#935] Update tf_parser to have add_common_op() for parse_relu6

Similar to that of the onnx_parser.cpp add a add_common_op template and functionality to support clip based operations. This is done so clip operations can be guarenteed to have the same dimensions.

* fixup! [#935] Update tf_parser to have add_common_op() for parse_relu6

* fixup! fixup! [#935] Update tf_parser to have add_common_op() for parse_relu6

* fixup! fixup! fixup! [#935] Update tf_parser to have add_common_op() for parse_relu6

* fixup! fixup! fixup! fixup! [#935] Update tf_parser to have add_common_op() for parse_relu6

* Formatting

* fixup! Formatting
Co-authored-by: default avatarUmang Yadav <29876643+umangyadav@users.noreply.github.com>
Co-authored-by: default avatarPaul Fultz II <pfultz2@yahoo.com>
parent add6fb3b
......@@ -33,6 +33,16 @@ struct tf_parser
instruction_ref add_broadcastable_binary_op(const std::string& op_name,
instruction_ref arg0,
instruction_ref arg1) const;
instruction_ref add_common_op(const std::string& op_name,
std::vector<instruction_ref> inputs) const;
template <class... Ts>
instruction_ref add_common_op(const std::string& op_name, Ts... xs) const
{
return add_common_op(op_name, {xs...});
}
instruction_ref add_instruction(const operation& op,
const std::vector<instruction_ref>& args) const;
......
......@@ -18,15 +18,10 @@ struct parse_relu6 : op_parser<parse_relu6>
const tf_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto input_lens = args[0]->get_shape().lens();
auto min_val = info.add_literal(0.0f);
auto max_val = info.add_literal(6.0f);
min_val =
info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}), min_val);
max_val =
info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}), max_val);
return info.add_instruction(make_op("clip"), args.front(), min_val, max_val);
return info.add_common_op("clip", args[0], min_val, max_val);
}
};
......
......@@ -79,7 +79,13 @@ instruction_ref tf_parser::node_info::add_broadcastable_binary_op(const std::str
instruction_ref arg0,
instruction_ref arg1) const
{
return add_common_op(*mm, make_op(op_name), {arg0, arg1});
return this->add_common_op(op_name, arg0, arg1);
}
instruction_ref tf_parser::node_info::add_common_op(const std::string& op_name,
std::vector<instruction_ref> inputs) const
{
return migraphx::add_common_op(*mm, make_op(op_name), std::move(inputs));
}
int64_t tf_parser::parse_axis(const int64_t dim, const size_t num_dims) const
......
......@@ -471,6 +471,15 @@ def relu6_test(g1):
tf.nn.relu6(g1_input, 'relu6')
@tf_test
def relu6_mismatch_test(g1):
with g1.as_default():
g1_input = tf.compat.v1.placeholder(tf.float16,
shape=(1, 3, 13, 37),
name='0')
tf.nn.relu6(g1_input, 'relu6')
@tf_test
def reshape_test(g1):
with g1.as_default():
......@@ -676,6 +685,7 @@ if __name__ == '__main__':
pow_test()
relu_test()
relu6_test()
relu6_mismatch_test()
reshape_test()
rsqrt_test()
shape_test()
......
:
0 Placeholder*
dtype0*
shape: %

relu6Relu60*
T0"
\ No newline at end of file
......@@ -706,6 +706,31 @@ TEST_CASE(relu6_test)
EXPECT(p == prog);
}
TEST_CASE(relu6_mismatch_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<size_t> input_lens{1, 3, 13, 37};
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::half_type, input_lens});
auto min_val = mm->add_literal(0.0f);
auto max_val = mm->add_literal(6.0f);
auto l0_convert = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), l0);
min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
min_val);
max_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
max_val);
mm->add_instruction(migraphx::make_op("clip"), l0_convert, min_val, max_val);
auto prog = optimize_tf("relu6_mismatch_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(reshape_test)
{
migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment