"...video_migraphx_netint.git" did not exist on "c57fcc1d27ef5690829517426e030c3c9f5399dc"
host_gemm.hpp 1.14 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
#pragma once
#include "host_tensor.hpp"
3

Chao Liu's avatar
Chao Liu committed
4
5
6
7
8
9
template <typename AType,
          typename BType,
          typename CType,
          typename AElementwiseOperation,
          typename BElementwiseOperation,
          typename CElementwiseOperation>
10
11
void host_gemm_mk_kn_mn(const Tensor<AType>& a_m_k,
                        const Tensor<BType>& b_k_n,
Chao Liu's avatar
Chao Liu committed
12
13
14
15
                        Tensor<CType>& c_m_n,
                        const AElementwiseOperation& a_element_op,
                        const BElementwiseOperation& b_element_op,
                        const CElementwiseOperation& c_element_op)
16
17
18
19
20
21
22
23
{
    auto f_mk_kn_mn = [&](auto m, auto n) {
        const int K = a_m_k.mDesc.GetLengths()[1];

        double v = 0;

        for(int k = 0; k < K; ++k)
        {
Chao Liu's avatar
Chao Liu committed
24
25
            v += static_cast<const double>(a_element_op(a_m_k(m, k))) *
                 static_cast<const double>(b_element_op(b_k_n(k, n)));
26
27
        }

Chao Liu's avatar
Chao Liu committed
28
        c_m_n(m, n) = c_element_op(v);
29
30
31
32
33
34
    };

    make_ParallelTensorFunctor(f_mk_kn_mn,
                               c_m_n.mDesc.GetLengths()[0],
                               c_m_n.mDesc.GetLengths()[1])(std::thread::hardware_concurrency());
}