Commit d8454a2b authored by helloyongyang's avatar helloyongyang
Browse files

Refactor runners

parent 2054eca3
......@@ -18,5 +18,7 @@
"dit_quantized_ckpt": "/path/to/Wan2.1-R2V721-Audio-14B-720P/fp8",
"mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm"
}
},
"adapter_quantized": true,
"adapter_quant_scheme": "fp8"
}
......@@ -6,6 +6,11 @@ try:
except ModuleNotFoundError:
ops = None
try:
import sgl_kernel
except ImportError:
sgl_kernel = None
try:
from torchao.quantization.utils import quant_int8_per_token_matmul, quantize_activation_per_token_absmax
except ModuleNotFoundError:
......@@ -117,6 +122,58 @@ class VllmQuantLinearFp8(nn.Module):
return self
class SglQuantLinearFp8(nn.Module):
def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.float8_e4m3fn))
self.register_buffer("weight_scale", torch.empty((out_features, 1), dtype=torch.float32))
if bias:
self.register_buffer("bias", torch.empty(out_features, dtype=dtype))
else:
self.register_buffer("bias", None)
def act_quant_func(self, x):
m, k = x.shape
input_tensor_quant = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda", requires_grad=False)
input_tensor_scale = torch.empty((m, 1), dtype=torch.float32, device="cuda", requires_grad=False)
sgl_kernel.sgl_per_token_quant_fp8(x, input_tensor_quant, input_tensor_scale)
return input_tensor_quant, input_tensor_scale
def forward(self, input_tensor):
input_tensor = input_tensor.squeeze(0)
shape = (input_tensor.shape[0], self.weight.shape[0])
dtype = input_tensor.dtype
device = input_tensor.device
output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = sgl_kernel.fp8_scaled_mm(
input_tensor_quant,
self.weight.t(),
input_tensor_scale,
self.weight_scale,
dtype,
bias=self.bias,
)
return output_tensor.unsqueeze(0)
def _apply(self, fn):
for module in self.children():
module._apply(fn)
def maybe_cast(t):
if t is not None and t.device != fn(t).device:
return fn(t)
return t
self.weight = maybe_cast(self.weight)
self.weight_scale = maybe_cast(self.weight_scale)
self.bias = maybe_cast(self.bias)
return self
class TorchaoQuantLinearInt8(nn.Module):
def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16):
super().__init__()
......
......@@ -13,9 +13,8 @@ import torch.nn.functional as F
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from einops import rearrange
from loguru import logger
from transformers import AutoModel
from lightx2v.utils.envs import *
from lightx2v.models.input_encoders.hf.q_linear import SglQuantLinearFp8
def load_safetensors(in_path: str):
......@@ -84,8 +83,6 @@ def rank0_load_state_dict_from_path(model, in_path: str, strict: bool = True):
for buffer in model.buffers():
dist.broadcast(buffer.data, src=0)
return model.to(dtype=GET_DTYPE())
def linear_interpolation(features, output_len: int):
features = features.transpose(1, 2)
......@@ -120,7 +117,7 @@ def get_q_lens_audio_range(
class PerceiverAttentionCA(nn.Module):
def __init__(self, dim_head=128, heads=16, kv_dim=2048, adaLN: bool = False):
def __init__(self, dim_head=128, heads=16, kv_dim=2048, adaLN: bool = False, quantized=False, quant_scheme=None):
super().__init__()
self.dim_head = dim_head
self.heads = heads
......@@ -129,9 +126,17 @@ class PerceiverAttentionCA(nn.Module):
self.norm_kv = nn.LayerNorm(kv_dim)
self.norm_q = nn.LayerNorm(inner_dim, elementwise_affine=not adaLN)
self.to_q = nn.Linear(inner_dim, inner_dim)
self.to_kv = nn.Linear(kv_dim, inner_dim * 2)
self.to_out = nn.Linear(inner_dim, inner_dim)
if quantized:
if quant_scheme == "fp8":
self.to_q = SglQuantLinearFp8(inner_dim, inner_dim)
self.to_kv = nn.Linear(kv_dim, inner_dim * 2)
self.to_out = SglQuantLinearFp8(inner_dim, inner_dim)
else:
raise ValueError(f"Unsupported quant_scheme: {quant_scheme}")
else:
self.to_q = nn.Linear(inner_dim, inner_dim)
self.to_kv = nn.Linear(kv_dim, inner_dim * 2)
self.to_out = nn.Linear(inner_dim, inner_dim)
if adaLN:
self.shift_scale_gate = nn.Parameter(torch.randn(1, 3, inner_dim) / inner_dim**0.5)
else:
......@@ -151,7 +156,7 @@ class PerceiverAttentionCA(nn.Module):
shift = shift.transpose(0, 1)
gate = gate.transpose(0, 1)
latents = norm_q * (1 + scale) + shift
q = self.to_q(latents.to(GET_DTYPE()))
q = self.to_q(latents)
k, v = self.to_kv(x).chunk(2, dim=-1)
q = rearrange(q, "B L (H C) -> (B L) H C", H=self.heads)
k = rearrange(k, "B T L (H C) -> (B T L) H C", H=self.heads)
......@@ -258,6 +263,8 @@ class AudioAdapter(nn.Module):
mlp_dims: tuple = (1024, 1024, 32 * 768),
time_freq_dim: int = 256,
projection_transformer_layers: int = 4,
quantized: bool = False,
quant_scheme: str = None,
):
super().__init__()
self.audio_proj = AudioProjection(
......@@ -280,6 +287,8 @@ class AudioAdapter(nn.Module):
heads=num_attention_heads,
kv_dim=mlp_dims[-1] // num_tokens,
adaLN=time_freq_dim > 0,
quantized=quantized,
quant_scheme=quant_scheme,
)
for _ in range(ca_num)
]
......@@ -298,181 +307,9 @@ class AudioAdapter(nn.Module):
audio_feature = rearrange(audio_feature, "B (T S) N C -> B T (S N) C", S=4)
return audio_feature
def forward(self, audio_feat: torch.Tensor, timestep: torch.Tensor, latent_frame: int, weight: float = 1.0, seq_p_group=None):
def modify_hidden_states(hidden_states, grid_sizes, ca_block: PerceiverAttentionCA, x, t_emb, dtype, weight, seq_p_group):
"""thw specify the latent_frame, latent_height, latenf_width after
hidden_states is patchified.
latent_frame does not include the reference images so that the
audios and hidden_states are strictly aligned
"""
if len(hidden_states.shape) == 2: # 扩展batchsize dim
hidden_states = hidden_states.unsqueeze(0) # bs = 1
t, h, w = grid_sizes[0].tolist()
n_tokens = t * h * w
ori_dtype = hidden_states.dtype
device = hidden_states.device
bs, n_tokens_per_rank = hidden_states.shape[:2]
if seq_p_group is not None:
sp_size = dist.get_world_size(seq_p_group)
sp_rank = dist.get_rank(seq_p_group)
else:
sp_size = 1
sp_rank = 0
tail_length = n_tokens_per_rank * sp_size - n_tokens
n_unused_ranks = tail_length // n_tokens_per_rank
if sp_rank > sp_size - n_unused_ranks - 1:
n_query_tokens = 0
elif sp_rank == sp_size - n_unused_ranks - 1:
n_query_tokens = n_tokens_per_rank - tail_length % n_tokens_per_rank
else:
n_query_tokens = n_tokens_per_rank
if n_query_tokens > 0:
hidden_states_aligned = hidden_states[:, :n_query_tokens]
hidden_states_tail = hidden_states[:, n_query_tokens:]
else:
# for ranks that should be excluded from cross-attn, fake cross-attn will be applied so that FSDP works.
hidden_states_aligned = hidden_states[:, :1]
hidden_states_tail = hidden_states[:, 1:]
q_lens, t0, t1 = get_q_lens_audio_range(batchsize=bs, n_tokens_per_rank=n_tokens_per_rank, n_query_tokens=n_query_tokens, n_tokens_per_frame=h * w, sp_rank=sp_rank)
q_lens = torch.tensor(q_lens, device=device, dtype=torch.int32)
"""
processing audio features in sp_state can be moved outside.
"""
x = x[:, t0:t1]
x = x.to(dtype)
k_lens = torch.tensor([self.num_tokens_x4] * (t1 - t0) * bs, device=device, dtype=torch.int32)
assert q_lens.shape == k_lens.shape
# ca_block:CrossAttention函数
residual = ca_block(x, hidden_states_aligned, t_emb, q_lens, k_lens) * weight
residual = residual.to(ori_dtype) # audio做了CrossAttention之后以Residual的方式注入
if n_query_tokens == 0:
residual = residual * 0.0
hidden_states = torch.cat([hidden_states_aligned + residual, hidden_states_tail], dim=1)
if len(hidden_states.shape) == 3: #
hidden_states = hidden_states.squeeze(0) # bs = 1
return hidden_states
@torch.no_grad()
def forward_audio_proj(self, audio_feat, latent_frame):
x = self.audio_proj(audio_feat, latent_frame)
x = self.rearange_audio_features(x)
x = x + self.audio_pe
if self.time_embedding is not None:
t_emb = self.time_embedding(timestep).unflatten(1, (3, -1))
else:
t_emb = torch.zeros((len(x), 3, self.dim), device=x.device, dtype=x.dtype)
ret_dict = {}
for block_idx, base_idx in enumerate(range(0, self.base_num_layers, self.interval)):
block_dict = {
"kwargs": {
"ca_block": self.ca[block_idx],
"x": x,
"weight": weight,
"t_emb": t_emb,
"dtype": x.dtype,
"seq_p_group": seq_p_group,
},
"modify_func": modify_hidden_states,
}
ret_dict[base_idx] = block_dict
return ret_dict
@classmethod
def from_transformer(
cls,
transformer,
audio_feature_dim: int = 1024,
interval: int = 1,
time_freq_dim: int = 256,
projection_transformer_layers: int = 4,
):
num_attention_heads = transformer.config["num_heads"]
base_num_layers = transformer.config["num_layers"]
attention_head_dim = transformer.config["dim"] // num_attention_heads
audio_adapter = AudioAdapter(
attention_head_dim,
num_attention_heads,
base_num_layers,
interval=interval,
audio_feature_dim=audio_feature_dim,
time_freq_dim=time_freq_dim,
projection_transformer_layers=projection_transformer_layers,
mlp_dims=(1024, 1024, 32 * audio_feature_dim),
)
return audio_adapter
def get_fsdp_wrap_module_list(
self,
):
ret_list = list(self.ca)
return ret_list
def enable_gradient_checkpointing(
self,
):
pass
class AudioAdapterPipe:
def __init__(
self,
audio_adapter: AudioAdapter,
audio_encoder_repo: str = "microsoft/wavlm-base-plus",
dtype=torch.float32,
device="cuda",
tgt_fps: int = 15,
weight: float = 1.0,
cpu_offload: bool = False,
seq_p_group=None,
) -> None:
self.seq_p_group = seq_p_group
self.audio_adapter = audio_adapter
self.dtype = dtype
self.audio_encoder_dtype = torch.float16
self.cpu_offload = cpu_offload
##音频编码器
self.audio_encoder = AutoModel.from_pretrained(audio_encoder_repo)
self.audio_encoder.eval()
self.audio_encoder.to(device, self.audio_encoder_dtype)
self.tgt_fps = tgt_fps
self.weight = weight
if "base" in audio_encoder_repo:
self.audio_feature_dim = 768
else:
self.audio_feature_dim = 1024
def update_model(self, audio_adapter):
self.audio_adapter = audio_adapter
def __call__(self, audio_input_feat, timestep, latent_shape: tuple, dropout_cond: callable = None):
# audio_input_feat is from AudioPreprocessor
latent_frame = latent_shape[2]
if len(audio_input_feat.shape) == 1: # 扩展batchsize = 1
audio_input_feat = audio_input_feat.unsqueeze(0)
latent_frame = latent_shape[1]
video_frame = (latent_frame - 1) * 4 + 1
audio_length = int(50 / self.tgt_fps * video_frame)
with torch.no_grad():
try:
if self.cpu_offload:
self.audio_encoder = self.audio_encoder.to("cuda")
audio_feat = self.audio_encoder(audio_input_feat.to(self.audio_encoder_dtype), return_dict=True).last_hidden_state
if self.cpu_offload:
self.audio_encoder = self.audio_encoder.to("cpu")
except Exception as err:
audio_feat = torch.rand(1, audio_length, self.audio_feature_dim).to("cuda")
print(err)
audio_feat = audio_feat.to(self.dtype)
if dropout_cond is not None:
audio_feat = dropout_cond(audio_feat)
return self.audio_adapter(audio_feat=audio_feat, timestep=timestep, latent_frame=latent_frame, weight=self.weight, seq_p_group=self.seq_p_group)
return x
import torch
from transformers import AutoFeatureExtractor, AutoModel
from lightx2v.utils.envs import *
class SekoAudioEncoderModel:
def __init__(self, model_path, audio_sr):
self.model_path = model_path
self.audio_sr = audio_sr
self.load()
def load(self):
self.audio_feature_extractor = AutoFeatureExtractor.from_pretrained(self.model_path)
self.audio_feature_encoder = AutoModel.from_pretrained(self.model_path)
self.audio_feature_encoder.eval()
self.audio_feature_encoder.to(GET_DTYPE())
def to_cpu(self):
self.audio_feature_encoder = self.audio_feature_encoder.to("cpu")
def to_cuda(self):
self.audio_feature_encoder = self.audio_feature_encoder.to("cuda")
@torch.no_grad()
def infer(self, audio_segment):
audio_feat = self.audio_feature_extractor(audio_segment, sampling_rate=self.audio_sr, return_tensors="pt").input_values.to(self.audio_feature_encoder.device).to(dtype=GET_DTYPE())
audio_feat = self.audio_feature_encoder(audio_feat, return_dict=True).last_hidden_state
return audio_feat
......@@ -26,6 +26,11 @@ class WanAudioModel(WanModel):
self.post_infer_class = WanAudioPostInfer
self.transformer_infer_class = WanAudioTransformerInfer
def set_audio_adapter(self, audio_adapter):
self.audio_adapter = audio_adapter
self.pre_infer.set_audio_adapter(self.audio_adapter)
self.transformer_infer.set_audio_adapter(self.audio_adapter)
class Wan22MoeAudioModel(WanAudioModel):
def _load_ckpt(self, unified_dtype, sensitive_layer):
......
import math
import torch
from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer
......@@ -8,32 +6,14 @@ from lightx2v.utils.envs import *
class WanAudioPostInfer(WanPostInfer):
def __init__(self, config):
self.out_dim = config["out_dim"]
self.patch_size = (1, 2, 2)
self.clean_cuda_cache = config.get("clean_cuda_cache", False)
self.infer_dtype = GET_DTYPE()
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
def set_scheduler(self, scheduler):
self.scheduler = scheduler
super().__init__(config)
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, x, pre_infer_out):
x = x[:, : pre_infer_out.valid_patch_length]
x = x[: pre_infer_out.seq_lens[0]]
x = self.unpatchify(x, pre_infer_out.grid_sizes)
if self.clean_cuda_cache:
torch.cuda.empty_cache()
return [u.float() for u in x]
def unpatchify(self, x, grid_sizes):
x = x.unsqueeze(0)
c = self.out_dim
out = []
for u, v in zip(x, grid_sizes.tolist()):
u = u[: math.prod(v)].view(*v, *self.patch_size, c)
u = torch.einsum("fhwpqrc->cfphqwr", u)
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
out.append(u)
return out
......@@ -35,6 +35,9 @@ class WanAudioPreInfer(WanPreInfer):
else:
self.sp_size = 1
def set_audio_adapter(self, audio_adapter):
self.audio_adapter = audio_adapter
def infer(self, weights, inputs):
prev_latents = inputs["previmg_encoder_output"]["prev_latents"]
if self.config.model_cls == "wan2.2_audio":
......@@ -48,7 +51,7 @@ class WanAudioPreInfer(WanPreInfer):
hidden_states = torch.cat([hidden_states, prev_mask, prev_latents], dim=1)
hidden_states = hidden_states.squeeze(0)
x = [hidden_states]
x = hidden_states
t = torch.stack([self.scheduler.timesteps[self.scheduler.step_index]])
if self.config.model_cls == "wan2.2_audio":
......@@ -61,31 +64,23 @@ class WanAudioPreInfer(WanPreInfer):
temp_ts = torch.cat([temp_ts, temp_ts.new_ones(max_seq_len - temp_ts.size(0)) * t])
t = temp_ts.unsqueeze(0)
audio_dit_blocks = []
audio_encoder_output = inputs["audio_encoder_output"]
audio_model_input = {
"audio_input_feat": audio_encoder_output.to(hidden_states.device),
"latent_shape": hidden_states.shape,
"timestep": t,
}
audio_dit_blocks.append(inputs["audio_adapter_pipe"](**audio_model_input))
# audio_dit_blocks = None##Debug Drop Audio
t_emb = self.audio_adapter.time_embedding(t).unflatten(1, (3, -1))
if self.scheduler.infer_condition:
context = inputs["text_encoder_output"]["context"]
else:
context = inputs["text_encoder_output"]["context_null"]
seq_len = self.scheduler.seq_len
# seq_len = self.scheduler.seq_len
clip_fea = inputs["image_encoder_output"]["clip_encoder_out"]
ref_image_encoder = inputs["image_encoder_output"]["vae_encoder_out"].to(self.scheduler.latents.dtype)
batch_size = len(x)
num_channels, _, height, width = x[0].shape
# batch_size = len(x)
num_channels, _, height, width = x.shape
_, ref_num_channels, ref_num_frames, _, _ = ref_image_encoder.shape
if ref_num_channels != num_channels:
zero_padding = torch.zeros(
(batch_size, num_channels - ref_num_channels, ref_num_frames, height, width),
(1, num_channels - ref_num_channels, ref_num_frames, height, width),
dtype=self.scheduler.latents.dtype,
device=self.scheduler.latents.device,
)
......@@ -93,13 +88,10 @@ class WanAudioPreInfer(WanPreInfer):
y = list(torch.unbind(ref_image_encoder, dim=0)) # 第一个batch维度变成list
# embeddings
x = [weights.patch_embedding.apply(u.unsqueeze(0)) for u in x]
x_grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
x = [u.flatten(2).transpose(1, 2) for u in x]
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long).cuda()
assert seq_lens.max() <= seq_len
x = torch.cat([torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) for u in x])
valid_patch_length = x[0].size(0)
x = weights.patch_embedding.apply(x.unsqueeze(0))
grid_sizes = torch.tensor(x.shape[2:], dtype=torch.long).unsqueeze(0)
x = x.flatten(2).transpose(1, 2).contiguous()
seq_lens = torch.tensor(x.size(1), dtype=torch.long).cuda().unsqueeze(0)
y = [weights.patch_embedding.apply(u.unsqueeze(0)) for u in y]
# y_grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in y])
......@@ -169,12 +161,11 @@ class WanAudioPreInfer(WanPreInfer):
return WanPreInferModuleOutput(
embed=embed,
grid_sizes=x_grid_sizes,
grid_sizes=grid_sizes,
x=x.squeeze(0),
embed0=embed0.squeeze(0),
seq_lens=seq_lens,
freqs=self.freqs,
context=context,
audio_dit_blocks=audio_dit_blocks,
valid_patch_length=valid_patch_length,
adapter_output={"audio_encoder_output": inputs["audio_encoder_output"], "t_emb": t_emb},
)
import torch
import torch.distributed as dist
from lightx2v.models.input_encoders.hf.seko_audio.audio_adapter import get_q_lens_audio_range
from lightx2v.models.networks.wan.infer.offload.transformer_infer import WanOffloadTransformerInfer
from lightx2v.models.networks.wan.infer.utils import compute_freqs_audio, compute_freqs_audio_dist
......@@ -5,7 +9,13 @@ from lightx2v.models.networks.wan.infer.utils import compute_freqs_audio, comput
class WanAudioTransformerInfer(WanOffloadTransformerInfer):
def __init__(self, config):
super().__init__(config)
self.num_tokens = 32
self.num_tokens_x4 = self.num_tokens * 4
def set_audio_adapter(self, audio_adapter):
self.audio_adapter = audio_adapter
@torch.no_grad()
def compute_freqs(self, q, grid_sizes, freqs):
if self.config["seq_parallel"]:
freqs_i = compute_freqs_audio_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs, self.seq_p_group)
......@@ -13,13 +23,77 @@ class WanAudioTransformerInfer(WanOffloadTransformerInfer):
freqs_i = compute_freqs_audio(q.size(2) // 2, grid_sizes, freqs)
return freqs_i
@torch.no_grad()
def post_process(self, x, y, c_gate_msa, pre_infer_out):
x = super().post_process(x, y, c_gate_msa, pre_infer_out)
# Apply audio_dit if available
if pre_infer_out.audio_dit_blocks is not None and hasattr(self, "block_idx"):
for ipa_out in pre_infer_out.audio_dit_blocks:
if self.block_idx in ipa_out:
cur_modify = ipa_out[self.block_idx]
x = cur_modify["modify_func"](x, pre_infer_out.grid_sizes, **cur_modify["kwargs"])
x = self.modify_hidden_states(
hidden_states=x,
grid_sizes=pre_infer_out.grid_sizes,
ca_block=self.audio_adapter.ca[self.block_idx],
audio_encoder_output=pre_infer_out.adapter_output["audio_encoder_output"],
t_emb=pre_infer_out.adapter_output["t_emb"],
weight=1.0,
seq_p_group=self.seq_p_group,
)
return x
@torch.no_grad()
def modify_hidden_states(self, hidden_states, grid_sizes, ca_block, audio_encoder_output, t_emb, weight, seq_p_group):
"""thw specify the latent_frame, latent_height, latenf_width after
hidden_states is patchified.
latent_frame does not include the reference images so that the
audios and hidden_states are strictly aligned
"""
if len(hidden_states.shape) == 2: # 扩展batchsize dim
hidden_states = hidden_states.unsqueeze(0) # bs = 1
t, h, w = grid_sizes[0].tolist()
n_tokens = t * h * w
ori_dtype = hidden_states.dtype
device = hidden_states.device
bs, n_tokens_per_rank = hidden_states.shape[:2]
if seq_p_group is not None:
sp_size = dist.get_world_size(seq_p_group)
sp_rank = dist.get_rank(seq_p_group)
else:
sp_size = 1
sp_rank = 0
tail_length = n_tokens_per_rank * sp_size - n_tokens
n_unused_ranks = tail_length // n_tokens_per_rank
if sp_rank > sp_size - n_unused_ranks - 1:
n_query_tokens = 0
elif sp_rank == sp_size - n_unused_ranks - 1:
n_query_tokens = n_tokens_per_rank - tail_length % n_tokens_per_rank
else:
n_query_tokens = n_tokens_per_rank
if n_query_tokens > 0:
hidden_states_aligned = hidden_states[:, :n_query_tokens]
hidden_states_tail = hidden_states[:, n_query_tokens:]
else:
# for ranks that should be excluded from cross-attn, fake cross-attn will be applied so that FSDP works.
hidden_states_aligned = hidden_states[:, :1]
hidden_states_tail = hidden_states[:, 1:]
q_lens, t0, t1 = get_q_lens_audio_range(batchsize=bs, n_tokens_per_rank=n_tokens_per_rank, n_query_tokens=n_query_tokens, n_tokens_per_frame=h * w, sp_rank=sp_rank)
q_lens = torch.tensor(q_lens, device=device, dtype=torch.int32)
"""
processing audio features in sp_state can be moved outside.
"""
audio_encoder_output = audio_encoder_output[:, t0:t1]
k_lens = torch.tensor([self.num_tokens_x4] * (t1 - t0) * bs, device=device, dtype=torch.int32)
assert q_lens.shape == k_lens.shape
# ca_block:CrossAttention函数
residual = ca_block(audio_encoder_output, hidden_states_aligned, t_emb, q_lens, k_lens) * weight
residual = residual.to(ori_dtype) # audio做了CrossAttention之后以Residual的方式注入
if n_query_tokens == 0:
residual = residual * 0.0
hidden_states = torch.cat([hidden_states_aligned + residual, hidden_states_tail], dim=1)
if len(hidden_states.shape) == 3: #
hidden_states = hidden_states.squeeze(0) # bs = 1
return hidden_states
from dataclasses import dataclass
from typing import Any, List, Optional
from typing import Any, Dict
import torch
@dataclass
class WanPreInferModuleOutput:
# wan base model
embed: torch.Tensor
grid_sizes: torch.Tensor
x: torch.Tensor
......@@ -13,7 +14,6 @@ class WanPreInferModuleOutput:
seq_lens: torch.Tensor
freqs: torch.Tensor
context: torch.Tensor
audio_dit_blocks: List[Any] = None
valid_patch_length: Optional[int] = None
hints: List[Any] = None
context_scale: float = 1.0
# wan adapter model
adapter_output: Dict[str, Any] = None
......@@ -9,7 +9,7 @@ class WanVaceTransformerInfer(WanOffloadTransformerInfer):
self.vace_blocks_mapping = {orig_idx: seq_idx for seq_idx, orig_idx in enumerate(self.config.vace_layers)}
def infer(self, weights, pre_infer_out):
pre_infer_out.hints = self.infer_vace(weights, pre_infer_out)
pre_infer_out.adapter_output["hints"] = self.infer_vace(weights, pre_infer_out)
x = self.infer_main_blocks(weights, pre_infer_out)
return self.infer_non_blocks(weights, x, pre_infer_out.embed)
......@@ -40,6 +40,6 @@ class WanVaceTransformerInfer(WanOffloadTransformerInfer):
x = super().post_process(x, y, c_gate_msa, pre_infer_out)
if self.infer_state == "base" and self.block_idx in self.vace_blocks_mapping:
hint_idx = self.vace_blocks_mapping[self.block_idx]
x = x + pre_infer_out.hints[hint_idx] * pre_infer_out.context_scale
x = x + pre_infer_out.adapter_output["hints"][hint_idx] * pre_infer_out.adapter_output.get("context_scale", 1.0)
return x
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Protocol, Tuple, Union
from abc import ABC
from lightx2v.utils.utils import save_videos_grid
class TransformerModel(Protocol):
"""Protocol for transformer models"""
def set_scheduler(self, scheduler: Any) -> None: ...
def scheduler(self) -> Any: ...
class TextEncoderModel(Protocol):
"""Protocol for text encoder models"""
def infer(self, texts: List[str], config: Dict[str, Any]) -> Any: ...
class ImageEncoderModel(Protocol):
"""Protocol for image encoder models"""
def encode(self, image: Any) -> Any: ...
class VAEModel(Protocol):
"""Protocol for VAE models"""
def encode(self, image: Any) -> Tuple[Any, Dict[str, Any]]: ...
def decode(self, latents: Any, generator: Optional[Any] = None, config: Optional[Dict[str, Any]] = None) -> Any: ...
class BaseRunner(ABC):
"""Abstract base class for all Runners
Defines interface methods that all subclasses must implement
"""
def __init__(self, config: Dict[str, Any]):
def __init__(self, config):
self.config = config
@abstractmethod
def load_transformer(self) -> TransformerModel:
def load_transformer(self):
"""Load transformer model
Returns:
......@@ -48,8 +20,7 @@ class BaseRunner(ABC):
"""
pass
@abstractmethod
def load_text_encoder(self) -> Union[TextEncoderModel, List[TextEncoderModel]]:
def load_text_encoder(self):
"""Load text encoder
Returns:
......@@ -57,8 +28,7 @@ class BaseRunner(ABC):
"""
pass
@abstractmethod
def load_image_encoder(self) -> Optional[ImageEncoderModel]:
def load_image_encoder(self):
"""Load image encoder
Returns:
......@@ -66,8 +36,7 @@ class BaseRunner(ABC):
"""
pass
@abstractmethod
def load_vae(self) -> Tuple[VAEModel, VAEModel]:
def load_vae(self):
"""Load VAE encoder and decoder
Returns:
......@@ -75,8 +44,7 @@ class BaseRunner(ABC):
"""
pass
@abstractmethod
def run_image_encoder(self, img: Any) -> Any:
def run_image_encoder(self, img):
"""Run image encoder
Args:
......@@ -87,8 +55,7 @@ class BaseRunner(ABC):
"""
pass
@abstractmethod
def run_vae_encoder(self, img: Any) -> Tuple[Any, Dict[str, Any]]:
def run_vae_encoder(self, img):
"""Run VAE encoder
Args:
......@@ -99,8 +66,7 @@ class BaseRunner(ABC):
"""
pass
@abstractmethod
def run_text_encoder(self, prompt: str, img: Optional[Any] = None) -> Any:
def run_text_encoder(self, prompt, img):
"""Run text encoder
Args:
......@@ -112,8 +78,7 @@ class BaseRunner(ABC):
"""
pass
@abstractmethod
def get_encoder_output_i2v(self, clip_encoder_out: Any, vae_encoder_out: Any, text_encoder_output: Any, img: Any) -> Dict[str, Any]:
def get_encoder_output_i2v(self, clip_encoder_out, vae_encoder_out, text_encoder_output, img):
"""Combine encoder outputs for i2v task
Args:
......@@ -127,12 +92,11 @@ class BaseRunner(ABC):
"""
pass
@abstractmethod
def init_scheduler(self) -> None:
def init_scheduler(self):
"""Initialize scheduler"""
pass
def set_target_shape(self) -> Dict[str, Any]:
def set_target_shape(self):
"""Set target shape
Subclasses can override this method to provide specific implementation
......@@ -142,7 +106,7 @@ class BaseRunner(ABC):
"""
return {}
def save_video_func(self, images: Any) -> None:
def save_video_func(self, images):
"""Save video implementation
Subclasses can override this method to customize save logic
......@@ -152,7 +116,7 @@ class BaseRunner(ABC):
"""
save_videos_grid(images, self.config.get("save_video_path", "./output.mp4"), n_rows=1, fps=self.config.get("fps", 8))
def load_vae_decoder(self) -> VAEModel:
def load_vae_decoder(self):
"""Load VAE decoder
Default implementation: get decoder from load_vae method
......@@ -164,3 +128,21 @@ class BaseRunner(ABC):
if not hasattr(self, "vae_decoder") or self.vae_decoder is None:
_, self.vae_decoder = self.load_vae()
return self.vae_decoder
def get_video_segment_num(self):
self.video_segment_num = 1
def init_run(self):
pass
def init_run_segment(self, segment_idx):
self.segment_idx = segment_idx
def run_segment(self, total_steps=None):
pass
def end_run_segment(self):
pass
def end_run(self):
pass
......@@ -3,6 +3,7 @@ import gc
import requests
import torch
import torch.distributed as dist
import torchvision.transforms.functional as TF
from PIL import Image
from loguru import logger
from requests.exceptions import RequestException
......@@ -35,8 +36,6 @@ class DefaultRunner(BaseRunner):
self.load_model()
elif self.config.get("lazy_load", False):
assert self.config.get("cpu_offload", False)
self.run_dit = self._run_dit_local
self.run_vae_decoder = self._run_vae_decoder_local
if self.config["task"] == "i2v":
self.run_input_encoder = self._run_input_encoder_local_i2v
elif self.config["task"] == "flf2v":
......@@ -108,7 +107,7 @@ class DefaultRunner(BaseRunner):
def set_progress_callback(self, callback):
self.progress_callback = callback
def run(self, total_steps=None):
def run_segment(self, total_steps=None):
if total_steps is None:
total_steps = self.model.scheduler.infer_steps
for step_index in range(total_steps):
......@@ -130,8 +129,7 @@ class DefaultRunner(BaseRunner):
def run_step(self):
self.inputs = self.run_input_encoder()
self.set_target_shape()
self.run_dit(total_steps=1)
self.run_main(total_steps=1)
def end_run(self):
self.model.scheduler.clear()
......@@ -147,10 +145,15 @@ class DefaultRunner(BaseRunner):
torch.cuda.empty_cache()
gc.collect()
def read_image_input(self, img_path):
img = Image.open(img_path).convert("RGB")
img = TF.to_tensor(img).sub_(0.5).div_(0.5).unsqueeze(0).cuda()
return img
@ProfilingContext("Run Encoders")
def _run_input_encoder_local_i2v(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
img = Image.open(self.config["image_path"]).convert("RGB")
img = self.read_image_input(self.config["image_path"])
clip_encoder_out = self.run_image_encoder(img) if self.config.get("use_image_encoder", True) else None
vae_encode_out = self.run_vae_encoder(img)
text_encoder_output = self.run_text_encoder(prompt, img)
......@@ -172,8 +175,8 @@ class DefaultRunner(BaseRunner):
@ProfilingContext("Run Encoders")
def _run_input_encoder_local_flf2v(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
first_frame = Image.open(self.config["image_path"]).convert("RGB")
last_frame = Image.open(self.config["last_frame_path"]).convert("RGB")
first_frame = self.read_image_input(self.config["image_path"])
last_frame = self.read_image_input(self.config["last_frame_path"])
clip_encoder_out = self.run_image_encoder(first_frame, last_frame) if self.config.get("use_image_encoder", True) else None
vae_encode_out = self.run_vae_encoder(first_frame, last_frame)
text_encoder_output = self.run_text_encoder(prompt, first_frame)
......@@ -201,20 +204,32 @@ class DefaultRunner(BaseRunner):
gc.collect()
return self.get_encoder_output_i2v(None, vae_encoder_out, text_encoder_output)
@ProfilingContext("Run DiT")
def _run_dit_local(self, total_steps=None):
def init_run(self):
self.set_target_shape()
self.get_video_segment_num()
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.model = self.load_transformer()
self.init_scheduler()
self.model.scheduler.prepare(self.inputs["image_encoder_output"])
if self.config.get("model_cls") == "wan2.2" and self.config["task"] == "i2v":
self.inputs["image_encoder_output"]["vae_encoder_out"] = None
latents, generator = self.run(total_steps)
@ProfilingContext("Run DiT")
def run_main(self, total_steps=None):
self.init_run()
for segment_idx in range(self.video_segment_num):
# 1. default do nothing
self.init_run_segment(segment_idx)
# 2. main inference loop
latents, generator = self.run_segment(total_steps=total_steps)
# 3. vae decoder
self.gen_video = self.run_vae_decoder(latents, generator)
# 4. default do nothing
self.end_run_segment()
self.end_run()
return latents, generator
@ProfilingContext("Run VAE Decoder")
def _run_vae_decoder_local(self, latents, generator):
def run_vae_decoder(self, latents, generator):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.vae_decoder = self.load_vae_decoder()
images = self.vae_decoder.decode(latents, generator=generator, config=self.config)
......@@ -240,15 +255,15 @@ class DefaultRunner(BaseRunner):
logger.info(f"Enhanced prompt: {enhanced_prompt}")
return enhanced_prompt
def process_images_after_vae_decoder(self, images, save_video=True):
images = vae_to_comfyui_image(images)
def process_images_after_vae_decoder(self, save_video=True):
self.gen_video = vae_to_comfyui_image(self.gen_video)
if "video_frame_interpolation" in self.config:
assert self.vfi_model is not None and self.config["video_frame_interpolation"].get("target_fps", None) is not None
target_fps = self.config["video_frame_interpolation"]["target_fps"]
logger.info(f"Interpolating frames from {self.config.get('fps', 16)} to {target_fps}")
images = self.vfi_model.interpolate_frames(
images,
self.gen_video = self.vfi_model.interpolate_frames(
self.gen_video,
source_fps=self.config.get("fps", 16),
target_fps=target_fps,
)
......@@ -262,24 +277,21 @@ class DefaultRunner(BaseRunner):
if not dist.is_initialized() or dist.get_rank() == 0:
logger.info(f"🎬 Start to save video 🎬")
save_to_video(images, self.config.save_video_path, fps=fps, method="ffmpeg")
save_to_video(self.gen_video, self.config.save_video_path, fps=fps, method="ffmpeg")
logger.info(f"✅ Video saved successfully to: {self.config.save_video_path} ✅")
return {"video": self.gen_video}
def run_pipeline(self, save_video=True):
if self.config["use_prompt_enhancer"]:
self.config["prompt_enhanced"] = self.post_prompt_enhancer()
self.inputs = self.run_input_encoder()
self.set_target_shape()
latents, generator = self.run_dit()
self.run_main()
images = self.run_vae_decoder(latents, generator)
self.process_images_after_vae_decoder(images, save_video=save_video)
gen_video = self.process_images_after_vae_decoder(save_video=save_video)
del latents, generator
torch.cuda.empty_cache()
gc.collect()
# Return (images, audio) - audio is None for default runner
return images, None
return gen_video
......@@ -225,12 +225,10 @@ class WanRunner(DefaultRunner):
def run_image_encoder(self, first_frame, last_frame=None):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.image_encoder = self.load_image_encoder()
first_frame = TF.to_tensor(first_frame).sub_(0.5).div_(0.5).cuda()
if last_frame is None:
clip_encoder_out = self.image_encoder.visual([first_frame[None, :, :, :]]).squeeze(0).to(GET_DTYPE())
clip_encoder_out = self.image_encoder.visual([first_frame]).squeeze(0).to(GET_DTYPE())
else:
last_frame = TF.to_tensor(last_frame).sub_(0.5).div_(0.5).cuda()
clip_encoder_out = self.image_encoder.visual([first_frame[:, None, :, :].transpose(0, 1), last_frame[:, None, :, :].transpose(0, 1)]).squeeze(0).to(GET_DTYPE())
clip_encoder_out = self.image_encoder.visual([first_frame, last_frame]).squeeze(0).to(GET_DTYPE())
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.image_encoder
torch.cuda.empty_cache()
......@@ -238,9 +236,7 @@ class WanRunner(DefaultRunner):
return clip_encoder_out
def run_vae_encoder(self, first_frame, last_frame=None):
first_frame_size = first_frame.size
first_frame = TF.to_tensor(first_frame).sub_(0.5).div_(0.5).cuda()
h, w = first_frame.shape[1:]
h, w = first_frame.shape[2:]
aspect_ratio = h / w
max_area = self.config.target_height * self.config.target_width
lat_h = round(np.sqrt(max_area * aspect_ratio) // self.config.vae_stride[1] // self.config.patch_size[1] * self.config.patch_size[1])
......@@ -260,8 +256,8 @@ class WanRunner(DefaultRunner):
return vae_encode_out_list
else:
if last_frame is not None:
last_frame_size = last_frame.size
last_frame = TF.to_tensor(last_frame).sub_(0.5).div_(0.5).cuda()
first_frame_size = first_frame.shape[2:]
last_frame_size = last_frame.shape[2:]
if first_frame_size != last_frame_size:
last_frame_resize_ratio = max(first_frame_size[0] / last_frame_size[0], first_frame_size[1] / last_frame_size[1])
last_frame_size = [
......@@ -298,16 +294,16 @@ class WanRunner(DefaultRunner):
if last_frame is not None:
vae_input = torch.concat(
[
torch.nn.functional.interpolate(first_frame[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
torch.nn.functional.interpolate(first_frame.cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
torch.zeros(3, self.config.target_video_length - 2, h, w),
torch.nn.functional.interpolate(last_frame[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
torch.nn.functional.interpolate(last_frame.cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
],
dim=1,
).cuda()
else:
vae_input = torch.concat(
[
torch.nn.functional.interpolate(first_frame[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
torch.nn.functional.interpolate(first_frame.cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
torch.zeros(3, self.config.target_video_length - 1, h, w),
],
dim=1,
......
import gc
import math
import numpy as np
import torch
from loguru import logger
from lightx2v.models.schedulers.scheduler import BaseScheduler
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.utils.envs import *
def unsqueeze_to_ndim(in_tensor, tgt_n_dim):
if in_tensor.ndim > tgt_n_dim:
warnings.warn(f"the given tensor of shape {in_tensor.shape} is expected to unsqueeze to {tgt_n_dim}, the original tensor will be returned")
return in_tensor
if in_tensor.ndim < tgt_n_dim:
in_tensor = in_tensor[(...,) + (None,) * (tgt_n_dim - in_tensor.ndim)]
return in_tensor
class EulerSchedulerTimestepFix(BaseScheduler):
def __init__(self, config, **kwargs):
# super().__init__(**kwargs)
self.init_noise_sigma = 1.0
self.config = config
self.latents = None
self.device = torch.device("cuda")
self.infer_steps = self.config.infer_steps
self.target_video_length = self.config.target_video_length
self.sample_shift = self.config.sample_shift
self.num_train_timesteps = 1000
self.step_index = None
class ConsistencyModelScheduler(WanScheduler):
def __init__(self, config):
super().__init__(config)
def step_pre(self, step_index):
self.step_index = step_index
......@@ -37,12 +19,6 @@ class EulerSchedulerTimestepFix(BaseScheduler):
def prepare(self, image_encoder_output=None):
self.prepare_latents(self.config.target_shape, dtype=torch.float32)
if self.config.task in ["t2v"]:
self.seq_len = math.ceil((self.config.target_shape[2] * self.config.target_shape[3]) / (self.config.patch_size[1] * self.config.patch_size[2]) * self.config.target_shape[1])
elif self.config.task in ["i2v"]:
self.seq_len = ((self.config.target_video_length - 1) // self.config.vae_stride[0] + 1) * self.config.lat_h * self.config.lat_w // (self.config.patch_size[1] * self.config.patch_size[2])
timesteps = np.linspace(self.num_train_timesteps, 0, self.infer_steps + 1, dtype=np.float32)
self.timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=self.device)
......@@ -53,29 +29,13 @@ class EulerSchedulerTimestepFix(BaseScheduler):
self.timesteps = self.sigmas * self.num_train_timesteps
def prepare_latents(self, target_shape, dtype=torch.float32):
self.generator = torch.Generator(device=self.device).manual_seed(self.config.seed)
self.latents = (
torch.randn(
target_shape[0],
target_shape[1],
target_shape[2],
target_shape[3],
dtype=dtype,
device=self.device,
generator=self.generator,
)
* self.init_noise_sigma
)
def step_post(self):
model_output = self.noise_pred.to(torch.float32)
sample = self.latents.to(torch.float32)
sigma = unsqueeze_to_ndim(self.sigmas[self.step_index], sample.ndim).to(sample.device, sample.dtype)
sigma_next = unsqueeze_to_ndim(self.sigmas[self.step_index + 1], sample.ndim).to(sample.device, sample.dtype)
x_t_next = sample + (sigma_next - sigma) * model_output
sigma = self.unsqueeze_to_ndim(self.sigmas[self.step_index], sample.ndim).to(sample.device, sample.dtype)
sigma_next = self.unsqueeze_to_ndim(self.sigmas[self.step_index + 1], sample.ndim).to(sample.device, sample.dtype)
x0 = sample - model_output * sigma
x_t_next = x0 * (1 - sigma_next) + sigma_next * torch.randn(x0.shape, dtype=x0.dtype, device=x0.device, generator=self.generator)
self.latents = x_t_next
def reset(self):
......@@ -83,13 +43,10 @@ class EulerSchedulerTimestepFix(BaseScheduler):
gc.collect()
torch.cuda.empty_cache()
class ConsistencyModelScheduler(EulerSchedulerTimestepFix):
def step_post(self):
model_output = self.noise_pred.to(torch.float32)
sample = self.latents.to(torch.float32)
sigma = unsqueeze_to_ndim(self.sigmas[self.step_index], sample.ndim).to(sample.device, sample.dtype)
sigma_next = unsqueeze_to_ndim(self.sigmas[self.step_index + 1], sample.ndim).to(sample.device, sample.dtype)
x0 = sample - model_output * sigma
x_t_next = x0 * (1 - sigma_next) + sigma_next * torch.randn(x0.shape, dtype=x0.dtype, device=x0.device, generator=self.generator)
self.latents = x_t_next
def unsqueeze_to_ndim(self, in_tensor, tgt_n_dim):
if in_tensor.ndim > tgt_n_dim:
logger.warning(f"the given tensor of shape {in_tensor.shape} is expected to unsqueeze to {tgt_n_dim}, the original tensor will be returned")
return in_tensor
if in_tensor.ndim < tgt_n_dim:
in_tensor = in_tensor[(...,) + (None,) * (tgt_n_dim - in_tensor.ndim)]
return in_tensor
......@@ -20,6 +20,7 @@ class WanScheduler4ChangingResolution:
assert len(config["resolution_rate"]) == len(config["changing_resolution_steps"])
def prepare_latents(self, target_shape, dtype=torch.float32):
self.generator = torch.Generator(device=self.device).manual_seed(self.config.seed)
self.latents_list = []
for i in range(len(self.config["resolution_rate"])):
self.latents_list.append(
......
......@@ -26,8 +26,6 @@ class WanScheduler(BaseScheduler):
def prepare(self, image_encoder_output=None):
if self.config["model_cls"] == "wan2.2" and self.config["task"] == "i2v":
self.vae_encoder_out = image_encoder_output["vae_encoder_out"]
self.generator = torch.Generator(device=self.device)
self.generator.manual_seed(self.config.seed)
self.prepare_latents(self.config.target_shape, dtype=torch.float32)
......@@ -51,6 +49,7 @@ class WanScheduler(BaseScheduler):
self.set_timesteps(self.infer_steps, device=self.device, shift=self.sample_shift)
def prepare_latents(self, target_shape, dtype=torch.float32):
self.generator = torch.Generator(device=self.device).manual_seed(self.config.seed)
self.latents = torch.randn(
target_shape[0],
target_shape[1],
......
import safetensors
import torch
from safetensors.torch import save_file
from lightx2v.utils.quant_utils import FloatQuantizer
model_path = "/data/nvme0/models/Wan2.1-R2V721-Audio-14B-720P/audio_adapter.safetensors"
state_dict = {}
with safetensors.safe_open(model_path, framework="pt", device="cpu") as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key)
new_state_dict = {}
new_model_path = "/data/nvme0/models/Wan2.1-R2V721-Audio-14B-720P/audio_adapter_fp8.safetensors"
for key in state_dict.keys():
if key.startswith("ca") and ".to" in key and "weight" in key and "to_kv" not in key:
print(key, state_dict[key].dtype)
weight = state_dict[key].to(torch.float32).cuda()
w_quantizer = FloatQuantizer("e4m3", True, "per_channel")
weight, weight_scale, _ = w_quantizer.real_quant_tensor(weight)
weight = weight.to(torch.float8_e4m3fn)
weight_scale = weight_scale.to(torch.float32)
new_state_dict[key] = weight.cpu()
new_state_dict[key + "_scale"] = weight_scale.cpu()
for key in state_dict.keys():
if key not in new_state_dict.keys():
new_state_dict[key] = state_dict[key]
save_file(new_state_dict, new_model_path)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment