dequantize.cuh 9.22 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
/*
Muyang Li's avatar
Muyang Li committed
2
3
Modified from NVIDIA FasterTransformer:
https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
Zhekai Zhang's avatar
Zhekai Zhang committed
4
5
6
7
8
9
10
11
12
13

@article{lin2023awq,
  title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
  author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
  journal={arXiv},
  year={2023}
}
*/
#pragma once

fengzch-das's avatar
fengzch-das committed
14
#include <cuda_fp16.h>
fengzch's avatar
fengzch committed
15
#include <cuda_bf16.h>
Zhekai Zhang's avatar
Zhekai Zhang committed
16
17
#include <cstdint>

Muyang Li's avatar
Muyang Li committed
18
__forceinline__ __device__ void dequantize_s4_to_fp16x2(half2 const &source, uint4 *result) {
Zhekai Zhang's avatar
Zhekai Zhang committed
19

Muyang Li's avatar
Muyang Li committed
20
    uint32_t *h        = reinterpret_cast<uint32_t *>(result);
Zhekai Zhang's avatar
Zhekai Zhang committed
21
22
23
    uint32_t const i4s = reinterpret_cast<uint32_t const &>(source);

    // First, we extract the i4s and construct an intermediate fp16 number.
Muyang Li's avatar
Muyang Li committed
24
25
26
    static constexpr uint32_t immLut                = (0xf0 & 0xcc) | 0xaa;
    static constexpr uint32_t BOTTOM_MASK           = 0x000f000f;
    static constexpr uint32_t TOP_MASK              = 0x00f000f0;
Zhekai Zhang's avatar
Zhekai Zhang committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
    static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;

    // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
    // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
    // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
    // elt_67 to fp16 without having to shift them to the bottom bits before hand.

    // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
    // immediately before required.
    const uint32_t top_i4s = i4s >> 8;
    // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
    asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
                 : "=r"(h[0])
                 : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
    // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
    asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
                 : "=r"(h[1])
                 : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
    // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
    asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
                 : "=r"(h[2])
                 : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
    // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
    asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
                 : "=r"(h[3])
                 : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));

    // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
    // half2 ctor. In this case, I chose performance reliability over code readability.

    // This is the half2 {1032, 1032} represented as an integer.
    // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
    // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
    static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
    // This is the half2 {1 / 16, 1 / 16} represented as an integer.
    static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
    // This is the half2 {-72, -72} represented as an integer.
    // static constexpr uint32_t NEG_72 = 0xd480d480;
    // Haotian: Let's use {-64, -64}.
    static constexpr uint32_t NEG_64 = 0xd400d400;

    // Finally, we construct the output numbers.
    // Convert elt_01
fengzch's avatar
fengzch committed
70
    // asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
fengzch's avatar
fengzch committed
71
    h[0] = __hsub(h[0], __float2half(1024.0f));
Zhekai Zhang's avatar
Zhekai Zhang committed
72
    // Convert elt_23
fengzch's avatar
fengzch committed
73
    // asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
fengzch's avatar
fengzch committed
74
    h[1] = __hfma(h[1], __float2half(0.0625f), __float2half(-64.0f));
Zhekai Zhang's avatar
Zhekai Zhang committed
75
    // Convert elt_45
fengzch's avatar
fengzch committed
76
    // asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
fengzch's avatar
fengzch committed
77
    h[2] = __hsub(h[2], __float2half(1024.0f));
Zhekai Zhang's avatar
Zhekai Zhang committed
78
    // Convert elt_67
fengzch's avatar
fengzch committed
79
    // asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
fengzch's avatar
fengzch committed
80
    h[3] = __hfma(h[3], __float2half(0.0625f), __float2half(-64.0f));    
Zhekai Zhang's avatar
Zhekai Zhang committed
81
82
}

fengzch's avatar
fengzch committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
// 设备端的bfloat16到float转换函数
__device__ float bf16_to_float_device(uint16_t bf16) {
    // 将bfloat16转为float:bf16左移16位作为float的高16位
    uint32_t val = (uint32_t)bf16 << 16;
    return __uint_as_float(val);
}

// 设备端的float到bfloat16转换函数
__device__ uint16_t float_to_bf16_device(float f) {
    // 将float转为bfloat16:取float的高16位
    uint32_t float_bits = __float_as_uint(f);
    // 四舍五入处理
    uint32_t rounding_bias = ((float_bits >> 16) & 1) + 0x7FFF;
    return (uint16_t)((float_bits + rounding_bias) >> 16);
}

// C++实现的bfloat16x2 FMA函数
__device__ uint32_t fma_bf16x2_cpp(uint32_t a, uint32_t b, uint32_t c) {
    // 解包a、b、c的高低位
    uint16_t a_high = (uint16_t)(a >> 16);
    uint16_t a_low = (uint16_t)(a & 0xFFFF);
    
    uint16_t b_high = (uint16_t)(b >> 16);
    uint16_t b_low = (uint16_t)(b & 0xFFFF);
    
    uint16_t c_high = (uint16_t)(c >> 16);
    uint16_t c_low = (uint16_t)(c & 0xFFFF);
    
    // 将bfloat16转换为float进行计算
    // 高位计算:(a_high * b_high) + c_high
    float a_high_f = bf16_to_float_device(a_high);
    float b_high_f = bf16_to_float_device(b_high);
    float c_high_f = bf16_to_float_device(c_high);
    float result_high_f = a_high_f * b_high_f + c_high_f;
    uint16_t result_high = float_to_bf16_device(result_high_f);
    
    // 低位计算:(a_low * b_low) + c_low
    float a_low_f = bf16_to_float_device(a_low);
    float b_low_f = bf16_to_float_device(b_low);
    float c_low_f = bf16_to_float_device(c_low);
    float result_low_f = a_low_f * b_low_f + c_low_f;
    uint16_t result_low = float_to_bf16_device(result_low_f);
    
    // 重新打包结果
    return ((uint32_t)result_high << 16) | result_low;
}

fengzch-das's avatar
fengzch-das committed
130
__forceinline__ __device__ void dequantize_s4_to_fp16x2(__nv_bfloat162 const &source, uint4 *result) {
Zhekai Zhang's avatar
Zhekai Zhang committed
131
    // dequantize_s4_to_fp16x2(reinterpret_cast<const half2 &>(source), result);
Muyang Li's avatar
Muyang Li committed
132

fengzch-das's avatar
fengzch-das committed
133
134
135
136
137
    // *reinterpret_cast<__nv_bfloat162 *>(&result->x) = cuda_cast<__nv_bfloat162>(*reinterpret_cast<half2
    // *>(&result->x)); *reinterpret_cast<__nv_bfloat162 *>(&result->y) =
    // cuda_cast<__nv_bfloat162>(*reinterpret_cast<half2 *>(&result->y)); *reinterpret_cast<__nv_bfloat162
    // *>(&result->z) = cuda_cast<__nv_bfloat162>(*reinterpret_cast<half2 *>(&result->z));
    // *reinterpret_cast<__nv_bfloat162 *>(&result->w) = cuda_cast<__nv_bfloat162>(*reinterpret_cast<half2
Muyang Li's avatar
Muyang Li committed
138
    // *>(&result->w));
Zhekai Zhang's avatar
Zhekai Zhang committed
139
140
141
142
143

    // return;

    // uint4 result;

Muyang Li's avatar
Muyang Li committed
144
    uint32_t *h        = reinterpret_cast<uint32_t *>(result);
Zhekai Zhang's avatar
Zhekai Zhang committed
145
146
147
    uint32_t const i4s = reinterpret_cast<uint32_t const &>(source);

    // First, we extract the i4s and construct an intermediate fp16 number.
Muyang Li's avatar
Muyang Li committed
148
149
    static constexpr uint32_t immLut                 = (0xf0 & 0xcc) | 0xaa;
    static constexpr uint32_t MASK                   = 0x000f000f;
Zhekai Zhang's avatar
Zhekai Zhang committed
150
151
    static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300;

fengzch's avatar
fengzch committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
    // // Extract elt_01 - (i4s & 0x000f000f) | 0x43004300
    // asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
    //              : "=r"(h[0])
    //              : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
    // // Extract elt_23 ((i4s >> 4) & 0x000f000f) | 0x43004300
    // asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
    //              : "=r"(h[1])
    //              : "r"(i4s >> 4), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
    // // Extract elt_45 ((i4s >> 8) & 0x000f000f) | 0x43004300
    // asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
    //              : "=r"(h[2])
    //              : "r"(i4s >> 8), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
    // // Extract elt_67 ((i4s >> 12) & 0x000f000f) | 0x43004300
    // asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
    //              : "=r"(h[3])
    //              : "r"(i4s >> 12), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));

    h[0] = ((i4s & MASK) | I4s_TO_BF16s_MAGIC_NUM);
    h[1] = (((i4s >> 4) & MASK) | I4s_TO_BF16s_MAGIC_NUM);
    h[2] = (((i4s >> 8) & MASK) | I4s_TO_BF16s_MAGIC_NUM);
    h[3] = (((i4s >> 12) & MASK) | I4s_TO_BF16s_MAGIC_NUM);
Zhekai Zhang's avatar
Zhekai Zhang committed
173
174
175
176
177
178
179
180

    // static constexpr uint32_t BF16_BIAS = 0xC308C308;
    // This is the BF16 {-128, -128} represented as an integer, we do not need to map to [-8, 7]
    static constexpr uint32_t BF16_BIAS = 0xC300C300;
    static constexpr uint32_t BF16_ONE  = 0x3F803F80;

    // Finally, we construct the output numbers.
    // Convert elt_01
fengzch's avatar
fengzch committed
181
182
183
184
185
186
187
    // asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[0]) : "r"(h[0]), "r"(BF16_ONE), "r"(BF16_BIAS));
    // // Convert elt_23
    // asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(BF16_ONE), "r"(BF16_BIAS));
    // // Convert elt_45
    // asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[2]) : "r"(h[2]), "r"(BF16_ONE), "r"(BF16_BIAS));
    // // Convert elt_67
    // asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(BF16_ONE), "r"(BF16_BIAS));
fengzch's avatar
fengzch committed
188
189
190
191
    h[0] = fma_bf16x2_cpp(h[0], BF16_ONE, BF16_BIAS);
    h[1] = fma_bf16x2_cpp(h[1], BF16_ONE, BF16_BIAS);
    h[2] = fma_bf16x2_cpp(h[2], BF16_ONE, BF16_BIAS);
    h[3] = fma_bf16x2_cpp(h[3], BF16_ONE, BF16_BIAS);
Muyang Li's avatar
Muyang Li committed
192
}