"git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "17165f937252a5c00e35c9f93e63ab63bc21690a"
Commit dd958c79 authored by wangshankun's avatar wangshankun
Browse files

Support: audio r2v dist infer

parent 820b4450
...@@ -12,6 +12,9 @@ import torch.nn.functional as F ...@@ -12,6 +12,9 @@ import torch.nn.functional as F
from diffusers.models.embeddings import TimestepEmbedding, Timesteps from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from einops import rearrange from einops import rearrange
from transformers import AutoModel from transformers import AutoModel
import torch.distributed as dist
from loguru import logger
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
...@@ -261,8 +264,8 @@ class AudioAdapter(nn.Module): ...@@ -261,8 +264,8 @@ class AudioAdapter(nn.Module):
audio_feature = rearrange(audio_feature, "B (T S) N C -> B T (S N) C", S=4) audio_feature = rearrange(audio_feature, "B (T S) N C -> B T (S N) C", S=4)
return audio_feature return audio_feature
def forward(self, audio_feat: torch.Tensor, timestep: torch.Tensor, latent_frame: int, weight: float = 1.0): def forward(self, audio_feat: torch.Tensor, timestep: torch.Tensor, latent_frame: int, weight: float = 1.0, seq_p_group=None):
def modify_hidden_states(hidden_states, grid_sizes, ca_block: PerceiverAttentionCA, x, t_emb, dtype, weight): def modify_hidden_states(hidden_states, grid_sizes, ca_block: PerceiverAttentionCA, x, t_emb, dtype, weight, seq_p_group):
"""thw specify the latent_frame, latent_height, latenf_width after """thw specify the latent_frame, latent_height, latenf_width after
hidden_states is patchified. hidden_states is patchified.
...@@ -271,15 +274,27 @@ class AudioAdapter(nn.Module): ...@@ -271,15 +274,27 @@ class AudioAdapter(nn.Module):
""" """
if len(hidden_states.shape) == 2: # 扩展batchsize dim if len(hidden_states.shape) == 2: # 扩展batchsize dim
hidden_states = hidden_states.unsqueeze(0) # bs = 1 hidden_states = hidden_states.unsqueeze(0) # bs = 1
# print(weight)
t, h, w = grid_sizes[0].tolist() t, h, w = grid_sizes[0].tolist()
n_tokens = t * h * w n_tokens = t * h * w
ori_dtype = hidden_states.dtype ori_dtype = hidden_states.dtype
device = hidden_states.device device = hidden_states.device
bs, n_tokens_per_rank = hidden_states.shape[:2] bs, n_tokens_per_rank = hidden_states.shape[:2]
tail_length = n_tokens_per_rank - n_tokens
if seq_p_group is not None:
sp_size = dist.get_world_size(seq_p_group)
sp_rank = dist.get_rank(seq_p_group)
else:
sp_size = 1
sp_rank = 0
tail_length = n_tokens_per_rank * sp_size - n_tokens
n_unused_ranks = tail_length // n_tokens_per_rank
if sp_rank > sp_size - n_unused_ranks - 1:
n_query_tokens = 0
elif sp_rank == sp_size - n_unused_ranks - 1:
n_query_tokens = n_tokens_per_rank - tail_length % n_tokens_per_rank n_query_tokens = n_tokens_per_rank - tail_length % n_tokens_per_rank
else:
n_query_tokens = n_tokens_per_rank
if n_query_tokens > 0: if n_query_tokens > 0:
hidden_states_aligned = hidden_states[:, :n_query_tokens] hidden_states_aligned = hidden_states[:, :n_query_tokens]
...@@ -289,7 +304,7 @@ class AudioAdapter(nn.Module): ...@@ -289,7 +304,7 @@ class AudioAdapter(nn.Module):
hidden_states_aligned = hidden_states[:, :1] hidden_states_aligned = hidden_states[:, :1]
hidden_states_tail = hidden_states[:, 1:] hidden_states_tail = hidden_states[:, 1:]
q_lens, t0, t1 = get_q_lens_audio_range(batchsize=bs, n_tokens_per_rank=n_tokens_per_rank, n_query_tokens=n_query_tokens, n_tokens_per_frame=h * w, sp_rank=0) q_lens, t0, t1 = get_q_lens_audio_range(batchsize=bs, n_tokens_per_rank=n_tokens_per_rank, n_query_tokens=n_query_tokens, n_tokens_per_frame=h * w, sp_rank=sp_rank)
q_lens = torch.tensor(q_lens, device=device, dtype=torch.int32) q_lens = torch.tensor(q_lens, device=device, dtype=torch.int32)
""" """
processing audio features in sp_state can be moved outside. processing audio features in sp_state can be moved outside.
...@@ -300,6 +315,7 @@ class AudioAdapter(nn.Module): ...@@ -300,6 +315,7 @@ class AudioAdapter(nn.Module):
assert q_lens.shape == k_lens.shape assert q_lens.shape == k_lens.shape
# ca_block:CrossAttention函数 # ca_block:CrossAttention函数
residual = ca_block(x, hidden_states_aligned, t_emb, q_lens, k_lens) * weight residual = ca_block(x, hidden_states_aligned, t_emb, q_lens, k_lens) * weight
residual = residual.to(ori_dtype) # audio做了CrossAttention之后以Residual的方式注入 residual = residual.to(ori_dtype) # audio做了CrossAttention之后以Residual的方式注入
if n_query_tokens == 0: if n_query_tokens == 0:
residual = residual * 0.0 residual = residual * 0.0
...@@ -325,6 +341,7 @@ class AudioAdapter(nn.Module): ...@@ -325,6 +341,7 @@ class AudioAdapter(nn.Module):
"weight": weight, "weight": weight,
"t_emb": t_emb, "t_emb": t_emb,
"dtype": x.dtype, "dtype": x.dtype,
"seq_p_group": seq_p_group,
}, },
"modify_func": modify_hidden_states, "modify_func": modify_hidden_states,
} }
...@@ -370,8 +387,17 @@ class AudioAdapter(nn.Module): ...@@ -370,8 +387,17 @@ class AudioAdapter(nn.Module):
class AudioAdapterPipe: class AudioAdapterPipe:
def __init__( def __init__(
self, audio_adapter: AudioAdapter, audio_encoder_repo: str = "microsoft/wavlm-base-plus", dtype=torch.float32, device="cuda", tgt_fps: int = 15, weight: float = 1.0, cpu_offload: bool = False self,
audio_adapter: AudioAdapter,
audio_encoder_repo: str = "microsoft/wavlm-base-plus",
dtype=torch.float32,
device="cuda",
tgt_fps: int = 15,
weight: float = 1.0,
cpu_offload: bool = False,
seq_p_group=None,
) -> None: ) -> None:
self.seq_p_group = seq_p_group
self.audio_adapter = audio_adapter self.audio_adapter = audio_adapter
self.dtype = dtype self.dtype = dtype
self.audio_encoder_dtype = torch.float16 self.audio_encoder_dtype = torch.float16
...@@ -415,4 +441,4 @@ class AudioAdapterPipe: ...@@ -415,4 +441,4 @@ class AudioAdapterPipe:
if dropout_cond is not None: if dropout_cond is not None:
audio_feat = dropout_cond(audio_feat) audio_feat = dropout_cond(audio_feat)
return self.audio_adapter(audio_feat=audio_feat, timestep=timestep, latent_frame=latent_frame, weight=self.weight) return self.audio_adapter(audio_feat=audio_feat, timestep=timestep, latent_frame=latent_frame, weight=self.weight, seq_p_group=self.seq_p_group)
...@@ -13,6 +13,8 @@ from lightx2v.models.networks.wan.weights.transformer_weights import ( ...@@ -13,6 +13,8 @@ from lightx2v.models.networks.wan.weights.transformer_weights import (
WanTransformerWeights, WanTransformerWeights,
) )
from loguru import logger
class WanAudioModel(WanModel): class WanAudioModel(WanModel):
pre_weight_class = WanPreWeights pre_weight_class = WanPreWeights
...@@ -65,6 +67,49 @@ class WanAudioModel(WanModel): ...@@ -65,6 +67,49 @@ class WanAudioModel(WanModel):
self.pre_weight.to_cpu() self.pre_weight.to_cpu()
self.post_weight.to_cpu() self.post_weight.to_cpu()
@torch.no_grad()
def infer_wo_cfg_parallel(self, inputs):
if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == 0:
self.to_cuda()
elif self.offload_granularity != "model":
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
if self.transformer_infer.mask_map is None:
_, c, h, w = self.scheduler.latents.shape
num_frame = c + 1 # for r2v
video_token_num = num_frame * (h // 2) * (w // 2)
self.transformer_infer.mask_map = MaskMap(video_token_num, num_frame)
embed, grid_sizes, pre_infer_out, valid_patch_length = self.pre_infer.infer(self.pre_weight, inputs, positive=True)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes, valid_patch_length)[0]
self.scheduler.noise_pred = noise_pred_cond
if self.clean_cuda_cache:
del x, embed, pre_infer_out, noise_pred_cond, grid_sizes
torch.cuda.empty_cache()
if self.config["enable_cfg"]:
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=False)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (self.scheduler.noise_pred - noise_pred_uncond)
if self.clean_cuda_cache:
del x, embed, pre_infer_out, noise_pred_uncond, grid_sizes
torch.cuda.empty_cache()
if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1:
self.to_cpu()
elif self.offload_granularity != "model":
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
class Wan22MoeAudioModel(WanAudioModel): class Wan22MoeAudioModel(WanAudioModel):
def _load_ckpt(self, unified_dtype, sensitive_layer): def _load_ckpt(self, unified_dtype, sensitive_layer):
......
...@@ -76,6 +76,7 @@ class WanTransformerDistInfer(WanTransformerInfer): ...@@ -76,6 +76,7 @@ class WanTransformerDistInfer(WanTransformerInfer):
cur_rank = dist.get_rank(self.seq_p_group) cur_rank = dist.get_rank(self.seq_p_group)
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist() f, h, w = grid_sizes[0].tolist()
valid_token_length = f * h * w
f = f + 1 f = f + 1
seq_len = f * h * w seq_len = f * h * w
freqs_i = torch.cat( freqs_i = torch.cat(
...@@ -87,6 +88,8 @@ class WanTransformerDistInfer(WanTransformerInfer): ...@@ -87,6 +88,8 @@ class WanTransformerDistInfer(WanTransformerInfer):
dim=-1, dim=-1,
).reshape(seq_len, 1, -1) ).reshape(seq_len, 1, -1)
freqs_i[valid_token_length:, :, :f] = 0
freqs_i = pad_freqs(freqs_i, s * world_size) freqs_i = pad_freqs(freqs_i, s * world_size)
s_per_rank = s s_per_rank = s
freqs_i_rank = freqs_i[(cur_rank * s_per_rank) : ((cur_rank + 1) * s_per_rank), :, :] freqs_i_rank = freqs_i[(cur_rank * s_per_rank) : ((cur_rank + 1) * s_per_rank), :, :]
......
...@@ -11,7 +11,6 @@ from lightx2v.utils.envs import * ...@@ -11,7 +11,6 @@ from lightx2v.utils.envs import *
from .utils import apply_rotary_emb, apply_rotary_emb_chunk, compute_freqs, compute_freqs_audio from .utils import apply_rotary_emb, apply_rotary_emb_chunk, compute_freqs, compute_freqs_audio
class WanTransformerInfer(BaseTransformerInfer): class WanTransformerInfer(BaseTransformerInfer):
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
...@@ -33,6 +32,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -33,6 +32,7 @@ class WanTransformerInfer(BaseTransformerInfer):
self.infer_dtype = GET_DTYPE() self.infer_dtype = GET_DTYPE()
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE() self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
self.seq_p_group = None
if self.config.get("cpu_offload", False): if self.config.get("cpu_offload", False):
if torch.cuda.get_device_capability(0) == (9, 0): if torch.cuda.get_device_capability(0) == (9, 0):
assert self.config["self_attn_1_type"] != "sage_attn2" assert self.config["self_attn_1_type"] != "sage_attn2"
...@@ -360,8 +360,6 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -360,8 +360,6 @@ class WanTransformerInfer(BaseTransformerInfer):
freqs_i = self.compute_freqs(q, grid_sizes, freqs) freqs_i = self.compute_freqs(q, grid_sizes, freqs)
freqs_i = self.zero_temporal_component_in_3DRoPE(seq_lens, freqs_i)
q = self.apply_rotary_emb_func(q, freqs_i) q = self.apply_rotary_emb_func(q, freqs_i)
k = self.apply_rotary_emb_func(k, freqs_i) k = self.apply_rotary_emb_func(k, freqs_i)
......
...@@ -22,6 +22,7 @@ def compute_freqs(c, grid_sizes, freqs): ...@@ -22,6 +22,7 @@ def compute_freqs(c, grid_sizes, freqs):
def compute_freqs_audio(c, grid_sizes, freqs): def compute_freqs_audio(c, grid_sizes, freqs):
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0] f, h, w = grid_sizes[0]
valid_token_length = f * h * w
f = f + 1 ##for r2v add 1 channel f = f + 1 ##for r2v add 1 channel
seq_len = f * h * w seq_len = f * h * w
freqs_i = torch.cat( freqs_i = torch.cat(
...@@ -33,6 +34,8 @@ def compute_freqs_audio(c, grid_sizes, freqs): ...@@ -33,6 +34,8 @@ def compute_freqs_audio(c, grid_sizes, freqs):
dim=-1, dim=-1,
).reshape(seq_len, 1, -1) ).reshape(seq_len, 1, -1)
freqs_i[valid_token_length:, :, :f] = 0 ###for r2v # zero temporl component corresponding to ref embeddings
return freqs_i return freqs_i
......
...@@ -426,7 +426,15 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -426,7 +426,15 @@ class WanAudioRunner(WanRunner): # type:ignore
else: else:
device = torch.device("cuda") device = torch.device("cuda")
audio_encoder_repo = self.config["model_path"] + "/audio_encoder" audio_encoder_repo = self.config["model_path"] + "/audio_encoder"
self._audio_adapter_pipe = AudioAdapterPipe(audio_adapter, audio_encoder_repo=audio_encoder_repo, dtype=GET_DTYPE(), device=device, weight=1.0, cpu_offload=cpu_offload)
if self.model.transformer_infer.seq_p_group is not None:
seq_p_group = self.model.transformer_infer.seq_p_group
else:
seq_p_group = None
self._audio_adapter_pipe = AudioAdapterPipe(
audio_adapter, audio_encoder_repo=audio_encoder_repo, dtype=GET_DTYPE(), device=device, weight=1.0, cpu_offload=cpu_offload, seq_p_group=seq_p_group
)
return self._audio_adapter_pipe return self._audio_adapter_pipe
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment