"mmdet3d/datasets/vscode:/vscode.git/clone" did not exist on "32a4328b16b85aae26d08d81157ab74b58edcdb1"
Commit d8454a2b authored by helloyongyang's avatar helloyongyang
Browse files

Refactor runners

parent 2054eca3
...@@ -18,5 +18,7 @@ ...@@ -18,5 +18,7 @@
"dit_quantized_ckpt": "/path/to/Wan2.1-R2V721-Audio-14B-720P/fp8", "dit_quantized_ckpt": "/path/to/Wan2.1-R2V721-Audio-14B-720P/fp8",
"mm_config": { "mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm" "mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm"
} },
"adapter_quantized": true,
"adapter_quant_scheme": "fp8"
} }
...@@ -6,6 +6,11 @@ try: ...@@ -6,6 +6,11 @@ try:
except ModuleNotFoundError: except ModuleNotFoundError:
ops = None ops = None
try:
import sgl_kernel
except ImportError:
sgl_kernel = None
try: try:
from torchao.quantization.utils import quant_int8_per_token_matmul, quantize_activation_per_token_absmax from torchao.quantization.utils import quant_int8_per_token_matmul, quantize_activation_per_token_absmax
except ModuleNotFoundError: except ModuleNotFoundError:
...@@ -117,6 +122,58 @@ class VllmQuantLinearFp8(nn.Module): ...@@ -117,6 +122,58 @@ class VllmQuantLinearFp8(nn.Module):
return self return self
class SglQuantLinearFp8(nn.Module):
def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.float8_e4m3fn))
self.register_buffer("weight_scale", torch.empty((out_features, 1), dtype=torch.float32))
if bias:
self.register_buffer("bias", torch.empty(out_features, dtype=dtype))
else:
self.register_buffer("bias", None)
def act_quant_func(self, x):
m, k = x.shape
input_tensor_quant = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda", requires_grad=False)
input_tensor_scale = torch.empty((m, 1), dtype=torch.float32, device="cuda", requires_grad=False)
sgl_kernel.sgl_per_token_quant_fp8(x, input_tensor_quant, input_tensor_scale)
return input_tensor_quant, input_tensor_scale
def forward(self, input_tensor):
input_tensor = input_tensor.squeeze(0)
shape = (input_tensor.shape[0], self.weight.shape[0])
dtype = input_tensor.dtype
device = input_tensor.device
output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = sgl_kernel.fp8_scaled_mm(
input_tensor_quant,
self.weight.t(),
input_tensor_scale,
self.weight_scale,
dtype,
bias=self.bias,
)
return output_tensor.unsqueeze(0)
def _apply(self, fn):
for module in self.children():
module._apply(fn)
def maybe_cast(t):
if t is not None and t.device != fn(t).device:
return fn(t)
return t
self.weight = maybe_cast(self.weight)
self.weight_scale = maybe_cast(self.weight_scale)
self.bias = maybe_cast(self.bias)
return self
class TorchaoQuantLinearInt8(nn.Module): class TorchaoQuantLinearInt8(nn.Module):
def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16): def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16):
super().__init__() super().__init__()
......
...@@ -13,9 +13,8 @@ import torch.nn.functional as F ...@@ -13,9 +13,8 @@ import torch.nn.functional as F
from diffusers.models.embeddings import TimestepEmbedding, Timesteps from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from einops import rearrange from einops import rearrange
from loguru import logger from loguru import logger
from transformers import AutoModel
from lightx2v.utils.envs import * from lightx2v.models.input_encoders.hf.q_linear import SglQuantLinearFp8
def load_safetensors(in_path: str): def load_safetensors(in_path: str):
...@@ -84,8 +83,6 @@ def rank0_load_state_dict_from_path(model, in_path: str, strict: bool = True): ...@@ -84,8 +83,6 @@ def rank0_load_state_dict_from_path(model, in_path: str, strict: bool = True):
for buffer in model.buffers(): for buffer in model.buffers():
dist.broadcast(buffer.data, src=0) dist.broadcast(buffer.data, src=0)
return model.to(dtype=GET_DTYPE())
def linear_interpolation(features, output_len: int): def linear_interpolation(features, output_len: int):
features = features.transpose(1, 2) features = features.transpose(1, 2)
...@@ -120,7 +117,7 @@ def get_q_lens_audio_range( ...@@ -120,7 +117,7 @@ def get_q_lens_audio_range(
class PerceiverAttentionCA(nn.Module): class PerceiverAttentionCA(nn.Module):
def __init__(self, dim_head=128, heads=16, kv_dim=2048, adaLN: bool = False): def __init__(self, dim_head=128, heads=16, kv_dim=2048, adaLN: bool = False, quantized=False, quant_scheme=None):
super().__init__() super().__init__()
self.dim_head = dim_head self.dim_head = dim_head
self.heads = heads self.heads = heads
...@@ -129,9 +126,17 @@ class PerceiverAttentionCA(nn.Module): ...@@ -129,9 +126,17 @@ class PerceiverAttentionCA(nn.Module):
self.norm_kv = nn.LayerNorm(kv_dim) self.norm_kv = nn.LayerNorm(kv_dim)
self.norm_q = nn.LayerNorm(inner_dim, elementwise_affine=not adaLN) self.norm_q = nn.LayerNorm(inner_dim, elementwise_affine=not adaLN)
self.to_q = nn.Linear(inner_dim, inner_dim) if quantized:
self.to_kv = nn.Linear(kv_dim, inner_dim * 2) if quant_scheme == "fp8":
self.to_out = nn.Linear(inner_dim, inner_dim) self.to_q = SglQuantLinearFp8(inner_dim, inner_dim)
self.to_kv = nn.Linear(kv_dim, inner_dim * 2)
self.to_out = SglQuantLinearFp8(inner_dim, inner_dim)
else:
raise ValueError(f"Unsupported quant_scheme: {quant_scheme}")
else:
self.to_q = nn.Linear(inner_dim, inner_dim)
self.to_kv = nn.Linear(kv_dim, inner_dim * 2)
self.to_out = nn.Linear(inner_dim, inner_dim)
if adaLN: if adaLN:
self.shift_scale_gate = nn.Parameter(torch.randn(1, 3, inner_dim) / inner_dim**0.5) self.shift_scale_gate = nn.Parameter(torch.randn(1, 3, inner_dim) / inner_dim**0.5)
else: else:
...@@ -151,7 +156,7 @@ class PerceiverAttentionCA(nn.Module): ...@@ -151,7 +156,7 @@ class PerceiverAttentionCA(nn.Module):
shift = shift.transpose(0, 1) shift = shift.transpose(0, 1)
gate = gate.transpose(0, 1) gate = gate.transpose(0, 1)
latents = norm_q * (1 + scale) + shift latents = norm_q * (1 + scale) + shift
q = self.to_q(latents.to(GET_DTYPE())) q = self.to_q(latents)
k, v = self.to_kv(x).chunk(2, dim=-1) k, v = self.to_kv(x).chunk(2, dim=-1)
q = rearrange(q, "B L (H C) -> (B L) H C", H=self.heads) q = rearrange(q, "B L (H C) -> (B L) H C", H=self.heads)
k = rearrange(k, "B T L (H C) -> (B T L) H C", H=self.heads) k = rearrange(k, "B T L (H C) -> (B T L) H C", H=self.heads)
...@@ -258,6 +263,8 @@ class AudioAdapter(nn.Module): ...@@ -258,6 +263,8 @@ class AudioAdapter(nn.Module):
mlp_dims: tuple = (1024, 1024, 32 * 768), mlp_dims: tuple = (1024, 1024, 32 * 768),
time_freq_dim: int = 256, time_freq_dim: int = 256,
projection_transformer_layers: int = 4, projection_transformer_layers: int = 4,
quantized: bool = False,
quant_scheme: str = None,
): ):
super().__init__() super().__init__()
self.audio_proj = AudioProjection( self.audio_proj = AudioProjection(
...@@ -280,6 +287,8 @@ class AudioAdapter(nn.Module): ...@@ -280,6 +287,8 @@ class AudioAdapter(nn.Module):
heads=num_attention_heads, heads=num_attention_heads,
kv_dim=mlp_dims[-1] // num_tokens, kv_dim=mlp_dims[-1] // num_tokens,
adaLN=time_freq_dim > 0, adaLN=time_freq_dim > 0,
quantized=quantized,
quant_scheme=quant_scheme,
) )
for _ in range(ca_num) for _ in range(ca_num)
] ]
...@@ -298,181 +307,9 @@ class AudioAdapter(nn.Module): ...@@ -298,181 +307,9 @@ class AudioAdapter(nn.Module):
audio_feature = rearrange(audio_feature, "B (T S) N C -> B T (S N) C", S=4) audio_feature = rearrange(audio_feature, "B (T S) N C -> B T (S N) C", S=4)
return audio_feature return audio_feature
def forward(self, audio_feat: torch.Tensor, timestep: torch.Tensor, latent_frame: int, weight: float = 1.0, seq_p_group=None): @torch.no_grad()
def modify_hidden_states(hidden_states, grid_sizes, ca_block: PerceiverAttentionCA, x, t_emb, dtype, weight, seq_p_group): def forward_audio_proj(self, audio_feat, latent_frame):
"""thw specify the latent_frame, latent_height, latenf_width after
hidden_states is patchified.
latent_frame does not include the reference images so that the
audios and hidden_states are strictly aligned
"""
if len(hidden_states.shape) == 2: # 扩展batchsize dim
hidden_states = hidden_states.unsqueeze(0) # bs = 1
t, h, w = grid_sizes[0].tolist()
n_tokens = t * h * w
ori_dtype = hidden_states.dtype
device = hidden_states.device
bs, n_tokens_per_rank = hidden_states.shape[:2]
if seq_p_group is not None:
sp_size = dist.get_world_size(seq_p_group)
sp_rank = dist.get_rank(seq_p_group)
else:
sp_size = 1
sp_rank = 0
tail_length = n_tokens_per_rank * sp_size - n_tokens
n_unused_ranks = tail_length // n_tokens_per_rank
if sp_rank > sp_size - n_unused_ranks - 1:
n_query_tokens = 0
elif sp_rank == sp_size - n_unused_ranks - 1:
n_query_tokens = n_tokens_per_rank - tail_length % n_tokens_per_rank
else:
n_query_tokens = n_tokens_per_rank
if n_query_tokens > 0:
hidden_states_aligned = hidden_states[:, :n_query_tokens]
hidden_states_tail = hidden_states[:, n_query_tokens:]
else:
# for ranks that should be excluded from cross-attn, fake cross-attn will be applied so that FSDP works.
hidden_states_aligned = hidden_states[:, :1]
hidden_states_tail = hidden_states[:, 1:]
q_lens, t0, t1 = get_q_lens_audio_range(batchsize=bs, n_tokens_per_rank=n_tokens_per_rank, n_query_tokens=n_query_tokens, n_tokens_per_frame=h * w, sp_rank=sp_rank)
q_lens = torch.tensor(q_lens, device=device, dtype=torch.int32)
"""
processing audio features in sp_state can be moved outside.
"""
x = x[:, t0:t1]
x = x.to(dtype)
k_lens = torch.tensor([self.num_tokens_x4] * (t1 - t0) * bs, device=device, dtype=torch.int32)
assert q_lens.shape == k_lens.shape
# ca_block:CrossAttention函数
residual = ca_block(x, hidden_states_aligned, t_emb, q_lens, k_lens) * weight
residual = residual.to(ori_dtype) # audio做了CrossAttention之后以Residual的方式注入
if n_query_tokens == 0:
residual = residual * 0.0
hidden_states = torch.cat([hidden_states_aligned + residual, hidden_states_tail], dim=1)
if len(hidden_states.shape) == 3: #
hidden_states = hidden_states.squeeze(0) # bs = 1
return hidden_states
x = self.audio_proj(audio_feat, latent_frame) x = self.audio_proj(audio_feat, latent_frame)
x = self.rearange_audio_features(x) x = self.rearange_audio_features(x)
x = x + self.audio_pe x = x + self.audio_pe
if self.time_embedding is not None: return x
t_emb = self.time_embedding(timestep).unflatten(1, (3, -1))
else:
t_emb = torch.zeros((len(x), 3, self.dim), device=x.device, dtype=x.dtype)
ret_dict = {}
for block_idx, base_idx in enumerate(range(0, self.base_num_layers, self.interval)):
block_dict = {
"kwargs": {
"ca_block": self.ca[block_idx],
"x": x,
"weight": weight,
"t_emb": t_emb,
"dtype": x.dtype,
"seq_p_group": seq_p_group,
},
"modify_func": modify_hidden_states,
}
ret_dict[base_idx] = block_dict
return ret_dict
@classmethod
def from_transformer(
cls,
transformer,
audio_feature_dim: int = 1024,
interval: int = 1,
time_freq_dim: int = 256,
projection_transformer_layers: int = 4,
):
num_attention_heads = transformer.config["num_heads"]
base_num_layers = transformer.config["num_layers"]
attention_head_dim = transformer.config["dim"] // num_attention_heads
audio_adapter = AudioAdapter(
attention_head_dim,
num_attention_heads,
base_num_layers,
interval=interval,
audio_feature_dim=audio_feature_dim,
time_freq_dim=time_freq_dim,
projection_transformer_layers=projection_transformer_layers,
mlp_dims=(1024, 1024, 32 * audio_feature_dim),
)
return audio_adapter
def get_fsdp_wrap_module_list(
self,
):
ret_list = list(self.ca)
return ret_list
def enable_gradient_checkpointing(
self,
):
pass
class AudioAdapterPipe:
def __init__(
self,
audio_adapter: AudioAdapter,
audio_encoder_repo: str = "microsoft/wavlm-base-plus",
dtype=torch.float32,
device="cuda",
tgt_fps: int = 15,
weight: float = 1.0,
cpu_offload: bool = False,
seq_p_group=None,
) -> None:
self.seq_p_group = seq_p_group
self.audio_adapter = audio_adapter
self.dtype = dtype
self.audio_encoder_dtype = torch.float16
self.cpu_offload = cpu_offload
##音频编码器
self.audio_encoder = AutoModel.from_pretrained(audio_encoder_repo)
self.audio_encoder.eval()
self.audio_encoder.to(device, self.audio_encoder_dtype)
self.tgt_fps = tgt_fps
self.weight = weight
if "base" in audio_encoder_repo:
self.audio_feature_dim = 768
else:
self.audio_feature_dim = 1024
def update_model(self, audio_adapter):
self.audio_adapter = audio_adapter
def __call__(self, audio_input_feat, timestep, latent_shape: tuple, dropout_cond: callable = None):
# audio_input_feat is from AudioPreprocessor
latent_frame = latent_shape[2]
if len(audio_input_feat.shape) == 1: # 扩展batchsize = 1
audio_input_feat = audio_input_feat.unsqueeze(0)
latent_frame = latent_shape[1]
video_frame = (latent_frame - 1) * 4 + 1
audio_length = int(50 / self.tgt_fps * video_frame)
with torch.no_grad():
try:
if self.cpu_offload:
self.audio_encoder = self.audio_encoder.to("cuda")
audio_feat = self.audio_encoder(audio_input_feat.to(self.audio_encoder_dtype), return_dict=True).last_hidden_state
if self.cpu_offload:
self.audio_encoder = self.audio_encoder.to("cpu")
except Exception as err:
audio_feat = torch.rand(1, audio_length, self.audio_feature_dim).to("cuda")
print(err)
audio_feat = audio_feat.to(self.dtype)
if dropout_cond is not None:
audio_feat = dropout_cond(audio_feat)
return self.audio_adapter(audio_feat=audio_feat, timestep=timestep, latent_frame=latent_frame, weight=self.weight, seq_p_group=self.seq_p_group)
import torch
from transformers import AutoFeatureExtractor, AutoModel
from lightx2v.utils.envs import *
class SekoAudioEncoderModel:
def __init__(self, model_path, audio_sr):
self.model_path = model_path
self.audio_sr = audio_sr
self.load()
def load(self):
self.audio_feature_extractor = AutoFeatureExtractor.from_pretrained(self.model_path)
self.audio_feature_encoder = AutoModel.from_pretrained(self.model_path)
self.audio_feature_encoder.eval()
self.audio_feature_encoder.to(GET_DTYPE())
def to_cpu(self):
self.audio_feature_encoder = self.audio_feature_encoder.to("cpu")
def to_cuda(self):
self.audio_feature_encoder = self.audio_feature_encoder.to("cuda")
@torch.no_grad()
def infer(self, audio_segment):
audio_feat = self.audio_feature_extractor(audio_segment, sampling_rate=self.audio_sr, return_tensors="pt").input_values.to(self.audio_feature_encoder.device).to(dtype=GET_DTYPE())
audio_feat = self.audio_feature_encoder(audio_feat, return_dict=True).last_hidden_state
return audio_feat
...@@ -26,6 +26,11 @@ class WanAudioModel(WanModel): ...@@ -26,6 +26,11 @@ class WanAudioModel(WanModel):
self.post_infer_class = WanAudioPostInfer self.post_infer_class = WanAudioPostInfer
self.transformer_infer_class = WanAudioTransformerInfer self.transformer_infer_class = WanAudioTransformerInfer
def set_audio_adapter(self, audio_adapter):
self.audio_adapter = audio_adapter
self.pre_infer.set_audio_adapter(self.audio_adapter)
self.transformer_infer.set_audio_adapter(self.audio_adapter)
class Wan22MoeAudioModel(WanAudioModel): class Wan22MoeAudioModel(WanAudioModel):
def _load_ckpt(self, unified_dtype, sensitive_layer): def _load_ckpt(self, unified_dtype, sensitive_layer):
......
import math
import torch import torch
from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer
...@@ -8,32 +6,14 @@ from lightx2v.utils.envs import * ...@@ -8,32 +6,14 @@ from lightx2v.utils.envs import *
class WanAudioPostInfer(WanPostInfer): class WanAudioPostInfer(WanPostInfer):
def __init__(self, config): def __init__(self, config):
self.out_dim = config["out_dim"] super().__init__(config)
self.patch_size = (1, 2, 2)
self.clean_cuda_cache = config.get("clean_cuda_cache", False)
self.infer_dtype = GET_DTYPE()
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
def set_scheduler(self, scheduler):
self.scheduler = scheduler
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE()) @torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, x, pre_infer_out): def infer(self, x, pre_infer_out):
x = x[:, : pre_infer_out.valid_patch_length] x = x[: pre_infer_out.seq_lens[0]]
x = self.unpatchify(x, pre_infer_out.grid_sizes) x = self.unpatchify(x, pre_infer_out.grid_sizes)
if self.clean_cuda_cache: if self.clean_cuda_cache:
torch.cuda.empty_cache() torch.cuda.empty_cache()
return [u.float() for u in x] return [u.float() for u in x]
def unpatchify(self, x, grid_sizes):
x = x.unsqueeze(0)
c = self.out_dim
out = []
for u, v in zip(x, grid_sizes.tolist()):
u = u[: math.prod(v)].view(*v, *self.patch_size, c)
u = torch.einsum("fhwpqrc->cfphqwr", u)
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
out.append(u)
return out
...@@ -35,6 +35,9 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -35,6 +35,9 @@ class WanAudioPreInfer(WanPreInfer):
else: else:
self.sp_size = 1 self.sp_size = 1
def set_audio_adapter(self, audio_adapter):
self.audio_adapter = audio_adapter
def infer(self, weights, inputs): def infer(self, weights, inputs):
prev_latents = inputs["previmg_encoder_output"]["prev_latents"] prev_latents = inputs["previmg_encoder_output"]["prev_latents"]
if self.config.model_cls == "wan2.2_audio": if self.config.model_cls == "wan2.2_audio":
...@@ -48,7 +51,7 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -48,7 +51,7 @@ class WanAudioPreInfer(WanPreInfer):
hidden_states = torch.cat([hidden_states, prev_mask, prev_latents], dim=1) hidden_states = torch.cat([hidden_states, prev_mask, prev_latents], dim=1)
hidden_states = hidden_states.squeeze(0) hidden_states = hidden_states.squeeze(0)
x = [hidden_states] x = hidden_states
t = torch.stack([self.scheduler.timesteps[self.scheduler.step_index]]) t = torch.stack([self.scheduler.timesteps[self.scheduler.step_index]])
if self.config.model_cls == "wan2.2_audio": if self.config.model_cls == "wan2.2_audio":
...@@ -61,31 +64,23 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -61,31 +64,23 @@ class WanAudioPreInfer(WanPreInfer):
temp_ts = torch.cat([temp_ts, temp_ts.new_ones(max_seq_len - temp_ts.size(0)) * t]) temp_ts = torch.cat([temp_ts, temp_ts.new_ones(max_seq_len - temp_ts.size(0)) * t])
t = temp_ts.unsqueeze(0) t = temp_ts.unsqueeze(0)
audio_dit_blocks = [] t_emb = self.audio_adapter.time_embedding(t).unflatten(1, (3, -1))
audio_encoder_output = inputs["audio_encoder_output"]
audio_model_input = {
"audio_input_feat": audio_encoder_output.to(hidden_states.device),
"latent_shape": hidden_states.shape,
"timestep": t,
}
audio_dit_blocks.append(inputs["audio_adapter_pipe"](**audio_model_input))
# audio_dit_blocks = None##Debug Drop Audio
if self.scheduler.infer_condition: if self.scheduler.infer_condition:
context = inputs["text_encoder_output"]["context"] context = inputs["text_encoder_output"]["context"]
else: else:
context = inputs["text_encoder_output"]["context_null"] context = inputs["text_encoder_output"]["context_null"]
seq_len = self.scheduler.seq_len # seq_len = self.scheduler.seq_len
clip_fea = inputs["image_encoder_output"]["clip_encoder_out"] clip_fea = inputs["image_encoder_output"]["clip_encoder_out"]
ref_image_encoder = inputs["image_encoder_output"]["vae_encoder_out"].to(self.scheduler.latents.dtype) ref_image_encoder = inputs["image_encoder_output"]["vae_encoder_out"].to(self.scheduler.latents.dtype)
batch_size = len(x) # batch_size = len(x)
num_channels, _, height, width = x[0].shape num_channels, _, height, width = x.shape
_, ref_num_channels, ref_num_frames, _, _ = ref_image_encoder.shape _, ref_num_channels, ref_num_frames, _, _ = ref_image_encoder.shape
if ref_num_channels != num_channels: if ref_num_channels != num_channels:
zero_padding = torch.zeros( zero_padding = torch.zeros(
(batch_size, num_channels - ref_num_channels, ref_num_frames, height, width), (1, num_channels - ref_num_channels, ref_num_frames, height, width),
dtype=self.scheduler.latents.dtype, dtype=self.scheduler.latents.dtype,
device=self.scheduler.latents.device, device=self.scheduler.latents.device,
) )
...@@ -93,13 +88,10 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -93,13 +88,10 @@ class WanAudioPreInfer(WanPreInfer):
y = list(torch.unbind(ref_image_encoder, dim=0)) # 第一个batch维度变成list y = list(torch.unbind(ref_image_encoder, dim=0)) # 第一个batch维度变成list
# embeddings # embeddings
x = [weights.patch_embedding.apply(u.unsqueeze(0)) for u in x] x = weights.patch_embedding.apply(x.unsqueeze(0))
x_grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) grid_sizes = torch.tensor(x.shape[2:], dtype=torch.long).unsqueeze(0)
x = [u.flatten(2).transpose(1, 2) for u in x] x = x.flatten(2).transpose(1, 2).contiguous()
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long).cuda() seq_lens = torch.tensor(x.size(1), dtype=torch.long).cuda().unsqueeze(0)
assert seq_lens.max() <= seq_len
x = torch.cat([torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) for u in x])
valid_patch_length = x[0].size(0)
y = [weights.patch_embedding.apply(u.unsqueeze(0)) for u in y] y = [weights.patch_embedding.apply(u.unsqueeze(0)) for u in y]
# y_grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in y]) # y_grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in y])
...@@ -169,12 +161,11 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -169,12 +161,11 @@ class WanAudioPreInfer(WanPreInfer):
return WanPreInferModuleOutput( return WanPreInferModuleOutput(
embed=embed, embed=embed,
grid_sizes=x_grid_sizes, grid_sizes=grid_sizes,
x=x.squeeze(0), x=x.squeeze(0),
embed0=embed0.squeeze(0), embed0=embed0.squeeze(0),
seq_lens=seq_lens, seq_lens=seq_lens,
freqs=self.freqs, freqs=self.freqs,
context=context, context=context,
audio_dit_blocks=audio_dit_blocks, adapter_output={"audio_encoder_output": inputs["audio_encoder_output"], "t_emb": t_emb},
valid_patch_length=valid_patch_length,
) )
import torch
import torch.distributed as dist
from lightx2v.models.input_encoders.hf.seko_audio.audio_adapter import get_q_lens_audio_range
from lightx2v.models.networks.wan.infer.offload.transformer_infer import WanOffloadTransformerInfer from lightx2v.models.networks.wan.infer.offload.transformer_infer import WanOffloadTransformerInfer
from lightx2v.models.networks.wan.infer.utils import compute_freqs_audio, compute_freqs_audio_dist from lightx2v.models.networks.wan.infer.utils import compute_freqs_audio, compute_freqs_audio_dist
...@@ -5,7 +9,13 @@ from lightx2v.models.networks.wan.infer.utils import compute_freqs_audio, comput ...@@ -5,7 +9,13 @@ from lightx2v.models.networks.wan.infer.utils import compute_freqs_audio, comput
class WanAudioTransformerInfer(WanOffloadTransformerInfer): class WanAudioTransformerInfer(WanOffloadTransformerInfer):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_tokens = 32
self.num_tokens_x4 = self.num_tokens * 4
def set_audio_adapter(self, audio_adapter):
self.audio_adapter = audio_adapter
@torch.no_grad()
def compute_freqs(self, q, grid_sizes, freqs): def compute_freqs(self, q, grid_sizes, freqs):
if self.config["seq_parallel"]: if self.config["seq_parallel"]:
freqs_i = compute_freqs_audio_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs, self.seq_p_group) freqs_i = compute_freqs_audio_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs, self.seq_p_group)
...@@ -13,13 +23,77 @@ class WanAudioTransformerInfer(WanOffloadTransformerInfer): ...@@ -13,13 +23,77 @@ class WanAudioTransformerInfer(WanOffloadTransformerInfer):
freqs_i = compute_freqs_audio(q.size(2) // 2, grid_sizes, freqs) freqs_i = compute_freqs_audio(q.size(2) // 2, grid_sizes, freqs)
return freqs_i return freqs_i
@torch.no_grad()
def post_process(self, x, y, c_gate_msa, pre_infer_out): def post_process(self, x, y, c_gate_msa, pre_infer_out):
x = super().post_process(x, y, c_gate_msa, pre_infer_out) x = super().post_process(x, y, c_gate_msa, pre_infer_out)
# Apply audio_dit if available x = self.modify_hidden_states(
if pre_infer_out.audio_dit_blocks is not None and hasattr(self, "block_idx"): hidden_states=x,
for ipa_out in pre_infer_out.audio_dit_blocks: grid_sizes=pre_infer_out.grid_sizes,
if self.block_idx in ipa_out: ca_block=self.audio_adapter.ca[self.block_idx],
cur_modify = ipa_out[self.block_idx] audio_encoder_output=pre_infer_out.adapter_output["audio_encoder_output"],
x = cur_modify["modify_func"](x, pre_infer_out.grid_sizes, **cur_modify["kwargs"]) t_emb=pre_infer_out.adapter_output["t_emb"],
weight=1.0,
seq_p_group=self.seq_p_group,
)
return x return x
@torch.no_grad()
def modify_hidden_states(self, hidden_states, grid_sizes, ca_block, audio_encoder_output, t_emb, weight, seq_p_group):
"""thw specify the latent_frame, latent_height, latenf_width after
hidden_states is patchified.
latent_frame does not include the reference images so that the
audios and hidden_states are strictly aligned
"""
if len(hidden_states.shape) == 2: # 扩展batchsize dim
hidden_states = hidden_states.unsqueeze(0) # bs = 1
t, h, w = grid_sizes[0].tolist()
n_tokens = t * h * w
ori_dtype = hidden_states.dtype
device = hidden_states.device
bs, n_tokens_per_rank = hidden_states.shape[:2]
if seq_p_group is not None:
sp_size = dist.get_world_size(seq_p_group)
sp_rank = dist.get_rank(seq_p_group)
else:
sp_size = 1
sp_rank = 0
tail_length = n_tokens_per_rank * sp_size - n_tokens
n_unused_ranks = tail_length // n_tokens_per_rank
if sp_rank > sp_size - n_unused_ranks - 1:
n_query_tokens = 0
elif sp_rank == sp_size - n_unused_ranks - 1:
n_query_tokens = n_tokens_per_rank - tail_length % n_tokens_per_rank
else:
n_query_tokens = n_tokens_per_rank
if n_query_tokens > 0:
hidden_states_aligned = hidden_states[:, :n_query_tokens]
hidden_states_tail = hidden_states[:, n_query_tokens:]
else:
# for ranks that should be excluded from cross-attn, fake cross-attn will be applied so that FSDP works.
hidden_states_aligned = hidden_states[:, :1]
hidden_states_tail = hidden_states[:, 1:]
q_lens, t0, t1 = get_q_lens_audio_range(batchsize=bs, n_tokens_per_rank=n_tokens_per_rank, n_query_tokens=n_query_tokens, n_tokens_per_frame=h * w, sp_rank=sp_rank)
q_lens = torch.tensor(q_lens, device=device, dtype=torch.int32)
"""
processing audio features in sp_state can be moved outside.
"""
audio_encoder_output = audio_encoder_output[:, t0:t1]
k_lens = torch.tensor([self.num_tokens_x4] * (t1 - t0) * bs, device=device, dtype=torch.int32)
assert q_lens.shape == k_lens.shape
# ca_block:CrossAttention函数
residual = ca_block(audio_encoder_output, hidden_states_aligned, t_emb, q_lens, k_lens) * weight
residual = residual.to(ori_dtype) # audio做了CrossAttention之后以Residual的方式注入
if n_query_tokens == 0:
residual = residual * 0.0
hidden_states = torch.cat([hidden_states_aligned + residual, hidden_states_tail], dim=1)
if len(hidden_states.shape) == 3: #
hidden_states = hidden_states.squeeze(0) # bs = 1
return hidden_states
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, List, Optional from typing import Any, Dict
import torch import torch
@dataclass @dataclass
class WanPreInferModuleOutput: class WanPreInferModuleOutput:
# wan base model
embed: torch.Tensor embed: torch.Tensor
grid_sizes: torch.Tensor grid_sizes: torch.Tensor
x: torch.Tensor x: torch.Tensor
...@@ -13,7 +14,6 @@ class WanPreInferModuleOutput: ...@@ -13,7 +14,6 @@ class WanPreInferModuleOutput:
seq_lens: torch.Tensor seq_lens: torch.Tensor
freqs: torch.Tensor freqs: torch.Tensor
context: torch.Tensor context: torch.Tensor
audio_dit_blocks: List[Any] = None
valid_patch_length: Optional[int] = None # wan adapter model
hints: List[Any] = None adapter_output: Dict[str, Any] = None
context_scale: float = 1.0
...@@ -9,7 +9,7 @@ class WanVaceTransformerInfer(WanOffloadTransformerInfer): ...@@ -9,7 +9,7 @@ class WanVaceTransformerInfer(WanOffloadTransformerInfer):
self.vace_blocks_mapping = {orig_idx: seq_idx for seq_idx, orig_idx in enumerate(self.config.vace_layers)} self.vace_blocks_mapping = {orig_idx: seq_idx for seq_idx, orig_idx in enumerate(self.config.vace_layers)}
def infer(self, weights, pre_infer_out): def infer(self, weights, pre_infer_out):
pre_infer_out.hints = self.infer_vace(weights, pre_infer_out) pre_infer_out.adapter_output["hints"] = self.infer_vace(weights, pre_infer_out)
x = self.infer_main_blocks(weights, pre_infer_out) x = self.infer_main_blocks(weights, pre_infer_out)
return self.infer_non_blocks(weights, x, pre_infer_out.embed) return self.infer_non_blocks(weights, x, pre_infer_out.embed)
...@@ -40,6 +40,6 @@ class WanVaceTransformerInfer(WanOffloadTransformerInfer): ...@@ -40,6 +40,6 @@ class WanVaceTransformerInfer(WanOffloadTransformerInfer):
x = super().post_process(x, y, c_gate_msa, pre_infer_out) x = super().post_process(x, y, c_gate_msa, pre_infer_out)
if self.infer_state == "base" and self.block_idx in self.vace_blocks_mapping: if self.infer_state == "base" and self.block_idx in self.vace_blocks_mapping:
hint_idx = self.vace_blocks_mapping[self.block_idx] hint_idx = self.vace_blocks_mapping[self.block_idx]
x = x + pre_infer_out.hints[hint_idx] * pre_infer_out.context_scale x = x + pre_infer_out.adapter_output["hints"][hint_idx] * pre_infer_out.adapter_output.get("context_scale", 1.0)
return x return x
from abc import ABC, abstractmethod from abc import ABC
from typing import Any, Dict, List, Optional, Protocol, Tuple, Union
from lightx2v.utils.utils import save_videos_grid from lightx2v.utils.utils import save_videos_grid
class TransformerModel(Protocol):
"""Protocol for transformer models"""
def set_scheduler(self, scheduler: Any) -> None: ...
def scheduler(self) -> Any: ...
class TextEncoderModel(Protocol):
"""Protocol for text encoder models"""
def infer(self, texts: List[str], config: Dict[str, Any]) -> Any: ...
class ImageEncoderModel(Protocol):
"""Protocol for image encoder models"""
def encode(self, image: Any) -> Any: ...
class VAEModel(Protocol):
"""Protocol for VAE models"""
def encode(self, image: Any) -> Tuple[Any, Dict[str, Any]]: ...
def decode(self, latents: Any, generator: Optional[Any] = None, config: Optional[Dict[str, Any]] = None) -> Any: ...
class BaseRunner(ABC): class BaseRunner(ABC):
"""Abstract base class for all Runners """Abstract base class for all Runners
Defines interface methods that all subclasses must implement Defines interface methods that all subclasses must implement
""" """
def __init__(self, config: Dict[str, Any]): def __init__(self, config):
self.config = config self.config = config
@abstractmethod def load_transformer(self):
def load_transformer(self) -> TransformerModel:
"""Load transformer model """Load transformer model
Returns: Returns:
...@@ -48,8 +20,7 @@ class BaseRunner(ABC): ...@@ -48,8 +20,7 @@ class BaseRunner(ABC):
""" """
pass pass
@abstractmethod def load_text_encoder(self):
def load_text_encoder(self) -> Union[TextEncoderModel, List[TextEncoderModel]]:
"""Load text encoder """Load text encoder
Returns: Returns:
...@@ -57,8 +28,7 @@ class BaseRunner(ABC): ...@@ -57,8 +28,7 @@ class BaseRunner(ABC):
""" """
pass pass
@abstractmethod def load_image_encoder(self):
def load_image_encoder(self) -> Optional[ImageEncoderModel]:
"""Load image encoder """Load image encoder
Returns: Returns:
...@@ -66,8 +36,7 @@ class BaseRunner(ABC): ...@@ -66,8 +36,7 @@ class BaseRunner(ABC):
""" """
pass pass
@abstractmethod def load_vae(self):
def load_vae(self) -> Tuple[VAEModel, VAEModel]:
"""Load VAE encoder and decoder """Load VAE encoder and decoder
Returns: Returns:
...@@ -75,8 +44,7 @@ class BaseRunner(ABC): ...@@ -75,8 +44,7 @@ class BaseRunner(ABC):
""" """
pass pass
@abstractmethod def run_image_encoder(self, img):
def run_image_encoder(self, img: Any) -> Any:
"""Run image encoder """Run image encoder
Args: Args:
...@@ -87,8 +55,7 @@ class BaseRunner(ABC): ...@@ -87,8 +55,7 @@ class BaseRunner(ABC):
""" """
pass pass
@abstractmethod def run_vae_encoder(self, img):
def run_vae_encoder(self, img: Any) -> Tuple[Any, Dict[str, Any]]:
"""Run VAE encoder """Run VAE encoder
Args: Args:
...@@ -99,8 +66,7 @@ class BaseRunner(ABC): ...@@ -99,8 +66,7 @@ class BaseRunner(ABC):
""" """
pass pass
@abstractmethod def run_text_encoder(self, prompt, img):
def run_text_encoder(self, prompt: str, img: Optional[Any] = None) -> Any:
"""Run text encoder """Run text encoder
Args: Args:
...@@ -112,8 +78,7 @@ class BaseRunner(ABC): ...@@ -112,8 +78,7 @@ class BaseRunner(ABC):
""" """
pass pass
@abstractmethod def get_encoder_output_i2v(self, clip_encoder_out, vae_encoder_out, text_encoder_output, img):
def get_encoder_output_i2v(self, clip_encoder_out: Any, vae_encoder_out: Any, text_encoder_output: Any, img: Any) -> Dict[str, Any]:
"""Combine encoder outputs for i2v task """Combine encoder outputs for i2v task
Args: Args:
...@@ -127,12 +92,11 @@ class BaseRunner(ABC): ...@@ -127,12 +92,11 @@ class BaseRunner(ABC):
""" """
pass pass
@abstractmethod def init_scheduler(self):
def init_scheduler(self) -> None:
"""Initialize scheduler""" """Initialize scheduler"""
pass pass
def set_target_shape(self) -> Dict[str, Any]: def set_target_shape(self):
"""Set target shape """Set target shape
Subclasses can override this method to provide specific implementation Subclasses can override this method to provide specific implementation
...@@ -142,7 +106,7 @@ class BaseRunner(ABC): ...@@ -142,7 +106,7 @@ class BaseRunner(ABC):
""" """
return {} return {}
def save_video_func(self, images: Any) -> None: def save_video_func(self, images):
"""Save video implementation """Save video implementation
Subclasses can override this method to customize save logic Subclasses can override this method to customize save logic
...@@ -152,7 +116,7 @@ class BaseRunner(ABC): ...@@ -152,7 +116,7 @@ class BaseRunner(ABC):
""" """
save_videos_grid(images, self.config.get("save_video_path", "./output.mp4"), n_rows=1, fps=self.config.get("fps", 8)) save_videos_grid(images, self.config.get("save_video_path", "./output.mp4"), n_rows=1, fps=self.config.get("fps", 8))
def load_vae_decoder(self) -> VAEModel: def load_vae_decoder(self):
"""Load VAE decoder """Load VAE decoder
Default implementation: get decoder from load_vae method Default implementation: get decoder from load_vae method
...@@ -164,3 +128,21 @@ class BaseRunner(ABC): ...@@ -164,3 +128,21 @@ class BaseRunner(ABC):
if not hasattr(self, "vae_decoder") or self.vae_decoder is None: if not hasattr(self, "vae_decoder") or self.vae_decoder is None:
_, self.vae_decoder = self.load_vae() _, self.vae_decoder = self.load_vae()
return self.vae_decoder return self.vae_decoder
def get_video_segment_num(self):
self.video_segment_num = 1
def init_run(self):
pass
def init_run_segment(self, segment_idx):
self.segment_idx = segment_idx
def run_segment(self, total_steps=None):
pass
def end_run_segment(self):
pass
def end_run(self):
pass
...@@ -3,6 +3,7 @@ import gc ...@@ -3,6 +3,7 @@ import gc
import requests import requests
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torchvision.transforms.functional as TF
from PIL import Image from PIL import Image
from loguru import logger from loguru import logger
from requests.exceptions import RequestException from requests.exceptions import RequestException
...@@ -35,8 +36,6 @@ class DefaultRunner(BaseRunner): ...@@ -35,8 +36,6 @@ class DefaultRunner(BaseRunner):
self.load_model() self.load_model()
elif self.config.get("lazy_load", False): elif self.config.get("lazy_load", False):
assert self.config.get("cpu_offload", False) assert self.config.get("cpu_offload", False)
self.run_dit = self._run_dit_local
self.run_vae_decoder = self._run_vae_decoder_local
if self.config["task"] == "i2v": if self.config["task"] == "i2v":
self.run_input_encoder = self._run_input_encoder_local_i2v self.run_input_encoder = self._run_input_encoder_local_i2v
elif self.config["task"] == "flf2v": elif self.config["task"] == "flf2v":
...@@ -108,7 +107,7 @@ class DefaultRunner(BaseRunner): ...@@ -108,7 +107,7 @@ class DefaultRunner(BaseRunner):
def set_progress_callback(self, callback): def set_progress_callback(self, callback):
self.progress_callback = callback self.progress_callback = callback
def run(self, total_steps=None): def run_segment(self, total_steps=None):
if total_steps is None: if total_steps is None:
total_steps = self.model.scheduler.infer_steps total_steps = self.model.scheduler.infer_steps
for step_index in range(total_steps): for step_index in range(total_steps):
...@@ -130,8 +129,7 @@ class DefaultRunner(BaseRunner): ...@@ -130,8 +129,7 @@ class DefaultRunner(BaseRunner):
def run_step(self): def run_step(self):
self.inputs = self.run_input_encoder() self.inputs = self.run_input_encoder()
self.set_target_shape() self.run_main(total_steps=1)
self.run_dit(total_steps=1)
def end_run(self): def end_run(self):
self.model.scheduler.clear() self.model.scheduler.clear()
...@@ -147,10 +145,15 @@ class DefaultRunner(BaseRunner): ...@@ -147,10 +145,15 @@ class DefaultRunner(BaseRunner):
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
def read_image_input(self, img_path):
img = Image.open(img_path).convert("RGB")
img = TF.to_tensor(img).sub_(0.5).div_(0.5).unsqueeze(0).cuda()
return img
@ProfilingContext("Run Encoders") @ProfilingContext("Run Encoders")
def _run_input_encoder_local_i2v(self): def _run_input_encoder_local_i2v(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"] prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
img = Image.open(self.config["image_path"]).convert("RGB") img = self.read_image_input(self.config["image_path"])
clip_encoder_out = self.run_image_encoder(img) if self.config.get("use_image_encoder", True) else None clip_encoder_out = self.run_image_encoder(img) if self.config.get("use_image_encoder", True) else None
vae_encode_out = self.run_vae_encoder(img) vae_encode_out = self.run_vae_encoder(img)
text_encoder_output = self.run_text_encoder(prompt, img) text_encoder_output = self.run_text_encoder(prompt, img)
...@@ -172,8 +175,8 @@ class DefaultRunner(BaseRunner): ...@@ -172,8 +175,8 @@ class DefaultRunner(BaseRunner):
@ProfilingContext("Run Encoders") @ProfilingContext("Run Encoders")
def _run_input_encoder_local_flf2v(self): def _run_input_encoder_local_flf2v(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"] prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
first_frame = Image.open(self.config["image_path"]).convert("RGB") first_frame = self.read_image_input(self.config["image_path"])
last_frame = Image.open(self.config["last_frame_path"]).convert("RGB") last_frame = self.read_image_input(self.config["last_frame_path"])
clip_encoder_out = self.run_image_encoder(first_frame, last_frame) if self.config.get("use_image_encoder", True) else None clip_encoder_out = self.run_image_encoder(first_frame, last_frame) if self.config.get("use_image_encoder", True) else None
vae_encode_out = self.run_vae_encoder(first_frame, last_frame) vae_encode_out = self.run_vae_encoder(first_frame, last_frame)
text_encoder_output = self.run_text_encoder(prompt, first_frame) text_encoder_output = self.run_text_encoder(prompt, first_frame)
...@@ -201,20 +204,32 @@ class DefaultRunner(BaseRunner): ...@@ -201,20 +204,32 @@ class DefaultRunner(BaseRunner):
gc.collect() gc.collect()
return self.get_encoder_output_i2v(None, vae_encoder_out, text_encoder_output) return self.get_encoder_output_i2v(None, vae_encoder_out, text_encoder_output)
@ProfilingContext("Run DiT") def init_run(self):
def _run_dit_local(self, total_steps=None): self.set_target_shape()
self.get_video_segment_num()
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.model = self.load_transformer() self.model = self.load_transformer()
self.init_scheduler() self.init_scheduler()
self.model.scheduler.prepare(self.inputs["image_encoder_output"]) self.model.scheduler.prepare(self.inputs["image_encoder_output"])
if self.config.get("model_cls") == "wan2.2" and self.config["task"] == "i2v": if self.config.get("model_cls") == "wan2.2" and self.config["task"] == "i2v":
self.inputs["image_encoder_output"]["vae_encoder_out"] = None self.inputs["image_encoder_output"]["vae_encoder_out"] = None
latents, generator = self.run(total_steps)
@ProfilingContext("Run DiT")
def run_main(self, total_steps=None):
self.init_run()
for segment_idx in range(self.video_segment_num):
# 1. default do nothing
self.init_run_segment(segment_idx)
# 2. main inference loop
latents, generator = self.run_segment(total_steps=total_steps)
# 3. vae decoder
self.gen_video = self.run_vae_decoder(latents, generator)
# 4. default do nothing
self.end_run_segment()
self.end_run() self.end_run()
return latents, generator
@ProfilingContext("Run VAE Decoder") @ProfilingContext("Run VAE Decoder")
def _run_vae_decoder_local(self, latents, generator): def run_vae_decoder(self, latents, generator):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.vae_decoder = self.load_vae_decoder() self.vae_decoder = self.load_vae_decoder()
images = self.vae_decoder.decode(latents, generator=generator, config=self.config) images = self.vae_decoder.decode(latents, generator=generator, config=self.config)
...@@ -240,15 +255,15 @@ class DefaultRunner(BaseRunner): ...@@ -240,15 +255,15 @@ class DefaultRunner(BaseRunner):
logger.info(f"Enhanced prompt: {enhanced_prompt}") logger.info(f"Enhanced prompt: {enhanced_prompt}")
return enhanced_prompt return enhanced_prompt
def process_images_after_vae_decoder(self, images, save_video=True): def process_images_after_vae_decoder(self, save_video=True):
images = vae_to_comfyui_image(images) self.gen_video = vae_to_comfyui_image(self.gen_video)
if "video_frame_interpolation" in self.config: if "video_frame_interpolation" in self.config:
assert self.vfi_model is not None and self.config["video_frame_interpolation"].get("target_fps", None) is not None assert self.vfi_model is not None and self.config["video_frame_interpolation"].get("target_fps", None) is not None
target_fps = self.config["video_frame_interpolation"]["target_fps"] target_fps = self.config["video_frame_interpolation"]["target_fps"]
logger.info(f"Interpolating frames from {self.config.get('fps', 16)} to {target_fps}") logger.info(f"Interpolating frames from {self.config.get('fps', 16)} to {target_fps}")
images = self.vfi_model.interpolate_frames( self.gen_video = self.vfi_model.interpolate_frames(
images, self.gen_video,
source_fps=self.config.get("fps", 16), source_fps=self.config.get("fps", 16),
target_fps=target_fps, target_fps=target_fps,
) )
...@@ -262,24 +277,21 @@ class DefaultRunner(BaseRunner): ...@@ -262,24 +277,21 @@ class DefaultRunner(BaseRunner):
if not dist.is_initialized() or dist.get_rank() == 0: if not dist.is_initialized() or dist.get_rank() == 0:
logger.info(f"🎬 Start to save video 🎬") logger.info(f"🎬 Start to save video 🎬")
save_to_video(images, self.config.save_video_path, fps=fps, method="ffmpeg") save_to_video(self.gen_video, self.config.save_video_path, fps=fps, method="ffmpeg")
logger.info(f"✅ Video saved successfully to: {self.config.save_video_path} ✅") logger.info(f"✅ Video saved successfully to: {self.config.save_video_path} ✅")
return {"video": self.gen_video}
def run_pipeline(self, save_video=True): def run_pipeline(self, save_video=True):
if self.config["use_prompt_enhancer"]: if self.config["use_prompt_enhancer"]:
self.config["prompt_enhanced"] = self.post_prompt_enhancer() self.config["prompt_enhanced"] = self.post_prompt_enhancer()
self.inputs = self.run_input_encoder() self.inputs = self.run_input_encoder()
self.set_target_shape()
latents, generator = self.run_dit() self.run_main()
images = self.run_vae_decoder(latents, generator) gen_video = self.process_images_after_vae_decoder(save_video=save_video)
self.process_images_after_vae_decoder(images, save_video=save_video)
del latents, generator
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
# Return (images, audio) - audio is None for default runner return gen_video
return images, None
import gc import gc
import os import os
import subprocess import subprocess
from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
...@@ -9,6 +8,7 @@ import numpy as np ...@@ -9,6 +8,7 @@ import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torchaudio as ta import torchaudio as ta
import torchvision.transforms.functional as TF
from PIL import Image from PIL import Image
from einops import rearrange from einops import rearrange
from loguru import logger from loguru import logger
...@@ -16,29 +16,19 @@ from torchvision.transforms import InterpolationMode ...@@ -16,29 +16,19 @@ from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import resize from torchvision.transforms.functional import resize
from transformers import AutoFeatureExtractor from transformers import AutoFeatureExtractor
from lightx2v.models.networks.wan.audio_adapter import AudioAdapter, AudioAdapterPipe, rank0_load_state_dict_from_path from lightx2v.models.input_encoders.hf.seko_audio.audio_adapter import AudioAdapter, rank0_load_state_dict_from_path
from lightx2v.models.input_encoders.hf.seko_audio.audio_encoder import SekoAudioEncoderModel
from lightx2v.models.networks.wan.audio_model import Wan22MoeAudioModel, WanAudioModel from lightx2v.models.networks.wan.audio_model import Wan22MoeAudioModel, WanAudioModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.runners.wan.wan_runner import MultiModelStruct, WanRunner from lightx2v.models.runners.wan.wan_runner import MultiModelStruct, WanRunner
from lightx2v.models.schedulers.wan.audio.scheduler import ConsistencyModelScheduler from lightx2v.models.schedulers.wan.audio.scheduler import ConsistencyModelScheduler
from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.profiler import ProfilingContext, ProfilingContext4Debug from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import find_torch_model_path, save_to_video, vae_to_comfyui_image from lightx2v.utils.utils import find_torch_model_path, save_to_video, vae_to_comfyui_image
@contextmanager
def memory_efficient_inference():
"""Context manager for memory-efficient inference"""
try:
yield
finally:
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
def get_optimal_patched_size_with_sp(patched_h, patched_w, sp_size): def get_optimal_patched_size_with_sp(patched_h, patched_w, sp_size):
assert sp_size > 0 and (sp_size & (sp_size - 1)) == 0, "sp_size must be a power of 2" assert sp_size > 0 and (sp_size & (sp_size - 1)) == 0, "sp_size must be a power of 2"
...@@ -244,17 +234,91 @@ class AudioProcessor: ...@@ -244,17 +234,91 @@ class AudioProcessor:
return segments return segments
class VideoGenerator: @RUNNER_REGISTER("wan2.1_audio")
"""Handles video generation for each segment""" class WanAudioRunner(WanRunner): # type:ignore
def __init__(self, config):
def __init__(self, model, vae_encoder, vae_decoder, config, progress_callback=None): super().__init__(config)
self.model = model self._audio_processor = None
self.vae_encoder = vae_encoder self._video_generator = None
self.vae_decoder = vae_decoder self._audio_preprocess = None
self.config = config
self.frame_preprocessor = FramePreprocessor() self.frame_preprocessor = FramePreprocessor()
self.progress_callback = progress_callback
self.total_segments = 1 def init_scheduler(self):
"""Initialize consistency model scheduler"""
scheduler = ConsistencyModelScheduler(self.config)
self.model.set_scheduler(scheduler)
def read_audio_input(self):
"""Read audio input"""
audio_sr = self.config.get("audio_sr", 16000)
target_fps = self.config.get("target_fps", 16)
self._audio_processor = AudioProcessor(audio_sr, target_fps)
audio_array = self._audio_processor.load_audio(self.config["audio_path"])
video_duration = self.config.get("video_duration", 5)
audio_len = int(audio_array.shape[0] / audio_sr * target_fps)
expected_frames = min(max(1, int(video_duration * target_fps)), audio_len)
# Segment audio
audio_segments = self._audio_processor.segment_audio(audio_array, expected_frames, self.config.get("target_video_length", 81))
return audio_segments, expected_frames
def read_image_input(self, img_path):
ref_img = Image.open(img_path).convert("RGB")
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(0).cuda()
ref_img, h, w = adaptive_resize(ref_img)
patched_h = h // self.config.vae_stride[1] // self.config.patch_size[1]
patched_w = w // self.config.vae_stride[2] // self.config.patch_size[2]
patched_h, patched_w = get_optimal_patched_size_with_sp(patched_h, patched_w, 1)
self.config.lat_h = patched_h * self.config.patch_size[1]
self.config.lat_w = patched_w * self.config.patch_size[2]
self.config.tgt_h = self.config.lat_h * self.config.vae_stride[1]
self.config.tgt_w = self.config.lat_w * self.config.vae_stride[2]
logger.info(f"[wan_audio] tgt_h: {self.config.tgt_h}, tgt_w: {self.config.tgt_w}, lat_h: {self.config.lat_h}, lat_w: {self.config.lat_w}")
ref_img = torch.nn.functional.interpolate(ref_img, size=(self.config.tgt_h, self.config.tgt_w), mode="bicubic")
return ref_img
def run_image_encoder(self, first_frame, last_frame=None):
clip_encoder_out = self.image_encoder.visual([first_frame]).squeeze(0).to(GET_DTYPE()) if self.config.get("use_image_encoder", True) else None
return clip_encoder_out
def run_vae_encoder(self, img):
img = rearrange(img, "1 C H W -> 1 C 1 H W")
vae_encoder_out = self.vae_encoder.encode(img.to(torch.float))
if self.config.model_cls == "wan2.2_audio":
vae_encoder_out = vae_encoder_out.unsqueeze(0).to(GET_DTYPE())
else:
if isinstance(vae_encoder_out, list):
vae_encoder_out = torch.stack(vae_encoder_out, dim=0).to(GET_DTYPE())
return vae_encoder_out
@ProfilingContext("Run Encoders")
def _run_input_encoder_local_r2v_audio(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
img = self.read_image_input(self.config["image_path"])
clip_encoder_out = self.run_image_encoder(img) if self.config.get("use_image_encoder", True) else None
vae_encode_out = self.run_vae_encoder(img)
audio_segments, expected_frames = self.read_audio_input()
text_encoder_output = self.run_text_encoder(prompt, None)
torch.cuda.empty_cache()
gc.collect()
return {
"text_encoder_output": text_encoder_output,
"image_encoder_output": {
"clip_encoder_out": clip_encoder_out,
"vae_encoder_out": vae_encode_out,
},
"audio_segments": audio_segments,
"expected_frames": expected_frames,
}
def prepare_prev_latents(self, prev_video: Optional[torch.Tensor], prev_frame_length: int) -> Optional[Dict[str, torch.Tensor]]: def prepare_prev_latents(self, prev_video: Optional[torch.Tensor], prev_frame_length: int) -> Optional[Dict[str, torch.Tensor]]:
"""Prepare previous latents for conditioning""" """Prepare previous latents for conditioning"""
...@@ -295,31 +359,6 @@ class VideoGenerator: ...@@ -295,31 +359,6 @@ class VideoGenerator:
return {"prev_latents": prev_latents, "prev_mask": prev_mask} return {"prev_latents": prev_latents, "prev_mask": prev_mask}
def _wan22_masks_like(self, tensor, zero=False, generator=None, p=0.2, prev_length=1):
assert isinstance(tensor, list)
out1 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor]
out2 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor]
if prev_length == 0:
return out1, out2
if zero:
if generator is not None:
for u, v in zip(out1, out2):
random_num = torch.rand(1, generator=generator, device=generator.device).item()
if random_num < p:
u[:, :prev_length] = torch.normal(mean=-3.5, std=0.5, size=(1,), device=u.device, generator=generator).expand_as(u[:, :prev_length]).exp()
v[:, :prev_length] = torch.zeros_like(v[:, :prev_length])
else:
u[:, :prev_length] = u[:, :prev_length]
v[:, :prev_length] = v[:, :prev_length]
else:
for u, v in zip(out1, out2):
u[:, :prev_length] = torch.zeros_like(u[:, :prev_length])
v[:, :prev_length] = torch.zeros_like(v[:, :prev_length])
return out1, out2
def _wan_mask_rearrange(self, mask: torch.Tensor) -> torch.Tensor: def _wan_mask_rearrange(self, mask: torch.Tensor) -> torch.Tensor:
"""Rearrange mask for WAN model""" """Rearrange mask for WAN model"""
if mask.ndim == 3: if mask.ndim == 3:
...@@ -332,250 +371,99 @@ class VideoGenerator: ...@@ -332,250 +371,99 @@ class VideoGenerator:
mask = mask.view(mask.shape[1] // 4, 4, h, w) mask = mask.view(mask.shape[1] // 4, 4, h, w)
return mask.transpose(0, 1) return mask.transpose(0, 1)
@torch.no_grad() def get_video_segment_num(self):
def generate_segment(self, inputs, audio_features, prev_video=None, prev_frame_length=5, segment_idx=0, total_steps=None): self.video_segment_num = len(self.inputs["audio_segments"])
"""Generate video segment"""
# Update inputs with audio features
inputs["audio_encoder_output"] = audio_features
# Reset scheduler for non-first segments
if segment_idx > 0:
self.model.scheduler.reset()
inputs["previmg_encoder_output"] = self.prepare_prev_latents(prev_video, prev_frame_length)
# Run inference loop def init_run(self):
if total_steps is None: super().init_run()
total_steps = self.model.scheduler.infer_steps
for step_index in range(total_steps):
logger.info(f"==> Segment {segment_idx}, Step {step_index}/{total_steps}")
with ProfilingContext4Debug("step_pre"): self.gen_video_list = []
self.model.scheduler.step_pre(step_index=step_index) self.cut_audio_list = []
self.prev_video = None
with ProfilingContext4Debug("🚀 infer_main"): def init_run_segment(self, segment_idx):
self.model.infer(inputs) self.segment_idx = segment_idx
with ProfilingContext4Debug("step_post"): self.segment = self.inputs["audio_segments"][segment_idx]
self.model.scheduler.step_post()
if self.config.model_cls == "wan2.2_audio":
prev_mask = inputs["previmg_encoder_output"]["prev_mask"]
prev_latents = inputs["previmg_encoder_output"]["prev_latents"]
self.model.scheduler.latents = (1.0 - prev_mask[0]) * prev_latents + prev_mask[0] * self.model.scheduler.latents
if self.progress_callback: self.config.seed = self.config.seed + segment_idx
segment_progress = (segment_idx * total_steps + step_index + 1) / (self.total_segments * total_steps) torch.manual_seed(self.config.seed)
self.progress_callback(int(segment_progress * 100), 100) logger.info(f"Processing segment {segment_idx + 1}/{self.video_segment_num}, seed: {self.config.seed}")
# Decode latents
latents = self.model.scheduler.latents
generator = self.model.scheduler.generator
with ProfilingContext("Run VAE Decoder"):
gen_video = self.vae_decoder.decode(latents, generator=generator, config=self.config)
gen_video = torch.clamp(gen_video, -1, 1).to(torch.float)
return gen_video
@RUNNER_REGISTER("wan2.1_audio")
class WanAudioRunner(WanRunner): # type:ignore
def __init__(self, config):
super().__init__(config)
self._audio_adapter_pipe = None
self._audio_processor = None
self._video_generator = None
self._audio_preprocess = None
def initialize(self):
"""Initialize all models once for multiple runs"""
# Initialize audio processor audio_features = self.audio_encoder.infer(self.segment.audio_array).to(self.model.device)
audio_sr = self.config.get("audio_sr", 16000) audio_features = self.audio_adapter.forward_audio_proj(audio_features, self.model.scheduler.latents.shape[1])
target_fps = self.config.get("target_fps", 16)
self._audio_processor = AudioProcessor(audio_sr, target_fps)
# Initialize scheduler self.inputs["audio_encoder_output"] = audio_features
self.init_scheduler()
def init_scheduler(self): # Reset scheduler for non-first segments
"""Initialize consistency model scheduler""" if segment_idx > 0:
scheduler = ConsistencyModelScheduler(self.config) self.model.scheduler.reset()
self.model.set_scheduler(scheduler)
def load_audio_adapter_lazy(self): self.inputs["previmg_encoder_output"] = self.prepare_prev_latents(self.prev_video, prev_frame_length=5)
"""Lazy load audio adapter when needed"""
if self._audio_adapter_pipe is not None:
return self._audio_adapter_pipe
# Audio adapter def end_run_segment(self):
audio_adapter_path = self.config["model_path"] + "/audio_adapter.safetensors" self.gen_video = torch.clamp(self.gen_video, -1, 1).to(torch.float)
audio_adapter = AudioAdapter.from_transformer(
self.model,
audio_feature_dim=1024,
interval=1,
time_freq_dim=256,
projection_transformer_layers=4,
)
# Audio encoder # Extract relevant frames
cpu_offload = self.config.get("cpu_offload", False) start_frame = 0 if self.segment_idx == 0 else 5
if cpu_offload: start_audio_frame = 0 if self.segment_idx == 0 else int(6 * self._audio_processor.audio_sr / self.config.get("target_fps", 16))
device = torch.device("cpu")
else:
device = torch.device("cuda")
audio_encoder_repo = self.config["model_path"] + "/audio_encoder"
if self.model.transformer_infer.seq_p_group is not None: if self.segment.is_last and self.segment.useful_length:
seq_p_group = self.model.transformer_infer.seq_p_group end_frame = self.segment.end_frame - self.segment.start_frame
self.gen_video_list.append(self.gen_video[:, :, start_frame:end_frame].cpu())
self.cut_audio_list.append(self.segment.audio_array[start_audio_frame : self.segment.useful_length])
elif self.segment.useful_length and self.inputs["expected_frames"] < self.config.get("target_video_length", 81):
self.gen_video_list.append(self.gen_video[:, :, start_frame : self.inputs["expected_frames"]].cpu())
self.cut_audio_list.append(self.segment.audio_array[start_audio_frame : self.segment.useful_length])
else: else:
seq_p_group = None self.gen_video_list.append(self.gen_video[:, :, start_frame:].cpu())
self.cut_audio_list.append(self.segment.audio_array[start_audio_frame:])
audio_adapter = rank0_load_state_dict_from_path(audio_adapter, audio_adapter_path, strict=False)
# Update prev_video for next iteration
self._audio_adapter_pipe = AudioAdapterPipe( self.prev_video = self.gen_video
audio_adapter, audio_encoder_repo=audio_encoder_repo, dtype=GET_DTYPE(), device=device, weight=1.0, cpu_offload=cpu_offload, seq_p_group=seq_p_group
) # Clean up GPU memory after each segment
del self.gen_video
return self._audio_adapter_pipe torch.cuda.empty_cache()
def prepare_inputs(self): def process_images_after_vae_decoder(self, save_video=True):
"""Prepare inputs for the model""" # Merge results
image_encoder_output = None gen_lvideo = torch.cat(self.gen_video_list, dim=2).float()
merge_audio = np.concatenate(self.cut_audio_list, axis=0).astype(np.float32)
if os.path.isfile(self.config.image_path):
with ProfilingContext("Run Img Encoder"): comfyui_images = vae_to_comfyui_image(gen_lvideo)
vae_encoder_out, clip_encoder_out = self.run_image_encoder(self.config, self.vae_encoder)
image_encoder_output = { # Apply frame interpolation if configured
"clip_encoder_out": clip_encoder_out, if "video_frame_interpolation" in self.config and self.vfi_model is not None:
"vae_encoder_out": vae_encoder_out, target_fps = self.config["video_frame_interpolation"]["target_fps"]
} logger.info(f"Interpolating frames from {self.config.get('fps', 16)} to {target_fps}")
comfyui_images = self.vfi_model.interpolate_frames(
comfyui_images,
source_fps=self.config.get("fps", 16),
target_fps=target_fps,
)
with ProfilingContext("Run Text Encoder"): if save_video:
img = Image.open(self.config["image_path"]).convert("RGB") if "video_frame_interpolation" in self.config and self.config["video_frame_interpolation"].get("target_fps"):
text_encoder_output = self.run_text_encoder(self.config["prompt"], img) fps = self.config["video_frame_interpolation"]["target_fps"]
else:
fps = self.config.get("fps", 16)
self.set_target_shape() if not dist.is_initialized() or dist.get_rank() == 0:
logger.info(f"🎬 Start to save video 🎬")
return {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output, "audio_adapter_pipe": self.load_audio_adapter_lazy()} self._save_video_with_audio(comfyui_images, merge_audio, fps)
logger.info(f"✅ Video saved successfully to: {self.config.save_video_path} ✅")
def run_pipeline(self, save_video=True): # Convert audio to ComfyUI format
"""Optimized pipeline with modular components""" audio_waveform = torch.from_numpy(merge_audio).unsqueeze(0).unsqueeze(0)
comfyui_audio = {"waveform": audio_waveform, "sample_rate": self._audio_processor.audio_sr}
try: return {"video": comfyui_images, "audio": comfyui_audio}
self.initialize()
assert self._audio_processor is not None
assert self._audio_preprocess is not None
self._video_generator = VideoGenerator(self.model, self.vae_encoder, self.vae_decoder, self.config, self.progress_callback)
with memory_efficient_inference():
if self.config["use_prompt_enhancer"]:
self.config["prompt_enhanced"] = self.post_prompt_enhancer()
self.inputs = self.prepare_inputs()
# Re-initialize scheduler after image encoding sets correct dimensions
self.init_scheduler()
self.model.scheduler.prepare(self.inputs["image_encoder_output"])
# Re-create video generator with updated model/scheduler
self._video_generator = VideoGenerator(self.model, self.vae_encoder, self.vae_decoder, self.config, self.progress_callback)
# Process audio
audio_array = self._audio_processor.load_audio(self.config["audio_path"])
video_duration = self.config.get("video_duration", 5)
target_fps = self.config.get("target_fps", 16)
max_num_frames = self.config.get("target_video_length", 81)
audio_len = int(audio_array.shape[0] / self._audio_processor.audio_sr * target_fps)
expected_frames = min(max(1, int(video_duration * target_fps)), audio_len)
# Segment audio
audio_segments = self._audio_processor.segment_audio(audio_array, expected_frames, max_num_frames)
self._video_generator.total_segments = len(audio_segments)
# Generate video segments
gen_video_list = []
cut_audio_list = []
prev_video = None
for idx, segment in enumerate(audio_segments):
self.config.seed = self.config.seed + idx
torch.manual_seed(self.config.seed)
logger.info(f"Processing segment {idx + 1}/{len(audio_segments)}, seed: {self.config.seed}")
# Process audio features
audio_features = self._audio_preprocess(segment.audio_array, sampling_rate=self._audio_processor.audio_sr, return_tensors="pt").input_values.squeeze(0).to(self.model.device)
# Generate video segment
with memory_efficient_inference():
gen_video = self._video_generator.generate_segment(
self.inputs.copy(), # Copy to avoid modifying original
audio_features,
prev_video=prev_video,
prev_frame_length=5,
segment_idx=idx,
)
# Extract relevant frames
start_frame = 0 if idx == 0 else 5
start_audio_frame = 0 if idx == 0 else int(6 * self._audio_processor.audio_sr / target_fps)
if segment.is_last and segment.useful_length:
end_frame = segment.end_frame - segment.start_frame
gen_video_list.append(gen_video[:, :, start_frame:end_frame].cpu())
cut_audio_list.append(segment.audio_array[start_audio_frame : segment.useful_length])
elif segment.useful_length and expected_frames < max_num_frames:
gen_video_list.append(gen_video[:, :, start_frame:expected_frames].cpu())
cut_audio_list.append(segment.audio_array[start_audio_frame : segment.useful_length])
else:
gen_video_list.append(gen_video[:, :, start_frame:].cpu())
cut_audio_list.append(segment.audio_array[start_audio_frame:])
# Update prev_video for next iteration
prev_video = gen_video
# Clean up GPU memory after each segment
del gen_video
torch.cuda.empty_cache()
# Merge results
with memory_efficient_inference():
gen_lvideo = torch.cat(gen_video_list, dim=2).float()
merge_audio = np.concatenate(cut_audio_list, axis=0).astype(np.float32)
comfyui_images = vae_to_comfyui_image(gen_lvideo)
# Apply frame interpolation if configured
if "video_frame_interpolation" in self.config and self.vfi_model is not None:
interpolation_target_fps = self.config["video_frame_interpolation"]["target_fps"]
logger.info(f"Interpolating frames from {target_fps} to {interpolation_target_fps}")
comfyui_images = self.vfi_model.interpolate_frames(
comfyui_images,
source_fps=target_fps,
target_fps=interpolation_target_fps,
)
target_fps = interpolation_target_fps
# Convert audio to ComfyUI format
audio_waveform = torch.from_numpy(merge_audio).unsqueeze(0).unsqueeze(0)
comfyui_audio = {"waveform": audio_waveform, "sample_rate": self._audio_processor.audio_sr}
# Save video if requested
if (self.config.get("device_mesh") is not None and dist.get_rank() == 0) or self.config.get("device_mesh") is None:
if save_video and self.config.get("save_video_path", None):
self._save_video_with_audio(comfyui_images, merge_audio, target_fps)
# Final cleanup
self.end_run()
return comfyui_images, comfyui_audio
finally: def init_modules(self):
self._video_generator = None super().init_modules()
gc.collect() self.run_input_encoder = self._run_input_encoder_local_r2v_audio
if torch.cuda.is_available():
torch.cuda.empty_cache()
def _save_video_with_audio(self, images, audio_array, fps): def _save_video_with_audio(self, images, audio_array, fps):
"""Save video with audio""" """Save video with audio"""
...@@ -620,63 +508,43 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -620,63 +508,43 @@ class WanAudioRunner(WanRunner): # type:ignore
lora_wrapper.apply_lora(lora_name, strength) lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}") logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
# XXX: trick
self._audio_preprocess = AutoFeatureExtractor.from_pretrained(self.config["model_path"], subfolder="audio_encoder")
return base_model return base_model
def run_image_encoder(self, config, vae_model): def load_audio_encoder(self):
"""Run image encoder""" model = SekoAudioEncoderModel(os.path.join(self.config["model_path"], "audio_encoder"), self.config["audio_sr"])
return model
ref_img = Image.open(config.image_path)
ref_img = (np.array(ref_img).astype(np.float32) - 127.5) / 127.5
ref_img = torch.from_numpy(ref_img).cuda()
ref_img = rearrange(ref_img, "H W C -> 1 C H W")
ref_img = ref_img[:, :3]
adaptive = config.get("adaptive_resize", False)
if adaptive:
# Use adaptive_resize to modify aspect ratio
ref_img, h, w = adaptive_resize(ref_img)
patched_h = h // self.config.vae_stride[1] // self.config.patch_size[1]
patched_w = w // self.config.vae_stride[2] // self.config.patch_size[2]
else:
h, w = ref_img.shape[2:]
aspect_ratio = h / w
max_area = config.target_height * config.target_width
patched_h = round(np.sqrt(max_area * aspect_ratio) // config.vae_stride[1] // config.patch_size[1])
patched_w = round(np.sqrt(max_area / aspect_ratio) // config.vae_stride[2] // config.patch_size[2])
patched_h, patched_w = get_optimal_patched_size_with_sp(patched_h, patched_w, 1)
config.lat_h = patched_h * self.config.patch_size[1]
config.lat_w = patched_w * self.config.patch_size[2]
config.tgt_h = config.lat_h * self.config.vae_stride[1] def load_audio_adapter(self):
config.tgt_w = config.lat_w * self.config.vae_stride[2] audio_adapter = AudioAdapter(
attention_head_dim=5120 // self.config["num_heads"],
logger.info(f"[wan_audio] adaptive_resize: {adaptive}, tgt_h: {config.tgt_h}, tgt_w: {config.tgt_w}, lat_h: {config.lat_h}, lat_w: {config.lat_w}") num_attention_heads=self.config["num_heads"],
base_num_layers=self.config["num_layers"],
cond_frms = torch.nn.functional.interpolate(ref_img, size=(config.tgt_h, config.tgt_w), mode="bicubic") interval=1,
audio_feature_dim=1024,
# clip encoder time_freq_dim=256,
clip_encoder_out = self.image_encoder.visual([cond_frms]).squeeze(0).to(GET_DTYPE()) if self.config.get("use_image_encoder", True) else None projection_transformer_layers=4,
mlp_dims=(1024, 1024, 32 * 1024),
# vae encode quantized=self.config.get("adapter_quantized", False),
cond_frms = rearrange(cond_frms, "1 C H W -> 1 C 1 H W") quant_scheme=self.config.get("adapter_quant_scheme", None),
vae_encoder_out = vae_model.encode(cond_frms.to(torch.float)) )
if self.config.get("adapter_quantized", False):
if self.config.model_cls == "wan2.2_audio": if self.config.get("adapter_quant_scheme", None) == "fp8":
vae_encoder_out = vae_encoder_out.unsqueeze(0).to(GET_DTYPE()) model_name = "audio_adapter_fp8.safetensors"
elif self.config.get("adapter_quant_scheme", None) == "int8":
model_name = "audio_adapter_int8.safetensors"
else:
raise ValueError(f"Unsupported quant_scheme: {self.config.get('adapter_quant_scheme', None)}")
else: else:
if isinstance(vae_encoder_out, list): model_name = "audio_adapter.safetensors"
vae_encoder_out = torch.stack(vae_encoder_out, dim=0).to(GET_DTYPE()) rank0_load_state_dict_from_path(audio_adapter, os.path.join(self.config["model_path"], model_name), strict=False)
return audio_adapter.to(dtype=GET_DTYPE())
return vae_encoder_out, clip_encoder_out @ProfilingContext("Load models")
def load_model(self):
super().load_model()
self.audio_encoder = self.load_audio_encoder()
self.audio_adapter = self.load_audio_adapter()
self.model.set_audio_adapter(self.audio_adapter)
def set_target_shape(self): def set_target_shape(self):
"""Set target shape for generation""" """Set target shape for generation"""
...@@ -701,62 +569,6 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -701,62 +569,6 @@ class WanAudioRunner(WanRunner): # type:ignore
ret["target_shape"] = self.config.target_shape ret["target_shape"] = self.config.target_shape
return ret return ret
def run_step(self):
"""Optimized pipeline with modular components"""
self.initialize()
assert self._audio_processor is not None
assert self._audio_preprocess is not None
self._video_generator = VideoGenerator(self.model, self.vae_encoder, self.vae_decoder, self.config, self.progress_callback)
with memory_efficient_inference():
if self.config["use_prompt_enhancer"]:
self.config["prompt_enhanced"] = self.post_prompt_enhancer()
self.inputs = self.prepare_inputs()
# Re-initialize scheduler after image encoding sets correct dimensions
self.init_scheduler()
self.model.scheduler.prepare(self.inputs["image_encoder_output"])
# Re-create video generator with updated model/scheduler
self._video_generator = VideoGenerator(self.model, self.vae_encoder, self.vae_decoder, self.config, self.progress_callback)
# Process audio
audio_array = self._audio_processor.load_audio(self.config["audio_path"])
video_duration = self.config.get("video_duration", 5)
target_fps = self.config.get("target_fps", 16)
max_num_frames = self.config.get("target_video_length", 81)
audio_len = int(audio_array.shape[0] / self._audio_processor.audio_sr * target_fps)
expected_frames = min(max(1, int(video_duration * target_fps)), audio_len)
# Segment audio
audio_segments = self._audio_processor.segment_audio(audio_array, expected_frames, max_num_frames)
self._video_generator.total_segments = len(audio_segments)
# Generate video segments
prev_video = None
torch.manual_seed(self.config.seed)
# Process audio features
audio_features = self._audio_preprocess(audio_segments[0].audio_array, sampling_rate=self._audio_processor.audio_sr, return_tensors="pt").input_values.squeeze(0).to(self.model.device)
# Generate video segment
with memory_efficient_inference():
self._video_generator.generate_segment(
self.inputs.copy(), # Copy to avoid modifying original
audio_features,
prev_video=prev_video,
prev_frame_length=5,
segment_idx=0,
total_steps=1,
)
# Final cleanup
self.end_run()
@RUNNER_REGISTER("wan2.2_audio") @RUNNER_REGISTER("wan2.2_audio")
class Wan22AudioRunner(WanAudioRunner): class Wan22AudioRunner(WanAudioRunner):
......
...@@ -225,12 +225,10 @@ class WanRunner(DefaultRunner): ...@@ -225,12 +225,10 @@ class WanRunner(DefaultRunner):
def run_image_encoder(self, first_frame, last_frame=None): def run_image_encoder(self, first_frame, last_frame=None):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.image_encoder = self.load_image_encoder() self.image_encoder = self.load_image_encoder()
first_frame = TF.to_tensor(first_frame).sub_(0.5).div_(0.5).cuda()
if last_frame is None: if last_frame is None:
clip_encoder_out = self.image_encoder.visual([first_frame[None, :, :, :]]).squeeze(0).to(GET_DTYPE()) clip_encoder_out = self.image_encoder.visual([first_frame]).squeeze(0).to(GET_DTYPE())
else: else:
last_frame = TF.to_tensor(last_frame).sub_(0.5).div_(0.5).cuda() clip_encoder_out = self.image_encoder.visual([first_frame, last_frame]).squeeze(0).to(GET_DTYPE())
clip_encoder_out = self.image_encoder.visual([first_frame[:, None, :, :].transpose(0, 1), last_frame[:, None, :, :].transpose(0, 1)]).squeeze(0).to(GET_DTYPE())
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.image_encoder del self.image_encoder
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -238,9 +236,7 @@ class WanRunner(DefaultRunner): ...@@ -238,9 +236,7 @@ class WanRunner(DefaultRunner):
return clip_encoder_out return clip_encoder_out
def run_vae_encoder(self, first_frame, last_frame=None): def run_vae_encoder(self, first_frame, last_frame=None):
first_frame_size = first_frame.size h, w = first_frame.shape[2:]
first_frame = TF.to_tensor(first_frame).sub_(0.5).div_(0.5).cuda()
h, w = first_frame.shape[1:]
aspect_ratio = h / w aspect_ratio = h / w
max_area = self.config.target_height * self.config.target_width max_area = self.config.target_height * self.config.target_width
lat_h = round(np.sqrt(max_area * aspect_ratio) // self.config.vae_stride[1] // self.config.patch_size[1] * self.config.patch_size[1]) lat_h = round(np.sqrt(max_area * aspect_ratio) // self.config.vae_stride[1] // self.config.patch_size[1] * self.config.patch_size[1])
...@@ -260,8 +256,8 @@ class WanRunner(DefaultRunner): ...@@ -260,8 +256,8 @@ class WanRunner(DefaultRunner):
return vae_encode_out_list return vae_encode_out_list
else: else:
if last_frame is not None: if last_frame is not None:
last_frame_size = last_frame.size first_frame_size = first_frame.shape[2:]
last_frame = TF.to_tensor(last_frame).sub_(0.5).div_(0.5).cuda() last_frame_size = last_frame.shape[2:]
if first_frame_size != last_frame_size: if first_frame_size != last_frame_size:
last_frame_resize_ratio = max(first_frame_size[0] / last_frame_size[0], first_frame_size[1] / last_frame_size[1]) last_frame_resize_ratio = max(first_frame_size[0] / last_frame_size[0], first_frame_size[1] / last_frame_size[1])
last_frame_size = [ last_frame_size = [
...@@ -298,16 +294,16 @@ class WanRunner(DefaultRunner): ...@@ -298,16 +294,16 @@ class WanRunner(DefaultRunner):
if last_frame is not None: if last_frame is not None:
vae_input = torch.concat( vae_input = torch.concat(
[ [
torch.nn.functional.interpolate(first_frame[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1), torch.nn.functional.interpolate(first_frame.cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
torch.zeros(3, self.config.target_video_length - 2, h, w), torch.zeros(3, self.config.target_video_length - 2, h, w),
torch.nn.functional.interpolate(last_frame[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1), torch.nn.functional.interpolate(last_frame.cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
], ],
dim=1, dim=1,
).cuda() ).cuda()
else: else:
vae_input = torch.concat( vae_input = torch.concat(
[ [
torch.nn.functional.interpolate(first_frame[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1), torch.nn.functional.interpolate(first_frame.cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
torch.zeros(3, self.config.target_video_length - 1, h, w), torch.zeros(3, self.config.target_video_length - 1, h, w),
], ],
dim=1, dim=1,
......
import gc import gc
import math
import numpy as np import numpy as np
import torch import torch
from loguru import logger
from lightx2v.models.schedulers.scheduler import BaseScheduler from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
def unsqueeze_to_ndim(in_tensor, tgt_n_dim): class ConsistencyModelScheduler(WanScheduler):
if in_tensor.ndim > tgt_n_dim: def __init__(self, config):
warnings.warn(f"the given tensor of shape {in_tensor.shape} is expected to unsqueeze to {tgt_n_dim}, the original tensor will be returned") super().__init__(config)
return in_tensor
if in_tensor.ndim < tgt_n_dim:
in_tensor = in_tensor[(...,) + (None,) * (tgt_n_dim - in_tensor.ndim)]
return in_tensor
class EulerSchedulerTimestepFix(BaseScheduler):
def __init__(self, config, **kwargs):
# super().__init__(**kwargs)
self.init_noise_sigma = 1.0
self.config = config
self.latents = None
self.device = torch.device("cuda")
self.infer_steps = self.config.infer_steps
self.target_video_length = self.config.target_video_length
self.sample_shift = self.config.sample_shift
self.num_train_timesteps = 1000
self.step_index = None
def step_pre(self, step_index): def step_pre(self, step_index):
self.step_index = step_index self.step_index = step_index
...@@ -37,12 +19,6 @@ class EulerSchedulerTimestepFix(BaseScheduler): ...@@ -37,12 +19,6 @@ class EulerSchedulerTimestepFix(BaseScheduler):
def prepare(self, image_encoder_output=None): def prepare(self, image_encoder_output=None):
self.prepare_latents(self.config.target_shape, dtype=torch.float32) self.prepare_latents(self.config.target_shape, dtype=torch.float32)
if self.config.task in ["t2v"]:
self.seq_len = math.ceil((self.config.target_shape[2] * self.config.target_shape[3]) / (self.config.patch_size[1] * self.config.patch_size[2]) * self.config.target_shape[1])
elif self.config.task in ["i2v"]:
self.seq_len = ((self.config.target_video_length - 1) // self.config.vae_stride[0] + 1) * self.config.lat_h * self.config.lat_w // (self.config.patch_size[1] * self.config.patch_size[2])
timesteps = np.linspace(self.num_train_timesteps, 0, self.infer_steps + 1, dtype=np.float32) timesteps = np.linspace(self.num_train_timesteps, 0, self.infer_steps + 1, dtype=np.float32)
self.timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=self.device) self.timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=self.device)
...@@ -53,29 +29,13 @@ class EulerSchedulerTimestepFix(BaseScheduler): ...@@ -53,29 +29,13 @@ class EulerSchedulerTimestepFix(BaseScheduler):
self.timesteps = self.sigmas * self.num_train_timesteps self.timesteps = self.sigmas * self.num_train_timesteps
def prepare_latents(self, target_shape, dtype=torch.float32):
self.generator = torch.Generator(device=self.device).manual_seed(self.config.seed)
self.latents = (
torch.randn(
target_shape[0],
target_shape[1],
target_shape[2],
target_shape[3],
dtype=dtype,
device=self.device,
generator=self.generator,
)
* self.init_noise_sigma
)
def step_post(self): def step_post(self):
model_output = self.noise_pred.to(torch.float32) model_output = self.noise_pred.to(torch.float32)
sample = self.latents.to(torch.float32) sample = self.latents.to(torch.float32)
sigma = self.unsqueeze_to_ndim(self.sigmas[self.step_index], sample.ndim).to(sample.device, sample.dtype)
sigma = unsqueeze_to_ndim(self.sigmas[self.step_index], sample.ndim).to(sample.device, sample.dtype) sigma_next = self.unsqueeze_to_ndim(self.sigmas[self.step_index + 1], sample.ndim).to(sample.device, sample.dtype)
sigma_next = unsqueeze_to_ndim(self.sigmas[self.step_index + 1], sample.ndim).to(sample.device, sample.dtype) x0 = sample - model_output * sigma
x_t_next = sample + (sigma_next - sigma) * model_output x_t_next = x0 * (1 - sigma_next) + sigma_next * torch.randn(x0.shape, dtype=x0.dtype, device=x0.device, generator=self.generator)
self.latents = x_t_next self.latents = x_t_next
def reset(self): def reset(self):
...@@ -83,13 +43,10 @@ class EulerSchedulerTimestepFix(BaseScheduler): ...@@ -83,13 +43,10 @@ class EulerSchedulerTimestepFix(BaseScheduler):
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
def unsqueeze_to_ndim(self, in_tensor, tgt_n_dim):
class ConsistencyModelScheduler(EulerSchedulerTimestepFix): if in_tensor.ndim > tgt_n_dim:
def step_post(self): logger.warning(f"the given tensor of shape {in_tensor.shape} is expected to unsqueeze to {tgt_n_dim}, the original tensor will be returned")
model_output = self.noise_pred.to(torch.float32) return in_tensor
sample = self.latents.to(torch.float32) if in_tensor.ndim < tgt_n_dim:
sigma = unsqueeze_to_ndim(self.sigmas[self.step_index], sample.ndim).to(sample.device, sample.dtype) in_tensor = in_tensor[(...,) + (None,) * (tgt_n_dim - in_tensor.ndim)]
sigma_next = unsqueeze_to_ndim(self.sigmas[self.step_index + 1], sample.ndim).to(sample.device, sample.dtype) return in_tensor
x0 = sample - model_output * sigma
x_t_next = x0 * (1 - sigma_next) + sigma_next * torch.randn(x0.shape, dtype=x0.dtype, device=x0.device, generator=self.generator)
self.latents = x_t_next
...@@ -20,6 +20,7 @@ class WanScheduler4ChangingResolution: ...@@ -20,6 +20,7 @@ class WanScheduler4ChangingResolution:
assert len(config["resolution_rate"]) == len(config["changing_resolution_steps"]) assert len(config["resolution_rate"]) == len(config["changing_resolution_steps"])
def prepare_latents(self, target_shape, dtype=torch.float32): def prepare_latents(self, target_shape, dtype=torch.float32):
self.generator = torch.Generator(device=self.device).manual_seed(self.config.seed)
self.latents_list = [] self.latents_list = []
for i in range(len(self.config["resolution_rate"])): for i in range(len(self.config["resolution_rate"])):
self.latents_list.append( self.latents_list.append(
......
...@@ -26,8 +26,6 @@ class WanScheduler(BaseScheduler): ...@@ -26,8 +26,6 @@ class WanScheduler(BaseScheduler):
def prepare(self, image_encoder_output=None): def prepare(self, image_encoder_output=None):
if self.config["model_cls"] == "wan2.2" and self.config["task"] == "i2v": if self.config["model_cls"] == "wan2.2" and self.config["task"] == "i2v":
self.vae_encoder_out = image_encoder_output["vae_encoder_out"] self.vae_encoder_out = image_encoder_output["vae_encoder_out"]
self.generator = torch.Generator(device=self.device)
self.generator.manual_seed(self.config.seed)
self.prepare_latents(self.config.target_shape, dtype=torch.float32) self.prepare_latents(self.config.target_shape, dtype=torch.float32)
...@@ -51,6 +49,7 @@ class WanScheduler(BaseScheduler): ...@@ -51,6 +49,7 @@ class WanScheduler(BaseScheduler):
self.set_timesteps(self.infer_steps, device=self.device, shift=self.sample_shift) self.set_timesteps(self.infer_steps, device=self.device, shift=self.sample_shift)
def prepare_latents(self, target_shape, dtype=torch.float32): def prepare_latents(self, target_shape, dtype=torch.float32):
self.generator = torch.Generator(device=self.device).manual_seed(self.config.seed)
self.latents = torch.randn( self.latents = torch.randn(
target_shape[0], target_shape[0],
target_shape[1], target_shape[1],
......
import safetensors
import torch
from safetensors.torch import save_file
from lightx2v.utils.quant_utils import FloatQuantizer
model_path = "/data/nvme0/models/Wan2.1-R2V721-Audio-14B-720P/audio_adapter.safetensors"
state_dict = {}
with safetensors.safe_open(model_path, framework="pt", device="cpu") as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key)
new_state_dict = {}
new_model_path = "/data/nvme0/models/Wan2.1-R2V721-Audio-14B-720P/audio_adapter_fp8.safetensors"
for key in state_dict.keys():
if key.startswith("ca") and ".to" in key and "weight" in key and "to_kv" not in key:
print(key, state_dict[key].dtype)
weight = state_dict[key].to(torch.float32).cuda()
w_quantizer = FloatQuantizer("e4m3", True, "per_channel")
weight, weight_scale, _ = w_quantizer.real_quant_tensor(weight)
weight = weight.to(torch.float8_e4m3fn)
weight_scale = weight_scale.to(torch.float32)
new_state_dict[key] = weight.cpu()
new_state_dict[key + "_scale"] = weight_scale.cpu()
for key in state_dict.keys():
if key not in new_state_dict.keys():
new_state_dict[key] = state_dict[key]
save_file(new_state_dict, new_model_path)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment