common.h 11.5 KB
Newer Older
1
2
3
4
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

Chen Xin's avatar
Chen Xin committed
5
#include "src/turbomind/macro.h"
6
7
8
9
10
11
12
13
14
15
16
#include <cassert>
#include <cstdint>
#include <cuda_fp16.h>
#include <type_traits>

namespace turbomind {

#ifndef TURBOMIND_S4_DEQUANT_USE_FMA
#define TURBOMIND_S4_DEQUANT_USE_FMA 0
#endif

17
18
19
20
21
// #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750))
// #define TURBOMIND_ARCH_SM75 1
// #else
// #define TURBOMIND_ARCH_SM75 0
// #endif
22

23
24
25
26
27
// #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
// #define TURBOMIND_ARCH_SM80 1
// #else
// #define TURBOMIND_ARCH_SM80 0
// #endif
28

29
30
// constexpr int WARP_SIZE = 32;
constexpr int WARP_SIZE = 64;
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

#if defined(__CUDA_ARCH__) && !defined(__INTELLISENSE__)
#if defined(__CUDACC_RTC__) || (defined(__clang__) && defined(__CUDA__))
#define PRAGMA_UNROLL _Pragma("unroll")
#define PRAGMA_NO_UNROLL _Pragma("unroll 1")
#else
#define PRAGMA_UNROLL #pragma unroll
#define PRAGMA_NO_UNROLL #pragma unroll 1
#endif
#else
#define PRAGMA_UNROLL
#define PRAGMA_NO_UNROLL
#endif

// Modified from NVIDIA FasterTransformer:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
// Modified from llm-awq https://github.com/mit-han-lab/llm-awq/blob/main/awq/kernels/csrc/quantization/dequantize.cuh
__inline__ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
{
    uint4 result;

    uint32_t*      h   = reinterpret_cast<uint32_t*>(&result);
    uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);

    // First, we extract the i4s and construct an intermediate fp16 number.
    static constexpr uint32_t immLut                = (0xf0 & 0xcc) | 0xaa;
    static constexpr uint32_t BOTTOM_MASK           = 0x000f000f;
    static constexpr uint32_t TOP_MASK              = 0x00f000f0;
    static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;

    // Note that the entire sequence only requires 1 shift instruction. This is
    // thanks to the register packing format and the fact that we force our
    // integers to be unsigned, and account for this in the fp16 subtractions. In
    // addition, I exploit the fact that sub and fma have the same throughput in
    // order to convert elt_23 and elt_67 to fp16 without having to shift them to
    // the bottom bits before hand.

    // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
    // dependency if we issue immediately before required.
    const uint32_t top_i4s = i4s >> 8;
    // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
72
73
74
    // asm("lop3.b32 %0, %1, %2, %3, %4;\n"
    //     : "=r"(h[0])
    //     : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
gaoqiong's avatar
gaoqiong committed
75
76
    h[0]=(i4s & BOTTOM_MASK)|I4s_TO_F16s_MAGIC_NUM;
    // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
77
78
79
    // asm("lop3.b32 %0, %1, %2, %3, %4;\n"
    //     : "=r"(h[1])
    //     : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
gaoqiong's avatar
gaoqiong committed
80
81
    h[1]=(i4s & TOP_MASK)|I4s_TO_F16s_MAGIC_NUM;
    // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
82
83
84
    // asm("lop3.b32 %0, %1, %2, %3, %4;\n"
    //     : "=r"(h[2])
    //     : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
gaoqiong's avatar
gaoqiong committed
85
86
    h[2]=(top_i4s & BOTTOM_MASK)|I4s_TO_F16s_MAGIC_NUM;
    // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
87
88
89
    // asm("lop3.b32 %0, %1, %2, %3, %4;\n"
    //     : "=r"(h[3])
    //     : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
gaoqiong's avatar
gaoqiong committed
90
91
    h[3]=(top_i4s & TOP_MASK)|I4s_TO_F16s_MAGIC_NUM;

92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    // I use inline PTX below because I am not sure if the compiler will emit
    // float2half instructions if I use the half2 ctor. In this case, I chose
    // performance reliability over code readability.

    // This is the half2 {1032, 1032} represented as an integer.
    // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
    // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
    static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
    // This is the half2 {1 / 16, 1 / 16} represented as an integer.
    static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
    // This is the half2 {-72, -72} represented as an integer.
    // static constexpr uint32_t NEG_72 = 0xd480d480;
    // Haotian: Let's use {-64, -64}.
    static constexpr uint32_t NEG_64 = 0xd400d400;

    // Finally, we construct the output numbers.
    // Convert elt_01
gaoqiong's avatar
gaoqiong committed
109
110
111
112
113
114
115
116
117
118
119
    //asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
    h[0]=h[0]-FP16_TOP_MAGIC_NUM;
    // Convert elt_23
    //asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
    h[1]=h[1]*ONE_SIXTEENTH+NEG_64;
    // Convert elt_45
    //asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
    h[2]=h[2]-FP16_TOP_MAGIC_NUM;
    // Convert elt_67
    //asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
    h[3]=h[3]*ONE_SIXTEENTH+NEG_64;
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
    return result;
}

__inline__ __device__ uint4 dequantize_s4_to_fp16x2_v2(uint32_t const& source)
{
    uint4 result;

    uint32_t*       h   = reinterpret_cast<uint32_t*>(&result);
    uint32_t const& i4s = reinterpret_cast<uint32_t const&>(source);

    // First, we extract the i4s and construct an intermediate fp16 number.
    static constexpr uint32_t immLut      = (0xf0 & 0xcc) | 0xaa;
    static constexpr uint32_t BOT_MASK    = 0x000f000f;
    static constexpr uint32_t TOP_MASK    = 0x00f000f0;
    static constexpr uint32_t MAGIC_NUM_0 = 0x64006400;        // `1024`
    static constexpr uint32_t MAGIC_NUM_1 = 0x54005400;        // `64`
    static constexpr uint32_t MAGIC_NUM_2 = MAGIC_NUM_1 >> 4;  // `64` >> 4

    // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
    // dependency if we issue immediately before required.
    const uint32_t top_i4s = i4s >> 8;
gaoqiong's avatar
gaoqiong committed
141
142
143
144
145
146
147
148
149
150
151

 //  64 only, trade 4 hfma2 with 2 shifts
    h[0] =(i4s & BOT_MASK) |MAGIC_NUM_2;
    h[1] =(i4s & TOP_MASK) |MAGIC_NUM_1;
    h[2] =(top_i4s & BOT_MASK) |MAGIC_NUM_2;
    h[3] =(top_i4s & TOP_MASK) |MAGIC_NUM_1;
    h[0] <<= 4;
    h[2] <<= 4;
    // we don't need to subtract the magic nums because zeros will go through the same dequant function
    // and carry the same magic constant, the magic num will be canceled out after subtracting zeros
    
152
153
154
155

    return result;
}

gaoqiong's avatar
gaoqiong committed
156

157
158
159
__inline__ __device__ uint32_t cast_smem_ptr_to_uint(void const* const ptr)
{
    uint32_t smem_int_ptr;
160
161
162
163
    printf("=========common.h 161\n");
    // asm("{.reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n"
    //     : "=r"(smem_int_ptr)
    //     : "l"(ptr));
164
165
166
167
168
169

    return smem_int_ptr;
}

__inline__ __device__ void ldmatrix_m8n8_x4_b16(uint& d0, uint& d1, uint& d2, uint& d3, uint32_t smem_int_ptr)
{
170
171
172
173
174
175
176
177
    printf("=========common.h 171\n");
// #if TURBOMIND_ARCH_SM75
//     asm("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
//         : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
//         : "r"(smem_int_ptr));
// #else
//     assert(TURBOMIND_ARCH_SM75);
// #endif
178
179
180
181
}

__inline__ __device__ void ldmatrix_m8n8_x2_b16(uint& d0, uint& d1, uint32_t smem_int_ptr)
{
182
183
184
185
186
187
    printf("=========common.h 183\n");
// #if TURBOMIND_ARCH_SM75
//     asm("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" : "=r"(d0), "=r"(d1) : "r"(smem_int_ptr));
// #else
//     assert(TURBOMIND_ARCH_SM75);
// #endif
188
189
}

190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
// __inline__ __device__ void wait_flag(int* lock, int status, int thread_id)
// {
//     int state = 0;
//     while (__syncthreads_and(state != status)) {
//         if (thread_id == 0) {
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
//             asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
// #else
//             asm volatile("ld.global.cg.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
// #endif
//         }
//     }

//     __syncthreads();  // memory fence
// }

// __inline__ __device__ void release_flag(int* lock, int status, int thread_id)
// {
//     __syncthreads();  // memory fence

//     if (thread_id == 0) {
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
//         asm volatile("st.global.release.gpu.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
// #else
//         asm volatile("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
// #endif
//     }
// }
218
219
220

__inline__ __device__ half2 apply_Q(const half2& x, const half2& q)
{
gaoqiong's avatar
gaoqiong committed
221
222
223
    //uint s, z;
    //(half2&)z = __halves2half2(q.x, q.x);
    //(half2&)s = __halves2half2(q.y, q.y);
224

gaoqiong's avatar
gaoqiong committed
225
226
    //auto& t = (const uint&)x;
    uint v;
227
228
229
230
231
232
233
    // if (TURBOMIND_S4_DEQUANT_USE_FMA) {
    //     asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(v) : "r"(t), "r"(s), "r"(z));
    // }
    // else {
    //     asm("sub.ftz.f16x2 %0, %1, %2;\n" : "=r"(u) : "r"(t), "r"(z));
    //     asm("mul.ftz.f16x2 %0, %1, %2;\n" : "=r"(v) : "r"(u), "r"(s));
    // }
gaoqiong's avatar
gaoqiong committed
234

235
236
237
238
239
240
    return (half2&)v;
}

template<typename T, int N>
struct Array {

Li Zhang's avatar
Li Zhang committed
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
    using value_type      = T;
    using size_type       = int;
    using difference_type = int;
    using reference       = value_type&;
    using const_reference = const value_type&;
    using pointer         = value_type*;
    using const_pointer   = const value_type*;
    using iterator        = pointer;
    using const_iterator  = const_pointer;

    static_assert(N > 0);

    T __a[N];

    __device__ __host__ constexpr reference operator[](size_type i) noexcept
    {
        return __a[i];
    }
    __device__ __host__ constexpr const_reference operator[](size_type i) const noexcept
    {
        return __a[i];
    }

    __device__ __host__ constexpr reference front() noexcept
    {
        return *begin();
    }

    __device__ __host__ constexpr const_reference front() const noexcept
    {
        return *begin();
    }

    __device__ __host__ constexpr reference back() noexcept
    {
        return *(end() - 1);
    }

    __device__ __host__ constexpr const_reference back() const noexcept
    {
        return *(end() - 1);
    }

    __device__ __host__ constexpr pointer data() noexcept
285
    {
Li Zhang's avatar
Li Zhang committed
286
        return &__a[0];
287
    }
Li Zhang's avatar
Li Zhang committed
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319

    __device__ __host__ constexpr const_pointer data() const noexcept
    {
        return &__a[0];
    }

    __device__ __host__ constexpr iterator begin() noexcept
    {
        return data();
    }

    __device__ __host__ constexpr const_iterator begin() const noexcept
    {
        return data();
    }

    __device__ __host__ constexpr iterator end() noexcept
    {
        return data() + N;
    }

    __device__ __host__ constexpr const_iterator end() const noexcept
    {
        return data() + N;
    }

    __device__ __host__ constexpr std::integral_constant<int, N> size() const noexcept
    {
        return {};
    }

    __device__ __host__ constexpr std::false_type empty() const noexcept
320
    {
Li Zhang's avatar
Li Zhang committed
321
        return {};
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
    }
};

template<int... Ns>
struct Shape {
    static constexpr Array<int, sizeof...(Ns)> data_{Ns...};

    constexpr Shape() = default;

    Shape(std::integral_constant<int, Ns>...){};

    template<int index>
    constexpr auto get() const noexcept
    {
        return std::integral_constant<int, data_[index]>{};
    }

    constexpr auto m() const noexcept
    {
        return get<0>();
    }

    constexpr auto n() const noexcept
    {
        return get<1>();
    }

    constexpr auto k() const noexcept
    {
        return get<2>();
    }

    constexpr int c() const noexcept
    {
        return get<0>();
    }

    constexpr int s() const noexcept
    {
        return get<1>();
    }

    constexpr int count() const noexcept
    {
        return (Ns * ...);
    }
};

}  // namespace turbomind