/* Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h @article{lin2023awq, title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023} } */ #pragma once #include #include #include __forceinline__ __device__ void dequantize_s4_to_fp16x2(half2 const &source, uint4 *result) { uint32_t *h = reinterpret_cast(result); uint32_t const i4s = reinterpret_cast(source); // First, we extract the i4s and construct an intermediate fp16 number. static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; static constexpr uint32_t BOTTOM_MASK = 0x000f000f; static constexpr uint32_t TOP_MASK = 0x00f000f0; static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions. // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and // elt_67 to fp16 without having to shift them to the bottom bits before hand. // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue // immediately before required. const uint32_t top_i4s = i4s >> 8; // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[0]) : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[1]) : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[2]) : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[3]) : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the // half2 ctor. In this case, I chose performance reliability over code readability. // This is the half2 {1032, 1032} represented as an integer. // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7] static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400; // This is the half2 {1 / 16, 1 / 16} represented as an integer. static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; // This is the half2 {-72, -72} represented as an integer. // static constexpr uint32_t NEG_72 = 0xd480d480; // Haotian: Let's use {-64, -64}. static constexpr uint32_t NEG_64 = 0xd400d400; // Finally, we construct the output numbers. // Convert elt_01 // asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); h[0] = __hsub(h[0], __float2half(1024.0f)); // Convert elt_23 // asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); h[1] = __hfma(h[1], __float2half(0.0625f), __float2half(-64.0f)); // Convert elt_45 // asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); h[2] = __hsub(h[2], __float2half(1024.0f)); // Convert elt_67 // asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); h[3] = __hfma(h[3], __float2half(0.0625f), __float2half(-64.0f)); } // 设备端的bfloat16到float转换函数 __device__ float bf16_to_float_device(uint16_t bf16) { // 将bfloat16转为float:bf16左移16位作为float的高16位 uint32_t val = (uint32_t)bf16 << 16; return __uint_as_float(val); } // 设备端的float到bfloat16转换函数 __device__ uint16_t float_to_bf16_device(float f) { // 将float转为bfloat16:取float的高16位 uint32_t float_bits = __float_as_uint(f); // 四舍五入处理 uint32_t rounding_bias = ((float_bits >> 16) & 1) + 0x7FFF; return (uint16_t)((float_bits + rounding_bias) >> 16); } // C++实现的bfloat16x2 FMA函数 __device__ uint32_t fma_bf16x2_cpp(uint32_t a, uint32_t b, uint32_t c) { // 解包a、b、c的高低位 uint16_t a_high = (uint16_t)(a >> 16); uint16_t a_low = (uint16_t)(a & 0xFFFF); uint16_t b_high = (uint16_t)(b >> 16); uint16_t b_low = (uint16_t)(b & 0xFFFF); uint16_t c_high = (uint16_t)(c >> 16); uint16_t c_low = (uint16_t)(c & 0xFFFF); // 将bfloat16转换为float进行计算 // 高位计算:(a_high * b_high) + c_high float a_high_f = bf16_to_float_device(a_high); float b_high_f = bf16_to_float_device(b_high); float c_high_f = bf16_to_float_device(c_high); float result_high_f = a_high_f * b_high_f + c_high_f; uint16_t result_high = float_to_bf16_device(result_high_f); // 低位计算:(a_low * b_low) + c_low float a_low_f = bf16_to_float_device(a_low); float b_low_f = bf16_to_float_device(b_low); float c_low_f = bf16_to_float_device(c_low); float result_low_f = a_low_f * b_low_f + c_low_f; uint16_t result_low = float_to_bf16_device(result_low_f); // 重新打包结果 return ((uint32_t)result_high << 16) | result_low; } __forceinline__ __device__ void dequantize_s4_to_fp16x2(__nv_bfloat162 const &source, uint4 *result) { // dequantize_s4_to_fp16x2(reinterpret_cast(source), result); // *reinterpret_cast<__nv_bfloat162 *>(&result->x) = cuda_cast<__nv_bfloat162>(*reinterpret_cast(&result->x)); *reinterpret_cast<__nv_bfloat162 *>(&result->y) = // cuda_cast<__nv_bfloat162>(*reinterpret_cast(&result->y)); *reinterpret_cast<__nv_bfloat162 // *>(&result->z) = cuda_cast<__nv_bfloat162>(*reinterpret_cast(&result->z)); // *reinterpret_cast<__nv_bfloat162 *>(&result->w) = cuda_cast<__nv_bfloat162>(*reinterpret_cast(&result->w)); // return; // uint4 result; uint32_t *h = reinterpret_cast(result); uint32_t const i4s = reinterpret_cast(source); // First, we extract the i4s and construct an intermediate fp16 number. static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; static constexpr uint32_t MASK = 0x000f000f; static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; // // Extract elt_01 - (i4s & 0x000f000f) | 0x43004300 // asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" // : "=r"(h[0]) // : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); // // Extract elt_23 ((i4s >> 4) & 0x000f000f) | 0x43004300 // asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" // : "=r"(h[1]) // : "r"(i4s >> 4), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); // // Extract elt_45 ((i4s >> 8) & 0x000f000f) | 0x43004300 // asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" // : "=r"(h[2]) // : "r"(i4s >> 8), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); // // Extract elt_67 ((i4s >> 12) & 0x000f000f) | 0x43004300 // asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" // : "=r"(h[3]) // : "r"(i4s >> 12), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); h[0] = ((i4s & MASK) | I4s_TO_BF16s_MAGIC_NUM); h[1] = (((i4s >> 4) & MASK) | I4s_TO_BF16s_MAGIC_NUM); h[2] = (((i4s >> 8) & MASK) | I4s_TO_BF16s_MAGIC_NUM); h[3] = (((i4s >> 12) & MASK) | I4s_TO_BF16s_MAGIC_NUM); // static constexpr uint32_t BF16_BIAS = 0xC308C308; // This is the BF16 {-128, -128} represented as an integer, we do not need to map to [-8, 7] static constexpr uint32_t BF16_BIAS = 0xC300C300; static constexpr uint32_t BF16_ONE = 0x3F803F80; // Finally, we construct the output numbers. // Convert elt_01 // asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[0]) : "r"(h[0]), "r"(BF16_ONE), "r"(BF16_BIAS)); // // Convert elt_23 // asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(BF16_ONE), "r"(BF16_BIAS)); // // Convert elt_45 // asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[2]) : "r"(h[2]), "r"(BF16_ONE), "r"(BF16_BIAS)); // // Convert elt_67 // asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(BF16_ONE), "r"(BF16_BIAS)); h[0] = fma_bf16x2_cpp(h[0], BF16_ONE, BF16_BIAS); h[1] = fma_bf16x2_cpp(h[1], BF16_ONE, BF16_BIAS); h[2] = fma_bf16x2_cpp(h[2], BF16_ONE, BF16_BIAS); h[3] = fma_bf16x2_cpp(h[3], BF16_ONE, BF16_BIAS); }