common.h 12.5 KB
Newer Older
1
2
3
4
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

Chen Xin's avatar
Chen Xin committed
5
#include "src/turbomind/macro.h"
6
7
8
9
10
11
12
13
14
15
16
#include <cassert>
#include <cstdint>
#include <cuda_fp16.h>
#include <type_traits>

namespace turbomind {

#ifndef TURBOMIND_S4_DEQUANT_USE_FMA
#define TURBOMIND_S4_DEQUANT_USE_FMA 0
#endif

17
18
19
20
21
// #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750))
// #define TURBOMIND_ARCH_SM75 1
// #else
// #define TURBOMIND_ARCH_SM75 0
// #endif
22

23
24
25
26
27
// #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
// #define TURBOMIND_ARCH_SM80 1
// #else
// #define TURBOMIND_ARCH_SM80 0
// #endif
28

29
30
// constexpr int WARP_SIZE = 32;
constexpr int WARP_SIZE = 64;
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

#if defined(__CUDA_ARCH__) && !defined(__INTELLISENSE__)
#if defined(__CUDACC_RTC__) || (defined(__clang__) && defined(__CUDA__))
#define PRAGMA_UNROLL _Pragma("unroll")
#define PRAGMA_NO_UNROLL _Pragma("unroll 1")
#else
#define PRAGMA_UNROLL #pragma unroll
#define PRAGMA_NO_UNROLL #pragma unroll 1
#endif
#else
#define PRAGMA_UNROLL
#define PRAGMA_NO_UNROLL
#endif

// Modified from NVIDIA FasterTransformer:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
// Modified from llm-awq https://github.com/mit-han-lab/llm-awq/blob/main/awq/kernels/csrc/quantization/dequantize.cuh
__inline__ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
{
    uint4 result;

    uint32_t*      h   = reinterpret_cast<uint32_t*>(&result);
    uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);

    // First, we extract the i4s and construct an intermediate fp16 number.
    static constexpr uint32_t immLut                = (0xf0 & 0xcc) | 0xaa;
    static constexpr uint32_t BOTTOM_MASK           = 0x000f000f;
    static constexpr uint32_t TOP_MASK              = 0x00f000f0;
    static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;

    // Note that the entire sequence only requires 1 shift instruction. This is
    // thanks to the register packing format and the fact that we force our
    // integers to be unsigned, and account for this in the fp16 subtractions. In
    // addition, I exploit the fact that sub and fma have the same throughput in
    // order to convert elt_23 and elt_67 to fp16 without having to shift them to
    // the bottom bits before hand.

    // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
    // dependency if we issue immediately before required.
    const uint32_t top_i4s = i4s >> 8;
    // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    // asm("lop3.b32 %0, %1, %2, %3, %4;\n"
    //     : "=r"(h[0])
    //     : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
    // // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
    // asm("lop3.b32 %0, %1, %2, %3, %4;\n"
    //     : "=r"(h[1])
    //     : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
    // // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
    // asm("lop3.b32 %0, %1, %2, %3, %4;\n"
    //     : "=r"(h[2])
    //     : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
    // // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
    // asm("lop3.b32 %0, %1, %2, %3, %4;\n"
    //     : "=r"(h[3])
    //     : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
    printf("=========common.h 86\n");
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
    // I use inline PTX below because I am not sure if the compiler will emit
    // float2half instructions if I use the half2 ctor. In this case, I chose
    // performance reliability over code readability.

    // This is the half2 {1032, 1032} represented as an integer.
    // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
    // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
    static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
    // This is the half2 {1 / 16, 1 / 16} represented as an integer.
    static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
    // This is the half2 {-72, -72} represented as an integer.
    // static constexpr uint32_t NEG_72 = 0xd480d480;
    // Haotian: Let's use {-64, -64}.
    static constexpr uint32_t NEG_64 = 0xd400d400;

    // Finally, we construct the output numbers.
    // Convert elt_01
105
106
107
108
109
110
111
    // asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
    // // Convert elt_23
    // asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
    // // Convert elt_45
    // asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
    // // Convert elt_67
    // asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133

    return result;
}

__inline__ __device__ uint4 dequantize_s4_to_fp16x2_v2(uint32_t const& source)
{
    uint4 result;

    uint32_t*       h   = reinterpret_cast<uint32_t*>(&result);
    uint32_t const& i4s = reinterpret_cast<uint32_t const&>(source);

    // First, we extract the i4s and construct an intermediate fp16 number.
    static constexpr uint32_t immLut      = (0xf0 & 0xcc) | 0xaa;
    static constexpr uint32_t BOT_MASK    = 0x000f000f;
    static constexpr uint32_t TOP_MASK    = 0x00f000f0;
    static constexpr uint32_t MAGIC_NUM_0 = 0x64006400;        // `1024`
    static constexpr uint32_t MAGIC_NUM_1 = 0x54005400;        // `64`
    static constexpr uint32_t MAGIC_NUM_2 = MAGIC_NUM_1 >> 4;  // `64` >> 4

    // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
    // dependency if we issue immediately before required.
    const uint32_t top_i4s = i4s >> 8;
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    printf("=========common.h 133\n");
    // if (0) {  // 1024 & 64
    //     asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[0]) : "r"(i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_0), "n"(immLut));
    //     asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[1]) : "r"(i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
    //     asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[2]) : "r"(top_i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_0), "n"(immLut));
    //     asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[3]) : "r"(top_i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
    //     asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(MAGIC_NUM_0));
    //     asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(MAGIC_NUM_1));
    //     asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(MAGIC_NUM_0));
    //     asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[3]) : "r"(h[3]), "r"(MAGIC_NUM_1));
    // }
    // else {  //  64 only, trade 4 hfma2 with 2 shifts
    //     asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[0]) : "r"(i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_2), "n"(immLut));
    //     asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[1]) : "r"(i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
    //     asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[2]) : "r"(top_i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_2), "n"(immLut));
    //     asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[3]) : "r"(top_i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
    //     h[0] <<= 4;
    //     h[2] <<= 4;
    //     // we don't need to subtract the magic nums because zeros will go through the same dequant function
    //     // and carry the same magic constant, the magic num will be canceled out after subtracting zeros
    // }
155
156
157
158
159
160
161

    return result;
}

__inline__ __device__ uint32_t cast_smem_ptr_to_uint(void const* const ptr)
{
    uint32_t smem_int_ptr;
162
163
164
165
    printf("=========common.h 161\n");
    // asm("{.reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n"
    //     : "=r"(smem_int_ptr)
    //     : "l"(ptr));
166
167
168
169
170
171

    return smem_int_ptr;
}

__inline__ __device__ void ldmatrix_m8n8_x4_b16(uint& d0, uint& d1, uint& d2, uint& d3, uint32_t smem_int_ptr)
{
172
173
174
175
176
177
178
179
    printf("=========common.h 171\n");
// #if TURBOMIND_ARCH_SM75
//     asm("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
//         : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
//         : "r"(smem_int_ptr));
// #else
//     assert(TURBOMIND_ARCH_SM75);
// #endif
180
181
182
183
}

__inline__ __device__ void ldmatrix_m8n8_x2_b16(uint& d0, uint& d1, uint32_t smem_int_ptr)
{
184
185
186
187
188
189
    printf("=========common.h 183\n");
// #if TURBOMIND_ARCH_SM75
//     asm("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" : "=r"(d0), "=r"(d1) : "r"(smem_int_ptr));
// #else
//     assert(TURBOMIND_ARCH_SM75);
// #endif
190
191
}

192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
// __inline__ __device__ void wait_flag(int* lock, int status, int thread_id)
// {
//     int state = 0;
//     while (__syncthreads_and(state != status)) {
//         if (thread_id == 0) {
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
//             asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
// #else
//             asm volatile("ld.global.cg.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
// #endif
//         }
//     }

//     __syncthreads();  // memory fence
// }

// __inline__ __device__ void release_flag(int* lock, int status, int thread_id)
// {
//     __syncthreads();  // memory fence

//     if (thread_id == 0) {
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
//         asm volatile("st.global.release.gpu.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
// #else
//         asm volatile("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
// #endif
//     }
// }
220
221
222
223
224
225
226
227
228

__inline__ __device__ half2 apply_Q(const half2& x, const half2& q)
{
    uint s, z;
    (half2&)z = __halves2half2(q.x, q.x);
    (half2&)s = __halves2half2(q.y, q.y);

    auto& t = (const uint&)x;
    uint  u, v;
229
230
231
232
233
234
235
236
    // if (TURBOMIND_S4_DEQUANT_USE_FMA) {
    //     asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(v) : "r"(t), "r"(s), "r"(z));
    // }
    // else {
    //     asm("sub.ftz.f16x2 %0, %1, %2;\n" : "=r"(u) : "r"(t), "r"(z));
    //     asm("mul.ftz.f16x2 %0, %1, %2;\n" : "=r"(v) : "r"(u), "r"(s));
    // }
    printf("=========common.h 235\n");
237
238
239
240
241
242
    return (half2&)v;
}

template<typename T, int N>
struct Array {

Li Zhang's avatar
Li Zhang committed
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
    using value_type      = T;
    using size_type       = int;
    using difference_type = int;
    using reference       = value_type&;
    using const_reference = const value_type&;
    using pointer         = value_type*;
    using const_pointer   = const value_type*;
    using iterator        = pointer;
    using const_iterator  = const_pointer;

    static_assert(N > 0);

    T __a[N];

    __device__ __host__ constexpr reference operator[](size_type i) noexcept
    {
        return __a[i];
    }
    __device__ __host__ constexpr const_reference operator[](size_type i) const noexcept
    {
        return __a[i];
    }

    __device__ __host__ constexpr reference front() noexcept
    {
        return *begin();
    }

    __device__ __host__ constexpr const_reference front() const noexcept
    {
        return *begin();
    }

    __device__ __host__ constexpr reference back() noexcept
    {
        return *(end() - 1);
    }

    __device__ __host__ constexpr const_reference back() const noexcept
    {
        return *(end() - 1);
    }

    __device__ __host__ constexpr pointer data() noexcept
287
    {
Li Zhang's avatar
Li Zhang committed
288
        return &__a[0];
289
    }
Li Zhang's avatar
Li Zhang committed
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321

    __device__ __host__ constexpr const_pointer data() const noexcept
    {
        return &__a[0];
    }

    __device__ __host__ constexpr iterator begin() noexcept
    {
        return data();
    }

    __device__ __host__ constexpr const_iterator begin() const noexcept
    {
        return data();
    }

    __device__ __host__ constexpr iterator end() noexcept
    {
        return data() + N;
    }

    __device__ __host__ constexpr const_iterator end() const noexcept
    {
        return data() + N;
    }

    __device__ __host__ constexpr std::integral_constant<int, N> size() const noexcept
    {
        return {};
    }

    __device__ __host__ constexpr std::false_type empty() const noexcept
322
    {
Li Zhang's avatar
Li Zhang committed
323
        return {};
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
    }
};

template<int... Ns>
struct Shape {
    static constexpr Array<int, sizeof...(Ns)> data_{Ns...};

    constexpr Shape() = default;

    Shape(std::integral_constant<int, Ns>...){};

    template<int index>
    constexpr auto get() const noexcept
    {
        return std::integral_constant<int, data_[index]>{};
    }

    constexpr auto m() const noexcept
    {
        return get<0>();
    }

    constexpr auto n() const noexcept
    {
        return get<1>();
    }

    constexpr auto k() const noexcept
    {
        return get<2>();
    }

    constexpr int c() const noexcept
    {
        return get<0>();
    }

    constexpr int s() const noexcept
    {
        return get<1>();
    }

    constexpr int count() const noexcept
    {
        return (Ns * ...);
    }
};

}  // namespace turbomind