"vscode:/vscode.git/clone" did not exist on "7189e614d7763bad19c95ab0b65afe299121e900"
Unverified Commit c78ce73d authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'develop' into jit-layernorm-merge

parents 37ddce62 f2667056
......@@ -61,12 +61,19 @@ vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs
[&](const auto& input) -> std::size_t {
auto stride = input.strides()[axis];
auto len = input.lens()[axis];
if(stride != 0 and stride != 1)
if(not contains({0, 1}, stride))
return 1;
if(len == 1 and input.elements() > sizes.front())
return sizes.front();
auto it = std::find_if(
sizes.begin(), sizes.end(), [&](auto i) { return (len % i) == 0; });
auto it = std::find_if(sizes.begin(), sizes.end(), [&](auto vsize) {
// The len is divisible by the size and all the strides are divisible by
// the size
return (len % vsize) == 0 and
std::all_of(
input.strides().begin(), input.strides().end(), [&](auto i) {
return contains({0, 1}, i) or i % vsize == 0;
});
});
if(it != sizes.end())
return *it;
return 1;
......
......@@ -41,8 +41,9 @@ struct parse_relu6 : op_parser<parse_relu6>
const tf_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto min_val = info.add_literal(0.0f);
auto max_val = info.add_literal(6.0f);
shape::type_t output_type = args[0]->get_shape().type();
auto min_val = info.add_literal(migraphx::literal{migraphx::shape{output_type}, {0.0f}});
auto max_val = info.add_literal(migraphx::literal{migraphx::shape{output_type}, {6.0f}});
return info.add_common_op("clip", args[0], min_val, max_val);
}
......
......@@ -495,10 +495,10 @@ def relu6_test(g1):
@tf_test
def relu6_mismatch_test(g1):
def relu6_half_test(g1):
with g1.as_default():
g1_input = tf.compat.v1.placeholder(tf.float16,
shape=(1, 3, 13, 37),
shape=(1, 3, 16, 16),
name='0')
tf.nn.relu6(g1_input, 'relu6')
......@@ -708,7 +708,7 @@ if __name__ == '__main__':
pow_test()
relu_test()
relu6_test()
relu6_mismatch_test()
relu6_half_test()
reshape_test()
rsqrt_test()
shape_test()
......
......@@ -2,7 +2,7 @@
:
0 Placeholder*
dtype0*
shape: %
shape:

relu6Relu60*
T0"
\ No newline at end of file
......@@ -729,27 +729,23 @@ TEST_CASE(relu6_test)
EXPECT(p == prog);
}
TEST_CASE(relu6_mismatch_test)
TEST_CASE(relu6_half_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<size_t> input_lens{1, 3, 13, 37};
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::half_type, input_lens});
auto min_val = mm->add_literal(0.0f);
auto max_val = mm->add_literal(6.0f);
auto l0_convert = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), l0);
std::vector<size_t> input_lens{1, 3, 16, 16};
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::half_type, input_lens});
auto min_val =
mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {0.0f}});
auto max_val =
mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {6.0f}});
min_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
min_val);
max_val = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
max_val);
mm->add_instruction(migraphx::make_op("clip"), l0_convert, min_val, max_val);
auto prog = optimize_tf("relu6_mismatch_test.pb", false);
mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val);
auto prog = optimize_tf("relu6_half_test.pb", false);
EXPECT(p == prog);
}
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_slice2 : verify_program<test_slice2>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {1, 44, 57, 57}};
auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, 44, 57, 57}});
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 44, 56, 56}});
auto slice0 = mm->add_instruction(
migraphx::make_op(
"slice",
{{"axes", {0, 2, 3, 1}}, {"starts", {0, 1, 1, 0}}, {"ends", {1, 57, 57, 44}}}),
x);
mm->add_instruction(migraphx::make_op("add"), y, slice0);
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment