reference_gemm.hpp 1.07 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
4
5
6
7
8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck/utility/common_header.hpp"
#include "ck/library/utility/host_tensor.hpp"

Chao Liu's avatar
Chao Liu committed
9
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
Chao Liu's avatar
Chao Liu committed
10
11
12
13
void reference_gemm(const Tensor<ADataType>& a_m_k,
                    const Tensor<BDataType>& b_n_k,
                    Tensor<CDataType>& c_m_n)
{
Chao Liu's avatar
Chao Liu committed
14
15
    const int N = b_n_k.mDesc.GetLengths()[0];
    const int K = b_n_k.mDesc.GetLengths()[1];
Chao Liu's avatar
Chao Liu committed
16

Chao Liu's avatar
Chao Liu committed
17
18
    auto f = [&](auto m) {
        for(int n = 0; n < N; ++n)
Chao Liu's avatar
Chao Liu committed
19
        {
Chao Liu's avatar
Chao Liu committed
20
            AccDataType v_acc = 0;
Chao Liu's avatar
Chao Liu committed
21

Chao Liu's avatar
Chao Liu committed
22
23
24
25
            for(int k = 0; k < K; ++k)
            {
                ADataType v_a = a_m_k(m, k);
                BDataType v_b = b_n_k(n, k);
Chao Liu's avatar
Chao Liu committed
26

Chao Liu's avatar
Chao Liu committed
27
28
29
30
31
                v_acc += ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
            }

            c_m_n(m, n) = ck::type_convert<CDataType>(v_acc);
        }
Chao Liu's avatar
Chao Liu committed
32
33
    };

Chao Liu's avatar
Chao Liu committed
34
    make_ParallelTensorFunctor(f, c_m_n.mDesc.GetLengths()[0])(std::thread::hardware_concurrency());
Chao Liu's avatar
Chao Liu committed
35
}