Commit 359ec2f8 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 9d52515a
......@@ -830,45 +830,46 @@ struct dot
return pack(f(self.alpha, "alpha"), f(self.beta, "beta"));
}
std::vector<std::size_t> shape_broadcast(std::vector<std::size_t> &a, std::vector<std::size_t> &b) const
std::vector<std::size_t> shape_broadcast(std::vector<std::size_t>& a,
std::vector<std::size_t>& b) const
{
if (a.empty())
if(a.empty())
return b;
else if (b.empty())
else if(b.empty())
return a;
auto a_size = a.size();
auto b_size = b.size();
auto n_dim = std::min(a_size, b_size);
std::vector<std::size_t> out_lens(std::max(a_size, b_size));
for (std::size_t i = 0; i < n_dim; ++i)
for(std::size_t i = 0; i < n_dim; ++i)
{
if (a[a_size - 1 - i] == b[b_size - 1 - i])
if(a[a_size - 1 - i] == b[b_size - 1 - i])
{
out_lens[i] = a[a_size - 1 - i];
}
else if (a[a_size - 1 - i] == 1)
else if(a[a_size - 1 - i] == 1)
{
out_lens[i] = b[b_size - 1 - i];
}
else if (b[b_size - 1 - i] == 1)
else if(b[b_size - 1 - i] == 1)
{
out_lens[i] = a[a_size - 1 - i];
}
else
{
MIGRAPHX_THROW("DOT : dimension mismatch, matrix A: {" + to_string_range(a) +
"}, and matrix B: {" + to_string_range(b)
+ "} are not broadcastable");
"}, and matrix B: {" + to_string_range(b) +
"} are not broadcastable");
}
}
if (a_size > n_dim)
if(a_size > n_dim)
{
std::copy(a.rbegin() + n_dim, a.rend(), out_lens.begin() + n_dim);
}
if (b_size > n_dim)
if(b_size > n_dim)
{
std::copy(b.rbegin() + n_dim, b.rend(), out_lens.rbegin() + n_dim);
}
......@@ -886,7 +887,7 @@ struct dot
const shape& b = inputs.at(1);
auto t = a.type();
if (a.scalar() || b.scalar())
if(a.scalar() || b.scalar())
{
MIGRAPHX_THROW("DOT: scalar operands are not allowed, use op::mul{} instead");
}
......@@ -894,26 +895,26 @@ struct dot
auto a_lens = a.lens();
auto b_lens = b.lens();
std::vector<std::size_t> out_lens;
if (a_lens.size() == 1)
if(a_lens.size() == 1)
{
// inner product, output is a scalar, following numpy.matmul()
if (b_lens.size() == 1)
if(b_lens.size() == 1)
{
if (a_lens.front() != b_lens.front())
if(a_lens.front() != b_lens.front())
{
MIGRAPHX_THROW("DOT : dimension mismatch, vector A: {" + to_string_range(a_lens) +
"}, cannot multiply vector B: {" + to_string_range(b_lens)
+ "}");
MIGRAPHX_THROW("DOT : dimension mismatch, vector A: {" +
to_string_range(a_lens) + "}, cannot multiply vector B: {" +
to_string_range(b_lens) + "}");
}
}
else
{
std::size_t dim_0 = b_lens.size() - 2;
if (a_lens.front() != b_lens[dim_0])
if(a_lens.front() != b_lens[dim_0])
{
MIGRAPHX_THROW("DOT : dimension mismatch, vector A: {" + to_string_range(a_lens) +
"}, cannot multiply matrix B: {" + to_string_range(b_lens)
+ "}");
MIGRAPHX_THROW("DOT : dimension mismatch, vector A: {" +
to_string_range(a_lens) + "}, cannot multiply matrix B: {" +
to_string_range(b_lens) + "}");
}
out_lens = b_lens;
......@@ -923,13 +924,13 @@ struct dot
else
{
std::size_t dim_0 = a_lens.size() - 1;
if (b_lens.size() == 1)
if(b_lens.size() == 1)
{
if (a_lens.back() != b_lens.back())
if(a_lens.back() != b_lens.back())
{
MIGRAPHX_THROW("DOT : dimension mismatch, matrix A: {" + to_string_range(a_lens) +
"}, cannot multiply vector B: {" + to_string_range(b_lens)
+ "}");
MIGRAPHX_THROW("DOT : dimension mismatch, matrix A: {" +
to_string_range(a_lens) + "}, cannot multiply vector B: {" +
to_string_range(b_lens) + "}");
}
out_lens = a_lens;
......@@ -939,11 +940,11 @@ struct dot
{
std::size_t dim_0 = a_lens.size() - 1;
std::size_t dim_1 = b_lens.size() - 2;
if (a_lens[dim_0] != b_lens[dim_1])
if(a_lens[dim_0] != b_lens[dim_1])
{
MIGRAPHX_THROW("DOT : dimension mismatch, matrix A: {" + to_string_range(a_lens) +
"}, cannot multiply matrix B: {" + to_string_range(b_lens)
+ "}");
MIGRAPHX_THROW("DOT : dimension mismatch, matrix A: {" +
to_string_range(a_lens) + "}, cannot multiply matrix B: {" +
to_string_range(b_lens) + "}");
}
a_lens.pop_back();
......@@ -961,7 +962,7 @@ struct dot
}
// c is broadcast
if (inputs.size() == 3)
if(inputs.size() == 3)
// according to the specification of the numpy.matmul()
// inputs with the shape dims more than 2 are acceptable
......
......@@ -127,8 +127,7 @@ argument miopen_gemm::compute(context& ctx,
auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(op.beta));
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
generic_rocblas_batched_gemm(
as,
generic_rocblas_batched_gemm(as,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment