Commit b5bcbed7 authored by wangshankun's avatar wangshankun
Browse files

重构audio的prepare_prev_latents

parent 99a6f046
{ {
"infer_steps": 4, "infer_steps": 4,
"target_fps": 16, "target_fps": 16,
"video_duration": 5, "video_duration": 12,
"audio_sr": 16000, "audio_sr": 16000,
"target_video_length": 81, "target_video_length": 81,
"target_height": 720, "target_height": 720,
......
...@@ -8,7 +8,7 @@ from lightx2v.common.ops import * ...@@ -8,7 +8,7 @@ from lightx2v.common.ops import *
from lightx2v.models.runners.cogvideox.cogvidex_runner import CogvideoxRunner # noqa: F401 from lightx2v.models.runners.cogvideox.cogvidex_runner import CogvideoxRunner # noqa: F401
from lightx2v.models.runners.graph_runner import GraphRunner from lightx2v.models.runners.graph_runner import GraphRunner
from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner # noqa: F401 from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner # noqa: F401
from lightx2v.models.runners.wan.wan_audio_runner import Wan22MoeAudioRunner, WanAudioRunner, Wan22AudioRunner # noqa: F401 from lightx2v.models.runners.wan.wan_audio_runner import Wan22AudioRunner, Wan22MoeAudioRunner, WanAudioRunner # noqa: F401
from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner # noqa: F401 from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner # noqa: F401
from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner # noqa: F401 from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner # noqa: F401
from lightx2v.models.runners.wan.wan_runner import Wan22MoeRunner, WanRunner # noqa: F401 from lightx2v.models.runners.wan.wan_runner import Wan22MoeRunner, WanRunner # noqa: F401
...@@ -39,7 +39,20 @@ def main(): ...@@ -39,7 +39,20 @@ def main():
"--model_cls", "--model_cls",
type=str, type=str,
required=True, required=True,
choices=["wan2.1", "hunyuan", "wan2.1_distill", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox", "wan2.1_audio", "wan2.2_moe", "wan2.2_moe_audio", "wan2.2_audio", "wan2.2", "wan2.2_moe_distill"], choices=[
"wan2.1",
"hunyuan",
"wan2.1_distill",
"wan2.1_causvid",
"wan2.1_skyreels_v2_df",
"cogvideox",
"wan2.1_audio",
"wan2.2_moe",
"wan2.2_moe_audio",
"wan2.2_audio",
"wan2.2",
"wan2.2_moe_distill",
],
default="wan2.1", default="wan2.1",
) )
......
...@@ -145,7 +145,12 @@ class PerceiverAttentionCA(nn.Module): ...@@ -145,7 +145,12 @@ class PerceiverAttentionCA(nn.Module):
batchsize = len(x) batchsize = len(x)
x = self.norm_kv(x) x = self.norm_kv(x)
shift, scale, gate = (t_emb + self.shift_scale_gate).chunk(3, dim=1) shift, scale, gate = (t_emb + self.shift_scale_gate).chunk(3, dim=1)
latents = self.norm_q(latents) * (1 + scale) + shift norm_q = self.norm_q(latents)
if scale.shape[0] != norm_q.shape[0]:
scale = scale.transpose(0, 1) # (1, 5070, 3072)
shift = shift.transpose(0, 1)
gate = gate.transpose(0, 1)
latents = norm_q * (1 + scale) + shift
q = self.to_q(latents.to(GET_DTYPE())) q = self.to_q(latents.to(GET_DTYPE()))
k, v = self.to_kv(x).chunk(2, dim=-1) k, v = self.to_kv(x).chunk(2, dim=-1)
q = rearrange(q, "B L (H C) -> (B L) H C", H=self.heads) q = rearrange(q, "B L (H C) -> (B L) H C", H=self.heads)
...@@ -222,16 +227,23 @@ class TimeEmbedding(nn.Module): ...@@ -222,16 +227,23 @@ class TimeEmbedding(nn.Module):
self.act_fn = nn.SiLU() self.act_fn = nn.SiLU()
self.time_proj = nn.Linear(dim, time_proj_dim) self.time_proj = nn.Linear(dim, time_proj_dim)
def forward( def forward(self, timestep: torch.Tensor):
self, # Project timestep
timestep: torch.Tensor, if timestep.dim() == 2:
): timestep = self.timesteps_proj(timestep.squeeze(0)).unsqueeze(0)
timestep = self.timesteps_proj(timestep) else:
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype timestep = self.timesteps_proj(timestep)
timestep = timestep.to(time_embedder_dtype)
# Match dtype with time_embedder (except int8)
target_dtype = next(self.time_embedder.parameters()).dtype
if timestep.dtype != target_dtype and target_dtype != torch.int8:
timestep = timestep.to(target_dtype)
# Time embedding projection
temb = self.time_embedder(timestep) temb = self.time_embedder(timestep)
timestep_proj = self.time_proj(self.act_fn(temb)) timestep_proj = self.time_proj(self.act_fn(temb))
return timestep_proj
return timestep_proj.squeeze(0) if timestep_proj.dim() == 3 else timestep_proj
class AudioAdapter(nn.Module): class AudioAdapter(nn.Module):
......
import math
import torch import torch
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from ..module_io import WanPreInferModuleOutput from ..module_io import WanPreInferModuleOutput
from ..utils import rope_params, sinusoidal_embedding_1d, masks_like from ..utils import rope_params, sinusoidal_embedding_1d
from loguru import logger
class WanAudioPreInfer(WanPreInfer): class WanAudioPreInfer(WanPreInfer):
def __init__(self, config): def __init__(self, config):
...@@ -28,13 +30,17 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -28,13 +30,17 @@ class WanAudioPreInfer(WanPreInfer):
self.infer_dtype = GET_DTYPE() self.infer_dtype = GET_DTYPE()
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE() self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
def infer(self, weights, inputs, positive): if config.parallel:
self.sp_size = config.parallel.get("seq_p_size", 1)
else:
self.sp_size = 1
def infer(self, weights, inputs, positive):
prev_latents = inputs["previmg_encoder_output"]["prev_latents"] prev_latents = inputs["previmg_encoder_output"]["prev_latents"]
if self.config.model_cls == "wan2.2_audio": if self.config.model_cls == "wan2.2_audio":
hidden_states = self.scheduler.latents hidden_states = self.scheduler.latents
mask1, mask2 = masks_like([hidden_states], zero=True, prev_length=hidden_states.shape[1]) prev_mask = inputs["previmg_encoder_output"]["prev_mask"]
hidden_states = (1. - mask2[0]) * prev_latents + mask2[0] * hidden_states hidden_states = (1.0 - prev_mask[0]) * prev_latents + prev_mask[0] * hidden_states
else: else:
prev_latents = prev_latents.unsqueeze(0) prev_latents = prev_latents.unsqueeze(0)
prev_mask = inputs["previmg_encoder_output"]["prev_mask"] prev_mask = inputs["previmg_encoder_output"]["prev_mask"]
...@@ -45,6 +51,16 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -45,6 +51,16 @@ class WanAudioPreInfer(WanPreInfer):
x = [hidden_states] x = [hidden_states]
t = torch.stack([self.scheduler.timesteps[self.scheduler.step_index]]) t = torch.stack([self.scheduler.timesteps[self.scheduler.step_index]])
if self.config.model_cls == "wan2.2_audio":
_, lat_f, lat_h, lat_w = self.scheduler.latents.shape
F = (lat_f - 1) * self.config.vae_stride[0] + 1
max_seq_len = ((F - 1) // self.config.vae_stride[0] + 1) * lat_h * lat_w // (self.config.patch_size[1] * self.config.patch_size[2])
max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
temp_ts = (prev_mask[0][0][:, ::2, ::2] * t).flatten()
temp_ts = torch.cat([temp_ts, temp_ts.new_ones(max_seq_len - temp_ts.size(0)) * t])
t = temp_ts.unsqueeze(0)
audio_dit_blocks = [] audio_dit_blocks = []
audio_encoder_output = inputs["audio_encoder_output"] audio_encoder_output = inputs["audio_encoder_output"]
audio_model_input = { audio_model_input = {
...@@ -53,7 +69,7 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -53,7 +69,7 @@ class WanAudioPreInfer(WanPreInfer):
"timestep": t, "timestep": t,
} }
audio_dit_blocks.append(inputs["audio_adapter_pipe"](**audio_model_input)) audio_dit_blocks.append(inputs["audio_adapter_pipe"](**audio_model_input))
audio_dit_blocks = None##Debug Drop Audio # audio_dit_blocks = None##Debug Drop Audio
if positive: if positive:
context = inputs["text_encoder_output"]["context"] context = inputs["text_encoder_output"]["context"]
...@@ -66,7 +82,7 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -66,7 +82,7 @@ class WanAudioPreInfer(WanPreInfer):
batch_size = len(x) batch_size = len(x)
num_channels, _, height, width = x[0].shape num_channels, _, height, width = x[0].shape
_, ref_num_channels, ref_num_frames, _, _ = ref_image_encoder.shape _, ref_num_channels, ref_num_frames, _, _ = ref_image_encoder.shape
if ref_num_channels != num_channels: if ref_num_channels != num_channels:
zero_padding = torch.zeros( zero_padding = torch.zeros(
(batch_size, num_channels - ref_num_channels, ref_num_frames, height, width), (batch_size, num_channels - ref_num_channels, ref_num_frames, height, width),
......
...@@ -4,38 +4,6 @@ import torch.distributed as dist ...@@ -4,38 +4,6 @@ import torch.distributed as dist
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
def masks_like(tensor, zero=False, generator=None, p=0.2, prev_length=1):
assert isinstance(tensor, list)
out1 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor]
out2 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor]
if prev_length == 0:
return out1, out2
if zero:
if generator is not None:
for u, v in zip(out1, out2):
random_num = torch.rand(
1, generator=generator, device=generator.device).item()
if random_num < p:
u[:, :prev_length] = torch.normal(
mean=-3.5,
std=0.5,
size=(1,),
device=u.device,
generator=generator).expand_as(u[:, :prev_length]).exp()
v[:, :prev_length] = torch.zeros_like(v[:, :prev_length])
else:
u[:, :prev_length] = u[:, :prev_length]
v[:, :prev_length] = v[:, :prev_length]
else:
for u, v in zip(out1, out2):
u[:, :prev_length] = torch.zeros_like(u[:, :prev_length])
v[:, :prev_length] = torch.zeros_like(v[:, :prev_length])
return out1, out2
def compute_freqs(c, grid_sizes, freqs): def compute_freqs(c, grid_sizes, freqs):
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0] f, h, w = grid_sizes[0]
......
...@@ -21,11 +21,12 @@ from lightx2v.models.networks.wan.audio_model import Wan22MoeAudioModel, WanAudi ...@@ -21,11 +21,12 @@ from lightx2v.models.networks.wan.audio_model import Wan22MoeAudioModel, WanAudi
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.runners.wan.wan_runner import MultiModelStruct, WanRunner from lightx2v.models.runners.wan.wan_runner import MultiModelStruct, WanRunner
from lightx2v.models.schedulers.wan.audio.scheduler import ConsistencyModelScheduler from lightx2v.models.schedulers.wan.audio.scheduler import ConsistencyModelScheduler
from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.profiler import ProfilingContext, ProfilingContext4Debug from lightx2v.utils.profiler import ProfilingContext, ProfilingContext4Debug
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image, find_torch_model_path from lightx2v.utils.utils import find_torch_model_path, save_to_video, vae_to_comfyui_image
from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE
@contextmanager @contextmanager
def memory_efficient_inference(): def memory_efficient_inference():
...@@ -257,9 +258,6 @@ class VideoGenerator: ...@@ -257,9 +258,6 @@ class VideoGenerator:
def prepare_prev_latents(self, prev_video: Optional[torch.Tensor], prev_frame_length: int) -> Optional[Dict[str, torch.Tensor]]: def prepare_prev_latents(self, prev_video: Optional[torch.Tensor], prev_frame_length: int) -> Optional[Dict[str, torch.Tensor]]:
"""Prepare previous latents for conditioning""" """Prepare previous latents for conditioning"""
if prev_video is None:
return None
device = torch.device("cuda") device = torch.device("cuda")
dtype = GET_DTYPE() dtype = GET_DTYPE()
vae_dtype = torch.float vae_dtype = torch.float
...@@ -267,22 +265,29 @@ class VideoGenerator: ...@@ -267,22 +265,29 @@ class VideoGenerator:
tgt_h, tgt_w = self.config.tgt_h, self.config.tgt_w tgt_h, tgt_w = self.config.tgt_h, self.config.tgt_w
prev_frames = torch.zeros((1, 3, self.config.target_video_length, tgt_h, tgt_w), device=device) prev_frames = torch.zeros((1, 3, self.config.target_video_length, tgt_h, tgt_w), device=device)
# Extract and process last frames if prev_video is not None:
last_frames = prev_video[:, :, -prev_frame_length:].clone().to(device) # Extract and process last frames
last_frames = self.frame_preprocessor.process_prev_frames(last_frames) last_frames = prev_video[:, :, -prev_frame_length:].clone().to(device)
last_frames = self.frame_preprocessor.process_prev_frames(last_frames)
prev_frames[:, :, :prev_frame_length] = last_frames
prev_frames[:, :, :prev_frame_length] = last_frames
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config)[0].to(dtype)
# Create mask
prev_token_length = (prev_frame_length - 1) // 4 + 1
_, nframe, height, width = self.model.scheduler.latents.shape _, nframe, height, width = self.model.scheduler.latents.shape
frames_n = (nframe - 1) * 4 + 1 if self.config.model_cls == "wan2.2_audio":
prev_frame_len = max((prev_token_length - 1) * 4 + 1, 0) prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config).to(dtype)
_, prev_mask = self._wan22_masks_like([self.model.scheduler.latents], zero=True, prev_length=prev_latents.shape[1])
else:
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config)[0].to(dtype)
if prev_video is not None:
prev_token_length = (prev_frame_length - 1) // 4 + 1
prev_frame_len = max((prev_token_length - 1) * 4 + 1, 0)
else:
prev_frame_len = 0
prev_mask = torch.ones((1, frames_n, height, width), device=device, dtype=dtype) frames_n = (nframe - 1) * 4 + 1
prev_mask[:, prev_frame_len:] = 0 prev_mask = torch.ones((1, frames_n, height, width), device=device, dtype=dtype)
prev_mask = self._wan_mask_rearrange(prev_mask).unsqueeze(0) prev_mask[:, prev_frame_len:] = 0
prev_mask = self._wan_mask_rearrange(prev_mask).unsqueeze(0)
if prev_latents.shape[-2:] != (height, width): if prev_latents.shape[-2:] != (height, width):
logger.warning(f"Size mismatch: prev_latents {prev_latents.shape} vs scheduler latents (H={height}, W={width}). Config tgt_h={self.config.tgt_h}, tgt_w={self.config.tgt_w}") logger.warning(f"Size mismatch: prev_latents {prev_latents.shape} vs scheduler latents (H={height}, W={width}). Config tgt_h={self.config.tgt_h}, tgt_w={self.config.tgt_w}")
...@@ -290,6 +295,31 @@ class VideoGenerator: ...@@ -290,6 +295,31 @@ class VideoGenerator:
return {"prev_latents": prev_latents, "prev_mask": prev_mask} return {"prev_latents": prev_latents, "prev_mask": prev_mask}
def _wan22_masks_like(self, tensor, zero=False, generator=None, p=0.2, prev_length=1):
assert isinstance(tensor, list)
out1 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor]
out2 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor]
if prev_length == 0:
return out1, out2
if zero:
if generator is not None:
for u, v in zip(out1, out2):
random_num = torch.rand(1, generator=generator, device=generator.device).item()
if random_num < p:
u[:, :prev_length] = torch.normal(mean=-3.5, std=0.5, size=(1,), device=u.device, generator=generator).expand_as(u[:, :prev_length]).exp()
v[:, :prev_length] = torch.zeros_like(v[:, :prev_length])
else:
u[:, :prev_length] = u[:, :prev_length]
v[:, :prev_length] = v[:, :prev_length]
else:
for u, v in zip(out1, out2):
u[:, :prev_length] = torch.zeros_like(u[:, :prev_length])
v[:, :prev_length] = torch.zeros_like(v[:, :prev_length])
return out1, out2
def _wan_mask_rearrange(self, mask: torch.Tensor) -> torch.Tensor: def _wan_mask_rearrange(self, mask: torch.Tensor) -> torch.Tensor:
"""Rearrange mask for WAN model""" """Rearrange mask for WAN model"""
if mask.ndim == 3: if mask.ndim == 3:
...@@ -312,52 +342,7 @@ class VideoGenerator: ...@@ -312,52 +342,7 @@ class VideoGenerator:
if segment_idx > 0: if segment_idx > 0:
self.model.scheduler.reset() self.model.scheduler.reset()
# Prepare previous latents - ALWAYS needed, even for first segment inputs["previmg_encoder_output"] = self.prepare_prev_latents(prev_video, prev_frame_length)
device = torch.device("cuda")
dtype = GET_DTYPE()
vae_dtype = torch.float
tgt_h, tgt_w = self.config.tgt_h, self.config.tgt_w
max_num_frames = self.config.target_video_length
if segment_idx == 0:
# First segment - create zero frames
prev_frames = torch.zeros((1, 3, max_num_frames, tgt_h, tgt_w), device=device)
if self.config.model_cls == 'wan2.2_audio':
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config).to(dtype)
else:
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config)[0].to(dtype)
prev_len = 0
else:
# Subsequent segments - use previous video
previmg_encoder_output = self.prepare_prev_latents(prev_video, prev_frame_length)
if previmg_encoder_output:
prev_latents = previmg_encoder_output["prev_latents"]
prev_len = (prev_frame_length - 1) // 4 + 1
else:
# Fallback to zeros if prepare_prev_latents fails
prev_frames = torch.zeros((1, 3, max_num_frames, tgt_h, tgt_w), device=device)
if self.config.model_cls == 'wan2.2_audio':
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config).to(dtype)
else:
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config)[0].to(dtype)
prev_len = 0
# Create mask for prev_latents
_, nframe, height, width = self.model.scheduler.latents.shape
frames_n = (nframe - 1) * 4 + 1
prev_frame_len = max((prev_len - 1) * 4 + 1, 0)
prev_mask = torch.ones((1, frames_n, height, width), device=device, dtype=dtype)
prev_mask[:, prev_frame_len:] = 0
prev_mask = self._wan_mask_rearrange(prev_mask).unsqueeze(0)
if prev_latents.shape[-2:] != (height, width):
logger.warning(f"Size mismatch: prev_latents {prev_latents.shape} vs scheduler latents (H={height}, W={width}). Config tgt_h={self.config.tgt_h}, tgt_w={self.config.tgt_w}")
prev_latents = torch.nn.functional.interpolate(prev_latents, size=(height, width), mode="bilinear", align_corners=False)
# Always set previmg_encoder_output
inputs["previmg_encoder_output"] = {"prev_latents": prev_latents, "prev_mask": prev_mask}
# Run inference loop # Run inference loop
if total_steps is None: if total_steps is None:
...@@ -373,6 +358,10 @@ class VideoGenerator: ...@@ -373,6 +358,10 @@ class VideoGenerator:
with ProfilingContext4Debug("step_post"): with ProfilingContext4Debug("step_post"):
self.model.scheduler.step_post() self.model.scheduler.step_post()
if self.config.model_cls == "wan2.2_audio":
prev_mask = inputs["previmg_encoder_output"]["prev_mask"]
prev_latents = inputs["previmg_encoder_output"]["prev_latents"]
self.model.scheduler.latents = (1.0 - prev_mask[0]) * prev_latents + prev_mask[0] * self.model.scheduler.latents
if self.progress_callback: if self.progress_callback:
segment_progress = (segment_idx * total_steps + step_index + 1) / (self.total_segments * total_steps) segment_progress = (segment_idx * total_steps + step_index + 1) / (self.total_segments * total_steps)
...@@ -396,6 +385,11 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -396,6 +385,11 @@ class WanAudioRunner(WanRunner): # type:ignore
self._video_generator = None self._video_generator = None
self._audio_preprocess = None self._audio_preprocess = None
if self.seq_p_group is None:
self.sp_size = 1
else:
self.sp_size = dist.get_world_size(self.seq_p_group)
def initialize(self): def initialize(self):
"""Initialize all models once for multiple runs""" """Initialize all models once for multiple runs"""
...@@ -620,7 +614,6 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -620,7 +614,6 @@ class WanAudioRunner(WanRunner): # type:ignore
def load_transformer(self): def load_transformer(self):
"""Load transformer with LoRA support""" """Load transformer with LoRA support"""
base_model = WanAudioModel(self.config.model_path, self.config, self.init_device, self.seq_p_group) base_model = WanAudioModel(self.config.model_path, self.config, self.init_device, self.seq_p_group)
logger.info(f"Loaded base model: {self.config.model_path}")
if self.config.get("lora_configs") and self.config.lora_configs: if self.config.get("lora_configs") and self.config.lora_configs:
assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False) assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
lora_wrapper = WanLoraWrapper(base_model) lora_wrapper = WanLoraWrapper(base_model)
...@@ -695,7 +688,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -695,7 +688,7 @@ class WanAudioRunner(WanRunner): # type:ignore
num_channels_latents = 16 num_channels_latents = 16
if self.config.model_cls == "wan2.2_audio": if self.config.model_cls == "wan2.2_audio":
num_channels_latents = self.config.num_channels_latents num_channels_latents = self.config.num_channels_latents
if self.config.task == "i2v": if self.config.task == "i2v":
self.config.target_shape = ( self.config.target_shape = (
num_channels_latents, num_channels_latents,
...@@ -813,6 +806,7 @@ class Wan22AudioRunner(WanAudioRunner): ...@@ -813,6 +806,7 @@ class Wan22AudioRunner(WanAudioRunner):
vae_decoder = self.load_vae_decoder() vae_decoder = self.load_vae_decoder()
return vae_encoder, vae_decoder return vae_encoder, vae_decoder
@RUNNER_REGISTER("wan2.2_moe_audio") @RUNNER_REGISTER("wan2.2_moe_audio")
class Wan22MoeAudioRunner(WanAudioRunner): class Wan22MoeAudioRunner(WanAudioRunner):
def __init__(self, config): def __init__(self, config):
......
...@@ -7,7 +7,6 @@ import torch.nn.functional as F ...@@ -7,7 +7,6 @@ import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from lightx2v.utils.utils import load_weights from lightx2v.utils.utils import load_weights
from loguru import logger
__all__ = [ __all__ = [
"Wan2_2_VAE", "Wan2_2_VAE",
......
#!/bin/bash #!/bin/bash
# set path and first # set path and first
lightx2v_path= lightx2v_path="/home/wangshankun/code/LightX2V"
model_path= model_path="/data/nvme0/gushiqiao/models/Wan2.2-R2V812-Audio-5B"
export CUDA_VISIBLE_DEVICES=0,1,2,3 export CUDA_VISIBLE_DEVICES=0,1,2,3
......
#!/bin/bash #!/bin/bash
# set path and first # set path and first
lightx2v_path= lightx2v_path="/home/wangshankun/code/LightX2V"
model_path= model_path="/data/nvme0/models/Wan2.1-R2V721-Audio-14B-720P"
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
...@@ -11,6 +11,7 @@ source ${lightx2v_path}/scripts/base/base.sh ...@@ -11,6 +11,7 @@ source ${lightx2v_path}/scripts/base/base.sh
export TORCH_CUDA_ARCH_LIST="9.0" export TORCH_CUDA_ARCH_LIST="9.0"
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export ENABLE_GRAPH_MODE=false
export ENABLE_GRAPH_MODE=false export ENABLE_GRAPH_MODE=false
export SENSITIVE_LAYER_DTYPE=None export SENSITIVE_LAYER_DTYPE=None
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
# set path and first # set path and first
lightx2v_path= lightx2v_path="/home/wangshankun/code/LightX2V"
model_path= model_path="/data/nvme0/models/Wan2.1-R2V721-Audio-14B-720P"
# set environment variables # set environment variables
source ${lightx2v_path}/scripts/base/base.sh source ${lightx2v_path}/scripts/base/base.sh
......
#!/bin/bash #!/bin/bash
# set path and first # set path and first
lightx2v_path= lightx2v_path="/home/wangshankun/code/LightX2V"
model_path= model_path="/data/nvme0/gushiqiao/models/official_models/wan2.2/Wan2.2-TI2V-5B"
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
# set environment variables # set environment variables
source ${lightx2v_path}/scripts/base/base.sh source ${lightx2v_path}/scripts/base/base.sh
export ENABLE_GRAPH_MODE=false
python -m lightx2v.infer \ python -m lightx2v.infer \
--model_cls wan2.2 \ --model_cls wan2.2 \
--task i2v \ --task i2v \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment