amd_xdlops.hpp 10.5 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

zjing14's avatar
zjing14 committed
4
5
6
#ifndef CK_AMD_XDLOPS_HPP
#define CK_AMD_XDLOPS_HPP

7
#include "data_type.hpp"
zjing14's avatar
zjing14 committed
8
9
10

namespace ck {

11
// fp32
12
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
13
14
struct intrin_mfma_f32_32x32x1f32;

15
16
template <>
struct intrin_mfma_f32_32x32x1f32<64, 64>
zjing14's avatar
zjing14 committed
17
18
19
20
{
    template <class FloatC>
    __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
    {
21
        reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
22
            reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
23
        reg_c.template AsType<float32_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
24
            reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0);
zjing14's avatar
zjing14 committed
25
26
27
    }
};

28
29
template <>
struct intrin_mfma_f32_32x32x1f32<32, 64>
zjing14's avatar
zjing14 committed
30
31
32
33
{
    template <class FloatC>
    __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
    {
34
        reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
35
            reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
zjing14's avatar
zjing14 committed
36
37
38
    }
};

39
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
40
41
struct intrin_mfma_f32_32x32x2f32;

42
43
template <>
struct intrin_mfma_f32_32x32x2f32<32, 32>
zjing14's avatar
zjing14 committed
44
45
46
47
{
    template <class FloatC>
    __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
    {
48
        reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x2f32(
49
            reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
zjing14's avatar
zjing14 committed
50
51
52
    }
};

53
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
54
55
struct intrin_mfma_f32_16x16x4f32;

56
57
template <>
struct intrin_mfma_f32_16x16x4f32<16, 16>
zjing14's avatar
zjing14 committed
58
59
60
61
{
    template <class FloatC>
    __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
    {
62
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f32(
63
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
zjing14's avatar
zjing14 committed
64
65
66
    }
};

67
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
68
69
struct intrin_mfma_f32_16x16x1f32;

70
71
template <>
struct intrin_mfma_f32_16x16x1f32<16, 64>
zjing14's avatar
zjing14 committed
72
73
74
75
{
    template <class FloatC>
    __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
    {
76
        reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x1f32(
77
            reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 2, 0, 0);
zjing14's avatar
zjing14 committed
78
79
80
    }
};

81
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
82
83
struct intrin_mfma_f32_4x4x1f32;

84
85
template <>
struct intrin_mfma_f32_4x4x1f32<4, 64>
zjing14's avatar
zjing14 committed
86
87
88
89
{
    template <class FloatC>
    __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
    {
90
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
91
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
zjing14's avatar
zjing14 committed
92
93
94
    }
};

95
96
template <>
struct intrin_mfma_f32_4x4x1f32<8, 64>
zjing14's avatar
zjing14 committed
97
98
99
100
{
    template <class FloatC>
    __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
    {
101
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
102
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
103
        reg_c.template AsType<float4_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
104
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<1>{}], 4, 1, 0);
zjing14's avatar
zjing14 committed
105
106
107
    }
};

108
// fp16
109
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
110
111
struct intrin_mfma_f32_32x32x4f16;

112
113
template <>
struct intrin_mfma_f32_32x32x4f16<64, 64>
zjing14's avatar
zjing14 committed
114
115
116
117
{
    template <class FloatC>
    __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
    {
118
        reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
119
            reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
120
        reg_c.template AsType<float32_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
121
            reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0);
zjing14's avatar
zjing14 committed
122
123
124
    }
};

125
126
template <>
struct intrin_mfma_f32_32x32x4f16<32, 64>
zjing14's avatar
zjing14 committed
127
128
129
130
{
    template <class FloatC>
    __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
    {
131
        reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
132
            reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
zjing14's avatar
zjing14 committed
133
134
135
    }
};

136
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
137
138
struct intrin_mfma_f32_32x32x8f16;

139
140
template <>
struct intrin_mfma_f32_32x32x8f16<32, 32>
zjing14's avatar
zjing14 committed
141
142
143
144
{
    template <class FloatC>
    __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
    {
145
        reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8f16(
146
            reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
zjing14's avatar
zjing14 committed
147
148
149
    }
};

150
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
151
152
struct intrin_mfma_f32_16x16x16f16;

153
154
template <>
struct intrin_mfma_f32_16x16x16f16<16, 16>
zjing14's avatar
zjing14 committed
155
156
157
158
{
    template <class FloatC>
    __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
    {
159
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16f16(
160
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
zjing14's avatar
zjing14 committed
161
162
163
    }
};

164
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
165
166
struct intrin_mfma_f32_16x16x4f16;

167
168
template <>
struct intrin_mfma_f32_16x16x4f16<16, 64>
zjing14's avatar
zjing14 committed
169
170
171
172
{
    template <class FloatC>
    __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
    {
173
        reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f16(
174
            reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 2, 0, 0);
zjing14's avatar
zjing14 committed
175
176
177
    }
};

178
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
179
180
struct intrin_mfma_f32_4x4x4f16;

181
182
template <>
struct intrin_mfma_f32_4x4x4f16<4, 64>
zjing14's avatar
zjing14 committed
183
184
185
186
{
    template <class FloatC>
    __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
    {
187
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
188
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
zjing14's avatar
zjing14 committed
189
190
191
    }
};

192
193
template <>
struct intrin_mfma_f32_4x4x4f16<8, 64>
zjing14's avatar
zjing14 committed
194
195
196
197
{
    template <class FloatC>
    __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
    {
198
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
199
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
200
        reg_c.template AsType<float4_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
201
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<1>{}], 4, 1, 0);
zjing14's avatar
zjing14 committed
202
203
204
    }
};

205
206
207
// bfp16
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x8bf16_1k;
zjing14's avatar
zjing14 committed
208

209
210
template <>
struct intrin_mfma_f32_32x32x8bf16_1k<32, 32>
zjing14's avatar
zjing14 committed
211
{
212
    template <class FloatC>
213
    __device__ static void Run(const bhalf4_t& reg_a, const bhalf4_t& reg_b, FloatC& reg_c)
zjing14's avatar
zjing14 committed
214
    {
215
216
        reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
            reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
zjing14's avatar
zjing14 committed
217
218
219
    }
};

220
221
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x16bf16_1k;
zjing14's avatar
zjing14 committed
222

223
224
template <>
struct intrin_mfma_f32_16x16x16bf16_1k<16, 16>
zjing14's avatar
zjing14 committed
225
{
226
    template <class FloatC>
227
    __device__ static void Run(const bhalf4_t& reg_a, const bhalf4_t& reg_b, FloatC& reg_c)
zjing14's avatar
zjing14 committed
228
    {
229
230
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
zjing14's avatar
zjing14 committed
231
232
233
    }
};

234
235
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x4bf16;
zjing14's avatar
zjing14 committed
236

237
238
template <>
struct intrin_mfma_f32_32x32x4bf16<32, 32>
zjing14's avatar
zjing14 committed
239
{
240
    template <class FloatC>
241
    __device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c)
zjing14's avatar
zjing14 committed
242
    {
243
        reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16(
244
            reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
zjing14's avatar
zjing14 committed
245
246
247
248
    }
};

template <index_t MPerWave, index_t NPerWave>
249
struct intrin_mfma_f32_16x16x8bf16;
zjing14's avatar
zjing14 committed
250
251

template <>
252
struct intrin_mfma_f32_16x16x8bf16<16, 16>
zjing14's avatar
zjing14 committed
253
{
254
    template <class FloatC>
255
    __device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c)
256
    {
257
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8bf16(
258
259
260
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
    }
};
zjing14's avatar
zjing14 committed
261
262

template <index_t MPerWave, index_t NPerWave>
263
struct intrin_mfma_i32_32x32x8i8;
zjing14's avatar
zjing14 committed
264
265

template <>
266
struct intrin_mfma_i32_32x32x8i8<32, 32>
zjing14's avatar
zjing14 committed
267
{
268
269
    template <class FloatC>
    __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
zjing14's avatar
zjing14 committed
270
    {
271
        reg_c.template AsType<int32x16_t>()(Number<0>{}) =
Chao Liu's avatar
Chao Liu committed
272
273
            __builtin_amdgcn_mfma_i32_32x32x8i8(bit_cast<int32_t>(reg_a),
                                                bit_cast<int32_t>(reg_b),
274
275
276
277
                                                reg_c.template AsType<int32x16_t>()[Number<0>{}],
                                                0,
                                                0,
                                                0);
zjing14's avatar
zjing14 committed
278
279
280
    }
};

281
282
283
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_i32_16x16x16i8;

zjing14's avatar
zjing14 committed
284
template <>
285
struct intrin_mfma_i32_16x16x16i8<16, 16>
zjing14's avatar
zjing14 committed
286
{
287
288
    template <class FloatC>
    __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
zjing14's avatar
zjing14 committed
289
    {
290
        reg_c.template AsType<int32x4_t>()(Number<0>{}) =
Chao Liu's avatar
Chao Liu committed
291
292
            __builtin_amdgcn_mfma_i32_16x16x16i8(bit_cast<int32_t>(reg_a),
                                                 bit_cast<int32_t>(reg_b),
293
294
295
296
                                                 reg_c.template AsType<int32x4_t>()[Number<0>{}],
                                                 0,
                                                 0,
                                                 0);
zjing14's avatar
zjing14 committed
297
298
299
    }
};

300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f64_16x16x4f64;

template <>
struct intrin_mfma_f64_16x16x4f64<16, 16>
{
    template <class FloatC>
    __device__ static void Run(const double& reg_a, const double& reg_b, FloatC& reg_c)
    {
#ifdef __gfx90a__
        reg_c.template AsType<double4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f64_16x16x4f64(
            reg_a, reg_b, reg_c.template AsType<double4_t>()[Number<0>{}], 0, 0, 0);
#else
        ignore = reg_a;
        ignore = reg_b;
        ignore = reg_c;
#endif
    }
};
zjing14's avatar
zjing14 committed
319
320
} // namespace ck
#endif