amd_xdlops.hpp 11.2 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

zjing14's avatar
zjing14 committed
4
5
6
#ifndef CK_AMD_XDLOPS_HPP
#define CK_AMD_XDLOPS_HPP

7
#include "data_type.hpp"
zjing14's avatar
zjing14 committed
8
9
10

namespace ck {

11
// fp32
12
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
13
14
struct intrin_mfma_f32_32x32x1f32;

15
16
template <>
struct intrin_mfma_f32_32x32x1f32<64, 64>
zjing14's avatar
zjing14 committed
17
18
19
20
{
    template <class FloatC>
    __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
    {
21
        reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
22
            reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
23
        reg_c.template AsType<float32_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
24
            reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0);
zjing14's avatar
zjing14 committed
25
26
27
    }
};

28
29
template <>
struct intrin_mfma_f32_32x32x1f32<32, 64>
zjing14's avatar
zjing14 committed
30
31
32
33
{
    template <class FloatC>
    __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
    {
34
        reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
35
            reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
zjing14's avatar
zjing14 committed
36
37
38
    }
};

39
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
40
41
struct intrin_mfma_f32_32x32x2f32;

42
43
template <>
struct intrin_mfma_f32_32x32x2f32<32, 32>
zjing14's avatar
zjing14 committed
44
45
46
47
{
    template <class FloatC>
    __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
    {
48
        reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x2f32(
49
            reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
zjing14's avatar
zjing14 committed
50
51
52
    }
};

53
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
54
55
struct intrin_mfma_f32_16x16x4f32;

56
57
template <>
struct intrin_mfma_f32_16x16x4f32<16, 16>
zjing14's avatar
zjing14 committed
58
59
60
61
{
    template <class FloatC>
    __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
    {
62
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f32(
63
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
zjing14's avatar
zjing14 committed
64
65
66
    }
};

67
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
68
69
struct intrin_mfma_f32_16x16x1f32;

70
71
template <>
struct intrin_mfma_f32_16x16x1f32<16, 64>
zjing14's avatar
zjing14 committed
72
73
74
75
{
    template <class FloatC>
    __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
    {
76
        reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x1f32(
77
            reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 2, 0, 0);
zjing14's avatar
zjing14 committed
78
79
80
    }
};

81
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
82
83
struct intrin_mfma_f32_4x4x1f32;

84
85
template <>
struct intrin_mfma_f32_4x4x1f32<4, 64>
zjing14's avatar
zjing14 committed
86
87
88
89
{
    template <class FloatC>
    __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
    {
90
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
91
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
zjing14's avatar
zjing14 committed
92
93
94
    }
};

95
96
template <>
struct intrin_mfma_f32_4x4x1f32<8, 64>
zjing14's avatar
zjing14 committed
97
98
99
100
{
    template <class FloatC>
    __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
    {
101
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
102
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
103
        reg_c.template AsType<float4_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
104
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<1>{}], 4, 1, 0);
zjing14's avatar
zjing14 committed
105
106
107
    }
};

108
// fp16
109
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
110
111
struct intrin_mfma_f32_32x32x4f16;

112
113
template <>
struct intrin_mfma_f32_32x32x4f16<64, 64>
zjing14's avatar
zjing14 committed
114
115
116
117
{
    template <class FloatC>
    __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
    {
118
        reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
119
            reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
120
        reg_c.template AsType<float32_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
121
            reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0);
zjing14's avatar
zjing14 committed
122
123
124
    }
};

125
126
template <>
struct intrin_mfma_f32_32x32x4f16<32, 64>
zjing14's avatar
zjing14 committed
127
128
129
130
{
    template <class FloatC>
    __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
    {
131
        reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
132
            reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
zjing14's avatar
zjing14 committed
133
134
135
    }
};

136
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
137
138
struct intrin_mfma_f32_32x32x8f16;

139
140
template <>
struct intrin_mfma_f32_32x32x8f16<32, 32>
zjing14's avatar
zjing14 committed
141
142
143
144
{
    template <class FloatC>
    __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
    {
145
        reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8f16(
146
            reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
zjing14's avatar
zjing14 committed
147
148
149
    }
};

150
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
151
152
struct intrin_mfma_f32_16x16x16f16;

153
154
template <>
struct intrin_mfma_f32_16x16x16f16<16, 16>
zjing14's avatar
zjing14 committed
155
156
157
158
{
    template <class FloatC>
    __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
    {
159
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16f16(
160
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
zjing14's avatar
zjing14 committed
161
162
163
    }
};

164
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
165
166
struct intrin_mfma_f32_16x16x4f16;

167
168
template <>
struct intrin_mfma_f32_16x16x4f16<16, 64>
zjing14's avatar
zjing14 committed
169
170
171
172
{
    template <class FloatC>
    __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
    {
173
        reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f16(
174
            reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 2, 0, 0);
zjing14's avatar
zjing14 committed
175
176
177
    }
};

178
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
179
180
struct intrin_mfma_f32_4x4x4f16;

181
182
template <>
struct intrin_mfma_f32_4x4x4f16<4, 64>
zjing14's avatar
zjing14 committed
183
184
185
186
{
    template <class FloatC>
    __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
    {
187
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
188
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
zjing14's avatar
zjing14 committed
189
190
191
    }
};

192
193
template <>
struct intrin_mfma_f32_4x4x4f16<8, 64>
zjing14's avatar
zjing14 committed
194
195
196
197
{
    template <class FloatC>
    __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
    {
198
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
199
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
200
        reg_c.template AsType<float4_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
201
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<1>{}], 4, 1, 0);
zjing14's avatar
zjing14 committed
202
203
204
    }
};

205
206
207
// bfp16
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x8bf16_1k;
zjing14's avatar
zjing14 committed
208

209
210
template <>
struct intrin_mfma_f32_32x32x8bf16_1k<32, 32>
zjing14's avatar
zjing14 committed
211
{
212
    template <class FloatC>
213
    __device__ static void Run(const bhalf4_t& reg_a, const bhalf4_t& reg_b, FloatC& reg_c)
zjing14's avatar
zjing14 committed
214
    {
215
216
        reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
            reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
zjing14's avatar
zjing14 committed
217
    }
ltqin's avatar
ltqin committed
218
219
220
221
222
223
224

    template <class FloatC>
    __device__ static void Run(const bfloat16x4_t& reg_a, const bfloat16x4_t& reg_b, FloatC& reg_c)
    {
        reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
            reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
    }
zjing14's avatar
zjing14 committed
225
226
};

227
228
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x16bf16_1k;
zjing14's avatar
zjing14 committed
229

230
231
template <>
struct intrin_mfma_f32_16x16x16bf16_1k<16, 16>
zjing14's avatar
zjing14 committed
232
{
233
    template <class FloatC>
234
    __device__ static void Run(const bhalf4_t& reg_a, const bhalf4_t& reg_b, FloatC& reg_c)
zjing14's avatar
zjing14 committed
235
    {
236
237
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
zjing14's avatar
zjing14 committed
238
239
240
    }
};

241
242
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x4bf16;
zjing14's avatar
zjing14 committed
243

244
245
template <>
struct intrin_mfma_f32_32x32x4bf16<32, 32>
zjing14's avatar
zjing14 committed
246
{
247
    template <class FloatC>
248
    __device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c)
zjing14's avatar
zjing14 committed
249
    {
250
        reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16(
251
            reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
zjing14's avatar
zjing14 committed
252
    }
ltqin's avatar
ltqin committed
253
254
255
256
257
258
    template <class FloatC>
    __device__ static void Run(const bfloat16x2_t& reg_a, const bfloat16x2_t& reg_b, FloatC& reg_c)
    {
        reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16(
            reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
    }
zjing14's avatar
zjing14 committed
259
260
261
};

template <index_t MPerWave, index_t NPerWave>
262
struct intrin_mfma_f32_16x16x8bf16;
zjing14's avatar
zjing14 committed
263
264

template <>
265
struct intrin_mfma_f32_16x16x8bf16<16, 16>
zjing14's avatar
zjing14 committed
266
{
267
    template <class FloatC>
268
    __device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c)
269
    {
270
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8bf16(
271
272
273
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
    }
};
zjing14's avatar
zjing14 committed
274
275

template <index_t MPerWave, index_t NPerWave>
276
struct intrin_mfma_i32_32x32x8i8;
zjing14's avatar
zjing14 committed
277
278

template <>
279
struct intrin_mfma_i32_32x32x8i8<32, 32>
zjing14's avatar
zjing14 committed
280
{
281
282
    template <class FloatC>
    __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
zjing14's avatar
zjing14 committed
283
    {
284
        reg_c.template AsType<int32x16_t>()(Number<0>{}) =
Chao Liu's avatar
Chao Liu committed
285
286
            __builtin_amdgcn_mfma_i32_32x32x8i8(bit_cast<int32_t>(reg_a),
                                                bit_cast<int32_t>(reg_b),
287
288
289
290
                                                reg_c.template AsType<int32x16_t>()[Number<0>{}],
                                                0,
                                                0,
                                                0);
zjing14's avatar
zjing14 committed
291
292
293
    }
};

294
295
296
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_i32_16x16x16i8;

zjing14's avatar
zjing14 committed
297
template <>
298
struct intrin_mfma_i32_16x16x16i8<16, 16>
zjing14's avatar
zjing14 committed
299
{
300
301
    template <class FloatC>
    __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
zjing14's avatar
zjing14 committed
302
    {
303
        reg_c.template AsType<int32x4_t>()(Number<0>{}) =
Chao Liu's avatar
Chao Liu committed
304
305
            __builtin_amdgcn_mfma_i32_16x16x16i8(bit_cast<int32_t>(reg_a),
                                                 bit_cast<int32_t>(reg_b),
306
307
308
309
                                                 reg_c.template AsType<int32x4_t>()[Number<0>{}],
                                                 0,
                                                 0,
                                                 0);
zjing14's avatar
zjing14 committed
310
311
312
    }
};

313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f64_16x16x4f64;

template <>
struct intrin_mfma_f64_16x16x4f64<16, 16>
{
    template <class FloatC>
    __device__ static void Run(const double& reg_a, const double& reg_b, FloatC& reg_c)
    {
#ifdef __gfx90a__
        reg_c.template AsType<double4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f64_16x16x4f64(
            reg_a, reg_b, reg_c.template AsType<double4_t>()[Number<0>{}], 0, 0, 0);
#else
        ignore = reg_a;
        ignore = reg_b;
        ignore = reg_c;
#endif
    }
};
zjing14's avatar
zjing14 committed
332
333
} // namespace ck
#endif