amd_wmma.hpp 8.21 KB
Newer Older
Umang Yadav's avatar
Umang Yadav committed
1
2
3

#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
4
// SPDX-License-Identifier: MIT
Illia Silin's avatar
Illia Silin committed
5
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
6
7
8
9

#ifndef CK_AMD_WMMA_HPP
#define CK_AMD_WMMA_HPP

10
#include "ck/utility/amd_inline_asm.hpp"
11
12
13
14
#include "data_type.hpp"
// TODO: Add arch limitation
namespace ck {

15
16
/********************************WAVE32 MODE***********************************************/

17
18
19
20
21
22
23
24
25
26
// src: fp16, dst: fp32
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_f16_w32;

template <>
struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
{
    template <class FloatC>
    __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
    {
27
28
        // * Inline assembly need to elimate the duplicated data load, compiler won't help you
        // delete them.
29
30
        // amd_assembly_wmma_f32_16x16x16_f16_w32(
        //     reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
Haocong WANG's avatar
Haocong WANG committed
31
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
32
33
34
35
36
37
38
        reg_c.template AsType<float8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(
            reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
#else
        ignore = reg_a;
        ignore = reg_b;
        ignore = reg_c;
#endif
39
40
41
42
43
44
45
46
47
48
49
50
51
    }
};

// src: bf16, dst: fp32
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_bf16_w32;

template <>
struct intrin_wmma_f32_16x16x16_bf16_w32<16, 16>
{
    template <class FloatC>
    __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
    {
Haocong WANG's avatar
Haocong WANG committed
52
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
53
54
55
        reg_c.template AsType<float8_t>()(Number<0>{}) =
            __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(
                reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
56
57
58
59
60
#else
        ignore = reg_a;
        ignore = reg_b;
        ignore = reg_c;
#endif
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    }
};

// src: fp16, dst: fp16
template <index_t MPerWave, index_t NPerWave, index_t Opsel>
struct intrin_wmma_f16_16x16x16_f16_w32;

template <index_t Opsel>
struct intrin_wmma_f16_16x16x16_f16_w32<16, 16, Opsel>
{
    template <class FloatC>
    __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
    {
        // opsel usage
        // false: D0.[0:15] = result
        // true : D0.[16:31]= result
Haocong WANG's avatar
Haocong WANG committed
77
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
78
79
        reg_c.template AsType<half16_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(
            reg_a, reg_b, reg_c.template AsType<half16_t>()[Number<0>{}], Opsel);
80
81
82
83
84
#else
        ignore = reg_a;
        ignore = reg_b;
        ignore = reg_c;
#endif
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
    }
};

// src: bf16, dst: bf16
template <index_t MPerWave, index_t NPerWave, index_t Opsel>
struct intrin_wmma_bf16_16x16x16_bf16_w32;

template <index_t Opsel>
struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel>
{
    template <class FloatC>
    __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
    {
        // opsel usage
        // false: D0.[0:15] = result
        // true : D0.[16:31]= result
Haocong WANG's avatar
Haocong WANG committed
101
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
102
103
104
        reg_c.template AsType<bhalf16_t>()(Number<0>{}) =
            __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32(
                reg_a, reg_b, reg_c.template AsType<bhalf16_t>()[Number<0>{}], Opsel);
105
106
107
108
109
#else
        ignore = reg_a;
        ignore = reg_b;
        ignore = reg_c;
#endif
110
111
112
113
114
115
116
117
118
119
120
121
122
    }
};

// src: iu8, dst: i32
template <index_t MPerWave, index_t NPerWave, bool neg_a, bool neg_b, bool clamp>
struct intrin_wmma_i32_16x16x16_iu8_w32;

template <bool neg_a, bool neg_b, bool clamp>
struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp>
{
    template <class FloatC>
    __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
    {
Haocong WANG's avatar
Haocong WANG committed
123
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
124
125
126
127
128
129
130
131
        reg_c.template AsType<int32x8_t>()(Number<0>{}) =
            __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
                neg_a,
                bit_cast<int32x4_t>(reg_a),
                neg_b,
                bit_cast<int32x4_t>(reg_b),
                reg_c.template AsType<int32x8_t>()[Number<0>{}],
                clamp);
132
133
134
135
136
#else
        ignore = reg_a;
        ignore = reg_b;
        ignore = reg_c;
#endif
137
138
139
    }
};

140
141
142
143
144
145
146
147
148
149
150
/********************************WAVE64 MODE***********************************************/

template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_f16_w64;

template <>
struct intrin_wmma_f32_16x16x16_f16_w64<16, 16>
{
    template <class FloatC>
    __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
    {
Haocong WANG's avatar
Haocong WANG committed
151
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
152
153
        reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64(
            reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
154
155
156
157
158
#else
        ignore = reg_a;
        ignore = reg_b;
        ignore = reg_c;
#endif
159
160
161
162
163
164
165
166
167
168
169
170
171
    }
};

// src: bf16, dst: fp32
template <index_t MPerWave, index_t NPerWave>
struct intrin_wmma_f32_16x16x16_bf16_w64;

template <>
struct intrin_wmma_f32_16x16x16_bf16_w64<16, 16>
{
    template <class FloatC>
    __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
    {
Haocong WANG's avatar
Haocong WANG committed
172
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
173
174
175
        reg_c.template AsType<float4_t>()(Number<0>{}) =
            __builtin_amdgcn_wmma_f32_16x16x16_bf16_w64(
                reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
176
177
178
179
180
#else
        ignore = reg_a;
        ignore = reg_b;
        ignore = reg_c;
#endif
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    }
};

// src: fp16, dst: fp16
template <index_t MPerWave, index_t NPerWave, index_t Opsel>
struct intrin_wmma_f16_16x16x16_f16_w64;

template <index_t Opsel>
struct intrin_wmma_f16_16x16x16_f16_w64<16, 16, Opsel>
{
    template <class FloatC>
    __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
    {
        // opsel usage
        // false: D0.[0:15] = result
        // true : D0.[16:31]= result
Haocong WANG's avatar
Haocong WANG committed
197
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
198
199
        reg_c.template AsType<half8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w64(
            reg_a, reg_b, reg_c.template AsType<half8_t>()[Number<0>{}], Opsel);
200
201
202
203
204
#else
        ignore = reg_a;
        ignore = reg_b;
        ignore = reg_c;
#endif
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
    }
};

// src: bf16, dst: bf16
template <index_t MPerWave, index_t NPerWave, index_t Opsel>
struct intrin_wmma_bf16_16x16x16_bf16_w64;

template <index_t Opsel>
struct intrin_wmma_bf16_16x16x16_bf16_w64<16, 16, Opsel>
{
    template <class FloatC>
    __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
    {
        // opsel usage
        // false: D0.[0:15] = result
        // true : D0.[16:31]= result
Haocong WANG's avatar
Haocong WANG committed
221
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
222
223
224
        reg_c.template AsType<bhalf8_t>()(Number<0>{}) =
            __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64(
                reg_a, reg_b, reg_c.template AsType<bhalf8_t>()[Number<0>{}], Opsel);
225
226
227
228
229
#else
        ignore = reg_a;
        ignore = reg_b;
        ignore = reg_c;
#endif
230
231
232
233
234
235
236
237
238
239
240
241
242
    }
};

// src: iu8, dst: i32
template <index_t MPerWave, index_t NPerWave, bool neg_a, bool neg_b, bool clamp>
struct intrin_wmma_i32_16x16x16_iu8_w64;

template <bool neg_a, bool neg_b, bool clamp>
struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
{
    template <class FloatC>
    __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
    {
Haocong WANG's avatar
Haocong WANG committed
243
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
244
245
246
247
248
249
250
251
        reg_c.template AsType<int32x4_t>()(Number<0>{}) =
            __builtin_amdgcn_wmma_i32_16x16x16_iu8_w64(
                neg_a,
                bit_cast<int32x4_t>(reg_a),
                neg_b,
                bit_cast<int32x4_t>(reg_b),
                reg_c.template AsType<int32x4_t>()[Number<0>{}],
                clamp);
252
253
254
255
256
#else
        ignore = reg_a;
        ignore = reg_b;
        ignore = reg_c;
#endif
257
258
259
    }
};

260
261
} // namespace ck
#endif
Umang Yadav's avatar
Umang Yadav committed
262
263

#pragma clang diagnostic pop