try: import flash_attn except ModuleNotFoundError: flash_attn = None import math import torch import torch.nn as nn import torch.nn.functional as F from diffusers.models.embeddings import TimestepEmbedding, Timesteps from einops import rearrange from transformers import AutoModel from loguru import logger import os import safetensors from typing import List, Optional, Tuple, Union def load_safetensors(in_path: str): if os.path.isdir(in_path): return load_safetensors_from_dir(in_path) elif os.path.isfile(in_path): return load_safetensors_from_path(in_path) else: raise ValueError(f"{in_path} does not exist") def load_safetensors_from_path(in_path: str): tensors = {} with safetensors.safe_open(in_path, framework="pt", device="cpu") as f: for key in f.keys(): tensors[key] = f.get_tensor(key) return tensors def load_safetensors_from_dir(in_dir: str): tensors = {} safetensors = os.listdir(in_dir) safetensors = [f for f in safetensors if f.endswith(".safetensors")] for f in safetensors: tensors.update(load_safetensors_from_path(os.path.join(in_dir, f))) return tensors def load_pt_safetensors(in_path: str): ext = os.path.splitext(in_path)[-1] if ext in (".pt", ".pth", ".tar"): state_dict = torch.load(in_path, map_location="cpu", weights_only=True) else: state_dict = load_safetensors(in_path) return state_dict def rank0_load_state_dict_from_path(model, in_path: str, strict: bool = True): import torch.distributed as dist if (dist.is_initialized() and dist.get_rank() == 0) or (not dist.is_initialized()): state_dict = load_pt_safetensors(in_path) model.load_state_dict(state_dict, strict=strict) if dist.is_initialized(): dist.barrier() return model.to(dtype=torch.bfloat16, device="cuda") def linear_interpolation(features, output_len: int): features = features.transpose(1, 2) output_features = F.interpolate(features, size=output_len, align_corners=False, mode="linear") return output_features.transpose(1, 2) def get_q_lens_audio_range( batchsize, n_tokens_per_rank, n_query_tokens, n_tokens_per_frame, sp_rank, ): if n_query_tokens == 0: q_lens = [1] * batchsize return q_lens, 0, 1 idx0 = n_tokens_per_rank * sp_rank first_length = idx0 - idx0 // n_tokens_per_frame * n_tokens_per_frame n_frames = (n_query_tokens - first_length) // n_tokens_per_frame last_length = n_query_tokens - n_frames * n_tokens_per_frame - first_length q_lens = [] if first_length > 0: q_lens.append(first_length) q_lens += [n_tokens_per_frame] * n_frames if last_length > 0: q_lens.append(last_length) t0 = idx0 // n_tokens_per_frame idx1 = idx0 + n_query_tokens t1 = math.ceil(idx1 / n_tokens_per_frame) return q_lens * batchsize, t0, t1 class PerceiverAttentionCA(nn.Module): def __init__(self, dim_head=128, heads=16, kv_dim=2048, adaLN: bool = False): super().__init__() self.dim_head = dim_head self.heads = heads inner_dim = dim_head * heads kv_dim = inner_dim if kv_dim is None else kv_dim self.norm_kv = nn.LayerNorm(kv_dim) self.norm_q = nn.LayerNorm(inner_dim, elementwise_affine=not adaLN) self.to_q = nn.Linear(inner_dim, inner_dim) self.to_kv = nn.Linear(kv_dim, inner_dim * 2) self.to_out = nn.Linear(inner_dim, inner_dim) if adaLN: self.shift_scale_gate = nn.Parameter(torch.randn(1, 3, inner_dim) / inner_dim**0.5) else: shift_scale_gate = torch.zeros((1, 3, inner_dim)) shift_scale_gate[:, 2] = 1 self.register_buffer("shift_scale_gate", shift_scale_gate, persistent=False) def forward(self, x, latents, t_emb, q_lens, k_lens): """x shape (batchsize, latent_frame, audio_tokens_per_latent, model_dim) latents (batchsize, length, model_dim)""" batchsize = len(x) x = self.norm_kv(x) shift, scale, gate = (t_emb + self.shift_scale_gate).chunk(3, dim=1) latents = self.norm_q(latents) * (1 + scale) + shift q = self.to_q(latents) k, v = self.to_kv(x).chunk(2, dim=-1) q = rearrange(q, "B L (H C) -> (B L) H C", H=self.heads) k = rearrange(k, "B T L (H C) -> (B T L) H C", H=self.heads) v = rearrange(v, "B T L (H C) -> (B T L) H C", H=self.heads) out = flash_attn.flash_attn_varlen_func( q=q, k=k, v=v, cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True), cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True), max_seqlen_q=q_lens.max(), max_seqlen_k=k_lens.max(), dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), deterministic=False, ) out = rearrange(out, "(B L) H C -> B L (H C)", B=batchsize) return self.to_out(out) * gate class AudioProjection(nn.Module): def __init__( self, audio_feature_dim: int = 768, n_neighbors: tuple = (2, 2), num_tokens: int = 32, mlp_dims: tuple = (1024, 1024, 32 * 768), transformer_layers: int = 4, ): super().__init__() mlp = [] self.left, self.right = n_neighbors self.audio_frames = sum(n_neighbors) + 1 in_dim = audio_feature_dim * self.audio_frames for i, out_dim in enumerate(mlp_dims): mlp.append(nn.Linear(in_dim, out_dim)) if i != len(mlp_dims) - 1: mlp.append(nn.ReLU()) in_dim = out_dim self.mlp = nn.Sequential(*mlp) self.norm = nn.LayerNorm(mlp_dims[-1] // num_tokens) self.num_tokens = num_tokens if transformer_layers > 0: decoder_layer = nn.TransformerDecoderLayer(d_model=audio_feature_dim, nhead=audio_feature_dim // 64, dim_feedforward=4 * audio_feature_dim, dropout=0.0, batch_first=True) self.transformer_decoder = nn.TransformerDecoder( decoder_layer, num_layers=transformer_layers, ) else: self.transformer_decoder = None def forward(self, audio_feature, latent_frame): video_frame = (latent_frame - 1) * 4 + 1 audio_feature_ori = audio_feature audio_feature = linear_interpolation(audio_feature_ori, video_frame) if self.transformer_decoder is not None: audio_feature = self.transformer_decoder(audio_feature, audio_feature_ori) audio_feature = F.pad(audio_feature, pad=(0, 0, self.left, self.right), mode="replicate") audio_feature = audio_feature.unfold(dimension=1, size=self.audio_frames, step=1) audio_feature = rearrange(audio_feature, "B T C W -> B T (W C)") audio_feature = self.mlp(audio_feature) # (B, video_frame, C) audio_feature = rearrange(audio_feature, "B T (N C) -> B T N C", N=self.num_tokens) # (B, video_frame, num_tokens, C) return self.norm(audio_feature) class TimeEmbedding(nn.Module): def __init__(self, dim, time_freq_dim, time_proj_dim): super().__init__() self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim) self.act_fn = nn.SiLU() self.time_proj = nn.Linear(dim, time_proj_dim) def forward( self, timestep: torch.Tensor, ): timestep = self.timesteps_proj(timestep) time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype timestep = timestep.to(time_embedder_dtype) temb = self.time_embedder(timestep) timestep_proj = self.time_proj(self.act_fn(temb)) return timestep_proj class AudioAdapter(nn.Module): def __init__( self, attention_head_dim=64, num_attention_heads=40, base_num_layers=30, interval=1, audio_feature_dim: int = 768, num_tokens: int = 32, mlp_dims: tuple = (1024, 1024, 32 * 768), time_freq_dim: int = 256, projection_transformer_layers: int = 4, ): super().__init__() self.audio_proj = AudioProjection( audio_feature_dim=audio_feature_dim, n_neighbors=(2, 2), num_tokens=num_tokens, mlp_dims=mlp_dims, transformer_layers=projection_transformer_layers, ) # self.num_tokens = num_tokens * 4 self.num_tokens_x4 = num_tokens * 4 self.audio_pe = nn.Parameter(torch.randn(self.num_tokens_x4, mlp_dims[-1] // num_tokens) * 0.02) ca_num = math.ceil(base_num_layers / interval) self.base_num_layers = base_num_layers self.interval = interval self.ca = nn.ModuleList( [ PerceiverAttentionCA( dim_head=attention_head_dim, heads=num_attention_heads, kv_dim=mlp_dims[-1] // num_tokens, adaLN=time_freq_dim > 0, ) for _ in range(ca_num) ] ) self.dim = attention_head_dim * num_attention_heads if time_freq_dim > 0: self.time_embedding = TimeEmbedding(self.dim, time_freq_dim, self.dim * 3) else: self.time_embedding = None def rearange_audio_features(self, audio_feature: torch.Tensor): # audio_feature (B, video_frame, num_tokens, C) audio_feature_0 = audio_feature[:, :1] audio_feature_0 = torch.repeat_interleave(audio_feature_0, repeats=4, dim=1) audio_feature = torch.cat([audio_feature_0, audio_feature[:, 1:]], dim=1) # (B, 4 * latent_frame, num_tokens, C) audio_feature = rearrange(audio_feature, "B (T S) N C -> B T (S N) C", S=4) return audio_feature def forward(self, audio_feat: torch.Tensor, timestep: torch.Tensor, latent_frame: int, weight: float = 1.0): def modify_hidden_states(hidden_states, grid_sizes, ca_block: PerceiverAttentionCA, x, t_emb, dtype, weight): """thw specify the latent_frame, latent_height, latenf_width after hidden_states is patchified. latent_frame does not include the reference images so that the audios and hidden_states are strictly aligned """ if len(hidden_states.shape) == 2: # 扩展batchsize dim hidden_states = hidden_states.unsqueeze(0) # bs = 1 # print(weight) t, h, w = grid_sizes[0].tolist() n_tokens = t * h * w ori_dtype = hidden_states.dtype device = hidden_states.device bs, n_tokens_per_rank = hidden_states.shape[:2] tail_length = n_tokens_per_rank - n_tokens n_query_tokens = n_tokens_per_rank - tail_length % n_tokens_per_rank if n_query_tokens > 0: hidden_states_aligned = hidden_states[:, :n_query_tokens] hidden_states_tail = hidden_states[:, n_query_tokens:] else: # for ranks that should be excluded from cross-attn, fake cross-attn will be applied so that FSDP works. hidden_states_aligned = hidden_states[:, :1] hidden_states_tail = hidden_states[:, 1:] q_lens, t0, t1 = get_q_lens_audio_range(batchsize=bs, n_tokens_per_rank=n_tokens_per_rank, n_query_tokens=n_query_tokens, n_tokens_per_frame=h * w, sp_rank=0) q_lens = torch.tensor(q_lens, device=device, dtype=torch.int32) """ processing audio features in sp_state can be moved outside. """ x = x[:, t0:t1] x = x.to(dtype) k_lens = torch.tensor([self.num_tokens_x4] * (t1 - t0) * bs, device=device, dtype=torch.int32) assert q_lens.shape == k_lens.shape # ca_block:CrossAttention函数 residual = ca_block(x, hidden_states_aligned, t_emb, q_lens, k_lens) * weight residual = residual.to(ori_dtype) # audio做了CrossAttention之后以Residual的方式注入 if n_query_tokens == 0: residual = residual * 0.0 hidden_states = torch.cat([hidden_states_aligned + residual, hidden_states_tail], dim=1) if len(hidden_states.shape) == 3: # hidden_states = hidden_states.squeeze(0) # bs = 1 return hidden_states x = self.audio_proj(audio_feat, latent_frame) x = self.rearange_audio_features(x) x = x + self.audio_pe if self.time_embedding is not None: t_emb = self.time_embedding(timestep).unflatten(1, (3, -1)) else: t_emb = torch.zeros((len(x), 3, self.dim), device=x.device, dtype=x.dtype) ret_dict = {} for block_idx, base_idx in enumerate(range(0, self.base_num_layers, self.interval)): block_dict = { "kwargs": { "ca_block": self.ca[block_idx], "x": x, "weight": weight, "t_emb": t_emb, "dtype": x.dtype, }, "modify_func": modify_hidden_states, } ret_dict[base_idx] = block_dict return ret_dict @classmethod def from_transformer( cls, transformer, audio_feature_dim: int = 1024, interval: int = 1, time_freq_dim: int = 256, projection_transformer_layers: int = 4, ): num_attention_heads = transformer.config["num_heads"] base_num_layers = transformer.config["num_layers"] attention_head_dim = transformer.config["dim"] // num_attention_heads audio_adapter = AudioAdapter( attention_head_dim, num_attention_heads, base_num_layers, interval=interval, audio_feature_dim=audio_feature_dim, time_freq_dim=time_freq_dim, projection_transformer_layers=projection_transformer_layers, mlp_dims=(1024, 1024, 32 * audio_feature_dim), ) return audio_adapter def get_fsdp_wrap_module_list( self, ): ret_list = list(self.ca) return ret_list def enable_gradient_checkpointing( self, ): pass class AudioAdapterPipe: def __init__( self, audio_adapter: AudioAdapter, audio_encoder_repo: str = "microsoft/wavlm-base-plus", dtype=torch.float32, device="cuda", generator=None, tgt_fps: int = 15, weight: float = 1.0 ) -> None: self.audio_adapter = audio_adapter self.dtype = dtype self.device = device self.generator = generator self.audio_encoder_dtype = torch.float16 ##音频编码器 self.audio_encoder = AutoModel.from_pretrained(audio_encoder_repo) self.audio_encoder.eval() self.audio_encoder.to(device, self.audio_encoder_dtype) self.tgt_fps = tgt_fps self.weight = weight if "base" in audio_encoder_repo: self.audio_feature_dim = 768 else: self.audio_feature_dim = 1024 def update_model(self, audio_adapter): self.audio_adapter = audio_adapter def __call__(self, audio_input_feat, timestep, latent_shape: tuple, dropout_cond: callable = None): # audio_input_feat is from AudioPreprocessor latent_frame = latent_shape[2] if len(audio_input_feat.shape) == 1: # 扩展batchsize = 1 audio_input_feat = audio_input_feat.unsqueeze(0) latent_frame = latent_shape[1] video_frame = (latent_frame - 1) * 4 + 1 audio_length = int(50 / self.tgt_fps * video_frame) with torch.no_grad(): audio_input_feat = audio_input_feat.to(self.device, self.audio_encoder_dtype) try: audio_feat = self.audio_encoder(audio_input_feat, return_dict=True).last_hidden_state except Exception as err: audio_feat = torch.rand(1, audio_length, self.audio_feature_dim).to(self.device) print(err) audio_feat = audio_feat.to(self.dtype) if dropout_cond is not None: audio_feat = dropout_cond(audio_feat) return self.audio_adapter(audio_feat=audio_feat, timestep=timestep, latent_frame=latent_frame, weight=self.weight)