Unverified Commit 49aff300 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files
parent c47dc6e8
...@@ -15,11 +15,12 @@ ...@@ -15,11 +15,12 @@
"cpu_offload": true, "cpu_offload": true,
"offload_granularity": "block", "offload_granularity": "block",
"offload_ratio": 1, "offload_ratio": 1,
"t5_cpu_offload": true, "t5_cpu_offload": false,
"t5_offload_granularity": "model",
"t5_quantized": true, "t5_quantized": true,
"t5_quant_scheme": "fp8-q8f", "t5_quant_scheme": "fp8-q8f",
"clip_cpu_offload": false, "clip_cpu_offload": false,
"clip_quantized": true,
"clip_quant_scheme": "fp8-q8f",
"audio_encoder_cpu_offload": false, "audio_encoder_cpu_offload": false,
"audio_adapter_cpu_offload": false, "audio_adapter_cpu_offload": false,
"adapter_quantized": true, "adapter_quantized": true,
......
{
"infer_steps": 4,
"target_video_length": 81,
"text_len": 512,
"target_height": 720,
"target_width": 1280,
"self_attn_1_type": "sage_attn2",
"cross_attn_1_type": "sage_attn2",
"cross_attn_2_type": "sage_attn2",
"sample_guide_scale": [
3.5,
3.5
],
"sample_shift": 5.0,
"enable_cfg": false,
"cpu_offload": true,
"offload_granularity": "block",
"t5_cpu_offload": false,
"vae_cpu_offload": false,
"use_image_encoder": false,
"boundary_step_index": 2,
"denoising_step_list": [
1000,
750,
500,
250
],
"mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F"
},
"t5_quantized": true,
"t5_quant_scheme": "fp8-q8f"
}
...@@ -38,7 +38,7 @@ class WanRunner(DefaultRunner): ...@@ -38,7 +38,7 @@ class WanRunner(DefaultRunner):
super().__init__(config) super().__init__(config)
self.vae_cls = WanVAE self.vae_cls = WanVAE
self.tiny_vae_cls = WanVAE_tiny self.tiny_vae_cls = WanVAE_tiny
self.vae_name = "Wan2.1_VAE.pth" self.vae_name = config.get("vae_name", "Wan2.1_VAE.pth")
self.tiny_vae_name = "taew2_1.pth" self.tiny_vae_name = "taew2_1.pth"
def load_transformer(self): def load_transformer(self):
...@@ -73,7 +73,7 @@ class WanRunner(DefaultRunner): ...@@ -73,7 +73,7 @@ class WanRunner(DefaultRunner):
clip_quant_scheme = self.config.get("clip_quant_scheme", None) clip_quant_scheme = self.config.get("clip_quant_scheme", None)
assert clip_quant_scheme is not None assert clip_quant_scheme is not None
tmp_clip_quant_scheme = clip_quant_scheme.split("-")[0] tmp_clip_quant_scheme = clip_quant_scheme.split("-")[0]
clip_model_name = f"clip-{tmp_clip_quant_scheme}.pth" clip_model_name = f"models_clip_open-clip-xlm-roberta-large-vit-huge-14-{tmp_clip_quant_scheme}.pth"
clip_quantized_ckpt = find_torch_model_path(self.config, "clip_quantized_ckpt", clip_model_name) clip_quantized_ckpt = find_torch_model_path(self.config, "clip_quantized_ckpt", clip_model_name)
clip_original_ckpt = None clip_original_ckpt = None
else: else:
...@@ -154,6 +154,7 @@ class WanRunner(DefaultRunner): ...@@ -154,6 +154,7 @@ class WanRunner(DefaultRunner):
"cpu_offload": vae_offload, "cpu_offload": vae_offload,
"dtype": GET_DTYPE(), "dtype": GET_DTYPE(),
"load_from_rank0": self.config.get("load_from_rank0", False), "load_from_rank0": self.config.get("load_from_rank0", False),
"use_lightvae": self.config.get("use_lightvae", False),
} }
if self.config["task"] not in ["i2v", "flf2v", "animate", "vace", "s2v"]: if self.config["task"] not in ["i2v", "flf2v", "animate", "vace", "s2v"]:
return None return None
...@@ -174,6 +175,7 @@ class WanRunner(DefaultRunner): ...@@ -174,6 +175,7 @@ class WanRunner(DefaultRunner):
"parallel": self.config["parallel"], "parallel": self.config["parallel"],
"use_tiling": self.config.get("use_tiling_vae", False), "use_tiling": self.config.get("use_tiling_vae", False),
"cpu_offload": vae_offload, "cpu_offload": vae_offload,
"use_lightvae": self.config.get("use_lightvae", False),
"dtype": GET_DTYPE(), "dtype": GET_DTYPE(),
"load_from_rank0": self.config.get("load_from_rank0", False), "load_from_rank0": self.config.get("load_from_rank0", False),
} }
......
...@@ -263,16 +263,7 @@ class AttentionBlock(nn.Module): ...@@ -263,16 +263,7 @@ class AttentionBlock(nn.Module):
class Encoder3d(nn.Module): class Encoder3d(nn.Module):
def __init__( def __init__(self, dim=128, z_dim=4, dim_mult=[1, 2, 4, 4], num_res_blocks=2, attn_scales=[], temperal_downsample=[True, True, False], dropout=0.0, pruning_rate=0.0):
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0,
):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.z_dim = z_dim self.z_dim = z_dim
...@@ -283,6 +274,7 @@ class Encoder3d(nn.Module): ...@@ -283,6 +274,7 @@ class Encoder3d(nn.Module):
# dimensions # dimensions
dims = [dim * u for u in [1] + dim_mult] dims = [dim * u for u in [1] + dim_mult]
dims = [int(d * (1 - pruning_rate)) for d in dims]
scale = 1.0 scale = 1.0
# init block # init block
...@@ -375,16 +367,7 @@ class Encoder3d(nn.Module): ...@@ -375,16 +367,7 @@ class Encoder3d(nn.Module):
class Decoder3d(nn.Module): class Decoder3d(nn.Module):
def __init__( def __init__(self, dim=128, z_dim=4, dim_mult=[1, 2, 4, 4], num_res_blocks=2, attn_scales=[], temperal_upsample=[False, True, True], dropout=0.0, pruning_rate=0.0):
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[False, True, True],
dropout=0.0,
):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.z_dim = z_dim self.z_dim = z_dim
...@@ -395,6 +378,8 @@ class Decoder3d(nn.Module): ...@@ -395,6 +378,8 @@ class Decoder3d(nn.Module):
# dimensions # dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
dims = [int(d * (1 - pruning_rate)) for d in dims]
scale = 1.0 / 2 ** (len(dim_mult) - 2) scale = 1.0 / 2 ** (len(dim_mult) - 2)
# init block # init block
...@@ -498,16 +483,7 @@ def count_conv3d(model): ...@@ -498,16 +483,7 @@ def count_conv3d(model):
class WanVAE_(nn.Module): class WanVAE_(nn.Module):
def __init__( def __init__(self, dim=128, z_dim=4, dim_mult=[1, 2, 4, 4], num_res_blocks=2, attn_scales=[], temperal_downsample=[True, True, False], dropout=0.0, pruning_rate=0.0):
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0,
):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.z_dim = z_dim self.z_dim = z_dim
...@@ -534,6 +510,7 @@ class WanVAE_(nn.Module): ...@@ -534,6 +510,7 @@ class WanVAE_(nn.Module):
attn_scales, attn_scales,
self.temperal_downsample, self.temperal_downsample,
dropout, dropout,
pruning_rate,
) )
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.conv2 = CausalConv3d(z_dim, z_dim, 1) self.conv2 = CausalConv3d(z_dim, z_dim, 1)
...@@ -545,6 +522,7 @@ class WanVAE_(nn.Module): ...@@ -545,6 +522,7 @@ class WanVAE_(nn.Module):
attn_scales, attn_scales,
self.temperal_upsample, self.temperal_upsample,
dropout, dropout,
pruning_rate,
) )
def forward(self, x): def forward(self, x):
...@@ -739,23 +717,6 @@ class WanVAE_(nn.Module): ...@@ -739,23 +717,6 @@ class WanVAE_(nn.Module):
self.clear_cache() self.clear_cache()
return out return out
def cached_decode(self, z, scale):
# z: [b,c,t,h,w]
if isinstance(scale[0], torch.Tensor):
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
else:
z = z / scale[1] + scale[0]
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
else:
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2)
return out
def reparameterize(self, mu, log_var): def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var) std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std) eps = torch.randn_like(std)
...@@ -778,7 +739,7 @@ class WanVAE_(nn.Module): ...@@ -778,7 +739,7 @@ class WanVAE_(nn.Module):
self._enc_feat_map = [None] * self._enc_conv_num self._enc_feat_map = [None] * self._enc_conv_num
def _video_vae(pretrained_path=None, z_dim=None, device="cpu", cpu_offload=False, dtype=torch.float, load_from_rank0=False, **kwargs): def _video_vae(pretrained_path=None, z_dim=None, device="cpu", cpu_offload=False, dtype=torch.float, load_from_rank0=False, pruning_rate=0.0, **kwargs):
""" """
Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL. Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
""" """
...@@ -791,6 +752,7 @@ def _video_vae(pretrained_path=None, z_dim=None, device="cpu", cpu_offload=False ...@@ -791,6 +752,7 @@ def _video_vae(pretrained_path=None, z_dim=None, device="cpu", cpu_offload=False
attn_scales=[], attn_scales=[],
temperal_downsample=[False, True, True], temperal_downsample=[False, True, True],
dropout=0.0, dropout=0.0,
pruning_rate=pruning_rate,
) )
cfg.update(**kwargs) cfg.update(**kwargs)
...@@ -820,6 +782,7 @@ class WanVAE: ...@@ -820,6 +782,7 @@ class WanVAE:
cpu_offload=False, cpu_offload=False,
use_2d_split=True, use_2d_split=True,
load_from_rank0=False, load_from_rank0=False,
use_lightvae=False,
): ):
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
...@@ -827,6 +790,10 @@ class WanVAE: ...@@ -827,6 +790,10 @@ class WanVAE:
self.use_tiling = use_tiling self.use_tiling = use_tiling
self.cpu_offload = cpu_offload self.cpu_offload = cpu_offload
self.use_2d_split = use_2d_split self.use_2d_split = use_2d_split
if use_lightvae:
pruning_rate = 0.75 # 0.75
else:
pruning_rate = 0.0
mean = [ mean = [
-0.7571, -0.7571,
...@@ -906,7 +873,13 @@ class WanVAE: ...@@ -906,7 +873,13 @@ class WanVAE:
} }
# init model # init model
self.model = _video_vae(pretrained_path=vae_pth, z_dim=z_dim, cpu_offload=cpu_offload, dtype=dtype, load_from_rank0=load_from_rank0).eval().requires_grad_(False).to(device).to(dtype) self.model = (
_video_vae(pretrained_path=vae_pth, z_dim=z_dim, cpu_offload=cpu_offload, dtype=dtype, load_from_rank0=load_from_rank0, pruning_rate=pruning_rate)
.eval()
.requires_grad_(False)
.to(device)
.to(dtype)
)
def _calculate_2d_grid(self, latent_height, latent_width, world_size): def _calculate_2d_grid(self, latent_height, latent_width, world_size):
if (latent_height, latent_width, world_size) in self.grid_table: if (latent_height, latent_width, world_size) in self.grid_table:
......
...@@ -2,10 +2,12 @@ import argparse ...@@ -2,10 +2,12 @@ import argparse
import gc import gc
import glob import glob
import json import json
import multiprocessing
import os import os
import re import re
import shutil import shutil
from collections import defaultdict from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
import torch import torch
from loguru import logger from loguru import logger
...@@ -293,7 +295,7 @@ def get_key_mapping_rules(direction, model_type): ...@@ -293,7 +295,7 @@ def get_key_mapping_rules(direction, model_type):
raise ValueError(f"Unsupported model type: {model_type}") raise ValueError(f"Unsupported model type: {model_type}")
def quantize_tensor(w, w_bit=8, dtype=torch.int8): def quantize_tensor(w, w_bit=8, dtype=torch.int8, comfyui_mode=False):
""" """
Quantize a 2D tensor to specified bit width using symmetric min-max quantization Quantize a 2D tensor to specified bit width using symmetric min-max quantization
...@@ -314,14 +316,16 @@ def quantize_tensor(w, w_bit=8, dtype=torch.int8): ...@@ -314,14 +316,16 @@ def quantize_tensor(w, w_bit=8, dtype=torch.int8):
org_w_shape = w.shape org_w_shape = w.shape
# Calculate quantization parameters # Calculate quantization parameters
max_val = w.abs().amax(dim=1, keepdim=True).clamp(min=1e-5) if not comfyui_mode:
max_val = w.abs().amax(dim=1, keepdim=True).clamp(min=1e-5)
else:
max_val = w.abs().max()
if dtype == torch.float8_e4m3fn: if dtype == torch.float8_e4m3fn:
finfo = torch.finfo(dtype) finfo = torch.finfo(dtype)
qmin, qmax = finfo.min, finfo.max qmin, qmax = finfo.min, finfo.max
elif dtype == torch.int8: elif dtype == torch.int8:
qmin, qmax = -128, 127 qmin, qmax = -128, 127
# Quantize tensor # Quantize tensor
scales = max_val / qmax scales = max_val / qmax
...@@ -335,21 +339,15 @@ def quantize_tensor(w, w_bit=8, dtype=torch.int8): ...@@ -335,21 +339,15 @@ def quantize_tensor(w, w_bit=8, dtype=torch.int8):
assert torch.isnan(scales).sum() == 0 assert torch.isnan(scales).sum() == 0
assert torch.isnan(w_q).sum() == 0 assert torch.isnan(w_q).sum() == 0
scales = scales.view(org_w_shape[0], -1) if not comfyui_mode:
w_q = w_q.reshape(org_w_shape) scales = scales.view(org_w_shape[0], -1)
w_q = w_q.reshape(org_w_shape)
return w_q, scales return w_q, scales
def quantize_model( def quantize_model(
weights, weights, w_bit=8, target_keys=["attn", "ffn"], adapter_keys=None, key_idx=2, ignore_key=None, linear_dtype=torch.int8, non_linear_dtype=torch.float, comfyui_mode=False, comfyui_keys=[]
w_bit=8,
target_keys=["attn", "ffn"],
adapter_keys=None,
key_idx=2,
ignore_key=None,
linear_dtype=torch.int8,
non_linear_dtype=torch.float,
): ):
""" """
Quantize model weights in-place Quantize model weights in-place
...@@ -363,7 +361,9 @@ def quantize_model( ...@@ -363,7 +361,9 @@ def quantize_model(
Modified state dictionary with quantized weights and scales Modified state dictionary with quantized weights and scales
""" """
total_quantized = 0 total_quantized = 0
total_size = 0 original_size = 0
quantized_size = 0
non_quantized_size = 0
keys = list(weights.keys()) keys = list(weights.keys())
with tqdm(keys, desc="Quantizing weights") as pbar: with tqdm(keys, desc="Quantizing weights") as pbar:
...@@ -380,87 +380,237 @@ def quantize_model( ...@@ -380,87 +380,237 @@ def quantize_model(
if not isinstance(tensor, torch.Tensor) or tensor.dim() != 2: if not isinstance(tensor, torch.Tensor) or tensor.dim() != 2:
if tensor.dtype != non_linear_dtype: if tensor.dtype != non_linear_dtype:
weights[key] = tensor.to(non_linear_dtype) weights[key] = tensor.to(non_linear_dtype)
non_quantized_size += weights[key].numel() * weights[key].element_size()
else:
non_quantized_size += tensor.numel() * tensor.element_size()
continue continue
# Check if key matches target modules # Check if key matches target modules
parts = key.split(".") parts = key.split(".")
if len(parts) < key_idx + 1 or parts[key_idx] not in target_keys: if comfyui_mode and key in comfyui_keys:
if adapter_keys is not None and not any(adapter_key in parts for adapter_key in adapter_keys): pass
elif len(parts) < key_idx + 1 or parts[key_idx] not in target_keys:
if adapter_keys is None:
if tensor.dtype != non_linear_dtype: if tensor.dtype != non_linear_dtype:
weights[key] = tensor.to(non_linear_dtype) weights[key] = tensor.to(non_linear_dtype)
continue non_quantized_size += weights[key].numel() * weights[key].element_size()
else:
non_quantized_size += tensor.numel() * tensor.element_size()
elif not any(adapter_key in parts for adapter_key in adapter_keys):
if tensor.dtype != non_linear_dtype:
weights[key] = tensor.to(non_linear_dtype)
non_quantized_size += weights[key].numel() * weights[key].element_size()
else:
non_quantized_size += tensor.numel() * tensor.element_size()
else:
non_quantized_size += tensor.numel() * tensor.element_size()
continue
try: # try:
# Quantize tensor and store results original_tensor_size = tensor.numel() * tensor.element_size()
w_q, scales = quantize_tensor(tensor, w_bit, linear_dtype) original_size += original_tensor_size
# Quantize tensor and store results
w_q, scales = quantize_tensor(tensor, w_bit, linear_dtype, comfyui_mode)
# Replace original tensor and store scales # Replace original tensor and store scales
weights[key] = w_q weights[key] = w_q
if comfyui_mode:
weights[key.replace(".weight", ".scale_weight")] = scales
else:
weights[key + "_scale"] = scales weights[key + "_scale"] = scales
total_quantized += 1 quantized_tensor_size = w_q.numel() * w_q.element_size()
total_size += tensor.numel() * tensor.element_size() / (1024**2) # MB scale_size = scales.numel() * scales.element_size()
del w_q, scales quantized_size += quantized_tensor_size + scale_size
except Exception as e: total_quantized += 1
logger.error(f"Error quantizing {key}: {str(e)}") del w_q, scales
# except Exception as e:
# logger.error(f"Error quantizing {key}: {str(e)}")
gc.collect() gc.collect()
logger.info(f"Quantized {total_quantized} tensors, reduced size by {total_size:.2f} MB") original_size_mb = original_size / (1024**2)
quantized_size_mb = quantized_size / (1024**2)
non_quantized_size_mb = non_quantized_size / (1024**2)
total_final_size_mb = (quantized_size + non_quantized_size) / (1024**2)
size_reduction_mb = original_size_mb - quantized_size_mb
logger.info(f"Quantized {total_quantized} tensors")
logger.info(f"Original quantized tensors size: {original_size_mb:.2f} MB")
logger.info(f"After quantization size: {quantized_size_mb:.2f} MB (includes scales)")
logger.info(f"Non-quantized tensors size: {non_quantized_size_mb:.2f} MB")
logger.info(f"Total final model size: {total_final_size_mb:.2f} MB")
logger.info(f"Size reduction in quantized tensors: {size_reduction_mb:.2f} MB ({size_reduction_mb / original_size_mb * 100:.1f}%)")
if comfyui_mode:
weights["scaled_fp8"] = torch.zeros(2, dtype=torch.float8_e4m3fn)
return weights return weights
def load_loras(lora_path, weight_dict, alpha): def load_loras(lora_path, weight_dict, alpha, key_mapping_rules=None):
logger.info(f"Loading LoRA from: {lora_path}") logger.info(f"Loading LoRA from: {lora_path} with alpha={alpha}")
with safe_open(lora_path, framework="pt") as f: with safe_open(lora_path, framework="pt") as f:
lora_weights = {k: f.get_tensor(k) for k in f.keys()} lora_weights = {k: f.get_tensor(k) for k in f.keys()}
lora_pairs = {} lora_pairs = {}
lora_diffs = {} lora_diffs = {}
prefix = "diffusion_model." lora_alphas = {} # Store LoRA-specific alpha values
def try_lora_pair(key, suffix_a, suffix_b, target_suffix): # Extract LoRA alpha values if present
if key.endswith(suffix_a): for key in lora_weights.keys():
base_name = key[len(prefix) :].replace(suffix_a, target_suffix) if key.endswith(".alpha"):
pair_key = key.replace(suffix_a, suffix_b) base_key = key[:-6] # Remove .alpha
if pair_key in lora_weights: lora_alphas[base_key] = lora_weights[key].item()
lora_pairs[base_name] = (key, pair_key)
# Handle different prefixes: "diffusion_model." or "transformer_blocks." or no prefix
def try_lora_diff(key, suffix, target_suffix): def get_model_key(lora_key, suffix_to_remove, suffix_to_add):
if key.endswith(suffix): """Extract the model weight key from LoRA key"""
base_name = key[len(prefix) :].replace(suffix, target_suffix) # Remove the LoRA-specific suffix
lora_diffs[base_name] = key if lora_key.endswith(suffix_to_remove):
base = lora_key[: -len(suffix_to_remove)]
else:
return None
# For Qwen models, keep transformer_blocks prefix
# Check if this is a Qwen-style LoRA (transformer_blocks.NUMBER.)
if base.startswith("transformer_blocks.") and base.split(".")[1].isdigit():
# Keep the full path for Qwen models
model_key = base + suffix_to_add
else:
# Remove common prefixes for other models
prefixes_to_remove = ["diffusion_model.", "model.", "unet."]
for prefix in prefixes_to_remove:
if base.startswith(prefix):
base = base[len(prefix) :]
break
model_key = base + suffix_to_add
# Apply key mapping rules if provided (for converted models)
if key_mapping_rules:
for pattern, replacement in key_mapping_rules:
model_key = re.sub(pattern, replacement, model_key)
return model_key
# Collect all LoRA pairs and diffs
for key in lora_weights.keys(): for key in lora_weights.keys():
if not key.startswith(prefix): # Skip alpha parameters
if key.endswith(".alpha"):
continue continue
try_lora_pair(key, "lora_A.weight", "lora_B.weight", "weight") # Pattern 1: .lora_down.weight / .lora_up.weight
try_lora_pair(key, "lora_down.weight", "lora_up.weight", "weight") if key.endswith(".lora_down.weight"):
try_lora_diff(key, "diff", "weight") base = key[: -len(".lora_down.weight")]
try_lora_diff(key, "diff_b", "bias") up_key = base + ".lora_up.weight"
try_lora_diff(key, "diff_m", "modulation") if up_key in lora_weights:
model_key = get_model_key(key, ".lora_down.weight", ".weight")
if model_key:
lora_pairs[model_key] = (key, up_key)
# Pattern 2: .lora_A.weight / .lora_B.weight
elif key.endswith(".lora_A.weight"):
base = key[: -len(".lora_A.weight")]
b_key = base + ".lora_B.weight"
if b_key in lora_weights:
model_key = get_model_key(key, ".lora_A.weight", ".weight")
if model_key:
lora_pairs[model_key] = (key, b_key)
# Pattern 3: diff weights (direct addition)
elif key.endswith(".diff"):
model_key = get_model_key(key, ".diff", ".weight")
if model_key:
lora_diffs[model_key] = key
elif key.endswith(".diff_b"):
model_key = get_model_key(key, ".diff_b", ".bias")
if model_key:
lora_diffs[model_key] = key
elif key.endswith(".diff_m"):
model_key = get_model_key(key, ".diff_m", ".modulation")
if model_key:
lora_diffs[model_key] = key
applied_count = 0 applied_count = 0
unused_lora_keys = set()
# Apply LoRA weights by iterating through model weights
for name, param in weight_dict.items(): for name, param in weight_dict.items():
# Apply LoRA pairs (matmul pattern)
if name in lora_pairs: if name in lora_pairs:
name_lora_A, name_lora_B = lora_pairs[name] name_lora_down, name_lora_up = lora_pairs[name]
lora_A = lora_weights[name_lora_A].to(param.device, param.dtype) lora_down = lora_weights[name_lora_down].to(param.device, param.dtype)
lora_B = lora_weights[name_lora_B].to(param.device, param.dtype) lora_up = lora_weights[name_lora_up].to(param.device, param.dtype)
param += torch.matmul(lora_B, lora_A) * alpha
applied_count += 1 # Get LoRA-specific alpha if available, otherwise use global alpha
base_key = name_lora_down[: -len(".lora_down.weight")] if name_lora_down.endswith(".lora_down.weight") else name_lora_down[: -len(".lora_A.weight")]
lora_alpha = lora_alphas.get(base_key, alpha)
# Calculate rank from dimensions
rank = lora_down.shape[0] # rank is the output dimension of down projection
try:
# Standard LoRA formula: W' = W + (alpha/rank) * BA
# where B = up (rank x out_features), A = down (rank x in_features)
# Note: PyTorch linear layers store weight as (out_features, in_features)
if len(lora_down.shape) == 2 and len(lora_up.shape) == 2:
# For linear layers: down is (rank, in_features), up is (out_features, rank)
lora_delta = torch.mm(lora_up, lora_down) * (lora_alpha / rank)
else:
# For other shapes, try element-wise multiplication or skip
logger.warning(f"Unexpected LoRA shape for {name}: down={lora_down.shape}, up={lora_up.shape}")
continue
param.data += lora_delta
applied_count += 1
logger.debug(f"Applied LoRA to {name} with alpha={lora_alpha}, rank={rank}")
except Exception as e:
logger.warning(f"Failed to apply LoRA pair for {name}: {e}")
logger.warning(f" Shapes - param: {param.shape}, down: {lora_down.shape}, up: {lora_up.shape}")
# Apply diff weights (direct addition)
elif name in lora_diffs: elif name in lora_diffs:
name_diff = lora_diffs[name] name_diff = lora_diffs[name]
lora_diff = lora_weights[name_diff].to(param.device, param.dtype) lora_diff = lora_weights[name_diff].to(param.device, param.dtype)
try: try:
param += lora_diff * alpha param.data += lora_diff * alpha
applied_count += 1 applied_count += 1
logger.debug(f"Applied LoRA diff to {name}")
except Exception as e: except Exception as e:
continue logger.warning(f"Failed to apply LoRA diff for {name}: {e}")
logger.info(f"Applied {applied_count} LoRA weight adjustments")
# Check for unused LoRA weights (potential key mismatch issues)
used_lora_keys = set()
for down_key, up_key in lora_pairs.values():
used_lora_keys.add(down_key)
used_lora_keys.add(up_key)
for diff_key in lora_diffs.values():
used_lora_keys.add(diff_key)
all_lora_keys = set(k for k in lora_weights.keys() if not k.endswith(".alpha"))
unused_lora_keys = all_lora_keys - used_lora_keys
if unused_lora_keys:
logger.warning(f"Found {len(unused_lora_keys)} unused LoRA weights - this may indicate key mismatch:")
for key in list(unused_lora_keys)[:10]: # Show first 10
logger.warning(f" Unused: {key}")
if len(unused_lora_keys) > 10:
logger.warning(f" ... and {len(unused_lora_keys) - 10} more")
logger.info(f"Applied {applied_count} LoRA weight adjustments out of {len(lora_pairs) + len(lora_diffs)} possible")
if applied_count == 0 and (lora_pairs or lora_diffs):
logger.error("No LoRA weights were applied! Check for key name mismatches.")
logger.info("Model weight keys sample: " + str(list(weight_dict.keys())[:5]))
logger.info("LoRA pairs keys sample: " + str(list(lora_pairs.keys())[:5]))
logger.info("LoRA diff keys sample: " + str(list(lora_diffs.keys())[:5]))
def convert_weights(args): def convert_weights(args):
...@@ -473,6 +623,8 @@ def convert_weights(args): ...@@ -473,6 +623,8 @@ def convert_weights(args):
merged_weights = {} merged_weights = {}
logger.info(f"Processing source files: {src_files}") logger.info(f"Processing source files: {src_files}")
# Optimize loading for better memory usage
for file_path in tqdm(src_files, desc="Loading weights"): for file_path in tqdm(src_files, desc="Loading weights"):
logger.info(f"Loading weights from: {file_path}") logger.info(f"Loading weights from: {file_path}")
if file_path.endswith(".pt") or file_path.endswith(".pth"): if file_path.endswith(".pt") or file_path.endswith(".pth"):
...@@ -480,47 +632,117 @@ def convert_weights(args): ...@@ -480,47 +632,117 @@ def convert_weights(args):
if args.model_type == "hunyuan_dit": if args.model_type == "hunyuan_dit":
weights = weights["module"] weights = weights["module"]
elif file_path.endswith(".safetensors"): elif file_path.endswith(".safetensors"):
# Use lazy loading for safetensors to reduce memory usage
with safe_open(file_path, framework="pt") as f: with safe_open(file_path, framework="pt") as f:
weights = {k: f.get_tensor(k) for k in f.keys()} # Only load tensors when needed (lazy loading)
weights = {}
keys = f.keys()
# For large files, show progress
if len(keys) > 100:
for k in tqdm(keys, desc=f"Loading {os.path.basename(file_path)}", leave=False):
weights[k] = f.get_tensor(k)
else:
weights = {k: f.get_tensor(k) for k in keys}
duplicate_keys = set(weights.keys()) & set(merged_weights.keys()) duplicate_keys = set(weights.keys()) & set(merged_weights.keys())
if duplicate_keys: if duplicate_keys:
raise ValueError(f"Duplicate keys found: {duplicate_keys} in file {file_path}") raise ValueError(f"Duplicate keys found: {duplicate_keys} in file {file_path}")
merged_weights.update(weights)
if args.lora_path is not None: # Update weights more efficiently
# Handle alpha list - if single alpha, replicate for all LoRAs merged_weights.update(weights)
if len(args.lora_alpha) == 1 and len(args.lora_path) > 1:
args.lora_alpha = args.lora_alpha * len(args.lora_path)
elif len(args.lora_alpha) != len(args.lora_path):
raise ValueError(f"Number of lora_alpha ({len(args.lora_alpha)}) must match number of lora_path ({len(args.lora_path)}) or be 1")
for path, alpha in zip(args.lora_path, args.lora_alpha): # Clear weights dict to free memory
load_loras(path, merged_weights, alpha) del weights
if len(src_files) > 1:
gc.collect() # Force garbage collection between files
if args.direction is not None: if args.direction is not None:
rules = get_key_mapping_rules(args.direction, args.model_type) rules = get_key_mapping_rules(args.direction, args.model_type)
converted_weights = {} converted_weights = {}
logger.info("Converting keys...") logger.info("Converting keys...")
for key in tqdm(merged_weights.keys(), desc="Converting keys"):
# Pre-compile regex patterns for better performance
compiled_rules = [(re.compile(pattern), replacement) for pattern, replacement in rules]
def convert_key(key):
"""Convert a single key using compiled rules"""
new_key = key new_key = key
for pattern, replacement in rules: for pattern, replacement in compiled_rules:
new_key = re.sub(pattern, replacement, new_key) new_key = pattern.sub(replacement, new_key)
converted_weights[new_key] = merged_weights[key] return new_key
# Batch convert keys using list comprehension (faster than loop)
keys_list = list(merged_weights.keys())
# Use parallel processing for large models
if len(keys_list) > 1000 and args.parallel:
logger.info(f"Using parallel processing for {len(keys_list)} keys")
# Use ThreadPoolExecutor for I/O bound regex operations
num_workers = min(8, multiprocessing.cpu_count())
with ThreadPoolExecutor(max_workers=num_workers) as executor:
# Submit all conversion tasks
future_to_key = {executor.submit(convert_key, key): key for key in keys_list}
# Process results as they complete with progress bar
for future in tqdm(as_completed(future_to_key), total=len(keys_list), desc="Converting keys (parallel)"):
original_key = future_to_key[future]
new_key = future.result()
converted_weights[new_key] = merged_weights[original_key]
else:
# For smaller models, use simple loop with less overhead
for key in tqdm(keys_list, desc="Converting keys"):
new_key = convert_key(key)
converted_weights[new_key] = merged_weights[key]
else: else:
converted_weights = merged_weights converted_weights = merged_weights
# Apply LoRA AFTER key conversion to ensure proper key matching
if args.lora_path is not None:
# Handle alpha list - if single alpha, replicate for all LoRAs
if len(args.lora_alpha) == 1 and len(args.lora_path) > 1:
args.lora_alpha = args.lora_alpha * len(args.lora_path)
elif len(args.lora_alpha) != len(args.lora_path):
raise ValueError(f"Number of lora_alpha ({len(args.lora_alpha)}) must match number of lora_path ({len(args.lora_path)}) or be 1")
# Determine if we should apply key mapping rules to LoRA keys
key_mapping_rules = None
if args.lora_key_convert == "convert" and args.direction is not None:
# Apply same conversion as model
key_mapping_rules = get_key_mapping_rules(args.direction, args.model_type)
logger.info("Applying key conversion to LoRA weights")
elif args.lora_key_convert == "same":
# Don't convert LoRA keys
logger.info("Using original LoRA keys without conversion")
else: # auto
# Auto-detect: if model was converted, try with conversion first
if args.direction is not None:
key_mapping_rules = get_key_mapping_rules(args.direction, args.model_type)
logger.info("Auto mode: will try with key conversion first")
for path, alpha in zip(args.lora_path, args.lora_alpha):
# Pass key mapping rules to handle converted keys properly
load_loras(path, converted_weights, alpha, key_mapping_rules)
if args.quantized: if args.quantized:
converted_weights = quantize_model( if args.full_quantized and args.comfyui_mode:
converted_weights, logger.info("Quant all tensors...")
w_bit=args.bits, for k in converted_weights.keys():
target_keys=args.target_keys, converted_weights[k] = converted_weights[k].float().to(args.linear_dtype)
adapter_keys=args.adapter_keys, else:
key_idx=args.key_idx, converted_weights = quantize_model(
ignore_key=args.ignore_key, converted_weights,
linear_dtype=args.linear_dtype, w_bit=args.bits,
non_linear_dtype=args.non_linear_dtype, target_keys=args.target_keys,
) adapter_keys=args.adapter_keys,
key_idx=args.key_idx,
ignore_key=args.ignore_key,
linear_dtype=args.linear_dtype,
non_linear_dtype=args.non_linear_dtype,
comfyui_mode=args.comfyui_mode,
comfyui_keys=args.comfyui_keys,
)
os.makedirs(args.output, exist_ok=True) os.makedirs(args.output, exist_ok=True)
...@@ -529,8 +751,33 @@ def convert_weights(args): ...@@ -529,8 +751,33 @@ def convert_weights(args):
else: else:
index = {"metadata": {"total_size": 0}, "weight_map": {}} index = {"metadata": {"total_size": 0}, "weight_map": {}}
if args.single_file:
output_filename = f"{args.output_name}.safetensors"
output_path = os.path.join(args.output, output_filename)
logger.info(f"Saving model to single file: {output_path}")
if args.save_by_block: # For memory efficiency with large models
try:
# If model is very large (over threshold), consider warning
total_size = sum(tensor.numel() * tensor.element_size() for tensor in converted_weights.values())
total_size_gb = total_size / (1024**3)
if total_size_gb > 10: # Warn if model is larger than 10GB
logger.warning(f"Model size is {total_size_gb:.2f}GB. This will require significant memory to save as a single file.")
logger.warning("Consider using --save_by_block or default chunked saving for better memory efficiency.")
# Save the entire model as a single file
st.save_file(converted_weights, output_path)
logger.info(f"Model saved successfully to: {output_path} ({total_size_gb:.2f}GB)")
except MemoryError:
logger.error("Memory error while saving. The model is too large to save as a single file.")
logger.error("Please use --save_by_block or remove --single_file to use chunked saving.")
raise
except Exception as e:
logger.error(f"Error saving model: {e}")
raise
elif args.save_by_block:
logger.info("Backward conversion: grouping weights by block") logger.info("Backward conversion: grouping weights by block")
block_groups = defaultdict(dict) block_groups = defaultdict(dict)
non_block_weights = {} non_block_weights = {}
...@@ -649,6 +896,8 @@ def main(): ...@@ -649,6 +896,8 @@ def main():
parser.add_argument("-b", "--save_by_block", action="store_true") parser.add_argument("-b", "--save_by_block", action="store_true")
# Quantization # Quantization
parser.add_argument("--comfyui_mode", action="store_true")
parser.add_argument("--full_quantized", action="store_true")
parser.add_argument("--quantized", action="store_true") parser.add_argument("--quantized", action="store_true")
parser.add_argument("--bits", type=int, default=8, choices=[8], help="Quantization bit width") parser.add_argument("--bits", type=int, default=8, choices=[8], help="Quantization bit width")
parser.add_argument( parser.add_argument(
...@@ -679,8 +928,24 @@ def main(): ...@@ -679,8 +928,24 @@ def main():
help="Alpha for LoRA weight scaling", help="Alpha for LoRA weight scaling",
) )
parser.add_argument("--copy_no_weight_files", action="store_true") parser.add_argument("--copy_no_weight_files", action="store_true")
parser.add_argument("--single_file", action="store_true", help="Save as a single safetensors file instead of chunking (warning: requires loading entire model in memory)")
parser.add_argument(
"--lora_key_convert",
choices=["auto", "same", "convert"],
default="auto",
help="How to handle LoRA key conversion: 'auto' (detect from LoRA), 'same' (use original keys), 'convert' (apply same conversion as model)",
)
parser.add_argument("--parallel", action="store_true", default=True, help="Use parallel processing for faster conversion (default: True)")
parser.add_argument("--no-parallel", dest="parallel", action="store_false", help="Disable parallel processing")
args = parser.parse_args() args = parser.parse_args()
# Validate conflicting arguments
if args.single_file and args.save_by_block:
parser.error("--single_file and --save_by_block cannot be used together. Choose one saving strategy.")
if args.single_file and args.chunk_size > 0 and args.chunk_size != 100:
logger.warning("--chunk_size is ignored when using --single_file option.")
if args.quantized: if args.quantized:
args.linear_dtype = eval(args.linear_dtype) args.linear_dtype = eval(args.linear_dtype)
args.non_linear_dtype = eval(args.non_linear_dtype) args.non_linear_dtype = eval(args.non_linear_dtype)
...@@ -688,8 +953,16 @@ def main(): ...@@ -688,8 +953,16 @@ def main():
model_type_keys_map = { model_type_keys_map = {
"qwen_image_dit": { "qwen_image_dit": {
"key_idx": 2, "key_idx": 2,
"target_keys": ["attn", "img_mlp", "txt_mlp"], "target_keys": ["attn", "img_mlp", "txt_mlp", "txt_mod", "img_mod"],
"ignore_key": None, "ignore_key": None,
"comfyui_keys": [
"time_text_embed.timestep_embedder.linear_1.weight",
"time_text_embed.timestep_embedder.linear_2.weight",
"img_in.weight",
"txt_in.weight",
"norm_out.linear.weight",
"proj_out.weight",
],
}, },
"wan_dit": { "wan_dit": {
"key_idx": 2, "key_idx": 2,
...@@ -726,6 +999,7 @@ def main(): ...@@ -726,6 +999,7 @@ def main():
args.adapter_keys = model_type_keys_map[args.model_type]["adapter_keys"] if "adapter_keys" in model_type_keys_map[args.model_type] else None args.adapter_keys = model_type_keys_map[args.model_type]["adapter_keys"] if "adapter_keys" in model_type_keys_map[args.model_type] else None
args.key_idx = model_type_keys_map[args.model_type]["key_idx"] args.key_idx = model_type_keys_map[args.model_type]["key_idx"]
args.ignore_key = model_type_keys_map[args.model_type]["ignore_key"] args.ignore_key = model_type_keys_map[args.model_type]["ignore_key"]
args.comfyui_keys = model_type_keys_map[args.model_type]["comfyui_keys"] if "comfyui_keys" in model_type_keys_map[args.model_type] else None
if os.path.isfile(args.output): if os.path.isfile(args.output):
raise ValueError("Output path must be a directory, not a file") raise ValueError("Output path must be a directory, not a file")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment