Commit 873a35be authored by muyangli's avatar muyangli
Browse files

v0.1.4 ready to release


Co-authored-by: default avatarZhekai Zhang <sxtyzhangzk@gmail.com>
Co-authored-by: default avatarMuyang Li <lmxyy1999@foxmail.com>
Co-authored-by: default avatarYujun Lin <16437040+synxlin@users.noreply.github.com>
parent d9cd6858
#pragma once
#include <torch/extension.h>
torch::Tensor awq_gemv_forward_cuda(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int m,
int n,
int k,
int group_size);
// Adated from FasterTransformer, https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
#pragma once
#include <assert.h>
#include <stdint.h>
#include <float.h>
#include <type_traits>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#define ENABLE_BF16 1
template <typename T>
struct to_cpp_t;
template <>
struct to_cpp_t<at::Half>
{
using type = half;
};
template <>
struct to_cpp_t<at::BFloat16>
{
using type = __nv_bfloat16;
};
template <typename T>
struct num_elems;
template <>
struct num_elems<float>
{
static constexpr int value = 1;
};
template <>
struct num_elems<float2>
{
static constexpr int value = 2;
};
template <>
struct num_elems<float4>
{
static constexpr int value = 4;
};
template <>
struct num_elems<half>
{
static constexpr int value = 1;
};
template <>
struct num_elems<half2>
{
static constexpr int value = 2;
};
#ifdef ENABLE_BF16
template <>
struct num_elems<__nv_bfloat16>
{
static constexpr int value = 1;
};
template <>
struct num_elems<__nv_bfloat162>
{
static constexpr int value = 2;
};
#endif
template <typename T, int num>
struct packed_as;
template <typename T>
struct packed_as<T, 1>
{
using type = T;
};
template <>
struct packed_as<half, 2>
{
using type = half2;
};
template <>
struct packed_as<float, 2>
{
using type = float2;
};
template <>
struct packed_as<int8_t, 2>
{
using type = int16_t;
};
template <>
struct packed_as<int32_t, 2>
{
using type = int2;
};
template <>
struct packed_as<half2, 1>
{
using type = half;
};
template <>
struct packed_as<float2, 1>
{
using type = float;
};
#ifdef ENABLE_BF16
template <>
struct packed_as<__nv_bfloat16, 2>
{
using type = __nv_bfloat162;
};
template <>
struct packed_as<__nv_bfloat162, 1>
{
using type = __nv_bfloat16;
};
#endif
#ifdef ENABLE_FP8
template <>
struct packed_as<__nv_fp8_e4m3, 2>
{
using type = __nv_fp8x2_e4m3;
};
template <>
struct packed_as<__nv_fp8x2_e4m3, 1>
{
using type = __nv_fp8_e4m3;
};
template <>
struct packed_as<__nv_fp8_e5m2, 2>
{
using type = __nv_fp8x2_e5m2;
};
template <>
struct packed_as<__nv_fp8x2_e5m2, 1>
{
using type = __nv_fp8_e5m2;
};
#endif
template <typename f16_t>
__device__ __forceinline__
packed_as<f16_t, 2>::type
f162f162(f16_t x);
template <>
__device__ __forceinline__
packed_as<half, 2>::type
f162f162<half>(half x)
{
return __half2half2(x);
}
#ifdef ENABLE_BF16
template <>
__device__ __forceinline__
packed_as<__nv_bfloat16, 2>::type
f162f162<__nv_bfloat16>(__nv_bfloat16 x)
{
return __bfloat162bfloat162(x);
}
# endif
template <typename T>
__device__ __forceinline__
float2
f1622float2(T val);
template <>
__device__ __forceinline__
float2
f1622float2<half2>(half2 val)
{
return __half22float2(val);
}
#ifdef ENABLE_BF16
template <>
__device__ __forceinline__
float2
f1622float2<__nv_bfloat162>(__nv_bfloat162 val)
{
return __bfloat1622float2(val);
}
# endif
inline __device__ float2 operator*(float2 a, float2 b) { return make_float2(a.x * b.x, a.y * b.y); }
inline __device__ float2 operator+(float2 a, float2 b) { return make_float2(a.x + b.x, a.y + b.y); }
inline __device__ float2 operator-(float2 a, float2 b) { return make_float2(a.x - b.x, a.y - b.y); }
inline __device__ float2 operator*(float2 a, float b) { return make_float2(a.x * b, a.y * b); }
inline __device__ float2 operator+(float2 a, float b) { return make_float2(a.x + b, a.y + b); }
inline __device__ float2 operator-(float2 a, float b) { return make_float2(a.x - b, a.y - b); }
static inline __device__ int8_t float_to_int8_rn(float x)
{
uint32_t dst;
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x));
return reinterpret_cast<const int8_t &>(dst);
}
template <typename T>
inline __device__ T ldg(const T *val)
{
return __ldg(val);
}
#if ENABLE_BF16
#define float22bf162 __float22bfloat162_rn
inline __device__ int16_t bf1622int16(__nv_bfloat162 val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float2 f_val;
f_val.x = max(min(__low2float(val), 127.f), -128.f);
f_val.y = max(min(__high2float(val), 127.f), -128.f);
union
{
int8_t int8[2];
int16_t int16;
};
int8[0] = static_cast<int8_t>(static_cast<short>(f_val.x));
int8[1] = static_cast<int8_t>(static_cast<short>(f_val.y));
return int16;
#else
val = __hmin2(val, make_bfloat162(127., 127.));
val = __hmax2(val, make_bfloat162(-128., -128.));
union
{
int8_t int8[2];
int16_t int16;
};
int8[0] = static_cast<int8_t>(static_cast<short>(val.x));
int8[1] = static_cast<int8_t>(static_cast<short>(val.y));
return int16;
#endif
}
#endif
#if ENABLE_BF16
template <>
inline __device__ __nv_bfloat162 ldg(const __nv_bfloat162 *val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return val[0];
#else
return __ldg(val);
#endif
}
template <>
inline __device__ __nv_bfloat16 ldg(const __nv_bfloat16 *val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return val[0];
#else
return __ldg(val);
#endif
}
#endif // ENABLE_BF16
template <typename T_OUT, typename T_IN>
__device__ inline T_OUT cuda_cast(T_IN val)
{
return val;
}
template <>
__device__ inline float2 cuda_cast<float2, int2>(int2 val)
{
return make_float2(val.x, val.y);
}
template <>
__device__ inline float2 cuda_cast<float2, float>(float val)
{
return make_float2(val, val);
}
template <>
__device__ inline float2 cuda_cast<float2, half2>(half2 val)
{
return __half22float2(val);
}
template <>
__device__ inline half2 cuda_cast<half2, float2>(float2 val)
{
return __float22half2_rn(val);
}
template <>
__device__ inline half2 cuda_cast<half2, float>(float val)
{
return __float2half2_rn(val);
}
template <>
__device__ inline half2 cuda_cast<half2, half>(half val)
{
return __half2half2(val);
}
template <>
__device__ inline int8_t cuda_cast<int8_t, half>(half val)
{
union
{
int8_t int8[2];
int16_t int16;
};
union
{
half fp16;
int16_t int16_in;
};
fp16 = val;
asm volatile("cvt.rni.sat.s8.f16 %0, %1;" : "=h"(int16) : "h"(int16_in));
return int8[0];
}
template <>
__device__ inline int16_t cuda_cast<int16_t, half2>(half2 val)
{
union
{
int8_t int8[2];
int16_t int16;
};
int8[0] = cuda_cast<int8_t>(val.x);
int8[1] = cuda_cast<int8_t>(val.y);
return int16;
}
template <>
__device__ inline int8_t cuda_cast<int8_t, float>(float val)
{
union
{
int8_t int8[2];
int16_t int16;
};
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val));
return int8[0];
}
template <>
__device__ inline int16_t cuda_cast<int16_t, float2>(float2 val)
{
union
{
int8_t int8[2];
int16_t int16;
};
int8[0] = cuda_cast<int8_t>(val.x);
int8[1] = cuda_cast<int8_t>(val.y);
return int16;
}
template <>
__device__ inline half2 cuda_cast<half2, int16_t>(int16_t val)
{
union
{
int8_t int8[2];
int16_t int16;
};
int16 = val;
return make_half2(int8[0], int8[1]);
}
template <>
__device__ inline float2 cuda_cast<float2, int16_t>(int16_t val)
{
union
{
int8_t int8[2];
int16_t int16;
};
int16 = val;
return make_float2(int8[0], int8[1]);
}
#ifdef ENABLE_BF16
template <>
__device__ inline __nv_bfloat16 cuda_cast(int32_t val)
{
return static_cast<float>(val);
}
template <>
__device__ inline __nv_bfloat16 cuda_cast(int8_t val)
{
return static_cast<float>(val);
}
template <>
__device__ inline int8_t cuda_cast(__nv_bfloat16 val)
{
return static_cast<float>(val);
}
template <>
__device__ inline float cuda_cast<float, __nv_bfloat16>(__nv_bfloat16 val)
{
return __bfloat162float(val);
}
template <>
__device__ inline float2 cuda_cast<float2, __nv_bfloat162>(__nv_bfloat162 val)
{
return __bfloat1622float2(val);
}
template <>
__device__ inline half cuda_cast<half, __nv_bfloat16>(__nv_bfloat16 val)
{
return __float2half(__bfloat162float(val));
}
template <>
__device__ inline int16_t cuda_cast<int16_t, __nv_bfloat162>(__nv_bfloat162 val)
{
return bf1622int16(val);
}
template <>
__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val)
{
return __float2bfloat16(val);
}
template <>
__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, half>(half val)
{
return __float2bfloat16(__half2float(val));
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_bfloat16>(__nv_bfloat16 val)
{
return __bfloat162bfloat162(val);
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float>(float val)
{
return __float2bfloat162_rn(val);
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float2>(float2 val)
{
return float22bf162(val);
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, int16_t>(int16_t val)
{
union
{
int8_t int8[2];
int16_t int16;
};
int16 = val;
__nv_bfloat162 res;
res.x = cuda_cast<__nv_bfloat16>(int8[0]);
res.y = cuda_cast<__nv_bfloat16>(int8[1]);
return res;
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, half2>(half2 val)
{
return float22bf162(__half22float2(val));
}
#endif // ENABLE BF16
template <typename To, typename Ti>
__device__ inline To cuda_sum(Ti val)
{
return cuda_cast<To>(val);
};
template <typename To>
__device__ inline To cuda_sum(float2 val)
{
return cuda_cast<To>(val.x + val.y);
};
// Unary maximum: compute the max of a vector type
template <typename To, typename Ti>
__device__ inline To cuda_max(Ti val)
{
return cuda_cast<To>(val);
};
template <>
__device__ inline float cuda_max(float2 val)
{
return fmaxf(val.x, val.y);
}
template <>
__device__ inline half cuda_max(half2 val)
{
return __hmax(val.x, val.y);
}
#ifdef ENABLE_BF16
template <>
__device__ inline __nv_bfloat16 cuda_max(__nv_bfloat162 val)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
return __hmax(val.x, val.y);
#endif
}
#endif
// Binary maximum: compute the max of two scalar types
template <typename T>
__device__ inline T cuda_max(T val1, T val2)
{
return (val1 > val2) ? val1 : val2;
}
template <typename T>
__device__ inline T cuda_abs(T val)
{
assert(false);
return {};
}
template <>
__device__ inline float cuda_abs(float val)
{
return fabs(val);
}
template <>
__device__ inline float2 cuda_abs(float2 val)
{
return make_float2(fabs(val.x), fabs(val.y));
}
template <>
__device__ inline half cuda_abs(half val)
{
return __habs(val);
}
template <>
__device__ inline half2 cuda_abs(half2 val)
{
return __habs2(val);
}
#ifdef ENABLE_BF16
#if __CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)
template <>
__device__ inline __nv_bfloat16 cuda_abs(__nv_bfloat16 val)
{
return __habs(val);
}
template <>
__device__ inline __nv_bfloat162 cuda_abs(__nv_bfloat162 val)
{
return __habs2(val);
}
#endif
#endif
from .comfyui_converter import comfyui2diffusers
from .diffusers_converter import convert_to_nunchaku_flux_lowrank_dict
from .utils import detect_format
from .xlab_converter import xlab2diffusers
...@@ -362,6 +362,11 @@ def convert_to_nunchaku_flux_lowrank_dict( ...@@ -362,6 +362,11 @@ def convert_to_nunchaku_flux_lowrank_dict(
else: else:
extra_lora_dict = filter_state_dict(lora, filter_prefix="transformer.") extra_lora_dict = filter_state_dict(lora, filter_prefix="transformer.")
unquantized_lora_dict = {}
for k in list(extra_lora_dict.keys()):
if "transformer_blocks" not in k:
unquantized_lora_dict[k] = extra_lora_dict.pop(k)
for k in extra_lora_dict.keys(): for k in extra_lora_dict.keys():
fc1_k = k fc1_k = k
if "ff.net.0.proj" in k: if "ff.net.0.proj" in k:
...@@ -408,4 +413,5 @@ def convert_to_nunchaku_flux_lowrank_dict( ...@@ -408,4 +413,5 @@ def convert_to_nunchaku_flux_lowrank_dict(
prefix=block_name, prefix=block_name,
) )
converted.update(unquantized_lora_dict)
return converted return converted
from .text_encoders.t5_encoder import NunchakuT5EncoderModel
from .transformers import NunchakuFluxTransformer2dModel, NunchakuSanaTransformer2DModel
# -*- coding: utf-8 -*-
"""QServe state dict converter module."""
import argparse
import os
import safetensors.torch
import torch
import tqdm
def ceil_divide(x: int, divisor: int) -> int:
"""Ceiling division.
Args:
x (`int`):
dividend.
divisor (`int`):
divisor.
Returns:
`int`:
ceiling division result.
"""
return (x + divisor - 1) // divisor
def ceil_num_groups(in_features: int, group_size: int, weight_bits: int = 4) -> int:
"""Calculate the ceiling number of quantization groups.
Args:
in_features (`int`):
input channel size.
group_size (`int`):
quantization group size.
weight_bits (`int`, *optional*, defaults to `4`):
quantized weight bits.
Returns:
`int`:
ceiling number of quantization groups.
"""
assert in_features % group_size == 0, "input channel size should be divisible by group size."
num_groups = in_features // group_size
assert weight_bits in (4, 2, 1), "weight bits should be 4, 2, or 1."
pack_size = 32 // weight_bits # one INT32 contains `pack_size` elements of weights
num_packs = ceil_divide(num_groups, pack_size)
if group_size >= 128:
num_packs_factor = 1
elif group_size == 64:
num_packs_factor = 2
elif group_size == 32:
num_packs_factor = 4
else:
raise NotImplementedError
# make sure num_packs is a multiple of num_packs_factor
num_packs = ceil_divide(num_packs, num_packs_factor) * num_packs_factor
num_groups = num_packs * pack_size
return num_groups
def pack_w4(weight: torch.Tensor) -> torch.Tensor:
assert weight.dtype == torch.int32, f"quantized weight should be torch.int32, but got {weight.dtype}."
oc, ic = weight.shape
assert ic % 32 == 0, "input channel size should be divisible by 32."
# [0, 1, ..., 31] -> [0, 8, 16, 24, 1, 9, 17, 25, ..., 7, 15, 23, 31]
weight = weight.view(-1, 4, 8)
weight = weight[:, 0] | (weight[:, 1] << 4) | (weight[:, 2] << 8) | (weight[:, 3] << 12)
weight = weight.view(oc // 4, 4, ic // 64, 16).permute(0, 2, 1, 3).reshape(oc // 4, ic)
return weight.to(torch.int16)
def convert_to_tinychat_w4x16y16_linear_weight(
weight: torch.Tensor,
scale: torch.Tensor,
zero: torch.Tensor,
group_size: int = -1,
zero_pre_scaled: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Convert a weight tensor to TinyChat W4-X16-Y16 linear weight format.
Args:
weight (`torch.Tensor`):
weight tensor to be converted.
scale (`torch.Tensor`):
scale tensor for the weight tensor.
zero (`torch.Tensor`):
zero point tensor for the weight tensor.
group_size (`int`, *optional*, defaults to `-1`):
quantization group size.
zero_pre_scaled (`bool`, *optional*, defaults to `False`):
whether zero point tensor is pre-scaled.
Returns:
`tuple[torch.Tensor, torch.Tensor, torch.Tensor]`:
packed quantized weight tensor, scale tensor, and zero point tensor.
"""
dtype, device = weight.dtype, weight.device
assert dtype in (torch.float16, torch.bfloat16), "currently tinychat only supports fp16 and bf16."
assert scale is not None, "scale tensor is required for quantization."
assert zero is not None, "zero point tensor is required for quantization."
weight = weight.to(dtype=torch.float32)
scale = scale.to(dtype=torch.float32, device=device)
zero = zero.to(dtype=torch.float32, device=device)
if zero_pre_scaled:
zero = zero * scale
oc, ic = weight.shape
group_size = ic if group_size <= 0 else group_size
assert group_size <= ic, "group size should be less than or equal to input channel size."
assert ic % group_size == 0, "input channel size should be divisible by group size."
ng = ic // group_size
if scale.numel() == 1:
scale = scale.view(1, 1).expand(oc, ng)
scale = scale.reshape(oc, ng).contiguous().view(oc, ng, 1)
if zero.numel() == 1:
zero = zero.view(1, 1).expand(oc, ng)
zero = zero.reshape(oc, ng).contiguous().view(oc, ng, 1)
weight = weight.view(oc, ng, -1).add_(zero).div_(scale).round_().view(oc, ic)
_weight = pack_w4(weight.to(torch.int32))
_ng = ceil_num_groups(ic, group_size, weight_bits=4)
_scale = torch.zeros((_ng, oc), dtype=dtype, device=device)
_zero = torch.zeros((_ng, oc), dtype=dtype, device=device)
_scale[:ng] = scale.view(oc, ng).t().to(dtype=dtype)
_zero[:ng] = zero.view(oc, ng).t().to(dtype=dtype).neg_()
return _weight, _scale, _zero
def convert_to_tinychat_w4x16y16_linear_state_dict(
param_name: str,
weight: torch.Tensor,
scale: torch.Tensor,
zero: torch.Tensor,
group_size: int = -1,
zero_pre_scaled: bool = False,
) -> dict[str, torch.Tensor]:
"""Convert a weight tensor to TinyChat W4-X16-Y16 linear state dictionary.
Args:
param_name (`str`):
parameter name.
weight (`torch.Tensor`):
weight tensor to be converted.
scale (`torch.Tensor`):
scale tensor for the weight tensor.
zero (`torch.Tensor`):
zero point tensor for the weight tensor.
group_size (`int`, *optional*, defaults to `-1`):
quantization group size.
zero_pre_scaled (`bool`, *optional*, defaults to `False`):
whether zero point tensor is pre-scaled.
Returns:
`dict[str, torch.Tensor]`:
state dictionary for the quantized weight tensor.
"""
module_name = param_name[:-7]
weight, scale, zero = convert_to_tinychat_w4x16y16_linear_weight(
weight, scale=scale, zero=zero, group_size=group_size, zero_pre_scaled=zero_pre_scaled
)
state_dict: dict[str, torch.Tensor] = {}
state_dict[f"{module_name}.qweight"] = weight.cpu()
state_dict[f"{module_name}.scales"] = scale.cpu()
state_dict[f"{module_name}.scaled_zeros"] = zero.cpu()
return state_dict
def convert_to_tinychat_state_dict(
state_dict: dict[str, torch.Tensor],
scale_dict: dict[str, torch.Tensor],
group_size: int = -1,
) -> dict[str, torch.Tensor]:
scales: dict[str, dict[tuple[int, ...], torch.Tensor]] = {}
zeros: dict[str, tuple[torch.Tensor | None, bool]] = {}
print("Loading scale tensors...")
for name, tensor in tqdm.tqdm(scale_dict.items(), desc="Loading scale tensors", leave=False, dynamic_ncols=True):
print(f" - Loading tensor {name} (dtype: {tensor.dtype}, shape: {tensor.shape}, device: {tensor.device})")
if name.endswith("zero"):
# this is a zero point tensor
zero = None if tensor is None or all(t.item() == 0 for t in tensor.flatten()) else tensor
if name.endswith(".scaled_zero"):
zeros[name[:-12]] = (zero, False) # zero point tensor is post-scaled
else:
zeros[name[:-5]] = (zero, True) # zero point tensor is pre-scaled
else:
assert ".weight.scale" in name
# this is a scale tensor
idx = name.index(".weight.scale")
param_name = name[: idx + 7]
scale_level = tuple(map(int, name[idx + 14 :].split(".")))
scales.setdefault(param_name, {})[scale_level] = tensor
for param_name in zeros.keys():
assert param_name in state_dict, f"zero point tensor {param_name} not found in state dict."
assert param_name in scales, f"scale tensor {param_name} not found in scale dict."
converted: dict[str, torch.Tensor] = {}
print("Converting state dict...")
for param_name, param in tqdm.tqdm(state_dict.items(), desc="Converting state dict", dynamic_ncols=True):
if param_name in scales:
print(f" - Converting {param_name} (dtype: {param.dtype}, shape: {param.shape}, device: {param.device})")
weight = param.data.clone()
if param_name in zeros:
zero, zero_pre_scaled = zeros[param_name]
zero = zero.clone() if zero is not None else None
else:
zero, zero_pre_scaled = None, False
level_scales = sorted(scales[param_name].items(), key=lambda x: x[0])
assert len(level_scales) == 1, "more than one scale levels are not supported."
scale = level_scales[0][1].clone()
converted.update(
convert_to_tinychat_w4x16y16_linear_state_dict(
param_name,
weight,
scale=scale,
zero=zero,
group_size=group_size,
zero_pre_scaled=zero_pre_scaled,
)
)
else:
if isinstance(param, torch.Tensor):
print(f" - Copying {param_name} (dtype: {param.dtype}, shape: {param.shape}, device: {param.device})")
converted[param_name] = param.clone().cpu()
else:
print(f" - Copying {param_name} (type: {type(param)}, value: {param})")
converted[param_name] = param
return converted
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--quant-path", type=str, required=True, help="path to the quantization checkpoint directory.")
parser.add_argument("--group-size", type=int, default=-1, help="quantization group size.")
parser.add_argument("--output-root", type=str, default="", help="root to the output checkpoint directory.")
parser.add_argument("--model-name", type=str, default=None, help="model name.")
parser.add_argument("--model-path", type=str, default=None, help="path to the huggingface model directory.")
parser.add_argument("--copy-on-save", action="store_true", help="copy files on save.")
args = parser.parse_args()
if not args.output_root:
args.output_root = args.quant_path
if args.model_name is None:
assert args.model_path is not None, "model name or path is required."
model_name = args.model_path.rstrip(os.sep).split(os.sep)[-1]
print(f"Model name not provided. Using model name {model_name}.")
else:
model_name = args.model_name
state_dict = torch.load(
os.path.join(args.quant_path, "model.pt"),
map_location="cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu",
)
scale_dict = torch.load(os.path.join(args.quant_path, "scale.pt"), map_location="cpu")
converted = convert_to_tinychat_state_dict(state_dict, scale_dict, group_size=args.group_size)
model_name = f"{args.model_name}-w4a16"
model_name += f"-g{args.group_size}" if args.group_size > 0 else "-gchn"
output_dirpath = os.path.join(args.output_root, model_name)
os.makedirs(output_dirpath, exist_ok=True)
if args.model_path and os.path.exists(args.model_path):
output_path = os.path.join(output_dirpath, "model.safetensors")
safetensors.torch.save_file(converted, output_path)
print(f"Quantized model checkpoint saved to {output_path}.")
for filename in os.listdir(args.model_path):
if filename == "tokenizer.model" or (
filename.endswith(".json") and filename != "pytorch_model.bin.index.json"
):
filepath = os.path.abspath(os.path.join(args.model_path, filename))
if args.copy_on_save:
os.system(f"cp {filepath} {output_dirpath}/")
else:
os.system(f"ln -s {filepath} {output_dirpath}/{filename}")
else:
output_path = os.path.join(output_dirpath, "tinychat-v2.pt")
torch.save(converted, output_path)
print(f"Quantized model checkpoint saved to {output_path}.")
print(f"Quantized model saved to {output_dirpath}.")
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
"""TinyChat Quantized Linear Module""" """TinyChat Quantized Linear Module"""
import warnings
import torch import torch
import torch.nn as nn import torch.nn as nn
from nunchaku._C.ops import gemm_cuda, gemv_awq
from .tinychat_utils import ceil_num_groups, convert_to_tinychat_w4x16y16_linear_weight from .tinychat_utils import ceil_num_groups, convert_to_tinychat_w4x16y16_linear_weight
from ..._C.ops import gemm_awq, gemv_awq
__all__ = ["W4Linear"] __all__ = ["W4Linear"]
warnings.warn(
"Module `tinychat.linear` will be moved to `Nunchaku` and deprecated in the future release.",
DeprecationWarning,
stacklevel=2,
)
class W4Linear(nn.Module): class W4Linear(nn.Module):
def __init__( def __init__(
...@@ -89,7 +81,7 @@ class W4Linear(nn.Module): ...@@ -89,7 +81,7 @@ class W4Linear(nn.Module):
self.group_size, self.group_size,
) )
else: else:
out = gemm_cuda(x, self.qweight, self.scales, self.scaled_zeros) out = gemm_awq(x, self.qweight, self.scales, self.scaled_zeros)
out = out + self.bias if self.bias is not None else out out = out + self.bias if self.bias is not None else out
return out return out
......
import os import os
import torch import torch
from nunchaku.models.linear import W4Linear
from huggingface_hub import constants, hf_hub_download from huggingface_hub import constants, hf_hub_download
from safetensors.torch import load_file from safetensors.torch import load_file
from torch import nn from torch import nn
from transformers import PretrainedConfig, T5EncoderModel from transformers import PretrainedConfig, T5EncoderModel
from .linear import W4Linear
def quantize_t5_encoder( def quantize_t5_encoder(
t5_encoder: nn.Module, t5_encoder: nn.Module,
......
from .transformer_flux import NunchakuFluxTransformer2dModel
from .transformer_sana import NunchakuSanaTransformer2DModel
...@@ -8,9 +8,10 @@ from huggingface_hub import utils ...@@ -8,9 +8,10 @@ from huggingface_hub import utils
from packaging.version import Version from packaging.version import Version
from torch import nn from torch import nn
from nunchaku.utils import fetch_or_download
from .utils import NunchakuModelLoaderMixin, pad_tensor from .utils import NunchakuModelLoaderMixin, pad_tensor
from .._C import QuantizedFluxModel, utils as cutils from ..._C import QuantizedFluxModel, utils as cutils
from ..utils import fetch_or_download from ...utils import load_state_dict_in_safetensors
SVD_RANK = 32 SVD_RANK = 32
...@@ -88,8 +89,6 @@ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: ...@@ -88,8 +89,6 @@ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
else: else:
out = out.view(batch_size, -1, dim // 2, 1, 1) out = out.view(batch_size, -1, dim // 2, 1, 1)
# stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
# out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
return out.float() return out.float()
...@@ -108,12 +107,14 @@ class EmbedND(nn.Module): ...@@ -108,12 +107,14 @@ class EmbedND(nn.Module):
return emb.unsqueeze(1) return emb.unsqueeze(1)
def load_quantized_module(path: str, device: str | torch.device = "cuda", use_fp4: bool = False) -> QuantizedFluxModel: def load_quantized_module(
path: str, device: str | torch.device = "cuda", use_fp4: bool = False, offload: bool = False
) -> QuantizedFluxModel:
device = torch.device(device) device = torch.device(device)
assert device.type == "cuda" assert device.type == "cuda"
m = QuantizedFluxModel() m = QuantizedFluxModel()
cutils.disable_memory_auto_release() cutils.disable_memory_auto_release()
m.init(use_fp4, True, 0 if device.index is None else device.index) m.init(use_fp4, offload, True, 0 if device.index is None else device.index)
m.load(path) m.load(path)
return m return m
...@@ -147,19 +148,49 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader ...@@ -147,19 +148,49 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
guidance_embeds=guidance_embeds, guidance_embeds=guidance_embeds,
axes_dims_rope=axes_dims_rope, axes_dims_rope=axes_dims_rope,
) )
self.unquantized_loras = {}
self.unquantized_state_dict = None
@classmethod @classmethod
@utils.validate_hf_hub_args @utils.validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs):
device = kwargs.get("device", "cuda") device = kwargs.get("device", "cuda")
precision = kwargs.get("precision", "int4") precision = kwargs.get("precision", "int4")
offload = kwargs.get("offload", False)
assert precision in ["int4", "fp4"] assert precision in ["int4", "fp4"]
transformer, transformer_block_path = cls._build_model(pretrained_model_name_or_path, **kwargs) transformer, transformer_block_path = cls._build_model(pretrained_model_name_or_path, **kwargs)
m = load_quantized_module(transformer_block_path, device=device, use_fp4=precision == "fp4") m = load_quantized_module(transformer_block_path, device=device, use_fp4=precision == "fp4", offload=offload)
transformer.inject_quantized_module(m, device) transformer.inject_quantized_module(m, device)
return transformer return transformer
def update_unquantized_lora_params(self, strength: float = 1):
new_state_dict = {}
for k in self.unquantized_state_dict.keys():
v = self.unquantized_state_dict[k]
if k.replace(".weight", ".lora_B.weight") in self.unquantized_loras:
new_state_dict[k] = v + strength * (
self.unquantized_loras[k.replace(".weight", ".lora_B.weight")]
@ self.unquantized_loras[k.replace(".weight", ".lora_A.weight")]
)
else:
new_state_dict[k] = v
self.load_state_dict(new_state_dict, strict=True)
def update_lora_params(self, path: str): def update_lora_params(self, path: str):
state_dict = load_state_dict_in_safetensors(path)
unquantized_loras = {}
for k in state_dict.keys():
if "transformer_blocks" not in k:
unquantized_loras[k] = state_dict[k]
self.unquantized_loras = unquantized_loras
if len(unquantized_loras) > 0:
if self.unquantized_state_dict is None:
unquantized_state_dict = self.state_dict()
self.unquantized_state_dict = {k: v.cpu() for k, v in unquantized_state_dict.items()}
self.update_unquantized_lora_params(1)
path = fetch_or_download(path) path = fetch_or_download(path)
block = self.transformer_blocks[0] block = self.transformer_blocks[0]
assert isinstance(block, NunchakuFluxTransformerBlocks) assert isinstance(block, NunchakuFluxTransformerBlocks)
...@@ -169,6 +200,8 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader ...@@ -169,6 +200,8 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
block = self.transformer_blocks[0] block = self.transformer_blocks[0]
assert isinstance(block, NunchakuFluxTransformerBlocks) assert isinstance(block, NunchakuFluxTransformerBlocks)
block.m.setLoraScale(SVD_RANK, strength) block.m.setLoraScale(SVD_RANK, strength)
if len(self.unquantized_loras) > 0:
self.update_unquantized_lora_params(strength)
def inject_quantized_module(self, m: QuantizedFluxModel, device: str | torch.device = "cuda"): def inject_quantized_module(self, m: QuantizedFluxModel, device: str | torch.device = "cuda"):
self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=[16, 56, 56]) self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=[16, 56, 56])
......
...@@ -9,7 +9,7 @@ from huggingface_hub import utils ...@@ -9,7 +9,7 @@ from huggingface_hub import utils
from torch import nn from torch import nn
from .utils import NunchakuModelLoaderMixin from .utils import NunchakuModelLoaderMixin
from .._C import QuantizedSanaModel, utils as cutils from ..._C import QuantizedSanaModel, utils as cutils
SVD_RANK = 32 SVD_RANK = 32
......
import torch import torch
from diffusers import FluxPipeline from diffusers import FluxPipeline
from .models.transformer_flux import NunchakuFluxTransformer2dModel from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel
if __name__ == "__main__": if __name__ == "__main__":
capability = torch.cuda.get_device_capability(0) capability = torch.cuda.get_device_capability(0)
......
...@@ -5,11 +5,11 @@ import torch ...@@ -5,11 +5,11 @@ import torch
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
def fetch_or_download(path: str) -> str: def fetch_or_download(path: str, repo_type: str = "model") -> str:
if not os.path.exists(path): if not os.path.exists(path):
hf_repo_id = os.path.dirname(path) hf_repo_id = os.path.dirname(path)
filename = os.path.basename(path) filename = os.path.basename(path)
path = hf_hub_download(repo_id=hf_repo_id, filename=filename) path = hf_hub_download(repo_id=hf_repo_id, filename=filename, repo_type=repo_type)
return path return path
......
...@@ -21,4 +21,4 @@ dependencies = [ ...@@ -21,4 +21,4 @@ dependencies = [
"protobuf", "protobuf",
"huggingface_hub", "huggingface_hub",
] ]
requires-python = ">=3.11, <3.13" requires-python = ">=3.10, <3.13"
#!/bin/bash
# Modified from https://github.com/sgl-project/sglang/blob/main/sgl-kernel/build.sh
set -ex
PYTHON_VERSION=$1
TORCH_VERSION=$2
CUDA_VERSION=$3
MAX_JOBS=${4:-} # optional
PYTHON_ROOT_PATH=/opt/python/cp${PYTHON_VERSION//.}-cp${PYTHON_VERSION//.}
docker run --rm \
-v "$(pwd)":/nunchaku \
pytorch/manylinux-builder:cuda${CUDA_VERSION} \
bash -c "
cd /nunchaku && \
rm -rf build && \
yum install -y devtoolset-11 && \
source scl_source enable devtoolset-11 && \
gcc --version && g++ --version && \
${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir torch==${TORCH_VERSION} numpy --index-url https://download.pytorch.org/whl/cu${CUDA_VERSION//.} && \
${PYTHON_ROOT_PATH}/bin/pip install build ninja wheel setuptools && \
export NUNCHAKU_INSTALL_MODE=ALL && \
export NUNCHAKU_BUILD_WHEELS=1 && \
export MAX_JOBS=${MAX_JOBS} && \
${PYTHON_ROOT_PATH}/bin/python -m build --wheel --no-isolation
"
\ No newline at end of file
#!/bin/bash
set -ex
#docker run --rm \
# -v "$(pwd)":/nunchaku \
# pytorch/manylinux-builder:cuda12.4 \
# bash -c "cd /nunchaku && rm -r *"
docker run --rm -it \
-v "$(pwd)":/nunchaku \
pytorch/manylinux-builder:cuda12.4 \
bash
\ No newline at end of file
...@@ -8,6 +8,7 @@ import torch ...@@ -8,6 +8,7 @@ import torch
from packaging import version as packaging_version from packaging import version as packaging_version
from torch.utils.cpp_extension import BuildExtension, CUDA_HOME, CUDAExtension from torch.utils.cpp_extension import BuildExtension, CUDA_HOME, CUDAExtension
class CustomBuildExtension(BuildExtension): class CustomBuildExtension(BuildExtension):
def build_extensions(self): def build_extensions(self):
for ext in self.extensions: for ext in self.extensions:
...@@ -74,8 +75,6 @@ if __name__ == "__main__": ...@@ -74,8 +75,6 @@ if __name__ == "__main__":
"third_party/mio/include", "third_party/mio/include",
"third_party/spdlog/include", "third_party/spdlog/include",
"third_party/Block-Sparse-Attention/csrc/block_sparse_attn", "third_party/Block-Sparse-Attention/csrc/block_sparse_attn",
"src/interop",
"src/kernels",
] ]
INCLUDE_DIRS = [os.path.join(ROOT_DIR, dir) for dir in INCLUDE_DIRS] INCLUDE_DIRS = [os.path.join(ROOT_DIR, dir) for dir in INCLUDE_DIRS]
...@@ -94,8 +93,13 @@ if __name__ == "__main__": ...@@ -94,8 +93,13 @@ if __name__ == "__main__":
else: else:
return [] return []
sm_targets = get_sm_targets()
print(f"Detected SM targets: {sm_targets}", file=sys.stderr)
assert len(sm_targets) > 0, "No SM targets found"
GCC_FLAGS = ["-DENABLE_BF16=1", "-DBUILD_NUNCHAKU=1", "-fvisibility=hidden", "-g", "-std=c++20", "-UNDEBUG", "-Og"] GCC_FLAGS = ["-DENABLE_BF16=1", "-DBUILD_NUNCHAKU=1", "-fvisibility=hidden", "-g", "-std=c++20", "-UNDEBUG", "-Og"]
MSVC_FLAGS = ["/DENABLE_BF16=1", "/DBUILD_NUNCHAKU=1", "/std:c++20", "/UNDEBUG", "/Zc:__cplusplus"] MSVC_FLAGS = ["/DENABLE_BF16=1", "/DBUILD_NUNCHAKU=1", "/std:c++20", "/UNDEBUG", "/Zc:__cplusplus", "/FS"]
NVCC_FLAGS = [ NVCC_FLAGS = [
"-DENABLE_BF16=1", "-DENABLE_BF16=1",
"-DBUILD_NUNCHAKU=1", "-DBUILD_NUNCHAKU=1",
...@@ -113,7 +117,7 @@ if __name__ == "__main__": ...@@ -113,7 +117,7 @@ if __name__ == "__main__":
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__", "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__", "-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__", "-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
"--threads=3", f"--threads={len(sm_targets)}",
"--expt-relaxed-constexpr", "--expt-relaxed-constexpr",
"--expt-extended-lambda", "--expt-extended-lambda",
"--ptxas-options=--allow-expensive-optimizations=true", "--ptxas-options=--allow-expensive-optimizations=true",
...@@ -122,15 +126,10 @@ if __name__ == "__main__": ...@@ -122,15 +126,10 @@ if __name__ == "__main__":
if os.getenv("NUNCHAKU_BUILD_WHEELS", "0") == "0": if os.getenv("NUNCHAKU_BUILD_WHEELS", "0") == "0":
NVCC_FLAGS.append("--generate-line-info") NVCC_FLAGS.append("--generate-line-info")
sm_targets = get_sm_targets()
print(f"Detected SM targets: {sm_targets}", file=sys.stderr)
assert len(sm_targets) > 0, "No SM targets found"
for target in sm_targets: for target in sm_targets:
NVCC_FLAGS += ["-gencode", f"arch=compute_{target},code=sm_{target}"] NVCC_FLAGS += ["-gencode", f"arch=compute_{target},code=sm_{target}"]
NVCC_MSVC_FLAGS = ["-Xcompiler", "/Zc:__cplusplus"] NVCC_MSVC_FLAGS = ["-Xcompiler", "/Zc:__cplusplus", "-Xcompiler", "/FS"]
nunchaku_extension = CUDAExtension( nunchaku_extension = CUDAExtension(
name="nunchaku._C", name="nunchaku._C",
...@@ -165,16 +164,12 @@ if __name__ == "__main__": ...@@ -165,16 +164,12 @@ if __name__ == "__main__":
"src/kernels/dwconv.cu", "src/kernels/dwconv.cu",
"src/kernels/gemm_batched.cu", "src/kernels/gemm_batched.cu",
"src/kernels/gemm_f16.cu", "src/kernels/gemm_f16.cu",
"src/kernels/awq/gemm_cuda.cu", "src/kernels/awq/gemm_awq.cu",
"src/kernels/awq/gemv_awq.cu", "src/kernels/awq/gemv_awq.cu",
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/flash_api.cpp"), *ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/flash_api.cpp"),
*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/flash_api_adapter.cpp"), *ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/flash_api_adapter.cpp"),
], ],
extra_compile_args={ extra_compile_args={"gcc": GCC_FLAGS, "msvc": MSVC_FLAGS, "nvcc": NVCC_FLAGS, "nvcc_msvc": NVCC_MSVC_FLAGS},
"gcc": GCC_FLAGS,
"msvc": MSVC_FLAGS,
"nvcc": NVCC_FLAGS,
},
include_dirs=INCLUDE_DIRS, include_dirs=INCLUDE_DIRS,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment