Commit 32751f4a authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent d5122475
...@@ -845,7 +845,7 @@ struct dot ...@@ -845,7 +845,7 @@ struct dot
const shape& b = inputs.at(1); const shape& b = inputs.at(1);
auto t = a.type(); auto t = a.type();
if (!std::all_of(inputs.begin(), inputs.end(), [](auto s) { return s.lens().size() >= 2; })) if(!std::all_of(inputs.begin(), inputs.end(), [](auto s) { return s.lens().size() >= 2; }))
{ {
MIGRAPHX_THROW("DOT: dot only accept 2 or more dims operands"); MIGRAPHX_THROW("DOT: dot only accept 2 or more dims operands");
} }
......
...@@ -255,15 +255,12 @@ argument miopen_gemm::compute(context& ctx, ...@@ -255,15 +255,12 @@ argument miopen_gemm::compute(context& ctx,
rocblas_int m = out_lens[dim_0]; rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1]; rocblas_int n = out_lens[dim_1];
rocblas_int k = args[0].get_shape().lens()[dim_1]; rocblas_int k = args[0].get_shape().lens()[dim_1];
auto num_matrices = std::accumulate(out_lens.rbegin() + 2, auto num_matrices = std::accumulate(
out_lens.rend(), out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
std::size_t{1},
std::multiplies<std::size_t>());
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); }; auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
if (num_matrices == 1) if(num_matrices == 1)
{ {
generic_rocblas_gemm( generic_rocblas_gemm(as,
as,
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none, transa ? rocblas_operation_transpose : rocblas_operation_none,
......
...@@ -1054,7 +1054,6 @@ struct gemm_2args_mv : verify_program<gemm_2args_mv> ...@@ -1054,7 +1054,6 @@ struct gemm_2args_mv : verify_program<gemm_2args_mv>
auto l2 = p.add_parameter("2", m2_shape); auto l2 = p.add_parameter("2", m2_shape);
auto ul2 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l2); auto ul2 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l2);
p.add_instruction(migraphx::op::dot{}, l1, ul2); p.add_instruction(migraphx::op::dot{}, l1, ul2);
return p; return p;
...@@ -1113,7 +1112,6 @@ struct gemm_2args_vbm : verify_program<gemm_2args_vbm> ...@@ -1113,7 +1112,6 @@ struct gemm_2args_vbm : verify_program<gemm_2args_vbm>
auto res = p.add_instruction(migraphx::op::dot{}, bul1, l2); auto res = p.add_instruction(migraphx::op::dot{}, bul1, l2);
p.add_instruction(migraphx::op::squeeze{{2}}, res); p.add_instruction(migraphx::op::squeeze{{2}}, res);
return p; return p;
} }
}; };
......
...@@ -414,15 +414,19 @@ TEST_CASE(matmul) ...@@ -414,15 +414,19 @@ TEST_CASE(matmul)
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {6, 1, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {6, 1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {6, 5, 4}}; migraphx::shape s_m2{migraphx::shape::float_type, {6, 5, 4}};
expect_shape( expect_shape(migraphx::shape{migraphx::shape::float_type, {6, 1, 4}},
migraphx::shape{migraphx::shape::float_type, {6, 1, 4}}, migraphx::op::dot{}, s_m1, s_m2); migraphx::op::dot{},
s_m1,
s_m2);
} }
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {1, 6, 1, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {1, 6, 1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 6, 5, 4}}; migraphx::shape s_m2{migraphx::shape::float_type, {1, 6, 5, 4}};
expect_shape( expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 6, 1, 4}},
migraphx::shape{migraphx::shape::float_type, {1, 6, 1, 4}}, migraphx::op::dot{}, s_m1, s_m2); migraphx::op::dot{},
s_m1,
s_m2);
} }
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment