Unverified Commit 9b6fd118 authored by Musisoul's avatar Musisoul Committed by GitHub
Browse files

[Feat] entrypoint like diffusers (#475)



### 单卡
```bash
python examples/simple_launch.py
```
```python
# examples/simple_launch.py
from lightx2v import LightGenerator

generator = LightGenerator(
    model_path="/path/to/Wan2.1-T2V-1.3B",
    model_cls="wan2.1",
    task="t2v",
)

video_path = generator.generate(
    prompt="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
    negative_prompt="镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
    seed=42,
    save_result_path="output.mp4",
)
```
### 多卡
```bash
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
torchrun --nproc_per_node=8 examples/multi_launch.py
```

---------
Co-authored-by: default avatargushiqiao <975033167@qq.com>
parent d996a81c
......@@ -3,6 +3,7 @@ __author__ = "LightX2V Contributors"
__license__ = "Apache 2.0"
from lightx2v import common, deploy, models, utils
from lightx2v.pipeline import LightX2VPipeline
__all__ = [
"__version__",
......@@ -12,4 +13,5 @@ __all__ = [
"common",
"deploy",
"utils",
"LightX2VPipeline",
]
......@@ -54,7 +54,7 @@ class DefaultRunner(BaseRunner):
elif self.config["task"] == "s2v":
self.run_input_encoder = self._run_input_encoder_local_s2v
self.config.lock() # lock config to avoid modification
if self.config.get("compile", False):
if self.config.get("compile", False) and hasattr(self.model, "comple"):
logger.info(f"[Compile] Compile all shapes: {self.config.get('compile_shapes', [])}")
self.model.compile(self.config.get("compile_shapes", []))
......@@ -295,7 +295,7 @@ class DefaultRunner(BaseRunner):
@ProfilingContext4DebugL2("Run DiT")
def run_main(self):
self.init_run()
if self.config.get("compile", False):
if self.config.get("compile", False) and hasattr(self.model, "comple"):
self.model.select_graph_for_compile(self.input_info)
for segment_idx in range(self.video_segment_num):
logger.info(f"🔄 start segment {segment_idx + 1}/{self.video_segment_num}")
......
......@@ -54,7 +54,7 @@ class MultiDistillModelStruct(MultiModelStruct):
def get_current_model_index(self):
if self.scheduler.step_index < self.boundary_step_index:
logger.info(f"using - HIGH - noise model at step_index {self.scheduler.step_index + 1}")
self.scheduler.sample_guide_scale = self.config["sample_guide_scale"][0]
# self.scheduler.sample_guide_scale = self.config["sample_guide_scale"][0]
if self.config.get("cpu_offload", False) and self.config.get("offload_granularity", "block") == "model":
if self.cur_model_index == -1:
self.to_cuda(model_index=0)
......@@ -64,7 +64,7 @@ class MultiDistillModelStruct(MultiModelStruct):
self.cur_model_index = 0
else:
logger.info(f"using - LOW - noise model at step_index {self.scheduler.step_index + 1}")
self.scheduler.sample_guide_scale = self.config["sample_guide_scale"][1]
# self.scheduler.sample_guide_scale = self.config["sample_guide_scale"][1]
if self.config.get("cpu_offload", False) and self.config.get("offload_granularity", "block") == "model":
if self.cur_model_index == -1:
self.to_cuda(model_index=1)
......
......@@ -103,7 +103,7 @@ class WanRunner(DefaultRunner):
t5_device = torch.device("cpu")
else:
t5_device = torch.device(self.run_device)
tokenizer_path = os.path.join(self.config["model_path"], "google/umt5-xxl")
# quant_config
t5_quantized = self.config.get("t5_quantized", False)
if t5_quantized:
......@@ -113,13 +113,11 @@ class WanRunner(DefaultRunner):
t5_model_name = f"models_t5_umt5-xxl-enc-{tmp_t5_quant_scheme}.pth"
t5_quantized_ckpt = find_torch_model_path(self.config, "t5_quantized_ckpt", t5_model_name)
t5_original_ckpt = None
tokenizer_path = os.path.join(os.path.dirname(t5_quantized_ckpt), "google/umt5-xxl")
else:
t5_quant_scheme = None
t5_quantized_ckpt = None
t5_model_name = "models_t5_umt5-xxl-enc-bf16.pth"
t5_original_ckpt = find_torch_model_path(self.config, "t5_original_ckpt", t5_model_name)
tokenizer_path = os.path.join(os.path.dirname(t5_original_ckpt), "google/umt5-xxl")
text_encoder = T5EncoderModel(
text_len=self.config["text_len"],
......
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["DTYPE"] = "BF16"
os.environ["SENSITIVE_LAYER_DTYPE"] = "None"
os.environ["PROFILING_DEBUG_LEVEL"] = "2"
import json
import torch
import torch.distributed as dist
from loguru import logger
from lightx2v.models.runners.wan.wan_sf_runner import WanSFRunner # noqa: F401
from lightx2v.models.runners.wan.wan_vace_runner import WanVaceRunner # noqa: F401
from lightx2v.utils.input_info import set_input_info
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.set_config import print_config, set_config, set_parallel_config
from lightx2v.utils.utils import seed_all
def dict_like(cls):
cls.__getitem__ = lambda self, key: getattr(self, key)
cls.__setitem__ = lambda self, key, value: setattr(self, key, value)
cls.__delitem__ = lambda self, key: delattr(self, key)
cls.__contains__ = lambda self, key: hasattr(self, key)
def update(self, *args, **kwargs):
for arg in args:
if isinstance(arg, dict):
items = arg.items()
else:
items = arg
for k, v in items:
setattr(self, k, v)
for k, v in kwargs.items():
setattr(self, k, v)
def get(self, key, default=None):
return getattr(self, key, default)
cls.get = get
cls.update = update
return cls
@dict_like
class LightX2VPipeline:
def __init__(
self,
task,
model_path,
model_cls,
image_path=None,
last_frame_path=None,
audio_path=None,
sf_model_path=None,
src_ref_images=None,
src_video=None,
src_mask=None,
dit_original_ckpt=None,
low_noise_original_ckpt=None,
high_noise_original_ckpt=None,
transformer_model_name=None,
):
self.task = task
self.model_path = model_path
self.model_cls = model_cls
self.sf_model_path = sf_model_path
self.image_path = image_path
self.last_frame_path = last_frame_path
self.audio_path = audio_path
self.src_ref_images = src_ref_images
self.src_video = src_video
self.src_mask = src_mask
self.dit_original_ckpt = dit_original_ckpt
self.low_noise_original_ckpt = low_noise_original_ckpt
self.high_noise_original_ckpt = high_noise_original_ckpt
self.transformer_model_name = transformer_model_name
if self.model_cls in [
"wan2.1",
"wan2.1_distill",
"wan2.1_vace",
"wan2.1_sf",
"wan2.1_sf_mtxg2",
"seko_talk",
"wan2.2_moe",
"wan2.2_moe_audio",
"wan2.2_audio",
"wan2.2_moe_distill",
"wan2.2_animate",
]:
self.vae_stride = (4, 8, 8)
if self.model_cls.startswith("wan2.2_moe"):
self.use_image_encoder = False
elif self.model_cls in ["wan2.2"]:
self.vae_stride = (4, 16, 16)
self.num_channels_latents = 48
elif self.model_cls in ["hunyuan_video_1.5"]:
self.vae_stride = (4, 16, 16)
self.num_channels_latents = 32
def create_generator(
self,
attn_mode="flash_attn2",
infer_steps=50,
num_frames=81,
height=480,
width=832,
guidance_scale=5.0,
sample_shift=5.0,
fps=16,
aspect_ratio="16:9",
boundary=0.900,
boundary_step_index=2,
denoising_step_list=[1000, 750, 500, 250],
config_json=None,
):
if config_json is not None:
self.set_infer_config_json(config_json)
else:
self.set_infer_config(
attn_mode,
infer_steps,
num_frames,
height,
width,
guidance_scale,
sample_shift,
fps,
aspect_ratio,
boundary,
boundary_step_index,
denoising_step_list,
)
config = set_config(self)
print_config(config)
self.runner = self._init_runner(config)
logger.info(f"Initializing {self.model_cls} runner for {self.task} task...")
logger.info(f"Model path: {self.model_path}")
logger.info("LightGenerator initialized successfully!")
def set_infer_config(
self,
attn_mode,
infer_steps,
num_frames,
height,
width,
guidance_scale,
sample_shift,
fps,
aspect_ratio,
boundary,
boundary_step_index,
denoising_step_list,
):
self.infer_steps = infer_steps
self.target_width = width
self.target_height = height
self.target_video_length = num_frames
self.sample_guide_scale = guidance_scale
self.sample_shift = sample_shift
if self.sample_guide_scale == 1:
self.enable_cfg = False
else:
self.enable_cfg = True
self.fps = fps
self.aspect_ratio = aspect_ratio
self.boundary = boundary
self.boundary_step_index = boundary_step_index
self.denoising_step_list = denoising_step_list
if self.model_cls.startswith("wan"):
self.self_attn_1_type = attn_mode
self.cross_attn_1_type = attn_mode
self.cross_attn_2_type = attn_mode
elif self.model_cls in ["hunyuan_video_1.5"]:
self.attn_type = attn_mode
def set_infer_config_json(self, config_json):
logger.info(f"Loading infer config from {config_json}")
with open(config_json, "r") as f:
config_json = json.load(f)
self.update(config_json)
def enable_lightvae(
self,
use_lightvae=False,
use_tae=False,
vae_path=None,
tae_path=None,
):
self.use_lightvae = use_lightvae
self.use_tae = use_tae
self.vae_path = vae_path
self.tae_path = tae_path
if self.use_tae and self.model_cls.startswith("wan") and "lighttae" in tae_path:
self.need_scaled = True
def enable_quantize(
self,
dit_quantized=False,
text_encoder_quantized=False,
image_encoder_quantized=False,
dit_quantized_ckpt=None,
low_noise_quantized_ckpt=None,
high_noise_quantized_ckpt=None,
text_encoder_quantized_ckpt=False,
image_encoder_quantized_ckpt=False,
quant_scheme="fp8-sgl",
):
self.dit_quantized = dit_quantized
self.dit_quant_scheme = quant_scheme
self.dit_quantized_ckpt = dit_quantized_ckpt
self.low_noise_quantized_ckpt = low_noise_quantized_ckpt
self.high_noise_quantized_ckpt = high_noise_quantized_ckpt
if self.model_cls.startswith("wan"):
self.t5_quant_scheme = quant_scheme
self.t5_quantized = text_encoder_quantized
self.t5_quantized_ckpt = text_encoder_quantized_ckpt
self.clip_quant_scheme = quant_scheme
self.clip_quantized = image_encoder_quantized
self.clip_quantized_ckpt = image_encoder_quantized_ckpt
elif self.model_cls in ["hunyuan_video_1.5"]:
self.qwen25vl_quantized = text_encoder_quantized
self.qwen25vl_quantized_ckpt = text_encoder_quantized_ckpt
self.qwen25vl_quant_scheme = quant_scheme
def enable_offload(
self,
cpu_offload=False,
offload_granularity="block",
text_encoder_offload=False,
image_encoder_offload=False,
vae_offload=False,
):
self.cpu_offload = cpu_offload
self.offload_granularity = offload_granularity
self.vae_offload = vae_offload
if self.model_cls in [
"wan2.1",
"wan2.1_distill",
"wan2.1_vace",
"wan2.1_sf",
"wan2.1_sf_mtxg2",
"seko_talk",
"wan2.2_moe",
"wan2.2",
"wan2.2_moe_audio",
"wan2.2_audio",
"wan2.2_moe_distill",
"wan2.2_animate",
]:
self.t5_cpu_offload = text_encoder_offload
self.clip_encoder_offload = image_encoder_offload
elif self.model_cls in ["hunyuan_video_1.5"]:
self.qwen25vl_cpu_offload = text_encoder_offload
self.siglip_cpu_offload = image_encoder_offload
self.byt5_cpu_offload = image_encoder_offload
def enable_compile(
self,
):
self.compile = True
self.compile_shapes = [
[480, 832],
[544, 960],
[720, 1280],
[832, 480],
[960, 544],
[1280, 720],
[480, 480],
[576, 576],
[704, 704],
[960, 960],
]
def enable_lora(self, lora_configs):
self.lora_configs = lora_configs
def enable_cache(
self,
cache_method="Tea",
coefficients=[],
teacache_thresh=0.15,
use_ret_steps=False,
magcache_calibration=False,
magcache_K=6,
magcache_thresh=0.24,
magcache_retention_ratio=0.2,
magcache_ratios=[],
):
self.feature_caching = cache_method
if cache_method == "Tea":
self.coefficients = coefficients
self.teacache_thresh = teacache_thresh
self.use_ret_steps = use_ret_steps
elif cache_method == "Mag":
self.magcache_calibration = magcache_calibration
self.magcache_K = magcache_K
self.magcache_thresh = magcache_thresh
self.magcache_retention_ratio = magcache_retention_ratio
self.magcache_ratios = magcache_ratios
def enable_parallel(self, cfg_p_size=1, seq_p_size=1, seq_p_attn_type="ulysses"):
self._init_parallel()
self.parallel = {
"cfg_p_size": cfg_p_size,
"seq_p_size": seq_p_size,
"seq_p_attn_type": seq_p_attn_type,
}
set_parallel_config(self)
@torch.no_grad()
def generate(
self,
seed,
prompt,
negative_prompt,
save_result_path,
return_result_tensor=False,
):
# Run inference (following LightX2V pattern)
self.seed = seed
self.prompt = prompt
self.negative_prompt = negative_prompt
self.save_result_path = save_result_path
self.return_result_tensor = return_result_tensor
seed_all(self.seed)
input_info = set_input_info(self)
self.runner.run_pipeline(input_info)
logger.info("Video generated successfully!")
logger.info(f"Video Saved in {save_result_path}")
def _init_runner(self, config):
torch.set_grad_enabled(False)
runner = RUNNER_REGISTER[config["model_cls"]](config)
runner.init_modules()
return runner
def _init_parallel(self):
dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank())
......@@ -47,7 +47,6 @@ dependencies = [
"imageio-ffmpeg",
"einops",
"loguru",
"sgl-kernel",
"qtorch",
"ftfy",
"gradio",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment