host_gemm.hpp 5.25 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
#pragma once
#include "host_tensor.hpp"
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
template <>
void host_gemm<ushort, ushort, ushort>(const Tensor<ushort>& a,
                                       const Tensor<ushort>& b,
                                       Tensor<ushort>& c,
                                       const GemmMatrixLayout layout)
{
    if(layout == GemmMatrixLayout::MK_KN_MN)
    {
        auto f_mk_kn_mn = [&](auto m, auto n) {
            const int K = a.mDesc.GetLengths()[1];

            double v = 0;

            for(int k = 0; k < K; ++k)
            {
Jing Zhang's avatar
Jing Zhang committed
19
                v += ck::bf16_to_f32(a(m, k)) * ck::bf16_to_f32(b(k, n));
20
21
            }

Jing Zhang's avatar
Jing Zhang committed
22
            c(m, n) = ck::f32_to_bf16(v);
23
24
25
26
27
28
29
30
31
32
33
34
35
36
        };

        make_ParallelTensorFunctor(f_mk_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
            std::thread::hardware_concurrency());
    }
    else if(layout == GemmMatrixLayout::MK_NK_MN)
    {
        auto f_mk_nk_mn = [&](auto m, auto n) {
            const int K = a.mDesc.GetLengths()[1];

            double v = 0;

            for(int k = 0; k < K; ++k)
            {
Jing Zhang's avatar
Jing Zhang committed
37
                v += ck::bf16_to_f32(a(m, k)) * ck::bf16_to_f32(b(n, k));
38
39
            }

Jing Zhang's avatar
Jing Zhang committed
40
            c(m, n) = ck::f32_to_bf16(v);
41
42
43
44
45
46
47
48
49
50
51
52
53
54
        };

        make_ParallelTensorFunctor(f_mk_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
            std::thread::hardware_concurrency());
    }
    else if(layout == GemmMatrixLayout::KM_KN_MN)
    {
        auto f_km_kn_mn = [&](auto m, auto n) {
            const int K = a.mDesc.GetLengths()[0];

            double v = 0;

            for(int k = 0; k < K; ++k)
            {
Jing Zhang's avatar
Jing Zhang committed
55
                v += ck::bf16_to_f32(a(k, m)) * ck::bf16_to_f32(b(k, n));
56
57
            }

Jing Zhang's avatar
Jing Zhang committed
58
            c(m, n) = ck::f32_to_bf16(v);
59
60
61
62
63
64
65
66
67
68
69
70
71
72
        };

        make_ParallelTensorFunctor(f_km_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
            std::thread::hardware_concurrency());
    }
    else if(layout == GemmMatrixLayout::KM_NK_MN)
    {
        auto f_km_nk_mn = [&](auto m, auto n) {
            const int K = a.mDesc.GetLengths()[0];

            double v = 0;

            for(int k = 0; k < K; ++k)
            {
Jing Zhang's avatar
Jing Zhang committed
73
                v += ck::bf16_to_f32(a(k, m)) * ck::bf16_to_f32(b(n, k));
74
75
            }

Jing Zhang's avatar
Jing Zhang committed
76
            c(m, n) = ck::f32_to_bf16(v);
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        };

        make_ParallelTensorFunctor(f_km_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
            std::thread::hardware_concurrency());
    }
    else if(layout == GemmMatrixLayout::MK_KN_NM)
    {
        auto f_mk_kn_nm = [&](auto n, auto m) {
            const int K = a.mDesc.GetLengths()[1];

            double v = 0;

            for(int k = 0; k < K; ++k)
            {
Jing Zhang's avatar
Jing Zhang committed
91
                v += ck::bf16_to_f32(a(m, k)) * ck::bf16_to_f32(b(k, n));
92
93
            }

Jing Zhang's avatar
Jing Zhang committed
94
            c(n, m) = ck::f32_to_bf16(v);
95
96
97
98
99
100
101
102
103
104
105
106
107
108
        };

        make_ParallelTensorFunctor(f_mk_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
            std::thread::hardware_concurrency());
    }
    else if(layout == GemmMatrixLayout::MK_NK_NM)
    {
        auto f_mk_nk_nm = [&](auto n, auto m) {
            const int K = a.mDesc.GetLengths()[1];

            double v = 0;

            for(int k = 0; k < K; ++k)
            {
Jing Zhang's avatar
Jing Zhang committed
109
                v += ck::bf16_to_f32(a(m, k)) * ck::bf16_to_f32(b(n, k));
110
111
            }

Jing Zhang's avatar
Jing Zhang committed
112
            c(n, m) = ck::f32_to_bf16(v);
113
114
115
116
117
118
119
120
121
122
123
124
125
126
        };

        make_ParallelTensorFunctor(f_mk_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
            std::thread::hardware_concurrency());
    }
    else if(layout == GemmMatrixLayout::KM_KN_NM)
    {
        auto f_km_kn_nm = [&](auto n, auto m) {
            const int K = a.mDesc.GetLengths()[0];

            double v = 0;

            for(int k = 0; k < K; ++k)
            {
Jing Zhang's avatar
Jing Zhang committed
127
                v += ck::bf16_to_f32(a(k, m)) * ck::bf16_to_f32(b(k, n));
128
129
            }

Jing Zhang's avatar
Jing Zhang committed
130
            c(n, m) = ck::f32_to_bf16(v);
131
132
133
134
135
136
137
138
139
140
141
142
143
144
        };

        make_ParallelTensorFunctor(f_km_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
            std::thread::hardware_concurrency());
    }
    else if(layout == GemmMatrixLayout::KM_NK_NM)
    {
        auto f_km_nk_nm = [&](auto n, auto m) {
            const int K = a.mDesc.GetLengths()[0];

            double v = 0;

            for(int k = 0; k < K; ++k)
            {
Jing Zhang's avatar
Jing Zhang committed
145
                v += ck::bf16_to_f32(a(k, m)) * ck::bf16_to_f32(b(n, k));
146
147
            }

Jing Zhang's avatar
Jing Zhang committed
148
            c(n, m) = ck::f32_to_bf16(v);
149
150
151
152
153
154
155
156
157
158
159
        };

        make_ParallelTensorFunctor(f_km_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
            std::thread::hardware_concurrency());
    }
    else
    {
        throw std::runtime_error("wrong! not supported layout");
    }
}

160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
template <typename AType, typename BType, typename CType>
void host_gemm_mk_kn_mn(const Tensor<AType>& a_m_k,
                        const Tensor<BType>& b_k_n,
                        Tensor<CType>& c_m_n)
{
    auto f_mk_kn_mn = [&](auto m, auto n) {
        const int K = a_m_k.mDesc.GetLengths()[1];

        double v = 0;

        for(int k = 0; k < K; ++k)
        {
            v += static_cast<const double>(a_m_k(m, k)) * static_cast<const double>(b_k_n(k, n));
        }

        c_m_n(m, n) = v;
    };

    make_ParallelTensorFunctor(f_mk_kn_mn,
                               c_m_n.mDesc.GetLengths()[0],
                               c_m_n.mDesc.GetLengths()[1])(std::thread::hardware_concurrency());
}