Commit 007ea283 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

backup second implementation of the compute_shape for the dot operator

parent 359ec2f8
...@@ -830,16 +830,36 @@ struct dot ...@@ -830,16 +830,36 @@ struct dot
return pack(f(self.alpha, "alpha"), f(self.beta, "beta")); return pack(f(self.alpha, "alpha"), f(self.beta, "beta"));
} }
// if not a multi_broadcast, b should be broadcastable to a
std::vector<std::size_t> shape_broadcast(std::vector<std::size_t>& a, std::vector<std::size_t> shape_broadcast(std::vector<std::size_t>& a,
std::vector<std::size_t>& b) const std::vector<std::size_t>& b,
bool is_mutli_broadcast = true) const
{ {
if(a.empty()) if(b.empty())
return b;
else if(b.empty())
return a; return a;
if (a.empty())
{
if (is_mutli_broadcast)
{
return b;
}
else
{
MIGRAPHX_THROW("DOT: C is not broadcastable to A * B (scalar)");
}
}
auto a_size = a.size(); auto a_size = a.size();
auto b_size = b.size(); auto b_size = b.size();
if (is_mutli_broadcast && b_size > a_size)
{
MIGRAPHX_THROW("DOT: C {" + to_string_range(b) +
"} is not broadcastable to A * b {" + to_string_range(a) +
"}");
}
auto n_dim = std::min(a_size, b_size); auto n_dim = std::min(a_size, b_size);
std::vector<std::size_t> out_lens(std::max(a_size, b_size)); std::vector<std::size_t> out_lens(std::max(a_size, b_size));
for(std::size_t i = 0; i < n_dim; ++i) for(std::size_t i = 0; i < n_dim; ++i)
...@@ -848,19 +868,31 @@ struct dot ...@@ -848,19 +868,31 @@ struct dot
{ {
out_lens[i] = a[a_size - 1 - i]; out_lens[i] = a[a_size - 1 - i];
} }
else if(a[a_size - 1 - i] == 1)
{
out_lens[i] = b[b_size - 1 - i];
}
else if(b[b_size - 1 - i] == 1) else if(b[b_size - 1 - i] == 1)
{ {
out_lens[i] = a[a_size - 1 - i]; out_lens[i] = a[a_size - 1 - i];
} }
else else
{ {
MIGRAPHX_THROW("DOT : dimension mismatch, matrix A: {" + to_string_range(a) + if(a[a_size - 1 - i] == 1 && is_mutli_broadcast)
"}, and matrix B: {" + to_string_range(b) + {
"} are not broadcastable"); out_lens[i] = b[b_size - 1 - i];
}
else
{
if (is_mutli_broadcast)
{
MIGRAPHX_THROW("DOT : dimension mismatch, matrix A: {" + to_string_range(a) +
"}, and matrix B: {" + to_string_range(b) +
"} are not broadcastable");
}
else
{
MIGRAPHX_THROW("DOT: C {" + to_string_range(b) +
"} is not broadcastable to A * b {" + to_string_range(a) +
"}");
}
}
} }
} }
...@@ -894,120 +926,79 @@ struct dot ...@@ -894,120 +926,79 @@ struct dot
auto a_lens = a.lens(); auto a_lens = a.lens();
auto b_lens = b.lens(); auto b_lens = b.lens();
std::vector<std::size_t> out_lens; bool is_a_appended = false;
if(a_lens.size() == 1) bool is_b_appended = false;
if (a_lens.size() == 1)
{ {
// inner product, output is a scalar, following numpy.matmul() a_lens.insert(a_lens.begin(), 1);
if(b_lens.size() == 1) is_a_appended = true;
{ }
if(a_lens.front() != b_lens.front())
{
MIGRAPHX_THROW("DOT : dimension mismatch, vector A: {" +
to_string_range(a_lens) + "}, cannot multiply vector B: {" +
to_string_range(b_lens) + "}");
}
}
else
{
std::size_t dim_0 = b_lens.size() - 2;
if(a_lens.front() != b_lens[dim_0])
{
MIGRAPHX_THROW("DOT : dimension mismatch, vector A: {" +
to_string_range(a_lens) + "}, cannot multiply matrix B: {" +
to_string_range(b_lens) + "}");
}
out_lens = b_lens; if (b_lens.size() == 1)
out_lens.erase(out_lens.begin() + dim_0); {
} b_lens.push_back(1);
is_b_appended = true;
} }
else
std::size_t dim_0 = a_lens.size() - 1;
std::size_t dim_1 = b_lens.size() - 2;
if (a_lens[dim_0] != b_lens[dim_1])
{ {
std::size_t dim_0 = a_lens.size() - 1; MIGRAPHX_THROW("DOT : dimension mismatch, operand A: {" +
if(b_lens.size() == 1) to_string_range(a.lens()) + "}, cannot multiply operand B: {" +
{ to_string_range(b.lens()) + "}");
if(a_lens.back() != b_lens.back()) }
{
MIGRAPHX_THROW("DOT : dimension mismatch, matrix A: {" +
to_string_range(a_lens) + "}, cannot multiply vector B: {" +
to_string_range(b_lens) + "}");
}
out_lens = a_lens; // remove the matrix dims, do multi_broadcast of the shape of the batch
out_lens.pop_back(); a_lens.pop_back();
} std::size_t out_m = a_lens.back();
else a_lens.pop_back();
{
std::size_t dim_0 = a_lens.size() - 1;
std::size_t dim_1 = b_lens.size() - 2;
if(a_lens[dim_0] != b_lens[dim_1])
{
MIGRAPHX_THROW("DOT : dimension mismatch, matrix A: {" +
to_string_range(a_lens) + "}, cannot multiply matrix B: {" +
to_string_range(b_lens) + "}");
}
a_lens.pop_back(); std::size_t out_n = b_lens.back();
std::size_t out_m = a_lens.back(); b_lens.pop_back();
a_lens.pop_back(); b_lens.pop_back();
std::size_t out_n = b_lens.back(); auto out_lens = shape_broadcast(a_lens, b_lens);
b_lens.pop_back(); out_lens.push_back(out_m);
b_lens.pop_back(); out_lens.push_back(out_n);
out_lens = shape_broadcast(a_lens, b_lens); // remove the prepended 1, if a is a vector
out_lens.push_back(out_m); if (is_a_appended)
out_lens.push_back(out_n); {
} out_lens.erase(out_lens.begin() + out_lens.size() - 2);
} }
// c is broadcast // remove the appended 1, if b is a vector
if(inputs.size() == 3) if (is_b_appended)
// according to the specification of the numpy.matmul()
// inputs with the shape dims more than 2 are acceptable
// as long as dim values are the same in the two inputs
if(!std::equal(a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2))
{
MIGRAPHX_THROW("DOT: number of matrices in stack are different in A and B");
}
if(inputs.size() == 3)
{ {
check_shapes{{inputs[0], inputs[2]}, *this}.has(2).same_type(); out_lens.pop_back();
const shape& c = inputs.at(2);
if(!std::equal(a.lens().rbegin() + 2, a.lens().rend(), c.lens().rbegin() + 2))
{
MIGRAPHX_THROW("DOT: number of matrices in stack are different in A and C");
}
} }
std::size_t dim_0 = a.lens().size() - 2;
std::size_t dim_1 = a.lens().size() - 1; // c is unibroadcastable to A * B
if(a.lens()[dim_1] != b.lens()[dim_0])
MIGRAPHX_THROW("DOT : inner dimensions do not match: {" + to_string_range(a.lens()) +
"} x {" + to_string_range(b.lens()) + "}");
if(inputs.size() == 3) if(inputs.size() == 3)
{ {
const shape& c = inputs.at(2); // same type as A and B
if(a.lens()[dim_0] != c.lens()[dim_0]) check_shapes{{inputs[0], inputs[2]}, *this}.has(2).same_type();
if (out_lens.empty() && (!inputs[2].scalar()))
{ {
MIGRAPHX_THROW("DOT : matrix size does not match: A: {" + MIGRAPHX_THROW("DOT: C is not broadcastable to A*B (scalar)");
to_string_range(a.lens()) + "}, C: {" + to_string_range(c.lens()) +
"}");
} }
if(b.lens()[dim_1] != c.lens()[dim_1]) //check c is broadcastable to A * B
{ auto c_lens = inputs[2].lens();
MIGRAPHX_THROW("DOT : matrix size does not match: B: {" + shape_broadcast(out_lens, c_lens, false);
to_string_range(b.lens()) + "}, C: {" + to_string_range(c.lens()) +
"}");
}
} }
auto out_lens = a.lens(); if (out_lens.empty())
out_lens[dim_1] = b.lens()[dim_1]; {
return {t, out_lens}; return {t};
}
else
{
return {t, out_lens};
}
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment