Commit ca28e1e8 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 02f359b2
...@@ -244,15 +244,15 @@ argument miopen_gemm::batch_matmul(context& ctx, ...@@ -244,15 +244,15 @@ argument miopen_gemm::batch_matmul(context& ctx,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& args) const const std::vector<argument>& args) const
{ {
bool transa = args[0].get_shape().transposed(); bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed(); bool transb = args[1].get_shape().transposed();
auto a_lens = args[0].get_shape().lens(); auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens(); auto b_lens = args[1].get_shape().lens();
auto out_lens = output_shape.lens(); auto out_lens = output_shape.lens();
auto an_dim = a_lens.size(); auto an_dim = a_lens.size();
auto bn_dim = b_lens.size(); auto bn_dim = b_lens.size();
auto outn_dim = out_lens.size(); auto outn_dim = out_lens.size();
rocblas_int lda = args[0].get_shape().strides()[transa ? an_dim - 1 : an_dim - 2]; rocblas_int lda = args[0].get_shape().strides()[transa ? an_dim - 1 : an_dim - 2];
...@@ -265,20 +265,24 @@ argument miopen_gemm::batch_matmul(context& ctx, ...@@ -265,20 +265,24 @@ argument miopen_gemm::batch_matmul(context& ctx,
std::vector<std::size_t> a_batch_lens(a_lens.begin(), a_lens.begin() + an_dim - 2); std::vector<std::size_t> a_batch_lens(a_lens.begin(), a_lens.begin() + an_dim - 2);
std::vector<std::size_t> b_batch_lens(b_lens.begin(), b_lens.begin() + bn_dim - 2); std::vector<std::size_t> b_batch_lens(b_lens.begin(), b_lens.begin() + bn_dim - 2);
if (a_batch_lens == b_batch_lens || a_batch_lens.empty() || b_batch_lens.empty()) if(a_batch_lens == b_batch_lens || a_batch_lens.empty() || b_batch_lens.empty())
{ {
std::size_t numa_matrices = std::size_t numa_matrices = std::accumulate(a_batch_lens.begin(),
std::accumulate(a_batch_lens.begin(), a_batch_lens.end(), std::size_t{1}, std::multiplies<std::size_t>()); a_batch_lens.end(),
std::size_t numb_matrices = std::size_t{1},
std::accumulate(b_batch_lens.begin(), b_batch_lens.end(), std::size_t{1}, std::multiplies<std::size_t>()); std::multiplies<std::size_t>());
std::size_t num_matrices = std::max(numa_matrices, numb_matrices); std::size_t numb_matrices = std::accumulate(b_batch_lens.begin(),
rocblas_int stride_a = (numa_matrices == 1) ? 0 : m * k; b_batch_lens.end(),
rocblas_int stride_b = (numb_matrices == 1) ? 0 : k * n; std::size_t{1},
rocblas_int stride_c = m * n; std::multiplies<std::size_t>());
std::size_t num_matrices = std::max(numa_matrices, numb_matrices);
rocblas_int stride_a = (numa_matrices == 1) ? 0 : m * k;
rocblas_int stride_b = (numb_matrices == 1) ? 0 : k * n;
rocblas_int stride_c = m * n;
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha)); auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(beta)); auto beta_r = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data()));}; auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
generic_rocblas_batched_gemm( generic_rocblas_batched_gemm(
as, as,
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
...@@ -313,19 +317,19 @@ argument miopen_gemm::batch_matmul(context& ctx, ...@@ -313,19 +317,19 @@ argument miopen_gemm::batch_matmul(context& ctx,
shape_for_each(out_batch_shape, [&](auto out_idx) { shape_for_each(out_batch_shape, [&](auto out_idx) {
std::size_t out_ind = out_batch_shape.index(out_idx.begin(), out_idx.end()); std::size_t out_ind = out_batch_shape.index(out_idx.begin(), out_idx.end());
auto type_size = output_shape.type_size(); auto type_size = output_shape.type_size();
std::vector<std::size_t> a_idx(a_batch_lens.size()); std::vector<std::size_t> a_idx(a_batch_lens.size());
std::vector<std::size_t> b_idx(b_batch_lens.size()); std::vector<std::size_t> b_idx(b_batch_lens.size());
std::transform(out_idx.begin() + a_len_diff, std::transform(out_idx.begin() + a_len_diff,
out_idx.end(), out_idx.end(),
a_batch_lens.begin(), a_batch_lens.begin(),
a_idx.begin(), a_idx.begin(),
[&](auto i, auto j) { return (j == 1) ? 0 : i; }); [&](auto i, auto j) { return (j == 1) ? 0 : i; });
std::transform(out_idx.begin() + b_len_diff, std::transform(out_idx.begin() + b_len_diff,
out_idx.end(), out_idx.end(),
b_batch_lens.begin(), b_batch_lens.begin(),
b_idx.begin(), b_idx.begin(),
[&](auto i, auto j) { return (j == 1) ? 0 : i; }); [&](auto i, auto j) { return (j == 1) ? 0 : i; });
std::size_t a_ind = a_batch_shape.index(a_idx.begin(), a_idx.end()); std::size_t a_ind = a_batch_shape.index(a_idx.begin(), a_idx.end());
std::size_t b_ind = b_batch_shape.index(b_idx.begin(), b_idx.end()); std::size_t b_ind = b_batch_shape.index(b_idx.begin(), b_idx.end());
...@@ -336,22 +340,21 @@ argument miopen_gemm::batch_matmul(context& ctx, ...@@ -336,22 +340,21 @@ argument miopen_gemm::batch_matmul(context& ctx,
auto to_pointer = [&](auto&& arg, std::size_t offset = 0) { auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
return to_rocblas_type(as.from(arg.data() + offset)); return to_rocblas_type(as.from(arg.data() + offset));
}; };
generic_rocblas_gemm( generic_rocblas_gemm(as,
as, ctx.get_stream().get_rocblas(),
ctx.get_stream().get_rocblas(), transb ? rocblas_operation_transpose : rocblas_operation_none,
transb ? rocblas_operation_transpose : rocblas_operation_none, transa ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none, n,
n, m,
m, k,
k, &alpha_r,
&alpha_r, to_pointer(args[1], k * n * b_ind * type_size),
to_pointer(args[1], k * n * b_ind * type_size), ldb,
ldb, to_pointer(args[0], m * k * a_ind * type_size),
to_pointer(args[0], m * k * a_ind * type_size), lda,
lda, &beta_r,
&beta_r, to_pointer(args[2], m * n * out_ind * type_size),
to_pointer(args[2], m * n * out_ind * type_size), ldc);
ldc);
}); });
}); });
} }
...@@ -447,9 +450,7 @@ argument miopen_gemm::compute(context& ctx, ...@@ -447,9 +450,7 @@ argument miopen_gemm::compute(context& ctx,
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha)); auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(beta)); auto beta_r = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg) { auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
return to_rocblas_type(as.from(arg.data()));
};
generic_rocblas_batched_gemm( generic_rocblas_batched_gemm(
as, as,
......
...@@ -22,7 +22,8 @@ struct miopen_gemm ...@@ -22,7 +22,8 @@ struct miopen_gemm
private: private:
void fill_result(const shape& output_shape, const argument& result, const argument& c) const; void fill_result(const shape& output_shape, const argument& result, const argument& c) const;
argument batch_matmul(context& ctx, const shape& output_shape, const std::vector<argument>& args) const; argument
batch_matmul(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
}; };
} // namespace gpu } // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment