"vscode:/vscode.git/clone" did not exist on "eb06cc6bd5ea01a3bc3ef535bf463b9289f84c94"
Commit 359ec2f8 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 9d52515a
...@@ -830,45 +830,46 @@ struct dot ...@@ -830,45 +830,46 @@ struct dot
return pack(f(self.alpha, "alpha"), f(self.beta, "beta")); return pack(f(self.alpha, "alpha"), f(self.beta, "beta"));
} }
std::vector<std::size_t> shape_broadcast(std::vector<std::size_t> &a, std::vector<std::size_t> &b) const std::vector<std::size_t> shape_broadcast(std::vector<std::size_t>& a,
std::vector<std::size_t>& b) const
{ {
if (a.empty()) if(a.empty())
return b; return b;
else if (b.empty()) else if(b.empty())
return a; return a;
auto a_size = a.size(); auto a_size = a.size();
auto b_size = b.size(); auto b_size = b.size();
auto n_dim = std::min(a_size, b_size); auto n_dim = std::min(a_size, b_size);
std::vector<std::size_t> out_lens(std::max(a_size, b_size)); std::vector<std::size_t> out_lens(std::max(a_size, b_size));
for (std::size_t i = 0; i < n_dim; ++i) for(std::size_t i = 0; i < n_dim; ++i)
{ {
if (a[a_size - 1 - i] == b[b_size - 1 - i]) if(a[a_size - 1 - i] == b[b_size - 1 - i])
{ {
out_lens[i] = a[a_size - 1 - i]; out_lens[i] = a[a_size - 1 - i];
} }
else if (a[a_size - 1 - i] == 1) else if(a[a_size - 1 - i] == 1)
{ {
out_lens[i] = b[b_size - 1 - i]; out_lens[i] = b[b_size - 1 - i];
} }
else if (b[b_size - 1 - i] == 1) else if(b[b_size - 1 - i] == 1)
{ {
out_lens[i] = a[a_size - 1 - i]; out_lens[i] = a[a_size - 1 - i];
} }
else else
{ {
MIGRAPHX_THROW("DOT : dimension mismatch, matrix A: {" + to_string_range(a) + MIGRAPHX_THROW("DOT : dimension mismatch, matrix A: {" + to_string_range(a) +
"}, and matrix B: {" + to_string_range(b) "}, and matrix B: {" + to_string_range(b) +
+ "} are not broadcastable"); "} are not broadcastable");
} }
} }
if (a_size > n_dim) if(a_size > n_dim)
{ {
std::copy(a.rbegin() + n_dim, a.rend(), out_lens.begin() + n_dim); std::copy(a.rbegin() + n_dim, a.rend(), out_lens.begin() + n_dim);
} }
if (b_size > n_dim) if(b_size > n_dim)
{ {
std::copy(b.rbegin() + n_dim, b.rend(), out_lens.rbegin() + n_dim); std::copy(b.rbegin() + n_dim, b.rend(), out_lens.rbegin() + n_dim);
} }
...@@ -886,7 +887,7 @@ struct dot ...@@ -886,7 +887,7 @@ struct dot
const shape& b = inputs.at(1); const shape& b = inputs.at(1);
auto t = a.type(); auto t = a.type();
if (a.scalar() || b.scalar()) if(a.scalar() || b.scalar())
{ {
MIGRAPHX_THROW("DOT: scalar operands are not allowed, use op::mul{} instead"); MIGRAPHX_THROW("DOT: scalar operands are not allowed, use op::mul{} instead");
} }
...@@ -894,26 +895,26 @@ struct dot ...@@ -894,26 +895,26 @@ struct dot
auto a_lens = a.lens(); auto a_lens = a.lens();
auto b_lens = b.lens(); auto b_lens = b.lens();
std::vector<std::size_t> out_lens; std::vector<std::size_t> out_lens;
if (a_lens.size() == 1) if(a_lens.size() == 1)
{ {
// inner product, output is a scalar, following numpy.matmul() // inner product, output is a scalar, following numpy.matmul()
if (b_lens.size() == 1) if(b_lens.size() == 1)
{ {
if (a_lens.front() != b_lens.front()) if(a_lens.front() != b_lens.front())
{ {
MIGRAPHX_THROW("DOT : dimension mismatch, vector A: {" + to_string_range(a_lens) + MIGRAPHX_THROW("DOT : dimension mismatch, vector A: {" +
"}, cannot multiply vector B: {" + to_string_range(b_lens) to_string_range(a_lens) + "}, cannot multiply vector B: {" +
+ "}"); to_string_range(b_lens) + "}");
} }
} }
else else
{ {
std::size_t dim_0 = b_lens.size() - 2; std::size_t dim_0 = b_lens.size() - 2;
if (a_lens.front() != b_lens[dim_0]) if(a_lens.front() != b_lens[dim_0])
{ {
MIGRAPHX_THROW("DOT : dimension mismatch, vector A: {" + to_string_range(a_lens) + MIGRAPHX_THROW("DOT : dimension mismatch, vector A: {" +
"}, cannot multiply matrix B: {" + to_string_range(b_lens) to_string_range(a_lens) + "}, cannot multiply matrix B: {" +
+ "}"); to_string_range(b_lens) + "}");
} }
out_lens = b_lens; out_lens = b_lens;
...@@ -923,13 +924,13 @@ struct dot ...@@ -923,13 +924,13 @@ struct dot
else else
{ {
std::size_t dim_0 = a_lens.size() - 1; std::size_t dim_0 = a_lens.size() - 1;
if (b_lens.size() == 1) if(b_lens.size() == 1)
{ {
if (a_lens.back() != b_lens.back()) if(a_lens.back() != b_lens.back())
{ {
MIGRAPHX_THROW("DOT : dimension mismatch, matrix A: {" + to_string_range(a_lens) + MIGRAPHX_THROW("DOT : dimension mismatch, matrix A: {" +
"}, cannot multiply vector B: {" + to_string_range(b_lens) to_string_range(a_lens) + "}, cannot multiply vector B: {" +
+ "}"); to_string_range(b_lens) + "}");
} }
out_lens = a_lens; out_lens = a_lens;
...@@ -939,11 +940,11 @@ struct dot ...@@ -939,11 +940,11 @@ struct dot
{ {
std::size_t dim_0 = a_lens.size() - 1; std::size_t dim_0 = a_lens.size() - 1;
std::size_t dim_1 = b_lens.size() - 2; std::size_t dim_1 = b_lens.size() - 2;
if (a_lens[dim_0] != b_lens[dim_1]) if(a_lens[dim_0] != b_lens[dim_1])
{ {
MIGRAPHX_THROW("DOT : dimension mismatch, matrix A: {" + to_string_range(a_lens) + MIGRAPHX_THROW("DOT : dimension mismatch, matrix A: {" +
"}, cannot multiply matrix B: {" + to_string_range(b_lens) to_string_range(a_lens) + "}, cannot multiply matrix B: {" +
+ "}"); to_string_range(b_lens) + "}");
} }
a_lens.pop_back(); a_lens.pop_back();
...@@ -961,7 +962,7 @@ struct dot ...@@ -961,7 +962,7 @@ struct dot
} }
// c is broadcast // c is broadcast
if (inputs.size() == 3) if(inputs.size() == 3)
// according to the specification of the numpy.matmul() // according to the specification of the numpy.matmul()
// inputs with the shape dims more than 2 are acceptable // inputs with the shape dims more than 2 are acceptable
......
...@@ -127,8 +127,7 @@ argument miopen_gemm::compute(context& ctx, ...@@ -127,8 +127,7 @@ argument miopen_gemm::compute(context& ctx,
auto alpha_r = to_rocblas_type(as(op.alpha)); auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(op.beta)); auto beta_r = to_rocblas_type(as(op.beta));
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); }; auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
generic_rocblas_batched_gemm( generic_rocblas_batched_gemm(as,
as,
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none, transa ? rocblas_operation_transpose : rocblas_operation_none,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment