Commit b2106be7 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 007ea283
...@@ -838,9 +838,9 @@ struct dot ...@@ -838,9 +838,9 @@ struct dot
if(b.empty()) if(b.empty())
return a; return a;
if (a.empty()) if(a.empty())
{ {
if (is_mutli_broadcast) if(is_mutli_broadcast)
{ {
return b; return b;
} }
...@@ -853,14 +853,13 @@ struct dot ...@@ -853,14 +853,13 @@ struct dot
auto a_size = a.size(); auto a_size = a.size();
auto b_size = b.size(); auto b_size = b.size();
if (is_mutli_broadcast && b_size > a_size) if(is_mutli_broadcast && b_size > a_size)
{ {
MIGRAPHX_THROW("DOT: C {" + to_string_range(b) + MIGRAPHX_THROW("DOT: C {" + to_string_range(b) + "} is not broadcastable to A * b {" +
"} is not broadcastable to A * b {" + to_string_range(a) + to_string_range(a) + "}");
"}");
} }
auto n_dim = std::min(a_size, b_size); auto n_dim = std::min(a_size, b_size);
std::vector<std::size_t> out_lens(std::max(a_size, b_size)); std::vector<std::size_t> out_lens(std::max(a_size, b_size));
for(std::size_t i = 0; i < n_dim; ++i) for(std::size_t i = 0; i < n_dim; ++i)
{ {
...@@ -872,25 +871,25 @@ struct dot ...@@ -872,25 +871,25 @@ struct dot
{ {
out_lens[i] = a[a_size - 1 - i]; out_lens[i] = a[a_size - 1 - i];
} }
else else
{ {
if(a[a_size - 1 - i] == 1 && is_mutli_broadcast) if(a[a_size - 1 - i] == 1 && is_mutli_broadcast)
{ {
out_lens[i] = b[b_size - 1 - i]; out_lens[i] = b[b_size - 1 - i];
} }
else else
{ {
if (is_mutli_broadcast) if(is_mutli_broadcast)
{ {
MIGRAPHX_THROW("DOT : dimension mismatch, matrix A: {" + to_string_range(a) + MIGRAPHX_THROW("DOT : dimension mismatch, matrix A: {" +
"}, and matrix B: {" + to_string_range(b) + to_string_range(a) + "}, and matrix B: {" +
"} are not broadcastable"); to_string_range(b) + "} are not broadcastable");
} }
else else
{ {
MIGRAPHX_THROW("DOT: C {" + to_string_range(b) + MIGRAPHX_THROW("DOT: C {" + to_string_range(b) +
"} is not broadcastable to A * b {" + to_string_range(a) + "} is not broadcastable to A * b {" + to_string_range(a) +
"}"); "}");
} }
} }
} }
...@@ -924,18 +923,18 @@ struct dot ...@@ -924,18 +923,18 @@ struct dot
MIGRAPHX_THROW("DOT: scalar operands are not allowed, use op::mul{} instead"); MIGRAPHX_THROW("DOT: scalar operands are not allowed, use op::mul{} instead");
} }
auto a_lens = a.lens(); auto a_lens = a.lens();
auto b_lens = b.lens(); auto b_lens = b.lens();
bool is_a_appended = false; bool is_a_appended = false;
bool is_b_appended = false; bool is_b_appended = false;
if (a_lens.size() == 1) if(a_lens.size() == 1)
{ {
a_lens.insert(a_lens.begin(), 1); a_lens.insert(a_lens.begin(), 1);
is_a_appended = true; is_a_appended = true;
} }
if (b_lens.size() == 1) if(b_lens.size() == 1)
{ {
b_lens.push_back(1); b_lens.push_back(1);
is_b_appended = true; is_b_appended = true;
...@@ -943,11 +942,10 @@ struct dot ...@@ -943,11 +942,10 @@ struct dot
std::size_t dim_0 = a_lens.size() - 1; std::size_t dim_0 = a_lens.size() - 1;
std::size_t dim_1 = b_lens.size() - 2; std::size_t dim_1 = b_lens.size() - 2;
if (a_lens[dim_0] != b_lens[dim_1]) if(a_lens[dim_0] != b_lens[dim_1])
{ {
MIGRAPHX_THROW("DOT : dimension mismatch, operand A: {" + MIGRAPHX_THROW("DOT : dimension mismatch, operand A: {" + to_string_range(a.lens()) +
to_string_range(a.lens()) + "}, cannot multiply operand B: {" + "}, cannot multiply operand B: {" + to_string_range(b.lens()) + "}");
to_string_range(b.lens()) + "}");
} }
// remove the matrix dims, do multi_broadcast of the shape of the batch // remove the matrix dims, do multi_broadcast of the shape of the batch
...@@ -964,34 +962,33 @@ struct dot ...@@ -964,34 +962,33 @@ struct dot
out_lens.push_back(out_n); out_lens.push_back(out_n);
// remove the prepended 1, if a is a vector // remove the prepended 1, if a is a vector
if (is_a_appended) if(is_a_appended)
{ {
out_lens.erase(out_lens.begin() + out_lens.size() - 2); out_lens.erase(out_lens.begin() + out_lens.size() - 2);
} }
// remove the appended 1, if b is a vector // remove the appended 1, if b is a vector
if (is_b_appended) if(is_b_appended)
{ {
out_lens.pop_back(); out_lens.pop_back();
} }
// c is unibroadcastable to A * B // c is unibroadcastable to A * B
if(inputs.size() == 3) if(inputs.size() == 3)
{ {
// same type as A and B // same type as A and B
check_shapes{{inputs[0], inputs[2]}, *this}.has(2).same_type(); check_shapes{{inputs[0], inputs[2]}, *this}.has(2).same_type();
if (out_lens.empty() && (!inputs[2].scalar())) if(out_lens.empty() && (!inputs[2].scalar()))
{ {
MIGRAPHX_THROW("DOT: C is not broadcastable to A*B (scalar)"); MIGRAPHX_THROW("DOT: C is not broadcastable to A*B (scalar)");
} }
//check c is broadcastable to A * B // check c is broadcastable to A * B
auto c_lens = inputs[2].lens(); auto c_lens = inputs[2].lens();
shape_broadcast(out_lens, c_lens, false); shape_broadcast(out_lens, c_lens, false);
} }
if (out_lens.empty()) if(out_lens.empty())
{ {
return {t}; return {t};
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment