Commit ca28e1e8 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 02f359b2
...@@ -265,12 +265,16 @@ argument miopen_gemm::batch_matmul(context& ctx, ...@@ -265,12 +265,16 @@ argument miopen_gemm::batch_matmul(context& ctx,
std::vector<std::size_t> a_batch_lens(a_lens.begin(), a_lens.begin() + an_dim - 2); std::vector<std::size_t> a_batch_lens(a_lens.begin(), a_lens.begin() + an_dim - 2);
std::vector<std::size_t> b_batch_lens(b_lens.begin(), b_lens.begin() + bn_dim - 2); std::vector<std::size_t> b_batch_lens(b_lens.begin(), b_lens.begin() + bn_dim - 2);
if (a_batch_lens == b_batch_lens || a_batch_lens.empty() || b_batch_lens.empty()) if(a_batch_lens == b_batch_lens || a_batch_lens.empty() || b_batch_lens.empty())
{ {
std::size_t numa_matrices = std::size_t numa_matrices = std::accumulate(a_batch_lens.begin(),
std::accumulate(a_batch_lens.begin(), a_batch_lens.end(), std::size_t{1}, std::multiplies<std::size_t>()); a_batch_lens.end(),
std::size_t numb_matrices = std::size_t{1},
std::accumulate(b_batch_lens.begin(), b_batch_lens.end(), std::size_t{1}, std::multiplies<std::size_t>()); std::multiplies<std::size_t>());
std::size_t numb_matrices = std::accumulate(b_batch_lens.begin(),
b_batch_lens.end(),
std::size_t{1},
std::multiplies<std::size_t>());
std::size_t num_matrices = std::max(numa_matrices, numb_matrices); std::size_t num_matrices = std::max(numa_matrices, numb_matrices);
rocblas_int stride_a = (numa_matrices == 1) ? 0 : m * k; rocblas_int stride_a = (numa_matrices == 1) ? 0 : m * k;
rocblas_int stride_b = (numb_matrices == 1) ? 0 : k * n; rocblas_int stride_b = (numb_matrices == 1) ? 0 : k * n;
...@@ -278,7 +282,7 @@ argument miopen_gemm::batch_matmul(context& ctx, ...@@ -278,7 +282,7 @@ argument miopen_gemm::batch_matmul(context& ctx,
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha)); auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(beta)); auto beta_r = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data()));}; auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
generic_rocblas_batched_gemm( generic_rocblas_batched_gemm(
as, as,
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
...@@ -336,8 +340,7 @@ argument miopen_gemm::batch_matmul(context& ctx, ...@@ -336,8 +340,7 @@ argument miopen_gemm::batch_matmul(context& ctx,
auto to_pointer = [&](auto&& arg, std::size_t offset = 0) { auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
return to_rocblas_type(as.from(arg.data() + offset)); return to_rocblas_type(as.from(arg.data() + offset));
}; };
generic_rocblas_gemm( generic_rocblas_gemm(as,
as,
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none, transa ? rocblas_operation_transpose : rocblas_operation_none,
...@@ -447,9 +450,7 @@ argument miopen_gemm::compute(context& ctx, ...@@ -447,9 +450,7 @@ argument miopen_gemm::compute(context& ctx,
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha)); auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(beta)); auto beta_r = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg) { auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
return to_rocblas_type(as.from(arg.data()));
};
generic_rocblas_batched_gemm( generic_rocblas_batched_gemm(
as, as,
......
...@@ -22,7 +22,8 @@ struct miopen_gemm ...@@ -22,7 +22,8 @@ struct miopen_gemm
private: private:
void fill_result(const shape& output_shape, const argument& result, const argument& c) const; void fill_result(const shape& output_shape, const argument& result, const argument& c) const;
argument batch_matmul(context& ctx, const shape& output_shape, const std::vector<argument>& args) const; argument
batch_matmul(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
}; };
} // namespace gpu } // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment