host_gemm.hpp 1.37 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
// SPDX-License-Identifier: MIT
Illia Silin's avatar
Illia Silin committed
2
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
Chao Liu's avatar
Chao Liu committed
3

Chao Liu's avatar
Chao Liu committed
4
#pragma once
Chao Liu's avatar
Chao Liu committed
5

Chao Liu's avatar
Chao Liu committed
6
#include "host_tensor.hpp"
7

Chao Liu's avatar
Chao Liu committed
8
9
10
11
12
13
template <typename AType,
          typename BType,
          typename CType,
          typename AElementwiseOperation,
          typename BElementwiseOperation,
          typename CElementwiseOperation>
14
15
void host_gemm_mk_kn_mn(const Tensor<AType>& a_m_k,
                        const Tensor<BType>& b_k_n,
Chao Liu's avatar
Chao Liu committed
16
17
18
19
                        Tensor<CType>& c_m_n,
                        const AElementwiseOperation& a_element_op,
                        const BElementwiseOperation& b_element_op,
                        const CElementwiseOperation& c_element_op)
20
21
22
23
{
    auto f_mk_kn_mn = [&](auto m, auto n) {
        const int K = a_m_k.mDesc.GetLengths()[1];

Chao Liu's avatar
Chao Liu committed
24
        float v_acc = 0;
25
26
27

        for(int k = 0; k < K; ++k)
        {
Chao Liu's avatar
Chao Liu committed
28
29
30
31
32
33
34
            float v_a;
            float v_b;

            a_element_op(v_a, static_cast<const float>(a_m_k(m, k)));
            b_element_op(v_b, static_cast<const float>(b_k_n(k, n)));

            v_acc += v_a * v_b;
35
36
        }

Chao Liu's avatar
Chao Liu committed
37
38
39
40
41
        float v_c;

        c_element_op(v_c, v_acc);

        c_m_n(m, n) = v_c;
42
43
44
45
46
47
    };

    make_ParallelTensorFunctor(f_mk_kn_mn,
                               c_m_n.mDesc.GetLengths()[0],
                               c_m_n.mDesc.GetLengths()[1])(std::thread::hardware_concurrency());
}