mma_earlycuda.cuh 6.96 KB
Newer Older
sxtyzhangzk's avatar
sxtyzhangzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
#pragma once

#include <cstdint>
#include "common.h"

// cuda 12.4- does not support "C" constraint in inline assembly :(
// use explicit specialization for now


namespace nunchaku::kernels {


namespace mma_helper {
    struct f32 {
        static constexpr const char value[] = "f32";
    };
    struct f16 {
        static constexpr const char value[] = "f16";
    };
    struct bf16 {
        static constexpr const char value[] = "bf16";
    };
    struct s32 {
        static constexpr const char value[] = "s32";
    };
    struct s4 {
        static constexpr const char value[] = "s4";
    };
    struct u4 {
        static constexpr const char value[] = "u4";
    };

    template<bool is_bf16>
    using f16bf16 = std::conditional_t<is_bf16, bf16, f16>;
    template<bool is_unsigned>
    using s4u4 = std::conditional_t<is_unsigned, u4, s4>;
};

__device__ __forceinline__
static uint2 mma_m16n8k16_f16f16f16f16(uint4 a, uint2 b, uint2 c) {
    uint2 d;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
    asm volatile(
        "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
        "{%0,  %1},"
        "{%2,  %3,  %4,  %5},"
        "{%6,  %7},"
        "{%8,  %9};\n"
        : 
        "=r"(d.x), "=r"(d.y)
        : 
        "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
        "r"(b.x), "r"(b.y),
        "r"(c.x), "r"(c.y)
    );
#else
    asm volatile(
        "{"
        ".reg .b32 tmp0, tmp1;"
        "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
        "{tmp0,  tmp1},"
        "{%2,  %3},"
        "{%6},"
        "{%8,  %9};\n"
        "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
        "{%0,  %1},"
        "{%4,  %5},"
        "{%7},"
        "{tmp0,  tmp1};"
        "}\n"
        : 
        "=r"(d.x), "=r"(d.y)
        : 
        "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
        "r"(b.x), "r"(b.y),
        "r"(c.x), "r"(c.y)
    );
#endif
    return d;
}

template<bool is_bf16>
__device__ __forceinline__
static uint4 mma_m16n8k16_f32f16f16f32(uint4 a, uint2 b, uint4 c) = delete;

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template<>
__device__ __forceinline__
uint4 mma_m16n8k16_f32f16f16f32<true>(uint4 a, uint2 b, uint4 c) {
    uint4 d;
    asm volatile(
        "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
        "{%0,  %1,  %2,  %3},"
        "{%4,  %5,  %6,  %7},"
        "{%8,  %9},"
        "{%10,  %11,  %12,  %13};\n"
        : 
        "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
        : 
        "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
        "r"(b.x), "r"(b.y),
        "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w)
    );
    return d;
}
#endif

template<>
__device__ __forceinline__
uint4 mma_m16n8k16_f32f16f16f32<false>(uint4 a, uint2 b, uint4 c) {
    uint4 d;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
    asm volatile(
        "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
        "{%0,  %1,  %2,  %3},"
        "{%4,  %5,  %6,  %7},"
        "{%8,  %9},"
        "{%10,  %11,  %12,  %13};\n"
        : 
        "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
        : 
        "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
        "r"(b.x), "r"(b.y),
        "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w)
    );
#else
    asm volatile(
        "{"
        ".reg .b32 tmp0, tmp1, tmp2, tmp3;"
        "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
        "{tmp0,  tmp1,  tmp2,  tmp3},"
        "{%4,  %5},"
        "{%8},"
        "{%10,  %11,  %12,  %13};\n"
        "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
        "{%0,  %1,  %2,  %3},"
        "{%6,  %7},"
        "{%9},"
        "{tmp0,  tmp1,  tmp2,  tmp3};"
        "}\n"
        : 
        "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
        : 
        "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
        "r"(b.x), "r"(b.y),
        "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w)
    );
#endif
    return d;
}

template<typename AType, typename BType>
__device__ __forceinline__
static uint4 mma_m16n8kx_s32common(uint4 a, uint2 b, uint4 c) = delete;

template<>
__device__ __forceinline__
uint4 mma_m16n8kx_s32common<mma_helper::s4, mma_helper::s4>(uint4 a, uint2 b, uint4 c) {
    uint4 d;
    static constexpr int K = 64;

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
    asm volatile(
        "mma.sync.aligned.m16n8k%14.row.col.s32.s4.s4.s32 "
        "{%0,  %1,  %2,  %3},"
        "{%4,  %5,  %6,  %7},"
        "{%8,  %9},"
        "{%10,  %11,  %12,  %13};\n"
        : 
        "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
        : 
        "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
        "r"(b.x), "r"(b.y),
        "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w),
        "n"(K)
    );
#else
    asm volatile(
        "{"
        ".reg .b32 tmp0, tmp1, tmp2, tmp3;"
        "mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
        "{tmp0, tmp1},"
        "{%4},"
        "{%8},"
        "{%10,  %11};\n"
        "mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
        "{tmp2, tmp3},"
        "{%5},"
        "{%8},"
        "{%12,  %13};\n"
        "mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
        "{%0,  %1},"
        "{%6},"
        "{%9},"
        "{tmp0, tmp1};\n"
        "mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
        "{%2,  %3},"
        "{%7},"
        "{%9},"
        "{tmp2, tmp3};\n"
        "}\n"
        : 
        "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
        : 
        "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
        "r"(b.x), "r"(b.y),
        "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w),
        "n"(K / 2)
    );
#endif
    return d;
}

template<>
__device__ __forceinline__
uint4 mma_m16n8kx_s32common<mma_helper::u4, mma_helper::s4>(uint4 a, uint2 b, uint4 c) {
    uint4 d;
    static constexpr int K = 64;

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
    asm volatile(
        "mma.sync.aligned.m16n8k%14.row.col.s32.u4.s4.s32 "
        "{%0,  %1,  %2,  %3},"
        "{%4,  %5,  %6,  %7},"
        "{%8,  %9},"
        "{%10,  %11,  %12,  %13};\n"
        : 
        "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
        : 
        "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
        "r"(b.x), "r"(b.y),
        "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w),
        "n"(K)
    );
#else
    asm volatile(
        "{"
        ".reg .b32 tmp0, tmp1, tmp2, tmp3;"
        "mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
        "{tmp0, tmp1},"
        "{%4},"
        "{%8},"
        "{%10,  %11};\n"
        "mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
        "{tmp2, tmp3},"
        "{%5},"
        "{%8},"
        "{%12,  %13};\n"
        "mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
        "{%0,  %1},"
        "{%6},"
        "{%9},"
        "{tmp0, tmp1};\n"
        "mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
        "{%2,  %3},"
        "{%7},"
        "{%9},"
        "{tmp2, tmp3};\n"
        "}\n"
        : 
        "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
        : 
        "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w),
        "r"(b.x), "r"(b.y),
        "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w),
        "n"(K / 2)
    );
#endif
    return d;
}
    

};  // namespace nunchaku::kernels