host_gemm.hpp 873 Bytes
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
#pragma once
#include "host_tensor.hpp"
3

Chao Liu's avatar
Chao Liu committed
4
template <typename AType, typename BType, typename CType, typename CElementwiseOperation>
5
6
void host_gemm_mk_kn_mn(const Tensor<AType>& a_m_k,
                        const Tensor<BType>& b_k_n,
Chao Liu's avatar
Chao Liu committed
7
8
                        Tensor<CType>& c_m_n,
                        const CElementwiseOperation& c_element_op)
9
10
11
12
13
14
15
16
17
18
19
{
    auto f_mk_kn_mn = [&](auto m, auto n) {
        const int K = a_m_k.mDesc.GetLengths()[1];

        double v = 0;

        for(int k = 0; k < K; ++k)
        {
            v += static_cast<const double>(a_m_k(m, k)) * static_cast<const double>(b_k_n(k, n));
        }

Chao Liu's avatar
Chao Liu committed
20
        c_m_n(m, n) = c_element_op(v);
21
22
23
24
25
26
    };

    make_ParallelTensorFunctor(f_mk_kn_mn,
                               c_m_n.mDesc.GetLengths()[0],
                               c_m_n.mDesc.GetLengths()[1])(std::thread::hardware_concurrency());
}