"test/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "13dfa914a5632b951cd686fbf18dcbd4652d7c89"
Unverified Commit c78ce73d authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'develop' into jit-layernorm-merge

parents 37ddce62 f2667056
...@@ -61,12 +61,19 @@ vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs ...@@ -61,12 +61,19 @@ vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs
[&](const auto& input) -> std::size_t { [&](const auto& input) -> std::size_t {
auto stride = input.strides()[axis]; auto stride = input.strides()[axis];
auto len = input.lens()[axis]; auto len = input.lens()[axis];
if(stride != 0 and stride != 1) if(not contains({0, 1}, stride))
return 1; return 1;
if(len == 1 and input.elements() > sizes.front()) if(len == 1 and input.elements() > sizes.front())
return sizes.front(); return sizes.front();
auto it = std::find_if( auto it = std::find_if(sizes.begin(), sizes.end(), [&](auto vsize) {
sizes.begin(), sizes.end(), [&](auto i) { return (len % i) == 0; }); // The len is divisible by the size and all the strides are divisible by
// the size
return (len % vsize) == 0 and
std::all_of(
input.strides().begin(), input.strides().end(), [&](auto i) {
return contains({0, 1}, i) or i % vsize == 0;
});
});
if(it != sizes.end()) if(it != sizes.end())
return *it; return *it;
return 1; return 1;
......
...@@ -41,8 +41,9 @@ struct parse_relu6 : op_parser<parse_relu6> ...@@ -41,8 +41,9 @@ struct parse_relu6 : op_parser<parse_relu6>
const tf_parser::node_info& info, const tf_parser::node_info& info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
auto min_val = info.add_literal(0.0f); shape::type_t output_type = args[0]->get_shape().type();
auto max_val = info.add_literal(6.0f); auto min_val = info.add_literal(migraphx::literal{migraphx::shape{output_type}, {0.0f}});
auto max_val = info.add_literal(migraphx::literal{migraphx::shape{output_type}, {6.0f}});
return info.add_common_op("clip", args[0], min_val, max_val); return info.add_common_op("clip", args[0], min_val, max_val);
} }
......
...@@ -495,10 +495,10 @@ def relu6_test(g1): ...@@ -495,10 +495,10 @@ def relu6_test(g1):
@tf_test @tf_test
def relu6_mismatch_test(g1): def relu6_half_test(g1):
with g1.as_default(): with g1.as_default():
g1_input = tf.compat.v1.placeholder(tf.float16, g1_input = tf.compat.v1.placeholder(tf.float16,
shape=(1, 3, 13, 37), shape=(1, 3, 16, 16),
name='0') name='0')
tf.nn.relu6(g1_input, 'relu6') tf.nn.relu6(g1_input, 'relu6')
...@@ -708,7 +708,7 @@ if __name__ == '__main__': ...@@ -708,7 +708,7 @@ if __name__ == '__main__':
pow_test() pow_test()
relu_test() relu_test()
relu6_test() relu6_test()
relu6_mismatch_test() relu6_half_test()
reshape_test() reshape_test()
rsqrt_test() rsqrt_test()
shape_test() shape_test()
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
: :
0 Placeholder* 0 Placeholder*
dtype0* dtype0*
shape: % shape:
 
relu6Relu60* relu6Relu60*
T0" T0"
\ No newline at end of file
...@@ -729,27 +729,23 @@ TEST_CASE(relu6_test) ...@@ -729,27 +729,23 @@ TEST_CASE(relu6_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(relu6_mismatch_test) TEST_CASE(relu6_half_test)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
std::vector<size_t> input_lens{1, 3, 13, 37}; std::vector<size_t> input_lens{1, 3, 16, 16};
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::half_type, input_lens}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::half_type, input_lens});
auto min_val = mm->add_literal(0.0f); auto min_val =
auto max_val = mm->add_literal(6.0f); mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {0.0f}});
auto max_val =
auto l0_convert = mm->add_instruction( mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {6.0f}});
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), l0);
min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
min_val); min_val);
max_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), max_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
max_val); max_val);
mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val);
mm->add_instruction(migraphx::make_op("clip"), l0_convert, min_val, max_val); auto prog = optimize_tf("relu6_half_test.pb", false);
auto prog = optimize_tf("relu6_mismatch_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_slice2 : verify_program<test_slice2>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {1, 44, 57, 57}};
auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, 44, 57, 57}});
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 44, 56, 56}});
auto slice0 = mm->add_instruction(
migraphx::make_op(
"slice",
{{"axes", {0, 2, 3, 1}}, {"starts", {0, 1, 1, 0}}, {"ends", {1, 57, 57, 44}}}),
x);
mm->add_instruction(migraphx::make_op("add"), y, slice0);
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment