Commit 9e3680b7 authored by helloyongyang's avatar helloyongyang
Browse files

fix ci

parent 7367d6c8
......@@ -8,7 +8,7 @@ from lightx2v.common.ops import *
from lightx2v.models.runners.cogvideox.cogvidex_runner import CogvideoxRunner # noqa: F401
from lightx2v.models.runners.graph_runner import GraphRunner
from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner # noqa: F401
from lightx2v.models.runners.wan.wan_audio_runner import Wan22MoeAudioRunner, WanAudioRunner, Wan22AudioRunner # noqa: F401
from lightx2v.models.runners.wan.wan_audio_runner import Wan22AudioRunner, Wan22MoeAudioRunner, WanAudioRunner # noqa: F401
from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner # noqa: F401
from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner # noqa: F401
from lightx2v.models.runners.wan.wan_runner import Wan22MoeRunner, WanRunner # noqa: F401
......@@ -39,7 +39,20 @@ def main():
"--model_cls",
type=str,
required=True,
choices=["wan2.1", "hunyuan", "wan2.1_distill", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox", "wan2.1_audio", "wan2.2_moe", "wan2.2_moe_audio", "wan2.2_audio", "wan2.2", "wan2.2_moe_distill"],
choices=[
"wan2.1",
"hunyuan",
"wan2.1_distill",
"wan2.1_causvid",
"wan2.1_skyreels_v2_df",
"cogvideox",
"wan2.1_audio",
"wan2.2_moe",
"wan2.2",
"wan2.2_moe_audio",
"wan2.2_audio",
"wan2.2_moe_distill",
],
default="wan2.1",
)
......
......@@ -4,8 +4,8 @@ from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.utils.envs import *
from ..module_io import WanPreInferModuleOutput
from ..utils import rope_params, sinusoidal_embedding_1d, masks_like
from loguru import logger
from ..utils import masks_like, rope_params, sinusoidal_embedding_1d
class WanAudioPreInfer(WanPreInfer):
def __init__(self, config):
......@@ -29,12 +29,11 @@ class WanAudioPreInfer(WanPreInfer):
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
def infer(self, weights, inputs, positive):
prev_latents = inputs["previmg_encoder_output"]["prev_latents"]
if self.config.model_cls == "wan2.2_audio":
hidden_states = self.scheduler.latents
mask1, mask2 = masks_like([hidden_states], zero=True, prev_length=hidden_states.shape[1])
hidden_states = (1. - mask2[0]) * prev_latents + mask2[0] * hidden_states
hidden_states = (1.0 - mask2[0]) * prev_latents + mask2[0] * hidden_states
else:
prev_latents = prev_latents.unsqueeze(0)
prev_mask = inputs["previmg_encoder_output"]["prev_mask"]
......@@ -53,7 +52,7 @@ class WanAudioPreInfer(WanPreInfer):
"timestep": t,
}
audio_dit_blocks.append(inputs["audio_adapter_pipe"](**audio_model_input))
audio_dit_blocks = None##Debug Drop Audio
audio_dit_blocks = None ##Debug Drop Audio
if positive:
context = inputs["text_encoder_output"]["context"]
......@@ -66,7 +65,7 @@ class WanAudioPreInfer(WanPreInfer):
batch_size = len(x)
num_channels, _, height, width = x[0].shape
_, ref_num_channels, ref_num_frames, _, _ = ref_image_encoder.shape
if ref_num_channels != num_channels:
zero_padding = torch.zeros(
(batch_size, num_channels - ref_num_channels, ref_num_frames, height, width),
......
......@@ -15,15 +15,9 @@ def masks_like(tensor, zero=False, generator=None, p=0.2, prev_length=1):
if zero:
if generator is not None:
for u, v in zip(out1, out2):
random_num = torch.rand(
1, generator=generator, device=generator.device).item()
random_num = torch.rand(1, generator=generator, device=generator.device).item()
if random_num < p:
u[:, :prev_length] = torch.normal(
mean=-3.5,
std=0.5,
size=(1,),
device=u.device,
generator=generator).expand_as(u[:, :prev_length]).exp()
u[:, :prev_length] = torch.normal(mean=-3.5, std=0.5, size=(1,), device=u.device, generator=generator).expand_as(u[:, :prev_length]).exp()
v[:, :prev_length] = torch.zeros_like(v[:, :prev_length])
else:
u[:, :prev_length] = u[:, :prev_length]
......
......@@ -21,11 +21,12 @@ from lightx2v.models.networks.wan.audio_model import Wan22MoeAudioModel, WanAudi
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.runners.wan.wan_runner import MultiModelStruct, WanRunner
from lightx2v.models.schedulers.wan.audio.scheduler import ConsistencyModelScheduler
from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE
from lightx2v.utils.envs import *
from lightx2v.utils.profiler import ProfilingContext, ProfilingContext4Debug
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image, find_torch_model_path
from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE
from lightx2v.utils.utils import find_torch_model_path, save_to_video, vae_to_comfyui_image
@contextmanager
def memory_efficient_inference():
......@@ -322,7 +323,7 @@ class VideoGenerator:
if segment_idx == 0:
# First segment - create zero frames
prev_frames = torch.zeros((1, 3, max_num_frames, tgt_h, tgt_w), device=device)
if self.config.model_cls == 'wan2.2_audio':
if self.config.model_cls == "wan2.2_audio":
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config).to(dtype)
else:
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config)[0].to(dtype)
......@@ -337,7 +338,7 @@ class VideoGenerator:
else:
# Fallback to zeros if prepare_prev_latents fails
prev_frames = torch.zeros((1, 3, max_num_frames, tgt_h, tgt_w), device=device)
if self.config.model_cls == 'wan2.2_audio':
if self.config.model_cls == "wan2.2_audio":
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config).to(dtype)
else:
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config)[0].to(dtype)
......@@ -695,7 +696,7 @@ class WanAudioRunner(WanRunner): # type:ignore
num_channels_latents = 16
if self.config.model_cls == "wan2.2_audio":
num_channels_latents = self.config.num_channels_latents
if self.config.task == "i2v":
self.config.target_shape = (
num_channels_latents,
......@@ -813,6 +814,7 @@ class Wan22AudioRunner(WanAudioRunner):
vae_decoder = self.load_vae_decoder()
return vae_encoder, vae_decoder
@RUNNER_REGISTER("wan2.2_moe_audio")
class Wan22MoeAudioRunner(WanAudioRunner):
def __init__(self, config):
......
......@@ -7,7 +7,6 @@ import torch.nn.functional as F
from einops import rearrange
from lightx2v.utils.utils import load_weights
from loguru import logger
__all__ = [
"Wan2_2_VAE",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment