Commit 360db15f authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent f9c38c09
...@@ -907,27 +907,30 @@ struct dot ...@@ -907,27 +907,30 @@ struct dot
auto b_lens = inputs[1].lens(); auto b_lens = inputs[1].lens();
auto out_lens = a_lens; auto out_lens = a_lens;
auto t = inputs[0].type(); auto t = inputs[0].type();
if (inputs[1].lens().size() > 2) if(inputs[1].lens().size() > 2)
{ {
if(!std::equal(a_lens.rbegin() + 2, a_lens.rend(), b_lens.rbegin() + 2)) if(!std::equal(a_lens.rbegin() + 2, a_lens.rend(), b_lens.rbegin() + 2))
{ {
MIGRAPHX_THROW("DOT: dimension mismatch, operand A: {" + to_string_range(a_lens) + MIGRAPHX_THROW("DOT: dimension mismatch, operand A: {" +
"}, cannot multiply operand B: {" + to_string_range(b_lens) + "}"); to_string_range(a_lens) + "}, cannot multiply operand B: {" +
to_string_range(b_lens) + "}");
} }
std::size_t dim_0 = a_lens.size() - 2; std::size_t dim_0 = a_lens.size() - 2;
std::size_t dim_1 = a_lens.size() - 1; std::size_t dim_1 = a_lens.size() - 1;
if(a_lens[dim_1] != b_lens[dim_0]) if(a_lens[dim_1] != b_lens[dim_0])
MIGRAPHX_THROW("Inner dimensions do not match, operand A: {" + to_string_range(a_lens) + MIGRAPHX_THROW("Inner dimensions do not match, operand A: {" +
"}, operand B: {" + to_string_range(b_lens) + "}"); to_string_range(a_lens) + "}, operand B: {" +
to_string_range(b_lens) + "}");
out_lens[dim_1] = b_lens[dim_1]; out_lens[dim_1] = b_lens[dim_1];
// C should be the same shape as A * B // C should be the same shape as A * B
auto c_lens = inputs[2].lens(); auto c_lens = inputs[2].lens();
if(!std::equal(c_lens.begin(), c_lens.end(), out_lens.begin())) if(!std::equal(c_lens.begin(), c_lens.end(), out_lens.begin()))
{ {
MIGRAPHX_THROW("DOT: dimension mismatch, operand C: {" + to_string_range(c_lens) + MIGRAPHX_THROW("DOT: dimension mismatch, operand C: {" +
"}, cannot add to operand A * B: {" + to_string_range(out_lens) + "}"); to_string_range(c_lens) + "}, cannot add to operand A * B: {" +
to_string_range(out_lens) + "}");
} }
} }
else else
...@@ -938,8 +941,9 @@ struct dot ...@@ -938,8 +941,9 @@ struct dot
if(a_lens[1] != b_lens[0]) if(a_lens[1] != b_lens[0])
{ {
MIGRAPHX_THROW("DOT : dimension mismatch, operand A: {" + to_string_range(a_lens) + MIGRAPHX_THROW("DOT : dimension mismatch, operand A: {" +
"}, cannot multiply operand B: {" + to_string_range(b_lens) + "}"); to_string_range(a_lens) + "}, cannot multiply operand B: {" +
to_string_range(b_lens) + "}");
} }
out_lens[1] = b_lens[1]; out_lens[1] = b_lens[1];
......
...@@ -385,9 +385,13 @@ argument miopen_gemm::compute(context& ctx, ...@@ -385,9 +385,13 @@ argument miopen_gemm::compute(context& ctx,
rocblas_int m = out_lens[dim_0]; rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1]; rocblas_int n = out_lens[dim_1];
rocblas_int k = args[0].get_shape().lens()[dim_1]; rocblas_int k = args[0].get_shape().lens()[dim_1];
auto num_matrices = std::accumulate(out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()); auto num_matrices = std::accumulate(out_lens.rbegin() + 2,
out_lens.rend(),
std::size_t{1},
std::multiplies<std::size_t>());
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); }; auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
generic_rocblas_batched_gemm(as, generic_rocblas_batched_gemm(
as,
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none, transa ? rocblas_operation_transpose : rocblas_operation_none,
......
...@@ -424,14 +424,18 @@ TEST_CASE(dot) ...@@ -424,14 +424,18 @@ TEST_CASE(dot)
migraphx::shape s_m1{migraphx::shape::float_type, {1, 1, 4, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {1, 1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 2, 5, 7}}; migraphx::shape s_m2{migraphx::shape::float_type, {1, 2, 5, 7}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 2, 4, 7}}, expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 2, 4, 7}},
migraphx::op::dot{}, s_m1, s_m2); migraphx::op::dot{},
s_m1,
s_m2);
} }
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {1, 2, 4, 5}}; migraphx::shape s_m1{migraphx::shape::float_type, {1, 2, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {2, 1, 5, 7}}; migraphx::shape s_m2{migraphx::shape::float_type, {2, 1, 5, 7}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 2, 4, 7}}, expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 2, 4, 7}},
migraphx::op::dot{}, s_m1, s_m2); migraphx::op::dot{},
s_m1,
s_m2);
} }
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment