amd_inline_asm.hpp 14.9 KB
Newer Older
1
2
3
#ifndef CK_AMD_INLINE_ASM_HPP
#define CK_AMD_INLINE_ASM_HPP

4
#include "data_type.hpp"
Jing Zhang's avatar
Jing Zhang committed
5

6
7
namespace ck {

Chao Liu's avatar
Chao Liu committed
8
9
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
10
__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1)
Chao Liu's avatar
Chao Liu committed
11
{
Chao Liu's avatar
Chao Liu committed
12
13
14
15
16
17
    asm volatile("\n \
            v_fmac_f32 %0, %2, %3 \n \
            v_fmac_f32 %1, %2, %4 \n \
            "
                 : "=v"(c0), "=v"(c1)
                 : "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
Chao Liu's avatar
Chao Liu committed
18
19
}

Chao Liu's avatar
Chao Liu committed
20
21
22
23
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c2 += inner_product(a, b2)
// c3 += inner_product(a, b3)
24
__device__ void amd_assembly_outer_product_1x4(
25
    float a, float b0, float b1, float b2, float b3, float& c0, float& c1, float& c2, float& c3)
Chao Liu's avatar
Chao Liu committed
26
{
Chao Liu's avatar
Chao Liu committed
27
28
29
30
31
32
33
34
    asm volatile("\n \
            v_fmac_f32 %0, %4, %5 \n \
            v_fmac_f32 %1, %4, %6 \n \
            v_fmac_f32 %2, %4, %7 \n \
            v_fmac_f32 %3, %4, %8 \n \
            "
                 : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
                 : "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3));
Jing Zhang's avatar
Jing Zhang committed
35
36
}

Chao Liu's avatar
Chao Liu committed
37
38
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
39
40
__device__ void
amd_assembly_outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, float& c1)
Chao Liu's avatar
Chao Liu committed
41
{
42
    asm volatile("\n \
43
44
            v_dot2_f32_f16 %0, %2, %3, %0\n \
            v_dot2_f32_f16 %1, %2, %4, %1\n \
45
            "
Chao Liu's avatar
Chao Liu committed
46
47
                 : "=v"(c0), "=v"(c1)
                 : "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
Jing Zhang's avatar
Jing Zhang committed
48
49
}

Chao Liu's avatar
Chao Liu committed
50
51
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
52
53
__device__ void
amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1)
Chao Liu's avatar
Chao Liu committed
54
{
55
    // TODO remove pointer casting
56
57
58
    const half2_t* p_a_half2  = reinterpret_cast<const half2_t*>(&a);
    const half2_t* p_b0_half2 = reinterpret_cast<const half2_t*>(&b0);
    const half2_t* p_b1_half2 = reinterpret_cast<const half2_t*>(&b1);
Chao Liu's avatar
Chao Liu committed
59

60
61
    // do dot2 two times
    asm volatile("\n \
62
63
64
65
            v_dot2_f32_f16 %0, %2, %4, %0\n \
            v_dot2_f32_f16 %1, %2, %6, %1\n \
            v_dot2_f32_f16 %0, %3, %5, %0\n \
            v_dot2_f32_f16 %1, %3, %7, %1\n \
66
            "
Chao Liu's avatar
Chao Liu committed
67
                 : "=v"(c0), "=v"(c1)
68
                 : "v"(p_a_half2[0]),
Chao Liu's avatar
Chao Liu committed
69
                   "v"(p_a_half2[1]),
70
71
72
                   "v"(p_b0_half2[0]),
                   "v"(p_b0_half2[1]),
                   "v"(p_b1_half2[0]),
Chao Liu's avatar
Chao Liu committed
73
                   "v"(p_b1_half2[1]),
74
                   "0"(c0),
Chao Liu's avatar
Chao Liu committed
75
                   "1"(c1));
Jing Zhang's avatar
Jing Zhang committed
76
77
}

Chao Liu's avatar
Chao Liu committed
78
79
80
81
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c2 += inner_product(a, b2)
// c3 += inner_product(a, b3)
82
83
84
85
86
87
88
89
90
__device__ void amd_assembly_outer_product_1x4(half2_t a,
                                               half2_t b0,
                                               half2_t b1,
                                               half2_t b2,
                                               half2_t b3,
                                               float& c0,
                                               float& c1,
                                               float& c2,
                                               float& c3)
Jing Zhang's avatar
Jing Zhang committed
91
{
92
    asm volatile("\n \
93
94
95
96
            v_dot2_f32_f16 %0, %4, %5, %0\n \
            v_dot2_f32_f16 %1, %4, %6, %1\n \
            v_dot2_f32_f16 %2, %4, %7, %2\n \
            v_dot2_f32_f16 %3, %4, %8, %3\n \
97
            "
Chao Liu's avatar
Chao Liu committed
98
99
                 : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
                 : "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3));
Jing Zhang's avatar
Jing Zhang committed
100
101
}

Chao Liu's avatar
Chao Liu committed
102
103
104
105
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c2 += inner_product(a, b2)
// c3 += inner_product(a, b3)
106
107
108
109
110
111
112
113
114
__device__ void amd_assembly_outer_product_1x4(half4_t a,
                                               half4_t b0,
                                               half4_t b1,
                                               half4_t b2,
                                               half4_t b3,
                                               float& c0,
                                               float& c1,
                                               float& c2,
                                               float& c3)
Jing Zhang's avatar
Jing Zhang committed
115
{
116
    // TODO remove pointer casting
117
118
119
120
121
    const half2_t* p_a_half2  = reinterpret_cast<const half2_t*>(&a);
    const half2_t* p_b0_half2 = reinterpret_cast<const half2_t*>(&b0);
    const half2_t* p_b1_half2 = reinterpret_cast<const half2_t*>(&b1);
    const half2_t* p_b2_half2 = reinterpret_cast<const half2_t*>(&b2);
    const half2_t* p_b3_half2 = reinterpret_cast<const half2_t*>(&b3);
Jing Zhang's avatar
Jing Zhang committed
122

123
124
    // do dot2 two times
    asm volatile("\n \
125
126
127
128
129
130
131
132
            v_dot2_f32_f16 %0, %4, %6,  %0\n \
            v_dot2_f32_f16 %1, %4, %8,  %1\n \
            v_dot2_f32_f16 %2, %4, %10, %2\n \
            v_dot2_f32_f16 %3, %4, %12, %3\n \
            v_dot2_f32_f16 %0, %5, %7,  %0\n \
            v_dot2_f32_f16 %1, %5, %9,  %1\n \
            v_dot2_f32_f16 %2, %5, %11, %2\n \
            v_dot2_f32_f16 %3, %5, %13, %3\n \
Jing Zhang's avatar
Jing Zhang committed
133
            "
Chao Liu's avatar
Chao Liu committed
134
                 : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
135
                 : "v"(p_a_half2[0]),
Chao Liu's avatar
Chao Liu committed
136
                   "v"(p_a_half2[1]),
137
138
139
                   "v"(p_b0_half2[0]),
                   "v"(p_b0_half2[1]),
                   "v"(p_b1_half2[0]),
Chao Liu's avatar
Chao Liu committed
140
                   "v"(p_b1_half2[1]),
141
142
143
                   "v"(p_b2_half2[0]),
                   "v"(p_b2_half2[1]),
                   "v"(p_b3_half2[0]),
Chao Liu's avatar
Chao Liu committed
144
                   "v"(p_b3_half2[1]),
145
146
147
                   "0"(c0),
                   "1"(c1),
                   "2"(c2),
Chao Liu's avatar
Chao Liu committed
148
149
150
                   "3"(c3));
}

151
152
153
154
155
156
157
158
159
160
161
__device__ void amd_assembly_outer_product_1x4(half8_t a,
                                               half8_t b0,
                                               half8_t b1,
                                               half8_t b2,
                                               half8_t b3,
                                               float& c0,
                                               float& c1,
                                               float& c2,
                                               float& c3)
{

162
    // TODO remove pointer casting
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
    const half4_t* p_a_half4  = reinterpret_cast<const half4_t*>(&a);
    const half4_t* p_b0_half4 = reinterpret_cast<const half4_t*>(&b0);
    const half4_t* p_b1_half4 = reinterpret_cast<const half4_t*>(&b1);
    const half4_t* p_b2_half4 = reinterpret_cast<const half4_t*>(&b2);
    const half4_t* p_b3_half4 = reinterpret_cast<const half4_t*>(&b3);

    amd_assembly_outer_product_1x4(
        p_a_half4[0], p_b0_half4[0], p_b1_half4[0], p_b2_half4[0], p_b3_half4[0], c0, c1, c2, c3);

    amd_assembly_outer_product_1x4(
        p_a_half4[1], p_b0_half4[1], p_b1_half4[1], p_b2_half4[1], p_b3_half4[1], c0, c1, c2, c3);
}

__device__ void amd_assembly_outer_product_1x4(half16_t a,
                                               half16_t b0,
                                               half16_t b1,
                                               half16_t b2,
                                               half16_t b3,
                                               float& c0,
                                               float& c1,
                                               float& c2,
                                               float& c3)
{
186
    // TODO remove pointer casting
187
188
189
190
191
192
193
194
195
196
197
198
199
    const half8_t* p_a_half8  = reinterpret_cast<const half8_t*>(&a);
    const half8_t* p_b0_half8 = reinterpret_cast<const half8_t*>(&b0);
    const half8_t* p_b1_half8 = reinterpret_cast<const half8_t*>(&b1);
    const half8_t* p_b2_half8 = reinterpret_cast<const half8_t*>(&b2);
    const half8_t* p_b3_half8 = reinterpret_cast<const half8_t*>(&b3);

    amd_assembly_outer_product_1x4(
        p_a_half8[0], p_b0_half8[0], p_b1_half8[0], p_b2_half8[0], p_b3_half8[0], c0, c1, c2, c3);

    amd_assembly_outer_product_1x4(
        p_a_half8[1], p_b0_half8[1], p_b1_half8[1], p_b2_half8[1], p_b3_half8[1], c0, c1, c2, c3);
}

Chao Liu's avatar
Chao Liu committed
200
201
202
203
204
205
206
207
208
209
210
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
__device__ void
amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0, int32_t& c1)
{
#if 1
    asm volatile("\n \
            v_dot4_i32_i8 %0, %2, %3, %0\n \
            v_dot4_i32_i8 %1, %2, %4, %1\n \
            "
                 : "=v"(c0), "=v"(c1)
211
212
213
214
215
                 : "v"(as_type<int32_t>(a)),
                   "v"(as_type<int32_t>(b0)),
                   "v"(as_type<int32_t>(b1)),
                   "0"(c0),
                   "1"(c1));
Chao Liu's avatar
Chao Liu committed
216
#else
217
218
    c0 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b0), c0, false);
    c1 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b1), c1, false);
Chao Liu's avatar
Chao Liu committed
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
#endif
}

// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c2 += inner_product(a, b2)
// c3 += inner_product(a, b3)
__device__ void amd_assembly_outer_product_1x4(int8x4_t a,
                                               int8x4_t b0,
                                               int8x4_t b1,
                                               int8x4_t b2,
                                               int8x4_t b3,
                                               int32_t& c0,
                                               int32_t& c1,
                                               int32_t& c2,
                                               int32_t& c3)
{
#if 1
    asm volatile("\n \
            v_dot4_i32_i8 %0, %4, %5, %0\n \
            v_dot4_i32_i8 %1, %4, %6, %1\n \
            v_dot4_i32_i8 %2, %4, %7, %2\n \
            v_dot4_i32_i8 %3, %4, %8, %3\n \
            "
                 : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
244
245
246
247
248
249
250
251
252
                 : "v"(as_type<int32_t>(a)),
                   "v"(as_type<int32_t>(b0)),
                   "v"(as_type<int32_t>(b1)),
                   "v"(as_type<int32_t>(b2)),
                   "v"(as_type<int32_t>(b3)),
                   "0"(c0),
                   "1"(c1),
                   "2"(c2),
                   "3"(c3));
Chao Liu's avatar
Chao Liu committed
253
#else
254
255
256
257
    c0 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b0), c0, false);
    c1 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b1), c1, false);
    c2 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b2), c2, false);
    c3 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b3), c3, false);
Chao Liu's avatar
Chao Liu committed
258
#endif
Jing Zhang's avatar
Jing Zhang committed
259
}
260

261
262
263
264
265
266
267
268
269
270
__device__ void amd_assembly_outer_product_1x4(int8x8_t a,
                                               int8x8_t b0,
                                               int8x8_t b1,
                                               int8x8_t b2,
                                               int8x8_t b3,
                                               int32_t& c0,
                                               int32_t& c1,
                                               int32_t& c2,
                                               int32_t& c3)
{
271
272
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
273

274
275
276
277
278
    amd_assembly_outer_product_1x4(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I0],
                                   vector_type<int8_t, 8>{b0}.AsType<int8x4_t>()[I0],
                                   vector_type<int8_t, 8>{b1}.AsType<int8x4_t>()[I0],
                                   vector_type<int8_t, 8>{b2}.AsType<int8x4_t>()[I0],
                                   vector_type<int8_t, 8>{b3}.AsType<int8x4_t>()[I0],
279
280
281
282
283
                                   c0,
                                   c1,
                                   c2,
                                   c3);

284
285
286
287
288
    amd_assembly_outer_product_1x4(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I1],
                                   vector_type<int8_t, 8>{b0}.AsType<int8x4_t>()[I1],
                                   vector_type<int8_t, 8>{b1}.AsType<int8x4_t>()[I1],
                                   vector_type<int8_t, 8>{b2}.AsType<int8x4_t>()[I1],
                                   vector_type<int8_t, 8>{b3}.AsType<int8x4_t>()[I1],
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
                                   c0,
                                   c1,
                                   c2,
                                   c3);
}

__device__ void amd_assembly_outer_product_1x4(int8x16_t a,
                                               int8x16_t b0,
                                               int8x16_t b1,
                                               int8x16_t b2,
                                               int8x16_t b3,
                                               int32_t& c0,
                                               int32_t& c1,
                                               int32_t& c2,
                                               int32_t& c3)

{
306
307
308
309
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};
310

311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
    amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I0],
                                   vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I0],
                                   vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I0],
                                   vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I0],
                                   vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I0],
                                   c0,
                                   c1,
                                   c2,
                                   c3);

    amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I1],
                                   vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I1],
                                   vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I1],
                                   vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I1],
                                   vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I1],
                                   c0,
                                   c1,
                                   c2,
                                   c3);
330

331
332
333
334
335
    amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I2],
                                   vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I2],
                                   vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I2],
                                   vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I2],
                                   vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I2],
336
337
338
339
340
                                   c0,
                                   c1,
                                   c2,
                                   c3);

341
342
343
344
345
    amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I3],
                                   vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I3],
                                   vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I3],
                                   vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I3],
                                   vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I3],
346
347
348
349
350
351
                                   c0,
                                   c1,
                                   c2,
                                   c3);
}

352
353
} // namespace ck
#endif