Commit b2efe895 authored by charlie's avatar charlie
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into nonstd_NMS

parents e4759983 c0398ded
...@@ -471,6 +471,15 @@ def relu6_test(g1): ...@@ -471,6 +471,15 @@ def relu6_test(g1):
tf.nn.relu6(g1_input, 'relu6') tf.nn.relu6(g1_input, 'relu6')
@tf_test
def relu6_mismatch_test(g1):
with g1.as_default():
g1_input = tf.compat.v1.placeholder(tf.float16,
shape=(1, 3, 13, 37),
name='0')
tf.nn.relu6(g1_input, 'relu6')
@tf_test @tf_test
def reshape_test(g1): def reshape_test(g1):
with g1.as_default(): with g1.as_default():
...@@ -676,6 +685,7 @@ if __name__ == '__main__': ...@@ -676,6 +685,7 @@ if __name__ == '__main__':
pow_test() pow_test()
relu_test() relu_test()
relu6_test() relu6_test()
relu6_mismatch_test()
reshape_test() reshape_test()
rsqrt_test() rsqrt_test()
shape_test() shape_test()
......
:
0 Placeholder*
dtype0*
shape: %

relu6Relu60*
T0"
\ No newline at end of file
...@@ -706,6 +706,31 @@ TEST_CASE(relu6_test) ...@@ -706,6 +706,31 @@ TEST_CASE(relu6_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(relu6_mismatch_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<size_t> input_lens{1, 3, 13, 37};
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::half_type, input_lens});
auto min_val = mm->add_literal(0.0f);
auto max_val = mm->add_literal(6.0f);
auto l0_convert = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), l0);
min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
min_val);
max_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
max_val);
mm->add_instruction(migraphx::make_op("clip"), l0_convert, min_val, max_val);
auto prog = optimize_tf("relu6_mismatch_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(reshape_test) TEST_CASE(reshape_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -28,6 +28,8 @@ struct allocation_model ...@@ -28,6 +28,8 @@ struct allocation_model
operation allocate(const shape& s) const; operation allocate(const shape& s) const;
/// Create a preallocated operator for the given shape /// Create a preallocated operator for the given shape
operation preallocate(const shape& s, const std::string& id) const; operation preallocate(const shape& s, const std::string& id) const;
/// Check if outputs are to be inserted
bool needs_out_params() const;
}; };
#else #else
...@@ -37,7 +39,8 @@ interface('allocation_model', ...@@ -37,7 +39,8 @@ interface('allocation_model',
virtual('name', returns='std::string', const=True), virtual('name', returns='std::string', const=True),
virtual('copy', returns='std::string', const=True), virtual('copy', returns='std::string', const=True),
virtual('allocate', s='const shape&', returns='operation', const=True), virtual('allocate', s='const shape&', returns='operation', const=True),
virtual('preallocate', s='const shape&', id='std::string', returns='operation', const=True) virtual('preallocate', s='const shape&', id='std::string', returns='operation', const=True),
virtual('needs_out_params', returns='bool', const=True)
) )
%> %>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment