amd_inline_asm.hpp 15.7 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
// SPDX-License-Identifier: MIT
Illia Silin's avatar
Illia Silin committed
2
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
Chao Liu's avatar
Chao Liu committed
3

4
5
6
#ifndef CK_AMD_INLINE_ASM_HPP
#define CK_AMD_INLINE_ASM_HPP

7
#include "data_type.hpp"
8
#include "c_style_pointer_cast.hpp"
Jing Zhang's avatar
Jing Zhang committed
9

10
11
// TODO: deprecate all amd_assembly_outer_product_xxx

12
13
namespace ck {

Jing Zhang's avatar
Jing Zhang committed
14
15
16
17
18
19
20
21
22
inline __device__ int amd_assembly_and_or_b32(int a, int b, int d)
{
    int c;
    asm volatile("v_and_or_b32 %0, %1, %2, %3"
            : "=v"(c)
            : "v"(a), "v"(b), "v"(d));
    return c;
}

Jing Zhang's avatar
format  
Jing Zhang committed
23
24
25
26
27
inline __device__ half2_t amd_assembly_pk_fma_f16(half2_t a, half2_t b, half2_t c)
{
    half2_t d;
    asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c));
    return d;
Jing Zhang's avatar
Jing Zhang committed
28
29
}

Jing Zhang's avatar
format  
Jing Zhang committed
30
31
32
33
34
inline __device__ half2_t amd_assembly_pk_add_f16(half2_t a, half2_t b)
{
    half2_t c;
    asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
    return c;
Jing Zhang's avatar
Jing Zhang committed
35
36
}

Chao Liu's avatar
Chao Liu committed
37
38
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
39
__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1)
Chao Liu's avatar
Chao Liu committed
40
{
Chao Liu's avatar
Chao Liu committed
41
42
43
44
45
46
    asm volatile("\n \
            v_fmac_f32 %0, %2, %3 \n \
            v_fmac_f32 %1, %2, %4 \n \
            "
                 : "=v"(c0), "=v"(c1)
                 : "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
Chao Liu's avatar
Chao Liu committed
47
48
}

Chao Liu's avatar
Chao Liu committed
49
50
51
52
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c2 += inner_product(a, b2)
// c3 += inner_product(a, b3)
53
__device__ void amd_assembly_outer_product_1x4(
54
    float a, float b0, float b1, float b2, float b3, float& c0, float& c1, float& c2, float& c3)
Chao Liu's avatar
Chao Liu committed
55
{
Chao Liu's avatar
Chao Liu committed
56
57
58
59
60
61
62
63
    asm volatile("\n \
            v_fmac_f32 %0, %4, %5 \n \
            v_fmac_f32 %1, %4, %6 \n \
            v_fmac_f32 %2, %4, %7 \n \
            v_fmac_f32 %3, %4, %8 \n \
            "
                 : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
                 : "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3));
Jing Zhang's avatar
Jing Zhang committed
64
65
}

Chao Liu's avatar
Chao Liu committed
66
67
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
68
69
__device__ void
amd_assembly_outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, float& c1)
Chao Liu's avatar
Chao Liu committed
70
{
71
    asm volatile("\n \
72
73
            v_dot2_f32_f16 %0, %2, %3, %0\n \
            v_dot2_f32_f16 %1, %2, %4, %1\n \
74
            "
Chao Liu's avatar
Chao Liu committed
75
76
                 : "=v"(c0), "=v"(c1)
                 : "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
Jing Zhang's avatar
Jing Zhang committed
77
78
}

Chao Liu's avatar
Chao Liu committed
79
80
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
81
82
__device__ void
amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1)
Chao Liu's avatar
Chao Liu committed
83
{
84
    // TODO remove pointer casting
85
86
87
    const half2_t* p_a_half2  = c_style_pointer_cast<const half2_t*>(&a);
    const half2_t* p_b0_half2 = c_style_pointer_cast<const half2_t*>(&b0);
    const half2_t* p_b1_half2 = c_style_pointer_cast<const half2_t*>(&b1);
Chao Liu's avatar
Chao Liu committed
88

89
90
    // do dot2 two times
    asm volatile("\n \
91
92
93
94
            v_dot2_f32_f16 %0, %2, %4, %0\n \
            v_dot2_f32_f16 %1, %2, %6, %1\n \
            v_dot2_f32_f16 %0, %3, %5, %0\n \
            v_dot2_f32_f16 %1, %3, %7, %1\n \
95
            "
Chao Liu's avatar
Chao Liu committed
96
                 : "=v"(c0), "=v"(c1)
97
                 : "v"(p_a_half2[0]),
Chao Liu's avatar
Chao Liu committed
98
                   "v"(p_a_half2[1]),
99
100
101
                   "v"(p_b0_half2[0]),
                   "v"(p_b0_half2[1]),
                   "v"(p_b1_half2[0]),
Chao Liu's avatar
Chao Liu committed
102
                   "v"(p_b1_half2[1]),
103
                   "0"(c0),
Chao Liu's avatar
Chao Liu committed
104
                   "1"(c1));
Jing Zhang's avatar
Jing Zhang committed
105
106
}

Chao Liu's avatar
Chao Liu committed
107
108
109
110
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c2 += inner_product(a, b2)
// c3 += inner_product(a, b3)
111
112
113
114
115
116
117
118
119
__device__ void amd_assembly_outer_product_1x4(half2_t a,
                                               half2_t b0,
                                               half2_t b1,
                                               half2_t b2,
                                               half2_t b3,
                                               float& c0,
                                               float& c1,
                                               float& c2,
                                               float& c3)
Jing Zhang's avatar
Jing Zhang committed
120
{
121
    asm volatile("\n \
122
123
124
125
            v_dot2_f32_f16 %0, %4, %5, %0\n \
            v_dot2_f32_f16 %1, %4, %6, %1\n \
            v_dot2_f32_f16 %2, %4, %7, %2\n \
            v_dot2_f32_f16 %3, %4, %8, %3\n \
126
            "
Chao Liu's avatar
Chao Liu committed
127
128
                 : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
                 : "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3));
Jing Zhang's avatar
Jing Zhang committed
129
130
}

Chao Liu's avatar
Chao Liu committed
131
132
133
134
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c2 += inner_product(a, b2)
// c3 += inner_product(a, b3)
135
136
137
138
139
140
141
142
143
__device__ void amd_assembly_outer_product_1x4(half4_t a,
                                               half4_t b0,
                                               half4_t b1,
                                               half4_t b2,
                                               half4_t b3,
                                               float& c0,
                                               float& c1,
                                               float& c2,
                                               float& c3)
Jing Zhang's avatar
Jing Zhang committed
144
{
145
    // TODO remove pointer casting
146
147
148
149
150
    const half2_t* p_a_half2  = c_style_pointer_cast<const half2_t*>(&a);
    const half2_t* p_b0_half2 = c_style_pointer_cast<const half2_t*>(&b0);
    const half2_t* p_b1_half2 = c_style_pointer_cast<const half2_t*>(&b1);
    const half2_t* p_b2_half2 = c_style_pointer_cast<const half2_t*>(&b2);
    const half2_t* p_b3_half2 = c_style_pointer_cast<const half2_t*>(&b3);
Jing Zhang's avatar
Jing Zhang committed
151

152
153
    // do dot2 two times
    asm volatile("\n \
154
155
156
157
158
159
160
161
            v_dot2_f32_f16 %0, %4, %6,  %0\n \
            v_dot2_f32_f16 %1, %4, %8,  %1\n \
            v_dot2_f32_f16 %2, %4, %10, %2\n \
            v_dot2_f32_f16 %3, %4, %12, %3\n \
            v_dot2_f32_f16 %0, %5, %7,  %0\n \
            v_dot2_f32_f16 %1, %5, %9,  %1\n \
            v_dot2_f32_f16 %2, %5, %11, %2\n \
            v_dot2_f32_f16 %3, %5, %13, %3\n \
Jing Zhang's avatar
Jing Zhang committed
162
            "
Chao Liu's avatar
Chao Liu committed
163
                 : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
164
                 : "v"(p_a_half2[0]),
Chao Liu's avatar
Chao Liu committed
165
                   "v"(p_a_half2[1]),
166
167
168
                   "v"(p_b0_half2[0]),
                   "v"(p_b0_half2[1]),
                   "v"(p_b1_half2[0]),
Chao Liu's avatar
Chao Liu committed
169
                   "v"(p_b1_half2[1]),
170
171
172
                   "v"(p_b2_half2[0]),
                   "v"(p_b2_half2[1]),
                   "v"(p_b3_half2[0]),
Chao Liu's avatar
Chao Liu committed
173
                   "v"(p_b3_half2[1]),
174
175
176
                   "0"(c0),
                   "1"(c1),
                   "2"(c2),
Chao Liu's avatar
Chao Liu committed
177
178
179
                   "3"(c3));
}

180
181
182
183
184
185
186
187
188
189
190
__device__ void amd_assembly_outer_product_1x4(half8_t a,
                                               half8_t b0,
                                               half8_t b1,
                                               half8_t b2,
                                               half8_t b3,
                                               float& c0,
                                               float& c1,
                                               float& c2,
                                               float& c3)
{

191
    // TODO remove pointer casting
192
193
194
195
196
    const half4_t* p_a_half4  = c_style_pointer_cast<const half4_t*>(&a);
    const half4_t* p_b0_half4 = c_style_pointer_cast<const half4_t*>(&b0);
    const half4_t* p_b1_half4 = c_style_pointer_cast<const half4_t*>(&b1);
    const half4_t* p_b2_half4 = c_style_pointer_cast<const half4_t*>(&b2);
    const half4_t* p_b3_half4 = c_style_pointer_cast<const half4_t*>(&b3);
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214

    amd_assembly_outer_product_1x4(
        p_a_half4[0], p_b0_half4[0], p_b1_half4[0], p_b2_half4[0], p_b3_half4[0], c0, c1, c2, c3);

    amd_assembly_outer_product_1x4(
        p_a_half4[1], p_b0_half4[1], p_b1_half4[1], p_b2_half4[1], p_b3_half4[1], c0, c1, c2, c3);
}

__device__ void amd_assembly_outer_product_1x4(half16_t a,
                                               half16_t b0,
                                               half16_t b1,
                                               half16_t b2,
                                               half16_t b3,
                                               float& c0,
                                               float& c1,
                                               float& c2,
                                               float& c3)
{
215
    // TODO remove pointer casting
216
217
218
219
220
    const half8_t* p_a_half8  = c_style_pointer_cast<const half8_t*>(&a);
    const half8_t* p_b0_half8 = c_style_pointer_cast<const half8_t*>(&b0);
    const half8_t* p_b1_half8 = c_style_pointer_cast<const half8_t*>(&b1);
    const half8_t* p_b2_half8 = c_style_pointer_cast<const half8_t*>(&b2);
    const half8_t* p_b3_half8 = c_style_pointer_cast<const half8_t*>(&b3);
221
222
223
224
225
226
227
228

    amd_assembly_outer_product_1x4(
        p_a_half8[0], p_b0_half8[0], p_b1_half8[0], p_b2_half8[0], p_b3_half8[0], c0, c1, c2, c3);

    amd_assembly_outer_product_1x4(
        p_a_half8[1], p_b0_half8[1], p_b1_half8[1], p_b2_half8[1], p_b3_half8[1], c0, c1, c2, c3);
}

Chao Liu's avatar
Chao Liu committed
229
230
231
232
233
234
235
236
237
238
239
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
__device__ void
amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0, int32_t& c1)
{
#if 1
    asm volatile("\n \
            v_dot4_i32_i8 %0, %2, %3, %0\n \
            v_dot4_i32_i8 %1, %2, %4, %1\n \
            "
                 : "=v"(c0), "=v"(c1)
240
241
242
                 : "v"(bit_cast<int32_t>(a)),
                   "v"(bit_cast<int32_t>(b0)),
                   "v"(bit_cast<int32_t>(b1)),
243
244
                   "0"(c0),
                   "1"(c1));
Chao Liu's avatar
Chao Liu committed
245
#else
zjing14's avatar
zjing14 committed
246
247
    c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
    c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
Chao Liu's avatar
Chao Liu committed
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
#endif
}

// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c2 += inner_product(a, b2)
// c3 += inner_product(a, b3)
__device__ void amd_assembly_outer_product_1x4(int8x4_t a,
                                               int8x4_t b0,
                                               int8x4_t b1,
                                               int8x4_t b2,
                                               int8x4_t b3,
                                               int32_t& c0,
                                               int32_t& c1,
                                               int32_t& c2,
                                               int32_t& c3)
{
#if 1
    asm volatile("\n \
            v_dot4_i32_i8 %0, %4, %5, %0\n \
            v_dot4_i32_i8 %1, %4, %6, %1\n \
            v_dot4_i32_i8 %2, %4, %7, %2\n \
            v_dot4_i32_i8 %3, %4, %8, %3\n \
            "
                 : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
273
274
275
276
277
                 : "v"(bit_cast<int32_t>(a)),
                   "v"(bit_cast<int32_t>(b0)),
                   "v"(bit_cast<int32_t>(b1)),
                   "v"(bit_cast<int32_t>(b2)),
                   "v"(bit_cast<int32_t>(b3)),
278
279
280
281
                   "0"(c0),
                   "1"(c1),
                   "2"(c2),
                   "3"(c3));
Chao Liu's avatar
Chao Liu committed
282
#else
zjing14's avatar
zjing14 committed
283
284
285
286
    c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
    c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
    c2 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b2), c2, false);
    c3 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b3), c3, false);
Chao Liu's avatar
Chao Liu committed
287
#endif
Jing Zhang's avatar
Jing Zhang committed
288
}
289

290
291
292
293
294
295
296
297
298
299
__device__ void amd_assembly_outer_product_1x4(int8x8_t a,
                                               int8x8_t b0,
                                               int8x8_t b1,
                                               int8x8_t b2,
                                               int8x8_t b3,
                                               int32_t& c0,
                                               int32_t& c1,
                                               int32_t& c2,
                                               int32_t& c3)
{
300
301
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
302

303
304
305
306
307
    amd_assembly_outer_product_1x4(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I0],
                                   vector_type<int8_t, 8>{b0}.AsType<int8x4_t>()[I0],
                                   vector_type<int8_t, 8>{b1}.AsType<int8x4_t>()[I0],
                                   vector_type<int8_t, 8>{b2}.AsType<int8x4_t>()[I0],
                                   vector_type<int8_t, 8>{b3}.AsType<int8x4_t>()[I0],
308
309
310
311
312
                                   c0,
                                   c1,
                                   c2,
                                   c3);

313
314
315
316
317
    amd_assembly_outer_product_1x4(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I1],
                                   vector_type<int8_t, 8>{b0}.AsType<int8x4_t>()[I1],
                                   vector_type<int8_t, 8>{b1}.AsType<int8x4_t>()[I1],
                                   vector_type<int8_t, 8>{b2}.AsType<int8x4_t>()[I1],
                                   vector_type<int8_t, 8>{b3}.AsType<int8x4_t>()[I1],
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
                                   c0,
                                   c1,
                                   c2,
                                   c3);
}

__device__ void amd_assembly_outer_product_1x4(int8x16_t a,
                                               int8x16_t b0,
                                               int8x16_t b1,
                                               int8x16_t b2,
                                               int8x16_t b3,
                                               int32_t& c0,
                                               int32_t& c1,
                                               int32_t& c2,
                                               int32_t& c3)

{
335
336
337
338
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};
339

340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
    amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I0],
                                   vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I0],
                                   vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I0],
                                   vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I0],
                                   vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I0],
                                   c0,
                                   c1,
                                   c2,
                                   c3);

    amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I1],
                                   vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I1],
                                   vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I1],
                                   vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I1],
                                   vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I1],
                                   c0,
                                   c1,
                                   c2,
                                   c3);
359

360
361
362
363
364
    amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I2],
                                   vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I2],
                                   vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I2],
                                   vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I2],
                                   vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I2],
365
366
367
368
369
                                   c0,
                                   c1,
                                   c2,
                                   c3);

370
371
372
373
374
    amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I3],
                                   vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I3],
                                   vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I3],
                                   vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I3],
                                   vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I3],
375
376
377
378
379
380
                                   c0,
                                   c1,
                                   c2,
                                   c3);
}

381
382
} // namespace ck
#endif