Commit 5fc97e4f authored by gushiqiao's avatar gushiqiao
Browse files

Update gradio

parent 7a8951ba
...@@ -8,16 +8,11 @@ import gc ...@@ -8,16 +8,11 @@ import gc
from easydict import EasyDict from easydict import EasyDict
from datetime import datetime from datetime import datetime
from loguru import logger from loguru import logger
import sys
from pathlib import Path
module_path = str(Path(__file__).resolve().parent.parent) import importlib.util
sys.path.append(module_path) import psutil
from lightx2v.infer import init_runner # noqa: E402
from lightx2v.utils.envs import * # noqa: E402
# advance_ptq
logger.add( logger.add(
"inference_logs.log", "inference_logs.log",
rotation="100 MB", rotation="100 MB",
...@@ -28,8 +23,78 @@ logger.add( ...@@ -28,8 +23,78 @@ logger.add(
) )
global_runner = None def is_module_installed(module_name):
current_config = None try:
spec = importlib.util.find_spec(module_name)
return spec is not None
except ModuleNotFoundError:
return False
def get_available_quant_ops():
available_ops = []
vllm_installed = is_module_installed("vllm")
if vllm_installed:
available_ops.append(("vllm", True))
else:
available_ops.append(("vllm", False))
sgl_installed = is_module_installed("sgl_kernel")
if sgl_installed:
available_ops.append(("sgl", True))
else:
available_ops.append(("sgl", False))
q8f_installed = is_module_installed("q8_kernels")
if q8f_installed:
available_ops.append(("q8f", True))
else:
available_ops.append(("q8f", False))
return available_ops
def get_available_attn_ops():
available_ops = []
vllm_installed = is_module_installed("flash_attn")
if vllm_installed:
available_ops.append(("flash_attn2", True))
else:
available_ops.append(("flash_attn2", False))
sgl_installed = is_module_installed("flash_attn_interface")
if sgl_installed:
available_ops.append(("flash_attn3", True))
else:
available_ops.append(("flash_attn3", False))
q8f_installed = is_module_installed("sageattention")
if q8f_installed:
available_ops.append(("sage_attn2", True))
else:
available_ops.append(("sage_attn2", False))
return available_ops
def get_gpu_memory(gpu_idx=0):
if not torch.cuda.is_available():
return 0
try:
with torch.cuda.device(gpu_idx):
memory_info = torch.cuda.mem_get_info()
total_memory = memory_info[1] / (1024**3) # Convert bytes to GB
return total_memory
except Exception as e:
logger.warning(f"Failed to get GPU memory: {e}")
return 0
def get_cpu_memory():
available_bytes = psutil.virtual_memory().available
return available_bytes / 1024**3
def generate_unique_filename(base_dir="./saved_videos"): def generate_unique_filename(base_dir="./saved_videos"):
...@@ -38,6 +103,32 @@ def generate_unique_filename(base_dir="./saved_videos"): ...@@ -38,6 +103,32 @@ def generate_unique_filename(base_dir="./saved_videos"):
return os.path.join(base_dir, f"{model_cls}_{timestamp}.mp4") return os.path.join(base_dir, f"{model_cls}_{timestamp}.mp4")
def is_fp8_supported_gpu():
if not torch.cuda.is_available():
return False
compute_capability = torch.cuda.get_device_capability(0)
major, minor = compute_capability
return (major == 8 and minor == 9) or (major >= 9)
global_runner = None
current_config = None
available_quant_ops = get_available_quant_ops()
quant_op_choices = []
for op_name, is_installed in available_quant_ops:
status_text = "✅ Installed" if is_installed else "❌ Not Installed"
display_text = f"{op_name} ({status_text})"
quant_op_choices.append((op_name, display_text))
available_attn_ops = get_available_attn_ops()
attn_op_choices = []
for op_name, is_installed in available_attn_ops:
status_text = "✅ Installed" if is_installed else "❌ Not Installed"
display_text = f"{op_name} ({status_text})"
attn_op_choices.append((op_name, display_text))
def run_inference( def run_inference(
model_type, model_type,
task, task,
...@@ -53,6 +144,7 @@ def run_inference( ...@@ -53,6 +144,7 @@ def run_inference(
sample_shift, sample_shift,
enable_teacache, enable_teacache,
teacache_thresh, teacache_thresh,
use_ret_steps,
enable_cfg, enable_cfg,
cfg_scale, cfg_scale,
dit_quant_scheme, dit_quant_scheme,
...@@ -63,25 +155,29 @@ def run_inference( ...@@ -63,25 +155,29 @@ def run_inference(
use_tiling_vae, use_tiling_vae,
lazy_load, lazy_load,
precision_mode, precision_mode,
use_expandable_alloc,
cpu_offload, cpu_offload,
offload_granularity, offload_granularity,
offload_ratio,
t5_offload_granularity, t5_offload_granularity,
attention_type, attention_type,
quant_op, quant_op,
rotary_chunk, rotary_chunk,
rotary_chunk_size,
clean_cuda_cache, clean_cuda_cache,
): ):
quant_op = quant_op.split("(")[0].strip()
attention_type = attention_type.split("(")[0].strip()
global global_runner, current_config, model_path global global_runner, current_config, model_path
if os.path.exists(os.path.join(model_path, "config.json")): if os.path.exists(os.path.join(model_path, "config.json")):
with open(os.path.join(model_path, "config.json"), "r") as f: with open(os.path.join(model_path, "config.json"), "r") as f:
model_config = json.load(f) model_config = json.load(f)
if task == "Text-to-Video": if task == "Image to Video":
task = "t2v"
elif task == "Image-to-Video":
task = "i2v" task = "i2v"
elif task == "Text to Video":
task = "t2v"
if task == "t2v": if task == "t2v":
if model_type == "Wan2.1 1.3B": if model_type == "Wan2.1 1.3B":
...@@ -124,7 +220,6 @@ def run_inference( ...@@ -124,7 +220,6 @@ def run_inference(
if resolution in [ if resolution in [
"1280x720", "1280x720",
"720x1280", "720x1280",
"1024x1024",
"1280x544", "1280x544",
"544x1280", "544x1280",
"1104x832", "1104x832",
...@@ -173,7 +268,7 @@ def run_inference( ...@@ -173,7 +268,7 @@ def run_inference(
else: else:
t5_quant_ckpt = None t5_quant_ckpt = None
is_clip_quant = clip_quant_scheme != "bf16" is_clip_quant = clip_quant_scheme != "fp16"
if is_clip_quant: if is_clip_quant:
if clip_quant_scheme == "int8": if clip_quant_scheme == "int8":
clip_quant_ckpt = os.path.join(model_path, "clip-int8.pth") clip_quant_ckpt = os.path.join(model_path, "clip-int8.pth")
...@@ -192,10 +287,6 @@ def run_inference( ...@@ -192,10 +287,6 @@ def run_inference(
os.environ["DTYPE"] = "BF16" os.environ["DTYPE"] = "BF16"
else: else:
os.environ.pop("DTYPE", None) os.environ.pop("DTYPE", None)
if use_expandable_alloc:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:true"
else:
os.environ.pop("PYTORCH_CUDA_ALLOC_CONF", None)
if is_dit_quant: if is_dit_quant:
if quant_op == "vllm": if quant_op == "vllm":
...@@ -204,8 +295,11 @@ def run_inference( ...@@ -204,8 +295,11 @@ def run_inference(
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Sgl" mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Sgl"
elif quant_op == "q8f": elif quant_op == "q8f":
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Q8F" mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Q8F"
dit_quantized_ckpt = os.path.join(model_path, dit_quant_scheme)
else: else:
mm_type = "Default" mm_type = "Default"
dit_quantized_ckpt = None
config = { config = {
"infer_steps": infer_steps, "infer_steps": infer_steps,
...@@ -219,15 +313,16 @@ def run_inference( ...@@ -219,15 +313,16 @@ def run_inference(
"sample_shift": sample_shift, "sample_shift": sample_shift,
"cpu_offload": cpu_offload, "cpu_offload": cpu_offload,
"offload_granularity": offload_granularity, "offload_granularity": offload_granularity,
"offload_ratio": offload_ratio,
"t5_offload_granularity": t5_offload_granularity, "t5_offload_granularity": t5_offload_granularity,
"dit_quantized_ckpt": model_path if is_dit_quant else None, "dit_quantized_ckpt": dit_quantized_ckpt,
"mm_config": { "mm_config": {
"mm_type": mm_type, "mm_type": mm_type,
}, },
"fps": fps, "fps": fps,
"feature_caching": "Tea" if enable_teacache else "NoCaching", "feature_caching": "Tea" if enable_teacache else "NoCaching",
"coefficients": coefficient, "coefficients": coefficient[0] if use_ret_steps else coefficient[1],
"use_ret_steps": True, "use_ret_steps": use_ret_steps,
"teacache_thresh": teacache_thresh, "teacache_thresh": teacache_thresh,
"t5_quantized": is_t5_quant, "t5_quantized": is_t5_quant,
"t5_quantized_ckpt": t5_quant_ckpt, "t5_quantized_ckpt": t5_quant_ckpt,
...@@ -250,6 +345,7 @@ def run_inference( ...@@ -250,6 +345,7 @@ def run_inference(
"use_prompt_enhancer": False, "use_prompt_enhancer": False,
"text_len": 512, "text_len": 512,
"rotary_chunk": rotary_chunk, "rotary_chunk": rotary_chunk,
"rotary_chunk_size": rotary_chunk_size,
"clean_cuda_cache": clean_cuda_cache, "clean_cuda_cache": clean_cuda_cache,
} }
...@@ -269,11 +365,10 @@ def run_inference( ...@@ -269,11 +365,10 @@ def run_inference(
config["mode"] = "infer" config["mode"] = "infer"
config.update(model_config) config.update(model_config)
print(config)
logger.info(f"Using model: {model_path}") logger.info(f"Using model: {model_path}")
logger.info(f"Inference configuration:\n{json.dumps(config, indent=4, ensure_ascii=False)}") logger.info(f"Inference configuration:\n{json.dumps(config, indent=4, ensure_ascii=False)}")
# 初始化或复用runner # Initialize or reuse the runner
runner = global_runner runner = global_runner
if needs_reinit: if needs_reinit:
if runner is not None: if runner is not None:
...@@ -281,11 +376,15 @@ def run_inference( ...@@ -281,11 +376,15 @@ def run_inference(
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
from lightx2v.infer import init_runner # noqa
runner = init_runner(config) runner = init_runner(config)
current_config = config current_config = config
if not lazy_load: if not lazy_load:
global_runner = runner global_runner = runner
else:
runner.config = config
asyncio.run(runner.run_pipeline()) asyncio.run(runner.run_pipeline())
...@@ -297,35 +396,233 @@ def run_inference( ...@@ -297,35 +396,233 @@ def run_inference(
return save_video_path return save_video_path
def main(): def auto_configure(enable_auto_config, model_type, resolution):
parser = argparse.ArgumentParser(description="Light Video Generation") default_config = {
parser.add_argument("--model_path", type=str, required=True, help="Model folder path") "torch_compile_val": False,
parser.add_argument( "lazy_load_val": False,
"--model_cls", "rotary_chunk_val": False,
type=str, "rotary_chunk_size_val": 100,
choices=["wan2.1"], "clean_cuda_cache_val": False,
default="wan2.1", "cpu_offload_val": False,
help="Model class to use", "offload_granularity_val": "block",
) "offload_ratio_val": 1,
parser.add_argument("--server_port", type=int, default=7862, help="Server port") "t5_offload_granularity_val": "model",
parser.add_argument("--server_name", type=str, default="0.0.0.0", help="Server name") "attention_type_val": attn_op_choices[0][1],
args = parser.parse_args() "quant_op_val": quant_op_choices[0][1],
"dit_quant_scheme_val": "bf16",
"t5_quant_scheme_val": "bf16",
"clip_quant_scheme_val": "fp16",
"precision_mode_val": "fp32",
"use_tiny_vae_val": False,
"use_tiling_vae_val": False,
"enable_teacache_val": False,
"teacache_thresh_val": 0.26,
"use_ret_steps_val": False,
}
global model_path, model_cls if not enable_auto_config:
model_path = args.model_path return tuple(gr.update(value=default_config[key]) for key in default_config)
model_cls = args.model_cls
gpu_memory = round(get_gpu_memory())
cpu_memory = round(get_cpu_memory())
if is_fp8_supported_gpu():
quant_type = "fp8"
else:
quant_type = "int8"
attn_priority = ["sage_attn2", "flash_attn3", "flash_attn2"]
quant_op_priority = ["sgl", "vllm", "q8f"]
for op in attn_priority:
if dict(available_attn_ops).get(op):
default_config["attention_type_val"] = dict(attn_op_choices)[op]
break
for op in quant_op_priority:
if dict(available_quant_ops).get(op):
default_config["quant_op_val"] = dict(quant_op_choices)[op]
break
if resolution in [
"1280x720",
"720x1280",
"1280x544",
"544x1280",
"1104x832",
"832x1104",
"960x960",
]:
res = "720p"
elif resolution in [
"960x544",
"544x960",
]:
res = "540p"
else:
res = "480p"
if model_type in ["Wan2.1 14B"]:
is_14b = True
else:
is_14b = False
if res == "720p" and is_14b:
gpu_rules = [
(80, {}),
(48, {"cpu_offload_val": True, "offload_ratio_val": 0.5}),
(40, {"cpu_offload_val": True, "offload_ratio_val": 0.8}),
(32, {"cpu_offload_val": True, "offload_ratio_val": 1}),
(
24,
{
"cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True,
},
),
(
16,
{
"cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
"rotary_chunk_val": True,
"rotary_chunk_size_val": 100,
},
),
(
12,
{
"cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
"rotary_chunk_val": True,
"rotary_chunk_size_val": 100,
"clean_cuda_cache_val": True,
},
),
(
8,
{
"cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
"rotary_chunk_val": True,
"rotary_chunk_size_val": 100,
"clean_cuda_cache_val": True,
"t5_quant_scheme_val": quant_type,
"clip_quant_scheme_val": quant_type,
"dit_quant_scheme_val": quant_type,
"lazy_load_val": True,
},
),
]
elif is_14b:
gpu_rules = [
(80, {}),
(48, {"cpu_offload_val": True, "offload_ratio_val": 0.2}),
(40, {"cpu_offload_val": True, "offload_ratio_val": 0.5}),
(24, {"cpu_offload_val": True, "offload_ratio_val": 0.8}),
(
16,
{
"cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True,
"offload_granularity_val": "block",
},
),
(
8,
(
{
"cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
"t5_quant_scheme_val": quant_type,
"clip_quant_scheme_val": quant_type,
"dit_quant_scheme_val": quant_type,
"lazy_load_val": True,
"rotary_chunk_val": True,
"rotary_chunk_size_val": 10000,
}
if res == "540p"
else {
"cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
"t5_quant_scheme_val": quant_type,
"clip_quant_scheme_val": quant_type,
"dit_quant_scheme_val": quant_type,
"lazy_load_val": True,
}
),
),
]
if is_14b:
cpu_rules = [
(128, {}),
(64, {"dit_quant_scheme_val": quant_type}),
(32, {"dit_quant_scheme_val": quant_type, "lazy_load_val": True}),
(
16,
{
"dit_quant_scheme_val": quant_type,
"t5_quant_scheme_val": quant_type,
"clip_quant_scheme_val": quant_type,
"lazy_load_val": True,
"dit_quant_scheme_val": quant_type,
},
),
]
for threshold, updates in gpu_rules:
if gpu_memory >= threshold:
default_config.update(updates)
break
for threshold, updates in cpu_rules:
if cpu_memory >= threshold:
default_config.update(updates)
break
return tuple(gr.update(value=default_config[key]) for key in default_config)
def main():
def update_model_type(task_type): def update_model_type(task_type):
if task_type == "Image-to-Video": if task_type == "Image to Video":
return gr.update(choices=["Wan2.1 14B"], value="Wan2.1 14B") return gr.update(choices=["Wan2.1 14B"], value="Wan2.1 14B")
elif task_type == "Text-to-Video": elif task_type == "Text to Video":
return gr.update(choices=["Wan2.1 14B", "Wan2.1 1.3B"], value="Wan2.1 14B") return gr.update(choices=["Wan2.1 14B", "Wan2.1 1.3B"], value="Wan2.1 14B")
def toggle_image_input(task): def toggle_image_input(task):
return gr.update(visible=(task == "Image-to-Video")) return gr.update(visible=(task == "Image to Video"))
with gr.Blocks( with gr.Blocks(
title="Lightx2v (Lightweight Video Inference Generation Engine)", title="Lightx2v (Lightweight Video Inference and Generation Engine)",
css=""" css="""
.main-content { max-width: 1400px; margin: auto; } .main-content { max-width: 1400px; margin: auto; }
.output-video { max-height: 650px; } .output-video { max-height: 650px; }
...@@ -335,7 +632,7 @@ def main(): ...@@ -335,7 +632,7 @@ def main():
""", """,
) as demo: ) as demo:
gr.Markdown(f"# 🎬 {model_cls} Video Generator") gr.Markdown(f"# 🎬 {model_cls} Video Generator")
gr.Markdown(f"### Using model: {model_path}") gr.Markdown(f"### Using Model: {model_path}")
with gr.Tabs() as tabs: with gr.Tabs() as tabs:
with gr.Tab("Basic Settings", id=1): with gr.Tab("Basic Settings", id=1):
...@@ -346,11 +643,10 @@ def main(): ...@@ -346,11 +643,10 @@ def main():
with gr.Row(): with gr.Row():
task = gr.Dropdown( task = gr.Dropdown(
choices=["Image-to-Video", "Text-to-Video"], choices=["Image to Video", "Text to Video"],
value="Image-to-Video", value="Image to Video",
label="Task Type", label="Task Type",
) )
model_type = gr.Dropdown( model_type = gr.Dropdown(
choices=["Wan2.1 14B"], choices=["Wan2.1 14B"],
value="Wan2.1 14B", value="Wan2.1 14B",
...@@ -368,7 +664,7 @@ def main(): ...@@ -368,7 +664,7 @@ def main():
type="filepath", type="filepath",
height=300, height=300,
interactive=True, interactive=True,
visible=True, visible=True, # Initially visible
) )
task.change( task.change(
...@@ -389,9 +685,9 @@ def main(): ...@@ -389,9 +685,9 @@ def main():
negative_prompt = gr.Textbox( negative_prompt = gr.Textbox(
label="Negative Prompt", label="Negative Prompt",
lines=3, lines=3,
placeholder="Content you don't want in the video...", placeholder="What you don't want to appear in the video...",
max_lines=5, max_lines=5,
value="camera shake, garish colors, overexposure, static, blurry details, subtitles, style, work, painting, image, still, overall gray, worst quality, low quality, JPEG compression artifacts, ugly, mutilated, extra fingers, poorly drawn hands, poorly drawn face, deformed, disfigured, deformed limbs, finger fusion, static frame, cluttered background, three legs, crowded background, walking backwards", value="镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
) )
with gr.Column(): with gr.Column():
resolution = gr.Dropdown( resolution = gr.Dropdown(
...@@ -399,7 +695,6 @@ def main(): ...@@ -399,7 +695,6 @@ def main():
# 720p # 720p
("1280x720 (16:9, 720p)", "1280x720"), ("1280x720 (16:9, 720p)", "1280x720"),
("720x1280 (9:16, 720p)", "720x1280"), ("720x1280 (9:16, 720p)", "720x1280"),
("1024x1024 (1:1, 720p)", "1024x1024"),
("1280x544 (21:9, 720p)", "1280x544"), ("1280x544 (21:9, 720p)", "1280x544"),
("544x1280 (9:21, 720p)", "544x1280"), ("544x1280 (9:21, 720p)", "544x1280"),
("1104x832 (4:3, 720p)", "1104x832"), ("1104x832 (4:3, 720p)", "1104x832"),
...@@ -415,8 +710,8 @@ def main(): ...@@ -415,8 +710,8 @@ def main():
("720x720 (1:1, 480p)", "720x720"), ("720x720 (1:1, 480p)", "720x720"),
("512x512 (1:1, 480p)", "512x512"), ("512x512 (1:1, 480p)", "512x512"),
], ],
value="480x832", value="832x480",
label="Max Resolution", label="Maximum Resolution",
) )
with gr.Column(): with gr.Column():
seed = gr.Slider( seed = gr.Slider(
...@@ -425,50 +720,60 @@ def main(): ...@@ -425,50 +720,60 @@ def main():
maximum=10000000, maximum=10000000,
step=1, step=1,
value=42, value=42,
info="Fix the random seed for reproducible results",
) )
infer_steps = gr.Slider( infer_steps = gr.Slider(
label="Inference Steps", label="Inference Steps",
minimum=1, minimum=1,
maximum=100, maximum=100,
step=1, step=1,
value=20, value=40,
info="Inference steps for video generation. More steps may improve quality but reduce speed", info="Number of inference steps for video generation. Increasing steps may improve quality but reduce speed.",
)
sample_shift = gr.Slider(
label="Distribution Shift",
value=5,
minimum=0,
maximum=10,
step=1,
info="Controls the distribution shift of samples. Larger values mean more obvious shifts",
) )
fps = gr.Slider( enable_cfg = gr.Checkbox(
label="Frame Rate (FPS)", label="Enable Classifier-Free Guidance",
minimum=8, value=True,
maximum=30, info="Enable classifier-free guidance to control prompt strength",
step=1, )
value=16, cfg_scale = gr.Slider(
info="Frames per second. Higher FPS produces smoother video", label="CFG Scale Factor",
) minimum=1,
num_frames = gr.Slider( maximum=10,
label="Total Frames", step=1,
minimum=16, value=5,
maximum=120, info="Controls the influence strength of the prompt. Higher values give more influence to the prompt.",
step=1, )
value=81, sample_shift = gr.Slider(
info="Total number of frames. More frames produce longer video", label="Distribution Shift",
) value=5,
minimum=0,
save_video_path = gr.Textbox( maximum=10,
label="Output Video Path", step=1,
value=generate_unique_filename(), info="Controls the degree of distribution shift for samples. Larger values indicate more significant shifts.",
info="Must include .mp4 suffix. If left empty or using default, a unique filename will be automatically generated",
) )
infer_btn = gr.Button("Generate Video", variant="primary", size="lg") fps = gr.Slider(
label="Frames Per Second (FPS)",
minimum=8,
maximum=30,
step=1,
value=16,
info="Frames per second of the video. Higher FPS results in smoother videos.",
)
num_frames = gr.Slider(
label="Total Frames",
minimum=16,
maximum=120,
step=1,
value=81,
info="Total number of frames in the video. More frames result in longer videos.",
)
save_video_path = gr.Textbox(
label="Output Video Path",
value=generate_unique_filename(),
info="Must include .mp4 extension. If left blank or using the default value, a unique filename will be automatically generated.",
)
with gr.Column(scale=6): with gr.Column(scale=6):
gr.Markdown("## 📤 Generated Video") gr.Markdown("## 📤 Generated Video")
output_video = gr.Video( output_video = gr.Video(
...@@ -479,133 +784,139 @@ def main(): ...@@ -479,133 +784,139 @@ def main():
elem_classes=["output-video"], elem_classes=["output-video"],
) )
infer_btn = gr.Button("Generate Video", variant="primary", size="lg")
with gr.Tab("⚙️ Advanced Options", id=2): with gr.Tab("⚙️ Advanced Options", id=2):
with gr.Group(elem_classes="advanced-options"): with gr.Group(elem_classes="advanced-options"):
gr.Markdown("### Classifier-Free Guidance (CFG)") gr.Markdown("### Auto configuration")
with gr.Row(): with gr.Row():
enable_cfg = gr.Checkbox( enable_auto_config = gr.Checkbox(
label="Enable Classifier-Free Guidance", label="Auto configuration",
value=False, value=False,
info="Enable classifier guidance to control prompt strength", info="Auto-tune optimization settings for your GPU",
)
cfg_scale = gr.Slider(
label="CFG Scale",
minimum=1,
maximum=100,
step=1,
value=5,
info="Controls the influence strength of the prompt. Higher values mean stronger influence",
) )
gr.Markdown("### Memory Optimization") gr.Markdown("### GPU Memory Optimization")
with gr.Row(): with gr.Row():
lazy_load = gr.Checkbox( rotary_chunk = gr.Checkbox(
label="Enable Lazy Loading", label="Chunked Rotary Position Embedding",
value=False,
info="Lazily load model components during inference, suitable for memory-constrained environments",
)
torch_compile = gr.Checkbox(
label="Enable Torch Compile",
value=False,
info="Use torch.compile to accelerate the inference process",
)
use_expandable_alloc = gr.Checkbox(
label="Enable Expandable Memory Allocation",
value=False, value=False,
info="Helps reduce memory fragmentation", info="When enabled, processes rotary position embeddings in chunks to save GPU memory.",
) )
rotary_chunk = gr.Checkbox( rotary_chunk_size = gr.Slider(
label="Chunked Rotary Position Encoding", label="Rotary Embedding Chunk Size",
value=False, value=100,
info="When enabled, uses chunked processing for rotary position encoding to save memory.", minimum=100,
maximum=10000,
step=100,
info="Controls the chunk size for applying rotary embeddings. Larger values may improve performance but increase memory usage. Only effective if 'rotary_chunk' is checked.",
) )
clean_cuda_cache = gr.Checkbox( clean_cuda_cache = gr.Checkbox(
label="Clean CUDA Memory Cache", label="Clean CUDA Memory Cache",
value=False, value=False,
info="When enabled, frees up memory in a timely manner but slows down inference.", info="When enabled, frees up GPU memory promptly but slows down inference.",
) )
gr.Markdown("### Asynchronous Offloading")
with gr.Row(): with gr.Row():
cpu_offload = gr.Checkbox( cpu_offload = gr.Checkbox(
label="CPU Offload", label="CPU Offloading",
value=False,
info="Offload parts of the model computation from GPU to CPU to reduce GPU memory usage",
)
lazy_load = gr.Checkbox(
label="Enable Lazy Loading",
value=False, value=False,
info="Offload part of the model computation from GPU to CPU to reduce video memory usage", info="Lazy load model components during inference. Requires CPU loading and DIT quantization.",
) )
offload_granularity = gr.Dropdown( offload_granularity = gr.Dropdown(
label="Dit Offload Granularity", label="Dit Offload Granularity",
choices=["block", "phase"], choices=["block", "phase"],
value="block", value="phase",
info="Controls the granularity of Dit model offloading to CPU", info="Sets Dit model offloading granularity: blocks or computational phases",
)
offload_ratio = gr.Slider(
label="Offload ratio for Dit model",
minimum=0.0,
maximum=1.0,
step=0.1,
value=1.0,
info="Controls how much of the Dit model is offloaded to the CPU",
) )
t5_offload_granularity = gr.Dropdown( t5_offload_granularity = gr.Dropdown(
label="T5 Encoder Offload Granularity", label="T5 Encoder Offload Granularity",
choices=["model", "block"], choices=["model", "block"],
value="block", value="model",
info="Controls the granularity of T5 Encoder model offloading to CPU", info="Controls the granularity when offloading the T5 Encoder model to CPU",
) )
gr.Markdown("### Low-Precision Quantization") gr.Markdown("### Low-Precision Quantization")
with gr.Row(): with gr.Row():
torch_compile = gr.Checkbox(
label="Torch Compile",
value=False,
info="Use torch.compile to accelerate the inference process",
)
attention_type = gr.Dropdown( attention_type = gr.Dropdown(
label="Attention Operator", label="Attention Operator",
choices=["flash_attn2", "flash_attn3", "sage_attn2"], choices=[op[1] for op in attn_op_choices],
value="flash_attn2", value=attn_op_choices[0][1],
info="Using a suitable attention operator can accelerate inference", info="Use appropriate attention operators to accelerate inference",
) )
quant_op = gr.Dropdown( quant_op = gr.Dropdown(
label="Quantization Operator", label="Quantization Matmul Operator",
choices=["vllm", "sgl", "q8f"], choices=[op[1] for op in quant_op_choices],
value="vllm", value=quant_op_choices[0][1],
info="Using a suitable quantization operator can accelerate inference", info="Select the quantization matrix multiplication operator to accelerate inference",
interactive=True,
) )
dit_quant_scheme = gr.Dropdown( dit_quant_scheme = gr.Dropdown(
label="Dit", label="Dit",
choices=["fp8", "int8", "bf16"], choices=["fp8", "int8", "bf16"],
value="bf16", value="bf16",
info="Quantization precision for Dit model", info="Quantization precision for the Dit model",
) )
t5_quant_scheme = gr.Dropdown( t5_quant_scheme = gr.Dropdown(
label="T5 Encoder", label="T5 Encoder",
choices=["fp8", "int8", "bf16"], choices=["fp8", "int8", "bf16"],
value="bf16", value="bf16",
info="Quantization precision for T5 Encoder model", info="Quantization precision for the T5 Encoder model",
) )
clip_quant_scheme = gr.Dropdown( clip_quant_scheme = gr.Dropdown(
label="Clip Encoder", label="Clip Encoder",
choices=["fp8", "int8", "fp16"], choices=["fp8", "int8", "fp16"],
value="fp16", value="fp16",
info="Quantization precision for Clip Encoder", info="Quantization precision for the Clip Encoder",
) )
precision_mode = gr.Dropdown( precision_mode = gr.Dropdown(
label="Sensitive Layer Precision", label="Precision Mode",
choices=["fp32", "bf16"], choices=["fp32", "bf16"],
value="bf16", value="fp32",
info="Select the numerical precision for sensitive layer calculations.", info="Select the numerical precision used for sensitive layers.",
) )
gr.Markdown("### Variational Autoencoder (VAE)") gr.Markdown("### Variational Autoencoder (VAE)")
with gr.Row(): with gr.Row():
use_tiny_vae = gr.Checkbox( use_tiny_vae = gr.Checkbox(
label="Use Lightweight VAE", label="Use Tiny VAE",
value=False, value=False,
info="Use a lightweight VAE model to accelerate the decoding process", info="Use a lightweight VAE model to accelerate the decoding process",
) )
use_tiling_vae = gr.Checkbox( use_tiling_vae = gr.Checkbox(
label="Enable VAE Tiling Inference", label="VAE Tiling Inference",
value=False, value=False,
info="Use VAE tiling inference to reduce video memory usage", info="Use VAE tiling inference to reduce GPU memory usage",
) )
gr.Markdown("### Feature Caching") gr.Markdown("### Feature Caching")
with gr.Row(): with gr.Row():
enable_teacache = gr.Checkbox( enable_teacache = gr.Checkbox(
label="Enable Tea Cache", label="Tea Cache",
value=False, value=False,
info="Cache features during inference to reduce the number of inference steps", info="Cache features during inference to reduce the number of inference steps",
) )
...@@ -614,9 +925,41 @@ def main(): ...@@ -614,9 +925,41 @@ def main():
value=0.26, value=0.26,
minimum=0, minimum=0,
maximum=1, maximum=1,
info="Higher acceleration may lead to lower quality - setting to 0.1 gives about 2.0x acceleration, setting to 0.2 gives about 3.0x acceleration", info="Higher acceleration may result in lower quality —— Setting to 0.1 provides ~2.0x acceleration, setting to 0.2 provides ~3.0x acceleration",
)
use_ret_steps = gr.Checkbox(
label="Cache Only Key Steps",
value=False,
info="When checked, cache is written only at key steps where the scheduler returns results; when unchecked, cache is written at all steps to ensure the highest quality",
) )
enable_auto_config.change(
fn=auto_configure,
inputs=[enable_auto_config, model_type, resolution],
outputs=[
torch_compile,
lazy_load,
rotary_chunk,
rotary_chunk_size,
clean_cuda_cache,
cpu_offload,
offload_granularity,
offload_ratio,
t5_offload_granularity,
attention_type,
quant_op,
dit_quant_scheme,
t5_quant_scheme,
clip_quant_scheme,
precision_mode,
use_tiny_vae,
use_tiling_vae,
enable_teacache,
teacache_thresh,
use_ret_steps,
],
)
infer_btn.click( infer_btn.click(
fn=run_inference, fn=run_inference,
inputs=[ inputs=[
...@@ -634,6 +977,7 @@ def main(): ...@@ -634,6 +977,7 @@ def main():
sample_shift, sample_shift,
enable_teacache, enable_teacache,
teacache_thresh, teacache_thresh,
use_ret_steps,
enable_cfg, enable_cfg,
cfg_scale, cfg_scale,
dit_quant_scheme, dit_quant_scheme,
...@@ -644,13 +988,14 @@ def main(): ...@@ -644,13 +988,14 @@ def main():
use_tiling_vae, use_tiling_vae,
lazy_load, lazy_load,
precision_mode, precision_mode,
use_expandable_alloc,
cpu_offload, cpu_offload,
offload_granularity, offload_granularity,
offload_ratio,
t5_offload_granularity, t5_offload_granularity,
attention_type, attention_type,
quant_op, quant_op,
rotary_chunk, rotary_chunk,
rotary_chunk_size,
clean_cuda_cache, clean_cuda_cache,
], ],
outputs=output_video, outputs=output_video,
...@@ -660,4 +1005,21 @@ def main(): ...@@ -660,4 +1005,21 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() parser = argparse.ArgumentParser(description="Light Video Generation")
parser.add_argument("--model_path", type=str, required=True, help="Model folder path")
parser.add_argument(
"--model_cls",
type=str,
choices=["wan2.1"],
default="wan2.1",
help="Model class to use",
)
parser.add_argument("--server_port", type=int, default=7862, help="Server port")
parser.add_argument("--server_name", type=str, default="0.0.0.0", help="Server ip")
args = parser.parse_args()
global model_path, model_cls
model_path = args.model_path
model_cls = args.model_cls
main()
\ No newline at end of file
...@@ -8,16 +8,11 @@ import gc ...@@ -8,16 +8,11 @@ import gc
from easydict import EasyDict from easydict import EasyDict
from datetime import datetime from datetime import datetime
from loguru import logger from loguru import logger
import sys
from pathlib import Path
module_path = str(Path(__file__).resolve().parent.parent) import importlib.util
sys.path.append(module_path) import psutil
from lightx2v.infer import init_runner # noqa: E402
from lightx2v.utils.envs import * # noqa: E402
# advance_ptq
logger.add( logger.add(
"inference_logs.log", "inference_logs.log",
rotation="100 MB", rotation="100 MB",
...@@ -28,8 +23,79 @@ logger.add( ...@@ -28,8 +23,79 @@ logger.add(
) )
global_runner = None def is_module_installed(module_name):
current_config = None """检查模块是否已安装"""
try:
spec = importlib.util.find_spec(module_name)
return spec is not None
except ModuleNotFoundError:
return False
def get_available_quant_ops():
available_ops = []
vllm_installed = is_module_installed("vllm")
if vllm_installed:
available_ops.append(("vllm", True))
else:
available_ops.append(("vllm", False))
sgl_installed = is_module_installed("sgl_kernel")
if sgl_installed:
available_ops.append(("sgl", True))
else:
available_ops.append(("sgl", False))
q8f_installed = is_module_installed("q8_kernels")
if q8f_installed:
available_ops.append(("q8f", True))
else:
available_ops.append(("q8f", False))
return available_ops
def get_available_attn_ops():
available_ops = []
vllm_installed = is_module_installed("flash_attn")
if vllm_installed:
available_ops.append(("flash_attn2", True))
else:
available_ops.append(("flash_attn2", False))
sgl_installed = is_module_installed("flash_attn_interface")
if sgl_installed:
available_ops.append(("flash_attn3", True))
else:
available_ops.append(("flash_attn3", False))
q8f_installed = is_module_installed("sageattention")
if q8f_installed:
available_ops.append(("sage_attn2", True))
else:
available_ops.append(("sage_attn2", False))
return available_ops
def get_gpu_memory(gpu_idx=0):
if not torch.cuda.is_available():
return 0
try:
with torch.cuda.device(gpu_idx):
memory_info = torch.cuda.mem_get_info()
total_memory = memory_info[1] / (1024**3)
return total_memory
except Exception as e:
logger.warning(f"获取GPU内存失败: {e}")
return 0
def get_cpu_memory():
available_bytes = psutil.virtual_memory().available
return available_bytes / 1024**3
def generate_unique_filename(base_dir="./saved_videos"): def generate_unique_filename(base_dir="./saved_videos"):
...@@ -38,6 +104,32 @@ def generate_unique_filename(base_dir="./saved_videos"): ...@@ -38,6 +104,32 @@ def generate_unique_filename(base_dir="./saved_videos"):
return os.path.join(base_dir, f"{model_cls}_{timestamp}.mp4") return os.path.join(base_dir, f"{model_cls}_{timestamp}.mp4")
def is_fp8_supported_gpu():
if not torch.cuda.is_available():
return False
compute_capability = torch.cuda.get_device_capability(0)
major, minor = compute_capability
return (major == 8 and minor == 9) or (major >= 9)
global_runner = None
current_config = None
available_quant_ops = get_available_quant_ops()
quant_op_choices = []
for op_name, is_installed in available_quant_ops:
status_text = "✅ 已安装" if is_installed else "❌ 未安装"
display_text = f"{op_name} ({status_text})"
quant_op_choices.append((op_name, display_text))
available_attn_ops = get_available_attn_ops()
attn_op_choices = []
for op_name, is_installed in available_attn_ops:
status_text = "✅ 已安装" if is_installed else "❌ 未安装"
display_text = f"{op_name} ({status_text})"
attn_op_choices.append((op_name, display_text))
def run_inference( def run_inference(
model_type, model_type,
task, task,
...@@ -53,6 +145,7 @@ def run_inference( ...@@ -53,6 +145,7 @@ def run_inference(
sample_shift, sample_shift,
enable_teacache, enable_teacache,
teacache_thresh, teacache_thresh,
use_ret_steps,
enable_cfg, enable_cfg,
cfg_scale, cfg_scale,
dit_quant_scheme, dit_quant_scheme,
...@@ -63,25 +156,29 @@ def run_inference( ...@@ -63,25 +156,29 @@ def run_inference(
use_tiling_vae, use_tiling_vae,
lazy_load, lazy_load,
precision_mode, precision_mode,
use_expandable_alloc,
cpu_offload, cpu_offload,
offload_granularity, offload_granularity,
offload_ratio,
t5_offload_granularity, t5_offload_granularity,
attention_type, attention_type,
quant_op, quant_op,
rotary_chunk, rotary_chunk,
rotary_chunk_size,
clean_cuda_cache, clean_cuda_cache,
): ):
quant_op = quant_op.split("(")[0].strip()
attention_type = attention_type.split("(")[0].strip()
global global_runner, current_config, model_path global global_runner, current_config, model_path
if os.path.exists(os.path.join(model_path, "config.json")): if os.path.exists(os.path.join(model_path, "config.json")):
with open(os.path.join(model_path, "config.json"), "r") as f: with open(os.path.join(model_path, "config.json"), "r") as f:
model_config = json.load(f) model_config = json.load(f)
if task == "文生视频": if task == "图像生成视频":
task = "t2v"
elif task == "图生视频":
task = "i2v" task = "i2v"
elif task == "文本生成视频":
task = "t2v"
if task == "t2v": if task == "t2v":
if model_type == "Wan2.1 1.3B": if model_type == "Wan2.1 1.3B":
...@@ -124,7 +221,6 @@ def run_inference( ...@@ -124,7 +221,6 @@ def run_inference(
if resolution in [ if resolution in [
"1280x720", "1280x720",
"720x1280", "720x1280",
"1024x1024",
"1280x544", "1280x544",
"544x1280", "544x1280",
"1104x832", "1104x832",
...@@ -173,7 +269,7 @@ def run_inference( ...@@ -173,7 +269,7 @@ def run_inference(
else: else:
t5_quant_ckpt = None t5_quant_ckpt = None
is_clip_quant = clip_quant_scheme != "bf16" is_clip_quant = clip_quant_scheme != "fp16"
if is_clip_quant: if is_clip_quant:
if clip_quant_scheme == "int8": if clip_quant_scheme == "int8":
clip_quant_ckpt = os.path.join(model_path, "clip-int8.pth") clip_quant_ckpt = os.path.join(model_path, "clip-int8.pth")
...@@ -192,10 +288,6 @@ def run_inference( ...@@ -192,10 +288,6 @@ def run_inference(
os.environ["DTYPE"] = "BF16" os.environ["DTYPE"] = "BF16"
else: else:
os.environ.pop("DTYPE", None) os.environ.pop("DTYPE", None)
if use_expandable_alloc:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:true"
else:
os.environ.pop("PYTORCH_CUDA_ALLOC_CONF", None)
if is_dit_quant: if is_dit_quant:
if quant_op == "vllm": if quant_op == "vllm":
...@@ -204,8 +296,11 @@ def run_inference( ...@@ -204,8 +296,11 @@ def run_inference(
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Sgl" mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Sgl"
elif quant_op == "q8f": elif quant_op == "q8f":
mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Q8F" mm_type = f"W-{dit_quant_scheme}-channel-sym-A-{dit_quant_scheme}-channel-sym-dynamic-Q8F"
dit_quantized_ckpt = os.path.join(model_path, dit_quant_scheme)
else: else:
mm_type = "Default" mm_type = "Default"
dit_quantized_ckpt = None
config = { config = {
"infer_steps": infer_steps, "infer_steps": infer_steps,
...@@ -219,15 +314,16 @@ def run_inference( ...@@ -219,15 +314,16 @@ def run_inference(
"sample_shift": sample_shift, "sample_shift": sample_shift,
"cpu_offload": cpu_offload, "cpu_offload": cpu_offload,
"offload_granularity": offload_granularity, "offload_granularity": offload_granularity,
"offload_ratio": offload_ratio,
"t5_offload_granularity": t5_offload_granularity, "t5_offload_granularity": t5_offload_granularity,
"dit_quantized_ckpt": model_path if is_dit_quant else None, "dit_quantized_ckpt": dit_quantized_ckpt,
"mm_config": { "mm_config": {
"mm_type": mm_type, "mm_type": mm_type,
}, },
"fps": fps, "fps": fps,
"feature_caching": "Tea" if enable_teacache else "NoCaching", "feature_caching": "Tea" if enable_teacache else "NoCaching",
"coefficients": coefficient, "coefficients": coefficient[0] if use_ret_steps else coefficient[1],
"use_ret_steps": True, "use_ret_steps": use_ret_steps,
"teacache_thresh": teacache_thresh, "teacache_thresh": teacache_thresh,
"t5_quantized": is_t5_quant, "t5_quantized": is_t5_quant,
"t5_quantized_ckpt": t5_quant_ckpt, "t5_quantized_ckpt": t5_quant_ckpt,
...@@ -250,6 +346,7 @@ def run_inference( ...@@ -250,6 +346,7 @@ def run_inference(
"use_prompt_enhancer": False, "use_prompt_enhancer": False,
"text_len": 512, "text_len": 512,
"rotary_chunk": rotary_chunk, "rotary_chunk": rotary_chunk,
"rotary_chunk_size": rotary_chunk_size,
"clean_cuda_cache": clean_cuda_cache, "clean_cuda_cache": clean_cuda_cache,
} }
...@@ -269,11 +366,9 @@ def run_inference( ...@@ -269,11 +366,9 @@ def run_inference(
config["mode"] = "infer" config["mode"] = "infer"
config.update(model_config) config.update(model_config)
print(config)
logger.info(f"使用模型: {model_path}") logger.info(f"使用模型: {model_path}")
logger.info(f"推理配置:\n{json.dumps(config, indent=4, ensure_ascii=False)}") logger.info(f"推理配置:\n{json.dumps(config, indent=4, ensure_ascii=False)}")
# 初始化或复用runner
runner = global_runner runner = global_runner
if needs_reinit: if needs_reinit:
if runner is not None: if runner is not None:
...@@ -281,11 +376,15 @@ def run_inference( ...@@ -281,11 +376,15 @@ def run_inference(
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
from lightx2v.infer import init_runner # noqa
runner = init_runner(config) runner = init_runner(config)
current_config = config current_config = config
if not lazy_load: if not lazy_load:
global_runner = runner global_runner = runner
else:
runner.config = config
asyncio.run(runner.run_pipeline()) asyncio.run(runner.run_pipeline())
...@@ -297,35 +396,233 @@ def run_inference( ...@@ -297,35 +396,233 @@ def run_inference(
return save_video_path return save_video_path
def main(): def auto_configure(enable_auto_config, model_type, resolution):
parser = argparse.ArgumentParser(description="Light Video Generation") default_config = {
parser.add_argument("--model_path", type=str, required=True, help="模型文件夹路径") "torch_compile_val": False,
parser.add_argument( "lazy_load_val": False,
"--model_cls", "rotary_chunk_val": False,
type=str, "rotary_chunk_size_val": 100,
choices=["wan2.1"], "clean_cuda_cache_val": False,
default="wan2.1", "cpu_offload_val": False,
help="使用的模型类别", "offload_granularity_val": "block",
) "offload_ratio_val": 1,
parser.add_argument("--server_port", type=int, default=7862, help="服务器端口") "t5_offload_granularity_val": "model",
parser.add_argument("--server_name", type=str, default="0.0.0.0", help="服务器名称") "attention_type_val": attn_op_choices[0][1],
args = parser.parse_args() "quant_op_val": quant_op_choices[0][1],
"dit_quant_scheme_val": "bf16",
"t5_quant_scheme_val": "bf16",
"clip_quant_scheme_val": "fp16",
"precision_mode_val": "fp32",
"use_tiny_vae_val": False,
"use_tiling_vae_val": False,
"enable_teacache_val": False,
"teacache_thresh_val": 0.26,
"use_ret_steps_val": False,
}
global model_path, model_cls if not enable_auto_config:
model_path = args.model_path return tuple(gr.update(value=default_config[key]) for key in default_config)
model_cls = args.model_cls
gpu_memory = round(get_gpu_memory())
cpu_memory = round(get_cpu_memory())
if is_fp8_supported_gpu():
quant_type = "fp8"
else:
quant_type = "int8"
attn_priority = ["sage_attn2", "flash_attn3", "flash_attn2"]
quant_op_priority = ["sgl", "vllm", "q8f"]
for op in attn_priority:
if dict(available_attn_ops).get(op):
default_config["attention_type_val"] = dict(attn_op_choices)[op]
break
for op in quant_op_priority:
if dict(available_quant_ops).get(op):
default_config["quant_op_val"] = dict(quant_op_choices)[op]
break
if resolution in [
"1280x720",
"720x1280",
"1280x544",
"544x1280",
"1104x832",
"832x1104",
"960x960",
]:
res = "720p"
elif resolution in [
"960x544",
"544x960",
]:
res = "540p"
else:
res = "480p"
if model_type in ["Wan2.1 14B"]:
is_14b = True
else:
is_14b = False
if res == "720p" and is_14b:
gpu_rules = [
(80, {}),
(48, {"cpu_offload_val": True, "offload_ratio_val": 0.5}),
(40, {"cpu_offload_val": True, "offload_ratio_val": 0.8}),
(32, {"cpu_offload_val": True, "offload_ratio_val": 1}),
(
24,
{
"cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True,
},
),
(
16,
{
"cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
"rotary_chunk_val": True,
"rotary_chunk_size_val": 100,
},
),
(
12,
{
"cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
"rotary_chunk_val": True,
"rotary_chunk_size_val": 100,
"clean_cuda_cache_val": True,
},
),
(
8,
{
"cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
"rotary_chunk_val": True,
"rotary_chunk_size_val": 100,
"clean_cuda_cache_val": True,
"t5_quant_scheme_val": quant_type,
"clip_quant_scheme_val": quant_type,
"dit_quant_scheme_val": quant_type,
"lazy_load_val": True,
},
),
]
elif is_14b:
gpu_rules = [
(80, {}),
(48, {"cpu_offload_val": True, "offload_ratio_val": 0.2}),
(40, {"cpu_offload_val": True, "offload_ratio_val": 0.5}),
(24, {"cpu_offload_val": True, "offload_ratio_val": 0.8}),
(
16,
{
"cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True,
"offload_granularity_val": "block",
},
),
(
8,
(
{
"cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
"t5_quant_scheme_val": quant_type,
"clip_quant_scheme_val": quant_type,
"dit_quant_scheme_val": quant_type,
"lazy_load_val": True,
"rotary_chunk_val": True,
"rotary_chunk_size_val": 10000,
}
if res == "540p"
else {
"cpu_offload_val": True,
"offload_ratio_val": 1,
"t5_offload_granularity_val": "block",
"precision_mode_val": "bf16",
"use_tiling_vae_val": True,
"offload_granularity_val": "phase",
"t5_quant_scheme_val": quant_type,
"clip_quant_scheme_val": quant_type,
"dit_quant_scheme_val": quant_type,
"lazy_load_val": True,
}
),
),
]
if is_14b:
cpu_rules = [
(128, {}),
(64, {"dit_quant_scheme_val": quant_type}),
(32, {"dit_quant_scheme_val": quant_type, "lazy_load_val": True}),
(
16,
{
"dit_quant_scheme_val": quant_type,
"t5_quant_scheme_val": quant_type,
"clip_quant_scheme_val": quant_type,
"lazy_load_val": True,
"dit_quant_scheme_val": quant_type,
},
),
]
for threshold, updates in gpu_rules:
if gpu_memory >= threshold:
default_config.update(updates)
break
for threshold, updates in cpu_rules:
if cpu_memory >= threshold:
default_config.update(updates)
break
return tuple(gr.update(value=default_config[key]) for key in default_config)
def main():
def update_model_type(task_type): def update_model_type(task_type):
if task_type == "图视频": if task_type == "图像生成视频":
return gr.update(choices=["Wan2.1 14B"], value="Wan2.1 14B") return gr.update(choices=["Wan2.1 14B"], value="Wan2.1 14B")
elif task_type == "文视频": elif task_type == "文本生成视频":
return gr.update(choices=["Wan2.1 14B", "Wan2.1 1.3B"], value="Wan2.1 14B") return gr.update(choices=["Wan2.1 14B", "Wan2.1 1.3B"], value="Wan2.1 14B")
def toggle_image_input(task): def toggle_image_input(task):
return gr.update(visible=(task == "图视频")) return gr.update(visible=(task == "图像生成视频"))
with gr.Blocks( with gr.Blocks(
title="Lightx2v(轻量级视频推理生成引擎)", title="Lightx2v (轻量级视频生成推理引擎)",
css=""" css="""
.main-content { max-width: 1400px; margin: auto; } .main-content { max-width: 1400px; margin: auto; }
.output-video { max-height: 650px; } .output-video { max-height: 650px; }
...@@ -338,7 +635,7 @@ def main(): ...@@ -338,7 +635,7 @@ def main():
gr.Markdown(f"### 使用模型: {model_path}") gr.Markdown(f"### 使用模型: {model_path}")
with gr.Tabs() as tabs: with gr.Tabs() as tabs:
with gr.Tab("基设置", id=1): with gr.Tab("基设置", id=1):
with gr.Row(): with gr.Row():
with gr.Column(scale=4): with gr.Column(scale=4):
with gr.Group(): with gr.Group():
...@@ -346,8 +643,8 @@ def main(): ...@@ -346,8 +643,8 @@ def main():
with gr.Row(): with gr.Row():
task = gr.Dropdown( task = gr.Dropdown(
choices=["图视频", "文视频"], choices=["图像生成视频", "文本生成视频"],
value="图视频", value="图像生成视频",
label="任务类型", label="任务类型",
) )
model_type = gr.Dropdown( model_type = gr.Dropdown(
...@@ -363,11 +660,11 @@ def main(): ...@@ -363,11 +660,11 @@ def main():
with gr.Row(): with gr.Row():
image_path = gr.Image( image_path = gr.Image(
label="输入图", label="输入图",
type="filepath", type="filepath",
height=300, height=300,
interactive=True, interactive=True,
visible=True, # Initially visible visible=True,
) )
task.change( task.change(
...@@ -388,7 +685,7 @@ def main(): ...@@ -388,7 +685,7 @@ def main():
negative_prompt = gr.Textbox( negative_prompt = gr.Textbox(
label="负向提示词", label="负向提示词",
lines=3, lines=3,
placeholder="不希望视频出现的内容...", placeholder="不希望出现在视频中的内容...",
max_lines=5, max_lines=5,
value="镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", value="镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
) )
...@@ -398,7 +695,6 @@ def main(): ...@@ -398,7 +695,6 @@ def main():
# 720p # 720p
("1280x720 (16:9, 720p)", "1280x720"), ("1280x720 (16:9, 720p)", "1280x720"),
("720x1280 (9:16, 720p)", "720x1280"), ("720x1280 (9:16, 720p)", "720x1280"),
("1024x1024 (1:1, 720p)", "1024x1024"),
("1280x544 (21:9, 720p)", "1280x544"), ("1280x544 (21:9, 720p)", "1280x544"),
("544x1280 (9:21, 720p)", "544x1280"), ("544x1280 (9:21, 720p)", "544x1280"),
("1104x832 (4:3, 720p)", "1104x832"), ("1104x832 (4:3, 720p)", "1104x832"),
...@@ -424,50 +720,60 @@ def main(): ...@@ -424,50 +720,60 @@ def main():
maximum=10000000, maximum=10000000,
step=1, step=1,
value=42, value=42,
info="固定随机种子以获得可复现的结果",
) )
infer_steps = gr.Slider( infer_steps = gr.Slider(
label="推理步数", label="推理步数",
minimum=1, minimum=1,
maximum=100, maximum=100,
step=1, step=1,
value=20, value=40,
info="视频生成的推理步数,增加步数可能提高质量但会降低速度", info="视频生成的推理步数。增加步数可能提高质量但降低速度",
)
sample_shift = gr.Slider(
label="分布偏移程度",
value=5,
minimum=0,
maximum=10,
step=1,
info="用于控制样本的分布偏移程度,数值越大表示偏移越明显",
) )
fps = gr.Slider( enable_cfg = gr.Checkbox(
label="帧率(FPS)", label="启用无分类器引导",
minimum=8, value=True,
maximum=30, info="启用无分类器引导以控制提示词强度",
step=1, )
value=16, cfg_scale = gr.Slider(
info="视频每秒帧数,更高的FPS生成更流畅的视频", label="CFG缩放因子",
) minimum=1,
num_frames = gr.Slider( maximum=10,
label="总帧数", step=1,
minimum=16, value=5,
maximum=120, info="控制提示词的影响强度。值越高,提示词的影响越大",
step=1, )
value=81, sample_shift = gr.Slider(
info="视频总帧数,更多的帧数生成更长的视频", label="分布偏移",
) value=5,
minimum=0,
save_video_path = gr.Textbox( maximum=10,
label="输出视频路径", step=1,
value=generate_unique_filename(), info="控制样本分布偏移的程度。值越大表示偏移越明显",
info="必须包含.mp4后缀,如果留空或使用默认值,将自动生成唯一文件名",
) )
infer_btn = gr.Button("生成视频", variant="primary", size="lg") fps = gr.Slider(
label="每秒帧数(FPS)",
minimum=8,
maximum=30,
step=1,
value=16,
info="视频的每秒帧数。较高的FPS会产生更流畅的视频",
)
num_frames = gr.Slider(
label="总帧数",
minimum=16,
maximum=120,
step=1,
value=81,
info="视频中的总帧数。更多帧数会产生更长的视频",
)
save_video_path = gr.Textbox(
label="输出视频路径",
value=generate_unique_filename(),
info="必须包含.mp4扩展名。如果留空或使用默认值,将自动生成唯一文件名。",
)
with gr.Column(scale=6): with gr.Column(scale=6):
gr.Markdown("## 📤 生成的视频") gr.Markdown("## 📤 生成的视频")
output_video = gr.Video( output_video = gr.Video(
...@@ -478,114 +784,120 @@ def main(): ...@@ -478,114 +784,120 @@ def main():
elem_classes=["output-video"], elem_classes=["output-video"],
) )
infer_btn = gr.Button("生成视频", variant="primary", size="lg")
with gr.Tab("⚙️ 高级选项", id=2): with gr.Tab("⚙️ 高级选项", id=2):
with gr.Group(elem_classes="advanced-options"): with gr.Group(elem_classes="advanced-options"):
gr.Markdown("### 无分类器引导(CFG)") gr.Markdown("### 自动配置")
with gr.Row(): with gr.Row():
enable_cfg = gr.Checkbox( enable_auto_config = gr.Checkbox(
label="启用无分类器引导", label="自动配置",
value=False, value=False,
info="启用分类器引导,用于控制提示词强度", info="自动调整优化设置以适应您的GPU",
)
cfg_scale = gr.Slider(
label="CFG缩放系数",
minimum=1,
maximum=100,
step=1,
value=5,
info="控制提示词的影响强度,值越高提示词影响越大",
) )
gr.Markdown("### 显存/内存优化") gr.Markdown("### GPU内存优化")
with gr.Row(): with gr.Row():
lazy_load = gr.Checkbox( rotary_chunk = gr.Checkbox(
label="启用延迟加载", label="分块旋转位置编码",
value=False,
info="推理时延迟加载模型组件,适用内存受限环境",
)
torch_compile = gr.Checkbox(
label="启用Torch编译",
value=False,
info="使用torch.compile加速推理过程",
)
use_expandable_alloc = gr.Checkbox(
label="启用可扩展显存分配",
value=False, value=False,
info="有助于减少显存碎片", info="启用时,将旋转位置编码分块处理以节省GPU内存。",
) )
rotary_chunk = gr.Checkbox( rotary_chunk_size = gr.Slider(
label="分块处理旋转位置编码", label="旋转编码块大小",
value=False, value=100,
info="启用后,使用分块处理旋转位置编码节省显存。", minimum=100,
maximum=10000,
step=100,
info="控制应用旋转编码的块大小, 较大的值可能提高性能但增加内存使用, 仅在'rotary_chunk'勾选时有效",
) )
clean_cuda_cache = gr.Checkbox( clean_cuda_cache = gr.Checkbox(
label="清理 CUDA存缓存", label="清理CUDA存缓存",
value=False, value=False,
info="启用后,及时释放显存但推理速度变慢。", info="及时释放GPU内存, 但会减慢推理速度。",
) )
gr.Markdown("### 异步卸载")
with gr.Row(): with gr.Row():
cpu_offload = gr.Checkbox( cpu_offload = gr.Checkbox(
label="CPU卸载", label="CPU卸载",
value=False, value=False,
info="将模型的部分计算从 GPU 卸载到 CPU,以降低显存占用", info="将模型计算的一部分从GPU卸载到CPU以减少GPU内存使用",
) )
lazy_load = gr.Checkbox(
label="启用延迟加载",
value=False,
info="在推理过程中延迟加载模型组件, 仅在'cpu_offload'勾选和使用量化Dit模型时有效",
)
offload_granularity = gr.Dropdown( offload_granularity = gr.Dropdown(
label="Dit 卸载粒度", label="Dit卸载粒度",
choices=["block", "phase"], choices=["block", "phase"],
value="block", value="phase",
info="控制 Dit 模型卸载到 CPU 时的粒度", info="设置Dit模型卸载粒度: 块或计算阶段",
)
offload_ratio = gr.Slider(
label="Dit模型卸载比例",
minimum=0.0,
maximum=1.0,
step=0.1,
value=1.0,
info="控制将多少Dit模型卸载到CPU",
) )
t5_offload_granularity = gr.Dropdown( t5_offload_granularity = gr.Dropdown(
label="T5 Encoder 卸载粒度", label="T5编码器卸载粒度",
choices=["model", "block"], choices=["model", "block"],
value="block", value="model",
info="控制 T5 Encoder 模型卸载到 CPU 时的粒度", info="控制T5编码器模型卸载到CPU时的粒度",
) )
gr.Markdown("### 低精度量化") gr.Markdown("### 低精度量化")
with gr.Row(): with gr.Row():
attention_type = gr.Dropdown( torch_compile = gr.Checkbox(
label="attention 算子", label="Torch编译",
choices=["flash_attn2", "flash_attn3", "sage_attn2"], value=False,
value="flash_attn2", info="使用torch.compile加速推理过程",
info="使用合适的 attention 算子可加速推理",
) )
attention_type = gr.Dropdown(
label="注意力算子",
choices=[op[1] for op in attn_op_choices],
value=attn_op_choices[0][1],
info="使用适当的注意力算子加速推理",
)
quant_op = gr.Dropdown( quant_op = gr.Dropdown(
label="量化算子", label="量化矩阵乘法算子",
choices=["vllm", "sgl", "q8f"], choices=[op[1] for op in quant_op_choices],
value="vllm", value=quant_op_choices[0][1],
info="使用合适的量化算子可加速推理", info="选择量化矩阵乘法算子以加速推理",
interactive=True,
) )
dit_quant_scheme = gr.Dropdown( dit_quant_scheme = gr.Dropdown(
label="Dit", label="Dit",
choices=["fp8", "int8", "bf16"], choices=["fp8", "int8", "bf16"],
value="bf16", value="bf16",
info="Dit模型的量化精度", info="Dit模型的推理精度",
) )
t5_quant_scheme = gr.Dropdown( t5_quant_scheme = gr.Dropdown(
label="T5 Encoder", label="T5编码器",
choices=["fp8", "int8", "bf16"], choices=["fp8", "int8", "bf16"],
value="bf16", value="bf16",
info="T5 Encoder模型的量化精度", info="T5编码器模型的推理精度",
) )
clip_quant_scheme = gr.Dropdown( clip_quant_scheme = gr.Dropdown(
label="Clip Encoder", label="Clip编码器",
choices=["fp8", "int8", "fp16"], choices=["fp8", "int8", "fp16"],
value="fp16", value="fp16",
info="Clip Encoder的量化精度", info="Clip编码器的推理精度",
) )
precision_mode = gr.Dropdown( precision_mode = gr.Dropdown(
label="敏感层精度", label="精度模式",
choices=["fp32", "bf16"], choices=["fp32", "bf16"],
value="bf16", value="fp32",
info="选择用于敏感层计算的数值精度。", info="部分敏感层的推理精度。",
) )
gr.Markdown("### 变分自编码器(VAE)") gr.Markdown("### 变分自编码器(VAE)")
...@@ -596,15 +908,15 @@ def main(): ...@@ -596,15 +908,15 @@ def main():
info="使用轻量级VAE模型加速解码过程", info="使用轻量级VAE模型加速解码过程",
) )
use_tiling_vae = gr.Checkbox( use_tiling_vae = gr.Checkbox(
label="启用 VAE 平铺推理", label="VAE分块推理",
value=False, value=False,
info="使用 VAE 平铺推理以降低显存占用", info="使用VAE分块推理以减少GPU内存使用",
) )
gr.Markdown("### 特征缓存") gr.Markdown("### 特征缓存")
with gr.Row(): with gr.Row():
enable_teacache = gr.Checkbox( enable_teacache = gr.Checkbox(
label="启用Tea Cache", label="Tea Cache",
value=False, value=False,
info="在推理过程中缓存特征以减少推理步数", info="在推理过程中缓存特征以减少推理步数",
) )
...@@ -613,9 +925,41 @@ def main(): ...@@ -613,9 +925,41 @@ def main():
value=0.26, value=0.26,
minimum=0, minimum=0,
maximum=1, maximum=1,
info="加速越高,质量可能越差 —— 设置为 0.1 可获得约 2.0 倍加速,设置为 0.2 可获得约 3.0 倍加速", info="较高的加速可能导致质量下降 —— 设置为0.1提供约2.0倍加速,设置为0.2提供约3.0倍加速",
)
use_ret_steps = gr.Checkbox(
label="仅缓存关键步骤",
value=False,
info="勾选时,仅在调度器返回结果的关键步骤写入缓存;未勾选时,在所有步骤写入缓存以确保最高质量",
) )
enable_auto_config.change(
fn=auto_configure,
inputs=[enable_auto_config, model_type, resolution],
outputs=[
torch_compile,
lazy_load,
rotary_chunk,
rotary_chunk_size,
clean_cuda_cache,
cpu_offload,
offload_granularity,
offload_ratio,
t5_offload_granularity,
attention_type,
quant_op,
dit_quant_scheme,
t5_quant_scheme,
clip_quant_scheme,
precision_mode,
use_tiny_vae,
use_tiling_vae,
enable_teacache,
teacache_thresh,
use_ret_steps,
],
)
infer_btn.click( infer_btn.click(
fn=run_inference, fn=run_inference,
inputs=[ inputs=[
...@@ -633,6 +977,7 @@ def main(): ...@@ -633,6 +977,7 @@ def main():
sample_shift, sample_shift,
enable_teacache, enable_teacache,
teacache_thresh, teacache_thresh,
use_ret_steps,
enable_cfg, enable_cfg,
cfg_scale, cfg_scale,
dit_quant_scheme, dit_quant_scheme,
...@@ -643,13 +988,14 @@ def main(): ...@@ -643,13 +988,14 @@ def main():
use_tiling_vae, use_tiling_vae,
lazy_load, lazy_load,
precision_mode, precision_mode,
use_expandable_alloc,
cpu_offload, cpu_offload,
offload_granularity, offload_granularity,
offload_ratio,
t5_offload_granularity, t5_offload_granularity,
attention_type, attention_type,
quant_op, quant_op,
rotary_chunk, rotary_chunk,
rotary_chunk_size,
clean_cuda_cache, clean_cuda_cache,
], ],
outputs=output_video, outputs=output_video,
...@@ -659,4 +1005,21 @@ def main(): ...@@ -659,4 +1005,21 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() parser = argparse.ArgumentParser(description="轻量级视频生成")
parser.add_argument("--model_path", type=str, required=True, help="模型文件夹路径")
parser.add_argument(
"--model_cls",
type=str,
choices=["wan2.1"],
default="wan2.1",
help="要使用的模型类别",
)
parser.add_argument("--server_port", type=int, default=7862, help="服务器端口")
parser.add_argument("--server_name", type=str, default="0.0.0.0", help="服务器IP")
args = parser.parse_args()
global model_path, model_cls
model_path = args.model_path
model_cls = args.model_cls
main()
\ No newline at end of file
#!/bin/bash
lightx2v_path=/path/to/lightx2v
model_path=/path/to/wan
export CUDA_VISIBLE_DEVICES=0
export CUDA_LAUNCH_BLOCKING=1
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
python gradio_demo.py \
--model_path $model_path \
--server_name 0.0.0.0 \
--server_port 8005
# python gradio_demo_zh.py \
# --model_path $model_path \
# --server_name 0.0.0.0 \
# --server_port 8005
...@@ -56,8 +56,9 @@ class WeightAsyncStreamManager(object): ...@@ -56,8 +56,9 @@ class WeightAsyncStreamManager(object):
class LazyWeightAsyncStreamManager(WeightAsyncStreamManager): class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
def __init__(self, blocks_num, offload_ratio=1, phases_num=1, num_disk_workers=1, max_memory=2): def __init__(self, blocks_num, offload_ratio=1, phases_num=1, num_disk_workers=1, max_memory=2, offload_gra="phase"):
super().__init__(blocks_num, offload_ratio, phases_num) super().__init__(blocks_num, offload_ratio, phases_num)
self.offload_gra = offload_gra
self.worker_stop_event = threading.Event() self.worker_stop_event = threading.Event()
self.pin_memory_buffer = MemoryBuffer(max_memory * (1024**3)) self.pin_memory_buffer = MemoryBuffer(max_memory * (1024**3))
self.disk_task_queue = queue.PriorityQueue() self.disk_task_queue = queue.PriorityQueue()
...@@ -72,7 +73,10 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager): ...@@ -72,7 +73,10 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
def _start_disk_workers(self, num_workers): def _start_disk_workers(self, num_workers):
for i in range(num_workers): for i in range(num_workers):
worker = threading.Thread(target=self._disk_worker_loop, daemon=True) if self.offload_gra == "phase":
worker = threading.Thread(target=self._disk_worker_loop, daemon=True)
else:
worker = threading.Thread(target=self._disk_worker_loop_block, daemon=True)
worker.start() worker.start()
self.disk_workers.append(worker) self.disk_workers.append(worker)
...@@ -96,33 +100,74 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager): ...@@ -96,33 +100,74 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
except Exception as e: except Exception as e:
logger.error(f"Disk worker thread error: {e}") logger.error(f"Disk worker thread error: {e}")
def _disk_worker_loop_block(self):
while not self.worker_stop_event.is_set():
try:
_, task = self.disk_task_queue.get(timeout=0.5)
if task is None:
break
block_idx, block = task
for phase in block.compute_phases:
phase.load_from_disk()
self.pin_memory_buffer.push(block_idx, block)
with self.task_lock:
if block_idx in self.pending_tasks:
del self.pending_tasks[block_idx]
except queue.Empty:
continue
except Exception as e:
logger.error(f"Disk worker thread error: {e}")
def _async_prefetch_block(self, weights): def _async_prefetch_block(self, weights):
next_block_idx = self.pin_memory_buffer.get_max_block_index() next_block_idx = self.pin_memory_buffer.get_max_block_index()
if next_block_idx < 0: if next_block_idx < 0:
next_block_idx = 0 next_block_idx = 0
for phase_idx in range(self.phases_num): if self.offload_gra == "phase":
obj_key = (next_block_idx, phase_idx) for phase_idx in range(self.phases_num):
obj_key = (next_block_idx, phase_idx)
if self.pin_memory_buffer.exists(obj_key) or (obj_key in self.pending_tasks):
continue
with self.task_lock:
self.pending_tasks[obj_key] = True
phase = weights.blocks[next_block_idx].compute_phases[phase_idx]
priority_key = (next_block_idx, phase_idx)
self.disk_task_queue.put((priority_key, (next_block_idx, phase_idx, phase)))
else:
obj_key = next_block_idx
if self.pin_memory_buffer.exists(obj_key) or (obj_key in self.pending_tasks): if self.pin_memory_buffer.exists(obj_key) or (obj_key in self.pending_tasks):
continue return
with self.task_lock: with self.task_lock:
self.pending_tasks[obj_key] = True self.pending_tasks[obj_key] = True
phase = weights.blocks[next_block_idx].compute_phases[phase_idx] block = weights.blocks[next_block_idx]
self.disk_task_queue.put((obj_key, (next_block_idx, block)))
priority_key = (next_block_idx, phase_idx)
self.disk_task_queue.put((priority_key, (next_block_idx, phase_idx, phase)))
def _sync_prefetch_block(self, weights): def _sync_prefetch_block(self, weights):
block_idx = 0 block_idx = 0
while not self.pin_memory_buffer.is_nearly_full(): while not self.pin_memory_buffer.is_nearly_full():
for phase_idx in range(self.phases_num): if self.offload_gra == "phase":
phase = weights.blocks[block_idx].compute_phases[phase_idx] for phase_idx in range(self.phases_num):
logger.info(f"Synchronous loading: block={block_idx}, phase={phase_idx}") phase = weights.blocks[block_idx].compute_phases[phase_idx]
phase.load_from_disk() logger.info(f"Synchronous loading: block={block_idx}, phase={phase_idx}")
self.pin_memory_buffer.push((block_idx, phase_idx), phase) phase.load_from_disk()
self.pin_memory_buffer.push((block_idx, phase_idx), phase)
else:
block = weights.blocks[block_idx]
logger.info(f"Synchronous loading: block={block_idx}")
for phase in block.compute_phases:
phase.load_from_disk()
self.pin_memory_buffer.push(block_idx, block)
block_idx += 1 block_idx += 1
def prefetch_weights_from_disk(self, weights): def prefetch_weights_from_disk(self, weights):
...@@ -132,6 +177,37 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager): ...@@ -132,6 +177,37 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
self._sync_prefetch_block(weights) self._sync_prefetch_block(weights)
self.initial_prefetch_done = True self.initial_prefetch_done = True
def prefetch_weights(self, block_idx, blocks):
obj_key = block_idx
if not self.pin_memory_buffer.exists(obj_key):
is_loading = False
with self.task_lock:
if obj_key in self.pending_tasks:
is_loading = True
if is_loading:
start_time = time.time()
while not self.pin_memory_buffer.exists(obj_key):
time.sleep(0.001)
if time.time() - start_time > 5:
raise TimeoutError(f"Load timeout: block={block_idx}")
else:
logger.info("Not find prefetch block={block_idx} task. This is a bug.")
with torch.cuda.stream(self.cuda_load_stream):
block = self.pin_memory_buffer.get(obj_key)
block.to_cuda_async()
self.active_weights[2] = (obj_key, block)
with torch.cuda.stream(self.cpu_load_stream):
if block_idx < self.offload_block_num:
if self.active_weights[1] is not None:
old_key, old_block = self.active_weights[1]
if self.pin_memory_buffer.exists(old_key):
old_block.to_cpu_async()
self.pin_memory_buffer.pop(old_key)
def prefetch_phase(self, block_idx, phase_idx, blocks): def prefetch_phase(self, block_idx, phase_idx, blocks):
obj_key = (block_idx, phase_idx) obj_key = (block_idx, phase_idx)
...@@ -193,32 +269,45 @@ class MemoryBuffer: ...@@ -193,32 +269,45 @@ class MemoryBuffer:
self.cache = OrderedDict() self.cache = OrderedDict()
self.max_mem = max_memory_bytes self.max_mem = max_memory_bytes
self.used_mem = 0 self.used_mem = 0
self.phases_size_map = {} self.obj_size_map = {}
self.lock = threading.Lock() self.lock = threading.Lock()
self.insertion_order = [] self.insertion_order = []
self.insertion_index = 0 self.insertion_index = 0
def push(self, key, phase_obj): def push(self, key, obj):
with self.lock: with self.lock:
if key in self.cache: if key in self.cache:
return return
_, phase_idx = key if hasattr(obj, "compute_phases"):
if phase_idx not in self.phases_size_map: obj_idx = key
self.phases_size_map[phase_idx] = phase_obj.calculate_size() if len(self.obj_size_map) == 0:
size = self.phases_size_map[phase_idx] _size = 0
for phase in obj.compute_phases:
_size += phase.calculate_size()
self.obj_size_map[0] = _size
size = self.obj_size_map[0]
else:
_, obj_idx = key
if obj_idx not in self.obj_size_map:
self.obj_size_map[obj_idx] = obj.calculate_size()
size = self.obj_size_map[obj_idx]
self.cache[key] = (size, phase_obj, self.insertion_index) self.cache[key] = (size, obj, self.insertion_index)
self.insertion_order.append((key, self.insertion_index)) self.insertion_order.append((key, self.insertion_index))
self.insertion_index += 1 self.insertion_index += 1
self.used_mem += size self.used_mem += size
def _remove_key(self, key): def _remove_key(self, key):
if key in self.cache: if key in self.cache:
size, phase, idx = self.cache.pop(key) size, obj, idx = self.cache.pop(key)
try: try:
phase.clear() if hasattr(obj, "compute_phases"):
for phase in obj.compute_phases:
phase.clear()
else:
obj.clear()
except Exception as e: except Exception as e:
logger.info(f"Error clearing phase: {e}") logger.info(f"Error clearing obj: {e}")
self.used_mem -= size self.used_mem -= size
self.insertion_order = [(k, i) for (k, i) in self.insertion_order if k != key] self.insertion_order = [(k, i) for (k, i) in self.insertion_order if k != key]
...@@ -226,14 +315,22 @@ class MemoryBuffer: ...@@ -226,14 +315,22 @@ class MemoryBuffer:
def get(self, key, default=None): def get(self, key, default=None):
with self.lock: with self.lock:
if key in self.cache: if key in self.cache:
size, phase, idx = self.cache[key] size, obj, idx = self.cache[key]
return phase return obj
return default return default
def exists(self, key): def exists(self, key):
with self.lock: with self.lock:
return key in self.cache return key in self.cache
def pop_front(self):
with self.lock:
if not self.insertion_order:
return False
front_key, _ = self.insertion_order[0]
self._remove_key(front_key)
return True
def pop(self, key): def pop(self, key):
with self.lock: with self.lock:
if key in self.cache: if key in self.cache:
...@@ -249,7 +346,10 @@ class MemoryBuffer: ...@@ -249,7 +346,10 @@ class MemoryBuffer:
with self.lock: with self.lock:
if not self.cache: if not self.cache:
return -1 return -1
return (list(self.cache.keys())[-1][0] + 1) % 40 if isinstance(list(self.cache.keys())[-1], tuple):
return (list(self.cache.keys())[-1][0] + 1) % 40
else:
return (list(self.cache.keys())[-1] + 1) % 40
def clear(self): def clear(self):
with self.lock: with self.lock:
...@@ -260,4 +360,4 @@ class MemoryBuffer: ...@@ -260,4 +360,4 @@ class MemoryBuffer:
self.insertion_index = 0 self.insertion_index = 0
self.used_mem = 0 self.used_mem = 0
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
\ No newline at end of file
import torch import torch
from .utils import compute_freqs, compute_freqs_dist, compute_freqs_audio, compute_freqs_audio_dist, apply_rotary_emb, apply_rotary_emb_chunk from .utils import compute_freqs, compute_freqs_dist, apply_rotary_emb, apply_rotary_emb_chunk
from lightx2v.common.offload.manager import ( from lightx2v.common.offload.manager import (
WeightAsyncStreamManager, WeightAsyncStreamManager,
LazyWeightAsyncStreamManager, LazyWeightAsyncStreamManager,
) )
from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from loguru import logger from functools import partial
import os
class WanTransformerInfer(BaseTransformerInfer): class WanTransformerInfer(BaseTransformerInfer):
...@@ -21,10 +20,12 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -21,10 +20,12 @@ class WanTransformerInfer(BaseTransformerInfer):
self.head_dim = config["dim"] // config["num_heads"] self.head_dim = config["dim"] // config["num_heads"]
self.window_size = config.get("window_size", (-1, -1)) self.window_size = config.get("window_size", (-1, -1))
self.parallel_attention = None self.parallel_attention = None
self.apply_rotary_emb_func = apply_rotary_emb_chunk if config.get("rotary_chunk", False) else apply_rotary_emb if config.get("rotary_chunk", False):
chunk_size = config.get("rotary_chunk_size", 100)
self.apply_rotary_emb_func = partial(apply_rotary_emb_chunk, chunk_size=chunk_size)
else:
self.apply_rotary_emb_func = apply_rotary_emb
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False) self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
self.mask_map = None
if self.config["cpu_offload"]: if self.config["cpu_offload"]:
if "offload_ratio" in self.config: if "offload_ratio" in self.config:
offload_ratio = self.config["offload_ratio"] offload_ratio = self.config["offload_ratio"]
...@@ -32,7 +33,10 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -32,7 +33,10 @@ class WanTransformerInfer(BaseTransformerInfer):
offload_ratio = 1 offload_ratio = 1
offload_granularity = self.config.get("offload_granularity", "block") offload_granularity = self.config.get("offload_granularity", "block")
if offload_granularity == "block": if offload_granularity == "block":
self.infer_func = self._infer_with_offload if not self.config.get("lazy_load", False):
self.infer_func = self._infer_with_offload
else:
self.infer_func = self._infer_with_lazy_offload
elif offload_granularity == "phase": elif offload_granularity == "phase":
if not self.config.get("lazy_load", False): if not self.config.get("lazy_load", False):
self.infer_func = self._infer_with_phases_offload self.infer_func = self._infer_with_phases_offload
...@@ -52,6 +56,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -52,6 +56,7 @@ class WanTransformerInfer(BaseTransformerInfer):
phases_num=self.phases_num, phases_num=self.phases_num,
num_disk_workers=self.config.get("num_disk_workers", 2), num_disk_workers=self.config.get("num_disk_workers", 2),
max_memory=self.config.get("max_memory", 2), max_memory=self.config.get("max_memory", 2),
offload_gra=offload_granularity,
) )
else: else:
self.infer_func = self._infer_without_offload self.infer_func = self._infer_without_offload
...@@ -68,10 +73,10 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -68,10 +73,10 @@ class WanTransformerInfer(BaseTransformerInfer):
return cu_seqlens_q, cu_seqlens_k return cu_seqlens_q, cu_seqlens_k
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE()) @torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None): def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
return self.infer_func(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks) return self.infer_func(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
def _infer_with_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None): def _infer_with_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
for block_idx in range(self.blocks_num): for block_idx in range(self.blocks_num):
if block_idx == 0: if block_idx == 0:
self.weights_stream_mgr.active_weights[0] = weights.blocks[0] self.weights_stream_mgr.active_weights[0] = weights.blocks[0]
...@@ -96,7 +101,44 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -96,7 +101,44 @@ class WanTransformerInfer(BaseTransformerInfer):
return x return x
def _infer_with_phases_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None): def _infer_with_lazy_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
self.weights_stream_mgr.prefetch_weights_from_disk(weights)
for block_idx in range(self.blocks_num):
if block_idx == 0:
block = self.weights_stream_mgr.pin_memory_buffer.get(block_idx)
block.to_cuda()
self.weights_stream_mgr.active_weights[0] = (block_idx, block)
if block_idx < self.blocks_num - 1:
self.weights_stream_mgr.prefetch_weights(block_idx + 1, weights.blocks)
with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
x = self.infer_block(
self.weights_stream_mgr.active_weights[0][1],
grid_sizes,
embed,
x,
embed0,
seq_lens,
freqs,
context,
)
self.weights_stream_mgr.swap_weights()
if block_idx == self.blocks_num - 1:
self.weights_stream_mgr.pin_memory_buffer.pop_front()
self.weights_stream_mgr._async_prefetch_block(weights)
if self.clean_cuda_cache:
del grid_sizes, embed, embed0, seq_lens, freqs, context
torch.cuda.empty_cache()
return x
def _infer_with_phases_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
for block_idx in range(weights.blocks_num): for block_idx in range(weights.blocks_num):
for phase_idx in range(self.phases_num): for phase_idx in range(self.phases_num):
if block_idx == 0 and phase_idx == 0: if block_idx == 0 and phase_idx == 0:
...@@ -133,11 +175,18 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -133,11 +175,18 @@ class WanTransformerInfer(BaseTransformerInfer):
self.weights_stream_mgr.swap_phases() self.weights_stream_mgr.swap_phases()
torch.cuda.empty_cache() if self.clean_cuda_cache:
del attn_out, y_out, y
torch.cuda.empty_cache()
if self.clean_cuda_cache:
del shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa
del grid_sizes, embed, embed0, seq_lens, freqs, context
torch.cuda.empty_cache()
return x return x
def _infer_with_phases_lazy_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None): def _infer_with_phases_lazy_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
self.weights_stream_mgr.prefetch_weights_from_disk(weights) self.weights_stream_mgr.prefetch_weights_from_disk(weights)
for block_idx in range(weights.blocks_num): for block_idx in range(weights.blocks_num):
...@@ -198,22 +247,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -198,22 +247,7 @@ class WanTransformerInfer(BaseTransformerInfer):
return x return x
def zero_temporal_component_in_3DRoPE(self, valid_token_length, rotary_emb=None): def _infer_without_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
if rotary_emb is None:
return None
self.use_real = False
rope_t_dim = 44
if self.use_real:
freqs_cos, freqs_sin = rotary_emb
freqs_cos[valid_token_length:, :, :rope_t_dim] = 0
freqs_sin[valid_token_length:, :, :rope_t_dim] = 0
return freqs_cos, freqs_sin
else:
freqs_cis = rotary_emb
freqs_cis[valid_token_length:, :, : rope_t_dim // 2] = 0
return freqs_cis
def _infer_without_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
for block_idx in range(self.blocks_num): for block_idx in range(self.blocks_num):
x = self.infer_block( x = self.infer_block(
weights.blocks[block_idx], weights.blocks[block_idx],
...@@ -225,12 +259,6 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -225,12 +259,6 @@ class WanTransformerInfer(BaseTransformerInfer):
freqs, freqs,
context, context,
) )
if audio_dit_blocks is not None and len(audio_dit_blocks) > 0:
for ipa_out in audio_dit_blocks:
if block_idx in ipa_out:
cur_modify = ipa_out[block_idx]
x = cur_modify["modify_func"](x, grid_sizes, **cur_modify["kwargs"])
return x return x
def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
...@@ -290,23 +318,14 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -290,23 +318,14 @@ class WanTransformerInfer(BaseTransformerInfer):
v = weights.self_attn_v.apply(norm1_out).view(s, n, d) v = weights.self_attn_v.apply(norm1_out).view(s, n, d)
if not self.parallel_attention: if not self.parallel_attention:
if self.config.get("audio_sr", False): freqs_i = compute_freqs(q.size(2) // 2, grid_sizes, freqs)
freqs_i = compute_freqs_audio(q.size(2) // 2, grid_sizes, freqs)
else:
freqs_i = compute_freqs(q.size(2) // 2, grid_sizes, freqs)
else: else:
if self.config.get("audio_sr", False): freqs_i = compute_freqs_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs)
freqs_i = compute_freqs_audio_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs)
else:
freqs_i = compute_freqs_dist(q.size(2) // 2, grid_sizes, freqs)
freqs_i = self.zero_temporal_component_in_3DRoPE(seq_lens, freqs_i)
q = self.apply_rotary_emb_func(q, freqs_i) q = self.apply_rotary_emb_func(q, freqs_i)
k = self.apply_rotary_emb_func(k, freqs_i) k = self.apply_rotary_emb_func(k, freqs_i)
k_lens = torch.empty_like(seq_lens).fill_(freqs_i.size(0)) cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(q, k_lens=seq_lens)
cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(q, k_lens=k_lens)
if self.clean_cuda_cache: if self.clean_cuda_cache:
del freqs_i, norm1_out, norm1_weight, norm1_bias del freqs_i, norm1_out, norm1_weight, norm1_bias
...@@ -322,7 +341,6 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -322,7 +341,6 @@ class WanTransformerInfer(BaseTransformerInfer):
max_seqlen_q=q.size(0), max_seqlen_q=q.size(0),
max_seqlen_kv=k.size(0), max_seqlen_kv=k.size(0),
model_cls=self.config["model_cls"], model_cls=self.config["model_cls"],
mask_map=self.mask_map,
) )
else: else:
attn_out = self.parallel_attention( attn_out = self.parallel_attention(
...@@ -388,6 +406,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -388,6 +406,7 @@ class WanTransformerInfer(BaseTransformerInfer):
q, q,
k_lens=torch.tensor([k_img.size(0)], dtype=torch.int32, device=k.device), k_lens=torch.tensor([k_img.size(0)], dtype=torch.int32, device=k.device),
) )
img_attn_out = weights.cross_attn_2.apply( img_attn_out = weights.cross_attn_2.apply(
q=q, q=q,
k=k_img, k=k_img,
...@@ -452,4 +471,4 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -452,4 +471,4 @@ class WanTransformerInfer(BaseTransformerInfer):
if self.clean_cuda_cache: if self.clean_cuda_cache:
del y, c_gate_msa del y, c_gate_msa
torch.cuda.empty_cache() torch.cuda.empty_cache()
return x return x
\ No newline at end of file
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from loguru import logger
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
...@@ -20,45 +19,6 @@ def compute_freqs(c, grid_sizes, freqs): ...@@ -20,45 +19,6 @@ def compute_freqs(c, grid_sizes, freqs):
return freqs_i return freqs_i
def compute_freqs_audio(c, grid_sizes, freqs):
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist()
f = f + 1 ##for r2v add 1 channel
seq_len = f * h * w
freqs_i = torch.cat(
[
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
],
dim=-1,
).reshape(seq_len, 1, -1)
return freqs_i
def compute_freqs_audio_dist(s, c, grid_sizes, freqs):
world_size = dist.get_world_size()
cur_rank = dist.get_rank()
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist()
f = f + 1
seq_len = f * h * w
freqs_i = torch.cat(
[
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
],
dim=-1,
).reshape(seq_len, 1, -1)
freqs_i = pad_freqs(freqs_i, s * world_size)
s_per_rank = s
freqs_i_rank = freqs_i[(cur_rank * s_per_rank) : ((cur_rank + 1) * s_per_rank), :, :]
return freqs_i_rank
def compute_freqs_causvid(c, grid_sizes, freqs, start_frame=0): def compute_freqs_causvid(c, grid_sizes, freqs, start_frame=0):
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist() f, h, w = grid_sizes[0].tolist()
...@@ -115,7 +75,7 @@ def apply_rotary_emb(x, freqs_i): ...@@ -115,7 +75,7 @@ def apply_rotary_emb(x, freqs_i):
return x_i.to(torch.bfloat16) return x_i.to(torch.bfloat16)
def apply_rotary_emb_chunk(x, freqs_i, chunk_size=100, remaining_chunk_size=100): def apply_rotary_emb_chunk(x, freqs_i, chunk_size, remaining_chunk_size=100):
n = x.size(1) n = x.size(1)
seq_len = freqs_i.size(0) seq_len = freqs_i.size(0)
......
import os import os
import torch
from functools import lru_cache from functools import lru_cache
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment