Commit dd26f1aa authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into rnn_optimization

parents 4e3d06ab 4a3e493c
#include <migraphx/auto_contiguous.hpp> #include <migraphx/auto_contiguous.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/transpose.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <test.hpp> #include <test.hpp>
......
#include <migraphx/common_subexpression_elimination.hpp> #include <migraphx/common_subexpression_elimination.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/add.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <test.hpp> #include <test.hpp>
......
#include <migraphx/constant_propagate.hpp> #include <migraphx/constant_propagate.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/add.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <test.hpp> #include <test.hpp>
......
...@@ -51,7 +51,7 @@ void matmul_test() ...@@ -51,7 +51,7 @@ void matmul_test()
p.add_instruction(migraphx::op::dot{}, al, bl); p.add_instruction(migraphx::op::dot{}, al, bl);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<T> results_vector(12); std::vector<T> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(c, results_vector)); EXPECT(migraphx::verify_range(c, results_vector));
} }
...@@ -100,7 +100,7 @@ void matmul_test_ex() ...@@ -100,7 +100,7 @@ void matmul_test_ex()
p.add_instruction(migraphx::op::dot{}, al, bl); p.add_instruction(migraphx::op::dot{}, al, bl);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<T> results_vector(12); std::vector<T> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(c, results_vector)); EXPECT(migraphx::verify_range(c, results_vector));
} }
...@@ -416,165 +416,6 @@ TEST_CASE(gemm_mutli_3args) ...@@ -416,165 +416,6 @@ TEST_CASE(gemm_mutli_3args)
TEST_CASE(gemm_3args) TEST_CASE(gemm_3args)
{ {
{
migraphx::program p;
std::vector<float> a = {-0.86217194,
-1.04129542,
-0.64850364,
-0.97078327,
-0.40516386,
0.83136927,
0.37717502,
0.42271939,
1.10062165,
-0.92239359,
0.40403076,
-0.43935377};
std::vector<float> b = {0.76084386,
1.89201125,
1.73218067,
0.7148568,
-0.55578914,
0.05799101,
-1.24090721,
-0.51151978,
1.13255803,
0.21540723,
-1.10459009,
0.45580331};
std::vector<float> c = {-0.80473623,
0.35154171,
-2.73077756,
-0.09093885,
-1.88850472,
-0.03375556,
-0.41798276,
2.87368099,
2.11031439};
migraphx::shape a_shape{migraphx::shape::float_type, {3, 4}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::float_type, {4, 3}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
migraphx::shape c_shape{migraphx::shape::float_type};
auto cl = p.add_literal(migraphx::literal{c_shape, {1}});
p.add_instruction(migraphx::op::dot{}, al, bl, cl);
std::vector<float> gold = {
0.195264, 1.35154, -1.73078, 0.909061, -0.888505, 0.966244, 0.582017, 3.87368, 3.11031};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
std::vector<float> a = {-0.86217194,
-1.04129542,
-0.64850364,
-0.97078327,
-0.40516386,
0.83136927,
0.37717502,
0.42271939,
1.10062165,
-0.92239359,
0.40403076,
-0.43935377};
std::vector<float> b = {0.76084386,
1.89201125,
1.73218067,
0.7148568,
-0.55578914,
0.05799101,
-1.24090721,
-0.51151978,
1.13255803,
0.21540723,
-1.10459009,
0.45580331};
migraphx::shape a_shape{migraphx::shape::float_type, {3, 4}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::float_type, {4, 3}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
migraphx::shape c_shape{migraphx::shape::float_type, {3, 1}};
std::vector<float> vec_c(3, 2.0f);
auto cl = p.add_literal(migraphx::literal{c_shape, vec_c});
p.add_instruction(migraphx::op::dot{}, al, bl, cl);
std::vector<float> gold = {
1.19526,
2.35154,
-0.730778,
1.90906,
0.111495,
1.96624,
1.58202,
4.87368,
4.11031,
};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{
migraphx::program p;
std::vector<float> a = {-0.86217194,
-1.04129542,
-0.64850364,
-0.97078327,
-0.40516386,
0.83136927,
0.37717502,
0.42271939,
1.10062165,
-0.92239359,
0.40403076,
-0.43935377};
std::vector<float> b = {0.76084386,
1.89201125,
1.73218067,
0.7148568,
-0.55578914,
0.05799101,
-1.24090721,
-0.51151978,
1.13255803,
0.21540723,
-1.10459009,
0.45580331};
migraphx::shape a_shape{migraphx::shape::float_type, {3, 4}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::float_type, {4, 3}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
migraphx::shape c_shape{migraphx::shape::float_type, {3}};
std::vector<float> vec_c(3, 2.0f);
auto cl = p.add_literal(migraphx::literal{c_shape, vec_c});
p.add_instruction(migraphx::op::dot{}, al, bl, cl);
std::vector<float> gold = {
1.19526,
2.35154,
-0.730778,
1.90906,
0.111495,
1.96624,
1.58202,
4.87368,
4.11031,
};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
{ {
migraphx::program p; migraphx::program p;
std::vector<float> a = {-0.86217194, std::vector<float> a = {-0.86217194,
...@@ -657,9 +498,11 @@ TEST_CASE(matmul_vv_inner_product) ...@@ -657,9 +498,11 @@ TEST_CASE(matmul_vv_inner_product)
-0.2342857}; -0.2342857};
migraphx::shape a_shape{migraphx::shape::float_type, {8}}; migraphx::shape a_shape{migraphx::shape::float_type, {8}};
migraphx::shape b_shape{migraphx::shape::float_type, {8}}; migraphx::shape b_shape{migraphx::shape::float_type, {8}};
auto al = p.add_literal(migraphx::literal{a_shape, a}); auto al = p.add_literal(migraphx::literal{a_shape, a});
auto bl = p.add_literal(migraphx::literal{b_shape, b}); auto bl = p.add_literal(migraphx::literal{b_shape, b});
p.add_instruction(migraphx::op::dot{}, al, bl); auto ual = p.add_instruction(migraphx::op::unsqueeze{{0}}, al);
auto ubl = p.add_instruction(migraphx::op::unsqueeze{{1}}, bl);
p.add_instruction(migraphx::op::dot{}, ual, ubl);
std::vector<float> gold = {-1.43461}; std::vector<float> gold = {-1.43461};
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
...@@ -690,8 +533,10 @@ TEST_CASE(matmul_vv_inner_product) ...@@ -690,8 +533,10 @@ TEST_CASE(matmul_vv_inner_product)
migraphx::shape b_shape{migraphx::shape::float_type, {8}}; migraphx::shape b_shape{migraphx::shape::float_type, {8}};
auto al = p.add_literal(migraphx::literal{a_shape, a}); auto al = p.add_literal(migraphx::literal{a_shape, a});
auto bl = p.add_literal(migraphx::literal{b_shape, b}); auto bl = p.add_literal(migraphx::literal{b_shape, b});
auto ual = p.add_instruction(migraphx::op::unsqueeze{{0}}, al);
auto ubl = p.add_instruction(migraphx::op::unsqueeze{{1}}, bl);
float alpha = 0.32f; float alpha = 0.32f;
p.add_instruction(migraphx::op::dot{alpha}, al, bl); p.add_instruction(migraphx::op::dot{alpha}, ual, ubl);
std::vector<float> gold = {-0.4590752}; std::vector<float> gold = {-0.4590752};
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
...@@ -722,10 +567,11 @@ TEST_CASE(matmul_vm) ...@@ -722,10 +567,11 @@ TEST_CASE(matmul_vm)
1.21119765, 1.23869861, 1.42169414, 0.86412382, 1.05898002, 1.21119765, 1.23869861, 1.42169414, 0.86412382, 1.05898002,
-0.31918307, 1.08546695, 1.50682711, -0.66083538, -0.32683929}; -0.31918307, 1.08546695, 1.50682711, -0.66083538, -0.32683929};
migraphx::shape a_shape{migraphx::shape::float_type, {8}}; migraphx::shape a_shape{migraphx::shape::float_type, {8}};
auto al = p.add_literal(migraphx::literal{a_shape, a}); auto al = p.add_literal(migraphx::literal{a_shape, a});
auto ual = p.add_instruction(migraphx::op::unsqueeze{{0}}, al);
migraphx::shape b_shape{migraphx::shape::float_type, {8, 5}}; migraphx::shape b_shape{migraphx::shape::float_type, {8, 5}};
auto bl = p.add_literal(migraphx::literal{b_shape, b}); auto bl = p.add_literal(migraphx::literal{b_shape, b});
p.add_instruction(migraphx::op::dot{}, al, bl); p.add_instruction(migraphx::op::dot{}, ual, bl);
std::vector<float> gold = {-3.78111, -3.40007, -2.1972, -3.31448, -3.80326}; std::vector<float> gold = {-3.78111, -3.40007, -2.1972, -3.31448, -3.80326};
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
...@@ -754,11 +600,12 @@ TEST_CASE(matmul_vm) ...@@ -754,11 +600,12 @@ TEST_CASE(matmul_vm)
1.21119765, 1.23869861, 1.42169414, 0.86412382, 1.05898002, 1.21119765, 1.23869861, 1.42169414, 0.86412382, 1.05898002,
-0.31918307, 1.08546695, 1.50682711, -0.66083538, -0.32683929}; -0.31918307, 1.08546695, 1.50682711, -0.66083538, -0.32683929};
migraphx::shape a_shape{migraphx::shape::float_type, {8}}; migraphx::shape a_shape{migraphx::shape::float_type, {8}};
auto al = p.add_literal(migraphx::literal{a_shape, a}); auto al = p.add_literal(migraphx::literal{a_shape, a});
auto ual = p.add_instruction(migraphx::op::unsqueeze{{0}}, al);
migraphx::shape b_shape{migraphx::shape::float_type, {8, 5}}; migraphx::shape b_shape{migraphx::shape::float_type, {8, 5}};
auto bl = p.add_literal(migraphx::literal{b_shape, b}); auto bl = p.add_literal(migraphx::literal{b_shape, b});
float alpha = 0.5f; float alpha = 0.5f;
p.add_instruction(migraphx::op::dot{alpha}, al, bl); p.add_instruction(migraphx::op::dot{alpha}, ual, bl);
std::vector<float> gold = {-1.89056, -1.70003, -1.0986, -1.65724, -1.90163}; std::vector<float> gold = {-1.89056, -1.70003, -1.0986, -1.65724, -1.90163};
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
...@@ -787,10 +634,12 @@ TEST_CASE(matmul_vm) ...@@ -787,10 +634,12 @@ TEST_CASE(matmul_vm)
-0.18205627, 0.29446203, -1.91360924, 0.46102174, 0.44977568, -0.48113321}; -0.18205627, 0.29446203, -1.91360924, 0.46102174, 0.44977568, -0.48113321};
migraphx::shape a_shape{migraphx::shape::float_type, {6}}; migraphx::shape a_shape{migraphx::shape::float_type, {6}};
auto al = p.add_literal(migraphx::literal{a_shape, a}); auto al = p.add_literal(migraphx::literal{a_shape, a});
auto ual = p.add_instruction(migraphx::op::unsqueeze{{0}}, al);
auto bual = p.add_instruction(migraphx::op::multibroadcast{{3, 1, 6}}, ual);
migraphx::shape b_shape{migraphx::shape::float_type, {3, 6, 4}}; migraphx::shape b_shape{migraphx::shape::float_type, {3, 6, 4}};
auto bl = p.add_literal(migraphx::literal{b_shape, b}); auto bl = p.add_literal(migraphx::literal{b_shape, b});
p.add_instruction(migraphx::op::dot{}, al, bl); p.add_instruction(migraphx::op::dot{}, bual, bl);
std::vector<float> gold = {1.22914, std::vector<float> gold = {1.22914,
-1.17896, -1.17896,
2.28596, 2.28596,
...@@ -829,10 +678,12 @@ TEST_CASE(matmul_vm) ...@@ -829,10 +678,12 @@ TEST_CASE(matmul_vm)
-0.18205627, 0.29446203, -1.91360924, 0.46102174, 0.44977568, -0.48113321}; -0.18205627, 0.29446203, -1.91360924, 0.46102174, 0.44977568, -0.48113321};
migraphx::shape a_shape{migraphx::shape::float_type, {6}}; migraphx::shape a_shape{migraphx::shape::float_type, {6}};
auto al = p.add_literal(migraphx::literal{a_shape, a}); auto al = p.add_literal(migraphx::literal{a_shape, a});
auto ual = p.add_instruction(migraphx::op::unsqueeze{{0}}, al);
auto bual = p.add_instruction(migraphx::op::multibroadcast{{3, 1, 6}}, ual);
migraphx::shape b_shape{migraphx::shape::float_type, {3, 6, 4}}; migraphx::shape b_shape{migraphx::shape::float_type, {3, 6, 4}};
auto bl = p.add_literal(migraphx::literal{b_shape, b}); auto bl = p.add_literal(migraphx::literal{b_shape, b});
p.add_instruction(migraphx::op::dot{0.21f}, al, bl); p.add_instruction(migraphx::op::dot{0.21f}, bual, bl);
std::vector<float> gold = {0.25812, std::vector<float> gold = {0.25812,
-0.247582, -0.247582,
0.480051, 0.480051,
...@@ -878,8 +729,9 @@ TEST_CASE(matmul_mv) ...@@ -878,8 +729,9 @@ TEST_CASE(matmul_mv)
migraphx::shape a_shape{migraphx::shape::float_type, {3, 5}}; migraphx::shape a_shape{migraphx::shape::float_type, {3, 5}};
auto al = p.add_literal(migraphx::literal{a_shape, a}); auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::float_type, {5}}; migraphx::shape b_shape{migraphx::shape::float_type, {5}};
auto bl = p.add_literal(migraphx::literal{b_shape, b}); auto bl = p.add_literal(migraphx::literal{b_shape, b});
p.add_instruction(migraphx::op::dot{}, al, bl); auto ubl = p.add_instruction(migraphx::op::unsqueeze{{1}}, bl);
p.add_instruction(migraphx::op::dot{}, al, ubl);
std::vector<float> gold = {1.31982, 1.19022, -1.96062}; std::vector<float> gold = {1.31982, 1.19022, -1.96062};
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
...@@ -912,8 +764,9 @@ TEST_CASE(matmul_mv) ...@@ -912,8 +764,9 @@ TEST_CASE(matmul_mv)
auto al = p.add_literal(migraphx::literal{a_shape, a}); auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::float_type, {5}}; migraphx::shape b_shape{migraphx::shape::float_type, {5}};
auto bl = p.add_literal(migraphx::literal{b_shape, b}); auto bl = p.add_literal(migraphx::literal{b_shape, b});
auto ubl = p.add_instruction(migraphx::op::unsqueeze{{1}}, bl);
float alpha = 0.3f; float alpha = 0.3f;
p.add_instruction(migraphx::op::dot{alpha}, al, bl); p.add_instruction(migraphx::op::dot{alpha}, al, ubl);
std::vector<float> gold = {0.395946, 0.357067, -0.588187}; std::vector<float> gold = {0.395946, 0.357067, -0.588187};
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
...@@ -940,8 +793,10 @@ TEST_CASE(matmul_mv) ...@@ -940,8 +793,10 @@ TEST_CASE(matmul_mv)
migraphx::shape a_shape{migraphx::shape::float_type, {2, 2, 3, 5}}; migraphx::shape a_shape{migraphx::shape::float_type, {2, 2, 3, 5}};
auto al = p.add_literal(migraphx::literal{a_shape, a}); auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::float_type, {5}}; migraphx::shape b_shape{migraphx::shape::float_type, {5}};
auto bl = p.add_literal(migraphx::literal{b_shape, b}); auto bl = p.add_literal(migraphx::literal{b_shape, b});
p.add_instruction(migraphx::op::dot{}, al, bl); auto ubl = p.add_instruction(migraphx::op::unsqueeze{{1}}, bl);
auto bubl = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 5, 1}}, ubl);
p.add_instruction(migraphx::op::dot{}, al, bubl);
std::vector<float> gold = {-0.792717, std::vector<float> gold = {-0.792717,
6.33595, 6.33595,
2.61466, 2.61466,
...@@ -996,8 +851,9 @@ TEST_CASE(matmul_mm1) ...@@ -996,8 +851,9 @@ TEST_CASE(matmul_mm1)
migraphx::shape a_shape{migraphx::shape::float_type, {2, 2, 3, 5}}; migraphx::shape a_shape{migraphx::shape::float_type, {2, 2, 3, 5}};
auto al = p.add_literal(migraphx::literal{a_shape, a}); auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::float_type, {5, 3}}; migraphx::shape b_shape{migraphx::shape::float_type, {5, 3}};
auto bl = p.add_literal(migraphx::literal{b_shape, b}); auto bl = p.add_literal(migraphx::literal{b_shape, b});
p.add_instruction(migraphx::op::dot{}, al, bl); auto bbl = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 5, 3}}, bl);
p.add_instruction(migraphx::op::dot{}, al, bbl);
std::vector<float> gold = {-0.386828, 0.187735, -0.22822, -0.148057, 2.015, -2.56938, std::vector<float> gold = {-0.386828, 0.187735, -0.22822, -0.148057, 2.015, -2.56938,
-0.782212, 1.9459, 0.927426, -2.44907, 2.40531, 2.30232, -0.782212, 1.9459, 0.927426, -2.44907, 2.40531, 2.30232,
0.182745, -4.21937, 1.77551, 1.50775, -2.60888, -2.32484, 0.182745, -4.21937, 1.77551, 1.50775, -2.60888, -2.32484,
...@@ -1041,10 +897,11 @@ TEST_CASE(matmul_mm1) ...@@ -1041,10 +897,11 @@ TEST_CASE(matmul_mm1)
-0.14231862, -1.90915568, -0.06895489, 0.20160375, 0.01945916, 0.03586956}; -0.14231862, -1.90915568, -0.06895489, 0.20160375, 0.01945916, 0.03586956};
migraphx::shape a_shape{migraphx::shape::float_type, {3, 4}}; migraphx::shape a_shape{migraphx::shape::float_type, {3, 4}};
auto al = p.add_literal(migraphx::literal{a_shape, a}); auto al = p.add_literal(migraphx::literal{a_shape, a});
auto bal = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 3, 4}}, al);
migraphx::shape b_shape{migraphx::shape::float_type, {2, 3, 4, 3}}; migraphx::shape b_shape{migraphx::shape::float_type, {2, 3, 4, 3}};
auto bl = p.add_literal(migraphx::literal{b_shape, b}); auto bl = p.add_literal(migraphx::literal{b_shape, b});
p.add_instruction(migraphx::op::dot{}, al, bl); p.add_instruction(migraphx::op::dot{}, bal, bl);
std::vector<float> gold = { std::vector<float> gold = {
-1.61175, 3.11849, -0.703205, 0.331635, -0.00946922, 0.645626, 0.834069, 1.06409, -1.61175, 3.11849, -0.703205, 0.331635, -0.00946922, 0.645626, 0.834069, 1.06409,
0.881037, 0.227628, -0.200308, -1.71836, 0.156255, 0.477222, 0.571363, -1.04543, 0.881037, 0.227628, -0.200308, -1.71836, 0.156255, 0.477222, 0.571363, -1.04543,
...@@ -1086,6 +943,7 @@ TEST_CASE(matmul_mm2) ...@@ -1086,6 +943,7 @@ TEST_CASE(matmul_mm2)
auto al = p.add_literal(migraphx::literal{a_shape, a}); auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::float_type, {2, 1, 5, 3}}; migraphx::shape b_shape{migraphx::shape::float_type, {2, 1, 5, 3}};
auto bl = p.add_literal(migraphx::literal{b_shape, b}); auto bl = p.add_literal(migraphx::literal{b_shape, b});
auto bbl = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 5, 3}}, bl);
std::vector<float> gold = { std::vector<float> gold = {
0.70574512, -2.80915314, -1.57644969, 1.75415381, -3.13303087, -1.00150259, 0.70574512, -2.80915314, -1.57644969, 1.75415381, -3.13303087, -1.00150259,
-0.18675123, -0.23349122, -0.12357225, 0.82911538, 1.37473744, -1.11709934, -0.18675123, -0.23349122, -0.12357225, 0.82911538, 1.37473744, -1.11709934,
...@@ -1093,7 +951,7 @@ TEST_CASE(matmul_mm2) ...@@ -1093,7 +951,7 @@ TEST_CASE(matmul_mm2)
4.81988916, -3.63687142, -0.19101717, -4.92522092, -1.76377022, -3.58095615, 4.81988916, -3.63687142, -0.19101717, -4.92522092, -1.76377022, -3.58095615,
1.83096922, 2.5512663, -1.07926588, -2.12749134, 0.33014536, -0.80393025, 1.83096922, 2.5512663, -1.07926588, -2.12749134, 0.33014536, -0.80393025,
0.60740202, 0.95217761, -1.06087445, -4.75868152, -3.6687713, -1.26539821}; 0.60740202, 0.95217761, -1.06087445, -4.75868152, -3.6687713, -1.26539821};
p.add_instruction(migraphx::op::dot{}, al, bl); p.add_instruction(migraphx::op::dot{}, al, bbl);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> m; std::vector<float> m;
...@@ -1117,10 +975,12 @@ TEST_CASE(matmul_mm2) ...@@ -1117,10 +975,12 @@ TEST_CASE(matmul_mm2)
1.7746011, 0.24935804, 0.42830791, -0.13593643, 0.38749427, 1.7746011, 0.24935804, 0.42830791, -0.13593643, 0.38749427,
1.39776254, -0.42911717, -1.3537624, -0.81999648, -0.1754485}; 1.39776254, -0.42911717, -1.3537624, -0.81999648, -0.1754485};
migraphx::shape a_shape{migraphx::shape::float_type, {1, 2, 3, 5}}; migraphx::shape a_shape{migraphx::shape::float_type, {1, 2, 3, 5}};
auto al = p.add_literal(migraphx::literal{a_shape, a}); auto al = p.add_literal(migraphx::literal{a_shape, a});
auto bal = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 3, 5}}, al);
migraphx::shape b_shape{migraphx::shape::float_type, {2, 1, 5, 3}}; migraphx::shape b_shape{migraphx::shape::float_type, {2, 1, 5, 3}};
auto bl = p.add_literal(migraphx::literal{b_shape, b}); auto bl = p.add_literal(migraphx::literal{b_shape, b});
p.add_instruction(migraphx::op::dot{}, al, bl); auto bbl = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 5, 3}}, bl);
p.add_instruction(migraphx::op::dot{}, bal, bbl);
std::vector<float> gold = { std::vector<float> gold = {
1.64924590e+00, 2.84575831e+00, 1.07340773e+00, 2.19817080e-01, -1.87873283e+00, 1.64924590e+00, 2.84575831e+00, 1.07340773e+00, 2.19817080e-01, -1.87873283e+00,
1.91883003e+00, -2.89962196e-01, 2.76404142e+00, 1.50048102e+00, -6.29650347e-01, 1.91883003e+00, -2.89962196e-01, 2.76404142e+00, 1.50048102e+00, -6.29650347e-01,
...@@ -1211,8 +1071,9 @@ TEST_CASE(matmul_mm2) ...@@ -1211,8 +1071,9 @@ TEST_CASE(matmul_mm2)
migraphx::shape a_shape{migraphx::shape::float_type, {2, 2, 3, 4}}; migraphx::shape a_shape{migraphx::shape::float_type, {2, 2, 3, 4}};
auto al = p.add_literal(migraphx::literal{a_shape, a}); auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::float_type, {2, 4, 5}}; migraphx::shape b_shape{migraphx::shape::float_type, {2, 4, 5}};
auto bl = p.add_literal(migraphx::literal{b_shape, b}); auto bl = p.add_literal(migraphx::literal{b_shape, b});
p.add_instruction(migraphx::op::dot{}, al, bl); auto bbl = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 4, 5}}, bl);
p.add_instruction(migraphx::op::dot{}, al, bbl);
std::vector<float> gold = { std::vector<float> gold = {
-1.08585245, 0.39575611, 0.33947977, -0.86339678, 1.50710753, 0.05646156, -1.08585245, 0.39575611, 0.33947977, -0.86339678, 1.50710753, 0.05646156,
-0.43180359, 0.19639674, -0.33742881, 0.98443538, -0.9021272, 1.25043704, -0.43180359, 0.19639674, -0.33742881, 0.98443538, -0.9021272, 1.25043704,
...@@ -1230,49 +1091,6 @@ TEST_CASE(matmul_mm2) ...@@ -1230,49 +1091,6 @@ TEST_CASE(matmul_mm2)
result.visit([&](auto output) { m.assign(output.begin(), output.end()); }); result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold)); EXPECT(migraphx::verify_range(m, gold));
} }
{
migraphx::program p;
std::vector<float> a = {0.44434486, -0.4775394, 1.22403495, 1.3390557, -0.16682514,
-0.14706984, 2.03517409, -0.15236999, 1.31615472, -0.98724552,
0.87351608, 0.32548614, -3.41102373, -1.98384933, 0.50167115,
0.59746381, 0.52601494, 1.68033723, -1.69118135, -0.07171001,
-0.21904557, -0.1435285, -0.3086262, -0.55035202};
std::vector<float> b = {
-0.94363619, -0.77647765, -0.67011854, -2.09503007, 0.90123996, -0.46622586,
1.42071249, 0.03609514, 0.15959348, 1.39677643, 1.04978928, 1.00156894,
-0.27378851, 0.0874493, 1.34600448, 2.08173849, 0.46533488, 0.00631963,
-0.56208786, 0.02443816, 0.45989363, 0.62163606, -0.4031336, 0.46017999,
0.39662946, -0.47854661, 1.67630842, -0.21867977, 0.63853741, 0.45437104,
0.29735596, -0.71734146, 0.1237553, 0.0409191, 0.14675446, -0.28671886,
-0.10558661, 0.45182015, 0.52462527, 0.85523901, -0.99229207, 0.35318084,
-1.00044197, 1.79608682, -0.45742108, -0.70323029, -0.39590981, -0.46266041,
-0.69778675, 0.37064368, 0.47614881, -0.30574358, 0.51562266, 1.47646532,
0.81795032, 0.62790241, -1.17363991, -0.82171213, 0.43211813, -0.63605139,
1.18437641, 0.23012845, -0.37945211, 0.01256212};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 4}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::float_type, {2, 2, 4, 4}};
auto bl = p.add_literal(migraphx::literal{b_shape, b});
p.add_instruction(migraphx::op::dot{}, al, bl);
std::vector<float> gold = {
-1.02094755, 1.70442001, 2.1111438, 3.06536646, 0.39139469, 3.0274623,
1.83426191, 2.06536787, -2.08142323, 0.68688487, -0.92945811, -1.2405549,
-1.91914741, -1.22339147, 3.73566635, -1.5345778, 0.30098761, 1.82460858,
-3.82933195, 1.20738012, -0.64176798, -0.19297878, -0.50001913, 0.39087862,
-1.72170067, -0.70693856, -1.94004086, 1.0431326, -1.95490676, 0.75266023,
-2.07738769, 3.64789696, -0.74854627, -0.31258412, -1.32754766, 0.1966239,
1.47609026, -4.46809498, -3.2567728, -0.51434837, 2.39927998, 4.04908547,
0.92131416, 1.96903951, -0.21076738, -0.16615248, -0.1462282, 0.16623842};
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(m, gold));
}
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/rnn.hpp>
#include <migraphx/op/gru.hpp>
#include <migraphx/op/lstm.hpp>
#include <migraphx/op/rnn_last_output.hpp>
#include <migraphx/op/rnn_last_cell_output.hpp>
#include <migraphx/op/abnormal_ops.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/cpu/target.hpp> #include <migraphx/cpu/target.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
......
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/abnormal_ops.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/identity.hpp>
#include <test.hpp> #include <test.hpp>
struct dce_target struct dce_target
...@@ -129,4 +131,55 @@ TEST_CASE(undefined_test) ...@@ -129,4 +131,55 @@ TEST_CASE(undefined_test)
EXPECT(result != migraphx::literal{4}); EXPECT(result != migraphx::literal{4});
} }
TEST_CASE(duplicate_args1)
{
migraphx::program p;
auto l0 = p.add_literal(0);
auto l3 = p.add_literal(3);
p.add_instruction(migraphx::op::add{}, l3, l3);
p.add_instruction(migraphx::op::identity{}, l0);
auto count = std::distance(p.begin(), p.end());
p.compile(dce_target{});
EXPECT(std::distance(p.begin(), p.end()) != count);
EXPECT(std::distance(p.begin(), p.end()) == 2);
auto result = p.eval({});
EXPECT(result == migraphx::literal{0});
}
TEST_CASE(duplicate_args2)
{
migraphx::program p;
auto l0 = p.add_literal(0);
auto l3 = p.add_literal(3);
auto sum1 = p.add_instruction(migraphx::op::add{}, l0, l3);
p.add_instruction(migraphx::op::add{}, sum1, l3);
p.add_instruction(migraphx::op::identity{}, l0);
auto count = std::distance(p.begin(), p.end());
p.compile(dce_target{});
EXPECT(std::distance(p.begin(), p.end()) != count);
EXPECT(std::distance(p.begin(), p.end()) == 2);
auto result = p.eval({});
EXPECT(result == migraphx::literal{0});
}
TEST_CASE(duplicate_args3)
{
migraphx::program p;
auto l0 = p.add_literal(0);
auto l3 = p.add_literal(3);
auto sum1 = p.add_instruction(migraphx::op::add{}, l0, l3);
auto sum2 = p.add_instruction(migraphx::op::add{}, l0, sum1);
p.add_instruction(migraphx::op::add{}, sum2, l3);
p.add_instruction(migraphx::op::identity{}, l0);
auto count = std::distance(p.begin(), p.end());
p.compile(dce_target{});
EXPECT(std::distance(p.begin(), p.end()) != count);
EXPECT(std::distance(p.begin(), p.end()) == 2);
auto result = p.eval({});
EXPECT(result == migraphx::literal{0});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <migraphx/eliminate_allocation.hpp> #include <migraphx/eliminate_allocation.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/operators.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <test.hpp> #include <test.hpp>
......
#include <migraphx/eliminate_concat.hpp> #include <migraphx/eliminate_concat.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/concat.hpp>
#include <migraphx/op/load.hpp>
#include <migraphx/op/identity.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <test.hpp> #include <test.hpp>
......
#include <migraphx/eliminate_contiguous.hpp> #include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/transpose.hpp>
#include <migraphx/op/contiguous.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <test.hpp> #include <test.hpp>
......
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <migraphx/op/identity.hpp>
#include <test.hpp>
struct eliminate_identity_target
{
std::string name() const { return "eliminate_identity"; }
std::vector<migraphx::pass> get_passes(migraphx::context&) const
{
return {migraphx::eliminate_identity{}};
}
migraphx::context get_context() const { return {}; }
};
TEST_CASE(simple_test)
{
migraphx::program p;
auto one = p.add_literal(1);
auto one_identity = p.add_instruction(migraphx::op::identity{}, one);
auto two = p.add_literal(2);
auto two_identity = p.add_instruction(migraphx::op::identity{}, two);
p.add_instruction(sum_op{}, one_identity, two_identity);
p.compile(eliminate_identity_target{});
EXPECT(std::none_of(p.begin(), p.end(), [](const migraphx::instruction& ins) {
return ins.name() == "identity";
}));
auto result = p.eval({});
EXPECT(result == migraphx::literal{3});
}
TEST_CASE(simple_test_end)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto ans = p.add_instruction(sum_op{}, one, two);
p.add_instruction(migraphx::op::identity{}, ans);
p.compile(eliminate_identity_target{});
EXPECT(std::none_of(p.begin(), p.end(), [](const migraphx::instruction& ins) {
return ins.name() == "identity";
}));
auto result = p.eval({});
EXPECT(result == migraphx::literal{3});
}
TEST_CASE(simple_test_end_dependency)
{
migraphx::program p;
auto one = p.add_literal(1.0);
auto two = p.add_literal(2.0);
auto three = p.add_literal(3.0);
auto ans = p.add_instruction(sum_op{}, one, two);
p.add_instruction(sum_op{}, ans, three);
p.add_instruction(migraphx::op::identity{}, ans);
p.compile(eliminate_identity_target{});
EXPECT(std::any_of(p.begin(), p.end(), [](const migraphx::instruction& ins) {
return ins.name() == "identity";
}));
auto result = p.eval({});
EXPECT(result == migraphx::literal{3.0});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <migraphx/operators.hpp>
#include <test.hpp>
struct eliminate_pad_target
{
std::string name() const { return "eliminate_pad"; }
std::vector<migraphx::pass> get_passes(migraphx::context&) const
{
return {migraphx::eliminate_pad{}, migraphx::dead_code_elimination{}};
}
migraphx::context get_context() const { return {}; }
};
migraphx::instruction_ref
create_im2col(migraphx::instruction_ref& l_img, size_t channels, migraphx::program& p)
{
size_t f[2] = {1, 1};
std::vector<int32_t> weights(channels * f[0] * f[1]);
migraphx::shape s_weights{migraphx::shape::int32_type, {1, channels, f[0], f[1]}};
auto l_weights = p.add_literal(migraphx::literal{s_weights, weights});
return p.add_instruction(migraphx::op::im2col{}, l_img, l_weights);
}
migraphx::instruction_ref
create_conv(migraphx::instruction_ref& l_img,
size_t channels,
migraphx::program& p,
migraphx::op::padding_mode_t padding_mode = migraphx::op::padding_mode_t::default_)
{
migraphx::shape s_weights{migraphx::shape::int32_type, {4, channels, 3, 3}};
std::vector<int32_t> weights(4 * channels * 3 * 3);
auto l_weights = p.add_literal(migraphx::literal{s_weights, weights});
migraphx::op::convolution op;
op.padding_mode = padding_mode;
return p.add_instruction(op, l_img, l_weights);
}
TEST_CASE(rewrite_test)
{
migraphx::program p;
size_t img_dim[2] = {2, 2};
size_t channels = 1;
std::vector<int32_t> input(channels * img_dim[0] * img_dim[1]);
std::iota(input.begin(), input.end(), 0);
migraphx::shape s_img{migraphx::shape::int32_type, {1, channels, img_dim[0], img_dim[1]}};
auto l_img = p.add_literal(migraphx::literal{s_img, input});
auto padded_img = p.add_instruction(migraphx::op::pad{{0, 0, 1, 1, 0, 0, 1, 1}}, l_img);
auto l0 = create_im2col(padded_img, channels, p);
auto l1 = create_conv(padded_img, channels, p);
auto l2 = p.add_instruction(migraphx::op::pooling{}, padded_img);
p.add_instruction(migraphx::op::identity{}, l0, l1, l2);
p.compile(eliminate_pad_target{});
EXPECT(std::none_of(
p.begin(), p.end(), [](const migraphx::instruction& ins) { return ins.name() == "pad"; }));
}
TEST_CASE(rewrite_test_asymmetric)
{
migraphx::program p;
size_t img_dim[2] = {2, 2};
size_t channels = 1;
std::vector<int32_t> input(channels * img_dim[0] * img_dim[1]);
std::iota(input.begin(), input.end(), 0);
migraphx::shape s_img{migraphx::shape::int32_type, {1, channels, img_dim[0], img_dim[1]}};
auto l_img = p.add_literal(migraphx::literal{s_img, input});
auto padded_img = p.add_instruction(migraphx::op::pad{{0, 0, 0, 0, 0, 0, 2, 2}}, l_img);
create_im2col(padded_img, channels, p);
p.compile(eliminate_pad_target{});
EXPECT(std::any_of(
p.begin(), p.end(), [](const migraphx::instruction& ins) { return ins.name() == "pad"; }));
}
TEST_CASE(rewrite_test_same_padding)
{
migraphx::program p;
size_t img_dim[2] = {2, 2};
size_t channels = 1;
std::vector<int32_t> input(channels * img_dim[0] * img_dim[1]);
std::iota(input.begin(), input.end(), 0);
migraphx::shape s_img{migraphx::shape::int32_type, {1, channels, img_dim[0], img_dim[1]}};
auto l_img = p.add_literal(migraphx::literal{s_img, input});
auto padded_img = p.add_instruction(migraphx::op::pad{{0, 0, 1, 1, 0, 0, 1, 1}}, l_img);
create_conv(padded_img, channels, p, migraphx::op::padding_mode_t::same);
p.compile(eliminate_pad_target{});
EXPECT(std::any_of(
p.begin(), p.end(), [](const migraphx::instruction& ins) { return ins.name() == "pad"; }));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp> #include <migraphx/fwd_conv_batchnorm_rewrite.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/cpu/target.hpp> #include <migraphx/cpu/target.hpp>
#include <migraphx/operators.hpp> #include <migraphx/op/convolution.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/batch_norm.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
#include <test.hpp> #include <test.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
bool is_batch_norm(migraphx::instruction& ins) { return ins.name() == "batch_norm_inference"; }
TEST_CASE(fwd_conv_batchnorm_rewrite_test) TEST_CASE(fwd_conv_batchnorm_rewrite_test)
{ {
std::vector<float> xdata = { std::vector<float> xdata = {
...@@ -65,4 +71,105 @@ TEST_CASE(fwd_conv_batchnorm_rewrite_test) ...@@ -65,4 +71,105 @@ TEST_CASE(fwd_conv_batchnorm_rewrite_test)
EXPECT(migraphx::verify_range(results_vector1, results_vector2)); EXPECT(migraphx::verify_range(results_vector1, results_vector2));
} }
TEST_CASE(non_literal)
{
migraphx::shape xs{migraphx::shape::float_type, {1, 3, 8, 8}};
migraphx::shape ws{migraphx::shape::float_type, {4, 3, 1, 1}};
migraphx::shape vars{migraphx::shape::float_type, {4}};
auto create_program = [&]() {
migraphx::program p;
auto x = p.add_parameter("x", xs);
auto w = p.add_parameter("w", ws);
auto conv = p.add_instruction(migraphx::op::convolution{}, x, w);
auto scale = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1)));
auto bias = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2)));
auto mean = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3)));
auto variance = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4)));
p.add_instruction(migraphx::op::batch_norm_inference{}, conv, scale, bias, mean, variance);
return p;
};
migraphx::program p1 = create_program();
migraphx::program p2 = create_program();
migraphx::fwd_conv_batchnorm_rewrite opt;
opt.apply(p2);
EXPECT(any_of(p1, &is_batch_norm));
EXPECT(any_of(p2, &is_batch_norm));
}
TEST_CASE(as_literal)
{
migraphx::shape xs{migraphx::shape::float_type, {1, 3, 8, 8}};
migraphx::shape ws{migraphx::shape::float_type, {4, 3, 1, 1}};
migraphx::shape vars{migraphx::shape::float_type, {4}};
auto create_program = [&]() {
migraphx::program p;
auto x = p.add_literal(migraphx::generate_literal(xs, 1));
auto w = p.add_literal(migraphx::generate_literal(ws, 1));
auto conv = p.add_instruction(migraphx::op::convolution{}, x, w);
auto scale = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1)));
auto bias = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2)));
auto mean = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3)));
auto variance = p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4)));
p.add_instruction(migraphx::op::batch_norm_inference{}, conv, scale, bias, mean, variance);
return p;
};
migraphx::program p1 = create_program();
migraphx::program p2 = create_program();
migraphx::fwd_conv_batchnorm_rewrite opt;
opt.apply(p2);
EXPECT(any_of(p1, &is_batch_norm));
EXPECT(none_of(p2, &is_batch_norm));
p1.compile(migraphx::cpu::target{});
p2.compile(migraphx::cpu::target{});
auto result1 = p1.eval({});
auto result2 = p2.eval({});
visit_all(result1, result2)([&](auto r1, auto r2) { EXPECT(migraphx::verify_range(r1, r2)); });
}
TEST_CASE(literal_reshape)
{
migraphx::shape xs{migraphx::shape::float_type, {1, 3, 8, 8}};
migraphx::shape ws{migraphx::shape::float_type, {4, 3, 1, 1}};
migraphx::shape vars{migraphx::shape::float_type, {4}};
auto create_program = [&]() {
migraphx::program p;
auto reshape = [&](auto ins) {
return p.add_instruction(migraphx::op::reshape{{1, 4, 1, 1}}, ins);
};
auto x = p.add_literal(migraphx::generate_literal(xs, 1));
auto w = p.add_literal(migraphx::generate_literal(ws, 1));
auto conv = p.add_instruction(migraphx::op::convolution{}, x, w);
auto scale = reshape(p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1))));
auto bias = reshape(p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2))));
auto mean = reshape(p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))));
auto variance = reshape(p.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))));
p.add_instruction(migraphx::op::batch_norm_inference{}, conv, scale, bias, mean, variance);
return p;
};
migraphx::program p1 = create_program();
migraphx::program p2 = create_program();
migraphx::fwd_conv_batchnorm_rewrite opt;
opt.apply(p2);
EXPECT(any_of(p1, &is_batch_norm));
EXPECT(none_of(p2, &is_batch_norm));
p1.compile(migraphx::cpu::target{});
p2.compile(migraphx::cpu::target{});
auto result1 = p1.eval({});
auto result2 = p2.eval({});
visit_all(result1, result2)([&](auto r1, auto r2) { EXPECT(migraphx::verify_range(r1, r2)); });
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -893,10 +893,11 @@ struct gemm_2args_mm_1 : verify_program<gemm_2args_mm_1> ...@@ -893,10 +893,11 @@ struct gemm_2args_mm_1 : verify_program<gemm_2args_mm_1>
migraphx::program p; migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}}; migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}}; migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape); auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape); auto l2 = p.add_parameter("2", m2_shape);
auto bl2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4}}, l2);
p.add_instruction(migraphx::op::dot{}, l1, l2); p.add_instruction(migraphx::op::dot{}, l1, bl2);
return p; return p;
} }
...@@ -909,10 +910,11 @@ struct gemm_2args_mm_2 : verify_program<gemm_2args_mm_2> ...@@ -909,10 +910,11 @@ struct gemm_2args_mm_2 : verify_program<gemm_2args_mm_2>
migraphx::program p; migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}}; migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 4}}; migraphx::shape m2_shape{migraphx::shape::float_type, {3, 4}};
auto l1 = p.add_parameter("1", m1_shape); auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape); auto l2 = p.add_parameter("2", m2_shape);
auto bl2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4}}, l2);
p.add_instruction(migraphx::op::dot{}, l1, l2); p.add_instruction(migraphx::op::dot{}, l1, bl2);
return p; return p;
} }
...@@ -925,10 +927,11 @@ struct gemm_2args_mm_3 : verify_program<gemm_2args_mm_3> ...@@ -925,10 +927,11 @@ struct gemm_2args_mm_3 : verify_program<gemm_2args_mm_3>
migraphx::program p; migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}}; migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 3, 4}}; migraphx::shape m2_shape{migraphx::shape::float_type, {3, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape); auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape); auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{3, 2, 3}}, l1);
auto l2 = p.add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, l1, l2); p.add_instruction(migraphx::op::dot{}, bl1, l2);
return p; return p;
} }
...@@ -941,10 +944,11 @@ struct gemm_2args_mm_4 : verify_program<gemm_2args_mm_4> ...@@ -941,10 +944,11 @@ struct gemm_2args_mm_4 : verify_program<gemm_2args_mm_4>
migraphx::program p; migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}}; migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 3, 4}}; migraphx::shape m2_shape{migraphx::shape::float_type, {3, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape); auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape); auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{3, 2, 3}}, l1);
auto l2 = p.add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, l1, l2); p.add_instruction(migraphx::op::dot{}, bl1, l2);
return p; return p;
} }
...@@ -957,10 +961,11 @@ struct gemm_2args_mm_5 : verify_program<gemm_2args_mm_5> ...@@ -957,10 +961,11 @@ struct gemm_2args_mm_5 : verify_program<gemm_2args_mm_5>
migraphx::program p; migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 1, 2, 3}}; migraphx::shape m1_shape{migraphx::shape::float_type, {2, 1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 4}}; migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape); auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape); auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 2, 3}}, l1);
auto l2 = p.add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, l1, l2); p.add_instruction(migraphx::op::dot{}, bl1, l2);
return p; return p;
} }
...@@ -973,10 +978,12 @@ struct gemm_2args_mm_6 : verify_program<gemm_2args_mm_6> ...@@ -973,10 +978,12 @@ struct gemm_2args_mm_6 : verify_program<gemm_2args_mm_6>
migraphx::program p; migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 1, 2, 3}}; migraphx::shape m1_shape{migraphx::shape::float_type, {2, 1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 3, 4}}; migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape); auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape); auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 2, 3}}, l1);
auto l2 = p.add_parameter("2", m2_shape);
auto bl2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 3, 4}}, l2);
p.add_instruction(migraphx::op::dot{}, l1, l2); p.add_instruction(migraphx::op::dot{}, bl1, bl2);
return p; return p;
} }
...@@ -989,10 +996,11 @@ struct gemm_2args_mm_7 : verify_program<gemm_2args_mm_7> ...@@ -989,10 +996,11 @@ struct gemm_2args_mm_7 : verify_program<gemm_2args_mm_7>
migraphx::program p; migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}}; migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 4}}; migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape); auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape); auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 2, 3}}, l1);
auto l2 = p.add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, l1, l2); p.add_instruction(migraphx::op::dot{}, bl1, l2);
return p; return p;
} }
...@@ -1019,12 +1027,17 @@ struct gemm_2args_vv : verify_program<gemm_2args_vv> ...@@ -1019,12 +1027,17 @@ struct gemm_2args_vv : verify_program<gemm_2args_vv>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {5}}; migraphx::shape m1_shape{migraphx::shape::float_type, {8}};
migraphx::shape m2_shape{migraphx::shape::float_type, {5}}; migraphx::shape m2_shape{migraphx::shape::float_type, {8}};
auto l1 = p.add_parameter("1", m1_shape); auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape); auto ul1 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l1);
auto l2 = p.add_parameter("2", m2_shape);
auto ul2 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l2);
float alpha = 0.23f;
p.add_instruction(migraphx::op::dot{}, l1, l2); auto res = p.add_instruction(migraphx::op::dot{alpha}, ul1, ul2);
auto sres = p.add_instruction(migraphx::op::squeeze{{0}}, res);
p.add_instruction(migraphx::op::squeeze{{0}}, sres);
return p; return p;
} }
...@@ -1037,10 +1050,11 @@ struct gemm_2args_mv : verify_program<gemm_2args_mv> ...@@ -1037,10 +1050,11 @@ struct gemm_2args_mv : verify_program<gemm_2args_mv>
migraphx::program p; migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {3, 5}}; migraphx::shape m1_shape{migraphx::shape::float_type, {3, 5}};
migraphx::shape m2_shape{migraphx::shape::float_type, {5}}; migraphx::shape m2_shape{migraphx::shape::float_type, {5}};
auto l1 = p.add_parameter("1", m1_shape); auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape); auto l2 = p.add_parameter("2", m2_shape);
auto ul2 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l2);
p.add_instruction(migraphx::op::dot{}, l1, l2); p.add_instruction(migraphx::op::dot{}, l1, ul2);
return p; return p;
} }
...@@ -1053,10 +1067,12 @@ struct gemm_2args_bmv : verify_program<gemm_2args_bmv> ...@@ -1053,10 +1067,12 @@ struct gemm_2args_bmv : verify_program<gemm_2args_bmv>
migraphx::program p; migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 3, 5}}; migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 3, 5}};
migraphx::shape m2_shape{migraphx::shape::float_type, {5}}; migraphx::shape m2_shape{migraphx::shape::float_type, {5}};
auto l1 = p.add_parameter("1", m1_shape); auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape); auto l2 = p.add_parameter("2", m2_shape);
auto ul2 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l2);
auto bul2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 5, 1}}, ul2);
p.add_instruction(migraphx::op::dot{}, l1, l2); p.add_instruction(migraphx::op::dot{}, l1, bul2);
return p; return p;
} }
...@@ -1069,10 +1085,12 @@ struct gemm_2args_vm : verify_program<gemm_2args_vm> ...@@ -1069,10 +1085,12 @@ struct gemm_2args_vm : verify_program<gemm_2args_vm>
migraphx::program p; migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {5}}; migraphx::shape m1_shape{migraphx::shape::float_type, {5}};
migraphx::shape m2_shape{migraphx::shape::float_type, {5, 4}}; migraphx::shape m2_shape{migraphx::shape::float_type, {5, 4}};
auto l1 = p.add_parameter("1", m1_shape); auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape); auto ul1 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l1);
auto l2 = p.add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, l1, l2); auto res = p.add_instruction(migraphx::op::dot{}, ul1, l2);
p.add_instruction(migraphx::op::squeeze{{0}}, res);
return p; return p;
} }
...@@ -1085,10 +1103,14 @@ struct gemm_2args_vbm : verify_program<gemm_2args_vbm> ...@@ -1085,10 +1103,14 @@ struct gemm_2args_vbm : verify_program<gemm_2args_vbm>
migraphx::program p; migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {5}}; migraphx::shape m1_shape{migraphx::shape::float_type, {5}};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 2, 5, 4}}; migraphx::shape m2_shape{migraphx::shape::float_type, {2, 2, 5, 4}};
auto l1 = p.add_parameter("1", m1_shape); auto l1 = p.add_parameter("1", m1_shape);
auto ul1 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l1);
auto bul1 = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 1, 5}}, ul1);
auto l2 = p.add_parameter("2", m2_shape); auto l2 = p.add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, l1, l2); auto res = p.add_instruction(migraphx::op::dot{}, bul1, l2);
p.add_instruction(migraphx::op::squeeze{{2}}, res);
return p; return p;
} }
...@@ -1114,66 +1136,6 @@ struct gemm_multi_3args : verify_program<gemm_multi_3args> ...@@ -1114,66 +1136,6 @@ struct gemm_multi_3args : verify_program<gemm_multi_3args>
} }
}; };
struct gemm_multi_3args_c0 : verify_program<gemm_multi_3args_c0>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 2}};
migraphx::shape m3_shape{migraphx::shape::float_type};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto l3 = p.add_parameter("3", m3_shape);
float alpha = 0.35;
float beta = 0.41;
p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
return p;
}
};
struct gemm_multi_3args_c5 : verify_program<gemm_multi_3args_c5>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 5}};
migraphx::shape m3_shape{migraphx::shape::float_type, {5}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto l3 = p.add_parameter("3", m3_shape);
float alpha = 0.35;
float beta = 0.41;
p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
return p;
}
};
struct gemm_multi_3args_c21 : verify_program<gemm_multi_3args_c21>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 5}};
migraphx::shape m3_shape{migraphx::shape::float_type, {2, 1}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto l3 = p.add_parameter("3", m3_shape);
float alpha = 0.35;
float beta = 0.41;
p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
return p;
}
};
struct gemm_multi_3args_c25 : verify_program<gemm_multi_3args_c25> struct gemm_multi_3args_c25 : verify_program<gemm_multi_3args_c25>
{ {
migraphx::program create_program() const migraphx::program create_program() const
......
#include <migraphx/memory_coloring.hpp> #include <migraphx/memory_coloring.hpp>
#include <migraphx/operators.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
......
matmul-example:{

1
2y"MatMul test_matmulZ
1



Z
2





b
y





B
\ No newline at end of file
matmul-example:_

1
2y"MatMul test_matmulZ
1



Z
2

b
y


B
\ No newline at end of file
matmul-example:W

1
2y"MatMul test_matmulZ
1


Z
2

b
y

B
\ No newline at end of file
matmul-example:_

1
2y"MatMul test_matmulZ
1

Z
2



b
y


B
\ No newline at end of file
matmul-example:W

1
2y"MatMul test_matmulZ
1

Z
2


b
y

B
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment