Unverified Commit 49aff300 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files
parent c47dc6e8
......@@ -15,11 +15,12 @@
"cpu_offload": true,
"offload_granularity": "block",
"offload_ratio": 1,
"t5_cpu_offload": true,
"t5_offload_granularity": "model",
"t5_cpu_offload": false,
"t5_quantized": true,
"t5_quant_scheme": "fp8-q8f",
"clip_cpu_offload": false,
"clip_quantized": true,
"clip_quant_scheme": "fp8-q8f",
"audio_encoder_cpu_offload": false,
"audio_adapter_cpu_offload": false,
"adapter_quantized": true,
......
{
"infer_steps": 4,
"target_video_length": 81,
"text_len": 512,
"target_height": 720,
"target_width": 1280,
"self_attn_1_type": "sage_attn2",
"cross_attn_1_type": "sage_attn2",
"cross_attn_2_type": "sage_attn2",
"sample_guide_scale": [
3.5,
3.5
],
"sample_shift": 5.0,
"enable_cfg": false,
"cpu_offload": true,
"offload_granularity": "block",
"t5_cpu_offload": false,
"vae_cpu_offload": false,
"use_image_encoder": false,
"boundary_step_index": 2,
"denoising_step_list": [
1000,
750,
500,
250
],
"mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F"
},
"t5_quantized": true,
"t5_quant_scheme": "fp8-q8f"
}
......@@ -38,7 +38,7 @@ class WanRunner(DefaultRunner):
super().__init__(config)
self.vae_cls = WanVAE
self.tiny_vae_cls = WanVAE_tiny
self.vae_name = "Wan2.1_VAE.pth"
self.vae_name = config.get("vae_name", "Wan2.1_VAE.pth")
self.tiny_vae_name = "taew2_1.pth"
def load_transformer(self):
......@@ -73,7 +73,7 @@ class WanRunner(DefaultRunner):
clip_quant_scheme = self.config.get("clip_quant_scheme", None)
assert clip_quant_scheme is not None
tmp_clip_quant_scheme = clip_quant_scheme.split("-")[0]
clip_model_name = f"clip-{tmp_clip_quant_scheme}.pth"
clip_model_name = f"models_clip_open-clip-xlm-roberta-large-vit-huge-14-{tmp_clip_quant_scheme}.pth"
clip_quantized_ckpt = find_torch_model_path(self.config, "clip_quantized_ckpt", clip_model_name)
clip_original_ckpt = None
else:
......@@ -154,6 +154,7 @@ class WanRunner(DefaultRunner):
"cpu_offload": vae_offload,
"dtype": GET_DTYPE(),
"load_from_rank0": self.config.get("load_from_rank0", False),
"use_lightvae": self.config.get("use_lightvae", False),
}
if self.config["task"] not in ["i2v", "flf2v", "animate", "vace", "s2v"]:
return None
......@@ -174,6 +175,7 @@ class WanRunner(DefaultRunner):
"parallel": self.config["parallel"],
"use_tiling": self.config.get("use_tiling_vae", False),
"cpu_offload": vae_offload,
"use_lightvae": self.config.get("use_lightvae", False),
"dtype": GET_DTYPE(),
"load_from_rank0": self.config.get("load_from_rank0", False),
}
......
......@@ -263,16 +263,7 @@ class AttentionBlock(nn.Module):
class Encoder3d(nn.Module):
def __init__(
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0,
):
def __init__(self, dim=128, z_dim=4, dim_mult=[1, 2, 4, 4], num_res_blocks=2, attn_scales=[], temperal_downsample=[True, True, False], dropout=0.0, pruning_rate=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
......@@ -283,6 +274,7 @@ class Encoder3d(nn.Module):
# dimensions
dims = [dim * u for u in [1] + dim_mult]
dims = [int(d * (1 - pruning_rate)) for d in dims]
scale = 1.0
# init block
......@@ -375,16 +367,7 @@ class Encoder3d(nn.Module):
class Decoder3d(nn.Module):
def __init__(
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[False, True, True],
dropout=0.0,
):
def __init__(self, dim=128, z_dim=4, dim_mult=[1, 2, 4, 4], num_res_blocks=2, attn_scales=[], temperal_upsample=[False, True, True], dropout=0.0, pruning_rate=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
......@@ -395,6 +378,8 @@ class Decoder3d(nn.Module):
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
dims = [int(d * (1 - pruning_rate)) for d in dims]
scale = 1.0 / 2 ** (len(dim_mult) - 2)
# init block
......@@ -498,16 +483,7 @@ def count_conv3d(model):
class WanVAE_(nn.Module):
def __init__(
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0,
):
def __init__(self, dim=128, z_dim=4, dim_mult=[1, 2, 4, 4], num_res_blocks=2, attn_scales=[], temperal_downsample=[True, True, False], dropout=0.0, pruning_rate=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
......@@ -534,6 +510,7 @@ class WanVAE_(nn.Module):
attn_scales,
self.temperal_downsample,
dropout,
pruning_rate,
)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
......@@ -545,6 +522,7 @@ class WanVAE_(nn.Module):
attn_scales,
self.temperal_upsample,
dropout,
pruning_rate,
)
def forward(self, x):
......@@ -739,23 +717,6 @@ class WanVAE_(nn.Module):
self.clear_cache()
return out
def cached_decode(self, z, scale):
# z: [b,c,t,h,w]
if isinstance(scale[0], torch.Tensor):
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
else:
z = z / scale[1] + scale[0]
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
else:
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2)
return out
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
......@@ -778,7 +739,7 @@ class WanVAE_(nn.Module):
self._enc_feat_map = [None] * self._enc_conv_num
def _video_vae(pretrained_path=None, z_dim=None, device="cpu", cpu_offload=False, dtype=torch.float, load_from_rank0=False, **kwargs):
def _video_vae(pretrained_path=None, z_dim=None, device="cpu", cpu_offload=False, dtype=torch.float, load_from_rank0=False, pruning_rate=0.0, **kwargs):
"""
Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
"""
......@@ -791,6 +752,7 @@ def _video_vae(pretrained_path=None, z_dim=None, device="cpu", cpu_offload=False
attn_scales=[],
temperal_downsample=[False, True, True],
dropout=0.0,
pruning_rate=pruning_rate,
)
cfg.update(**kwargs)
......@@ -820,6 +782,7 @@ class WanVAE:
cpu_offload=False,
use_2d_split=True,
load_from_rank0=False,
use_lightvae=False,
):
self.dtype = dtype
self.device = device
......@@ -827,6 +790,10 @@ class WanVAE:
self.use_tiling = use_tiling
self.cpu_offload = cpu_offload
self.use_2d_split = use_2d_split
if use_lightvae:
pruning_rate = 0.75 # 0.75
else:
pruning_rate = 0.0
mean = [
-0.7571,
......@@ -906,7 +873,13 @@ class WanVAE:
}
# init model
self.model = _video_vae(pretrained_path=vae_pth, z_dim=z_dim, cpu_offload=cpu_offload, dtype=dtype, load_from_rank0=load_from_rank0).eval().requires_grad_(False).to(device).to(dtype)
self.model = (
_video_vae(pretrained_path=vae_pth, z_dim=z_dim, cpu_offload=cpu_offload, dtype=dtype, load_from_rank0=load_from_rank0, pruning_rate=pruning_rate)
.eval()
.requires_grad_(False)
.to(device)
.to(dtype)
)
def _calculate_2d_grid(self, latent_height, latent_width, world_size):
if (latent_height, latent_width, world_size) in self.grid_table:
......
......@@ -2,10 +2,12 @@ import argparse
import gc
import glob
import json
import multiprocessing
import os
import re
import shutil
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
import torch
from loguru import logger
......@@ -293,7 +295,7 @@ def get_key_mapping_rules(direction, model_type):
raise ValueError(f"Unsupported model type: {model_type}")
def quantize_tensor(w, w_bit=8, dtype=torch.int8):
def quantize_tensor(w, w_bit=8, dtype=torch.int8, comfyui_mode=False):
"""
Quantize a 2D tensor to specified bit width using symmetric min-max quantization
......@@ -314,14 +316,16 @@ def quantize_tensor(w, w_bit=8, dtype=torch.int8):
org_w_shape = w.shape
# Calculate quantization parameters
max_val = w.abs().amax(dim=1, keepdim=True).clamp(min=1e-5)
if not comfyui_mode:
max_val = w.abs().amax(dim=1, keepdim=True).clamp(min=1e-5)
else:
max_val = w.abs().max()
if dtype == torch.float8_e4m3fn:
finfo = torch.finfo(dtype)
qmin, qmax = finfo.min, finfo.max
elif dtype == torch.int8:
qmin, qmax = -128, 127
# Quantize tensor
scales = max_val / qmax
......@@ -335,21 +339,15 @@ def quantize_tensor(w, w_bit=8, dtype=torch.int8):
assert torch.isnan(scales).sum() == 0
assert torch.isnan(w_q).sum() == 0
scales = scales.view(org_w_shape[0], -1)
w_q = w_q.reshape(org_w_shape)
if not comfyui_mode:
scales = scales.view(org_w_shape[0], -1)
w_q = w_q.reshape(org_w_shape)
return w_q, scales
def quantize_model(
weights,
w_bit=8,
target_keys=["attn", "ffn"],
adapter_keys=None,
key_idx=2,
ignore_key=None,
linear_dtype=torch.int8,
non_linear_dtype=torch.float,
weights, w_bit=8, target_keys=["attn", "ffn"], adapter_keys=None, key_idx=2, ignore_key=None, linear_dtype=torch.int8, non_linear_dtype=torch.float, comfyui_mode=False, comfyui_keys=[]
):
"""
Quantize model weights in-place
......@@ -363,7 +361,9 @@ def quantize_model(
Modified state dictionary with quantized weights and scales
"""
total_quantized = 0
total_size = 0
original_size = 0
quantized_size = 0
non_quantized_size = 0
keys = list(weights.keys())
with tqdm(keys, desc="Quantizing weights") as pbar:
......@@ -380,87 +380,237 @@ def quantize_model(
if not isinstance(tensor, torch.Tensor) or tensor.dim() != 2:
if tensor.dtype != non_linear_dtype:
weights[key] = tensor.to(non_linear_dtype)
non_quantized_size += weights[key].numel() * weights[key].element_size()
else:
non_quantized_size += tensor.numel() * tensor.element_size()
continue
# Check if key matches target modules
parts = key.split(".")
if len(parts) < key_idx + 1 or parts[key_idx] not in target_keys:
if adapter_keys is not None and not any(adapter_key in parts for adapter_key in adapter_keys):
if comfyui_mode and key in comfyui_keys:
pass
elif len(parts) < key_idx + 1 or parts[key_idx] not in target_keys:
if adapter_keys is None:
if tensor.dtype != non_linear_dtype:
weights[key] = tensor.to(non_linear_dtype)
continue
non_quantized_size += weights[key].numel() * weights[key].element_size()
else:
non_quantized_size += tensor.numel() * tensor.element_size()
elif not any(adapter_key in parts for adapter_key in adapter_keys):
if tensor.dtype != non_linear_dtype:
weights[key] = tensor.to(non_linear_dtype)
non_quantized_size += weights[key].numel() * weights[key].element_size()
else:
non_quantized_size += tensor.numel() * tensor.element_size()
else:
non_quantized_size += tensor.numel() * tensor.element_size()
continue
try:
# Quantize tensor and store results
w_q, scales = quantize_tensor(tensor, w_bit, linear_dtype)
# try:
original_tensor_size = tensor.numel() * tensor.element_size()
original_size += original_tensor_size
# Quantize tensor and store results
w_q, scales = quantize_tensor(tensor, w_bit, linear_dtype, comfyui_mode)
# Replace original tensor and store scales
weights[key] = w_q
# Replace original tensor and store scales
weights[key] = w_q
if comfyui_mode:
weights[key.replace(".weight", ".scale_weight")] = scales
else:
weights[key + "_scale"] = scales
total_quantized += 1
total_size += tensor.numel() * tensor.element_size() / (1024**2) # MB
del w_q, scales
quantized_tensor_size = w_q.numel() * w_q.element_size()
scale_size = scales.numel() * scales.element_size()
quantized_size += quantized_tensor_size + scale_size
except Exception as e:
logger.error(f"Error quantizing {key}: {str(e)}")
total_quantized += 1
del w_q, scales
# except Exception as e:
# logger.error(f"Error quantizing {key}: {str(e)}")
gc.collect()
logger.info(f"Quantized {total_quantized} tensors, reduced size by {total_size:.2f} MB")
original_size_mb = original_size / (1024**2)
quantized_size_mb = quantized_size / (1024**2)
non_quantized_size_mb = non_quantized_size / (1024**2)
total_final_size_mb = (quantized_size + non_quantized_size) / (1024**2)
size_reduction_mb = original_size_mb - quantized_size_mb
logger.info(f"Quantized {total_quantized} tensors")
logger.info(f"Original quantized tensors size: {original_size_mb:.2f} MB")
logger.info(f"After quantization size: {quantized_size_mb:.2f} MB (includes scales)")
logger.info(f"Non-quantized tensors size: {non_quantized_size_mb:.2f} MB")
logger.info(f"Total final model size: {total_final_size_mb:.2f} MB")
logger.info(f"Size reduction in quantized tensors: {size_reduction_mb:.2f} MB ({size_reduction_mb / original_size_mb * 100:.1f}%)")
if comfyui_mode:
weights["scaled_fp8"] = torch.zeros(2, dtype=torch.float8_e4m3fn)
return weights
def load_loras(lora_path, weight_dict, alpha):
logger.info(f"Loading LoRA from: {lora_path}")
def load_loras(lora_path, weight_dict, alpha, key_mapping_rules=None):
logger.info(f"Loading LoRA from: {lora_path} with alpha={alpha}")
with safe_open(lora_path, framework="pt") as f:
lora_weights = {k: f.get_tensor(k) for k in f.keys()}
lora_pairs = {}
lora_diffs = {}
prefix = "diffusion_model."
lora_alphas = {} # Store LoRA-specific alpha values
def try_lora_pair(key, suffix_a, suffix_b, target_suffix):
if key.endswith(suffix_a):
base_name = key[len(prefix) :].replace(suffix_a, target_suffix)
pair_key = key.replace(suffix_a, suffix_b)
if pair_key in lora_weights:
lora_pairs[base_name] = (key, pair_key)
def try_lora_diff(key, suffix, target_suffix):
if key.endswith(suffix):
base_name = key[len(prefix) :].replace(suffix, target_suffix)
lora_diffs[base_name] = key
# Extract LoRA alpha values if present
for key in lora_weights.keys():
if key.endswith(".alpha"):
base_key = key[:-6] # Remove .alpha
lora_alphas[base_key] = lora_weights[key].item()
# Handle different prefixes: "diffusion_model." or "transformer_blocks." or no prefix
def get_model_key(lora_key, suffix_to_remove, suffix_to_add):
"""Extract the model weight key from LoRA key"""
# Remove the LoRA-specific suffix
if lora_key.endswith(suffix_to_remove):
base = lora_key[: -len(suffix_to_remove)]
else:
return None
# For Qwen models, keep transformer_blocks prefix
# Check if this is a Qwen-style LoRA (transformer_blocks.NUMBER.)
if base.startswith("transformer_blocks.") and base.split(".")[1].isdigit():
# Keep the full path for Qwen models
model_key = base + suffix_to_add
else:
# Remove common prefixes for other models
prefixes_to_remove = ["diffusion_model.", "model.", "unet."]
for prefix in prefixes_to_remove:
if base.startswith(prefix):
base = base[len(prefix) :]
break
model_key = base + suffix_to_add
# Apply key mapping rules if provided (for converted models)
if key_mapping_rules:
for pattern, replacement in key_mapping_rules:
model_key = re.sub(pattern, replacement, model_key)
return model_key
# Collect all LoRA pairs and diffs
for key in lora_weights.keys():
if not key.startswith(prefix):
# Skip alpha parameters
if key.endswith(".alpha"):
continue
try_lora_pair(key, "lora_A.weight", "lora_B.weight", "weight")
try_lora_pair(key, "lora_down.weight", "lora_up.weight", "weight")
try_lora_diff(key, "diff", "weight")
try_lora_diff(key, "diff_b", "bias")
try_lora_diff(key, "diff_m", "modulation")
# Pattern 1: .lora_down.weight / .lora_up.weight
if key.endswith(".lora_down.weight"):
base = key[: -len(".lora_down.weight")]
up_key = base + ".lora_up.weight"
if up_key in lora_weights:
model_key = get_model_key(key, ".lora_down.weight", ".weight")
if model_key:
lora_pairs[model_key] = (key, up_key)
# Pattern 2: .lora_A.weight / .lora_B.weight
elif key.endswith(".lora_A.weight"):
base = key[: -len(".lora_A.weight")]
b_key = base + ".lora_B.weight"
if b_key in lora_weights:
model_key = get_model_key(key, ".lora_A.weight", ".weight")
if model_key:
lora_pairs[model_key] = (key, b_key)
# Pattern 3: diff weights (direct addition)
elif key.endswith(".diff"):
model_key = get_model_key(key, ".diff", ".weight")
if model_key:
lora_diffs[model_key] = key
elif key.endswith(".diff_b"):
model_key = get_model_key(key, ".diff_b", ".bias")
if model_key:
lora_diffs[model_key] = key
elif key.endswith(".diff_m"):
model_key = get_model_key(key, ".diff_m", ".modulation")
if model_key:
lora_diffs[model_key] = key
applied_count = 0
unused_lora_keys = set()
# Apply LoRA weights by iterating through model weights
for name, param in weight_dict.items():
# Apply LoRA pairs (matmul pattern)
if name in lora_pairs:
name_lora_A, name_lora_B = lora_pairs[name]
lora_A = lora_weights[name_lora_A].to(param.device, param.dtype)
lora_B = lora_weights[name_lora_B].to(param.device, param.dtype)
param += torch.matmul(lora_B, lora_A) * alpha
applied_count += 1
name_lora_down, name_lora_up = lora_pairs[name]
lora_down = lora_weights[name_lora_down].to(param.device, param.dtype)
lora_up = lora_weights[name_lora_up].to(param.device, param.dtype)
# Get LoRA-specific alpha if available, otherwise use global alpha
base_key = name_lora_down[: -len(".lora_down.weight")] if name_lora_down.endswith(".lora_down.weight") else name_lora_down[: -len(".lora_A.weight")]
lora_alpha = lora_alphas.get(base_key, alpha)
# Calculate rank from dimensions
rank = lora_down.shape[0] # rank is the output dimension of down projection
try:
# Standard LoRA formula: W' = W + (alpha/rank) * BA
# where B = up (rank x out_features), A = down (rank x in_features)
# Note: PyTorch linear layers store weight as (out_features, in_features)
if len(lora_down.shape) == 2 and len(lora_up.shape) == 2:
# For linear layers: down is (rank, in_features), up is (out_features, rank)
lora_delta = torch.mm(lora_up, lora_down) * (lora_alpha / rank)
else:
# For other shapes, try element-wise multiplication or skip
logger.warning(f"Unexpected LoRA shape for {name}: down={lora_down.shape}, up={lora_up.shape}")
continue
param.data += lora_delta
applied_count += 1
logger.debug(f"Applied LoRA to {name} with alpha={lora_alpha}, rank={rank}")
except Exception as e:
logger.warning(f"Failed to apply LoRA pair for {name}: {e}")
logger.warning(f" Shapes - param: {param.shape}, down: {lora_down.shape}, up: {lora_up.shape}")
# Apply diff weights (direct addition)
elif name in lora_diffs:
name_diff = lora_diffs[name]
lora_diff = lora_weights[name_diff].to(param.device, param.dtype)
try:
param += lora_diff * alpha
param.data += lora_diff * alpha
applied_count += 1
logger.debug(f"Applied LoRA diff to {name}")
except Exception as e:
continue
logger.info(f"Applied {applied_count} LoRA weight adjustments")
logger.warning(f"Failed to apply LoRA diff for {name}: {e}")
# Check for unused LoRA weights (potential key mismatch issues)
used_lora_keys = set()
for down_key, up_key in lora_pairs.values():
used_lora_keys.add(down_key)
used_lora_keys.add(up_key)
for diff_key in lora_diffs.values():
used_lora_keys.add(diff_key)
all_lora_keys = set(k for k in lora_weights.keys() if not k.endswith(".alpha"))
unused_lora_keys = all_lora_keys - used_lora_keys
if unused_lora_keys:
logger.warning(f"Found {len(unused_lora_keys)} unused LoRA weights - this may indicate key mismatch:")
for key in list(unused_lora_keys)[:10]: # Show first 10
logger.warning(f" Unused: {key}")
if len(unused_lora_keys) > 10:
logger.warning(f" ... and {len(unused_lora_keys) - 10} more")
logger.info(f"Applied {applied_count} LoRA weight adjustments out of {len(lora_pairs) + len(lora_diffs)} possible")
if applied_count == 0 and (lora_pairs or lora_diffs):
logger.error("No LoRA weights were applied! Check for key name mismatches.")
logger.info("Model weight keys sample: " + str(list(weight_dict.keys())[:5]))
logger.info("LoRA pairs keys sample: " + str(list(lora_pairs.keys())[:5]))
logger.info("LoRA diff keys sample: " + str(list(lora_diffs.keys())[:5]))
def convert_weights(args):
......@@ -473,6 +623,8 @@ def convert_weights(args):
merged_weights = {}
logger.info(f"Processing source files: {src_files}")
# Optimize loading for better memory usage
for file_path in tqdm(src_files, desc="Loading weights"):
logger.info(f"Loading weights from: {file_path}")
if file_path.endswith(".pt") or file_path.endswith(".pth"):
......@@ -480,47 +632,117 @@ def convert_weights(args):
if args.model_type == "hunyuan_dit":
weights = weights["module"]
elif file_path.endswith(".safetensors"):
# Use lazy loading for safetensors to reduce memory usage
with safe_open(file_path, framework="pt") as f:
weights = {k: f.get_tensor(k) for k in f.keys()}
# Only load tensors when needed (lazy loading)
weights = {}
keys = f.keys()
# For large files, show progress
if len(keys) > 100:
for k in tqdm(keys, desc=f"Loading {os.path.basename(file_path)}", leave=False):
weights[k] = f.get_tensor(k)
else:
weights = {k: f.get_tensor(k) for k in keys}
duplicate_keys = set(weights.keys()) & set(merged_weights.keys())
if duplicate_keys:
raise ValueError(f"Duplicate keys found: {duplicate_keys} in file {file_path}")
merged_weights.update(weights)
if args.lora_path is not None:
# Handle alpha list - if single alpha, replicate for all LoRAs
if len(args.lora_alpha) == 1 and len(args.lora_path) > 1:
args.lora_alpha = args.lora_alpha * len(args.lora_path)
elif len(args.lora_alpha) != len(args.lora_path):
raise ValueError(f"Number of lora_alpha ({len(args.lora_alpha)}) must match number of lora_path ({len(args.lora_path)}) or be 1")
# Update weights more efficiently
merged_weights.update(weights)
for path, alpha in zip(args.lora_path, args.lora_alpha):
load_loras(path, merged_weights, alpha)
# Clear weights dict to free memory
del weights
if len(src_files) > 1:
gc.collect() # Force garbage collection between files
if args.direction is not None:
rules = get_key_mapping_rules(args.direction, args.model_type)
converted_weights = {}
logger.info("Converting keys...")
for key in tqdm(merged_weights.keys(), desc="Converting keys"):
# Pre-compile regex patterns for better performance
compiled_rules = [(re.compile(pattern), replacement) for pattern, replacement in rules]
def convert_key(key):
"""Convert a single key using compiled rules"""
new_key = key
for pattern, replacement in rules:
new_key = re.sub(pattern, replacement, new_key)
converted_weights[new_key] = merged_weights[key]
for pattern, replacement in compiled_rules:
new_key = pattern.sub(replacement, new_key)
return new_key
# Batch convert keys using list comprehension (faster than loop)
keys_list = list(merged_weights.keys())
# Use parallel processing for large models
if len(keys_list) > 1000 and args.parallel:
logger.info(f"Using parallel processing for {len(keys_list)} keys")
# Use ThreadPoolExecutor for I/O bound regex operations
num_workers = min(8, multiprocessing.cpu_count())
with ThreadPoolExecutor(max_workers=num_workers) as executor:
# Submit all conversion tasks
future_to_key = {executor.submit(convert_key, key): key for key in keys_list}
# Process results as they complete with progress bar
for future in tqdm(as_completed(future_to_key), total=len(keys_list), desc="Converting keys (parallel)"):
original_key = future_to_key[future]
new_key = future.result()
converted_weights[new_key] = merged_weights[original_key]
else:
# For smaller models, use simple loop with less overhead
for key in tqdm(keys_list, desc="Converting keys"):
new_key = convert_key(key)
converted_weights[new_key] = merged_weights[key]
else:
converted_weights = merged_weights
# Apply LoRA AFTER key conversion to ensure proper key matching
if args.lora_path is not None:
# Handle alpha list - if single alpha, replicate for all LoRAs
if len(args.lora_alpha) == 1 and len(args.lora_path) > 1:
args.lora_alpha = args.lora_alpha * len(args.lora_path)
elif len(args.lora_alpha) != len(args.lora_path):
raise ValueError(f"Number of lora_alpha ({len(args.lora_alpha)}) must match number of lora_path ({len(args.lora_path)}) or be 1")
# Determine if we should apply key mapping rules to LoRA keys
key_mapping_rules = None
if args.lora_key_convert == "convert" and args.direction is not None:
# Apply same conversion as model
key_mapping_rules = get_key_mapping_rules(args.direction, args.model_type)
logger.info("Applying key conversion to LoRA weights")
elif args.lora_key_convert == "same":
# Don't convert LoRA keys
logger.info("Using original LoRA keys without conversion")
else: # auto
# Auto-detect: if model was converted, try with conversion first
if args.direction is not None:
key_mapping_rules = get_key_mapping_rules(args.direction, args.model_type)
logger.info("Auto mode: will try with key conversion first")
for path, alpha in zip(args.lora_path, args.lora_alpha):
# Pass key mapping rules to handle converted keys properly
load_loras(path, converted_weights, alpha, key_mapping_rules)
if args.quantized:
converted_weights = quantize_model(
converted_weights,
w_bit=args.bits,
target_keys=args.target_keys,
adapter_keys=args.adapter_keys,
key_idx=args.key_idx,
ignore_key=args.ignore_key,
linear_dtype=args.linear_dtype,
non_linear_dtype=args.non_linear_dtype,
)
if args.full_quantized and args.comfyui_mode:
logger.info("Quant all tensors...")
for k in converted_weights.keys():
converted_weights[k] = converted_weights[k].float().to(args.linear_dtype)
else:
converted_weights = quantize_model(
converted_weights,
w_bit=args.bits,
target_keys=args.target_keys,
adapter_keys=args.adapter_keys,
key_idx=args.key_idx,
ignore_key=args.ignore_key,
linear_dtype=args.linear_dtype,
non_linear_dtype=args.non_linear_dtype,
comfyui_mode=args.comfyui_mode,
comfyui_keys=args.comfyui_keys,
)
os.makedirs(args.output, exist_ok=True)
......@@ -529,8 +751,33 @@ def convert_weights(args):
else:
index = {"metadata": {"total_size": 0}, "weight_map": {}}
if args.single_file:
output_filename = f"{args.output_name}.safetensors"
output_path = os.path.join(args.output, output_filename)
logger.info(f"Saving model to single file: {output_path}")
if args.save_by_block:
# For memory efficiency with large models
try:
# If model is very large (over threshold), consider warning
total_size = sum(tensor.numel() * tensor.element_size() for tensor in converted_weights.values())
total_size_gb = total_size / (1024**3)
if total_size_gb > 10: # Warn if model is larger than 10GB
logger.warning(f"Model size is {total_size_gb:.2f}GB. This will require significant memory to save as a single file.")
logger.warning("Consider using --save_by_block or default chunked saving for better memory efficiency.")
# Save the entire model as a single file
st.save_file(converted_weights, output_path)
logger.info(f"Model saved successfully to: {output_path} ({total_size_gb:.2f}GB)")
except MemoryError:
logger.error("Memory error while saving. The model is too large to save as a single file.")
logger.error("Please use --save_by_block or remove --single_file to use chunked saving.")
raise
except Exception as e:
logger.error(f"Error saving model: {e}")
raise
elif args.save_by_block:
logger.info("Backward conversion: grouping weights by block")
block_groups = defaultdict(dict)
non_block_weights = {}
......@@ -649,6 +896,8 @@ def main():
parser.add_argument("-b", "--save_by_block", action="store_true")
# Quantization
parser.add_argument("--comfyui_mode", action="store_true")
parser.add_argument("--full_quantized", action="store_true")
parser.add_argument("--quantized", action="store_true")
parser.add_argument("--bits", type=int, default=8, choices=[8], help="Quantization bit width")
parser.add_argument(
......@@ -679,8 +928,24 @@ def main():
help="Alpha for LoRA weight scaling",
)
parser.add_argument("--copy_no_weight_files", action="store_true")
parser.add_argument("--single_file", action="store_true", help="Save as a single safetensors file instead of chunking (warning: requires loading entire model in memory)")
parser.add_argument(
"--lora_key_convert",
choices=["auto", "same", "convert"],
default="auto",
help="How to handle LoRA key conversion: 'auto' (detect from LoRA), 'same' (use original keys), 'convert' (apply same conversion as model)",
)
parser.add_argument("--parallel", action="store_true", default=True, help="Use parallel processing for faster conversion (default: True)")
parser.add_argument("--no-parallel", dest="parallel", action="store_false", help="Disable parallel processing")
args = parser.parse_args()
# Validate conflicting arguments
if args.single_file and args.save_by_block:
parser.error("--single_file and --save_by_block cannot be used together. Choose one saving strategy.")
if args.single_file and args.chunk_size > 0 and args.chunk_size != 100:
logger.warning("--chunk_size is ignored when using --single_file option.")
if args.quantized:
args.linear_dtype = eval(args.linear_dtype)
args.non_linear_dtype = eval(args.non_linear_dtype)
......@@ -688,8 +953,16 @@ def main():
model_type_keys_map = {
"qwen_image_dit": {
"key_idx": 2,
"target_keys": ["attn", "img_mlp", "txt_mlp"],
"target_keys": ["attn", "img_mlp", "txt_mlp", "txt_mod", "img_mod"],
"ignore_key": None,
"comfyui_keys": [
"time_text_embed.timestep_embedder.linear_1.weight",
"time_text_embed.timestep_embedder.linear_2.weight",
"img_in.weight",
"txt_in.weight",
"norm_out.linear.weight",
"proj_out.weight",
],
},
"wan_dit": {
"key_idx": 2,
......@@ -726,6 +999,7 @@ def main():
args.adapter_keys = model_type_keys_map[args.model_type]["adapter_keys"] if "adapter_keys" in model_type_keys_map[args.model_type] else None
args.key_idx = model_type_keys_map[args.model_type]["key_idx"]
args.ignore_key = model_type_keys_map[args.model_type]["ignore_key"]
args.comfyui_keys = model_type_keys_map[args.model_type]["comfyui_keys"] if "comfyui_keys" in model_type_keys_map[args.model_type] else None
if os.path.isfile(args.output):
raise ValueError("Output path must be a directory, not a file")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment