"vscode:/vscode.git/clone" did not exist on "4811a3d1df15f24b009e5ce168e1e5589e743d4b"
host_gemm.hpp 1.26 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
#pragma once
#include "host_tensor.hpp"
3

Chao Liu's avatar
Chao Liu committed
4
5
6
7
8
9
template <typename AType,
          typename BType,
          typename CType,
          typename AElementwiseOperation,
          typename BElementwiseOperation,
          typename CElementwiseOperation>
10
11
void host_gemm_mk_kn_mn(const Tensor<AType>& a_m_k,
                        const Tensor<BType>& b_k_n,
Chao Liu's avatar
Chao Liu committed
12
13
14
15
                        Tensor<CType>& c_m_n,
                        const AElementwiseOperation& a_element_op,
                        const BElementwiseOperation& b_element_op,
                        const CElementwiseOperation& c_element_op)
16
17
18
19
{
    auto f_mk_kn_mn = [&](auto m, auto n) {
        const int K = a_m_k.mDesc.GetLengths()[1];

Chao Liu's avatar
Chao Liu committed
20
        float v_acc = 0;
21
22
23

        for(int k = 0; k < K; ++k)
        {
Chao Liu's avatar
Chao Liu committed
24
25
26
27
28
29
30
            float v_a;
            float v_b;

            a_element_op(v_a, static_cast<const float>(a_m_k(m, k)));
            b_element_op(v_b, static_cast<const float>(b_k_n(k, n)));

            v_acc += v_a * v_b;
31
32
        }

Chao Liu's avatar
Chao Liu committed
33
34
35
36
37
        float v_c;

        c_element_op(v_c, v_acc);

        c_m_n(m, n) = v_c;
38
39
40
41
42
43
    };

    make_ParallelTensorFunctor(f_mk_kn_mn,
                               c_m_n.mDesc.GetLengths()[0],
                               c_m_n.mDesc.GetLengths()[1])(std::thread::hardware_concurrency());
}