"example/01_gemm/gemm_xdl_fp16.cpp" did not exist on "6260ced2f3a4d9a2a832563905135c01ba72b56b"
Commit 32751f4a authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent d5122475
...@@ -845,7 +845,7 @@ struct dot ...@@ -845,7 +845,7 @@ struct dot
const shape& b = inputs.at(1); const shape& b = inputs.at(1);
auto t = a.type(); auto t = a.type();
if (!std::all_of(inputs.begin(), inputs.end(), [](auto s) { return s.lens().size() >= 2; })) if(!std::all_of(inputs.begin(), inputs.end(), [](auto s) { return s.lens().size() >= 2; }))
{ {
MIGRAPHX_THROW("DOT: dot only accept 2 or more dims operands"); MIGRAPHX_THROW("DOT: dot only accept 2 or more dims operands");
} }
......
...@@ -231,7 +231,7 @@ argument miopen_gemm::compute(context& ctx, ...@@ -231,7 +231,7 @@ argument miopen_gemm::compute(context& ctx,
m * n, m * n,
num_matrices); num_matrices);
}); });
return args[3]; return args[3];
} }
...@@ -255,29 +255,26 @@ argument miopen_gemm::compute(context& ctx, ...@@ -255,29 +255,26 @@ argument miopen_gemm::compute(context& ctx,
rocblas_int m = out_lens[dim_0]; rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1]; rocblas_int n = out_lens[dim_1];
rocblas_int k = args[0].get_shape().lens()[dim_1]; rocblas_int k = args[0].get_shape().lens()[dim_1];
auto num_matrices = std::accumulate(out_lens.rbegin() + 2, auto num_matrices = std::accumulate(
out_lens.rend(), out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
std::size_t{1}, auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
std::multiplies<std::size_t>()); if(num_matrices == 1)
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
if (num_matrices == 1)
{ {
generic_rocblas_gemm( generic_rocblas_gemm(as,
as, ctx.get_stream().get_rocblas(),
ctx.get_stream().get_rocblas(), transb ? rocblas_operation_transpose : rocblas_operation_none,
transb ? rocblas_operation_transpose : rocblas_operation_none, transa ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none, n,
n, m,
m, k,
k, &alpha_r,
&alpha_r, to_pointer(args[1]),
to_pointer(args[1]), ldb,
ldb, to_pointer(args[0]),
to_pointer(args[0]), lda,
lda, &beta_r,
&beta_r, to_pointer(args[2]),
to_pointer(args[2]), ldc);
ldc);
} }
else else
{ {
......
...@@ -498,8 +498,8 @@ TEST_CASE(matmul_vv_inner_product) ...@@ -498,8 +498,8 @@ TEST_CASE(matmul_vv_inner_product)
-0.2342857}; -0.2342857};
migraphx::shape a_shape{migraphx::shape::float_type, {8}}; migraphx::shape a_shape{migraphx::shape::float_type, {8}};
migraphx::shape b_shape{migraphx::shape::float_type, {8}}; migraphx::shape b_shape{migraphx::shape::float_type, {8}};
auto al = p.add_literal(migraphx::literal{a_shape, a}); auto al = p.add_literal(migraphx::literal{a_shape, a});
auto bl = p.add_literal(migraphx::literal{b_shape, b}); auto bl = p.add_literal(migraphx::literal{b_shape, b});
auto ual = p.add_instruction(migraphx::op::unsqueeze{{0}}, al); auto ual = p.add_instruction(migraphx::op::unsqueeze{{0}}, al);
auto ubl = p.add_instruction(migraphx::op::unsqueeze{{1}}, bl); auto ubl = p.add_instruction(migraphx::op::unsqueeze{{1}}, bl);
p.add_instruction(migraphx::op::dot{}, ual, ubl); p.add_instruction(migraphx::op::dot{}, ual, ubl);
...@@ -533,8 +533,8 @@ TEST_CASE(matmul_vv_inner_product) ...@@ -533,8 +533,8 @@ TEST_CASE(matmul_vv_inner_product)
migraphx::shape b_shape{migraphx::shape::float_type, {8}}; migraphx::shape b_shape{migraphx::shape::float_type, {8}};
auto al = p.add_literal(migraphx::literal{a_shape, a}); auto al = p.add_literal(migraphx::literal{a_shape, a});
auto bl = p.add_literal(migraphx::literal{b_shape, b}); auto bl = p.add_literal(migraphx::literal{b_shape, b});
auto ual = p.add_instruction(migraphx::op::unsqueeze{{0}}, al); auto ual = p.add_instruction(migraphx::op::unsqueeze{{0}}, al);
auto ubl = p.add_instruction(migraphx::op::unsqueeze{{1}}, bl); auto ubl = p.add_instruction(migraphx::op::unsqueeze{{1}}, bl);
float alpha = 0.32f; float alpha = 0.32f;
p.add_instruction(migraphx::op::dot{alpha}, ual, ubl); p.add_instruction(migraphx::op::dot{alpha}, ual, ubl);
std::vector<float> gold = {-0.4590752}; std::vector<float> gold = {-0.4590752};
...@@ -567,7 +567,7 @@ TEST_CASE(matmul_vm) ...@@ -567,7 +567,7 @@ TEST_CASE(matmul_vm)
1.21119765, 1.23869861, 1.42169414, 0.86412382, 1.05898002, 1.21119765, 1.23869861, 1.42169414, 0.86412382, 1.05898002,
-0.31918307, 1.08546695, 1.50682711, -0.66083538, -0.32683929}; -0.31918307, 1.08546695, 1.50682711, -0.66083538, -0.32683929};
migraphx::shape a_shape{migraphx::shape::float_type, {8}}; migraphx::shape a_shape{migraphx::shape::float_type, {8}};
auto al = p.add_literal(migraphx::literal{a_shape, a}); auto al = p.add_literal(migraphx::literal{a_shape, a});
auto ual = p.add_instruction(migraphx::op::unsqueeze{{0}}, al); auto ual = p.add_instruction(migraphx::op::unsqueeze{{0}}, al);
migraphx::shape b_shape{migraphx::shape::float_type, {8, 5}}; migraphx::shape b_shape{migraphx::shape::float_type, {8, 5}};
auto bl = p.add_literal(migraphx::literal{b_shape, b}); auto bl = p.add_literal(migraphx::literal{b_shape, b});
...@@ -600,7 +600,7 @@ TEST_CASE(matmul_vm) ...@@ -600,7 +600,7 @@ TEST_CASE(matmul_vm)
1.21119765, 1.23869861, 1.42169414, 0.86412382, 1.05898002, 1.21119765, 1.23869861, 1.42169414, 0.86412382, 1.05898002,
-0.31918307, 1.08546695, 1.50682711, -0.66083538, -0.32683929}; -0.31918307, 1.08546695, 1.50682711, -0.66083538, -0.32683929};
migraphx::shape a_shape{migraphx::shape::float_type, {8}}; migraphx::shape a_shape{migraphx::shape::float_type, {8}};
auto al = p.add_literal(migraphx::literal{a_shape, a}); auto al = p.add_literal(migraphx::literal{a_shape, a});
auto ual = p.add_instruction(migraphx::op::unsqueeze{{0}}, al); auto ual = p.add_instruction(migraphx::op::unsqueeze{{0}}, al);
migraphx::shape b_shape{migraphx::shape::float_type, {8, 5}}; migraphx::shape b_shape{migraphx::shape::float_type, {8, 5}};
auto bl = p.add_literal(migraphx::literal{b_shape, b}); auto bl = p.add_literal(migraphx::literal{b_shape, b});
...@@ -634,8 +634,8 @@ TEST_CASE(matmul_vm) ...@@ -634,8 +634,8 @@ TEST_CASE(matmul_vm)
-0.18205627, 0.29446203, -1.91360924, 0.46102174, 0.44977568, -0.48113321}; -0.18205627, 0.29446203, -1.91360924, 0.46102174, 0.44977568, -0.48113321};
migraphx::shape a_shape{migraphx::shape::float_type, {6}}; migraphx::shape a_shape{migraphx::shape::float_type, {6}};
auto al = p.add_literal(migraphx::literal{a_shape, a}); auto al = p.add_literal(migraphx::literal{a_shape, a});
auto ual = p.add_instruction(migraphx::op::unsqueeze{{0}}, al); auto ual = p.add_instruction(migraphx::op::unsqueeze{{0}}, al);
auto bual = p.add_instruction(migraphx::op::multibroadcast{{3, 1, 6}}, ual); auto bual = p.add_instruction(migraphx::op::multibroadcast{{3, 1, 6}}, ual);
migraphx::shape b_shape{migraphx::shape::float_type, {3, 6, 4}}; migraphx::shape b_shape{migraphx::shape::float_type, {3, 6, 4}};
auto bl = p.add_literal(migraphx::literal{b_shape, b}); auto bl = p.add_literal(migraphx::literal{b_shape, b});
...@@ -678,8 +678,8 @@ TEST_CASE(matmul_vm) ...@@ -678,8 +678,8 @@ TEST_CASE(matmul_vm)
-0.18205627, 0.29446203, -1.91360924, 0.46102174, 0.44977568, -0.48113321}; -0.18205627, 0.29446203, -1.91360924, 0.46102174, 0.44977568, -0.48113321};
migraphx::shape a_shape{migraphx::shape::float_type, {6}}; migraphx::shape a_shape{migraphx::shape::float_type, {6}};
auto al = p.add_literal(migraphx::literal{a_shape, a}); auto al = p.add_literal(migraphx::literal{a_shape, a});
auto ual = p.add_instruction(migraphx::op::unsqueeze{{0}}, al); auto ual = p.add_instruction(migraphx::op::unsqueeze{{0}}, al);
auto bual = p.add_instruction(migraphx::op::multibroadcast{{3, 1, 6}}, ual); auto bual = p.add_instruction(migraphx::op::multibroadcast{{3, 1, 6}}, ual);
migraphx::shape b_shape{migraphx::shape::float_type, {3, 6, 4}}; migraphx::shape b_shape{migraphx::shape::float_type, {3, 6, 4}};
auto bl = p.add_literal(migraphx::literal{b_shape, b}); auto bl = p.add_literal(migraphx::literal{b_shape, b});
...@@ -729,7 +729,7 @@ TEST_CASE(matmul_mv) ...@@ -729,7 +729,7 @@ TEST_CASE(matmul_mv)
migraphx::shape a_shape{migraphx::shape::float_type, {3, 5}}; migraphx::shape a_shape{migraphx::shape::float_type, {3, 5}};
auto al = p.add_literal(migraphx::literal{a_shape, a}); auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::float_type, {5}}; migraphx::shape b_shape{migraphx::shape::float_type, {5}};
auto bl = p.add_literal(migraphx::literal{b_shape, b}); auto bl = p.add_literal(migraphx::literal{b_shape, b});
auto ubl = p.add_instruction(migraphx::op::unsqueeze{{1}}, bl); auto ubl = p.add_instruction(migraphx::op::unsqueeze{{1}}, bl);
p.add_instruction(migraphx::op::dot{}, al, ubl); p.add_instruction(migraphx::op::dot{}, al, ubl);
std::vector<float> gold = {1.31982, 1.19022, -1.96062}; std::vector<float> gold = {1.31982, 1.19022, -1.96062};
...@@ -764,7 +764,7 @@ TEST_CASE(matmul_mv) ...@@ -764,7 +764,7 @@ TEST_CASE(matmul_mv)
auto al = p.add_literal(migraphx::literal{a_shape, a}); auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::float_type, {5}}; migraphx::shape b_shape{migraphx::shape::float_type, {5}};
auto bl = p.add_literal(migraphx::literal{b_shape, b}); auto bl = p.add_literal(migraphx::literal{b_shape, b});
auto ubl = p.add_instruction(migraphx::op::unsqueeze{{1}}, bl); auto ubl = p.add_instruction(migraphx::op::unsqueeze{{1}}, bl);
float alpha = 0.3f; float alpha = 0.3f;
p.add_instruction(migraphx::op::dot{alpha}, al, ubl); p.add_instruction(migraphx::op::dot{alpha}, al, ubl);
std::vector<float> gold = {0.395946, 0.357067, -0.588187}; std::vector<float> gold = {0.395946, 0.357067, -0.588187};
...@@ -793,8 +793,8 @@ TEST_CASE(matmul_mv) ...@@ -793,8 +793,8 @@ TEST_CASE(matmul_mv)
migraphx::shape a_shape{migraphx::shape::float_type, {2, 2, 3, 5}}; migraphx::shape a_shape{migraphx::shape::float_type, {2, 2, 3, 5}};
auto al = p.add_literal(migraphx::literal{a_shape, a}); auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::float_type, {5}}; migraphx::shape b_shape{migraphx::shape::float_type, {5}};
auto bl = p.add_literal(migraphx::literal{b_shape, b}); auto bl = p.add_literal(migraphx::literal{b_shape, b});
auto ubl = p.add_instruction(migraphx::op::unsqueeze{{1}}, bl); auto ubl = p.add_instruction(migraphx::op::unsqueeze{{1}}, bl);
auto bubl = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 5, 1}}, ubl); auto bubl = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 5, 1}}, ubl);
p.add_instruction(migraphx::op::dot{}, al, bubl); p.add_instruction(migraphx::op::dot{}, al, bubl);
std::vector<float> gold = {-0.792717, std::vector<float> gold = {-0.792717,
...@@ -851,7 +851,7 @@ TEST_CASE(matmul_mm1) ...@@ -851,7 +851,7 @@ TEST_CASE(matmul_mm1)
migraphx::shape a_shape{migraphx::shape::float_type, {2, 2, 3, 5}}; migraphx::shape a_shape{migraphx::shape::float_type, {2, 2, 3, 5}};
auto al = p.add_literal(migraphx::literal{a_shape, a}); auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::float_type, {5, 3}}; migraphx::shape b_shape{migraphx::shape::float_type, {5, 3}};
auto bl = p.add_literal(migraphx::literal{b_shape, b}); auto bl = p.add_literal(migraphx::literal{b_shape, b});
auto bbl = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 5, 3}}, bl); auto bbl = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 5, 3}}, bl);
p.add_instruction(migraphx::op::dot{}, al, bbl); p.add_instruction(migraphx::op::dot{}, al, bbl);
std::vector<float> gold = {-0.386828, 0.187735, -0.22822, -0.148057, 2.015, -2.56938, std::vector<float> gold = {-0.386828, 0.187735, -0.22822, -0.148057, 2.015, -2.56938,
...@@ -897,7 +897,7 @@ TEST_CASE(matmul_mm1) ...@@ -897,7 +897,7 @@ TEST_CASE(matmul_mm1)
-0.14231862, -1.90915568, -0.06895489, 0.20160375, 0.01945916, 0.03586956}; -0.14231862, -1.90915568, -0.06895489, 0.20160375, 0.01945916, 0.03586956};
migraphx::shape a_shape{migraphx::shape::float_type, {3, 4}}; migraphx::shape a_shape{migraphx::shape::float_type, {3, 4}};
auto al = p.add_literal(migraphx::literal{a_shape, a}); auto al = p.add_literal(migraphx::literal{a_shape, a});
auto bal = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 3, 4}}, al); auto bal = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 3, 4}}, al);
migraphx::shape b_shape{migraphx::shape::float_type, {2, 3, 4, 3}}; migraphx::shape b_shape{migraphx::shape::float_type, {2, 3, 4, 3}};
auto bl = p.add_literal(migraphx::literal{b_shape, b}); auto bl = p.add_literal(migraphx::literal{b_shape, b});
...@@ -943,7 +943,7 @@ TEST_CASE(matmul_mm2) ...@@ -943,7 +943,7 @@ TEST_CASE(matmul_mm2)
auto al = p.add_literal(migraphx::literal{a_shape, a}); auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::float_type, {2, 1, 5, 3}}; migraphx::shape b_shape{migraphx::shape::float_type, {2, 1, 5, 3}};
auto bl = p.add_literal(migraphx::literal{b_shape, b}); auto bl = p.add_literal(migraphx::literal{b_shape, b});
auto bbl = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 5, 3}}, bl); auto bbl = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 5, 3}}, bl);
std::vector<float> gold = { std::vector<float> gold = {
0.70574512, -2.80915314, -1.57644969, 1.75415381, -3.13303087, -1.00150259, 0.70574512, -2.80915314, -1.57644969, 1.75415381, -3.13303087, -1.00150259,
-0.18675123, -0.23349122, -0.12357225, 0.82911538, 1.37473744, -1.11709934, -0.18675123, -0.23349122, -0.12357225, 0.82911538, 1.37473744, -1.11709934,
...@@ -975,10 +975,10 @@ TEST_CASE(matmul_mm2) ...@@ -975,10 +975,10 @@ TEST_CASE(matmul_mm2)
1.7746011, 0.24935804, 0.42830791, -0.13593643, 0.38749427, 1.7746011, 0.24935804, 0.42830791, -0.13593643, 0.38749427,
1.39776254, -0.42911717, -1.3537624, -0.81999648, -0.1754485}; 1.39776254, -0.42911717, -1.3537624, -0.81999648, -0.1754485};
migraphx::shape a_shape{migraphx::shape::float_type, {1, 2, 3, 5}}; migraphx::shape a_shape{migraphx::shape::float_type, {1, 2, 3, 5}};
auto al = p.add_literal(migraphx::literal{a_shape, a}); auto al = p.add_literal(migraphx::literal{a_shape, a});
auto bal = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 3, 5}}, al); auto bal = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 3, 5}}, al);
migraphx::shape b_shape{migraphx::shape::float_type, {2, 1, 5, 3}}; migraphx::shape b_shape{migraphx::shape::float_type, {2, 1, 5, 3}};
auto bl = p.add_literal(migraphx::literal{b_shape, b}); auto bl = p.add_literal(migraphx::literal{b_shape, b});
auto bbl = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 5, 3}}, bl); auto bbl = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 5, 3}}, bl);
p.add_instruction(migraphx::op::dot{}, bal, bbl); p.add_instruction(migraphx::op::dot{}, bal, bbl);
std::vector<float> gold = { std::vector<float> gold = {
...@@ -1071,7 +1071,7 @@ TEST_CASE(matmul_mm2) ...@@ -1071,7 +1071,7 @@ TEST_CASE(matmul_mm2)
migraphx::shape a_shape{migraphx::shape::float_type, {2, 2, 3, 4}}; migraphx::shape a_shape{migraphx::shape::float_type, {2, 2, 3, 4}};
auto al = p.add_literal(migraphx::literal{a_shape, a}); auto al = p.add_literal(migraphx::literal{a_shape, a});
migraphx::shape b_shape{migraphx::shape::float_type, {2, 4, 5}}; migraphx::shape b_shape{migraphx::shape::float_type, {2, 4, 5}};
auto bl = p.add_literal(migraphx::literal{b_shape, b}); auto bl = p.add_literal(migraphx::literal{b_shape, b});
auto bbl = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 4, 5}}, bl); auto bbl = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 4, 5}}, bl);
p.add_instruction(migraphx::op::dot{}, al, bbl); p.add_instruction(migraphx::op::dot{}, al, bbl);
std::vector<float> gold = { std::vector<float> gold = {
......
...@@ -893,8 +893,8 @@ struct gemm_2args_mm_1 : verify_program<gemm_2args_mm_1> ...@@ -893,8 +893,8 @@ struct gemm_2args_mm_1 : verify_program<gemm_2args_mm_1>
migraphx::program p; migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}}; migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}}; migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape); auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape); auto l2 = p.add_parameter("2", m2_shape);
auto bl2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4}}, l2); auto bl2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4}}, l2);
p.add_instruction(migraphx::op::dot{}, l1, bl2); p.add_instruction(migraphx::op::dot{}, l1, bl2);
...@@ -910,8 +910,8 @@ struct gemm_2args_mm_2 : verify_program<gemm_2args_mm_2> ...@@ -910,8 +910,8 @@ struct gemm_2args_mm_2 : verify_program<gemm_2args_mm_2>
migraphx::program p; migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}}; migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 4}}; migraphx::shape m2_shape{migraphx::shape::float_type, {3, 4}};
auto l1 = p.add_parameter("1", m1_shape); auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape); auto l2 = p.add_parameter("2", m2_shape);
auto bl2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4}}, l2); auto bl2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4}}, l2);
p.add_instruction(migraphx::op::dot{}, l1, bl2); p.add_instruction(migraphx::op::dot{}, l1, bl2);
...@@ -927,9 +927,9 @@ struct gemm_2args_mm_3 : verify_program<gemm_2args_mm_3> ...@@ -927,9 +927,9 @@ struct gemm_2args_mm_3 : verify_program<gemm_2args_mm_3>
migraphx::program p; migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}}; migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 3, 4}}; migraphx::shape m2_shape{migraphx::shape::float_type, {3, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape); auto l1 = p.add_parameter("1", m1_shape);
auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{3, 2, 3}}, l1); auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{3, 2, 3}}, l1);
auto l2 = p.add_parameter("2", m2_shape); auto l2 = p.add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, bl1, l2); p.add_instruction(migraphx::op::dot{}, bl1, l2);
...@@ -944,9 +944,9 @@ struct gemm_2args_mm_4 : verify_program<gemm_2args_mm_4> ...@@ -944,9 +944,9 @@ struct gemm_2args_mm_4 : verify_program<gemm_2args_mm_4>
migraphx::program p; migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}}; migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 3, 4}}; migraphx::shape m2_shape{migraphx::shape::float_type, {3, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape); auto l1 = p.add_parameter("1", m1_shape);
auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{3, 2, 3}}, l1); auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{3, 2, 3}}, l1);
auto l2 = p.add_parameter("2", m2_shape); auto l2 = p.add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, bl1, l2); p.add_instruction(migraphx::op::dot{}, bl1, l2);
...@@ -961,9 +961,9 @@ struct gemm_2args_mm_5 : verify_program<gemm_2args_mm_5> ...@@ -961,9 +961,9 @@ struct gemm_2args_mm_5 : verify_program<gemm_2args_mm_5>
migraphx::program p; migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 1, 2, 3}}; migraphx::shape m1_shape{migraphx::shape::float_type, {2, 1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 4}}; migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape); auto l1 = p.add_parameter("1", m1_shape);
auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 2, 3}}, l1); auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 2, 3}}, l1);
auto l2 = p.add_parameter("2", m2_shape); auto l2 = p.add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, bl1, l2); p.add_instruction(migraphx::op::dot{}, bl1, l2);
...@@ -978,9 +978,9 @@ struct gemm_2args_mm_6 : verify_program<gemm_2args_mm_6> ...@@ -978,9 +978,9 @@ struct gemm_2args_mm_6 : verify_program<gemm_2args_mm_6>
migraphx::program p; migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 1, 2, 3}}; migraphx::shape m1_shape{migraphx::shape::float_type, {2, 1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 3, 4}}; migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape); auto l1 = p.add_parameter("1", m1_shape);
auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 2, 3}}, l1); auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 2, 3}}, l1);
auto l2 = p.add_parameter("2", m2_shape); auto l2 = p.add_parameter("2", m2_shape);
auto bl2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 3, 4}}, l2); auto bl2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 3, 4}}, l2);
p.add_instruction(migraphx::op::dot{}, bl1, bl2); p.add_instruction(migraphx::op::dot{}, bl1, bl2);
...@@ -996,9 +996,9 @@ struct gemm_2args_mm_7 : verify_program<gemm_2args_mm_7> ...@@ -996,9 +996,9 @@ struct gemm_2args_mm_7 : verify_program<gemm_2args_mm_7>
migraphx::program p; migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}}; migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 4}}; migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape); auto l1 = p.add_parameter("1", m1_shape);
auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 2, 3}}, l1); auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 2, 3}}, l1);
auto l2 = p.add_parameter("2", m2_shape); auto l2 = p.add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, bl1, l2); p.add_instruction(migraphx::op::dot{}, bl1, l2);
...@@ -1030,12 +1030,12 @@ struct gemm_2args_vv : verify_program<gemm_2args_vv> ...@@ -1030,12 +1030,12 @@ struct gemm_2args_vv : verify_program<gemm_2args_vv>
migraphx::shape m1_shape{migraphx::shape::float_type, {8}}; migraphx::shape m1_shape{migraphx::shape::float_type, {8}};
migraphx::shape m2_shape{migraphx::shape::float_type, {8}}; migraphx::shape m2_shape{migraphx::shape::float_type, {8}};
auto l1 = p.add_parameter("1", m1_shape); auto l1 = p.add_parameter("1", m1_shape);
auto ul1 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l1); auto ul1 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l1);
auto l2 = p.add_parameter("2", m2_shape); auto l2 = p.add_parameter("2", m2_shape);
auto ul2 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l2); auto ul2 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l2);
float alpha = 0.23f; float alpha = 0.23f;
auto res = p.add_instruction(migraphx::op::dot{alpha}, ul1, ul2); auto res = p.add_instruction(migraphx::op::dot{alpha}, ul1, ul2);
auto sres = p.add_instruction(migraphx::op::squeeze{{0}}, res); auto sres = p.add_instruction(migraphx::op::squeeze{{0}}, res);
p.add_instruction(migraphx::op::squeeze{{0}}, sres); p.add_instruction(migraphx::op::squeeze{{0}}, sres);
...@@ -1050,11 +1050,10 @@ struct gemm_2args_mv : verify_program<gemm_2args_mv> ...@@ -1050,11 +1050,10 @@ struct gemm_2args_mv : verify_program<gemm_2args_mv>
migraphx::program p; migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {3, 5}}; migraphx::shape m1_shape{migraphx::shape::float_type, {3, 5}};
migraphx::shape m2_shape{migraphx::shape::float_type, {5}}; migraphx::shape m2_shape{migraphx::shape::float_type, {5}};
auto l1 = p.add_parameter("1", m1_shape); auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape); auto l2 = p.add_parameter("2", m2_shape);
auto ul2 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l2); auto ul2 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l2);
p.add_instruction(migraphx::op::dot{}, l1, ul2); p.add_instruction(migraphx::op::dot{}, l1, ul2);
return p; return p;
...@@ -1068,9 +1067,9 @@ struct gemm_2args_bmv : verify_program<gemm_2args_bmv> ...@@ -1068,9 +1067,9 @@ struct gemm_2args_bmv : verify_program<gemm_2args_bmv>
migraphx::program p; migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 3, 5}}; migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 3, 5}};
migraphx::shape m2_shape{migraphx::shape::float_type, {5}}; migraphx::shape m2_shape{migraphx::shape::float_type, {5}};
auto l1 = p.add_parameter("1", m1_shape); auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape); auto l2 = p.add_parameter("2", m2_shape);
auto ul2 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l2); auto ul2 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l2);
auto bul2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 5, 1}}, ul2); auto bul2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 5, 1}}, ul2);
p.add_instruction(migraphx::op::dot{}, l1, bul2); p.add_instruction(migraphx::op::dot{}, l1, bul2);
...@@ -1086,9 +1085,9 @@ struct gemm_2args_vm : verify_program<gemm_2args_vm> ...@@ -1086,9 +1085,9 @@ struct gemm_2args_vm : verify_program<gemm_2args_vm>
migraphx::program p; migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {5}}; migraphx::shape m1_shape{migraphx::shape::float_type, {5}};
migraphx::shape m2_shape{migraphx::shape::float_type, {5, 4}}; migraphx::shape m2_shape{migraphx::shape::float_type, {5, 4}};
auto l1 = p.add_parameter("1", m1_shape); auto l1 = p.add_parameter("1", m1_shape);
auto ul1 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l1); auto ul1 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l1);
auto l2 = p.add_parameter("2", m2_shape); auto l2 = p.add_parameter("2", m2_shape);
auto res = p.add_instruction(migraphx::op::dot{}, ul1, l2); auto res = p.add_instruction(migraphx::op::dot{}, ul1, l2);
p.add_instruction(migraphx::op::squeeze{{0}}, res); p.add_instruction(migraphx::op::squeeze{{0}}, res);
...@@ -1104,8 +1103,8 @@ struct gemm_2args_vbm : verify_program<gemm_2args_vbm> ...@@ -1104,8 +1103,8 @@ struct gemm_2args_vbm : verify_program<gemm_2args_vbm>
migraphx::program p; migraphx::program p;
migraphx::shape m1_shape{migraphx::shape::float_type, {5}}; migraphx::shape m1_shape{migraphx::shape::float_type, {5}};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 2, 5, 4}}; migraphx::shape m2_shape{migraphx::shape::float_type, {2, 2, 5, 4}};
auto l1 = p.add_parameter("1", m1_shape); auto l1 = p.add_parameter("1", m1_shape);
auto ul1 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l1); auto ul1 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l1);
auto bul1 = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 1, 5}}, ul1); auto bul1 = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 1, 5}}, ul1);
auto l2 = p.add_parameter("2", m2_shape); auto l2 = p.add_parameter("2", m2_shape);
...@@ -1113,7 +1112,6 @@ struct gemm_2args_vbm : verify_program<gemm_2args_vbm> ...@@ -1113,7 +1112,6 @@ struct gemm_2args_vbm : verify_program<gemm_2args_vbm>
auto res = p.add_instruction(migraphx::op::dot{}, bul1, l2); auto res = p.add_instruction(migraphx::op::dot{}, bul1, l2);
p.add_instruction(migraphx::op::squeeze{{2}}, res); p.add_instruction(migraphx::op::squeeze{{2}}, res);
return p; return p;
} }
}; };
......
...@@ -608,8 +608,8 @@ TEST_CASE(gemm_ex_brcst) ...@@ -608,8 +608,8 @@ TEST_CASE(gemm_ex_brcst)
TEST_CASE(matmul_vv) TEST_CASE(matmul_vv)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7}}); auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}}); auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}});
auto sl0 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l0); auto sl0 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l0);
auto sl1 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l1); auto sl1 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l1);
auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, sl0, sl1); auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, sl0, sl1);
...@@ -624,8 +624,8 @@ TEST_CASE(matmul_vv) ...@@ -624,8 +624,8 @@ TEST_CASE(matmul_vv)
TEST_CASE(matmul_vm) TEST_CASE(matmul_vm)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7}}); auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7, 8}}); auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7, 8}});
auto sl0 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l0); auto sl0 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l0);
auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, sl0, l1); auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, sl0, l1);
p.add_instruction(migraphx::op::squeeze{{0}}, res); p.add_instruction(migraphx::op::squeeze{{0}}, res);
...@@ -638,9 +638,9 @@ TEST_CASE(matmul_vm) ...@@ -638,9 +638,9 @@ TEST_CASE(matmul_vm)
TEST_CASE(matmul_vbm) TEST_CASE(matmul_vbm)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7}}); auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {5, 7, 8}}); auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {5, 7, 8}});
auto sl0 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l0); auto sl0 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l0);
auto bsl0 = p.add_instruction(migraphx::op::multibroadcast{{5, 1, 7}}, sl0); auto bsl0 = p.add_instruction(migraphx::op::multibroadcast{{5, 1, 7}}, sl0);
std::cout << "ONNX_TEST" << std::endl; std::cout << "ONNX_TEST" << std::endl;
auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, bsl0, l1); auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, bsl0, l1);
...@@ -655,8 +655,8 @@ TEST_CASE(matmul_vbm) ...@@ -655,8 +655,8 @@ TEST_CASE(matmul_vbm)
TEST_CASE(matmul_mv) TEST_CASE(matmul_mv)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {6, 7}}); auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {6, 7}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}}); auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}});
auto sl1 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l1); auto sl1 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l1);
auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, l0, sl1); auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, l0, sl1);
p.add_instruction(migraphx::op::squeeze{{1}}, res); p.add_instruction(migraphx::op::squeeze{{1}}, res);
...@@ -669,11 +669,11 @@ TEST_CASE(matmul_mv) ...@@ -669,11 +669,11 @@ TEST_CASE(matmul_mv)
TEST_CASE(matmul_bmv) TEST_CASE(matmul_bmv)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 6, 7}}); auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 6, 7}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}}); auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}});
auto sl1 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l1); auto sl1 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l1);
auto bsl1 = p.add_instruction(migraphx::op::multibroadcast{{3, 7, 1}}, sl1); auto bsl1 = p.add_instruction(migraphx::op::multibroadcast{{3, 7, 1}}, sl1);
auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, l0, bsl1); auto res = p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, l0, bsl1);
p.add_instruction(migraphx::op::squeeze{{2}}, res); p.add_instruction(migraphx::op::squeeze{{2}}, res);
auto prog = migraphx::parse_onnx("matmul_bmv.onnx"); auto prog = migraphx::parse_onnx("matmul_bmv.onnx");
...@@ -684,8 +684,8 @@ TEST_CASE(matmul_bmv) ...@@ -684,8 +684,8 @@ TEST_CASE(matmul_bmv)
TEST_CASE(matmul_bmbm) TEST_CASE(matmul_bmbm)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 6, 7}}); auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 6, 7}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {5, 2, 1, 7, 8}}); auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {5, 2, 1, 7, 8}});
auto bl0 = p.add_instruction(migraphx::op::multibroadcast{{5, 2, 3, 6, 7}}, l0); auto bl0 = p.add_instruction(migraphx::op::multibroadcast{{5, 2, 3, 6, 7}}, l0);
auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{5, 2, 3, 7, 8}}, l1); auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{5, 2, 3, 7, 8}}, l1);
p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, bl0, bl1); p.add_instruction(migraphx::op::dot{1.0f, 0.0f}, bl0, bl1);
......
...@@ -414,15 +414,19 @@ TEST_CASE(matmul) ...@@ -414,15 +414,19 @@ TEST_CASE(matmul)
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {6, 1, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {6, 1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {6, 5, 4}}; migraphx::shape s_m2{migraphx::shape::float_type, {6, 5, 4}};
expect_shape( expect_shape(migraphx::shape{migraphx::shape::float_type, {6, 1, 4}},
migraphx::shape{migraphx::shape::float_type, {6, 1, 4}}, migraphx::op::dot{}, s_m1, s_m2); migraphx::op::dot{},
s_m1,
s_m2);
} }
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {1, 6, 1, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {1, 6, 1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 6, 5, 4}}; migraphx::shape s_m2{migraphx::shape::float_type, {1, 6, 5, 4}};
expect_shape( expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 6, 1, 4}},
migraphx::shape{migraphx::shape::float_type, {1, 6, 1, 4}}, migraphx::op::dot{}, s_m1, s_m2); migraphx::op::dot{},
s_m1,
s_m2);
} }
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment