amd_xdlops.hpp 12 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

zjing14's avatar
zjing14 committed
4
5
6
#ifndef CK_AMD_XDLOPS_HPP
#define CK_AMD_XDLOPS_HPP

7
#include "data_type.hpp"
zjing14's avatar
zjing14 committed
8
9
10

namespace ck {

11
// fp32
12
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
13
14
struct intrin_mfma_f32_32x32x1f32;

15
16
template <>
struct intrin_mfma_f32_32x32x1f32<64, 64>
zjing14's avatar
zjing14 committed
17
18
19
20
{
    template <class FloatC>
    __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
    {
21
        reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
22
            reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
23
        reg_c.template AsType<float32_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
24
            reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0);
zjing14's avatar
zjing14 committed
25
26
27
    }
};

28
29
template <>
struct intrin_mfma_f32_32x32x1f32<32, 64>
zjing14's avatar
zjing14 committed
30
31
32
33
{
    template <class FloatC>
    __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
    {
34
        reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
35
            reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
zjing14's avatar
zjing14 committed
36
37
38
    }
};

39
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
40
41
struct intrin_mfma_f32_32x32x2f32;

42
43
template <>
struct intrin_mfma_f32_32x32x2f32<32, 32>
zjing14's avatar
zjing14 committed
44
45
46
47
{
    template <class FloatC>
    __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
    {
48
        reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x2f32(
49
            reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
zjing14's avatar
zjing14 committed
50
51
52
    }
};

53
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
54
55
struct intrin_mfma_f32_16x16x4f32;

56
57
template <>
struct intrin_mfma_f32_16x16x4f32<16, 16>
zjing14's avatar
zjing14 committed
58
59
60
61
{
    template <class FloatC>
    __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
    {
62
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f32(
63
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
zjing14's avatar
zjing14 committed
64
65
66
    }
};

67
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
68
69
struct intrin_mfma_f32_16x16x1f32;

70
71
template <>
struct intrin_mfma_f32_16x16x1f32<16, 64>
zjing14's avatar
zjing14 committed
72
73
74
75
{
    template <class FloatC>
    __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
    {
76
        reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x1f32(
77
            reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 2, 0, 0);
zjing14's avatar
zjing14 committed
78
79
80
    }
};

81
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
82
83
struct intrin_mfma_f32_4x4x1f32;

84
85
template <>
struct intrin_mfma_f32_4x4x1f32<4, 64>
zjing14's avatar
zjing14 committed
86
87
88
89
{
    template <class FloatC>
    __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
    {
90
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
91
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
zjing14's avatar
zjing14 committed
92
93
94
    }
};

95
96
template <>
struct intrin_mfma_f32_4x4x1f32<8, 64>
zjing14's avatar
zjing14 committed
97
98
99
100
{
    template <class FloatC>
    __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
    {
101
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
102
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
103
        reg_c.template AsType<float4_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
104
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<1>{}], 4, 1, 0);
zjing14's avatar
zjing14 committed
105
106
107
    }
};

108
// fp16
109
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
110
111
struct intrin_mfma_f32_32x32x4f16;

112
113
template <>
struct intrin_mfma_f32_32x32x4f16<64, 64>
zjing14's avatar
zjing14 committed
114
115
116
117
{
    template <class FloatC>
    __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
    {
118
        reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
119
            reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
120
        reg_c.template AsType<float32_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
121
            reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0);
zjing14's avatar
zjing14 committed
122
123
124
    }
};

125
126
template <>
struct intrin_mfma_f32_32x32x4f16<32, 64>
zjing14's avatar
zjing14 committed
127
128
129
130
{
    template <class FloatC>
    __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
    {
131
        reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
132
            reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
zjing14's avatar
zjing14 committed
133
134
135
    }
};

136
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
137
138
struct intrin_mfma_f32_32x32x8f16;

139
140
template <>
struct intrin_mfma_f32_32x32x8f16<32, 32>
zjing14's avatar
zjing14 committed
141
142
143
144
{
    template <class FloatC>
    __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
    {
145
        reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8f16(
146
            reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
zjing14's avatar
zjing14 committed
147
148
149
    }
};

150
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
151
152
struct intrin_mfma_f32_16x16x16f16;

153
154
template <>
struct intrin_mfma_f32_16x16x16f16<16, 16>
zjing14's avatar
zjing14 committed
155
156
157
158
{
    template <class FloatC>
    __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
    {
159
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16f16(
160
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
zjing14's avatar
zjing14 committed
161
162
163
    }
};

164
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
165
166
struct intrin_mfma_f32_16x16x4f16;

167
168
template <>
struct intrin_mfma_f32_16x16x4f16<16, 64>
zjing14's avatar
zjing14 committed
169
170
171
172
{
    template <class FloatC>
    __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
    {
173
        reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f16(
174
            reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 2, 0, 0);
zjing14's avatar
zjing14 committed
175
176
177
    }
};

178
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
179
180
struct intrin_mfma_f32_4x4x4f16;

181
182
template <>
struct intrin_mfma_f32_4x4x4f16<4, 64>
zjing14's avatar
zjing14 committed
183
184
185
186
{
    template <class FloatC>
    __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
    {
187
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
188
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
zjing14's avatar
zjing14 committed
189
190
191
    }
};

192
193
template <>
struct intrin_mfma_f32_4x4x4f16<8, 64>
zjing14's avatar
zjing14 committed
194
195
196
197
{
    template <class FloatC>
    __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
    {
198
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
199
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
200
        reg_c.template AsType<float4_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
201
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<1>{}], 4, 1, 0);
zjing14's avatar
zjing14 committed
202
203
204
    }
};

205
206
207
// bfp16
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x8bf16_1k;
zjing14's avatar
zjing14 committed
208

209
210
template <>
struct intrin_mfma_f32_32x32x8bf16_1k<32, 32>
zjing14's avatar
zjing14 committed
211
{
212
    template <class FloatC>
213
    __device__ static void Run(const bhalf4_t& reg_a, const bhalf4_t& reg_b, FloatC& reg_c)
zjing14's avatar
zjing14 committed
214
    {
215
216
        reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
            reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
zjing14's avatar
zjing14 committed
217
218
219
    }
};

220
221
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x16bf16_1k;
zjing14's avatar
zjing14 committed
222

223
224
template <>
struct intrin_mfma_f32_16x16x16bf16_1k<16, 16>
zjing14's avatar
zjing14 committed
225
{
226
    template <class FloatC>
227
    __device__ static void Run(const bhalf4_t& reg_a, const bhalf4_t& reg_b, FloatC& reg_c)
zjing14's avatar
zjing14 committed
228
    {
229
230
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
zjing14's avatar
zjing14 committed
231
232
233
    }
};

234
235
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x4bf16;
zjing14's avatar
zjing14 committed
236

237
238
template <>
struct intrin_mfma_f32_32x32x4bf16<32, 32>
zjing14's avatar
zjing14 committed
239
{
240
    template <class FloatC>
241
    __device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c)
zjing14's avatar
zjing14 committed
242
    {
243
        reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16(
244
            reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
zjing14's avatar
zjing14 committed
245
246
247
248
    }
};

template <index_t MPerWave, index_t NPerWave>
249
struct intrin_mfma_f32_16x16x8bf16;
zjing14's avatar
zjing14 committed
250
251

template <>
252
struct intrin_mfma_f32_16x16x8bf16<16, 16>
zjing14's avatar
zjing14 committed
253
{
254
    template <class FloatC>
255
    __device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c)
256
    {
257
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8bf16(
258
259
260
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
    }
};
zjing14's avatar
zjing14 committed
261
262

template <index_t MPerWave, index_t NPerWave>
263
struct intrin_mfma_i32_32x32x8i8;
zjing14's avatar
zjing14 committed
264
265

template <>
266
struct intrin_mfma_i32_32x32x8i8<32, 32>
zjing14's avatar
zjing14 committed
267
{
268
269
    template <class FloatC>
    __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
zjing14's avatar
zjing14 committed
270
    {
271
        reg_c.template AsType<int32x16_t>()(Number<0>{}) =
Chao Liu's avatar
Chao Liu committed
272
273
            __builtin_amdgcn_mfma_i32_32x32x8i8(bit_cast<int32_t>(reg_a),
                                                bit_cast<int32_t>(reg_b),
274
275
276
277
                                                reg_c.template AsType<int32x16_t>()[Number<0>{}],
                                                0,
                                                0,
                                                0);
zjing14's avatar
zjing14 committed
278
279
280
    }
};

281
282
283
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_i32_16x16x16i8;

zjing14's avatar
zjing14 committed
284
template <>
285
struct intrin_mfma_i32_16x16x16i8<16, 16>
zjing14's avatar
zjing14 committed
286
{
287
288
    template <class FloatC>
    __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
zjing14's avatar
zjing14 committed
289
    {
290
        reg_c.template AsType<int32x4_t>()(Number<0>{}) =
Chao Liu's avatar
Chao Liu committed
291
292
            __builtin_amdgcn_mfma_i32_16x16x16i8(bit_cast<int32_t>(reg_a),
                                                 bit_cast<int32_t>(reg_b),
293
294
295
296
                                                 reg_c.template AsType<int32x4_t>()[Number<0>{}],
                                                 0,
                                                 0,
                                                 0);
zjing14's avatar
zjing14 committed
297
298
299
    }
};

300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_i32_32x32x16i8;

template <>
struct intrin_mfma_i32_32x32x16i8<32, 32>
{
    template <class FloatC>
    __device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c)
    {
        reg_c.template AsType<int32x16_t>()(Number<0>{}) =
            __builtin_amdgcn_mfma_i32_32x32x16_i8(bit_cast<int64_t>(reg_a),
                                                  bit_cast<int64_t>(reg_b),
                                                  reg_c.template AsType<int32x16_t>()[Number<0>{}],
                                                  0,
                                                  0,
                                                  0);
    }
};

template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_i32_16x16x32i8;

template <>
struct intrin_mfma_i32_16x16x32i8<16, 16>
{
    template <class FloatC>
    __device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c)
    {
        reg_c.template AsType<int32x4_t>()(Number<0>{}) =
            __builtin_amdgcn_mfma_i32_16x16x32i8(bit_cast<int64_t>(reg_a),
                                                 bit_cast<int64_t>(reg_b),
                                                 reg_c.template AsType<int32x4_t>()[Number<0>{}],
                                                 0,
                                                 0,
                                                 0);
    }
};

338
339
340
341
342
343
344
345
346
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f64_16x16x4f64;

template <>
struct intrin_mfma_f64_16x16x4f64<16, 16>
{
    template <class FloatC>
    __device__ static void Run(const double& reg_a, const double& reg_b, FloatC& reg_c)
    {
347
#if defined(__gfx90a__) || defined(__gfx940__)
348
349
350
351
352
353
354
355
356
        reg_c.template AsType<double4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f64_16x16x4f64(
            reg_a, reg_b, reg_c.template AsType<double4_t>()[Number<0>{}], 0, 0, 0);
#else
        ignore = reg_a;
        ignore = reg_b;
        ignore = reg_c;
#endif
    }
};
zjing14's avatar
zjing14 committed
357
358
} // namespace ck
#endif