Unverified Commit 0c98c38e authored by Ted Themistokleous's avatar Ted Themistokleous Committed by GitHub
Browse files

Merge branch 'develop' into enable_navi_32_ci

parents 1612d8f3 64b306ab
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/generate.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(unsqueeze_test_1)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> data(4 * 3 * 3);
migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 1, 3, 3}};
auto l0 = mm->add_literal(migraphx::literal{s1, data});
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
EXPECT(result.get_shape() == s2);
}
TEST_CASE(unsqueeze_test_2)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> data(4 * 3 * 3);
migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}};
auto l0 = mm->add_literal(migraphx::literal{s1, data});
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
EXPECT(result.get_shape() == s2);
}
TEST_CASE(unsqueeze_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s1{migraphx::shape::float_type, {{1, 4}, {3, 3}, {3, 3}}};
auto p0 = mm->add_parameter("x", s1);
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), p0);
p.compile(migraphx::make_target("ref"));
std::vector<float> input_data(4 * 3 * 3);
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {4, 3, 3}};
params0["x"] = migraphx::argument(input_fixed_shape0, input_data.data());
auto result = p.eval(params0).back();
migraphx::shape s2{migraphx::shape::float_type, {4, 1, 3, 3}};
EXPECT(result.get_shape() == s2);
}
TEST_CASE(unsqueeze_transpose_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}};
auto l0 = mm->add_literal(migraphx::generate_literal(s1));
auto l0_trans =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1}}}), l0);
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), l0_trans);
auto p_uncompiled = p;
auto* mm_uncompiled = p_uncompiled.get_main_module();
mm_uncompiled->add_instruction(migraphx::make_op("contiguous"),
std::prev(mm_uncompiled->end()));
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
auto expected_result = p_uncompiled.eval({}).back();
EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {3, 4, 1, 3}});
EXPECT(result == expected_result);
}
TEST_CASE(unsqueeze_multibroadcast_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3}};
auto l0 = mm->add_literal(migraphx::generate_literal(s1));
auto l0_brcst =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {4, 4, 3, 3}}}), l0);
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), l0_brcst);
auto p_uncompiled = p;
auto* mm_uncompiled = p_uncompiled.get_main_module();
mm_uncompiled->add_instruction(migraphx::make_op("contiguous"),
std::prev(mm_uncompiled->end()));
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
auto expected_result = p_uncompiled.eval({}).back();
EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {4, 4, 1, 3, 3}});
EXPECT(result == expected_result);
}
TEST_CASE(unsqueeze_slice_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4, 4}};
auto l0 = mm->add_literal(migraphx::generate_literal(s1));
auto l0_slice = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {2}}, {"ends", {3}}}), l0);
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l0_slice);
auto p_uncompiled = p;
auto* mm_uncompiled = p_uncompiled.get_main_module();
mm_uncompiled->add_instruction(migraphx::make_op("contiguous"),
std::prev(mm_uncompiled->end()));
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
auto expected_result = p_uncompiled.eval({}).back();
EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {2, 1, 3, 4, 1}});
EXPECT(result == expected_result);
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(where_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sb{migraphx::shape::bool_type, {3, 3}};
migraphx::shape sx{migraphx::shape::float_type, {3, 3}};
std::vector<bool> b{true, true, true, false, false, false, true, false, true};
std::vector<float> x(9, 1.0);
std::vector<float> y(9, 2.0);
auto lb = mm->add_literal(migraphx::literal{sb, b});
auto lx = mm->add_literal(migraphx::literal{sx, x});
auto ly = mm->add_literal(migraphx::literal{sx, y});
auto w = mm->add_instruction(migraphx::make_op("where"), lb, lx, ly);
mm->add_return({w});
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
std::vector<float> gold(9);
for(int i = 0; i < gold.size(); ++i)
gold[i] = b[i] ? x[i] : y[i];
EXPECT(migraphx::verify::verify_range(result_vec, gold));
}
TEST_CASE(where_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sb{migraphx::shape::bool_type, {{2, 3}, {2, 3}}};
migraphx::shape sx{migraphx::shape::float_type, {{2, 3}, {2, 3}}};
auto lb = mm->add_parameter("predicate", sb);
auto lx = mm->add_parameter("X", sx);
auto ly = mm->add_parameter("Y", sx);
mm->add_instruction(migraphx::make_op("where"), lb, lx, ly);
p.compile(migraphx::make_target("ref"));
std::vector<char> b{1, 1, 1, 0, 0, 0, 1, 0, 1};
std::vector<float> x(9, 1.0);
std::vector<float> y(9, 2.0);
migraphx::parameter_map params;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3, 3}};
migraphx::shape input_fixed_shape1{migraphx::shape::uint8_type, {3, 3}};
params["X"] = migraphx::argument(input_fixed_shape0, x.data());
params["Y"] = migraphx::argument(input_fixed_shape0, y.data());
params["predicate"] = migraphx::argument(input_fixed_shape1, b.data());
auto result = p.eval(params).back();
std::vector<float> results_vector(3 * 3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 1, 1, 2, 2, 2, 1, 2, 1};
EXPECT(migraphx::verify::verify_range(results_vector, gold));
}
TEST_CASE(where_broadcasted_inputs_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sb{migraphx::shape::bool_type, {3, 3}};
std::vector<bool> b{true, true, true, false, false, false, true, false, true};
auto lb = mm->add_literal(migraphx::literal{sb, b});
auto lx = mm->add_literal(migraphx::literal(1.0f));
auto ly = mm->add_literal(migraphx::literal(2.0f));
auto mbx = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 3}}}), lx);
auto mby = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 3}}}), ly);
auto w = mm->add_instruction(migraphx::make_op("where"), lb, mbx, mby);
mm->add_return({w});
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
std::vector<float> gold(9);
std::vector<float> x(9, 1.0);
std::vector<float> y(9, 2.0);
for(int i = 0; i < gold.size(); ++i)
gold[i] = b[i] ? x[i] : y[i];
EXPECT(migraphx::verify::verify_range(result_vec, gold));
}
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -24,7 +24,7 @@
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
......@@ -153,7 +153,7 @@ TEST_CASE(simplify_add_broadcast1)
{
migraphx::shape inner{migraphx::shape::int32_type, {2}};
migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::op::broadcast b{1, {1, 2, 3, 3}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}});
migraphx::module m1;
{
auto x = m1.add_parameter("x", outer);
......@@ -188,7 +188,7 @@ TEST_CASE(simplify_add_broadcast2)
{
migraphx::shape inner{migraphx::shape::int32_type, {2}};
migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::op::broadcast b{1, {1, 2, 3, 3}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}});
auto create_program = [&] {
migraphx::module m;
auto x = m.add_parameter("x", outer);
......@@ -539,7 +539,7 @@ TEST_CASE(simplify_conv_add)
TEST_CASE(simplify_inner_broadcast1)
{
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 1, 4, 5}}});
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}});
......@@ -564,7 +564,7 @@ TEST_CASE(simplify_inner_broadcast1)
TEST_CASE(simplify_inner_broadcast2)
{
auto b = migraphx::op::multibroadcast{{2, 1, 4, 5}};
auto b = migraphx::make_op("multibroadcast", {{"out_lens", {2, 1, 4, 5}}});
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1, 1, 1, 1}});
......@@ -589,7 +589,7 @@ TEST_CASE(simplify_inner_broadcast2)
TEST_CASE(simplify_inner_broadcast_scalar)
{
auto b = migraphx::op::multibroadcast{{32, 384}};
auto b = migraphx::make_op("multibroadcast", {{"out_lens", {32, 384}}});
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1, 384}});
......@@ -603,9 +603,10 @@ TEST_CASE(simplify_inner_broadcast_scalar)
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1, 384}});
auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {1, 1}});
auto yb = m2.add_instruction(migraphx::op::multibroadcast{{1, 384}}, y);
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1, 384}});
auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {1, 1}});
auto yb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 384}}}), y);
auto sum = m2.add_instruction(migraphx::make_op("add"), x, yb);
auto sumb = m2.add_instruction(b, sum);
m2.add_instruction(pass_op{}, sumb);
......@@ -615,7 +616,7 @@ TEST_CASE(simplify_inner_broadcast_scalar)
TEST_CASE(simplify_inner_broadcast_different_dims)
{
auto b = migraphx::op::multibroadcast{{2, 384, 768}};
auto b = migraphx::make_op("multibroadcast", {{"out_lens", {2, 384, 768}}});
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {384, 768}});
......@@ -629,9 +630,10 @@ TEST_CASE(simplify_inner_broadcast_different_dims)
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {384, 768}});
auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {768}});
auto yb = m2.add_instruction(migraphx::op::multibroadcast{{384, 768}}, y);
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {384, 768}});
auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {768}});
auto yb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {384, 768}}}), y);
auto sum = m2.add_instruction(migraphx::make_op("add"), x, yb);
auto sumb = m2.add_instruction(b, sum);
m2.add_instruction(pass_op{}, sumb);
......@@ -641,8 +643,8 @@ TEST_CASE(simplify_inner_broadcast_different_dims)
TEST_CASE(simplify_inner_broadcast_different_broadcasts)
{
auto b = migraphx::op::broadcast{1, {1, 24, 112, 112}};
auto mb = migraphx::op::multibroadcast{{1, 24, 112, 112}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 24, 112, 112}}});
auto mb = migraphx::make_op("multibroadcast", {{"out_lens", {1, 24, 112, 112}}});
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {24}});
......@@ -891,7 +893,7 @@ TEST_CASE(simplify_concat_add_relu_partial_broadcast)
auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}};
migraphx::module m1;
{
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 1, 4, 5}}});
auto x = m1.add_parameter("x", s);
auto y = m1.add_parameter("y", s);
auto one = m1.add_literal(1);
......@@ -907,7 +909,7 @@ TEST_CASE(simplify_concat_add_relu_partial_broadcast)
migraphx::module m2;
{
auto b = migraphx::op::broadcast{1, {2, 2, 4, 5}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 2, 4, 5}}});
auto x = m2.add_parameter("x", s);
auto y = m2.add_parameter("y", s);
auto one = m2.add_literal(1);
......@@ -926,7 +928,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_different_axis)
auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}};
migraphx::module m1;
{
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 1, 4, 5}}});
auto x = m1.add_parameter("x", s);
auto y = m1.add_parameter("y", s);
auto one = m1.add_literal(1);
......@@ -944,7 +946,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_different_axis)
migraphx::module m2;
{
auto b = migraphx::op::broadcast{1, {2, 2, 4, 5}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 2, 4, 5}}});
auto x = m2.add_parameter("x", s);
auto y = m2.add_parameter("y", s);
auto one = m2.add_literal(1);
......@@ -964,7 +966,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_same_axis)
auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}};
migraphx::module m1;
{
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 1, 4, 5}}});
auto x = m1.add_parameter("x", s);
auto y = m1.add_parameter("y", s);
auto one = m1.add_literal(1);
......@@ -982,7 +984,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_same_axis)
migraphx::module m2;
{
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 1, 4, 5}}});
auto x = m2.add_parameter("x", s);
auto y = m2.add_parameter("y", s);
auto one = m2.add_literal(1);
......@@ -1695,7 +1697,7 @@ TEST_CASE(simplify_split_add_relu)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::module m1;
{
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
......@@ -1716,7 +1718,7 @@ TEST_CASE(simplify_split_add_relu)
migraphx::module m2;
{
auto b = migraphx::op::broadcast{1, {3, 2, 4}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 2, 4}}});
auto input = m2.add_parameter("input", s);
auto one = m2.add_literal(1);
auto two = m2.add_literal(2);
......@@ -1846,8 +1848,8 @@ TEST_CASE(simplify_split_add_relu_reshape)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::module m1;
{
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto r = migraphx::op::reshape{{3, 4}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto r = migraphx::make_op("reshape", {{"dims", {3, 4}}});
auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
......@@ -1870,7 +1872,7 @@ TEST_CASE(simplify_split_add_relu_reshape)
migraphx::module m2;
{
auto b = migraphx::op::broadcast{1, {3, 2, 4}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 2, 4}}});
auto input = m2.add_parameter("input", s);
auto one = m2.add_literal(1);
auto two = m2.add_literal(2);
......@@ -1894,7 +1896,7 @@ TEST_CASE(simplify_slice_different_axis)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4, 2}};
migraphx::module m1;
{
auto r = migraphx::op::reshape{{3, 2, 4}};
auto r = migraphx::make_op("reshape", {{"dims", {3, 2, 4}}});
auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
......@@ -1926,7 +1928,7 @@ TEST_CASE(simplify_slice_missing_begining_slice)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}};
migraphx::module m1;
{
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {2}}, {"ends", {3}}}), input);
......@@ -1954,7 +1956,7 @@ TEST_CASE(simplify_slice_missing_middle_slice)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}};
migraphx::module m1;
{
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {2}}, {"ends", {3}}}), input);
......@@ -1982,7 +1984,7 @@ TEST_CASE(simplify_slice_missing_end_slice)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}};
migraphx::module m1;
{
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
......@@ -2010,7 +2012,7 @@ TEST_CASE(simplify_split_add_relu_concat_same_axis)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::module m1;
{
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
......@@ -2031,7 +2033,7 @@ TEST_CASE(simplify_split_add_relu_concat_same_axis)
migraphx::module m2;
{
auto b = migraphx::op::broadcast{1, {3, 2, 4}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 2, 4}}});
auto input = m2.add_parameter("input", s);
auto one = m2.add_literal(1);
auto two = m2.add_literal(2);
......@@ -2049,7 +2051,7 @@ TEST_CASE(simplify_split_add_relu_multi_axes)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4, 6}};
migraphx::module m1;
{
auto b = migraphx::op::broadcast{1, {3, 1, 4, 3}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4, 3}}});
auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1, 3}}, {"starts", {0, 0}}, {"ends", {1, 3}}}),
......@@ -2078,7 +2080,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split1)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::module m1;
{
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
......@@ -2100,7 +2102,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split1)
migraphx::module m2;
{
auto b = migraphx::op::broadcast{1, {3, 2, 4}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 2, 4}}});
auto input = m2.add_parameter("input", s);
auto slice = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
......@@ -2126,7 +2128,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split2)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::module m1;
{
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
......@@ -2149,7 +2151,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split2)
migraphx::module m2;
{
auto b = migraphx::op::broadcast{1, {3, 2, 4}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 2, 4}}});
auto input = m2.add_parameter("input", s);
auto slice = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
......@@ -3033,6 +3035,36 @@ void reorder_slice_trans_diff_perm()
TEST_CASE_REGISTER(reorder_slice_trans_diff_perm<1>);
TEST_CASE_REGISTER(reorder_slice_trans_diff_perm<4>);
TEST_CASE(reorder_slice_trans_multi_outputs)
{
migraphx::module m1;
{
auto s = migraphx::shape{migraphx::shape::float_type, {8, 128, 1920}};
auto input = m1.add_parameter("input", s);
std::vector<int64_t> perm = {0, 2, 1};
auto slc0 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {640}}}), input);
auto slc1 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {640}}, {"ends", {1280}}}),
input);
auto slc2 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1280}}, {"ends", {1920}}}),
input);
auto t0 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), slc0);
auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), slc1);
auto t2 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), slc2);
auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1);
auto dot = m1.add_instruction(migraphx::make_op("mul"), sum, t2);
auto slc_cont = m1.add_instruction(migraphx::make_op("contiguous"), slc1);
m1.add_return({slc_cont, dot});
};
run_pass(m1);
auto m2 = m1;
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(reorder_slice_ins_deps)
{
auto create_module = [] {
......
......@@ -24,7 +24,6 @@
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp>
#include <basic_ops.hpp>
......@@ -477,7 +476,7 @@ TEST_CASE(concat_multibroadcasts1)
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; });
auto md = std::distance(m.begin(), new_mb);
EXPECT(cd == md - 1);
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1);
EXPECT(new_concat->get_operator().to_value()["axis"].to<int>() == 1);
}
TEST_CASE(concat_multibroadcasts2)
......@@ -500,7 +499,7 @@ TEST_CASE(concat_multibroadcasts2)
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; });
auto md = std::distance(m.begin(), new_mb);
EXPECT(cd == md - 1);
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 0);
EXPECT(new_concat->get_operator().to_value()["axis"].to<int>() == 0);
}
TEST_CASE(concat_multibroadcasts3)
......@@ -523,7 +522,7 @@ TEST_CASE(concat_multibroadcasts3)
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; });
auto md = std::distance(m.begin(), new_mb);
EXPECT(cd == md - 1);
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 2);
EXPECT(new_concat->get_operator().to_value()["axis"].to<int>() == 2);
}
TEST_CASE(concat_multibroadcasts4)
......@@ -559,7 +558,7 @@ TEST_CASE(concat_transpose1)
auto new_concat =
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != m.end()});
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 3);
EXPECT(new_concat->get_operator().to_value()["axis"].to<int>() == 3);
}
TEST_CASE(concat_transpose2)
......@@ -583,7 +582,7 @@ TEST_CASE(concat_transpose2)
auto new_concat =
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != m.end()});
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1);
EXPECT(new_concat->get_operator().to_value()["axis"].to<int>() == 1);
}
TEST_CASE(concat_transpose3)
......@@ -607,7 +606,7 @@ TEST_CASE(concat_transpose3)
auto new_concat =
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != m.end()});
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1);
EXPECT(new_concat->get_operator().to_value()["axis"].to<int>() == 1);
}
TEST_CASE(concat_transpose4)
......
......@@ -99,4 +99,29 @@ TEST_CASE(interpolate_string_custom3)
EXPECT(s == "****b****");
}
TEST_CASE(slit_string_simple1)
{
std::string input = "one,two,three";
auto resuts = migraphx::split_string(input, ',');
EXPECT(resuts.size() == 3);
EXPECT(resuts.front() == "one");
EXPECT(resuts.back() == "three");
}
TEST_CASE(slit_string_simple2)
{
std::string input = "one";
auto resuts = migraphx::split_string(input, ',');
EXPECT(resuts.size() == 1);
EXPECT(resuts.front() == "one");
}
TEST_CASE(slit_string_simple3)
{
std::string input = "one two three";
auto resuts = migraphx::split_string(input, ',');
EXPECT(resuts.size() == 1);
EXPECT(resuts.front() == "one two three");
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -37,7 +37,6 @@
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/reduce_mean.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/serialize.hpp>
......@@ -840,12 +839,8 @@ TEST_CASE(slice_test)
mm->add_literal(migraphx::literal{s0, {1, 0}});
mm->add_literal(migraphx::literal{s0, {2, -1}});
migraphx::op::slice op;
op.starts = {1, 0};
op.ends = {3, 10};
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
mm->add_instruction(op, l0);
mm->add_instruction(
migraphx::make_op("slice", {{"starts", {1, 0}}, {"ends", {3, 10}}, {"axes", {0, 1}}}), l0);
auto prog = optimize_tf("slice_test.pb", false);
EXPECT(p == prog);
......@@ -975,13 +970,10 @@ TEST_CASE(stridedslice_test)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 1, 1}});
auto l1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
std::size_t num_axes = 4;
migraphx::op::slice op;
op.starts = {0, 0, 0, 0};
op.ends = {1, 1, 1, 5};
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
auto l2 = mm->add_instruction(op, l1);
auto l2 = mm->add_instruction(
migraphx::make_op(
"slice", {{"starts", {0, 0, 0, 0}}, {"ends", {1, 1, 1, 5}}, {"axes", {0, 1, 2, 3}}}),
l1);
auto shrink_axis = 1;
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {shrink_axis}}}), l2);
auto prog = optimize_tf("stridedslice_test.pb", true);
......@@ -995,12 +987,6 @@ TEST_CASE(stridedslice_masks_test)
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 3, 3}});
std::size_t num_axes = 4;
migraphx::op::slice op;
op.starts = {0, 1, 1, 0};
op.ends = {1, 3, 3, 10};
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
// add literals for starts, ends, and strides in tf (NHWC format)
mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}},
std::vector<int>{0, 1, 1, 0});
......@@ -1011,7 +997,10 @@ TEST_CASE(stridedslice_masks_test)
auto l1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
auto l2 = mm->add_instruction(op, l1);
auto l2 = mm->add_instruction(
migraphx::make_op(
"slice", {{"starts", {0, 1, 1, 0}}, {"ends", {1, 3, 3, 10}}, {"axes", {0, 1, 2, 3}}}),
l1);
auto l3 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), l2);
mm->add_return({l3});
......
......@@ -25,7 +25,7 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
struct gemm_literal : verify_program<gemm_literal>
{
......@@ -38,7 +38,7 @@ struct gemm_literal : verify_program<gemm_literal>
auto a = mm->add_literal(migraphx::generate_literal(a_shape));
auto b = mm->add_parameter("b", b_shape);
mm->add_instruction(migraphx::op::dot{}, a, b);
mm->add_instruction(migraphx::make_op("dot"), a, b);
return p;
}
......
......@@ -31,7 +31,7 @@ pip3 install -r requirements-dev.txt
# Add newer cmake to the path
export PATH="/opt/cmake/bin:$PATH"
export CXXFLAGS="-D__HIP_PLATFORM_AMD__=1 -w"
./build.sh --config Release --cmake_extra_defines CMAKE_HIP_COMPILER=/opt/rocm/llvm/bin/clang++ --update --build --parallel --cmake_extra_defines ONNXRUNTIME_VERSION=$(cat ./VERSION_NUMBER) --skip_tests --rocm_home /opt/rocm --use_migraphx --migraphx_home /opt/rocm --rocm_version=`cat /opt/rocm/.info/version-dev` --allow_running_as_root
./build.sh --config Release --cmake_extra_defines CMAKE_HIP_COMPILER=/opt/rocm/llvm/bin/clang++ --update --build --build_wheel --parallel --cmake_extra_defines ONNXRUNTIME_VERSION=$(cat ./VERSION_NUMBER) --skip_tests --rocm_home /opt/rocm --use_migraphx --migraphx_home /opt/rocm --rocm_version=`cat /opt/rocm/.info/version-dev` --allow_running_as_root
cd build/Linux/Release
#Add test launcher for onnxrt tests
......
......@@ -3,7 +3,7 @@ FROM registry.suse.com/suse/sle15:15.4
RUN sh -c 'echo -e "\
[rocm]\n\
name=rocm\n\
baseurl=https://repo.radeon.com/rocm/zyp/5.5/main\n\
baseurl=https://repo.radeon.com/rocm/zyp/5.6/main\n\
enabled=1\n\
gpgcheck=1\n\
gpgkey=https://repo.radeon.com/rocm/rocm.gpg.key\n\
......@@ -19,7 +19,12 @@ RUN zypper --gpg-auto-import-keys install -y \
gcc-c++ \
gdb \
git \
python3-pip
python3-pip \
rpm-build
#addition of repos for packages
RUN OPENSUSE_REPO=https://download.opensuse.org/repositories && \
zypper addrepo ${OPENSUSE_REPO}/devel:/languages:/perl/SLE_15_SP4/devel:languages:perl.repo
# Workaround broken rocm packages
RUN ln -s /opt/rocm-* /opt/rocm
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment