Commit d91e8d68 authored by sandy's avatar sandy Committed by GitHub
Browse files

Audio/r2v 5b deb (#251)



* [Feat] audio driven for wan2.2-r2v-5b

* refactor: reduce vae args dim &&  timestep init

* update audio pre_infer

* update runner

* update scripts

* [Fix] move ref timestep padding to Scheduler

* fix wan i2v

---------
Co-authored-by: default avatarhelloyongyang <yongyang1030@163.com>
parent d71f936d
{ {
"infer_steps": 4, "infer_steps": 4,
"target_fps": 16, "target_fps": 24,
"video_duration": 12, "video_duration": 12,
"audio_sr": 16000, "audio_sr": 16000,
"target_video_length": 121, "target_video_length": 121,
...@@ -20,9 +20,11 @@ ...@@ -20,9 +20,11 @@
"offload_granularity": "model", "offload_granularity": "model",
"fps": 24, "fps": 24,
"use_image_encoder": false, "use_image_encoder": false,
"adaptive_resize": true,
"use_31_block": false,
"lora_configs": [ "lora_configs": [
{ {
"path": "/data/nvme0/models/wan_ti2v_5b_ref/20250812/model_ema.safetensors", "path": "/mnt/aigc/rtxiang/pretrain/qianhai_weights/lora_model.safetensors",
"strength": 0.125 "strength": 0.125
} }
] ]
......
...@@ -8,7 +8,7 @@ from lightx2v.models.runners.cogvideox.cogvidex_runner import CogvideoxRunner # ...@@ -8,7 +8,7 @@ from lightx2v.models.runners.cogvideox.cogvidex_runner import CogvideoxRunner #
from lightx2v.models.runners.graph_runner import GraphRunner from lightx2v.models.runners.graph_runner import GraphRunner
from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner # noqa: F401 from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner # noqa: F401
from lightx2v.models.runners.qwen_image.qwen_image_runner import QwenImageRunner # noqa: F401 from lightx2v.models.runners.qwen_image.qwen_image_runner import QwenImageRunner # noqa: F401
from lightx2v.models.runners.wan.wan_audio_runner import WanAudioRunner # noqa: F401 from lightx2v.models.runners.wan.wan_audio_runner import Wan22AudioRunner, WanAudioRunner # noqa: F401
from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner # noqa: F401 from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner # noqa: F401
from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner # noqa: F401 from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner # noqa: F401
from lightx2v.models.runners.wan.wan_runner import Wan22MoeRunner, WanRunner # noqa: F401 from lightx2v.models.runners.wan.wan_runner import Wan22MoeRunner, WanRunner # noqa: F401
......
...@@ -11,6 +11,7 @@ class WanAudioPostInfer(WanPostInfer): ...@@ -11,6 +11,7 @@ class WanAudioPostInfer(WanPostInfer):
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE()) @torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, x, pre_infer_out): def infer(self, x, pre_infer_out):
x = x[: pre_infer_out.seq_lens[0]] x = x[: pre_infer_out.seq_lens[0]]
pre_infer_out.grid_sizes[:, 0] -= 1
x = self.unpatchify(x, pre_infer_out.grid_sizes) x = self.unpatchify(x, pre_infer_out.grid_sizes)
if self.clean_cuda_cache: if self.clean_cuda_cache:
......
...@@ -24,19 +24,17 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -24,19 +24,17 @@ class WanAudioPreInfer(WanPreInfer):
self.freq_dim = config["freq_dim"] self.freq_dim = config["freq_dim"]
self.dim = config["dim"] self.dim = config["dim"]
self.text_len = config["text_len"] self.text_len = config["text_len"]
self.rope_t_dim = d // 2 - 2 * (d // 6)
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False) self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
self.infer_dtype = GET_DTYPE() self.infer_dtype = GET_DTYPE()
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE() self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
if config.parallel:
self.sp_size = config.parallel.get("seq_p_size", 1)
else:
self.sp_size = 1
def infer(self, weights, inputs): def infer(self, weights, inputs):
prev_latents = inputs["previmg_encoder_output"]["prev_latents"] prev_latents = inputs["previmg_encoder_output"]["prev_latents"]
hidden_states = self.scheduler.latents
if self.config.model_cls != "wan2.2_audio":
prev_mask = inputs["previmg_encoder_output"]["prev_mask"] prev_mask = inputs["previmg_encoder_output"]["prev_mask"]
hidden_states = torch.cat([self.scheduler.latents, prev_mask, prev_latents], dim=0) hidden_states = torch.cat([hidden_states, prev_mask, prev_latents], dim=0)
x = hidden_states x = hidden_states
t = self.scheduler.timestep_input t = self.scheduler.timestep_input
...@@ -45,11 +43,10 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -45,11 +43,10 @@ class WanAudioPreInfer(WanPreInfer):
context = inputs["text_encoder_output"]["context"] context = inputs["text_encoder_output"]["context"]
else: else:
context = inputs["text_encoder_output"]["context_null"] context = inputs["text_encoder_output"]["context_null"]
# seq_len = self.scheduler.seq_len
clip_fea = inputs["image_encoder_output"]["clip_encoder_out"] clip_fea = inputs["image_encoder_output"]["clip_encoder_out"]
ref_image_encoder = inputs["image_encoder_output"]["vae_encoder_out"].to(self.scheduler.latents.dtype) ref_image_encoder = inputs["image_encoder_output"]["vae_encoder_out"].to(self.scheduler.latents.dtype)
# batch_size = len(x)
num_channels, _, height, width = x.shape num_channels, _, height, width = x.shape
ref_num_channels, ref_num_frames, _, _ = ref_image_encoder.shape ref_num_channels, ref_num_frames, _, _ = ref_image_encoder.shape
...@@ -60,7 +57,7 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -60,7 +57,7 @@ class WanAudioPreInfer(WanPreInfer):
device=self.scheduler.latents.device, device=self.scheduler.latents.device,
) )
ref_image_encoder = torch.concat([ref_image_encoder, zero_padding], dim=0) ref_image_encoder = torch.concat([ref_image_encoder, zero_padding], dim=0)
y = ref_image_encoder # 第一个batch维度变成list y = ref_image_encoder
# embeddings # embeddings
x = weights.patch_embedding.apply(x.unsqueeze(0)) x = weights.patch_embedding.apply(x.unsqueeze(0))
...@@ -70,8 +67,11 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -70,8 +67,11 @@ class WanAudioPreInfer(WanPreInfer):
y = weights.patch_embedding.apply(y.unsqueeze(0)) y = weights.patch_embedding.apply(y.unsqueeze(0))
y = y.flatten(2).transpose(1, 2).contiguous() y = y.flatten(2).transpose(1, 2).contiguous()
x = torch.cat([x, y], dim=1).squeeze(0)
x = torch.cat([x, y], dim=1) ####for r2v # zero temporl component corresponding to ref embeddings
self.freqs[grid_sizes[0][0] :, : self.rope_t_dim] = 0
grid_sizes[:, 0] += 1
embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten()) embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten())
if self.sensitive_layer_dtype != self.infer_dtype: if self.sensitive_layer_dtype != self.infer_dtype:
...@@ -117,7 +117,7 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -117,7 +117,7 @@ class WanAudioPreInfer(WanPreInfer):
return WanPreInferModuleOutput( return WanPreInferModuleOutput(
embed=embed, embed=embed,
grid_sizes=grid_sizes, grid_sizes=grid_sizes,
x=x.squeeze(0), x=x,
embed0=embed0.squeeze(0), embed0=embed0.squeeze(0),
seq_lens=seq_lens, seq_lens=seq_lens,
freqs=self.freqs, freqs=self.freqs,
......
...@@ -3,7 +3,6 @@ import torch.distributed as dist ...@@ -3,7 +3,6 @@ import torch.distributed as dist
from lightx2v.models.input_encoders.hf.seko_audio.audio_adapter import get_q_lens_audio_range from lightx2v.models.input_encoders.hf.seko_audio.audio_adapter import get_q_lens_audio_range
from lightx2v.models.networks.wan.infer.offload.transformer_infer import WanOffloadTransformerInfer from lightx2v.models.networks.wan.infer.offload.transformer_infer import WanOffloadTransformerInfer
from lightx2v.models.networks.wan.infer.utils import compute_freqs_audio, compute_freqs_audio_dist
class WanAudioTransformerInfer(WanOffloadTransformerInfer): class WanAudioTransformerInfer(WanOffloadTransformerInfer):
...@@ -15,21 +14,15 @@ class WanAudioTransformerInfer(WanOffloadTransformerInfer): ...@@ -15,21 +14,15 @@ class WanAudioTransformerInfer(WanOffloadTransformerInfer):
def set_audio_adapter(self, audio_adapter): def set_audio_adapter(self, audio_adapter):
self.audio_adapter = audio_adapter self.audio_adapter = audio_adapter
@torch.no_grad()
def compute_freqs(self, q, grid_sizes, freqs):
if self.config["seq_parallel"]:
freqs_i = compute_freqs_audio_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs, self.seq_p_group)
else:
freqs_i = compute_freqs_audio(q.size(2) // 2, grid_sizes, freqs)
return freqs_i
@torch.no_grad() @torch.no_grad()
def post_process(self, x, y, c_gate_msa, pre_infer_out): def post_process(self, x, y, c_gate_msa, pre_infer_out):
x = super().post_process(x, y, c_gate_msa, pre_infer_out) x = super().post_process(x, y, c_gate_msa, pre_infer_out)
audio_grid_sizes = [row.clone() for row in pre_infer_out.grid_sizes]
audio_grid_sizes[0][0] -= 1
x = self.modify_hidden_states( x = self.modify_hidden_states(
hidden_states=x.to(self.infer_dtype), hidden_states=x.to(self.infer_dtype),
grid_sizes=pre_infer_out.grid_sizes, grid_sizes=audio_grid_sizes,
ca_block=self.audio_adapter.ca[self.block_idx], ca_block=self.audio_adapter.ca[self.block_idx],
audio_encoder_output=pre_infer_out.adapter_output["audio_encoder_output"], audio_encoder_output=pre_infer_out.adapter_output["audio_encoder_output"],
t_emb=self.scheduler.audio_adapter_t_emb, t_emb=self.scheduler.audio_adapter_t_emb,
......
...@@ -20,26 +20,6 @@ def compute_freqs(c, grid_sizes, freqs): ...@@ -20,26 +20,6 @@ def compute_freqs(c, grid_sizes, freqs):
return freqs_i return freqs_i
def compute_freqs_audio(c, grid_sizes, freqs):
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0]
valid_token_length = f * h * w
f = f + 1 ##for r2v add 1 channel
seq_len = f * h * w
freqs_i = torch.cat(
[
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), # 时间(帧)编码
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), # 空间(高度)编码
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1), # 空间(宽度)编码
],
dim=-1,
).reshape(seq_len, 1, -1)
freqs_i[valid_token_length:, :, :f] = 0 ###for r2v # zero temporl component corresponding to ref embeddings
return freqs_i
def compute_freqs_dist(s, c, grid_sizes, freqs, seq_p_group): def compute_freqs_dist(s, c, grid_sizes, freqs, seq_p_group):
world_size = dist.get_world_size(seq_p_group) world_size = dist.get_world_size(seq_p_group)
cur_rank = dist.get_rank(seq_p_group) cur_rank = dist.get_rank(seq_p_group)
...@@ -61,31 +41,6 @@ def compute_freqs_dist(s, c, grid_sizes, freqs, seq_p_group): ...@@ -61,31 +41,6 @@ def compute_freqs_dist(s, c, grid_sizes, freqs, seq_p_group):
return freqs_i_rank return freqs_i_rank
def compute_freqs_audio_dist(s, c, grid_sizes, freqs, seq_p_group):
world_size = dist.get_world_size(seq_p_group)
cur_rank = dist.get_rank(seq_p_group)
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0]
valid_token_length = f * h * w
f = f + 1
seq_len = f * h * w
freqs_i = torch.cat(
[
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
],
dim=-1,
).reshape(seq_len, 1, -1)
freqs_i[valid_token_length:, :, :f] = 0
freqs_i = pad_freqs(freqs_i, s * world_size)
s_per_rank = s
freqs_i_rank = freqs_i[(cur_rank * s_per_rank) : ((cur_rank + 1) * s_per_rank), :, :]
return freqs_i_rank
def compute_freqs_causvid(c, grid_sizes, freqs, start_frame=0): def compute_freqs_causvid(c, grid_sizes, freqs, start_frame=0):
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0] f, h, w = grid_sizes[0]
......
...@@ -21,10 +21,11 @@ from lightx2v.models.networks.wan.audio_model import WanAudioModel ...@@ -21,10 +21,11 @@ from lightx2v.models.networks.wan.audio_model import WanAudioModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.runners.wan.wan_runner import WanRunner from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.schedulers.wan.audio.scheduler import ConsistencyModelScheduler from lightx2v.models.schedulers.wan.audio.scheduler import ConsistencyModelScheduler
from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.profiler import ProfilingContext, ProfilingContext4Debug from lightx2v.utils.profiler import ProfilingContext, ProfilingContext4Debug
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import load_weights, save_to_video, vae_to_comfyui_image from lightx2v.utils.utils import find_torch_model_path, load_weights, save_to_video, vae_to_comfyui_image
def get_optimal_patched_size_with_sp(patched_h, patched_w, sp_size): def get_optimal_patched_size_with_sp(patched_h, patched_w, sp_size):
...@@ -300,7 +301,8 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -300,7 +301,8 @@ class WanAudioRunner(WanRunner): # type:ignore
self.vae_encoder = self.load_vae_encoder() self.vae_encoder = self.load_vae_encoder()
img = rearrange(img, "1 C H W -> 1 C 1 H W") img = rearrange(img, "1 C H W -> 1 C 1 H W")
vae_encoder_out = self.vae_encoder.encode(img.to(torch.float))[0] vae_encoder_out = self.vae_encoder.encode(img.to(torch.float)).to(GET_DTYPE())
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.vae_encoder del self.vae_encoder
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -339,30 +341,32 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -339,30 +341,32 @@ class WanAudioRunner(WanRunner): # type:ignore
if prev_video is not None: if prev_video is not None:
# Extract and process last frames # Extract and process last frames
last_frames = prev_video[:, :, -prev_frame_length:].clone().to(device) last_frames = prev_video[:, :, -prev_frame_length:].clone().to(device)
if self.config.model_cls != "wan2.2_audio":
last_frames = self.frame_preprocessor.process_prev_frames(last_frames) last_frames = self.frame_preprocessor.process_prev_frames(last_frames)
prev_frames[:, :, :prev_frame_length] = last_frames prev_frames[:, :, :prev_frame_length] = last_frames
prev_len = (prev_frame_length - 1) // 4 + 1
else:
prev_len = 0
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.vae_encoder = self.load_vae_encoder() self.vae_encoder = self.load_vae_encoder()
_, nframe, height, width = self.model.scheduler.latents.shape _, nframe, height, width = self.model.scheduler.latents.shape
if self.config.model_cls == "wan2.2_audio": if self.config.model_cls == "wan2.2_audio":
if prev_video is not None:
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype)).to(dtype) prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype)).to(dtype)
_, prev_mask = self._wan22_masks_like([self.model.scheduler.latents], zero=True, prev_length=prev_latents.shape[1])
else: else:
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype))[0].to(dtype) prev_latents = None
prev_mask = self.model.scheduler.mask
if prev_video is not None:
prev_token_length = (prev_frame_length - 1) // 4 + 1
prev_frame_len = max((prev_token_length - 1) * 4 + 1, 0)
else: else:
prev_frame_len = 0 prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype)).to(dtype)
frames_n = (nframe - 1) * 4 + 1 frames_n = (nframe - 1) * 4 + 1
prev_mask = torch.ones((1, frames_n, height, width), device=device, dtype=dtype) prev_mask = torch.ones((1, frames_n, height, width), device=device, dtype=dtype)
prev_mask[:, prev_frame_len:] = 0 prev_mask[:, prev_len:] = 0
prev_mask = self._wan_mask_rearrange(prev_mask) prev_mask = self._wan_mask_rearrange(prev_mask)
if prev_latents is not None:
if prev_latents.shape[-2:] != (height, width): if prev_latents.shape[-2:] != (height, width):
logger.warning(f"Size mismatch: prev_latents {prev_latents.shape} vs scheduler latents (H={height}, W={width}). Config tgt_h={self.config.tgt_h}, tgt_w={self.config.tgt_w}") logger.warning(f"Size mismatch: prev_latents {prev_latents.shape} vs scheduler latents (H={height}, W={width}). Config tgt_h={self.config.tgt_h}, tgt_w={self.config.tgt_w}")
prev_latents = torch.nn.functional.interpolate(prev_latents, size=(height, width), mode="bilinear", align_corners=False) prev_latents = torch.nn.functional.interpolate(prev_latents, size=(height, width), mode="bilinear", align_corners=False)
...@@ -372,7 +376,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -372,7 +376,7 @@ class WanAudioRunner(WanRunner): # type:ignore
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
return {"prev_latents": prev_latents, "prev_mask": prev_mask} return {"prev_latents": prev_latents, "prev_mask": prev_mask, "prev_len": prev_len}
def _wan_mask_rearrange(self, mask: torch.Tensor) -> torch.Tensor: def _wan_mask_rearrange(self, mask: torch.Tensor) -> torch.Tensor:
"""Rearrange mask for WAN model""" """Rearrange mask for WAN model"""
...@@ -413,12 +417,11 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -413,12 +417,11 @@ class WanAudioRunner(WanRunner): # type:ignore
audio_features = self.audio_adapter.forward_audio_proj(audio_features, self.model.scheduler.latents.shape[1]) audio_features = self.audio_adapter.forward_audio_proj(audio_features, self.model.scheduler.latents.shape[1])
self.inputs["audio_encoder_output"] = audio_features self.inputs["audio_encoder_output"] = audio_features
self.inputs["previmg_encoder_output"] = self.prepare_prev_latents(self.prev_video, prev_frame_length=5)
# Reset scheduler for non-first segments # Reset scheduler for non-first segments
if segment_idx > 0: if segment_idx > 0:
self.model.scheduler.reset() self.model.scheduler.reset(self.inputs["previmg_encoder_output"])
self.inputs["previmg_encoder_output"] = self.prepare_prev_latents(self.prev_video, prev_frame_length=5)
@ProfilingContext4Debug("End run segment") @ProfilingContext4Debug("End run segment")
def end_run_segment(self): def end_run_segment(self):
...@@ -600,3 +603,48 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -600,3 +603,48 @@ class WanAudioRunner(WanRunner): # type:ignore
ret["target_shape"] = self.config.target_shape ret["target_shape"] = self.config.target_shape
return ret return ret
@RUNNER_REGISTER("wan2.2_audio")
class Wan22AudioRunner(WanAudioRunner):
def __init__(self, config):
super().__init__(config)
def load_vae_decoder(self):
# offload config
vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
if vae_offload:
vae_device = torch.device("cpu")
else:
vae_device = torch.device("cuda")
vae_config = {
"vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.2_VAE.pth"),
"device": vae_device,
"cpu_offload": vae_offload,
"offload_cache": self.config.get("vae_offload_cache", False),
}
vae_decoder = Wan2_2_VAE(**vae_config)
return vae_decoder
def load_vae_encoder(self):
# offload config
vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
if vae_offload:
vae_device = torch.device("cpu")
else:
vae_device = torch.device("cuda")
vae_config = {
"vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.2_VAE.pth"),
"device": vae_device,
"cpu_offload": vae_offload,
"offload_cache": self.config.get("vae_offload_cache", False),
}
if self.config.task != "i2v":
return None
else:
return Wan2_2_VAE(**vae_config)
def load_vae(self):
vae_encoder = self.load_vae_encoder()
vae_decoder = self.load_vae_decoder()
return vae_encoder, vae_decoder
...@@ -309,7 +309,7 @@ class WanRunner(DefaultRunner): ...@@ -309,7 +309,7 @@ class WanRunner(DefaultRunner):
dim=1, dim=1,
).cuda() ).cuda()
vae_encoder_out = self.vae_encoder.encode([vae_input])[0] vae_encoder_out = self.vae_encoder.encode(vae_input.unsqueeze(0))
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.vae_encoder del self.vae_encoder
......
import gc import gc
import math
import numpy as np import numpy as np
import torch import torch
...@@ -6,12 +7,22 @@ from loguru import logger ...@@ -6,12 +7,22 @@ from loguru import logger
from lightx2v.models.schedulers.wan.scheduler import WanScheduler from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.utils import masks_like
class ConsistencyModelScheduler(WanScheduler): class ConsistencyModelScheduler(WanScheduler):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
if self.config.parallel:
self.sp_size = self.config.parallel.get("seq_p_size", 1)
else:
self.sp_size = 1
if self.config["model_cls"] == "wan2.2_audio":
self.prev_latents = None
self.prev_len = 0
def set_audio_adapter(self, audio_adapter): def set_audio_adapter(self, audio_adapter):
self.audio_adapter = audio_adapter self.audio_adapter = audio_adapter
...@@ -23,7 +34,45 @@ class ConsistencyModelScheduler(WanScheduler): ...@@ -23,7 +34,45 @@ class ConsistencyModelScheduler(WanScheduler):
if self.audio_adapter.cpu_offload: if self.audio_adapter.cpu_offload:
self.audio_adapter.time_embedding.to("cpu") self.audio_adapter.time_embedding.to("cpu")
def prepare(self, image_encoder_output=None): if self.config.model_cls == "wan2.2_audio":
_, lat_f, lat_h, lat_w = self.latents.shape
F = (lat_f - 1) * self.config.vae_stride[0] + 1
per_latent_token_len = lat_h * lat_w // (self.config.patch_size[1] * self.config.patch_size[2])
max_seq_len = ((F - 1) // self.config.vae_stride[0] + 1) * per_latent_token_len
max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
temp_ts = (self.mask[0][:, ::2, ::2] * self.timestep_input).flatten()
self.timestep_input = torch.cat([temp_ts, temp_ts.new_ones(max_seq_len - temp_ts.size(0)) * self.timestep_input]).unsqueeze(0)
self.timestep_input = torch.cat(
[
self.timestep_input,
torch.zeros(
(1, per_latent_token_len), # padding for reference frame latent
dtype=self.timestep_input.dtype,
device=self.timestep_input.device,
),
],
dim=1,
)
def prepare_latents(self, target_shape, dtype=torch.float32):
self.generator = torch.Generator(device=self.device).manual_seed(self.config.seed)
self.latents = torch.randn(
target_shape[0],
target_shape[1],
target_shape[2],
target_shape[3],
dtype=dtype,
device=self.device,
generator=self.generator,
)
if self.config["model_cls"] == "wan2.2_audio":
self.mask = masks_like(self.latents, zero=True, prev_len=self.prev_len)
if self.prev_latents is not None:
self.latents = (1.0 - self.mask) * self.prev_latents + self.mask * self.latents
def prepare(self, previmg_encoder_output=None):
self.prepare_latents(self.config.target_shape, dtype=torch.float32) self.prepare_latents(self.config.target_shape, dtype=torch.float32)
timesteps = np.linspace(self.num_train_timesteps, 0, self.infer_steps + 1, dtype=np.float32) timesteps = np.linspace(self.num_train_timesteps, 0, self.infer_steps + 1, dtype=np.float32)
...@@ -43,8 +92,13 @@ class ConsistencyModelScheduler(WanScheduler): ...@@ -43,8 +92,13 @@ class ConsistencyModelScheduler(WanScheduler):
x0 = sample - model_output * sigma x0 = sample - model_output * sigma
x_t_next = x0 * (1 - sigma_next) + sigma_next * torch.randn(x0.shape, dtype=x0.dtype, device=x0.device, generator=self.generator) x_t_next = x0 * (1 - sigma_next) + sigma_next * torch.randn(x0.shape, dtype=x0.dtype, device=x0.device, generator=self.generator)
self.latents = x_t_next self.latents = x_t_next
if self.config["model_cls"] == "wan2.2_audio" and self.prev_latents is not None:
self.latents = (1.0 - self.mask) * self.prev_latents + self.mask * self.latents
def reset(self): def reset(self, previmg_encoder_output=None):
if self.config["model_cls"] == "wan2.2_audio":
self.prev_latents = previmg_encoder_output["prev_latents"]
self.prev_len = previmg_encoder_output["prev_len"]
self.prepare_latents(self.config.target_shape, dtype=torch.float32) self.prepare_latents(self.config.target_shape, dtype=torch.float32)
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
...@@ -892,17 +892,17 @@ class WanVAE: ...@@ -892,17 +892,17 @@ class WanVAE:
self.inv_std = self.inv_std.cuda() self.inv_std = self.inv_std.cuda()
self.scale = [self.mean, self.inv_std] self.scale = [self.mean, self.inv_std]
def encode(self, videos): def encode(self, video):
""" """
videos: A list of videos each with shape [C, T, H, W]. video: one video with shape [1, C, T, H, W].
""" """
if self.cpu_offload: if self.cpu_offload:
self.to_cuda() self.to_cuda()
if self.use_tiling: if self.use_tiling:
out = [self.model.tiled_encode(u.unsqueeze(0).to(self.current_device()), self.scale).float().squeeze(0) for u in videos] out = self.model.tiled_encode(video, self.scale).float().squeeze(0)
else: else:
out = [self.model.encode(u.unsqueeze(0).to(self.current_device()), self.scale).float().squeeze(0) for u in videos] out = self.model.encode(video, self.scale).float().squeeze(0)
if self.cpu_offload: if self.cpu_offload:
self.to_cpu() self.to_cpu()
......
...@@ -985,10 +985,10 @@ class Wan2_2_VAE: ...@@ -985,10 +985,10 @@ class Wan2_2_VAE:
self.inv_std = self.inv_std.cuda() self.inv_std = self.inv_std.cuda()
self.scale = [self.mean, self.inv_std] self.scale = [self.mean, self.inv_std]
def encode(self, videos): def encode(self, video):
if self.cpu_offload: if self.cpu_offload:
self.to_cuda() self.to_cuda()
out = self.model.encode(videos.unsqueeze(0), self.scale).float().squeeze(0) out = self.model.encode(video, self.scale).float().squeeze(0)
if self.cpu_offload: if self.cpu_offload:
self.to_cpu() self.to_cpu()
return out return out
......
...@@ -434,17 +434,16 @@ def load_weights(checkpoint_path, cpu_offload=False, remove_key=None): ...@@ -434,17 +434,16 @@ def load_weights(checkpoint_path, cpu_offload=False, remove_key=None):
return distributed_weight_dict return distributed_weight_dict
def masks_like(tensor, zero=False, generator=None, p=0.2): def masks_like(tensor, zero=False, generator=None, p=0.2, prev_len=1):
assert isinstance(tensor, torch.Tensor) assert isinstance(tensor, torch.Tensor)
out = torch.ones_like(tensor) out = torch.ones_like(tensor)
if zero: if zero:
if generator is not None: if generator is not None:
random_num = torch.rand(1, generator=generator, device=generator.device).item() random_num = torch.rand(1, generator=generator, device=generator.device).item()
if random_num < p: if random_num < p:
out[:, 0] = torch.zeros_like(out[:, 0]) out[:, :prev_len] = torch.zeros_like(out[:, :prev_len])
else: else:
out[:, 0] = torch.zeros_like(out[:, 0]) out[:, :prev_len] = torch.zeros_like(out[:, :prev_len])
return out return out
......
...@@ -9,6 +9,8 @@ export CUDA_VISIBLE_DEVICES=0 ...@@ -9,6 +9,8 @@ export CUDA_VISIBLE_DEVICES=0
# set environment variables # set environment variables
source ${lightx2v_path}/scripts/base/base.sh source ${lightx2v_path}/scripts/base/base.sh
export ENABLE_GRAPH_MODE=false
export SENSITIVE_LAYER_DTYPE=None
python -m lightx2v.infer \ python -m lightx2v.infer \
--model_cls wan2.2_audio \ --model_cls wan2.2_audio \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment