"docs/source/vscode:/vscode.git/clone" did not exist on "954e18ab9713da83e1484f78a6f6e178b0d9fe2a"
Unverified Commit 9a765f9b authored by Gu Shiqiao's avatar Gu Shiqiao Committed by GitHub
Browse files

fix gradio (#587)

parent 8530a2fb
...@@ -4,6 +4,9 @@ import glob ...@@ -4,6 +4,9 @@ import glob
import importlib.util import importlib.util
import json import json
import os import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["DTYPE"] = "BF16"
import random import random
from datetime import datetime from datetime import datetime
...@@ -12,6 +15,15 @@ import psutil ...@@ -12,6 +15,15 @@ import psutil
import torch import torch
from loguru import logger from loguru import logger
from lightx2v.utils.input_info import set_input_info
from lightx2v.utils.set_config import get_default_config
try:
from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace
except ImportError:
apply_rope_with_cos_sin_cache_inplace = None
logger.add( logger.add(
"inference_logs.log", "inference_logs.log",
rotation="100 MB", rotation="100 MB",
...@@ -24,38 +36,196 @@ logger.add( ...@@ -24,38 +36,196 @@ logger.add(
MAX_NUMPY_SEED = 2**32 - 1 MAX_NUMPY_SEED = 2**32 - 1
def find_hf_model_path(model_path, subdir=["original", "fp8", "int8"]): def scan_model_path_contents(model_path):
paths_to_check = [model_path] """Scan model_path directory and return available files and subdirectories"""
if isinstance(subdir, list): if not model_path or not os.path.exists(model_path):
for sub in subdir: return {"dirs": [], "files": [], "safetensors_dirs": [], "pth_files": []}
paths_to_check.append(os.path.join(model_path, sub))
else:
paths_to_check.append(os.path.join(model_path, subdir))
for path in paths_to_check: dirs = []
safetensors_pattern = os.path.join(path, "*.safetensors") files = []
safetensors_files = glob.glob(safetensors_pattern) safetensors_dirs = []
if safetensors_files: pth_files = []
logger.info(f"Found Hugging Face model files in: {path}")
return path
raise FileNotFoundError(f"No Hugging Face model files (.safetensors) found.\nPlease download the model from: https://huggingface.co/lightx2v/ or specify the model path in the configuration file.")
try:
for item in os.listdir(model_path):
item_path = os.path.join(model_path, item)
if os.path.isdir(item_path):
dirs.append(item)
# Check if directory contains safetensors files
if glob.glob(os.path.join(item_path, "*.safetensors")):
safetensors_dirs.append(item)
elif os.path.isfile(item_path):
files.append(item)
if item.endswith(".pth"):
pth_files.append(item)
except Exception as e:
logger.warning(f"Failed to scan directory: {e}")
def find_torch_model_path(model_path, filename=None, subdir=["original", "fp8", "int8"]): return {
paths_to_check = [ "dirs": sorted(dirs),
os.path.join(model_path, filename), "files": sorted(files),
] "safetensors_dirs": sorted(safetensors_dirs),
if isinstance(subdir, list): "pth_files": sorted(pth_files),
for sub in subdir: }
paths_to_check.append(os.path.join(model_path, sub, filename))
def get_dit_choices(model_path, model_type="wan2.1"):
"""Get Diffusion model options (filtered by model type)"""
contents = scan_model_path_contents(model_path)
excluded_keywords = ["vae", "tae", "clip", "t5", "high_noise", "low_noise"]
fp8_supported = is_fp8_supported_gpu()
if model_type == "wan2.1":
# wan2.1: filter files/dirs containing wan2.1 or Wan2.1
def is_valid(name):
name_lower = name.lower()
if "wan2.1" not in name_lower:
return False
if not fp8_supported and "fp8" in name_lower:
return False
return not any(kw in name_lower for kw in excluded_keywords)
else: else:
paths_to_check.append(os.path.join(model_path, subdir, filename)) # wan2.2: filter files/dirs containing wan2.2 or Wan2.2
print(paths_to_check) def is_valid(name):
for path in paths_to_check: name_lower = name.lower()
if os.path.exists(path): if "wan2.2" not in name_lower:
logger.info(f"Found PyTorch model checkpoint: {path}") return False
return path if not fp8_supported and "fp8" in name_lower:
raise FileNotFoundError(f"PyTorch model file '{filename}' not found.\nPlease download the model from https://huggingface.co/lightx2v/ or specify the model path in the configuration file.") return False
return not any(kw in name_lower for kw in excluded_keywords)
# Filter matching directories and files
dir_choices = [d for d in contents["dirs"] if is_valid(d)]
file_choices = [f for f in contents["files"] if is_valid(f)]
choices = dir_choices + file_choices
return choices if choices else [""]
def get_high_noise_choices(model_path):
"""Get high noise model options (files/dirs containing high_noise)"""
contents = scan_model_path_contents(model_path)
fp8_supported = is_fp8_supported_gpu()
def is_valid(name):
name_lower = name.lower()
if not fp8_supported and "fp8" in name_lower:
return False
return "high_noise" in name_lower or "high-noise" in name_lower
dir_choices = [d for d in contents["dirs"] if is_valid(d)]
file_choices = [f for f in contents["files"] if is_valid(f)]
choices = dir_choices + file_choices
return choices if choices else [""]
def get_low_noise_choices(model_path):
"""Get low noise model options (files/dirs containing low_noise)"""
contents = scan_model_path_contents(model_path)
fp8_supported = is_fp8_supported_gpu()
def is_valid(name):
name_lower = name.lower()
if not fp8_supported and "fp8" in name_lower:
return False
return "low_noise" in name_lower or "low-noise" in name_lower
dir_choices = [d for d in contents["dirs"] if is_valid(d)]
file_choices = [f for f in contents["files"] if is_valid(f)]
choices = dir_choices + file_choices
return choices if choices else [""]
def get_t5_choices(model_path):
"""Get T5 model options (.pth or .safetensors files containing t5 keyword)"""
contents = scan_model_path_contents(model_path)
fp8_supported = is_fp8_supported_gpu()
# Filter from .pth files
pth_choices = [f for f in contents["pth_files"] if "t5" in f.lower() and (fp8_supported or "fp8" not in f.lower())]
# Filter from .safetensors files
safetensors_choices = [f for f in contents["files"] if f.endswith(".safetensors") and "t5" in f.lower() and (fp8_supported or "fp8" not in f.lower())]
# Filter from directories containing safetensors
safetensors_dir_choices = [d for d in contents["safetensors_dirs"] if "t5" in d.lower() and (fp8_supported or "fp8" not in d.lower())]
choices = pth_choices + safetensors_choices + safetensors_dir_choices
return choices if choices else [""]
def get_clip_choices(model_path):
"""Get CLIP model options (.pth or .safetensors files containing clip keyword)"""
contents = scan_model_path_contents(model_path)
fp8_supported = is_fp8_supported_gpu()
# Filter from .pth files
pth_choices = [f for f in contents["pth_files"] if "clip" in f.lower() and (fp8_supported or "fp8" not in f.lower())]
# Filter from .safetensors files
safetensors_choices = [f for f in contents["files"] if f.endswith(".safetensors") and "clip" in f.lower() and (fp8_supported or "fp8" not in f.lower())]
# Filter from directories containing safetensors
safetensors_dir_choices = [d for d in contents["safetensors_dirs"] if "clip" in d.lower() and (fp8_supported or "fp8" not in d.lower())]
choices = pth_choices + safetensors_choices + safetensors_dir_choices
return choices if choices else [""]
def get_vae_choices(model_path):
"""Get VAE model options (.pth or .safetensors files containing vae/VAE/tae keyword)"""
contents = scan_model_path_contents(model_path)
fp8_supported = is_fp8_supported_gpu()
# Filter from .pth files
pth_choices = [f for f in contents["pth_files"] if any(kw in f.lower() for kw in ["vae", "tae"]) and (fp8_supported or "fp8" not in f.lower())]
# Filter from .safetensors files
safetensors_choices = [f for f in contents["files"] if f.endswith(".safetensors") and any(kw in f.lower() for kw in ["vae", "tae"]) and (fp8_supported or "fp8" not in f.lower())]
# Filter from directories containing safetensors
safetensors_dir_choices = [d for d in contents["safetensors_dirs"] if any(kw in d.lower() for kw in ["vae", "tae"]) and (fp8_supported or "fp8" not in d.lower())]
choices = pth_choices + safetensors_choices + safetensors_dir_choices
return choices if choices else [""]
def detect_quant_scheme(model_name):
"""Automatically detect quantization scheme from model name
- If model name contains "int8" → "int8"
- If model name contains "fp8" and device supports → "fp8"
- Otherwise return None (no quantization)
"""
if not model_name:
return None
name_lower = model_name.lower()
if "int8" in name_lower:
return "int8"
elif "fp8" in name_lower:
if is_fp8_supported_gpu():
return "fp8"
else:
# Device doesn't support fp8, return None (use default precision)
return None
return None
def update_model_path_options(model_path, model_type="wan2.1"):
"""Update all model path selectors when model_path or model_type changes"""
dit_choices = get_dit_choices(model_path, model_type)
high_noise_choices = get_high_noise_choices(model_path)
low_noise_choices = get_low_noise_choices(model_path)
t5_choices = get_t5_choices(model_path)
clip_choices = get_clip_choices(model_path)
vae_choices = get_vae_choices(model_path)
return (
gr.update(choices=dit_choices, value=dit_choices[0] if dit_choices else ""),
gr.update(choices=high_noise_choices, value=high_noise_choices[0] if high_noise_choices else ""),
gr.update(choices=low_noise_choices, value=low_noise_choices[0] if low_noise_choices else ""),
gr.update(choices=t5_choices, value=t5_choices[0] if t5_choices else ""),
gr.update(choices=clip_choices, value=clip_choices[0] if clip_choices else ""),
gr.update(choices=vae_choices, value=vae_choices[0] if vae_choices else ""),
)
def generate_random_seed(): def generate_random_seed():
...@@ -109,12 +279,18 @@ def get_available_attn_ops(): ...@@ -109,12 +279,18 @@ def get_available_attn_ops():
else: else:
available_ops.append(("flash_attn3", False)) available_ops.append(("flash_attn3", False))
q8f_installed = is_module_installed("sageattention") sage_installed = is_module_installed("sageattention")
if q8f_installed: if sage_installed:
available_ops.append(("sage_attn2", True)) available_ops.append(("sage_attn2", True))
else: else:
available_ops.append(("sage_attn2", False)) available_ops.append(("sage_attn2", False))
sage3_installed = is_module_installed("sageattn3")
if sage3_installed:
available_ops.append(("sage_attn3", True))
else:
available_ops.append(("sage_attn3", False))
torch_installed = is_module_installed("torch") torch_installed = is_module_installed("torch")
if torch_installed: if torch_installed:
available_ops.append(("torch_sdpa", True)) available_ops.append(("torch_sdpa", True))
...@@ -150,6 +326,8 @@ def cleanup_memory(): ...@@ -150,6 +326,8 @@ def cleanup_memory():
torch.cuda.synchronize() torch.cuda.synchronize()
try: try:
import psutil
if hasattr(psutil, "virtual_memory"): if hasattr(psutil, "virtual_memory"):
if os.name == "posix": if os.name == "posix":
try: try:
...@@ -163,7 +341,7 @@ def cleanup_memory(): ...@@ -163,7 +341,7 @@ def cleanup_memory():
def generate_unique_filename(output_dir): def generate_unique_filename(output_dir):
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
return os.path.join(output_dir, f"{model_cls}_{timestamp}.mp4") return os.path.join(output_dir, f"{timestamp}.mp4")
def is_fp8_supported_gpu(): def is_fp8_supported_gpu():
...@@ -231,13 +409,25 @@ def get_quantization_options(model_path): ...@@ -231,13 +409,25 @@ def get_quantization_options(model_path):
return {"dit_choices": dit_choices, "dit_default": dit_default, "t5_choices": t5_choices, "t5_default": t5_default, "clip_choices": clip_choices, "clip_default": clip_default} return {"dit_choices": dit_choices, "dit_default": dit_default, "t5_choices": t5_choices, "t5_default": t5_default, "clip_choices": clip_choices, "clip_default": clip_default}
def determine_model_cls(model_type, dit_name, high_noise_name):
"""Determine model_cls based on model type and file name"""
# Determine file name to check
if model_type == "wan2.1":
check_name = dit_name.lower() if dit_name else ""
is_distill = "4step" in check_name
return "wan2.1_distill" if is_distill else "wan2.1"
else:
# wan2.2
check_name = high_noise_name.lower() if high_noise_name else ""
is_distill = "4step" in check_name
return "wan2.2_moe_distill" if is_distill else "wan2.2_moe"
global_runner = None global_runner = None
current_config = None current_config = None
cur_dit_quant_scheme = None cur_dit_path = None
cur_clip_quant_scheme = None cur_t5_path = None
cur_t5_quant_scheme = None cur_clip_path = None
cur_precision_mode = None
cur_enable_teacache = None
available_quant_ops = get_available_quant_ops() available_quant_ops = get_available_quant_ops()
quant_op_choices = [] quant_op_choices = []
...@@ -247,8 +437,29 @@ for op_name, is_installed in available_quant_ops: ...@@ -247,8 +437,29 @@ for op_name, is_installed in available_quant_ops:
quant_op_choices.append((op_name, display_text)) quant_op_choices.append((op_name, display_text))
available_attn_ops = get_available_attn_ops() available_attn_ops = get_available_attn_ops()
# Priority order
attn_priority = ["sage_attn3", "sage_attn2", "flash_attn3", "flash_attn2", "torch_sdpa"]
# Sort by priority, installed ones first, uninstalled ones last
attn_op_choices = [] attn_op_choices = []
for op_name, is_installed in available_attn_ops: attn_op_dict = dict(available_attn_ops)
# Add installed ones first (by priority)
for op_name in attn_priority:
if op_name in attn_op_dict and attn_op_dict[op_name]:
status_text = "✅ Installed"
display_text = f"{op_name} ({status_text})"
attn_op_choices.append((op_name, display_text))
# Add uninstalled ones (by priority)
for op_name in attn_priority:
if op_name in attn_op_dict and not attn_op_dict[op_name]:
status_text = "❌ Not Installed"
display_text = f"{op_name} ({status_text})"
attn_op_choices.append((op_name, display_text))
# Add other operators not in priority list (installed ones first)
other_ops = [(op_name, is_installed) for op_name, is_installed in available_attn_ops if op_name not in attn_priority]
for op_name, is_installed in sorted(other_ops, key=lambda x: not x[1]): # Installed ones first
status_text = "✅ Installed" if is_installed else "❌ Not Installed" status_text = "✅ Installed" if is_installed else "❌ Not Installed"
display_text = f"{op_name} ({status_text})" display_text = f"{op_name} ({status_text})"
attn_op_choices.append((op_name, display_text)) attn_op_choices.append((op_name, display_text))
...@@ -258,36 +469,36 @@ def run_inference( ...@@ -258,36 +469,36 @@ def run_inference(
prompt, prompt,
negative_prompt, negative_prompt,
save_result_path, save_result_path,
torch_compile,
infer_steps, infer_steps,
num_frames, num_frames,
resolution, resolution,
seed, seed,
sample_shift, sample_shift,
enable_teacache,
teacache_thresh,
use_ret_steps,
enable_cfg, enable_cfg,
cfg_scale, cfg_scale,
dit_quant_scheme,
t5_quant_scheme,
clip_quant_scheme,
fps, fps,
use_tae,
use_tiling_vae, use_tiling_vae,
lazy_load, lazy_load,
precision_mode,
cpu_offload, cpu_offload,
offload_granularity, offload_granularity,
offload_ratio,
t5_cpu_offload, t5_cpu_offload,
clip_cpu_offload,
vae_cpu_offload,
unload_modules, unload_modules,
t5_offload_granularity,
attention_type, attention_type,
quant_op, quant_op,
rotary_chunk, rope_chunk,
rotary_chunk_size, rope_chunk_size,
clean_cuda_cache, clean_cuda_cache,
model_path_input,
model_type_input,
task_type_input,
dit_path_input,
high_noise_path_input,
low_noise_path_input,
t5_path_input,
clip_path_input,
vae_path_input,
image_path=None, image_path=None,
): ):
cleanup_memory() cleanup_memory()
...@@ -295,8 +506,23 @@ def run_inference( ...@@ -295,8 +506,23 @@ def run_inference(
quant_op = quant_op.split("(")[0].strip() quant_op = quant_op.split("(")[0].strip()
attention_type = attention_type.split("(")[0].strip() attention_type = attention_type.split("(")[0].strip()
global global_runner, current_config, model_path, task global global_runner, current_config, model_path, model_cls
global cur_dit_quant_scheme, cur_clip_quant_scheme, cur_t5_quant_scheme, cur_precision_mode, cur_enable_teacache global cur_dit_path, cur_t5_path, cur_clip_path
task = task_type_input
model_cls = determine_model_cls(model_type_input, dit_path_input, high_noise_path_input)
logger.info(f"Auto-determined model_cls: {model_cls} (Model type: {model_type_input})")
if model_type_input == "wan2.1":
dit_quant_detected = detect_quant_scheme(dit_path_input)
else:
dit_quant_detected = detect_quant_scheme(high_noise_path_input)
t5_quant_detected = detect_quant_scheme(t5_path_input)
clip_quant_detected = detect_quant_scheme(clip_path_input)
logger.info(f"Auto-detected quantization scheme - DIT: {dit_quant_detected}, T5: {t5_quant_detected}, CLIP: {clip_quant_detected}")
if model_path_input and model_path_input.strip():
model_path = model_path_input.strip()
if os.path.exists(os.path.join(model_path, "config.json")): if os.path.exists(os.path.join(model_path, "config.json")):
with open(os.path.join(model_path, "config.json"), "r") as f: with open(os.path.join(model_path, "config.json"), "r") as f:
...@@ -304,157 +530,88 @@ def run_inference( ...@@ -304,157 +530,88 @@ def run_inference(
else: else:
model_config = {} model_config = {}
if task == "t2v": save_result_path = generate_unique_filename(output_dir)
if model_size == "1.3b":
# 1.3B is_dit_quant = dit_quant_detected != "bf16"
coefficient = [ is_t5_quant = t5_quant_detected != "bf16"
[ is_clip_quant = clip_quant_detected != "fp16"
-5.21862437e04,
9.23041404e03, dit_quantized_ckpt = None
-5.28275948e02, dit_original_ckpt = None
1.36987616e01, high_noise_quantized_ckpt = None
-4.99875664e-02, low_noise_quantized_ckpt = None
], high_noise_original_ckpt = None
[ low_noise_original_ckpt = None
2.39676752e03,
-1.31110545e03, if is_dit_quant:
2.01331979e02, dit_quant_scheme = f"{dit_quant_detected}-{quant_op}"
-8.29855975e00, if "wan2.1" in model_cls:
1.37887774e-01, dit_quantized_ckpt = os.path.join(model_path, dit_path_input)
],
]
else: else:
# 14B high_noise_quantized_ckpt = os.path.join(model_path, high_noise_path_input)
coefficient = [ low_noise_quantized_ckpt = os.path.join(model_path, low_noise_path_input)
[ else:
-3.03318725e05, dit_quantized_ckpt = "Default"
4.90537029e04, if "wan2.1" in model_cls:
-2.65530556e03, dit_original_ckpt = os.path.join(model_path, dit_path_input)
5.87365115e01,
-3.15583525e-01,
],
[
-5784.54975374,
5449.50911966,
-1811.16591783,
256.27178429,
-13.02252404,
],
]
elif task == "i2v":
if resolution in [
"1280x720",
"720x1280",
"1280x544",
"544x1280",
"1104x832",
"832x1104",
"960x960",
]:
# 720p
coefficient = [
[
8.10705460e03,
2.13393892e03,
-3.72934672e02,
1.66203073e01,
-4.17769401e-02,
],
[-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683],
]
else: else:
# 480p high_noise_original_ckpt = os.path.join(model_path, high_noise_path_input)
coefficient = [ low_noise_original_ckpt = os.path.join(model_path, low_noise_path_input)
[
2.57151496e05,
-3.54229917e04,
1.40286849e03,
-1.35890334e01,
1.32517977e-01,
],
[
-3.02331670e02,
2.23948934e02,
-5.25463970e01,
5.87348440e00,
-2.01973289e-01,
],
]
save_result_path = generate_unique_filename(output_dir) # Use frontend-selected T5 path
is_dit_quant = dit_quant_scheme != "bf16"
is_t5_quant = t5_quant_scheme != "bf16"
if is_t5_quant: if is_t5_quant:
t5_model_name = f"models_t5_umt5-xxl-enc-{t5_quant_scheme}.pth" t5_quantized_ckpt = os.path.join(model_path, t5_path_input)
t5_quant_ckpt = find_torch_model_path(model_path, t5_model_name, t5_quant_scheme) t5_quant_scheme = f"{t5_quant_detected}-{quant_op}"
t5_original_ckpt = None t5_original_ckpt = None
else: else:
t5_quant_ckpt = None t5_quantized_ckpt = None
t5_model_name = "models_t5_umt5-xxl-enc-bf16.pth" t5_quant_scheme = None
t5_original_ckpt = find_torch_model_path(model_path, t5_model_name, "original") t5_original_ckpt = os.path.join(model_path, t5_path_input)
is_clip_quant = clip_quant_scheme != "fp16" # Use frontend-selected CLIP path
if is_clip_quant: if is_clip_quant:
clip_model_name = f"clip-{clip_quant_scheme}.pth" clip_quantized_ckpt = os.path.join(model_path, clip_path_input)
clip_quant_ckpt = find_torch_model_path(model_path, clip_model_name, clip_quant_scheme) clip_quant_scheme = f"{clip_quant_detected}-{quant_op}"
clip_original_ckpt = None clip_original_ckpt = None
else: else:
clip_quant_ckpt = None clip_quantized_ckpt = None
clip_model_name = "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" clip_quant_scheme = None
clip_original_ckpt = find_torch_model_path(model_path, clip_model_name, "original") clip_original_ckpt = os.path.join(model_path, clip_path_input)
if model_type_input == "wan2.1":
current_dit_path = dit_path_input
else:
current_dit_path = f"{high_noise_path_input}|{low_noise_path_input}" if high_noise_path_input and low_noise_path_input else None
current_t5_path = t5_path_input
current_clip_path = clip_path_input
needs_reinit = ( needs_reinit = (
lazy_load lazy_load
or unload_modules or unload_modules
or global_runner is None or global_runner is None
or current_config is None or current_config is None
or cur_dit_quant_scheme is None or cur_dit_path is None
or cur_dit_quant_scheme != dit_quant_scheme or cur_dit_path != current_dit_path
or cur_clip_quant_scheme is None or cur_t5_path is None
or cur_clip_quant_scheme != clip_quant_scheme or cur_t5_path != current_t5_path
or cur_t5_quant_scheme is None or cur_clip_path is None
or cur_t5_quant_scheme != t5_quant_scheme or cur_clip_path != current_clip_path
or cur_precision_mode is None
or cur_precision_mode != precision_mode
or cur_enable_teacache is None
or cur_enable_teacache != enable_teacache
) )
if torch_compile: if cfg_scale == 1:
os.environ["ENABLE_GRAPH_MODE"] = "true" enable_cfg = False
else:
os.environ["ENABLE_GRAPH_MODE"] = "false"
if precision_mode == "bf16":
os.environ["DTYPE"] = "BF16"
else: else:
os.environ.pop("DTYPE", None) enable_cfg = True
if is_dit_quant: vae_name_lower = vae_path_input.lower() if vae_path_input else ""
if quant_op == "vllm": use_tae = "tae" in vae_name_lower or "lighttae" in vae_name_lower
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Vllm" use_lightvae = "lightvae" in vae_name_lower
elif quant_op == "sgl": need_scaled = "lighttae" in vae_name_lower
if dit_quant_scheme == "int8":
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Sgl-ActVllm" logger.info(f"VAE configuration - use_tae: {use_tae}, use_lightvae: {use_lightvae}, need_scaled: {need_scaled} (VAE: {vae_path_input})")
else:
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Sgl"
elif quant_op == "q8f":
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Q8F"
t5_quant_scheme = f"{t5_quant_scheme}-q8f"
clip_quant_scheme = f"{clip_quant_scheme}-q8f"
dit_quantized_ckpt = find_hf_model_path(model_path, dit_quant_scheme)
if os.path.exists(os.path.join(dit_quantized_ckpt, "config.json")):
with open(os.path.join(dit_quantized_ckpt, "config.json"), "r") as f:
quant_model_config = json.load(f)
else:
quant_model_config = {}
else:
mm_type = "Default"
dit_quantized_ckpt = None
quant_model_config = {}
config = { config_graio = {
"infer_steps": infer_steps, "infer_steps": infer_steps,
"target_video_length": num_frames, "target_video_length": num_frames,
"target_width": int(resolution.split("x")[0]), "target_width": int(resolution.split("x")[0]),
...@@ -462,38 +619,11 @@ def run_inference( ...@@ -462,38 +619,11 @@ def run_inference(
"self_attn_1_type": attention_type, "self_attn_1_type": attention_type,
"cross_attn_1_type": attention_type, "cross_attn_1_type": attention_type,
"cross_attn_2_type": attention_type, "cross_attn_2_type": attention_type,
"seed": seed,
"enable_cfg": enable_cfg, "enable_cfg": enable_cfg,
"sample_guide_scale": cfg_scale, "sample_guide_scale": cfg_scale,
"sample_shift": sample_shift, "sample_shift": sample_shift,
"cpu_offload": cpu_offload,
"offload_granularity": offload_granularity,
"offload_ratio": offload_ratio,
"t5_offload_granularity": t5_offload_granularity,
"dit_quantized_ckpt": dit_quantized_ckpt,
"mm_config": {
"mm_type": mm_type,
},
"fps": fps, "fps": fps,
"feature_caching": "Tea" if enable_teacache else "NoCaching", "feature_caching": "NoCaching",
"coefficients": coefficient[0] if use_ret_steps else coefficient[1],
"use_ret_steps": use_ret_steps,
"teacache_thresh": teacache_thresh,
"t5_cpu_offload": t5_cpu_offload,
"unload_modules": unload_modules,
"t5_original_ckpt": t5_original_ckpt,
"t5_quantized": is_t5_quant,
"t5_quantized_ckpt": t5_quant_ckpt,
"t5_quant_scheme": t5_quant_scheme,
"clip_original_ckpt": clip_original_ckpt,
"clip_quantized": is_clip_quant,
"clip_quantized_ckpt": clip_quant_ckpt,
"clip_quant_scheme": clip_quant_scheme,
"vae_path": find_torch_model_path(model_path, "Wan2.1_VAE.pth"),
"use_tiling_vae": use_tiling_vae,
"use_tae": use_tae,
"tae_path": (find_torch_model_path(model_path, "taew2_1.pth") if use_tae else None),
"lazy_load": lazy_load,
"do_mm_calib": False, "do_mm_calib": False,
"parallel_attn_type": None, "parallel_attn_type": None,
"parallel_vae": False, "parallel_vae": False,
...@@ -504,14 +634,49 @@ def run_inference( ...@@ -504,14 +634,49 @@ def run_inference(
"strength_model": 1.0, "strength_model": 1.0,
"use_prompt_enhancer": False, "use_prompt_enhancer": False,
"text_len": 512, "text_len": 512,
"rotary_chunk": rotary_chunk,
"rotary_chunk_size": rotary_chunk_size,
"clean_cuda_cache": clean_cuda_cache,
"denoising_step_list": [1000, 750, 500, 250], "denoising_step_list": [1000, 750, 500, 250],
"cpu_offload": True if "wan2.2" in model_cls else cpu_offload,
"offload_granularity": "phase" if "wan2.2" in model_cls else offload_granularity,
"t5_cpu_offload": t5_cpu_offload,
"clip_cpu_offload": clip_cpu_offload,
"vae_cpu_offload": vae_cpu_offload,
"dit_quantized": is_dit_quant,
"dit_quant_scheme": dit_quant_scheme,
"dit_quantized_ckpt": dit_quantized_ckpt,
"dit_original_ckpt": dit_original_ckpt,
"high_noise_quantized_ckpt": high_noise_quantized_ckpt,
"low_noise_quantized_ckpt": low_noise_quantized_ckpt,
"high_noise_original_ckpt": high_noise_original_ckpt,
"low_noise_original_ckpt": low_noise_original_ckpt,
"t5_original_ckpt": t5_original_ckpt,
"t5_quantized": is_t5_quant,
"t5_quantized_ckpt": t5_quantized_ckpt,
"t5_quant_scheme": t5_quant_scheme,
"clip_original_ckpt": clip_original_ckpt,
"clip_quantized": is_clip_quant,
"clip_quantized_ckpt": clip_quantized_ckpt,
"clip_quant_scheme": clip_quant_scheme,
"vae_path": os.path.join(model_path, vae_path_input),
"use_tiling_vae": use_tiling_vae,
"use_tae": use_tae,
"use_lightvae": use_lightvae,
"need_scaled": need_scaled,
"lazy_load": lazy_load,
"rope_chunk": rope_chunk,
"rope_chunk_size": rope_chunk_size,
"clean_cuda_cache": clean_cuda_cache,
"unload_modules": unload_modules,
"seq_parallel": False,
"warm_up_cpu_buffers": False,
"boundary_step_index": 2,
"boundary": 0.900,
"use_image_encoder": False if "wan2.2" in model_cls else True,
"rope_type": "flashinfer" if apply_rope_with_cos_sin_cache_inplace else "torch",
} }
args = argparse.Namespace( args = argparse.Namespace(
model_cls=model_cls, model_cls=model_cls,
seed=seed,
task=task, task=task,
model_path=model_path, model_path=model_path,
prompt_enhancer=None, prompt_enhancer=None,
...@@ -519,11 +684,13 @@ def run_inference( ...@@ -519,11 +684,13 @@ def run_inference(
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
image_path=image_path, image_path=image_path,
save_result_path=save_result_path, save_result_path=save_result_path,
return_result_tensor=False,
) )
config = get_default_config()
config.update({k: v for k, v in vars(args).items()}) config.update({k: v for k, v in vars(args).items()})
config.update(model_config) config.update(model_config)
config.update(quant_model_config) config.update(config_graio)
logger.info(f"Using model: {model_path}") logger.info(f"Using model: {model_path}")
logger.info(f"Inference configuration:\n{json.dumps(config, indent=4, ensure_ascii=False)}") logger.info(f"Inference configuration:\n{json.dumps(config, indent=4, ensure_ascii=False)}")
...@@ -539,28 +706,19 @@ def run_inference( ...@@ -539,28 +706,19 @@ def run_inference(
from lightx2v.infer import init_runner # noqa from lightx2v.infer import init_runner # noqa
runner = init_runner(config) runner = init_runner(config)
input_info = set_input_info(args)
current_config = config current_config = config
cur_dit_quant_scheme = dit_quant_scheme cur_dit_path = current_dit_path
cur_clip_quant_scheme = clip_quant_scheme cur_t5_path = current_t5_path
cur_t5_quant_scheme = t5_quant_scheme cur_clip_path = current_clip_path
cur_precision_mode = precision_mode
cur_enable_teacache = enable_teacache
if not lazy_load: if not lazy_load:
global_runner = runner global_runner = runner
else: else:
runner.config = config runner.config = config
runner.run_pipeline() runner.run_pipeline(input_info)
del config, args, model_config, quant_model_config
if "dit_quantized_ckpt" in locals():
del dit_quantized_ckpt
if "t5_quant_ckpt" in locals():
del t5_quant_ckpt
if "clip_quant_ckpt" in locals():
del clip_quant_ckpt
cleanup_memory() cleanup_memory()
return save_result_path return save_result_path
...@@ -571,44 +729,28 @@ def handle_lazy_load_change(lazy_load_enabled): ...@@ -571,44 +729,28 @@ def handle_lazy_load_change(lazy_load_enabled):
return gr.update(value=lazy_load_enabled) return gr.update(value=lazy_load_enabled)
def auto_configure(enable_auto_config, resolution): def auto_configure(resolution):
"""Auto-configure inference options based on machine configuration and resolution"""
default_config = { default_config = {
"torch_compile_val": False,
"lazy_load_val": False, "lazy_load_val": False,
"rotary_chunk_val": False, "rope_chunk_val": False,
"rotary_chunk_size_val": 100, "rope_chunk_size_val": 100,
"clean_cuda_cache_val": False, "clean_cuda_cache_val": False,
"cpu_offload_val": False, "cpu_offload_val": False,
"offload_granularity_val": "block", "offload_granularity_val": "block",
"offload_ratio_val": 1,
"t5_cpu_offload_val": False, "t5_cpu_offload_val": False,
"clip_cpu_offload_val": False,
"vae_cpu_offload_val": False,
"unload_modules_val": False, "unload_modules_val": False,
"t5_offload_granularity_val": "model",
"attention_type_val": attn_op_choices[0][1], "attention_type_val": attn_op_choices[0][1],
"quant_op_val": quant_op_choices[0][1], "quant_op_val": quant_op_choices[0][1],
"dit_quant_scheme_val": "bf16",
"t5_quant_scheme_val": "bf16",
"clip_quant_scheme_val": "fp16",
"precision_mode_val": "fp32",
"use_tae_val": False,
"use_tiling_vae_val": False, "use_tiling_vae_val": False,
"enable_teacache_val": False,
"teacache_thresh_val": 0.26,
"use_ret_steps_val": False,
} }
if not enable_auto_config:
return tuple(gr.update(value=default_config[key]) for key in default_config)
gpu_memory = round(get_gpu_memory()) gpu_memory = round(get_gpu_memory())
cpu_memory = round(get_cpu_memory()) cpu_memory = round(get_cpu_memory())
if is_fp8_supported_gpu(): attn_priority = ["sage_attn3", "sage_attn2", "flash_attn3", "flash_attn2", "torch_sdpa"]
quant_type = "fp8"
else:
quant_type = "int8"
attn_priority = ["sage_attn2", "flash_attn3", "flash_attn2", "torch_sdpa"]
if is_ada_architecture_gpu(): if is_ada_architecture_gpu():
quant_op_priority = ["q8f", "vllm", "sgl"] quant_op_priority = ["q8f", "vllm", "sgl"]
...@@ -643,25 +785,15 @@ def auto_configure(enable_auto_config, resolution): ...@@ -643,25 +785,15 @@ def auto_configure(enable_auto_config, resolution):
else: else:
res = "480p" res = "480p"
if model_size == "14b": if res == "720p":
is_14b = True
else:
is_14b = False
if res == "720p" and is_14b:
gpu_rules = [ gpu_rules = [
(80, {}), (80, {}),
(48, {"cpu_offload_val": True, "offload_ratio_val": 0.5, "t5_cpu_offload_val": True}), (40, {"cpu_offload_val": False, "t5_cpu_offload_val": True, "vae_cpu_offload_val": True, "clip_cpu_offload_val": True}),
(40, {"cpu_offload_val": True, "offload_ratio_val": 0.8, "t5_cpu_offload_val": True}), (32, {"cpu_offload_val": True, "t5_cpu_offload_val": False, "vae_cpu_offload_val": False, "clip_cpu_offload_val": False}),
(32, {"cpu_offload_val": True, "offload_ratio_val": 1, "t5_cpu_offload_val": True}),
( (
24, 24,
{ {
"cpu_offload_val": True, "cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True, "use_tiling_vae_val": True,
}, },
), ),
...@@ -669,155 +801,68 @@ def auto_configure(enable_auto_config, resolution): ...@@ -669,155 +801,68 @@ def auto_configure(enable_auto_config, resolution):
16, 16,
{ {
"cpu_offload_val": True, "cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True, "use_tiling_vae_val": True,
"offload_granularity_val": "phase", "offload_granularity_val": "phase",
"rotary_chunk_val": True, "rope_chunk_val": True,
"rotary_chunk_size_val": 100, "rope_chunk_size_val": 100,
},
),
(
12,
{
"cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
"rotary_chunk_val": True,
"rotary_chunk_size_val": 100,
"clean_cuda_cache_val": True,
"use_tae_val": True,
}, },
), ),
( (
8, 8,
{ {
"cpu_offload_val": True, "cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True, "use_tiling_vae_val": True,
"offload_granularity_val": "phase", "offload_granularity_val": "phase",
"rotary_chunk_val": True, "rope_chunk_val": True,
"rotary_chunk_size_val": 100, "rope_chunk_size_val": 100,
"clean_cuda_cache_val": True, "clean_cuda_cache_val": True,
"t5_quant_scheme_val": quant_type,
"clip_quant_scheme_val": quant_type,
"dit_quant_scheme_val": quant_type,
"lazy_load_val": True,
"unload_modules_val": True,
"use_tae_val": True,
}, },
), ),
] ]
elif is_14b: else:
gpu_rules = [ gpu_rules = [
(80, {}), (80, {}),
(48, {"cpu_offload_val": True, "offload_ratio_val": 0.2, "t5_cpu_offload_val": True}), (40, {"cpu_offload_val": False, "t5_cpu_offload_val": True, "vae_cpu_offload_val": True, "clip_cpu_offload_val": True}),
(40, {"cpu_offload_val": True, "offload_ratio_val": 0.5, "t5_cpu_offload_val": True}), (32, {"cpu_offload_val": True, "t5_cpu_offload_val": False, "vae_cpu_offload_val": False, "clip_cpu_offload_val": False}),
(24, {"cpu_offload_val": True, "offload_ratio_val": 0.8, "t5_cpu_offload_val": True}),
( (
16, 24,
{ {
"cpu_offload_val": True, "cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True, "use_tiling_vae_val": True,
"offload_granularity_val": "block",
},
),
(
8,
(
{
"cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
"t5_quant_scheme_val": quant_type,
"clip_quant_scheme_val": quant_type,
"dit_quant_scheme_val": quant_type,
"lazy_load_val": True,
"unload_modules_val": True,
"rotary_chunk_val": True,
"rotary_chunk_size_val": 10000,
"use_tae_val": True,
}
if res == "540p"
else {
"cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
"t5_quant_scheme_val": quant_type,
"clip_quant_scheme_val": quant_type,
"dit_quant_scheme_val": quant_type,
"lazy_load_val": True,
"unload_modules_val": True,
"use_tae_val": True,
}
),
),
]
else:
gpu_rules = [
(24, {}),
(
8,
{
"t5_cpu_offload_val": True,
"t5_offload_granularity_val": "block",
"t5_quant_scheme_val": quant_type,
}, },
), ),
]
if is_14b:
cpu_rules = [
(128, {}),
(64, {"dit_quant_scheme_val": quant_type}),
(32, {"dit_quant_scheme_val": quant_type, "lazy_load_val": True}),
( (
16, 16,
{ {
"dit_quant_scheme_val": quant_type, "cpu_offload_val": True,
"t5_quant_scheme_val": quant_type, "use_tiling_vae_val": True,
"clip_quant_scheme_val": quant_type, "offload_granularity_val": "phase",
"lazy_load_val": True,
"unload_modules_val": True,
}, },
), ),
]
else:
cpu_rules = [
(64, {}),
( (
16, 8,
{ {
"t5_quant_scheme_val": quant_type, "cpu_offload_val": True,
"unload_modules_val": True, "use_tiling_vae_val": True,
"use_tae_val": True, "offload_granularity_val": "phase",
}, },
), ),
] ]
cpu_rules = [
(128, {}),
(64, {}),
(32, {"unload_modules_val": True}),
(
16,
{
"lazy_load_val": True,
"unload_modules_val": True,
},
),
]
for threshold, updates in gpu_rules: for threshold, updates in gpu_rules:
if gpu_memory >= threshold: if gpu_memory >= threshold:
default_config.update(updates) default_config.update(updates)
...@@ -828,511 +873,551 @@ def auto_configure(enable_auto_config, resolution): ...@@ -828,511 +873,551 @@ def auto_configure(enable_auto_config, resolution):
default_config.update(updates) default_config.update(updates)
break break
return tuple(gr.update(value=default_config[key]) for key in default_config) return (
gr.update(value=default_config["lazy_load_val"]),
gr.update(value=default_config["rope_chunk_val"]),
gr.update(value=default_config["rope_chunk_size_val"]),
gr.update(value=default_config["clean_cuda_cache_val"]),
gr.update(value=default_config["cpu_offload_val"]),
gr.update(value=default_config["offload_granularity_val"]),
gr.update(value=default_config["t5_cpu_offload_val"]),
gr.update(value=default_config["clip_cpu_offload_val"]),
gr.update(value=default_config["vae_cpu_offload_val"]),
gr.update(value=default_config["unload_modules_val"]),
gr.update(value=default_config["attention_type_val"]),
gr.update(value=default_config["quant_op_val"]),
gr.update(value=default_config["use_tiling_vae_val"]),
)
def main(): def main():
with gr.Blocks( with gr.Blocks(
title="Lightx2v (Lightweight Video Inference and Generation Engine)", title="Lightx2v (Lightweight Video Inference and Generation Engine)",
css=""" css="""
.main-content { max-width: 1400px; margin: auto; } .main-content { max-width: 1600px; margin: auto; padding: 20px; }
.output-video { max-height: 650px; }
.warning { color: #ff6b6b; font-weight: bold; } .warning { color: #ff6b6b; font-weight: bold; }
.advanced-options { background: #f9f9ff; border-radius: 10px; padding: 15px; }
.tab-button { font-size: 16px; padding: 10px 20px; } /* Model configuration area styles */
.auto-config-title { .model-config {
background: linear-gradient(45deg, #ff6b6b, #4ecdc4); margin-bottom: 20px !important;
background-clip: text; border: 1px solid #e0e0e0;
-webkit-background-clip: text; border-radius: 12px;
color: transparent; padding: 15px;
text-align: center; background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
margin: 0 !important;
padding: 8px;
border: 2px solid #4ecdc4;
border-radius: 8px;
background-color: #f0f8ff;
} }
.auto-config-checkbox {
border: 2px solid #ff6b6b !important; /* Input parameters area styles */
border-radius: 8px !important; .input-params {
padding: 10px !important; margin-bottom: 20px !important;
background: linear-gradient(135deg, #fff5f5, #f0fff0) !important; border: 1px solid #e0e0e0;
box-shadow: 0 2px 8px rgba(255, 107, 107, 0.2) !important; border-radius: 12px;
padding: 15px;
background: linear-gradient(135deg, #fff5f5 0%, #ffeef0 100%);
} }
.auto-config-checkbox label {
font-size: 16px !important; /* Output video area styles */
font-weight: bold !important; .output-video {
color: #2c3e50 !important; border: 1px solid #e0e0e0;
border-radius: 12px;
padding: 20px;
background: linear-gradient(135deg, #e0f2fe 0%, #bae6fd 100%);
min-height: 400px;
} }
""",
) as demo:
gr.Markdown(f"# 🎬 {model_cls} Video Generator")
gr.Markdown(f"### Using Model: {model_path}")
with gr.Tabs() as tabs:
with gr.Tab("Basic Settings", id=1):
with gr.Row():
with gr.Column(scale=4):
with gr.Group():
gr.Markdown("## 📥 Input Parameters")
if task == "i2v":
with gr.Row():
image_path = gr.Image(
label="Input Image",
type="filepath",
height=300,
interactive=True,
visible=True,
)
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Prompt",
lines=3,
placeholder="Describe the video content...",
max_lines=5,
)
with gr.Column():
negative_prompt = gr.Textbox(
label="Negative Prompt",
lines=3,
placeholder="What you don't want to appear in the video...",
max_lines=5,
value="Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards",
)
with gr.Column():
resolution = gr.Dropdown(
choices=[
# 720p
("1280x720 (16:9, 720p)", "1280x720"),
("720x1280 (9:16, 720p)", "720x1280"),
("1280x544 (21:9, 720p)", "1280x544"),
("544x1280 (9:21, 720p)", "544x1280"),
("1104x832 (4:3, 720p)", "1104x832"),
("832x1104 (3:4, 720p)", "832x1104"),
("960x960 (1:1, 720p)", "960x960"),
# 480p
("960x544 (16:9, 540p)", "960x544"),
("544x960 (9:16, 540p)", "544x960"),
("832x480 (16:9, 480p)", "832x480"),
("480x832 (9:16, 480p)", "480x832"),
("832x624 (4:3, 480p)", "832x624"),
("624x832 (3:4, 480p)", "624x832"),
("720x720 (1:1, 480p)", "720x720"),
("512x512 (1:1, 480p)", "512x512"),
],
value="832x480",
label="Maximum Resolution",
)
with gr.Column():
with gr.Group():
gr.Markdown("### 🚀 **Smart Configuration Recommendation**", elem_classes=["auto-config-title"])
enable_auto_config = gr.Checkbox(
label="🎯 **Auto-configure Inference Options**",
value=False,
info="💡 **Automatically optimize GPU settings to match the current resolution. After changing the resolution, please re-check this option to prevent potential performance degradation or runtime errors.**",
elem_classes=["auto-config-checkbox"],
)
with gr.Column(scale=9):
seed = gr.Slider(
label="Random Seed",
minimum=0,
maximum=MAX_NUMPY_SEED,
step=1,
value=generate_random_seed(),
)
with gr.Column(scale=1):
randomize_btn = gr.Button("🎲 Randomize", variant="secondary")
randomize_btn.click(fn=generate_random_seed, inputs=None, outputs=seed)
with gr.Column():
# Set default inference steps based on model class
if model_cls == "wan2.1_distill":
infer_steps = gr.Slider(
label="Inference Steps",
minimum=4,
maximum=4,
step=1,
value=4,
interactive=False,
info="Inference steps fixed at 4 for optimal performance for distill model.",
)
elif model_cls == "wan2.1":
if task == "i2v":
infer_steps = gr.Slider(
label="Inference Steps",
minimum=1,
maximum=100,
step=1,
value=40,
info="Number of inference steps for video generation. Increasing steps may improve quality but reduce speed.",
)
elif task == "t2v":
infer_steps = gr.Slider(
label="Inference Steps",
minimum=1,
maximum=100,
step=1,
value=50,
info="Number of inference steps for video generation. Increasing steps may improve quality but reduce speed.",
)
# Set default CFG based on model class
default_enable_cfg = False if model_cls == "wan2.1_distill" else True
enable_cfg = gr.Checkbox(
label="Enable Classifier-Free Guidance",
value=default_enable_cfg,
info="Enable classifier-free guidance to control prompt strength",
)
cfg_scale = gr.Slider(
label="CFG Scale Factor",
minimum=1,
maximum=10,
step=1,
value=5,
info="Controls the influence strength of the prompt. Higher values give more influence to the prompt.",
)
sample_shift = gr.Slider(
label="Distribution Shift",
value=5,
minimum=0,
maximum=10,
step=1,
info="Controls the degree of distribution shift for samples. Larger values indicate more significant shifts.",
)
fps = gr.Slider( /* Generate button styles */
label="Frames Per Second (FPS)", .generate-btn {
minimum=8, width: 100%;
maximum=30, margin-top: 20px;
step=1, padding: 15px 30px !important;
value=16, font-size: 18px !important;
info="Frames per second of the video. Higher FPS results in smoother videos.", font-weight: bold !important;
) background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
num_frames = gr.Slider( border: none !important;
label="Total Frames", border-radius: 10px !important;
minimum=16, box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important;
maximum=120, transition: all 0.3s ease !important;
step=1, }
value=81, .generate-btn:hover {
info="Total number of frames in the video. More frames result in longer videos.", transform: translateY(-2px);
) box-shadow: 0 6px 20px rgba(102, 126, 234, 0.6) !important;
}
save_result_path = gr.Textbox( /* Accordion header styles */
label="Output Video Path", .model-config .gr-accordion-header,
value=generate_unique_filename(output_dir), .input-params .gr-accordion-header,
info="Must include .mp4 extension. If left blank or using the default value, a unique filename will be automatically generated.", .output-video .gr-accordion-header {
) font-size: 20px !important;
with gr.Column(scale=6): font-weight: bold !important;
gr.Markdown("## 📤 Generated Video") padding: 15px !important;
output_video = gr.Video( }
label="Result",
height=624,
width=360,
autoplay=True,
elem_classes=["output-video"],
)
infer_btn = gr.Button("Generate Video", variant="primary", size="lg") /* Optimize spacing */
.gr-row {
margin-bottom: 15px;
}
with gr.Tab("⚙️ Advanced Options", id=2): /* Video player styles */
with gr.Group(elem_classes="advanced-options"): .output-video video {
gr.Markdown("### GPU Memory Optimization") border-radius: 10px;
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1);
}
""",
) as demo:
gr.Markdown(f"# 🎬 LightX2V Video Generator")
# Main layout: left and right columns
with gr.Row():
# Left: configuration and input area
with gr.Column(scale=5):
# Model configuration area
with gr.Accordion("🗂️ Model Configuration", open=True, elem_classes=["model-config"]):
# FP8 support notice
if not is_fp8_supported_gpu():
gr.Markdown("⚠️ **Your device does not support FP8 inference**. Models containing FP8 have been automatically hidden.")
# Hidden state components
model_path_input = gr.Textbox(value=model_path, visible=False)
# Model type + Task type
with gr.Row(): with gr.Row():
rotary_chunk = gr.Checkbox( model_type_input = gr.Radio(
label="Chunked Rotary Position Embedding", label="Model Type",
value=False, choices=["wan2.1", "wan2.2"],
info="When enabled, processes rotary position embeddings in chunks to save GPU memory.", value="wan2.1",
info="wan2.2 requires separate high noise and low noise models",
)
task_type_input = gr.Radio(
label="Task Type",
choices=["i2v", "t2v"],
value="i2v",
info="i2v: Image-to-video, t2v: Text-to-video",
) )
rotary_chunk_size = gr.Slider( # wan2.1: Diffusion model (single row)
label="Rotary Embedding Chunk Size", with gr.Row() as wan21_row:
value=100, dit_path_input = gr.Dropdown(
minimum=100, label="🎨 Diffusion Model",
maximum=10000, choices=get_dit_choices(model_path, "wan2.1"),
step=100, value=get_dit_choices(model_path, "wan2.1")[0] if get_dit_choices(model_path, "wan2.1") else "",
info="Controls the chunk size for applying rotary embeddings. Larger values may improve performance but increase memory usage. Only effective if 'rotary_chunk' is checked.", allow_custom_value=True,
visible=True,
) )
unload_modules = gr.Checkbox( # wan2.2 specific: high noise model + low noise model (hidden by default)
label="Unload Modules", with gr.Row(visible=False) as wan22_row:
value=False, high_noise_path_input = gr.Dropdown(
info="Unload modules (T5, CLIP, DIT, etc.) after inference to reduce GPU/CPU memory usage", label="🔊 High Noise Model",
choices=get_high_noise_choices(model_path),
value=get_high_noise_choices(model_path)[0] if get_high_noise_choices(model_path) else "",
allow_custom_value=True,
) )
clean_cuda_cache = gr.Checkbox( low_noise_path_input = gr.Dropdown(
label="Clean CUDA Memory Cache", label="🔇 Low Noise Model",
value=False, choices=get_low_noise_choices(model_path),
info="When enabled, frees up GPU memory promptly but slows down inference.", value=get_low_noise_choices(model_path)[0] if get_low_noise_choices(model_path) else "",
allow_custom_value=True,
) )
gr.Markdown("### Asynchronous Offloading") # Text encoder (single row)
with gr.Row(): with gr.Row():
cpu_offload = gr.Checkbox( t5_path_input = gr.Dropdown(
label="CPU Offloading", label="📝 Text Encoder",
value=False, choices=get_t5_choices(model_path),
info="Offload parts of the model computation from GPU to CPU to reduce GPU memory usage", value=get_t5_choices(model_path)[0] if get_t5_choices(model_path) else "",
) allow_custom_value=True,
lazy_load = gr.Checkbox(
label="Enable Lazy Loading",
value=False,
info="Lazy load model components during inference. Requires CPU loading and DIT quantization.",
) )
offload_granularity = gr.Dropdown( # Image encoder + VAE decoder
label="Dit Offload Granularity", with gr.Row():
choices=["block", "phase"], clip_path_input = gr.Dropdown(
value="phase", label="🖼️ Image Encoder",
info="Sets Dit model offloading granularity: blocks or computational phases", choices=get_clip_choices(model_path),
) value=get_clip_choices(model_path)[0] if get_clip_choices(model_path) else "",
offload_ratio = gr.Slider( allow_custom_value=True,
label="Offload ratio for Dit model",
minimum=0.0,
maximum=1.0,
step=0.1,
value=1.0,
info="Controls how much of the Dit model is offloaded to the CPU",
) )
t5_cpu_offload = gr.Checkbox( vae_path_input = gr.Dropdown(
label="T5 CPU Offloading", label="🎞️ VAE Decoder",
value=False, choices=get_vae_choices(model_path),
info="Offload the T5 Encoder model to CPU to reduce GPU memory usage", value=get_vae_choices(model_path)[0] if get_vae_choices(model_path) else "",
allow_custom_value=True,
) )
t5_offload_granularity = gr.Dropdown( # Attention operator and quantization matrix multiplication operator
label="T5 Encoder Offload Granularity",
choices=["model", "block"],
value="model",
info="Controls the granularity when offloading the T5 Encoder model to CPU",
)
gr.Markdown("### Low-Precision Quantization")
with gr.Row(): with gr.Row():
torch_compile = gr.Checkbox(
label="Torch Compile",
value=False,
info="Use torch.compile to accelerate the inference process",
)
attention_type = gr.Dropdown( attention_type = gr.Dropdown(
label="Attention Operator", label="Attention Operator",
choices=[op[1] for op in attn_op_choices], choices=[op[1] for op in attn_op_choices],
value=attn_op_choices[0][1], value=attn_op_choices[0][1] if attn_op_choices else "",
info="Use appropriate attention operators to accelerate inference", info="Use appropriate attention operators to accelerate inference",
) )
quant_op = gr.Dropdown( quant_op = gr.Dropdown(
label="Quantization Matmul Operator", label="Quantization Matmul Operator",
choices=[op[1] for op in quant_op_choices], choices=[op[1] for op in quant_op_choices],
value=quant_op_choices[0][1], value=quant_op_choices[0][1],
info="Select the quantization matrix multiplication operator to accelerate inference", info="Select quantization matrix multiplication operator to accelerate inference",
interactive=True, interactive=True,
) )
# Get dynamic quantization options
quant_options = get_quantization_options(model_path) # Determine if model is distill version
def is_distill_model(model_type, dit_path, high_noise_path):
dit_quant_scheme = gr.Dropdown( """Determine if model is distill version based on model type and path"""
label="Dit", if model_type == "wan2.1":
choices=quant_options["dit_choices"], check_name = dit_path.lower() if dit_path else ""
value=quant_options["dit_default"], else:
info="Quantization precision for the Dit model", check_name = high_noise_path.lower() if high_noise_path else ""
) return "4step" in check_name
t5_quant_scheme = gr.Dropdown(
label="T5 Encoder", # Model type change event
choices=quant_options["t5_choices"], def on_model_type_change(model_type, model_path_val):
value=quant_options["t5_default"], if model_type == "wan2.2":
info="Quantization precision for the T5 Encoder model", return gr.update(visible=False), gr.update(visible=True), gr.update()
) else:
clip_quant_scheme = gr.Dropdown( # Update wan2.1 Diffusion model options
label="Clip Encoder", dit_choices = get_dit_choices(model_path_val, "wan2.1")
choices=quant_options["clip_choices"], return (
value=quant_options["clip_default"], gr.update(visible=True),
info="Quantization precision for the Clip Encoder", gr.update(visible=False),
) gr.update(choices=dit_choices, value=dit_choices[0] if dit_choices else ""),
precision_mode = gr.Dropdown( )
label="Precision Mode for Sensitive Layers",
choices=["fp32", "bf16"], model_type_input.change(
value="fp32", fn=on_model_type_change,
info="Select the numerical precision for critical model components like normalization and embedding layers. FP32 offers higher accuracy, while BF16 improves performance on compatible hardware.", inputs=[model_type_input, model_path_input],
outputs=[wan21_row, wan22_row, dit_path_input],
)
# Input parameters area
with gr.Accordion("📥 Input Parameters", open=True, elem_classes=["input-params"]):
# Image input (shown for i2v)
with gr.Row(visible=True) as image_input_row:
image_path = gr.Image(
label="Input Image",
type="filepath",
height=300,
interactive=True,
) )
gr.Markdown("### Variational Autoencoder (VAE)") # Task type change event
def on_task_type_change(task_type):
return gr.update(visible=(task_type == "i2v"))
task_type_input.change(
fn=on_task_type_change,
inputs=[task_type_input],
outputs=[image_input_row],
)
with gr.Row(): with gr.Row():
use_tae = gr.Checkbox( with gr.Column():
label="Use Tiny VAE", prompt = gr.Textbox(
value=False, label="Prompt",
info="Use a lightweight VAE model to accelerate the decoding process", lines=3,
placeholder="Describe the video content...",
max_lines=5,
)
with gr.Column():
negative_prompt = gr.Textbox(
label="Negative Prompt",
lines=3,
placeholder="What you don't want to appear in the video...",
max_lines=5,
value="Camera shake, bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards",
)
with gr.Column():
resolution = gr.Dropdown(
choices=[
# 720p
("1280x720 (16:9, 720p)", "1280x720"),
("720x1280 (9:16, 720p)", "720x1280"),
("1280x544 (21:9, 720p)", "1280x544"),
("544x1280 (9:21, 720p)", "544x1280"),
("1104x832 (4:3, 720p)", "1104x832"),
("832x1104 (3:4, 720p)", "832x1104"),
("960x960 (1:1, 720p)", "960x960"),
# 480p
("960x544 (16:9, 540p)", "960x544"),
("544x960 (9:16, 540p)", "544x960"),
("832x480 (16:9, 480p)", "832x480"),
("480x832 (9:16, 480p)", "480x832"),
("832x624 (4:3, 480p)", "832x624"),
("624x832 (3:4, 480p)", "624x832"),
("720x720 (1:1, 480p)", "720x720"),
("512x512 (1:1, 480p)", "512x512"),
],
value="832x480",
label="Maximum Resolution",
)
with gr.Column(scale=9):
seed = gr.Slider(
label="Random Seed",
minimum=0,
maximum=MAX_NUMPY_SEED,
step=1,
value=generate_random_seed(),
)
with gr.Column():
default_dit = get_dit_choices(model_path, "wan2.1")[0] if get_dit_choices(model_path, "wan2.1") else ""
default_high_noise = get_high_noise_choices(model_path)[0] if get_high_noise_choices(model_path) else ""
default_is_distill = is_distill_model("wan2.1", default_dit, default_high_noise)
if default_is_distill:
infer_steps = gr.Slider(
label="Inference Steps",
minimum=1,
maximum=100,
step=1,
value=4,
info="Distill model inference steps default to 4.",
)
else:
infer_steps = gr.Slider(
label="Inference Steps",
minimum=1,
maximum=100,
step=1,
value=40,
info="Number of inference steps for video generation. Increasing steps may improve quality but reduce speed.",
)
# Dynamically update inference steps when model path changes
def update_infer_steps(model_type, dit_path, high_noise_path):
is_distill = is_distill_model(model_type, dit_path, high_noise_path)
if is_distill:
return gr.update(minimum=1, maximum=100, value=4, interactive=True)
else:
return gr.update(minimum=1, maximum=100, value=40, interactive=True)
# Listen to model path changes
dit_path_input.change(
fn=lambda mt, dp, hnp: update_infer_steps(mt, dp, hnp),
inputs=[model_type_input, dit_path_input, high_noise_path_input],
outputs=[infer_steps],
)
high_noise_path_input.change(
fn=lambda mt, dp, hnp: update_infer_steps(mt, dp, hnp),
inputs=[model_type_input, dit_path_input, high_noise_path_input],
outputs=[infer_steps],
)
model_type_input.change(
fn=lambda mt, dp, hnp: update_infer_steps(mt, dp, hnp),
inputs=[model_type_input, dit_path_input, high_noise_path_input],
outputs=[infer_steps],
)
# Set default CFG based on model class
# CFG scale factor: default to 1 for distill, otherwise 5
default_cfg_scale = 1 if default_is_distill else 5
# enable_cfg is not exposed to frontend, automatically set based on cfg_scale
# If cfg_scale == 1, then enable_cfg = False, otherwise enable_cfg = True
default_enable_cfg = False if default_cfg_scale == 1 else True
enable_cfg = gr.Checkbox(
label="Enable Classifier-Free Guidance",
value=default_enable_cfg,
visible=False, # Hidden, not exposed to frontend
)
with gr.Row():
sample_shift = gr.Slider(
label="Distribution Shift",
value=5,
minimum=0,
maximum=10,
step=1,
info="Controls the degree of distribution shift for samples. Larger values indicate more significant shifts.",
) )
use_tiling_vae = gr.Checkbox( cfg_scale = gr.Slider(
label="VAE Tiling Inference", label="CFG Scale Factor",
value=False, minimum=1,
info="Use VAE tiling inference to reduce GPU memory usage", maximum=10,
step=1,
value=default_cfg_scale,
info="Controls the influence strength of the prompt. Higher values give more influence to the prompt. When value is 1, CFG is automatically disabled.",
) )
gr.Markdown("### Feature Caching") # Update enable_cfg based on cfg_scale
def update_enable_cfg(cfg_scale_val):
"""Automatically set enable_cfg based on cfg_scale value"""
if cfg_scale_val == 1:
return gr.update(value=False)
else:
return gr.update(value=True)
# Dynamically update CFG scale factor and enable_cfg when model path changes
def update_cfg_scale(model_type, dit_path, high_noise_path):
is_distill = is_distill_model(model_type, dit_path, high_noise_path)
if is_distill:
new_cfg_scale = 1
else:
new_cfg_scale = 5
new_enable_cfg = False if new_cfg_scale == 1 else True
return gr.update(value=new_cfg_scale), gr.update(value=new_enable_cfg)
dit_path_input.change(
fn=lambda mt, dp, hnp: update_cfg_scale(mt, dp, hnp),
inputs=[model_type_input, dit_path_input, high_noise_path_input],
outputs=[cfg_scale, enable_cfg],
)
high_noise_path_input.change(
fn=lambda mt, dp, hnp: update_cfg_scale(mt, dp, hnp),
inputs=[model_type_input, dit_path_input, high_noise_path_input],
outputs=[cfg_scale, enable_cfg],
)
model_type_input.change(
fn=lambda mt, dp, hnp: update_cfg_scale(mt, dp, hnp),
inputs=[model_type_input, dit_path_input, high_noise_path_input],
outputs=[cfg_scale, enable_cfg],
)
cfg_scale.change(
fn=update_enable_cfg,
inputs=[cfg_scale],
outputs=[enable_cfg],
)
with gr.Row(): with gr.Row():
enable_teacache = gr.Checkbox( fps = gr.Slider(
label="Tea Cache", label="Frames Per Second (FPS)",
value=False, minimum=8,
info="Cache features during inference to reduce the number of inference steps", maximum=30,
) step=1,
teacache_thresh = gr.Slider( value=16,
label="Tea Cache Threshold", info="Frames per second of the video. Higher FPS results in smoother videos.",
value=0.26,
minimum=0,
maximum=1,
info="Higher acceleration may result in lower quality —— Setting to 0.1 provides ~2.0x acceleration, setting to 0.2 provides ~3.0x acceleration",
) )
use_ret_steps = gr.Checkbox( num_frames = gr.Slider(
label="Cache Only Key Steps", label="Total Frames",
value=False, minimum=16,
info="When checked, cache is written only at key steps where the scheduler returns results; when unchecked, cache is written at all steps to ensure the highest quality", maximum=120,
step=1,
value=81,
info="Total number of frames in the video. More frames result in longer videos.",
) )
enable_auto_config.change( save_result_path = gr.Textbox(
fn=auto_configure, label="Output Video Path",
inputs=[enable_auto_config, resolution], value=generate_unique_filename(output_dir),
outputs=[ info="Must include .mp4 extension. If left blank or using the default value, a unique filename will be automatically generated.",
torch_compile, visible=False, # Hide output path, auto-generated
lazy_load, )
rotary_chunk,
rotary_chunk_size, with gr.Column(scale=4):
clean_cuda_cache, with gr.Accordion("📤 Generated Video", open=True, elem_classes=["output-video"]):
cpu_offload, output_video = gr.Video(
offload_granularity, label="",
offload_ratio, height=600,
t5_cpu_offload, autoplay=True,
unload_modules, show_label=False,
t5_offload_granularity, )
attention_type,
quant_op, infer_btn = gr.Button("🎬 Generate Video", variant="primary", size="lg", elem_classes=["generate-btn"])
dit_quant_scheme,
t5_quant_scheme, rope_chunk = gr.Checkbox(label="Chunked Rotary Position Embedding", value=False, visible=False)
clip_quant_scheme, rope_chunk_size = gr.Slider(label="Rotary Embedding Chunk Size", value=100, minimum=100, maximum=10000, step=100, visible=False)
precision_mode, unload_modules = gr.Checkbox(label="Unload Modules", value=False, visible=False)
use_tae, clean_cuda_cache = gr.Checkbox(label="Clean CUDA Memory Cache", value=False, visible=False)
use_tiling_vae, cpu_offload = gr.Checkbox(label="CPU Offloading", value=False, visible=False)
enable_teacache, lazy_load = gr.Checkbox(label="Enable Lazy Loading", value=False, visible=False)
teacache_thresh, offload_granularity = gr.Dropdown(label="Dit Offload Granularity", choices=["block", "phase"], value="phase", visible=False)
use_ret_steps, t5_cpu_offload = gr.Checkbox(label="T5 CPU Offloading", value=False, visible=False)
], clip_cpu_offload = gr.Checkbox(label="CLIP CPU Offloading", value=False, visible=False)
) vae_cpu_offload = gr.Checkbox(label="VAE CPU Offloading", value=False, visible=False)
use_tiling_vae = gr.Checkbox(label="VAE Tiling Inference", value=False, visible=False)
lazy_load.change(
fn=handle_lazy_load_change, resolution.change(
inputs=[lazy_load], fn=auto_configure,
outputs=[unload_modules], inputs=[resolution],
) outputs=[
if task == "i2v": lazy_load,
infer_btn.click( rope_chunk,
fn=run_inference, rope_chunk_size,
inputs=[ clean_cuda_cache,
prompt, cpu_offload,
negative_prompt, offload_granularity,
save_result_path, t5_cpu_offload,
torch_compile, clip_cpu_offload,
infer_steps, vae_cpu_offload,
num_frames, unload_modules,
resolution, attention_type,
seed, quant_op,
sample_shift, use_tiling_vae,
enable_teacache, ],
teacache_thresh, )
use_ret_steps,
enable_cfg, demo.load(
cfg_scale, fn=lambda res: auto_configure(res),
dit_quant_scheme, inputs=[resolution],
t5_quant_scheme, outputs=[
clip_quant_scheme, lazy_load,
fps, rope_chunk,
use_tae, rope_chunk_size,
use_tiling_vae, clean_cuda_cache,
lazy_load, cpu_offload,
precision_mode, offload_granularity,
cpu_offload, t5_cpu_offload,
offload_granularity, clip_cpu_offload,
offload_ratio, vae_cpu_offload,
t5_cpu_offload, unload_modules,
unload_modules, attention_type,
t5_offload_granularity, quant_op,
attention_type, use_tiling_vae,
quant_op, ],
rotary_chunk, )
rotary_chunk_size,
clean_cuda_cache, infer_btn.click(
image_path, fn=run_inference,
], inputs=[
outputs=output_video, prompt,
) negative_prompt,
else: save_result_path,
infer_btn.click( infer_steps,
fn=run_inference, num_frames,
inputs=[ resolution,
prompt, seed,
negative_prompt, sample_shift,
save_result_path, enable_cfg,
torch_compile, cfg_scale,
infer_steps, fps,
num_frames, use_tiling_vae,
resolution, lazy_load,
seed, cpu_offload,
sample_shift, offload_granularity,
enable_teacache, t5_cpu_offload,
teacache_thresh, clip_cpu_offload,
use_ret_steps, vae_cpu_offload,
enable_cfg, unload_modules,
cfg_scale, attention_type,
dit_quant_scheme, quant_op,
t5_quant_scheme, rope_chunk,
clip_quant_scheme, rope_chunk_size,
fps, clean_cuda_cache,
use_tae, model_path_input,
use_tiling_vae, model_type_input,
lazy_load, task_type_input,
precision_mode, dit_path_input,
cpu_offload, high_noise_path_input,
offload_granularity, low_noise_path_input,
offload_ratio, t5_path_input,
t5_cpu_offload, clip_path_input,
unload_modules, vae_path_input,
t5_offload_granularity, image_path,
attention_type, ],
quant_op, outputs=output_video,
rotary_chunk, )
rotary_chunk_size,
clean_cuda_cache,
],
outputs=output_video,
)
demo.launch(share=True, server_port=args.server_port, server_name=args.server_name, inbrowser=True, allowed_paths=[output_dir]) demo.launch(share=True, server_port=args.server_port, server_name=args.server_name, inbrowser=True, allowed_paths=[output_dir])
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Light Video Generation") parser = argparse.ArgumentParser(description="Lightweight Video Generation")
parser.add_argument("--model_path", type=str, required=True, help="Model folder path") parser.add_argument("--model_path", type=str, required=True, help="Model folder path")
parser.add_argument(
"--model_cls",
type=str,
choices=["wan2.1", "wan2.1_distill"],
default="wan2.1",
help="Model class to use (wan2.1: standard model, wan2.1_distill: distilled model for faster inference)",
)
parser.add_argument("--model_size", type=str, required=True, choices=["14b", "1.3b"], help="Model type to use")
parser.add_argument("--task", type=str, required=True, choices=["i2v", "t2v"], help="Specify the task type. 'i2v' for image-to-video translation, 't2v' for text-to-video generation.")
parser.add_argument("--server_port", type=int, default=7862, help="Server port") parser.add_argument("--server_port", type=int, default=7862, help="Server port")
parser.add_argument("--server_name", type=str, default="0.0.0.0", help="Server ip") parser.add_argument("--server_name", type=str, default="0.0.0.0", help="Server IP")
parser.add_argument("--output_dir", type=str, default="./outputs", help="Output video save directory") parser.add_argument("--output_dir", type=str, default="./outputs", help="Output video save directory")
args = parser.parse_args() args = parser.parse_args()
global model_path, model_cls, model_size, output_dir global model_path, model_cls, output_dir
model_path = args.model_path model_path = args.model_path
model_cls = args.model_cls model_cls = "wan2.1"
model_size = args.model_size
task = args.task
output_dir = args.output_dir output_dir = args.output_dir
main() main()
...@@ -4,6 +4,9 @@ import glob ...@@ -4,6 +4,9 @@ import glob
import importlib.util import importlib.util
import json import json
import os import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["DTYPE"] = "BF16"
import random import random
from datetime import datetime from datetime import datetime
...@@ -12,6 +15,15 @@ import psutil ...@@ -12,6 +15,15 @@ import psutil
import torch import torch
from loguru import logger from loguru import logger
from lightx2v.utils.input_info import set_input_info
from lightx2v.utils.set_config import get_default_config
try:
from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace
except ImportError:
apply_rope_with_cos_sin_cache_inplace = None
logger.add( logger.add(
"inference_logs.log", "inference_logs.log",
rotation="100 MB", rotation="100 MB",
...@@ -24,38 +36,196 @@ logger.add( ...@@ -24,38 +36,196 @@ logger.add(
MAX_NUMPY_SEED = 2**32 - 1 MAX_NUMPY_SEED = 2**32 - 1
def find_hf_model_path(model_path, subdir=["original", "fp8", "int8"]): def scan_model_path_contents(model_path):
paths_to_check = [model_path] """扫描 model_path 目录,返回可用的文件和子目录"""
if isinstance(subdir, list): if not model_path or not os.path.exists(model_path):
for sub in subdir: return {"dirs": [], "files": [], "safetensors_dirs": [], "pth_files": []}
paths_to_check.append(os.path.join(model_path, sub))
else:
paths_to_check.append(os.path.join(model_path, subdir))
for path in paths_to_check: dirs = []
safetensors_pattern = os.path.join(path, "*.safetensors") files = []
safetensors_files = glob.glob(safetensors_pattern) safetensors_dirs = []
if safetensors_files: pth_files = []
logger.info(f"Found Hugging Face model files in: {path}")
return path
raise FileNotFoundError(f"No Hugging Face model files (.safetensors) found.\nPlease download the model from: https://huggingface.co/lightx2v/ or specify the model path in the configuration file.")
try:
for item in os.listdir(model_path):
item_path = os.path.join(model_path, item)
if os.path.isdir(item_path):
dirs.append(item)
# 检查目录是否包含 safetensors 文件
if glob.glob(os.path.join(item_path, "*.safetensors")):
safetensors_dirs.append(item)
elif os.path.isfile(item_path):
files.append(item)
if item.endswith(".pth"):
pth_files.append(item)
except Exception as e:
logger.warning(f"扫描目录失败: {e}")
def find_torch_model_path(model_path, filename=None, subdir=["original", "fp8", "int8"]): return {
paths_to_check = [ "dirs": sorted(dirs),
os.path.join(model_path, filename), "files": sorted(files),
] "safetensors_dirs": sorted(safetensors_dirs),
if isinstance(subdir, list): "pth_files": sorted(pth_files),
for sub in subdir: }
paths_to_check.append(os.path.join(model_path, sub, filename))
def get_dit_choices(model_path, model_type="wan2.1"):
"""获取 Diffusion 模型可选项(根据模型类型筛选)"""
contents = scan_model_path_contents(model_path)
excluded_keywords = ["vae", "tae", "clip", "t5", "high_noise", "low_noise"]
fp8_supported = is_fp8_supported_gpu()
if model_type == "wan2.1":
# wan2.1: 筛选包含 wan2.1 或 Wan2.1 的文件/目录
def is_valid(name):
name_lower = name.lower()
if "wan2.1" not in name_lower:
return False
if not fp8_supported and "fp8" in name_lower:
return False
return not any(kw in name_lower for kw in excluded_keywords)
else: else:
paths_to_check.append(os.path.join(model_path, subdir, filename)) # wan2.2: 筛选包含 wan2.2 或 Wan2.2 的文件/目录
print(paths_to_check) def is_valid(name):
for path in paths_to_check: name_lower = name.lower()
if os.path.exists(path): if "wan2.2" not in name_lower:
logger.info(f"Found PyTorch model checkpoint: {path}") return False
return path if not fp8_supported and "fp8" in name_lower:
raise FileNotFoundError(f"PyTorch model file '{filename}' not found.\nPlease download the model from https://huggingface.co/lightx2v/ or specify the model path in the configuration file.") return False
return not any(kw in name_lower for kw in excluded_keywords)
# 筛选符合条件的目录和文件
dir_choices = [d for d in contents["dirs"] if is_valid(d)]
file_choices = [f for f in contents["files"] if is_valid(f)]
choices = dir_choices + file_choices
return choices if choices else [""]
def get_high_noise_choices(model_path):
"""获取高噪模型可选项(包含 high_noise 的文件/目录)"""
contents = scan_model_path_contents(model_path)
fp8_supported = is_fp8_supported_gpu()
def is_valid(name):
name_lower = name.lower()
if not fp8_supported and "fp8" in name_lower:
return False
return "high_noise" in name_lower or "high-noise" in name_lower
dir_choices = [d for d in contents["dirs"] if is_valid(d)]
file_choices = [f for f in contents["files"] if is_valid(f)]
choices = dir_choices + file_choices
return choices if choices else [""]
def get_low_noise_choices(model_path):
"""获取低噪模型可选项(包含 low_noise 的文件/目录)"""
contents = scan_model_path_contents(model_path)
fp8_supported = is_fp8_supported_gpu()
def is_valid(name):
name_lower = name.lower()
if not fp8_supported and "fp8" in name_lower:
return False
return "low_noise" in name_lower or "low-noise" in name_lower
dir_choices = [d for d in contents["dirs"] if is_valid(d)]
file_choices = [f for f in contents["files"] if is_valid(f)]
choices = dir_choices + file_choices
return choices if choices else [""]
def get_t5_choices(model_path):
"""获取 T5 模型可选项(.pth 或 .safetensors 文件,包含 t5 关键字)"""
contents = scan_model_path_contents(model_path)
fp8_supported = is_fp8_supported_gpu()
# 从 .pth 文件中筛选
pth_choices = [f for f in contents["pth_files"] if "t5" in f.lower() and (fp8_supported or "fp8" not in f.lower())]
# 从 .safetensors 文件中筛选
safetensors_choices = [f for f in contents["files"] if f.endswith(".safetensors") and "t5" in f.lower() and (fp8_supported or "fp8" not in f.lower())]
# 从包含 safetensors 的目录中筛选
safetensors_dir_choices = [d for d in contents["safetensors_dirs"] if "t5" in d.lower() and (fp8_supported or "fp8" not in d.lower())]
choices = pth_choices + safetensors_choices + safetensors_dir_choices
return choices if choices else [""]
def get_clip_choices(model_path):
"""获取 CLIP 模型可选项(.pth 或 .safetensors 文件,包含 clip 关键字)"""
contents = scan_model_path_contents(model_path)
fp8_supported = is_fp8_supported_gpu()
# 从 .pth 文件中筛选
pth_choices = [f for f in contents["pth_files"] if "clip" in f.lower() and (fp8_supported or "fp8" not in f.lower())]
# 从 .safetensors 文件中筛选
safetensors_choices = [f for f in contents["files"] if f.endswith(".safetensors") and "clip" in f.lower() and (fp8_supported or "fp8" not in f.lower())]
# 从包含 safetensors 的目录中筛选
safetensors_dir_choices = [d for d in contents["safetensors_dirs"] if "clip" in d.lower() and (fp8_supported or "fp8" not in d.lower())]
choices = pth_choices + safetensors_choices + safetensors_dir_choices
return choices if choices else [""]
def get_vae_choices(model_path):
"""获取 VAE 模型可选项(.pth 或 .safetensors 文件,包含 vae/VAE/tae 关键字)"""
contents = scan_model_path_contents(model_path)
fp8_supported = is_fp8_supported_gpu()
# 从 .pth 文件中筛选
pth_choices = [f for f in contents["pth_files"] if any(kw in f.lower() for kw in ["vae", "tae"]) and (fp8_supported or "fp8" not in f.lower())]
# 从 .safetensors 文件中筛选
safetensors_choices = [f for f in contents["files"] if f.endswith(".safetensors") and any(kw in f.lower() for kw in ["vae", "tae"]) and (fp8_supported or "fp8" not in f.lower())]
# 从包含 safetensors 的目录中筛选
safetensors_dir_choices = [d for d in contents["safetensors_dirs"] if any(kw in d.lower() for kw in ["vae", "tae"]) and (fp8_supported or "fp8" not in d.lower())]
choices = pth_choices + safetensors_choices + safetensors_dir_choices
return choices if choices else [""]
def detect_quant_scheme(model_name):
"""根据模型名字自动检测量化精度
- 如果模型名字包含 "int8" → "int8"
- 如果模型名字包含 "fp8" 且设备支持 → "fp8"
- 否则返回 None(表示不使用量化)
"""
if not model_name:
return None
name_lower = model_name.lower()
if "int8" in name_lower:
return "int8"
elif "fp8" in name_lower:
if is_fp8_supported_gpu():
return "fp8"
else:
# 设备不支持fp8,返回None(使用默认精度)
return None
return None
def update_model_path_options(model_path, model_type="wan2.1"):
"""当 model_path 或 model_type 改变时,更新所有模型路径选择器"""
dit_choices = get_dit_choices(model_path, model_type)
high_noise_choices = get_high_noise_choices(model_path)
low_noise_choices = get_low_noise_choices(model_path)
t5_choices = get_t5_choices(model_path)
clip_choices = get_clip_choices(model_path)
vae_choices = get_vae_choices(model_path)
return (
gr.update(choices=dit_choices, value=dit_choices[0] if dit_choices else ""),
gr.update(choices=high_noise_choices, value=high_noise_choices[0] if high_noise_choices else ""),
gr.update(choices=low_noise_choices, value=low_noise_choices[0] if low_noise_choices else ""),
gr.update(choices=t5_choices, value=t5_choices[0] if t5_choices else ""),
gr.update(choices=clip_choices, value=clip_choices[0] if clip_choices else ""),
gr.update(choices=vae_choices, value=vae_choices[0] if vae_choices else ""),
)
def generate_random_seed(): def generate_random_seed():
...@@ -109,12 +279,18 @@ def get_available_attn_ops(): ...@@ -109,12 +279,18 @@ def get_available_attn_ops():
else: else:
available_ops.append(("flash_attn3", False)) available_ops.append(("flash_attn3", False))
q8f_installed = is_module_installed("sageattention") sage_installed = is_module_installed("sageattention")
if q8f_installed: if sage_installed:
available_ops.append(("sage_attn2", True)) available_ops.append(("sage_attn2", True))
else: else:
available_ops.append(("sage_attn2", False)) available_ops.append(("sage_attn2", False))
sage3_installed = is_module_installed("sageattn3")
if sage3_installed:
available_ops.append(("sage_attn3", True))
else:
available_ops.append(("sage_attn3", False))
torch_installed = is_module_installed("torch") torch_installed = is_module_installed("torch")
if torch_installed: if torch_installed:
available_ops.append(("torch_sdpa", True)) available_ops.append(("torch_sdpa", True))
...@@ -165,7 +341,7 @@ def cleanup_memory(): ...@@ -165,7 +341,7 @@ def cleanup_memory():
def generate_unique_filename(output_dir): def generate_unique_filename(output_dir):
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
return os.path.join(output_dir, f"{model_cls}_{timestamp}.mp4") return os.path.join(output_dir, f"{timestamp}.mp4")
def is_fp8_supported_gpu(): def is_fp8_supported_gpu():
...@@ -233,13 +409,25 @@ def get_quantization_options(model_path): ...@@ -233,13 +409,25 @@ def get_quantization_options(model_path):
return {"dit_choices": dit_choices, "dit_default": dit_default, "t5_choices": t5_choices, "t5_default": t5_default, "clip_choices": clip_choices, "clip_default": clip_default} return {"dit_choices": dit_choices, "dit_default": dit_default, "t5_choices": t5_choices, "t5_default": t5_default, "clip_choices": clip_choices, "clip_default": clip_default}
def determine_model_cls(model_type, dit_name, high_noise_name):
"""根据模型类型和文件名确定 model_cls"""
# 确定要检查的文件名
if model_type == "wan2.1":
check_name = dit_name.lower() if dit_name else ""
is_distill = "4step" in check_name
return "wan2.1_distill" if is_distill else "wan2.1"
else:
# wan2.2
check_name = high_noise_name.lower() if high_noise_name else ""
is_distill = "4step" in check_name
return "wan2.2_moe_distill" if is_distill else "wan2.2_moe"
global_runner = None global_runner = None
current_config = None current_config = None
cur_dit_quant_scheme = None cur_dit_path = None
cur_clip_quant_scheme = None cur_t5_path = None
cur_t5_quant_scheme = None cur_clip_path = None
cur_precision_mode = None
cur_enable_teacache = None
available_quant_ops = get_available_quant_ops() available_quant_ops = get_available_quant_ops()
quant_op_choices = [] quant_op_choices = []
...@@ -249,8 +437,29 @@ for op_name, is_installed in available_quant_ops: ...@@ -249,8 +437,29 @@ for op_name, is_installed in available_quant_ops:
quant_op_choices.append((op_name, display_text)) quant_op_choices.append((op_name, display_text))
available_attn_ops = get_available_attn_ops() available_attn_ops = get_available_attn_ops()
# 优先级顺序
attn_priority = ["sage_attn2", "flash_attn3", "flash_attn2", "torch_sdpa"]
# 按优先级排序,已安装的在前,未安装的在后
attn_op_choices = [] attn_op_choices = []
for op_name, is_installed in available_attn_ops: attn_op_dict = dict(available_attn_ops)
# 先添加已安装的(按优先级)
for op_name in attn_priority:
if op_name in attn_op_dict and attn_op_dict[op_name]:
status_text = "✅ 已安装"
display_text = f"{op_name} ({status_text})"
attn_op_choices.append((op_name, display_text))
# 再添加未安装的(按优先级)
for op_name in attn_priority:
if op_name in attn_op_dict and not attn_op_dict[op_name]:
status_text = "❌ 未安装"
display_text = f"{op_name} ({status_text})"
attn_op_choices.append((op_name, display_text))
# 添加其他不在优先级列表中的算子(已安装的在前)
other_ops = [(op_name, is_installed) for op_name, is_installed in available_attn_ops if op_name not in attn_priority]
for op_name, is_installed in sorted(other_ops, key=lambda x: not x[1]): # 已安装的在前
status_text = "✅ 已安装" if is_installed else "❌ 未安装" status_text = "✅ 已安装" if is_installed else "❌ 未安装"
display_text = f"{op_name} ({status_text})" display_text = f"{op_name} ({status_text})"
attn_op_choices.append((op_name, display_text)) attn_op_choices.append((op_name, display_text))
...@@ -260,36 +469,36 @@ def run_inference( ...@@ -260,36 +469,36 @@ def run_inference(
prompt, prompt,
negative_prompt, negative_prompt,
save_result_path, save_result_path,
torch_compile,
infer_steps, infer_steps,
num_frames, num_frames,
resolution, resolution,
seed, seed,
sample_shift, sample_shift,
enable_teacache,
teacache_thresh,
use_ret_steps,
enable_cfg, enable_cfg,
cfg_scale, cfg_scale,
dit_quant_scheme,
t5_quant_scheme,
clip_quant_scheme,
fps, fps,
use_tae,
use_tiling_vae, use_tiling_vae,
lazy_load, lazy_load,
precision_mode,
cpu_offload, cpu_offload,
offload_granularity, offload_granularity,
offload_ratio,
t5_cpu_offload, t5_cpu_offload,
clip_cpu_offload,
vae_cpu_offload,
unload_modules, unload_modules,
t5_offload_granularity,
attention_type, attention_type,
quant_op, quant_op,
rotary_chunk, rope_chunk,
rotary_chunk_size, rope_chunk_size,
clean_cuda_cache, clean_cuda_cache,
model_path_input,
model_type_input,
task_type_input,
dit_path_input,
high_noise_path_input,
low_noise_path_input,
t5_path_input,
clip_path_input,
vae_path_input,
image_path=None, image_path=None,
): ):
cleanup_memory() cleanup_memory()
...@@ -297,8 +506,23 @@ def run_inference( ...@@ -297,8 +506,23 @@ def run_inference(
quant_op = quant_op.split("(")[0].strip() quant_op = quant_op.split("(")[0].strip()
attention_type = attention_type.split("(")[0].strip() attention_type = attention_type.split("(")[0].strip()
global global_runner, current_config, model_path, task global global_runner, current_config, model_path, model_cls
global cur_dit_quant_scheme, cur_clip_quant_scheme, cur_t5_quant_scheme, cur_precision_mode, cur_enable_teacache global cur_dit_path, cur_t5_path, cur_clip_path
task = task_type_input
model_cls = determine_model_cls(model_type_input, dit_path_input, high_noise_path_input)
logger.info(f"自动确定 model_cls: {model_cls} (模型类型: {model_type_input})")
if model_type_input == "wan2.1":
dit_quant_detected = detect_quant_scheme(dit_path_input)
else:
dit_quant_detected = detect_quant_scheme(high_noise_path_input)
t5_quant_detected = detect_quant_scheme(t5_path_input)
clip_quant_detected = detect_quant_scheme(clip_path_input)
logger.info(f"自动检测量化精度 - DIT: {dit_quant_detected}, T5: {t5_quant_detected}, CLIP: {clip_quant_detected}")
if model_path_input and model_path_input.strip():
model_path = model_path_input.strip()
if os.path.exists(os.path.join(model_path, "config.json")): if os.path.exists(os.path.join(model_path, "config.json")):
with open(os.path.join(model_path, "config.json"), "r") as f: with open(os.path.join(model_path, "config.json"), "r") as f:
...@@ -306,159 +530,88 @@ def run_inference( ...@@ -306,159 +530,88 @@ def run_inference(
else: else:
model_config = {} model_config = {}
if task == "t2v":
if model_size == "1.3b":
# 1.3B
coefficient = [
[
-5.21862437e04,
9.23041404e03,
-5.28275948e02,
1.36987616e01,
-4.99875664e-02,
],
[
2.39676752e03,
-1.31110545e03,
2.01331979e02,
-8.29855975e00,
1.37887774e-01,
],
]
else:
# 14B
coefficient = [
[
-3.03318725e05,
4.90537029e04,
-2.65530556e03,
5.87365115e01,
-3.15583525e-01,
],
[
-5784.54975374,
5449.50911966,
-1811.16591783,
256.27178429,
-13.02252404,
],
]
elif task == "i2v":
if resolution in [
"1280x720",
"720x1280",
"1280x544",
"544x1280",
"1104x832",
"832x1104",
"960x960",
]:
# 720p
coefficient = [
[
8.10705460e03,
2.13393892e03,
-3.72934672e02,
1.66203073e01,
-4.17769401e-02,
],
[-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683],
]
else:
# 480p
coefficient = [
[
2.57151496e05,
-3.54229917e04,
1.40286849e03,
-1.35890334e01,
1.32517977e-01,
],
[
-3.02331670e02,
2.23948934e02,
-5.25463970e01,
5.87348440e00,
-2.01973289e-01,
],
]
save_result_path = generate_unique_filename(output_dir) save_result_path = generate_unique_filename(output_dir)
is_dit_quant = dit_quant_scheme != "bf16" is_dit_quant = dit_quant_detected != "bf16"
is_t5_quant = t5_quant_scheme != "bf16" is_t5_quant = t5_quant_detected != "bf16"
is_clip_quant = clip_quant_detected != "fp16"
dit_quantized_ckpt = None
dit_original_ckpt = None
high_noise_quantized_ckpt = None
low_noise_quantized_ckpt = None
high_noise_original_ckpt = None
low_noise_original_ckpt = None
if is_dit_quant:
dit_quant_scheme = f"{dit_quant_detected}-{quant_op}"
if "wan2.1" in model_cls:
dit_quantized_ckpt = os.path.join(model_path, dit_path_input)
else:
high_noise_quantized_ckpt = os.path.join(model_path, high_noise_path_input)
low_noise_quantized_ckpt = os.path.join(model_path, low_noise_path_input)
else:
dit_quantized_ckpt = "Default"
if "wan2.1" in model_cls:
dit_original_ckpt = os.path.join(model_path, dit_path_input)
else:
high_noise_original_ckpt = os.path.join(model_path, high_noise_path_input)
low_noise_original_ckpt = os.path.join(model_path, low_noise_path_input)
# 使用前端选择的 T5 路径
if is_t5_quant: if is_t5_quant:
t5_model_name = f"models_t5_umt5-xxl-enc-{t5_quant_scheme}.pth" t5_quantized_ckpt = os.path.join(model_path, t5_path_input)
t5_quantized_ckpt = find_torch_model_path(model_path, t5_model_name, t5_quant_scheme) t5_quant_scheme = f"{t5_quant_detected}-{quant_op}"
t5_original_ckpt = None t5_original_ckpt = None
else: else:
t5_quantized_ckpt = None t5_quantized_ckpt = None
t5_model_name = "models_t5_umt5-xxl-enc-bf16.pth" t5_quant_scheme = None
t5_original_ckpt = find_torch_model_path(model_path, t5_model_name, "original") t5_original_ckpt = os.path.join(model_path, t5_path_input)
is_clip_quant = clip_quant_scheme != "fp16"
# 使用前端选择的 CLIP 路径
if is_clip_quant: if is_clip_quant:
clip_model_name = f"clip-{t5_quant_scheme}.pth" clip_quantized_ckpt = os.path.join(model_path, clip_path_input)
clip_quantized_ckpt = find_torch_model_path(model_path, clip_model_name, clip_quant_scheme) clip_quant_scheme = f"{clip_quant_detected}-{quant_op}"
clip_original_ckpt = None clip_original_ckpt = None
else: else:
clip_quantized_ckpt = None clip_quantized_ckpt = None
clip_model_name = "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" clip_quant_scheme = None
clip_original_ckpt = find_torch_model_path(model_path, clip_model_name, "original") clip_original_ckpt = os.path.join(model_path, clip_path_input)
if model_type_input == "wan2.1":
current_dit_path = dit_path_input
else:
current_dit_path = f"{high_noise_path_input}|{low_noise_path_input}" if high_noise_path_input and low_noise_path_input else None
current_t5_path = t5_path_input
current_clip_path = clip_path_input
needs_reinit = ( needs_reinit = (
lazy_load lazy_load
or unload_modules or unload_modules
or global_runner is None or global_runner is None
or current_config is None or current_config is None
or cur_dit_quant_scheme is None or cur_dit_path is None
or cur_dit_quant_scheme != dit_quant_scheme or cur_dit_path != current_dit_path
or cur_clip_quant_scheme is None or cur_t5_path is None
or cur_clip_quant_scheme != clip_quant_scheme or cur_t5_path != current_t5_path
or cur_t5_quant_scheme is None or cur_clip_path is None
or cur_t5_quant_scheme != t5_quant_scheme or cur_clip_path != current_clip_path
or cur_precision_mode is None
or cur_precision_mode != precision_mode
or cur_enable_teacache is None
or cur_enable_teacache != enable_teacache
) )
if torch_compile: if cfg_scale == 1:
os.environ["ENABLE_GRAPH_MODE"] = "true" enable_cfg = False
else: else:
os.environ["ENABLE_GRAPH_MODE"] = "false" enable_cfg = True
if precision_mode == "bf16":
os.environ["DTYPE"] = "BF16"
else:
os.environ.pop("DTYPE", None)
if is_dit_quant: vae_name_lower = vae_path_input.lower() if vae_path_input else ""
if quant_op == "vllm": use_tae = "tae" in vae_name_lower or "lighttae" in vae_name_lower
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Vllm" use_lightvae = "lightvae" in vae_name_lower
elif quant_op == "sgl": need_scaled = "lighttae" in vae_name_lower
if dit_quant_scheme == "int8":
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Sgl-ActVllm"
else:
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Sgl"
elif quant_op == "q8f":
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Q8F"
t5_quant_scheme = f"{t5_quant_scheme}-q8f"
clip_quant_scheme = f"{clip_quant_scheme}-q8f"
dit_quantized_ckpt = find_hf_model_path(model_path, dit_quant_scheme)
if os.path.exists(os.path.join(dit_quantized_ckpt, "config.json")):
with open(os.path.join(dit_quantized_ckpt, "config.json"), "r") as f:
quant_model_config = json.load(f)
else:
quant_model_config = {}
else:
mm_type = "Default"
dit_quantized_ckpt = None
quant_model_config = {}
config = { logger.info(f"VAE 配置 - use_tae: {use_tae}, use_lightvae: {use_lightvae}, need_scaled: {need_scaled} (VAE: {vae_path_input})")
config_graio = {
"infer_steps": infer_steps, "infer_steps": infer_steps,
"target_video_length": num_frames, "target_video_length": num_frames,
"target_width": int(resolution.split("x")[0]), "target_width": int(resolution.split("x")[0]),
...@@ -466,26 +619,36 @@ def run_inference( ...@@ -466,26 +619,36 @@ def run_inference(
"self_attn_1_type": attention_type, "self_attn_1_type": attention_type,
"cross_attn_1_type": attention_type, "cross_attn_1_type": attention_type,
"cross_attn_2_type": attention_type, "cross_attn_2_type": attention_type,
"seed": seed,
"enable_cfg": enable_cfg, "enable_cfg": enable_cfg,
"sample_guide_scale": cfg_scale, "sample_guide_scale": cfg_scale,
"sample_shift": sample_shift, "sample_shift": sample_shift,
"cpu_offload": cpu_offload,
"offload_granularity": offload_granularity,
"offload_ratio": offload_ratio,
"t5_offload_granularity": t5_offload_granularity,
"dit_quantized_ckpt": dit_quantized_ckpt,
"mm_config": {
"mm_type": mm_type,
},
"fps": fps, "fps": fps,
"feature_caching": "Tea" if enable_teacache else "NoCaching", "feature_caching": "NoCaching",
"coefficients": coefficient[0] if use_ret_steps else coefficient[1], "do_mm_calib": False,
"use_ret_steps": use_ret_steps, "parallel_attn_type": None,
"teacache_thresh": teacache_thresh, "parallel_vae": False,
"t5_original_ckpt": t5_original_ckpt, "max_area": False,
"vae_stride": (4, 8, 8),
"patch_size": (1, 2, 2),
"lora_path": None,
"strength_model": 1.0,
"use_prompt_enhancer": False,
"text_len": 512,
"denoising_step_list": [1000, 750, 500, 250],
"cpu_offload": True if "wan2.2" in model_cls else cpu_offload,
"offload_granularity": "phase" if "wan2.2" in model_cls else offload_granularity,
"t5_cpu_offload": t5_cpu_offload, "t5_cpu_offload": t5_cpu_offload,
"unload_modules": unload_modules, "clip_cpu_offload": clip_cpu_offload,
"vae_cpu_offload": vae_cpu_offload,
"dit_quantized": is_dit_quant,
"dit_quant_scheme": dit_quant_scheme,
"dit_quantized_ckpt": dit_quantized_ckpt,
"dit_original_ckpt": dit_original_ckpt,
"high_noise_quantized_ckpt": high_noise_quantized_ckpt,
"low_noise_quantized_ckpt": low_noise_quantized_ckpt,
"high_noise_original_ckpt": high_noise_original_ckpt,
"low_noise_original_ckpt": low_noise_original_ckpt,
"t5_original_ckpt": t5_original_ckpt,
"t5_quantized": is_t5_quant, "t5_quantized": is_t5_quant,
"t5_quantized_ckpt": t5_quantized_ckpt, "t5_quantized_ckpt": t5_quantized_ckpt,
"t5_quant_scheme": t5_quant_scheme, "t5_quant_scheme": t5_quant_scheme,
...@@ -493,29 +656,27 @@ def run_inference( ...@@ -493,29 +656,27 @@ def run_inference(
"clip_quantized": is_clip_quant, "clip_quantized": is_clip_quant,
"clip_quantized_ckpt": clip_quantized_ckpt, "clip_quantized_ckpt": clip_quantized_ckpt,
"clip_quant_scheme": clip_quant_scheme, "clip_quant_scheme": clip_quant_scheme,
"vae_path": find_torch_model_path(model_path, "Wan2.1_VAE.pth"), "vae_path": os.path.join(model_path, vae_path_input),
"use_tiling_vae": use_tiling_vae, "use_tiling_vae": use_tiling_vae,
"use_tae": use_tae, "use_tae": use_tae,
"tae_path": (find_torch_model_path(model_path, "taew2_1.pth") if use_tae else None), "use_lightvae": use_lightvae,
"need_scaled": need_scaled,
"lazy_load": lazy_load, "lazy_load": lazy_load,
"do_mm_calib": False, "rope_chunk": rope_chunk,
"parallel_attn_type": None, "rope_chunk_size": rope_chunk_size,
"parallel_vae": False,
"max_area": False,
"vae_stride": (4, 8, 8),
"patch_size": (1, 2, 2),
"lora_path": None,
"strength_model": 1.0,
"use_prompt_enhancer": False,
"text_len": 512,
"rotary_chunk": rotary_chunk,
"rotary_chunk_size": rotary_chunk_size,
"clean_cuda_cache": clean_cuda_cache, "clean_cuda_cache": clean_cuda_cache,
"denoising_step_list": [1000, 750, 500, 250], "unload_modules": unload_modules,
"seq_parallel": False,
"warm_up_cpu_buffers": False,
"boundary_step_index": 2,
"boundary": 0.900,
"use_image_encoder": False if "wan2.2" in model_cls else True,
"rope_type": "flashinfer" if apply_rope_with_cos_sin_cache_inplace else "torch",
} }
args = argparse.Namespace( args = argparse.Namespace(
model_cls=model_cls, model_cls=model_cls,
seed=seed,
task=task, task=task,
model_path=model_path, model_path=model_path,
prompt_enhancer=None, prompt_enhancer=None,
...@@ -523,11 +684,13 @@ def run_inference( ...@@ -523,11 +684,13 @@ def run_inference(
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
image_path=image_path, image_path=image_path,
save_result_path=save_result_path, save_result_path=save_result_path,
return_result_tensor=False,
) )
config = get_default_config()
config.update({k: v for k, v in vars(args).items()}) config.update({k: v for k, v in vars(args).items()})
config.update(model_config) config.update(model_config)
config.update(quant_model_config) config.update(config_graio)
logger.info(f"使用模型: {model_path}") logger.info(f"使用模型: {model_path}")
logger.info(f"推理配置:\n{json.dumps(config, indent=4, ensure_ascii=False)}") logger.info(f"推理配置:\n{json.dumps(config, indent=4, ensure_ascii=False)}")
...@@ -543,28 +706,19 @@ def run_inference( ...@@ -543,28 +706,19 @@ def run_inference(
from lightx2v.infer import init_runner # noqa from lightx2v.infer import init_runner # noqa
runner = init_runner(config) runner = init_runner(config)
input_info = set_input_info(args)
current_config = config current_config = config
cur_dit_quant_scheme = dit_quant_scheme cur_dit_path = current_dit_path
cur_clip_quant_scheme = clip_quant_scheme cur_t5_path = current_t5_path
cur_t5_quant_scheme = t5_quant_scheme cur_clip_path = current_clip_path
cur_precision_mode = precision_mode
cur_enable_teacache = enable_teacache
if not lazy_load: if not lazy_load:
global_runner = runner global_runner = runner
else: else:
runner.config = config runner.config = config
runner.run_pipeline() runner.run_pipeline(input_info)
del config, args, model_config, quant_model_config
if "dit_quantized_ckpt" in locals():
del dit_quantized_ckpt
if "t5_quant_ckpt" in locals():
del t5_quant_ckpt
if "clip_quant_ckpt" in locals():
del clip_quant_ckpt
cleanup_memory() cleanup_memory()
return save_result_path return save_result_path
...@@ -575,44 +729,28 @@ def handle_lazy_load_change(lazy_load_enabled): ...@@ -575,44 +729,28 @@ def handle_lazy_load_change(lazy_load_enabled):
return gr.update(value=lazy_load_enabled) return gr.update(value=lazy_load_enabled)
def auto_configure(enable_auto_config, resolution): def auto_configure(resolution):
"""根据机器配置和分辨率自动设置推理选项"""
default_config = { default_config = {
"torch_compile_val": False,
"lazy_load_val": False, "lazy_load_val": False,
"rotary_chunk_val": False, "rope_chunk_val": False,
"rotary_chunk_size_val": 100, "rope_chunk_size_val": 100,
"clean_cuda_cache_val": False, "clean_cuda_cache_val": False,
"cpu_offload_val": False, "cpu_offload_val": False,
"offload_granularity_val": "block", "offload_granularity_val": "block",
"offload_ratio_val": 1,
"t5_cpu_offload_val": False, "t5_cpu_offload_val": False,
"clip_cpu_offload_val": False,
"vae_cpu_offload_val": False,
"unload_modules_val": False, "unload_modules_val": False,
"t5_offload_granularity_val": "model",
"attention_type_val": attn_op_choices[0][1], "attention_type_val": attn_op_choices[0][1],
"quant_op_val": quant_op_choices[0][1], "quant_op_val": quant_op_choices[0][1],
"dit_quant_scheme_val": "bf16",
"t5_quant_scheme_val": "bf16",
"clip_quant_scheme_val": "fp16",
"precision_mode_val": "fp32",
"use_tae_val": False,
"use_tiling_vae_val": False, "use_tiling_vae_val": False,
"enable_teacache_val": False,
"teacache_thresh_val": 0.26,
"use_ret_steps_val": False,
} }
if not enable_auto_config:
return tuple(gr.update(value=default_config[key]) for key in default_config)
gpu_memory = round(get_gpu_memory()) gpu_memory = round(get_gpu_memory())
cpu_memory = round(get_cpu_memory()) cpu_memory = round(get_cpu_memory())
if is_fp8_supported_gpu(): attn_priority = ["sage_attn3", "sage_attn2", "flash_attn3", "flash_attn2", "torch_sdpa"]
quant_type = "fp8"
else:
quant_type = "int8"
attn_priority = ["sage_attn2", "flash_attn3", "flash_attn2", "torch_sdpa"]
if is_ada_architecture_gpu(): if is_ada_architecture_gpu():
quant_op_priority = ["q8f", "vllm", "sgl"] quant_op_priority = ["q8f", "vllm", "sgl"]
...@@ -647,25 +785,15 @@ def auto_configure(enable_auto_config, resolution): ...@@ -647,25 +785,15 @@ def auto_configure(enable_auto_config, resolution):
else: else:
res = "480p" res = "480p"
if model_size == "14b": if res == "720p":
is_14b = True
else:
is_14b = False
if res == "720p" and is_14b:
gpu_rules = [ gpu_rules = [
(80, {}), (80, {}),
(48, {"cpu_offload_val": True, "offload_ratio_val": 0.5, "t5_cpu_offload_val": True}), (40, {"cpu_offload_val": False, "t5_cpu_offload_val": True, "vae_cpu_offload_val": True, "clip_cpu_offload_val": True}),
(40, {"cpu_offload_val": True, "offload_ratio_val": 0.8, "t5_cpu_offload_val": True}), (32, {"cpu_offload_val": True, "t5_cpu_offload_val": False, "vae_cpu_offload_val": False, "clip_cpu_offload_val": False}),
(32, {"cpu_offload_val": True, "offload_ratio_val": 1, "t5_cpu_offload_val": True}),
( (
24, 24,
{ {
"cpu_offload_val": True, "cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True, "use_tiling_vae_val": True,
}, },
), ),
...@@ -673,155 +801,68 @@ def auto_configure(enable_auto_config, resolution): ...@@ -673,155 +801,68 @@ def auto_configure(enable_auto_config, resolution):
16, 16,
{ {
"cpu_offload_val": True, "cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True, "use_tiling_vae_val": True,
"offload_granularity_val": "phase", "offload_granularity_val": "phase",
"rotary_chunk_val": True, "rope_chunk_val": True,
"rotary_chunk_size_val": 100, "rope_chunk_size_val": 100,
},
),
(
12,
{
"cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
"rotary_chunk_val": True,
"rotary_chunk_size_val": 100,
"clean_cuda_cache_val": True,
"use_tae_val": True,
}, },
), ),
( (
8, 8,
{ {
"cpu_offload_val": True, "cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True, "use_tiling_vae_val": True,
"offload_granularity_val": "phase", "offload_granularity_val": "phase",
"rotary_chunk_val": True, "rope_chunk_val": True,
"rotary_chunk_size_val": 100, "rope_chunk_size_val": 100,
"clean_cuda_cache_val": True, "clean_cuda_cache_val": True,
"t5_quant_scheme_val": quant_type,
"clip_quant_scheme_val": quant_type,
"dit_quant_scheme_val": quant_type,
"lazy_load_val": True,
"unload_modules_val": True,
"use_tae_val": True,
}, },
), ),
] ]
elif is_14b: else:
gpu_rules = [ gpu_rules = [
(80, {}), (80, {}),
(48, {"cpu_offload_val": True, "offload_ratio_val": 0.2, "t5_cpu_offload_val": True}), (40, {"cpu_offload_val": False, "t5_cpu_offload_val": True, "vae_cpu_offload_val": True, "clip_cpu_offload_val": True}),
(40, {"cpu_offload_val": True, "offload_ratio_val": 0.5, "t5_cpu_offload_val": True}), (32, {"cpu_offload_val": True, "t5_cpu_offload_val": False, "vae_cpu_offload_val": False, "clip_cpu_offload_val": False}),
(24, {"cpu_offload_val": True, "offload_ratio_val": 0.8, "t5_cpu_offload_val": True}),
( (
16, 24,
{ {
"cpu_offload_val": True, "cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True, "use_tiling_vae_val": True,
"offload_granularity_val": "block",
}, },
), ),
(
8,
(
{
"cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
"t5_quant_scheme_val": quant_type,
"clip_quant_scheme_val": quant_type,
"dit_quant_scheme_val": quant_type,
"lazy_load_val": True,
"unload_modules_val": True,
"rotary_chunk_val": True,
"rotary_chunk_size_val": 10000,
"use_tae_val": True,
}
if res == "540p"
else {
"cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
"t5_quant_scheme_val": quant_type,
"clip_quant_scheme_val": quant_type,
"dit_quant_scheme_val": quant_type,
"lazy_load_val": True,
"unload_modules_val": True,
"use_tae_val": True,
}
),
),
]
else:
gpu_rules = [
(24, {}),
(
8,
{
"t5_cpu_offload_val": True,
"t5_offload_granularity_val": "block",
"t5_quant_scheme_val": quant_type,
},
),
]
if is_14b:
cpu_rules = [
(128, {}),
(64, {"dit_quant_scheme_val": quant_type}),
(32, {"dit_quant_scheme_val": quant_type, "lazy_load_val": True}),
( (
16, 16,
{ {
"dit_quant_scheme_val": quant_type, "cpu_offload_val": True,
"t5_quant_scheme_val": quant_type, "use_tiling_vae_val": True,
"clip_quant_scheme_val": quant_type, "offload_granularity_val": "phase",
"lazy_load_val": True,
"unload_modules_val": True,
}, },
), ),
]
else:
cpu_rules = [
(64, {}),
( (
16, 8,
{ {
"t5_quant_scheme_val": quant_type, "cpu_offload_val": True,
"unload_modules_val": True, "use_tiling_vae_val": True,
"use_tae_val": True, "offload_granularity_val": "phase",
}, },
), ),
] ]
cpu_rules = [
(128, {}),
(64, {}),
(32, {"unload_modules_val": True}),
(
16,
{
"lazy_load_val": True,
"unload_modules_val": True,
},
),
]
for threshold, updates in gpu_rules: for threshold, updates in gpu_rules:
if gpu_memory >= threshold: if gpu_memory >= threshold:
default_config.update(updates) default_config.update(updates)
...@@ -832,298 +873,181 @@ def auto_configure(enable_auto_config, resolution): ...@@ -832,298 +873,181 @@ def auto_configure(enable_auto_config, resolution):
default_config.update(updates) default_config.update(updates)
break break
return tuple(gr.update(value=default_config[key]) for key in default_config) return (
gr.update(value=default_config["lazy_load_val"]),
gr.update(value=default_config["rope_chunk_val"]),
gr.update(value=default_config["rope_chunk_size_val"]),
gr.update(value=default_config["clean_cuda_cache_val"]),
gr.update(value=default_config["cpu_offload_val"]),
gr.update(value=default_config["offload_granularity_val"]),
gr.update(value=default_config["t5_cpu_offload_val"]),
gr.update(value=default_config["clip_cpu_offload_val"]),
gr.update(value=default_config["vae_cpu_offload_val"]),
gr.update(value=default_config["unload_modules_val"]),
gr.update(value=default_config["attention_type_val"]),
gr.update(value=default_config["quant_op_val"]),
gr.update(value=default_config["use_tiling_vae_val"]),
)
def main(): def main():
with gr.Blocks( with gr.Blocks(
title="Lightx2v (轻量级视频推理和生成引擎)", title="Lightx2v (轻量级视频推理和生成引擎)",
css=""" css="""
.main-content { max-width: 1400px; margin: auto; } .main-content { max-width: 1600px; margin: auto; padding: 20px; }
.output-video { max-height: 650px; }
.warning { color: #ff6b6b; font-weight: bold; } .warning { color: #ff6b6b; font-weight: bold; }
.advanced-options { background: #f9f9ff; border-radius: 10px; padding: 15px; }
.tab-button { font-size: 16px; padding: 10px 20px; } /* 模型配置区域样式 */
.auto-config-title { .model-config {
background: linear-gradient(45deg, #ff6b6b, #4ecdc4); margin-bottom: 20px !important;
background-clip: text; border: 1px solid #e0e0e0;
-webkit-background-clip: text; border-radius: 12px;
color: transparent; padding: 15px;
text-align: center; background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
margin: 0 !important;
padding: 8px;
border: 2px solid #4ecdc4;
border-radius: 8px;
background-color: #f0f8ff;
} }
.auto-config-checkbox {
border: 2px solid #ff6b6b !important; /* 输入参数区域样式 */
border-radius: 8px !important; .input-params {
padding: 10px !important; margin-bottom: 20px !important;
background: linear-gradient(135deg, #fff5f5, #f0fff0) !important; border: 1px solid #e0e0e0;
box-shadow: 0 2px 8px rgba(255, 107, 107, 0.2) !important; border-radius: 12px;
padding: 15px;
background: linear-gradient(135deg, #fff5f5 0%, #ffeef0 100%);
} }
.auto-config-checkbox label {
font-size: 16px !important; /* 输出视频区域样式 */
font-weight: bold !important; .output-video {
color: #2c3e50 !important; border: 1px solid #e0e0e0;
border-radius: 12px;
padding: 20px;
background: linear-gradient(135deg, #e0f2fe 0%, #bae6fd 100%);
min-height: 400px;
} }
""",
) as demo:
gr.Markdown(f"# 🎬 {model_cls} 视频生成器")
gr.Markdown(f"### 使用模型: {model_path}")
with gr.Tabs() as tabs:
with gr.Tab("基本设置", id=1):
with gr.Row():
with gr.Column(scale=4):
with gr.Group():
gr.Markdown("## 📥 输入参数")
if task == "i2v":
with gr.Row():
image_path = gr.Image(
label="输入图像",
type="filepath",
height=300,
interactive=True,
visible=True,
)
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="提示词",
lines=3,
placeholder="描述视频内容...",
max_lines=5,
)
with gr.Column():
negative_prompt = gr.Textbox(
label="负向提示词",
lines=3,
placeholder="不希望出现在视频中的内容...",
max_lines=5,
value="镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
)
with gr.Column():
resolution = gr.Dropdown(
choices=[
# 720p
("1280x720 (16:9, 720p)", "1280x720"),
("720x1280 (9:16, 720p)", "720x1280"),
("1280x544 (21:9, 720p)", "1280x544"),
("544x1280 (9:21, 720p)", "544x1280"),
("1104x832 (4:3, 720p)", "1104x832"),
("832x1104 (3:4, 720p)", "832x1104"),
("960x960 (1:1, 720p)", "960x960"),
# 480p
("960x544 (16:9, 540p)", "960x544"),
("544x960 (9:16, 540p)", "544x960"),
("832x480 (16:9, 480p)", "832x480"),
("480x832 (9:16, 480p)", "480x832"),
("832x624 (4:3, 480p)", "832x624"),
("624x832 (3:4, 480p)", "624x832"),
("720x720 (1:1, 480p)", "720x720"),
("512x512 (1:1, 480p)", "512x512"),
],
value="832x480",
label="最大分辨率",
)
with gr.Column():
with gr.Group():
gr.Markdown("### 🚀 **智能配置推荐**", elem_classes=["auto-config-title"])
enable_auto_config = gr.Checkbox(
label="🎯 **自动配置推理选项**",
value=False,
info="💡 **智能优化GPU设置以匹配当前分辨率。修改分辨率后,请重新勾选此选项,否则可能导致性能下降或运行失败。**",
elem_classes=["auto-config-checkbox"],
)
with gr.Column(scale=9):
seed = gr.Slider(
label="随机种子",
minimum=0,
maximum=MAX_NUMPY_SEED,
step=1,
value=generate_random_seed(),
)
with gr.Column(scale=1):
randomize_btn = gr.Button("🎲 随机化", variant="secondary")
randomize_btn.click(fn=generate_random_seed, inputs=None, outputs=seed)
with gr.Column():
# 根据模型类别设置默认推理步数
if model_cls == "wan2.1_distill":
infer_steps = gr.Slider(
label="推理步数",
minimum=4,
maximum=4,
step=1,
value=4,
interactive=False,
info="推理步数固定为4,以获得最佳性能(对于蒸馏模型)。",
)
elif model_cls == "wan2.1":
if task == "i2v":
infer_steps = gr.Slider(
label="推理步数",
minimum=1,
maximum=100,
step=1,
value=40,
info="视频生成的推理步数。增加步数可能提高质量但降低速度。",
)
elif task == "t2v":
infer_steps = gr.Slider(
label="推理步数",
minimum=1,
maximum=100,
step=1,
value=50,
info="视频生成的推理步数。增加步数可能提高质量但降低速度。",
)
# 根据模型类别设置默认CFG
default_enable_cfg = False if model_cls == "wan2.1_distill" else True
enable_cfg = gr.Checkbox(
label="启用无分类器引导",
value=default_enable_cfg,
info="启用无分类器引导以控制提示词强度",
)
cfg_scale = gr.Slider(
label="CFG缩放因子",
minimum=1,
maximum=10,
step=1,
value=5,
info="控制提示词的影响强度。值越高,提示词的影响越大。",
)
sample_shift = gr.Slider(
label="分布偏移",
value=5,
minimum=0,
maximum=10,
step=1,
info="控制样本分布偏移的程度。值越大表示偏移越明显。",
)
fps = gr.Slider( /* 生成按钮样式 */
label="每秒帧数(FPS)", .generate-btn {
minimum=8, width: 100%;
maximum=30, margin-top: 20px;
step=1, padding: 15px 30px !important;
value=16, font-size: 18px !important;
info="视频的每秒帧数。较高的FPS会产生更流畅的视频。", font-weight: bold !important;
) background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
num_frames = gr.Slider( border: none !important;
label="总帧数", border-radius: 10px !important;
minimum=16, box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important;
maximum=120, transition: all 0.3s ease !important;
step=1, }
value=81, .generate-btn:hover {
info="视频中的总帧数。更多帧数会产生更长的视频。", transform: translateY(-2px);
) box-shadow: 0 6px 20px rgba(102, 126, 234, 0.6) !important;
}
save_result_path = gr.Textbox( /* Accordion 标题样式 */
label="输出视频路径", .model-config .gr-accordion-header,
value=generate_unique_filename(output_dir), .input-params .gr-accordion-header,
info="必须包含.mp4扩展名。如果留空或使用默认值,将自动生成唯一文件名。", .output-video .gr-accordion-header {
) font-size: 20px !important;
with gr.Column(scale=6): font-weight: bold !important;
gr.Markdown("## 📤 生成的视频") padding: 15px !important;
output_video = gr.Video( }
label="结果",
height=624,
width=360,
autoplay=True,
elem_classes=["output-video"],
)
infer_btn = gr.Button("生成视频", variant="primary", size="lg") /* 优化间距 */
.gr-row {
margin-bottom: 15px;
}
with gr.Tab("⚙️ 高级选项", id=2): /* 视频播放器样式 */
with gr.Group(elem_classes="advanced-options"): .output-video video {
gr.Markdown("### GPU内存优化") border-radius: 10px;
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1);
}
""",
) as demo:
gr.Markdown(f"# 🎬 LightX2V 视频生成器")
# 主布局:左右分栏
with gr.Row():
# 左侧:配置和输入区域
with gr.Column(scale=5):
# 模型配置区域
with gr.Accordion("🗂️ 模型配置", open=True, elem_classes=["model-config"]):
# FP8 支持提示
if not is_fp8_supported_gpu():
gr.Markdown("⚠️ **您的设备不支持fp8推理**,已自动隐藏包含fp8的模型选项。")
# 隐藏的状态组件
model_path_input = gr.Textbox(value=model_path, visible=False)
# 模型类型 + 任务类型
with gr.Row(): with gr.Row():
rotary_chunk = gr.Checkbox( model_type_input = gr.Radio(
label="分块旋转位置编码", label="模型类型",
value=False, choices=["wan2.1", "wan2.2"],
info="启用时,将旋转位置编码分块处理以节省GPU内存。", value="wan2.1",
info="wan2.2 需要分别指定高噪模型和低噪模型",
)
task_type_input = gr.Radio(
label="任务类型",
choices=["i2v", "t2v"],
value="i2v",
info="i2v: 图生视频, t2v: 文生视频",
) )
rotary_chunk_size = gr.Slider( # wan2.1:Diffusion模型(单独一行)
label="旋转编码块大小", with gr.Row() as wan21_row:
value=100, dit_path_input = gr.Dropdown(
minimum=100, label="🎨 Diffusion模型",
maximum=10000, choices=get_dit_choices(model_path, "wan2.1"),
step=100, value=get_dit_choices(model_path, "wan2.1")[0] if get_dit_choices(model_path, "wan2.1") else "",
info="控制应用旋转编码的块大小。较大的值可能提高性能但增加内存使用。仅在'rotary_chunk'勾选时有效。", allow_custom_value=True,
visible=True,
) )
unload_modules = gr.Checkbox(
label="卸载模块", # wan2.2 专用:高噪模型 + 低噪模型(默认隐藏)
value=False, with gr.Row(visible=False) as wan22_row:
info="推理后卸载模块(T5、CLIP、DIT等)以减少GPU/CPU内存使用", high_noise_path_input = gr.Dropdown(
label="🔊 高噪模型",
choices=get_high_noise_choices(model_path),
value=get_high_noise_choices(model_path)[0] if get_high_noise_choices(model_path) else "",
allow_custom_value=True,
) )
clean_cuda_cache = gr.Checkbox( low_noise_path_input = gr.Dropdown(
label="清理CUDA内存缓存", label="🔇 低噪模型",
value=False, choices=get_low_noise_choices(model_path),
info="启用时,及时释放GPU内存但会减慢推理速度。", value=get_low_noise_choices(model_path)[0] if get_low_noise_choices(model_path) else "",
allow_custom_value=True,
) )
gr.Markdown("### 异步卸载") # 文本编码器(单独一行)
with gr.Row(): with gr.Row():
cpu_offload = gr.Checkbox( t5_path_input = gr.Dropdown(
label="CPU卸载", label="📝 文本编码器",
value=False, choices=get_t5_choices(model_path),
info="将模型计算的一部分从GPU卸载到CPU以减少GPU内存使用", value=get_t5_choices(model_path)[0] if get_t5_choices(model_path) else "",
allow_custom_value=True,
) )
lazy_load = gr.Checkbox( # 图像编码器 + VAE解码器
label="启用延迟加载", with gr.Row():
value=False, clip_path_input = gr.Dropdown(
info="在推理过程中延迟加载模型组件。需要CPU加载和DIT量化。", label="🖼️ 图像编码器",
) choices=get_clip_choices(model_path),
value=get_clip_choices(model_path)[0] if get_clip_choices(model_path) else "",
offload_granularity = gr.Dropdown( allow_custom_value=True,
label="Dit卸载粒度",
choices=["block", "phase"],
value="phase",
info="设置Dit模型卸载粒度:块或计算阶段",
)
offload_ratio = gr.Slider(
label="Dit模型卸载比例",
minimum=0.0,
maximum=1.0,
step=0.1,
value=1.0,
info="控制将多少Dit模型卸载到CPU",
)
t5_cpu_offload = gr.Checkbox(
label="T5 CPU卸载",
value=False,
info="将T5编码器模型卸载到CPU以减少GPU内存使用",
) )
t5_offload_granularity = gr.Dropdown( vae_path_input = gr.Dropdown(
label="T5编码器卸载粒度", label="🎞️ VAE解码器",
choices=["model", "block"], choices=get_vae_choices(model_path),
value="model", value=get_vae_choices(model_path)[0] if get_vae_choices(model_path) else "",
info="控制将T5编码器模型卸载到CPU时的粒度", allow_custom_value=True,
) )
gr.Markdown("### 低精度量化") # 注意力算子和量化矩阵乘法算子
with gr.Row(): with gr.Row():
torch_compile = gr.Checkbox(
label="Torch编译",
value=False,
info="使用torch.compile加速推理过程",
)
attention_type = gr.Dropdown( attention_type = gr.Dropdown(
label="注意力算子", label="注意力算子",
choices=[op[1] for op in attn_op_choices], choices=[op[1] for op in attn_op_choices],
value=attn_op_choices[0][1], value=attn_op_choices[0][1] if attn_op_choices else "",
info="使用适当的注意力算子加速推理", info="使用适当的注意力算子加速推理",
) )
quant_op = gr.Dropdown( quant_op = gr.Dropdown(
...@@ -1133,182 +1057,352 @@ def main(): ...@@ -1133,182 +1057,352 @@ def main():
info="选择量化矩阵乘法算子以加速推理", info="选择量化矩阵乘法算子以加速推理",
interactive=True, interactive=True,
) )
# 获取动态量化选项
quant_options = get_quantization_options(model_path) # 判断模型是否是 distill 版本
def is_distill_model(model_type, dit_path, high_noise_path):
dit_quant_scheme = gr.Dropdown( """根据模型类型和路径判断是否是 distill 版本"""
label="Dit", if model_type == "wan2.1":
choices=quant_options["dit_choices"], check_name = dit_path.lower() if dit_path else ""
value=quant_options["dit_default"], else:
info="Dit模型的量化精度", check_name = high_noise_path.lower() if high_noise_path else ""
) return "4step" in check_name
t5_quant_scheme = gr.Dropdown(
label="T5编码器", # 模型类型切换事件
choices=quant_options["t5_choices"], def on_model_type_change(model_type, model_path_val):
value=quant_options["t5_default"], if model_type == "wan2.2":
info="T5编码器模型的量化精度", return gr.update(visible=False), gr.update(visible=True), gr.update()
) else:
clip_quant_scheme = gr.Dropdown( # 更新 wan2.1 的 Diffusion 模型选项
label="Clip编码器", dit_choices = get_dit_choices(model_path_val, "wan2.1")
choices=quant_options["clip_choices"], return (
value=quant_options["clip_default"], gr.update(visible=True),
info="Clip编码器的量化精度", gr.update(visible=False),
) gr.update(choices=dit_choices, value=dit_choices[0] if dit_choices else ""),
precision_mode = gr.Dropdown( )
label="敏感层精度模式",
choices=["fp32", "bf16"], model_type_input.change(
value="fp32", fn=on_model_type_change,
info="选择用于关键模型组件(如归一化和嵌入层)的数值精度。FP32提供更高精度,而BF16在兼容硬件上提高性能。", inputs=[model_type_input, model_path_input],
outputs=[wan21_row, wan22_row, dit_path_input],
)
# 输入参数区域
with gr.Accordion("📥 输入参数", open=True, elem_classes=["input-params"]):
# 图片输入(i2v 时显示)
with gr.Row(visible=True) as image_input_row:
image_path = gr.Image(
label="输入图像",
type="filepath",
height=300,
interactive=True,
) )
gr.Markdown("### 变分自编码器(VAE)") # 任务类型切换事件
def on_task_type_change(task_type):
return gr.update(visible=(task_type == "i2v"))
task_type_input.change(
fn=on_task_type_change,
inputs=[task_type_input],
outputs=[image_input_row],
)
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="提示词",
lines=3,
placeholder="描述视频内容...",
max_lines=5,
)
with gr.Column():
negative_prompt = gr.Textbox(
label="负向提示词",
lines=3,
placeholder="不希望出现在视频中的内容...",
max_lines=5,
value="镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
)
with gr.Column():
resolution = gr.Dropdown(
choices=[
# 720p
("1280x720 (16:9, 720p)", "1280x720"),
("720x1280 (9:16, 720p)", "720x1280"),
("1280x544 (21:9, 720p)", "1280x544"),
("544x1280 (9:21, 720p)", "544x1280"),
("1104x832 (4:3, 720p)", "1104x832"),
("832x1104 (3:4, 720p)", "832x1104"),
("960x960 (1:1, 720p)", "960x960"),
# 480p
("960x544 (16:9, 540p)", "960x544"),
("544x960 (9:16, 540p)", "544x960"),
("832x480 (16:9, 480p)", "832x480"),
("480x832 (9:16, 480p)", "480x832"),
("832x624 (4:3, 480p)", "832x624"),
("624x832 (3:4, 480p)", "624x832"),
("720x720 (1:1, 480p)", "720x720"),
("512x512 (1:1, 480p)", "512x512"),
],
value="832x480",
label="最大分辨率",
)
with gr.Column(scale=9):
seed = gr.Slider(
label="随机种子",
minimum=0,
maximum=MAX_NUMPY_SEED,
step=1,
value=generate_random_seed(),
)
with gr.Column():
default_dit = get_dit_choices(model_path, "wan2.1")[0] if get_dit_choices(model_path, "wan2.1") else ""
default_high_noise = get_high_noise_choices(model_path)[0] if get_high_noise_choices(model_path) else ""
default_is_distill = is_distill_model("wan2.1", default_dit, default_high_noise)
if default_is_distill:
infer_steps = gr.Slider(
label="推理步数",
minimum=1,
maximum=100,
step=1,
value=4,
info="蒸馏模型推理步数默认为4。",
)
else:
infer_steps = gr.Slider(
label="推理步数",
minimum=1,
maximum=100,
step=1,
value=40,
info="视频生成的推理步数。增加步数可能提高质量但降低速度。",
)
# 当模型路径改变时,动态更新推理步数
def update_infer_steps(model_type, dit_path, high_noise_path):
is_distill = is_distill_model(model_type, dit_path, high_noise_path)
if is_distill:
return gr.update(minimum=1, maximum=100, value=4, interactive=True)
else:
return gr.update(minimum=1, maximum=100, value=40, interactive=True)
# 监听模型路径变化
dit_path_input.change(
fn=lambda mt, dp, hnp: update_infer_steps(mt, dp, hnp),
inputs=[model_type_input, dit_path_input, high_noise_path_input],
outputs=[infer_steps],
)
high_noise_path_input.change(
fn=lambda mt, dp, hnp: update_infer_steps(mt, dp, hnp),
inputs=[model_type_input, dit_path_input, high_noise_path_input],
outputs=[infer_steps],
)
model_type_input.change(
fn=lambda mt, dp, hnp: update_infer_steps(mt, dp, hnp),
inputs=[model_type_input, dit_path_input, high_noise_path_input],
outputs=[infer_steps],
)
# 根据模型类别设置默认CFG
# CFG缩放因子:distill 时默认为 1,否则默认为 5
default_cfg_scale = 1 if default_is_distill else 5
# enable_cfg 不暴露到前端,根据 cfg_scale 自动设置
# 如果 cfg_scale == 1,则 enable_cfg = False,否则 enable_cfg = True
default_enable_cfg = False if default_cfg_scale == 1 else True
enable_cfg = gr.Checkbox(
label="启用无分类器引导",
value=default_enable_cfg,
visible=False, # 隐藏,不暴露到前端
)
with gr.Row(): with gr.Row():
use_tae = gr.Checkbox( sample_shift = gr.Slider(
label="使用轻量级VAE", label="分布偏移",
value=False, value=5,
info="使用轻量级VAE模型加速解码过程", minimum=0,
maximum=10,
step=1,
info="控制样本分布偏移的程度。值越大表示偏移越明显。",
) )
use_tiling_vae = gr.Checkbox( cfg_scale = gr.Slider(
label="VAE分块推理", label="CFG缩放因子",
value=False, minimum=1,
info="使用VAE分块推理以减少GPU内存使用", maximum=10,
step=1,
value=default_cfg_scale,
info="控制提示词的影响强度。值越高,提示词的影响越大。当值为1时,自动禁用CFG。",
) )
gr.Markdown("### 特征缓存") # 根据 cfg_scale 更新 enable_cfg
def update_enable_cfg(cfg_scale_val):
"""根据 cfg_scale 的值自动设置 enable_cfg"""
if cfg_scale_val == 1:
return gr.update(value=False)
else:
return gr.update(value=True)
# 当模型路径改变时,动态更新 CFG 缩放因子和 enable_cfg
def update_cfg_scale(model_type, dit_path, high_noise_path):
is_distill = is_distill_model(model_type, dit_path, high_noise_path)
if is_distill:
new_cfg_scale = 1
else:
new_cfg_scale = 5
new_enable_cfg = False if new_cfg_scale == 1 else True
return gr.update(value=new_cfg_scale), gr.update(value=new_enable_cfg)
dit_path_input.change(
fn=lambda mt, dp, hnp: update_cfg_scale(mt, dp, hnp),
inputs=[model_type_input, dit_path_input, high_noise_path_input],
outputs=[cfg_scale, enable_cfg],
)
high_noise_path_input.change(
fn=lambda mt, dp, hnp: update_cfg_scale(mt, dp, hnp),
inputs=[model_type_input, dit_path_input, high_noise_path_input],
outputs=[cfg_scale, enable_cfg],
)
model_type_input.change(
fn=lambda mt, dp, hnp: update_cfg_scale(mt, dp, hnp),
inputs=[model_type_input, dit_path_input, high_noise_path_input],
outputs=[cfg_scale, enable_cfg],
)
cfg_scale.change(
fn=update_enable_cfg,
inputs=[cfg_scale],
outputs=[enable_cfg],
)
with gr.Row(): with gr.Row():
enable_teacache = gr.Checkbox( fps = gr.Slider(
label="Tea Cache", label="每秒帧数(FPS)",
value=False, minimum=8,
info="在推理过程中缓存特征以减少推理步数", maximum=30,
) step=1,
teacache_thresh = gr.Slider( value=16,
label="Tea Cache阈值", info="视频的每秒帧数。较高的FPS会产生更流畅的视频。",
value=0.26,
minimum=0,
maximum=1,
info="较高的加速可能导致质量下降 —— 设置为0.1提供约2.0倍加速,设置为0.2提供约3.0倍加速",
) )
use_ret_steps = gr.Checkbox( num_frames = gr.Slider(
label="仅缓存关键步骤", label="总帧数",
value=False, minimum=16,
info="勾选时,仅在调度器返回结果的关键步骤写入缓存;未勾选时,在所有步骤写入缓存以确保最高质量", maximum=120,
step=1,
value=81,
info="视频中的总帧数。更多帧数会产生更长的视频。",
) )
enable_auto_config.change( save_result_path = gr.Textbox(
fn=auto_configure, label="输出视频路径",
inputs=[enable_auto_config, resolution], value=generate_unique_filename(output_dir),
outputs=[ info="必须包含.mp4扩展名。如果留空或使用默认值,将自动生成唯一文件名。",
torch_compile, visible=False, # 隐藏输出路径,自动生成
lazy_load, )
rotary_chunk,
rotary_chunk_size, with gr.Column(scale=4):
clean_cuda_cache, with gr.Accordion("📤 生成的视频", open=True, elem_classes=["output-video"]):
cpu_offload, output_video = gr.Video(
offload_granularity, label="",
offload_ratio, height=600,
t5_cpu_offload, autoplay=True,
unload_modules, show_label=False,
t5_offload_granularity, )
attention_type,
quant_op, infer_btn = gr.Button("🎬 生成视频", variant="primary", size="lg", elem_classes=["generate-btn"])
dit_quant_scheme,
t5_quant_scheme, rope_chunk = gr.Checkbox(label="分块旋转位置编码", value=False, visible=False)
clip_quant_scheme, rope_chunk_size = gr.Slider(label="旋转编码块大小", value=100, minimum=100, maximum=10000, step=100, visible=False)
precision_mode, unload_modules = gr.Checkbox(label="卸载模块", value=False, visible=False)
use_tae, clean_cuda_cache = gr.Checkbox(label="清理CUDA内存缓存", value=False, visible=False)
use_tiling_vae, cpu_offload = gr.Checkbox(label="CPU卸载", value=False, visible=False)
enable_teacache, lazy_load = gr.Checkbox(label="启用延迟加载", value=False, visible=False)
teacache_thresh, offload_granularity = gr.Dropdown(label="Dit卸载粒度", choices=["block", "phase"], value="phase", visible=False)
use_ret_steps, t5_cpu_offload = gr.Checkbox(label="T5 CPU卸载", value=False, visible=False)
], clip_cpu_offload = gr.Checkbox(label="CLIP CPU卸载", value=False, visible=False)
) vae_cpu_offload = gr.Checkbox(label="VAE CPU卸载", value=False, visible=False)
use_tiling_vae = gr.Checkbox(label="VAE分块推理", value=False, visible=False)
lazy_load.change(
fn=handle_lazy_load_change, resolution.change(
inputs=[lazy_load], fn=auto_configure,
outputs=[unload_modules], inputs=[resolution],
) outputs=[
if task == "i2v": lazy_load,
infer_btn.click( rope_chunk,
fn=run_inference, rope_chunk_size,
inputs=[ clean_cuda_cache,
prompt, cpu_offload,
negative_prompt, offload_granularity,
save_result_path, t5_cpu_offload,
torch_compile, clip_cpu_offload,
infer_steps, vae_cpu_offload,
num_frames, unload_modules,
resolution, attention_type,
seed, quant_op,
sample_shift, use_tiling_vae,
enable_teacache, ],
teacache_thresh, )
use_ret_steps,
enable_cfg, demo.load(
cfg_scale, fn=lambda res: auto_configure(res),
dit_quant_scheme, inputs=[resolution],
t5_quant_scheme, outputs=[
clip_quant_scheme, lazy_load,
fps, rope_chunk,
use_tae, rope_chunk_size,
use_tiling_vae, clean_cuda_cache,
lazy_load, cpu_offload,
precision_mode, offload_granularity,
cpu_offload, t5_cpu_offload,
offload_granularity, clip_cpu_offload,
offload_ratio, vae_cpu_offload,
t5_cpu_offload, unload_modules,
unload_modules, attention_type,
t5_offload_granularity, quant_op,
attention_type, use_tiling_vae,
quant_op, ],
rotary_chunk, )
rotary_chunk_size,
clean_cuda_cache, infer_btn.click(
image_path, fn=run_inference,
], inputs=[
outputs=output_video, prompt,
) negative_prompt,
else: save_result_path,
infer_btn.click( infer_steps,
fn=run_inference, num_frames,
inputs=[ resolution,
prompt, seed,
negative_prompt, sample_shift,
save_result_path, enable_cfg,
torch_compile, cfg_scale,
infer_steps, fps,
num_frames, use_tiling_vae,
resolution, lazy_load,
seed, cpu_offload,
sample_shift, offload_granularity,
enable_teacache, t5_cpu_offload,
teacache_thresh, clip_cpu_offload,
use_ret_steps, vae_cpu_offload,
enable_cfg, unload_modules,
cfg_scale, attention_type,
dit_quant_scheme, quant_op,
t5_quant_scheme, rope_chunk,
clip_quant_scheme, rope_chunk_size,
fps, clean_cuda_cache,
use_tae, model_path_input,
use_tiling_vae, model_type_input,
lazy_load, task_type_input,
precision_mode, dit_path_input,
cpu_offload, high_noise_path_input,
offload_granularity, low_noise_path_input,
offload_ratio, t5_path_input,
t5_cpu_offload, clip_path_input,
unload_modules, vae_path_input,
t5_offload_granularity, image_path,
attention_type, ],
quant_op, outputs=output_video,
rotary_chunk, )
rotary_chunk_size,
clean_cuda_cache,
],
outputs=output_video,
)
demo.launch(share=True, server_port=args.server_port, server_name=args.server_name, inbrowser=True, allowed_paths=[output_dir]) demo.launch(share=True, server_port=args.server_port, server_name=args.server_name, inbrowser=True, allowed_paths=[output_dir])
...@@ -1316,25 +1410,14 @@ def main(): ...@@ -1316,25 +1410,14 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="轻量级视频生成") parser = argparse.ArgumentParser(description="轻量级视频生成")
parser.add_argument("--model_path", type=str, required=True, help="模型文件夹路径") parser.add_argument("--model_path", type=str, required=True, help="模型文件夹路径")
parser.add_argument(
"--model_cls",
type=str,
choices=["wan2.1", "wan2.1_distill"],
default="wan2.1",
help="要使用的模型类别 (wan2.1: 标准模型, wan2.1_distill: 蒸馏模型,推理更快)",
)
parser.add_argument("--model_size", type=str, required=True, choices=["14b", "1.3b"], help="模型大小:14b 或 1.3b")
parser.add_argument("--task", type=str, required=True, choices=["i2v", "t2v"], help="指定任务类型。'i2v'用于图像到视频转换,'t2v'用于文本到视频生成。")
parser.add_argument("--server_port", type=int, default=7862, help="服务器端口") parser.add_argument("--server_port", type=int, default=7862, help="服务器端口")
parser.add_argument("--server_name", type=str, default="0.0.0.0", help="服务器IP") parser.add_argument("--server_name", type=str, default="0.0.0.0", help="服务器IP")
parser.add_argument("--output_dir", type=str, default="./outputs", help="输出视频保存目录") parser.add_argument("--output_dir", type=str, default="./outputs", help="输出视频保存目录")
args = parser.parse_args() args = parser.parse_args()
global model_path, model_cls, model_size, output_dir global model_path, model_cls, output_dir
model_path = args.model_path model_path = args.model_path
model_cls = args.model_cls model_cls = "wan2.1"
model_size = args.model_size
task = args.task
output_dir = args.output_dir output_dir = args.output_dir
main() main()
...@@ -14,27 +14,15 @@ ...@@ -14,27 +14,15 @@
# Lightx2v project root directory path # Lightx2v project root directory path
# Example: /home/user/lightx2v or /data/video_gen/lightx2v # Example: /home/user/lightx2v or /data/video_gen/lightx2v
lightx2v_path=/data/video_gen/LightX2V lightx2v_path=/path/to/LightX2V
# Model path configuration # Model path configuration
# Image-to-video model path (for i2v tasks)
# Example: /path/to/Wan2.1-I2V-14B-720P-Lightx2v # Example: /path/to/Wan2.1-I2V-14B-720P-Lightx2v
i2v_model_path=/path/to/Wan2.1-I2V-14B-480P-Lightx2v model_path=/path/to/models
# Text-to-video model path (for t2v tasks)
# Example: /path/to/Wan2.1-T2V-1.3B
t2v_model_path=/path/to/Wan2.1-T2V-1.3B
# Model size configuration
# Default model size (14b, 1.3b)
model_size="14b"
# Model class configuration
# Default model class (wan2.1, wan2.1_distill)
model_cls="wan2.1"
# Server configuration # Server configuration
server_name="0.0.0.0" server_name="0.0.0.0"
server_port=8032 server_port=8033
# Output directory configuration # Output directory configuration
output_dir="./outputs" output_dir="./outputs"
...@@ -50,18 +38,12 @@ export PROFILING_DEBUG_LEVEL=2 ...@@ -50,18 +38,12 @@ export PROFILING_DEBUG_LEVEL=2
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
# ==================== Parameter Parsing ==================== # ==================== Parameter Parsing ====================
# Default task type
task="i2v"
# Default interface language # Default interface language
lang="zh" lang="zh"
# 解析命令行参数 # 解析命令行参数
while [[ $# -gt 0 ]]; do while [[ $# -gt 0 ]]; do
case $1 in case $1 in
--task)
task="$2"
shift 2
;;
--lang) --lang)
lang="$2" lang="$2"
shift 2 shift 2
...@@ -75,55 +57,32 @@ while [[ $# -gt 0 ]]; do ...@@ -75,55 +57,32 @@ while [[ $# -gt 0 ]]; do
export CUDA_VISIBLE_DEVICES=$gpu_id export CUDA_VISIBLE_DEVICES=$gpu_id
shift 2 shift 2
;; ;;
--model_size)
model_size="$2"
shift 2
;;
--model_cls)
model_cls="$2"
shift 2
;;
--output_dir) --output_dir)
output_dir="$2" output_dir="$2"
shift 2 shift 2
;; ;;
--model_path)
model_path="$2"
shift 2
;;
--help) --help)
echo "🎬 Lightx2v Gradio Demo Startup Script" echo "🎬 Lightx2v Gradio Demo Startup Script"
echo "==========================================" echo "=========================================="
echo "Usage: $0 [options]" echo "Usage: $0 [options]"
echo "" echo ""
echo "📋 Available options:" echo "📋 Available options:"
echo " --task i2v|t2v Task type (default: i2v)"
echo " i2v: Image-to-video generation"
echo " t2v: Text-to-video generation"
echo " --lang zh|en Interface language (default: zh)" echo " --lang zh|en Interface language (default: zh)"
echo " zh: Chinese interface" echo " zh: Chinese interface"
echo " en: English interface" echo " en: English interface"
echo " --port PORT Server port (default: 8032)" echo " --port PORT Server port (default: 8032)"
echo " --gpu GPU_ID GPU device ID (default: 0)" echo " --gpu GPU_ID GPU device ID (default: 0)"
echo " --model_size MODEL_SIZE" echo " --model_path PATH Model path (default: configured in script)"
echo " Model size (default: 14b)" echo " --output_dir DIR Output video save directory (default: ./outputs)"
echo " 14b: 14 billion parameters model" echo " --help Show this help message"
echo " 1.3b: 1.3 billion parameters model"
echo " --model_cls MODEL_CLASS"
echo " Model class (default: wan2.1)"
echo " wan2.1: Standard model variant"
echo " wan2.1_distill: Distilled model variant for faster inference"
echo " --output_dir OUTPUT_DIR"
echo " Output video save directory (default: ./saved_videos)"
echo " --help Show this help message"
echo ""
echo "🚀 Usage examples:"
echo " $0 # Default startup for image-to-video mode"
echo " $0 --task i2v --lang zh --port 8032 # Start with specified parameters"
echo " $0 --task t2v --lang en --port 7860 # Text-to-video with English interface"
echo " $0 --task i2v --gpu 1 --port 8032 # Use GPU 1"
echo " $0 --task t2v --model_size 1.3b # Use 1.3B model"
echo " $0 --task i2v --model_size 14b # Use 14B model"
echo " $0 --task i2v --model_cls wan2.1_distill # Use distilled model"
echo " $0 --task i2v --output_dir ./custom_output # Use custom output directory"
echo "" echo ""
echo "📝 Notes:" echo "📝 Notes:"
echo " - Task type (i2v/t2v) and model type are selected in the web UI"
echo " - Model class is auto-detected based on selected diffusion model"
echo " - Edit script to configure model paths before first use" echo " - Edit script to configure model paths before first use"
echo " - Ensure required Python dependencies are installed" echo " - Ensure required Python dependencies are installed"
echo " - Recommended to use GPU with 8GB+ VRAM" echo " - Recommended to use GPU with 8GB+ VRAM"
...@@ -139,37 +98,11 @@ while [[ $# -gt 0 ]]; do ...@@ -139,37 +98,11 @@ while [[ $# -gt 0 ]]; do
done done
# ==================== Parameter Validation ==================== # ==================== Parameter Validation ====================
if [[ "$task" != "i2v" && "$task" != "t2v" ]]; then
echo "Error: Task type must be 'i2v' or 't2v'"
exit 1
fi
if [[ "$lang" != "zh" && "$lang" != "en" ]]; then if [[ "$lang" != "zh" && "$lang" != "en" ]]; then
echo "Error: Language must be 'zh' or 'en'" echo "Error: Language must be 'zh' or 'en'"
exit 1 exit 1
fi fi
# Validate model size
if [[ "$model_size" != "14b" && "$model_size" != "1.3b" ]]; then
echo "Error: Model size must be '14b' or '1.3b'"
exit 1
fi
# Validate model class
if [[ "$model_cls" != "wan2.1" && "$model_cls" != "wan2.1_distill" ]]; then
echo "Error: Model class must be 'wan2.1' or 'wan2.1_distill'"
exit 1
fi
# Select model path based on task type
if [[ "$task" == "i2v" ]]; then
model_path=$i2v_model_path
echo "🎬 Starting Image-to-Video mode"
else
model_path=$t2v_model_path
echo "🎬 Starting Text-to-Video mode"
fi
# Check if model path exists # Check if model path exists
if [[ ! -d "$model_path" ]]; then if [[ ! -d "$model_path" ]]; then
echo "❌ Error: Model path does not exist" echo "❌ Error: Model path does not exist"
...@@ -208,13 +141,11 @@ echo "🚀 Lightx2v Gradio Demo Starting..." ...@@ -208,13 +141,11 @@ echo "🚀 Lightx2v Gradio Demo Starting..."
echo "==========================================" echo "=========================================="
echo "📁 Project path: $lightx2v_path" echo "📁 Project path: $lightx2v_path"
echo "🤖 Model path: $model_path" echo "🤖 Model path: $model_path"
echo "🎯 Task type: $task"
echo "🤖 Model size: $model_size"
echo "🤖 Model class: $model_cls"
echo "🌏 Interface language: $lang" echo "🌏 Interface language: $lang"
echo "🖥️ GPU device: $gpu_id" echo "🖥️ GPU device: $gpu_id"
echo "🌐 Server address: $server_name:$server_port" echo "🌐 Server address: $server_name:$server_port"
echo "📁 Output directory: $output_dir" echo "📁 Output directory: $output_dir"
echo "📝 Note: Task type and model class are selected in web UI"
echo "==========================================" echo "=========================================="
# Display system resource information # Display system resource information
...@@ -239,11 +170,8 @@ echo "==========================================" ...@@ -239,11 +170,8 @@ echo "=========================================="
# Start Python demo # Start Python demo
python $demo_file \ python $demo_file \
--model_path "$model_path" \ --model_path "$model_path" \
--model_cls "$model_cls" \
--task "$task" \
--server_name "$server_name" \ --server_name "$server_name" \
--server_port "$server_port" \ --server_port "$server_port" \
--model_size "$model_size" \
--output_dir "$output_dir" --output_dir "$output_dir"
# Display final system resource usage # Display final system resource usage
......
...@@ -16,21 +16,9 @@ REM Example: D:\LightX2V ...@@ -16,21 +16,9 @@ REM Example: D:\LightX2V
set lightx2v_path=/path/to/LightX2V set lightx2v_path=/path/to/LightX2V
REM Model path configuration REM Model path configuration
REM Image-to-video model path (for i2v tasks) REM Model root directory path
REM Example: D:\models\Wan2.1-I2V-14B-480P-Lightx2v REM Example: D:\models\LightX2V
set i2v_model_path=/path/to/Wan2.1-I2V-14B-480P-Lightx2v set model_path=/path/to/LightX2V
REM Text-to-video model path (for t2v tasks)
REM Example: D:\models\Wan2.1-T2V-1.3B
set t2v_model_path=/path/to/Wan2.1-T2V-1.3B
REM Model size configuration
REM Default model size (14b, 1.3b)
set model_size=14b
REM Model class configuration
REM Default model class (wan2.1, wan2.1_distill)
set model_cls=wan2.1
REM Server configuration REM Server configuration
set server_name=127.0.0.1 set server_name=127.0.0.1
...@@ -49,20 +37,12 @@ set PROFILING_DEBUG_LEVEL=2 ...@@ -49,20 +37,12 @@ set PROFILING_DEBUG_LEVEL=2
set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
REM ==================== Parameter Parsing ==================== REM ==================== Parameter Parsing ====================
REM Default task type
set task=i2v
REM Default interface language REM Default interface language
set lang=zh set lang=zh
REM Parse command line arguments REM Parse command line arguments
:parse_args :parse_args
if "%1"=="" goto :end_parse if "%1"=="" goto :end_parse
if "%1"=="--task" (
set task=%2
shift
shift
goto :parse_args
)
if "%1"=="--lang" ( if "%1"=="--lang" (
set lang=%2 set lang=%2
shift shift
...@@ -82,18 +62,6 @@ if "%1"=="--gpu" ( ...@@ -82,18 +62,6 @@ if "%1"=="--gpu" (
shift shift
goto :parse_args goto :parse_args
) )
if "%1"=="--model_size" (
set model_size=%2
shift
shift
goto :parse_args
)
if "%1"=="--model_cls" (
set model_cls=%2
shift
shift
goto :parse_args
)
if "%1"=="--output_dir" ( if "%1"=="--output_dir" (
set output_dir=%2 set output_dir=%2
shift shift
...@@ -106,38 +74,24 @@ if "%1"=="--help" ( ...@@ -106,38 +74,24 @@ if "%1"=="--help" (
echo Usage: %0 [options] echo Usage: %0 [options]
echo. echo.
echo 📋 Available options: echo 📋 Available options:
echo --task i2v^|t2v Task type (default: i2v)
echo i2v: Image-to-video generation
echo t2v: Text-to-video generation
echo --lang zh^|en Interface language (default: zh) echo --lang zh^|en Interface language (default: zh)
echo zh: Chinese interface echo zh: Chinese interface
echo en: English interface echo en: English interface
echo --port PORT Server port (default: 8032) echo --port PORT Server port (default: 8032)
echo --gpu GPU_ID GPU device ID (default: 0) echo --gpu GPU_ID GPU device ID (default: 0)
echo --model_size MODEL_SIZE
echo Model size (default: 14b)
echo 14b: 14B parameter model
echo 1.3b: 1.3B parameter model
echo --model_cls MODEL_CLASS
echo Model class (default: wan2.1)
echo wan2.1: Standard model variant
echo wan2.1_distill: Distilled model variant for faster inference
echo --output_dir OUTPUT_DIR echo --output_dir OUTPUT_DIR
echo Output video save directory (default: ./saved_videos) echo Output video save directory (default: ./outputs)
echo --help Show this help message echo --help Show this help message
echo. echo.
echo 🚀 Usage examples: echo 🚀 Usage examples:
echo %0 # Default startup for image-to-video mode echo %0 # Default startup
echo %0 --task i2v --lang zh --port 8032 # Start with specified parameters echo %0 --lang zh --port 8032 # Start with specified parameters
echo %0 --task t2v --lang en --port 7860 # Text-to-video with English interface echo %0 --lang en --port 7860 # English interface
echo %0 --task i2v --gpu 1 --port 8032 # Use GPU 1 echo %0 --gpu 1 --port 8032 # Use GPU 1
echo %0 --task t2v --model_size 1.3b # Use 1.3B model echo %0 --output_dir ./custom_output # Use custom output directory
echo %0 --task i2v --model_size 14b # Use 14B model
echo %0 --task i2v --model_cls wan2.1_distill # Use distilled model
echo %0 --task i2v --output_dir ./custom_output # Use custom output directory
echo. echo.
echo 📝 Notes: echo 📝 Notes:
echo - Edit script to configure model paths before first use echo - Edit script to configure model path before first use
echo - Ensure required Python dependencies are installed echo - Ensure required Python dependencies are installed
echo - Recommended to use GPU with 8GB+ VRAM echo - Recommended to use GPU with 8GB+ VRAM
echo - 🚨 Strongly recommend storing models on SSD for better performance echo - 🚨 Strongly recommend storing models on SSD for better performance
...@@ -152,13 +106,6 @@ exit /b 1 ...@@ -152,13 +106,6 @@ exit /b 1
:end_parse :end_parse
REM ==================== Parameter Validation ==================== REM ==================== Parameter Validation ====================
if "%task%"=="i2v" goto :valid_task
if "%task%"=="t2v" goto :valid_task
echo Error: Task type must be 'i2v' or 't2v'
pause
exit /b 1
:valid_task
if "%lang%"=="zh" goto :valid_lang if "%lang%"=="zh" goto :valid_lang
if "%lang%"=="en" goto :valid_lang if "%lang%"=="en" goto :valid_lang
echo Error: Language must be 'zh' or 'en' echo Error: Language must be 'zh' or 'en'
...@@ -166,29 +113,6 @@ pause ...@@ -166,29 +113,6 @@ pause
exit /b 1 exit /b 1
:valid_lang :valid_lang
if "%model_size%"=="14b" goto :valid_size
if "%model_size%"=="1.3b" goto :valid_size
echo Error: Model size must be '14b' or '1.3b'
pause
exit /b 1
:valid_size
if "%model_cls%"=="wan2.1" goto :valid_cls
if "%model_cls%"=="wan2.1_distill" goto :valid_cls
echo Error: Model class must be 'wan2.1' or 'wan2.1_distill'
pause
exit /b 1
:valid_cls
REM Select model path based on task type
if "%task%"=="i2v" (
set model_path=%i2v_model_path%
echo 🎬 Starting Image-to-Video mode
) else (
set model_path=%t2v_model_path%
echo 🎬 Starting Text-to-Video mode
)
REM Check if model path exists REM Check if model path exists
if not exist "%model_path%" ( if not exist "%model_path%" (
...@@ -230,9 +154,6 @@ echo 🚀 LightX2V Gradio Starting... ...@@ -230,9 +154,6 @@ echo 🚀 LightX2V Gradio Starting...
echo ========================================== echo ==========================================
echo 📁 Project path: %lightx2v_path% echo 📁 Project path: %lightx2v_path%
echo 🤖 Model path: %model_path% echo 🤖 Model path: %model_path%
echo 🎯 Task type: %task%
echo 🤖 Model size: %model_size%
echo 🤖 Model class: %model_cls%
echo 🌏 Interface language: %lang% echo 🌏 Interface language: %lang%
echo 🖥️ GPU device: %gpu_id% echo 🖥️ GPU device: %gpu_id%
echo 🌐 Server address: %server_name%:%server_port% echo 🌐 Server address: %server_name%:%server_port%
...@@ -262,11 +183,8 @@ echo ========================================== ...@@ -262,11 +183,8 @@ echo ==========================================
REM Start Python demo REM Start Python demo
python %demo_file% ^ python %demo_file% ^
--model_path "%model_path%" ^ --model_path "%model_path%" ^
--model_cls %model_cls% ^
--task %task% ^
--server_name %server_name% ^ --server_name %server_name% ^
--server_port %server_port% ^ --server_port %server_port% ^
--model_size %model_size% ^
--output_dir "%output_dir%" --output_dir "%output_dir%"
REM Display final system resource usage REM Display final system resource usage
......
...@@ -38,51 +38,52 @@ Follow the [Quick Start Guide](../getting_started/quickstart.md) to install the ...@@ -38,51 +38,52 @@ Follow the [Quick Start Guide](../getting_started/quickstart.md) to install the
-[sglang-kernel](https://github.com/sgl-project/sglang/tree/main/sgl-kernel) -[sglang-kernel](https://github.com/sgl-project/sglang/tree/main/sgl-kernel)
-[q8-kernel](https://github.com/KONAKONA666/q8_kernels) (only supports ADA architecture GPUs) -[q8-kernel](https://github.com/KONAKONA666/q8_kernels) (only supports ADA architecture GPUs)
Install according to the project homepage tutorials for each operator as needed Install according to the project homepage tutorials for each operator as needed.
### 🤖 Supported Models ### 📥 Model Download
#### 🎬 Image-to-Video Models Refer to the [Model Structure Documentation](../getting_started/model_structure.md) to download complete models (including quantized and non-quantized versions) or download only quantized/non-quantized versions.
| Model Name | Resolution | Parameters | Features | Recommended Use | #### wan2.1 Model Directory Structure
|------------|------------|------------|----------|-----------------|
| ✅ [Wan2.1-I2V-14B-480P-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-480P-Lightx2v) | 480p | 14B | Standard version | Balance speed and quality |
| ✅ [Wan2.1-I2V-14B-720P-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-720P-Lightx2v) | 720p | 14B | HD version | Pursue high-quality output |
| ✅ [Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-Lightx2v) | 480p | 14B | Distilled optimized version | Faster inference speed |
| ✅ [Wan2.1-I2V-14B-720P-StepDistill-CfgDistill-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-720P-StepDistill-CfgDistill-Lightx2v) | 720p | 14B | HD distilled version | High quality + fast inference |
#### 📝 Text-to-Video Models ```
models/
| Model Name | Parameters | Features | Recommended Use | ├── wan2.1_i2v_720p_lightx2v_4step.safetensors # Original precision
|------------|------------|----------|-----------------| ├── wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step.safetensors # FP8 quantization
| ✅ [Wan2.1-T2V-1.3B-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-T2V-1.3B-Lightx2v) | 1.3B | Lightweight | Fast prototyping and testing | ├── wan2.1_i2v_720p_int8_lightx2v_4step.safetensors # INT8 quantization
| ✅ [Wan2.1-T2V-14B-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-Lightx2v) | 14B | Standard version | Balance speed and quality | ├── wan2.1_i2v_720p_int8_lightx2v_4step_split # INT8 quantization block storage directory
| ✅ [Wan2.1-T2V-14B-StepDistill-CfgDistill-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill-Lightx2v) | 14B | Distilled optimized version | High quality + fast inference | ├── wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step_split # FP8 quantization block storage directory
├── Other weights (e.g., t2v)
**💡 Model Selection Recommendations**: ├── t5/clip/xlm-roberta-large/google # text and image encoder
- **First-time use**: Recommend choosing distilled versions (`wan2.1_distill`) ├── vae/lightvae/lighttae # vae
- **Pursuing quality**: Choose 720p resolution or 14B parameter models └── config.json # Model configuration file
- **Pursuing speed**: Choose 480p resolution or 1.3B parameter models, prioritize distilled versions ```
- **Resource-constrained**: Prioritize distilled versions and lower resolutions
- **Real-time applications**: Strongly recommend using distilled models (`wan2.1_distill`)
**🎯 Model Category Description**:
- **`wan2.1`**: Standard model, provides the best video generation quality, suitable for scenarios with extremely high quality requirements
- **`wan2.1_distill`**: Distilled model, optimized through knowledge distillation technology, significantly improves inference speed, maintains good quality while greatly reducing computation time, suitable for most application scenarios
**📥 Model Download**:
Refer to the [Model Structure Documentation](./model_structure.md) to download complete models (including quantized and non-quantized versions) or download only quantized/non-quantized versions.
**Download Options**:
- **Complete Model**: When downloading complete models with both quantized and non-quantized versions, you can freely choose the quantization precision for DIT/T5/CLIP in the advanced options of the `Gradio` Web frontend. #### wan2.2 Model Directory Structure
- **Non-quantized Version Only**: When downloading only non-quantized versions, in the `Gradio` Web frontend, the quantization precision for `DIT/T5/CLIP` can only be set to bf16/fp16. If you need to use quantized versions of models, please manually download quantized weights to the `i2v_model_path` or `t2v_model_path` directory where Gradio is started. ```
models/
├── wan2.2_i2v_A14b_high_noise_lightx2v_4step_1030.safetensors # high noise original precision
├── wan2.2_i2v_A14b_high_noise_fp8_e4m3_lightx2v_4step_1030.safetensors # high noise FP8 quantization
├── wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step_1030.safetensors # high noise INT8 quantization
├── wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step_1030_split # high noise INT8 quantization block storage directory
├── wan2.2_i2v_A14b_low_noise_lightx2v_4step.safetensors # low noise original precision
├── wan2.2_i2v_A14b_low_noise_fp8_e4m3_lightx2v_4step.safetensors # low noise FP8 quantization
├── wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors # low noise INT8 quantization
├── wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step_split # low noise INT8 quantization block storage directory
├── t5/clip/xlm-roberta-large/google # text and image encoder
├── vae/lightvae/lighttae # vae
└── config.json # Model configuration file
```
- **Quantized Version Only**: When downloading only quantized versions, in the `Gradio` Web frontend, the quantization precision for `DIT/T5/CLIP` can only be set to fp8 or int8 (depending on the weights you downloaded). If you need to use non-quantized versions of models, please manually download non-quantized weights to the `i2v_model_path` or `t2v_model_path` directory where Gradio is started. **📝 Download Instructions**:
- **Note**: Whether you download complete models or partial models, the values for `i2v_model_path` and `t2v_model_path` parameters should be the first-level directory paths. For example: `Wan2.1-I2V-14B-480P-Lightx2v/`, not `Wan2.1-I2V-14B-480P-Lightx2v/int8`. - Model weights can be downloaded from HuggingFace:
- [Wan2.1-Distill-Models](https://huggingface.co/lightx2v/Wan2.1-Distill-Models)
- [Wan2.2-Distill-Models](https://huggingface.co/lightx2v/Wan2.2-Distill-Models)
- Text and Image Encoders can be downloaded from [Encoders](https://huggingface.co/lightx2v/Encoderss)
- VAE can be downloaded from [Autoencoders](https://huggingface.co/lightx2v/Autoencoders)
- For `xxx_split` directories (e.g., `wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step_split`), which store multiple safetensors by block, suitable for devices with insufficient memory. For example, devices with 16GB or less memory should download according to their own situation.
### Startup Methods ### Startup Methods
...@@ -96,8 +97,7 @@ vim run_gradio.sh ...@@ -96,8 +97,7 @@ vim run_gradio.sh
# Configuration items that need to be modified: # Configuration items that need to be modified:
# - lightx2v_path: Lightx2v project root directory path # - lightx2v_path: Lightx2v project root directory path
# - i2v_model_path: Image-to-video model path # - model_path: Model root directory path (contains all model files)
# - t2v_model_path: Text-to-video model path
# 💾 Important note: Recommend pointing model paths to SSD storage locations # 💾 Important note: Recommend pointing model paths to SSD storage locations
# Example: /mnt/ssd/models/ or /data/ssd/models/ # Example: /mnt/ssd/models/ or /data/ssd/models/
...@@ -105,11 +105,9 @@ vim run_gradio.sh ...@@ -105,11 +105,9 @@ vim run_gradio.sh
# 2. Run the startup script # 2. Run the startup script
bash run_gradio.sh bash run_gradio.sh
# 3. Or start with parameters (recommended using distilled models) # 3. Or start with parameters
bash run_gradio.sh --task i2v --lang en --model_cls wan2.1 --model_size 14b --port 8032 bash run_gradio.sh --lang en --port 8032
bash run_gradio.sh --task t2v --lang en --model_cls wan2.1 --model_size 1.3b --port 8032 bash run_gradio.sh --lang zh --port 7862
bash run_gradio.sh --task i2v --lang en --model_cls wan2.1_distill --model_size 14b --port 8032
bash run_gradio.sh --task t2v --lang en --model_cls wan2.1_distill --model_size 1.3b --port 8032
``` ```
**Windows Environment:** **Windows Environment:**
...@@ -120,8 +118,7 @@ notepad run_gradio_win.bat ...@@ -120,8 +118,7 @@ notepad run_gradio_win.bat
# Configuration items that need to be modified: # Configuration items that need to be modified:
# - lightx2v_path: Lightx2v project root directory path # - lightx2v_path: Lightx2v project root directory path
# - i2v_model_path: Image-to-video model path # - model_path: Model root directory path (contains all model files)
# - t2v_model_path: Text-to-video model path
# 💾 Important note: Recommend pointing model paths to SSD storage locations # 💾 Important note: Recommend pointing model paths to SSD storage locations
# Example: D:\models\ or E:\models\ # Example: D:\models\ or E:\models\
...@@ -129,201 +126,101 @@ notepad run_gradio_win.bat ...@@ -129,201 +126,101 @@ notepad run_gradio_win.bat
# 2. Run the startup script # 2. Run the startup script
run_gradio_win.bat run_gradio_win.bat
# 3. Or start with parameters (recommended using distilled models) # 3. Or start with parameters
run_gradio_win.bat --task i2v --lang en --model_cls wan2.1 --model_size 14b --port 8032 run_gradio_win.bat --lang en --port 8032
run_gradio_win.bat --task t2v --lang en --model_cls wan2.1 --model_size 1.3b --port 8032 run_gradio_win.bat --lang zh --port 7862
run_gradio_win.bat --task i2v --lang en --model_cls wan2.1_distill --model_size 14b --port 8032
run_gradio_win.bat --task t2v --lang en --model_cls wan2.1_distill --model_size 1.3b --port 8032
``` ```
#### Method 2: Direct Command Line Startup #### Method 2: Direct Command Line Startup
```bash
pip install -v git+https://github.com/ModelTC/LightX2V.git
```
**Linux Environment:** **Linux Environment:**
**Image-to-Video Mode:** **English Interface Version:**
```bash ```bash
python gradio_demo.py \ python gradio_demo.py \
--model_path /path/to/Wan2.1-I2V-14B-480P-Lightx2v \ --model_path /path/to/models \
--model_cls wan2.1 \
--model_size 14b \
--task i2v \
--server_name 0.0.0.0 \ --server_name 0.0.0.0 \
--server_port 7862 --server_port 7862
``` ```
**English Interface Version:** **Chinese Interface Version:**
```bash ```bash
python gradio_demo.py \ python gradio_demo_zh.py \
--model_path /path/to/Wan2.1-T2V-14B-StepDistill-CfgDistill-Lightx2v \ --model_path /path/to/models \
--model_cls wan2.1_distill \
--model_size 14b \
--task t2v \
--server_name 0.0.0.0 \ --server_name 0.0.0.0 \
--server_port 7862 --server_port 7862
``` ```
**Windows Environment:** **Windows Environment:**
**Image-to-Video Mode:** **English Interface Version:**
```cmd ```cmd
python gradio_demo.py ^ python gradio_demo.py ^
--model_path D:\models\Wan2.1-I2V-14B-480P-Lightx2v ^ --model_path D:\models ^
--model_cls wan2.1 ^
--model_size 14b ^
--task i2v ^
--server_name 127.0.0.1 ^ --server_name 127.0.0.1 ^
--server_port 7862 --server_port 7862
``` ```
**English Interface Version:** **Chinese Interface Version:**
```cmd ```cmd
python gradio_demo.py ^ python gradio_demo_zh.py ^
--model_path D:\models\Wan2.1-T2V-14B-StepDistill-CfgDistill-Lightx2v ^ --model_path D:\models ^
--model_cls wan2.1_distill ^
--model_size 14b ^
--task t2v ^
--server_name 127.0.0.1 ^ --server_name 127.0.0.1 ^
--server_port 7862 --server_port 7862
``` ```
**💡 Tip**: Model type (wan2.1/wan2.2), task type (i2v/t2v), and specific model file selection are all configured in the Web interface.
## 📋 Command Line Parameters ## 📋 Command Line Parameters
| Parameter | Type | Required | Default | Description | | Parameter | Type | Required | Default | Description |
|-----------|------|----------|---------|-------------| |-----------|------|----------|---------|-------------|
| `--model_path` | str | ✅ | - | Model folder path | | `--model_path` | str | ✅ | - | Model root directory path (directory containing all model files) |
| `--model_cls` | str | ❌ | wan2.1 | Model class: `wan2.1` (standard model) or `wan2.1_distill` (distilled model, faster inference) |
| `--model_size` | str | ✅ | - | Model size: `14b (image-to-video or text-to-video)` or `1.3b (text-to-video)` |
| `--task` | str | ✅ | - | Task type: `i2v` (image-to-video) or `t2v` (text-to-video) |
| `--server_port` | int | ❌ | 7862 | Server port | | `--server_port` | int | ❌ | 7862 | Server port |
| `--server_name` | str | ❌ | 0.0.0.0 | Server IP address | | `--server_name` | str | ❌ | 0.0.0.0 | Server IP address |
| `--output_dir` | str | ❌ | ./outputs | Output video save directory |
**💡 Note**: Model type (wan2.1/wan2.2), task type (i2v/t2v), and specific model file selection are all configured in the Web interface.
## 🎯 Features ## 🎯 Features
### Basic Settings ### Model Configuration
- **Model Type**: Supports wan2.1 and wan2.2 model architectures
- **Task Type**: Supports Image-to-Video (i2v) and Text-to-Video (t2v) generation modes
- **Model Selection**: Frontend automatically identifies and filters available model files, supports automatic quantization precision detection
- **Encoder Configuration**: Supports selection of T5 text encoder, CLIP image encoder, and VAE decoder
- **Operator Selection**: Supports multiple attention operators and quantization matrix multiplication operators, system automatically sorts by installation status
### Input Parameters
#### Input Parameters
- **Prompt**: Describe the expected video content - **Prompt**: Describe the expected video content
- **Negative Prompt**: Specify elements you don't want to appear - **Negative Prompt**: Specify elements you don't want to appear
- **Input Image**: Upload input image required in i2v mode
- **Resolution**: Supports multiple preset resolutions (480p/540p/720p) - **Resolution**: Supports multiple preset resolutions (480p/540p/720p)
- **Random Seed**: Controls the randomness of generation results - **Random Seed**: Controls the randomness of generation results
- **Inference Steps**: Affects the balance between generation quality and speed - **Inference Steps**: Affects the balance between generation quality and speed (defaults to 4 steps for distilled models)
### Video Parameters
#### Video Parameters
- **FPS**: Frames per second - **FPS**: Frames per second
- **Total Frames**: Video length - **Total Frames**: Video length
- **CFG Scale Factor**: Controls prompt influence strength (1-10) - **CFG Scale Factor**: Controls prompt influence strength (1-10, defaults to 1 for distilled models)
- **Distribution Shift**: Controls generation style deviation degree (0-10) - **Distribution Shift**: Controls generation style deviation degree (0-10)
### Advanced Optimization Options ## 🔧 Auto-Configuration Feature
#### GPU Memory Optimization
- **Chunked Rotary Position Embedding**: Saves GPU memory
- **Rotary Embedding Chunk Size**: Controls chunk granularity
- **Clean CUDA Cache**: Promptly frees GPU memory
#### Asynchronous Offloading
- **CPU Offloading**: Transfers partial computation to CPU
- **Lazy Loading**: Loads model components on-demand, significantly reduces system memory consumption
- **Offload Granularity Control**: Fine-grained control of offloading strategies
#### Low-Precision Quantization
- **Attention Operators**: Flash Attention, Sage Attention, etc.
- **Quantization Operators**: vLLM, SGL, Q8F, etc.
- **Precision Modes**: FP8, INT8, BF16, etc.
#### VAE Optimization
- **Lightweight VAE**: Accelerates decoding process
- **VAE Tiling Inference**: Reduces memory usage
#### Feature Caching The system automatically configures optimal inference options based on your hardware configuration (GPU VRAM and CPU memory) without manual adjustment. The best configuration is automatically applied on startup, including:
- **Tea Cache**: Caches intermediate features to accelerate generation
- **Cache Threshold**: Controls cache trigger conditions
- **Key Step Caching**: Writes cache only at key steps
## 🔧 Auto-Configuration Feature - **GPU Memory Optimization**: Automatically enables CPU offloading, VAE tiling inference, etc. based on VRAM size
- **CPU Memory Optimization**: Automatically enables lazy loading, module unloading, etc. based on system memory
- **Operator Selection**: Automatically selects the best installed operators (sorted by priority)
- **Quantization Configuration**: Automatically detects and applies quantization precision based on model file names
After enabling "Auto-configure Inference Options", the system will automatically optimize parameters based on your hardware configuration:
### GPU Memory Rules
- **80GB+**: Default configuration, no optimization needed
- **48GB**: Enable CPU offloading, offload ratio 50%
- **40GB**: Enable CPU offloading, offload ratio 80%
- **32GB**: Enable CPU offloading, offload ratio 100%
- **24GB**: Enable BF16 precision, VAE tiling
- **16GB**: Enable chunked offloading, rotary embedding chunking
- **12GB**: Enable cache cleaning, lightweight VAE
- **8GB**: Enable quantization, lazy loading
### CPU Memory Rules
- **128GB+**: Default configuration
- **64GB**: Enable DIT quantization
- **32GB**: Enable lazy loading
- **16GB**: Enable full model quantization
## ⚠️ Important Notes
### 🚀 Low-Resource Device Optimization Recommendations
**💡 For devices with insufficient VRAM or performance constraints**:
- **🎯 Model Selection**: Prioritize using distilled version models (`wan2.1_distill`)
- **⚡ Inference Steps**: Recommend setting to 4 steps
- **🔧 CFG Settings**: Recommend disabling CFG option to improve generation speed
- **🔄 Auto-Configuration**: Enable "Auto-configure Inference Options"
- **💾 Storage Optimization**: Ensure models are stored on SSD for optimal loading performance
## 🎨 Interface Description
### Basic Settings Tab
- **Input Parameters**: Prompts, resolution, and other basic settings
- **Video Parameters**: FPS, frame count, CFG, and other video generation parameters
- **Output Settings**: Video save path configuration
### Advanced Options Tab
- **GPU Memory Optimization**: Memory management related options
- **Asynchronous Offloading**: CPU offloading and lazy loading
- **Low-Precision Quantization**: Various quantization optimization options
- **VAE Optimization**: Variational Autoencoder optimization
- **Feature Caching**: Cache strategy configuration
## 🔍 Troubleshooting
### Common Issues
**💡 Tip**: Generally, after enabling "Auto-configure Inference Options", the system will automatically optimize parameter settings based on your hardware configuration, and performance issues usually won't occur. If you encounter problems, please refer to the following solutions:
1. **Gradio Webpage Opens Blank**
- Try upgrading gradio: `pip install --upgrade gradio`
2. **CUDA Memory Insufficient**
- Enable CPU offloading
- Reduce resolution
- Enable quantization options
3. **System Memory Insufficient**
- Enable CPU offloading
- Enable lazy loading option
- Enable quantization options
4. **Slow Generation Speed**
- Reduce inference steps
- Enable auto-configuration
- Use lightweight models
- Enable Tea Cache
- Use quantization operators
- 💾 **Check if models are stored on SSD**
5. **Slow Model Loading**
- 💾 **Migrate models to SSD storage**
- Enable lazy loading option
- Check disk I/O performance
- Consider using NVMe SSD
6. **Poor Video Quality**
- Increase inference steps
- Increase CFG scale factor
- Use 14B models
- Optimize prompts
### Log Viewing ### Log Viewing
......
...@@ -38,51 +38,53 @@ LightX2V/app/ ...@@ -38,51 +38,53 @@ LightX2V/app/
-[sglang-kernel](https://github.com/sgl-project/sglang/tree/main/sgl-kernel) -[sglang-kernel](https://github.com/sgl-project/sglang/tree/main/sgl-kernel)
-[q8-kernel](https://github.com/KONAKONA666/q8_kernels) (仅支持ADA架构的GPU) -[q8-kernel](https://github.com/KONAKONA666/q8_kernels) (仅支持ADA架构的GPU)
可根据需要,按照各算子的项目主页教程进行安装 可根据需要,按照各算子的项目主页教程进行安装
### 🤖 支持的模型 ### 📥 模型下载
#### 🎬 图像到视频模型 (Image-to-Video) 可参考[模型结构文档](../getting_started/model_structure.md)下载完整模型(包含量化和非量化版本)或仅下载量化/非量化版本。
| 模型名称 | 分辨率 | 参数量 | 特点 | 推荐场景 | #### wan2.1 模型目录结构
|----------|--------|--------|------|----------|
| ✅ [Wan2.1-I2V-14B-480P-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-480P-Lightx2v) | 480p | 14B | 标准版本 | 平衡速度和质量 |
| ✅ [Wan2.1-I2V-14B-720P-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-720P-Lightx2v) | 720p | 14B | 高清版本 | 追求高质量输出 |
| ✅ [Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-Lightx2v) | 480p | 14B | 蒸馏优化版 | 更快的推理速度 |
| ✅ [Wan2.1-I2V-14B-720P-StepDistill-CfgDistill-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-720P-StepDistill-CfgDistill-Lightx2v) | 720p | 14B | 高清蒸馏版 | 高质量+快速推理 |
#### 📝 文本到视频模型 (Text-to-Video) ```
models/
| 模型名称 | 参数量 | 特点 | 推荐场景 | ├── wan2.1_i2v_720p_lightx2v_4step.safetensors # 原始精度
|----------|--------|------|----------| ├── wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step.safetensors # FP8 量化
| ✅ [Wan2.1-T2V-1.3B-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-T2V-1.3B-Lightx2v) | 1.3B | 轻量级 | 快速原型测试 | ├── wan2.1_i2v_720p_int8_lightx2v_4step.safetensors # INT8 量化
| ✅ [Wan2.1-T2V-14B-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-Lightx2v) | 14B | 标准版本 | 平衡速度和质量 | ├── wan2.1_i2v_720p_int8_lightx2v_4step_split # INT8 量化分block存储目录
| ✅ [Wan2.1-T2V-14B-StepDistill-CfgDistill-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill-Lightx2v) | 14B | 蒸馏优化版 | 高质量+快速推理 | ├── wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step_split # FP8 量化分block存储目录
├── 其他权重(例如t2v)
**💡 模型选择建议**: ├── t5/clip/xlm-roberta-large/google # text和image encoder
- **首次使用**: 建议选择蒸馏版本 (`wan2.1_distill`) ├── vae/lightvae/lighttae # vae
- **追求质量**: 选择720p分辨率或14B参数模型 └── config.json # 模型配置文件
- **追求速度**: 选择480p分辨率或1.3B参数模型,优先使用蒸馏版本 ```
- **资源受限**: 优先选择蒸馏版本和较低分辨率
- **实时应用**: 强烈推荐使用蒸馏模型 (`wan2.1_distill`)
**🎯 模型类别说明**:
- **`wan2.1`**: 标准模型,提供最佳的视频生成质量,适合对质量要求极高的场景
- **`wan2.1_distill`**: 蒸馏模型,通过知识蒸馏技术优化,推理速度显著提升,在保持良好质量的同时大幅减少计算时间,适合大多数应用场景
**📥 下载模型**:
可参考[模型结构文档](./model_structure.md)下载完整模型(包含量化和非量化版本)或仅下载量化/非量化版本。
**下载选项说明** #### wan2.2 模型目录结构
- **完整模型**:下载包含量化和非量化版本的完整模型时,在`Gradio` Web前端的高级选项中可以自由选择DIT/T5/CLIP的量化精度。 ```
models/
├── wan2.2_i2v_A14b_high_noise_lightx2v_4step_1030.safetensors # high noise 原始精度
├── wan2.2_i2v_A14b_high_noise_fp8_e4m3_lightx2v_4step_1030.safetensors # high noise FP8 量化
├── wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step_1030.safetensors # high noise INT8 量化
├── wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step_1030_split # high noise INT8 量化分block存储目录
├── wan2.2_i2v_A14b_low_noise_lightx2v_4step.safetensors # low noise 原始精度
├── wan2.2_i2v_A14b_low_noise_fp8_e4m3_lightx2v_4step.safetensors # low noise FP8 量化
├── wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors # low noise INT8 量化
├── wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step_split # low noise INT8 量化分block存储目录
├── t5/clip/xlm-roberta-large/google # text和image encoder
├── vae/lightvae/lighttae # vae
└── config.json # 模型配置文件
```
- **仅非量化版本**:仅下载非量化版本时,在`Gradio` Web前端中,`DIT/T5/CLIP`的量化精度只能选择bf16/fp16。如需使用量化版本的模型,请手动下载量化权重到Gradio启动的`i2v_model_path`或者`t2v_model_path`目录下。 **📝 下载说明**
- **仅量化版本**:仅下载量化版本时,在`Gradio` Web前端中,`DIT/T5/CLIP`的量化精度只能选择fp8或int8(取决于您下载的权重)。如需使用非量化版本的模型,请手动下载非量化权重到Gradio启动的`i2v_model_path`或者`t2v_model_path`目录下。 - 模型权重可从 HuggingFace 下载:
- [Wan2.1-Distill-Models](https://huggingface.co/lightx2v/Wan2.1-Distill-Models)
- [Wan2.2-Distill-Models](https://huggingface.co/lightx2v/Wan2.2-Distill-Models)
- Text 和 Image Encoder 可从 [Encoders](https://huggingface.co/lightx2v/Encoderss) 下载
- VAE 可从 [Autoencoders](https://huggingface.co/lightx2v/Autoencoders) 下载
- 对于 `xxx_split` 目录(例如 `wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step_split`),即按照 block 存储的多个 safetensors,适用于内存不足的设备。例如内存 16GB 以内,请根据自身情况下载
- **注意**:无论是下载了完整模型还是部分模型,`i2v_model_path``t2v_model_path` 参数的值都应该是一级目录的路径。例如:`Wan2.1-I2V-14B-480P-Lightx2v/`,而不是 `Wan2.1-I2V-14B-480P-Lightx2v/int8`
### 启动方式 ### 启动方式
...@@ -96,8 +98,7 @@ vim run_gradio.sh ...@@ -96,8 +98,7 @@ vim run_gradio.sh
# 需要修改的配置项: # 需要修改的配置项:
# - lightx2v_path: Lightx2v项目根目录路径 # - lightx2v_path: Lightx2v项目根目录路径
# - i2v_model_path: 图像到视频模型路径 # - model_path: 模型根目录路径(包含所有模型文件)
# - t2v_model_path: 文本到视频模型路径
# 💾 重要提示:建议将模型路径指向SSD存储位置 # 💾 重要提示:建议将模型路径指向SSD存储位置
# 例如:/mnt/ssd/models/ 或 /data/ssd/models/ # 例如:/mnt/ssd/models/ 或 /data/ssd/models/
...@@ -105,11 +106,9 @@ vim run_gradio.sh ...@@ -105,11 +106,9 @@ vim run_gradio.sh
# 2. 运行启动脚本 # 2. 运行启动脚本
bash run_gradio.sh bash run_gradio.sh
# 3. 或使用参数启动(推荐使用蒸馏模型) # 3. 或使用参数启动
bash run_gradio.sh --task i2v --lang zh --model_cls wan2.1 --model_size 14b --port 8032 bash run_gradio.sh --lang zh --port 8032
bash run_gradio.sh --task t2v --lang zh --model_cls wan2.1 --model_size 1.3b --port 8032 bash run_gradio.sh --lang en --port 7862
bash run_gradio.sh --task i2v --lang zh --model_cls wan2.1_distill --model_size 14b --port 8032
bash run_gradio.sh --task t2v --lang zh --model_cls wan2.1_distill --model_size 1.3b --port 8032
``` ```
**Windows 环境:** **Windows 环境:**
...@@ -120,8 +119,7 @@ notepad run_gradio_win.bat ...@@ -120,8 +119,7 @@ notepad run_gradio_win.bat
# 需要修改的配置项: # 需要修改的配置项:
# - lightx2v_path: Lightx2v项目根目录路径 # - lightx2v_path: Lightx2v项目根目录路径
# - i2v_model_path: 图像到视频模型路径 # - model_path: 模型根目录路径(包含所有模型文件)
# - t2v_model_path: 文本到视频模型路径
# 💾 重要提示:建议将模型路径指向SSD存储位置 # 💾 重要提示:建议将模型路径指向SSD存储位置
# 例如:D:\models\ 或 E:\models\ # 例如:D:\models\ 或 E:\models\
...@@ -129,24 +127,23 @@ notepad run_gradio_win.bat ...@@ -129,24 +127,23 @@ notepad run_gradio_win.bat
# 2. 运行启动脚本 # 2. 运行启动脚本
run_gradio_win.bat run_gradio_win.bat
# 3. 或使用参数启动(推荐使用蒸馏模型) # 3. 或使用参数启动
run_gradio_win.bat --task i2v --lang zh --model_cls wan2.1 --model_size 14b --port 8032 run_gradio_win.bat --lang zh --port 8032
run_gradio_win.bat --task t2v --lang zh --model_cls wan2.1 --model_size 1.3b --port 8032 run_gradio_win.bat --lang en --port 7862
run_gradio_win.bat --task i2v --lang zh --model_cls wan2.1_distill --model_size 14b --port 8032
run_gradio_win.bat --task t2v --lang zh --model_cls wan2.1_distill --model_size 1.3b --port 8032
``` ```
#### 方式二:直接命令行启动 #### 方式二:直接命令行启动
```bash
pip install -v git+https://github.com/ModelTC/LightX2V.git
```
**Linux 环境:** **Linux 环境:**
**图像到视频模式:** **中文界面版本:**
```bash ```bash
python gradio_demo_zh.py \ python gradio_demo_zh.py \
--model_path /path/to/Wan2.1-I2V-14B-480P-Lightx2v \ --model_path /path/to/models \
--model_cls wan2.1 \
--model_size 14b \
--task i2v \
--server_name 0.0.0.0 \ --server_name 0.0.0.0 \
--server_port 7862 --server_port 7862
``` ```
...@@ -154,176 +151,77 @@ python gradio_demo_zh.py \ ...@@ -154,176 +151,77 @@ python gradio_demo_zh.py \
**英文界面版本:** **英文界面版本:**
```bash ```bash
python gradio_demo.py \ python gradio_demo.py \
--model_path /path/to/Wan2.1-T2V-14B-StepDistill-CfgDistill-Lightx2v \ --model_path /path/to/models \
--model_cls wan2.1_distill \
--model_size 14b \
--task t2v \
--server_name 0.0.0.0 \ --server_name 0.0.0.0 \
--server_port 7862 --server_port 7862
``` ```
**Windows 环境:** **Windows 环境:**
**图像到视频模式:** **中文界面版本:**
```cmd ```cmd
python gradio_demo_zh.py ^ python gradio_demo_zh.py ^
--model_path D:\models\Wan2.1-I2V-14B-480P-Lightx2v ^ --model_path D:\models ^
--model_cls wan2.1 ^
--model_size 14b ^
--task i2v ^
--server_name 127.0.0.1 ^ --server_name 127.0.0.1 ^
--server_port 7862 --server_port 7862
``` ```
**英文界面版本:** **英文界面版本:**
```cmd ```cmd
python gradio_demo_zh.py ^ python gradio_demo.py ^
--model_path D:\models\Wan2.1-T2V-14B-StepDistill-CfgDistill-Lightx2v ^ --model_path D:\models ^
--model_cls wan2.1_distill ^
--model_size 14b ^
--task i2v ^
--server_name 127.0.0.1 ^ --server_name 127.0.0.1 ^
--server_port 7862 --server_port 7862
``` ```
**💡 提示**:模型类型(wan2.1/wan2.2)、任务类型(i2v/t2v)以及具体的模型文件选择均在 Web 界面中进行配置。
## 📋 命令行参数 ## 📋 命令行参数
| 参数 | 类型 | 必需 | 默认值 | 说明 | | 参数 | 类型 | 必需 | 默认值 | 说明 |
|------|------|------|--------|------| |------|------|------|--------|------|
| `--model_path` | str | ✅ | - | 模型文件夹路径 | | `--model_path` | str | ✅ | - | 模型根目录路径(包含所有模型文件的目录) |
| `--model_cls` | str | ❌ | wan2.1 | 模型类别:`wan2.1`(标准模型)或 `wan2.1_distill`(蒸馏模型,推理更快) |
| `--model_size` | str | ✅ | - | 模型大小:`14b``1.3b)` |
| `--task` | str | ✅ | - | 任务类型:`i2v`(图像到视频)或 `t2v`(文本到视频) |
| `--server_port` | int | ❌ | 7862 | 服务器端口 | | `--server_port` | int | ❌ | 7862 | 服务器端口 |
| `--server_name` | str | ❌ | 0.0.0.0 | 服务器IP地址 | | `--server_name` | str | ❌ | 0.0.0.0 | 服务器IP地址 |
| `--output_dir` | str | ❌ | ./outputs | 输出视频保存目录 |
**💡 说明**:模型类型(wan2.1/wan2.2)、任务类型(i2v/t2v)以及具体的模型文件选择均在 Web 界面中进行配置。
## 🎯 功能特性 ## 🎯 功能特性
### 基本设置 ### 模型配置
- **模型类型**: 支持 wan2.1 和 wan2.2 两种模型架构
- **任务类型**: 支持图像到视频(i2v)和文本到视频(t2v)两种生成模式
- **模型选择**: 前端自动识别并筛选可用的模型文件,支持自动检测量化精度
- **编码器配置**: 支持选择 T5 文本编码器、CLIP 图像编码器和 VAE 解码器
- **算子选择**: 支持多种注意力算子和量化矩阵乘法算子,系统会根据安装状态自动排序
### 输入参数
#### 输入参数
- **提示词 (Prompt)**: 描述期望的视频内容 - **提示词 (Prompt)**: 描述期望的视频内容
- **负向提示词 (Negative Prompt)**: 指定不希望出现的元素 - **负向提示词 (Negative Prompt)**: 指定不希望出现的元素
- **输入图像**: i2v 模式下需要上传输入图像
- **分辨率**: 支持多种预设分辨率(480p/540p/720p) - **分辨率**: 支持多种预设分辨率(480p/540p/720p)
- **随机种子**: 控制生成结果的随机性 - **随机种子**: 控制生成结果的随机性
- **推理步数**: 影响生成质量和速度的平衡 - **推理步数**: 影响生成质量和速度的平衡(蒸馏模型默认为 4 步)
### 视频参数
#### 视频参数
- **FPS**: 每秒帧数 - **FPS**: 每秒帧数
- **总帧数**: 视频长度 - **总帧数**: 视频长度
- **CFG缩放因子**: 控制提示词影响强度(1-10) - **CFG缩放因子**: 控制提示词影响强度(1-10,蒸馏模型默认为 1
- **分布偏移**: 控制生成风格偏离程度(0-10) - **分布偏移**: 控制生成风格偏离程度(0-10)
### 高级优化选项 ## 🔧 自动配置功能
#### GPU内存优化
- **分块旋转位置编码**: 节省GPU内存
- **旋转编码块大小**: 控制分块粒度
- **清理CUDA缓存**: 及时释放GPU内存
#### 异步卸载
- **CPU卸载**: 将部分计算转移到CPU
- **延迟加载**: 按需加载模型组件,显著节省系统内存消耗
- **卸载粒度控制**: 精细控制卸载策略
#### 低精度量化
- **注意力算子**: Flash Attention、Sage Attention等
- **量化算子**: vLLM、SGL、Q8F等
- **精度模式**: FP8、INT8、BF16等
#### VAE优化
- **轻量级VAE**: 加速解码过程
- **VAE分块推理**: 减少内存占用
#### 特征缓存 系统会根据您的硬件配置(GPU 显存和 CPU 内存)自动配置最优推理选项,无需手动调整。启动时会自动应用最佳配置,包括:
- **Tea Cache**: 缓存中间特征加速生成
- **缓存阈值**: 控制缓存触发条件
- **关键步缓存**: 仅在关键步骤写入缓存
## 🔧 自动配置功能 - **GPU 内存优化**: 根据显存大小自动启用 CPU 卸载、VAE 分块推理等
- **CPU 内存优化**: 根据系统内存自动启用延迟加载、模块卸载等
- **算子选择**: 自动选择已安装的最优算子(按优先级排序)
- **量化配置**: 根据模型文件名自动检测并应用量化精度
启用"自动配置推理选项"后,系统会根据您的硬件配置自动优化参数:
### GPU内存规则
- **80GB+**: 默认配置,无需优化
- **48GB**: 启用CPU卸载,卸载比例50%
- **40GB**: 启用CPU卸载,卸载比例80%
- **32GB**: 启用CPU卸载,卸载比例100%
- **24GB**: 启用BF16精度、VAE分块
- **16GB**: 启用分块卸载、旋转编码分块
- **12GB**: 启用清理缓存、轻量级VAE
- **8GB**: 启用量化、延迟加载
### CPU内存规则
- **128GB+**: 默认配置
- **64GB**: 启用DIT量化
- **32GB**: 启用延迟加载
- **16GB**: 启用全模型量化
## ⚠️ 重要注意事项
### 🚀 低资源设备优化建议
**💡 针对显存不足或性能受限的设备**:
- **🎯 模型选择**: 优先使用蒸馏版本模型 (`wan2.1_distill`)
- **⚡ 推理步数**: 建议设置为 4 步
- **🔧 CFG设置**: 建议关闭CFG选项以提升生成速度
- **🔄 自动配置**: 启用"自动配置推理选项"
- **💾 存储优化**: 确保模型存储在SSD上以获得最佳加载性能
## 🎨 界面说明
### 基本设置标签页
- **输入参数**: 提示词、分辨率等基本设置
- **视频参数**: FPS、帧数、CFG等视频生成参数
- **输出设置**: 视频保存路径配置
### 高级选项标签页
- **GPU内存优化**: 内存管理相关选项
- **异步卸载**: CPU卸载和延迟加载
- **低精度量化**: 各种量化优化选项
- **VAE优化**: 变分自编码器优化
- **特征缓存**: 缓存策略配置
## 🔍 故障排除
### 常见问题
**💡 提示**: 一般情况下,启用"自动配置推理选项"后,系统会根据您的硬件配置自动优化参数设置,通常不会出现性能问题。如果遇到问题,请参考以下解决方案:
1. **Gradio网页打开空白**
- 尝试升级gradio `pip install --upgrade gradio`
2. **CUDA内存不足**
- 启用CPU卸载
- 降低分辨率
- 启用量化选项
3. **系统内存不足**
- 启用CPU卸载
- 启用延迟加载选项
- 启用量化选项
4. **生成速度慢**
- 减少推理步数
- 启用自动配置
- 使用轻量级模型
- 启用Tea Cache
- 使用量化算子
- 💾 **检查模型是否存放在SSD上**
5. **模型加载缓慢**
- 💾 **将模型迁移到SSD存储**
- 启用延迟加载选项
- 检查磁盘I/O性能
- 考虑使用NVMe SSD
6. **视频质量不佳**
- 增加推理步数
- 提高CFG缩放因子
- 使用14B模型
- 优化提示词
### 日志查看 ### 日志查看
......
import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import torch import torch
...@@ -115,8 +116,6 @@ class WeightAsyncStreamManager(object): ...@@ -115,8 +116,6 @@ class WeightAsyncStreamManager(object):
self.prefetch_futures.append(future) self.prefetch_futures.append(future)
def swap_cpu_buffers(self): def swap_cpu_buffers(self):
import time
wait_start = time.time() wait_start = time.time()
already_done = all(f.done() for f in self.prefetch_futures) already_done = all(f.done() for f in self.prefetch_futures)
for f in self.prefetch_futures: for f in self.prefetch_futures:
...@@ -125,25 +124,11 @@ class WeightAsyncStreamManager(object): ...@@ -125,25 +124,11 @@ class WeightAsyncStreamManager(object):
logger.debug(f"[Prefetch] block {self.prefetch_block_idx}: wait={wait_time:.3f}s, already_done={already_done}") logger.debug(f"[Prefetch] block {self.prefetch_block_idx}: wait={wait_time:.3f}s, already_done={already_done}")
self.cpu_buffers = [self.cpu_buffers[1], self.cpu_buffers[0]] self.cpu_buffers = [self.cpu_buffers[1], self.cpu_buffers[0]]
def shutdown(self, wait=True): def __del__(self):
"""Shutdown the thread pool executor and wait for all pending tasks to complete."""
if hasattr(self, "executor") and self.executor is not None: if hasattr(self, "executor") and self.executor is not None:
# Wait for all pending futures to complete before shutting down for f in self.prefetch_futures:
if hasattr(self, "prefetch_futures"): if not f.done():
for f in self.prefetch_futures: f.result()
try: self.executor.shutdown(wait=False)
if not f.done():
f.result()
except Exception:
pass
self.executor.shutdown(wait=wait)
self.executor = None self.executor = None
logger.debug("ThreadPoolExecutor shut down successfully.") logger.debug("ThreadPoolExecutor shut down successfully.")
def __del__(self):
"""Cleanup method to ensure executor is shut down when object is destroyed."""
try:
if hasattr(self, "executor") and self.executor is not None:
self.executor.shutdown(wait=False)
except Exception:
pass
...@@ -178,7 +178,7 @@ class WanModel(CompiledMethodsMixin): ...@@ -178,7 +178,7 @@ class WanModel(CompiledMethodsMixin):
if os.path.exists(non_block_file): if os.path.exists(non_block_file):
safetensors_files = [non_block_file] safetensors_files = [non_block_file]
else: else:
raise ValueError(f"Non-block file not found in {safetensors_path}") raise ValueError(f"Non-block file not found in {safetensors_path}. Please check the model path. Lazy load mode only supports loading chunked model weights.")
weight_dict = {} weight_dict = {}
for file_path in safetensors_files: for file_path in safetensors_files:
...@@ -221,7 +221,7 @@ class WanModel(CompiledMethodsMixin): ...@@ -221,7 +221,7 @@ class WanModel(CompiledMethodsMixin):
if os.path.exists(non_block_file): if os.path.exists(non_block_file):
safetensors_files = [non_block_file] safetensors_files = [non_block_file]
else: else:
raise ValueError(f"Non-block file not found in {safetensors_path}, Please check the lazy load model path") raise ValueError(f"Non-block file not found in {safetensors_path}. Please check the model path. Lazy load mode only supports loading chunked model weights.")
weight_dict = {} weight_dict = {}
for safetensor_path in safetensors_files: for safetensor_path in safetensors_files:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment