Commit f3ea46e5 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 6ec90d65
...@@ -176,8 +176,8 @@ void miopen_gemm::fill_result(const shape& output_shape, ...@@ -176,8 +176,8 @@ void miopen_gemm::fill_result(const shape& output_shape,
const argument& result, const argument& result,
const argument& c) const const argument& c) const
{ {
auto out_lens = output_shape.lens(); auto out_lens = output_shape.lens();
auto c_lens = c.get_shape().lens(); auto c_lens = c.get_shape().lens();
auto type_size = output_shape.type_size(); auto type_size = output_shape.type_size();
if(output_shape == c.get_shape()) if(output_shape == c.get_shape())
{ {
...@@ -262,9 +262,9 @@ argument miopen_gemm::compute(context& ctx, ...@@ -262,9 +262,9 @@ argument miopen_gemm::compute(context& ctx,
rocblas_int n = out_lens[1]; rocblas_int n = out_lens[1];
rocblas_int k = args[0].get_shape().lens()[1]; rocblas_int k = args[0].get_shape().lens()[1];
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); }; auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
auto cpu_a = migraphx::gpu::from_gpu(args[0]); auto cpu_a = migraphx::gpu::from_gpu(args[0]);
auto cpu_b = migraphx::gpu::from_gpu(args[1]); auto cpu_b = migraphx::gpu::from_gpu(args[1]);
auto cpu_res = migraphx::gpu::from_gpu(args[3]); auto cpu_res = migraphx::gpu::from_gpu(args[3]);
std::cout << "gpu::gemm, cpu_a = " << cpu_a << std::endl; std::cout << "gpu::gemm, cpu_a = " << cpu_a << std::endl;
std::cout << "gpu::gemm, cpu_b = " << cpu_b << std::endl; std::cout << "gpu::gemm, cpu_b = " << cpu_b << std::endl;
std::cout << "gpu::gemm, cpu_res = " << cpu_res << std::endl; std::cout << "gpu::gemm, cpu_res = " << cpu_res << std::endl;
...@@ -277,7 +277,6 @@ argument miopen_gemm::compute(context& ctx, ...@@ -277,7 +277,6 @@ argument miopen_gemm::compute(context& ctx,
std::cout << "gpu::gemm, ldb = " << ldb << std::endl; std::cout << "gpu::gemm, ldb = " << ldb << std::endl;
std::cout << "gpu::gemm, ldc = " << ldc << std::endl; std::cout << "gpu::gemm, ldc = " << ldc << std::endl;
generic_rocblas_gemm(as, generic_rocblas_gemm(as,
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
...@@ -372,17 +371,17 @@ argument miopen_gemm::compute(context& ctx, ...@@ -372,17 +371,17 @@ argument miopen_gemm::compute(context& ctx,
rocblas_int lda = a_lens[0]; rocblas_int lda = a_lens[0];
rocblas_int ldb = args[1].get_shape().strides()[(transb ? dim_1 : dim_0)]; rocblas_int ldb = args[1].get_shape().strides()[(transb ? dim_1 : dim_0)];
rocblas_int ldc = b_lens[dim_1]; rocblas_int ldc = b_lens[dim_1];
rocblas_int m = 1; rocblas_int m = 1;
rocblas_int n = args[1].get_shape().lens()[dim_1]; rocblas_int n = args[1].get_shape().lens()[dim_1];
rocblas_int k = a_lens[0]; rocblas_int k = a_lens[0];
float beta = 0.0f; float beta = 0.0f;
assert(b_lens[dim_0] == args[0].get_shape().elements()); assert(b_lens[dim_0] == args[0].get_shape().elements());
std::size_t batch_num = std::accumulate( std::size_t batch_num = std::accumulate(
b_lens.rbegin() + 2, b_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()); b_lens.rbegin() + 2, b_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
auto cpu_a = migraphx::gpu::from_gpu(args[0]); auto cpu_a = migraphx::gpu::from_gpu(args[0]);
auto cpu_b = migraphx::gpu::from_gpu(args[1]); auto cpu_b = migraphx::gpu::from_gpu(args[1]);
auto cpu_res = migraphx::gpu::from_gpu(args[2]); auto cpu_res = migraphx::gpu::from_gpu(args[2]);
std::cout << "gpu::gemm, cpu_a = " << cpu_a << std::endl; std::cout << "gpu::gemm, cpu_a = " << cpu_a << std::endl;
std::cout << "gpu::gemm, cpu_b = " << cpu_b << std::endl; std::cout << "gpu::gemm, cpu_b = " << cpu_b << std::endl;
...@@ -404,30 +403,28 @@ argument miopen_gemm::compute(context& ctx, ...@@ -404,30 +403,28 @@ argument miopen_gemm::compute(context& ctx,
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha)); auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(beta)); auto beta_r = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg) { auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
return to_rocblas_type(as.from(arg.data()));
};
generic_rocblas_batched_gemm( generic_rocblas_batched_gemm(
as, as,
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none, transa ? rocblas_operation_transpose : rocblas_operation_none,
n, n,
m, m,
k, k,
&alpha_r, &alpha_r,
to_pointer(args[1]), to_pointer(args[1]),
ldb, ldb,
k * n, k * n,
to_pointer(args[0]), to_pointer(args[0]),
lda, lda,
0, 0,
&beta_r, &beta_r,
to_pointer(args[2]), to_pointer(args[2]),
ldc, ldc,
m * n, m * n,
batch_num); batch_num);
}); });
return args[2]; return args[2];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment