amd_xdlops.hpp 10.5 KB
Newer Older
zjing14's avatar
zjing14 committed
1
2
3
#ifndef CK_AMD_XDLOPS_HPP
#define CK_AMD_XDLOPS_HPP

4
#ifndef CK_NOGPU
5
#include "data_type.hpp"
zjing14's avatar
zjing14 committed
6
7
8

namespace ck {

9
// fp32
10
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
11
12
struct intrin_mfma_f32_32x32x1f32;

13
14
template <>
struct intrin_mfma_f32_32x32x1f32<64, 64>
zjing14's avatar
zjing14 committed
15
16
17
18
{
    template <class FloatC>
    __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
    {
19
        reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
20
            reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
21
        reg_c.template AsType<float32_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
22
            reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0);
zjing14's avatar
zjing14 committed
23
24
25
    }
};

26
27
template <>
struct intrin_mfma_f32_32x32x1f32<32, 64>
zjing14's avatar
zjing14 committed
28
29
30
31
{
    template <class FloatC>
    __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
    {
32
        reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
33
            reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
zjing14's avatar
zjing14 committed
34
35
36
    }
};

37
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
38
39
struct intrin_mfma_f32_32x32x2f32;

40
41
template <>
struct intrin_mfma_f32_32x32x2f32<32, 32>
zjing14's avatar
zjing14 committed
42
43
44
45
{
    template <class FloatC>
    __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
    {
46
        reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x2f32(
47
            reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
zjing14's avatar
zjing14 committed
48
49
50
    }
};

51
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
52
53
struct intrin_mfma_f32_16x16x4f32;

54
55
template <>
struct intrin_mfma_f32_16x16x4f32<16, 16>
zjing14's avatar
zjing14 committed
56
57
58
59
{
    template <class FloatC>
    __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
    {
60
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f32(
61
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
zjing14's avatar
zjing14 committed
62
63
64
    }
};

65
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
66
67
struct intrin_mfma_f32_16x16x1f32;

68
69
template <>
struct intrin_mfma_f32_16x16x1f32<16, 64>
zjing14's avatar
zjing14 committed
70
71
72
73
{
    template <class FloatC>
    __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
    {
74
        reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x1f32(
75
            reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 2, 0, 0);
zjing14's avatar
zjing14 committed
76
77
78
    }
};

79
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
80
81
struct intrin_mfma_f32_4x4x1f32;

82
83
template <>
struct intrin_mfma_f32_4x4x1f32<4, 64>
zjing14's avatar
zjing14 committed
84
85
86
87
{
    template <class FloatC>
    __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
    {
88
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
89
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
zjing14's avatar
zjing14 committed
90
91
92
    }
};

93
94
template <>
struct intrin_mfma_f32_4x4x1f32<8, 64>
zjing14's avatar
zjing14 committed
95
96
97
98
{
    template <class FloatC>
    __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
    {
99
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
100
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
101
        reg_c.template AsType<float4_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
102
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<1>{}], 4, 1, 0);
zjing14's avatar
zjing14 committed
103
104
105
    }
};

106
// fp16
107
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
108
109
struct intrin_mfma_f32_32x32x4f16;

110
111
template <>
struct intrin_mfma_f32_32x32x4f16<64, 64>
zjing14's avatar
zjing14 committed
112
113
114
115
{
    template <class FloatC>
    __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
    {
116
        reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
117
            reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
118
        reg_c.template AsType<float32_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
119
            reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0);
zjing14's avatar
zjing14 committed
120
121
122
    }
};

123
124
template <>
struct intrin_mfma_f32_32x32x4f16<32, 64>
zjing14's avatar
zjing14 committed
125
126
127
128
{
    template <class FloatC>
    __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
    {
129
        reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
130
            reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
zjing14's avatar
zjing14 committed
131
132
133
    }
};

134
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
135
136
struct intrin_mfma_f32_32x32x8f16;

137
138
template <>
struct intrin_mfma_f32_32x32x8f16<32, 32>
zjing14's avatar
zjing14 committed
139
140
141
142
{
    template <class FloatC>
    __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
    {
143
        reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8f16(
144
            reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
zjing14's avatar
zjing14 committed
145
146
147
    }
};

148
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
149
150
struct intrin_mfma_f32_16x16x16f16;

151
152
template <>
struct intrin_mfma_f32_16x16x16f16<16, 16>
zjing14's avatar
zjing14 committed
153
154
155
156
{
    template <class FloatC>
    __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
    {
157
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16f16(
158
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
zjing14's avatar
zjing14 committed
159
160
161
    }
};

162
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
163
164
struct intrin_mfma_f32_16x16x4f16;

165
166
template <>
struct intrin_mfma_f32_16x16x4f16<16, 64>
zjing14's avatar
zjing14 committed
167
168
169
170
{
    template <class FloatC>
    __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
    {
171
        reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f16(
172
            reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 2, 0, 0);
zjing14's avatar
zjing14 committed
173
174
175
    }
};

176
template <index_t MPerWave, index_t NPerWave>
zjing14's avatar
zjing14 committed
177
178
struct intrin_mfma_f32_4x4x4f16;

179
180
template <>
struct intrin_mfma_f32_4x4x4f16<4, 64>
zjing14's avatar
zjing14 committed
181
182
183
184
{
    template <class FloatC>
    __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
    {
185
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
186
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
zjing14's avatar
zjing14 committed
187
188
189
    }
};

190
191
template <>
struct intrin_mfma_f32_4x4x4f16<8, 64>
zjing14's avatar
zjing14 committed
192
193
194
195
{
    template <class FloatC>
    __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
    {
196
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
197
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
198
        reg_c.template AsType<float4_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
199
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<1>{}], 4, 1, 0);
zjing14's avatar
zjing14 committed
200
201
202
    }
};

203
204
205
// bfp16
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x8bf16_1k;
zjing14's avatar
zjing14 committed
206

207
208
template <>
struct intrin_mfma_f32_32x32x8bf16_1k<32, 32>
zjing14's avatar
zjing14 committed
209
{
210
    template <class FloatC>
211
    __device__ static void Run(const bhalf4_t& reg_a, const bhalf4_t& reg_b, FloatC& reg_c)
zjing14's avatar
zjing14 committed
212
    {
213
214
        reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
            reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
zjing14's avatar
zjing14 committed
215
216
217
    }
};

218
219
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x16bf16_1k;
zjing14's avatar
zjing14 committed
220

221
222
template <>
struct intrin_mfma_f32_16x16x16bf16_1k<16, 16>
zjing14's avatar
zjing14 committed
223
{
224
    template <class FloatC>
225
    __device__ static void Run(const bhalf4_t& reg_a, const bhalf4_t& reg_b, FloatC& reg_c)
zjing14's avatar
zjing14 committed
226
    {
227
228
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
zjing14's avatar
zjing14 committed
229
230
231
    }
};

232
233
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x4bf16;
zjing14's avatar
zjing14 committed
234

235
236
template <>
struct intrin_mfma_f32_32x32x4bf16<32, 32>
zjing14's avatar
zjing14 committed
237
{
238
    template <class FloatC>
239
    __device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c)
zjing14's avatar
zjing14 committed
240
    {
241
        reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16(
242
            reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
zjing14's avatar
zjing14 committed
243
244
245
246
    }
};

template <index_t MPerWave, index_t NPerWave>
247
struct intrin_mfma_f32_16x16x8bf16;
zjing14's avatar
zjing14 committed
248
249

template <>
250
struct intrin_mfma_f32_16x16x8bf16<16, 16>
zjing14's avatar
zjing14 committed
251
{
252
    template <class FloatC>
253
    __device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c)
254
    {
255
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16(
256
257
258
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
    }
};
zjing14's avatar
zjing14 committed
259
260

template <index_t MPerWave, index_t NPerWave>
261
struct intrin_mfma_i32_32x32x8i8;
zjing14's avatar
zjing14 committed
262
263

template <>
264
struct intrin_mfma_i32_32x32x8i8<32, 32>
zjing14's avatar
zjing14 committed
265
{
266
267
    template <class FloatC>
    __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
zjing14's avatar
zjing14 committed
268
    {
269
        reg_c.template AsType<int32x16_t>()(Number<0>{}) =
Chao Liu's avatar
Chao Liu committed
270
271
            __builtin_amdgcn_mfma_i32_32x32x8i8(bit_cast<int32_t>(reg_a),
                                                bit_cast<int32_t>(reg_b),
272
273
274
275
                                                reg_c.template AsType<int32x16_t>()[Number<0>{}],
                                                0,
                                                0,
                                                0);
zjing14's avatar
zjing14 committed
276
277
278
    }
};

279
280
281
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_i32_16x16x16i8;

zjing14's avatar
zjing14 committed
282
template <>
283
struct intrin_mfma_i32_16x16x16i8<16, 16>
zjing14's avatar
zjing14 committed
284
{
285
286
    template <class FloatC>
    __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
zjing14's avatar
zjing14 committed
287
    {
288
        reg_c.template AsType<int32x4_t>()(Number<0>{}) =
Chao Liu's avatar
Chao Liu committed
289
290
            __builtin_amdgcn_mfma_i32_16x16x16i8(bit_cast<int32_t>(reg_a),
                                                 bit_cast<int32_t>(reg_b),
291
292
293
294
                                                 reg_c.template AsType<int32x4_t>()[Number<0>{}],
                                                 0,
                                                 0,
                                                 0);
zjing14's avatar
zjing14 committed
295
296
297
    }
};

298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f64_16x16x4f64;

template <>
struct intrin_mfma_f64_16x16x4f64<16, 16>
{
    template <class FloatC>
    __device__ static void Run(const double& reg_a, const double& reg_b, FloatC& reg_c)
    {
#ifdef __gfx90a__
        reg_c.template AsType<double4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f64_16x16x4f64(
            reg_a, reg_b, reg_c.template AsType<double4_t>()[Number<0>{}], 0, 0, 0);
#else
        ignore = reg_a;
        ignore = reg_b;
        ignore = reg_c;
#endif
    }
};
zjing14's avatar
zjing14 committed
317
318
} // namespace ck
#endif
319
#endif