reference_batched_gemm.hpp 2.06 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
4
5
6
7
8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck/utility/common_header.hpp"
#include "ck/library/utility/host_tensor.hpp"

9
10
11
12
13
14
15
template <typename ADataType,
          typename BDataType,
          typename AccDataType,
          typename CDataType,
          typename AElementOp,
          typename BElementOp,
          typename ACCElementOp>
Chao Liu's avatar
Chao Liu committed
16
17
void reference_batched_gemm(const Tensor<ADataType>& a_b_m_k,
                            const Tensor<BDataType>& b_b_n_k,
18
19
20
21
                            Tensor<CDataType>& c_b_m_n,
                            const AElementOp& a_element_op,
                            const BElementOp& b_element_op,
                            const ACCElementOp& acc_element_op)
Chao Liu's avatar
Chao Liu committed
22
23
24
25
26
27
28
29
30
31
32
{
    const int N = b_b_n_k.mDesc.GetLengths()[1];
    const int K = b_b_n_k.mDesc.GetLengths()[2];

    auto f = [&](auto batch, auto m) {
        for(int n = 0; n < N; ++n)
        {
            AccDataType v_acc = 0;

            for(int k = 0; k < K; ++k)
            {
33
34
                ADataType v_a = a_element_op(a_b_m_k(batch, m, k));
                BDataType v_b = b_element_op(b_b_n_k(batch, n, k));
Chao Liu's avatar
Chao Liu committed
35
36
37
38

                v_acc += ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
            }

39
            c_b_m_n(batch, m, n) = ck::type_convert<CDataType>(acc_element_op(v_acc));
Chao Liu's avatar
Chao Liu committed
40
41
42
43
44
45
        }
    };

    make_ParallelTensorFunctor(f, c_b_m_n.mDesc.GetLengths()[0], c_b_m_n.mDesc.GetLengths()[1])(
        std::thread::hardware_concurrency());
}
46
47
48
49
50
51
52
53
54
55
56
57
58
59

template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
void reference_batched_gemm(const Tensor<ADataType>& a_b_m_k,
                            const Tensor<BDataType>& b_b_n_k,
                            Tensor<CDataType>& c_b_m_n)
{
    reference_batched_gemm<ADataType, BDataType, AccDataType, CDataType>(
        a_b_m_k,
        b_b_n_k,
        c_b_m_n,
        [](const ADataType& x) { return x; },
        [](const BDataType& x) { return x; },
        [](const AccDataType& x) { return x; });
}