"...resnet50_tensorflow.git" did not exist on "92bad0d216cc46140c52da8d75d4685eb364736a"
Commit 8a45e79f authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 77212cc1
...@@ -839,7 +839,8 @@ struct dot ...@@ -839,7 +839,8 @@ struct dot
// according to the specification of the numpy.matmul() // according to the specification of the numpy.matmul()
// inputs with the shape dims more than 2 are acceptable // inputs with the shape dims more than 2 are acceptable
// as long as dim values are the same in the two inputs // as long as dim values are the same in the two inputs
if(!std::equal(a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2, b.lens().rend())) if(!std::equal(
a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2, b.lens().rend()))
{ {
MIGRAPHX_THROW("DOT: dim values mismatch"); MIGRAPHX_THROW("DOT: dim values mismatch");
} }
...@@ -854,13 +855,12 @@ struct dot ...@@ -854,13 +855,12 @@ struct dot
auto out_lens = a.lens(); auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1]; out_lens[dim_1] = b.lens()[dim_1];
if (inputs.size() == 3 && out_lens != inputs.at(2).lens()) if(inputs.size() == 3 && out_lens != inputs.at(2).lens())
{ {
MIGRAPHX_THROW("DOT: dimension mismatch, operand C: {" + to_string_range(c_lens) + MIGRAPHX_THROW("DOT: dimension mismatch, operand C: {" + to_string_range(c_lens) +
"}, cannot add to operand A * B: {" + to_string_range(out_lens) + "}, cannot add to operand A * B: {" + to_string_range(out_lens) + "}");
"}");
} }
return {t, out_lens}; return {t, out_lens};
} }
}; };
......
...@@ -61,7 +61,7 @@ void migemm_impl(tensor_view<T> cmat, ...@@ -61,7 +61,7 @@ void migemm_impl(tensor_view<T> cmat,
if(alpha != 0.0) if(alpha != 0.0)
{ {
c = c + alpha * a * b; c = c + alpha * a * b;
} }
}); });
}); });
} }
...@@ -101,7 +101,8 @@ void migemm_impl( ...@@ -101,7 +101,8 @@ void migemm_impl(
{ {
auto lens = amat.get_shape().lens(); auto lens = amat.get_shape().lens();
bool batch_mul = bool batch_mul =
std::accumulate(lens.rbegin() + 2, lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()) == 1; std::accumulate(
lens.rbegin() + 2, lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()) == 1;
if(batch_mul) if(batch_mul)
{ {
migemm_impl(cmat, amat, bmat, alpha, beta, is_fast_gemm_type<T>{}); migemm_impl(cmat, amat, bmat, alpha, beta, is_fast_gemm_type<T>{});
......
...@@ -369,14 +369,14 @@ struct cpu_gemm ...@@ -369,14 +369,14 @@ struct cpu_gemm
{ {
op::dot op; op::dot op;
std::string name() const { return "cpu::dot"; } std::string name() const { return "cpu::dot"; }
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
if(inputs.size() == 3) if(inputs.size() == 3)
{ {
auto c_shape = inputs.at(2); auto c_shape = inputs.at(2);
check_shapes{{c_shape}}.not_broadcasted(); check_shapes{{c_shape}}.not_broadcasted();
} }
return op.compute_shape(inputs); return op.compute_shape(inputs);
} }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment