Commit dd958c79 authored by wangshankun's avatar wangshankun
Browse files

Support: audio r2v dist infer

parent 820b4450
......@@ -12,6 +12,9 @@ import torch.nn.functional as F
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from einops import rearrange
from transformers import AutoModel
import torch.distributed as dist
from loguru import logger
from lightx2v.utils.envs import *
......@@ -261,8 +264,8 @@ class AudioAdapter(nn.Module):
audio_feature = rearrange(audio_feature, "B (T S) N C -> B T (S N) C", S=4)
return audio_feature
def forward(self, audio_feat: torch.Tensor, timestep: torch.Tensor, latent_frame: int, weight: float = 1.0):
def modify_hidden_states(hidden_states, grid_sizes, ca_block: PerceiverAttentionCA, x, t_emb, dtype, weight):
def forward(self, audio_feat: torch.Tensor, timestep: torch.Tensor, latent_frame: int, weight: float = 1.0, seq_p_group=None):
def modify_hidden_states(hidden_states, grid_sizes, ca_block: PerceiverAttentionCA, x, t_emb, dtype, weight, seq_p_group):
"""thw specify the latent_frame, latent_height, latenf_width after
hidden_states is patchified.
......@@ -271,15 +274,27 @@ class AudioAdapter(nn.Module):
"""
if len(hidden_states.shape) == 2: # 扩展batchsize dim
hidden_states = hidden_states.unsqueeze(0) # bs = 1
# print(weight)
t, h, w = grid_sizes[0].tolist()
n_tokens = t * h * w
ori_dtype = hidden_states.dtype
device = hidden_states.device
bs, n_tokens_per_rank = hidden_states.shape[:2]
tail_length = n_tokens_per_rank - n_tokens
if seq_p_group is not None:
sp_size = dist.get_world_size(seq_p_group)
sp_rank = dist.get_rank(seq_p_group)
else:
sp_size = 1
sp_rank = 0
tail_length = n_tokens_per_rank * sp_size - n_tokens
n_unused_ranks = tail_length // n_tokens_per_rank
if sp_rank > sp_size - n_unused_ranks - 1:
n_query_tokens = 0
elif sp_rank == sp_size - n_unused_ranks - 1:
n_query_tokens = n_tokens_per_rank - tail_length % n_tokens_per_rank
else:
n_query_tokens = n_tokens_per_rank
if n_query_tokens > 0:
hidden_states_aligned = hidden_states[:, :n_query_tokens]
......@@ -289,7 +304,7 @@ class AudioAdapter(nn.Module):
hidden_states_aligned = hidden_states[:, :1]
hidden_states_tail = hidden_states[:, 1:]
q_lens, t0, t1 = get_q_lens_audio_range(batchsize=bs, n_tokens_per_rank=n_tokens_per_rank, n_query_tokens=n_query_tokens, n_tokens_per_frame=h * w, sp_rank=0)
q_lens, t0, t1 = get_q_lens_audio_range(batchsize=bs, n_tokens_per_rank=n_tokens_per_rank, n_query_tokens=n_query_tokens, n_tokens_per_frame=h * w, sp_rank=sp_rank)
q_lens = torch.tensor(q_lens, device=device, dtype=torch.int32)
"""
processing audio features in sp_state can be moved outside.
......@@ -300,6 +315,7 @@ class AudioAdapter(nn.Module):
assert q_lens.shape == k_lens.shape
# ca_block:CrossAttention函数
residual = ca_block(x, hidden_states_aligned, t_emb, q_lens, k_lens) * weight
residual = residual.to(ori_dtype) # audio做了CrossAttention之后以Residual的方式注入
if n_query_tokens == 0:
residual = residual * 0.0
......@@ -325,6 +341,7 @@ class AudioAdapter(nn.Module):
"weight": weight,
"t_emb": t_emb,
"dtype": x.dtype,
"seq_p_group": seq_p_group,
},
"modify_func": modify_hidden_states,
}
......@@ -370,8 +387,17 @@ class AudioAdapter(nn.Module):
class AudioAdapterPipe:
def __init__(
self, audio_adapter: AudioAdapter, audio_encoder_repo: str = "microsoft/wavlm-base-plus", dtype=torch.float32, device="cuda", tgt_fps: int = 15, weight: float = 1.0, cpu_offload: bool = False
self,
audio_adapter: AudioAdapter,
audio_encoder_repo: str = "microsoft/wavlm-base-plus",
dtype=torch.float32,
device="cuda",
tgt_fps: int = 15,
weight: float = 1.0,
cpu_offload: bool = False,
seq_p_group=None,
) -> None:
self.seq_p_group = seq_p_group
self.audio_adapter = audio_adapter
self.dtype = dtype
self.audio_encoder_dtype = torch.float16
......@@ -415,4 +441,4 @@ class AudioAdapterPipe:
if dropout_cond is not None:
audio_feat = dropout_cond(audio_feat)
return self.audio_adapter(audio_feat=audio_feat, timestep=timestep, latent_frame=latent_frame, weight=self.weight)
return self.audio_adapter(audio_feat=audio_feat, timestep=timestep, latent_frame=latent_frame, weight=self.weight, seq_p_group=self.seq_p_group)
......@@ -13,6 +13,8 @@ from lightx2v.models.networks.wan.weights.transformer_weights import (
WanTransformerWeights,
)
from loguru import logger
class WanAudioModel(WanModel):
pre_weight_class = WanPreWeights
......@@ -65,6 +67,49 @@ class WanAudioModel(WanModel):
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
@torch.no_grad()
def infer_wo_cfg_parallel(self, inputs):
if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == 0:
self.to_cuda()
elif self.offload_granularity != "model":
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
if self.transformer_infer.mask_map is None:
_, c, h, w = self.scheduler.latents.shape
num_frame = c + 1 # for r2v
video_token_num = num_frame * (h // 2) * (w // 2)
self.transformer_infer.mask_map = MaskMap(video_token_num, num_frame)
embed, grid_sizes, pre_infer_out, valid_patch_length = self.pre_infer.infer(self.pre_weight, inputs, positive=True)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes, valid_patch_length)[0]
self.scheduler.noise_pred = noise_pred_cond
if self.clean_cuda_cache:
del x, embed, pre_infer_out, noise_pred_cond, grid_sizes
torch.cuda.empty_cache()
if self.config["enable_cfg"]:
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=False)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (self.scheduler.noise_pred - noise_pred_uncond)
if self.clean_cuda_cache:
del x, embed, pre_infer_out, noise_pred_uncond, grid_sizes
torch.cuda.empty_cache()
if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1:
self.to_cpu()
elif self.offload_granularity != "model":
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
class Wan22MoeAudioModel(WanAudioModel):
def _load_ckpt(self, unified_dtype, sensitive_layer):
......
......@@ -76,6 +76,7 @@ class WanTransformerDistInfer(WanTransformerInfer):
cur_rank = dist.get_rank(self.seq_p_group)
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist()
valid_token_length = f * h * w
f = f + 1
seq_len = f * h * w
freqs_i = torch.cat(
......@@ -87,6 +88,8 @@ class WanTransformerDistInfer(WanTransformerInfer):
dim=-1,
).reshape(seq_len, 1, -1)
freqs_i[valid_token_length:, :, :f] = 0
freqs_i = pad_freqs(freqs_i, s * world_size)
s_per_rank = s
freqs_i_rank = freqs_i[(cur_rank * s_per_rank) : ((cur_rank + 1) * s_per_rank), :, :]
......
......@@ -11,7 +11,6 @@ from lightx2v.utils.envs import *
from .utils import apply_rotary_emb, apply_rotary_emb_chunk, compute_freqs, compute_freqs_audio
class WanTransformerInfer(BaseTransformerInfer):
def __init__(self, config):
self.config = config
......@@ -33,6 +32,7 @@ class WanTransformerInfer(BaseTransformerInfer):
self.infer_dtype = GET_DTYPE()
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
self.seq_p_group = None
if self.config.get("cpu_offload", False):
if torch.cuda.get_device_capability(0) == (9, 0):
assert self.config["self_attn_1_type"] != "sage_attn2"
......@@ -360,8 +360,6 @@ class WanTransformerInfer(BaseTransformerInfer):
freqs_i = self.compute_freqs(q, grid_sizes, freqs)
freqs_i = self.zero_temporal_component_in_3DRoPE(seq_lens, freqs_i)
q = self.apply_rotary_emb_func(q, freqs_i)
k = self.apply_rotary_emb_func(k, freqs_i)
......
......@@ -22,6 +22,7 @@ def compute_freqs(c, grid_sizes, freqs):
def compute_freqs_audio(c, grid_sizes, freqs):
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0]
valid_token_length = f * h * w
f = f + 1 ##for r2v add 1 channel
seq_len = f * h * w
freqs_i = torch.cat(
......@@ -33,6 +34,8 @@ def compute_freqs_audio(c, grid_sizes, freqs):
dim=-1,
).reshape(seq_len, 1, -1)
freqs_i[valid_token_length:, :, :f] = 0 ###for r2v # zero temporl component corresponding to ref embeddings
return freqs_i
......
......@@ -426,7 +426,15 @@ class WanAudioRunner(WanRunner): # type:ignore
else:
device = torch.device("cuda")
audio_encoder_repo = self.config["model_path"] + "/audio_encoder"
self._audio_adapter_pipe = AudioAdapterPipe(audio_adapter, audio_encoder_repo=audio_encoder_repo, dtype=GET_DTYPE(), device=device, weight=1.0, cpu_offload=cpu_offload)
if self.model.transformer_infer.seq_p_group is not None:
seq_p_group = self.model.transformer_infer.seq_p_group
else:
seq_p_group = None
self._audio_adapter_pipe = AudioAdapterPipe(
audio_adapter, audio_encoder_repo=audio_encoder_repo, dtype=GET_DTYPE(), device=device, weight=1.0, cpu_offload=cpu_offload, seq_p_group=seq_p_group
)
return self._audio_adapter_pipe
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment