"tools/vscode:/vscode.git/clone" did not exist on "6902d16075482a4af3b8c3a3d21cf6fdf0c288f0"
Unverified Commit 0c98c38e authored by Ted Themistokleous's avatar Ted Themistokleous Committed by GitHub
Browse files

Merge branch 'develop' into enable_navi_32_ci

parents 1612d8f3 64b306ab
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/generate.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(unsqueeze_test_1)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> data(4 * 3 * 3);
migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 1, 3, 3}};
auto l0 = mm->add_literal(migraphx::literal{s1, data});
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
EXPECT(result.get_shape() == s2);
}
TEST_CASE(unsqueeze_test_2)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> data(4 * 3 * 3);
migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}};
migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}};
auto l0 = mm->add_literal(migraphx::literal{s1, data});
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), l0);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
EXPECT(result.get_shape() == s2);
}
TEST_CASE(unsqueeze_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s1{migraphx::shape::float_type, {{1, 4}, {3, 3}, {3, 3}}};
auto p0 = mm->add_parameter("x", s1);
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), p0);
p.compile(migraphx::make_target("ref"));
std::vector<float> input_data(4 * 3 * 3);
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {4, 3, 3}};
params0["x"] = migraphx::argument(input_fixed_shape0, input_data.data());
auto result = p.eval(params0).back();
migraphx::shape s2{migraphx::shape::float_type, {4, 1, 3, 3}};
EXPECT(result.get_shape() == s2);
}
TEST_CASE(unsqueeze_transpose_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}};
auto l0 = mm->add_literal(migraphx::generate_literal(s1));
auto l0_trans =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1}}}), l0);
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), l0_trans);
auto p_uncompiled = p;
auto* mm_uncompiled = p_uncompiled.get_main_module();
mm_uncompiled->add_instruction(migraphx::make_op("contiguous"),
std::prev(mm_uncompiled->end()));
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
auto expected_result = p_uncompiled.eval({}).back();
EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {3, 4, 1, 3}});
EXPECT(result == expected_result);
}
TEST_CASE(unsqueeze_multibroadcast_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s1{migraphx::shape::float_type, {4, 1, 3}};
auto l0 = mm->add_literal(migraphx::generate_literal(s1));
auto l0_brcst =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {4, 4, 3, 3}}}), l0);
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), l0_brcst);
auto p_uncompiled = p;
auto* mm_uncompiled = p_uncompiled.get_main_module();
mm_uncompiled->add_instruction(migraphx::make_op("contiguous"),
std::prev(mm_uncompiled->end()));
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
auto expected_result = p_uncompiled.eval({}).back();
EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {4, 4, 1, 3, 3}});
EXPECT(result == expected_result);
}
TEST_CASE(unsqueeze_slice_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s1{migraphx::shape::float_type, {2, 3, 4, 4}};
auto l0 = mm->add_literal(migraphx::generate_literal(s1));
auto l0_slice = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {3}}, {"starts", {2}}, {"ends", {3}}}), l0);
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l0_slice);
auto p_uncompiled = p;
auto* mm_uncompiled = p_uncompiled.get_main_module();
mm_uncompiled->add_instruction(migraphx::make_op("contiguous"),
std::prev(mm_uncompiled->end()));
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
auto expected_result = p_uncompiled.eval({}).back();
EXPECT(result.get_shape() == migraphx::shape{migraphx::shape::float_type, {2, 1, 3, 4, 1}});
EXPECT(result == expected_result);
}
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
TEST_CASE(where_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sb{migraphx::shape::bool_type, {3, 3}};
migraphx::shape sx{migraphx::shape::float_type, {3, 3}};
std::vector<bool> b{true, true, true, false, false, false, true, false, true};
std::vector<float> x(9, 1.0);
std::vector<float> y(9, 2.0);
auto lb = mm->add_literal(migraphx::literal{sb, b});
auto lx = mm->add_literal(migraphx::literal{sx, x});
auto ly = mm->add_literal(migraphx::literal{sx, y});
auto w = mm->add_instruction(migraphx::make_op("where"), lb, lx, ly);
mm->add_return({w});
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
std::vector<float> gold(9);
for(int i = 0; i < gold.size(); ++i)
gold[i] = b[i] ? x[i] : y[i];
EXPECT(migraphx::verify::verify_range(result_vec, gold));
}
TEST_CASE(where_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sb{migraphx::shape::bool_type, {{2, 3}, {2, 3}}};
migraphx::shape sx{migraphx::shape::float_type, {{2, 3}, {2, 3}}};
auto lb = mm->add_parameter("predicate", sb);
auto lx = mm->add_parameter("X", sx);
auto ly = mm->add_parameter("Y", sx);
mm->add_instruction(migraphx::make_op("where"), lb, lx, ly);
p.compile(migraphx::make_target("ref"));
std::vector<char> b{1, 1, 1, 0, 0, 0, 1, 0, 1};
std::vector<float> x(9, 1.0);
std::vector<float> y(9, 2.0);
migraphx::parameter_map params;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {3, 3}};
migraphx::shape input_fixed_shape1{migraphx::shape::uint8_type, {3, 3}};
params["X"] = migraphx::argument(input_fixed_shape0, x.data());
params["Y"] = migraphx::argument(input_fixed_shape0, y.data());
params["predicate"] = migraphx::argument(input_fixed_shape1, b.data());
auto result = p.eval(params).back();
std::vector<float> results_vector(3 * 3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 1, 1, 2, 2, 2, 1, 2, 1};
EXPECT(migraphx::verify::verify_range(results_vector, gold));
}
TEST_CASE(where_broadcasted_inputs_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sb{migraphx::shape::bool_type, {3, 3}};
std::vector<bool> b{true, true, true, false, false, false, true, false, true};
auto lb = mm->add_literal(migraphx::literal{sb, b});
auto lx = mm->add_literal(migraphx::literal(1.0f));
auto ly = mm->add_literal(migraphx::literal(2.0f));
auto mbx = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 3}}}), lx);
auto mby = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 3}}}), ly);
auto w = mm->add_instruction(migraphx::make_op("where"), lb, mbx, mby);
mm->add_return({w});
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
std::vector<float> gold(9);
std::vector<float> x(9, 1.0);
std::vector<float> y(9, 2.0);
for(int i = 0; i < gold.size(); ++i)
gold[i] = b[i] ? x[i] : y[i];
EXPECT(migraphx::verify::verify_range(result_vec, gold));
}
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include <migraphx/simplify_algebra.hpp> #include <migraphx/simplify_algebra.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/operators.hpp> #include <migraphx/permutation.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
...@@ -153,7 +153,7 @@ TEST_CASE(simplify_add_broadcast1) ...@@ -153,7 +153,7 @@ TEST_CASE(simplify_add_broadcast1)
{ {
migraphx::shape inner{migraphx::shape::int32_type, {2}}; migraphx::shape inner{migraphx::shape::int32_type, {2}};
migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}}; migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::op::broadcast b{1, {1, 2, 3, 3}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}});
migraphx::module m1; migraphx::module m1;
{ {
auto x = m1.add_parameter("x", outer); auto x = m1.add_parameter("x", outer);
...@@ -188,7 +188,7 @@ TEST_CASE(simplify_add_broadcast2) ...@@ -188,7 +188,7 @@ TEST_CASE(simplify_add_broadcast2)
{ {
migraphx::shape inner{migraphx::shape::int32_type, {2}}; migraphx::shape inner{migraphx::shape::int32_type, {2}};
migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}}; migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::op::broadcast b{1, {1, 2, 3, 3}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}});
auto create_program = [&] { auto create_program = [&] {
migraphx::module m; migraphx::module m;
auto x = m.add_parameter("x", outer); auto x = m.add_parameter("x", outer);
...@@ -539,7 +539,7 @@ TEST_CASE(simplify_conv_add) ...@@ -539,7 +539,7 @@ TEST_CASE(simplify_conv_add)
TEST_CASE(simplify_inner_broadcast1) TEST_CASE(simplify_inner_broadcast1)
{ {
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 1, 4, 5}}});
migraphx::module m1; migraphx::module m1;
{ {
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}}); auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}});
...@@ -564,7 +564,7 @@ TEST_CASE(simplify_inner_broadcast1) ...@@ -564,7 +564,7 @@ TEST_CASE(simplify_inner_broadcast1)
TEST_CASE(simplify_inner_broadcast2) TEST_CASE(simplify_inner_broadcast2)
{ {
auto b = migraphx::op::multibroadcast{{2, 1, 4, 5}}; auto b = migraphx::make_op("multibroadcast", {{"out_lens", {2, 1, 4, 5}}});
migraphx::module m1; migraphx::module m1;
{ {
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1, 1, 1, 1}}); auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1, 1, 1, 1}});
...@@ -589,7 +589,7 @@ TEST_CASE(simplify_inner_broadcast2) ...@@ -589,7 +589,7 @@ TEST_CASE(simplify_inner_broadcast2)
TEST_CASE(simplify_inner_broadcast_scalar) TEST_CASE(simplify_inner_broadcast_scalar)
{ {
auto b = migraphx::op::multibroadcast{{32, 384}}; auto b = migraphx::make_op("multibroadcast", {{"out_lens", {32, 384}}});
migraphx::module m1; migraphx::module m1;
{ {
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1, 384}}); auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1, 384}});
...@@ -603,9 +603,10 @@ TEST_CASE(simplify_inner_broadcast_scalar) ...@@ -603,9 +603,10 @@ TEST_CASE(simplify_inner_broadcast_scalar)
migraphx::module m2; migraphx::module m2;
{ {
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1, 384}}); auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1, 384}});
auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {1, 1}}); auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {1, 1}});
auto yb = m2.add_instruction(migraphx::op::multibroadcast{{1, 384}}, y); auto yb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 384}}}), y);
auto sum = m2.add_instruction(migraphx::make_op("add"), x, yb); auto sum = m2.add_instruction(migraphx::make_op("add"), x, yb);
auto sumb = m2.add_instruction(b, sum); auto sumb = m2.add_instruction(b, sum);
m2.add_instruction(pass_op{}, sumb); m2.add_instruction(pass_op{}, sumb);
...@@ -615,7 +616,7 @@ TEST_CASE(simplify_inner_broadcast_scalar) ...@@ -615,7 +616,7 @@ TEST_CASE(simplify_inner_broadcast_scalar)
TEST_CASE(simplify_inner_broadcast_different_dims) TEST_CASE(simplify_inner_broadcast_different_dims)
{ {
auto b = migraphx::op::multibroadcast{{2, 384, 768}}; auto b = migraphx::make_op("multibroadcast", {{"out_lens", {2, 384, 768}}});
migraphx::module m1; migraphx::module m1;
{ {
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {384, 768}}); auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {384, 768}});
...@@ -629,9 +630,10 @@ TEST_CASE(simplify_inner_broadcast_different_dims) ...@@ -629,9 +630,10 @@ TEST_CASE(simplify_inner_broadcast_different_dims)
migraphx::module m2; migraphx::module m2;
{ {
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {384, 768}}); auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {384, 768}});
auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {768}}); auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {768}});
auto yb = m2.add_instruction(migraphx::op::multibroadcast{{384, 768}}, y); auto yb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {384, 768}}}), y);
auto sum = m2.add_instruction(migraphx::make_op("add"), x, yb); auto sum = m2.add_instruction(migraphx::make_op("add"), x, yb);
auto sumb = m2.add_instruction(b, sum); auto sumb = m2.add_instruction(b, sum);
m2.add_instruction(pass_op{}, sumb); m2.add_instruction(pass_op{}, sumb);
...@@ -641,8 +643,8 @@ TEST_CASE(simplify_inner_broadcast_different_dims) ...@@ -641,8 +643,8 @@ TEST_CASE(simplify_inner_broadcast_different_dims)
TEST_CASE(simplify_inner_broadcast_different_broadcasts) TEST_CASE(simplify_inner_broadcast_different_broadcasts)
{ {
auto b = migraphx::op::broadcast{1, {1, 24, 112, 112}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 24, 112, 112}}});
auto mb = migraphx::op::multibroadcast{{1, 24, 112, 112}}; auto mb = migraphx::make_op("multibroadcast", {{"out_lens", {1, 24, 112, 112}}});
migraphx::module m1; migraphx::module m1;
{ {
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {24}}); auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {24}});
...@@ -891,7 +893,7 @@ TEST_CASE(simplify_concat_add_relu_partial_broadcast) ...@@ -891,7 +893,7 @@ TEST_CASE(simplify_concat_add_relu_partial_broadcast)
auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}}; auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}};
migraphx::module m1; migraphx::module m1;
{ {
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 1, 4, 5}}});
auto x = m1.add_parameter("x", s); auto x = m1.add_parameter("x", s);
auto y = m1.add_parameter("y", s); auto y = m1.add_parameter("y", s);
auto one = m1.add_literal(1); auto one = m1.add_literal(1);
...@@ -907,7 +909,7 @@ TEST_CASE(simplify_concat_add_relu_partial_broadcast) ...@@ -907,7 +909,7 @@ TEST_CASE(simplify_concat_add_relu_partial_broadcast)
migraphx::module m2; migraphx::module m2;
{ {
auto b = migraphx::op::broadcast{1, {2, 2, 4, 5}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 2, 4, 5}}});
auto x = m2.add_parameter("x", s); auto x = m2.add_parameter("x", s);
auto y = m2.add_parameter("y", s); auto y = m2.add_parameter("y", s);
auto one = m2.add_literal(1); auto one = m2.add_literal(1);
...@@ -926,7 +928,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_different_axis) ...@@ -926,7 +928,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_different_axis)
auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}}; auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}};
migraphx::module m1; migraphx::module m1;
{ {
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 1, 4, 5}}});
auto x = m1.add_parameter("x", s); auto x = m1.add_parameter("x", s);
auto y = m1.add_parameter("y", s); auto y = m1.add_parameter("y", s);
auto one = m1.add_literal(1); auto one = m1.add_literal(1);
...@@ -944,7 +946,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_different_axis) ...@@ -944,7 +946,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_different_axis)
migraphx::module m2; migraphx::module m2;
{ {
auto b = migraphx::op::broadcast{1, {2, 2, 4, 5}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 2, 4, 5}}});
auto x = m2.add_parameter("x", s); auto x = m2.add_parameter("x", s);
auto y = m2.add_parameter("y", s); auto y = m2.add_parameter("y", s);
auto one = m2.add_literal(1); auto one = m2.add_literal(1);
...@@ -964,7 +966,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_same_axis) ...@@ -964,7 +966,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_same_axis)
auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}}; auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}};
migraphx::module m1; migraphx::module m1;
{ {
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 1, 4, 5}}});
auto x = m1.add_parameter("x", s); auto x = m1.add_parameter("x", s);
auto y = m1.add_parameter("y", s); auto y = m1.add_parameter("y", s);
auto one = m1.add_literal(1); auto one = m1.add_literal(1);
...@@ -982,7 +984,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_same_axis) ...@@ -982,7 +984,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_same_axis)
migraphx::module m2; migraphx::module m2;
{ {
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 1, 4, 5}}});
auto x = m2.add_parameter("x", s); auto x = m2.add_parameter("x", s);
auto y = m2.add_parameter("y", s); auto y = m2.add_parameter("y", s);
auto one = m2.add_literal(1); auto one = m2.add_literal(1);
...@@ -1695,7 +1697,7 @@ TEST_CASE(simplify_split_add_relu) ...@@ -1695,7 +1697,7 @@ TEST_CASE(simplify_split_add_relu)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}}; auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::module m1; migraphx::module m1;
{ {
auto b = migraphx::op::broadcast{1, {3, 1, 4}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s); auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction( auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
...@@ -1716,7 +1718,7 @@ TEST_CASE(simplify_split_add_relu) ...@@ -1716,7 +1718,7 @@ TEST_CASE(simplify_split_add_relu)
migraphx::module m2; migraphx::module m2;
{ {
auto b = migraphx::op::broadcast{1, {3, 2, 4}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 2, 4}}});
auto input = m2.add_parameter("input", s); auto input = m2.add_parameter("input", s);
auto one = m2.add_literal(1); auto one = m2.add_literal(1);
auto two = m2.add_literal(2); auto two = m2.add_literal(2);
...@@ -1846,8 +1848,8 @@ TEST_CASE(simplify_split_add_relu_reshape) ...@@ -1846,8 +1848,8 @@ TEST_CASE(simplify_split_add_relu_reshape)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}}; auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::module m1; migraphx::module m1;
{ {
auto b = migraphx::op::broadcast{1, {3, 1, 4}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto r = migraphx::op::reshape{{3, 4}}; auto r = migraphx::make_op("reshape", {{"dims", {3, 4}}});
auto input = m1.add_parameter("input", s); auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction( auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
...@@ -1870,7 +1872,7 @@ TEST_CASE(simplify_split_add_relu_reshape) ...@@ -1870,7 +1872,7 @@ TEST_CASE(simplify_split_add_relu_reshape)
migraphx::module m2; migraphx::module m2;
{ {
auto b = migraphx::op::broadcast{1, {3, 2, 4}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 2, 4}}});
auto input = m2.add_parameter("input", s); auto input = m2.add_parameter("input", s);
auto one = m2.add_literal(1); auto one = m2.add_literal(1);
auto two = m2.add_literal(2); auto two = m2.add_literal(2);
...@@ -1894,7 +1896,7 @@ TEST_CASE(simplify_slice_different_axis) ...@@ -1894,7 +1896,7 @@ TEST_CASE(simplify_slice_different_axis)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4, 2}}; auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4, 2}};
migraphx::module m1; migraphx::module m1;
{ {
auto r = migraphx::op::reshape{{3, 2, 4}}; auto r = migraphx::make_op("reshape", {{"dims", {3, 2, 4}}});
auto input = m1.add_parameter("input", s); auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction( auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
...@@ -1926,7 +1928,7 @@ TEST_CASE(simplify_slice_missing_begining_slice) ...@@ -1926,7 +1928,7 @@ TEST_CASE(simplify_slice_missing_begining_slice)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}}; auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}};
migraphx::module m1; migraphx::module m1;
{ {
auto b = migraphx::op::broadcast{1, {3, 1, 4}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s); auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction( auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {2}}, {"ends", {3}}}), input); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {2}}, {"ends", {3}}}), input);
...@@ -1954,7 +1956,7 @@ TEST_CASE(simplify_slice_missing_middle_slice) ...@@ -1954,7 +1956,7 @@ TEST_CASE(simplify_slice_missing_middle_slice)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}}; auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}};
migraphx::module m1; migraphx::module m1;
{ {
auto b = migraphx::op::broadcast{1, {3, 1, 4}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s); auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction( auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {2}}, {"ends", {3}}}), input); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {2}}, {"ends", {3}}}), input);
...@@ -1982,7 +1984,7 @@ TEST_CASE(simplify_slice_missing_end_slice) ...@@ -1982,7 +1984,7 @@ TEST_CASE(simplify_slice_missing_end_slice)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}}; auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}};
migraphx::module m1; migraphx::module m1;
{ {
auto b = migraphx::op::broadcast{1, {3, 1, 4}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s); auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction( auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
...@@ -2010,7 +2012,7 @@ TEST_CASE(simplify_split_add_relu_concat_same_axis) ...@@ -2010,7 +2012,7 @@ TEST_CASE(simplify_split_add_relu_concat_same_axis)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}}; auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::module m1; migraphx::module m1;
{ {
auto b = migraphx::op::broadcast{1, {3, 1, 4}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s); auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction( auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
...@@ -2031,7 +2033,7 @@ TEST_CASE(simplify_split_add_relu_concat_same_axis) ...@@ -2031,7 +2033,7 @@ TEST_CASE(simplify_split_add_relu_concat_same_axis)
migraphx::module m2; migraphx::module m2;
{ {
auto b = migraphx::op::broadcast{1, {3, 2, 4}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 2, 4}}});
auto input = m2.add_parameter("input", s); auto input = m2.add_parameter("input", s);
auto one = m2.add_literal(1); auto one = m2.add_literal(1);
auto two = m2.add_literal(2); auto two = m2.add_literal(2);
...@@ -2049,7 +2051,7 @@ TEST_CASE(simplify_split_add_relu_multi_axes) ...@@ -2049,7 +2051,7 @@ TEST_CASE(simplify_split_add_relu_multi_axes)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4, 6}}; auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4, 6}};
migraphx::module m1; migraphx::module m1;
{ {
auto b = migraphx::op::broadcast{1, {3, 1, 4, 3}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4, 3}}});
auto input = m1.add_parameter("input", s); auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction( auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1, 3}}, {"starts", {0, 0}}, {"ends", {1, 3}}}), migraphx::make_op("slice", {{"axes", {1, 3}}, {"starts", {0, 0}}, {"ends", {1, 3}}}),
...@@ -2078,7 +2080,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split1) ...@@ -2078,7 +2080,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split1)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}}; auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::module m1; migraphx::module m1;
{ {
auto b = migraphx::op::broadcast{1, {3, 1, 4}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s); auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction( auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
...@@ -2100,7 +2102,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split1) ...@@ -2100,7 +2102,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split1)
migraphx::module m2; migraphx::module m2;
{ {
auto b = migraphx::op::broadcast{1, {3, 2, 4}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 2, 4}}});
auto input = m2.add_parameter("input", s); auto input = m2.add_parameter("input", s);
auto slice = m2.add_instruction( auto slice = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
...@@ -2126,7 +2128,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split2) ...@@ -2126,7 +2128,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split2)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}}; auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::module m1; migraphx::module m1;
{ {
auto b = migraphx::op::broadcast{1, {3, 1, 4}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s); auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction( auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
...@@ -2149,7 +2151,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split2) ...@@ -2149,7 +2151,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split2)
migraphx::module m2; migraphx::module m2;
{ {
auto b = migraphx::op::broadcast{1, {3, 2, 4}}; auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 2, 4}}});
auto input = m2.add_parameter("input", s); auto input = m2.add_parameter("input", s);
auto slice = m2.add_instruction( auto slice = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input); migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
...@@ -3033,6 +3035,36 @@ void reorder_slice_trans_diff_perm() ...@@ -3033,6 +3035,36 @@ void reorder_slice_trans_diff_perm()
TEST_CASE_REGISTER(reorder_slice_trans_diff_perm<1>); TEST_CASE_REGISTER(reorder_slice_trans_diff_perm<1>);
TEST_CASE_REGISTER(reorder_slice_trans_diff_perm<4>); TEST_CASE_REGISTER(reorder_slice_trans_diff_perm<4>);
TEST_CASE(reorder_slice_trans_multi_outputs)
{
migraphx::module m1;
{
auto s = migraphx::shape{migraphx::shape::float_type, {8, 128, 1920}};
auto input = m1.add_parameter("input", s);
std::vector<int64_t> perm = {0, 2, 1};
auto slc0 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {640}}}), input);
auto slc1 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {640}}, {"ends", {1280}}}),
input);
auto slc2 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1280}}, {"ends", {1920}}}),
input);
auto t0 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), slc0);
auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), slc1);
auto t2 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), slc2);
auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1);
auto dot = m1.add_instruction(migraphx::make_op("mul"), sum, t2);
auto slc_cont = m1.add_instruction(migraphx::make_op("contiguous"), slc1);
m1.add_return({slc_cont, dot});
};
run_pass(m1);
auto m2 = m1;
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(reorder_slice_ins_deps) TEST_CASE(reorder_slice_ins_deps)
{ {
auto create_module = [] { auto create_module = [] {
......
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
#include <migraphx/simplify_reshapes.hpp> #include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
...@@ -477,7 +476,7 @@ TEST_CASE(concat_multibroadcasts1) ...@@ -477,7 +476,7 @@ TEST_CASE(concat_multibroadcasts1)
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; }); std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; });
auto md = std::distance(m.begin(), new_mb); auto md = std::distance(m.begin(), new_mb);
EXPECT(cd == md - 1); EXPECT(cd == md - 1);
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1); EXPECT(new_concat->get_operator().to_value()["axis"].to<int>() == 1);
} }
TEST_CASE(concat_multibroadcasts2) TEST_CASE(concat_multibroadcasts2)
...@@ -500,7 +499,7 @@ TEST_CASE(concat_multibroadcasts2) ...@@ -500,7 +499,7 @@ TEST_CASE(concat_multibroadcasts2)
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; }); std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; });
auto md = std::distance(m.begin(), new_mb); auto md = std::distance(m.begin(), new_mb);
EXPECT(cd == md - 1); EXPECT(cd == md - 1);
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 0); EXPECT(new_concat->get_operator().to_value()["axis"].to<int>() == 0);
} }
TEST_CASE(concat_multibroadcasts3) TEST_CASE(concat_multibroadcasts3)
...@@ -523,7 +522,7 @@ TEST_CASE(concat_multibroadcasts3) ...@@ -523,7 +522,7 @@ TEST_CASE(concat_multibroadcasts3)
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; }); std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; });
auto md = std::distance(m.begin(), new_mb); auto md = std::distance(m.begin(), new_mb);
EXPECT(cd == md - 1); EXPECT(cd == md - 1);
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 2); EXPECT(new_concat->get_operator().to_value()["axis"].to<int>() == 2);
} }
TEST_CASE(concat_multibroadcasts4) TEST_CASE(concat_multibroadcasts4)
...@@ -559,7 +558,7 @@ TEST_CASE(concat_transpose1) ...@@ -559,7 +558,7 @@ TEST_CASE(concat_transpose1)
auto new_concat = auto new_concat =
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; }); std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != m.end()}); EXPECT(bool{new_concat != m.end()});
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 3); EXPECT(new_concat->get_operator().to_value()["axis"].to<int>() == 3);
} }
TEST_CASE(concat_transpose2) TEST_CASE(concat_transpose2)
...@@ -583,7 +582,7 @@ TEST_CASE(concat_transpose2) ...@@ -583,7 +582,7 @@ TEST_CASE(concat_transpose2)
auto new_concat = auto new_concat =
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; }); std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != m.end()}); EXPECT(bool{new_concat != m.end()});
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1); EXPECT(new_concat->get_operator().to_value()["axis"].to<int>() == 1);
} }
TEST_CASE(concat_transpose3) TEST_CASE(concat_transpose3)
...@@ -607,7 +606,7 @@ TEST_CASE(concat_transpose3) ...@@ -607,7 +606,7 @@ TEST_CASE(concat_transpose3)
auto new_concat = auto new_concat =
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; }); std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != m.end()}); EXPECT(bool{new_concat != m.end()});
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1); EXPECT(new_concat->get_operator().to_value()["axis"].to<int>() == 1);
} }
TEST_CASE(concat_transpose4) TEST_CASE(concat_transpose4)
......
...@@ -99,4 +99,29 @@ TEST_CASE(interpolate_string_custom3) ...@@ -99,4 +99,29 @@ TEST_CASE(interpolate_string_custom3)
EXPECT(s == "****b****"); EXPECT(s == "****b****");
} }
TEST_CASE(slit_string_simple1)
{
std::string input = "one,two,three";
auto resuts = migraphx::split_string(input, ',');
EXPECT(resuts.size() == 3);
EXPECT(resuts.front() == "one");
EXPECT(resuts.back() == "three");
}
TEST_CASE(slit_string_simple2)
{
std::string input = "one";
auto resuts = migraphx::split_string(input, ',');
EXPECT(resuts.size() == 1);
EXPECT(resuts.front() == "one");
}
TEST_CASE(slit_string_simple3)
{
std::string input = "one two three";
auto resuts = migraphx::split_string(input, ',');
EXPECT(resuts.size() == 1);
EXPECT(resuts.front() == "one two three");
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -37,7 +37,6 @@ ...@@ -37,7 +37,6 @@
#include <migraphx/op/convolution.hpp> #include <migraphx/op/convolution.hpp>
#include <migraphx/op/reduce_mean.hpp> #include <migraphx/op/reduce_mean.hpp>
#include <migraphx/op/pooling.hpp> #include <migraphx/op/pooling.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
...@@ -840,12 +839,8 @@ TEST_CASE(slice_test) ...@@ -840,12 +839,8 @@ TEST_CASE(slice_test)
mm->add_literal(migraphx::literal{s0, {1, 0}}); mm->add_literal(migraphx::literal{s0, {1, 0}});
mm->add_literal(migraphx::literal{s0, {2, -1}}); mm->add_literal(migraphx::literal{s0, {2, -1}});
migraphx::op::slice op; mm->add_instruction(
op.starts = {1, 0}; migraphx::make_op("slice", {{"starts", {1, 0}}, {"ends", {3, 10}}, {"axes", {0, 1}}}), l0);
op.ends = {3, 10};
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
mm->add_instruction(op, l0);
auto prog = optimize_tf("slice_test.pb", false); auto prog = optimize_tf("slice_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
...@@ -975,13 +970,10 @@ TEST_CASE(stridedslice_test) ...@@ -975,13 +970,10 @@ TEST_CASE(stridedslice_test)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 1, 1}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 1, 1}});
auto l1 = auto l1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
std::size_t num_axes = 4; auto l2 = mm->add_instruction(
migraphx::op::slice op; migraphx::make_op(
op.starts = {0, 0, 0, 0}; "slice", {{"starts", {0, 0, 0, 0}}, {"ends", {1, 1, 1, 5}}, {"axes", {0, 1, 2, 3}}}),
op.ends = {1, 1, 1, 5}; l1);
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
auto l2 = mm->add_instruction(op, l1);
auto shrink_axis = 1; auto shrink_axis = 1;
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {shrink_axis}}}), l2); mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {shrink_axis}}}), l2);
auto prog = optimize_tf("stridedslice_test.pb", true); auto prog = optimize_tf("stridedslice_test.pb", true);
...@@ -995,12 +987,6 @@ TEST_CASE(stridedslice_masks_test) ...@@ -995,12 +987,6 @@ TEST_CASE(stridedslice_masks_test)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 3, 3}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 3, 3}});
std::size_t num_axes = 4;
migraphx::op::slice op;
op.starts = {0, 1, 1, 0};
op.ends = {1, 3, 3, 10};
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
// add literals for starts, ends, and strides in tf (NHWC format) // add literals for starts, ends, and strides in tf (NHWC format)
mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}},
std::vector<int>{0, 1, 1, 0}); std::vector<int>{0, 1, 1, 0});
...@@ -1011,7 +997,10 @@ TEST_CASE(stridedslice_masks_test) ...@@ -1011,7 +997,10 @@ TEST_CASE(stridedslice_masks_test)
auto l1 = auto l1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
auto l2 = mm->add_instruction(op, l1); auto l2 = mm->add_instruction(
migraphx::make_op(
"slice", {{"starts", {0, 1, 1, 0}}, {"ends", {1, 3, 3, 10}}, {"axes", {0, 1, 2, 3}}}),
l1);
auto l3 = auto l3 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), l2); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), l2);
mm->add_return({l3}); mm->add_return({l3});
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/operators.hpp> #include <migraphx/make_op.hpp>
struct gemm_literal : verify_program<gemm_literal> struct gemm_literal : verify_program<gemm_literal>
{ {
...@@ -38,7 +38,7 @@ struct gemm_literal : verify_program<gemm_literal> ...@@ -38,7 +38,7 @@ struct gemm_literal : verify_program<gemm_literal>
auto a = mm->add_literal(migraphx::generate_literal(a_shape)); auto a = mm->add_literal(migraphx::generate_literal(a_shape));
auto b = mm->add_parameter("b", b_shape); auto b = mm->add_parameter("b", b_shape);
mm->add_instruction(migraphx::op::dot{}, a, b); mm->add_instruction(migraphx::make_op("dot"), a, b);
return p; return p;
} }
......
...@@ -31,7 +31,7 @@ pip3 install -r requirements-dev.txt ...@@ -31,7 +31,7 @@ pip3 install -r requirements-dev.txt
# Add newer cmake to the path # Add newer cmake to the path
export PATH="/opt/cmake/bin:$PATH" export PATH="/opt/cmake/bin:$PATH"
export CXXFLAGS="-D__HIP_PLATFORM_AMD__=1 -w" export CXXFLAGS="-D__HIP_PLATFORM_AMD__=1 -w"
./build.sh --config Release --cmake_extra_defines CMAKE_HIP_COMPILER=/opt/rocm/llvm/bin/clang++ --update --build --parallel --cmake_extra_defines ONNXRUNTIME_VERSION=$(cat ./VERSION_NUMBER) --skip_tests --rocm_home /opt/rocm --use_migraphx --migraphx_home /opt/rocm --rocm_version=`cat /opt/rocm/.info/version-dev` --allow_running_as_root ./build.sh --config Release --cmake_extra_defines CMAKE_HIP_COMPILER=/opt/rocm/llvm/bin/clang++ --update --build --build_wheel --parallel --cmake_extra_defines ONNXRUNTIME_VERSION=$(cat ./VERSION_NUMBER) --skip_tests --rocm_home /opt/rocm --use_migraphx --migraphx_home /opt/rocm --rocm_version=`cat /opt/rocm/.info/version-dev` --allow_running_as_root
cd build/Linux/Release cd build/Linux/Release
#Add test launcher for onnxrt tests #Add test launcher for onnxrt tests
......
...@@ -3,7 +3,7 @@ FROM registry.suse.com/suse/sle15:15.4 ...@@ -3,7 +3,7 @@ FROM registry.suse.com/suse/sle15:15.4
RUN sh -c 'echo -e "\ RUN sh -c 'echo -e "\
[rocm]\n\ [rocm]\n\
name=rocm\n\ name=rocm\n\
baseurl=https://repo.radeon.com/rocm/zyp/5.5/main\n\ baseurl=https://repo.radeon.com/rocm/zyp/5.6/main\n\
enabled=1\n\ enabled=1\n\
gpgcheck=1\n\ gpgcheck=1\n\
gpgkey=https://repo.radeon.com/rocm/rocm.gpg.key\n\ gpgkey=https://repo.radeon.com/rocm/rocm.gpg.key\n\
...@@ -19,7 +19,12 @@ RUN zypper --gpg-auto-import-keys install -y \ ...@@ -19,7 +19,12 @@ RUN zypper --gpg-auto-import-keys install -y \
gcc-c++ \ gcc-c++ \
gdb \ gdb \
git \ git \
python3-pip python3-pip \
rpm-build
#addition of repos for packages
RUN OPENSUSE_REPO=https://download.opensuse.org/repositories && \
zypper addrepo ${OPENSUSE_REPO}/devel:/languages:/perl/SLE_15_SP4/devel:languages:perl.repo
# Workaround broken rocm packages # Workaround broken rocm packages
RUN ln -s /opt/rocm-* /opt/rocm RUN ln -s /opt/rocm-* /opt/rocm
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment