amd_xdlops.hpp 14.3 KB
Newer Older
Umang Yadav's avatar
Umang Yadav committed
1
2
3

#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
Chao Liu's avatar
Chao Liu committed
4
// SPDX-License-Identifier: MIT
Illia Silin's avatar
Illia Silin committed
5
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
Chao Liu's avatar
Chao Liu committed
6

zjing14's avatar
zjing14 committed
7
8
9
#ifndef CK_AMD_XDLOPS_HPP
#define CK_AMD_XDLOPS_HPP

10
#include "data_type.hpp"
zjing14's avatar
zjing14 committed
11
12
13

namespace ck {

14
// fp32
15
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
16
17
struct intrin_mfma_f32_32x32x1f32;

18
19
template <>
struct intrin_mfma_f32_32x32x1f32<64, 64>
zjing14's avatar
zjing14 committed
20
21
22
23
{
    template <class FloatC>
    __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
    {
24
        reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
25
            reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
26
        reg_c.template AsType<float32_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
27
            reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0);
zjing14's avatar
zjing14 committed
28
29
30
    }
};

31
32
template <>
struct intrin_mfma_f32_32x32x1f32<32, 64>
zjing14's avatar
zjing14 committed
33
34
35
36
{
    template <class FloatC>
    __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
    {
37
        reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
38
            reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
zjing14's avatar
zjing14 committed
39
40
41
    }
};

42
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
43
44
struct intrin_mfma_f32_32x32x2f32;

45
46
template <>
struct intrin_mfma_f32_32x32x2f32<32, 32>
zjing14's avatar
zjing14 committed
47
48
49
50
{
    template <class FloatC>
    __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
    {
51
        reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x2f32(
52
            reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
zjing14's avatar
zjing14 committed
53
54
55
    }
};

56
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
57
58
struct intrin_mfma_f32_16x16x4f32;

59
60
template <>
struct intrin_mfma_f32_16x16x4f32<16, 16>
zjing14's avatar
zjing14 committed
61
62
63
64
{
    template <class FloatC>
    __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
    {
65
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f32(
66
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
zjing14's avatar
zjing14 committed
67
68
69
    }
};

70
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
71
72
struct intrin_mfma_f32_16x16x1f32;

73
74
template <>
struct intrin_mfma_f32_16x16x1f32<16, 64>
zjing14's avatar
zjing14 committed
75
76
77
78
{
    template <class FloatC>
    __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
    {
79
        reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x1f32(
80
            reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 2, 0, 0);
zjing14's avatar
zjing14 committed
81
82
83
    }
};

84
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
85
86
struct intrin_mfma_f32_4x4x1f32;

87
88
template <>
struct intrin_mfma_f32_4x4x1f32<4, 64>
zjing14's avatar
zjing14 committed
89
90
91
92
{
    template <class FloatC>
    __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
    {
93
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
94
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
zjing14's avatar
zjing14 committed
95
96
97
    }
};

98
99
template <>
struct intrin_mfma_f32_4x4x1f32<8, 64>
zjing14's avatar
zjing14 committed
100
101
102
103
{
    template <class FloatC>
    __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
    {
104
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
105
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
106
        reg_c.template AsType<float4_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
107
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<1>{}], 4, 1, 0);
zjing14's avatar
zjing14 committed
108
109
110
    }
};

111
// fp16
112
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
113
114
struct intrin_mfma_f32_32x32x4f16;

115
116
template <>
struct intrin_mfma_f32_32x32x4f16<64, 64>
zjing14's avatar
zjing14 committed
117
118
119
120
{
    template <class FloatC>
    __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
    {
121
        reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
122
            reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
123
        reg_c.template AsType<float32_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
124
            reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0);
zjing14's avatar
zjing14 committed
125
126
127
    }
};

128
129
template <>
struct intrin_mfma_f32_32x32x4f16<32, 64>
zjing14's avatar
zjing14 committed
130
131
132
133
{
    template <class FloatC>
    __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
    {
134
        reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
135
            reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
zjing14's avatar
zjing14 committed
136
137
138
    }
};

139
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
140
141
struct intrin_mfma_f32_32x32x8f16;

142
143
template <>
struct intrin_mfma_f32_32x32x8f16<32, 32>
zjing14's avatar
zjing14 committed
144
145
146
147
{
    template <class FloatC>
    __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
    {
148
        reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8f16(
149
            reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
zjing14's avatar
zjing14 committed
150
151
152
    }
};

153
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
154
155
struct intrin_mfma_f32_16x16x16f16;

156
157
template <>
struct intrin_mfma_f32_16x16x16f16<16, 16>
zjing14's avatar
zjing14 committed
158
159
160
161
{
    template <class FloatC>
    __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
    {
162
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16f16(
163
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
zjing14's avatar
zjing14 committed
164
165
166
    }
};

167
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
168
169
struct intrin_mfma_f32_16x16x4f16;

170
171
template <>
struct intrin_mfma_f32_16x16x4f16<16, 64>
zjing14's avatar
zjing14 committed
172
173
174
175
{
    template <class FloatC>
    __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
    {
176
        reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f16(
177
            reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 2, 0, 0);
zjing14's avatar
zjing14 committed
178
179
180
    }
};

181
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
182
183
struct intrin_mfma_f32_4x4x4f16;

184
185
template <>
struct intrin_mfma_f32_4x4x4f16<4, 64>
zjing14's avatar
zjing14 committed
186
187
188
189
{
    template <class FloatC>
    __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
    {
190
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
191
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
zjing14's avatar
zjing14 committed
192
193
194
    }
};

195
196
template <>
struct intrin_mfma_f32_4x4x4f16<8, 64>
zjing14's avatar
zjing14 committed
197
198
199
200
{
    template <class FloatC>
    __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
    {
201
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
202
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
203
        reg_c.template AsType<float4_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
204
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<1>{}], 4, 1, 0);
zjing14's avatar
zjing14 committed
205
206
207
    }
};

208
209
210
// bfp16
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x8bf16_1k;
zjing14's avatar
zjing14 committed
211

212
213
template <>
struct intrin_mfma_f32_32x32x8bf16_1k<32, 32>
zjing14's avatar
zjing14 committed
214
{
215
    template <class FloatC>
216
    __device__ static void Run(const bhalf4_t& reg_a, const bhalf4_t& reg_b, FloatC& reg_c)
zjing14's avatar
zjing14 committed
217
    {
218
219
        reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
            reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
zjing14's avatar
zjing14 committed
220
221
222
    }
};

223
224
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x16bf16_1k;
zjing14's avatar
zjing14 committed
225

226
227
template <>
struct intrin_mfma_f32_16x16x16bf16_1k<16, 16>
zjing14's avatar
zjing14 committed
228
{
229
    template <class FloatC>
230
    __device__ static void Run(const bhalf4_t& reg_a, const bhalf4_t& reg_b, FloatC& reg_c)
zjing14's avatar
zjing14 committed
231
    {
232
233
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
zjing14's avatar
zjing14 committed
234
235
236
    }
};

237
238
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x4bf16;
zjing14's avatar
zjing14 committed
239

240
241
template <>
struct intrin_mfma_f32_32x32x4bf16<32, 32>
zjing14's avatar
zjing14 committed
242
{
243
    template <class FloatC>
244
    __device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c)
zjing14's avatar
zjing14 committed
245
    {
246
        reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16(
247
            reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
zjing14's avatar
zjing14 committed
248
249
250
251
    }
};

template <index_t MPerWave, index_t NPerWave>
252
struct intrin_mfma_f32_16x16x8bf16;
zjing14's avatar
zjing14 committed
253
254

template <>
255
struct intrin_mfma_f32_16x16x8bf16<16, 16>
zjing14's avatar
zjing14 committed
256
{
257
    template <class FloatC>
258
    __device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c)
259
    {
260
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8bf16(
261
262
263
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
    }
};
zjing14's avatar
zjing14 committed
264
265

template <index_t MPerWave, index_t NPerWave>
266
struct intrin_mfma_i32_32x32x8i8;
zjing14's avatar
zjing14 committed
267
268

template <>
269
struct intrin_mfma_i32_32x32x8i8<32, 32>
zjing14's avatar
zjing14 committed
270
{
271
272
    template <class FloatC>
    __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
zjing14's avatar
zjing14 committed
273
    {
274
        reg_c.template AsType<int32x16_t>()(Number<0>{}) =
Chao Liu's avatar
Chao Liu committed
275
276
            __builtin_amdgcn_mfma_i32_32x32x8i8(bit_cast<int32_t>(reg_a),
                                                bit_cast<int32_t>(reg_b),
277
278
279
280
                                                reg_c.template AsType<int32x16_t>()[Number<0>{}],
                                                0,
                                                0,
                                                0);
zjing14's avatar
zjing14 committed
281
282
283
    }
};

284
285
286
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_i32_16x16x16i8;

zjing14's avatar
zjing14 committed
287
template <>
288
struct intrin_mfma_i32_16x16x16i8<16, 16>
zjing14's avatar
zjing14 committed
289
{
290
291
    template <class FloatC>
    __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
zjing14's avatar
zjing14 committed
292
    {
293
        reg_c.template AsType<int32x4_t>()(Number<0>{}) =
Chao Liu's avatar
Chao Liu committed
294
295
            __builtin_amdgcn_mfma_i32_16x16x16i8(bit_cast<int32_t>(reg_a),
                                                 bit_cast<int32_t>(reg_b),
296
297
298
299
                                                 reg_c.template AsType<int32x4_t>()[Number<0>{}],
                                                 0,
                                                 0,
                                                 0);
zjing14's avatar
zjing14 committed
300
301
302
    }
};

303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_i32_32x32x16i8;

template <>
struct intrin_mfma_i32_32x32x16i8<32, 32>
{
    template <class FloatC>
    __device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c)
    {
        reg_c.template AsType<int32x16_t>()(Number<0>{}) =
            __builtin_amdgcn_mfma_i32_32x32x16_i8(bit_cast<int64_t>(reg_a),
                                                  bit_cast<int64_t>(reg_b),
                                                  reg_c.template AsType<int32x16_t>()[Number<0>{}],
                                                  0,
                                                  0,
                                                  0);
    }
};

template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_i32_16x16x32i8;

template <>
struct intrin_mfma_i32_16x16x32i8<16, 16>
{
    template <class FloatC>
    __device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c)
    {
        reg_c.template AsType<int32x4_t>()(Number<0>{}) =
            __builtin_amdgcn_mfma_i32_16x16x32i8(bit_cast<int64_t>(reg_a),
                                                 bit_cast<int64_t>(reg_b),
                                                 reg_c.template AsType<int32x4_t>()[Number<0>{}],
                                                 0,
                                                 0,
                                                 0);
    }
};

341
342
343
344
345
346
347
348
349
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f64_16x16x4f64;

template <>
struct intrin_mfma_f64_16x16x4f64<16, 16>
{
    template <class FloatC>
    __device__ static void Run(const double& reg_a, const double& reg_b, FloatC& reg_c)
    {
350
#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
351
352
353
354
355
356
357
358
359
        reg_c.template AsType<double4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f64_16x16x4f64(
            reg_a, reg_b, reg_c.template AsType<double4_t>()[Number<0>{}], 0, 0, 0);
#else
        ignore = reg_a;
        ignore = reg_b;
        ignore = reg_c;
#endif
    }
};
360
361
362
363
364
365
366
367
368
369

template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16f8f8;

template <>
struct intrin_mfma_f32_32x32x16f8f8<32, 32>
{
    template <class FloatC>
    __device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
    {
370
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
        reg_c.template AsType<float16_t>()(Number<0>{}) =
            __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
                bit_cast<long>(reg_a),
                bit_cast<long>(reg_b),
                reg_c.template AsType<float16_t>()[Number<0>{}],
                0,
                0,
                0);
#else
        vector_type<f8_t, 8> reg_a_v(reg_a);
        vector_type<f8_t, 8> reg_b_v(reg_b);

        static_for<0, 8, 1>{}([&](auto k) {
            float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
            float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);

            intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
        });
#endif
    }
};

template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x32f8f8;

template <>
struct intrin_mfma_f32_16x16x32f8f8<16, 16>
{
    template <class FloatC>
    __device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
    {
402
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
            bit_cast<long>(reg_a),
            bit_cast<long>(reg_b),
            reg_c.template AsType<float4_t>()[Number<0>{}],
            0,
            0,
            0);
#else
        vector_type<f8_t, 8> reg_a_v(reg_a);
        vector_type<f8_t, 8> reg_b_v(reg_b);

        static_for<0, 8, 1>{}([&](auto k) {
            float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
            float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);

            intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
        });
#endif
    }
};
zjing14's avatar
zjing14 committed
423
424
} // namespace ck
#endif
Umang Yadav's avatar
Umang Yadav committed
425
426

#pragma clang diagnostic pop