Commit 007ea283 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

backup second implementation of the compute_shape for the dot operator

parent 359ec2f8
...@@ -830,16 +830,36 @@ struct dot ...@@ -830,16 +830,36 @@ struct dot
return pack(f(self.alpha, "alpha"), f(self.beta, "beta")); return pack(f(self.alpha, "alpha"), f(self.beta, "beta"));
} }
// if not a multi_broadcast, b should be broadcastable to a
std::vector<std::size_t> shape_broadcast(std::vector<std::size_t>& a, std::vector<std::size_t> shape_broadcast(std::vector<std::size_t>& a,
std::vector<std::size_t>& b) const std::vector<std::size_t>& b,
bool is_mutli_broadcast = true) const
{ {
if(a.empty()) if(b.empty())
return b;
else if(b.empty())
return a; return a;
if (a.empty())
{
if (is_mutli_broadcast)
{
return b;
}
else
{
MIGRAPHX_THROW("DOT: C is not broadcastable to A * B (scalar)");
}
}
auto a_size = a.size(); auto a_size = a.size();
auto b_size = b.size(); auto b_size = b.size();
if (is_mutli_broadcast && b_size > a_size)
{
MIGRAPHX_THROW("DOT: C {" + to_string_range(b) +
"} is not broadcastable to A * b {" + to_string_range(a) +
"}");
}
auto n_dim = std::min(a_size, b_size); auto n_dim = std::min(a_size, b_size);
std::vector<std::size_t> out_lens(std::max(a_size, b_size)); std::vector<std::size_t> out_lens(std::max(a_size, b_size));
for(std::size_t i = 0; i < n_dim; ++i) for(std::size_t i = 0; i < n_dim; ++i)
...@@ -848,20 +868,32 @@ struct dot ...@@ -848,20 +868,32 @@ struct dot
{ {
out_lens[i] = a[a_size - 1 - i]; out_lens[i] = a[a_size - 1 - i];
} }
else if(a[a_size - 1 - i] == 1)
{
out_lens[i] = b[b_size - 1 - i];
}
else if(b[b_size - 1 - i] == 1) else if(b[b_size - 1 - i] == 1)
{ {
out_lens[i] = a[a_size - 1 - i]; out_lens[i] = a[a_size - 1 - i];
} }
else else
{
if(a[a_size - 1 - i] == 1 && is_mutli_broadcast)
{
out_lens[i] = b[b_size - 1 - i];
}
else
{
if (is_mutli_broadcast)
{ {
MIGRAPHX_THROW("DOT : dimension mismatch, matrix A: {" + to_string_range(a) + MIGRAPHX_THROW("DOT : dimension mismatch, matrix A: {" + to_string_range(a) +
"}, and matrix B: {" + to_string_range(b) + "}, and matrix B: {" + to_string_range(b) +
"} are not broadcastable"); "} are not broadcastable");
} }
else
{
MIGRAPHX_THROW("DOT: C {" + to_string_range(b) +
"} is not broadcastable to A * b {" + to_string_range(a) +
"}");
}
}
}
} }
if(a_size > n_dim) if(a_size > n_dim)
...@@ -894,59 +926,31 @@ struct dot ...@@ -894,59 +926,31 @@ struct dot
auto a_lens = a.lens(); auto a_lens = a.lens();
auto b_lens = b.lens(); auto b_lens = b.lens();
std::vector<std::size_t> out_lens; bool is_a_appended = false;
if(a_lens.size() == 1) bool is_b_appended = false;
{
// inner product, output is a scalar, following numpy.matmul() if (a_lens.size() == 1)
if(b_lens.size() == 1)
{
if(a_lens.front() != b_lens.front())
{
MIGRAPHX_THROW("DOT : dimension mismatch, vector A: {" +
to_string_range(a_lens) + "}, cannot multiply vector B: {" +
to_string_range(b_lens) + "}");
}
}
else
{
std::size_t dim_0 = b_lens.size() - 2;
if(a_lens.front() != b_lens[dim_0])
{ {
MIGRAPHX_THROW("DOT : dimension mismatch, vector A: {" + a_lens.insert(a_lens.begin(), 1);
to_string_range(a_lens) + "}, cannot multiply matrix B: {" + is_a_appended = true;
to_string_range(b_lens) + "}");
} }
out_lens = b_lens; if (b_lens.size() == 1)
out_lens.erase(out_lens.begin() + dim_0);
}
}
else
{ {
std::size_t dim_0 = a_lens.size() - 1; b_lens.push_back(1);
if(b_lens.size() == 1) is_b_appended = true;
{
if(a_lens.back() != b_lens.back())
{
MIGRAPHX_THROW("DOT : dimension mismatch, matrix A: {" +
to_string_range(a_lens) + "}, cannot multiply vector B: {" +
to_string_range(b_lens) + "}");
} }
out_lens = a_lens;
out_lens.pop_back();
}
else
{
std::size_t dim_0 = a_lens.size() - 1; std::size_t dim_0 = a_lens.size() - 1;
std::size_t dim_1 = b_lens.size() - 2; std::size_t dim_1 = b_lens.size() - 2;
if(a_lens[dim_0] != b_lens[dim_1]) if (a_lens[dim_0] != b_lens[dim_1])
{ {
MIGRAPHX_THROW("DOT : dimension mismatch, matrix A: {" + MIGRAPHX_THROW("DOT : dimension mismatch, operand A: {" +
to_string_range(a_lens) + "}, cannot multiply matrix B: {" + to_string_range(a.lens()) + "}, cannot multiply operand B: {" +
to_string_range(b_lens) + "}"); to_string_range(b.lens()) + "}");
} }
// remove the matrix dims, do multi_broadcast of the shape of the batch
a_lens.pop_back(); a_lens.pop_back();
std::size_t out_m = a_lens.back(); std::size_t out_m = a_lens.back();
a_lens.pop_back(); a_lens.pop_back();
...@@ -955,60 +959,47 @@ struct dot ...@@ -955,60 +959,47 @@ struct dot
b_lens.pop_back(); b_lens.pop_back();
b_lens.pop_back(); b_lens.pop_back();
out_lens = shape_broadcast(a_lens, b_lens); auto out_lens = shape_broadcast(a_lens, b_lens);
out_lens.push_back(out_m); out_lens.push_back(out_m);
out_lens.push_back(out_n); out_lens.push_back(out_n);
}
}
// c is broadcast // remove the prepended 1, if a is a vector
if(inputs.size() == 3) if (is_a_appended)
{
out_lens.erase(out_lens.begin() + out_lens.size() - 2);
}
// according to the specification of the numpy.matmul() // remove the appended 1, if b is a vector
// inputs with the shape dims more than 2 are acceptable if (is_b_appended)
// as long as dim values are the same in the two inputs
if(!std::equal(a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2))
{ {
MIGRAPHX_THROW("DOT: number of matrices in stack are different in A and B"); out_lens.pop_back();
} }
// c is unibroadcastable to A * B
if(inputs.size() == 3) if(inputs.size() == 3)
{ {
// same type as A and B
check_shapes{{inputs[0], inputs[2]}, *this}.has(2).same_type(); check_shapes{{inputs[0], inputs[2]}, *this}.has(2).same_type();
const shape& c = inputs.at(2); if (out_lens.empty() && (!inputs[2].scalar()))
if(!std::equal(a.lens().rbegin() + 2, a.lens().rend(), c.lens().rbegin() + 2))
{ {
MIGRAPHX_THROW("DOT: number of matrices in stack are different in A and C"); MIGRAPHX_THROW("DOT: C is not broadcastable to A*B (scalar)");
}
} }
std::size_t dim_0 = a.lens().size() - 2; //check c is broadcastable to A * B
std::size_t dim_1 = a.lens().size() - 1; auto c_lens = inputs[2].lens();
if(a.lens()[dim_1] != b.lens()[dim_0]) shape_broadcast(out_lens, c_lens, false);
MIGRAPHX_THROW("DOT : inner dimensions do not match: {" + to_string_range(a.lens()) +
"} x {" + to_string_range(b.lens()) + "}");
if(inputs.size() == 3)
{
const shape& c = inputs.at(2);
if(a.lens()[dim_0] != c.lens()[dim_0])
{
MIGRAPHX_THROW("DOT : matrix size does not match: A: {" +
to_string_range(a.lens()) + "}, C: {" + to_string_range(c.lens()) +
"}");
} }
if(b.lens()[dim_1] != c.lens()[dim_1]) if (out_lens.empty())
{ {
MIGRAPHX_THROW("DOT : matrix size does not match: B: {" + return {t};
to_string_range(b.lens()) + "}, C: {" + to_string_range(c.lens()) +
"}");
}
} }
else
auto out_lens = a.lens(); {
out_lens[dim_1] = b.lens()[dim_1];
return {t, out_lens}; return {t, out_lens};
} }
}
}; };
struct unary struct unary
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment