Commit 9e3680b7 authored by helloyongyang's avatar helloyongyang
Browse files

fix ci

parent 7367d6c8
...@@ -8,7 +8,7 @@ from lightx2v.common.ops import * ...@@ -8,7 +8,7 @@ from lightx2v.common.ops import *
from lightx2v.models.runners.cogvideox.cogvidex_runner import CogvideoxRunner # noqa: F401 from lightx2v.models.runners.cogvideox.cogvidex_runner import CogvideoxRunner # noqa: F401
from lightx2v.models.runners.graph_runner import GraphRunner from lightx2v.models.runners.graph_runner import GraphRunner
from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner # noqa: F401 from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner # noqa: F401
from lightx2v.models.runners.wan.wan_audio_runner import Wan22MoeAudioRunner, WanAudioRunner, Wan22AudioRunner # noqa: F401 from lightx2v.models.runners.wan.wan_audio_runner import Wan22AudioRunner, Wan22MoeAudioRunner, WanAudioRunner # noqa: F401
from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner # noqa: F401 from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner # noqa: F401
from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner # noqa: F401 from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner # noqa: F401
from lightx2v.models.runners.wan.wan_runner import Wan22MoeRunner, WanRunner # noqa: F401 from lightx2v.models.runners.wan.wan_runner import Wan22MoeRunner, WanRunner # noqa: F401
...@@ -39,7 +39,20 @@ def main(): ...@@ -39,7 +39,20 @@ def main():
"--model_cls", "--model_cls",
type=str, type=str,
required=True, required=True,
choices=["wan2.1", "hunyuan", "wan2.1_distill", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox", "wan2.1_audio", "wan2.2_moe", "wan2.2_moe_audio", "wan2.2_audio", "wan2.2", "wan2.2_moe_distill"], choices=[
"wan2.1",
"hunyuan",
"wan2.1_distill",
"wan2.1_causvid",
"wan2.1_skyreels_v2_df",
"cogvideox",
"wan2.1_audio",
"wan2.2_moe",
"wan2.2",
"wan2.2_moe_audio",
"wan2.2_audio",
"wan2.2_moe_distill",
],
default="wan2.1", default="wan2.1",
) )
......
...@@ -4,8 +4,8 @@ from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer ...@@ -4,8 +4,8 @@ from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from ..module_io import WanPreInferModuleOutput from ..module_io import WanPreInferModuleOutput
from ..utils import rope_params, sinusoidal_embedding_1d, masks_like from ..utils import masks_like, rope_params, sinusoidal_embedding_1d
from loguru import logger
class WanAudioPreInfer(WanPreInfer): class WanAudioPreInfer(WanPreInfer):
def __init__(self, config): def __init__(self, config):
...@@ -29,12 +29,11 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -29,12 +29,11 @@ class WanAudioPreInfer(WanPreInfer):
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE() self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
def infer(self, weights, inputs, positive): def infer(self, weights, inputs, positive):
prev_latents = inputs["previmg_encoder_output"]["prev_latents"] prev_latents = inputs["previmg_encoder_output"]["prev_latents"]
if self.config.model_cls == "wan2.2_audio": if self.config.model_cls == "wan2.2_audio":
hidden_states = self.scheduler.latents hidden_states = self.scheduler.latents
mask1, mask2 = masks_like([hidden_states], zero=True, prev_length=hidden_states.shape[1]) mask1, mask2 = masks_like([hidden_states], zero=True, prev_length=hidden_states.shape[1])
hidden_states = (1. - mask2[0]) * prev_latents + mask2[0] * hidden_states hidden_states = (1.0 - mask2[0]) * prev_latents + mask2[0] * hidden_states
else: else:
prev_latents = prev_latents.unsqueeze(0) prev_latents = prev_latents.unsqueeze(0)
prev_mask = inputs["previmg_encoder_output"]["prev_mask"] prev_mask = inputs["previmg_encoder_output"]["prev_mask"]
...@@ -53,7 +52,7 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -53,7 +52,7 @@ class WanAudioPreInfer(WanPreInfer):
"timestep": t, "timestep": t,
} }
audio_dit_blocks.append(inputs["audio_adapter_pipe"](**audio_model_input)) audio_dit_blocks.append(inputs["audio_adapter_pipe"](**audio_model_input))
audio_dit_blocks = None##Debug Drop Audio audio_dit_blocks = None ##Debug Drop Audio
if positive: if positive:
context = inputs["text_encoder_output"]["context"] context = inputs["text_encoder_output"]["context"]
...@@ -66,7 +65,7 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -66,7 +65,7 @@ class WanAudioPreInfer(WanPreInfer):
batch_size = len(x) batch_size = len(x)
num_channels, _, height, width = x[0].shape num_channels, _, height, width = x[0].shape
_, ref_num_channels, ref_num_frames, _, _ = ref_image_encoder.shape _, ref_num_channels, ref_num_frames, _, _ = ref_image_encoder.shape
if ref_num_channels != num_channels: if ref_num_channels != num_channels:
zero_padding = torch.zeros( zero_padding = torch.zeros(
(batch_size, num_channels - ref_num_channels, ref_num_frames, height, width), (batch_size, num_channels - ref_num_channels, ref_num_frames, height, width),
......
...@@ -15,15 +15,9 @@ def masks_like(tensor, zero=False, generator=None, p=0.2, prev_length=1): ...@@ -15,15 +15,9 @@ def masks_like(tensor, zero=False, generator=None, p=0.2, prev_length=1):
if zero: if zero:
if generator is not None: if generator is not None:
for u, v in zip(out1, out2): for u, v in zip(out1, out2):
random_num = torch.rand( random_num = torch.rand(1, generator=generator, device=generator.device).item()
1, generator=generator, device=generator.device).item()
if random_num < p: if random_num < p:
u[:, :prev_length] = torch.normal( u[:, :prev_length] = torch.normal(mean=-3.5, std=0.5, size=(1,), device=u.device, generator=generator).expand_as(u[:, :prev_length]).exp()
mean=-3.5,
std=0.5,
size=(1,),
device=u.device,
generator=generator).expand_as(u[:, :prev_length]).exp()
v[:, :prev_length] = torch.zeros_like(v[:, :prev_length]) v[:, :prev_length] = torch.zeros_like(v[:, :prev_length])
else: else:
u[:, :prev_length] = u[:, :prev_length] u[:, :prev_length] = u[:, :prev_length]
......
...@@ -21,11 +21,12 @@ from lightx2v.models.networks.wan.audio_model import Wan22MoeAudioModel, WanAudi ...@@ -21,11 +21,12 @@ from lightx2v.models.networks.wan.audio_model import Wan22MoeAudioModel, WanAudi
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.runners.wan.wan_runner import MultiModelStruct, WanRunner from lightx2v.models.runners.wan.wan_runner import MultiModelStruct, WanRunner
from lightx2v.models.schedulers.wan.audio.scheduler import ConsistencyModelScheduler from lightx2v.models.schedulers.wan.audio.scheduler import ConsistencyModelScheduler
from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.profiler import ProfilingContext, ProfilingContext4Debug from lightx2v.utils.profiler import ProfilingContext, ProfilingContext4Debug
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image, find_torch_model_path from lightx2v.utils.utils import find_torch_model_path, save_to_video, vae_to_comfyui_image
from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE
@contextmanager @contextmanager
def memory_efficient_inference(): def memory_efficient_inference():
...@@ -322,7 +323,7 @@ class VideoGenerator: ...@@ -322,7 +323,7 @@ class VideoGenerator:
if segment_idx == 0: if segment_idx == 0:
# First segment - create zero frames # First segment - create zero frames
prev_frames = torch.zeros((1, 3, max_num_frames, tgt_h, tgt_w), device=device) prev_frames = torch.zeros((1, 3, max_num_frames, tgt_h, tgt_w), device=device)
if self.config.model_cls == 'wan2.2_audio': if self.config.model_cls == "wan2.2_audio":
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config).to(dtype) prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config).to(dtype)
else: else:
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config)[0].to(dtype) prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config)[0].to(dtype)
...@@ -337,7 +338,7 @@ class VideoGenerator: ...@@ -337,7 +338,7 @@ class VideoGenerator:
else: else:
# Fallback to zeros if prepare_prev_latents fails # Fallback to zeros if prepare_prev_latents fails
prev_frames = torch.zeros((1, 3, max_num_frames, tgt_h, tgt_w), device=device) prev_frames = torch.zeros((1, 3, max_num_frames, tgt_h, tgt_w), device=device)
if self.config.model_cls == 'wan2.2_audio': if self.config.model_cls == "wan2.2_audio":
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config).to(dtype) prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config).to(dtype)
else: else:
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config)[0].to(dtype) prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config)[0].to(dtype)
...@@ -695,7 +696,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -695,7 +696,7 @@ class WanAudioRunner(WanRunner): # type:ignore
num_channels_latents = 16 num_channels_latents = 16
if self.config.model_cls == "wan2.2_audio": if self.config.model_cls == "wan2.2_audio":
num_channels_latents = self.config.num_channels_latents num_channels_latents = self.config.num_channels_latents
if self.config.task == "i2v": if self.config.task == "i2v":
self.config.target_shape = ( self.config.target_shape = (
num_channels_latents, num_channels_latents,
...@@ -813,6 +814,7 @@ class Wan22AudioRunner(WanAudioRunner): ...@@ -813,6 +814,7 @@ class Wan22AudioRunner(WanAudioRunner):
vae_decoder = self.load_vae_decoder() vae_decoder = self.load_vae_decoder()
return vae_encoder, vae_decoder return vae_encoder, vae_decoder
@RUNNER_REGISTER("wan2.2_moe_audio") @RUNNER_REGISTER("wan2.2_moe_audio")
class Wan22MoeAudioRunner(WanAudioRunner): class Wan22MoeAudioRunner(WanAudioRunner):
def __init__(self, config): def __init__(self, config):
......
...@@ -7,7 +7,6 @@ import torch.nn.functional as F ...@@ -7,7 +7,6 @@ import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from lightx2v.utils.utils import load_weights from lightx2v.utils.utils import load_weights
from loguru import logger
__all__ = [ __all__ = [
"Wan2_2_VAE", "Wan2_2_VAE",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment