Commit 2d98b64e authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Added the CPU implementation of the dot operator

parent b2106be7
......@@ -819,6 +819,10 @@ struct gather
// vector input; if A or B is 2-dim, it is a matrix (no case of a batch of
// vectors as input). If A or B is 3 or more dims, it is considered as a
// stack(batch) of matrices.
// Note that, we optimze the scenario of either the Matmul or Gemm operators,
// But for extensional scenarios like GEMM with three inputs, and each arg
// is a batch is matrices, the implementation may need further optimization
// later.
struct dot
{
float alpha = 1.0;
......
......@@ -73,18 +73,32 @@ void migemm_impl(tensor_view<T> cmat,
float beta,
std::false_type)
{
std::size_t n_dims = cmat.get_shape().lens().size();
auto a_lens = amat.get_shape().lens();
auto b_lens = bmat.get_shape().lens();
auto c_lens = cmat.get_shape().lens();
std::size_t n_dims = c_lens.size();
std::size_t dim_0 = n_dims - 2;
std::size_t dim_1 = n_dims - 1;
auto k = amat.get_shape().lens()[dim_1];
auto k = a_lens[dim_1];
assert(a_lens[dim_1] == b_lens[dim_0]);
assert(c_lens[dim_0] == a_lens[dim_0]);
assert(c_lens[dim_1] == b_lens[dim_1]);
assert(amat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_0]);
assert(cmat.get_shape().lens()[dim_0] == amat.get_shape().lens()[dim_0]);
assert(cmat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_1]);
std::size_t a_len_diff = c_lens.size() - a_lens.size();
std::size_t b_len_diff = c_lens.size() - b_lens.size();
std::vector<std::size_t> a_idx(a_lens.size());
std::vector<std::size_t> b_idx(b_lens.size());
shape_for_each(cmat.get_shape(), [&](const auto& c_idx) {
auto a_idx = c_idx;
auto b_idx = c_idx;
std::transform(c_lens.begin() + a_len_diff, c_lens.end(), a_lens.begin(), a_idx.begin(), [&](auto i, auto j) {
return (j == 1) ? 0 : i;
});
std::transform(c_lens.begin() + b_len_diff, c_lens.end(), b_lens.begin(), b_idx.begin(), [&](auto i, auto j) {
return (j == 1) ? 0 : i;
});
double s = 0.0;
dfor(k)([&](auto kk) {
a_idx[dim_1] = b_idx[dim_0] = kk;
......@@ -98,11 +112,10 @@ template <class T>
void migemm_impl(
tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, float alpha, float beta)
{
auto lens = amat.get_shape().lens();
bool batch_mul =
std::accumulate(lens.begin(), lens.end(), std::size_t{1}, std::multiplies<std::size_t>()) ==
(*lens.rbegin()) * (*(lens.rbegin() + 1));
if(batch_mul)
auto lens = cmat.get_shape().lens();
std::size_t num_matrices =
std::accumulate(lens.rbegin() + 2, lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
if(num_matrices == 1)
{
migemm_impl(cmat, amat, bmat, alpha, beta, is_fast_gemm_type<T>{});
}
......
......@@ -374,20 +374,80 @@ struct cpu_gemm
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
// all args are scalar
if (output_shape.scalar())
{
visit_all(result, args[0], args[1], args[2])([&](auto ret, auto a, auto b, auto c) {
ret[0] = op.alpha * a[0] * b[0] + op.beta * c[0];
});
return result;
}
// first argument is 1-dim, pre-pend 1 at beginning
auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens();
auto out_lens = output_shape.lens();
bool is_a_prepended = false;
shape::type_t t = output_shape.type();
if (a_lens.size() == 1)
{
is_a_prepended = true;
a_lens.insert(a_lens.begin(), 1);
out_lens.push_back(1);
std::swap(*out_lens.rbegin(), *(out_lens.rbegin() + 1));
}
bool is_b_appended = false;
if (b_lens.size() == 1)
{
is_b_appended = true;
b_lens.push_back(1);
out_lens.push_back(1);
}
// if there is a C input
if(args.size() == 3)
if(args.size() == 2)
{
result.visit([&](auto output) { std::fill(output.begin(), output.end(), 0); });
migemm({{t, out_lens}, result.data()}, {{t, a_lens}, args[0].data()},
{{t, b_lens}, args[1].data()}, op.alpha, op.beta);
return result;
}
// 3 input arguments
auto c_shape = args[2].get_shape();
// In GEMM, C is broadcastable to A * B, so we should consider C
// is not the same shape as A * B. If the same shape, copy C to
// the memory of the output
if (c_shape == output_shape)
{
// memory copy is more efficient than doing element by element
result.visit([&](auto output) {
args[2].visit(
[&](auto input) { std::copy(input.begin(), input.end(), output.begin()); });
[&](auto input) { std::memcpy(output.data(), input.data(), c_shape.bytes()); });
});
}
else
{
result.visit([&](auto output) { std::fill(output.begin(), output.end(), 0); });
auto out_len = output_shape.lens();
auto c_lens = c_shape.lens();
std::size_t len_diff = out_len.size() - c_lens.size();
visit_all(result, args[2]) ([&](auto output, auto c) {
shape_for_each(output_shape, [&](auto out_idx) {
// compute the input index
std::vector<std::size_t> in_idx(c_lens.size());
std::transform(c_lens.begin(), c_lens.end(), out_len.begin() + len_diff, in_idx.begin(), [&](auto i, auto j) {
return (i == 1) ? 0 : j;
});
output(out_idx.begin(), out_idx.end()) = c(in_idx.begin(), in_idx.end());
});
});
}
migemm(result, args[0], args[1], op.alpha, op.beta);
migemm({{t, out_lens}, result.data()}, {{t, a_lens}, args[0].data()},
{{t, b_lens}, args[1].data()}, op.alpha, op.beta);
return result;
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment