Unverified Commit 04812de2 authored by Yang Yong (雍洋)'s avatar Yang Yong (雍洋) Committed by GitHub
Browse files

Refactor Config System (#338)

parent 6a658f42
......@@ -14,8 +14,8 @@ class WanSFModel(WanModel):
self.to_cuda()
def _load_ckpt(self, unified_dtype, sensitive_layer):
sf_confg = self.config.sf_config
file_path = os.path.join(self.config.sf_model_path, f"checkpoints/self_forcing_{sf_confg.sf_type}.pt")
sf_confg = self.config["sf_config"]
file_path = os.path.join(self.config["sf_model_path"], f"checkpoints/self_forcing_{sf_confg['sf_type']}.pt")
_weight_dict = torch.load(file_path)["generator_ema"]
weight_dict = {}
for k, v in _weight_dict.items():
......
......@@ -40,7 +40,7 @@ class WanPreWeights(WeightModule):
MM_WEIGHT_REGISTER["Default"]("time_projection.1.weight", "time_projection.1.bias"),
)
if config.task in ["i2v", "flf2v", "animate"] and config.get("use_image_encoder", True):
if config["task"] in ["i2v", "flf2v", "animate", "s2v"] and config.get("use_image_encoder", True):
self.add_module(
"proj_0",
LN_WEIGHT_REGISTER["Default"]("img_emb.proj.0.weight", "img_emb.proj.0.bias"),
......@@ -58,7 +58,7 @@ class WanPreWeights(WeightModule):
LN_WEIGHT_REGISTER["Default"]("img_emb.proj.4.weight", "img_emb.proj.4.bias"),
)
if config.model_cls == "wan2.1_distill" and config.get("enable_dynamic_cfg", False):
if config["model_cls"] == "wan2.1_distill" and config.get("enable_dynamic_cfg", False):
self.add_module(
"cfg_cond_proj_1",
MM_WEIGHT_REGISTER["Default"]("guidance_embedding.linear_1.weight", "guidance_embedding.linear_1.bias"),
......@@ -68,12 +68,12 @@ class WanPreWeights(WeightModule):
MM_WEIGHT_REGISTER["Default"]("guidance_embedding.linear_2.weight", "guidance_embedding.linear_2.bias"),
)
if config.task == "flf2v" and config.get("use_image_encoder", True):
if config["task"] == "flf2v" and config.get("use_image_encoder", True):
self.add_module(
"emb_pos",
TENSOR_REGISTER["Default"](f"img_emb.emb_pos"),
)
if config.task == "animate":
if config["task"] == "animate":
self.add_module(
"pose_patch_embedding",
CONV3D_WEIGHT_REGISTER["Default"]("pose_patch_embedding.weight", "pose_patch_embedding.bias", stride=self.patch_size),
......
......@@ -60,7 +60,7 @@ class WanTransformerAttentionBlock(WeightModule):
self.lazy_load = self.config.get("lazy_load", False)
if self.lazy_load:
lazy_load_path = os.path.join(self.config.dit_quantized_ckpt, f"block_{block_index}.safetensors")
lazy_load_path = os.path.join(self.config["dit_quantized_ckpt"], f"block_{block_index}.safetensors")
self.lazy_load_file = safe_open(lazy_load_path, framework="pt", device="cpu")
else:
self.lazy_load_file = None
......@@ -197,7 +197,7 @@ class WanSelfAttention(WeightModule):
self.add_module("self_attn_1", ATTN_WEIGHT_REGISTER[self.config["self_attn_1_type"]]())
if self.config["seq_parallel"]:
self.add_module("self_attn_1_parallel", ATTN_WEIGHT_REGISTER[self.config.parallel.get("seq_p_attn_type", "ulysses")]())
self.add_module("self_attn_1_parallel", ATTN_WEIGHT_REGISTER[self.config["parallel"].get("seq_p_attn_type", "ulysses")]())
if self.quant_method in ["advanced_ptq"]:
self.add_module(
......@@ -296,7 +296,7 @@ class WanCrossAttention(WeightModule):
)
self.add_module("cross_attn_1", ATTN_WEIGHT_REGISTER[self.config["cross_attn_1_type"]]())
if self.config.task in ["i2v", "flf2v"] and self.config.get("use_image_encoder", True):
if self.config["task"] in ["i2v", "flf2v", "animate", "s2v"] and self.config.get("use_image_encoder", True):
self.add_module(
"cross_attn_k_img",
MM_WEIGHT_REGISTER[self.mm_type](
......
......@@ -3,8 +3,6 @@ from abc import ABC
import torch
import torch.distributed as dist
from lightx2v.utils.utils import save_videos_grid
class BaseRunner(ABC):
"""Abstract base class for all Runners
......@@ -15,6 +13,7 @@ class BaseRunner(ABC):
def __init__(self, config):
self.config = config
self.vae_encoder_need_img_original = False
self.input_info = None
def load_transformer(self):
"""Load transformer model
......@@ -100,26 +99,6 @@ class BaseRunner(ABC):
"""Initialize scheduler"""
pass
def set_target_shape(self):
"""Set target shape
Subclasses can override this method to provide specific implementation
Returns:
Dictionary containing target shape information
"""
return {}
def save_video_func(self, images):
"""Save video implementation
Subclasses can override this method to customize save logic
Args:
images: Image sequence to save
"""
save_videos_grid(images, self.config.get("save_video_path", "./output.mp4"), n_rows=1, fps=self.config.get("fps", 8))
def load_vae_decoder(self):
"""Load VAE decoder
......@@ -146,7 +125,7 @@ class BaseRunner(ABC):
pass
def end_run_segment(self, segment_idx=None):
pass
self.gen_video_final = self.gen_video
def end_run(self):
pass
......
import imageio
import numpy as np
from lightx2v.models.input_encoders.hf.t5_v1_1_xxl.model import T5EncoderModel_v1_1_xxl
from lightx2v.models.networks.cogvideox.model import CogvideoxModel
from lightx2v.models.runners.default_runner import DefaultRunner
......@@ -72,9 +69,3 @@ class CogvideoxRunner(DefaultRunner):
)
ret["target_shape"] = self.config.target_shape
return ret
def save_video_func(self, images):
with imageio.get_writer(self.config.save_video_path, fps=16) as writer:
for pil_image in images:
frame_np = np.array(pil_image, dtype=np.uint8)
writer.append_data(frame_np)
......@@ -22,13 +22,13 @@ class DefaultRunner(BaseRunner):
super().__init__(config)
self.has_prompt_enhancer = False
self.progress_callback = None
if self.config.task == "t2v" and self.config.get("sub_servers", {}).get("prompt_enhancer") is not None:
if self.config["task"] == "t2v" and self.config.get("sub_servers", {}).get("prompt_enhancer") is not None:
self.has_prompt_enhancer = True
if not self.check_sub_servers("prompt_enhancer"):
self.has_prompt_enhancer = False
logger.warning("No prompt enhancer server available, disable prompt enhancer.")
if not self.has_prompt_enhancer:
self.config.use_prompt_enhancer = False
self.config["use_prompt_enhancer"] = False
self.set_init_device()
self.init_scheduler()
......@@ -49,12 +49,15 @@ class DefaultRunner(BaseRunner):
self.run_input_encoder = self._run_input_encoder_local_vace
elif self.config["task"] == "animate":
self.run_input_encoder = self._run_input_encoder_local_animate
elif self.config["task"] == "s2v":
self.run_input_encoder = self._run_input_encoder_local_s2v
self.config.lock() # lock config to avoid modification
if self.config.get("compile", False):
logger.info(f"[Compile] Compile all shapes: {self.config.get('compile_shapes', [])}")
self.model.compile(self.config.get("compile_shapes", []))
def set_init_device(self):
if self.config.cpu_offload:
if self.config["cpu_offload"]:
self.init_device = torch.device("cpu")
else:
self.init_device = torch.device("cuda")
......@@ -96,21 +99,23 @@ class DefaultRunner(BaseRunner):
return len(available_servers) > 0
def set_inputs(self, inputs):
self.config["prompt"] = inputs.get("prompt", "")
self.config["use_prompt_enhancer"] = False
if self.has_prompt_enhancer:
self.config["use_prompt_enhancer"] = inputs.get("use_prompt_enhancer", False) # Reset use_prompt_enhancer from clinet side.
self.config["negative_prompt"] = inputs.get("negative_prompt", "")
self.config["image_path"] = inputs.get("image_path", "")
self.config["save_video_path"] = inputs.get("save_video_path", "")
self.config["infer_steps"] = inputs.get("infer_steps", self.config.get("infer_steps", 5))
self.config["target_video_length"] = inputs.get("target_video_length", self.config.get("target_video_length", 81))
self.config["seed"] = inputs.get("seed", self.config.get("seed", 42))
self.config["audio_path"] = inputs.get("audio_path", "") # for wan-audio
self.config["video_duration"] = inputs.get("video_duration", 5) # for wan-audio
# self.config["sample_shift"] = inputs.get("sample_shift", self.config.get("sample_shift", 5))
# self.config["sample_guide_scale"] = inputs.get("sample_guide_scale", self.config.get("sample_guide_scale", 5))
self.input_info.seed = inputs.get("seed", 42)
self.input_info.prompt = inputs.get("prompt", "")
if self.config["use_prompt_enhancer"]:
self.input_info.prompt_enhanced = inputs.get("prompt_enhanced", "")
self.input_info.negative_prompt = inputs.get("negative_prompt", "")
if "image_path" in self.input_info.__dataclass_fields__:
self.input_info.image_path = inputs.get("image_path", "")
if "audio_path" in self.input_info.__dataclass_fields__:
self.input_info.audio_path = inputs.get("audio_path", "")
if "video_path" in self.input_info.__dataclass_fields__:
self.input_info.video_path = inputs.get("video_path", "")
self.input_info.save_result_path = inputs.get("save_result_path", "")
def set_config(self, config_modify):
logger.info(f"modify config: {config_modify}")
with self.config.temporarily_unlocked():
self.config.update(config_modify)
def set_progress_callback(self, callback):
self.progress_callback = callback
......@@ -146,6 +151,7 @@ class DefaultRunner(BaseRunner):
def end_run(self):
self.model.scheduler.clear()
del self.inputs
self.input_info = None
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
if hasattr(self.model.transformer_infer, "weights_stream_mgr"):
self.model.transformer_infer.weights_stream_mgr.clear()
......@@ -162,23 +168,24 @@ class DefaultRunner(BaseRunner):
else:
img_ori = Image.open(img_path).convert("RGB")
img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).cuda()
self.input_info.original_size = img_ori.size
return img, img_ori
@ProfilingContext4DebugL2("Run Encoders")
def _run_input_encoder_local_i2v(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
img, img_ori = self.read_image_input(self.config["image_path"])
img, img_ori = self.read_image_input(self.input_info.image_path)
clip_encoder_out = self.run_image_encoder(img) if self.config.get("use_image_encoder", True) else None
vae_encode_out = self.run_vae_encoder(img_ori if self.vae_encoder_need_img_original else img)
text_encoder_output = self.run_text_encoder(prompt, img)
vae_encode_out, latent_shape = self.run_vae_encoder(img_ori if self.vae_encoder_need_img_original else img)
self.input_info.latent_shape = latent_shape # Important: set latent_shape in input_info
text_encoder_output = self.run_text_encoder(self.input_info)
torch.cuda.empty_cache()
gc.collect()
return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)
@ProfilingContext4DebugL2("Run Encoders")
def _run_input_encoder_local_t2v(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
text_encoder_output = self.run_text_encoder(prompt, None)
self.input_info.latent_shape = self.get_latent_shape_with_target_hw(self.config["target_height"], self.config["target_width"]) # Important: set latent_shape in input_info
text_encoder_output = self.run_text_encoder(self.input_info)
torch.cuda.empty_cache()
gc.collect()
return {
......@@ -188,22 +195,21 @@ class DefaultRunner(BaseRunner):
@ProfilingContext4DebugL2("Run Encoders")
def _run_input_encoder_local_flf2v(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
first_frame, _ = self.read_image_input(self.config["image_path"])
last_frame, _ = self.read_image_input(self.config["last_frame_path"])
first_frame, _ = self.read_image_input(self.input_info.image_path)
last_frame, _ = self.read_image_input(self.input_info.last_frame_path)
clip_encoder_out = self.run_image_encoder(first_frame, last_frame) if self.config.get("use_image_encoder", True) else None
vae_encode_out = self.run_vae_encoder(first_frame, last_frame)
text_encoder_output = self.run_text_encoder(prompt, first_frame)
vae_encode_out, latent_shape = self.run_vae_encoder(first_frame, last_frame)
self.input_info.latent_shape = latent_shape # Important: set latent_shape in input_info
text_encoder_output = self.run_text_encoder(self.input_info)
torch.cuda.empty_cache()
gc.collect()
return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output)
@ProfilingContext4DebugL2("Run Encoders")
def _run_input_encoder_local_vace(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
src_video = self.config.get("src_video", None)
src_mask = self.config.get("src_mask", None)
src_ref_images = self.config.get("src_ref_images", None)
src_video = self.input_info.src_video
src_mask = self.input_info.src_mask
src_ref_images = self.input_info.src_ref_images
src_video, src_mask, src_ref_images = self.prepare_source(
[src_video],
[src_mask],
......@@ -212,34 +218,38 @@ class DefaultRunner(BaseRunner):
)
self.src_ref_images = src_ref_images
vae_encoder_out = self.run_vae_encoder(src_video, src_ref_images, src_mask)
text_encoder_output = self.run_text_encoder(prompt)
vae_encoder_out, latent_shape = self.run_vae_encoder(src_video, src_ref_images, src_mask)
self.input_info.latent_shape = latent_shape # Important: set latent_shape in input_info
text_encoder_output = self.run_text_encoder(self.input_info)
torch.cuda.empty_cache()
gc.collect()
return self.get_encoder_output_i2v(None, vae_encoder_out, text_encoder_output)
@ProfilingContext4DebugL2("Run Text Encoder")
def _run_input_encoder_local_animate(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
text_encoder_output = self.run_text_encoder(prompt, None)
text_encoder_output = self.run_text_encoder(self.input_info)
torch.cuda.empty_cache()
gc.collect()
return self.get_encoder_output_i2v(None, None, text_encoder_output, None)
def _run_input_encoder_local_s2v(self):
pass
def init_run(self):
self.set_target_shape()
self.gen_video_final = None
self.get_video_segment_num()
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.model = self.load_transformer()
self.model.scheduler.prepare(self.inputs["image_encoder_output"])
if self.config.get("model_cls") == "wan2.2" and self.config["task"] == "i2v":
self.model.scheduler.prepare(seed=self.input_info.seed, latent_shape=self.input_info.latent_shape, image_encoder_output=self.inputs["image_encoder_output"])
if self.config.get("model_cls") == "wan2.2" and self.config["task"] in ["i2v", "s2v"]:
self.inputs["image_encoder_output"]["vae_encoder_out"] = None
@ProfilingContext4DebugL2("Run DiT")
def run_main(self, total_steps=None):
self.init_run()
if self.config.get("compile", False):
self.model.select_graph_for_compile()
self.model.select_graph_for_compile(self.input_info)
for segment_idx in range(self.video_segment_num):
logger.info(f"🔄 start segment {segment_idx + 1}/{self.video_segment_num}")
with ProfilingContext4DebugL1(f"segment end2end {segment_idx + 1}/{self.video_segment_num}"):
......@@ -252,7 +262,9 @@ class DefaultRunner(BaseRunner):
self.gen_video = self.run_vae_decoder(latents)
# 4. default do nothing
self.end_run_segment(segment_idx)
gen_video_final = self.process_images_after_vae_decoder()
self.end_run()
return {"video": gen_video_final}
@ProfilingContext4DebugL1("Run VAE Decoder")
def run_vae_decoder(self, latents):
......@@ -281,20 +293,22 @@ class DefaultRunner(BaseRunner):
logger.info(f"Enhanced prompt: {enhanced_prompt}")
return enhanced_prompt
def process_images_after_vae_decoder(self, save_video=True):
self.gen_video = vae_to_comfyui_image(self.gen_video)
def process_images_after_vae_decoder(self):
self.gen_video_final = vae_to_comfyui_image(self.gen_video_final)
if "video_frame_interpolation" in self.config:
assert self.vfi_model is not None and self.config["video_frame_interpolation"].get("target_fps", None) is not None
target_fps = self.config["video_frame_interpolation"]["target_fps"]
logger.info(f"Interpolating frames from {self.config.get('fps', 16)} to {target_fps}")
self.gen_video = self.vfi_model.interpolate_frames(
self.gen_video,
self.gen_video_final = self.vfi_model.interpolate_frames(
self.gen_video_final,
source_fps=self.config.get("fps", 16),
target_fps=target_fps,
)
if save_video:
if self.input_info.return_result_tensor:
return {"video": self.gen_video_final}
elif self.input_info.save_result_path is not None:
if "video_frame_interpolation" in self.config and self.config["video_frame_interpolation"].get("target_fps"):
fps = self.config["video_frame_interpolation"]["target_fps"]
else:
......@@ -303,22 +317,18 @@ class DefaultRunner(BaseRunner):
if not dist.is_initialized() or dist.get_rank() == 0:
logger.info(f"🎬 Start to save video 🎬")
save_to_video(self.gen_video, self.config.save_video_path, fps=fps, method="ffmpeg")
logger.info(f"✅ Video saved successfully to: {self.config.save_video_path} ✅")
if self.config.get("return_video", False):
return {"video": self.gen_video}
return {"video": None}
save_to_video(self.gen_video_final, self.input_info.save_result_path, fps=fps, method="ffmpeg")
logger.info(f"✅ Video saved successfully to: {self.input_info.save_result_path} ✅")
return {"video": None}
def run_pipeline(self, input_info):
self.input_info = input_info
def run_pipeline(self, save_video=True):
if self.config["use_prompt_enhancer"]:
self.config["prompt_enhanced"] = self.post_prompt_enhancer()
self.input_info.prompt_enhanced = self.post_prompt_enhancer()
self.inputs = self.run_input_encoder()
self.run_main()
gen_video = self.process_images_after_vae_decoder(save_video=save_video)
torch.cuda.empty_cache()
gc.collect()
gen_video_final = self.run_main()
return gen_video
return gen_video_final
......@@ -14,7 +14,6 @@ from lightx2v.models.schedulers.hunyuan.scheduler import HunyuanScheduler
from lightx2v.models.video_encoders.hf.hunyuan.hunyuan_vae import HunyuanVAE
from lightx2v.utils.envs import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import save_videos_grid
@RUNNER_REGISTER("hunyuan")
......@@ -152,6 +151,3 @@ class HunyuanRunner(DefaultRunner):
int(self.config.target_width) // vae_scale_factor,
)
return {"target_height": self.config.target_height, "target_width": self.config.target_width, "target_shape": self.config.target_shape}
def save_video_func(self, images):
save_videos_grid(images, self.config.save_video_path, fps=self.config.get("fps", 24))
......@@ -204,7 +204,7 @@ class QwenImageRunner(DefaultRunner):
images = self.run_vae_decoder(latents, generator)
image = images[0]
image.save(f"{self.config.save_video_path}")
image.save(f"{self.config.save_result_path}")
del latents, generator
torch.cuda.empty_cache()
......
......@@ -334,7 +334,7 @@ class WanAnimateRunner(WanRunner):
)
if start != 0:
self.model.scheduler.reset()
self.model.scheduler.reset(self.input_info.seed, self.input_info.latent_shape)
def end_run_segment(self, segment_idx):
if segment_idx != 0:
......
import gc
import json
import os
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
......@@ -337,16 +337,14 @@ class WanAudioRunner(WanRunner): # type:ignore
"""Initialize consistency model scheduler"""
self.scheduler = EulerScheduler(self.config)
def read_audio_input(self):
def read_audio_input(self, audio_path):
"""Read audio input - handles both single and multi-person scenarios"""
audio_sr = self.config.get("audio_sr", 16000)
target_fps = self.config.get("target_fps", 16)
self._audio_processor = AudioProcessor(audio_sr, target_fps)
# Get audio files from person objects or legacy format
audio_files = self._get_audio_files_from_config()
if not audio_files:
return [], 0
audio_files, mask_files = self.get_audio_files_from_audio_path(audio_path)
# Load audio based on single or multi-person mode
if len(audio_files) == 1:
......@@ -355,8 +353,6 @@ class WanAudioRunner(WanRunner): # type:ignore
else:
audio_array = self._audio_processor.load_multi_person_audio(audio_files)
self.config.audio_num = audio_array.size(0)
video_duration = self.config.get("video_duration", 5)
audio_len = int(audio_array.shape[1] / audio_sr * target_fps)
expected_frames = min(max(1, int(video_duration * target_fps)), audio_len)
......@@ -364,60 +360,35 @@ class WanAudioRunner(WanRunner): # type:ignore
# Segment audio
audio_segments = self._audio_processor.segment_audio(audio_array, expected_frames, self.config.get("target_video_length", 81), self.prev_frame_length)
return audio_array.size(0), audio_segments, expected_frames
def _get_audio_files_from_config(self):
talk_objects = self.config.get("talk_objects")
if talk_objects:
audio_files = []
for idx, person in enumerate(talk_objects):
audio_path = person.get("audio")
if audio_path and Path(audio_path).is_file():
audio_files.append(str(audio_path))
else:
logger.warning(f"Person {idx} audio file {audio_path} does not exist or not specified")
if audio_files:
logger.info(f"Loaded {len(audio_files)} audio files from talk_objects")
return audio_files
audio_path = self.config.get("audio_path")
if audio_path:
return [audio_path]
logger.error("config audio_path or talk_objects is not specified")
return []
# Mask latent for multi-person s2v
if mask_files is not None:
mask_latents = [self.process_single_mask(mask_file) for mask_file in mask_files]
mask_latents = torch.cat(mask_latents, dim=0)
else:
mask_latents = None
def read_person_mask(self):
mask_files = self._get_mask_files_from_config()
if not mask_files:
return None
return audio_segments, expected_frames, mask_latents, len(audio_files)
mask_latents = []
for mask_file in mask_files:
mask_latent = self._process_single_mask(mask_file)
mask_latents.append(mask_latent)
def get_audio_files_from_audio_path(self, audio_path):
if os.path.isdir(audio_path):
audio_files = []
mask_files = []
logger.info(f"audio_path is a directory, loading config.json from {audio_path}")
audio_config_path = os.path.join(audio_path, "config.json")
assert os.path.exists(audio_config_path), "config.json not found in audio_path"
with open(audio_config_path, "r") as f:
audio_config = json.load(f)
for talk_object in audio_config["talk_objects"]:
audio_files.append(os.path.join(audio_path, talk_object["audio"]))
mask_files.append(os.path.join(audio_path, talk_object["mask"]))
else:
logger.info(f"audio_path is a file without mask: {audio_path}")
audio_files = [audio_path]
mask_files = None
mask_latents = torch.cat(mask_latents, dim=0)
return mask_latents
return audio_files, mask_files
def _get_mask_files_from_config(self):
talk_objects = self.config.get("talk_objects")
if talk_objects:
mask_files = []
for idx, person in enumerate(talk_objects):
mask_path = person.get("mask")
if mask_path and Path(mask_path).is_file():
mask_files.append(str(mask_path))
elif mask_path:
logger.warning(f"Person {idx} mask file {mask_path} does not exist")
if mask_files:
logger.info(f"Loaded {len(mask_files)} mask files from talk_objects")
return mask_files
logger.info("config talk_objects is not specified")
return None
def _process_single_mask(self, mask_file):
def process_single_mask(self, mask_file):
mask_img = Image.open(mask_file).convert("RGB")
mask_img = TF.to_tensor(mask_img).sub_(0.5).div_(0.5).unsqueeze(0).cuda()
......@@ -456,21 +427,21 @@ class WanAudioRunner(WanRunner): # type:ignore
fixed_shape=self.config.get("fixed_shape", None),
)
logger.info(f"[wan_audio] resize_image target_h: {h}, target_w: {w}")
patched_h = h // self.config.vae_stride[1] // self.config.patch_size[1]
patched_w = w // self.config.vae_stride[2] // self.config.patch_size[2]
patched_h = h // self.config["vae_stride"][1] // self.config["patch_size"][1]
patched_w = w // self.config["vae_stride"][2] // self.config["patch_size"][2]
patched_h, patched_w = get_optimal_patched_size_with_sp(patched_h, patched_w, 1)
self.config.lat_h = patched_h * self.config.patch_size[1]
self.config.lat_w = patched_w * self.config.patch_size[2]
latent_h = patched_h * self.config["patch_size"][1]
latent_w = patched_w * self.config["patch_size"][2]
self.config.tgt_h = self.config.lat_h * self.config.vae_stride[1]
self.config.tgt_w = self.config.lat_w * self.config.vae_stride[2]
latent_shape = self.get_latent_shape_with_lat_hw(latent_h, latent_w)
target_shape = [latent_h * self.config["vae_stride"][1], latent_w * self.config["vae_stride"][2]]
logger.info(f"[wan_audio] tgt_h: {self.config.tgt_h}, tgt_w: {self.config.tgt_w}, lat_h: {self.config.lat_h}, lat_w: {self.config.lat_w}")
logger.info(f"[wan_audio] target_h: {target_shape[0]}, target_w: {target_shape[1]}, latent_h: {latent_h}, latent_w: {latent_w}")
ref_img = torch.nn.functional.interpolate(ref_img, size=(self.config.tgt_h, self.config.tgt_w), mode="bicubic")
return ref_img
ref_img = torch.nn.functional.interpolate(ref_img, size=(target_shape[0], target_shape[1]), mode="bicubic")
return ref_img, latent_shape, target_shape
def run_image_encoder(self, first_frame, last_frame=None):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
......@@ -496,20 +467,17 @@ class WanAudioRunner(WanRunner): # type:ignore
return vae_encoder_out
@ProfilingContext4DebugL2("Run Encoders")
def _run_input_encoder_local_r2v_audio(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
img = self.read_image_input(self.config["image_path"])
def _run_input_encoder_local_s2v(self):
img, latent_shape, target_shape = self.read_image_input(self.input_info.image_path)
self.input_info.latent_shape = latent_shape # Important: set latent_shape in input_info
self.input_info.target_shape = target_shape # Important: set target_shape in input_info
clip_encoder_out = self.run_image_encoder(img) if self.config.get("use_image_encoder", True) else None
vae_encode_out = self.run_vae_encoder(img)
audio_num, audio_segments, expected_frames = self.read_audio_input()
person_mask_latens = self.read_person_mask()
self.config.person_num = 0
if person_mask_latens is not None:
assert audio_num == person_mask_latens.size(0), "audio_num and person_mask_latens.size(0) must be the same"
self.config.person_num = person_mask_latens.size(0)
text_encoder_output = self.run_text_encoder(prompt, None)
audio_segments, expected_frames, person_mask_latens, audio_num = self.read_audio_input(self.input_info.audio_path)
self.input_info.audio_num = audio_num
self.input_info.with_mask = person_mask_latens is not None
text_encoder_output = self.run_text_encoder(self.input_info)
torch.cuda.empty_cache()
gc.collect()
return {
......@@ -528,13 +496,13 @@ class WanAudioRunner(WanRunner): # type:ignore
device = torch.device("cuda")
dtype = GET_DTYPE()
tgt_h, tgt_w = self.config.tgt_h, self.config.tgt_w
prev_frames = torch.zeros((1, 3, self.config.target_video_length, tgt_h, tgt_w), device=device)
tgt_h, tgt_w = self.input_info.target_shape[0], self.input_info.target_shape[1]
prev_frames = torch.zeros((1, 3, self.config["target_video_length"], tgt_h, tgt_w), device=device)
if prev_video is not None:
# Extract and process last frames
last_frames = prev_video[:, :, -prev_frame_length:].clone().to(device)
if self.config.model_cls != "wan2.2_audio":
if self.config["model_cls"] != "wan2.2_audio":
last_frames = self.frame_preprocessor.process_prev_frames(last_frames)
prev_frames[:, :, :prev_frame_length] = last_frames
prev_len = (prev_frame_length - 1) // 4 + 1
......@@ -546,7 +514,7 @@ class WanAudioRunner(WanRunner): # type:ignore
_, nframe, height, width = self.model.scheduler.latents.shape
with ProfilingContext4DebugL1("vae_encoder in init run segment"):
if self.config.model_cls == "wan2.2_audio":
if self.config["model_cls"] == "wan2.2_audio":
if prev_video is not None:
prev_latents = self.vae_encoder.encode(prev_frames.to(dtype))
else:
......@@ -563,7 +531,7 @@ class WanAudioRunner(WanRunner): # type:ignore
if prev_latents is not None:
if prev_latents.shape[-2:] != (height, width):
logger.warning(f"Size mismatch: prev_latents {prev_latents.shape} vs scheduler latents (H={height}, W={width}). Config tgt_h={self.config.tgt_h}, tgt_w={self.config.tgt_w}")
logger.warning(f"Size mismatch: prev_latents {prev_latents.shape} vs scheduler latents (H={height}, W={width}). Config tgt_h={tgt_h}, tgt_w={tgt_w}")
prev_latents = torch.nn.functional.interpolate(prev_latents, size=(height, width), mode="bilinear", align_corners=False)
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
......@@ -592,8 +560,8 @@ class WanAudioRunner(WanRunner): # type:ignore
super().init_run()
self.scheduler.set_audio_adapter(self.audio_adapter)
self.prev_video = None
if self.config.get("return_video", False):
self.gen_video_final = torch.zeros((self.inputs["expected_frames"], self.config.tgt_h, self.config.tgt_w, 3), dtype=torch.float32, device="cpu")
if self.input_info.return_result_tensor:
self.gen_video_final = torch.zeros((self.inputs["expected_frames"], self.input_info.target_shape[0], self.input_info.target_shape[1], 3), dtype=torch.float32, device="cpu")
self.cut_audio_final = torch.zeros((self.inputs["expected_frames"] * self._audio_processor.audio_frame_rate), dtype=torch.float32, device="cpu")
else:
self.gen_video_final = None
......@@ -608,8 +576,8 @@ class WanAudioRunner(WanRunner): # type:ignore
else:
self.segment = self.inputs["audio_segments"][segment_idx]
self.config.seed = self.config.seed + segment_idx
torch.manual_seed(self.config.seed)
self.input_info.seed = self.input_info.seed + segment_idx
torch.manual_seed(self.input_info.seed)
# logger.info(f"Processing segment {segment_idx + 1}/{self.video_segment_num}, seed: {self.config.seed}")
if (self.config.get("lazy_load", False) or self.config.get("unload_modules", False)) and not hasattr(self, "audio_encoder"):
......@@ -627,7 +595,7 @@ class WanAudioRunner(WanRunner): # type:ignore
# Reset scheduler for non-first segments
if segment_idx > 0:
self.model.scheduler.reset(self.inputs["previmg_encoder_output"])
self.model.scheduler.reset(self.input_info.seed, self.input_info.latent_shape, self.inputs["previmg_encoder_output"])
@ProfilingContext4DebugL1("End run segment")
def end_run_segment(self, segment_idx):
......@@ -650,7 +618,7 @@ class WanAudioRunner(WanRunner): # type:ignore
if self.va_recorder:
self.va_recorder.pub_livestream(video_seg, audio_seg)
elif self.config.get("return_video", False):
elif self.input_info.return_result_tensor:
self.gen_video_final[self.segment.start_frame : self.segment.end_frame].copy_(video_seg)
self.cut_audio_final[self.segment.start_frame * self._audio_processor.audio_frame_rate : self.segment.end_frame * self._audio_processor.audio_frame_rate].copy_(audio_seg)
......@@ -669,7 +637,7 @@ class WanAudioRunner(WanRunner): # type:ignore
return rank, world_size
def init_va_recorder(self):
output_video_path = self.config.get("save_video_path", None)
output_video_path = self.input_info.save_result_path
self.va_recorder = None
if isinstance(output_video_path, dict):
output_video_path = output_video_path["data"]
......@@ -722,7 +690,7 @@ class WanAudioRunner(WanRunner): # type:ignore
self.init_run()
if self.config.get("compile", False):
self.model.select_graph_for_compile()
self.model.select_graph_for_compile(self.input_info)
self.video_segment_num = "unlimited"
fetch_timeout = self.va_reader.segment_duration + 1
......@@ -760,24 +728,20 @@ class WanAudioRunner(WanRunner): # type:ignore
self.va_recorder = None
@ProfilingContext4DebugL1("Process after vae decoder")
def process_images_after_vae_decoder(self, save_video=False):
if self.config.get("return_video", False):
def process_images_after_vae_decoder(self):
if self.input_info.return_result_tensor:
audio_waveform = self.cut_audio_final.unsqueeze(0).unsqueeze(0)
comfyui_audio = {"waveform": audio_waveform, "sample_rate": self._audio_processor.audio_sr}
return {"video": self.gen_video_final, "audio": comfyui_audio}
return {"video": None, "audio": None}
def init_modules(self):
super().init_modules()
self.run_input_encoder = self._run_input_encoder_local_r2v_audio
def load_transformer(self):
"""Load transformer with LoRA support"""
base_model = WanAudioModel(self.config.model_path, self.config, self.init_device)
if self.config.get("lora_configs") and self.config.lora_configs:
assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
base_model = WanAudioModel(self.config["model_path"], self.config, self.init_device)
if self.config.get("lora_configs") and self.config["lora_configs"]:
assert not self.config.get("dit_quantized", False) or self.config["mm_config"].get("weight_auto_quant", False)
lora_wrapper = WanLoraWrapper(base_model)
for lora_config in self.config.lora_configs:
for lora_config in self.config["lora_configs"]:
lora_path = lora_config["path"]
strength = lora_config.get("strength", 1.0)
lora_name = lora_wrapper.load_lora(lora_path)
......@@ -814,7 +778,7 @@ class WanAudioRunner(WanRunner): # type:ignore
audio_adapter.to(device)
load_from_rank0 = self.config.get("load_from_rank0", False)
weights_dict = load_weights(self.config.adapter_model_path, cpu_offload=audio_adapter_offload, remove_key="ca", load_from_rank0=load_from_rank0)
weights_dict = load_weights(self.config["adapter_model_path"], cpu_offload=audio_adapter_offload, remove_key="ca", load_from_rank0=load_from_rank0)
audio_adapter.load_state_dict(weights_dict, strict=False)
return audio_adapter.to(dtype=GET_DTYPE())
......@@ -824,28 +788,14 @@ class WanAudioRunner(WanRunner): # type:ignore
self.audio_encoder = self.load_audio_encoder()
self.audio_adapter = self.load_audio_adapter()
def set_target_shape(self):
"""Set target shape for generation"""
ret = {}
num_channels_latents = 16
if self.config.model_cls == "wan2.2_audio":
num_channels_latents = self.config.num_channels_latents
if self.config.task == "i2v":
self.config.target_shape = (
num_channels_latents,
(self.config.target_video_length - 1) // self.config.vae_stride[0] + 1,
self.config.lat_h,
self.config.lat_w,
)
ret["lat_h"] = self.config.lat_h
ret["lat_w"] = self.config.lat_w
else:
error_msg = "t2v task is not supported in WanAudioRunner"
assert False, error_msg
ret["target_shape"] = self.config.target_shape
return ret
def get_latent_shape_with_lat_hw(self, latent_h, latent_w):
latent_shape = [
self.config.get("num_channels_latents", 16),
(self.config["target_video_length"] - 1) // self.config["vae_stride"][0] + 1,
latent_h,
latent_w,
]
return latent_shape
@RUNNER_REGISTER("wan2.2_audio")
......@@ -882,7 +832,7 @@ class Wan22AudioRunner(WanAudioRunner):
"cpu_offload": vae_offload,
"offload_cache": self.config.get("vae_offload_cache", False),
}
if self.config.task != "i2v":
if self.config.task not in ["i2v", "s2v"]:
return None
else:
return Wan2_2_VAE(**vae_config)
......
......@@ -50,7 +50,7 @@ class WanCausVidRunner(WanRunner):
self.scheduler = WanStepDistillScheduler(self.config)
def set_target_shape(self):
if self.config.task == "i2v":
if self.config.task in ["i2v", "s2v"]:
self.config.target_shape = (16, self.config.num_frame_per_block, self.config.lat_h, self.config.lat_w)
# i2v需根据input shape重置frame_seq_length
frame_seq_length = (self.config.lat_h // 2) * (self.config.lat_w // 2)
......
......@@ -17,28 +17,28 @@ class WanDistillRunner(WanRunner):
super().__init__(config)
def load_transformer(self):
if self.config.get("lora_configs") and self.config.lora_configs:
if self.config.get("lora_configs") and self.config["lora_configs"]:
model = WanModel(
self.config.model_path,
self.config["model_path"],
self.config,
self.init_device,
)
lora_wrapper = WanLoraWrapper(model)
for lora_config in self.config.lora_configs:
for lora_config in self.config["lora_configs"]:
lora_path = lora_config["path"]
strength = lora_config.get("strength", 1.0)
lora_name = lora_wrapper.load_lora(lora_path)
lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
else:
model = WanDistillModel(self.config.model_path, self.config, self.init_device)
model = WanDistillModel(self.config["model_path"], self.config, self.init_device)
return model
def init_scheduler(self):
if self.config.feature_caching == "NoCaching":
if self.config["feature_caching"] == "NoCaching":
self.scheduler = WanStepDistillScheduler(self.config)
else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
class MultiDistillModelStruct(MultiModelStruct):
......@@ -54,7 +54,7 @@ class MultiDistillModelStruct(MultiModelStruct):
def get_current_model_index(self):
if self.scheduler.step_index < self.boundary_step_index:
logger.info(f"using - HIGH - noise model at step_index {self.scheduler.step_index + 1}")
self.scheduler.sample_guide_scale = self.config.sample_guide_scale[0]
self.scheduler.sample_guide_scale = self.config["sample_guide_scale"][0]
if self.config.get("cpu_offload", False) and self.config.get("offload_granularity", "block") == "model":
if self.cur_model_index == -1:
self.to_cuda(model_index=0)
......@@ -64,7 +64,7 @@ class MultiDistillModelStruct(MultiModelStruct):
self.cur_model_index = 0
else:
logger.info(f"using - LOW - noise model at step_index {self.scheduler.step_index + 1}")
self.scheduler.sample_guide_scale = self.config.sample_guide_scale[1]
self.scheduler.sample_guide_scale = self.config["sample_guide_scale"][1]
if self.config.get("cpu_offload", False) and self.config.get("offload_granularity", "block") == "model":
if self.cur_model_index == -1:
self.to_cuda(model_index=1)
......@@ -81,8 +81,8 @@ class Wan22MoeDistillRunner(WanDistillRunner):
def load_transformer(self):
use_high_lora, use_low_lora = False, False
if self.config.get("lora_configs") and self.config.lora_configs:
for lora_config in self.config.lora_configs:
if self.config.get("lora_configs") and self.config["lora_configs"]:
for lora_config in self.config["lora_configs"]:
if lora_config.get("name", "") == "high_noise_model":
use_high_lora = True
elif lora_config.get("name", "") == "low_noise_model":
......@@ -90,12 +90,12 @@ class Wan22MoeDistillRunner(WanDistillRunner):
if use_high_lora:
high_noise_model = WanModel(
os.path.join(self.config.model_path, "high_noise_model"),
os.path.join(self.config["model_path"], "high_noise_model"),
self.config,
self.init_device,
)
high_lora_wrapper = WanLoraWrapper(high_noise_model)
for lora_config in self.config.lora_configs:
for lora_config in self.config["lora_configs"]:
if lora_config.get("name", "") == "high_noise_model":
lora_path = lora_config["path"]
strength = lora_config.get("strength", 1.0)
......@@ -104,19 +104,19 @@ class Wan22MoeDistillRunner(WanDistillRunner):
logger.info(f"High noise model loaded LoRA: {lora_name} with strength: {strength}")
else:
high_noise_model = Wan22MoeDistillModel(
os.path.join(self.config.model_path, "distill_models", "high_noise_model"),
os.path.join(self.config["model_path"], "distill_models", "high_noise_model"),
self.config,
self.init_device,
)
if use_low_lora:
low_noise_model = WanModel(
os.path.join(self.config.model_path, "low_noise_model"),
os.path.join(self.config["model_path"], "low_noise_model"),
self.config,
self.init_device,
)
low_lora_wrapper = WanLoraWrapper(low_noise_model)
for lora_config in self.config.lora_configs:
for lora_config in self.config["lora_configs"]:
if lora_config.get("name", "") == "low_noise_model":
lora_path = lora_config["path"]
strength = lora_config.get("strength", 1.0)
......@@ -125,15 +125,15 @@ class Wan22MoeDistillRunner(WanDistillRunner):
logger.info(f"Low noise model loaded LoRA: {lora_name} with strength: {strength}")
else:
low_noise_model = Wan22MoeDistillModel(
os.path.join(self.config.model_path, "distill_models", "low_noise_model"),
os.path.join(self.config["model_path"], "distill_models", "low_noise_model"),
self.config,
self.init_device,
)
return MultiDistillModelStruct([high_noise_model, low_noise_model], self.config, self.config.boundary_step_index)
return MultiDistillModelStruct([high_noise_model, low_noise_model], self.config, self.config["boundary_step_index"])
def init_scheduler(self):
if self.config.feature_caching == "NoCaching":
if self.config["feature_caching"] == "NoCaching":
self.scheduler = Wan22StepDistillScheduler(self.config)
else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
......@@ -28,7 +28,7 @@ from lightx2v.utils.envs import *
from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import *
from lightx2v.utils.utils import best_output_size, cache_video
from lightx2v.utils.utils import best_output_size
@RUNNER_REGISTER("wan2.1")
......@@ -42,7 +42,7 @@ class WanRunner(DefaultRunner):
def load_transformer(self):
model = WanModel(
self.config.model_path,
self.config["model_path"],
self.config,
self.init_device,
)
......@@ -59,7 +59,7 @@ class WanRunner(DefaultRunner):
def load_image_encoder(self):
image_encoder = None
if self.config.task in ["i2v", "flf2v", "animate"] and self.config.get("use_image_encoder", True):
if self.config["task"] in ["i2v", "flf2v", "animate", "s2v"] and self.config.get("use_image_encoder", True):
# offload config
clip_offload = self.config.get("clip_cpu_offload", self.config.get("cpu_offload", False))
if clip_offload:
......@@ -148,13 +148,13 @@ class WanRunner(DefaultRunner):
vae_config = {
"vae_pth": find_torch_model_path(self.config, "vae_pth", self.vae_name),
"device": vae_device,
"parallel": self.config.parallel,
"parallel": self.config["parallel"],
"use_tiling": self.config.get("use_tiling_vae", False),
"cpu_offload": vae_offload,
"dtype": GET_DTYPE(),
"load_from_rank0": self.config.get("load_from_rank0", False),
}
if self.config.task not in ["i2v", "flf2v", "animate", "vace"]:
if self.config["task"] not in ["i2v", "flf2v", "animate", "vace", "s2v"]:
return None
else:
return self.vae_cls(**vae_config)
......@@ -170,7 +170,7 @@ class WanRunner(DefaultRunner):
vae_config = {
"vae_pth": find_torch_model_path(self.config, "vae_pth", self.vae_name),
"device": vae_device,
"parallel": self.config.parallel,
"parallel": self.config["parallel"],
"use_tiling": self.config.get("use_tiling_vae", False),
"cpu_offload": vae_offload,
"dtype": GET_DTYPE(),
......@@ -192,9 +192,9 @@ class WanRunner(DefaultRunner):
return vae_encoder, vae_decoder
def init_scheduler(self):
if self.config.feature_caching == "NoCaching":
if self.config["feature_caching"] == "NoCaching":
scheduler_class = WanScheduler
elif self.config.feature_caching == "TaylorSeer":
elif self.config["feature_caching"] == "TaylorSeer":
scheduler_class = WanSchedulerTaylorCaching
elif self.config.feature_caching in ["Tea", "Ada", "Custom", "FirstBlock", "DualBlock", "DynamicBlock", "Mag"]:
scheduler_class = WanSchedulerCaching
......@@ -206,26 +206,28 @@ class WanRunner(DefaultRunner):
else:
self.scheduler = scheduler_class(self.config)
def run_text_encoder(self, text, img=None):
def run_text_encoder(self, input_info):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.text_encoders = self.load_text_encoder()
n_prompt = self.config.get("negative_prompt", "")
prompt = input_info.prompt_enhanced if self.config["use_prompt_enhancer"] else input_info.prompt
neg_prompt = input_info.negative_prompt
if self.config["cfg_parallel"]:
cfg_p_group = self.config["device_mesh"].get_group(mesh_dim="cfg_p")
cfg_p_rank = dist.get_rank(cfg_p_group)
if cfg_p_rank == 0:
context = self.text_encoders[0].infer([text])
context = self.text_encoders[0].infer([prompt])
context = torch.stack([torch.cat([u, u.new_zeros(self.config["text_len"] - u.size(0), u.size(1))]) for u in context])
text_encoder_output = {"context": context}
else:
context_null = self.text_encoders[0].infer([n_prompt])
context_null = self.text_encoders[0].infer([neg_prompt])
context_null = torch.stack([torch.cat([u, u.new_zeros(self.config["text_len"] - u.size(0), u.size(1))]) for u in context_null])
text_encoder_output = {"context_null": context_null}
else:
context = self.text_encoders[0].infer([text])
context = self.text_encoders[0].infer([prompt])
context = torch.stack([torch.cat([u, u.new_zeros(self.config["text_len"] - u.size(0), u.size(1))]) for u in context])
context_null = self.text_encoders[0].infer([n_prompt])
context_null = self.text_encoders[0].infer([neg_prompt])
context_null = torch.stack([torch.cat([u, u.new_zeros(self.config["text_len"] - u.size(0), u.size(1))]) for u in context_null])
text_encoder_output = {
"context": context,
......@@ -255,22 +257,22 @@ class WanRunner(DefaultRunner):
def run_vae_encoder(self, first_frame, last_frame=None):
h, w = first_frame.shape[2:]
aspect_ratio = h / w
max_area = self.config.target_height * self.config.target_width
lat_h = round(np.sqrt(max_area * aspect_ratio) // self.config.vae_stride[1] // self.config.patch_size[1] * self.config.patch_size[1])
lat_w = round(np.sqrt(max_area / aspect_ratio) // self.config.vae_stride[2] // self.config.patch_size[2] * self.config.patch_size[2])
max_area = self.config["target_height"] * self.config["target_width"]
latent_h = round(np.sqrt(max_area * aspect_ratio) // self.config["vae_stride"][1] // self.config["patch_size"][1] * self.config["patch_size"][1])
latent_w = round(np.sqrt(max_area / aspect_ratio) // self.config["vae_stride"][2] // self.config["patch_size"][2] * self.config["patch_size"][2])
latent_shape = self.get_latent_shape_with_lat_hw(latent_h, latent_w) # Important: latent_shape is used to set the input_info
if self.config.get("changing_resolution", False):
assert last_frame is None
self.config.lat_h, self.config.lat_w = lat_h, lat_w
vae_encode_out_list = []
for i in range(len(self.config["resolution_rate"])):
lat_h, lat_w = (
int(self.config.lat_h * self.config.resolution_rate[i]) // 2 * 2,
int(self.config.lat_w * self.config.resolution_rate[i]) // 2 * 2,
latent_h_tmp, latent_w_tmp = (
int(latent_h * self.config["resolution_rate"][i]) // 2 * 2,
int(latent_w * self.config["resolution_rate"][i]) // 2 * 2,
)
vae_encode_out_list.append(self.get_vae_encoder_output(first_frame, lat_h, lat_w))
vae_encode_out_list.append(self.get_vae_encoder_output(first_frame, self.config.lat_h, self.config.lat_w))
return vae_encode_out_list
vae_encode_out_list.append(self.get_vae_encoder_output(first_frame, latent_h_tmp, latent_w_tmp))
vae_encode_out_list.append(self.get_vae_encoder_output(first_frame, latent_h, latent_w))
return vae_encode_out_list, latent_shape
else:
if last_frame is not None:
first_frame_size = first_frame.shape[2:]
......@@ -282,16 +284,15 @@ class WanRunner(DefaultRunner):
round(last_frame_size[1] * last_frame_resize_ratio),
]
last_frame = TF.center_crop(last_frame, last_frame_size)
self.config.lat_h, self.config.lat_w = lat_h, lat_w
vae_encoder_out = self.get_vae_encoder_output(first_frame, lat_h, lat_w, last_frame)
return vae_encoder_out
vae_encoder_out = self.get_vae_encoder_output(first_frame, latent_h, latent_w, last_frame)
return vae_encoder_out, latent_shape
def get_vae_encoder_output(self, first_frame, lat_h, lat_w, last_frame=None):
h = lat_h * self.config.vae_stride[1]
w = lat_w * self.config.vae_stride[2]
h = lat_h * self.config["vae_stride"][1]
w = lat_w * self.config["vae_stride"][2]
msk = torch.ones(
1,
self.config.target_video_length,
self.config["target_video_length"],
lat_h,
lat_w,
device=torch.device("cuda"),
......@@ -312,7 +313,7 @@ class WanRunner(DefaultRunner):
vae_input = torch.concat(
[
torch.nn.functional.interpolate(first_frame.cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
torch.zeros(3, self.config.target_video_length - 2, h, w),
torch.zeros(3, self.config["target_video_length"] - 2, h, w),
torch.nn.functional.interpolate(last_frame.cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
],
dim=1,
......@@ -321,7 +322,7 @@ class WanRunner(DefaultRunner):
vae_input = torch.concat(
[
torch.nn.functional.interpolate(first_frame.cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
torch.zeros(3, self.config.target_video_length - 1, h, w),
torch.zeros(3, self.config["target_video_length"] - 1, h, w),
],
dim=1,
).cuda()
......@@ -345,32 +346,23 @@ class WanRunner(DefaultRunner):
"image_encoder_output": image_encoder_output,
}
def set_target_shape(self):
num_channels_latents = self.config.get("num_channels_latents", 16)
if self.config.task in ["i2v", "flf2v", "animate"]:
self.config.target_shape = (
num_channels_latents,
(self.config.target_video_length - 1) // self.config.vae_stride[0] + 1,
self.config.lat_h,
self.config.lat_w,
)
elif self.config.task == "t2v":
self.config.target_shape = (
num_channels_latents,
(self.config.target_video_length - 1) // self.config.vae_stride[0] + 1,
int(self.config.target_height) // self.config.vae_stride[1],
int(self.config.target_width) // self.config.vae_stride[2],
)
def save_video_func(self, images):
cache_video(
tensor=images,
save_file=self.config.save_video_path,
fps=self.config.get("fps", 16),
nrow=1,
normalize=True,
value_range=(-1, 1),
)
def get_latent_shape_with_lat_hw(self, latent_h, latent_w):
latent_shape = [
self.config.get("num_channels_latents", 16),
(self.config["target_video_length"] - 1) // self.config["vae_stride"][0] + 1,
latent_h,
latent_w,
]
return latent_shape
def get_latent_shape_with_target_hw(self, target_h, target_w):
latent_shape = [
self.config.get("num_channels_latents", 16),
(self.config["target_video_length"] - 1) // self.config["vae_stride"][0] + 1,
int(target_h) // self.config["vae_stride"][1],
int(target_w) // self.config["vae_stride"][2],
]
return latent_shape
class MultiModelStruct:
......@@ -400,7 +392,7 @@ class MultiModelStruct:
def get_current_model_index(self):
if self.scheduler.timesteps[self.scheduler.step_index] >= self.boundary_timestep:
logger.info(f"using - HIGH - noise model at step_index {self.scheduler.step_index + 1}")
self.scheduler.sample_guide_scale = self.config.sample_guide_scale[0]
self.scheduler.sample_guide_scale = self.config["sample_guide_scale"][0]
if self.config.get("cpu_offload", False) and self.config.get("offload_granularity", "block") == "model":
if self.cur_model_index == -1:
self.to_cuda(model_index=0)
......@@ -410,7 +402,7 @@ class MultiModelStruct:
self.cur_model_index = 0
else:
logger.info(f"using - LOW - noise model at step_index {self.scheduler.step_index + 1}")
self.scheduler.sample_guide_scale = self.config.sample_guide_scale[1]
self.scheduler.sample_guide_scale = self.config["sample_guide_scale"][1]
if self.config.get("cpu_offload", False) and self.config.get("offload_granularity", "block") == "model":
if self.cur_model_index == -1:
self.to_cuda(model_index=1)
......@@ -434,20 +426,20 @@ class Wan22MoeRunner(WanRunner):
def load_transformer(self):
# encoder -> high_noise_model -> low_noise_model -> vae -> video_output
high_noise_model = WanModel(
os.path.join(self.config.model_path, "high_noise_model"),
os.path.join(self.config["model_path"], "high_noise_model"),
self.config,
self.init_device,
)
low_noise_model = WanModel(
os.path.join(self.config.model_path, "low_noise_model"),
os.path.join(self.config["model_path"], "low_noise_model"),
self.config,
self.init_device,
)
if self.config.get("lora_configs") and self.config.lora_configs:
assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
if self.config.get("lora_configs") and self.config["lora_configs"]:
assert not self.config.get("dit_quantized", False) or self.config["mm_config"].get("weight_auto_quant", False)
for lora_config in self.config.lora_configs:
for lora_config in self.config["lora_configs"]:
lora_path = lora_config["path"]
strength = lora_config.get("strength", 1.0)
base_name = os.path.basename(lora_path)
......@@ -464,7 +456,7 @@ class Wan22MoeRunner(WanRunner):
else:
raise ValueError(f"Unsupported LoRA path: {lora_path}")
return MultiModelStruct([high_noise_model, low_noise_model], self.config, self.config.boundary)
return MultiModelStruct([high_noise_model, low_noise_model], self.config, self.config["boundary"])
@RUNNER_REGISTER("wan2.2")
......
......@@ -9,11 +9,10 @@ from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.schedulers.wan.self_forcing.scheduler import WanSFScheduler
from lightx2v.models.video_encoders.hf.wan.vae_sf import WanSFVAE
from lightx2v.utils.envs import *
from lightx2v.utils.memory_profiler import peak_memory_decorator
from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
torch.manual_seed(42)
@RUNNER_REGISTER("wan2.1_sf")
class WanSFRunner(WanRunner):
......@@ -59,40 +58,37 @@ class WanSFRunner(WanRunner):
gc.collect()
return images
@ProfilingContext4DebugL2("Run DiT")
def run_main(self, total_steps=None):
self.init_run()
if self.config.get("compile", False):
self.model.select_graph_for_compile()
total_blocks = self.scheduler.num_blocks
gen_videos = []
for seg_index in range(self.video_segment_num):
logger.info(f"==> segment_index: {seg_index + 1} / {total_blocks}")
total_steps = len(self.scheduler.denoising_step_list)
for step_index in range(total_steps):
logger.info(f"==> step_index: {step_index + 1} / {total_steps}")
with ProfilingContext4DebugL1("step_pre"):
self.model.scheduler.step_pre(seg_index=seg_index, step_index=step_index, is_rerun=False)
with ProfilingContext4DebugL1("🚀 infer_main"):
self.model.infer(self.inputs)
with ProfilingContext4DebugL1("step_post"):
self.model.scheduler.step_post()
latents = self.model.scheduler.stream_output
gen_videos.append(self.run_vae_decoder(latents))
# rerun with timestep zero to update KV cache using clean context
with ProfilingContext4DebugL1("step_pre_in_rerun"):
self.model.scheduler.step_pre(seg_index=seg_index, step_index=step_index, is_rerun=True)
with ProfilingContext4DebugL1("🚀 infer_main_in_rerun"):
def init_run(self):
super().init_run()
@ProfilingContext4DebugL1("End run segment")
def end_run_segment(self, segment_idx=None):
with ProfilingContext4DebugL1("step_pre_in_rerun"):
self.model.scheduler.step_pre(seg_index=segment_idx, step_index=self.model.scheduler.infer_steps - 1, is_rerun=True)
with ProfilingContext4DebugL1("🚀 infer_main_in_rerun"):
self.model.infer(self.inputs)
self.gen_video_final = torch.cat([self.gen_video_final, self.gen_video], dim=0) if self.gen_video_final is not None else self.gen_video
@peak_memory_decorator
def run_segment(self, total_steps=None):
if total_steps is None:
total_steps = self.model.scheduler.infer_steps
for step_index in range(total_steps):
# only for single segment, check stop signal every step
if self.video_segment_num == 1:
self.check_stop()
logger.info(f"==> step_index: {step_index + 1} / {total_steps}")
with ProfilingContext4DebugL1("step_pre"):
self.model.scheduler.step_pre(seg_index=self.segment_idx, step_index=step_index, is_rerun=False)
with ProfilingContext4DebugL1("🚀 infer_main"):
self.model.infer(self.inputs)
self.gen_video = torch.cat(gen_videos, dim=0)
with ProfilingContext4DebugL1("step_post"):
self.model.scheduler.step_post()
if self.progress_callback:
self.progress_callback(((step_index + 1) / total_steps) * 100, 100)
self.end_run()
return self.model.scheduler.stream_output
......@@ -154,10 +154,10 @@ class WanVaceRunner(WanRunner):
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(cat_latents, result_masks)]
def set_target_shape(self):
target_shape = self.latent_shape
target_shape[0] = int(target_shape[0] / 2)
self.config.target_shape = target_shape
def set_input_info_latent_shape(self):
latent_shape = self.latent_shape
latent_shape[0] = int(latent_shape[0] / 2)
return latent_shape
@ProfilingContext4DebugL1("Run VAE Decoder")
def run_vae_decoder(self, latents):
......
......@@ -6,8 +6,8 @@ class BaseScheduler:
self.config = config
self.latents = None
self.step_index = 0
self.infer_steps = config.infer_steps
self.caching_records = [True] * config.infer_steps
self.infer_steps = config["infer_steps"]
self.caching_records = [True] * config["infer_steps"]
self.flag_df = False
self.transformer_infer = None
self.infer_condition = True # cfg status
......
......@@ -13,8 +13,8 @@ class EulerScheduler(WanScheduler):
def __init__(self, config):
super().__init__(config)
if self.config.parallel:
self.sp_size = self.config.parallel.get("seq_p_size", 1)
if self.config["parallel"]:
self.sp_size = self.config["parallel"].get("seq_p_size", 1)
else:
self.sp_size = 1
......@@ -33,11 +33,11 @@ class EulerScheduler(WanScheduler):
if self.audio_adapter.cpu_offload:
self.audio_adapter.time_embedding.to("cpu")
if self.config.model_cls == "wan2.2_audio":
if self.config["model_cls"] == "wan2.2_audio":
_, lat_f, lat_h, lat_w = self.latents.shape
F = (lat_f - 1) * self.config.vae_stride[0] + 1
per_latent_token_len = lat_h * lat_w // (self.config.patch_size[1] * self.config.patch_size[2])
max_seq_len = ((F - 1) // self.config.vae_stride[0] + 1) * per_latent_token_len
F = (lat_f - 1) * self.config["vae_stride"][0] + 1
per_latent_token_len = lat_h * lat_w // (self.config["patch_size"][1] * self.config["patch_size"][2])
max_seq_len = ((F - 1) // self.config["vae_stride"][0] + 1) * per_latent_token_len
max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
temp_ts = (self.mask[0][:, ::2, ::2] * self.timestep_input).flatten()
......@@ -55,13 +55,13 @@ class EulerScheduler(WanScheduler):
dim=1,
)
def prepare_latents(self, target_shape, dtype=torch.float32):
self.generator = torch.Generator(device=self.device).manual_seed(self.config.seed)
def prepare_latents(self, seed, latent_shape, dtype=torch.float32):
self.generator = torch.Generator(device=self.device).manual_seed(seed)
self.latents = torch.randn(
target_shape[0],
target_shape[1],
target_shape[2],
target_shape[3],
latent_shape[0],
latent_shape[1],
latent_shape[2],
latent_shape[3],
dtype=dtype,
device=self.device,
generator=self.generator,
......@@ -71,8 +71,8 @@ class EulerScheduler(WanScheduler):
if self.prev_latents is not None:
self.latents = (1.0 - self.mask) * self.prev_latents + self.mask * self.latents
def prepare(self, previmg_encoder_output=None):
self.prepare_latents(self.config.target_shape, dtype=torch.float32)
def prepare(self, seed, latent_shape, image_encoder_output=None):
self.prepare_latents(seed, latent_shape, dtype=torch.float32)
timesteps = np.linspace(self.num_train_timesteps, 0, self.infer_steps + 1, dtype=np.float32)
self.timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=self.device)
......@@ -93,11 +93,11 @@ class EulerScheduler(WanScheduler):
if self.config["model_cls"] == "wan2.2_audio" and self.prev_latents is not None:
self.latents = (1.0 - self.mask) * self.prev_latents + self.mask * self.latents
def reset(self, previmg_encoder_output=None):
def reset(self, seed, latent_shape, image_encoder_output=None):
if self.config["model_cls"] == "wan2.2_audio":
self.prev_latents = previmg_encoder_output["prev_latents"]
self.prev_len = previmg_encoder_output["prev_len"]
self.prepare_latents(self.config.target_shape, dtype=torch.float32)
self.prev_latents = image_encoder_output["prev_latents"]
self.prev_len = image_encoder_output["prev_len"]
self.prepare_latents(seed, latent_shape, dtype=torch.float32)
def unsqueeze_to_ndim(self, in_tensor, tgt_n_dim):
if in_tensor.ndim > tgt_n_dim:
......
......@@ -19,16 +19,16 @@ class WanScheduler4ChangingResolution:
config["changing_resolution_steps"] = [config.infer_steps // 2]
assert len(config["resolution_rate"]) == len(config["changing_resolution_steps"])
def prepare_latents(self, target_shape, dtype=torch.float32):
self.generator = torch.Generator(device=self.device).manual_seed(self.config.seed)
def prepare_latents(self, seed, latent_shape, dtype=torch.float32):
self.generator = torch.Generator(device=self.device).manual_seed(seed)
self.latents_list = []
for i in range(len(self.config["resolution_rate"])):
self.latents_list.append(
torch.randn(
target_shape[0],
target_shape[1],
int(target_shape[2] * self.config["resolution_rate"][i]) // 2 * 2,
int(target_shape[3] * self.config["resolution_rate"][i]) // 2 * 2,
latent_shape[0],
latent_shape[1],
int(latent_shape[2] * self.config["resolution_rate"][i]) // 2 * 2,
int(latent_shape[3] * self.config["resolution_rate"][i]) // 2 * 2,
dtype=dtype,
device=self.device,
generator=self.generator,
......@@ -38,10 +38,10 @@ class WanScheduler4ChangingResolution:
# add original resolution latents
self.latents_list.append(
torch.randn(
target_shape[0],
target_shape[1],
target_shape[2],
target_shape[3],
latent_shape[0],
latent_shape[1],
latent_shape[2],
latent_shape[3],
dtype=dtype,
device=self.device,
generator=self.generator,
......
import gc
from typing import List, Optional, Union
import numpy as np
......@@ -12,22 +11,22 @@ class WanScheduler(BaseScheduler):
def __init__(self, config):
super().__init__(config)
self.device = torch.device("cuda")
self.infer_steps = self.config.infer_steps
self.target_video_length = self.config.target_video_length
self.sample_shift = self.config.sample_shift
self.infer_steps = self.config["infer_steps"]
self.target_video_length = self.config["target_video_length"]
self.sample_shift = self.config["sample_shift"]
self.shift = 1
self.num_train_timesteps = 1000
self.disable_corrector = []
self.solver_order = 2
self.noise_pred = None
self.sample_guide_scale = self.config.sample_guide_scale
self.caching_records_2 = [True] * self.config.infer_steps
self.sample_guide_scale = self.config["sample_guide_scale"]
self.caching_records_2 = [True] * self.config["infer_steps"]
def prepare(self, image_encoder_output=None):
if self.config["model_cls"] == "wan2.2" and self.config["task"] == "i2v":
def prepare(self, seed, latent_shape, image_encoder_output=None):
if self.config["model_cls"] == "wan2.2" and self.config["task"] in ["i2v", "s2v"]:
self.vae_encoder_out = image_encoder_output["vae_encoder_out"]
self.prepare_latents(self.config.target_shape, dtype=torch.float32)
self.prepare_latents(seed, latent_shape, dtype=torch.float32)
alphas = np.linspace(1, 1 / self.num_train_timesteps, self.num_train_timesteps)[::-1].copy()
sigmas = 1.0 - alphas
......@@ -48,18 +47,18 @@ class WanScheduler(BaseScheduler):
self.set_timesteps(self.infer_steps, device=self.device, shift=self.sample_shift)
def prepare_latents(self, target_shape, dtype=torch.float32):
self.generator = torch.Generator(device=self.device).manual_seed(self.config.seed)
def prepare_latents(self, seed, latent_shape, dtype=torch.float32):
self.generator = torch.Generator(device=self.device).manual_seed(seed)
self.latents = torch.randn(
target_shape[0],
target_shape[1],
target_shape[2],
target_shape[3],
latent_shape[0],
latent_shape[1],
latent_shape[2],
latent_shape[3],
dtype=dtype,
device=self.device,
generator=self.generator,
)
if self.config["model_cls"] == "wan2.2" and self.config["task"] == "i2v":
if self.config["model_cls"] == "wan2.2" and self.config["task"] in ["i2v", "s2v"]:
self.mask = masks_like(self.latents, zero=True)
self.latents = (1.0 - self.mask) * self.vae_encoder_out + self.mask * self.latents
......@@ -117,7 +116,7 @@ class WanScheduler(BaseScheduler):
x0_pred = sample - sigma_t * model_output
return x0_pred
def reset(self, step_index=None):
def reset(self, seed, latent_shape, step_index=None):
if step_index is not None:
self.step_index = step_index
self.model_outputs = [None] * self.solver_order
......@@ -126,9 +125,7 @@ class WanScheduler(BaseScheduler):
self.noise_pred = None
self.this_order = None
self.lower_order_nums = 0
self.prepare_latents(self.config.target_shape, dtype=torch.float32)
gc.collect()
torch.cuda.empty_cache()
self.prepare_latents(seed, latent_shape, dtype=torch.float32)
def multistep_uni_p_bh_update(
self,
......@@ -325,7 +322,7 @@ class WanScheduler(BaseScheduler):
def step_pre(self, step_index):
super().step_pre(step_index)
self.timestep_input = torch.stack([self.timesteps[self.step_index]])
if self.config["model_cls"] == "wan2.2" and self.config["task"] == "i2v":
if self.config["model_cls"] == "wan2.2" and self.config["task"] in ["i2v", "s2v"]:
self.timestep_input = (self.mask[0][:, ::2, ::2] * self.timestep_input).flatten()
def step_post(self):
......@@ -367,5 +364,5 @@ class WanScheduler(BaseScheduler):
self.lower_order_nums += 1
self.latents = prev_sample
if self.config["model_cls"] == "wan2.2" and self.config["task"] == "i2v":
if self.config["model_cls"] == "wan2.2" and self.config["task"] in ["i2v", "s2v"]:
self.latents = (1.0 - self.mask) * self.vae_encoder_out + self.mask * self.latents
......@@ -9,24 +9,25 @@ class WanSFScheduler(WanScheduler):
super().__init__(config)
self.device = torch.device("cuda")
self.dtype = torch.bfloat16
self.num_frame_per_block = self.config.sf_config.num_frame_per_block
self.num_output_frames = self.config.sf_config.num_output_frames
self.num_frame_per_block = self.config["sf_config"]["num_frame_per_block"]
self.num_output_frames = self.config["sf_config"]["num_output_frames"]
self.num_blocks = self.num_output_frames // self.num_frame_per_block
self.denoising_step_list = self.config.sf_config.denoising_step_list
self.denoising_step_list = self.config["sf_config"]["denoising_step_list"]
self.infer_steps = len(self.denoising_step_list)
self.all_num_frames = [self.num_frame_per_block] * self.num_blocks
self.num_input_frames = 0
self.denoising_strength = 1.0
self.sigma_max = 1.0
self.sigma_min = 0
self.sf_shift = self.config.sf_config.shift
self.sf_shift = self.config["sf_config"]["shift"]
self.inverse_timesteps = False
self.extra_one_step = True
self.reverse_sigmas = False
self.num_inference_steps = self.config.sf_config.num_inference_steps
self.num_inference_steps = self.config["sf_config"]["num_inference_steps"]
self.context_noise = 0
def prepare(self, image_encoder_output=None):
self.latents = torch.randn(self.config.target_shape, device=self.device, dtype=self.dtype)
def prepare(self, seed, latent_shape, image_encoder_output=None):
self.latents = torch.randn(latent_shape, device=self.device, dtype=self.dtype)
timesteps = []
for frame_block_idx, current_num_frames in enumerate(self.all_num_frames):
......@@ -39,7 +40,7 @@ class WanSFScheduler(WanScheduler):
timesteps.append(frame_steps)
self.timesteps = timesteps
self.noise_pred = torch.zeros(self.config.target_shape, device=self.device, dtype=self.dtype)
self.noise_pred = torch.zeros(latent_shape, device=self.device, dtype=self.dtype)
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * self.denoising_strength
if self.extra_one_step:
......@@ -91,7 +92,7 @@ class WanSFScheduler(WanScheduler):
x0_pred = x0_pred.to(original_dtype)
# add noise
if self.step_index < len(self.denoising_step_list) - 1:
if self.step_index < self.infer_steps - 1:
timestep_next = self.timesteps[self.seg_index][self.step_index + 1] * torch.ones(self.num_frame_per_block, device=self.device, dtype=torch.long)
timestep_id_next = torch.argmin((self.timesteps_sf.unsqueeze(0) - timestep_next.unsqueeze(1)).abs(), dim=1)
sigma_next = self.sigmas_sf[timestep_id_next].reshape(-1, 1, 1, 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment