Commit 102d2eb6 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 1a456eb4
......@@ -111,8 +111,8 @@ argument miopen_gemm::compute(context& ctx,
rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1];
rocblas_int k = args[0].get_shape().lens()[dim_1];
auto batch_num = std::accumulate(out_lens.rbegin() + 2, out_lens.rend(),
std::size_t{1}, std::multiplies<std::size_t>());
auto batch_num = std::accumulate(
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(alpha));
auto beta_r = to_rocblas_type(as(beta));
......
......@@ -977,15 +977,24 @@ TEST_CASE_REGISTER(gemm_test_ex<double>)
TEST_CASE(gemm_mutli_dim_2)
{
migraphx::program p;
std::vector<float> m1 = {
-0.76234141, 0.01368910, -0.86343423, -0.99465282, 0.76133268, 0.96507140,
-0.55893585, 0.02625652, 0.75171776, 0.23112578, 0.25624787, -1.50442161};
std::vector<float> m1 = {-0.76234141,
0.01368910,
-0.86343423,
-0.99465282,
0.76133268,
0.96507140,
-0.55893585,
0.02625652,
0.75171776,
0.23112578,
0.25624787,
-1.50442161};
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}};
std::vector<float> m2 = {
-0.15933632, -0.69594712, -0.06198966, -1.23905184, -0.83672704, -1.06971832,
-0.12272917, 1.07094116, -0.08346820, 1.16820693, -0.95700874, 0.24059691,
0.43326023, 0.78305235, -0.53506601, -0.69359678, -0.26334436, 1.56292796,
-0.33629175, -1.72693469, 0.41435494, 1.52136843, -0.40699791, -1.59839430};
std::vector<float> m2 = {-0.15933632, -0.69594712, -0.06198966, -1.23905184, -0.83672704,
-1.06971832, -0.12272917, 1.07094116, -0.08346820, 1.16820693,
-0.95700874, 0.24059691, 0.43326023, 0.78305235, -0.53506601,
-0.69359678, -0.26334436, 1.56292796, -0.33629175, -1.72693469,
0.41435494, 1.52136843, -0.40699791, -1.59839430};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 4}};
auto l1 = p.add_literal(migraphx::literal{m1_shape, m1});
auto l2 = p.add_literal(migraphx::literal{m2_shape, m2});
......@@ -996,10 +1005,22 @@ TEST_CASE(gemm_mutli_dim_2)
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
std::vector<float> m_res = {
0.18208394, -0.49276402, 0.87189133, 0.75150114, -0.55909610, 1.00521735,
-0.95536130, 2.27996211, 0.06239879, 0.74700068, -0.01570983, -0.85920856,
-0.59070835, -1.70729902, 0.40245487, 1.80182751};
std::vector<float> m_res = {0.18208394,
-0.49276402,
0.87189133,
0.75150114,
-0.55909610,
1.00521735,
-0.95536130,
2.27996211,
0.06239879,
0.74700068,
-0.01570983,
-0.85920856,
-0.59070835,
-1.70729902,
0.40245487,
1.80182751};
EXPECT(migraphx::verify_range(m, m_res));
}
......@@ -1032,11 +1053,11 @@ TEST_CASE(gemm_mutli_dim_2_3)
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
std::vector<float> m_res = {
0.26735861, -4.30770895, 1.05257728, -1.19954265, 0.50493170, -0.18729756,
1.09137941, -1.09298312, 3.42956915, -0.41681939, 0.17833257, 0.26040336,
0.15351280, 1.87632715, -0.63545406, -0.95467340, -1.74728628, -2.42477030,
0.76262372, 0.15539164, 3.32281958, 0.96769613, 0.43727545, 2.43019906};
std::vector<float> m_res = {0.26735861, -4.30770895, 1.05257728, -1.19954265, 0.50493170,
-0.18729756, 1.09137941, -1.09298312, 3.42956915, -0.41681939,
0.17833257, 0.26040336, 0.15351280, 1.87632715, -0.63545406,
-0.95467340, -1.74728628, -2.42477030, 0.76262372, 0.15539164,
3.32281958, 0.96769613, 0.43727545, 2.43019906};
EXPECT(migraphx::verify_range(m, m_res));
}
......@@ -1060,11 +1081,11 @@ TEST_CASE(gemm_mutli_dim1_2_3)
-0.12470397, 0.70404393, -0.15244797, 0.74288871, 0.07339926, -1.45811623,
0.27185845, 0.08804596, 0.99061977, -1.61752428, 0.29191159, 0.87271953};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 2}};
std::vector<float> m3 = {
-1.07692443, 0.85223457, -0.37266530, 2.31511577, 0.04227017, 1.13229428,
-0.52769242, 0.27307182, -0.47779843, -0.08023168, -0.22862823, 0.81489871,
1.13139581, 1.13860467, 0.24309065, 0.26533729, 0.49106772, -1.18860493,
0.27842449, 1.03568141, 0.49759611, 0.10021662, 0.00592602, 0.90862000};
std::vector<float> m3 = {-1.07692443, 0.85223457, -0.37266530, 2.31511577, 0.04227017,
1.13229428, -0.52769242, 0.27307182, -0.47779843, -0.08023168,
-0.22862823, 0.81489871, 1.13139581, 1.13860467, 0.24309065,
0.26533729, 0.49106772, -1.18860493, 0.27842449, 1.03568141,
0.49759611, 0.10021662, 0.00592602, 0.90862000};
migraphx::shape m3_shape{migraphx::shape::float_type, {2, 3, 2, 2}};
auto l1 = p.add_literal(migraphx::literal{m1_shape, m1});
......@@ -1082,11 +1103,11 @@ TEST_CASE(gemm_mutli_dim1_2_3)
std::vector<float> m;
result.visit([&](auto output) { m.assign(output.begin(), output.end()); });
std::vector<float> m_res = {
-0.91147203, 0.47540785, -0.30313587, 0.43325099, -0.43711586, 0.50928632,
0.06919868, -0.80382802, -0.05125718, -0.06685650, -0.06972163, 0.32407764,
0.45677396, 0.25909489, 0.56911252, -0.17183724, 0.10858734, 0.39406289,
0.04662959, 1.07979824, 0.40355016, 0.52410648, -0.31728447, 1.09550845};
std::vector<float> m_res = {-0.91147203, 0.47540785, -0.30313587, 0.43325099, -0.43711586,
0.50928632, 0.06919868, -0.80382802, -0.05125718, -0.06685650,
-0.06972163, 0.32407764, 0.45677396, 0.25909489, 0.56911252,
-0.17183724, 0.10858734, 0.39406289, 0.04662959, 1.07979824,
0.40355016, 0.52410648, -0.31728447, 1.09550845};
EXPECT(migraphx::verify_range(m, m_res));
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment