Commit ca28e1e8 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 02f359b2
......@@ -244,15 +244,15 @@ argument miopen_gemm::batch_matmul(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed();
bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed();
auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens();
auto out_lens = output_shape.lens();
auto an_dim = a_lens.size();
auto bn_dim = b_lens.size();
auto an_dim = a_lens.size();
auto bn_dim = b_lens.size();
auto outn_dim = out_lens.size();
rocblas_int lda = args[0].get_shape().strides()[transa ? an_dim - 1 : an_dim - 2];
......@@ -265,20 +265,24 @@ argument miopen_gemm::batch_matmul(context& ctx,
std::vector<std::size_t> a_batch_lens(a_lens.begin(), a_lens.begin() + an_dim - 2);
std::vector<std::size_t> b_batch_lens(b_lens.begin(), b_lens.begin() + bn_dim - 2);
if (a_batch_lens == b_batch_lens || a_batch_lens.empty() || b_batch_lens.empty())
if(a_batch_lens == b_batch_lens || a_batch_lens.empty() || b_batch_lens.empty())
{
std::size_t numa_matrices =
std::accumulate(a_batch_lens.begin(), a_batch_lens.end(), std::size_t{1}, std::multiplies<std::size_t>());
std::size_t numb_matrices =
std::accumulate(b_batch_lens.begin(), b_batch_lens.end(), std::size_t{1}, std::multiplies<std::size_t>());
std::size_t num_matrices = std::max(numa_matrices, numb_matrices);
rocblas_int stride_a = (numa_matrices == 1) ? 0 : m * k;
rocblas_int stride_b = (numb_matrices == 1) ? 0 : k * n;
rocblas_int stride_c = m * n;
std::size_t numa_matrices = std::accumulate(a_batch_lens.begin(),
a_batch_lens.end(),
std::size_t{1},
std::multiplies<std::size_t>());
std::size_t numb_matrices = std::accumulate(b_batch_lens.begin(),
b_batch_lens.end(),
std::size_t{1},
std::multiplies<std::size_t>());
std::size_t num_matrices = std::max(numa_matrices, numb_matrices);
rocblas_int stride_a = (numa_matrices == 1) ? 0 : m * k;
rocblas_int stride_b = (numb_matrices == 1) ? 0 : k * n;
rocblas_int stride_c = m * n;
output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data()));};
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
generic_rocblas_batched_gemm(
as,
ctx.get_stream().get_rocblas(),
......@@ -313,19 +317,19 @@ argument miopen_gemm::batch_matmul(context& ctx,
shape_for_each(out_batch_shape, [&](auto out_idx) {
std::size_t out_ind = out_batch_shape.index(out_idx.begin(), out_idx.end());
auto type_size = output_shape.type_size();
auto type_size = output_shape.type_size();
std::vector<std::size_t> a_idx(a_batch_lens.size());
std::vector<std::size_t> b_idx(b_batch_lens.size());
std::transform(out_idx.begin() + a_len_diff,
out_idx.end(),
a_batch_lens.begin(),
a_idx.begin(),
[&](auto i, auto j) { return (j == 1) ? 0 : i; });
out_idx.end(),
a_batch_lens.begin(),
a_idx.begin(),
[&](auto i, auto j) { return (j == 1) ? 0 : i; });
std::transform(out_idx.begin() + b_len_diff,
out_idx.end(),
b_batch_lens.begin(),
b_idx.begin(),
[&](auto i, auto j) { return (j == 1) ? 0 : i; });
out_idx.end(),
b_batch_lens.begin(),
b_idx.begin(),
[&](auto i, auto j) { return (j == 1) ? 0 : i; });
std::size_t a_ind = a_batch_shape.index(a_idx.begin(), a_idx.end());
std::size_t b_ind = b_batch_shape.index(b_idx.begin(), b_idx.end());
......@@ -336,22 +340,21 @@ argument miopen_gemm::batch_matmul(context& ctx,
auto to_pointer = [&](auto&& arg, std::size_t offset = 0) {
return to_rocblas_type(as.from(arg.data() + offset));
};
generic_rocblas_gemm(
as,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
to_pointer(args[1], k * n * b_ind * type_size),
ldb,
to_pointer(args[0], m * k * a_ind * type_size),
lda,
&beta_r,
to_pointer(args[2], m * n * out_ind * type_size),
ldc);
generic_rocblas_gemm(as,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
to_pointer(args[1], k * n * b_ind * type_size),
ldb,
to_pointer(args[0], m * k * a_ind * type_size),
lda,
&beta_r,
to_pointer(args[2], m * n * out_ind * type_size),
ldc);
});
});
}
......@@ -447,9 +450,7 @@ argument miopen_gemm::compute(context& ctx,
output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg) {
return to_rocblas_type(as.from(arg.data()));
};
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
generic_rocblas_batched_gemm(
as,
......
......@@ -22,7 +22,8 @@ struct miopen_gemm
private:
void fill_result(const shape& output_shape, const argument& result, const argument& c) const;
argument batch_matmul(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
argument
batch_matmul(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
};
} // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment