amd_wmma.hpp 8.1 KB
Newer Older
1
2
3
4
5
6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

#ifndef CK_AMD_WMMA_HPP
#define CK_AMD_WMMA_HPP

7
#include "ck/utility/amd_inline_asm.hpp"
8
9
10
11
#include "data_type.hpp"
// TODO: Add arch limitation
namespace ck {

12
13
/********************************WAVE32 MODE***********************************************/

14
15
16
17
18
19
20
21
22
23
// src: fp16, dst: fp32
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_f16_w32;

template <>
struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
{
    template <class FloatC>
    __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
    {
24
25
        // * Inline assembly need to elimate the duplicated data load, compiler won't help you
        // delete them.
26
27
        // amd_assembly_wmma_f32_16x16x16_f16_w32(
        //     reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
Haocong WANG's avatar
Haocong WANG committed
28
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
29
30
31
32
33
34
35
        reg_c.template AsType<float8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(
            reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
#else
        ignore = reg_a;
        ignore = reg_b;
        ignore = reg_c;
#endif
36
37
38
39
40
41
42
43
44
45
46
47
48
    }
};

// src: bf16, dst: fp32
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_bf16_w32;

template <>
struct intrin_wmma_f32_16x16x16_bf16_w32<16, 16>
{
    template <class FloatC>
    __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
    {
Haocong WANG's avatar
Haocong WANG committed
49
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
50
51
52
        reg_c.template AsType<float8_t>()(Number<0>{}) =
            __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(
                reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
53
54
55
56
57
#else
        ignore = reg_a;
        ignore = reg_b;
        ignore = reg_c;
#endif
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    }
};

// src: fp16, dst: fp16
template <index_t MPerWave, index_t NPerWave, index_t Opsel>
struct intrin_wmma_f16_16x16x16_f16_w32;

template <index_t Opsel>
struct intrin_wmma_f16_16x16x16_f16_w32<16, 16, Opsel>
{
    template <class FloatC>
    __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
    {
        // opsel usage
        // false: D0.[0:15] = result
        // true : D0.[16:31]= result
Haocong WANG's avatar
Haocong WANG committed
74
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
75
76
        reg_c.template AsType<half16_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(
            reg_a, reg_b, reg_c.template AsType<half16_t>()[Number<0>{}], Opsel);
77
78
79
80
81
#else
        ignore = reg_a;
        ignore = reg_b;
        ignore = reg_c;
#endif
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    }
};

// src: bf16, dst: bf16
template <index_t MPerWave, index_t NPerWave, index_t Opsel>
struct intrin_wmma_bf16_16x16x16_bf16_w32;

template <index_t Opsel>
struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel>
{
    template <class FloatC>
    __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
    {
        // opsel usage
        // false: D0.[0:15] = result
        // true : D0.[16:31]= result
Haocong WANG's avatar
Haocong WANG committed
98
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
99
100
101
        reg_c.template AsType<bhalf16_t>()(Number<0>{}) =
            __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32(
                reg_a, reg_b, reg_c.template AsType<bhalf16_t>()[Number<0>{}], Opsel);
102
103
104
105
106
#else
        ignore = reg_a;
        ignore = reg_b;
        ignore = reg_c;
#endif
107
108
109
110
111
112
113
114
115
116
117
118
119
    }
};

// src: iu8, dst: i32
template <index_t MPerWave, index_t NPerWave, bool neg_a, bool neg_b, bool clamp>
struct intrin_wmma_i32_16x16x16_iu8_w32;

template <bool neg_a, bool neg_b, bool clamp>
struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp>
{
    template <class FloatC>
    __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
    {
Haocong WANG's avatar
Haocong WANG committed
120
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
121
122
123
124
125
126
127
128
        reg_c.template AsType<int32x8_t>()(Number<0>{}) =
            __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
                neg_a,
                bit_cast<int32x4_t>(reg_a),
                neg_b,
                bit_cast<int32x4_t>(reg_b),
                reg_c.template AsType<int32x8_t>()[Number<0>{}],
                clamp);
129
130
131
132
133
#else
        ignore = reg_a;
        ignore = reg_b;
        ignore = reg_c;
#endif
134
135
136
    }
};

137
138
139
140
141
142
143
144
145
146
147
/********************************WAVE64 MODE***********************************************/

template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_f16_w64;

template <>
struct intrin_wmma_f32_16x16x16_f16_w64<16, 16>
{
    template <class FloatC>
    __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
    {
Haocong WANG's avatar
Haocong WANG committed
148
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
149
150
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64(
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
151
152
153
154
155
#else
        ignore = reg_a;
        ignore = reg_b;
        ignore = reg_c;
#endif
156
157
158
159
160
161
162
163
164
165
166
167
168
    }
};

// src: bf16, dst: fp32
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_bf16_w64;

template <>
struct intrin_wmma_f32_16x16x16_bf16_w64<16, 16>
{
    template <class FloatC>
    __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
    {
Haocong WANG's avatar
Haocong WANG committed
169
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
170
171
172
        reg_c.template AsType<float4_t>()(Number<0>{}) =
            __builtin_amdgcn_wmma_f32_16x16x16_bf16_w64(
                reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
173
174
175
176
177
#else
        ignore = reg_a;
        ignore = reg_b;
        ignore = reg_c;
#endif
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
    }
};

// src: fp16, dst: fp16
template <index_t MPerWave, index_t NPerWave, index_t Opsel>
struct intrin_wmma_f16_16x16x16_f16_w64;

template <index_t Opsel>
struct intrin_wmma_f16_16x16x16_f16_w64<16, 16, Opsel>
{
    template <class FloatC>
    __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
    {
        // opsel usage
        // false: D0.[0:15] = result
        // true : D0.[16:31]= result
Haocong WANG's avatar
Haocong WANG committed
194
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
195
196
        reg_c.template AsType<half8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w64(
            reg_a, reg_b, reg_c.template AsType<half8_t>()[Number<0>{}], Opsel);
197
198
199
200
201
#else
        ignore = reg_a;
        ignore = reg_b;
        ignore = reg_c;
#endif
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
    }
};

// src: bf16, dst: bf16
template <index_t MPerWave, index_t NPerWave, index_t Opsel>
struct intrin_wmma_bf16_16x16x16_bf16_w64;

template <index_t Opsel>
struct intrin_wmma_bf16_16x16x16_bf16_w64<16, 16, Opsel>
{
    template <class FloatC>
    __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
    {
        // opsel usage
        // false: D0.[0:15] = result
        // true : D0.[16:31]= result
Haocong WANG's avatar
Haocong WANG committed
218
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
219
220
221
        reg_c.template AsType<bhalf8_t>()(Number<0>{}) =
            __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64(
                reg_a, reg_b, reg_c.template AsType<bhalf8_t>()[Number<0>{}], Opsel);
222
223
224
225
226
#else
        ignore = reg_a;
        ignore = reg_b;
        ignore = reg_c;
#endif
227
228
229
230
231
232
233
234
235
236
237
238
239
    }
};

// src: iu8, dst: i32
template <index_t MPerWave, index_t NPerWave, bool neg_a, bool neg_b, bool clamp>
struct intrin_wmma_i32_16x16x16_iu8_w64;

template <bool neg_a, bool neg_b, bool clamp>
struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
{
    template <class FloatC>
    __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
    {
Haocong WANG's avatar
Haocong WANG committed
240
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
241
242
243
244
245
246
247
248
        reg_c.template AsType<int32x4_t>()(Number<0>{}) =
            __builtin_amdgcn_wmma_i32_16x16x16_iu8_w64(
                neg_a,
                bit_cast<int32x4_t>(reg_a),
                neg_b,
                bit_cast<int32x4_t>(reg_b),
                reg_c.template AsType<int32x4_t>()[Number<0>{}],
                clamp);
249
250
251
252
253
#else
        ignore = reg_a;
        ignore = reg_b;
        ignore = reg_c;
#endif
254
255
256
    }
};

257
258
} // namespace ck
#endif