"vscode:/vscode.git/clone" did not exist on "25cee5810e8da6c2ce4611b413b0fb14c853b4a8"
Commit ebe520c5 authored by Samuel Tesfai's avatar Samuel Tesfai
Browse files

Migrating tinychat 4-bit textencoder from deepcompressor to nunchaku

parent 9900050d
# -*- coding: utf-8 -*-
"""TinyChat Extension."""
import os
from torch.utils.cpp_extension import load
__all__ = ["_C"]
dirpath = os.path.dirname(__file__)
_C = load(
name="nunchaku_tinychat_C",
sources=[
f"{dirpath}/tinychat_pybind.cpp",
f"{dirpath}/quantization/gemv/gemv_cuda.cu",
f"{dirpath}/quantization/gemm/gemm_cuda.cu",
],
extra_cflags=["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++20"],
extra_cuda_cflags=[
"-O3",
"-std=c++20",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-U__CUDA_NO_HALF2_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
"--ptxas-options=--allow-expensive-optimizations=true",
"--threads=8",
],
)
/*
Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
}
*/
#pragma once
#include <cuda_fp16.h>
#include <cuda_bf16.h>
template <typename T>
__device__ __forceinline__ void dequantize_s4_to_f16x2(T const &source, uint4 *result);
template <>
__device__ __forceinline__ void dequantize_s4_to_f16x2<half2>(half2 const &source, uint4 *result)
{
uint32_t *h = reinterpret_cast<uint32_t *>(result);
uint32_t const i4s = reinterpret_cast<uint32_t const &>(source);
// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
static constexpr uint32_t TOP_MASK = 0x00f000f0;
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
// elt_67 to fp16 without having to shift them to the bottom bits before hand.
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
// immediately before required.
const uint32_t top_i4s = i4s >> 8;
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[1])
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[2])
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[3])
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
// half2 ctor. In this case, I chose performance reliability over code readability.
// This is the half2 {1032, 1032} represented as an integer.
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
// Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
// This is the half2 {-72, -72} represented as an integer.
// static constexpr uint32_t NEG_72 = 0xd480d480;
// Haotian: Let's use {-64, -64}.
static constexpr uint32_t NEG_64 = 0xd400d400;
// Finally, we construct the output numbers.
// Convert elt_01
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_23
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// Convert elt_45
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_67
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
}
template <>
__device__ __forceinline__ void dequantize_s4_to_f16x2<__nv_bfloat162>(__nv_bfloat162 const &source, uint4 *result)
{
uint32_t *h = reinterpret_cast<uint32_t *>(result);
uint32_t const source_i4s = reinterpret_cast<uint32_t const &>(source);
// First, we extract the i4s and construct an intermediate bf16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300;
uint32_t i4s = source_i4s;
// Extract elt_01 - (i4s & 0x000f000f) | 0x43004300
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
// Extract elt_23 (i4s & 0x00f000f0) | 0x43004300
i4s >>= 4;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[1])
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
// Extract elt_45 (top_i4s & 0x000f000f) | 0x43004300
i4s >>= 4;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[2])
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x43004300
i4s >>= 4;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[3])
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
// This is the BF16 {-136, -136} represented as an integer.
// static constexpr uint32_t BF16_BIAS = 0xC308C308;
// This is the BF16 {-128, -128} represented as an integer, we do not need to map to [-8, 7]
static constexpr uint32_t NEG_128 = 0xC300C300;
static constexpr uint32_t ONE = 0x3F803F80;
// Finally, we construct the output numbers.
// Convert elt_01
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[0]) : "r"(h[0]), "r"(ONE), "r"(NEG_128));
// Convert elt_23
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE), "r"(NEG_128));
// Convert elt_45
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[2]) : "r"(h[2]), "r"(ONE), "r"(NEG_128));
// Convert elt_67
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE), "r"(NEG_128));
}
\ No newline at end of file
This diff is collapsed.
#include <torch/extension.h>
torch::Tensor awq_gemm_forward_cuda(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scales,
torch::Tensor _zeros);
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Implementation of a CTA-wide semaphore for inter-CTA synchronization.
*/
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
// namespace cutlass {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// CTA-wide semaphore for inter-CTA synchronization.
class Semaphore
{
public:
int *lock;
bool wait_thread;
int state;
public:
/// Implements a semaphore to wait for a flag to reach a given value
__host__ __device__ Semaphore(int *lock_, int thread_id) : lock(lock_),
wait_thread(thread_id < 0 || thread_id == 0),
state(-1)
{
}
/// Permit fetching the synchronization mechanism early
__device__ void fetch()
{
if (wait_thread)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
#else
asm volatile("ld.global.cg.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
#endif
}
}
/// Gets the internal state
__device__ int get_state() const
{
return state;
}
/// Waits until the semaphore is equal to the given value
__device__ void wait(int status = 0)
{
while (__syncthreads_and(state != status))
{
fetch();
}
__syncthreads();
}
/// Updates the lock with the given result
__device__ void release(int status = 0)
{
__syncthreads();
if (wait_thread)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm volatile("st.global.release.gpu.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
#else
asm volatile("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
#endif
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// } // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/*
* Modified from NVIDIA [TRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/d37b507f41a87457fe9f10f7459d08f5db235745/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv)
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
}
*/
#include <stdio.h>
#include <torch/extension.h>
#include "gemv_cuda.h"
#include "../dequantize.cuh"
#include "../../utils.cuh"
#define PACK_FACTOR 8
#define WARP_SIZE 32
#define MEM_ACCESS_SIZE 128
// Reduce sum within the warp using the tree reduction algorithm.
template <typename fp_t, int Num, int WarpSize>
__device__ __forceinline__ static void warp_reduce(fp_t *psum, float (*out_smem)[Num * 4])
{
// kInterleave = 4
float fpsum[Num];
#pragma unroll
for (int i = 0; i < Num; ++i)
{
fpsum[i] = static_cast<float>(psum[i]);
}
#pragma unroll
for (int i = 0; i < Num; ++i)
{
// T0 + T1 + T8 + T9 + T16 + T17 + T24 + T25 (kInterleave = 4)
fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 16);
fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 8);
fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 1);
}
__syncthreads();
int warp = threadIdx.x / WarpSize, lane = threadIdx.x % WarpSize;
if (lane == 0 || lane == 2 || lane == 4 || lane == 6)
{
#pragma unroll
for (int i = 0; i < Num; ++i)
{
out_smem[warp][i * 4 + lane / 2] = fpsum[i];
}
}
__syncthreads();
};
__device__ __forceinline__ int make_divisible(int c, int divisor)
{
return (c + divisor - 1) / divisor;
}
template <typename f16_t, int NPerBlock, int Batch, int BlockSize, int GroupSize>
__global__ void gemv_kernel(
const f16_t *inputs, const uint32_t *weight, const f16_t *scales, const f16_t *zeros, f16_t *outputs,
const int IC, const int OC)
{
using f162_t = typename packed_as<f16_t, 2>::type;
using accum_t = float;
using accum2_t = typename packed_as<accum_t, 2>::type;
const int kStride = 64;
const int kElemsPerThread = MEM_ACCESS_SIZE / 4;
const int kThreadsNumPerTile = kStride / kElemsPerThread;
static constexpr int kShuffleBasicTile = 2;
static constexpr int kShuffleContinous = 4;
static constexpr int kShuffleStrided = 4;
constexpr int Num = NPerBlock * Batch;
constexpr int kInterleave = 4;
alignas(16) f16_t local_inputs[kElemsPerThread];
alignas(16) uint32_t local_qweights[MEM_ACCESS_SIZE / 32];
alignas(16) f16_t half_weight_buffer[kElemsPerThread];
alignas(16) f16_t dequantized_weight[kElemsPerThread * NPerBlock];
alignas(16) f16_t local_scale[NPerBlock];
alignas(16) f16_t local_scaled_zeros[NPerBlock];
accum_t psum[Num];
for (int i = 0; i < Num; ++i)
psum[i] = static_cast<accum_t>(0.f);
extern __shared__ uint8_t shmem[];
float(*out_smem)[Num * kInterleave] = reinterpret_cast<float(*)[Num * kInterleave]>(shmem);
const int blk_row_offset = blockIdx.x * NPerBlock * kInterleave;
const int thd_row_offset = (threadIdx.x / kThreadsNumPerTile) % kInterleave;
const int act_k_offset = threadIdx.x / (kThreadsNumPerTile * kInterleave) * kStride + (threadIdx.x % kThreadsNumPerTile) * kElemsPerThread;
const int group_offset = act_k_offset / GroupSize;
// TODO: use make_divisible
const uint32_t *blk_weight_ptr = weight + blk_row_offset * IC / PACK_FACTOR;
const f16_t *scale_ptr = scales + blk_row_offset + thd_row_offset + group_offset * OC;
const f16_t *zeros_ptr = zeros + blk_row_offset + thd_row_offset + group_offset * OC;
const f16_t *inputs_ptr = inputs + act_k_offset;
const int act_forward_step = BlockSize * kElemsPerThread / kInterleave;
const int scale_forward_step = act_forward_step / GroupSize * OC;
// Main loop iteration, each block completes the outputs for several OCs
for (int kk = threadIdx.x * kElemsPerThread; kk < IC * kInterleave; kk += BlockSize * kElemsPerThread)
{
// Load qweight, scales and scaled_zeros
#pragma unroll
for (int idx = 0; idx < NPerBlock; ++idx)
{
// use float4 to load weights, each thread load 32 int4 numbers (1 x float4, 128 bit)
*((float4 *)(local_qweights)) =
*((float4 *)(blk_weight_ptr + (idx * kInterleave * IC + kk) / PACK_FACTOR));
local_scale[idx] = *(scale_ptr + idx * kInterleave);
local_scaled_zeros[idx] = *(zeros_ptr + idx * kInterleave);
// Map int4 qweight to fp format
#pragma unroll
for (int i = 0; i < MEM_ACCESS_SIZE / 32; ++i)
{
// Converts 32 bits (8 x int4) to 8 fp16
dequantize_s4_to_f16x2(*reinterpret_cast<f162_t *>(local_qweights + i), reinterpret_cast<uint4 *>(half_weight_buffer + i * PACK_FACTOR));
}
// Dequantize (apply s/z) and shuffle elements to match the weight packing format
#pragma unroll
for (int i = 0; i < kShuffleContinous; ++i)
{
#pragma unroll
for (int j = 0; j < kShuffleStrided; ++j)
{
f162_t w =
*reinterpret_cast<f162_t *>(
half_weight_buffer + (i + j * kShuffleContinous) * kShuffleBasicTile);
w = __hfma2(w, f162f162(local_scale[idx]), f162f162(local_scaled_zeros[idx]));
dequantized_weight[((i * kShuffleStrided + j) * kShuffleBasicTile + 0) * NPerBlock + idx] = w.x;
dequantized_weight[((i * kShuffleStrided + j) * kShuffleBasicTile + 1) * NPerBlock + idx] = w.y;
}
}
}
#pragma unroll
for (int batch_idx = 0; batch_idx < Batch; ++batch_idx)
{
const f16_t *local_inputs_ptr = inputs_ptr + batch_idx * IC;
#pragma unroll
for (int idx = 0; idx < kElemsPerThread / 8; ++idx)
{
// load activation, 8 halves (128 bits) / step.
*((float4 *)(local_inputs + idx * 8)) = *((float4 *)(local_inputs_ptr + idx * 8));
}
// Perform the MACs
#pragma unroll
for (int x = 0; x < NPerBlock / 2; ++x)
{
#pragma unroll
for (int y = 0; y < kElemsPerThread; ++y)
{
accum2_t prod = cuda_cast<accum2_t>(__hmul2(
*reinterpret_cast<f162_t *>(dequantized_weight + y * NPerBlock + x * 2),
f162f162(local_inputs[y])));
*reinterpret_cast<accum2_t *>(psum + batch_idx * NPerBlock + x * 2) = prod + *reinterpret_cast<accum2_t *>(psum + batch_idx * NPerBlock + x * 2);
}
}
}
inputs_ptr += act_forward_step;
scale_ptr += scale_forward_step;
zeros_ptr += scale_forward_step;
}
warp_reduce<accum_t, Num, WARP_SIZE>(psum, out_smem);
// Num * Interleave = batch * NPerBlock * Interleave -> 1 thread_block write back num
for (int i = threadIdx.x; i < Num * kInterleave; i += BlockSize)
{
int batch_idx = i / (NPerBlock * kInterleave);
int oc_idx = i % (NPerBlock * kInterleave);
float acc = 0.f;
for (int j = 0; j < BlockSize / WARP_SIZE; ++j)
{
acc += out_smem[j][i];
}
outputs[batch_idx * OC + blk_row_offset + oc_idx] = static_cast<f16_t>(acc);
}
}
/*
Computes GEMV (PyTorch interface).
Args:
_in_feats: tensor of shape [B, IC];
_kernel: int tensor of shape [OC, IC // 8];
_zeros: int tensor of shape [OC, IC // G // 8];
_scaling_factors: tensor of shape [OC, IC // G];
blockDim_x: size of thread block, dimension x, where blockDim_x * workload_per_thread = IC;
blockDim_y: size of thread block, dimension y, where blockDim_y * gridDim_y = OC;
Returns:
out_feats: tensor of shape [B, OC];
*/
torch::Tensor awq_gemv_forward_cuda(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int m,
int n,
int k,
int group_size)
{
std::vector<int64_t> output_shape = _in_feats.sizes().vec();
output_shape.back() = n;
auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
at::Tensor _out_feats = torch::empty(output_shape, options);
static constexpr int N_PER_BLOCK = 2;
static constexpr int K_INTERLEAVE = 4;
static constexpr int BLOCK_SIZE = 256;
dim3 num_blocks(n / N_PER_BLOCK / K_INTERLEAVE);
dim3 num_threads(BLOCK_SIZE);
AT_DISPATCH_REDUCED_FLOATING_TYPES(
_in_feats.scalar_type(),
"awq_gemv_forward_cuda",
[&]
{
using f16_t = typename to_cpp_t<scalar_t>::type;
auto in_feats = reinterpret_cast<f16_t *>(_in_feats.data_ptr());
auto kernel = reinterpret_cast<uint32_t *>(_kernel.data_ptr());
auto zeros = reinterpret_cast<f16_t *>(_zeros.data_ptr());
auto scaling_factors = reinterpret_cast<f16_t *>(_scaling_factors.data_ptr());
auto out_feats = reinterpret_cast<f16_t *>(_out_feats.data_ptr());
if (group_size == 128)
{
switch (m)
{
case 1:
gemv_kernel<f16_t, N_PER_BLOCK, 1, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n);
break;
case 2:
gemv_kernel<f16_t, N_PER_BLOCK, 2, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n);
break;
case 3:
gemv_kernel<f16_t, N_PER_BLOCK, 3, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n);
break;
case 4:
gemv_kernel<f16_t, N_PER_BLOCK, 4, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n);
break;
case 5:
gemv_kernel<f16_t, N_PER_BLOCK, 5, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n);
break;
case 6:
gemv_kernel<f16_t, N_PER_BLOCK, 6, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n);
break;
case 7:
gemv_kernel<f16_t, N_PER_BLOCK, 7, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n);
break;
default:
throw std::runtime_error("Unsupported batch size for gemv kernel.\n");
}
}
else
{
throw std::runtime_error("Unsupported group size for gemv kernel.\n");
}
});
return _out_feats;
}
#pragma once
#include <torch/extension.h>
torch::Tensor awq_gemv_forward_cuda(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int m,
int n,
int k,
int group_size);
#include <pybind11/pybind11.h>
#include <torch/extension.h>
#include "quantization/gemm/gemm_cuda.h"
#include "quantization/gemv/gemv_cuda.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("awq_gemm_forward_cuda", &awq_gemm_forward_cuda, "AWQ quantized GEMM kernel.");
m.def("awq_gemv_forward_cuda", &awq_gemv_forward_cuda, "AWQ quantized GEMV kernel.");
}
\ No newline at end of file
// Adated from FasterTransformer, https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
#pragma once
#include <assert.h>
#include <stdint.h>
#include <float.h>
#include <type_traits>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#define ENABLE_BF16 1
template <typename T>
struct to_cpp_t;
template <>
struct to_cpp_t<at::Half>
{
using type = half;
};
template <>
struct to_cpp_t<at::BFloat16>
{
using type = __nv_bfloat16;
};
template <typename T>
struct num_elems;
template <>
struct num_elems<float>
{
static constexpr int value = 1;
};
template <>
struct num_elems<float2>
{
static constexpr int value = 2;
};
template <>
struct num_elems<float4>
{
static constexpr int value = 4;
};
template <>
struct num_elems<half>
{
static constexpr int value = 1;
};
template <>
struct num_elems<half2>
{
static constexpr int value = 2;
};
#ifdef ENABLE_BF16
template <>
struct num_elems<__nv_bfloat16>
{
static constexpr int value = 1;
};
template <>
struct num_elems<__nv_bfloat162>
{
static constexpr int value = 2;
};
#endif
template <typename T, int num>
struct packed_as;
template <typename T>
struct packed_as<T, 1>
{
using type = T;
};
template <>
struct packed_as<half, 2>
{
using type = half2;
};
template <>
struct packed_as<float, 2>
{
using type = float2;
};
template <>
struct packed_as<int8_t, 2>
{
using type = int16_t;
};
template <>
struct packed_as<int32_t, 2>
{
using type = int2;
};
template <>
struct packed_as<half2, 1>
{
using type = half;
};
template <>
struct packed_as<float2, 1>
{
using type = float;
};
#ifdef ENABLE_BF16
template <>
struct packed_as<__nv_bfloat16, 2>
{
using type = __nv_bfloat162;
};
template <>
struct packed_as<__nv_bfloat162, 1>
{
using type = __nv_bfloat16;
};
#endif
#ifdef ENABLE_FP8
template <>
struct packed_as<__nv_fp8_e4m3, 2>
{
using type = __nv_fp8x2_e4m3;
};
template <>
struct packed_as<__nv_fp8x2_e4m3, 1>
{
using type = __nv_fp8_e4m3;
};
template <>
struct packed_as<__nv_fp8_e5m2, 2>
{
using type = __nv_fp8x2_e5m2;
};
template <>
struct packed_as<__nv_fp8x2_e5m2, 1>
{
using type = __nv_fp8_e5m2;
};
#endif
template <typename f16_t>
__device__ __forceinline__
packed_as<f16_t, 2>::type
f162f162(f16_t x);
template <>
__device__ __forceinline__
packed_as<half, 2>::type
f162f162<half>(half x)
{
return __half2half2(x);
}
#ifdef ENABLE_BF16
template <>
__device__ __forceinline__
packed_as<__nv_bfloat16, 2>::type
f162f162<__nv_bfloat16>(__nv_bfloat16 x)
{
return __bfloat162bfloat162(x);
}
# endif
template <typename T>
__device__ __forceinline__
float2
f1622float2(T val);
template <>
__device__ __forceinline__
float2
f1622float2<half2>(half2 val)
{
return __half22float2(val);
}
#ifdef ENABLE_BF16
template <>
__device__ __forceinline__
float2
f1622float2<__nv_bfloat162>(__nv_bfloat162 val)
{
return __bfloat1622float2(val);
}
# endif
inline __device__ float2 operator*(float2 a, float2 b) { return make_float2(a.x * b.x, a.y * b.y); }
inline __device__ float2 operator+(float2 a, float2 b) { return make_float2(a.x + b.x, a.y + b.y); }
inline __device__ float2 operator-(float2 a, float2 b) { return make_float2(a.x - b.x, a.y - b.y); }
inline __device__ float2 operator*(float2 a, float b) { return make_float2(a.x * b, a.y * b); }
inline __device__ float2 operator+(float2 a, float b) { return make_float2(a.x + b, a.y + b); }
inline __device__ float2 operator-(float2 a, float b) { return make_float2(a.x - b, a.y - b); }
static inline __device__ int8_t float_to_int8_rn(float x)
{
uint32_t dst;
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x));
return reinterpret_cast<const int8_t &>(dst);
}
template <typename T>
inline __device__ T ldg(const T *val)
{
return __ldg(val);
}
#if ENABLE_BF16
#define float22bf162 __float22bfloat162_rn
inline __device__ int16_t bf1622int16(__nv_bfloat162 val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float2 f_val;
f_val.x = max(min(__low2float(val), 127.f), -128.f);
f_val.y = max(min(__high2float(val), 127.f), -128.f);
union
{
int8_t int8[2];
int16_t int16;
};
int8[0] = static_cast<int8_t>(static_cast<short>(f_val.x));
int8[1] = static_cast<int8_t>(static_cast<short>(f_val.y));
return int16;
#else
val = __hmin2(val, make_bfloat162(127., 127.));
val = __hmax2(val, make_bfloat162(-128., -128.));
union
{
int8_t int8[2];
int16_t int16;
};
int8[0] = static_cast<int8_t>(static_cast<short>(val.x));
int8[1] = static_cast<int8_t>(static_cast<short>(val.y));
return int16;
#endif
}
#endif
#if ENABLE_BF16
template <>
inline __device__ __nv_bfloat162 ldg(const __nv_bfloat162 *val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return val[0];
#else
return __ldg(val);
#endif
}
template <>
inline __device__ __nv_bfloat16 ldg(const __nv_bfloat16 *val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return val[0];
#else
return __ldg(val);
#endif
}
#endif // ENABLE_BF16
template <typename T_OUT, typename T_IN>
__device__ inline T_OUT cuda_cast(T_IN val)
{
return val;
}
template <>
__device__ inline float2 cuda_cast<float2, int2>(int2 val)
{
return make_float2(val.x, val.y);
}
template <>
__device__ inline float2 cuda_cast<float2, float>(float val)
{
return make_float2(val, val);
}
template <>
__device__ inline float2 cuda_cast<float2, half2>(half2 val)
{
return __half22float2(val);
}
template <>
__device__ inline half2 cuda_cast<half2, float2>(float2 val)
{
return __float22half2_rn(val);
}
template <>
__device__ inline half2 cuda_cast<half2, float>(float val)
{
return __float2half2_rn(val);
}
template <>
__device__ inline half2 cuda_cast<half2, half>(half val)
{
return __half2half2(val);
}
template <>
__device__ inline int8_t cuda_cast<int8_t, half>(half val)
{
union
{
int8_t int8[2];
int16_t int16;
};
union
{
half fp16;
int16_t int16_in;
};
fp16 = val;
asm volatile("cvt.rni.sat.s8.f16 %0, %1;" : "=h"(int16) : "h"(int16_in));
return int8[0];
}
template <>
__device__ inline int16_t cuda_cast<int16_t, half2>(half2 val)
{
union
{
int8_t int8[2];
int16_t int16;
};
int8[0] = cuda_cast<int8_t>(val.x);
int8[1] = cuda_cast<int8_t>(val.y);
return int16;
}
template <>
__device__ inline int8_t cuda_cast<int8_t, float>(float val)
{
union
{
int8_t int8[2];
int16_t int16;
};
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val));
return int8[0];
}
template <>
__device__ inline int16_t cuda_cast<int16_t, float2>(float2 val)
{
union
{
int8_t int8[2];
int16_t int16;
};
int8[0] = cuda_cast<int8_t>(val.x);
int8[1] = cuda_cast<int8_t>(val.y);
return int16;
}
template <>
__device__ inline half2 cuda_cast<half2, int16_t>(int16_t val)
{
union
{
int8_t int8[2];
int16_t int16;
};
int16 = val;
return make_half2(int8[0], int8[1]);
}
template <>
__device__ inline float2 cuda_cast<float2, int16_t>(int16_t val)
{
union
{
int8_t int8[2];
int16_t int16;
};
int16 = val;
return make_float2(int8[0], int8[1]);
}
#ifdef ENABLE_BF16
template <>
__device__ inline __nv_bfloat16 cuda_cast(int32_t val)
{
return static_cast<float>(val);
}
template <>
__device__ inline __nv_bfloat16 cuda_cast(int8_t val)
{
return static_cast<float>(val);
}
template <>
__device__ inline int8_t cuda_cast(__nv_bfloat16 val)
{
return static_cast<float>(val);
}
template <>
__device__ inline float cuda_cast<float, __nv_bfloat16>(__nv_bfloat16 val)
{
return __bfloat162float(val);
}
template <>
__device__ inline float2 cuda_cast<float2, __nv_bfloat162>(__nv_bfloat162 val)
{
return __bfloat1622float2(val);
}
template <>
__device__ inline half cuda_cast<half, __nv_bfloat16>(__nv_bfloat16 val)
{
return __float2half(__bfloat162float(val));
}
template <>
__device__ inline int16_t cuda_cast<int16_t, __nv_bfloat162>(__nv_bfloat162 val)
{
return bf1622int16(val);
}
template <>
__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val)
{
return __float2bfloat16(val);
}
template <>
__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, half>(half val)
{
return __float2bfloat16(__half2float(val));
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_bfloat16>(__nv_bfloat16 val)
{
return __bfloat162bfloat162(val);
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float>(float val)
{
return __float2bfloat162_rn(val);
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float2>(float2 val)
{
return float22bf162(val);
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, int16_t>(int16_t val)
{
union
{
int8_t int8[2];
int16_t int16;
};
int16 = val;
__nv_bfloat162 res;
res.x = cuda_cast<__nv_bfloat16>(int8[0]);
res.y = cuda_cast<__nv_bfloat16>(int8[1]);
return res;
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, half2>(half2 val)
{
return float22bf162(__half22float2(val));
}
#endif // ENABLE BF16
template <typename To, typename Ti>
__device__ inline To cuda_sum(Ti val)
{
return cuda_cast<To>(val);
};
template <typename To>
__device__ inline To cuda_sum(float2 val)
{
return cuda_cast<To>(val.x + val.y);
};
// Unary maximum: compute the max of a vector type
template <typename To, typename Ti>
__device__ inline To cuda_max(Ti val)
{
return cuda_cast<To>(val);
};
template <>
__device__ inline float cuda_max(float2 val)
{
return fmaxf(val.x, val.y);
}
template <>
__device__ inline half cuda_max(half2 val)
{
return __hmax(val.x, val.y);
}
#ifdef ENABLE_BF16
template <>
__device__ inline __nv_bfloat16 cuda_max(__nv_bfloat162 val)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
return __hmax(val.x, val.y);
#endif
}
#endif
// Binary maximum: compute the max of two scalar types
template <typename T>
__device__ inline T cuda_max(T val1, T val2)
{
return (val1 > val2) ? val1 : val2;
}
template <typename T>
__device__ inline T cuda_abs(T val)
{
assert(false);
return {};
}
template <>
__device__ inline float cuda_abs(float val)
{
return fabs(val);
}
template <>
__device__ inline float2 cuda_abs(float2 val)
{
return make_float2(fabs(val.x), fabs(val.y));
}
template <>
__device__ inline half cuda_abs(half val)
{
return __habs(val);
}
template <>
__device__ inline half2 cuda_abs(half2 val)
{
return __habs2(val);
}
#ifdef ENABLE_BF16
#if __CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)
template <>
__device__ inline __nv_bfloat16 cuda_abs(__nv_bfloat16 val)
{
return __habs(val);
}
template <>
__device__ inline __nv_bfloat162 cuda_abs(__nv_bfloat162 val)
{
return __habs2(val);
}
#endif
#endif
# -*- coding: utf-8 -*-
"""QServe state dict converter module."""
import argparse
import os
import safetensors.torch
import torch
import tqdm
def ceil_divide(x: int, divisor: int) -> int:
"""Ceiling division.
Args:
x (`int`):
dividend.
divisor (`int`):
divisor.
Returns:
`int`:
ceiling division result.
"""
return (x + divisor - 1) // divisor
def ceil_num_groups(in_features: int, group_size: int, weight_bits: int = 4) -> int:
"""Calculate the ceiling number of quantization groups.
Args:
in_features (`int`):
input channel size.
group_size (`int`):
quantization group size.
weight_bits (`int`, *optional*, defaults to `4`):
quantized weight bits.
Returns:
`int`:
ceiling number of quantization groups.
"""
assert in_features % group_size == 0, "input channel size should be divisible by group size."
num_groups = in_features // group_size
assert weight_bits in (4, 2, 1), "weight bits should be 4, 2, or 1."
pack_size = 32 // weight_bits # one INT32 contains `pack_size` elements of weights
num_packs = ceil_divide(num_groups, pack_size)
if group_size >= 128:
num_packs_factor = 1
elif group_size == 64:
num_packs_factor = 2
elif group_size == 32:
num_packs_factor = 4
else:
raise NotImplementedError
# make sure num_packs is a multiple of num_packs_factor
num_packs = ceil_divide(num_packs, num_packs_factor) * num_packs_factor
num_groups = num_packs * pack_size
return num_groups
def pack_w4(weight: torch.Tensor) -> torch.Tensor:
assert weight.dtype == torch.int32, f"quantized weight should be torch.int32, but got {weight.dtype}."
oc, ic = weight.shape
assert ic % 32 == 0, "input channel size should be divisible by 32."
# [0, 1, ..., 31] -> [0, 8, 16, 24, 1, 9, 17, 25, ..., 7, 15, 23, 31]
weight = weight.view(-1, 4, 8)
weight = weight[:, 0] | (weight[:, 1] << 4) | (weight[:, 2] << 8) | (weight[:, 3] << 12)
weight = weight.view(oc // 4, 4, ic // 64, 16).permute(0, 2, 1, 3).reshape(oc // 4, ic)
return weight.to(torch.int16)
def convert_to_tinychat_w4x16y16_linear_weight(
weight: torch.Tensor,
scale: torch.Tensor,
zero: torch.Tensor,
group_size: int = -1,
zero_pre_scaled: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Convert a weight tensor to TinyChat W4-X16-Y16 linear weight format.
Args:
weight (`torch.Tensor`):
weight tensor to be converted.
scale (`torch.Tensor`):
scale tensor for the weight tensor.
zero (`torch.Tensor`):
zero point tensor for the weight tensor.
group_size (`int`, *optional*, defaults to `-1`):
quantization group size.
zero_pre_scaled (`bool`, *optional*, defaults to `False`):
whether zero point tensor is pre-scaled.
Returns:
`tuple[torch.Tensor, torch.Tensor, torch.Tensor]`:
packed quantized weight tensor, scale tensor, and zero point tensor.
"""
dtype, device = weight.dtype, weight.device
assert dtype in (torch.float16, torch.bfloat16), "currently tinychat only supports fp16 and bf16."
assert scale is not None, "scale tensor is required for quantization."
assert zero is not None, "zero point tensor is required for quantization."
weight = weight.to(dtype=torch.float32)
scale = scale.to(dtype=torch.float32, device=device)
zero = zero.to(dtype=torch.float32, device=device)
if zero_pre_scaled:
zero = zero * scale
oc, ic = weight.shape
group_size = ic if group_size <= 0 else group_size
assert group_size <= ic, "group size should be less than or equal to input channel size."
assert ic % group_size == 0, "input channel size should be divisible by group size."
ng = ic // group_size
if scale.numel() == 1:
scale = scale.view(1, 1).expand(oc, ng)
scale = scale.reshape(oc, ng).contiguous().view(oc, ng, 1)
if zero.numel() == 1:
zero = zero.view(1, 1).expand(oc, ng)
zero = zero.reshape(oc, ng).contiguous().view(oc, ng, 1)
weight = weight.view(oc, ng, -1).add_(zero).div_(scale).round_().view(oc, ic)
_weight = pack_w4(weight.to(torch.int32))
_ng = ceil_num_groups(ic, group_size, weight_bits=4)
_scale = torch.zeros((_ng, oc), dtype=dtype, device=device)
_zero = torch.zeros((_ng, oc), dtype=dtype, device=device)
_scale[:ng] = scale.view(oc, ng).t().to(dtype=dtype)
_zero[:ng] = zero.view(oc, ng).t().to(dtype=dtype).neg_()
return _weight, _scale, _zero
def convert_to_tinychat_w4x16y16_linear_state_dict(
param_name: str,
weight: torch.Tensor,
scale: torch.Tensor,
zero: torch.Tensor,
group_size: int = -1,
zero_pre_scaled: bool = False,
) -> dict[str, torch.Tensor]:
"""Convert a weight tensor to TinyChat W4-X16-Y16 linear state dictionary.
Args:
param_name (`str`):
parameter name.
weight (`torch.Tensor`):
weight tensor to be converted.
scale (`torch.Tensor`):
scale tensor for the weight tensor.
zero (`torch.Tensor`):
zero point tensor for the weight tensor.
group_size (`int`, *optional*, defaults to `-1`):
quantization group size.
zero_pre_scaled (`bool`, *optional*, defaults to `False`):
whether zero point tensor is pre-scaled.
Returns:
`dict[str, torch.Tensor]`:
state dictionary for the quantized weight tensor.
"""
module_name = param_name[:-7]
weight, scale, zero = convert_to_tinychat_w4x16y16_linear_weight(
weight, scale=scale, zero=zero, group_size=group_size, zero_pre_scaled=zero_pre_scaled
)
state_dict: dict[str, torch.Tensor] = {}
state_dict[f"{module_name}.qweight"] = weight.cpu()
state_dict[f"{module_name}.scales"] = scale.cpu()
state_dict[f"{module_name}.scaled_zeros"] = zero.cpu()
return state_dict
def convert_to_tinychat_state_dict(
state_dict: dict[str, torch.Tensor],
scale_dict: dict[str, torch.Tensor],
group_size: int = -1,
) -> dict[str, torch.Tensor]:
scales: dict[str, dict[tuple[int, ...], torch.Tensor]] = {}
zeros: dict[str, tuple[torch.Tensor | None, bool]] = {}
print("Loading scale tensors...")
for name, tensor in tqdm.tqdm(scale_dict.items(), desc="Loading scale tensors", leave=False, dynamic_ncols=True):
print(f" - Loading tensor {name} (dtype: {tensor.dtype}, shape: {tensor.shape}, device: {tensor.device})")
if name.endswith("zero"):
# this is a zero point tensor
zero = None if tensor is None or all(t.item() == 0 for t in tensor.flatten()) else tensor
if name.endswith(".scaled_zero"):
zeros[name[:-12]] = (zero, False) # zero point tensor is post-scaled
else:
zeros[name[:-5]] = (zero, True) # zero point tensor is pre-scaled
else:
assert ".weight.scale" in name
# this is a scale tensor
idx = name.index(".weight.scale")
param_name = name[: idx + 7]
scale_level = tuple(map(int, name[idx + 14 :].split(".")))
scales.setdefault(param_name, {})[scale_level] = tensor
for param_name in zeros.keys():
assert param_name in state_dict, f"zero point tensor {param_name} not found in state dict."
assert param_name in scales, f"scale tensor {param_name} not found in scale dict."
converted: dict[str, torch.Tensor] = {}
print("Converting state dict...")
for param_name, param in tqdm.tqdm(state_dict.items(), desc="Converting state dict", dynamic_ncols=True):
if param_name in scales:
print(f" - Converting {param_name} (dtype: {param.dtype}, shape: {param.shape}, device: {param.device})")
weight = param.data.clone()
if param_name in zeros:
zero, zero_pre_scaled = zeros[param_name]
zero = zero.clone() if zero is not None else None
else:
zero, zero_pre_scaled = None, False
level_scales = sorted(scales[param_name].items(), key=lambda x: x[0])
assert len(level_scales) == 1, "more than one scale levels are not supported."
scale = level_scales[0][1].clone()
converted.update(
convert_to_tinychat_w4x16y16_linear_state_dict(
param_name,
weight,
scale=scale,
zero=zero,
group_size=group_size,
zero_pre_scaled=zero_pre_scaled,
)
)
else:
if isinstance(param, torch.Tensor):
print(f" - Copying {param_name} (dtype: {param.dtype}, shape: {param.shape}, device: {param.device})")
converted[param_name] = param.clone().cpu()
else:
print(f" - Copying {param_name} (type: {type(param)}, value: {param})")
converted[param_name] = param
return converted
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--quant-path", type=str, required=True, help="path to the quantization checkpoint directory.")
parser.add_argument("--group-size", type=int, default=-1, help="quantization group size.")
parser.add_argument("--output-root", type=str, default="", help="root to the output checkpoint directory.")
parser.add_argument("--model-name", type=str, default=None, help="model name.")
parser.add_argument("--model-path", type=str, default=None, help="path to the huggingface model directory.")
parser.add_argument("--copy-on-save", action="store_true", help="copy files on save.")
args = parser.parse_args()
if not args.output_root:
args.output_root = args.quant_path
if args.model_name is None:
assert args.model_path is not None, "model name or path is required."
model_name = args.model_path.rstrip(os.sep).split(os.sep)[-1]
print(f"Model name not provided. Using model name {model_name}.")
else:
model_name = args.model_name
state_dict = torch.load(
os.path.join(args.quant_path, "model.pt"),
map_location="cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu",
)
scale_dict = torch.load(os.path.join(args.quant_path, "scale.pt"), map_location="cpu")
converted = convert_to_tinychat_state_dict(state_dict, scale_dict, group_size=args.group_size)
model_name = f"{args.model_name}-w4a16"
model_name += f"-g{args.group_size}" if args.group_size > 0 else "-gchn"
output_dirpath = os.path.join(args.output_root, model_name)
os.makedirs(output_dirpath, exist_ok=True)
if args.model_path and os.path.exists(args.model_path):
output_path = os.path.join(output_dirpath, "model.safetensors")
safetensors.torch.save_file(converted, output_path)
print(f"Quantized model checkpoint saved to {output_path}.")
for filename in os.listdir(args.model_path):
if filename == "tokenizer.model" or (
filename.endswith(".json") and filename != "pytorch_model.bin.index.json"
):
filepath = os.path.abspath(os.path.join(args.model_path, filename))
if args.copy_on_save:
os.system(f"cp {filepath} {output_dirpath}/")
else:
os.system(f"ln -s {filepath} {output_dirpath}/{filename}")
else:
output_path = os.path.join(output_dirpath, "tinychat-v2.pt")
torch.save(converted, output_path)
print(f"Quantized model checkpoint saved to {output_path}.")
print(f"Quantized model saved to {output_dirpath}.")
# -*- coding: utf-8 -*-
"""TinyChat Quantized Linear Module"""
import warnings
import torch
import torch.nn as nn
from nunchaku.csrc.load import _C
from .tinychat_utils import ceil_num_groups, convert_to_tinychat_w4x16y16_linear_weight
__all__ = ["W4Linear"]
warnings.warn(
"Module `tinychat.linear` will be moved to `Nunchaku` and deprecated in the future release.",
DeprecationWarning,
stacklevel=2,
)
class W4Linear(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = False,
group_size: int = 128,
dtype: torch.dtype = torch.float16,
device: str | torch.device = "cuda",
):
super().__init__()
assert dtype in (torch.float16, torch.bfloat16), f"Unsupported dtype: {dtype}"
self.in_features = in_features
self.out_features = out_features
self.group_size = group_size if group_size != -1 else in_features
assert self.in_features % self.group_size == 0
assert out_features % (32 // self.weight_bits) == 0
self.ceil_num_groups = ceil_num_groups(
in_features=self.in_features,
group_size=self.group_size,
weight_bits=self.weight_bits,
)
assert out_features % (self.interleave) == 0
self.register_buffer(
"qweight",
torch.zeros(
(
self.out_features // self.interleave,
self.in_features // (16 // self.weight_bits) * self.interleave,
),
dtype=torch.int16,
device=device,
),
)
self.register_buffer(
"scales",
torch.zeros((self.ceil_num_groups, self.out_features), dtype=dtype, device=device),
)
self.register_buffer(
"scaled_zeros",
torch.zeros((self.ceil_num_groups, self.out_features), dtype=dtype, device=device),
)
if bias:
self.register_buffer("bias", torch.zeros((out_features), dtype=dtype, device=device))
else:
self.bias = None
@property
def weight_bits(self) -> int:
return 4
@property
def interleave(self) -> int:
return 4
@torch.no_grad()
def forward(self, x):
if x.numel() / x.shape[-1] < 8:
out = _C.awq_gemv_forward_cuda(
x,
self.qweight,
self.scales,
self.scaled_zeros,
x.numel() // x.shape[-1],
self.out_features,
self.in_features,
self.group_size,
)
else:
out = _C.awq_gemm_forward_cuda(x, self.qweight, self.scales, self.scaled_zeros)
out = out + self.bias if self.bias is not None else out
return out
@staticmethod
def from_linear(
linear: nn.Linear,
group_size: int,
init_only: bool = False,
weight: torch.Tensor | None = None,
scale: torch.Tensor | None = None,
zero: torch.Tensor | None = None,
zero_pre_scaled: bool = False,
) -> "W4Linear":
"""Convert a linear layer to a TinyChat 4-bit weight-only quantized linear layer.
Args:
linear (`nn.Linear`):
linear layer to be converted.
group_size (`int`):
quantization group size.
init_only (`bool`, *optional*, defaults to `False`):
whether to only initialize the quantized linear layer.
weight (`torch.Tensor`, *optional*, defaults to `None`):
weight tensor for the quantized linear layer.
scale (`torch.Tensor`, *optional*, defaults to `None`):
scale tensor for the quantized linear layer.
zero (`torch.Tensor`, *optional*, defaults to `None`):
zero point tensor for the quantized linear layer.
zero_pre_scaled (`bool`, *optional*, defaults to `False`):
whether zero point tensor is pre-scaled.
Returns:
`W4Linear`:
quantized linear layer.
"""
assert isinstance(linear, nn.Linear)
weight = linear.weight.data if weight is None else weight.data
dtype, device = weight.dtype, weight.device
oc, ic = linear.out_features, linear.in_features
_linear = W4Linear(
in_features=ic,
out_features=oc,
bias=linear.bias is not None,
group_size=group_size,
dtype=dtype,
device=device,
)
if init_only:
return _linear
if linear.bias is not None:
_linear.bias.data.copy_(linear.bias.data)
if scale is None:
assert zero is None, "scale and zero point tensors should be provided together."
group_size = ic if group_size <= 0 else group_size
assert group_size <= ic, "group size should be less than or equal to input channel size."
assert ic % group_size == 0, "input channel size should be divisible by group size."
ng, gs = ic // group_size, group_size
weight = weight.to(dtype=torch.float32).view(oc, 1, ng, gs)
vmin, vmax = weight.amin(dim=-1, keepdim=True), weight.amax(dim=-1, keepdim=True)
scale = (vmax - vmin).div_(15)
scale[scale == 0] = 1.0
if zero_pre_scaled:
zero = vmin.neg_().div_(scale).round_().clamp_(0, 15)
weight = weight.div_(scale).add_(zero).round_().clamp_(0, 15).sub_(zero).mul_(scale)
else:
zero = vmin.neg_().clamp_min(0)
weight = weight.add_(zero).div_(scale).round_().clamp_(0, 15).mul_(scale).sub_(zero)
weight = weight.to(dtype=dtype).view(oc, ic)
scale = scale.to(dtype=dtype)
zero = zero.to(dtype=dtype)
weight, scale, zero = convert_to_tinychat_w4x16y16_linear_weight(
weight=weight,
scale=scale,
zero=zero,
group_size=group_size,
zero_pre_scaled=zero_pre_scaled,
)
_linear.qweight.data.copy_(weight)
_linear.scales.data.copy_(scale)
_linear.scaled_zeros.data.copy_(zero)
return _linear
def extra_repr(self) -> str:
return "in_features={}, out_features={}, bias={}, weight_bits={}, group_size={}".format(
self.in_features,
self.out_features,
self.bias is not None,
self.weight_bits,
self.group_size,
)
import os import os
import torch import torch
from deepcompressor.backend.tinychat.linear import W4Linear from nunchaku.models.linear import W4Linear
from huggingface_hub import constants, hf_hub_download from huggingface_hub import constants, hf_hub_download
from safetensors.torch import load_file from safetensors.torch import load_file
from torch import nn from torch import nn
......
# -*- coding: utf-8 -*-
"""TinyChat backend utilities."""
import torch
__all__ = ["ceil_num_groups", "convert_to_tinychat_w4x16y16_linear_weight"]
def ceil_divide(x: int, divisor: int) -> int:
"""Ceiling division.
Args:
x (`int`):
dividend.
divisor (`int`):
divisor.
Returns:
`int`:
ceiling division result.
"""
return (x + divisor - 1) // divisor
def ceil_num_groups(in_features: int, group_size: int, weight_bits: int = 4) -> int:
"""Calculate the ceiling number of quantization groups.
Args:
in_features (`int`):
input channel size.
group_size (`int`):
quantization group size.
weight_bits (`int`, *optional*, defaults to `4`):
quantized weight bits.
Returns:
`int`:
ceiling number of quantization groups.
"""
assert in_features % group_size == 0, "input channel size should be divisible by group size."
num_groups = in_features // group_size
assert weight_bits in (4, 2, 1), "weight bits should be 4, 2, or 1."
pack_size = 32 // weight_bits # one INT32 contains `pack_size` elements of weights
num_packs = ceil_divide(num_groups, pack_size)
if group_size >= 128:
num_packs_factor = 1
elif group_size == 64:
num_packs_factor = 2
elif group_size == 32:
num_packs_factor = 4
else:
raise NotImplementedError
# make sure num_packs is a multiple of num_packs_factor
num_packs = ceil_divide(num_packs, num_packs_factor) * num_packs_factor
num_groups = num_packs * pack_size
return num_groups
def pack_w4(weight: torch.Tensor) -> torch.Tensor:
assert weight.dtype == torch.int32, f"quantized weight should be torch.int32, but got {weight.dtype}."
oc, ic = weight.shape
assert ic % 32 == 0, "input channel size should be divisible by 32."
# [0, 1, ..., 31] -> [0, 8, 16, 24, 1, 9, 17, 25, ..., 7, 15, 23, 31]
weight = weight.view(-1, 4, 8)
weight = weight[:, 0] | (weight[:, 1] << 4) | (weight[:, 2] << 8) | (weight[:, 3] << 12)
weight = weight.view(oc // 4, 4, ic // 64, 16).permute(0, 2, 1, 3).reshape(oc // 4, ic)
return weight.to(torch.int16)
def convert_to_tinychat_w4x16y16_linear_weight(
weight: torch.Tensor,
scale: torch.Tensor,
zero: torch.Tensor,
group_size: int = -1,
zero_pre_scaled: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Convert a weight tensor to TinyChat W4-X16-Y16 linear weight format.
Args:
weight (`torch.Tensor`):
weight tensor to be converted.
scale (`torch.Tensor`):
scale tensor for the weight tensor.
zero (`torch.Tensor`):
zero point tensor for the weight tensor.
group_size (`int`, *optional*, defaults to `-1`):
quantization group size.
zero_pre_scaled (`bool`, *optional*, defaults to `False`):
whether zero point tensor is pre-scaled.
Returns:
`tuple[torch.Tensor, torch.Tensor, torch.Tensor]`:
packed quantized weight tensor, scale tensor, and zero point tensor.
"""
dtype, device = weight.dtype, weight.device
assert dtype in (torch.float16, torch.bfloat16), "currently tinychat only supports fp16 and bf16."
assert scale is not None, "scale tensor is required for quantization."
assert zero is not None, "zero point tensor is required for quantization."
weight = weight.to(dtype=torch.float32)
scale = scale.to(dtype=torch.float32, device=device)
zero = zero.to(dtype=torch.float32, device=device)
if zero_pre_scaled:
zero = zero * scale
oc, ic = weight.shape
group_size = ic if group_size <= 0 else group_size
assert group_size <= ic, "group size should be less than or equal to input channel size."
assert ic % group_size == 0, "input channel size should be divisible by group size."
ng = ic // group_size
if scale.numel() == 1:
scale = scale.view(1, 1).expand(oc, ng)
scale = scale.reshape(oc, ng).contiguous().view(oc, ng, 1)
if zero.numel() == 1:
zero = zero.view(1, 1).expand(oc, ng)
zero = zero.reshape(oc, ng).contiguous().view(oc, ng, 1)
weight = weight.view(oc, ng, -1).add_(zero).div_(scale).round_().view(oc, ic)
assert weight.min() >= 0 and weight.max() <= 15, "quantized weight should be in [0, 15]."
_weight = pack_w4(weight.to(torch.int32))
_ng = ceil_num_groups(ic, group_size, weight_bits=4)
_scale = torch.zeros((_ng, oc), dtype=dtype, device=device)
_zero = torch.zeros((_ng, oc), dtype=dtype, device=device)
_scale[:ng] = scale.view(oc, ng).t().to(dtype=dtype)
_zero[:ng] = zero.view(oc, ng).t().to(dtype=dtype).neg_()
return _weight, _scale, _zero
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment