Commit 45e82cad authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 154b9287
...@@ -857,7 +857,8 @@ struct dot ...@@ -857,7 +857,8 @@ struct dot
out_lens[dim_1] = b.lens()[dim_1]; out_lens[dim_1] = b.lens()[dim_1];
if(inputs.size() == 3 && out_lens != inputs.at(2).lens()) if(inputs.size() == 3 && out_lens != inputs.at(2).lens())
{ {
MIGRAPHX_THROW("DOT: dimension mismatch, operand C: {" + to_string_range(inputs.at(2).lens()) + MIGRAPHX_THROW("DOT: dimension mismatch, operand C: {" +
to_string_range(inputs.at(2).lens()) +
"}, cannot add to operand A * B: {" + to_string_range(out_lens) + "}"); "}, cannot add to operand A * B: {" + to_string_range(out_lens) + "}");
} }
......
...@@ -274,60 +274,58 @@ argument miopen_gemm::compute(context& ctx, ...@@ -274,60 +274,58 @@ argument miopen_gemm::compute(context& ctx,
// matrix * vector (b is a vector) // matrix * vector (b is a vector)
else if(b_lens.size() == 2 && b_lens.at(1) == 1) else if(b_lens.size() == 2 && b_lens.at(1) == 1)
{ {
bool transa = args[0].get_shape().transposed(); bool transa = args[0].get_shape().transposed();
rocblas_int m = static_cast<rocblas_int>(a_lens[0]); rocblas_int m = static_cast<rocblas_int>(a_lens[0]);
rocblas_int n = static_cast<rocblas_int>(a_lens[1]); rocblas_int n = static_cast<rocblas_int>(a_lens[1]);
rocblas_int lda = args[0].get_shape().strides()[transa ? 1 : 0]; rocblas_int lda = args[0].get_shape().strides()[transa ? 1 : 0];
float beta = 0.0f; float beta = 0.0f;
assert(a_lens.back() == args[1].get_shape().elements()); assert(a_lens.back() == args[1].get_shape().elements());
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha)); auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(beta)); auto beta_r = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); }; auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
generic_rocblas_gemv( generic_rocblas_gemv(as,
as, ctx.get_stream().get_rocblas(),
ctx.get_stream().get_rocblas(), transa ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none, m,
m, n,
n, &alpha_r,
&alpha_r, to_pointer(args[0]),
to_pointer(args[0]), lda,
lda, to_pointer(args[1]),
to_pointer(args[1]), 1,
1, &beta_r,
&beta_r, to_pointer(args[2]),
to_pointer(args[2]), 1);
1);
}); });
} }
// vector * matrix (a is a vector) // vector * matrix (a is a vector)
else if(a_lens.size() == 2 && a_lens.at(0) == 1) else if(a_lens.size() == 2 && a_lens.at(0) == 1)
{ {
bool transb = !args[1].get_shape().transposed(); bool transb = !args[1].get_shape().transposed();
rocblas_int ldb = args[1].get_shape().strides()[(transb ? 1 : 0)]; rocblas_int ldb = args[1].get_shape().strides()[(transb ? 1 : 0)];
rocblas_int m = b_lens[0]; rocblas_int m = b_lens[0];
rocblas_int n = b_lens[1]; rocblas_int n = b_lens[1];
float beta = 0.0f; float beta = 0.0f;
assert(b_lens[0] == args[0].get_shape().elements()); assert(b_lens[0] == args[0].get_shape().elements());
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha)); auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(beta)); auto beta_r = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); }; auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
generic_rocblas_gemv( generic_rocblas_gemv(as,
as, ctx.get_stream().get_rocblas(),
ctx.get_stream().get_rocblas(), transb ? rocblas_operation_transpose : rocblas_operation_none,
transb ? rocblas_operation_transpose : rocblas_operation_none, m,
m, n,
n, &alpha_r,
&alpha_r, to_pointer(args[1]),
to_pointer(args[1]), ldb,
ldb, to_pointer(args[0]),
to_pointer(args[0]), 1,
1, &beta_r,
&beta_r, to_pointer(args[2]),
to_pointer(args[2]), 1);
1);
}); });
} }
// batch matrix multiplication // batch matrix multiplication
...@@ -337,7 +335,7 @@ argument miopen_gemm::compute(context& ctx, ...@@ -337,7 +335,7 @@ argument miopen_gemm::compute(context& ctx,
auto n_dim = output_shape.lens().size(); auto n_dim = output_shape.lens().size();
auto dim_1 = n_dim - 1; auto dim_1 = n_dim - 1;
auto dim_0 = n_dim - 2; auto dim_0 = n_dim - 2;
float beta = 0.0f; float beta = 0.0f;
auto alpha_r = to_rocblas_type(as(op.alpha)); auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(beta)); auto beta_r = to_rocblas_type(as(beta));
bool transa = args[0].get_shape().transposed(); bool transa = args[0].get_shape().transposed();
...@@ -374,9 +372,9 @@ argument miopen_gemm::compute(context& ctx, ...@@ -374,9 +372,9 @@ argument miopen_gemm::compute(context& ctx,
ldc, ldc,
m * n, m * n,
num_matrices); num_matrices);
}); });
} }
return args[2]; return args[2];
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment