Unverified Commit bf19c132 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

update cvt and lora loader (#391)


Co-authored-by: default avatargaclove <peng.gaoc@gmail.com>
parent 3d5b147b
import os
import queue
import signal
import socket
import subprocess
import threading
......@@ -368,7 +367,6 @@ class VARecorder:
if self.video_queue:
self.video_queue.put(None)
# Wait for threads to finish
if self.audio_thread and self.audio_thread.is_alive():
self.audio_thread.join(timeout=5)
if self.audio_thread.is_alive():
......@@ -378,32 +376,91 @@ class VARecorder:
if self.video_thread.is_alive():
logger.warning("Video push thread did not stop gracefully")
# Close TCP connections, sockets
if self.audio_conn:
self.audio_conn.close()
try:
self.audio_conn.getpeername()
self.audio_conn.shutdown(socket.SHUT_WR)
logger.info("Audio connection shutdown initiated")
except OSError:
# Connection already closed, skip shutdown
pass
if self.video_conn:
self.video_conn.close()
try:
self.video_conn.getpeername()
self.video_conn.shutdown(socket.SHUT_WR)
logger.info("Video connection shutdown initiated")
except OSError:
# Connection already closed, skip shutdown
pass
if self.ffmpeg_process:
is_local_file = not self.livestream_url.startswith(("rtmp://", "http"))
timeout_seconds = 15 if is_local_file else 10
logger.info(f"Waiting for FFmpeg to finalize (timeout={timeout_seconds}s, local_file={is_local_file})")
try:
self.ffmpeg_process.wait(timeout=timeout_seconds)
logger.info("FFmpeg process exited gracefully")
except subprocess.TimeoutExpired:
logger.warning(f"FFmpeg process did not exit within {timeout_seconds}s, sending SIGTERM...")
try:
self.ffmpeg_process.terminate() # SIGTERM
self.ffmpeg_process.wait(timeout=3)
logger.warning("FFmpeg process terminated with SIGTERM")
except subprocess.TimeoutExpired:
logger.error("FFmpeg process still running, killing with SIGKILL...")
self.ffmpeg_process.kill()
finally:
self.ffmpeg_process = None
if self.audio_conn:
try:
self.audio_conn.close()
except Exception as e:
logger.debug(f"Error closing audio connection: {e}")
finally:
self.audio_conn = None
if self.video_conn:
try:
self.video_conn.close()
except Exception as e:
logger.debug(f"Error closing video connection: {e}")
finally:
self.video_conn = None
if self.audio_socket:
self.audio_socket.close()
try:
self.audio_socket.close()
except Exception as e:
logger.debug(f"Error closing audio socket: {e}")
finally:
self.audio_socket = None
if self.video_socket:
self.video_socket.close()
try:
self.video_socket.close()
except Exception as e:
logger.debug(f"Error closing video socket: {e}")
finally:
self.video_socket = None
while self.audio_queue and self.audio_queue.qsize() > 0:
self.audio_queue.get_nowait()
while self.video_queue and self.video_queue.qsize() > 0:
self.video_queue.get_nowait()
if self.audio_queue:
while self.audio_queue.qsize() > 0:
try:
self.audio_queue.get_nowait()
except: # noqa
break
if self.video_queue:
while self.video_queue.qsize() > 0:
try:
self.video_queue.get_nowait()
except: # noqa
break
self.audio_queue = None
self.video_queue = None
logger.warning("Cleaned audio and video queues")
# Stop ffmpeg process
if self.ffmpeg_process:
self.ffmpeg_process.send_signal(signal.SIGINT)
try:
self.ffmpeg_process.wait(timeout=5)
except subprocess.TimeoutExpired:
self.ffmpeg_process.kill()
logger.warning("FFmpeg recorder process stopped")
logger.info("VARecorder stopped and resources cleaned up")
def __del__(self):
self.stop(wait=False)
......
......@@ -364,6 +364,8 @@ class WanAudioRunner(WanRunner): # type:ignore
monitor_cli.lightx2v_input_audio_len.observe(audio_len)
expected_frames = min(max(1, int(video_duration * target_fps)), audio_len)
if expected_frames < int(video_duration * target_fps):
logger.warning(f"Input video duration is greater than actual audio duration, using audio duration instead: audio_duration={audio_len / target_fps}, video_duration={video_duration}")
# Segment audio
audio_segments = self._audio_processor.segment_audio(audio_array, expected_frames, self.config.get("target_video_length", 81), self.prev_frame_length)
......
......@@ -11,7 +11,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
import torch
from loguru import logger
from qtorch.quant import float_quantize
from lora_loader import LoRALoader
from safetensors import safe_open
from safetensors import torch as st
from tqdm import tqdm
......@@ -330,6 +330,8 @@ def quantize_tensor(w, w_bit=8, dtype=torch.int8, comfyui_mode=False):
scales = max_val / qmax
if dtype == torch.float8_e4m3fn:
from qtorch.quant import float_quantize
scaled_tensor = w / scales
scaled_tensor = torch.clip(scaled_tensor, qmin, qmax)
w_q = float_quantize(scaled_tensor.float(), 4, 3, rounding="nearest").to(dtype)
......@@ -452,165 +454,33 @@ def quantize_model(
return weights
def load_loras(lora_path, weight_dict, alpha, key_mapping_rules=None):
logger.info(f"Loading LoRA from: {lora_path} with alpha={alpha}")
with safe_open(lora_path, framework="pt") as f:
lora_weights = {k: f.get_tensor(k) for k in f.keys()}
lora_pairs = {}
lora_diffs = {}
lora_alphas = {} # Store LoRA-specific alpha values
# Extract LoRA alpha values if present
for key in lora_weights.keys():
if key.endswith(".alpha"):
base_key = key[:-6] # Remove .alpha
lora_alphas[base_key] = lora_weights[key].item()
# Handle different prefixes: "diffusion_model." or "transformer_blocks." or no prefix
def get_model_key(lora_key, suffix_to_remove, suffix_to_add):
"""Extract the model weight key from LoRA key"""
# Remove the LoRA-specific suffix
if lora_key.endswith(suffix_to_remove):
base = lora_key[: -len(suffix_to_remove)]
else:
return None
# For Qwen models, keep transformer_blocks prefix
# Check if this is a Qwen-style LoRA (transformer_blocks.NUMBER.)
if base.startswith("transformer_blocks.") and base.split(".")[1].isdigit():
# Keep the full path for Qwen models
model_key = base + suffix_to_add
else:
# Remove common prefixes for other models
prefixes_to_remove = ["diffusion_model.", "model.", "unet."]
for prefix in prefixes_to_remove:
if base.startswith(prefix):
base = base[len(prefix) :]
break
model_key = base + suffix_to_add
# Apply key mapping rules if provided (for converted models)
if key_mapping_rules:
for pattern, replacement in key_mapping_rules:
model_key = re.sub(pattern, replacement, model_key)
return model_key
# Collect all LoRA pairs and diffs
for key in lora_weights.keys():
# Skip alpha parameters
if key.endswith(".alpha"):
continue
# Pattern 1: .lora_down.weight / .lora_up.weight
if key.endswith(".lora_down.weight"):
base = key[: -len(".lora_down.weight")]
up_key = base + ".lora_up.weight"
if up_key in lora_weights:
model_key = get_model_key(key, ".lora_down.weight", ".weight")
if model_key:
lora_pairs[model_key] = (key, up_key)
# Pattern 2: .lora_A.weight / .lora_B.weight
elif key.endswith(".lora_A.weight"):
base = key[: -len(".lora_A.weight")]
b_key = base + ".lora_B.weight"
if b_key in lora_weights:
model_key = get_model_key(key, ".lora_A.weight", ".weight")
if model_key:
lora_pairs[model_key] = (key, b_key)
# Pattern 3: diff weights (direct addition)
elif key.endswith(".diff"):
model_key = get_model_key(key, ".diff", ".weight")
if model_key:
lora_diffs[model_key] = key
elif key.endswith(".diff_b"):
model_key = get_model_key(key, ".diff_b", ".bias")
if model_key:
lora_diffs[model_key] = key
elif key.endswith(".diff_m"):
model_key = get_model_key(key, ".diff_m", ".modulation")
if model_key:
lora_diffs[model_key] = key
applied_count = 0
unused_lora_keys = set()
# Apply LoRA weights by iterating through model weights
for name, param in weight_dict.items():
# Apply LoRA pairs (matmul pattern)
if name in lora_pairs:
name_lora_down, name_lora_up = lora_pairs[name]
lora_down = lora_weights[name_lora_down].to(param.device, param.dtype)
lora_up = lora_weights[name_lora_up].to(param.device, param.dtype)
# Get LoRA-specific alpha if available, otherwise use global alpha
base_key = name_lora_down[: -len(".lora_down.weight")] if name_lora_down.endswith(".lora_down.weight") else name_lora_down[: -len(".lora_A.weight")]
lora_alpha = lora_alphas.get(base_key, alpha)
# Calculate rank from dimensions
rank = lora_down.shape[0] # rank is the output dimension of down projection
try:
# Standard LoRA formula: W' = W + (alpha/rank) * BA
# where B = up (rank x out_features), A = down (rank x in_features)
# Note: PyTorch linear layers store weight as (out_features, in_features)
if len(lora_down.shape) == 2 and len(lora_up.shape) == 2:
# For linear layers: down is (rank, in_features), up is (out_features, rank)
lora_delta = torch.mm(lora_up, lora_down) * (lora_alpha / rank)
else:
# For other shapes, try element-wise multiplication or skip
logger.warning(f"Unexpected LoRA shape for {name}: down={lora_down.shape}, up={lora_up.shape}")
continue
param.data += lora_delta
applied_count += 1
logger.debug(f"Applied LoRA to {name} with alpha={lora_alpha}, rank={rank}")
except Exception as e:
logger.warning(f"Failed to apply LoRA pair for {name}: {e}")
logger.warning(f" Shapes - param: {param.shape}, down: {lora_down.shape}, up: {lora_up.shape}")
# Apply diff weights (direct addition)
elif name in lora_diffs:
name_diff = lora_diffs[name]
lora_diff = lora_weights[name_diff].to(param.device, param.dtype)
try:
param.data += lora_diff * alpha
applied_count += 1
logger.debug(f"Applied LoRA diff to {name}")
except Exception as e:
logger.warning(f"Failed to apply LoRA diff for {name}: {e}")
# Check for unused LoRA weights (potential key mismatch issues)
used_lora_keys = set()
for down_key, up_key in lora_pairs.values():
used_lora_keys.add(down_key)
used_lora_keys.add(up_key)
for diff_key in lora_diffs.values():
used_lora_keys.add(diff_key)
def load_loras(lora_path, weight_dict, alpha, key_mapping_rules=None, strength=1.0):
"""
Load and apply LoRA weights to model weights using the LoRALoader class.
all_lora_keys = set(k for k in lora_weights.keys() if not k.endswith(".alpha"))
unused_lora_keys = all_lora_keys - used_lora_keys
Args:
lora_path: Path to LoRA safetensors file
weight_dict: Model weights dictionary (will be modified in place)
alpha: Global alpha scaling factor
key_mapping_rules: Optional list of (pattern, replacement) regex rules for key mapping
strength: Additional strength factor for LoRA deltas
"""
logger.info(f"Loading LoRA from: {lora_path} with alpha={alpha}, strength={strength}")
if unused_lora_keys:
logger.warning(f"Found {len(unused_lora_keys)} unused LoRA weights - this may indicate key mismatch:")
for key in list(unused_lora_keys)[:10]: # Show first 10
logger.warning(f" Unused: {key}")
if len(unused_lora_keys) > 10:
logger.warning(f" ... and {len(unused_lora_keys) - 10} more")
# Load LoRA weights from safetensors file
with safe_open(lora_path, framework="pt") as f:
lora_weights = {k: f.get_tensor(k) for k in f.keys()}
logger.info(f"Applied {applied_count} LoRA weight adjustments out of {len(lora_pairs) + len(lora_diffs)} possible")
# Create LoRA loader with key mapping rules
lora_loader = LoRALoader(key_mapping_rules=key_mapping_rules)
if applied_count == 0 and (lora_pairs or lora_diffs):
logger.error("No LoRA weights were applied! Check for key name mismatches.")
logger.info("Model weight keys sample: " + str(list(weight_dict.keys())[:5]))
logger.info("LoRA pairs keys sample: " + str(list(lora_pairs.keys())[:5]))
logger.info("LoRA diff keys sample: " + str(list(lora_diffs.keys())[:5]))
# Apply LoRA weights to model
lora_loader.apply_lora(
weight_dict=weight_dict,
lora_weights=lora_weights,
alpha=alpha,
strength=strength,
)
def convert_weights(args):
......@@ -701,10 +571,18 @@ def convert_weights(args):
# Apply LoRA AFTER key conversion to ensure proper key matching
if args.lora_path is not None:
# Handle alpha list - if single alpha, replicate for all LoRAs
if len(args.lora_alpha) == 1 and len(args.lora_path) > 1:
args.lora_alpha = args.lora_alpha * len(args.lora_path)
elif len(args.lora_alpha) != len(args.lora_path):
raise ValueError(f"Number of lora_alpha ({len(args.lora_alpha)}) must match number of lora_path ({len(args.lora_path)}) or be 1")
if args.lora_alpha is not None:
if len(args.lora_alpha) == 1 and len(args.lora_path) > 1:
args.lora_alpha = args.lora_alpha * len(args.lora_path)
elif len(args.lora_alpha) != len(args.lora_path):
raise ValueError(f"Number of lora_alpha ({len(args.lora_alpha)}) must match number of lora_path ({len(args.lora_path)}) or be 1")
# Normalize strength list
if args.lora_strength is not None:
if len(args.lora_strength) == 1 and len(args.lora_path) > 1:
args.lora_strength = args.lora_strength * len(args.lora_path)
elif len(args.lora_strength) != len(args.lora_path):
raise ValueError(f"Number of strength ({len(args.lora_strength)}) must match number of lora_path ({len(args.lora_path)}) or be 1")
# Determine if we should apply key mapping rules to LoRA keys
key_mapping_rules = None
......@@ -721,9 +599,11 @@ def convert_weights(args):
key_mapping_rules = get_key_mapping_rules(args.direction, args.model_type)
logger.info("Auto mode: will try with key conversion first")
for path, alpha in zip(args.lora_path, args.lora_alpha):
for idx, path in enumerate(args.lora_path):
# Pass key mapping rules to handle converted keys properly
load_loras(path, converted_weights, alpha, key_mapping_rules)
strength = args.lora_strength[idx] if args.lora_strength is not None else 1.0
alpha = args.lora_alpha[idx] if args.lora_alpha is not None else None
load_loras(path, converted_weights, alpha, key_mapping_rules, strength=strength)
if args.quantized:
if args.full_quantized and args.comfyui_mode:
......@@ -924,8 +804,14 @@ def main():
"--lora_alpha",
type=float,
nargs="*",
default=[1.0],
help="Alpha for LoRA weight scaling",
default=None,
help="Alpha for LoRA weight scaling, Default non scaling. ",
)
parser.add_argument(
"--lora_strength",
type=float,
nargs="*",
help="Additional strength factor(s) for LoRA deltas; default 1.0",
)
parser.add_argument("--copy_no_weight_files", action="store_true")
parser.add_argument("--single_file", action="store_true", help="Save as a single safetensors file instead of chunking (warning: requires loading entire model in memory)")
......
"""
LoRA (Low-Rank Adaptation) loader with support for multiple format patterns.
Supported formats:
- Standard: {key}.lora_up.weight and {key}.lora_down.weight
- Diffusers: {key}_lora.up.weight and {key}_lora.down.weight
- Diffusers v2: {key}.lora_B.weight and {key}.lora_A.weight (B=up, A=down)
- Diffusers v3: {key}.lora.up.weight and {key}.lora.down.weight
- Mochi: {key}.lora_B and {key}.lora_A (no .weight suffix)
- Transformers: {key}.lora_linear_layer.up.weight and {key}.lora_linear_layer.down.weight
- Qwen: {key}.lora_B.default.weight and {key}.lora_A.default.weight
"""
import re
from enum import Enum
from typing import Dict, List, Optional, Tuple
import torch
from loguru import logger
class LoRAFormat(Enum):
"""Enum for different LoRA format patterns."""
STANDARD = "standard"
DIFFUSERS = "diffusers"
DIFFUSERS_V2 = "diffusers_v2"
DIFFUSERS_V3 = "diffusers_v3"
MOCHI = "mochi"
TRANSFORMERS = "transformers"
QWEN = "qwen"
class LoRAPatternDefinition:
"""Defines a single LoRA format pattern and how to extract its components."""
def __init__(
self,
format_name: LoRAFormat,
up_suffix: str,
down_suffix: str,
has_weight_suffix: bool = True,
mid_suffix: Optional[str] = None,
):
"""
Args:
format_name: The LoRA format type
up_suffix: Suffix for the up (B) weight matrix (e.g., ".lora_up.weight")
down_suffix: Suffix for the down (A) weight matrix (e.g., ".lora_down.weight")
has_weight_suffix: Whether the format includes .weight suffix
mid_suffix: Optional suffix for mid weight (only used in standard format)
"""
self.format_name = format_name
self.up_suffix = up_suffix
self.down_suffix = down_suffix
self.has_weight_suffix = has_weight_suffix
self.mid_suffix = mid_suffix
def get_base_key(self, key: str, detected_suffix: str) -> Optional[str]:
"""Extract base key by removing the detected suffix."""
if key.endswith(detected_suffix):
return key[: -len(detected_suffix)]
return None
class LoRAPatternMatcher:
"""Detects and matches LoRA format patterns in state dicts."""
def __init__(self):
"""Initialize the pattern matcher with all supported formats."""
self.patterns: Dict[LoRAFormat, LoRAPatternDefinition] = {
LoRAFormat.STANDARD: LoRAPatternDefinition(
LoRAFormat.STANDARD,
up_suffix=".lora_up.weight",
down_suffix=".lora_down.weight",
mid_suffix=".lora_mid.weight",
),
LoRAFormat.DIFFUSERS: LoRAPatternDefinition(
LoRAFormat.DIFFUSERS,
up_suffix="_lora.up.weight",
down_suffix="_lora.down.weight",
),
LoRAFormat.DIFFUSERS_V2: LoRAPatternDefinition(
LoRAFormat.DIFFUSERS_V2,
up_suffix=".lora_B.weight",
down_suffix=".lora_A.weight",
),
LoRAFormat.DIFFUSERS_V3: LoRAPatternDefinition(
LoRAFormat.DIFFUSERS_V3,
up_suffix=".lora.up.weight",
down_suffix=".lora.down.weight",
),
LoRAFormat.MOCHI: LoRAPatternDefinition(
LoRAFormat.MOCHI,
up_suffix=".lora_B",
down_suffix=".lora_A",
has_weight_suffix=False,
),
LoRAFormat.TRANSFORMERS: LoRAPatternDefinition(
LoRAFormat.TRANSFORMERS,
up_suffix=".lora_linear_layer.up.weight",
down_suffix=".lora_linear_layer.down.weight",
),
LoRAFormat.QWEN: LoRAPatternDefinition(
LoRAFormat.QWEN,
up_suffix=".lora_B.default.weight",
down_suffix=".lora_A.default.weight",
),
}
def detect_format(self, key: str, lora_weights: Dict) -> Optional[Tuple[LoRAFormat, str]]:
"""
Detect the LoRA format of a given key.
Args:
key: The weight key to check
lora_weights: The full LoRA weights dictionary
Returns:
Tuple of (LoRAFormat, detected_suffix) if format detected, None otherwise
"""
for format_type, pattern in self.patterns.items():
if key.endswith(pattern.up_suffix):
return (format_type, pattern.up_suffix)
return None
def extract_lora_pair(
self,
key: str,
lora_weights: Dict,
lora_alphas: Dict,
) -> Optional[Dict]:
"""
Extract a complete LoRA pair (up and down weights) from the state dict.
Args:
key: The up weight key
lora_weights: The full LoRA weights dictionary
lora_alphas: Dictionary of alpha values by base key
Returns:
Dictionary with extracted LoRA information, or None if pair is incomplete
"""
format_detected = self.detect_format(key, lora_weights)
if format_detected is None:
return None
format_type, up_suffix = format_detected
pattern = self.patterns[format_type]
# Extract base key
base_key = pattern.get_base_key(key, up_suffix)
if base_key is None:
return None
# Check if down weight exists
down_key = base_key + pattern.down_suffix
if down_key not in lora_weights:
return None
# Check for mid weight (only for standard format)
mid_key = None
if pattern.mid_suffix:
mid_key = base_key + pattern.mid_suffix
if mid_key not in lora_weights:
mid_key = None
# Get alpha value
alpha = lora_alphas.get(base_key, None)
return {
"format": format_type,
"base_key": base_key,
"up_key": key,
"down_key": down_key,
"mid_key": mid_key,
"alpha": alpha,
}
class LoRALoader:
"""Loads and applies LoRA weights to model weights using pattern matching."""
def __init__(self, key_mapping_rules: Optional[List[Tuple[str, str]]] = None):
"""
Args:
key_mapping_rules: Optional list of (pattern, replacement) regex rules for key mapping
"""
self.pattern_matcher = LoRAPatternMatcher()
self.key_mapping_rules = key_mapping_rules or []
self._compile_rules()
def _compile_rules(self):
"""Pre-compile regex patterns for better performance."""
self.compiled_rules = [(re.compile(pattern), replacement) for pattern, replacement in self.key_mapping_rules]
def _apply_key_mapping(self, key: str) -> str:
"""Apply key mapping rules to a key."""
for pattern, replacement in self.compiled_rules:
key = pattern.sub(replacement, key)
return key
def _get_model_key(
self,
lora_key: str,
base_key: str,
suffix_to_remove: str,
suffix_to_add: str = ".weight",
) -> Optional[str]:
"""
Extract the model weight key from LoRA key with proper prefix handling.
Args:
lora_key: The original LoRA key
base_key: The base key after removing LoRA suffix
suffix_to_remove: The suffix that was removed
suffix_to_add: The suffix to add for model key
Returns:
The model key, or None if extraction fails
"""
# For Qwen models, keep transformer_blocks prefix
if base_key.startswith("transformer_blocks.") and len(base_key.split(".")) > 1:
if base_key.split(".")[1].isdigit():
# Keep the full path for Qwen models
model_key = base_key + suffix_to_add
else:
# Remove common prefixes for other models
model_key = self._remove_prefixes(base_key) + suffix_to_add
else:
# Remove common prefixes for other models
model_key = self._remove_prefixes(base_key) + suffix_to_add
# Apply key mapping rules if provided
if self.compiled_rules:
model_key = self._apply_key_mapping(model_key)
return model_key
@staticmethod
def _remove_prefixes(key: str) -> str:
"""Remove common model prefixes from a key."""
prefixes_to_remove = ["diffusion_model.", "model.", "unet."]
for prefix in prefixes_to_remove:
if key.startswith(prefix):
return key[len(prefix) :]
return key
def extract_lora_alphas(self, lora_weights: Dict) -> Dict:
"""Extract LoRA alpha values from the state dict."""
lora_alphas = {}
for key in lora_weights.keys():
if key.endswith(".alpha"):
base_key = key[:-6] # Remove .alpha
lora_alphas[base_key] = lora_weights[key].item()
return lora_alphas
def extract_lora_pairs(self, lora_weights: Dict) -> Dict[str, Dict]:
"""
Extract all LoRA pairs from the state dict, mapping to model keys.
Args:
lora_weights: The LoRA state dictionary
Returns:
Dictionary mapping model keys to LoRA pair information
"""
lora_alphas = self.extract_lora_alphas(lora_weights)
lora_pairs = {}
for key in lora_weights.keys():
# Skip alpha parameters
if key.endswith(".alpha"):
continue
# Try to extract LoRA pair
pair_info = self.pattern_matcher.extract_lora_pair(key, lora_weights, lora_alphas)
if pair_info is None:
continue
# Determine the suffix to remove and add based on format
format_type = pair_info["format"]
pattern = self.pattern_matcher.patterns[format_type]
# Get the model key
model_key = self._get_model_key(
pair_info["up_key"],
pair_info["base_key"],
pattern.up_suffix,
".weight",
)
if model_key is None:
logger.warning(f"Failed to extract model key from LoRA key: {key}")
continue
lora_pairs[model_key] = pair_info
return lora_pairs
def extract_lora_diffs(self, lora_weights: Dict) -> Dict[str, Dict]:
"""
Extract diff-style LoRA weights (direct addition, not matrix multiplication).
Args:
lora_weights: The LoRA state dictionary
Returns:
Dictionary mapping model keys to diff information
"""
lora_diffs = {}
# Define diff patterns: (suffix_to_check, suffix_to_remove, suffix_to_add)
diff_patterns = [
(".diff", ".diff", ".weight"),
(".diff_b", ".diff_b", ".bias"),
(".diff_m", ".diff_m", ".modulation"),
]
for key in lora_weights.keys():
for check_suffix, remove_suffix, add_suffix in diff_patterns:
if key.endswith(check_suffix):
base_key = key[: -len(remove_suffix)]
model_key = self._get_model_key(key, base_key, remove_suffix, add_suffix)
if model_key:
lora_diffs[model_key] = {
"diff_key": key,
"type": check_suffix,
}
break
return lora_diffs
def apply_lora(
self,
weight_dict: Dict[str, torch.Tensor],
lora_weights: Dict[str, torch.Tensor],
alpha: float = None,
strength: float = 1.0,
) -> int:
"""
Apply LoRA weights to model weights.
Args:
weight_dict: The model weights dictionary (will be modified in place)
lora_weights: The LoRA weights dictionary
alpha: Global alpha scaling factor
strength: Additional strength factor for LoRA deltas
Returns:
Number of LoRA weights successfully applied
"""
# Extract LoRA pairs, diffs, and alphas
lora_pairs = self.extract_lora_pairs(lora_weights)
lora_diffs = self.extract_lora_diffs(lora_weights)
applied_count = 0
used_lora_keys = set()
# Apply LoRA pairs (matrix multiplication)
for model_key, pair_info in lora_pairs.items():
if model_key not in weight_dict:
logger.debug(f"Model key not found: {model_key}")
continue
param = weight_dict[model_key]
up_key = pair_info["up_key"]
down_key = pair_info["down_key"]
# Track used keys
used_lora_keys.add(up_key)
used_lora_keys.add(down_key)
if pair_info["mid_key"]:
used_lora_keys.add(pair_info["mid_key"])
try:
lora_up = lora_weights[up_key].to(param.device, param.dtype)
lora_down = lora_weights[down_key].to(param.device, param.dtype)
# Get LoRA-specific alpha if available, otherwise use global alpha
# Apply LoRA: W' = W + (alpha/rank) * B @ A
# where B = up (out_features, rank), A = down (rank, in_features)
if pair_info["alpha"]:
lora_scale = pair_info["alpha"] / lora_down.shape[0]
elif alpha is not None:
lora_scale = alpha / lora_down.shape[0]
else:
lora_scale = 1
if len(lora_down.shape) == 2 and len(lora_up.shape) == 2:
lora_delta = torch.mm(lora_up, lora_down) * lora_scale
if strength is not None:
lora_delta = lora_delta * float(strength)
param.data += lora_delta
applied_count += 1
logger.debug(f"Applied LoRA to {model_key} with lora_scale={lora_scale}")
else:
logger.warning(f"Unexpected LoRA shape for {model_key}: down={lora_down.shape}, up={lora_up.shape}")
except Exception as e:
logger.warning(f"Failed to apply LoRA pair for {model_key}: {e}")
logger.warning(f" Shapes - param: {param.shape}, down: {lora_weights[down_key].shape}, up: {lora_weights[up_key].shape}")
# Apply diff weights (direct addition)
for model_key, diff_info in lora_diffs.items():
if model_key not in weight_dict:
logger.debug(f"Model key not found for diff: {model_key}")
continue
param = weight_dict[model_key]
diff_key = diff_info["diff_key"]
# Track used keys
used_lora_keys.add(diff_key)
try:
lora_diff = lora_weights[diff_key].to(param.device, param.dtype)
if alpha is not None:
param.data += lora_diff * alpha * (float(strength) if strength is not None else 1.0)
else:
param.data += lora_diff * (float(strength) if strength is not None else 1.0)
applied_count += 1
logger.debug(f"Applied LoRA diff to {model_key} (type: {diff_info['type']})")
except Exception as e:
logger.warning(f"Failed to apply LoRA diff for {model_key}: {e}")
# Warn about unused keys
all_lora_keys = set(k for k in lora_weights.keys() if not k.endswith(".alpha"))
unused_lora_keys = all_lora_keys - used_lora_keys
if unused_lora_keys:
logger.warning(f"Found {len(unused_lora_keys)} unused LoRA weights - this may indicate key mismatch:")
for key in list(unused_lora_keys)[:10]: # Show first 10
logger.warning(f" Unused: {key}")
if len(unused_lora_keys) > 10:
logger.warning(f" ... and {len(unused_lora_keys) - 10} more")
logger.info(f"Applied {applied_count} LoRA weight adjustments out of {len(lora_pairs) + len(lora_diffs)} possible")
if applied_count == 0 and (lora_pairs or lora_diffs):
logger.error("No LoRA weights were applied! Check for key name mismatches.")
logger.info("Model weight keys sample: " + str(list(weight_dict.keys())[:5]))
logger.info("LoRA pairs keys sample: " + str(list(lora_pairs.keys())[:5]))
logger.info("LoRA diffs keys sample: " + str(list(lora_diffs.keys())[:5]))
return applied_count
import argparse
import sys
from pathlib import Path
import safetensors
import torch
from safetensors.torch import save_file
sys.path.append(str(Path(__file__).parent.parent.parent))
from lightx2v.utils.quant_utils import FloatQuantizer
model_path = "/data/nvme0/gushiqiao/models/Lightx2v_models/SekoTalk-Distill/audio_adapter_model.safetensors"
state_dict = {}
with safetensors.safe_open(model_path, framework="pt", device="cpu") as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key)
def main():
# 获取脚本所在目录
script_dir = Path(__file__).parent
project_root = script_dir.parent.parent
parser = argparse.ArgumentParser(description="Quantize audio adapter model to FP8")
parser.add_argument(
"--model_path",
type=str,
default=str(project_root / "models" / "SekoTalk-Distill" / "audio_adapter_model.safetensors"),
help="Path to input model file",
)
parser.add_argument(
"--output_path",
type=str,
default=str(project_root / "models" / "SekoTalk-Distill-fp8" / "audio_adapter_model_fp8.safetensors"),
help="Path to output quantized model file",
)
args = parser.parse_args()
model_path = Path(args.model_path)
output_path = Path(args.output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
state_dict = {}
with safetensors.safe_open(model_path, framework="pt", device="cpu") as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key)
new_state_dict = {}
new_model_path = "/data/nvme0/gushiqiao/models/Lightx2v_models/seko-new/SekoTalk-Distill-fp8/audio_adapter_model_fp8.safetensors"
new_state_dict = {}
for key in state_dict.keys():
if key.startswith("ca") and ".to" in key and "weight" in key:
print(key, state_dict[key].dtype)
for key in state_dict.keys():
if key.startswith("ca") and ".to" in key and "weight" in key:
print(f"Converting {key} to FP8, dtype: {state_dict[key].dtype}")
weight = state_dict[key].to(torch.float32).cuda()
w_quantizer = FloatQuantizer("e4m3", True, "per_channel")
weight, weight_scale, _ = w_quantizer.real_quant_tensor(weight)
weight = weight.to(torch.float8_e4m3fn)
weight_scale = weight_scale.to(torch.float32)
weight = state_dict[key].to(torch.float32).cuda()
w_quantizer = FloatQuantizer("e4m3", True, "per_channel")
weight, weight_scale, _ = w_quantizer.real_quant_tensor(weight)
weight = weight.to(torch.float8_e4m3fn)
weight_scale = weight_scale.to(torch.float32)
new_state_dict[key] = weight.cpu()
new_state_dict[key + "_scale"] = weight_scale.cpu()
new_state_dict[key] = weight.cpu()
new_state_dict[key + "_scale"] = weight_scale.cpu()
else:
# 不匹配的权重转换为BF16
print(f"Converting {key} to BF16, dtype: {state_dict[key].dtype}")
new_state_dict[key] = state_dict[key].to(torch.bfloat16)
save_file(new_state_dict, str(output_path))
print(f"Quantized model saved to: {output_path}")
for key in state_dict.keys():
if key not in new_state_dict.keys():
new_state_dict[key] = state_dict[key]
save_file(new_state_dict, new_model_path)
if __name__ == "__main__":
main()
# Model Conversion Tool
This converter tool can convert model weights between different formats.
A powerful model weight conversion tool that supports format conversion, quantization, LoRA merging, and more.
## Feature 1: Convert Quantized Models
## Main Features
This tool supports converting **FP32/FP16/BF16** model weights to **INT8, FP8** types.
- **Format Conversion**: Support PyTorch (.pth) and SafeTensors (.safetensors) format conversion
- **Model Quantization**: Support INT8 and FP8 quantization to significantly reduce model size
- **Architecture Conversion**: Support conversion between LightX2V and Diffusers architectures
- **LoRA Merging**: Support loading and merging multiple LoRA formats
- **Multi-Model Support**: Support Wan DiT, Qwen Image DiT, T5, CLIP, etc.
- **Flexible Saving**: Support single file, block-based, and chunked saving methods
- **Parallel Processing**: Support parallel acceleration for large model conversion
### Wan DIT
## Supported Model Types
- `wan_dit`: Wan DiT series models (default)
- `wan_animate_dit`: Wan Animate DiT models
- `qwen_image_dit`: Qwen Image DiT models
- `wan_t5`: Wan T5 text encoder
- `wan_clip`: Wan CLIP vision encoder
## Core Parameters
### Basic Parameters
- `-s, --source`: Input path (file or directory)
- `-o, --output`: Output directory path
- `-o_e, --output_ext`: Output format, `.pth` or `.safetensors` (default)
- `-o_n, --output_name`: Output file name (default: `converted`)
- `-t, --model_type`: Model type (default: `wan_dit`)
### Architecture Conversion Parameters
- `-d, --direction`: Conversion direction
- `None`: No architecture conversion (default)
- `forward`: LightX2V → Diffusers
- `backward`: Diffusers → LightX2V
### Quantization Parameters
- `--quantized`: Enable quantization
- `--bits`: Quantization bit width, currently only supports 8-bit
- `--linear_dtype`: Linear layer quantization type
- `torch.int8`: INT8 quantization
- `torch.float8_e4m3fn`: FP8 quantization
- `--non_linear_dtype`: Non-linear layer data type
- `torch.bfloat16`: BF16
- `torch.float16`: FP16
- `torch.float32`: FP32 (default)
- `--device`: Device for quantization, `cpu` (default) or `cuda`
- `--comfyui_mode`: ComfyUI compatible mode
- `--full_quantized`: Full quantization mode (effective in ComfyUI mode)
### LoRA Parameters
- `--lora_path`: LoRA file path(s), supports multiple (separated by spaces)
- `--lora_strength`: LoRA strength coefficients, supports multiple (default: 1.0)
- `--alpha`: LoRA alpha parameters, supports multiple
- `--lora_key_convert`: LoRA key conversion mode
- `auto`: Auto-detect (default)
- `same`: Use original key names
- `convert`: Apply same conversion as model
### Saving Parameters
- `--single_file`: Save as single file (note: large models consume significant memory)
- `-b, --save_by_block`: Save by blocks (recommended for backward conversion)
- `-c, --chunk-size`: Chunk size (default: 100, 0 means no chunking)
- `--copy_no_weight_files`: Copy non-weight files from source directory
### Performance Parameters
- `--parallel`: Enable parallel processing (default: True)
- `--no-parallel`: Disable parallel processing
## Supported LoRA Formats
The tool automatically detects and supports the following LoRA formats:
1. **Standard**: `{key}.lora_up.weight` and `{key}.lora_down.weight`
2. **Diffusers**: `{key}_lora.up.weight` and `{key}_lora.down.weight`
3. **Diffusers V2**: `{key}.lora_B.weight` and `{key}.lora_A.weight`
4. **Diffusers V3**: `{key}.lora.up.weight` and `{key}.lora.down.weight`
5. **Mochi**: `{key}.lora_B` and `{key}.lora_A` (no .weight suffix)
6. **Transformers**: `{key}.lora_linear_layer.up.weight` and `{key}.lora_linear_layer.down.weight`
7. **Qwen**: `{key}.lora_B.default.weight` and `{key}.lora_A.default.weight`
Additionally supports diff formats:
- `.diff`: Weight diff
- `.diff_b`: Bias diff
- `.diff_m`: Modulation diff
## Usage Examples
### 1. Model Quantization
#### 1.1 Wan DiT Quantization to INT8
**Multiple safetensors, saved by dit blocks**
```bash
python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \
--output /Path/To/output \
--source /path/to/Wan2.1-I2V-14B-480P/ \
--output /path/to/output \
--output_ext .safetensors \
--output_name wan_int8 \
--linear_dtype torch.int8 \
......@@ -20,91 +110,90 @@ python converter.py \
--save_by_block
```
**Single safetensor file**
```bash
python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \
--output /Path/To/output \
--source /path/to/Wan2.1-I2V-14B-480P/ \
--output /path/to/output \
--output_ext .safetensors \
--output_name wan_fp8 \
--linear_dtype torch.float8_e4m3fn \
--output_name wan2.1_i2v_480p_int8_lightx2v \
--linear_dtype torch.int8 \
--model_type wan_dit \
--quantized \
--save_by_block
--single_file
```
### Wan DiT + LoRA
#### 1.2 Wan DiT Quantization to FP8
**Multiple safetensors, saved by dit blocks**
```bash
python converter.py \
--source /Path/To/Wan-AI/Wan2.1-T2V-14B/ \
--output /Path/To/output \
--source /path/to/Wan2.1-I2V-14B-480P/ \
--output /path/to/output \
--output_ext .safetensors \
--output_name wan_int8 \
--linear_dtype torch.int8 \
--output_name wan_fp8 \
--linear_dtype torch.float8_e4m3fn \
--non_linear_dtype torch.bfloat16 \
--model_type wan_dit \
--lora_path /Path/To/LoRA1/ /Path/To/LoRA2/ \
--lora_alpha 1.0 1.0 \
--quantized \
--save_by_block
```
### Hunyuan DIT
**Single safetensor file**
```bash
python converter.py \
--source /Path/To/hunyuan/lightx2v_format/i2v/ \
--output /Path/To/output \
--output_ext ..safetensors \
--output_name hunyuan_int8 \
--linear_dtype torch.int8 \
--model_type hunyuan_dit \
--quantized
```
```bash
python converter.py \
--source /Path/To/hunyuan/lightx2v_format/i2v/ \
--output /Path/To/output \
--source /path/to/Wan2.1-I2V-14B-480P/ \
--output /path/to/output \
--output_ext .safetensors \
--output_name hunyuan_fp8 \
--output_name wan2.1_i2v_480p_scaled_fp8_e4m3_lightx2v \
--linear_dtype torch.float8_e4m3fn \
--model_type hunyuan_dit \
--quantized
--non_linear_dtype torch.bfloat16 \
--model_type wan_dit \
--quantized \
--single_file
```
### QWen-Image DIT
**ComfyUI scaled_fp8 format**
```bash
python converter.py \
--source /path/to/Qwen-Image-Edit/transformer \
--output /Path/To/output \
--source /path/to/Wan2.1-I2V-14B-480P/ \
--output /path/to/output \
--output_ext .safetensors \
--output_name qwen_int8 \
--linear_dtype torch.int8 \
--model_type qwen_image_dit \
--output_name wan2.1_i2v_480p_scaled_fp8_e4m3_lightx2v_comfyui \
--linear_dtype torch.float8_e4m3fn \
--non_linear_dtype torch.bfloat16 \
--model_type wan_dit \
--quantized \
--save_by_block
--single_file \
--comfyui_mode
```
**ComfyUI full FP8 format**
```bash
python converter.py \
--source /path/to/Qwen-Image-Edit/transformer \
--output /Path/To/output \
--source /path/to/Wan2.1-I2V-14B-480P/ \
--output /path/to/output \
--output_ext .safetensors \
--output_name qwen_fp8 \
--output_name wan2.1_i2v_480p_scaled_fp8_e4m3_lightx2v_comfyui \
--linear_dtype torch.float8_e4m3fn \
--model_type qwen_image_dit \
--non_linear_dtype torch.bfloat16 \
--model_type wan_dit \
--quantized \
--save_by_block
--single_file \
--comfyui_mode \
--full_quantized
```
### Wan T5EncoderModel
> **Tip**: For other DIT models, simply switch the `--model_type` parameter
#### 1.3 T5 Encoder Quantization
**INT8 Quantization**
```bash
python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_t5_umt5-xxl-enc-bf16.pth \
--output /Path/To/output \
--output_ext .pth\
--source /path/to/models_t5_umt5-xxl-enc-bf16.pth \
--output /path/to/output \
--output_ext .pth \
--output_name models_t5_umt5-xxl-enc-int8 \
--linear_dtype torch.int8 \
--non_linear_dtype torch.bfloat16 \
......@@ -112,11 +201,12 @@ python converter.py \
--quantized
```
**FP8 Quantization**
```bash
python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_t5_umt5-xxl-enc-bf16.pth \
--output /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/fp8 \
--output_ext .pth\
--source /path/to/models_t5_umt5-xxl-enc-bf16.pth \
--output /path/to/output \
--output_ext .pth \
--output_name models_t5_umt5-xxl-enc-fp8 \
--linear_dtype torch.float8_e4m3fn \
--non_linear_dtype torch.bfloat16 \
......@@ -124,51 +214,193 @@ python converter.py \
--quantized
```
#### 1.4 CLIP Encoder Quantization
**INT8 Quantization**
```bash
python converter.py \
--source /path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth \
--output /path/to/output \
--output_ext .pth \
--output_name models_clip_open-clip-xlm-roberta-large-vit-huge-14-int8 \
--linear_dtype torch.int8 \
--non_linear_dtype torch.float16 \
--model_type wan_clip \
--quantized
```
**FP8 Quantization**
```bash
python converter.py \
--source /path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth \
--output /path/to/output \
--output_ext .pth \
--output_name models_clip_open-clip-xlm-roberta-large-vit-huge-14-fp8 \
--linear_dtype torch.float8_e4m3fn \
--non_linear_dtype torch.float16 \
--model_type wan_clip \
--quantized
```
### 2. LoRA Merging
#### 2.1 Merge Single LoRA
```bash
python converter.py \
--source /path/to/base_model/ \
--output /path/to/output \
--output_ext .safetensors \
--output_name merged_model \
--model_type wan_dit \
--lora_path /path/to/lora.safetensors \
--lora_strength 1.0 \
--single_file
```
### Wan CLIPModel
#### 2.2 Merge Multiple LoRAs
```bash
python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth \
--output /Path/To/output \
--output_ext .pth \
--output_name clip-int8 \
--linear_dtype torch.int8 \
--non_linear_dtype torch.float16 \
--model_type wan_clip \
--quantized
--source /path/to/base_model/ \
--output /path/to/output \
--output_ext .safetensors \
--output_name merged_model \
--model_type wan_dit \
--lora_path /path/to/lora1.safetensors /path/to/lora2.safetensors \
--lora_strength 1.0 0.8 \
--single_file
```
#### 2.3 LoRA Merging with Quantization
**LoRA Merge → FP8 Quantization**
```bash
python converter.py \
--source /path/to/base_model/ \
--output /path/to/output \
--output_ext .safetensors \
--output_name merged_quantized \
--model_type wan_dit \
--lora_path /path/to/lora.safetensors \
--lora_strength 1.0 \
--quantized \
--linear_dtype torch.float8_e4m3fn \
--single_file
```
**LoRA Merge → ComfyUI scaled_fp8**
```bash
python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth \
--output ./output \
--output_ext .pth \
--output_name clip-fp8 \
--linear_dtype torch.float8_e4m3fn \
--non_linear_dtype torch.float16 \
--model_type wan_clip \
--quantized
--source /path/to/base_model/ \
--output /path/to/output \
--output_ext .safetensors \
--output_name merged_quantized \
--model_type wan_dit \
--lora_path /path/to/lora.safetensors \
--lora_strength 1.0 \
--quantized \
--linear_dtype torch.float8_e4m3fn \
--single_file \
--comfyui_mode
```
**LoRA Merge → ComfyUI Full FP8**
```bash
python converter.py \
--source /path/to/base_model/ \
--output /path/to/output \
--output_ext .safetensors \
--output_name merged_quantized \
--model_type wan_dit \
--lora_path /path/to/lora.safetensors \
--lora_strength 1.0 \
--quantized \
--linear_dtype torch.float8_e4m3fn \
--single_file \
--comfyui_mode \
--full_quantized
```
## Feature 2: Format Conversion Between Diffusers and Lightx2v
Supports mutual conversion between Diffusers architecture and LightX2V architecture
#### 2.4 LoRA Key Conversion Modes
### Lightx2v->Diffusers
**Auto-detect mode (recommended)**
```bash
python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P \
--output /Path/To/Wan2.1-I2V-14B-480P-Diffusers \
--direction forward \
--save_by_block
--source /path/to/model/ \
--output /path/to/output \
--lora_path /path/to/lora.safetensors \
--lora_key_convert auto \
--single_file
```
### Diffusers->Lightx2v
**Use original key names (LoRA already in target format)**
```bash
python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P-Diffusers \
--output /Path/To/Wan2.1-I2V-14B-480P \
--direction backward \
--save_by_block
--source /path/to/model/ \
--output /path/to/output \
--direction forward \
--lora_path /path/to/lora.safetensors \
--lora_key_convert same \
--single_file
```
**Apply conversion (LoRA in source format)**
```bash
python converter.py \
--source /path/to/model/ \
--output /path/to/output \
--direction forward \
--lora_path /path/to/lora.safetensors \
--lora_key_convert convert \
--single_file
```
### 3. Architecture Format Conversion
#### 3.1 LightX2V → Diffusers
```bash
python converter.py \
--source /path/to/Wan2.1-I2V-14B-480P \
--output /path/to/Wan2.1-I2V-14B-480P-Diffusers \
--output_ext .safetensors \
--model_type wan_dit \
--direction forward \
--chunk-size 100
```
#### 3.2 Diffusers → LightX2V
```bash
python converter.py \
--source /path/to/Wan2.1-I2V-14B-480P-Diffusers \
--output /path/to/Wan2.1-I2V-14B-480P \
--output_ext .safetensors \
--model_type wan_dit \
--direction backward \
--save_by_block
```
### 4. Format Conversion
#### 4.1 .pth → .safetensors
```bash
python converter.py \
--source /path/to/model.pth \
--output /path/to/output \
--output_ext .safetensors \
--output_name model
```
#### 4.2 Multiple .safetensors → Single File
```bash
python converter.py \
--source /path/to/model_directory/ \
--output /path/to/output \
--output_ext .safetensors \
--output_name merged_model \
--single_file
```
# 模型转换工具
该converter工具可在不同格式之间转换模型权重
这是一个功能强大的模型权重转换工具,支持格式转换、量化、LoRA融合等多种功能
## 功能1:转换量化模型
## 主要特性
该工具支持将 **FP32/FP16/BF16** 模型权重转换为 **INT8、FP8** 类型。
- **格式转换**: 支持 PyTorch (.pth) 和 SafeTensors (.safetensors) 格式互转
- **模型量化**: 支持 INT8 和 FP8 量化,显著减小模型体积
- **架构转换**: 支持 LightX2V 和 Diffusers 架构互转
- **LoRA 融合**: 支持多种 LoRA 格式的加载和融合
- **多模型支持**: 支持 Wan DiT、Qwen Image DiT、T5、CLIP 等
- **灵活保存**: 支持单文件、按块、分块等多种保存方式
- **并行处理**: 大模型转换支持并行加速
### Wan DIT
## 支持的模型类型
- `wan_dit`: Wan DiT 系列模型(默认)
- `wan_animate_dit`: Wan Animate DiT 模型
- `qwen_image_dit`: Qwen Image DiT 模型
- `wan_t5`: Wan T5 文本编码器
- `wan_clip`: Wan CLIP 视觉编码器
## 核心参数说明
### 基础参数
- `-s, --source`: 输入路径(文件或目录)
- `-o, --output`: 输出目录路径
- `-o_e, --output_ext`: 输出格式,可选 `.pth``.safetensors`(默认)
- `-o_n, --output_name`: 输出文件名(默认: `converted`
- `-t, --model_type`: 模型类型(默认: `wan_dit`
### 架构转换参数
- `-d, --direction`: 转换方向
- `None`: 不进行架构转换(默认)
- `forward`: LightX2V → Diffusers
- `backward`: Diffusers → LightX2V
### 量化参数
- `--quantized`: 启用量化
- `--bits`: 量化位宽,当前仅支持 8 位
- `--linear_dtype`: 线性层量化类型
- `torch.int8`: INT8 量化
- `torch.float8_e4m3fn`: FP8 量化
- `--non_linear_dtype`: 非线性层数据类型
- `torch.bfloat16`: BF16
- `torch.float16`: FP16
- `torch.float32`: FP32(默认)
- `--device`: 量化使用的设备,可选 `cpu`(默认)或 `cuda`
- `--comfyui_mode`: ComfyUI 兼容模式
- `--full_quantized`: 全量化模式(ComfyUI 模式下有效)
### LoRA 参数
- `--lora_path`: LoRA 文件路径,支持多个(用空格分隔)
- `--lora_strength`: LoRA 强度系数,支持多个(默认: 1.0)
- `--alpha`: LoRA alpha 参数,支持多个
- `--lora_key_convert`: LoRA 键转换模式
- `auto`: 自动检测(默认)
- `same`: 使用原始键名
- `convert`: 应用与模型相同的转换
### 保存参数
- `--single_file`: 保存为单个文件(注意: 大模型会消耗大量内存)
- `-b, --save_by_block`: 按块保存(推荐用于 backward 转换)
- `-c, --chunk-size`: 分块大小(默认: 100,0 表示不分块)
- `--copy_no_weight_files`: 复制源目录中的非权重文件
### 性能参数
- `--parallel`: 启用并行处理(默认: True)
- `--no-parallel`: 禁用并行处理
## 支持的 LoRA 格式
工具自动检测并支持以下 LoRA 格式:
1. **Standard**: `{key}.lora_up.weight``{key}.lora_down.weight`
2. **Diffusers**: `{key}_lora.up.weight``{key}_lora.down.weight`
3. **Diffusers V2**: `{key}.lora_B.weight``{key}.lora_A.weight`
4. **Diffusers V3**: `{key}.lora.up.weight``{key}.lora.down.weight`
5. **Mochi**: `{key}.lora_B``{key}.lora_A`(无 .weight 后缀)
6. **Transformers**: `{key}.lora_linear_layer.up.weight``{key}.lora_linear_layer.down.weight`
7. **Qwen**: `{key}.lora_B.default.weight``{key}.lora_A.default.weight`
此外还支持差值(diff)格式:
- `.diff`: 权重差值
- `.diff_b`: bias 差值
- `.diff_m`: modulation 差值
## 使用示例
### 1. 模型量化
#### 1.1 Wan DiT 量化为 INT8
**多个 safetensors,按 dit block 存储**
```bash
python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \
--output /Path/To/output \
--source /path/to/Wan2.1-I2V-14B-480P/ \
--output /path/to/output \
--output_ext .safetensors \
--output_name wan_int8 \
--linear_dtype torch.int8 \
......@@ -20,91 +110,90 @@ python converter.py \
--save_by_block
```
**单个 safetensor 文件**
```bash
python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/ \
--output /Path/To/output \
--source /path/to/Wan2.1-I2V-14B-480P/ \
--output /path/to/output \
--output_ext .safetensors \
--output_name wan_fp8 \
--linear_dtype torch.float8_e4m3fn \
--output_name wan2.1_i2v_480p_int8_lightx2v \
--linear_dtype torch.int8 \
--model_type wan_dit \
--quantized \
--save_by_block
--single_file
```
### Wan DiT + LoRA
#### 1.2 Wan DiT 量化为 FP8
**多个 safetensors,按 dit block 存储**
```bash
python converter.py \
--source /Path/To/Wan-AI/Wan2.1-T2V-14B/ \
--output /Path/To/output \
--source /path/to/Wan2.1-I2V-14B-480P/ \
--output /path/to/output \
--output_ext .safetensors \
--output_name wan_int8 \
--linear_dtype torch.int8 \
--output_name wan_fp8 \
--linear_dtype torch.float8_e4m3fn \
--non_linear_dtype torch.bfloat16 \
--model_type wan_dit \
--lora_path /Path/To/LoRA1/ /Path/To/LoRA2/ \
--lora_alpha 1.0 1.0 \
--quantized \
--save_by_block
```
### Hunyuan DIT
**单个 safetensor 文件**
```bash
python converter.py \
--source /Path/To/hunyuan/lightx2v_format/i2v/ \
--output /Path/To/output \
--output_ext ..safetensors \
--output_name hunyuan_int8 \
--linear_dtype torch.int8 \
--model_type hunyuan_dit \
--quantized
```
```bash
python converter.py \
--source /Path/To/hunyuan/lightx2v_format/i2v/ \
--output /Path/To/output \
--source /path/to/Wan2.1-I2V-14B-480P/ \
--output /path/to/output \
--output_ext .safetensors \
--output_name hunyuan_fp8 \
--output_name wan2.1_i2v_480p_scaled_fp8_e4m3_lightx2v \
--linear_dtype torch.float8_e4m3fn \
--model_type hunyuan_dit \
--quantized
--non_linear_dtype torch.bfloat16 \
--model_type wan_dit \
--quantized \
--single_file
```
### QWen-Image DIT
**ComfyUI 的 scaled_fp8 格式**
```bash
python converter.py \
--source /path/to/Qwen-Image-Edit/transformer \
--output /Path/To/output \
--source /path/to/Wan2.1-I2V-14B-480P/ \
--output /path/to/output \
--output_ext .safetensors \
--output_name qwen_int8 \
--linear_dtype torch.int8 \
--model_type qwen_image_dit \
--output_name wan2.1_i2v_480p_scaled_fp8_e4m3_lightx2v_comfyui \
--linear_dtype torch.float8_e4m3fn \
--non_linear_dtype torch.bfloat16 \
--model_type wan_dit \
--quantized \
--save_by_block
--single_file \
--comfyui_mode
```
**ComfyUI 的全 FP8 格式**
```bash
python converter.py \
--source /path/to/Qwen-Image-Edit/transformer \
--output /Path/To/output \
--source /path/to/Wan2.1-I2V-14B-480P/ \
--output /path/to/output \
--output_ext .safetensors \
--output_name qwen_fp8 \
--output_name wan2.1_i2v_480p_scaled_fp8_e4m3_lightx2v_comfyui \
--linear_dtype torch.float8_e4m3fn \
--model_type qwen_image_dit \
--non_linear_dtype torch.bfloat16 \
--model_type wan_dit \
--quantized \
--save_by_block
--single_file \
--comfyui_mode \
--full_quantized
```
### Wan T5EncoderModel
> **提示**: 对于其他 DIT 模型,切换 `--model_type` 参数即可
#### 1.3 T5 编码器量化
**INT8 量化**
```bash
python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_t5_umt5-xxl-enc-bf16.pth \
--output /Path/To/output \
--output_ext .pth\
--source /path/to/models_t5_umt5-xxl-enc-bf16.pth \
--output /path/to/output \
--output_ext .pth \
--output_name models_t5_umt5-xxl-enc-int8 \
--linear_dtype torch.int8 \
--non_linear_dtype torch.bfloat16 \
......@@ -112,11 +201,12 @@ python converter.py \
--quantized
```
**FP8 量化**
```bash
python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_t5_umt5-xxl-enc-bf16.pth \
--output /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/fp8 \
--output_ext .pth\
--source /path/to/models_t5_umt5-xxl-enc-bf16.pth \
--output /path/to/output \
--output_ext .pth \
--output_name models_t5_umt5-xxl-enc-fp8 \
--linear_dtype torch.float8_e4m3fn \
--non_linear_dtype torch.bfloat16 \
......@@ -124,51 +214,193 @@ python converter.py \
--quantized
```
#### 1.4 CLIP 编码器量化
**INT8 量化**
```bash
python converter.py \
--source /path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth \
--output /path/to/output \
--output_ext .pth \
--output_name models_clip_open-clip-xlm-roberta-large-vit-huge-14-int8 \
--linear_dtype torch.int8 \
--non_linear_dtype torch.float16 \
--model_type wan_clip \
--quantized
```
**FP8 量化**
```bash
python converter.py \
--source /path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth \
--output /path/to/output \
--output_ext .pth \
--output_name models_clip_open-clip-xlm-roberta-large-vit-huge-14-fp8 \
--linear_dtype torch.float8_e4m3fn \
--non_linear_dtype torch.float16 \
--model_type wan_clip \
--quantized
```
### 2. LoRA 融合
#### 2.1 融合单个 LoRA
```bash
python converter.py \
--source /path/to/base_model/ \
--output /path/to/output \
--output_ext .safetensors \
--output_name merged_model \
--model_type wan_dit \
--lora_path /path/to/lora.safetensors \
--lora_strength 1.0 \
--single_file
```
### Wan CLIPModel
#### 2.2 融合多个 LoRA
```bash
python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth \
--output /Path/To/output \
--output_ext .pth \
--output_name clip-int8 \
--linear_dtype torch.int8 \
--non_linear_dtype torch.float16 \
--model_type wan_clip \
--quantized
--source /path/to/base_model/ \
--output /path/to/output \
--output_ext .safetensors \
--output_name merged_model \
--model_type wan_dit \
--lora_path /path/to/lora1.safetensors /path/to/lora2.safetensors \
--lora_strength 1.0 0.8 \
--single_file
```
#### 2.3 LoRA 融合后量化
**LoRA 融合 → FP8 量化**
```bash
python converter.py \
--source /path/to/base_model/ \
--output /path/to/output \
--output_ext .safetensors \
--output_name merged_quantized \
--model_type wan_dit \
--lora_path /path/to/lora.safetensors \
--lora_strength 1.0 \
--quantized \
--linear_dtype torch.float8_e4m3fn \
--single_file
```
**LoRA 融合 → ComfyUI scaled_fp8**
```bash
python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth \
--output ./output \
--output_ext .pth \
--output_name clip-fp8 \
--linear_dtype torch.float8_e4m3fn \
--non_linear_dtype torch.float16 \
--model_type wan_clip \
--quantized
--source /path/to/base_model/ \
--output /path/to/output \
--output_ext .safetensors \
--output_name merged_quantized \
--model_type wan_dit \
--lora_path /path/to/lora.safetensors \
--lora_strength 1.0 \
--quantized \
--linear_dtype torch.float8_e4m3fn \
--single_file \
--comfyui_mode
```
**LoRA 融合 → ComfyUI 全 FP8**
```bash
python converter.py \
--source /path/to/base_model/ \
--output /path/to/output \
--output_ext .safetensors \
--output_name merged_quantized \
--model_type wan_dit \
--lora_path /path/to/lora.safetensors \
--lora_strength 1.0 \
--quantized \
--linear_dtype torch.float8_e4m3fn \
--single_file \
--comfyui_mode \
--full_quantized
```
## 功能2:Diffusers和Lightx2v之间的格式转换
支持 Diffusers 架构与 LightX2V 架构之间的相互转换
#### 2.4 LoRA 键转换模式
### Lightx2v->Diffusers
**自动检测模式(推荐)**
```bash
python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P \
--output /Path/To/Wan2.1-I2V-14B-480P-Diffusers \
--direction forward \
--save_by_block
--source /path/to/model/ \
--output /path/to/output \
--lora_path /path/to/lora.safetensors \
--lora_key_convert auto \
--single_file
```
### Diffusers->Lightx2v
**使用原始键名(LoRA 已经是目标格式)**
```bash
python converter.py \
--source /Path/To/Wan-AI/Wan2.1-I2V-14B-480P-Diffusers \
--output /Path/To/Wan2.1-I2V-14B-480P \
--direction backward \
--save_by_block
--source /path/to/model/ \
--output /path/to/output \
--direction forward \
--lora_path /path/to/lora.safetensors \
--lora_key_convert same \
--single_file
```
**应用转换(LoRA 使用源格式)**
```bash
python converter.py \
--source /path/to/model/ \
--output /path/to/output \
--direction forward \
--lora_path /path/to/lora.safetensors \
--lora_key_convert convert \
--single_file
```
### 3. 架构格式转换
#### 3.1 LightX2V → Diffusers
```bash
python converter.py \
--source /path/to/Wan2.1-I2V-14B-480P \
--output /path/to/Wan2.1-I2V-14B-480P-Diffusers \
--output_ext .safetensors \
--model_type wan_dit \
--direction forward \
--chunk-size 100
```
#### 3.2 Diffusers → LightX2V
```bash
python converter.py \
--source /path/to/Wan2.1-I2V-14B-480P-Diffusers \
--output /path/to/Wan2.1-I2V-14B-480P \
--output_ext .safetensors \
--model_type wan_dit \
--direction backward \
--save_by_block
```
### 4. 格式转换
#### 4.1 .pth → .safetensors
```bash
python converter.py \
--source /path/to/model.pth \
--output /path/to/output \
--output_ext .safetensors \
--output_name model
```
#### 4.2 多个 .safetensors → 单文件
```bash
python converter.py \
--source /path/to/model_directory/ \
--output /path/to/output \
--output_ext .safetensors \
--output_name merged_model \
--single_file
```
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment