Commit 9d52515a authored by Shucai Xiao's avatar Shucai Xiao
Browse files

code backup for an initial implementation of the function compute_shape.

parent c45be227
...@@ -810,7 +810,8 @@ struct gather ...@@ -810,7 +810,8 @@ struct gather
// The dot operation is combination of the onnx GEMM and MatMul operators. // The dot operation is combination of the onnx GEMM and MatMul operators.
// For GEMM, it support the C matrix in the formula alpha * AB + beta * C, // For GEMM, it support the C matrix in the formula alpha * AB + beta * C,
// in which C is broadcastable to the shape of AB. For the transpose of A // in which C is broadcastable to the shape of AB. For the transpose of A
// and B, we add a tranpose operator if the onnx file needs. // and B, we add a tranpose operator beforehand if the onnx gemm operator
// indicates a transpose.
// For MatMul, it has the same definition as the numpy.matmul, which means // For MatMul, it has the same definition as the numpy.matmul, which means
// A, B could be 1 to N-dims. For 1-dim input of A, it is a vector * matrix, // A, B could be 1 to N-dims. For 1-dim input of A, it is a vector * matrix,
// for 1-dim of B, it is a matrix * vector. Note that there is not support // for 1-dim of B, it is a matrix * vector. Note that there is not support
...@@ -829,6 +830,54 @@ struct dot ...@@ -829,6 +830,54 @@ struct dot
return pack(f(self.alpha, "alpha"), f(self.beta, "beta")); return pack(f(self.alpha, "alpha"), f(self.beta, "beta"));
} }
std::vector<std::size_t> shape_broadcast(std::vector<std::size_t> &a, std::vector<std::size_t> &b) const
{
if (a.empty())
return b;
else if (b.empty())
return a;
auto a_size = a.size();
auto b_size = b.size();
auto n_dim = std::min(a_size, b_size);
std::vector<std::size_t> out_lens(std::max(a_size, b_size));
for (std::size_t i = 0; i < n_dim; ++i)
{
if (a[a_size - 1 - i] == b[b_size - 1 - i])
{
out_lens[i] = a[a_size - 1 - i];
}
else if (a[a_size - 1 - i] == 1)
{
out_lens[i] = b[b_size - 1 - i];
}
else if (b[b_size - 1 - i] == 1)
{
out_lens[i] = a[a_size - 1 - i];
}
else
{
MIGRAPHX_THROW("DOT : dimension mismatch, matrix A: {" + to_string_range(a) +
"}, and matrix B: {" + to_string_range(b)
+ "} are not broadcastable");
}
}
if (a_size > n_dim)
{
std::copy(a.rbegin() + n_dim, a.rend(), out_lens.begin() + n_dim);
}
if (b_size > n_dim)
{
std::copy(b.rbegin() + n_dim, b.rend(), out_lens.rbegin() + n_dim);
}
std::reverse(out_lens.begin(), out_lens.end());
return out_lens;
}
std::string name() const { return "dot"; } std::string name() const { return "dot"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
...@@ -837,6 +886,83 @@ struct dot ...@@ -837,6 +886,83 @@ struct dot
const shape& b = inputs.at(1); const shape& b = inputs.at(1);
auto t = a.type(); auto t = a.type();
if (a.scalar() || b.scalar())
{
MIGRAPHX_THROW("DOT: scalar operands are not allowed, use op::mul{} instead");
}
auto a_lens = a.lens();
auto b_lens = b.lens();
std::vector<std::size_t> out_lens;
if (a_lens.size() == 1)
{
// inner product, output is a scalar, following numpy.matmul()
if (b_lens.size() == 1)
{
if (a_lens.front() != b_lens.front())
{
MIGRAPHX_THROW("DOT : dimension mismatch, vector A: {" + to_string_range(a_lens) +
"}, cannot multiply vector B: {" + to_string_range(b_lens)
+ "}");
}
}
else
{
std::size_t dim_0 = b_lens.size() - 2;
if (a_lens.front() != b_lens[dim_0])
{
MIGRAPHX_THROW("DOT : dimension mismatch, vector A: {" + to_string_range(a_lens) +
"}, cannot multiply matrix B: {" + to_string_range(b_lens)
+ "}");
}
out_lens = b_lens;
out_lens.erase(out_lens.begin() + dim_0);
}
}
else
{
std::size_t dim_0 = a_lens.size() - 1;
if (b_lens.size() == 1)
{
if (a_lens.back() != b_lens.back())
{
MIGRAPHX_THROW("DOT : dimension mismatch, matrix A: {" + to_string_range(a_lens) +
"}, cannot multiply vector B: {" + to_string_range(b_lens)
+ "}");
}
out_lens = a_lens;
out_lens.pop_back();
}
else
{
std::size_t dim_0 = a_lens.size() - 1;
std::size_t dim_1 = b_lens.size() - 2;
if (a_lens[dim_0] != b_lens[dim_1])
{
MIGRAPHX_THROW("DOT : dimension mismatch, matrix A: {" + to_string_range(a_lens) +
"}, cannot multiply matrix B: {" + to_string_range(b_lens)
+ "}");
}
a_lens.pop_back();
std::size_t out_m = a_lens.back();
a_lens.pop_back();
std::size_t out_n = b_lens.back();
b_lens.pop_back();
b_lens.pop_back();
out_lens = shape_broadcast(a_lens, b_lens);
out_lens.push_back(out_m);
out_lens.push_back(out_n);
}
}
// c is broadcast
if (inputs.size() == 3)
// according to the specification of the numpy.matmul() // according to the specification of the numpy.matmul()
// inputs with the shape dims more than 2 are acceptable // inputs with the shape dims more than 2 are acceptable
// as long as dim values are the same in the two inputs // as long as dim values are the same in the two inputs
......
...@@ -127,9 +127,6 @@ argument miopen_gemm::compute(context& ctx, ...@@ -127,9 +127,6 @@ argument miopen_gemm::compute(context& ctx,
auto alpha_r = to_rocblas_type(as(op.alpha)); auto alpha_r = to_rocblas_type(as(op.alpha));
auto beta_r = to_rocblas_type(as(op.beta)); auto beta_r = to_rocblas_type(as(op.beta));
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); }; auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
// call the strided implementation only if there are multiple matrices
if(batch_num > 1)
{
generic_rocblas_batched_gemm( generic_rocblas_batched_gemm(
as, as,
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
...@@ -150,25 +147,6 @@ argument miopen_gemm::compute(context& ctx, ...@@ -150,25 +147,6 @@ argument miopen_gemm::compute(context& ctx,
ldc, ldc,
m * n, m * n,
batch_num); batch_num);
}
else
{
generic_rocblas_gemm(as,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
to_pointer(args[1]),
ldb,
to_pointer(args[0]),
lda,
&beta_r,
to_pointer(args[2]),
ldc);
}
}); });
return (is_3inputs ? args[3] : args[2]); return (is_3inputs ? args[3] : args[2]);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment