Commit 0ea7b7a3 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 2d98b64e
......@@ -92,14 +92,18 @@ void migemm_impl(tensor_view<T> cmat,
std::vector<std::size_t> b_idx(b_lens.size());
shape_for_each(cmat.get_shape(), [&](const auto& c_idx) {
std::transform(c_lens.begin() + a_len_diff, c_lens.end(), a_lens.begin(), a_idx.begin(), [&](auto i, auto j) {
return (j == 1) ? 0 : i;
});
std::transform(c_lens.begin() + b_len_diff, c_lens.end(), b_lens.begin(), b_idx.begin(), [&](auto i, auto j) {
return (j == 1) ? 0 : i;
});
double s = 0.0;
std::transform(c_lens.begin() + a_len_diff,
c_lens.end(),
a_lens.begin(),
a_idx.begin(),
[&](auto i, auto j) { return (j == 1) ? 0 : i; });
std::transform(c_lens.begin() + b_len_diff,
c_lens.end(),
b_lens.begin(),
b_idx.begin(),
[&](auto i, auto j) { return (j == 1) ? 0 : i; });
double s = 0.0;
dfor(k)([&](auto kk) {
a_idx[dim_1] = b_idx[dim_0] = kk;
s += amat(a_idx.begin(), a_idx.end()) * bmat(b_idx.begin(), b_idx.end());
......@@ -112,9 +116,9 @@ template <class T>
void migemm_impl(
tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, float alpha, float beta)
{
auto lens = cmat.get_shape().lens();
std::size_t num_matrices =
std::accumulate(lens.rbegin() + 2, lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
auto lens = cmat.get_shape().lens();
std::size_t num_matrices = std::accumulate(
lens.rbegin() + 2, lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
if(num_matrices == 1)
{
migemm_impl(cmat, amat, bmat, alpha, beta, is_fast_gemm_type<T>{});
......
......@@ -375,22 +375,22 @@ struct cpu_gemm
{
argument result{output_shape};
// all args are scalar
if (output_shape.scalar())
if(output_shape.scalar())
{
visit_all(result, args[0], args[1], args[2])([&](auto ret, auto a, auto b, auto c) {
ret[0] = op.alpha * a[0] * b[0] + op.beta * c[0];
ret[0] = op.alpha * a[0] * b[0] + op.beta * c[0];
});
return result;
}
// first argument is 1-dim, pre-pend 1 at beginning
auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens();
auto out_lens = output_shape.lens();
auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens();
auto out_lens = output_shape.lens();
bool is_a_prepended = false;
shape::type_t t = output_shape.type();
if (a_lens.size() == 1)
shape::type_t t = output_shape.type();
if(a_lens.size() == 1)
{
is_a_prepended = true;
a_lens.insert(a_lens.begin(), 1);
......@@ -399,7 +399,7 @@ struct cpu_gemm
}
bool is_b_appended = false;
if (b_lens.size() == 1)
if(b_lens.size() == 1)
{
is_b_appended = true;
b_lens.push_back(1);
......@@ -410,17 +410,20 @@ struct cpu_gemm
if(args.size() == 2)
{
result.visit([&](auto output) { std::fill(output.begin(), output.end(), 0); });
migemm({{t, out_lens}, result.data()}, {{t, a_lens}, args[0].data()},
{{t, b_lens}, args[1].data()}, op.alpha, op.beta);
migemm({{t, out_lens}, result.data()},
{{t, a_lens}, args[0].data()},
{{t, b_lens}, args[1].data()},
op.alpha,
op.beta);
return result;
}
// 3 input arguments
auto c_shape = args[2].get_shape();
// In GEMM, C is broadcastable to A * B, so we should consider C
// In GEMM, C is broadcastable to A * B, so we should consider C
// is not the same shape as A * B. If the same shape, copy C to
// the memory of the output
if (c_shape == output_shape)
if(c_shape == output_shape)
{
// memory copy is more efficient than doing element by element
result.visit([&](auto output) {
......@@ -430,23 +433,28 @@ struct cpu_gemm
}
else
{
auto out_len = output_shape.lens();
auto c_lens = c_shape.lens();
auto out_len = output_shape.lens();
auto c_lens = c_shape.lens();
std::size_t len_diff = out_len.size() - c_lens.size();
visit_all(result, args[2]) ([&](auto output, auto c) {
visit_all(result, args[2])([&](auto output, auto c) {
shape_for_each(output_shape, [&](auto out_idx) {
// compute the input index
std::vector<std::size_t> in_idx(c_lens.size());
std::transform(c_lens.begin(), c_lens.end(), out_len.begin() + len_diff, in_idx.begin(), [&](auto i, auto j) {
return (i == 1) ? 0 : j;
});
std::transform(c_lens.begin(),
c_lens.end(),
out_len.begin() + len_diff,
in_idx.begin(),
[&](auto i, auto j) { return (i == 1) ? 0 : j; });
output(out_idx.begin(), out_idx.end()) = c(in_idx.begin(), in_idx.end());
});
});
}
migemm({{t, out_lens}, result.data()}, {{t, a_lens}, args[0].data()},
{{t, b_lens}, args[1].data()}, op.alpha, op.beta);
migemm({{t, out_lens}, result.data()},
{{t, a_lens}, args[0].data()},
{{t, b_lens}, args[1].data()},
op.alpha,
op.beta);
return result;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment