Commit 8a45e79f authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 77212cc1
......@@ -839,7 +839,8 @@ struct dot
// according to the specification of the numpy.matmul()
// inputs with the shape dims more than 2 are acceptable
// as long as dim values are the same in the two inputs
if(!std::equal(a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2, b.lens().rend()))
if(!std::equal(
a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2, b.lens().rend()))
{
MIGRAPHX_THROW("DOT: dim values mismatch");
}
......@@ -854,11 +855,10 @@ struct dot
auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1];
if (inputs.size() == 3 && out_lens != inputs.at(2).lens())
if(inputs.size() == 3 && out_lens != inputs.at(2).lens())
{
MIGRAPHX_THROW("DOT: dimension mismatch, operand C: {" + to_string_range(c_lens) +
"}, cannot add to operand A * B: {" + to_string_range(out_lens) +
"}");
"}, cannot add to operand A * B: {" + to_string_range(out_lens) + "}");
}
return {t, out_lens};
......
......@@ -101,7 +101,8 @@ void migemm_impl(
{
auto lens = amat.get_shape().lens();
bool batch_mul =
std::accumulate(lens.rbegin() + 2, lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()) == 1;
std::accumulate(
lens.rbegin() + 2, lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()) == 1;
if(batch_mul)
{
migemm_impl(cmat, amat, bmat, alpha, beta, is_fast_gemm_type<T>{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment