Commit 359ec2f8 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 9d52515a
...@@ -810,7 +810,7 @@ struct gather ...@@ -810,7 +810,7 @@ struct gather
// The dot operation is combination of the onnx GEMM and MatMul operators. // The dot operation is combination of the onnx GEMM and MatMul operators.
// For GEMM, it support the C matrix in the formula alpha * AB + beta * C, // For GEMM, it support the C matrix in the formula alpha * AB + beta * C,
// in which C is broadcastable to the shape of AB. For the transpose of A // in which C is broadcastable to the shape of AB. For the transpose of A
// and B, we add a tranpose operator beforehand if the onnx gemm operator // and B, we add a tranpose operator beforehand if the onnx gemm operator
// indicates a transpose. // indicates a transpose.
// For MatMul, it has the same definition as the numpy.matmul, which means // For MatMul, it has the same definition as the numpy.matmul, which means
// A, B could be 1 to N-dims. For 1-dim input of A, it is a vector * matrix, // A, B could be 1 to N-dims. For 1-dim input of A, it is a vector * matrix,
...@@ -830,45 +830,46 @@ struct dot ...@@ -830,45 +830,46 @@ struct dot
return pack(f(self.alpha, "alpha"), f(self.beta, "beta")); return pack(f(self.alpha, "alpha"), f(self.beta, "beta"));
} }
std::vector<std::size_t> shape_broadcast(std::vector<std::size_t> &a, std::vector<std::size_t> &b) const std::vector<std::size_t> shape_broadcast(std::vector<std::size_t>& a,
std::vector<std::size_t>& b) const
{ {
if (a.empty()) if(a.empty())
return b; return b;
else if (b.empty()) else if(b.empty())
return a; return a;
auto a_size = a.size(); auto a_size = a.size();
auto b_size = b.size(); auto b_size = b.size();
auto n_dim = std::min(a_size, b_size); auto n_dim = std::min(a_size, b_size);
std::vector<std::size_t> out_lens(std::max(a_size, b_size)); std::vector<std::size_t> out_lens(std::max(a_size, b_size));
for (std::size_t i = 0; i < n_dim; ++i) for(std::size_t i = 0; i < n_dim; ++i)
{ {
if (a[a_size - 1 - i] == b[b_size - 1 - i]) if(a[a_size - 1 - i] == b[b_size - 1 - i])
{ {
out_lens[i] = a[a_size - 1 - i]; out_lens[i] = a[a_size - 1 - i];
} }
else if (a[a_size - 1 - i] == 1) else if(a[a_size - 1 - i] == 1)
{ {
out_lens[i] = b[b_size - 1 - i]; out_lens[i] = b[b_size - 1 - i];
} }
else if (b[b_size - 1 - i] == 1) else if(b[b_size - 1 - i] == 1)
{ {
out_lens[i] = a[a_size - 1 - i]; out_lens[i] = a[a_size - 1 - i];
} }
else else
{ {
MIGRAPHX_THROW("DOT : dimension mismatch, matrix A: {" + to_string_range(a) + MIGRAPHX_THROW("DOT : dimension mismatch, matrix A: {" + to_string_range(a) +
"}, and matrix B: {" + to_string_range(b) "}, and matrix B: {" + to_string_range(b) +
+ "} are not broadcastable"); "} are not broadcastable");
} }
} }
if (a_size > n_dim) if(a_size > n_dim)
{ {
std::copy(a.rbegin() + n_dim, a.rend(), out_lens.begin() + n_dim); std::copy(a.rbegin() + n_dim, a.rend(), out_lens.begin() + n_dim);
} }
if (b_size > n_dim) if(b_size > n_dim)
{ {
std::copy(b.rbegin() + n_dim, b.rend(), out_lens.rbegin() + n_dim); std::copy(b.rbegin() + n_dim, b.rend(), out_lens.rbegin() + n_dim);
} }
...@@ -886,7 +887,7 @@ struct dot ...@@ -886,7 +887,7 @@ struct dot
const shape& b = inputs.at(1); const shape& b = inputs.at(1);
auto t = a.type(); auto t = a.type();
if (a.scalar() || b.scalar()) if(a.scalar() || b.scalar())
{ {
MIGRAPHX_THROW("DOT: scalar operands are not allowed, use op::mul{} instead"); MIGRAPHX_THROW("DOT: scalar operands are not allowed, use op::mul{} instead");
} }
...@@ -894,26 +895,26 @@ struct dot ...@@ -894,26 +895,26 @@ struct dot
auto a_lens = a.lens(); auto a_lens = a.lens();
auto b_lens = b.lens(); auto b_lens = b.lens();
std::vector<std::size_t> out_lens; std::vector<std::size_t> out_lens;
if (a_lens.size() == 1) if(a_lens.size() == 1)
{ {
// inner product, output is a scalar, following numpy.matmul() // inner product, output is a scalar, following numpy.matmul()
if (b_lens.size() == 1) if(b_lens.size() == 1)
{ {
if (a_lens.front() != b_lens.front()) if(a_lens.front() != b_lens.front())
{ {
MIGRAPHX_THROW("DOT : dimension mismatch, vector A: {" + to_string_range(a_lens) + MIGRAPHX_THROW("DOT : dimension mismatch, vector A: {" +
"}, cannot multiply vector B: {" + to_string_range(b_lens) to_string_range(a_lens) + "}, cannot multiply vector B: {" +
+ "}"); to_string_range(b_lens) + "}");
} }
} }
else else
{ {
std::size_t dim_0 = b_lens.size() - 2; std::size_t dim_0 = b_lens.size() - 2;
if (a_lens.front() != b_lens[dim_0]) if(a_lens.front() != b_lens[dim_0])
{ {
MIGRAPHX_THROW("DOT : dimension mismatch, vector A: {" + to_string_range(a_lens) + MIGRAPHX_THROW("DOT : dimension mismatch, vector A: {" +
"}, cannot multiply matrix B: {" + to_string_range(b_lens) to_string_range(a_lens) + "}, cannot multiply matrix B: {" +
+ "}"); to_string_range(b_lens) + "}");
} }
out_lens = b_lens; out_lens = b_lens;
...@@ -923,15 +924,15 @@ struct dot ...@@ -923,15 +924,15 @@ struct dot
else else
{ {
std::size_t dim_0 = a_lens.size() - 1; std::size_t dim_0 = a_lens.size() - 1;
if (b_lens.size() == 1) if(b_lens.size() == 1)
{ {
if (a_lens.back() != b_lens.back()) if(a_lens.back() != b_lens.back())
{ {
MIGRAPHX_THROW("DOT : dimension mismatch, matrix A: {" + to_string_range(a_lens) + MIGRAPHX_THROW("DOT : dimension mismatch, matrix A: {" +
"}, cannot multiply vector B: {" + to_string_range(b_lens) to_string_range(a_lens) + "}, cannot multiply vector B: {" +
+ "}"); to_string_range(b_lens) + "}");
} }
out_lens = a_lens; out_lens = a_lens;
out_lens.pop_back(); out_lens.pop_back();
} }
...@@ -939,11 +940,11 @@ struct dot ...@@ -939,11 +940,11 @@ struct dot
{ {
std::size_t dim_0 = a_lens.size() - 1; std::size_t dim_0 = a_lens.size() - 1;
std::size_t dim_1 = b_lens.size() - 2; std::size_t dim_1 = b_lens.size() - 2;
if (a_lens[dim_0] != b_lens[dim_1]) if(a_lens[dim_0] != b_lens[dim_1])
{ {
MIGRAPHX_THROW("DOT : dimension mismatch, matrix A: {" + to_string_range(a_lens) + MIGRAPHX_THROW("DOT : dimension mismatch, matrix A: {" +
"}, cannot multiply matrix B: {" + to_string_range(b_lens) to_string_range(a_lens) + "}, cannot multiply matrix B: {" +
+ "}"); to_string_range(b_lens) + "}");
} }
a_lens.pop_back(); a_lens.pop_back();
...@@ -959,17 +960,17 @@ struct dot ...@@ -959,17 +960,17 @@ struct dot
out_lens.push_back(out_n); out_lens.push_back(out_n);
} }
} }
// c is broadcast // c is broadcast
if (inputs.size() == 3) if(inputs.size() == 3)
// according to the specification of the numpy.matmul() // according to the specification of the numpy.matmul()
// inputs with the shape dims more than 2 are acceptable // inputs with the shape dims more than 2 are acceptable
// as long as dim values are the same in the two inputs // as long as dim values are the same in the two inputs
if(!std::equal(a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2)) if(!std::equal(a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2))
{ {
MIGRAPHX_THROW("DOT: number of matrices in stack are different in A and B"); MIGRAPHX_THROW("DOT: number of matrices in stack are different in A and B");
} }
if(inputs.size() == 3) if(inputs.size() == 3)
{ {
......
...@@ -127,26 +127,25 @@ argument miopen_gemm::compute(context& ctx, ...@@ -127,26 +127,25 @@ argument miopen_gemm::compute(context& ctx,
auto alpha_r = to_rocblas_type(as(op.alpha)); auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(op.beta)); auto beta_r = to_rocblas_type(as(op.beta));
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); }; auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
generic_rocblas_batched_gemm( generic_rocblas_batched_gemm(as,
as, ctx.get_stream().get_rocblas(),
ctx.get_stream().get_rocblas(), transb ? rocblas_operation_transpose : rocblas_operation_none,
transb ? rocblas_operation_transpose : rocblas_operation_none, transa ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none, n,
n, m,
m, k,
k, &alpha_r,
&alpha_r, to_pointer(args[1]),
to_pointer(args[1]), ldb,
ldb, k * n,
k * n, to_pointer(args[0]),
to_pointer(args[0]), lda,
lda, m * k,
m * k, &beta_r,
&beta_r, to_pointer(args[2]),
to_pointer(args[2]), ldc,
ldc, m * n,
m * n, batch_num);
batch_num);
}); });
return (is_3inputs ? args[3] : args[2]); return (is_3inputs ? args[3] : args[2]);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment