Unverified Commit 9a765f9b authored by Gu Shiqiao's avatar Gu Shiqiao Committed by GitHub
Browse files

fix gradio (#587)

parent 8530a2fb
......@@ -4,6 +4,9 @@ import glob
import importlib.util
import json
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["DTYPE"] = "BF16"
import random
from datetime import datetime
......@@ -12,6 +15,15 @@ import psutil
import torch
from loguru import logger
from lightx2v.utils.input_info import set_input_info
from lightx2v.utils.set_config import get_default_config
try:
from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace
except ImportError:
apply_rope_with_cos_sin_cache_inplace = None
logger.add(
"inference_logs.log",
rotation="100 MB",
......@@ -24,38 +36,196 @@ logger.add(
MAX_NUMPY_SEED = 2**32 - 1
def find_hf_model_path(model_path, subdir=["original", "fp8", "int8"]):
paths_to_check = [model_path]
if isinstance(subdir, list):
for sub in subdir:
paths_to_check.append(os.path.join(model_path, sub))
def scan_model_path_contents(model_path):
"""Scan model_path directory and return available files and subdirectories"""
if not model_path or not os.path.exists(model_path):
return {"dirs": [], "files": [], "safetensors_dirs": [], "pth_files": []}
dirs = []
files = []
safetensors_dirs = []
pth_files = []
try:
for item in os.listdir(model_path):
item_path = os.path.join(model_path, item)
if os.path.isdir(item_path):
dirs.append(item)
# Check if directory contains safetensors files
if glob.glob(os.path.join(item_path, "*.safetensors")):
safetensors_dirs.append(item)
elif os.path.isfile(item_path):
files.append(item)
if item.endswith(".pth"):
pth_files.append(item)
except Exception as e:
logger.warning(f"Failed to scan directory: {e}")
return {
"dirs": sorted(dirs),
"files": sorted(files),
"safetensors_dirs": sorted(safetensors_dirs),
"pth_files": sorted(pth_files),
}
def get_dit_choices(model_path, model_type="wan2.1"):
"""Get Diffusion model options (filtered by model type)"""
contents = scan_model_path_contents(model_path)
excluded_keywords = ["vae", "tae", "clip", "t5", "high_noise", "low_noise"]
fp8_supported = is_fp8_supported_gpu()
if model_type == "wan2.1":
# wan2.1: filter files/dirs containing wan2.1 or Wan2.1
def is_valid(name):
name_lower = name.lower()
if "wan2.1" not in name_lower:
return False
if not fp8_supported and "fp8" in name_lower:
return False
return not any(kw in name_lower for kw in excluded_keywords)
else:
paths_to_check.append(os.path.join(model_path, subdir))
# wan2.2: filter files/dirs containing wan2.2 or Wan2.2
def is_valid(name):
name_lower = name.lower()
if "wan2.2" not in name_lower:
return False
if not fp8_supported and "fp8" in name_lower:
return False
return not any(kw in name_lower for kw in excluded_keywords)
for path in paths_to_check:
safetensors_pattern = os.path.join(path, "*.safetensors")
safetensors_files = glob.glob(safetensors_pattern)
if safetensors_files:
logger.info(f"Found Hugging Face model files in: {path}")
return path
raise FileNotFoundError(f"No Hugging Face model files (.safetensors) found.\nPlease download the model from: https://huggingface.co/lightx2v/ or specify the model path in the configuration file.")
# Filter matching directories and files
dir_choices = [d for d in contents["dirs"] if is_valid(d)]
file_choices = [f for f in contents["files"] if is_valid(f)]
choices = dir_choices + file_choices
return choices if choices else [""]
def find_torch_model_path(model_path, filename=None, subdir=["original", "fp8", "int8"]):
paths_to_check = [
os.path.join(model_path, filename),
]
if isinstance(subdir, list):
for sub in subdir:
paths_to_check.append(os.path.join(model_path, sub, filename))
def get_high_noise_choices(model_path):
"""Get high noise model options (files/dirs containing high_noise)"""
contents = scan_model_path_contents(model_path)
fp8_supported = is_fp8_supported_gpu()
def is_valid(name):
name_lower = name.lower()
if not fp8_supported and "fp8" in name_lower:
return False
return "high_noise" in name_lower or "high-noise" in name_lower
dir_choices = [d for d in contents["dirs"] if is_valid(d)]
file_choices = [f for f in contents["files"] if is_valid(f)]
choices = dir_choices + file_choices
return choices if choices else [""]
def get_low_noise_choices(model_path):
"""Get low noise model options (files/dirs containing low_noise)"""
contents = scan_model_path_contents(model_path)
fp8_supported = is_fp8_supported_gpu()
def is_valid(name):
name_lower = name.lower()
if not fp8_supported and "fp8" in name_lower:
return False
return "low_noise" in name_lower or "low-noise" in name_lower
dir_choices = [d for d in contents["dirs"] if is_valid(d)]
file_choices = [f for f in contents["files"] if is_valid(f)]
choices = dir_choices + file_choices
return choices if choices else [""]
def get_t5_choices(model_path):
"""Get T5 model options (.pth or .safetensors files containing t5 keyword)"""
contents = scan_model_path_contents(model_path)
fp8_supported = is_fp8_supported_gpu()
# Filter from .pth files
pth_choices = [f for f in contents["pth_files"] if "t5" in f.lower() and (fp8_supported or "fp8" not in f.lower())]
# Filter from .safetensors files
safetensors_choices = [f for f in contents["files"] if f.endswith(".safetensors") and "t5" in f.lower() and (fp8_supported or "fp8" not in f.lower())]
# Filter from directories containing safetensors
safetensors_dir_choices = [d for d in contents["safetensors_dirs"] if "t5" in d.lower() and (fp8_supported or "fp8" not in d.lower())]
choices = pth_choices + safetensors_choices + safetensors_dir_choices
return choices if choices else [""]
def get_clip_choices(model_path):
"""Get CLIP model options (.pth or .safetensors files containing clip keyword)"""
contents = scan_model_path_contents(model_path)
fp8_supported = is_fp8_supported_gpu()
# Filter from .pth files
pth_choices = [f for f in contents["pth_files"] if "clip" in f.lower() and (fp8_supported or "fp8" not in f.lower())]
# Filter from .safetensors files
safetensors_choices = [f for f in contents["files"] if f.endswith(".safetensors") and "clip" in f.lower() and (fp8_supported or "fp8" not in f.lower())]
# Filter from directories containing safetensors
safetensors_dir_choices = [d for d in contents["safetensors_dirs"] if "clip" in d.lower() and (fp8_supported or "fp8" not in d.lower())]
choices = pth_choices + safetensors_choices + safetensors_dir_choices
return choices if choices else [""]
def get_vae_choices(model_path):
"""Get VAE model options (.pth or .safetensors files containing vae/VAE/tae keyword)"""
contents = scan_model_path_contents(model_path)
fp8_supported = is_fp8_supported_gpu()
# Filter from .pth files
pth_choices = [f for f in contents["pth_files"] if any(kw in f.lower() for kw in ["vae", "tae"]) and (fp8_supported or "fp8" not in f.lower())]
# Filter from .safetensors files
safetensors_choices = [f for f in contents["files"] if f.endswith(".safetensors") and any(kw in f.lower() for kw in ["vae", "tae"]) and (fp8_supported or "fp8" not in f.lower())]
# Filter from directories containing safetensors
safetensors_dir_choices = [d for d in contents["safetensors_dirs"] if any(kw in d.lower() for kw in ["vae", "tae"]) and (fp8_supported or "fp8" not in d.lower())]
choices = pth_choices + safetensors_choices + safetensors_dir_choices
return choices if choices else [""]
def detect_quant_scheme(model_name):
"""Automatically detect quantization scheme from model name
- If model name contains "int8" → "int8"
- If model name contains "fp8" and device supports → "fp8"
- Otherwise return None (no quantization)
"""
if not model_name:
return None
name_lower = model_name.lower()
if "int8" in name_lower:
return "int8"
elif "fp8" in name_lower:
if is_fp8_supported_gpu():
return "fp8"
else:
paths_to_check.append(os.path.join(model_path, subdir, filename))
print(paths_to_check)
for path in paths_to_check:
if os.path.exists(path):
logger.info(f"Found PyTorch model checkpoint: {path}")
return path
raise FileNotFoundError(f"PyTorch model file '{filename}' not found.\nPlease download the model from https://huggingface.co/lightx2v/ or specify the model path in the configuration file.")
# Device doesn't support fp8, return None (use default precision)
return None
return None
def update_model_path_options(model_path, model_type="wan2.1"):
"""Update all model path selectors when model_path or model_type changes"""
dit_choices = get_dit_choices(model_path, model_type)
high_noise_choices = get_high_noise_choices(model_path)
low_noise_choices = get_low_noise_choices(model_path)
t5_choices = get_t5_choices(model_path)
clip_choices = get_clip_choices(model_path)
vae_choices = get_vae_choices(model_path)
return (
gr.update(choices=dit_choices, value=dit_choices[0] if dit_choices else ""),
gr.update(choices=high_noise_choices, value=high_noise_choices[0] if high_noise_choices else ""),
gr.update(choices=low_noise_choices, value=low_noise_choices[0] if low_noise_choices else ""),
gr.update(choices=t5_choices, value=t5_choices[0] if t5_choices else ""),
gr.update(choices=clip_choices, value=clip_choices[0] if clip_choices else ""),
gr.update(choices=vae_choices, value=vae_choices[0] if vae_choices else ""),
)
def generate_random_seed():
......@@ -109,12 +279,18 @@ def get_available_attn_ops():
else:
available_ops.append(("flash_attn3", False))
q8f_installed = is_module_installed("sageattention")
if q8f_installed:
sage_installed = is_module_installed("sageattention")
if sage_installed:
available_ops.append(("sage_attn2", True))
else:
available_ops.append(("sage_attn2", False))
sage3_installed = is_module_installed("sageattn3")
if sage3_installed:
available_ops.append(("sage_attn3", True))
else:
available_ops.append(("sage_attn3", False))
torch_installed = is_module_installed("torch")
if torch_installed:
available_ops.append(("torch_sdpa", True))
......@@ -150,6 +326,8 @@ def cleanup_memory():
torch.cuda.synchronize()
try:
import psutil
if hasattr(psutil, "virtual_memory"):
if os.name == "posix":
try:
......@@ -163,7 +341,7 @@ def cleanup_memory():
def generate_unique_filename(output_dir):
os.makedirs(output_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
return os.path.join(output_dir, f"{model_cls}_{timestamp}.mp4")
return os.path.join(output_dir, f"{timestamp}.mp4")
def is_fp8_supported_gpu():
......@@ -231,13 +409,25 @@ def get_quantization_options(model_path):
return {"dit_choices": dit_choices, "dit_default": dit_default, "t5_choices": t5_choices, "t5_default": t5_default, "clip_choices": clip_choices, "clip_default": clip_default}
def determine_model_cls(model_type, dit_name, high_noise_name):
"""Determine model_cls based on model type and file name"""
# Determine file name to check
if model_type == "wan2.1":
check_name = dit_name.lower() if dit_name else ""
is_distill = "4step" in check_name
return "wan2.1_distill" if is_distill else "wan2.1"
else:
# wan2.2
check_name = high_noise_name.lower() if high_noise_name else ""
is_distill = "4step" in check_name
return "wan2.2_moe_distill" if is_distill else "wan2.2_moe"
global_runner = None
current_config = None
cur_dit_quant_scheme = None
cur_clip_quant_scheme = None
cur_t5_quant_scheme = None
cur_precision_mode = None
cur_enable_teacache = None
cur_dit_path = None
cur_t5_path = None
cur_clip_path = None
available_quant_ops = get_available_quant_ops()
quant_op_choices = []
......@@ -247,8 +437,29 @@ for op_name, is_installed in available_quant_ops:
quant_op_choices.append((op_name, display_text))
available_attn_ops = get_available_attn_ops()
# Priority order
attn_priority = ["sage_attn3", "sage_attn2", "flash_attn3", "flash_attn2", "torch_sdpa"]
# Sort by priority, installed ones first, uninstalled ones last
attn_op_choices = []
for op_name, is_installed in available_attn_ops:
attn_op_dict = dict(available_attn_ops)
# Add installed ones first (by priority)
for op_name in attn_priority:
if op_name in attn_op_dict and attn_op_dict[op_name]:
status_text = "✅ Installed"
display_text = f"{op_name} ({status_text})"
attn_op_choices.append((op_name, display_text))
# Add uninstalled ones (by priority)
for op_name in attn_priority:
if op_name in attn_op_dict and not attn_op_dict[op_name]:
status_text = "❌ Not Installed"
display_text = f"{op_name} ({status_text})"
attn_op_choices.append((op_name, display_text))
# Add other operators not in priority list (installed ones first)
other_ops = [(op_name, is_installed) for op_name, is_installed in available_attn_ops if op_name not in attn_priority]
for op_name, is_installed in sorted(other_ops, key=lambda x: not x[1]): # Installed ones first
status_text = "✅ Installed" if is_installed else "❌ Not Installed"
display_text = f"{op_name} ({status_text})"
attn_op_choices.append((op_name, display_text))
......@@ -258,36 +469,36 @@ def run_inference(
prompt,
negative_prompt,
save_result_path,
torch_compile,
infer_steps,
num_frames,
resolution,
seed,
sample_shift,
enable_teacache,
teacache_thresh,
use_ret_steps,
enable_cfg,
cfg_scale,
dit_quant_scheme,
t5_quant_scheme,
clip_quant_scheme,
fps,
use_tae,
use_tiling_vae,
lazy_load,
precision_mode,
cpu_offload,
offload_granularity,
offload_ratio,
t5_cpu_offload,
clip_cpu_offload,
vae_cpu_offload,
unload_modules,
t5_offload_granularity,
attention_type,
quant_op,
rotary_chunk,
rotary_chunk_size,
rope_chunk,
rope_chunk_size,
clean_cuda_cache,
model_path_input,
model_type_input,
task_type_input,
dit_path_input,
high_noise_path_input,
low_noise_path_input,
t5_path_input,
clip_path_input,
vae_path_input,
image_path=None,
):
cleanup_memory()
......@@ -295,8 +506,23 @@ def run_inference(
quant_op = quant_op.split("(")[0].strip()
attention_type = attention_type.split("(")[0].strip()
global global_runner, current_config, model_path, task
global cur_dit_quant_scheme, cur_clip_quant_scheme, cur_t5_quant_scheme, cur_precision_mode, cur_enable_teacache
global global_runner, current_config, model_path, model_cls
global cur_dit_path, cur_t5_path, cur_clip_path
task = task_type_input
model_cls = determine_model_cls(model_type_input, dit_path_input, high_noise_path_input)
logger.info(f"Auto-determined model_cls: {model_cls} (Model type: {model_type_input})")
if model_type_input == "wan2.1":
dit_quant_detected = detect_quant_scheme(dit_path_input)
else:
dit_quant_detected = detect_quant_scheme(high_noise_path_input)
t5_quant_detected = detect_quant_scheme(t5_path_input)
clip_quant_detected = detect_quant_scheme(clip_path_input)
logger.info(f"Auto-detected quantization scheme - DIT: {dit_quant_detected}, T5: {t5_quant_detected}, CLIP: {clip_quant_detected}")
if model_path_input and model_path_input.strip():
model_path = model_path_input.strip()
if os.path.exists(os.path.join(model_path, "config.json")):
with open(os.path.join(model_path, "config.json"), "r") as f:
......@@ -304,157 +530,88 @@ def run_inference(
else:
model_config = {}
if task == "t2v":
if model_size == "1.3b":
# 1.3B
coefficient = [
[
-5.21862437e04,
9.23041404e03,
-5.28275948e02,
1.36987616e01,
-4.99875664e-02,
],
[
2.39676752e03,
-1.31110545e03,
2.01331979e02,
-8.29855975e00,
1.37887774e-01,
],
]
save_result_path = generate_unique_filename(output_dir)
is_dit_quant = dit_quant_detected != "bf16"
is_t5_quant = t5_quant_detected != "bf16"
is_clip_quant = clip_quant_detected != "fp16"
dit_quantized_ckpt = None
dit_original_ckpt = None
high_noise_quantized_ckpt = None
low_noise_quantized_ckpt = None
high_noise_original_ckpt = None
low_noise_original_ckpt = None
if is_dit_quant:
dit_quant_scheme = f"{dit_quant_detected}-{quant_op}"
if "wan2.1" in model_cls:
dit_quantized_ckpt = os.path.join(model_path, dit_path_input)
else:
# 14B
coefficient = [
[
-3.03318725e05,
4.90537029e04,
-2.65530556e03,
5.87365115e01,
-3.15583525e-01,
],
[
-5784.54975374,
5449.50911966,
-1811.16591783,
256.27178429,
-13.02252404,
],
]
elif task == "i2v":
if resolution in [
"1280x720",
"720x1280",
"1280x544",
"544x1280",
"1104x832",
"832x1104",
"960x960",
]:
# 720p
coefficient = [
[
8.10705460e03,
2.13393892e03,
-3.72934672e02,
1.66203073e01,
-4.17769401e-02,
],
[-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683],
]
high_noise_quantized_ckpt = os.path.join(model_path, high_noise_path_input)
low_noise_quantized_ckpt = os.path.join(model_path, low_noise_path_input)
else:
# 480p
coefficient = [
[
2.57151496e05,
-3.54229917e04,
1.40286849e03,
-1.35890334e01,
1.32517977e-01,
],
[
-3.02331670e02,
2.23948934e02,
-5.25463970e01,
5.87348440e00,
-2.01973289e-01,
],
]
save_result_path = generate_unique_filename(output_dir)
dit_quantized_ckpt = "Default"
if "wan2.1" in model_cls:
dit_original_ckpt = os.path.join(model_path, dit_path_input)
else:
high_noise_original_ckpt = os.path.join(model_path, high_noise_path_input)
low_noise_original_ckpt = os.path.join(model_path, low_noise_path_input)
is_dit_quant = dit_quant_scheme != "bf16"
is_t5_quant = t5_quant_scheme != "bf16"
# Use frontend-selected T5 path
if is_t5_quant:
t5_model_name = f"models_t5_umt5-xxl-enc-{t5_quant_scheme}.pth"
t5_quant_ckpt = find_torch_model_path(model_path, t5_model_name, t5_quant_scheme)
t5_quantized_ckpt = os.path.join(model_path, t5_path_input)
t5_quant_scheme = f"{t5_quant_detected}-{quant_op}"
t5_original_ckpt = None
else:
t5_quant_ckpt = None
t5_model_name = "models_t5_umt5-xxl-enc-bf16.pth"
t5_original_ckpt = find_torch_model_path(model_path, t5_model_name, "original")
t5_quantized_ckpt = None
t5_quant_scheme = None
t5_original_ckpt = os.path.join(model_path, t5_path_input)
is_clip_quant = clip_quant_scheme != "fp16"
# Use frontend-selected CLIP path
if is_clip_quant:
clip_model_name = f"clip-{clip_quant_scheme}.pth"
clip_quant_ckpt = find_torch_model_path(model_path, clip_model_name, clip_quant_scheme)
clip_quantized_ckpt = os.path.join(model_path, clip_path_input)
clip_quant_scheme = f"{clip_quant_detected}-{quant_op}"
clip_original_ckpt = None
else:
clip_quant_ckpt = None
clip_model_name = "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"
clip_original_ckpt = find_torch_model_path(model_path, clip_model_name, "original")
clip_quantized_ckpt = None
clip_quant_scheme = None
clip_original_ckpt = os.path.join(model_path, clip_path_input)
if model_type_input == "wan2.1":
current_dit_path = dit_path_input
else:
current_dit_path = f"{high_noise_path_input}|{low_noise_path_input}" if high_noise_path_input and low_noise_path_input else None
current_t5_path = t5_path_input
current_clip_path = clip_path_input
needs_reinit = (
lazy_load
or unload_modules
or global_runner is None
or current_config is None
or cur_dit_quant_scheme is None
or cur_dit_quant_scheme != dit_quant_scheme
or cur_clip_quant_scheme is None
or cur_clip_quant_scheme != clip_quant_scheme
or cur_t5_quant_scheme is None
or cur_t5_quant_scheme != t5_quant_scheme
or cur_precision_mode is None
or cur_precision_mode != precision_mode
or cur_enable_teacache is None
or cur_enable_teacache != enable_teacache
or cur_dit_path is None
or cur_dit_path != current_dit_path
or cur_t5_path is None
or cur_t5_path != current_t5_path
or cur_clip_path is None
or cur_clip_path != current_clip_path
)
if torch_compile:
os.environ["ENABLE_GRAPH_MODE"] = "true"
if cfg_scale == 1:
enable_cfg = False
else:
os.environ["ENABLE_GRAPH_MODE"] = "false"
if precision_mode == "bf16":
os.environ["DTYPE"] = "BF16"
else:
os.environ.pop("DTYPE", None)
enable_cfg = True
if is_dit_quant:
if quant_op == "vllm":
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Vllm"
elif quant_op == "sgl":
if dit_quant_scheme == "int8":
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Sgl-ActVllm"
else:
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Sgl"
elif quant_op == "q8f":
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Q8F"
t5_quant_scheme = f"{t5_quant_scheme}-q8f"
clip_quant_scheme = f"{clip_quant_scheme}-q8f"
dit_quantized_ckpt = find_hf_model_path(model_path, dit_quant_scheme)
if os.path.exists(os.path.join(dit_quantized_ckpt, "config.json")):
with open(os.path.join(dit_quantized_ckpt, "config.json"), "r") as f:
quant_model_config = json.load(f)
else:
quant_model_config = {}
else:
mm_type = "Default"
dit_quantized_ckpt = None
quant_model_config = {}
vae_name_lower = vae_path_input.lower() if vae_path_input else ""
use_tae = "tae" in vae_name_lower or "lighttae" in vae_name_lower
use_lightvae = "lightvae" in vae_name_lower
need_scaled = "lighttae" in vae_name_lower
config = {
logger.info(f"VAE configuration - use_tae: {use_tae}, use_lightvae: {use_lightvae}, need_scaled: {need_scaled} (VAE: {vae_path_input})")
config_graio = {
"infer_steps": infer_steps,
"target_video_length": num_frames,
"target_width": int(resolution.split("x")[0]),
......@@ -462,38 +619,11 @@ def run_inference(
"self_attn_1_type": attention_type,
"cross_attn_1_type": attention_type,
"cross_attn_2_type": attention_type,
"seed": seed,
"enable_cfg": enable_cfg,
"sample_guide_scale": cfg_scale,
"sample_shift": sample_shift,
"cpu_offload": cpu_offload,
"offload_granularity": offload_granularity,
"offload_ratio": offload_ratio,
"t5_offload_granularity": t5_offload_granularity,
"dit_quantized_ckpt": dit_quantized_ckpt,
"mm_config": {
"mm_type": mm_type,
},
"fps": fps,
"feature_caching": "Tea" if enable_teacache else "NoCaching",
"coefficients": coefficient[0] if use_ret_steps else coefficient[1],
"use_ret_steps": use_ret_steps,
"teacache_thresh": teacache_thresh,
"t5_cpu_offload": t5_cpu_offload,
"unload_modules": unload_modules,
"t5_original_ckpt": t5_original_ckpt,
"t5_quantized": is_t5_quant,
"t5_quantized_ckpt": t5_quant_ckpt,
"t5_quant_scheme": t5_quant_scheme,
"clip_original_ckpt": clip_original_ckpt,
"clip_quantized": is_clip_quant,
"clip_quantized_ckpt": clip_quant_ckpt,
"clip_quant_scheme": clip_quant_scheme,
"vae_path": find_torch_model_path(model_path, "Wan2.1_VAE.pth"),
"use_tiling_vae": use_tiling_vae,
"use_tae": use_tae,
"tae_path": (find_torch_model_path(model_path, "taew2_1.pth") if use_tae else None),
"lazy_load": lazy_load,
"feature_caching": "NoCaching",
"do_mm_calib": False,
"parallel_attn_type": None,
"parallel_vae": False,
......@@ -504,14 +634,49 @@ def run_inference(
"strength_model": 1.0,
"use_prompt_enhancer": False,
"text_len": 512,
"rotary_chunk": rotary_chunk,
"rotary_chunk_size": rotary_chunk_size,
"clean_cuda_cache": clean_cuda_cache,
"denoising_step_list": [1000, 750, 500, 250],
"cpu_offload": True if "wan2.2" in model_cls else cpu_offload,
"offload_granularity": "phase" if "wan2.2" in model_cls else offload_granularity,
"t5_cpu_offload": t5_cpu_offload,
"clip_cpu_offload": clip_cpu_offload,
"vae_cpu_offload": vae_cpu_offload,
"dit_quantized": is_dit_quant,
"dit_quant_scheme": dit_quant_scheme,
"dit_quantized_ckpt": dit_quantized_ckpt,
"dit_original_ckpt": dit_original_ckpt,
"high_noise_quantized_ckpt": high_noise_quantized_ckpt,
"low_noise_quantized_ckpt": low_noise_quantized_ckpt,
"high_noise_original_ckpt": high_noise_original_ckpt,
"low_noise_original_ckpt": low_noise_original_ckpt,
"t5_original_ckpt": t5_original_ckpt,
"t5_quantized": is_t5_quant,
"t5_quantized_ckpt": t5_quantized_ckpt,
"t5_quant_scheme": t5_quant_scheme,
"clip_original_ckpt": clip_original_ckpt,
"clip_quantized": is_clip_quant,
"clip_quantized_ckpt": clip_quantized_ckpt,
"clip_quant_scheme": clip_quant_scheme,
"vae_path": os.path.join(model_path, vae_path_input),
"use_tiling_vae": use_tiling_vae,
"use_tae": use_tae,
"use_lightvae": use_lightvae,
"need_scaled": need_scaled,
"lazy_load": lazy_load,
"rope_chunk": rope_chunk,
"rope_chunk_size": rope_chunk_size,
"clean_cuda_cache": clean_cuda_cache,
"unload_modules": unload_modules,
"seq_parallel": False,
"warm_up_cpu_buffers": False,
"boundary_step_index": 2,
"boundary": 0.900,
"use_image_encoder": False if "wan2.2" in model_cls else True,
"rope_type": "flashinfer" if apply_rope_with_cos_sin_cache_inplace else "torch",
}
args = argparse.Namespace(
model_cls=model_cls,
seed=seed,
task=task,
model_path=model_path,
prompt_enhancer=None,
......@@ -519,11 +684,13 @@ def run_inference(
negative_prompt=negative_prompt,
image_path=image_path,
save_result_path=save_result_path,
return_result_tensor=False,
)
config = get_default_config()
config.update({k: v for k, v in vars(args).items()})
config.update(model_config)
config.update(quant_model_config)
config.update(config_graio)
logger.info(f"Using model: {model_path}")
logger.info(f"Inference configuration:\n{json.dumps(config, indent=4, ensure_ascii=False)}")
......@@ -539,28 +706,19 @@ def run_inference(
from lightx2v.infer import init_runner # noqa
runner = init_runner(config)
input_info = set_input_info(args)
current_config = config
cur_dit_quant_scheme = dit_quant_scheme
cur_clip_quant_scheme = clip_quant_scheme
cur_t5_quant_scheme = t5_quant_scheme
cur_precision_mode = precision_mode
cur_enable_teacache = enable_teacache
cur_dit_path = current_dit_path
cur_t5_path = current_t5_path
cur_clip_path = current_clip_path
if not lazy_load:
global_runner = runner
else:
runner.config = config
runner.run_pipeline()
del config, args, model_config, quant_model_config
if "dit_quantized_ckpt" in locals():
del dit_quantized_ckpt
if "t5_quant_ckpt" in locals():
del t5_quant_ckpt
if "clip_quant_ckpt" in locals():
del clip_quant_ckpt
runner.run_pipeline(input_info)
cleanup_memory()
return save_result_path
......@@ -571,44 +729,28 @@ def handle_lazy_load_change(lazy_load_enabled):
return gr.update(value=lazy_load_enabled)
def auto_configure(enable_auto_config, resolution):
def auto_configure(resolution):
"""Auto-configure inference options based on machine configuration and resolution"""
default_config = {
"torch_compile_val": False,
"lazy_load_val": False,
"rotary_chunk_val": False,
"rotary_chunk_size_val": 100,
"rope_chunk_val": False,
"rope_chunk_size_val": 100,
"clean_cuda_cache_val": False,
"cpu_offload_val": False,
"offload_granularity_val": "block",
"offload_ratio_val": 1,
"t5_cpu_offload_val": False,
"clip_cpu_offload_val": False,
"vae_cpu_offload_val": False,
"unload_modules_val": False,
"t5_offload_granularity_val": "model",
"attention_type_val": attn_op_choices[0][1],
"quant_op_val": quant_op_choices[0][1],
"dit_quant_scheme_val": "bf16",
"t5_quant_scheme_val": "bf16",
"clip_quant_scheme_val": "fp16",
"precision_mode_val": "fp32",
"use_tae_val": False,
"use_tiling_vae_val": False,
"enable_teacache_val": False,
"teacache_thresh_val": 0.26,
"use_ret_steps_val": False,
}
if not enable_auto_config:
return tuple(gr.update(value=default_config[key]) for key in default_config)
gpu_memory = round(get_gpu_memory())
cpu_memory = round(get_cpu_memory())
if is_fp8_supported_gpu():
quant_type = "fp8"
else:
quant_type = "int8"
attn_priority = ["sage_attn2", "flash_attn3", "flash_attn2", "torch_sdpa"]
attn_priority = ["sage_attn3", "sage_attn2", "flash_attn3", "flash_attn2", "torch_sdpa"]
if is_ada_architecture_gpu():
quant_op_priority = ["q8f", "vllm", "sgl"]
......@@ -643,25 +785,15 @@ def auto_configure(enable_auto_config, resolution):
else:
res = "480p"
if model_size == "14b":
is_14b = True
else:
is_14b = False
if res == "720p" and is_14b:
if res == "720p":
gpu_rules = [
(80, {}),
(48, {"cpu_offload_val": True, "offload_ratio_val": 0.5, "t5_cpu_offload_val": True}),
(40, {"cpu_offload_val": True, "offload_ratio_val": 0.8, "t5_cpu_offload_val": True}),
(32, {"cpu_offload_val": True, "offload_ratio_val": 1, "t5_cpu_offload_val": True}),
(40, {"cpu_offload_val": False, "t5_cpu_offload_val": True, "vae_cpu_offload_val": True, "clip_cpu_offload_val": True}),
(32, {"cpu_offload_val": True, "t5_cpu_offload_val": False, "vae_cpu_offload_val": False, "clip_cpu_offload_val": False}),
(
24,
{
"cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True,
},
),
......@@ -669,151 +801,64 @@ def auto_configure(enable_auto_config, resolution):
16,
{
"cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
"rotary_chunk_val": True,
"rotary_chunk_size_val": 100,
},
),
(
12,
{
"cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
"rotary_chunk_val": True,
"rotary_chunk_size_val": 100,
"clean_cuda_cache_val": True,
"use_tae_val": True,
"rope_chunk_val": True,
"rope_chunk_size_val": 100,
},
),
(
8,
{
"cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
"rotary_chunk_val": True,
"rotary_chunk_size_val": 100,
"rope_chunk_val": True,
"rope_chunk_size_val": 100,
"clean_cuda_cache_val": True,
"t5_quant_scheme_val": quant_type,
"clip_quant_scheme_val": quant_type,
"dit_quant_scheme_val": quant_type,
"lazy_load_val": True,
"unload_modules_val": True,
"use_tae_val": True,
},
),
]
elif is_14b:
else:
gpu_rules = [
(80, {}),
(48, {"cpu_offload_val": True, "offload_ratio_val": 0.2, "t5_cpu_offload_val": True}),
(40, {"cpu_offload_val": True, "offload_ratio_val": 0.5, "t5_cpu_offload_val": True}),
(24, {"cpu_offload_val": True, "offload_ratio_val": 0.8, "t5_cpu_offload_val": True}),
(40, {"cpu_offload_val": False, "t5_cpu_offload_val": True, "vae_cpu_offload_val": True, "clip_cpu_offload_val": True}),
(32, {"cpu_offload_val": True, "t5_cpu_offload_val": False, "vae_cpu_offload_val": False, "clip_cpu_offload_val": False}),
(
16,
24,
{
"cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True,
"offload_granularity_val": "block",
},
),
(
8,
(
16,
{
"cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
"t5_quant_scheme_val": quant_type,
"clip_quant_scheme_val": quant_type,
"dit_quant_scheme_val": quant_type,
"lazy_load_val": True,
"unload_modules_val": True,
"rotary_chunk_val": True,
"rotary_chunk_size_val": 10000,
"use_tae_val": True,
}
if res == "540p"
else {
"cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
"t5_quant_scheme_val": quant_type,
"clip_quant_scheme_val": quant_type,
"dit_quant_scheme_val": quant_type,
"lazy_load_val": True,
"unload_modules_val": True,
"use_tae_val": True,
}
),
},
),
]
else:
gpu_rules = [
(24, {}),
(
8,
{
"t5_cpu_offload_val": True,
"t5_offload_granularity_val": "block",
"t5_quant_scheme_val": quant_type,
"cpu_offload_val": True,
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
},
),
]
if is_14b:
cpu_rules = [
(128, {}),
(64, {"dit_quant_scheme_val": quant_type}),
(32, {"dit_quant_scheme_val": quant_type, "lazy_load_val": True}),
(
16,
{
"dit_quant_scheme_val": quant_type,
"t5_quant_scheme_val": quant_type,
"clip_quant_scheme_val": quant_type,
"lazy_load_val": True,
"unload_modules_val": True,
},
),
]
else:
cpu_rules = [
(64, {}),
(32, {"unload_modules_val": True}),
(
16,
{
"t5_quant_scheme_val": quant_type,
"lazy_load_val": True,
"unload_modules_val": True,
"use_tae_val": True,
},
),
]
......@@ -828,62 +873,238 @@ def auto_configure(enable_auto_config, resolution):
default_config.update(updates)
break
return tuple(gr.update(value=default_config[key]) for key in default_config)
return (
gr.update(value=default_config["lazy_load_val"]),
gr.update(value=default_config["rope_chunk_val"]),
gr.update(value=default_config["rope_chunk_size_val"]),
gr.update(value=default_config["clean_cuda_cache_val"]),
gr.update(value=default_config["cpu_offload_val"]),
gr.update(value=default_config["offload_granularity_val"]),
gr.update(value=default_config["t5_cpu_offload_val"]),
gr.update(value=default_config["clip_cpu_offload_val"]),
gr.update(value=default_config["vae_cpu_offload_val"]),
gr.update(value=default_config["unload_modules_val"]),
gr.update(value=default_config["attention_type_val"]),
gr.update(value=default_config["quant_op_val"]),
gr.update(value=default_config["use_tiling_vae_val"]),
)
def main():
with gr.Blocks(
title="Lightx2v (Lightweight Video Inference and Generation Engine)",
css="""
.main-content { max-width: 1400px; margin: auto; }
.output-video { max-height: 650px; }
.main-content { max-width: 1600px; margin: auto; padding: 20px; }
.warning { color: #ff6b6b; font-weight: bold; }
.advanced-options { background: #f9f9ff; border-radius: 10px; padding: 15px; }
.tab-button { font-size: 16px; padding: 10px 20px; }
.auto-config-title {
background: linear-gradient(45deg, #ff6b6b, #4ecdc4);
background-clip: text;
-webkit-background-clip: text;
color: transparent;
text-align: center;
margin: 0 !important;
padding: 8px;
border: 2px solid #4ecdc4;
border-radius: 8px;
background-color: #f0f8ff;
/* Model configuration area styles */
.model-config {
margin-bottom: 20px !important;
border: 1px solid #e0e0e0;
border-radius: 12px;
padding: 15px;
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
}
.auto-config-checkbox {
border: 2px solid #ff6b6b !important;
border-radius: 8px !important;
padding: 10px !important;
background: linear-gradient(135deg, #fff5f5, #f0fff0) !important;
box-shadow: 0 2px 8px rgba(255, 107, 107, 0.2) !important;
/* Input parameters area styles */
.input-params {
margin-bottom: 20px !important;
border: 1px solid #e0e0e0;
border-radius: 12px;
padding: 15px;
background: linear-gradient(135deg, #fff5f5 0%, #ffeef0 100%);
}
.auto-config-checkbox label {
font-size: 16px !important;
/* Output video area styles */
.output-video {
border: 1px solid #e0e0e0;
border-radius: 12px;
padding: 20px;
background: linear-gradient(135deg, #e0f2fe 0%, #bae6fd 100%);
min-height: 400px;
}
/* Generate button styles */
.generate-btn {
width: 100%;
margin-top: 20px;
padding: 15px 30px !important;
font-size: 18px !important;
font-weight: bold !important;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
border: none !important;
border-radius: 10px !important;
box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important;
transition: all 0.3s ease !important;
}
.generate-btn:hover {
transform: translateY(-2px);
box-shadow: 0 6px 20px rgba(102, 126, 234, 0.6) !important;
}
/* Accordion header styles */
.model-config .gr-accordion-header,
.input-params .gr-accordion-header,
.output-video .gr-accordion-header {
font-size: 20px !important;
font-weight: bold !important;
color: #2c3e50 !important;
padding: 15px !important;
}
/* Optimize spacing */
.gr-row {
margin-bottom: 15px;
}
/* Video player styles */
.output-video video {
border-radius: 10px;
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1);
}
""",
) as demo:
gr.Markdown(f"# 🎬 {model_cls} Video Generator")
gr.Markdown(f"### Using Model: {model_path}")
gr.Markdown(f"# 🎬 LightX2V Video Generator")
with gr.Tabs() as tabs:
with gr.Tab("Basic Settings", id=1):
# Main layout: left and right columns
with gr.Row():
with gr.Column(scale=4):
with gr.Group():
gr.Markdown("## 📥 Input Parameters")
# Left: configuration and input area
with gr.Column(scale=5):
# Model configuration area
with gr.Accordion("🗂️ Model Configuration", open=True, elem_classes=["model-config"]):
# FP8 support notice
if not is_fp8_supported_gpu():
gr.Markdown("⚠️ **Your device does not support FP8 inference**. Models containing FP8 have been automatically hidden.")
# Hidden state components
model_path_input = gr.Textbox(value=model_path, visible=False)
# Model type + Task type
with gr.Row():
model_type_input = gr.Radio(
label="Model Type",
choices=["wan2.1", "wan2.2"],
value="wan2.1",
info="wan2.2 requires separate high noise and low noise models",
)
task_type_input = gr.Radio(
label="Task Type",
choices=["i2v", "t2v"],
value="i2v",
info="i2v: Image-to-video, t2v: Text-to-video",
)
# wan2.1: Diffusion model (single row)
with gr.Row() as wan21_row:
dit_path_input = gr.Dropdown(
label="🎨 Diffusion Model",
choices=get_dit_choices(model_path, "wan2.1"),
value=get_dit_choices(model_path, "wan2.1")[0] if get_dit_choices(model_path, "wan2.1") else "",
allow_custom_value=True,
visible=True,
)
# wan2.2 specific: high noise model + low noise model (hidden by default)
with gr.Row(visible=False) as wan22_row:
high_noise_path_input = gr.Dropdown(
label="🔊 High Noise Model",
choices=get_high_noise_choices(model_path),
value=get_high_noise_choices(model_path)[0] if get_high_noise_choices(model_path) else "",
allow_custom_value=True,
)
low_noise_path_input = gr.Dropdown(
label="🔇 Low Noise Model",
choices=get_low_noise_choices(model_path),
value=get_low_noise_choices(model_path)[0] if get_low_noise_choices(model_path) else "",
allow_custom_value=True,
)
# Text encoder (single row)
with gr.Row():
t5_path_input = gr.Dropdown(
label="📝 Text Encoder",
choices=get_t5_choices(model_path),
value=get_t5_choices(model_path)[0] if get_t5_choices(model_path) else "",
allow_custom_value=True,
)
if task == "i2v":
# Image encoder + VAE decoder
with gr.Row():
clip_path_input = gr.Dropdown(
label="🖼️ Image Encoder",
choices=get_clip_choices(model_path),
value=get_clip_choices(model_path)[0] if get_clip_choices(model_path) else "",
allow_custom_value=True,
)
vae_path_input = gr.Dropdown(
label="🎞️ VAE Decoder",
choices=get_vae_choices(model_path),
value=get_vae_choices(model_path)[0] if get_vae_choices(model_path) else "",
allow_custom_value=True,
)
# Attention operator and quantization matrix multiplication operator
with gr.Row():
attention_type = gr.Dropdown(
label="⚡ Attention Operator",
choices=[op[1] for op in attn_op_choices],
value=attn_op_choices[0][1] if attn_op_choices else "",
info="Use appropriate attention operators to accelerate inference",
)
quant_op = gr.Dropdown(
label="Quantization Matmul Operator",
choices=[op[1] for op in quant_op_choices],
value=quant_op_choices[0][1],
info="Select quantization matrix multiplication operator to accelerate inference",
interactive=True,
)
# Determine if model is distill version
def is_distill_model(model_type, dit_path, high_noise_path):
"""Determine if model is distill version based on model type and path"""
if model_type == "wan2.1":
check_name = dit_path.lower() if dit_path else ""
else:
check_name = high_noise_path.lower() if high_noise_path else ""
return "4step" in check_name
# Model type change event
def on_model_type_change(model_type, model_path_val):
if model_type == "wan2.2":
return gr.update(visible=False), gr.update(visible=True), gr.update()
else:
# Update wan2.1 Diffusion model options
dit_choices = get_dit_choices(model_path_val, "wan2.1")
return (
gr.update(visible=True),
gr.update(visible=False),
gr.update(choices=dit_choices, value=dit_choices[0] if dit_choices else ""),
)
model_type_input.change(
fn=on_model_type_change,
inputs=[model_type_input, model_path_input],
outputs=[wan21_row, wan22_row, dit_path_input],
)
# Input parameters area
with gr.Accordion("📥 Input Parameters", open=True, elem_classes=["input-params"]):
# Image input (shown for i2v)
with gr.Row(visible=True) as image_input_row:
image_path = gr.Image(
label="Input Image",
type="filepath",
height=300,
interactive=True,
visible=True,
)
# Task type change event
def on_task_type_change(task_type):
return gr.update(visible=(task_type == "i2v"))
task_type_input.change(
fn=on_task_type_change,
inputs=[task_type_input],
outputs=[image_input_row],
)
with gr.Row():
......@@ -900,7 +1121,7 @@ def main():
lines=3,
placeholder="What you don't want to appear in the video...",
max_lines=5,
value="Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards",
value="Camera shake, bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards",
)
with gr.Column():
resolution = gr.Dropdown(
......@@ -927,15 +1148,6 @@ def main():
label="Maximum Resolution",
)
with gr.Column():
with gr.Group():
gr.Markdown("### 🚀 **Smart Configuration Recommendation**", elem_classes=["auto-config-title"])
enable_auto_config = gr.Checkbox(
label="🎯 **Auto-configure Inference Options**",
value=False,
info="💡 **Automatically optimize GPU settings to match the current resolution. After changing the resolution, please re-check this option to prevent potential performance degradation or runtime errors.**",
elem_classes=["auto-config-checkbox"],
)
with gr.Column(scale=9):
seed = gr.Slider(
label="Random Seed",
......@@ -944,25 +1156,21 @@ def main():
step=1,
value=generate_random_seed(),
)
with gr.Column(scale=1):
randomize_btn = gr.Button("🎲 Randomize", variant="secondary")
randomize_btn.click(fn=generate_random_seed, inputs=None, outputs=seed)
with gr.Column():
# Set default inference steps based on model class
if model_cls == "wan2.1_distill":
default_dit = get_dit_choices(model_path, "wan2.1")[0] if get_dit_choices(model_path, "wan2.1") else ""
default_high_noise = get_high_noise_choices(model_path)[0] if get_high_noise_choices(model_path) else ""
default_is_distill = is_distill_model("wan2.1", default_dit, default_high_noise)
if default_is_distill:
infer_steps = gr.Slider(
label="Inference Steps",
minimum=4,
maximum=4,
minimum=1,
maximum=100,
step=1,
value=4,
interactive=False,
info="Inference steps fixed at 4 for optimal performance for distill model.",
info="Distill model inference steps default to 4.",
)
elif model_cls == "wan2.1":
if task == "i2v":
else:
infer_steps = gr.Slider(
label="Inference Steps",
minimum=1,
......@@ -971,31 +1179,45 @@ def main():
value=40,
info="Number of inference steps for video generation. Increasing steps may improve quality but reduce speed.",
)
elif task == "t2v":
infer_steps = gr.Slider(
label="Inference Steps",
minimum=1,
maximum=100,
step=1,
value=50,
info="Number of inference steps for video generation. Increasing steps may improve quality but reduce speed.",
# Dynamically update inference steps when model path changes
def update_infer_steps(model_type, dit_path, high_noise_path):
is_distill = is_distill_model(model_type, dit_path, high_noise_path)
if is_distill:
return gr.update(minimum=1, maximum=100, value=4, interactive=True)
else:
return gr.update(minimum=1, maximum=100, value=40, interactive=True)
# Listen to model path changes
dit_path_input.change(
fn=lambda mt, dp, hnp: update_infer_steps(mt, dp, hnp),
inputs=[model_type_input, dit_path_input, high_noise_path_input],
outputs=[infer_steps],
)
high_noise_path_input.change(
fn=lambda mt, dp, hnp: update_infer_steps(mt, dp, hnp),
inputs=[model_type_input, dit_path_input, high_noise_path_input],
outputs=[infer_steps],
)
model_type_input.change(
fn=lambda mt, dp, hnp: update_infer_steps(mt, dp, hnp),
inputs=[model_type_input, dit_path_input, high_noise_path_input],
outputs=[infer_steps],
)
# Set default CFG based on model class
default_enable_cfg = False if model_cls == "wan2.1_distill" else True
# CFG scale factor: default to 1 for distill, otherwise 5
default_cfg_scale = 1 if default_is_distill else 5
# enable_cfg is not exposed to frontend, automatically set based on cfg_scale
# If cfg_scale == 1, then enable_cfg = False, otherwise enable_cfg = True
default_enable_cfg = False if default_cfg_scale == 1 else True
enable_cfg = gr.Checkbox(
label="Enable Classifier-Free Guidance",
value=default_enable_cfg,
info="Enable classifier-free guidance to control prompt strength",
)
cfg_scale = gr.Slider(
label="CFG Scale Factor",
minimum=1,
maximum=10,
step=1,
value=5,
info="Controls the influence strength of the prompt. Higher values give more influence to the prompt.",
visible=False, # Hidden, not exposed to frontend
)
with gr.Row():
sample_shift = gr.Slider(
label="Distribution Shift",
value=5,
......@@ -1004,7 +1226,56 @@ def main():
step=1,
info="Controls the degree of distribution shift for samples. Larger values indicate more significant shifts.",
)
cfg_scale = gr.Slider(
label="CFG Scale Factor",
minimum=1,
maximum=10,
step=1,
value=default_cfg_scale,
info="Controls the influence strength of the prompt. Higher values give more influence to the prompt. When value is 1, CFG is automatically disabled.",
)
# Update enable_cfg based on cfg_scale
def update_enable_cfg(cfg_scale_val):
"""Automatically set enable_cfg based on cfg_scale value"""
if cfg_scale_val == 1:
return gr.update(value=False)
else:
return gr.update(value=True)
# Dynamically update CFG scale factor and enable_cfg when model path changes
def update_cfg_scale(model_type, dit_path, high_noise_path):
is_distill = is_distill_model(model_type, dit_path, high_noise_path)
if is_distill:
new_cfg_scale = 1
else:
new_cfg_scale = 5
new_enable_cfg = False if new_cfg_scale == 1 else True
return gr.update(value=new_cfg_scale), gr.update(value=new_enable_cfg)
dit_path_input.change(
fn=lambda mt, dp, hnp: update_cfg_scale(mt, dp, hnp),
inputs=[model_type_input, dit_path_input, high_noise_path_input],
outputs=[cfg_scale, enable_cfg],
)
high_noise_path_input.change(
fn=lambda mt, dp, hnp: update_cfg_scale(mt, dp, hnp),
inputs=[model_type_input, dit_path_input, high_noise_path_input],
outputs=[cfg_scale, enable_cfg],
)
model_type_input.change(
fn=lambda mt, dp, hnp: update_cfg_scale(mt, dp, hnp),
inputs=[model_type_input, dit_path_input, high_noise_path_input],
outputs=[cfg_scale, enable_cfg],
)
cfg_scale.change(
fn=update_enable_cfg,
inputs=[cfg_scale],
outputs=[enable_cfg],
)
with gr.Row():
fps = gr.Slider(
label="Frames Per Second (FPS)",
minimum=8,
......@@ -1026,284 +1297,109 @@ def main():
label="Output Video Path",
value=generate_unique_filename(output_dir),
info="Must include .mp4 extension. If left blank or using the default value, a unique filename will be automatically generated.",
visible=False, # Hide output path, auto-generated
)
with gr.Column(scale=6):
gr.Markdown("## 📤 Generated Video")
with gr.Column(scale=4):
with gr.Accordion("📤 Generated Video", open=True, elem_classes=["output-video"]):
output_video = gr.Video(
label="Result",
height=624,
width=360,
label="",
height=600,
autoplay=True,
elem_classes=["output-video"],
)
infer_btn = gr.Button("Generate Video", variant="primary", size="lg")
with gr.Tab("⚙️ Advanced Options", id=2):
with gr.Group(elem_classes="advanced-options"):
gr.Markdown("### GPU Memory Optimization")
with gr.Row():
rotary_chunk = gr.Checkbox(
label="Chunked Rotary Position Embedding",
value=False,
info="When enabled, processes rotary position embeddings in chunks to save GPU memory.",
)
rotary_chunk_size = gr.Slider(
label="Rotary Embedding Chunk Size",
value=100,
minimum=100,
maximum=10000,
step=100,
info="Controls the chunk size for applying rotary embeddings. Larger values may improve performance but increase memory usage. Only effective if 'rotary_chunk' is checked.",
)
unload_modules = gr.Checkbox(
label="Unload Modules",
value=False,
info="Unload modules (T5, CLIP, DIT, etc.) after inference to reduce GPU/CPU memory usage",
)
clean_cuda_cache = gr.Checkbox(
label="Clean CUDA Memory Cache",
value=False,
info="When enabled, frees up GPU memory promptly but slows down inference.",
)
gr.Markdown("### Asynchronous Offloading")
with gr.Row():
cpu_offload = gr.Checkbox(
label="CPU Offloading",
value=False,
info="Offload parts of the model computation from GPU to CPU to reduce GPU memory usage",
)
lazy_load = gr.Checkbox(
label="Enable Lazy Loading",
value=False,
info="Lazy load model components during inference. Requires CPU loading and DIT quantization.",
)
offload_granularity = gr.Dropdown(
label="Dit Offload Granularity",
choices=["block", "phase"],
value="phase",
info="Sets Dit model offloading granularity: blocks or computational phases",
)
offload_ratio = gr.Slider(
label="Offload ratio for Dit model",
minimum=0.0,
maximum=1.0,
step=0.1,
value=1.0,
info="Controls how much of the Dit model is offloaded to the CPU",
)
t5_cpu_offload = gr.Checkbox(
label="T5 CPU Offloading",
value=False,
info="Offload the T5 Encoder model to CPU to reduce GPU memory usage",
)
t5_offload_granularity = gr.Dropdown(
label="T5 Encoder Offload Granularity",
choices=["model", "block"],
value="model",
info="Controls the granularity when offloading the T5 Encoder model to CPU",
)
gr.Markdown("### Low-Precision Quantization")
with gr.Row():
torch_compile = gr.Checkbox(
label="Torch Compile",
value=False,
info="Use torch.compile to accelerate the inference process",
)
attention_type = gr.Dropdown(
label="Attention Operator",
choices=[op[1] for op in attn_op_choices],
value=attn_op_choices[0][1],
info="Use appropriate attention operators to accelerate inference",
)
quant_op = gr.Dropdown(
label="Quantization Matmul Operator",
choices=[op[1] for op in quant_op_choices],
value=quant_op_choices[0][1],
info="Select the quantization matrix multiplication operator to accelerate inference",
interactive=True,
)
# Get dynamic quantization options
quant_options = get_quantization_options(model_path)
dit_quant_scheme = gr.Dropdown(
label="Dit",
choices=quant_options["dit_choices"],
value=quant_options["dit_default"],
info="Quantization precision for the Dit model",
)
t5_quant_scheme = gr.Dropdown(
label="T5 Encoder",
choices=quant_options["t5_choices"],
value=quant_options["t5_default"],
info="Quantization precision for the T5 Encoder model",
)
clip_quant_scheme = gr.Dropdown(
label="Clip Encoder",
choices=quant_options["clip_choices"],
value=quant_options["clip_default"],
info="Quantization precision for the Clip Encoder",
)
precision_mode = gr.Dropdown(
label="Precision Mode for Sensitive Layers",
choices=["fp32", "bf16"],
value="fp32",
info="Select the numerical precision for critical model components like normalization and embedding layers. FP32 offers higher accuracy, while BF16 improves performance on compatible hardware.",
show_label=False,
)
gr.Markdown("### Variational Autoencoder (VAE)")
with gr.Row():
use_tae = gr.Checkbox(
label="Use Tiny VAE",
value=False,
info="Use a lightweight VAE model to accelerate the decoding process",
)
use_tiling_vae = gr.Checkbox(
label="VAE Tiling Inference",
value=False,
info="Use VAE tiling inference to reduce GPU memory usage",
)
gr.Markdown("### Feature Caching")
with gr.Row():
enable_teacache = gr.Checkbox(
label="Tea Cache",
value=False,
info="Cache features during inference to reduce the number of inference steps",
)
teacache_thresh = gr.Slider(
label="Tea Cache Threshold",
value=0.26,
minimum=0,
maximum=1,
info="Higher acceleration may result in lower quality —— Setting to 0.1 provides ~2.0x acceleration, setting to 0.2 provides ~3.0x acceleration",
)
use_ret_steps = gr.Checkbox(
label="Cache Only Key Steps",
value=False,
info="When checked, cache is written only at key steps where the scheduler returns results; when unchecked, cache is written at all steps to ensure the highest quality",
)
enable_auto_config.change(
infer_btn = gr.Button("🎬 Generate Video", variant="primary", size="lg", elem_classes=["generate-btn"])
rope_chunk = gr.Checkbox(label="Chunked Rotary Position Embedding", value=False, visible=False)
rope_chunk_size = gr.Slider(label="Rotary Embedding Chunk Size", value=100, minimum=100, maximum=10000, step=100, visible=False)
unload_modules = gr.Checkbox(label="Unload Modules", value=False, visible=False)
clean_cuda_cache = gr.Checkbox(label="Clean CUDA Memory Cache", value=False, visible=False)
cpu_offload = gr.Checkbox(label="CPU Offloading", value=False, visible=False)
lazy_load = gr.Checkbox(label="Enable Lazy Loading", value=False, visible=False)
offload_granularity = gr.Dropdown(label="Dit Offload Granularity", choices=["block", "phase"], value="phase", visible=False)
t5_cpu_offload = gr.Checkbox(label="T5 CPU Offloading", value=False, visible=False)
clip_cpu_offload = gr.Checkbox(label="CLIP CPU Offloading", value=False, visible=False)
vae_cpu_offload = gr.Checkbox(label="VAE CPU Offloading", value=False, visible=False)
use_tiling_vae = gr.Checkbox(label="VAE Tiling Inference", value=False, visible=False)
resolution.change(
fn=auto_configure,
inputs=[enable_auto_config, resolution],
inputs=[resolution],
outputs=[
torch_compile,
lazy_load,
rotary_chunk,
rotary_chunk_size,
rope_chunk,
rope_chunk_size,
clean_cuda_cache,
cpu_offload,
offload_granularity,
offload_ratio,
t5_cpu_offload,
clip_cpu_offload,
vae_cpu_offload,
unload_modules,
t5_offload_granularity,
attention_type,
quant_op,
dit_quant_scheme,
t5_quant_scheme,
clip_quant_scheme,
precision_mode,
use_tae,
use_tiling_vae,
enable_teacache,
teacache_thresh,
use_ret_steps,
],
)
lazy_load.change(
fn=handle_lazy_load_change,
inputs=[lazy_load],
outputs=[unload_modules],
)
if task == "i2v":
infer_btn.click(
fn=run_inference,
inputs=[
prompt,
negative_prompt,
save_result_path,
torch_compile,
infer_steps,
num_frames,
resolution,
seed,
sample_shift,
enable_teacache,
teacache_thresh,
use_ret_steps,
enable_cfg,
cfg_scale,
dit_quant_scheme,
t5_quant_scheme,
clip_quant_scheme,
fps,
use_tae,
use_tiling_vae,
demo.load(
fn=lambda res: auto_configure(res),
inputs=[resolution],
outputs=[
lazy_load,
precision_mode,
rope_chunk,
rope_chunk_size,
clean_cuda_cache,
cpu_offload,
offload_granularity,
offload_ratio,
t5_cpu_offload,
clip_cpu_offload,
vae_cpu_offload,
unload_modules,
t5_offload_granularity,
attention_type,
quant_op,
rotary_chunk,
rotary_chunk_size,
clean_cuda_cache,
image_path,
use_tiling_vae,
],
outputs=output_video,
)
else:
infer_btn.click(
fn=run_inference,
inputs=[
prompt,
negative_prompt,
save_result_path,
torch_compile,
infer_steps,
num_frames,
resolution,
seed,
sample_shift,
enable_teacache,
teacache_thresh,
use_ret_steps,
enable_cfg,
cfg_scale,
dit_quant_scheme,
t5_quant_scheme,
clip_quant_scheme,
fps,
use_tae,
use_tiling_vae,
lazy_load,
precision_mode,
cpu_offload,
offload_granularity,
offload_ratio,
t5_cpu_offload,
clip_cpu_offload,
vae_cpu_offload,
unload_modules,
t5_offload_granularity,
attention_type,
quant_op,
rotary_chunk,
rotary_chunk_size,
rope_chunk,
rope_chunk_size,
clean_cuda_cache,
model_path_input,
model_type_input,
task_type_input,
dit_path_input,
high_noise_path_input,
low_noise_path_input,
t5_path_input,
clip_path_input,
vae_path_input,
image_path,
],
outputs=output_video,
)
......@@ -1312,27 +1408,16 @@ def main():
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Light Video Generation")
parser = argparse.ArgumentParser(description="Lightweight Video Generation")
parser.add_argument("--model_path", type=str, required=True, help="Model folder path")
parser.add_argument(
"--model_cls",
type=str,
choices=["wan2.1", "wan2.1_distill"],
default="wan2.1",
help="Model class to use (wan2.1: standard model, wan2.1_distill: distilled model for faster inference)",
)
parser.add_argument("--model_size", type=str, required=True, choices=["14b", "1.3b"], help="Model type to use")
parser.add_argument("--task", type=str, required=True, choices=["i2v", "t2v"], help="Specify the task type. 'i2v' for image-to-video translation, 't2v' for text-to-video generation.")
parser.add_argument("--server_port", type=int, default=7862, help="Server port")
parser.add_argument("--server_name", type=str, default="0.0.0.0", help="Server ip")
parser.add_argument("--server_name", type=str, default="0.0.0.0", help="Server IP")
parser.add_argument("--output_dir", type=str, default="./outputs", help="Output video save directory")
args = parser.parse_args()
global model_path, model_cls, model_size, output_dir
global model_path, model_cls, output_dir
model_path = args.model_path
model_cls = args.model_cls
model_size = args.model_size
task = args.task
model_cls = "wan2.1"
output_dir = args.output_dir
main()
......@@ -4,6 +4,9 @@ import glob
import importlib.util
import json
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["DTYPE"] = "BF16"
import random
from datetime import datetime
......@@ -12,6 +15,15 @@ import psutil
import torch
from loguru import logger
from lightx2v.utils.input_info import set_input_info
from lightx2v.utils.set_config import get_default_config
try:
from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace
except ImportError:
apply_rope_with_cos_sin_cache_inplace = None
logger.add(
"inference_logs.log",
rotation="100 MB",
......@@ -24,38 +36,196 @@ logger.add(
MAX_NUMPY_SEED = 2**32 - 1
def find_hf_model_path(model_path, subdir=["original", "fp8", "int8"]):
paths_to_check = [model_path]
if isinstance(subdir, list):
for sub in subdir:
paths_to_check.append(os.path.join(model_path, sub))
def scan_model_path_contents(model_path):
"""扫描 model_path 目录,返回可用的文件和子目录"""
if not model_path or not os.path.exists(model_path):
return {"dirs": [], "files": [], "safetensors_dirs": [], "pth_files": []}
dirs = []
files = []
safetensors_dirs = []
pth_files = []
try:
for item in os.listdir(model_path):
item_path = os.path.join(model_path, item)
if os.path.isdir(item_path):
dirs.append(item)
# 检查目录是否包含 safetensors 文件
if glob.glob(os.path.join(item_path, "*.safetensors")):
safetensors_dirs.append(item)
elif os.path.isfile(item_path):
files.append(item)
if item.endswith(".pth"):
pth_files.append(item)
except Exception as e:
logger.warning(f"扫描目录失败: {e}")
return {
"dirs": sorted(dirs),
"files": sorted(files),
"safetensors_dirs": sorted(safetensors_dirs),
"pth_files": sorted(pth_files),
}
def get_dit_choices(model_path, model_type="wan2.1"):
"""获取 Diffusion 模型可选项(根据模型类型筛选)"""
contents = scan_model_path_contents(model_path)
excluded_keywords = ["vae", "tae", "clip", "t5", "high_noise", "low_noise"]
fp8_supported = is_fp8_supported_gpu()
if model_type == "wan2.1":
# wan2.1: 筛选包含 wan2.1 或 Wan2.1 的文件/目录
def is_valid(name):
name_lower = name.lower()
if "wan2.1" not in name_lower:
return False
if not fp8_supported and "fp8" in name_lower:
return False
return not any(kw in name_lower for kw in excluded_keywords)
else:
paths_to_check.append(os.path.join(model_path, subdir))
# wan2.2: 筛选包含 wan2.2 或 Wan2.2 的文件/目录
def is_valid(name):
name_lower = name.lower()
if "wan2.2" not in name_lower:
return False
if not fp8_supported and "fp8" in name_lower:
return False
return not any(kw in name_lower for kw in excluded_keywords)
for path in paths_to_check:
safetensors_pattern = os.path.join(path, "*.safetensors")
safetensors_files = glob.glob(safetensors_pattern)
if safetensors_files:
logger.info(f"Found Hugging Face model files in: {path}")
return path
raise FileNotFoundError(f"No Hugging Face model files (.safetensors) found.\nPlease download the model from: https://huggingface.co/lightx2v/ or specify the model path in the configuration file.")
# 筛选符合条件的目录和文件
dir_choices = [d for d in contents["dirs"] if is_valid(d)]
file_choices = [f for f in contents["files"] if is_valid(f)]
choices = dir_choices + file_choices
return choices if choices else [""]
def find_torch_model_path(model_path, filename=None, subdir=["original", "fp8", "int8"]):
paths_to_check = [
os.path.join(model_path, filename),
]
if isinstance(subdir, list):
for sub in subdir:
paths_to_check.append(os.path.join(model_path, sub, filename))
def get_high_noise_choices(model_path):
"""获取高噪模型可选项(包含 high_noise 的文件/目录)"""
contents = scan_model_path_contents(model_path)
fp8_supported = is_fp8_supported_gpu()
def is_valid(name):
name_lower = name.lower()
if not fp8_supported and "fp8" in name_lower:
return False
return "high_noise" in name_lower or "high-noise" in name_lower
dir_choices = [d for d in contents["dirs"] if is_valid(d)]
file_choices = [f for f in contents["files"] if is_valid(f)]
choices = dir_choices + file_choices
return choices if choices else [""]
def get_low_noise_choices(model_path):
"""获取低噪模型可选项(包含 low_noise 的文件/目录)"""
contents = scan_model_path_contents(model_path)
fp8_supported = is_fp8_supported_gpu()
def is_valid(name):
name_lower = name.lower()
if not fp8_supported and "fp8" in name_lower:
return False
return "low_noise" in name_lower or "low-noise" in name_lower
dir_choices = [d for d in contents["dirs"] if is_valid(d)]
file_choices = [f for f in contents["files"] if is_valid(f)]
choices = dir_choices + file_choices
return choices if choices else [""]
def get_t5_choices(model_path):
"""获取 T5 模型可选项(.pth 或 .safetensors 文件,包含 t5 关键字)"""
contents = scan_model_path_contents(model_path)
fp8_supported = is_fp8_supported_gpu()
# 从 .pth 文件中筛选
pth_choices = [f for f in contents["pth_files"] if "t5" in f.lower() and (fp8_supported or "fp8" not in f.lower())]
# 从 .safetensors 文件中筛选
safetensors_choices = [f for f in contents["files"] if f.endswith(".safetensors") and "t5" in f.lower() and (fp8_supported or "fp8" not in f.lower())]
# 从包含 safetensors 的目录中筛选
safetensors_dir_choices = [d for d in contents["safetensors_dirs"] if "t5" in d.lower() and (fp8_supported or "fp8" not in d.lower())]
choices = pth_choices + safetensors_choices + safetensors_dir_choices
return choices if choices else [""]
def get_clip_choices(model_path):
"""获取 CLIP 模型可选项(.pth 或 .safetensors 文件,包含 clip 关键字)"""
contents = scan_model_path_contents(model_path)
fp8_supported = is_fp8_supported_gpu()
# 从 .pth 文件中筛选
pth_choices = [f for f in contents["pth_files"] if "clip" in f.lower() and (fp8_supported or "fp8" not in f.lower())]
# 从 .safetensors 文件中筛选
safetensors_choices = [f for f in contents["files"] if f.endswith(".safetensors") and "clip" in f.lower() and (fp8_supported or "fp8" not in f.lower())]
# 从包含 safetensors 的目录中筛选
safetensors_dir_choices = [d for d in contents["safetensors_dirs"] if "clip" in d.lower() and (fp8_supported or "fp8" not in d.lower())]
choices = pth_choices + safetensors_choices + safetensors_dir_choices
return choices if choices else [""]
def get_vae_choices(model_path):
"""获取 VAE 模型可选项(.pth 或 .safetensors 文件,包含 vae/VAE/tae 关键字)"""
contents = scan_model_path_contents(model_path)
fp8_supported = is_fp8_supported_gpu()
# 从 .pth 文件中筛选
pth_choices = [f for f in contents["pth_files"] if any(kw in f.lower() for kw in ["vae", "tae"]) and (fp8_supported or "fp8" not in f.lower())]
# 从 .safetensors 文件中筛选
safetensors_choices = [f for f in contents["files"] if f.endswith(".safetensors") and any(kw in f.lower() for kw in ["vae", "tae"]) and (fp8_supported or "fp8" not in f.lower())]
# 从包含 safetensors 的目录中筛选
safetensors_dir_choices = [d for d in contents["safetensors_dirs"] if any(kw in d.lower() for kw in ["vae", "tae"]) and (fp8_supported or "fp8" not in d.lower())]
choices = pth_choices + safetensors_choices + safetensors_dir_choices
return choices if choices else [""]
def detect_quant_scheme(model_name):
"""根据模型名字自动检测量化精度
- 如果模型名字包含 "int8" → "int8"
- 如果模型名字包含 "fp8" 且设备支持 → "fp8"
- 否则返回 None(表示不使用量化)
"""
if not model_name:
return None
name_lower = model_name.lower()
if "int8" in name_lower:
return "int8"
elif "fp8" in name_lower:
if is_fp8_supported_gpu():
return "fp8"
else:
paths_to_check.append(os.path.join(model_path, subdir, filename))
print(paths_to_check)
for path in paths_to_check:
if os.path.exists(path):
logger.info(f"Found PyTorch model checkpoint: {path}")
return path
raise FileNotFoundError(f"PyTorch model file '{filename}' not found.\nPlease download the model from https://huggingface.co/lightx2v/ or specify the model path in the configuration file.")
# 设备不支持fp8,返回None(使用默认精度)
return None
return None
def update_model_path_options(model_path, model_type="wan2.1"):
"""当 model_path 或 model_type 改变时,更新所有模型路径选择器"""
dit_choices = get_dit_choices(model_path, model_type)
high_noise_choices = get_high_noise_choices(model_path)
low_noise_choices = get_low_noise_choices(model_path)
t5_choices = get_t5_choices(model_path)
clip_choices = get_clip_choices(model_path)
vae_choices = get_vae_choices(model_path)
return (
gr.update(choices=dit_choices, value=dit_choices[0] if dit_choices else ""),
gr.update(choices=high_noise_choices, value=high_noise_choices[0] if high_noise_choices else ""),
gr.update(choices=low_noise_choices, value=low_noise_choices[0] if low_noise_choices else ""),
gr.update(choices=t5_choices, value=t5_choices[0] if t5_choices else ""),
gr.update(choices=clip_choices, value=clip_choices[0] if clip_choices else ""),
gr.update(choices=vae_choices, value=vae_choices[0] if vae_choices else ""),
)
def generate_random_seed():
......@@ -109,12 +279,18 @@ def get_available_attn_ops():
else:
available_ops.append(("flash_attn3", False))
q8f_installed = is_module_installed("sageattention")
if q8f_installed:
sage_installed = is_module_installed("sageattention")
if sage_installed:
available_ops.append(("sage_attn2", True))
else:
available_ops.append(("sage_attn2", False))
sage3_installed = is_module_installed("sageattn3")
if sage3_installed:
available_ops.append(("sage_attn3", True))
else:
available_ops.append(("sage_attn3", False))
torch_installed = is_module_installed("torch")
if torch_installed:
available_ops.append(("torch_sdpa", True))
......@@ -165,7 +341,7 @@ def cleanup_memory():
def generate_unique_filename(output_dir):
os.makedirs(output_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
return os.path.join(output_dir, f"{model_cls}_{timestamp}.mp4")
return os.path.join(output_dir, f"{timestamp}.mp4")
def is_fp8_supported_gpu():
......@@ -233,13 +409,25 @@ def get_quantization_options(model_path):
return {"dit_choices": dit_choices, "dit_default": dit_default, "t5_choices": t5_choices, "t5_default": t5_default, "clip_choices": clip_choices, "clip_default": clip_default}
def determine_model_cls(model_type, dit_name, high_noise_name):
"""根据模型类型和文件名确定 model_cls"""
# 确定要检查的文件名
if model_type == "wan2.1":
check_name = dit_name.lower() if dit_name else ""
is_distill = "4step" in check_name
return "wan2.1_distill" if is_distill else "wan2.1"
else:
# wan2.2
check_name = high_noise_name.lower() if high_noise_name else ""
is_distill = "4step" in check_name
return "wan2.2_moe_distill" if is_distill else "wan2.2_moe"
global_runner = None
current_config = None
cur_dit_quant_scheme = None
cur_clip_quant_scheme = None
cur_t5_quant_scheme = None
cur_precision_mode = None
cur_enable_teacache = None
cur_dit_path = None
cur_t5_path = None
cur_clip_path = None
available_quant_ops = get_available_quant_ops()
quant_op_choices = []
......@@ -249,8 +437,29 @@ for op_name, is_installed in available_quant_ops:
quant_op_choices.append((op_name, display_text))
available_attn_ops = get_available_attn_ops()
# 优先级顺序
attn_priority = ["sage_attn2", "flash_attn3", "flash_attn2", "torch_sdpa"]
# 按优先级排序,已安装的在前,未安装的在后
attn_op_choices = []
for op_name, is_installed in available_attn_ops:
attn_op_dict = dict(available_attn_ops)
# 先添加已安装的(按优先级)
for op_name in attn_priority:
if op_name in attn_op_dict and attn_op_dict[op_name]:
status_text = "✅ 已安装"
display_text = f"{op_name} ({status_text})"
attn_op_choices.append((op_name, display_text))
# 再添加未安装的(按优先级)
for op_name in attn_priority:
if op_name in attn_op_dict and not attn_op_dict[op_name]:
status_text = "❌ 未安装"
display_text = f"{op_name} ({status_text})"
attn_op_choices.append((op_name, display_text))
# 添加其他不在优先级列表中的算子(已安装的在前)
other_ops = [(op_name, is_installed) for op_name, is_installed in available_attn_ops if op_name not in attn_priority]
for op_name, is_installed in sorted(other_ops, key=lambda x: not x[1]): # 已安装的在前
status_text = "✅ 已安装" if is_installed else "❌ 未安装"
display_text = f"{op_name} ({status_text})"
attn_op_choices.append((op_name, display_text))
......@@ -260,36 +469,36 @@ def run_inference(
prompt,
negative_prompt,
save_result_path,
torch_compile,
infer_steps,
num_frames,
resolution,
seed,
sample_shift,
enable_teacache,
teacache_thresh,
use_ret_steps,
enable_cfg,
cfg_scale,
dit_quant_scheme,
t5_quant_scheme,
clip_quant_scheme,
fps,
use_tae,
use_tiling_vae,
lazy_load,
precision_mode,
cpu_offload,
offload_granularity,
offload_ratio,
t5_cpu_offload,
clip_cpu_offload,
vae_cpu_offload,
unload_modules,
t5_offload_granularity,
attention_type,
quant_op,
rotary_chunk,
rotary_chunk_size,
rope_chunk,
rope_chunk_size,
clean_cuda_cache,
model_path_input,
model_type_input,
task_type_input,
dit_path_input,
high_noise_path_input,
low_noise_path_input,
t5_path_input,
clip_path_input,
vae_path_input,
image_path=None,
):
cleanup_memory()
......@@ -297,8 +506,23 @@ def run_inference(
quant_op = quant_op.split("(")[0].strip()
attention_type = attention_type.split("(")[0].strip()
global global_runner, current_config, model_path, task
global cur_dit_quant_scheme, cur_clip_quant_scheme, cur_t5_quant_scheme, cur_precision_mode, cur_enable_teacache
global global_runner, current_config, model_path, model_cls
global cur_dit_path, cur_t5_path, cur_clip_path
task = task_type_input
model_cls = determine_model_cls(model_type_input, dit_path_input, high_noise_path_input)
logger.info(f"自动确定 model_cls: {model_cls} (模型类型: {model_type_input})")
if model_type_input == "wan2.1":
dit_quant_detected = detect_quant_scheme(dit_path_input)
else:
dit_quant_detected = detect_quant_scheme(high_noise_path_input)
t5_quant_detected = detect_quant_scheme(t5_path_input)
clip_quant_detected = detect_quant_scheme(clip_path_input)
logger.info(f"自动检测量化精度 - DIT: {dit_quant_detected}, T5: {t5_quant_detected}, CLIP: {clip_quant_detected}")
if model_path_input and model_path_input.strip():
model_path = model_path_input.strip()
if os.path.exists(os.path.join(model_path, "config.json")):
with open(os.path.join(model_path, "config.json"), "r") as f:
......@@ -306,159 +530,88 @@ def run_inference(
else:
model_config = {}
if task == "t2v":
if model_size == "1.3b":
# 1.3B
coefficient = [
[
-5.21862437e04,
9.23041404e03,
-5.28275948e02,
1.36987616e01,
-4.99875664e-02,
],
[
2.39676752e03,
-1.31110545e03,
2.01331979e02,
-8.29855975e00,
1.37887774e-01,
],
]
else:
# 14B
coefficient = [
[
-3.03318725e05,
4.90537029e04,
-2.65530556e03,
5.87365115e01,
-3.15583525e-01,
],
[
-5784.54975374,
5449.50911966,
-1811.16591783,
256.27178429,
-13.02252404,
],
]
elif task == "i2v":
if resolution in [
"1280x720",
"720x1280",
"1280x544",
"544x1280",
"1104x832",
"832x1104",
"960x960",
]:
# 720p
coefficient = [
[
8.10705460e03,
2.13393892e03,
-3.72934672e02,
1.66203073e01,
-4.17769401e-02,
],
[-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683],
]
else:
# 480p
coefficient = [
[
2.57151496e05,
-3.54229917e04,
1.40286849e03,
-1.35890334e01,
1.32517977e-01,
],
[
-3.02331670e02,
2.23948934e02,
-5.25463970e01,
5.87348440e00,
-2.01973289e-01,
],
]
save_result_path = generate_unique_filename(output_dir)
is_dit_quant = dit_quant_scheme != "bf16"
is_t5_quant = t5_quant_scheme != "bf16"
is_dit_quant = dit_quant_detected != "bf16"
is_t5_quant = t5_quant_detected != "bf16"
is_clip_quant = clip_quant_detected != "fp16"
dit_quantized_ckpt = None
dit_original_ckpt = None
high_noise_quantized_ckpt = None
low_noise_quantized_ckpt = None
high_noise_original_ckpt = None
low_noise_original_ckpt = None
if is_dit_quant:
dit_quant_scheme = f"{dit_quant_detected}-{quant_op}"
if "wan2.1" in model_cls:
dit_quantized_ckpt = os.path.join(model_path, dit_path_input)
else:
high_noise_quantized_ckpt = os.path.join(model_path, high_noise_path_input)
low_noise_quantized_ckpt = os.path.join(model_path, low_noise_path_input)
else:
dit_quantized_ckpt = "Default"
if "wan2.1" in model_cls:
dit_original_ckpt = os.path.join(model_path, dit_path_input)
else:
high_noise_original_ckpt = os.path.join(model_path, high_noise_path_input)
low_noise_original_ckpt = os.path.join(model_path, low_noise_path_input)
# 使用前端选择的 T5 路径
if is_t5_quant:
t5_model_name = f"models_t5_umt5-xxl-enc-{t5_quant_scheme}.pth"
t5_quantized_ckpt = find_torch_model_path(model_path, t5_model_name, t5_quant_scheme)
t5_quantized_ckpt = os.path.join(model_path, t5_path_input)
t5_quant_scheme = f"{t5_quant_detected}-{quant_op}"
t5_original_ckpt = None
else:
t5_quantized_ckpt = None
t5_model_name = "models_t5_umt5-xxl-enc-bf16.pth"
t5_original_ckpt = find_torch_model_path(model_path, t5_model_name, "original")
is_clip_quant = clip_quant_scheme != "fp16"
t5_quant_scheme = None
t5_original_ckpt = os.path.join(model_path, t5_path_input)
# 使用前端选择的 CLIP 路径
if is_clip_quant:
clip_model_name = f"clip-{t5_quant_scheme}.pth"
clip_quantized_ckpt = find_torch_model_path(model_path, clip_model_name, clip_quant_scheme)
clip_quantized_ckpt = os.path.join(model_path, clip_path_input)
clip_quant_scheme = f"{clip_quant_detected}-{quant_op}"
clip_original_ckpt = None
else:
clip_quantized_ckpt = None
clip_model_name = "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"
clip_original_ckpt = find_torch_model_path(model_path, clip_model_name, "original")
clip_quant_scheme = None
clip_original_ckpt = os.path.join(model_path, clip_path_input)
if model_type_input == "wan2.1":
current_dit_path = dit_path_input
else:
current_dit_path = f"{high_noise_path_input}|{low_noise_path_input}" if high_noise_path_input and low_noise_path_input else None
current_t5_path = t5_path_input
current_clip_path = clip_path_input
needs_reinit = (
lazy_load
or unload_modules
or global_runner is None
or current_config is None
or cur_dit_quant_scheme is None
or cur_dit_quant_scheme != dit_quant_scheme
or cur_clip_quant_scheme is None
or cur_clip_quant_scheme != clip_quant_scheme
or cur_t5_quant_scheme is None
or cur_t5_quant_scheme != t5_quant_scheme
or cur_precision_mode is None
or cur_precision_mode != precision_mode
or cur_enable_teacache is None
or cur_enable_teacache != enable_teacache
or cur_dit_path is None
or cur_dit_path != current_dit_path
or cur_t5_path is None
or cur_t5_path != current_t5_path
or cur_clip_path is None
or cur_clip_path != current_clip_path
)
if torch_compile:
os.environ["ENABLE_GRAPH_MODE"] = "true"
if cfg_scale == 1:
enable_cfg = False
else:
os.environ["ENABLE_GRAPH_MODE"] = "false"
if precision_mode == "bf16":
os.environ["DTYPE"] = "BF16"
else:
os.environ.pop("DTYPE", None)
enable_cfg = True
if is_dit_quant:
if quant_op == "vllm":
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Vllm"
elif quant_op == "sgl":
if dit_quant_scheme == "int8":
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Sgl-ActVllm"
else:
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Sgl"
elif quant_op == "q8f":
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Q8F"
t5_quant_scheme = f"{t5_quant_scheme}-q8f"
clip_quant_scheme = f"{clip_quant_scheme}-q8f"
dit_quantized_ckpt = find_hf_model_path(model_path, dit_quant_scheme)
if os.path.exists(os.path.join(dit_quantized_ckpt, "config.json")):
with open(os.path.join(dit_quantized_ckpt, "config.json"), "r") as f:
quant_model_config = json.load(f)
else:
quant_model_config = {}
else:
mm_type = "Default"
dit_quantized_ckpt = None
quant_model_config = {}
vae_name_lower = vae_path_input.lower() if vae_path_input else ""
use_tae = "tae" in vae_name_lower or "lighttae" in vae_name_lower
use_lightvae = "lightvae" in vae_name_lower
need_scaled = "lighttae" in vae_name_lower
config = {
logger.info(f"VAE 配置 - use_tae: {use_tae}, use_lightvae: {use_lightvae}, need_scaled: {need_scaled} (VAE: {vae_path_input})")
config_graio = {
"infer_steps": infer_steps,
"target_video_length": num_frames,
"target_width": int(resolution.split("x")[0]),
......@@ -466,26 +619,36 @@ def run_inference(
"self_attn_1_type": attention_type,
"cross_attn_1_type": attention_type,
"cross_attn_2_type": attention_type,
"seed": seed,
"enable_cfg": enable_cfg,
"sample_guide_scale": cfg_scale,
"sample_shift": sample_shift,
"cpu_offload": cpu_offload,
"offload_granularity": offload_granularity,
"offload_ratio": offload_ratio,
"t5_offload_granularity": t5_offload_granularity,
"dit_quantized_ckpt": dit_quantized_ckpt,
"mm_config": {
"mm_type": mm_type,
},
"fps": fps,
"feature_caching": "Tea" if enable_teacache else "NoCaching",
"coefficients": coefficient[0] if use_ret_steps else coefficient[1],
"use_ret_steps": use_ret_steps,
"teacache_thresh": teacache_thresh,
"t5_original_ckpt": t5_original_ckpt,
"feature_caching": "NoCaching",
"do_mm_calib": False,
"parallel_attn_type": None,
"parallel_vae": False,
"max_area": False,
"vae_stride": (4, 8, 8),
"patch_size": (1, 2, 2),
"lora_path": None,
"strength_model": 1.0,
"use_prompt_enhancer": False,
"text_len": 512,
"denoising_step_list": [1000, 750, 500, 250],
"cpu_offload": True if "wan2.2" in model_cls else cpu_offload,
"offload_granularity": "phase" if "wan2.2" in model_cls else offload_granularity,
"t5_cpu_offload": t5_cpu_offload,
"unload_modules": unload_modules,
"clip_cpu_offload": clip_cpu_offload,
"vae_cpu_offload": vae_cpu_offload,
"dit_quantized": is_dit_quant,
"dit_quant_scheme": dit_quant_scheme,
"dit_quantized_ckpt": dit_quantized_ckpt,
"dit_original_ckpt": dit_original_ckpt,
"high_noise_quantized_ckpt": high_noise_quantized_ckpt,
"low_noise_quantized_ckpt": low_noise_quantized_ckpt,
"high_noise_original_ckpt": high_noise_original_ckpt,
"low_noise_original_ckpt": low_noise_original_ckpt,
"t5_original_ckpt": t5_original_ckpt,
"t5_quantized": is_t5_quant,
"t5_quantized_ckpt": t5_quantized_ckpt,
"t5_quant_scheme": t5_quant_scheme,
......@@ -493,29 +656,27 @@ def run_inference(
"clip_quantized": is_clip_quant,
"clip_quantized_ckpt": clip_quantized_ckpt,
"clip_quant_scheme": clip_quant_scheme,
"vae_path": find_torch_model_path(model_path, "Wan2.1_VAE.pth"),
"vae_path": os.path.join(model_path, vae_path_input),
"use_tiling_vae": use_tiling_vae,
"use_tae": use_tae,
"tae_path": (find_torch_model_path(model_path, "taew2_1.pth") if use_tae else None),
"use_lightvae": use_lightvae,
"need_scaled": need_scaled,
"lazy_load": lazy_load,
"do_mm_calib": False,
"parallel_attn_type": None,
"parallel_vae": False,
"max_area": False,
"vae_stride": (4, 8, 8),
"patch_size": (1, 2, 2),
"lora_path": None,
"strength_model": 1.0,
"use_prompt_enhancer": False,
"text_len": 512,
"rotary_chunk": rotary_chunk,
"rotary_chunk_size": rotary_chunk_size,
"rope_chunk": rope_chunk,
"rope_chunk_size": rope_chunk_size,
"clean_cuda_cache": clean_cuda_cache,
"denoising_step_list": [1000, 750, 500, 250],
"unload_modules": unload_modules,
"seq_parallel": False,
"warm_up_cpu_buffers": False,
"boundary_step_index": 2,
"boundary": 0.900,
"use_image_encoder": False if "wan2.2" in model_cls else True,
"rope_type": "flashinfer" if apply_rope_with_cos_sin_cache_inplace else "torch",
}
args = argparse.Namespace(
model_cls=model_cls,
seed=seed,
task=task,
model_path=model_path,
prompt_enhancer=None,
......@@ -523,11 +684,13 @@ def run_inference(
negative_prompt=negative_prompt,
image_path=image_path,
save_result_path=save_result_path,
return_result_tensor=False,
)
config = get_default_config()
config.update({k: v for k, v in vars(args).items()})
config.update(model_config)
config.update(quant_model_config)
config.update(config_graio)
logger.info(f"使用模型: {model_path}")
logger.info(f"推理配置:\n{json.dumps(config, indent=4, ensure_ascii=False)}")
......@@ -543,28 +706,19 @@ def run_inference(
from lightx2v.infer import init_runner # noqa
runner = init_runner(config)
input_info = set_input_info(args)
current_config = config
cur_dit_quant_scheme = dit_quant_scheme
cur_clip_quant_scheme = clip_quant_scheme
cur_t5_quant_scheme = t5_quant_scheme
cur_precision_mode = precision_mode
cur_enable_teacache = enable_teacache
cur_dit_path = current_dit_path
cur_t5_path = current_t5_path
cur_clip_path = current_clip_path
if not lazy_load:
global_runner = runner
else:
runner.config = config
runner.run_pipeline()
del config, args, model_config, quant_model_config
if "dit_quantized_ckpt" in locals():
del dit_quantized_ckpt
if "t5_quant_ckpt" in locals():
del t5_quant_ckpt
if "clip_quant_ckpt" in locals():
del clip_quant_ckpt
runner.run_pipeline(input_info)
cleanup_memory()
return save_result_path
......@@ -575,44 +729,28 @@ def handle_lazy_load_change(lazy_load_enabled):
return gr.update(value=lazy_load_enabled)
def auto_configure(enable_auto_config, resolution):
def auto_configure(resolution):
"""根据机器配置和分辨率自动设置推理选项"""
default_config = {
"torch_compile_val": False,
"lazy_load_val": False,
"rotary_chunk_val": False,
"rotary_chunk_size_val": 100,
"rope_chunk_val": False,
"rope_chunk_size_val": 100,
"clean_cuda_cache_val": False,
"cpu_offload_val": False,
"offload_granularity_val": "block",
"offload_ratio_val": 1,
"t5_cpu_offload_val": False,
"clip_cpu_offload_val": False,
"vae_cpu_offload_val": False,
"unload_modules_val": False,
"t5_offload_granularity_val": "model",
"attention_type_val": attn_op_choices[0][1],
"quant_op_val": quant_op_choices[0][1],
"dit_quant_scheme_val": "bf16",
"t5_quant_scheme_val": "bf16",
"clip_quant_scheme_val": "fp16",
"precision_mode_val": "fp32",
"use_tae_val": False,
"use_tiling_vae_val": False,
"enable_teacache_val": False,
"teacache_thresh_val": 0.26,
"use_ret_steps_val": False,
}
if not enable_auto_config:
return tuple(gr.update(value=default_config[key]) for key in default_config)
gpu_memory = round(get_gpu_memory())
cpu_memory = round(get_cpu_memory())
if is_fp8_supported_gpu():
quant_type = "fp8"
else:
quant_type = "int8"
attn_priority = ["sage_attn2", "flash_attn3", "flash_attn2", "torch_sdpa"]
attn_priority = ["sage_attn3", "sage_attn2", "flash_attn3", "flash_attn2", "torch_sdpa"]
if is_ada_architecture_gpu():
quant_op_priority = ["q8f", "vllm", "sgl"]
......@@ -647,25 +785,15 @@ def auto_configure(enable_auto_config, resolution):
else:
res = "480p"
if model_size == "14b":
is_14b = True
else:
is_14b = False
if res == "720p" and is_14b:
if res == "720p":
gpu_rules = [
(80, {}),
(48, {"cpu_offload_val": True, "offload_ratio_val": 0.5, "t5_cpu_offload_val": True}),
(40, {"cpu_offload_val": True, "offload_ratio_val": 0.8, "t5_cpu_offload_val": True}),
(32, {"cpu_offload_val": True, "offload_ratio_val": 1, "t5_cpu_offload_val": True}),
(40, {"cpu_offload_val": False, "t5_cpu_offload_val": True, "vae_cpu_offload_val": True, "clip_cpu_offload_val": True}),
(32, {"cpu_offload_val": True, "t5_cpu_offload_val": False, "vae_cpu_offload_val": False, "clip_cpu_offload_val": False}),
(
24,
{
"cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True,
},
),
......@@ -673,151 +801,64 @@ def auto_configure(enable_auto_config, resolution):
16,
{
"cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
"rotary_chunk_val": True,
"rotary_chunk_size_val": 100,
},
),
(
12,
{
"cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
"rotary_chunk_val": True,
"rotary_chunk_size_val": 100,
"clean_cuda_cache_val": True,
"use_tae_val": True,
"rope_chunk_val": True,
"rope_chunk_size_val": 100,
},
),
(
8,
{
"cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
"rotary_chunk_val": True,
"rotary_chunk_size_val": 100,
"rope_chunk_val": True,
"rope_chunk_size_val": 100,
"clean_cuda_cache_val": True,
"t5_quant_scheme_val": quant_type,
"clip_quant_scheme_val": quant_type,
"dit_quant_scheme_val": quant_type,
"lazy_load_val": True,
"unload_modules_val": True,
"use_tae_val": True,
},
),
]
elif is_14b:
else:
gpu_rules = [
(80, {}),
(48, {"cpu_offload_val": True, "offload_ratio_val": 0.2, "t5_cpu_offload_val": True}),
(40, {"cpu_offload_val": True, "offload_ratio_val": 0.5, "t5_cpu_offload_val": True}),
(24, {"cpu_offload_val": True, "offload_ratio_val": 0.8, "t5_cpu_offload_val": True}),
(40, {"cpu_offload_val": False, "t5_cpu_offload_val": True, "vae_cpu_offload_val": True, "clip_cpu_offload_val": True}),
(32, {"cpu_offload_val": True, "t5_cpu_offload_val": False, "vae_cpu_offload_val": False, "clip_cpu_offload_val": False}),
(
16,
24,
{
"cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True,
"offload_granularity_val": "block",
},
),
(
8,
(
16,
{
"cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
"t5_quant_scheme_val": quant_type,
"clip_quant_scheme_val": quant_type,
"dit_quant_scheme_val": quant_type,
"lazy_load_val": True,
"unload_modules_val": True,
"rotary_chunk_val": True,
"rotary_chunk_size_val": 10000,
"use_tae_val": True,
}
if res == "540p"
else {
"cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
"t5_quant_scheme_val": quant_type,
"clip_quant_scheme_val": quant_type,
"dit_quant_scheme_val": quant_type,
"lazy_load_val": True,
"unload_modules_val": True,
"use_tae_val": True,
}
),
},
),
]
else:
gpu_rules = [
(24, {}),
(
8,
{
"t5_cpu_offload_val": True,
"t5_offload_granularity_val": "block",
"t5_quant_scheme_val": quant_type,
"cpu_offload_val": True,
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
},
),
]
if is_14b:
cpu_rules = [
(128, {}),
(64, {"dit_quant_scheme_val": quant_type}),
(32, {"dit_quant_scheme_val": quant_type, "lazy_load_val": True}),
(
16,
{
"dit_quant_scheme_val": quant_type,
"t5_quant_scheme_val": quant_type,
"clip_quant_scheme_val": quant_type,
"lazy_load_val": True,
"unload_modules_val": True,
},
),
]
else:
cpu_rules = [
(64, {}),
(32, {"unload_modules_val": True}),
(
16,
{
"t5_quant_scheme_val": quant_type,
"lazy_load_val": True,
"unload_modules_val": True,
"use_tae_val": True,
},
),
]
......@@ -832,62 +873,238 @@ def auto_configure(enable_auto_config, resolution):
default_config.update(updates)
break
return tuple(gr.update(value=default_config[key]) for key in default_config)
return (
gr.update(value=default_config["lazy_load_val"]),
gr.update(value=default_config["rope_chunk_val"]),
gr.update(value=default_config["rope_chunk_size_val"]),
gr.update(value=default_config["clean_cuda_cache_val"]),
gr.update(value=default_config["cpu_offload_val"]),
gr.update(value=default_config["offload_granularity_val"]),
gr.update(value=default_config["t5_cpu_offload_val"]),
gr.update(value=default_config["clip_cpu_offload_val"]),
gr.update(value=default_config["vae_cpu_offload_val"]),
gr.update(value=default_config["unload_modules_val"]),
gr.update(value=default_config["attention_type_val"]),
gr.update(value=default_config["quant_op_val"]),
gr.update(value=default_config["use_tiling_vae_val"]),
)
def main():
with gr.Blocks(
title="Lightx2v (轻量级视频推理和生成引擎)",
css="""
.main-content { max-width: 1400px; margin: auto; }
.output-video { max-height: 650px; }
.main-content { max-width: 1600px; margin: auto; padding: 20px; }
.warning { color: #ff6b6b; font-weight: bold; }
.advanced-options { background: #f9f9ff; border-radius: 10px; padding: 15px; }
.tab-button { font-size: 16px; padding: 10px 20px; }
.auto-config-title {
background: linear-gradient(45deg, #ff6b6b, #4ecdc4);
background-clip: text;
-webkit-background-clip: text;
color: transparent;
text-align: center;
margin: 0 !important;
padding: 8px;
border: 2px solid #4ecdc4;
border-radius: 8px;
background-color: #f0f8ff;
/* 模型配置区域样式 */
.model-config {
margin-bottom: 20px !important;
border: 1px solid #e0e0e0;
border-radius: 12px;
padding: 15px;
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
}
.auto-config-checkbox {
border: 2px solid #ff6b6b !important;
border-radius: 8px !important;
padding: 10px !important;
background: linear-gradient(135deg, #fff5f5, #f0fff0) !important;
box-shadow: 0 2px 8px rgba(255, 107, 107, 0.2) !important;
/* 输入参数区域样式 */
.input-params {
margin-bottom: 20px !important;
border: 1px solid #e0e0e0;
border-radius: 12px;
padding: 15px;
background: linear-gradient(135deg, #fff5f5 0%, #ffeef0 100%);
}
.auto-config-checkbox label {
font-size: 16px !important;
/* 输出视频区域样式 */
.output-video {
border: 1px solid #e0e0e0;
border-radius: 12px;
padding: 20px;
background: linear-gradient(135deg, #e0f2fe 0%, #bae6fd 100%);
min-height: 400px;
}
/* 生成按钮样式 */
.generate-btn {
width: 100%;
margin-top: 20px;
padding: 15px 30px !important;
font-size: 18px !important;
font-weight: bold !important;
color: #2c3e50 !important;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
border: none !important;
border-radius: 10px !important;
box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important;
transition: all 0.3s ease !important;
}
.generate-btn:hover {
transform: translateY(-2px);
box-shadow: 0 6px 20px rgba(102, 126, 234, 0.6) !important;
}
/* Accordion 标题样式 */
.model-config .gr-accordion-header,
.input-params .gr-accordion-header,
.output-video .gr-accordion-header {
font-size: 20px !important;
font-weight: bold !important;
padding: 15px !important;
}
/* 优化间距 */
.gr-row {
margin-bottom: 15px;
}
/* 视频播放器样式 */
.output-video video {
border-radius: 10px;
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1);
}
""",
) as demo:
gr.Markdown(f"# 🎬 {model_cls} 视频生成器")
gr.Markdown(f"### 使用模型: {model_path}")
gr.Markdown(f"# 🎬 LightX2V 视频生成器")
with gr.Tabs() as tabs:
with gr.Tab("基本设置", id=1):
# 主布局:左右分栏
with gr.Row():
with gr.Column(scale=4):
with gr.Group():
gr.Markdown("## 📥 输入参数")
# 左侧:配置和输入区域
with gr.Column(scale=5):
# 模型配置区域
with gr.Accordion("🗂️ 模型配置", open=True, elem_classes=["model-config"]):
# FP8 支持提示
if not is_fp8_supported_gpu():
gr.Markdown("⚠️ **您的设备不支持fp8推理**,已自动隐藏包含fp8的模型选项。")
# 隐藏的状态组件
model_path_input = gr.Textbox(value=model_path, visible=False)
# 模型类型 + 任务类型
with gr.Row():
model_type_input = gr.Radio(
label="模型类型",
choices=["wan2.1", "wan2.2"],
value="wan2.1",
info="wan2.2 需要分别指定高噪模型和低噪模型",
)
task_type_input = gr.Radio(
label="任务类型",
choices=["i2v", "t2v"],
value="i2v",
info="i2v: 图生视频, t2v: 文生视频",
)
# wan2.1:Diffusion模型(单独一行)
with gr.Row() as wan21_row:
dit_path_input = gr.Dropdown(
label="🎨 Diffusion模型",
choices=get_dit_choices(model_path, "wan2.1"),
value=get_dit_choices(model_path, "wan2.1")[0] if get_dit_choices(model_path, "wan2.1") else "",
allow_custom_value=True,
visible=True,
)
# wan2.2 专用:高噪模型 + 低噪模型(默认隐藏)
with gr.Row(visible=False) as wan22_row:
high_noise_path_input = gr.Dropdown(
label="🔊 高噪模型",
choices=get_high_noise_choices(model_path),
value=get_high_noise_choices(model_path)[0] if get_high_noise_choices(model_path) else "",
allow_custom_value=True,
)
low_noise_path_input = gr.Dropdown(
label="🔇 低噪模型",
choices=get_low_noise_choices(model_path),
value=get_low_noise_choices(model_path)[0] if get_low_noise_choices(model_path) else "",
allow_custom_value=True,
)
if task == "i2v":
# 文本编码器(单独一行)
with gr.Row():
t5_path_input = gr.Dropdown(
label="📝 文本编码器",
choices=get_t5_choices(model_path),
value=get_t5_choices(model_path)[0] if get_t5_choices(model_path) else "",
allow_custom_value=True,
)
# 图像编码器 + VAE解码器
with gr.Row():
clip_path_input = gr.Dropdown(
label="🖼️ 图像编码器",
choices=get_clip_choices(model_path),
value=get_clip_choices(model_path)[0] if get_clip_choices(model_path) else "",
allow_custom_value=True,
)
vae_path_input = gr.Dropdown(
label="🎞️ VAE解码器",
choices=get_vae_choices(model_path),
value=get_vae_choices(model_path)[0] if get_vae_choices(model_path) else "",
allow_custom_value=True,
)
# 注意力算子和量化矩阵乘法算子
with gr.Row():
attention_type = gr.Dropdown(
label="⚡ 注意力算子",
choices=[op[1] for op in attn_op_choices],
value=attn_op_choices[0][1] if attn_op_choices else "",
info="使用适当的注意力算子加速推理",
)
quant_op = gr.Dropdown(
label="量化矩阵乘法算子",
choices=[op[1] for op in quant_op_choices],
value=quant_op_choices[0][1],
info="选择量化矩阵乘法算子以加速推理",
interactive=True,
)
# 判断模型是否是 distill 版本
def is_distill_model(model_type, dit_path, high_noise_path):
"""根据模型类型和路径判断是否是 distill 版本"""
if model_type == "wan2.1":
check_name = dit_path.lower() if dit_path else ""
else:
check_name = high_noise_path.lower() if high_noise_path else ""
return "4step" in check_name
# 模型类型切换事件
def on_model_type_change(model_type, model_path_val):
if model_type == "wan2.2":
return gr.update(visible=False), gr.update(visible=True), gr.update()
else:
# 更新 wan2.1 的 Diffusion 模型选项
dit_choices = get_dit_choices(model_path_val, "wan2.1")
return (
gr.update(visible=True),
gr.update(visible=False),
gr.update(choices=dit_choices, value=dit_choices[0] if dit_choices else ""),
)
model_type_input.change(
fn=on_model_type_change,
inputs=[model_type_input, model_path_input],
outputs=[wan21_row, wan22_row, dit_path_input],
)
# 输入参数区域
with gr.Accordion("📥 输入参数", open=True, elem_classes=["input-params"]):
# 图片输入(i2v 时显示)
with gr.Row(visible=True) as image_input_row:
image_path = gr.Image(
label="输入图像",
type="filepath",
height=300,
interactive=True,
visible=True,
)
# 任务类型切换事件
def on_task_type_change(task_type):
return gr.update(visible=(task_type == "i2v"))
task_type_input.change(
fn=on_task_type_change,
inputs=[task_type_input],
outputs=[image_input_row],
)
with gr.Row():
......@@ -931,15 +1148,6 @@ def main():
label="最大分辨率",
)
with gr.Column():
with gr.Group():
gr.Markdown("### 🚀 **智能配置推荐**", elem_classes=["auto-config-title"])
enable_auto_config = gr.Checkbox(
label="🎯 **自动配置推理选项**",
value=False,
info="💡 **智能优化GPU设置以匹配当前分辨率。修改分辨率后,请重新勾选此选项,否则可能导致性能下降或运行失败。**",
elem_classes=["auto-config-checkbox"],
)
with gr.Column(scale=9):
seed = gr.Slider(
label="随机种子",
......@@ -948,25 +1156,21 @@ def main():
step=1,
value=generate_random_seed(),
)
with gr.Column(scale=1):
randomize_btn = gr.Button("🎲 随机化", variant="secondary")
randomize_btn.click(fn=generate_random_seed, inputs=None, outputs=seed)
with gr.Column():
# 根据模型类别设置默认推理步数
if model_cls == "wan2.1_distill":
default_dit = get_dit_choices(model_path, "wan2.1")[0] if get_dit_choices(model_path, "wan2.1") else ""
default_high_noise = get_high_noise_choices(model_path)[0] if get_high_noise_choices(model_path) else ""
default_is_distill = is_distill_model("wan2.1", default_dit, default_high_noise)
if default_is_distill:
infer_steps = gr.Slider(
label="推理步数",
minimum=4,
maximum=4,
minimum=1,
maximum=100,
step=1,
value=4,
interactive=False,
info="推理步数固定为4,以获得最佳性能(对于蒸馏模型)。",
info="蒸馏模型推理步数默认为4。",
)
elif model_cls == "wan2.1":
if task == "i2v":
else:
infer_steps = gr.Slider(
label="推理步数",
minimum=1,
......@@ -975,31 +1179,45 @@ def main():
value=40,
info="视频生成的推理步数。增加步数可能提高质量但降低速度。",
)
elif task == "t2v":
infer_steps = gr.Slider(
label="推理步数",
minimum=1,
maximum=100,
step=1,
value=50,
info="视频生成的推理步数。增加步数可能提高质量但降低速度。",
# 当模型路径改变时,动态更新推理步数
def update_infer_steps(model_type, dit_path, high_noise_path):
is_distill = is_distill_model(model_type, dit_path, high_noise_path)
if is_distill:
return gr.update(minimum=1, maximum=100, value=4, interactive=True)
else:
return gr.update(minimum=1, maximum=100, value=40, interactive=True)
# 监听模型路径变化
dit_path_input.change(
fn=lambda mt, dp, hnp: update_infer_steps(mt, dp, hnp),
inputs=[model_type_input, dit_path_input, high_noise_path_input],
outputs=[infer_steps],
)
high_noise_path_input.change(
fn=lambda mt, dp, hnp: update_infer_steps(mt, dp, hnp),
inputs=[model_type_input, dit_path_input, high_noise_path_input],
outputs=[infer_steps],
)
model_type_input.change(
fn=lambda mt, dp, hnp: update_infer_steps(mt, dp, hnp),
inputs=[model_type_input, dit_path_input, high_noise_path_input],
outputs=[infer_steps],
)
# 根据模型类别设置默认CFG
default_enable_cfg = False if model_cls == "wan2.1_distill" else True
# CFG缩放因子:distill 时默认为 1,否则默认为 5
default_cfg_scale = 1 if default_is_distill else 5
# enable_cfg 不暴露到前端,根据 cfg_scale 自动设置
# 如果 cfg_scale == 1,则 enable_cfg = False,否则 enable_cfg = True
default_enable_cfg = False if default_cfg_scale == 1 else True
enable_cfg = gr.Checkbox(
label="启用无分类器引导",
value=default_enable_cfg,
info="启用无分类器引导以控制提示词强度",
)
cfg_scale = gr.Slider(
label="CFG缩放因子",
minimum=1,
maximum=10,
step=1,
value=5,
info="控制提示词的影响强度。值越高,提示词的影响越大。",
visible=False, # 隐藏,不暴露到前端
)
with gr.Row():
sample_shift = gr.Slider(
label="分布偏移",
value=5,
......@@ -1008,7 +1226,56 @@ def main():
step=1,
info="控制样本分布偏移的程度。值越大表示偏移越明显。",
)
cfg_scale = gr.Slider(
label="CFG缩放因子",
minimum=1,
maximum=10,
step=1,
value=default_cfg_scale,
info="控制提示词的影响强度。值越高,提示词的影响越大。当值为1时,自动禁用CFG。",
)
# 根据 cfg_scale 更新 enable_cfg
def update_enable_cfg(cfg_scale_val):
"""根据 cfg_scale 的值自动设置 enable_cfg"""
if cfg_scale_val == 1:
return gr.update(value=False)
else:
return gr.update(value=True)
# 当模型路径改变时,动态更新 CFG 缩放因子和 enable_cfg
def update_cfg_scale(model_type, dit_path, high_noise_path):
is_distill = is_distill_model(model_type, dit_path, high_noise_path)
if is_distill:
new_cfg_scale = 1
else:
new_cfg_scale = 5
new_enable_cfg = False if new_cfg_scale == 1 else True
return gr.update(value=new_cfg_scale), gr.update(value=new_enable_cfg)
dit_path_input.change(
fn=lambda mt, dp, hnp: update_cfg_scale(mt, dp, hnp),
inputs=[model_type_input, dit_path_input, high_noise_path_input],
outputs=[cfg_scale, enable_cfg],
)
high_noise_path_input.change(
fn=lambda mt, dp, hnp: update_cfg_scale(mt, dp, hnp),
inputs=[model_type_input, dit_path_input, high_noise_path_input],
outputs=[cfg_scale, enable_cfg],
)
model_type_input.change(
fn=lambda mt, dp, hnp: update_cfg_scale(mt, dp, hnp),
inputs=[model_type_input, dit_path_input, high_noise_path_input],
outputs=[cfg_scale, enable_cfg],
)
cfg_scale.change(
fn=update_enable_cfg,
inputs=[cfg_scale],
outputs=[enable_cfg],
)
with gr.Row():
fps = gr.Slider(
label="每秒帧数(FPS)",
minimum=8,
......@@ -1030,282 +1297,109 @@ def main():
label="输出视频路径",
value=generate_unique_filename(output_dir),
info="必须包含.mp4扩展名。如果留空或使用默认值,将自动生成唯一文件名。",
visible=False, # 隐藏输出路径,自动生成
)
with gr.Column(scale=6):
gr.Markdown("## 📤 生成的视频")
with gr.Column(scale=4):
with gr.Accordion("📤 生成的视频", open=True, elem_classes=["output-video"]):
output_video = gr.Video(
label="结果",
height=624,
width=360,
label="",
height=600,
autoplay=True,
elem_classes=["output-video"],
)
infer_btn = gr.Button("生成视频", variant="primary", size="lg")
with gr.Tab("⚙️ 高级选项", id=2):
with gr.Group(elem_classes="advanced-options"):
gr.Markdown("### GPU内存优化")
with gr.Row():
rotary_chunk = gr.Checkbox(
label="分块旋转位置编码",
value=False,
info="启用时,将旋转位置编码分块处理以节省GPU内存。",
)
rotary_chunk_size = gr.Slider(
label="旋转编码块大小",
value=100,
minimum=100,
maximum=10000,
step=100,
info="控制应用旋转编码的块大小。较大的值可能提高性能但增加内存使用。仅在'rotary_chunk'勾选时有效。",
)
unload_modules = gr.Checkbox(
label="卸载模块",
value=False,
info="推理后卸载模块(T5、CLIP、DIT等)以减少GPU/CPU内存使用",
)
clean_cuda_cache = gr.Checkbox(
label="清理CUDA内存缓存",
value=False,
info="启用时,及时释放GPU内存但会减慢推理速度。",
)
gr.Markdown("### 异步卸载")
with gr.Row():
cpu_offload = gr.Checkbox(
label="CPU卸载",
value=False,
info="将模型计算的一部分从GPU卸载到CPU以减少GPU内存使用",
show_label=False,
)
lazy_load = gr.Checkbox(
label="启用延迟加载",
value=False,
info="在推理过程中延迟加载模型组件。需要CPU加载和DIT量化。",
)
offload_granularity = gr.Dropdown(
label="Dit卸载粒度",
choices=["block", "phase"],
value="phase",
info="设置Dit模型卸载粒度:块或计算阶段",
)
offload_ratio = gr.Slider(
label="Dit模型卸载比例",
minimum=0.0,
maximum=1.0,
step=0.1,
value=1.0,
info="控制将多少Dit模型卸载到CPU",
)
t5_cpu_offload = gr.Checkbox(
label="T5 CPU卸载",
value=False,
info="将T5编码器模型卸载到CPU以减少GPU内存使用",
)
t5_offload_granularity = gr.Dropdown(
label="T5编码器卸载粒度",
choices=["model", "block"],
value="model",
info="控制将T5编码器模型卸载到CPU时的粒度",
)
gr.Markdown("### 低精度量化")
with gr.Row():
torch_compile = gr.Checkbox(
label="Torch编译",
value=False,
info="使用torch.compile加速推理过程",
)
attention_type = gr.Dropdown(
label="注意力算子",
choices=[op[1] for op in attn_op_choices],
value=attn_op_choices[0][1],
info="使用适当的注意力算子加速推理",
)
quant_op = gr.Dropdown(
label="量化矩阵乘法算子",
choices=[op[1] for op in quant_op_choices],
value=quant_op_choices[0][1],
info="选择量化矩阵乘法算子以加速推理",
interactive=True,
)
# 获取动态量化选项
quant_options = get_quantization_options(model_path)
dit_quant_scheme = gr.Dropdown(
label="Dit",
choices=quant_options["dit_choices"],
value=quant_options["dit_default"],
info="Dit模型的量化精度",
)
t5_quant_scheme = gr.Dropdown(
label="T5编码器",
choices=quant_options["t5_choices"],
value=quant_options["t5_default"],
info="T5编码器模型的量化精度",
)
clip_quant_scheme = gr.Dropdown(
label="Clip编码器",
choices=quant_options["clip_choices"],
value=quant_options["clip_default"],
info="Clip编码器的量化精度",
)
precision_mode = gr.Dropdown(
label="敏感层精度模式",
choices=["fp32", "bf16"],
value="fp32",
info="选择用于关键模型组件(如归一化和嵌入层)的数值精度。FP32提供更高精度,而BF16在兼容硬件上提高性能。",
)
gr.Markdown("### 变分自编码器(VAE)")
with gr.Row():
use_tae = gr.Checkbox(
label="使用轻量级VAE",
value=False,
info="使用轻量级VAE模型加速解码过程",
)
use_tiling_vae = gr.Checkbox(
label="VAE分块推理",
value=False,
info="使用VAE分块推理以减少GPU内存使用",
)
gr.Markdown("### 特征缓存")
with gr.Row():
enable_teacache = gr.Checkbox(
label="Tea Cache",
value=False,
info="在推理过程中缓存特征以减少推理步数",
)
teacache_thresh = gr.Slider(
label="Tea Cache阈值",
value=0.26,
minimum=0,
maximum=1,
info="较高的加速可能导致质量下降 —— 设置为0.1提供约2.0倍加速,设置为0.2提供约3.0倍加速",
)
use_ret_steps = gr.Checkbox(
label="仅缓存关键步骤",
value=False,
info="勾选时,仅在调度器返回结果的关键步骤写入缓存;未勾选时,在所有步骤写入缓存以确保最高质量",
)
enable_auto_config.change(
infer_btn = gr.Button("🎬 生成视频", variant="primary", size="lg", elem_classes=["generate-btn"])
rope_chunk = gr.Checkbox(label="分块旋转位置编码", value=False, visible=False)
rope_chunk_size = gr.Slider(label="旋转编码块大小", value=100, minimum=100, maximum=10000, step=100, visible=False)
unload_modules = gr.Checkbox(label="卸载模块", value=False, visible=False)
clean_cuda_cache = gr.Checkbox(label="清理CUDA内存缓存", value=False, visible=False)
cpu_offload = gr.Checkbox(label="CPU卸载", value=False, visible=False)
lazy_load = gr.Checkbox(label="启用延迟加载", value=False, visible=False)
offload_granularity = gr.Dropdown(label="Dit卸载粒度", choices=["block", "phase"], value="phase", visible=False)
t5_cpu_offload = gr.Checkbox(label="T5 CPU卸载", value=False, visible=False)
clip_cpu_offload = gr.Checkbox(label="CLIP CPU卸载", value=False, visible=False)
vae_cpu_offload = gr.Checkbox(label="VAE CPU卸载", value=False, visible=False)
use_tiling_vae = gr.Checkbox(label="VAE分块推理", value=False, visible=False)
resolution.change(
fn=auto_configure,
inputs=[enable_auto_config, resolution],
inputs=[resolution],
outputs=[
torch_compile,
lazy_load,
rotary_chunk,
rotary_chunk_size,
rope_chunk,
rope_chunk_size,
clean_cuda_cache,
cpu_offload,
offload_granularity,
offload_ratio,
t5_cpu_offload,
clip_cpu_offload,
vae_cpu_offload,
unload_modules,
t5_offload_granularity,
attention_type,
quant_op,
dit_quant_scheme,
t5_quant_scheme,
clip_quant_scheme,
precision_mode,
use_tae,
use_tiling_vae,
enable_teacache,
teacache_thresh,
use_ret_steps,
],
)
lazy_load.change(
fn=handle_lazy_load_change,
inputs=[lazy_load],
outputs=[unload_modules],
)
if task == "i2v":
infer_btn.click(
fn=run_inference,
inputs=[
prompt,
negative_prompt,
save_result_path,
torch_compile,
infer_steps,
num_frames,
resolution,
seed,
sample_shift,
enable_teacache,
teacache_thresh,
use_ret_steps,
enable_cfg,
cfg_scale,
dit_quant_scheme,
t5_quant_scheme,
clip_quant_scheme,
fps,
use_tae,
use_tiling_vae,
demo.load(
fn=lambda res: auto_configure(res),
inputs=[resolution],
outputs=[
lazy_load,
precision_mode,
rope_chunk,
rope_chunk_size,
clean_cuda_cache,
cpu_offload,
offload_granularity,
offload_ratio,
t5_cpu_offload,
clip_cpu_offload,
vae_cpu_offload,
unload_modules,
t5_offload_granularity,
attention_type,
quant_op,
rotary_chunk,
rotary_chunk_size,
clean_cuda_cache,
image_path,
use_tiling_vae,
],
outputs=output_video,
)
else:
infer_btn.click(
fn=run_inference,
inputs=[
prompt,
negative_prompt,
save_result_path,
torch_compile,
infer_steps,
num_frames,
resolution,
seed,
sample_shift,
enable_teacache,
teacache_thresh,
use_ret_steps,
enable_cfg,
cfg_scale,
dit_quant_scheme,
t5_quant_scheme,
clip_quant_scheme,
fps,
use_tae,
use_tiling_vae,
lazy_load,
precision_mode,
cpu_offload,
offload_granularity,
offload_ratio,
t5_cpu_offload,
clip_cpu_offload,
vae_cpu_offload,
unload_modules,
t5_offload_granularity,
attention_type,
quant_op,
rotary_chunk,
rotary_chunk_size,
rope_chunk,
rope_chunk_size,
clean_cuda_cache,
model_path_input,
model_type_input,
task_type_input,
dit_path_input,
high_noise_path_input,
low_noise_path_input,
t5_path_input,
clip_path_input,
vae_path_input,
image_path,
],
outputs=output_video,
)
......@@ -1316,25 +1410,14 @@ def main():
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="轻量级视频生成")
parser.add_argument("--model_path", type=str, required=True, help="模型文件夹路径")
parser.add_argument(
"--model_cls",
type=str,
choices=["wan2.1", "wan2.1_distill"],
default="wan2.1",
help="要使用的模型类别 (wan2.1: 标准模型, wan2.1_distill: 蒸馏模型,推理更快)",
)
parser.add_argument("--model_size", type=str, required=True, choices=["14b", "1.3b"], help="模型大小:14b 或 1.3b")
parser.add_argument("--task", type=str, required=True, choices=["i2v", "t2v"], help="指定任务类型。'i2v'用于图像到视频转换,'t2v'用于文本到视频生成。")
parser.add_argument("--server_port", type=int, default=7862, help="服务器端口")
parser.add_argument("--server_name", type=str, default="0.0.0.0", help="服务器IP")
parser.add_argument("--output_dir", type=str, default="./outputs", help="输出视频保存目录")
args = parser.parse_args()
global model_path, model_cls, model_size, output_dir
global model_path, model_cls, output_dir
model_path = args.model_path
model_cls = args.model_cls
model_size = args.model_size
task = args.task
model_cls = "wan2.1"
output_dir = args.output_dir
main()
......@@ -14,27 +14,15 @@
# Lightx2v project root directory path
# Example: /home/user/lightx2v or /data/video_gen/lightx2v
lightx2v_path=/data/video_gen/LightX2V
lightx2v_path=/path/to/LightX2V
# Model path configuration
# Image-to-video model path (for i2v tasks)
# Example: /path/to/Wan2.1-I2V-14B-720P-Lightx2v
i2v_model_path=/path/to/Wan2.1-I2V-14B-480P-Lightx2v
# Text-to-video model path (for t2v tasks)
# Example: /path/to/Wan2.1-T2V-1.3B
t2v_model_path=/path/to/Wan2.1-T2V-1.3B
# Model size configuration
# Default model size (14b, 1.3b)
model_size="14b"
# Model class configuration
# Default model class (wan2.1, wan2.1_distill)
model_cls="wan2.1"
model_path=/path/to/models
# Server configuration
server_name="0.0.0.0"
server_port=8032
server_port=8033
# Output directory configuration
output_dir="./outputs"
......@@ -50,18 +38,12 @@ export PROFILING_DEBUG_LEVEL=2
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
# ==================== Parameter Parsing ====================
# Default task type
task="i2v"
# Default interface language
lang="zh"
# 解析命令行参数
while [[ $# -gt 0 ]]; do
case $1 in
--task)
task="$2"
shift 2
;;
--lang)
lang="$2"
shift 2
......@@ -75,55 +57,32 @@ while [[ $# -gt 0 ]]; do
export CUDA_VISIBLE_DEVICES=$gpu_id
shift 2
;;
--model_size)
model_size="$2"
shift 2
;;
--model_cls)
model_cls="$2"
shift 2
;;
--output_dir)
output_dir="$2"
shift 2
;;
--model_path)
model_path="$2"
shift 2
;;
--help)
echo "🎬 Lightx2v Gradio Demo Startup Script"
echo "=========================================="
echo "Usage: $0 [options]"
echo ""
echo "📋 Available options:"
echo " --task i2v|t2v Task type (default: i2v)"
echo " i2v: Image-to-video generation"
echo " t2v: Text-to-video generation"
echo " --lang zh|en Interface language (default: zh)"
echo " zh: Chinese interface"
echo " en: English interface"
echo " --port PORT Server port (default: 8032)"
echo " --gpu GPU_ID GPU device ID (default: 0)"
echo " --model_size MODEL_SIZE"
echo " Model size (default: 14b)"
echo " 14b: 14 billion parameters model"
echo " 1.3b: 1.3 billion parameters model"
echo " --model_cls MODEL_CLASS"
echo " Model class (default: wan2.1)"
echo " wan2.1: Standard model variant"
echo " wan2.1_distill: Distilled model variant for faster inference"
echo " --output_dir OUTPUT_DIR"
echo " Output video save directory (default: ./saved_videos)"
echo " --model_path PATH Model path (default: configured in script)"
echo " --output_dir DIR Output video save directory (default: ./outputs)"
echo " --help Show this help message"
echo ""
echo "🚀 Usage examples:"
echo " $0 # Default startup for image-to-video mode"
echo " $0 --task i2v --lang zh --port 8032 # Start with specified parameters"
echo " $0 --task t2v --lang en --port 7860 # Text-to-video with English interface"
echo " $0 --task i2v --gpu 1 --port 8032 # Use GPU 1"
echo " $0 --task t2v --model_size 1.3b # Use 1.3B model"
echo " $0 --task i2v --model_size 14b # Use 14B model"
echo " $0 --task i2v --model_cls wan2.1_distill # Use distilled model"
echo " $0 --task i2v --output_dir ./custom_output # Use custom output directory"
echo ""
echo "📝 Notes:"
echo " - Task type (i2v/t2v) and model type are selected in the web UI"
echo " - Model class is auto-detected based on selected diffusion model"
echo " - Edit script to configure model paths before first use"
echo " - Ensure required Python dependencies are installed"
echo " - Recommended to use GPU with 8GB+ VRAM"
......@@ -139,37 +98,11 @@ while [[ $# -gt 0 ]]; do
done
# ==================== Parameter Validation ====================
if [[ "$task" != "i2v" && "$task" != "t2v" ]]; then
echo "Error: Task type must be 'i2v' or 't2v'"
exit 1
fi
if [[ "$lang" != "zh" && "$lang" != "en" ]]; then
echo "Error: Language must be 'zh' or 'en'"
exit 1
fi
# Validate model size
if [[ "$model_size" != "14b" && "$model_size" != "1.3b" ]]; then
echo "Error: Model size must be '14b' or '1.3b'"
exit 1
fi
# Validate model class
if [[ "$model_cls" != "wan2.1" && "$model_cls" != "wan2.1_distill" ]]; then
echo "Error: Model class must be 'wan2.1' or 'wan2.1_distill'"
exit 1
fi
# Select model path based on task type
if [[ "$task" == "i2v" ]]; then
model_path=$i2v_model_path
echo "🎬 Starting Image-to-Video mode"
else
model_path=$t2v_model_path
echo "🎬 Starting Text-to-Video mode"
fi
# Check if model path exists
if [[ ! -d "$model_path" ]]; then
echo "❌ Error: Model path does not exist"
......@@ -208,13 +141,11 @@ echo "🚀 Lightx2v Gradio Demo Starting..."
echo "=========================================="
echo "📁 Project path: $lightx2v_path"
echo "🤖 Model path: $model_path"
echo "🎯 Task type: $task"
echo "🤖 Model size: $model_size"
echo "🤖 Model class: $model_cls"
echo "🌏 Interface language: $lang"
echo "🖥️ GPU device: $gpu_id"
echo "🌐 Server address: $server_name:$server_port"
echo "📁 Output directory: $output_dir"
echo "📝 Note: Task type and model class are selected in web UI"
echo "=========================================="
# Display system resource information
......@@ -239,11 +170,8 @@ echo "=========================================="
# Start Python demo
python $demo_file \
--model_path "$model_path" \
--model_cls "$model_cls" \
--task "$task" \
--server_name "$server_name" \
--server_port "$server_port" \
--model_size "$model_size" \
--output_dir "$output_dir"
# Display final system resource usage
......
......@@ -16,21 +16,9 @@ REM Example: D:\LightX2V
set lightx2v_path=/path/to/LightX2V
REM Model path configuration
REM Image-to-video model path (for i2v tasks)
REM Example: D:\models\Wan2.1-I2V-14B-480P-Lightx2v
set i2v_model_path=/path/to/Wan2.1-I2V-14B-480P-Lightx2v
REM Text-to-video model path (for t2v tasks)
REM Example: D:\models\Wan2.1-T2V-1.3B
set t2v_model_path=/path/to/Wan2.1-T2V-1.3B
REM Model size configuration
REM Default model size (14b, 1.3b)
set model_size=14b
REM Model class configuration
REM Default model class (wan2.1, wan2.1_distill)
set model_cls=wan2.1
REM Model root directory path
REM Example: D:\models\LightX2V
set model_path=/path/to/LightX2V
REM Server configuration
set server_name=127.0.0.1
......@@ -49,20 +37,12 @@ set PROFILING_DEBUG_LEVEL=2
set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
REM ==================== Parameter Parsing ====================
REM Default task type
set task=i2v
REM Default interface language
set lang=zh
REM Parse command line arguments
:parse_args
if "%1"=="" goto :end_parse
if "%1"=="--task" (
set task=%2
shift
shift
goto :parse_args
)
if "%1"=="--lang" (
set lang=%2
shift
......@@ -82,18 +62,6 @@ if "%1"=="--gpu" (
shift
goto :parse_args
)
if "%1"=="--model_size" (
set model_size=%2
shift
shift
goto :parse_args
)
if "%1"=="--model_cls" (
set model_cls=%2
shift
shift
goto :parse_args
)
if "%1"=="--output_dir" (
set output_dir=%2
shift
......@@ -106,38 +74,24 @@ if "%1"=="--help" (
echo Usage: %0 [options]
echo.
echo 📋 Available options:
echo --task i2v^|t2v Task type (default: i2v)
echo i2v: Image-to-video generation
echo t2v: Text-to-video generation
echo --lang zh^|en Interface language (default: zh)
echo zh: Chinese interface
echo en: English interface
echo --port PORT Server port (default: 8032)
echo --gpu GPU_ID GPU device ID (default: 0)
echo --model_size MODEL_SIZE
echo Model size (default: 14b)
echo 14b: 14B parameter model
echo 1.3b: 1.3B parameter model
echo --model_cls MODEL_CLASS
echo Model class (default: wan2.1)
echo wan2.1: Standard model variant
echo wan2.1_distill: Distilled model variant for faster inference
echo --output_dir OUTPUT_DIR
echo Output video save directory (default: ./saved_videos)
echo Output video save directory (default: ./outputs)
echo --help Show this help message
echo.
echo 🚀 Usage examples:
echo %0 # Default startup for image-to-video mode
echo %0 --task i2v --lang zh --port 8032 # Start with specified parameters
echo %0 --task t2v --lang en --port 7860 # Text-to-video with English interface
echo %0 --task i2v --gpu 1 --port 8032 # Use GPU 1
echo %0 --task t2v --model_size 1.3b # Use 1.3B model
echo %0 --task i2v --model_size 14b # Use 14B model
echo %0 --task i2v --model_cls wan2.1_distill # Use distilled model
echo %0 --task i2v --output_dir ./custom_output # Use custom output directory
echo %0 # Default startup
echo %0 --lang zh --port 8032 # Start with specified parameters
echo %0 --lang en --port 7860 # English interface
echo %0 --gpu 1 --port 8032 # Use GPU 1
echo %0 --output_dir ./custom_output # Use custom output directory
echo.
echo 📝 Notes:
echo - Edit script to configure model paths before first use
echo - Edit script to configure model path before first use
echo - Ensure required Python dependencies are installed
echo - Recommended to use GPU with 8GB+ VRAM
echo - 🚨 Strongly recommend storing models on SSD for better performance
......@@ -152,13 +106,6 @@ exit /b 1
:end_parse
REM ==================== Parameter Validation ====================
if "%task%"=="i2v" goto :valid_task
if "%task%"=="t2v" goto :valid_task
echo Error: Task type must be 'i2v' or 't2v'
pause
exit /b 1
:valid_task
if "%lang%"=="zh" goto :valid_lang
if "%lang%"=="en" goto :valid_lang
echo Error: Language must be 'zh' or 'en'
......@@ -166,29 +113,6 @@ pause
exit /b 1
:valid_lang
if "%model_size%"=="14b" goto :valid_size
if "%model_size%"=="1.3b" goto :valid_size
echo Error: Model size must be '14b' or '1.3b'
pause
exit /b 1
:valid_size
if "%model_cls%"=="wan2.1" goto :valid_cls
if "%model_cls%"=="wan2.1_distill" goto :valid_cls
echo Error: Model class must be 'wan2.1' or 'wan2.1_distill'
pause
exit /b 1
:valid_cls
REM Select model path based on task type
if "%task%"=="i2v" (
set model_path=%i2v_model_path%
echo 🎬 Starting Image-to-Video mode
) else (
set model_path=%t2v_model_path%
echo 🎬 Starting Text-to-Video mode
)
REM Check if model path exists
if not exist "%model_path%" (
......@@ -230,9 +154,6 @@ echo 🚀 LightX2V Gradio Starting...
echo ==========================================
echo 📁 Project path: %lightx2v_path%
echo 🤖 Model path: %model_path%
echo 🎯 Task type: %task%
echo 🤖 Model size: %model_size%
echo 🤖 Model class: %model_cls%
echo 🌏 Interface language: %lang%
echo 🖥️ GPU device: %gpu_id%
echo 🌐 Server address: %server_name%:%server_port%
......@@ -262,11 +183,8 @@ echo ==========================================
REM Start Python demo
python %demo_file% ^
--model_path "%model_path%" ^
--model_cls %model_cls% ^
--task %task% ^
--server_name %server_name% ^
--server_port %server_port% ^
--model_size %model_size% ^
--output_dir "%output_dir%"
REM Display final system resource usage
......
......@@ -38,51 +38,52 @@ Follow the [Quick Start Guide](../getting_started/quickstart.md) to install the
-[sglang-kernel](https://github.com/sgl-project/sglang/tree/main/sgl-kernel)
-[q8-kernel](https://github.com/KONAKONA666/q8_kernels) (only supports ADA architecture GPUs)
Install according to the project homepage tutorials for each operator as needed
Install according to the project homepage tutorials for each operator as needed.
### 🤖 Supported Models
### 📥 Model Download
#### 🎬 Image-to-Video Models
Refer to the [Model Structure Documentation](../getting_started/model_structure.md) to download complete models (including quantized and non-quantized versions) or download only quantized/non-quantized versions.
| Model Name | Resolution | Parameters | Features | Recommended Use |
|------------|------------|------------|----------|-----------------|
| ✅ [Wan2.1-I2V-14B-480P-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-480P-Lightx2v) | 480p | 14B | Standard version | Balance speed and quality |
| ✅ [Wan2.1-I2V-14B-720P-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-720P-Lightx2v) | 720p | 14B | HD version | Pursue high-quality output |
| ✅ [Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-Lightx2v) | 480p | 14B | Distilled optimized version | Faster inference speed |
| ✅ [Wan2.1-I2V-14B-720P-StepDistill-CfgDistill-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-720P-StepDistill-CfgDistill-Lightx2v) | 720p | 14B | HD distilled version | High quality + fast inference |
#### wan2.1 Model Directory Structure
#### 📝 Text-to-Video Models
| Model Name | Parameters | Features | Recommended Use |
|------------|------------|----------|-----------------|
| ✅ [Wan2.1-T2V-1.3B-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-T2V-1.3B-Lightx2v) | 1.3B | Lightweight | Fast prototyping and testing |
| ✅ [Wan2.1-T2V-14B-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-Lightx2v) | 14B | Standard version | Balance speed and quality |
| ✅ [Wan2.1-T2V-14B-StepDistill-CfgDistill-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill-Lightx2v) | 14B | Distilled optimized version | High quality + fast inference |
**💡 Model Selection Recommendations**:
- **First-time use**: Recommend choosing distilled versions (`wan2.1_distill`)
- **Pursuing quality**: Choose 720p resolution or 14B parameter models
- **Pursuing speed**: Choose 480p resolution or 1.3B parameter models, prioritize distilled versions
- **Resource-constrained**: Prioritize distilled versions and lower resolutions
- **Real-time applications**: Strongly recommend using distilled models (`wan2.1_distill`)
**🎯 Model Category Description**:
- **`wan2.1`**: Standard model, provides the best video generation quality, suitable for scenarios with extremely high quality requirements
- **`wan2.1_distill`**: Distilled model, optimized through knowledge distillation technology, significantly improves inference speed, maintains good quality while greatly reducing computation time, suitable for most application scenarios
**📥 Model Download**:
Refer to the [Model Structure Documentation](./model_structure.md) to download complete models (including quantized and non-quantized versions) or download only quantized/non-quantized versions.
**Download Options**:
```
models/
├── wan2.1_i2v_720p_lightx2v_4step.safetensors # Original precision
├── wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step.safetensors # FP8 quantization
├── wan2.1_i2v_720p_int8_lightx2v_4step.safetensors # INT8 quantization
├── wan2.1_i2v_720p_int8_lightx2v_4step_split # INT8 quantization block storage directory
├── wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step_split # FP8 quantization block storage directory
├── Other weights (e.g., t2v)
├── t5/clip/xlm-roberta-large/google # text and image encoder
├── vae/lightvae/lighttae # vae
└── config.json # Model configuration file
```
- **Complete Model**: When downloading complete models with both quantized and non-quantized versions, you can freely choose the quantization precision for DIT/T5/CLIP in the advanced options of the `Gradio` Web frontend.
#### wan2.2 Model Directory Structure
- **Non-quantized Version Only**: When downloading only non-quantized versions, in the `Gradio` Web frontend, the quantization precision for `DIT/T5/CLIP` can only be set to bf16/fp16. If you need to use quantized versions of models, please manually download quantized weights to the `i2v_model_path` or `t2v_model_path` directory where Gradio is started.
```
models/
├── wan2.2_i2v_A14b_high_noise_lightx2v_4step_1030.safetensors # high noise original precision
├── wan2.2_i2v_A14b_high_noise_fp8_e4m3_lightx2v_4step_1030.safetensors # high noise FP8 quantization
├── wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step_1030.safetensors # high noise INT8 quantization
├── wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step_1030_split # high noise INT8 quantization block storage directory
├── wan2.2_i2v_A14b_low_noise_lightx2v_4step.safetensors # low noise original precision
├── wan2.2_i2v_A14b_low_noise_fp8_e4m3_lightx2v_4step.safetensors # low noise FP8 quantization
├── wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors # low noise INT8 quantization
├── wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step_split # low noise INT8 quantization block storage directory
├── t5/clip/xlm-roberta-large/google # text and image encoder
├── vae/lightvae/lighttae # vae
└── config.json # Model configuration file
```
- **Quantized Version Only**: When downloading only quantized versions, in the `Gradio` Web frontend, the quantization precision for `DIT/T5/CLIP` can only be set to fp8 or int8 (depending on the weights you downloaded). If you need to use non-quantized versions of models, please manually download non-quantized weights to the `i2v_model_path` or `t2v_model_path` directory where Gradio is started.
**📝 Download Instructions**:
- **Note**: Whether you download complete models or partial models, the values for `i2v_model_path` and `t2v_model_path` parameters should be the first-level directory paths. For example: `Wan2.1-I2V-14B-480P-Lightx2v/`, not `Wan2.1-I2V-14B-480P-Lightx2v/int8`.
- Model weights can be downloaded from HuggingFace:
- [Wan2.1-Distill-Models](https://huggingface.co/lightx2v/Wan2.1-Distill-Models)
- [Wan2.2-Distill-Models](https://huggingface.co/lightx2v/Wan2.2-Distill-Models)
- Text and Image Encoders can be downloaded from [Encoders](https://huggingface.co/lightx2v/Encoderss)
- VAE can be downloaded from [Autoencoders](https://huggingface.co/lightx2v/Autoencoders)
- For `xxx_split` directories (e.g., `wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step_split`), which store multiple safetensors by block, suitable for devices with insufficient memory. For example, devices with 16GB or less memory should download according to their own situation.
### Startup Methods
......@@ -96,8 +97,7 @@ vim run_gradio.sh
# Configuration items that need to be modified:
# - lightx2v_path: Lightx2v project root directory path
# - i2v_model_path: Image-to-video model path
# - t2v_model_path: Text-to-video model path
# - model_path: Model root directory path (contains all model files)
# 💾 Important note: Recommend pointing model paths to SSD storage locations
# Example: /mnt/ssd/models/ or /data/ssd/models/
......@@ -105,11 +105,9 @@ vim run_gradio.sh
# 2. Run the startup script
bash run_gradio.sh
# 3. Or start with parameters (recommended using distilled models)
bash run_gradio.sh --task i2v --lang en --model_cls wan2.1 --model_size 14b --port 8032
bash run_gradio.sh --task t2v --lang en --model_cls wan2.1 --model_size 1.3b --port 8032
bash run_gradio.sh --task i2v --lang en --model_cls wan2.1_distill --model_size 14b --port 8032
bash run_gradio.sh --task t2v --lang en --model_cls wan2.1_distill --model_size 1.3b --port 8032
# 3. Or start with parameters
bash run_gradio.sh --lang en --port 8032
bash run_gradio.sh --lang zh --port 7862
```
**Windows Environment:**
......@@ -120,8 +118,7 @@ notepad run_gradio_win.bat
# Configuration items that need to be modified:
# - lightx2v_path: Lightx2v project root directory path
# - i2v_model_path: Image-to-video model path
# - t2v_model_path: Text-to-video model path
# - model_path: Model root directory path (contains all model files)
# 💾 Important note: Recommend pointing model paths to SSD storage locations
# Example: D:\models\ or E:\models\
......@@ -129,201 +126,101 @@ notepad run_gradio_win.bat
# 2. Run the startup script
run_gradio_win.bat
# 3. Or start with parameters (recommended using distilled models)
run_gradio_win.bat --task i2v --lang en --model_cls wan2.1 --model_size 14b --port 8032
run_gradio_win.bat --task t2v --lang en --model_cls wan2.1 --model_size 1.3b --port 8032
run_gradio_win.bat --task i2v --lang en --model_cls wan2.1_distill --model_size 14b --port 8032
run_gradio_win.bat --task t2v --lang en --model_cls wan2.1_distill --model_size 1.3b --port 8032
# 3. Or start with parameters
run_gradio_win.bat --lang en --port 8032
run_gradio_win.bat --lang zh --port 7862
```
#### Method 2: Direct Command Line Startup
```bash
pip install -v git+https://github.com/ModelTC/LightX2V.git
```
**Linux Environment:**
**Image-to-Video Mode:**
**English Interface Version:**
```bash
python gradio_demo.py \
--model_path /path/to/Wan2.1-I2V-14B-480P-Lightx2v \
--model_cls wan2.1 \
--model_size 14b \
--task i2v \
--model_path /path/to/models \
--server_name 0.0.0.0 \
--server_port 7862
```
**English Interface Version:**
**Chinese Interface Version:**
```bash
python gradio_demo.py \
--model_path /path/to/Wan2.1-T2V-14B-StepDistill-CfgDistill-Lightx2v \
--model_cls wan2.1_distill \
--model_size 14b \
--task t2v \
python gradio_demo_zh.py \
--model_path /path/to/models \
--server_name 0.0.0.0 \
--server_port 7862
```
**Windows Environment:**
**Image-to-Video Mode:**
**English Interface Version:**
```cmd
python gradio_demo.py ^
--model_path D:\models\Wan2.1-I2V-14B-480P-Lightx2v ^
--model_cls wan2.1 ^
--model_size 14b ^
--task i2v ^
--model_path D:\models ^
--server_name 127.0.0.1 ^
--server_port 7862
```
**English Interface Version:**
**Chinese Interface Version:**
```cmd
python gradio_demo.py ^
--model_path D:\models\Wan2.1-T2V-14B-StepDistill-CfgDistill-Lightx2v ^
--model_cls wan2.1_distill ^
--model_size 14b ^
--task t2v ^
python gradio_demo_zh.py ^
--model_path D:\models ^
--server_name 127.0.0.1 ^
--server_port 7862
```
**💡 Tip**: Model type (wan2.1/wan2.2), task type (i2v/t2v), and specific model file selection are all configured in the Web interface.
## 📋 Command Line Parameters
| Parameter | Type | Required | Default | Description |
|-----------|------|----------|---------|-------------|
| `--model_path` | str | ✅ | - | Model folder path |
| `--model_cls` | str | ❌ | wan2.1 | Model class: `wan2.1` (standard model) or `wan2.1_distill` (distilled model, faster inference) |
| `--model_size` | str | ✅ | - | Model size: `14b (image-to-video or text-to-video)` or `1.3b (text-to-video)` |
| `--task` | str | ✅ | - | Task type: `i2v` (image-to-video) or `t2v` (text-to-video) |
| `--model_path` | str | ✅ | - | Model root directory path (directory containing all model files) |
| `--server_port` | int | ❌ | 7862 | Server port |
| `--server_name` | str | ❌ | 0.0.0.0 | Server IP address |
| `--output_dir` | str | ❌ | ./outputs | Output video save directory |
**💡 Note**: Model type (wan2.1/wan2.2), task type (i2v/t2v), and specific model file selection are all configured in the Web interface.
## 🎯 Features
### Basic Settings
### Model Configuration
- **Model Type**: Supports wan2.1 and wan2.2 model architectures
- **Task Type**: Supports Image-to-Video (i2v) and Text-to-Video (t2v) generation modes
- **Model Selection**: Frontend automatically identifies and filters available model files, supports automatic quantization precision detection
- **Encoder Configuration**: Supports selection of T5 text encoder, CLIP image encoder, and VAE decoder
- **Operator Selection**: Supports multiple attention operators and quantization matrix multiplication operators, system automatically sorts by installation status
### Input Parameters
#### Input Parameters
- **Prompt**: Describe the expected video content
- **Negative Prompt**: Specify elements you don't want to appear
- **Input Image**: Upload input image required in i2v mode
- **Resolution**: Supports multiple preset resolutions (480p/540p/720p)
- **Random Seed**: Controls the randomness of generation results
- **Inference Steps**: Affects the balance between generation quality and speed
- **Inference Steps**: Affects the balance between generation quality and speed (defaults to 4 steps for distilled models)
### Video Parameters
#### Video Parameters
- **FPS**: Frames per second
- **Total Frames**: Video length
- **CFG Scale Factor**: Controls prompt influence strength (1-10)
- **CFG Scale Factor**: Controls prompt influence strength (1-10, defaults to 1 for distilled models)
- **Distribution Shift**: Controls generation style deviation degree (0-10)
### Advanced Optimization Options
#### GPU Memory Optimization
- **Chunked Rotary Position Embedding**: Saves GPU memory
- **Rotary Embedding Chunk Size**: Controls chunk granularity
- **Clean CUDA Cache**: Promptly frees GPU memory
#### Asynchronous Offloading
- **CPU Offloading**: Transfers partial computation to CPU
- **Lazy Loading**: Loads model components on-demand, significantly reduces system memory consumption
- **Offload Granularity Control**: Fine-grained control of offloading strategies
#### Low-Precision Quantization
- **Attention Operators**: Flash Attention, Sage Attention, etc.
- **Quantization Operators**: vLLM, SGL, Q8F, etc.
- **Precision Modes**: FP8, INT8, BF16, etc.
#### VAE Optimization
- **Lightweight VAE**: Accelerates decoding process
- **VAE Tiling Inference**: Reduces memory usage
## 🔧 Auto-Configuration Feature
#### Feature Caching
- **Tea Cache**: Caches intermediate features to accelerate generation
- **Cache Threshold**: Controls cache trigger conditions
- **Key Step Caching**: Writes cache only at key steps
The system automatically configures optimal inference options based on your hardware configuration (GPU VRAM and CPU memory) without manual adjustment. The best configuration is automatically applied on startup, including:
## 🔧 Auto-Configuration Feature
- **GPU Memory Optimization**: Automatically enables CPU offloading, VAE tiling inference, etc. based on VRAM size
- **CPU Memory Optimization**: Automatically enables lazy loading, module unloading, etc. based on system memory
- **Operator Selection**: Automatically selects the best installed operators (sorted by priority)
- **Quantization Configuration**: Automatically detects and applies quantization precision based on model file names
After enabling "Auto-configure Inference Options", the system will automatically optimize parameters based on your hardware configuration:
### GPU Memory Rules
- **80GB+**: Default configuration, no optimization needed
- **48GB**: Enable CPU offloading, offload ratio 50%
- **40GB**: Enable CPU offloading, offload ratio 80%
- **32GB**: Enable CPU offloading, offload ratio 100%
- **24GB**: Enable BF16 precision, VAE tiling
- **16GB**: Enable chunked offloading, rotary embedding chunking
- **12GB**: Enable cache cleaning, lightweight VAE
- **8GB**: Enable quantization, lazy loading
### CPU Memory Rules
- **128GB+**: Default configuration
- **64GB**: Enable DIT quantization
- **32GB**: Enable lazy loading
- **16GB**: Enable full model quantization
## ⚠️ Important Notes
### 🚀 Low-Resource Device Optimization Recommendations
**💡 For devices with insufficient VRAM or performance constraints**:
- **🎯 Model Selection**: Prioritize using distilled version models (`wan2.1_distill`)
- **⚡ Inference Steps**: Recommend setting to 4 steps
- **🔧 CFG Settings**: Recommend disabling CFG option to improve generation speed
- **🔄 Auto-Configuration**: Enable "Auto-configure Inference Options"
- **💾 Storage Optimization**: Ensure models are stored on SSD for optimal loading performance
## 🎨 Interface Description
### Basic Settings Tab
- **Input Parameters**: Prompts, resolution, and other basic settings
- **Video Parameters**: FPS, frame count, CFG, and other video generation parameters
- **Output Settings**: Video save path configuration
### Advanced Options Tab
- **GPU Memory Optimization**: Memory management related options
- **Asynchronous Offloading**: CPU offloading and lazy loading
- **Low-Precision Quantization**: Various quantization optimization options
- **VAE Optimization**: Variational Autoencoder optimization
- **Feature Caching**: Cache strategy configuration
## 🔍 Troubleshooting
### Common Issues
**💡 Tip**: Generally, after enabling "Auto-configure Inference Options", the system will automatically optimize parameter settings based on your hardware configuration, and performance issues usually won't occur. If you encounter problems, please refer to the following solutions:
1. **Gradio Webpage Opens Blank**
- Try upgrading gradio: `pip install --upgrade gradio`
2. **CUDA Memory Insufficient**
- Enable CPU offloading
- Reduce resolution
- Enable quantization options
3. **System Memory Insufficient**
- Enable CPU offloading
- Enable lazy loading option
- Enable quantization options
4. **Slow Generation Speed**
- Reduce inference steps
- Enable auto-configuration
- Use lightweight models
- Enable Tea Cache
- Use quantization operators
- 💾 **Check if models are stored on SSD**
5. **Slow Model Loading**
- 💾 **Migrate models to SSD storage**
- Enable lazy loading option
- Check disk I/O performance
- Consider using NVMe SSD
6. **Poor Video Quality**
- Increase inference steps
- Increase CFG scale factor
- Use 14B models
- Optimize prompts
### Log Viewing
......
......@@ -38,51 +38,53 @@ LightX2V/app/
-[sglang-kernel](https://github.com/sgl-project/sglang/tree/main/sgl-kernel)
-[q8-kernel](https://github.com/KONAKONA666/q8_kernels) (仅支持ADA架构的GPU)
可根据需要,按照各算子的项目主页教程进行安装
可根据需要,按照各算子的项目主页教程进行安装
### 🤖 支持的模型
### 📥 模型下载
#### 🎬 图像到视频模型 (Image-to-Video)
可参考[模型结构文档](../getting_started/model_structure.md)下载完整模型(包含量化和非量化版本)或仅下载量化/非量化版本。
| 模型名称 | 分辨率 | 参数量 | 特点 | 推荐场景 |
|----------|--------|--------|------|----------|
| ✅ [Wan2.1-I2V-14B-480P-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-480P-Lightx2v) | 480p | 14B | 标准版本 | 平衡速度和质量 |
| ✅ [Wan2.1-I2V-14B-720P-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-720P-Lightx2v) | 720p | 14B | 高清版本 | 追求高质量输出 |
| ✅ [Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-Lightx2v) | 480p | 14B | 蒸馏优化版 | 更快的推理速度 |
| ✅ [Wan2.1-I2V-14B-720P-StepDistill-CfgDistill-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-720P-StepDistill-CfgDistill-Lightx2v) | 720p | 14B | 高清蒸馏版 | 高质量+快速推理 |
#### wan2.1 模型目录结构
#### 📝 文本到视频模型 (Text-to-Video)
| 模型名称 | 参数量 | 特点 | 推荐场景 |
|----------|--------|------|----------|
| ✅ [Wan2.1-T2V-1.3B-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-T2V-1.3B-Lightx2v) | 1.3B | 轻量级 | 快速原型测试 |
| ✅ [Wan2.1-T2V-14B-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-Lightx2v) | 14B | 标准版本 | 平衡速度和质量 |
| ✅ [Wan2.1-T2V-14B-StepDistill-CfgDistill-Lightx2v](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill-Lightx2v) | 14B | 蒸馏优化版 | 高质量+快速推理 |
**💡 模型选择建议**:
- **首次使用**: 建议选择蒸馏版本 (`wan2.1_distill`)
- **追求质量**: 选择720p分辨率或14B参数模型
- **追求速度**: 选择480p分辨率或1.3B参数模型,优先使用蒸馏版本
- **资源受限**: 优先选择蒸馏版本和较低分辨率
- **实时应用**: 强烈推荐使用蒸馏模型 (`wan2.1_distill`)
**🎯 模型类别说明**:
- **`wan2.1`**: 标准模型,提供最佳的视频生成质量,适合对质量要求极高的场景
- **`wan2.1_distill`**: 蒸馏模型,通过知识蒸馏技术优化,推理速度显著提升,在保持良好质量的同时大幅减少计算时间,适合大多数应用场景
**📥 下载模型**:
可参考[模型结构文档](./model_structure.md)下载完整模型(包含量化和非量化版本)或仅下载量化/非量化版本。
```
models/
├── wan2.1_i2v_720p_lightx2v_4step.safetensors # 原始精度
├── wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step.safetensors # FP8 量化
├── wan2.1_i2v_720p_int8_lightx2v_4step.safetensors # INT8 量化
├── wan2.1_i2v_720p_int8_lightx2v_4step_split # INT8 量化分block存储目录
├── wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step_split # FP8 量化分block存储目录
├── 其他权重(例如t2v)
├── t5/clip/xlm-roberta-large/google # text和image encoder
├── vae/lightvae/lighttae # vae
└── config.json # 模型配置文件
```
**下载选项说明**
#### wan2.2 模型目录结构
- **完整模型**:下载包含量化和非量化版本的完整模型时,在`Gradio` Web前端的高级选项中可以自由选择DIT/T5/CLIP的量化精度。
```
models/
├── wan2.2_i2v_A14b_high_noise_lightx2v_4step_1030.safetensors # high noise 原始精度
├── wan2.2_i2v_A14b_high_noise_fp8_e4m3_lightx2v_4step_1030.safetensors # high noise FP8 量化
├── wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step_1030.safetensors # high noise INT8 量化
├── wan2.2_i2v_A14b_high_noise_int8_lightx2v_4step_1030_split # high noise INT8 量化分block存储目录
├── wan2.2_i2v_A14b_low_noise_lightx2v_4step.safetensors # low noise 原始精度
├── wan2.2_i2v_A14b_low_noise_fp8_e4m3_lightx2v_4step.safetensors # low noise FP8 量化
├── wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step.safetensors # low noise INT8 量化
├── wan2.2_i2v_A14b_low_noise_int8_lightx2v_4step_split # low noise INT8 量化分block存储目录
├── t5/clip/xlm-roberta-large/google # text和image encoder
├── vae/lightvae/lighttae # vae
└── config.json # 模型配置文件
```
- **仅非量化版本**:仅下载非量化版本时,在`Gradio` Web前端中,`DIT/T5/CLIP`的量化精度只能选择bf16/fp16。如需使用量化版本的模型,请手动下载量化权重到Gradio启动的`i2v_model_path`或者`t2v_model_path`目录下。
**📝 下载说明**
- **仅量化版本**:仅下载量化版本时,在`Gradio` Web前端中,`DIT/T5/CLIP`的量化精度只能选择fp8或int8(取决于您下载的权重)。如需使用非量化版本的模型,请手动下载非量化权重到Gradio启动的`i2v_model_path`或者`t2v_model_path`目录下。
- 模型权重可从 HuggingFace 下载:
- [Wan2.1-Distill-Models](https://huggingface.co/lightx2v/Wan2.1-Distill-Models)
- [Wan2.2-Distill-Models](https://huggingface.co/lightx2v/Wan2.2-Distill-Models)
- Text 和 Image Encoder 可从 [Encoders](https://huggingface.co/lightx2v/Encoderss) 下载
- VAE 可从 [Autoencoders](https://huggingface.co/lightx2v/Autoencoders) 下载
- 对于 `xxx_split` 目录(例如 `wan2.1_i2v_720p_scaled_fp8_e4m3_lightx2v_4step_split`),即按照 block 存储的多个 safetensors,适用于内存不足的设备。例如内存 16GB 以内,请根据自身情况下载
- **注意**:无论是下载了完整模型还是部分模型,`i2v_model_path``t2v_model_path` 参数的值都应该是一级目录的路径。例如:`Wan2.1-I2V-14B-480P-Lightx2v/`,而不是 `Wan2.1-I2V-14B-480P-Lightx2v/int8`
### 启动方式
......@@ -96,8 +98,7 @@ vim run_gradio.sh
# 需要修改的配置项:
# - lightx2v_path: Lightx2v项目根目录路径
# - i2v_model_path: 图像到视频模型路径
# - t2v_model_path: 文本到视频模型路径
# - model_path: 模型根目录路径(包含所有模型文件)
# 💾 重要提示:建议将模型路径指向SSD存储位置
# 例如:/mnt/ssd/models/ 或 /data/ssd/models/
......@@ -105,11 +106,9 @@ vim run_gradio.sh
# 2. 运行启动脚本
bash run_gradio.sh
# 3. 或使用参数启动(推荐使用蒸馏模型)
bash run_gradio.sh --task i2v --lang zh --model_cls wan2.1 --model_size 14b --port 8032
bash run_gradio.sh --task t2v --lang zh --model_cls wan2.1 --model_size 1.3b --port 8032
bash run_gradio.sh --task i2v --lang zh --model_cls wan2.1_distill --model_size 14b --port 8032
bash run_gradio.sh --task t2v --lang zh --model_cls wan2.1_distill --model_size 1.3b --port 8032
# 3. 或使用参数启动
bash run_gradio.sh --lang zh --port 8032
bash run_gradio.sh --lang en --port 7862
```
**Windows 环境:**
......@@ -120,8 +119,7 @@ notepad run_gradio_win.bat
# 需要修改的配置项:
# - lightx2v_path: Lightx2v项目根目录路径
# - i2v_model_path: 图像到视频模型路径
# - t2v_model_path: 文本到视频模型路径
# - model_path: 模型根目录路径(包含所有模型文件)
# 💾 重要提示:建议将模型路径指向SSD存储位置
# 例如:D:\models\ 或 E:\models\
......@@ -129,24 +127,23 @@ notepad run_gradio_win.bat
# 2. 运行启动脚本
run_gradio_win.bat
# 3. 或使用参数启动(推荐使用蒸馏模型)
run_gradio_win.bat --task i2v --lang zh --model_cls wan2.1 --model_size 14b --port 8032
run_gradio_win.bat --task t2v --lang zh --model_cls wan2.1 --model_size 1.3b --port 8032
run_gradio_win.bat --task i2v --lang zh --model_cls wan2.1_distill --model_size 14b --port 8032
run_gradio_win.bat --task t2v --lang zh --model_cls wan2.1_distill --model_size 1.3b --port 8032
# 3. 或使用参数启动
run_gradio_win.bat --lang zh --port 8032
run_gradio_win.bat --lang en --port 7862
```
#### 方式二:直接命令行启动
```bash
pip install -v git+https://github.com/ModelTC/LightX2V.git
```
**Linux 环境:**
**图像到视频模式:**
**中文界面版本:**
```bash
python gradio_demo_zh.py \
--model_path /path/to/Wan2.1-I2V-14B-480P-Lightx2v \
--model_cls wan2.1 \
--model_size 14b \
--task i2v \
--model_path /path/to/models \
--server_name 0.0.0.0 \
--server_port 7862
```
......@@ -154,176 +151,77 @@ python gradio_demo_zh.py \
**英文界面版本:**
```bash
python gradio_demo.py \
--model_path /path/to/Wan2.1-T2V-14B-StepDistill-CfgDistill-Lightx2v \
--model_cls wan2.1_distill \
--model_size 14b \
--task t2v \
--model_path /path/to/models \
--server_name 0.0.0.0 \
--server_port 7862
```
**Windows 环境:**
**图像到视频模式:**
**中文界面版本:**
```cmd
python gradio_demo_zh.py ^
--model_path D:\models\Wan2.1-I2V-14B-480P-Lightx2v ^
--model_cls wan2.1 ^
--model_size 14b ^
--task i2v ^
--model_path D:\models ^
--server_name 127.0.0.1 ^
--server_port 7862
```
**英文界面版本:**
```cmd
python gradio_demo_zh.py ^
--model_path D:\models\Wan2.1-T2V-14B-StepDistill-CfgDistill-Lightx2v ^
--model_cls wan2.1_distill ^
--model_size 14b ^
--task i2v ^
python gradio_demo.py ^
--model_path D:\models ^
--server_name 127.0.0.1 ^
--server_port 7862
```
**💡 提示**:模型类型(wan2.1/wan2.2)、任务类型(i2v/t2v)以及具体的模型文件选择均在 Web 界面中进行配置。
## 📋 命令行参数
| 参数 | 类型 | 必需 | 默认值 | 说明 |
|------|------|------|--------|------|
| `--model_path` | str | ✅ | - | 模型文件夹路径 |
| `--model_cls` | str | ❌ | wan2.1 | 模型类别:`wan2.1`(标准模型)或 `wan2.1_distill`(蒸馏模型,推理更快) |
| `--model_size` | str | ✅ | - | 模型大小:`14b``1.3b)` |
| `--task` | str | ✅ | - | 任务类型:`i2v`(图像到视频)或 `t2v`(文本到视频) |
| `--model_path` | str | ✅ | - | 模型根目录路径(包含所有模型文件的目录) |
| `--server_port` | int | ❌ | 7862 | 服务器端口 |
| `--server_name` | str | ❌ | 0.0.0.0 | 服务器IP地址 |
| `--output_dir` | str | ❌ | ./outputs | 输出视频保存目录 |
**💡 说明**:模型类型(wan2.1/wan2.2)、任务类型(i2v/t2v)以及具体的模型文件选择均在 Web 界面中进行配置。
## 🎯 功能特性
### 基本设置
### 模型配置
- **模型类型**: 支持 wan2.1 和 wan2.2 两种模型架构
- **任务类型**: 支持图像到视频(i2v)和文本到视频(t2v)两种生成模式
- **模型选择**: 前端自动识别并筛选可用的模型文件,支持自动检测量化精度
- **编码器配置**: 支持选择 T5 文本编码器、CLIP 图像编码器和 VAE 解码器
- **算子选择**: 支持多种注意力算子和量化矩阵乘法算子,系统会根据安装状态自动排序
### 输入参数
#### 输入参数
- **提示词 (Prompt)**: 描述期望的视频内容
- **负向提示词 (Negative Prompt)**: 指定不希望出现的元素
- **输入图像**: i2v 模式下需要上传输入图像
- **分辨率**: 支持多种预设分辨率(480p/540p/720p)
- **随机种子**: 控制生成结果的随机性
- **推理步数**: 影响生成质量和速度的平衡
- **推理步数**: 影响生成质量和速度的平衡(蒸馏模型默认为 4 步)
### 视频参数
#### 视频参数
- **FPS**: 每秒帧数
- **总帧数**: 视频长度
- **CFG缩放因子**: 控制提示词影响强度(1-10)
- **CFG缩放因子**: 控制提示词影响强度(1-10,蒸馏模型默认为 1
- **分布偏移**: 控制生成风格偏离程度(0-10)
### 高级优化选项
#### GPU内存优化
- **分块旋转位置编码**: 节省GPU内存
- **旋转编码块大小**: 控制分块粒度
- **清理CUDA缓存**: 及时释放GPU内存
#### 异步卸载
- **CPU卸载**: 将部分计算转移到CPU
- **延迟加载**: 按需加载模型组件,显著节省系统内存消耗
- **卸载粒度控制**: 精细控制卸载策略
#### 低精度量化
- **注意力算子**: Flash Attention、Sage Attention等
- **量化算子**: vLLM、SGL、Q8F等
- **精度模式**: FP8、INT8、BF16等
#### VAE优化
- **轻量级VAE**: 加速解码过程
- **VAE分块推理**: 减少内存占用
## 🔧 自动配置功能
#### 特征缓存
- **Tea Cache**: 缓存中间特征加速生成
- **缓存阈值**: 控制缓存触发条件
- **关键步缓存**: 仅在关键步骤写入缓存
系统会根据您的硬件配置(GPU 显存和 CPU 内存)自动配置最优推理选项,无需手动调整。启动时会自动应用最佳配置,包括:
## 🔧 自动配置功能
- **GPU 内存优化**: 根据显存大小自动启用 CPU 卸载、VAE 分块推理等
- **CPU 内存优化**: 根据系统内存自动启用延迟加载、模块卸载等
- **算子选择**: 自动选择已安装的最优算子(按优先级排序)
- **量化配置**: 根据模型文件名自动检测并应用量化精度
启用"自动配置推理选项"后,系统会根据您的硬件配置自动优化参数:
### GPU内存规则
- **80GB+**: 默认配置,无需优化
- **48GB**: 启用CPU卸载,卸载比例50%
- **40GB**: 启用CPU卸载,卸载比例80%
- **32GB**: 启用CPU卸载,卸载比例100%
- **24GB**: 启用BF16精度、VAE分块
- **16GB**: 启用分块卸载、旋转编码分块
- **12GB**: 启用清理缓存、轻量级VAE
- **8GB**: 启用量化、延迟加载
### CPU内存规则
- **128GB+**: 默认配置
- **64GB**: 启用DIT量化
- **32GB**: 启用延迟加载
- **16GB**: 启用全模型量化
## ⚠️ 重要注意事项
### 🚀 低资源设备优化建议
**💡 针对显存不足或性能受限的设备**:
- **🎯 模型选择**: 优先使用蒸馏版本模型 (`wan2.1_distill`)
- **⚡ 推理步数**: 建议设置为 4 步
- **🔧 CFG设置**: 建议关闭CFG选项以提升生成速度
- **🔄 自动配置**: 启用"自动配置推理选项"
- **💾 存储优化**: 确保模型存储在SSD上以获得最佳加载性能
## 🎨 界面说明
### 基本设置标签页
- **输入参数**: 提示词、分辨率等基本设置
- **视频参数**: FPS、帧数、CFG等视频生成参数
- **输出设置**: 视频保存路径配置
### 高级选项标签页
- **GPU内存优化**: 内存管理相关选项
- **异步卸载**: CPU卸载和延迟加载
- **低精度量化**: 各种量化优化选项
- **VAE优化**: 变分自编码器优化
- **特征缓存**: 缓存策略配置
## 🔍 故障排除
### 常见问题
**💡 提示**: 一般情况下,启用"自动配置推理选项"后,系统会根据您的硬件配置自动优化参数设置,通常不会出现性能问题。如果遇到问题,请参考以下解决方案:
1. **Gradio网页打开空白**
- 尝试升级gradio `pip install --upgrade gradio`
2. **CUDA内存不足**
- 启用CPU卸载
- 降低分辨率
- 启用量化选项
3. **系统内存不足**
- 启用CPU卸载
- 启用延迟加载选项
- 启用量化选项
4. **生成速度慢**
- 减少推理步数
- 启用自动配置
- 使用轻量级模型
- 启用Tea Cache
- 使用量化算子
- 💾 **检查模型是否存放在SSD上**
5. **模型加载缓慢**
- 💾 **将模型迁移到SSD存储**
- 启用延迟加载选项
- 检查磁盘I/O性能
- 考虑使用NVMe SSD
6. **视频质量不佳**
- 增加推理步数
- 提高CFG缩放因子
- 使用14B模型
- 优化提示词
### 日志查看
......
import time
from concurrent.futures import ThreadPoolExecutor
import torch
......@@ -115,8 +116,6 @@ class WeightAsyncStreamManager(object):
self.prefetch_futures.append(future)
def swap_cpu_buffers(self):
import time
wait_start = time.time()
already_done = all(f.done() for f in self.prefetch_futures)
for f in self.prefetch_futures:
......@@ -125,25 +124,11 @@ class WeightAsyncStreamManager(object):
logger.debug(f"[Prefetch] block {self.prefetch_block_idx}: wait={wait_time:.3f}s, already_done={already_done}")
self.cpu_buffers = [self.cpu_buffers[1], self.cpu_buffers[0]]
def shutdown(self, wait=True):
"""Shutdown the thread pool executor and wait for all pending tasks to complete."""
def __del__(self):
if hasattr(self, "executor") and self.executor is not None:
# Wait for all pending futures to complete before shutting down
if hasattr(self, "prefetch_futures"):
for f in self.prefetch_futures:
try:
if not f.done():
f.result()
except Exception:
pass
self.executor.shutdown(wait=wait)
self.executor.shutdown(wait=False)
self.executor = None
logger.debug("ThreadPoolExecutor shut down successfully.")
def __del__(self):
"""Cleanup method to ensure executor is shut down when object is destroyed."""
try:
if hasattr(self, "executor") and self.executor is not None:
self.executor.shutdown(wait=False)
except Exception:
pass
......@@ -178,7 +178,7 @@ class WanModel(CompiledMethodsMixin):
if os.path.exists(non_block_file):
safetensors_files = [non_block_file]
else:
raise ValueError(f"Non-block file not found in {safetensors_path}")
raise ValueError(f"Non-block file not found in {safetensors_path}. Please check the model path. Lazy load mode only supports loading chunked model weights.")
weight_dict = {}
for file_path in safetensors_files:
......@@ -221,7 +221,7 @@ class WanModel(CompiledMethodsMixin):
if os.path.exists(non_block_file):
safetensors_files = [non_block_file]
else:
raise ValueError(f"Non-block file not found in {safetensors_path}, Please check the lazy load model path")
raise ValueError(f"Non-block file not found in {safetensors_path}. Please check the model path. Lazy load mode only supports loading chunked model weights.")
weight_dict = {}
for safetensor_path in safetensors_files:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment