"vscode:/vscode.git/clone" did not exist on "72054ca05870174b18a1421479e76a3da5e39782"
Commit 61420957 authored by gushiqiao's avatar gushiqiao
Browse files

Support gradio demo for wan i2v.

parent fe9ccbda
...@@ -12,8 +12,6 @@ Set `"mm_config": {"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vll ...@@ -12,8 +12,6 @@ Set `"mm_config": {"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vll
- `mm_type`: Specifies the quantized operator - `mm_type`: Specifies the quantized operator
- `weight_auto_quant: true`: Enables automatic model quantization - `weight_auto_quant: true`: Enables automatic model quantization
## Quantized Inference
### Offline Quantization ### Offline Quantization
lightx2v also supports direct loading of pre-quantized weights. For offline model quantization, refer to the [documentation](https://github.com/ModelTC/lightx2v/tree/main/tools/convert/readme.md). lightx2v also supports direct loading of pre-quantized weights. For offline model quantization, refer to the [documentation](https://github.com/ModelTC/lightx2v/tree/main/tools/convert/readme.md).
...@@ -21,6 +19,9 @@ Configure the [quantization file](https://github.com/ModelTC/lightx2v/tree/main/ ...@@ -21,6 +19,9 @@ Configure the [quantization file](https://github.com/ModelTC/lightx2v/tree/main/
1. Set `dit_quantized_ckpt` to the converted weight path 1. Set `dit_quantized_ckpt` to the converted weight path
2. Set `weight_auto_quant` to `false` in `mm_type` 2. Set `weight_auto_quant` to `false` in `mm_type`
## Quantized Inference
### Automatic Quantization ### Automatic Quantization
```shell ```shell
bash scripts/run_wan_i2v_quant_auto.sh bash scripts/run_wan_i2v_quant_auto.sh
......
...@@ -11,13 +11,12 @@ lightx2v支持推理时自动对模型权重进行量化,具体可参考[配 ...@@ -11,13 +11,12 @@ lightx2v支持推理时自动对模型权重进行量化,具体可参考[配
值得注意的是,需要将配置文件的**mm_config**进行设置:**"mm_config": {"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm","weight_auto_quant": true }****mm_type**代表希望使用的量化算子,**weight_auto_quant:true**代表自动转量化模型。 值得注意的是,需要将配置文件的**mm_config**进行设置:**"mm_config": {"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm","weight_auto_quant": true }****mm_type**代表希望使用的量化算子,**weight_auto_quant:true**代表自动转量化模型。
## 量化推理
### 离线量化 ### 离线量化
lightx2v同时支持直接加载量化好的权重进行推理,对模型进行离线量化可参考[文档](https://github.com/ModelTC/lightx2v/tree/main/tools/convert/readme_zh.md) lightx2v同时支持直接加载量化好的权重进行推理,对模型进行离线量化可参考[文档](https://github.com/ModelTC/lightx2v/tree/main/tools/convert/readme_zh.md)
将转换的权重路径,写到[配置文件](https://github.com/ModelTC/lightx2v/tree/main/configs/quantization/wan_i2v_quant_offline.json)中的`dit_quantized_ckpt`中,同时`mm_type**中的**weight_auto_quant`置为`false`即可。 将转换的权重路径,写到[配置文件](https://github.com/ModelTC/lightx2v/tree/main/configs/quantization/wan_i2v_quant_offline.json)中的`dit_quantized_ckpt`中,同时`mm_type**中的**weight_auto_quant`置为`false`即可。
## 量化推理
### 自动量化 ### 自动量化
```shell ```shell
......
import os
import gradio as gr
import asyncio
import argparse
import json
import torch
import gc
from easydict import EasyDict
from loguru import logger
from lightx2v.infer import init_runner
from lightx2v.utils.envs import *
logger.add(
"inference_logs.log",
rotation="100 MB",
encoding="utf-8",
enqueue=True,
backtrace=True,
diagnose=True,
)
SUPPORTED_MODEL = "wan2.1"
TASK = "i2v"
def run_inference(
model_path,
prompt,
negative_prompt,
image_path,
save_video_path,
torch_compile,
infer_steps,
num_frames,
width,
height,
seed,
enable_teacache,
enable_cfg,
cfg_scale,
quant_option,
fps,
use_tiny_vae,
tiny_vae_path,
):
"""Wrapper for wan2.1 I2V inference logic with advanced options"""
if torch_compile:
os.environ["ENABLE_GRAPH_MODE"] = "true"
config = {
"infer_steps": infer_steps,
"target_video_length": num_frames,
"target_height": height,
"target_width": width,
"attention_type": "sage_attn2",
"seed": seed,
"enable_cfg": enable_cfg,
"sample_guide_scale": cfg_scale,
"sample_shift": 5,
"cpu_offload": True,
"offload_granularity": "phase",
"t5_offload_granularity": "block",
"dit_quantized_ckpt": model_path,
"mm_config": {
"mm_type": ("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm" if quant_option == "fp8" else "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm"),
},
"fps": fps,
"feature_caching": "Tea" if enable_teacache else "NoCaching",
"coefficients": [
[
2.57151496e05,
-3.54229917e04,
1.40286849e03,
-1.35890334e01,
1.32517977e-01,
],
[
-3.02331670e02,
2.23948934e02,
-5.25463970e01,
5.87348440e00,
-2.01973289e-01,
],
],
"use_ret_steps": True,
"teacache_thresh": 0.26,
"t5_quantized": True,
"t5_quantized_ckpt": os.path.join(model_path, "models_t5_umt5-xxl-enc-int8.pth"),
"t5_quant_scheme": "int8",
"clip_quantized": True,
"clip_quantized_ckpt": os.path.join(model_path, "clip-int8.pth"),
"clip_quant_scheme": "int8",
"use_tiling_vae": True,
"tiny_vae": use_tiny_vae,
"tiny_vae_path": tiny_vae_path if use_tiny_vae else None,
"lazy_load": True,
"do_mm_calib": False,
"parallel_attn_type": None,
"parallel_vae": False,
"max_area": False,
"vae_stride": (4, 8, 8),
"patch_size": (1, 2, 2),
"teacache_thresh": 0.26,
"use_bfloat16": True,
"lora_path": None,
"strength_model": 1.0,
"use_prompt_enhancer": False,
"text_len": 512,
}
args = argparse.Namespace(
model_cls=SUPPORTED_MODEL,
task=TASK,
model_path=model_path,
prompt_enhancer=None,
prompt=prompt,
negative_prompt=negative_prompt,
image_path=image_path,
save_video_path=save_video_path,
)
config.update({k: v for k, v in vars(args).items()})
config = EasyDict(config)
config["mode"] = "infer"
if os.path.exists(os.path.join(model_path, "config.json")):
with open(os.path.join(model_path, "config.json"), "r") as f:
model_config = json.load(f)
config.update(model_config)
logger.info(f"Updated inference config:\n{json.dumps(config, indent=4, ensure_ascii=False)}")
runner = init_runner(config)
asyncio.run(runner.run_pipeline())
del runner
gc.collect()
torch.cuda.empty_cache()
return save_video_path
with gr.Blocks(
title="Wan2.1 I2V Video Generation",
css="""
.advanced-options { background: #f9f9ff; border-radius: 10px; padding: 15px; }
.output-video { max-height: 650px; }
.warning { color: #ff6b6b; font-weight: bold; }
""",
) as demo:
gr.Markdown("# 🎬 Wan2.1 Image-to-Video (I2V) Generator")
with gr.Row():
with gr.Column(scale=4):
with gr.Group():
gr.Markdown("## 📥 Input Parameters")
with gr.Row():
image_path = gr.Image(
label="Input Image",
type="filepath",
height=300,
interactive=True,
)
model_path = gr.Textbox(
label="Model Path",
placeholder="/your/path/to/wan2.1_quant_model",
info="Local model folder path (in8/fp8 quantization supported)",
)
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Prompt",
lines=3,
placeholder="Describe the video content...",
max_lines=5,
)
with gr.Column():
negative_prompt = gr.Textbox(
label="Negative Prompt",
lines=3,
placeholder="Unwanted content...",
max_lines=5,
)
with gr.Column():
tiny_vae_path = gr.Textbox(
label="Tiny vae path",
lines=3,
placeholder="/your/path/to/tiny_vae.pth",
max_lines=5,
)
save_video_path = gr.Textbox(
label="Output Video Path",
value="./save_results/wan2.1_i2v_output.mp4",
info="Must include .mp4 suffix",
)
with gr.Accordion("⚙️ Advanced Options", open=False):
with gr.Group(elem_classes="advanced-options"):
gr.Markdown("### Performance Settings")
with gr.Row():
torch_compile = gr.Checkbox(
label="Enable Torch Compile",
value=False,
info="Use torch.compile for faster inference",
)
quant_option = gr.Radio(
label="Quantization Method",
choices=["fp8", "int8"],
value="fp8",
info="Select quantization method for model",
)
infer_steps = gr.Slider(
label="Inference Steps",
minimum=1,
maximum=100,
step=1,
value=20,
info="Infer steps for video generation",
)
enable_teacache = gr.Checkbox(
label="Enable Teacache",
value=False,
info="Teacache for caching features during inference",
)
enable_cfg = gr.Checkbox(
label="Enable CFG",
value=False,
info="Classifier-Free Guidance for prompt strength control",
)
use_tiny_vae = gr.Checkbox(
label="Use Tiny VAE",
value=False,
info="Tiny VAE for faster inference",
)
cfg_scale = gr.Slider(
label="CFG scale",
minimum=1,
maximum=100,
step=1,
value=5,
info="CFG scale for controlling the strength of the prompt",
)
seed = gr.Slider(
label="Seed",
minimum=-10000000,
maximum=10000000,
step=1,
value=42,
info="Random seed for reproducibility",
)
gr.Markdown("### Video Parameters")
with gr.Row():
fps = gr.Slider(
label="FPS (Frames Per Second)",
minimum=8,
maximum=30,
step=1,
value=16,
info="Higher FPS = smoother video",
)
num_frames = gr.Slider(
label="Number of Frames",
minimum=16,
maximum=120,
step=1,
value=81,
info="More frames = longer video",
)
with gr.Row():
width = gr.Number(
label="Width",
value=832,
precision=0,
minimum=320,
maximum=1920,
info="Output video width",
)
height = gr.Number(
label="Height",
value=480,
precision=0,
minimum=240,
maximum=1080,
info="Output video height",
)
gr.Markdown(
"""
<div class="warning">
⚠️ Note: Changing resolution may affect video quality and performance
</div>
"""
)
infer_btn = gr.Button("Generate Video", variant="primary", size="lg")
with gr.Column(scale=6):
gr.Markdown("## 📤 Generated Video")
output_video = gr.Video(
label="Result",
height=624,
width=360,
autoplay=True,
elem_classes=["output-video"],
)
infer_btn.click(
fn=run_inference,
inputs=[
model_path,
prompt,
negative_prompt,
image_path,
save_video_path,
torch_compile,
infer_steps,
num_frames,
width,
height,
seed,
enable_teacache,
enable_cfg,
cfg_scale,
quant_option,
fps,
use_tiny_vae,
tiny_vae_path,
],
outputs=output_video,
)
if __name__ == "__main__":
demo.launch(share=False, server_port=7860, server_name="0.0.0.0")
...@@ -4,8 +4,13 @@ import torch.nn.functional as F ...@@ -4,8 +4,13 @@ import torch.nn.functional as F
# from lightx2v.attentions import attention # from lightx2v.attentions import attention
from lightx2v.attentions.distributed.comm.ring_comm import RingComm from lightx2v.attentions.distributed.comm.ring_comm import RingComm
import flash_attn
from flash_attn.flash_attn_interface import _flash_attn_forward try:
import flash_attn
from flash_attn.flash_attn_interface import _flash_attn_forward
except ImportError:
flash_attn = None
_flash_attn_forward = None
from typing import Optional, Tuple from typing import Optional, Tuple
......
import torch import torch
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from vllm import _custom_ops as ops
import sgl_kernel
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from lightx2v.utils.quant_utils import IntegerQuantizer, FloatQuantizer from lightx2v.utils.quant_utils import IntegerQuantizer, FloatQuantizer
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from loguru import logger from loguru import logger
try:
from vllm import _custom_ops as ops
except ImportError:
ops = None
try:
import sgl_kernel
except ImportError:
sgl_kernel = None
try: try:
import q8_kernels.functional as Q8F import q8_kernels.functional as Q8F
except ImportError: except ImportError:
......
import torch import torch
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from lightx2v.utils.registry_factory import RMS_WEIGHT_REGISTER from lightx2v.utils.registry_factory import RMS_WEIGHT_REGISTER
import sgl_kernel
try:
import sgl_kernel
except ImportError:
sgl_kernel = None
class RMSWeightTemplate(metaclass=ABCMeta): class RMSWeightTemplate(metaclass=ABCMeta):
...@@ -83,8 +87,13 @@ class RMSWeightSgl(RMSWeight): ...@@ -83,8 +87,13 @@ class RMSWeightSgl(RMSWeight):
super().__init__(weight_name, lazy_load, lazy_load_file, eps) super().__init__(weight_name, lazy_load, lazy_load_file, eps)
def apply(self, input_tensor): def apply(self, input_tensor):
input_tensor = input_tensor.contiguous() if sgl_kernel is None:
orig_shape = input_tensor.shape # sgl_kernel is not available, fallback to default implementation
input_tensor = input_tensor.view(-1, orig_shape[-1]) input_tensor = input_tensor * torch.rsqrt(input_tensor.pow(2).mean(-1, keepdim=True) + self.eps)
input_tensor = sgl_kernel.rmsnorm(input_tensor, self.weight, self.eps).view(orig_shape) input_tensor = input_tensor * self.weight
else:
input_tensor = input_tensor.contiguous()
orig_shape = input_tensor.shape
input_tensor = input_tensor.view(-1, orig_shape[-1])
input_tensor = sgl_kernel.rmsnorm(input_tensor, self.weight, self.eps).view(orig_shape)
return input_tensor return input_tensor
...@@ -31,9 +31,10 @@ def init_runner(config): ...@@ -31,9 +31,10 @@ def init_runner(config):
if CHECK_ENABLE_GRAPH_MODE(): if CHECK_ENABLE_GRAPH_MODE():
default_runner = RUNNER_REGISTER[config.model_cls](config) default_runner = RUNNER_REGISTER[config.model_cls](config)
runner = GraphRunner(default_runner) runner = GraphRunner(default_runner)
runner.runner.init_modules()
else: else:
runner = RUNNER_REGISTER[config.model_cls](config) runner = RUNNER_REGISTER[config.model_cls](config)
runner.init_modules() runner.init_modules()
return runner return runner
......
import torch import torch
import sgl_kernel
import torch.cuda.amp as amp
import torch.distributed as dist import torch.distributed as dist
......
...@@ -26,6 +26,7 @@ class DefaultRunner: ...@@ -26,6 +26,7 @@ class DefaultRunner:
logger.warning("No prompt enhancer server available, disable prompt enhancer.") logger.warning("No prompt enhancer server available, disable prompt enhancer.")
def init_modules(self): def init_modules(self):
logger.info("Initializing runner modules...")
self.set_init_device() self.set_init_device()
if self.config["mode"] == "split_server": if self.config["mode"] == "split_server":
self.tensor_transporter = TensorTransporter() self.tensor_transporter = TensorTransporter()
...@@ -164,6 +165,7 @@ class DefaultRunner: ...@@ -164,6 +165,7 @@ class DefaultRunner:
self.vae_decoder = self.load_vae_decoder() self.vae_decoder = self.load_vae_decoder()
images = self.vae_decoder.decode(latents, generator=generator, config=self.config) images = self.vae_decoder.decode(latents, generator=generator, config=self.config)
if self.config.get("lazy_load", False): if self.config.get("lazy_load", False):
del self.vae_decoder
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
return images return images
......
...@@ -77,6 +77,8 @@ def cache_video( ...@@ -77,6 +77,8 @@ def cache_video(
for frame in tensor.numpy(): for frame in tensor.numpy():
writer.append_data(frame) writer.append_data(frame)
writer.close() writer.close()
del tensor
torch.cuda.empty_cache()
return cache_file return cache_file
except Exception as e: except Exception as e:
error = e error = e
......
...@@ -4,6 +4,7 @@ import gc ...@@ -4,6 +4,7 @@ import gc
import glob import glob
import json import json
import argparse import argparse
import shutil
import torch import torch
from safetensors import safe_open, torch as st from safetensors import safe_open, torch as st
from loguru import logger from loguru import logger
...@@ -516,6 +517,31 @@ def convert_weights(args): ...@@ -516,6 +517,31 @@ def convert_weights(args):
json.dump(index, f, indent=2) json.dump(index, f, indent=2)
logger.info(f"Index file written to: {index_path}") logger.info(f"Index file written to: {index_path}")
if os.path.isdir(args.source):
copy_non_weight_files(args.source, args.output)
def copy_non_weight_files(source_dir, target_dir):
ignore_extensions = [".pth", ".pt", ".safetensors"]
logger.info(f"Start copying non-weighted files and subdirectories...")
for item in tqdm(os.listdir(source_dir), desc="复制非权重文件"):
source_item = os.path.join(source_dir, item)
target_item = os.path.join(target_dir, item)
try:
if os.path.isdir(source_item):
os.makedirs(target_item, exist_ok=True)
copy_non_weight_files(source_item, target_item)
elif os.path.isfile(source_item) and not any(source_item.endswith(ext) for ext in ignore_extensions):
shutil.copy2(source_item, target_item)
logger.debug(f"复制文件: {source_item} -> {target_item}")
except Exception as e:
logger.error(f"复制 {source_item} 时出错: {str(e)}")
logger.info(f"Non-weight files and subdirectories copied")
def main(): def main():
parser = argparse.ArgumentParser(description="Model weight format converter") parser = argparse.ArgumentParser(description="Model weight format converter")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment