Commit 97f96f78 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 59e4f91c
...@@ -72,20 +72,20 @@ argument miopen_quant_gemm::compute(context& ctx, ...@@ -72,20 +72,20 @@ argument miopen_quant_gemm::compute(context& ctx,
auto a_lens = args[0].get_shape().lens(); auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens(); auto b_lens = args[1].get_shape().lens();
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto n_dim = output_shape.lens().size(); auto n_dim = output_shape.lens().size();
auto dim_1 = n_dim - 1; auto dim_1 = n_dim - 1;
auto dim_0 = n_dim - 2; auto dim_0 = n_dim - 2;
auto alpha_r = to_rocblas_type(as(op.alpha)); auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(beta)); auto beta_r = to_rocblas_type(as(beta));
bool transa = args[0].get_shape().transposed(); bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed(); bool transb = args[1].get_shape().transposed();
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0]; rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0]; rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = args[2].get_shape().strides()[dim_0]; rocblas_int ldc = args[2].get_shape().strides()[dim_0];
auto out_lens = output_shape.lens(); auto out_lens = output_shape.lens();
rocblas_int m = out_lens[dim_0]; rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1]; rocblas_int n = out_lens[dim_1];
rocblas_int k = args[0].get_shape().lens()[dim_1]; rocblas_int k = args[0].get_shape().lens()[dim_1];
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); }; auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
assert(k % 4 == 0); assert(k % 4 == 0);
assert(transa && (lda % 4 == 0)); assert(transa && (lda % 4 == 0));
...@@ -96,28 +96,31 @@ argument miopen_quant_gemm::compute(context& ctx, ...@@ -96,28 +96,31 @@ argument miopen_quant_gemm::compute(context& ctx,
if(num_matrices == 1) if(num_matrices == 1)
{ {
generic_rocblas_gemm_ex(ctx.get_stream().get_rocblas(), generic_rocblas_gemm_ex(ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none, transa ? rocblas_operation_transpose : rocblas_operation_none,
n, n,
m, m,
k, k,
&alpha_r, &alpha_r,
to_pointer(args[1]), to_pointer(args[1]),
rocblas_datatype_i8_r, rocblas_datatype_i8_r,
ldb, ldb,
to_pointer(args[0]), to_pointer(args[0]),
rocblas_datatype_i8_r, rocblas_datatype_i8_r,
lda, lda,
&beta_r, &beta_r,
to_pointer(args[2]), to_pointer(args[2]),
rocblas_datatype_i32_r, rocblas_datatype_i32_r,
ldc, ldc,
(is_3inputs ? to_pointer(args[3]) : to_pointer(args[2])), (is_3inputs ? to_pointer(args[3]) : to_pointer(args[2])),
rocblas_datatype_i32_r, rocblas_datatype_i32_r,
ldc, ldc,
rocblas_datatype_i32_r, rocblas_datatype_i32_r,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
0, 0, nullptr, nullptr); 0,
0,
nullptr,
nullptr);
} }
else else
{ {
...@@ -149,7 +152,10 @@ argument miopen_quant_gemm::compute(context& ctx, ...@@ -149,7 +152,10 @@ argument miopen_quant_gemm::compute(context& ctx,
num_matrices, num_matrices,
rocblas_datatype_i32_r, rocblas_datatype_i32_r,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
0, 0, nullptr, nullptr); 0,
0,
nullptr,
nullptr);
} }
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment