Unverified Commit bbd164c6 authored by Bilang ZHANG's avatar Bilang ZHANG Committed by GitHub
Browse files

update convert (#481)

--linear_dtype and --linear_quant_dtype unify as --linear_type
parent 4beb6ebc
......@@ -27,6 +27,11 @@ sys.path.append(str(Path(__file__).parent.parent.parent))
from lightx2v.utils.registry_factory import CONVERT_WEIGHT_REGISTER
from tools.convert.quant import *
dtype_mapping = {
"int8": torch.int8,
"fp8": torch.float8_e4m3fn,
}
def get_key_mapping_rules(direction, model_type):
if model_type == "wan_dit":
......@@ -306,59 +311,6 @@ def get_key_mapping_rules(direction, model_type):
raise ValueError(f"Unsupported model type: {model_type}")
def quantize_tensor(w, w_bit=8, dtype=torch.int8, comfyui_mode=False):
"""
Quantize a 2D tensor to specified bit width using symmetric min-max quantization
Args:
w: Input tensor to quantize (must be 2D)
w_bit: Quantization bit width (default: 8)
Returns:
quantized: Quantized tensor (int8)
scales: Scaling factors per row
"""
if w.dim() != 2:
raise ValueError(f"Only 2D tensors supported. Got {w.dim()}D tensor")
if torch.isnan(w).any():
raise ValueError("Tensor contains NaN values")
if w_bit != 8:
raise ValueError("Only support 8 bits")
org_w_shape = w.shape
# Calculate quantization parameters
if not comfyui_mode:
max_val = w.abs().amax(dim=1, keepdim=True).clamp(min=1e-5)
else:
max_val = w.abs().max()
if dtype == torch.float8_e4m3fn:
finfo = torch.finfo(dtype)
qmin, qmax = finfo.min, finfo.max
elif dtype == torch.int8:
qmin, qmax = -128, 127
# Quantize tensor
scales = max_val / qmax
if dtype == torch.float8_e4m3fn:
from qtorch.quant import float_quantize
scaled_tensor = w / scales
scaled_tensor = torch.clip(scaled_tensor, qmin, qmax)
w_q = float_quantize(scaled_tensor.float(), 4, 3, rounding="nearest").to(dtype)
else:
w_q = torch.clamp(torch.round(w / scales), qmin, qmax).to(dtype)
assert torch.isnan(scales).sum() == 0
assert torch.isnan(w_q).sum() == 0
if not comfyui_mode:
scales = scales.view(org_w_shape[0], -1)
w_q = w_q.reshape(org_w_shape)
return w_q, scales
def quantize_model(
weights,
w_bit=8,
......@@ -366,11 +318,10 @@ def quantize_model(
adapter_keys=None,
key_idx=2,
ignore_key=None,
linear_dtype=torch.int8,
linear_type="int8",
non_linear_dtype=torch.float,
comfyui_mode=False,
comfyui_keys=[],
linear_quant_type=None,
):
"""
Quantize model weights in-place
......@@ -435,13 +386,9 @@ def quantize_model(
original_size += original_tensor_size
# Quantize tensor and store results
if linear_quant_type:
quantizer = CONVERT_WEIGHT_REGISTER[linear_quant_type](tensor)
w_q, scales, extra = quantizer.weight_quant_func(tensor)
weight_global_scale = extra.get("weight_global_scale", None) # For nvfp4
else:
w_q, scales = quantize_tensor(tensor, w_bit, linear_dtype, comfyui_mode)
weight_global_scale = None
quantizer = CONVERT_WEIGHT_REGISTER[linear_type](tensor)
w_q, scales, extra = quantizer.weight_quant_func(tensor, comfyui_mode)
weight_global_scale = extra.get("weight_global_scale", None) # For nvfp4
# Replace original tensor and store scales
weights[key] = w_q
......@@ -637,6 +584,7 @@ def convert_weights(args):
if args.quantized:
if args.full_quantized and args.comfyui_mode:
logger.info("Quant all tensors...")
assert args.linear_dtype, f"Error: only support 'torch.int8' and 'torch.float8_e4m3fn'."
for k in converted_weights.keys():
converted_weights[k] = converted_weights[k].float().to(args.linear_dtype)
else:
......@@ -647,11 +595,10 @@ def convert_weights(args):
adapter_keys=args.adapter_keys,
key_idx=args.key_idx,
ignore_key=args.ignore_key,
linear_dtype=args.linear_dtype,
linear_type=args.linear_type,
non_linear_dtype=args.non_linear_dtype,
comfyui_mode=args.comfyui_mode,
comfyui_keys=args.comfyui_keys,
linear_quant_type=args.linear_quant_type,
)
os.makedirs(args.output, exist_ok=True)
......@@ -818,16 +765,10 @@ def main():
help="Device to use for quantization (cpu/cuda)",
)
parser.add_argument(
"--linear_dtype",
type=str,
choices=["torch.int8", "torch.float8_e4m3fn"],
help="Data type for linear",
)
parser.add_argument(
"--linear_quant_type",
"--linear_type",
type=str,
choices=["INT8", "FP8", "NVFP4", "MXFP4", "MXFP6", "MXFP8"],
help="Data type for linear",
choices=["int8", "fp8", "nvfp4", "mxfp4", "mxfp6", "mxfp8"],
help="Quant type for linear",
)
parser.add_argument(
"--non_linear_dtype",
......@@ -870,7 +811,7 @@ def main():
logger.warning("--chunk_size is ignored when using --single_file option.")
if args.quantized:
args.linear_dtype = eval(args.linear_dtype)
args.linear_dtype = dtype_mapping.get(args.linear_type, None)
args.non_linear_dtype = eval(args.non_linear_dtype)
model_type_keys_map = {
......
......@@ -22,16 +22,19 @@ class QuantTemplate(metaclass=ABCMeta):
self.extra = {}
@CONVERT_WEIGHT_REGISTER("INT8")
@CONVERT_WEIGHT_REGISTER("int8")
class QuantWeightINT8(QuantTemplate):
def __init__(self, weight):
super().__init__(weight)
self.weight_quant_func = self.load_int8_weight
@torch.no_grad()
def load_int8_weight(self, w):
def load_int8_weight(self, w, comfyui_mode=False):
org_w_shape = w.shape
max_val = w.abs().amax(dim=1, keepdim=True).clamp(min=1e-5)
if not comfyui_mode:
max_val = w.abs().amax(dim=1, keepdim=True).clamp(min=1e-5)
else:
max_val = w.abs().max()
qmin, qmax = -128, 127
scales = max_val / qmax
w_q = torch.clamp(torch.round(w / scales), qmin, qmax).to(torch.int8)
......@@ -39,22 +42,26 @@ class QuantWeightINT8(QuantTemplate):
assert torch.isnan(scales).sum() == 0
assert torch.isnan(w_q).sum() == 0
scales = scales.view(org_w_shape[0], -1)
w_q = w_q.reshape(org_w_shape)
if not comfyui_mode:
scales = scales.view(org_w_shape[0], -1)
w_q = w_q.reshape(org_w_shape)
return w_q, scales, self.extra
@CONVERT_WEIGHT_REGISTER("FP8")
@CONVERT_WEIGHT_REGISTER("fp8")
class QuantWeightFP8(QuantTemplate):
def __init__(self, weight):
super().__init__(weight)
self.weight_quant_func = self.load_fp8_weight
@torch.no_grad()
def load_fp8_weight(self, w):
def load_fp8_weight(self, w, comfyui_mode=False):
org_w_shape = w.shape
max_val = w.abs().amax(dim=1, keepdim=True).clamp(min=1e-5)
if not comfyui_mode:
max_val = w.abs().amax(dim=1, keepdim=True).clamp(min=1e-5)
else:
max_val = w.abs().max()
finfo = torch.finfo(torch.float8_e4m3fn)
qmin, qmax = finfo.min, finfo.max
scales = max_val / qmax
......@@ -65,20 +72,21 @@ class QuantWeightFP8(QuantTemplate):
assert torch.isnan(scales).sum() == 0
assert torch.isnan(w_q).sum() == 0
scales = scales.view(org_w_shape[0], -1)
w_q = w_q.reshape(org_w_shape)
if not comfyui_mode:
scales = scales.view(org_w_shape[0], -1)
w_q = w_q.reshape(org_w_shape)
return w_q, scales, self.extra
@CONVERT_WEIGHT_REGISTER("MXFP4")
@CONVERT_WEIGHT_REGISTER("mxfp4")
class QuantWeightMxFP4(QuantTemplate):
def __init__(self, weight):
super().__init__(weight)
self.weight_quant_func = self.load_mxfp4_weight
@torch.no_grad()
def load_mxfp4_weight(self, w):
def load_mxfp4_weight(self, w, comfyui_mode=False):
device = w.device
w = w.cuda().to(torch.bfloat16)
w_q, scales = scaled_mxfp4_quant(w)
......@@ -86,14 +94,14 @@ class QuantWeightMxFP4(QuantTemplate):
return w_q, scales, self.extra
@CONVERT_WEIGHT_REGISTER("MXFP6")
@CONVERT_WEIGHT_REGISTER("mxfp6")
class QuantWeightMxFP6(QuantTemplate):
def __init__(self, weight):
super().__init__(weight)
self.weight_quant_func = self.load_mxfp6_weight
@torch.no_grad()
def load_mxfp6_weight(self, w):
def load_mxfp6_weight(self, w, comfyui_mode=False):
device = w.device
w = w.cuda().to(torch.bfloat16)
w_q, scales = scaled_mxfp6_quant(w)
......@@ -101,14 +109,14 @@ class QuantWeightMxFP6(QuantTemplate):
return w_q, scales, self.extra
@CONVERT_WEIGHT_REGISTER("MXFP8")
@CONVERT_WEIGHT_REGISTER("mxfp8")
class QuantWeightMxFP8(QuantTemplate):
def __init__(self, weight):
super().__init__(weight)
self.weight_quant_func = self.load_mxfp8_weight
@torch.no_grad()
def load_mxfp8_weight(self, w):
def load_mxfp8_weight(self, w, comfyui_mode=False):
device = w.device
w = w.cuda().to(torch.bfloat16)
w_q, scales = scaled_mxfp8_quant(w)
......@@ -116,14 +124,14 @@ class QuantWeightMxFP8(QuantTemplate):
return w_q, scales, self.extra
@CONVERT_WEIGHT_REGISTER("NVFP4")
@CONVERT_WEIGHT_REGISTER("nvfp4")
class QuantWeightNVFP4(QuantTemplate):
def __init__(self, weight):
super().__init__(weight)
self.weight_quant_func = self.load_fp4_weight
@torch.no_grad()
def load_fp4_weight(self, w):
def load_fp4_weight(self, w, comfyui_mode=False):
device = w.device
w = w.cuda().to(torch.bfloat16)
weight_global_scale = (2688.0 / torch.max(torch.abs(w))).to(torch.float32)
......
......@@ -5,7 +5,7 @@ A powerful model weight conversion tool that supports format conversion, quantiz
## Main Features
- **Format Conversion**: Support PyTorch (.pth) and SafeTensors (.safetensors) format conversion
- **Model Quantization**: Support INT8 and FP8 quantization to significantly reduce model size
- **Model Quantization**: Support INT8, FP8, NVFP4, MXFP4, MXFP6 and MXFP8 quantization to significantly reduce model size
- **Architecture Conversion**: Support conversion between LightX2V and Diffusers architectures
- **LoRA Merging**: Support loading and merging multiple LoRA formats
- **Multi-Model Support**: Support Wan DiT, Qwen Image DiT, T5, CLIP, etc.
......@@ -42,16 +42,21 @@ A powerful model weight conversion tool that supports format conversion, quantiz
- `--quantized`: Enable quantization
- `--bits`: Quantization bit width, currently only supports 8-bit
- `--linear_dtype`: Linear layer quantization type
- `torch.int8`: INT8 quantization
- `torch.float8_e4m3fn`: FP8 quantization
- `--linear_type`: Linear layer quantization type
- `int8`: INT8 quantization (torch.int8)
- `fp8`: FP8 quantization (torch.float8_e4m3fn)
- `nvfp4`: NVFP4 quantization
- `mxfp4`: MXFP4 quantization
- `mxfp6`: MXFP6 quantization
- `mxfp8`: MXFP8 quantization
- `--non_linear_dtype`: Non-linear layer data type
- `torch.bfloat16`: BF16
- `torch.float16`: FP16
- `torch.float32`: FP32 (default)
- `--device`: Device for quantization, `cpu` or `cuda` (default)
- `--comfyui_mode`: ComfyUI compatible mode
- `--comfyui_mode`: ComfyUI compatible mode (only int8 and fp8)
- `--full_quantized`: Full quantization mode (effective in ComfyUI mode)
For nvfp4, mxfp4, mxfp6 and mxfp8, please install them fllowing LightX2V/lightx2v_kernel/README.md.
### LoRA Parameters
......@@ -105,7 +110,7 @@ python converter.py \
--output /path/to/output \
--output_ext .safetensors \
--output_name wan_int8 \
--linear_dtype torch.int8 \
--linear_type int8 \
--model_type wan_dit \
--quantized \
--save_by_block
......@@ -118,7 +123,7 @@ python converter.py \
--output /path/to/output \
--output_ext .safetensors \
--output_name wan2.1_i2v_480p_int8_lightx2v \
--linear_dtype torch.int8 \
--linear_type int8 \
--model_type wan_dit \
--quantized \
--single_file
......@@ -133,7 +138,7 @@ python converter.py \
--output /path/to/output \
--output_ext .safetensors \
--output_name wan_fp8 \
--linear_dtype torch.float8_e4m3fn \
--linear_type fp8 \
--non_linear_dtype torch.bfloat16 \
--model_type wan_dit \
--quantized \
......@@ -147,7 +152,7 @@ python converter.py \
--output /path/to/output \
--output_ext .safetensors \
--output_name wan2.1_i2v_480p_scaled_fp8_e4m3_lightx2v \
--linear_dtype torch.float8_e4m3fn \
--linear_type fp8 \
--non_linear_dtype torch.bfloat16 \
--model_type wan_dit \
--quantized \
......@@ -161,7 +166,7 @@ python converter.py \
--output /path/to/output \
--output_ext .safetensors \
--output_name wan2.1_i2v_480p_scaled_fp8_e4m3_lightx2v_comfyui \
--linear_dtype torch.float8_e4m3fn \
--linear_type fp8 \
--non_linear_dtype torch.bfloat16 \
--model_type wan_dit \
--quantized \
......@@ -176,7 +181,7 @@ python converter.py \
--output /path/to/output \
--output_ext .safetensors \
--output_name wan2.1_i2v_480p_scaled_fp8_e4m3_lightx2v_comfyui \
--linear_dtype torch.float8_e4m3fn \
--linear_type fp8 \
--non_linear_dtype torch.bfloat16 \
--model_type wan_dit \
--quantized \
......@@ -196,7 +201,7 @@ python converter.py \
--output /path/to/output \
--output_ext .pth \
--output_name models_t5_umt5-xxl-enc-int8 \
--linear_dtype torch.int8 \
--linear_type int8 \
--non_linear_dtype torch.bfloat16 \
--model_type wan_t5 \
--quantized
......@@ -209,7 +214,7 @@ python converter.py \
--output /path/to/output \
--output_ext .pth \
--output_name models_t5_umt5-xxl-enc-fp8 \
--linear_dtype torch.float8_e4m3fn \
--linear_type fp8 \
--non_linear_dtype torch.bfloat16 \
--model_type wan_t5 \
--quantized
......@@ -224,7 +229,7 @@ python converter.py \
--output /path/to/output \
--output_ext .pth \
--output_name models_clip_open-clip-xlm-roberta-large-vit-huge-14-int8 \
--linear_dtype torch.int8 \
--linear_type int8 \
--non_linear_dtype torch.float16 \
--model_type wan_clip \
--quantized
......@@ -237,7 +242,7 @@ python converter.py \
--output /path/to/output \
--output_ext .pth \
--output_name models_clip_open-clip-xlm-roberta-large-vit-huge-14-fp8 \
--linear_dtype torch.float8_e4m3fn \
--linear_type fp8 \
--non_linear_dtype torch.float16 \
--model_type wan_clip \
--quantized
......@@ -318,7 +323,7 @@ python converter.py \
--lora_path /path/to/lora.safetensors \
--lora_strength 1.0 \
--quantized \
--linear_dtype torch.float8_e4m3fn \
--linear_type fp8 \
--single_file
```
......@@ -333,7 +338,7 @@ python converter.py \
--lora_path /path/to/lora.safetensors \
--lora_strength 1.0 \
--quantized \
--linear_dtype torch.float8_e4m3fn \
--linear_type fp8 \
--single_file \
--comfyui_mode
```
......@@ -349,7 +354,7 @@ python converter.py \
--lora_path /path/to/lora.safetensors \
--lora_strength 1.0 \
--quantized \
--linear_dtype torch.float8_e4m3fn \
--linear_type fp8 \
--single_file \
--comfyui_mode \
--full_quantized
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment