"git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "a385ee27bd0025781eba61578889e470a1c027fb"
Unverified Commit 04812de2 authored by Yang Yong (雍洋)'s avatar Yang Yong (雍洋) Committed by GitHub
Browse files

Refactor Config System (#338)

parent 6a658f42
...@@ -14,8 +14,8 @@ class WanSFModel(WanModel): ...@@ -14,8 +14,8 @@ class WanSFModel(WanModel):
self.to_cuda() self.to_cuda()
def _load_ckpt(self, unified_dtype, sensitive_layer): def _load_ckpt(self, unified_dtype, sensitive_layer):
sf_confg = self.config.sf_config sf_confg = self.config["sf_config"]
file_path = os.path.join(self.config.sf_model_path, f"checkpoints/self_forcing_{sf_confg.sf_type}.pt") file_path = os.path.join(self.config["sf_model_path"], f"checkpoints/self_forcing_{sf_confg['sf_type']}.pt")
_weight_dict = torch.load(file_path)["generator_ema"] _weight_dict = torch.load(file_path)["generator_ema"]
weight_dict = {} weight_dict = {}
for k, v in _weight_dict.items(): for k, v in _weight_dict.items():
......
...@@ -40,7 +40,7 @@ class WanPreWeights(WeightModule): ...@@ -40,7 +40,7 @@ class WanPreWeights(WeightModule):
MM_WEIGHT_REGISTER["Default"]("time_projection.1.weight", "time_projection.1.bias"), MM_WEIGHT_REGISTER["Default"]("time_projection.1.weight", "time_projection.1.bias"),
) )
if config.task in ["i2v", "flf2v", "animate"] and config.get("use_image_encoder", True): if config["task"] in ["i2v", "flf2v", "animate", "s2v"] and config.get("use_image_encoder", True):
self.add_module( self.add_module(
"proj_0", "proj_0",
LN_WEIGHT_REGISTER["Default"]("img_emb.proj.0.weight", "img_emb.proj.0.bias"), LN_WEIGHT_REGISTER["Default"]("img_emb.proj.0.weight", "img_emb.proj.0.bias"),
...@@ -58,7 +58,7 @@ class WanPreWeights(WeightModule): ...@@ -58,7 +58,7 @@ class WanPreWeights(WeightModule):
LN_WEIGHT_REGISTER["Default"]("img_emb.proj.4.weight", "img_emb.proj.4.bias"), LN_WEIGHT_REGISTER["Default"]("img_emb.proj.4.weight", "img_emb.proj.4.bias"),
) )
if config.model_cls == "wan2.1_distill" and config.get("enable_dynamic_cfg", False): if config["model_cls"] == "wan2.1_distill" and config.get("enable_dynamic_cfg", False):
self.add_module( self.add_module(
"cfg_cond_proj_1", "cfg_cond_proj_1",
MM_WEIGHT_REGISTER["Default"]("guidance_embedding.linear_1.weight", "guidance_embedding.linear_1.bias"), MM_WEIGHT_REGISTER["Default"]("guidance_embedding.linear_1.weight", "guidance_embedding.linear_1.bias"),
...@@ -68,12 +68,12 @@ class WanPreWeights(WeightModule): ...@@ -68,12 +68,12 @@ class WanPreWeights(WeightModule):
MM_WEIGHT_REGISTER["Default"]("guidance_embedding.linear_2.weight", "guidance_embedding.linear_2.bias"), MM_WEIGHT_REGISTER["Default"]("guidance_embedding.linear_2.weight", "guidance_embedding.linear_2.bias"),
) )
if config.task == "flf2v" and config.get("use_image_encoder", True): if config["task"] == "flf2v" and config.get("use_image_encoder", True):
self.add_module( self.add_module(
"emb_pos", "emb_pos",
TENSOR_REGISTER["Default"](f"img_emb.emb_pos"), TENSOR_REGISTER["Default"](f"img_emb.emb_pos"),
) )
if config.task == "animate": if config["task"] == "animate":
self.add_module( self.add_module(
"pose_patch_embedding", "pose_patch_embedding",
CONV3D_WEIGHT_REGISTER["Default"]("pose_patch_embedding.weight", "pose_patch_embedding.bias", stride=self.patch_size), CONV3D_WEIGHT_REGISTER["Default"]("pose_patch_embedding.weight", "pose_patch_embedding.bias", stride=self.patch_size),
......
...@@ -60,7 +60,7 @@ class WanTransformerAttentionBlock(WeightModule): ...@@ -60,7 +60,7 @@ class WanTransformerAttentionBlock(WeightModule):
self.lazy_load = self.config.get("lazy_load", False) self.lazy_load = self.config.get("lazy_load", False)
if self.lazy_load: if self.lazy_load:
lazy_load_path = os.path.join(self.config.dit_quantized_ckpt, f"block_{block_index}.safetensors") lazy_load_path = os.path.join(self.config["dit_quantized_ckpt"], f"block_{block_index}.safetensors")
self.lazy_load_file = safe_open(lazy_load_path, framework="pt", device="cpu") self.lazy_load_file = safe_open(lazy_load_path, framework="pt", device="cpu")
else: else:
self.lazy_load_file = None self.lazy_load_file = None
...@@ -197,7 +197,7 @@ class WanSelfAttention(WeightModule): ...@@ -197,7 +197,7 @@ class WanSelfAttention(WeightModule):
self.add_module("self_attn_1", ATTN_WEIGHT_REGISTER[self.config["self_attn_1_type"]]()) self.add_module("self_attn_1", ATTN_WEIGHT_REGISTER[self.config["self_attn_1_type"]]())
if self.config["seq_parallel"]: if self.config["seq_parallel"]:
self.add_module("self_attn_1_parallel", ATTN_WEIGHT_REGISTER[self.config.parallel.get("seq_p_attn_type", "ulysses")]()) self.add_module("self_attn_1_parallel", ATTN_WEIGHT_REGISTER[self.config["parallel"].get("seq_p_attn_type", "ulysses")]())
if self.quant_method in ["advanced_ptq"]: if self.quant_method in ["advanced_ptq"]:
self.add_module( self.add_module(
...@@ -296,7 +296,7 @@ class WanCrossAttention(WeightModule): ...@@ -296,7 +296,7 @@ class WanCrossAttention(WeightModule):
) )
self.add_module("cross_attn_1", ATTN_WEIGHT_REGISTER[self.config["cross_attn_1_type"]]()) self.add_module("cross_attn_1", ATTN_WEIGHT_REGISTER[self.config["cross_attn_1_type"]]())
if self.config.task in ["i2v", "flf2v"] and self.config.get("use_image_encoder", True): if self.config["task"] in ["i2v", "flf2v", "animate", "s2v"] and self.config.get("use_image_encoder", True):
self.add_module( self.add_module(
"cross_attn_k_img", "cross_attn_k_img",
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
......
...@@ -3,8 +3,6 @@ from abc import ABC ...@@ -3,8 +3,6 @@ from abc import ABC
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from lightx2v.utils.utils import save_videos_grid
class BaseRunner(ABC): class BaseRunner(ABC):
"""Abstract base class for all Runners """Abstract base class for all Runners
...@@ -15,6 +13,7 @@ class BaseRunner(ABC): ...@@ -15,6 +13,7 @@ class BaseRunner(ABC):
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
self.vae_encoder_need_img_original = False self.vae_encoder_need_img_original = False
self.input_info = None
def load_transformer(self): def load_transformer(self):
"""Load transformer model """Load transformer model
...@@ -100,26 +99,6 @@ class BaseRunner(ABC): ...@@ -100,26 +99,6 @@ class BaseRunner(ABC):
"""Initialize scheduler""" """Initialize scheduler"""
pass pass
def set_target_shape(self):
"""Set target shape
Subclasses can override this method to provide specific implementation
Returns:
Dictionary containing target shape information
"""
return {}
def save_video_func(self, images):
"""Save video implementation
Subclasses can override this method to customize save logic
Args:
images: Image sequence to save
"""
save_videos_grid(images, self.config.get("save_video_path", "./output.mp4"), n_rows=1, fps=self.config.get("fps", 8))
def load_vae_decoder(self): def load_vae_decoder(self):
"""Load VAE decoder """Load VAE decoder
...@@ -146,7 +125,7 @@ class BaseRunner(ABC): ...@@ -146,7 +125,7 @@ class BaseRunner(ABC):
pass pass
def end_run_segment(self, segment_idx=None): def end_run_segment(self, segment_idx=None):
pass self.gen_video_final = self.gen_video
def end_run(self): def end_run(self):
pass pass
......
import imageio
import numpy as np
from lightx2v.models.input_encoders.hf.t5_v1_1_xxl.model import T5EncoderModel_v1_1_xxl from lightx2v.models.input_encoders.hf.t5_v1_1_xxl.model import T5EncoderModel_v1_1_xxl
from lightx2v.models.networks.cogvideox.model import CogvideoxModel from lightx2v.models.networks.cogvideox.model import CogvideoxModel
from lightx2v.models.runners.default_runner import DefaultRunner from lightx2v.models.runners.default_runner import DefaultRunner
...@@ -72,9 +69,3 @@ class CogvideoxRunner(DefaultRunner): ...@@ -72,9 +69,3 @@ class CogvideoxRunner(DefaultRunner):
) )
ret["target_shape"] = self.config.target_shape ret["target_shape"] = self.config.target_shape
return ret return ret
def save_video_func(self, images):
with imageio.get_writer(self.config.save_video_path, fps=16) as writer:
for pil_image in images:
frame_np = np.array(pil_image, dtype=np.uint8)
writer.append_data(frame_np)
...@@ -22,13 +22,13 @@ class DefaultRunner(BaseRunner): ...@@ -22,13 +22,13 @@ class DefaultRunner(BaseRunner):
super().__init__(config) super().__init__(config)
self.has_prompt_enhancer = False self.has_prompt_enhancer = False
self.progress_callback = None self.progress_callback = None
if self.config.task == "t2v" and self.config.get("sub_servers", {}).get("prompt_enhancer") is not None: if self.config["task"] == "t2v" and self.config.get("sub_servers", {}).get("prompt_enhancer") is not None:
self.has_prompt_enhancer = True self.has_prompt_enhancer = True
if not self.check_sub_servers("prompt_enhancer"): if not self.check_sub_servers("prompt_enhancer"):
self.has_prompt_enhancer = False self.has_prompt_enhancer = False
logger.warning("No prompt enhancer server available, disable prompt enhancer.") logger.warning("No prompt enhancer server available, disable prompt enhancer.")
if not self.has_prompt_enhancer: if not self.has_prompt_enhancer:
self.config.use_prompt_enhancer = False self.config["use_prompt_enhancer"] = False
self.set_init_device() self.set_init_device()
self.init_scheduler() self.init_scheduler()
...@@ -49,12 +49,15 @@ class DefaultRunner(BaseRunner): ...@@ -49,12 +49,15 @@ class DefaultRunner(BaseRunner):
self.run_input_encoder = self._run_input_encoder_local_vace self.run_input_encoder = self._run_input_encoder_local_vace
elif self.config["task"] == "animate": elif self.config["task"] == "animate":
self.run_input_encoder = self._run_input_encoder_local_animate self.run_input_encoder = self._run_input_encoder_local_animate
elif self.config["task"] == "s2v":
self.run_input_encoder = self._run_input_encoder_local_s2v
self.config.lock() # lock config to avoid modification
if self.config.get("compile", False): if self.config.get("compile", False):
logger.info(f"[Compile] Compile all shapes: {self.config.get('compile_shapes', [])}") logger.info(f"[Compile] Compile all shapes: {self.config.get('compile_shapes', [])}")
self.model.compile(self.config.get("compile_shapes", [])) self.model.compile(self.config.get("compile_shapes", []))
def set_init_device(self): def set_init_device(self):
if self.config.cpu_offload: if self.config["cpu_offload"]:
self.init_device = torch.device("cpu") self.init_device = torch.device("cpu")
else: else:
self.init_device = torch.device("cuda") self.init_device = torch.device("cuda")
...@@ -96,21 +99,23 @@ class DefaultRunner(BaseRunner): ...@@ -96,21 +99,23 @@ class DefaultRunner(BaseRunner):
return len(available_servers) > 0 return len(available_servers) > 0
def set_inputs(self, inputs): def set_inputs(self, inputs):
self.config["prompt"] = inputs.get("prompt", "") self.input_info.seed = inputs.get("seed", 42)
self.config["use_prompt_enhancer"] = False self.input_info.prompt = inputs.get("prompt", "")
if self.has_prompt_enhancer: if self.config["use_prompt_enhancer"]:
self.config["use_prompt_enhancer"] = inputs.get("use_prompt_enhancer", False) # Reset use_prompt_enhancer from clinet side. self.input_info.prompt_enhanced = inputs.get("prompt_enhanced", "")
self.config["negative_prompt"] = inputs.get("negative_prompt", "") self.input_info.negative_prompt = inputs.get("negative_prompt", "")
self.config["image_path"] = inputs.get("image_path", "") if "image_path" in self.input_info.__dataclass_fields__:
self.config["save_video_path"] = inputs.get("save_video_path", "") self.input_info.image_path = inputs.get("image_path", "")
self.config["infer_steps"] = inputs.get("infer_steps", self.config.get("infer_steps", 5)) if "audio_path" in self.input_info.__dataclass_fields__:
self.config["target_video_length"] = inputs.get("target_video_length", self.config.get("target_video_length", 81)) self.input_info.audio_path = inputs.get("audio_path", "")
self.config["seed"] = inputs.get("seed", self.config.get("seed", 42)) if "video_path" in self.input_info.__dataclass_fields__:
self.config["audio_path"] = inputs.get("audio_path", "") # for wan-audio self.input_info.video_path = inputs.get("video_path", "")
self.config["video_duration"] = inputs.get("video_duration", 5) # for wan-audio self.input_info.save_result_path = inputs.get("save_result_path", "")
# self.config["sample_shift"] = inputs.get("sample_shift", self.config.get("sample_shift", 5)) def set_config(self, config_modify):
# self.config["sample_guide_scale"] = inputs.get("sample_guide_scale", self.config.get("sample_guide_scale", 5)) logger.info(f"modify config: {config_modify}")
with self.config.temporarily_unlocked():
self.config.update(config_modify)
def set_progress_callback(self, callback): def set_progress_callback(self, callback):
self.progress_callback = callback self.progress_callback = callback
...@@ -146,6 +151,7 @@ class DefaultRunner(BaseRunner): ...@@ -146,6 +151,7 @@ class DefaultRunner(BaseRunner):
def end_run(self): def end_run(self):
self.model.scheduler.clear() self.model.scheduler.clear()
del self.inputs del self.inputs
self.input_info = None
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
if hasattr(self.model.transformer_infer, "weights_stream_mgr"): if hasattr(self.model.transformer_infer, "weights_stream_mgr"):
self.model.transformer_infer.weights_stream_mgr.clear() self.model.transformer_infer.weights_stream_mgr.clear()
...@@ -162,23 +168,24 @@ class DefaultRunner(BaseRunner): ...@@ -162,23 +168,24 @@ class DefaultRunner(BaseRunner):
else: else:
img_ori = Image.open(img_path).convert("RGB") img_ori = Image.open(img_path).convert("RGB")
img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).cuda() img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).cuda()
self.input_info.original_size = img_ori.size
return img, img_ori return img, img_ori
@ProfilingContext4DebugL2("Run Encoders") @ProfilingContext4DebugL2("Run Encoders")
def _run_input_encoder_local_i2v(self): def _run_input_encoder_local_i2v(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"] img, img_ori = self.read_image_input(self.input_info.image_path)
img, img_ori = self.read_image_input(self.config["image_path"])
clip_encoder_out = self.run_image_encoder(img) if self.config.get("use_image_encoder", True) else None clip_encoder_out = self.run_image_encoder(img) if self.config.get("use_image_encoder", True) else None
vae_encode_out = self.run_vae_encoder(img_ori if self.vae_encoder_need_img_original else img) vae_encode_out, latent_shape = self.run_vae_encoder(img_ori if self.vae_encoder_need_img_original else img)
text_encoder_output = self.run_text_encoder(prompt, img) self.input_info.latent_shape = latent_shape # Important: set latent_shape in input_info
text_encoder_output = self.run_text_encoder(self.input_info)
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img) return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)
@ProfilingContext4DebugL2("Run Encoders") @ProfilingContext4DebugL2("Run Encoders")
def _run_input_encoder_local_t2v(self): def _run_input_encoder_local_t2v(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"] self.input_info.latent_shape = self.get_latent_shape_with_target_hw(self.config["target_height"], self.config["target_width"]) # Important: set latent_shape in input_info
text_encoder_output = self.run_text_encoder(prompt, None) text_encoder_output = self.run_text_encoder(self.input_info)
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
return { return {
...@@ -188,22 +195,21 @@ class DefaultRunner(BaseRunner): ...@@ -188,22 +195,21 @@ class DefaultRunner(BaseRunner):
@ProfilingContext4DebugL2("Run Encoders") @ProfilingContext4DebugL2("Run Encoders")
def _run_input_encoder_local_flf2v(self): def _run_input_encoder_local_flf2v(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"] first_frame, _ = self.read_image_input(self.input_info.image_path)
first_frame, _ = self.read_image_input(self.config["image_path"]) last_frame, _ = self.read_image_input(self.input_info.last_frame_path)
last_frame, _ = self.read_image_input(self.config["last_frame_path"])
clip_encoder_out = self.run_image_encoder(first_frame, last_frame) if self.config.get("use_image_encoder", True) else None clip_encoder_out = self.run_image_encoder(first_frame, last_frame) if self.config.get("use_image_encoder", True) else None
vae_encode_out = self.run_vae_encoder(first_frame, last_frame) vae_encode_out, latent_shape = self.run_vae_encoder(first_frame, last_frame)
text_encoder_output = self.run_text_encoder(prompt, first_frame) self.input_info.latent_shape = latent_shape # Important: set latent_shape in input_info
text_encoder_output = self.run_text_encoder(self.input_info)
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output) return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output)
@ProfilingContext4DebugL2("Run Encoders") @ProfilingContext4DebugL2("Run Encoders")
def _run_input_encoder_local_vace(self): def _run_input_encoder_local_vace(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"] src_video = self.input_info.src_video
src_video = self.config.get("src_video", None) src_mask = self.input_info.src_mask
src_mask = self.config.get("src_mask", None) src_ref_images = self.input_info.src_ref_images
src_ref_images = self.config.get("src_ref_images", None)
src_video, src_mask, src_ref_images = self.prepare_source( src_video, src_mask, src_ref_images = self.prepare_source(
[src_video], [src_video],
[src_mask], [src_mask],
...@@ -212,34 +218,38 @@ class DefaultRunner(BaseRunner): ...@@ -212,34 +218,38 @@ class DefaultRunner(BaseRunner):
) )
self.src_ref_images = src_ref_images self.src_ref_images = src_ref_images
vae_encoder_out = self.run_vae_encoder(src_video, src_ref_images, src_mask) vae_encoder_out, latent_shape = self.run_vae_encoder(src_video, src_ref_images, src_mask)
text_encoder_output = self.run_text_encoder(prompt) self.input_info.latent_shape = latent_shape # Important: set latent_shape in input_info
text_encoder_output = self.run_text_encoder(self.input_info)
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
return self.get_encoder_output_i2v(None, vae_encoder_out, text_encoder_output) return self.get_encoder_output_i2v(None, vae_encoder_out, text_encoder_output)
@ProfilingContext4DebugL2("Run Text Encoder") @ProfilingContext4DebugL2("Run Text Encoder")
def _run_input_encoder_local_animate(self): def _run_input_encoder_local_animate(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"] text_encoder_output = self.run_text_encoder(self.input_info)
text_encoder_output = self.run_text_encoder(prompt, None)
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
return self.get_encoder_output_i2v(None, None, text_encoder_output, None) return self.get_encoder_output_i2v(None, None, text_encoder_output, None)
def _run_input_encoder_local_s2v(self):
pass
def init_run(self): def init_run(self):
self.set_target_shape() self.gen_video_final = None
self.get_video_segment_num() self.get_video_segment_num()
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.model = self.load_transformer() self.model = self.load_transformer()
self.model.scheduler.prepare(self.inputs["image_encoder_output"])
if self.config.get("model_cls") == "wan2.2" and self.config["task"] == "i2v": self.model.scheduler.prepare(seed=self.input_info.seed, latent_shape=self.input_info.latent_shape, image_encoder_output=self.inputs["image_encoder_output"])
if self.config.get("model_cls") == "wan2.2" and self.config["task"] in ["i2v", "s2v"]:
self.inputs["image_encoder_output"]["vae_encoder_out"] = None self.inputs["image_encoder_output"]["vae_encoder_out"] = None
@ProfilingContext4DebugL2("Run DiT") @ProfilingContext4DebugL2("Run DiT")
def run_main(self, total_steps=None): def run_main(self, total_steps=None):
self.init_run() self.init_run()
if self.config.get("compile", False): if self.config.get("compile", False):
self.model.select_graph_for_compile() self.model.select_graph_for_compile(self.input_info)
for segment_idx in range(self.video_segment_num): for segment_idx in range(self.video_segment_num):
logger.info(f"🔄 start segment {segment_idx + 1}/{self.video_segment_num}") logger.info(f"🔄 start segment {segment_idx + 1}/{self.video_segment_num}")
with ProfilingContext4DebugL1(f"segment end2end {segment_idx + 1}/{self.video_segment_num}"): with ProfilingContext4DebugL1(f"segment end2end {segment_idx + 1}/{self.video_segment_num}"):
...@@ -252,7 +262,9 @@ class DefaultRunner(BaseRunner): ...@@ -252,7 +262,9 @@ class DefaultRunner(BaseRunner):
self.gen_video = self.run_vae_decoder(latents) self.gen_video = self.run_vae_decoder(latents)
# 4. default do nothing # 4. default do nothing
self.end_run_segment(segment_idx) self.end_run_segment(segment_idx)
gen_video_final = self.process_images_after_vae_decoder()
self.end_run() self.end_run()
return {"video": gen_video_final}
@ProfilingContext4DebugL1("Run VAE Decoder") @ProfilingContext4DebugL1("Run VAE Decoder")
def run_vae_decoder(self, latents): def run_vae_decoder(self, latents):
...@@ -281,20 +293,22 @@ class DefaultRunner(BaseRunner): ...@@ -281,20 +293,22 @@ class DefaultRunner(BaseRunner):
logger.info(f"Enhanced prompt: {enhanced_prompt}") logger.info(f"Enhanced prompt: {enhanced_prompt}")
return enhanced_prompt return enhanced_prompt
def process_images_after_vae_decoder(self, save_video=True): def process_images_after_vae_decoder(self):
self.gen_video = vae_to_comfyui_image(self.gen_video) self.gen_video_final = vae_to_comfyui_image(self.gen_video_final)
if "video_frame_interpolation" in self.config: if "video_frame_interpolation" in self.config:
assert self.vfi_model is not None and self.config["video_frame_interpolation"].get("target_fps", None) is not None assert self.vfi_model is not None and self.config["video_frame_interpolation"].get("target_fps", None) is not None
target_fps = self.config["video_frame_interpolation"]["target_fps"] target_fps = self.config["video_frame_interpolation"]["target_fps"]
logger.info(f"Interpolating frames from {self.config.get('fps', 16)} to {target_fps}") logger.info(f"Interpolating frames from {self.config.get('fps', 16)} to {target_fps}")
self.gen_video = self.vfi_model.interpolate_frames( self.gen_video_final = self.vfi_model.interpolate_frames(
self.gen_video, self.gen_video_final,
source_fps=self.config.get("fps", 16), source_fps=self.config.get("fps", 16),
target_fps=target_fps, target_fps=target_fps,
) )
if save_video: if self.input_info.return_result_tensor:
return {"video": self.gen_video_final}
elif self.input_info.save_result_path is not None:
if "video_frame_interpolation" in self.config and self.config["video_frame_interpolation"].get("target_fps"): if "video_frame_interpolation" in self.config and self.config["video_frame_interpolation"].get("target_fps"):
fps = self.config["video_frame_interpolation"]["target_fps"] fps = self.config["video_frame_interpolation"]["target_fps"]
else: else:
...@@ -303,22 +317,18 @@ class DefaultRunner(BaseRunner): ...@@ -303,22 +317,18 @@ class DefaultRunner(BaseRunner):
if not dist.is_initialized() or dist.get_rank() == 0: if not dist.is_initialized() or dist.get_rank() == 0:
logger.info(f"🎬 Start to save video 🎬") logger.info(f"🎬 Start to save video 🎬")
save_to_video(self.gen_video, self.config.save_video_path, fps=fps, method="ffmpeg") save_to_video(self.gen_video_final, self.input_info.save_result_path, fps=fps, method="ffmpeg")
logger.info(f"✅ Video saved successfully to: {self.config.save_video_path} ✅") logger.info(f"✅ Video saved successfully to: {self.input_info.save_result_path} ✅")
if self.config.get("return_video", False): return {"video": None}
return {"video": self.gen_video}
return {"video": None} def run_pipeline(self, input_info):
self.input_info = input_info
def run_pipeline(self, save_video=True):
if self.config["use_prompt_enhancer"]: if self.config["use_prompt_enhancer"]:
self.config["prompt_enhanced"] = self.post_prompt_enhancer() self.input_info.prompt_enhanced = self.post_prompt_enhancer()
self.inputs = self.run_input_encoder() self.inputs = self.run_input_encoder()
self.run_main()
gen_video = self.process_images_after_vae_decoder(save_video=save_video) gen_video_final = self.run_main()
torch.cuda.empty_cache()
gc.collect()
return gen_video return gen_video_final
...@@ -14,7 +14,6 @@ from lightx2v.models.schedulers.hunyuan.scheduler import HunyuanScheduler ...@@ -14,7 +14,6 @@ from lightx2v.models.schedulers.hunyuan.scheduler import HunyuanScheduler
from lightx2v.models.video_encoders.hf.hunyuan.hunyuan_vae import HunyuanVAE from lightx2v.models.video_encoders.hf.hunyuan.hunyuan_vae import HunyuanVAE
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import save_videos_grid
@RUNNER_REGISTER("hunyuan") @RUNNER_REGISTER("hunyuan")
...@@ -152,6 +151,3 @@ class HunyuanRunner(DefaultRunner): ...@@ -152,6 +151,3 @@ class HunyuanRunner(DefaultRunner):
int(self.config.target_width) // vae_scale_factor, int(self.config.target_width) // vae_scale_factor,
) )
return {"target_height": self.config.target_height, "target_width": self.config.target_width, "target_shape": self.config.target_shape} return {"target_height": self.config.target_height, "target_width": self.config.target_width, "target_shape": self.config.target_shape}
def save_video_func(self, images):
save_videos_grid(images, self.config.save_video_path, fps=self.config.get("fps", 24))
...@@ -204,7 +204,7 @@ class QwenImageRunner(DefaultRunner): ...@@ -204,7 +204,7 @@ class QwenImageRunner(DefaultRunner):
images = self.run_vae_decoder(latents, generator) images = self.run_vae_decoder(latents, generator)
image = images[0] image = images[0]
image.save(f"{self.config.save_video_path}") image.save(f"{self.config.save_result_path}")
del latents, generator del latents, generator
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
...@@ -334,7 +334,7 @@ class WanAnimateRunner(WanRunner): ...@@ -334,7 +334,7 @@ class WanAnimateRunner(WanRunner):
) )
if start != 0: if start != 0:
self.model.scheduler.reset() self.model.scheduler.reset(self.input_info.seed, self.input_info.latent_shape)
def end_run_segment(self, segment_idx): def end_run_segment(self, segment_idx):
if segment_idx != 0: if segment_idx != 0:
......
import gc import gc
import json
import os import os
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import numpy as np import numpy as np
...@@ -337,16 +337,14 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -337,16 +337,14 @@ class WanAudioRunner(WanRunner): # type:ignore
"""Initialize consistency model scheduler""" """Initialize consistency model scheduler"""
self.scheduler = EulerScheduler(self.config) self.scheduler = EulerScheduler(self.config)
def read_audio_input(self): def read_audio_input(self, audio_path):
"""Read audio input - handles both single and multi-person scenarios""" """Read audio input - handles both single and multi-person scenarios"""
audio_sr = self.config.get("audio_sr", 16000) audio_sr = self.config.get("audio_sr", 16000)
target_fps = self.config.get("target_fps", 16) target_fps = self.config.get("target_fps", 16)
self._audio_processor = AudioProcessor(audio_sr, target_fps) self._audio_processor = AudioProcessor(audio_sr, target_fps)
# Get audio files from person objects or legacy format # Get audio files from person objects or legacy format
audio_files = self._get_audio_files_from_config() audio_files, mask_files = self.get_audio_files_from_audio_path(audio_path)
if not audio_files:
return [], 0
# Load audio based on single or multi-person mode # Load audio based on single or multi-person mode
if len(audio_files) == 1: if len(audio_files) == 1:
...@@ -355,8 +353,6 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -355,8 +353,6 @@ class WanAudioRunner(WanRunner): # type:ignore
else: else:
audio_array = self._audio_processor.load_multi_person_audio(audio_files) audio_array = self._audio_processor.load_multi_person_audio(audio_files)
self.config.audio_num = audio_array.size(0)
video_duration = self.config.get("video_duration", 5) video_duration = self.config.get("video_duration", 5)
audio_len = int(audio_array.shape[1] / audio_sr * target_fps) audio_len = int(audio_array.shape[1] / audio_sr * target_fps)
expected_frames = min(max(1, int(video_duration * target_fps)), audio_len) expected_frames = min(max(1, int(video_duration * target_fps)), audio_len)
...@@ -364,60 +360,35 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -364,60 +360,35 @@ class WanAudioRunner(WanRunner): # type:ignore
# Segment audio # Segment audio
audio_segments = self._audio_processor.segment_audio(audio_array, expected_frames, self.config.get("target_video_length", 81), self.prev_frame_length) audio_segments = self._audio_processor.segment_audio(audio_array, expected_frames, self.config.get("target_video_length", 81), self.prev_frame_length)
return audio_array.size(0), audio_segments, expected_frames # Mask latent for multi-person s2v
if mask_files is not None:
def _get_audio_files_from_config(self): mask_latents = [self.process_single_mask(mask_file) for mask_file in mask_files]
talk_objects = self.config.get("talk_objects") mask_latents = torch.cat(mask_latents, dim=0)
if talk_objects: else:
audio_files = [] mask_latents = None
for idx, person in enumerate(talk_objects):
audio_path = person.get("audio")
if audio_path and Path(audio_path).is_file():
audio_files.append(str(audio_path))
else:
logger.warning(f"Person {idx} audio file {audio_path} does not exist or not specified")
if audio_files:
logger.info(f"Loaded {len(audio_files)} audio files from talk_objects")
return audio_files
audio_path = self.config.get("audio_path")
if audio_path:
return [audio_path]
logger.error("config audio_path or talk_objects is not specified")
return []
def read_person_mask(self): return audio_segments, expected_frames, mask_latents, len(audio_files)
mask_files = self._get_mask_files_from_config()
if not mask_files:
return None
mask_latents = [] def get_audio_files_from_audio_path(self, audio_path):
for mask_file in mask_files: if os.path.isdir(audio_path):
mask_latent = self._process_single_mask(mask_file) audio_files = []
mask_latents.append(mask_latent) mask_files = []
logger.info(f"audio_path is a directory, loading config.json from {audio_path}")
audio_config_path = os.path.join(audio_path, "config.json")
assert os.path.exists(audio_config_path), "config.json not found in audio_path"
with open(audio_config_path, "r") as f:
audio_config = json.load(f)
for talk_object in audio_config["talk_objects"]:
audio_files.append(os.path.join(audio_path, talk_object["audio"]))
mask_files.append(os.path.join(audio_path, talk_object["mask"]))
else:
logger.info(f"audio_path is a file without mask: {audio_path}")
audio_files = [audio_path]
mask_files = None
mask_latents = torch.cat(mask_latents, dim=0) return audio_files, mask_files
return mask_latents
def _get_mask_files_from_config(self): def process_single_mask(self, mask_file):
talk_objects = self.config.get("talk_objects")
if talk_objects:
mask_files = []
for idx, person in enumerate(talk_objects):
mask_path = person.get("mask")
if mask_path and Path(mask_path).is_file():
mask_files.append(str(mask_path))
elif mask_path:
logger.warning(f"Person {idx} mask file {mask_path} does not exist")
if mask_files:
logger.info(f"Loaded {len(mask_files)} mask files from talk_objects")
return mask_files
logger.info("config talk_objects is not specified")
return None
def _process_single_mask(self, mask_file):
mask_img = Image.open(mask_file).convert("RGB") mask_img = Image.open(mask_file).convert("RGB")
mask_img = TF.to_tensor(mask_img).sub_(0.5).div_(0.5).unsqueeze(0).cuda() mask_img = TF.to_tensor(mask_img).sub_(0.5).div_(0.5).unsqueeze(0).cuda()
...@@ -456,21 +427,21 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -456,21 +427,21 @@ class WanAudioRunner(WanRunner): # type:ignore
fixed_shape=self.config.get("fixed_shape", None), fixed_shape=self.config.get("fixed_shape", None),
) )
logger.info(f"[wan_audio] resize_image target_h: {h}, target_w: {w}") logger.info(f"[wan_audio] resize_image target_h: {h}, target_w: {w}")
patched_h = h // self.config.vae_stride[1] // self.config.patch_size[1] patched_h = h // self.config["vae_stride"][1] // self.config["patch_size"][1]
patched_w = w // self.config.vae_stride[2] // self.config.patch_size[2] patched_w = w // self.config["vae_stride"][2] // self.config["patch_size"][2]
patched_h, patched_w = get_optimal_patched_size_with_sp(patched_h, patched_w, 1) patched_h, patched_w = get_optimal_patched_size_with_sp(patched_h, patched_w, 1)
self.config.lat_h = patched_h * self.config.patch_size[1] latent_h = patched_h * self.config["patch_size"][1]
self.config.lat_w = patched_w * self.config.patch_size[2] latent_w = patched_w * self.config["patch_size"][2]
self.config.tgt_h = self.config.lat_h * self.config.vae_stride[1] latent_shape = self.get_latent_shape_with_lat_hw(latent_h, latent_w)
self.config.tgt_w = self.config.lat_w * self.config.vae_stride[2] target_shape = [latent_h * self.config["vae_stride"][1], latent_w * self.config["vae_stride"][2]]
logger.info(f"[wan_audio] tgt_h: {self.config.tgt_h}, tgt_w: {self.config.tgt_w}, lat_h: {self.config.lat_h}, lat_w: {self.config.lat_w}") logger.info(f"[wan_audio] target_h: {target_shape[0]}, target_w: {target_shape[1]}, latent_h: {latent_h}, latent_w: {latent_w}")
ref_img = torch.nn.functional.interpolate(ref_img, size=(self.config.tgt_h, self.config.tgt_w), mode="bicubic") ref_img = torch.nn.functional.interpolate(ref_img, size=(target_shape[0], target_shape[1]), mode="bicubic")
return ref_img return ref_img, latent_shape, target_shape
def run_image_encoder(self, first_frame, last_frame=None): def run_image_encoder(self, first_frame, last_frame=None):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
...@@ -496,20 +467,17 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -496,20 +467,17 @@ class WanAudioRunner(WanRunner): # type:ignore
return vae_encoder_out return vae_encoder_out
@ProfilingContext4DebugL2("Run Encoders") @ProfilingContext4DebugL2("Run Encoders")
def _run_input_encoder_local_r2v_audio(self): def _run_input_encoder_local_s2v(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"] img, latent_shape, target_shape = self.read_image_input(self.input_info.image_path)
img = self.read_image_input(self.config["image_path"]) self.input_info.latent_shape = latent_shape # Important: set latent_shape in input_info
self.input_info.target_shape = target_shape # Important: set target_shape in input_info
clip_encoder_out = self.run_image_encoder(img) if self.config.get("use_image_encoder", True) else None clip_encoder_out = self.run_image_encoder(img) if self.config.get("use_image_encoder", True) else None
vae_encode_out = self.run_vae_encoder(img) vae_encode_out = self.run_vae_encoder(img)
audio_num, audio_segments, expected_frames = self.read_audio_input() audio_segments, expected_frames, person_mask_latens, audio_num = self.read_audio_input(self.input_info.audio_path)
person_mask_latens = self.read_person_mask() self.input_info.audio_num = audio_num
self.config.person_num = 0 self.input_info.with_mask = person_mask_latens is not None
if person_mask_latens is not None: text_encoder_output = self.run_text_encoder(self.input_info)
assert audio_num == person_mask_latens.size(0), "audio_num and person_mask_latens.size(0) must be the same"
self.config.person_num = person_mask_latens.size(0)
text_encoder_output = self.run_text_encoder(prompt, None)
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
return { return {
...@@ -528,13 +496,13 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -528,13 +496,13 @@ class WanAudioRunner(WanRunner): # type:ignore
device = torch.device("cuda") device = torch.device("cuda")
dtype = GET_DTYPE() dtype = GET_DTYPE()
tgt_h, tgt_w = self.config.tgt_h, self.config.tgt_w tgt_h, tgt_w = self.input_info.target_shape[0], self.input_info.target_shape[1]
prev_frames = torch.zeros((1, 3, self.config.target_video_length, tgt_h, tgt_w), device=device) prev_frames = torch.zeros((1, 3, self.config["target_video_length"], tgt_h, tgt_w), device=device)
if prev_video is not None: if prev_video is not None:
# Extract and process last frames # Extract and process last frames
last_frames = prev_video[:, :, -prev_frame_length:].clone().to(device) last_frames = prev_video[:, :, -prev_frame_length:].clone().to(device)
if self.config.model_cls != "wan2.2_audio": if self.config["model_cls"] != "wan2.2_audio":
last_frames = self.frame_preprocessor.process_prev_frames(last_frames) last_frames = self.frame_preprocessor.process_prev_frames(last_frames)
prev_frames[:, :, :prev_frame_length] = last_frames prev_frames[:, :, :prev_frame_length] = last_frames
prev_len = (prev_frame_length - 1) // 4 + 1 prev_len = (prev_frame_length - 1) // 4 + 1
...@@ -546,7 +514,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -546,7 +514,7 @@ class WanAudioRunner(WanRunner): # type:ignore
_, nframe, height, width = self.model.scheduler.latents.shape _, nframe, height, width = self.model.scheduler.latents.shape
with ProfilingContext4DebugL1("vae_encoder in init run segment"): with ProfilingContext4DebugL1("vae_encoder in init run segment"):
if self.config.model_cls == "wan2.2_audio": if self.config["model_cls"] == "wan2.2_audio":
if prev_video is not None: if prev_video is not None:
prev_latents = self.vae_encoder.encode(prev_frames.to(dtype)) prev_latents = self.vae_encoder.encode(prev_frames.to(dtype))
else: else:
...@@ -563,7 +531,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -563,7 +531,7 @@ class WanAudioRunner(WanRunner): # type:ignore
if prev_latents is not None: if prev_latents is not None:
if prev_latents.shape[-2:] != (height, width): if prev_latents.shape[-2:] != (height, width):
logger.warning(f"Size mismatch: prev_latents {prev_latents.shape} vs scheduler latents (H={height}, W={width}). Config tgt_h={self.config.tgt_h}, tgt_w={self.config.tgt_w}") logger.warning(f"Size mismatch: prev_latents {prev_latents.shape} vs scheduler latents (H={height}, W={width}). Config tgt_h={tgt_h}, tgt_w={tgt_w}")
prev_latents = torch.nn.functional.interpolate(prev_latents, size=(height, width), mode="bilinear", align_corners=False) prev_latents = torch.nn.functional.interpolate(prev_latents, size=(height, width), mode="bilinear", align_corners=False)
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
...@@ -592,8 +560,8 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -592,8 +560,8 @@ class WanAudioRunner(WanRunner): # type:ignore
super().init_run() super().init_run()
self.scheduler.set_audio_adapter(self.audio_adapter) self.scheduler.set_audio_adapter(self.audio_adapter)
self.prev_video = None self.prev_video = None
if self.config.get("return_video", False): if self.input_info.return_result_tensor:
self.gen_video_final = torch.zeros((self.inputs["expected_frames"], self.config.tgt_h, self.config.tgt_w, 3), dtype=torch.float32, device="cpu") self.gen_video_final = torch.zeros((self.inputs["expected_frames"], self.input_info.target_shape[0], self.input_info.target_shape[1], 3), dtype=torch.float32, device="cpu")
self.cut_audio_final = torch.zeros((self.inputs["expected_frames"] * self._audio_processor.audio_frame_rate), dtype=torch.float32, device="cpu") self.cut_audio_final = torch.zeros((self.inputs["expected_frames"] * self._audio_processor.audio_frame_rate), dtype=torch.float32, device="cpu")
else: else:
self.gen_video_final = None self.gen_video_final = None
...@@ -608,8 +576,8 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -608,8 +576,8 @@ class WanAudioRunner(WanRunner): # type:ignore
else: else:
self.segment = self.inputs["audio_segments"][segment_idx] self.segment = self.inputs["audio_segments"][segment_idx]
self.config.seed = self.config.seed + segment_idx self.input_info.seed = self.input_info.seed + segment_idx
torch.manual_seed(self.config.seed) torch.manual_seed(self.input_info.seed)
# logger.info(f"Processing segment {segment_idx + 1}/{self.video_segment_num}, seed: {self.config.seed}") # logger.info(f"Processing segment {segment_idx + 1}/{self.video_segment_num}, seed: {self.config.seed}")
if (self.config.get("lazy_load", False) or self.config.get("unload_modules", False)) and not hasattr(self, "audio_encoder"): if (self.config.get("lazy_load", False) or self.config.get("unload_modules", False)) and not hasattr(self, "audio_encoder"):
...@@ -627,7 +595,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -627,7 +595,7 @@ class WanAudioRunner(WanRunner): # type:ignore
# Reset scheduler for non-first segments # Reset scheduler for non-first segments
if segment_idx > 0: if segment_idx > 0:
self.model.scheduler.reset(self.inputs["previmg_encoder_output"]) self.model.scheduler.reset(self.input_info.seed, self.input_info.latent_shape, self.inputs["previmg_encoder_output"])
@ProfilingContext4DebugL1("End run segment") @ProfilingContext4DebugL1("End run segment")
def end_run_segment(self, segment_idx): def end_run_segment(self, segment_idx):
...@@ -650,7 +618,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -650,7 +618,7 @@ class WanAudioRunner(WanRunner): # type:ignore
if self.va_recorder: if self.va_recorder:
self.va_recorder.pub_livestream(video_seg, audio_seg) self.va_recorder.pub_livestream(video_seg, audio_seg)
elif self.config.get("return_video", False): elif self.input_info.return_result_tensor:
self.gen_video_final[self.segment.start_frame : self.segment.end_frame].copy_(video_seg) self.gen_video_final[self.segment.start_frame : self.segment.end_frame].copy_(video_seg)
self.cut_audio_final[self.segment.start_frame * self._audio_processor.audio_frame_rate : self.segment.end_frame * self._audio_processor.audio_frame_rate].copy_(audio_seg) self.cut_audio_final[self.segment.start_frame * self._audio_processor.audio_frame_rate : self.segment.end_frame * self._audio_processor.audio_frame_rate].copy_(audio_seg)
...@@ -669,7 +637,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -669,7 +637,7 @@ class WanAudioRunner(WanRunner): # type:ignore
return rank, world_size return rank, world_size
def init_va_recorder(self): def init_va_recorder(self):
output_video_path = self.config.get("save_video_path", None) output_video_path = self.input_info.save_result_path
self.va_recorder = None self.va_recorder = None
if isinstance(output_video_path, dict): if isinstance(output_video_path, dict):
output_video_path = output_video_path["data"] output_video_path = output_video_path["data"]
...@@ -722,7 +690,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -722,7 +690,7 @@ class WanAudioRunner(WanRunner): # type:ignore
self.init_run() self.init_run()
if self.config.get("compile", False): if self.config.get("compile", False):
self.model.select_graph_for_compile() self.model.select_graph_for_compile(self.input_info)
self.video_segment_num = "unlimited" self.video_segment_num = "unlimited"
fetch_timeout = self.va_reader.segment_duration + 1 fetch_timeout = self.va_reader.segment_duration + 1
...@@ -760,24 +728,20 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -760,24 +728,20 @@ class WanAudioRunner(WanRunner): # type:ignore
self.va_recorder = None self.va_recorder = None
@ProfilingContext4DebugL1("Process after vae decoder") @ProfilingContext4DebugL1("Process after vae decoder")
def process_images_after_vae_decoder(self, save_video=False): def process_images_after_vae_decoder(self):
if self.config.get("return_video", False): if self.input_info.return_result_tensor:
audio_waveform = self.cut_audio_final.unsqueeze(0).unsqueeze(0) audio_waveform = self.cut_audio_final.unsqueeze(0).unsqueeze(0)
comfyui_audio = {"waveform": audio_waveform, "sample_rate": self._audio_processor.audio_sr} comfyui_audio = {"waveform": audio_waveform, "sample_rate": self._audio_processor.audio_sr}
return {"video": self.gen_video_final, "audio": comfyui_audio} return {"video": self.gen_video_final, "audio": comfyui_audio}
return {"video": None, "audio": None} return {"video": None, "audio": None}
def init_modules(self):
super().init_modules()
self.run_input_encoder = self._run_input_encoder_local_r2v_audio
def load_transformer(self): def load_transformer(self):
"""Load transformer with LoRA support""" """Load transformer with LoRA support"""
base_model = WanAudioModel(self.config.model_path, self.config, self.init_device) base_model = WanAudioModel(self.config["model_path"], self.config, self.init_device)
if self.config.get("lora_configs") and self.config.lora_configs: if self.config.get("lora_configs") and self.config["lora_configs"]:
assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False) assert not self.config.get("dit_quantized", False) or self.config["mm_config"].get("weight_auto_quant", False)
lora_wrapper = WanLoraWrapper(base_model) lora_wrapper = WanLoraWrapper(base_model)
for lora_config in self.config.lora_configs: for lora_config in self.config["lora_configs"]:
lora_path = lora_config["path"] lora_path = lora_config["path"]
strength = lora_config.get("strength", 1.0) strength = lora_config.get("strength", 1.0)
lora_name = lora_wrapper.load_lora(lora_path) lora_name = lora_wrapper.load_lora(lora_path)
...@@ -814,7 +778,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -814,7 +778,7 @@ class WanAudioRunner(WanRunner): # type:ignore
audio_adapter.to(device) audio_adapter.to(device)
load_from_rank0 = self.config.get("load_from_rank0", False) load_from_rank0 = self.config.get("load_from_rank0", False)
weights_dict = load_weights(self.config.adapter_model_path, cpu_offload=audio_adapter_offload, remove_key="ca", load_from_rank0=load_from_rank0) weights_dict = load_weights(self.config["adapter_model_path"], cpu_offload=audio_adapter_offload, remove_key="ca", load_from_rank0=load_from_rank0)
audio_adapter.load_state_dict(weights_dict, strict=False) audio_adapter.load_state_dict(weights_dict, strict=False)
return audio_adapter.to(dtype=GET_DTYPE()) return audio_adapter.to(dtype=GET_DTYPE())
...@@ -824,28 +788,14 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -824,28 +788,14 @@ class WanAudioRunner(WanRunner): # type:ignore
self.audio_encoder = self.load_audio_encoder() self.audio_encoder = self.load_audio_encoder()
self.audio_adapter = self.load_audio_adapter() self.audio_adapter = self.load_audio_adapter()
def set_target_shape(self): def get_latent_shape_with_lat_hw(self, latent_h, latent_w):
"""Set target shape for generation""" latent_shape = [
ret = {} self.config.get("num_channels_latents", 16),
num_channels_latents = 16 (self.config["target_video_length"] - 1) // self.config["vae_stride"][0] + 1,
if self.config.model_cls == "wan2.2_audio": latent_h,
num_channels_latents = self.config.num_channels_latents latent_w,
]
if self.config.task == "i2v": return latent_shape
self.config.target_shape = (
num_channels_latents,
(self.config.target_video_length - 1) // self.config.vae_stride[0] + 1,
self.config.lat_h,
self.config.lat_w,
)
ret["lat_h"] = self.config.lat_h
ret["lat_w"] = self.config.lat_w
else:
error_msg = "t2v task is not supported in WanAudioRunner"
assert False, error_msg
ret["target_shape"] = self.config.target_shape
return ret
@RUNNER_REGISTER("wan2.2_audio") @RUNNER_REGISTER("wan2.2_audio")
...@@ -882,7 +832,7 @@ class Wan22AudioRunner(WanAudioRunner): ...@@ -882,7 +832,7 @@ class Wan22AudioRunner(WanAudioRunner):
"cpu_offload": vae_offload, "cpu_offload": vae_offload,
"offload_cache": self.config.get("vae_offload_cache", False), "offload_cache": self.config.get("vae_offload_cache", False),
} }
if self.config.task != "i2v": if self.config.task not in ["i2v", "s2v"]:
return None return None
else: else:
return Wan2_2_VAE(**vae_config) return Wan2_2_VAE(**vae_config)
......
...@@ -50,7 +50,7 @@ class WanCausVidRunner(WanRunner): ...@@ -50,7 +50,7 @@ class WanCausVidRunner(WanRunner):
self.scheduler = WanStepDistillScheduler(self.config) self.scheduler = WanStepDistillScheduler(self.config)
def set_target_shape(self): def set_target_shape(self):
if self.config.task == "i2v": if self.config.task in ["i2v", "s2v"]:
self.config.target_shape = (16, self.config.num_frame_per_block, self.config.lat_h, self.config.lat_w) self.config.target_shape = (16, self.config.num_frame_per_block, self.config.lat_h, self.config.lat_w)
# i2v需根据input shape重置frame_seq_length # i2v需根据input shape重置frame_seq_length
frame_seq_length = (self.config.lat_h // 2) * (self.config.lat_w // 2) frame_seq_length = (self.config.lat_h // 2) * (self.config.lat_w // 2)
......
...@@ -17,28 +17,28 @@ class WanDistillRunner(WanRunner): ...@@ -17,28 +17,28 @@ class WanDistillRunner(WanRunner):
super().__init__(config) super().__init__(config)
def load_transformer(self): def load_transformer(self):
if self.config.get("lora_configs") and self.config.lora_configs: if self.config.get("lora_configs") and self.config["lora_configs"]:
model = WanModel( model = WanModel(
self.config.model_path, self.config["model_path"],
self.config, self.config,
self.init_device, self.init_device,
) )
lora_wrapper = WanLoraWrapper(model) lora_wrapper = WanLoraWrapper(model)
for lora_config in self.config.lora_configs: for lora_config in self.config["lora_configs"]:
lora_path = lora_config["path"] lora_path = lora_config["path"]
strength = lora_config.get("strength", 1.0) strength = lora_config.get("strength", 1.0)
lora_name = lora_wrapper.load_lora(lora_path) lora_name = lora_wrapper.load_lora(lora_path)
lora_wrapper.apply_lora(lora_name, strength) lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}") logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
else: else:
model = WanDistillModel(self.config.model_path, self.config, self.init_device) model = WanDistillModel(self.config["model_path"], self.config, self.init_device)
return model return model
def init_scheduler(self): def init_scheduler(self):
if self.config.feature_caching == "NoCaching": if self.config["feature_caching"] == "NoCaching":
self.scheduler = WanStepDistillScheduler(self.config) self.scheduler = WanStepDistillScheduler(self.config)
else: else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}") raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
class MultiDistillModelStruct(MultiModelStruct): class MultiDistillModelStruct(MultiModelStruct):
...@@ -54,7 +54,7 @@ class MultiDistillModelStruct(MultiModelStruct): ...@@ -54,7 +54,7 @@ class MultiDistillModelStruct(MultiModelStruct):
def get_current_model_index(self): def get_current_model_index(self):
if self.scheduler.step_index < self.boundary_step_index: if self.scheduler.step_index < self.boundary_step_index:
logger.info(f"using - HIGH - noise model at step_index {self.scheduler.step_index + 1}") logger.info(f"using - HIGH - noise model at step_index {self.scheduler.step_index + 1}")
self.scheduler.sample_guide_scale = self.config.sample_guide_scale[0] self.scheduler.sample_guide_scale = self.config["sample_guide_scale"][0]
if self.config.get("cpu_offload", False) and self.config.get("offload_granularity", "block") == "model": if self.config.get("cpu_offload", False) and self.config.get("offload_granularity", "block") == "model":
if self.cur_model_index == -1: if self.cur_model_index == -1:
self.to_cuda(model_index=0) self.to_cuda(model_index=0)
...@@ -64,7 +64,7 @@ class MultiDistillModelStruct(MultiModelStruct): ...@@ -64,7 +64,7 @@ class MultiDistillModelStruct(MultiModelStruct):
self.cur_model_index = 0 self.cur_model_index = 0
else: else:
logger.info(f"using - LOW - noise model at step_index {self.scheduler.step_index + 1}") logger.info(f"using - LOW - noise model at step_index {self.scheduler.step_index + 1}")
self.scheduler.sample_guide_scale = self.config.sample_guide_scale[1] self.scheduler.sample_guide_scale = self.config["sample_guide_scale"][1]
if self.config.get("cpu_offload", False) and self.config.get("offload_granularity", "block") == "model": if self.config.get("cpu_offload", False) and self.config.get("offload_granularity", "block") == "model":
if self.cur_model_index == -1: if self.cur_model_index == -1:
self.to_cuda(model_index=1) self.to_cuda(model_index=1)
...@@ -81,8 +81,8 @@ class Wan22MoeDistillRunner(WanDistillRunner): ...@@ -81,8 +81,8 @@ class Wan22MoeDistillRunner(WanDistillRunner):
def load_transformer(self): def load_transformer(self):
use_high_lora, use_low_lora = False, False use_high_lora, use_low_lora = False, False
if self.config.get("lora_configs") and self.config.lora_configs: if self.config.get("lora_configs") and self.config["lora_configs"]:
for lora_config in self.config.lora_configs: for lora_config in self.config["lora_configs"]:
if lora_config.get("name", "") == "high_noise_model": if lora_config.get("name", "") == "high_noise_model":
use_high_lora = True use_high_lora = True
elif lora_config.get("name", "") == "low_noise_model": elif lora_config.get("name", "") == "low_noise_model":
...@@ -90,12 +90,12 @@ class Wan22MoeDistillRunner(WanDistillRunner): ...@@ -90,12 +90,12 @@ class Wan22MoeDistillRunner(WanDistillRunner):
if use_high_lora: if use_high_lora:
high_noise_model = WanModel( high_noise_model = WanModel(
os.path.join(self.config.model_path, "high_noise_model"), os.path.join(self.config["model_path"], "high_noise_model"),
self.config, self.config,
self.init_device, self.init_device,
) )
high_lora_wrapper = WanLoraWrapper(high_noise_model) high_lora_wrapper = WanLoraWrapper(high_noise_model)
for lora_config in self.config.lora_configs: for lora_config in self.config["lora_configs"]:
if lora_config.get("name", "") == "high_noise_model": if lora_config.get("name", "") == "high_noise_model":
lora_path = lora_config["path"] lora_path = lora_config["path"]
strength = lora_config.get("strength", 1.0) strength = lora_config.get("strength", 1.0)
...@@ -104,19 +104,19 @@ class Wan22MoeDistillRunner(WanDistillRunner): ...@@ -104,19 +104,19 @@ class Wan22MoeDistillRunner(WanDistillRunner):
logger.info(f"High noise model loaded LoRA: {lora_name} with strength: {strength}") logger.info(f"High noise model loaded LoRA: {lora_name} with strength: {strength}")
else: else:
high_noise_model = Wan22MoeDistillModel( high_noise_model = Wan22MoeDistillModel(
os.path.join(self.config.model_path, "distill_models", "high_noise_model"), os.path.join(self.config["model_path"], "distill_models", "high_noise_model"),
self.config, self.config,
self.init_device, self.init_device,
) )
if use_low_lora: if use_low_lora:
low_noise_model = WanModel( low_noise_model = WanModel(
os.path.join(self.config.model_path, "low_noise_model"), os.path.join(self.config["model_path"], "low_noise_model"),
self.config, self.config,
self.init_device, self.init_device,
) )
low_lora_wrapper = WanLoraWrapper(low_noise_model) low_lora_wrapper = WanLoraWrapper(low_noise_model)
for lora_config in self.config.lora_configs: for lora_config in self.config["lora_configs"]:
if lora_config.get("name", "") == "low_noise_model": if lora_config.get("name", "") == "low_noise_model":
lora_path = lora_config["path"] lora_path = lora_config["path"]
strength = lora_config.get("strength", 1.0) strength = lora_config.get("strength", 1.0)
...@@ -125,15 +125,15 @@ class Wan22MoeDistillRunner(WanDistillRunner): ...@@ -125,15 +125,15 @@ class Wan22MoeDistillRunner(WanDistillRunner):
logger.info(f"Low noise model loaded LoRA: {lora_name} with strength: {strength}") logger.info(f"Low noise model loaded LoRA: {lora_name} with strength: {strength}")
else: else:
low_noise_model = Wan22MoeDistillModel( low_noise_model = Wan22MoeDistillModel(
os.path.join(self.config.model_path, "distill_models", "low_noise_model"), os.path.join(self.config["model_path"], "distill_models", "low_noise_model"),
self.config, self.config,
self.init_device, self.init_device,
) )
return MultiDistillModelStruct([high_noise_model, low_noise_model], self.config, self.config.boundary_step_index) return MultiDistillModelStruct([high_noise_model, low_noise_model], self.config, self.config["boundary_step_index"])
def init_scheduler(self): def init_scheduler(self):
if self.config.feature_caching == "NoCaching": if self.config["feature_caching"] == "NoCaching":
self.scheduler = Wan22StepDistillScheduler(self.config) self.scheduler = Wan22StepDistillScheduler(self.config)
else: else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}") raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
...@@ -28,7 +28,7 @@ from lightx2v.utils.envs import * ...@@ -28,7 +28,7 @@ from lightx2v.utils.envs import *
from lightx2v.utils.profiler import * from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import * from lightx2v.utils.utils import *
from lightx2v.utils.utils import best_output_size, cache_video from lightx2v.utils.utils import best_output_size
@RUNNER_REGISTER("wan2.1") @RUNNER_REGISTER("wan2.1")
...@@ -42,7 +42,7 @@ class WanRunner(DefaultRunner): ...@@ -42,7 +42,7 @@ class WanRunner(DefaultRunner):
def load_transformer(self): def load_transformer(self):
model = WanModel( model = WanModel(
self.config.model_path, self.config["model_path"],
self.config, self.config,
self.init_device, self.init_device,
) )
...@@ -59,7 +59,7 @@ class WanRunner(DefaultRunner): ...@@ -59,7 +59,7 @@ class WanRunner(DefaultRunner):
def load_image_encoder(self): def load_image_encoder(self):
image_encoder = None image_encoder = None
if self.config.task in ["i2v", "flf2v", "animate"] and self.config.get("use_image_encoder", True): if self.config["task"] in ["i2v", "flf2v", "animate", "s2v"] and self.config.get("use_image_encoder", True):
# offload config # offload config
clip_offload = self.config.get("clip_cpu_offload", self.config.get("cpu_offload", False)) clip_offload = self.config.get("clip_cpu_offload", self.config.get("cpu_offload", False))
if clip_offload: if clip_offload:
...@@ -148,13 +148,13 @@ class WanRunner(DefaultRunner): ...@@ -148,13 +148,13 @@ class WanRunner(DefaultRunner):
vae_config = { vae_config = {
"vae_pth": find_torch_model_path(self.config, "vae_pth", self.vae_name), "vae_pth": find_torch_model_path(self.config, "vae_pth", self.vae_name),
"device": vae_device, "device": vae_device,
"parallel": self.config.parallel, "parallel": self.config["parallel"],
"use_tiling": self.config.get("use_tiling_vae", False), "use_tiling": self.config.get("use_tiling_vae", False),
"cpu_offload": vae_offload, "cpu_offload": vae_offload,
"dtype": GET_DTYPE(), "dtype": GET_DTYPE(),
"load_from_rank0": self.config.get("load_from_rank0", False), "load_from_rank0": self.config.get("load_from_rank0", False),
} }
if self.config.task not in ["i2v", "flf2v", "animate", "vace"]: if self.config["task"] not in ["i2v", "flf2v", "animate", "vace", "s2v"]:
return None return None
else: else:
return self.vae_cls(**vae_config) return self.vae_cls(**vae_config)
...@@ -170,7 +170,7 @@ class WanRunner(DefaultRunner): ...@@ -170,7 +170,7 @@ class WanRunner(DefaultRunner):
vae_config = { vae_config = {
"vae_pth": find_torch_model_path(self.config, "vae_pth", self.vae_name), "vae_pth": find_torch_model_path(self.config, "vae_pth", self.vae_name),
"device": vae_device, "device": vae_device,
"parallel": self.config.parallel, "parallel": self.config["parallel"],
"use_tiling": self.config.get("use_tiling_vae", False), "use_tiling": self.config.get("use_tiling_vae", False),
"cpu_offload": vae_offload, "cpu_offload": vae_offload,
"dtype": GET_DTYPE(), "dtype": GET_DTYPE(),
...@@ -192,9 +192,9 @@ class WanRunner(DefaultRunner): ...@@ -192,9 +192,9 @@ class WanRunner(DefaultRunner):
return vae_encoder, vae_decoder return vae_encoder, vae_decoder
def init_scheduler(self): def init_scheduler(self):
if self.config.feature_caching == "NoCaching": if self.config["feature_caching"] == "NoCaching":
scheduler_class = WanScheduler scheduler_class = WanScheduler
elif self.config.feature_caching == "TaylorSeer": elif self.config["feature_caching"] == "TaylorSeer":
scheduler_class = WanSchedulerTaylorCaching scheduler_class = WanSchedulerTaylorCaching
elif self.config.feature_caching in ["Tea", "Ada", "Custom", "FirstBlock", "DualBlock", "DynamicBlock", "Mag"]: elif self.config.feature_caching in ["Tea", "Ada", "Custom", "FirstBlock", "DualBlock", "DynamicBlock", "Mag"]:
scheduler_class = WanSchedulerCaching scheduler_class = WanSchedulerCaching
...@@ -206,26 +206,28 @@ class WanRunner(DefaultRunner): ...@@ -206,26 +206,28 @@ class WanRunner(DefaultRunner):
else: else:
self.scheduler = scheduler_class(self.config) self.scheduler = scheduler_class(self.config)
def run_text_encoder(self, text, img=None): def run_text_encoder(self, input_info):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.text_encoders = self.load_text_encoder() self.text_encoders = self.load_text_encoder()
n_prompt = self.config.get("negative_prompt", "")
prompt = input_info.prompt_enhanced if self.config["use_prompt_enhancer"] else input_info.prompt
neg_prompt = input_info.negative_prompt
if self.config["cfg_parallel"]: if self.config["cfg_parallel"]:
cfg_p_group = self.config["device_mesh"].get_group(mesh_dim="cfg_p") cfg_p_group = self.config["device_mesh"].get_group(mesh_dim="cfg_p")
cfg_p_rank = dist.get_rank(cfg_p_group) cfg_p_rank = dist.get_rank(cfg_p_group)
if cfg_p_rank == 0: if cfg_p_rank == 0:
context = self.text_encoders[0].infer([text]) context = self.text_encoders[0].infer([prompt])
context = torch.stack([torch.cat([u, u.new_zeros(self.config["text_len"] - u.size(0), u.size(1))]) for u in context]) context = torch.stack([torch.cat([u, u.new_zeros(self.config["text_len"] - u.size(0), u.size(1))]) for u in context])
text_encoder_output = {"context": context} text_encoder_output = {"context": context}
else: else:
context_null = self.text_encoders[0].infer([n_prompt]) context_null = self.text_encoders[0].infer([neg_prompt])
context_null = torch.stack([torch.cat([u, u.new_zeros(self.config["text_len"] - u.size(0), u.size(1))]) for u in context_null]) context_null = torch.stack([torch.cat([u, u.new_zeros(self.config["text_len"] - u.size(0), u.size(1))]) for u in context_null])
text_encoder_output = {"context_null": context_null} text_encoder_output = {"context_null": context_null}
else: else:
context = self.text_encoders[0].infer([text]) context = self.text_encoders[0].infer([prompt])
context = torch.stack([torch.cat([u, u.new_zeros(self.config["text_len"] - u.size(0), u.size(1))]) for u in context]) context = torch.stack([torch.cat([u, u.new_zeros(self.config["text_len"] - u.size(0), u.size(1))]) for u in context])
context_null = self.text_encoders[0].infer([n_prompt]) context_null = self.text_encoders[0].infer([neg_prompt])
context_null = torch.stack([torch.cat([u, u.new_zeros(self.config["text_len"] - u.size(0), u.size(1))]) for u in context_null]) context_null = torch.stack([torch.cat([u, u.new_zeros(self.config["text_len"] - u.size(0), u.size(1))]) for u in context_null])
text_encoder_output = { text_encoder_output = {
"context": context, "context": context,
...@@ -255,22 +257,22 @@ class WanRunner(DefaultRunner): ...@@ -255,22 +257,22 @@ class WanRunner(DefaultRunner):
def run_vae_encoder(self, first_frame, last_frame=None): def run_vae_encoder(self, first_frame, last_frame=None):
h, w = first_frame.shape[2:] h, w = first_frame.shape[2:]
aspect_ratio = h / w aspect_ratio = h / w
max_area = self.config.target_height * self.config.target_width max_area = self.config["target_height"] * self.config["target_width"]
lat_h = round(np.sqrt(max_area * aspect_ratio) // self.config.vae_stride[1] // self.config.patch_size[1] * self.config.patch_size[1]) latent_h = round(np.sqrt(max_area * aspect_ratio) // self.config["vae_stride"][1] // self.config["patch_size"][1] * self.config["patch_size"][1])
lat_w = round(np.sqrt(max_area / aspect_ratio) // self.config.vae_stride[2] // self.config.patch_size[2] * self.config.patch_size[2]) latent_w = round(np.sqrt(max_area / aspect_ratio) // self.config["vae_stride"][2] // self.config["patch_size"][2] * self.config["patch_size"][2])
latent_shape = self.get_latent_shape_with_lat_hw(latent_h, latent_w) # Important: latent_shape is used to set the input_info
if self.config.get("changing_resolution", False): if self.config.get("changing_resolution", False):
assert last_frame is None assert last_frame is None
self.config.lat_h, self.config.lat_w = lat_h, lat_w
vae_encode_out_list = [] vae_encode_out_list = []
for i in range(len(self.config["resolution_rate"])): for i in range(len(self.config["resolution_rate"])):
lat_h, lat_w = ( latent_h_tmp, latent_w_tmp = (
int(self.config.lat_h * self.config.resolution_rate[i]) // 2 * 2, int(latent_h * self.config["resolution_rate"][i]) // 2 * 2,
int(self.config.lat_w * self.config.resolution_rate[i]) // 2 * 2, int(latent_w * self.config["resolution_rate"][i]) // 2 * 2,
) )
vae_encode_out_list.append(self.get_vae_encoder_output(first_frame, lat_h, lat_w)) vae_encode_out_list.append(self.get_vae_encoder_output(first_frame, latent_h_tmp, latent_w_tmp))
vae_encode_out_list.append(self.get_vae_encoder_output(first_frame, self.config.lat_h, self.config.lat_w)) vae_encode_out_list.append(self.get_vae_encoder_output(first_frame, latent_h, latent_w))
return vae_encode_out_list return vae_encode_out_list, latent_shape
else: else:
if last_frame is not None: if last_frame is not None:
first_frame_size = first_frame.shape[2:] first_frame_size = first_frame.shape[2:]
...@@ -282,16 +284,15 @@ class WanRunner(DefaultRunner): ...@@ -282,16 +284,15 @@ class WanRunner(DefaultRunner):
round(last_frame_size[1] * last_frame_resize_ratio), round(last_frame_size[1] * last_frame_resize_ratio),
] ]
last_frame = TF.center_crop(last_frame, last_frame_size) last_frame = TF.center_crop(last_frame, last_frame_size)
self.config.lat_h, self.config.lat_w = lat_h, lat_w vae_encoder_out = self.get_vae_encoder_output(first_frame, latent_h, latent_w, last_frame)
vae_encoder_out = self.get_vae_encoder_output(first_frame, lat_h, lat_w, last_frame) return vae_encoder_out, latent_shape
return vae_encoder_out
def get_vae_encoder_output(self, first_frame, lat_h, lat_w, last_frame=None): def get_vae_encoder_output(self, first_frame, lat_h, lat_w, last_frame=None):
h = lat_h * self.config.vae_stride[1] h = lat_h * self.config["vae_stride"][1]
w = lat_w * self.config.vae_stride[2] w = lat_w * self.config["vae_stride"][2]
msk = torch.ones( msk = torch.ones(
1, 1,
self.config.target_video_length, self.config["target_video_length"],
lat_h, lat_h,
lat_w, lat_w,
device=torch.device("cuda"), device=torch.device("cuda"),
...@@ -312,7 +313,7 @@ class WanRunner(DefaultRunner): ...@@ -312,7 +313,7 @@ class WanRunner(DefaultRunner):
vae_input = torch.concat( vae_input = torch.concat(
[ [
torch.nn.functional.interpolate(first_frame.cpu(), size=(h, w), mode="bicubic").transpose(0, 1), torch.nn.functional.interpolate(first_frame.cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
torch.zeros(3, self.config.target_video_length - 2, h, w), torch.zeros(3, self.config["target_video_length"] - 2, h, w),
torch.nn.functional.interpolate(last_frame.cpu(), size=(h, w), mode="bicubic").transpose(0, 1), torch.nn.functional.interpolate(last_frame.cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
], ],
dim=1, dim=1,
...@@ -321,7 +322,7 @@ class WanRunner(DefaultRunner): ...@@ -321,7 +322,7 @@ class WanRunner(DefaultRunner):
vae_input = torch.concat( vae_input = torch.concat(
[ [
torch.nn.functional.interpolate(first_frame.cpu(), size=(h, w), mode="bicubic").transpose(0, 1), torch.nn.functional.interpolate(first_frame.cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
torch.zeros(3, self.config.target_video_length - 1, h, w), torch.zeros(3, self.config["target_video_length"] - 1, h, w),
], ],
dim=1, dim=1,
).cuda() ).cuda()
...@@ -345,32 +346,23 @@ class WanRunner(DefaultRunner): ...@@ -345,32 +346,23 @@ class WanRunner(DefaultRunner):
"image_encoder_output": image_encoder_output, "image_encoder_output": image_encoder_output,
} }
def set_target_shape(self): def get_latent_shape_with_lat_hw(self, latent_h, latent_w):
num_channels_latents = self.config.get("num_channels_latents", 16) latent_shape = [
if self.config.task in ["i2v", "flf2v", "animate"]: self.config.get("num_channels_latents", 16),
self.config.target_shape = ( (self.config["target_video_length"] - 1) // self.config["vae_stride"][0] + 1,
num_channels_latents, latent_h,
(self.config.target_video_length - 1) // self.config.vae_stride[0] + 1, latent_w,
self.config.lat_h, ]
self.config.lat_w, return latent_shape
)
elif self.config.task == "t2v": def get_latent_shape_with_target_hw(self, target_h, target_w):
self.config.target_shape = ( latent_shape = [
num_channels_latents, self.config.get("num_channels_latents", 16),
(self.config.target_video_length - 1) // self.config.vae_stride[0] + 1, (self.config["target_video_length"] - 1) // self.config["vae_stride"][0] + 1,
int(self.config.target_height) // self.config.vae_stride[1], int(target_h) // self.config["vae_stride"][1],
int(self.config.target_width) // self.config.vae_stride[2], int(target_w) // self.config["vae_stride"][2],
) ]
return latent_shape
def save_video_func(self, images):
cache_video(
tensor=images,
save_file=self.config.save_video_path,
fps=self.config.get("fps", 16),
nrow=1,
normalize=True,
value_range=(-1, 1),
)
class MultiModelStruct: class MultiModelStruct:
...@@ -400,7 +392,7 @@ class MultiModelStruct: ...@@ -400,7 +392,7 @@ class MultiModelStruct:
def get_current_model_index(self): def get_current_model_index(self):
if self.scheduler.timesteps[self.scheduler.step_index] >= self.boundary_timestep: if self.scheduler.timesteps[self.scheduler.step_index] >= self.boundary_timestep:
logger.info(f"using - HIGH - noise model at step_index {self.scheduler.step_index + 1}") logger.info(f"using - HIGH - noise model at step_index {self.scheduler.step_index + 1}")
self.scheduler.sample_guide_scale = self.config.sample_guide_scale[0] self.scheduler.sample_guide_scale = self.config["sample_guide_scale"][0]
if self.config.get("cpu_offload", False) and self.config.get("offload_granularity", "block") == "model": if self.config.get("cpu_offload", False) and self.config.get("offload_granularity", "block") == "model":
if self.cur_model_index == -1: if self.cur_model_index == -1:
self.to_cuda(model_index=0) self.to_cuda(model_index=0)
...@@ -410,7 +402,7 @@ class MultiModelStruct: ...@@ -410,7 +402,7 @@ class MultiModelStruct:
self.cur_model_index = 0 self.cur_model_index = 0
else: else:
logger.info(f"using - LOW - noise model at step_index {self.scheduler.step_index + 1}") logger.info(f"using - LOW - noise model at step_index {self.scheduler.step_index + 1}")
self.scheduler.sample_guide_scale = self.config.sample_guide_scale[1] self.scheduler.sample_guide_scale = self.config["sample_guide_scale"][1]
if self.config.get("cpu_offload", False) and self.config.get("offload_granularity", "block") == "model": if self.config.get("cpu_offload", False) and self.config.get("offload_granularity", "block") == "model":
if self.cur_model_index == -1: if self.cur_model_index == -1:
self.to_cuda(model_index=1) self.to_cuda(model_index=1)
...@@ -434,20 +426,20 @@ class Wan22MoeRunner(WanRunner): ...@@ -434,20 +426,20 @@ class Wan22MoeRunner(WanRunner):
def load_transformer(self): def load_transformer(self):
# encoder -> high_noise_model -> low_noise_model -> vae -> video_output # encoder -> high_noise_model -> low_noise_model -> vae -> video_output
high_noise_model = WanModel( high_noise_model = WanModel(
os.path.join(self.config.model_path, "high_noise_model"), os.path.join(self.config["model_path"], "high_noise_model"),
self.config, self.config,
self.init_device, self.init_device,
) )
low_noise_model = WanModel( low_noise_model = WanModel(
os.path.join(self.config.model_path, "low_noise_model"), os.path.join(self.config["model_path"], "low_noise_model"),
self.config, self.config,
self.init_device, self.init_device,
) )
if self.config.get("lora_configs") and self.config.lora_configs: if self.config.get("lora_configs") and self.config["lora_configs"]:
assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False) assert not self.config.get("dit_quantized", False) or self.config["mm_config"].get("weight_auto_quant", False)
for lora_config in self.config.lora_configs: for lora_config in self.config["lora_configs"]:
lora_path = lora_config["path"] lora_path = lora_config["path"]
strength = lora_config.get("strength", 1.0) strength = lora_config.get("strength", 1.0)
base_name = os.path.basename(lora_path) base_name = os.path.basename(lora_path)
...@@ -464,7 +456,7 @@ class Wan22MoeRunner(WanRunner): ...@@ -464,7 +456,7 @@ class Wan22MoeRunner(WanRunner):
else: else:
raise ValueError(f"Unsupported LoRA path: {lora_path}") raise ValueError(f"Unsupported LoRA path: {lora_path}")
return MultiModelStruct([high_noise_model, low_noise_model], self.config, self.config.boundary) return MultiModelStruct([high_noise_model, low_noise_model], self.config, self.config["boundary"])
@RUNNER_REGISTER("wan2.2") @RUNNER_REGISTER("wan2.2")
......
...@@ -9,11 +9,10 @@ from lightx2v.models.runners.wan.wan_runner import WanRunner ...@@ -9,11 +9,10 @@ from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.schedulers.wan.self_forcing.scheduler import WanSFScheduler from lightx2v.models.schedulers.wan.self_forcing.scheduler import WanSFScheduler
from lightx2v.models.video_encoders.hf.wan.vae_sf import WanSFVAE from lightx2v.models.video_encoders.hf.wan.vae_sf import WanSFVAE
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.memory_profiler import peak_memory_decorator
from lightx2v.utils.profiler import * from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
torch.manual_seed(42)
@RUNNER_REGISTER("wan2.1_sf") @RUNNER_REGISTER("wan2.1_sf")
class WanSFRunner(WanRunner): class WanSFRunner(WanRunner):
...@@ -59,40 +58,37 @@ class WanSFRunner(WanRunner): ...@@ -59,40 +58,37 @@ class WanSFRunner(WanRunner):
gc.collect() gc.collect()
return images return images
@ProfilingContext4DebugL2("Run DiT") def init_run(self):
def run_main(self, total_steps=None): super().init_run()
self.init_run()
if self.config.get("compile", False): @ProfilingContext4DebugL1("End run segment")
self.model.select_graph_for_compile() def end_run_segment(self, segment_idx=None):
with ProfilingContext4DebugL1("step_pre_in_rerun"):
total_blocks = self.scheduler.num_blocks self.model.scheduler.step_pre(seg_index=segment_idx, step_index=self.model.scheduler.infer_steps - 1, is_rerun=True)
gen_videos = [] with ProfilingContext4DebugL1("🚀 infer_main_in_rerun"):
for seg_index in range(self.video_segment_num): self.model.infer(self.inputs)
logger.info(f"==> segment_index: {seg_index + 1} / {total_blocks}") self.gen_video_final = torch.cat([self.gen_video_final, self.gen_video], dim=0) if self.gen_video_final is not None else self.gen_video
total_steps = len(self.scheduler.denoising_step_list) @peak_memory_decorator
for step_index in range(total_steps): def run_segment(self, total_steps=None):
logger.info(f"==> step_index: {step_index + 1} / {total_steps}") if total_steps is None:
total_steps = self.model.scheduler.infer_steps
with ProfilingContext4DebugL1("step_pre"): for step_index in range(total_steps):
self.model.scheduler.step_pre(seg_index=seg_index, step_index=step_index, is_rerun=False) # only for single segment, check stop signal every step
if self.video_segment_num == 1:
with ProfilingContext4DebugL1("🚀 infer_main"): self.check_stop()
self.model.infer(self.inputs) logger.info(f"==> step_index: {step_index + 1} / {total_steps}")
with ProfilingContext4DebugL1("step_post"): with ProfilingContext4DebugL1("step_pre"):
self.model.scheduler.step_post() self.model.scheduler.step_pre(seg_index=self.segment_idx, step_index=step_index, is_rerun=False)
latents = self.model.scheduler.stream_output with ProfilingContext4DebugL1("🚀 infer_main"):
gen_videos.append(self.run_vae_decoder(latents))
# rerun with timestep zero to update KV cache using clean context
with ProfilingContext4DebugL1("step_pre_in_rerun"):
self.model.scheduler.step_pre(seg_index=seg_index, step_index=step_index, is_rerun=True)
with ProfilingContext4DebugL1("🚀 infer_main_in_rerun"):
self.model.infer(self.inputs) self.model.infer(self.inputs)
self.gen_video = torch.cat(gen_videos, dim=0) with ProfilingContext4DebugL1("step_post"):
self.model.scheduler.step_post()
if self.progress_callback:
self.progress_callback(((step_index + 1) / total_steps) * 100, 100)
self.end_run() return self.model.scheduler.stream_output
...@@ -154,10 +154,10 @@ class WanVaceRunner(WanRunner): ...@@ -154,10 +154,10 @@ class WanVaceRunner(WanRunner):
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(cat_latents, result_masks)] return [torch.cat([zz, mm], dim=0) for zz, mm in zip(cat_latents, result_masks)]
def set_target_shape(self): def set_input_info_latent_shape(self):
target_shape = self.latent_shape latent_shape = self.latent_shape
target_shape[0] = int(target_shape[0] / 2) latent_shape[0] = int(latent_shape[0] / 2)
self.config.target_shape = target_shape return latent_shape
@ProfilingContext4DebugL1("Run VAE Decoder") @ProfilingContext4DebugL1("Run VAE Decoder")
def run_vae_decoder(self, latents): def run_vae_decoder(self, latents):
......
...@@ -6,8 +6,8 @@ class BaseScheduler: ...@@ -6,8 +6,8 @@ class BaseScheduler:
self.config = config self.config = config
self.latents = None self.latents = None
self.step_index = 0 self.step_index = 0
self.infer_steps = config.infer_steps self.infer_steps = config["infer_steps"]
self.caching_records = [True] * config.infer_steps self.caching_records = [True] * config["infer_steps"]
self.flag_df = False self.flag_df = False
self.transformer_infer = None self.transformer_infer = None
self.infer_condition = True # cfg status self.infer_condition = True # cfg status
......
...@@ -13,8 +13,8 @@ class EulerScheduler(WanScheduler): ...@@ -13,8 +13,8 @@ class EulerScheduler(WanScheduler):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
if self.config.parallel: if self.config["parallel"]:
self.sp_size = self.config.parallel.get("seq_p_size", 1) self.sp_size = self.config["parallel"].get("seq_p_size", 1)
else: else:
self.sp_size = 1 self.sp_size = 1
...@@ -33,11 +33,11 @@ class EulerScheduler(WanScheduler): ...@@ -33,11 +33,11 @@ class EulerScheduler(WanScheduler):
if self.audio_adapter.cpu_offload: if self.audio_adapter.cpu_offload:
self.audio_adapter.time_embedding.to("cpu") self.audio_adapter.time_embedding.to("cpu")
if self.config.model_cls == "wan2.2_audio": if self.config["model_cls"] == "wan2.2_audio":
_, lat_f, lat_h, lat_w = self.latents.shape _, lat_f, lat_h, lat_w = self.latents.shape
F = (lat_f - 1) * self.config.vae_stride[0] + 1 F = (lat_f - 1) * self.config["vae_stride"][0] + 1
per_latent_token_len = lat_h * lat_w // (self.config.patch_size[1] * self.config.patch_size[2]) per_latent_token_len = lat_h * lat_w // (self.config["patch_size"][1] * self.config["patch_size"][2])
max_seq_len = ((F - 1) // self.config.vae_stride[0] + 1) * per_latent_token_len max_seq_len = ((F - 1) // self.config["vae_stride"][0] + 1) * per_latent_token_len
max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
temp_ts = (self.mask[0][:, ::2, ::2] * self.timestep_input).flatten() temp_ts = (self.mask[0][:, ::2, ::2] * self.timestep_input).flatten()
...@@ -55,13 +55,13 @@ class EulerScheduler(WanScheduler): ...@@ -55,13 +55,13 @@ class EulerScheduler(WanScheduler):
dim=1, dim=1,
) )
def prepare_latents(self, target_shape, dtype=torch.float32): def prepare_latents(self, seed, latent_shape, dtype=torch.float32):
self.generator = torch.Generator(device=self.device).manual_seed(self.config.seed) self.generator = torch.Generator(device=self.device).manual_seed(seed)
self.latents = torch.randn( self.latents = torch.randn(
target_shape[0], latent_shape[0],
target_shape[1], latent_shape[1],
target_shape[2], latent_shape[2],
target_shape[3], latent_shape[3],
dtype=dtype, dtype=dtype,
device=self.device, device=self.device,
generator=self.generator, generator=self.generator,
...@@ -71,8 +71,8 @@ class EulerScheduler(WanScheduler): ...@@ -71,8 +71,8 @@ class EulerScheduler(WanScheduler):
if self.prev_latents is not None: if self.prev_latents is not None:
self.latents = (1.0 - self.mask) * self.prev_latents + self.mask * self.latents self.latents = (1.0 - self.mask) * self.prev_latents + self.mask * self.latents
def prepare(self, previmg_encoder_output=None): def prepare(self, seed, latent_shape, image_encoder_output=None):
self.prepare_latents(self.config.target_shape, dtype=torch.float32) self.prepare_latents(seed, latent_shape, dtype=torch.float32)
timesteps = np.linspace(self.num_train_timesteps, 0, self.infer_steps + 1, dtype=np.float32) timesteps = np.linspace(self.num_train_timesteps, 0, self.infer_steps + 1, dtype=np.float32)
self.timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=self.device) self.timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=self.device)
...@@ -93,11 +93,11 @@ class EulerScheduler(WanScheduler): ...@@ -93,11 +93,11 @@ class EulerScheduler(WanScheduler):
if self.config["model_cls"] == "wan2.2_audio" and self.prev_latents is not None: if self.config["model_cls"] == "wan2.2_audio" and self.prev_latents is not None:
self.latents = (1.0 - self.mask) * self.prev_latents + self.mask * self.latents self.latents = (1.0 - self.mask) * self.prev_latents + self.mask * self.latents
def reset(self, previmg_encoder_output=None): def reset(self, seed, latent_shape, image_encoder_output=None):
if self.config["model_cls"] == "wan2.2_audio": if self.config["model_cls"] == "wan2.2_audio":
self.prev_latents = previmg_encoder_output["prev_latents"] self.prev_latents = image_encoder_output["prev_latents"]
self.prev_len = previmg_encoder_output["prev_len"] self.prev_len = image_encoder_output["prev_len"]
self.prepare_latents(self.config.target_shape, dtype=torch.float32) self.prepare_latents(seed, latent_shape, dtype=torch.float32)
def unsqueeze_to_ndim(self, in_tensor, tgt_n_dim): def unsqueeze_to_ndim(self, in_tensor, tgt_n_dim):
if in_tensor.ndim > tgt_n_dim: if in_tensor.ndim > tgt_n_dim:
......
...@@ -19,16 +19,16 @@ class WanScheduler4ChangingResolution: ...@@ -19,16 +19,16 @@ class WanScheduler4ChangingResolution:
config["changing_resolution_steps"] = [config.infer_steps // 2] config["changing_resolution_steps"] = [config.infer_steps // 2]
assert len(config["resolution_rate"]) == len(config["changing_resolution_steps"]) assert len(config["resolution_rate"]) == len(config["changing_resolution_steps"])
def prepare_latents(self, target_shape, dtype=torch.float32): def prepare_latents(self, seed, latent_shape, dtype=torch.float32):
self.generator = torch.Generator(device=self.device).manual_seed(self.config.seed) self.generator = torch.Generator(device=self.device).manual_seed(seed)
self.latents_list = [] self.latents_list = []
for i in range(len(self.config["resolution_rate"])): for i in range(len(self.config["resolution_rate"])):
self.latents_list.append( self.latents_list.append(
torch.randn( torch.randn(
target_shape[0], latent_shape[0],
target_shape[1], latent_shape[1],
int(target_shape[2] * self.config["resolution_rate"][i]) // 2 * 2, int(latent_shape[2] * self.config["resolution_rate"][i]) // 2 * 2,
int(target_shape[3] * self.config["resolution_rate"][i]) // 2 * 2, int(latent_shape[3] * self.config["resolution_rate"][i]) // 2 * 2,
dtype=dtype, dtype=dtype,
device=self.device, device=self.device,
generator=self.generator, generator=self.generator,
...@@ -38,10 +38,10 @@ class WanScheduler4ChangingResolution: ...@@ -38,10 +38,10 @@ class WanScheduler4ChangingResolution:
# add original resolution latents # add original resolution latents
self.latents_list.append( self.latents_list.append(
torch.randn( torch.randn(
target_shape[0], latent_shape[0],
target_shape[1], latent_shape[1],
target_shape[2], latent_shape[2],
target_shape[3], latent_shape[3],
dtype=dtype, dtype=dtype,
device=self.device, device=self.device,
generator=self.generator, generator=self.generator,
......
import gc
from typing import List, Optional, Union from typing import List, Optional, Union
import numpy as np import numpy as np
...@@ -12,22 +11,22 @@ class WanScheduler(BaseScheduler): ...@@ -12,22 +11,22 @@ class WanScheduler(BaseScheduler):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.device = torch.device("cuda") self.device = torch.device("cuda")
self.infer_steps = self.config.infer_steps self.infer_steps = self.config["infer_steps"]
self.target_video_length = self.config.target_video_length self.target_video_length = self.config["target_video_length"]
self.sample_shift = self.config.sample_shift self.sample_shift = self.config["sample_shift"]
self.shift = 1 self.shift = 1
self.num_train_timesteps = 1000 self.num_train_timesteps = 1000
self.disable_corrector = [] self.disable_corrector = []
self.solver_order = 2 self.solver_order = 2
self.noise_pred = None self.noise_pred = None
self.sample_guide_scale = self.config.sample_guide_scale self.sample_guide_scale = self.config["sample_guide_scale"]
self.caching_records_2 = [True] * self.config.infer_steps self.caching_records_2 = [True] * self.config["infer_steps"]
def prepare(self, image_encoder_output=None): def prepare(self, seed, latent_shape, image_encoder_output=None):
if self.config["model_cls"] == "wan2.2" and self.config["task"] == "i2v": if self.config["model_cls"] == "wan2.2" and self.config["task"] in ["i2v", "s2v"]:
self.vae_encoder_out = image_encoder_output["vae_encoder_out"] self.vae_encoder_out = image_encoder_output["vae_encoder_out"]
self.prepare_latents(self.config.target_shape, dtype=torch.float32) self.prepare_latents(seed, latent_shape, dtype=torch.float32)
alphas = np.linspace(1, 1 / self.num_train_timesteps, self.num_train_timesteps)[::-1].copy() alphas = np.linspace(1, 1 / self.num_train_timesteps, self.num_train_timesteps)[::-1].copy()
sigmas = 1.0 - alphas sigmas = 1.0 - alphas
...@@ -48,18 +47,18 @@ class WanScheduler(BaseScheduler): ...@@ -48,18 +47,18 @@ class WanScheduler(BaseScheduler):
self.set_timesteps(self.infer_steps, device=self.device, shift=self.sample_shift) self.set_timesteps(self.infer_steps, device=self.device, shift=self.sample_shift)
def prepare_latents(self, target_shape, dtype=torch.float32): def prepare_latents(self, seed, latent_shape, dtype=torch.float32):
self.generator = torch.Generator(device=self.device).manual_seed(self.config.seed) self.generator = torch.Generator(device=self.device).manual_seed(seed)
self.latents = torch.randn( self.latents = torch.randn(
target_shape[0], latent_shape[0],
target_shape[1], latent_shape[1],
target_shape[2], latent_shape[2],
target_shape[3], latent_shape[3],
dtype=dtype, dtype=dtype,
device=self.device, device=self.device,
generator=self.generator, generator=self.generator,
) )
if self.config["model_cls"] == "wan2.2" and self.config["task"] == "i2v": if self.config["model_cls"] == "wan2.2" and self.config["task"] in ["i2v", "s2v"]:
self.mask = masks_like(self.latents, zero=True) self.mask = masks_like(self.latents, zero=True)
self.latents = (1.0 - self.mask) * self.vae_encoder_out + self.mask * self.latents self.latents = (1.0 - self.mask) * self.vae_encoder_out + self.mask * self.latents
...@@ -117,7 +116,7 @@ class WanScheduler(BaseScheduler): ...@@ -117,7 +116,7 @@ class WanScheduler(BaseScheduler):
x0_pred = sample - sigma_t * model_output x0_pred = sample - sigma_t * model_output
return x0_pred return x0_pred
def reset(self, step_index=None): def reset(self, seed, latent_shape, step_index=None):
if step_index is not None: if step_index is not None:
self.step_index = step_index self.step_index = step_index
self.model_outputs = [None] * self.solver_order self.model_outputs = [None] * self.solver_order
...@@ -126,9 +125,7 @@ class WanScheduler(BaseScheduler): ...@@ -126,9 +125,7 @@ class WanScheduler(BaseScheduler):
self.noise_pred = None self.noise_pred = None
self.this_order = None self.this_order = None
self.lower_order_nums = 0 self.lower_order_nums = 0
self.prepare_latents(self.config.target_shape, dtype=torch.float32) self.prepare_latents(seed, latent_shape, dtype=torch.float32)
gc.collect()
torch.cuda.empty_cache()
def multistep_uni_p_bh_update( def multistep_uni_p_bh_update(
self, self,
...@@ -325,7 +322,7 @@ class WanScheduler(BaseScheduler): ...@@ -325,7 +322,7 @@ class WanScheduler(BaseScheduler):
def step_pre(self, step_index): def step_pre(self, step_index):
super().step_pre(step_index) super().step_pre(step_index)
self.timestep_input = torch.stack([self.timesteps[self.step_index]]) self.timestep_input = torch.stack([self.timesteps[self.step_index]])
if self.config["model_cls"] == "wan2.2" and self.config["task"] == "i2v": if self.config["model_cls"] == "wan2.2" and self.config["task"] in ["i2v", "s2v"]:
self.timestep_input = (self.mask[0][:, ::2, ::2] * self.timestep_input).flatten() self.timestep_input = (self.mask[0][:, ::2, ::2] * self.timestep_input).flatten()
def step_post(self): def step_post(self):
...@@ -367,5 +364,5 @@ class WanScheduler(BaseScheduler): ...@@ -367,5 +364,5 @@ class WanScheduler(BaseScheduler):
self.lower_order_nums += 1 self.lower_order_nums += 1
self.latents = prev_sample self.latents = prev_sample
if self.config["model_cls"] == "wan2.2" and self.config["task"] == "i2v": if self.config["model_cls"] == "wan2.2" and self.config["task"] in ["i2v", "s2v"]:
self.latents = (1.0 - self.mask) * self.vae_encoder_out + self.mask * self.latents self.latents = (1.0 - self.mask) * self.vae_encoder_out + self.mask * self.latents
...@@ -9,24 +9,25 @@ class WanSFScheduler(WanScheduler): ...@@ -9,24 +9,25 @@ class WanSFScheduler(WanScheduler):
super().__init__(config) super().__init__(config)
self.device = torch.device("cuda") self.device = torch.device("cuda")
self.dtype = torch.bfloat16 self.dtype = torch.bfloat16
self.num_frame_per_block = self.config.sf_config.num_frame_per_block self.num_frame_per_block = self.config["sf_config"]["num_frame_per_block"]
self.num_output_frames = self.config.sf_config.num_output_frames self.num_output_frames = self.config["sf_config"]["num_output_frames"]
self.num_blocks = self.num_output_frames // self.num_frame_per_block self.num_blocks = self.num_output_frames // self.num_frame_per_block
self.denoising_step_list = self.config.sf_config.denoising_step_list self.denoising_step_list = self.config["sf_config"]["denoising_step_list"]
self.infer_steps = len(self.denoising_step_list)
self.all_num_frames = [self.num_frame_per_block] * self.num_blocks self.all_num_frames = [self.num_frame_per_block] * self.num_blocks
self.num_input_frames = 0 self.num_input_frames = 0
self.denoising_strength = 1.0 self.denoising_strength = 1.0
self.sigma_max = 1.0 self.sigma_max = 1.0
self.sigma_min = 0 self.sigma_min = 0
self.sf_shift = self.config.sf_config.shift self.sf_shift = self.config["sf_config"]["shift"]
self.inverse_timesteps = False self.inverse_timesteps = False
self.extra_one_step = True self.extra_one_step = True
self.reverse_sigmas = False self.reverse_sigmas = False
self.num_inference_steps = self.config.sf_config.num_inference_steps self.num_inference_steps = self.config["sf_config"]["num_inference_steps"]
self.context_noise = 0 self.context_noise = 0
def prepare(self, image_encoder_output=None): def prepare(self, seed, latent_shape, image_encoder_output=None):
self.latents = torch.randn(self.config.target_shape, device=self.device, dtype=self.dtype) self.latents = torch.randn(latent_shape, device=self.device, dtype=self.dtype)
timesteps = [] timesteps = []
for frame_block_idx, current_num_frames in enumerate(self.all_num_frames): for frame_block_idx, current_num_frames in enumerate(self.all_num_frames):
...@@ -39,7 +40,7 @@ class WanSFScheduler(WanScheduler): ...@@ -39,7 +40,7 @@ class WanSFScheduler(WanScheduler):
timesteps.append(frame_steps) timesteps.append(frame_steps)
self.timesteps = timesteps self.timesteps = timesteps
self.noise_pred = torch.zeros(self.config.target_shape, device=self.device, dtype=self.dtype) self.noise_pred = torch.zeros(latent_shape, device=self.device, dtype=self.dtype)
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * self.denoising_strength sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * self.denoising_strength
if self.extra_one_step: if self.extra_one_step:
...@@ -91,7 +92,7 @@ class WanSFScheduler(WanScheduler): ...@@ -91,7 +92,7 @@ class WanSFScheduler(WanScheduler):
x0_pred = x0_pred.to(original_dtype) x0_pred = x0_pred.to(original_dtype)
# add noise # add noise
if self.step_index < len(self.denoising_step_list) - 1: if self.step_index < self.infer_steps - 1:
timestep_next = self.timesteps[self.seg_index][self.step_index + 1] * torch.ones(self.num_frame_per_block, device=self.device, dtype=torch.long) timestep_next = self.timesteps[self.seg_index][self.step_index + 1] * torch.ones(self.num_frame_per_block, device=self.device, dtype=torch.long)
timestep_id_next = torch.argmin((self.timesteps_sf.unsqueeze(0) - timestep_next.unsqueeze(1)).abs(), dim=1) timestep_id_next = torch.argmin((self.timesteps_sf.unsqueeze(0) - timestep_next.unsqueeze(1)).abs(), dim=1)
sigma_next = self.sigmas_sf[timestep_id_next].reshape(-1, 1, 1, 1) sigma_next = self.sigmas_sf[timestep_id_next].reshape(-1, 1, 1, 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment