"vscode:/vscode.git/clone" did not exist on "02c6f1ac215f637507e4d48fe5c322ec9df0f43a"
Commit 3488b187 authored by Yang Yong(雍洋)'s avatar Yang Yong(雍洋) Committed by GitHub
Browse files

update audio pre_infer (#241)

parent d8454a2b
...@@ -28,7 +28,6 @@ class WanAudioModel(WanModel): ...@@ -28,7 +28,6 @@ class WanAudioModel(WanModel):
def set_audio_adapter(self, audio_adapter): def set_audio_adapter(self, audio_adapter):
self.audio_adapter = audio_adapter self.audio_adapter = audio_adapter
self.pre_infer.set_audio_adapter(self.audio_adapter)
self.transformer_infer.set_audio_adapter(self.audio_adapter) self.transformer_infer.set_audio_adapter(self.audio_adapter)
......
import math
import torch import torch
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
...@@ -35,36 +33,13 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -35,36 +33,13 @@ class WanAudioPreInfer(WanPreInfer):
else: else:
self.sp_size = 1 self.sp_size = 1
def set_audio_adapter(self, audio_adapter):
self.audio_adapter = audio_adapter
def infer(self, weights, inputs): def infer(self, weights, inputs):
prev_latents = inputs["previmg_encoder_output"]["prev_latents"] prev_latents = inputs["previmg_encoder_output"]["prev_latents"]
if self.config.model_cls == "wan2.2_audio":
hidden_states = self.scheduler.latents
prev_mask = inputs["previmg_encoder_output"]["prev_mask"]
hidden_states = (1.0 - prev_mask[0]) * prev_latents + prev_mask[0] * hidden_states
else:
prev_latents = prev_latents.unsqueeze(0)
prev_mask = inputs["previmg_encoder_output"]["prev_mask"] prev_mask = inputs["previmg_encoder_output"]["prev_mask"]
hidden_states = self.scheduler.latents.unsqueeze(0) hidden_states = torch.cat([self.scheduler.latents, prev_mask, prev_latents], dim=0)
hidden_states = torch.cat([hidden_states, prev_mask, prev_latents], dim=1)
hidden_states = hidden_states.squeeze(0)
x = hidden_states x = hidden_states
t = torch.stack([self.scheduler.timesteps[self.scheduler.step_index]]) t = self.scheduler.timestep_input
if self.config.model_cls == "wan2.2_audio":
_, lat_f, lat_h, lat_w = self.scheduler.latents.shape
F = (lat_f - 1) * self.config.vae_stride[0] + 1
max_seq_len = ((F - 1) // self.config.vae_stride[0] + 1) * lat_h * lat_w // (self.config.patch_size[1] * self.config.patch_size[2])
max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
temp_ts = (prev_mask[0][0][:, ::2, ::2] * t).flatten()
temp_ts = torch.cat([temp_ts, temp_ts.new_ones(max_seq_len - temp_ts.size(0)) * t])
t = temp_ts.unsqueeze(0)
t_emb = self.audio_adapter.time_embedding(t).unflatten(1, (3, -1))
if self.scheduler.infer_condition: if self.scheduler.infer_condition:
context = inputs["text_encoder_output"]["context"] context = inputs["text_encoder_output"]["context"]
...@@ -76,16 +51,16 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -76,16 +51,16 @@ class WanAudioPreInfer(WanPreInfer):
ref_image_encoder = inputs["image_encoder_output"]["vae_encoder_out"].to(self.scheduler.latents.dtype) ref_image_encoder = inputs["image_encoder_output"]["vae_encoder_out"].to(self.scheduler.latents.dtype)
# batch_size = len(x) # batch_size = len(x)
num_channels, _, height, width = x.shape num_channels, _, height, width = x.shape
_, ref_num_channels, ref_num_frames, _, _ = ref_image_encoder.shape ref_num_channels, ref_num_frames, _, _ = ref_image_encoder.shape
if ref_num_channels != num_channels: if ref_num_channels != num_channels:
zero_padding = torch.zeros( zero_padding = torch.zeros(
(1, num_channels - ref_num_channels, ref_num_frames, height, width), (num_channels - ref_num_channels, ref_num_frames, height, width),
dtype=self.scheduler.latents.dtype, dtype=self.scheduler.latents.dtype,
device=self.scheduler.latents.device, device=self.scheduler.latents.device,
) )
ref_image_encoder = torch.concat([ref_image_encoder, zero_padding], dim=1) ref_image_encoder = torch.concat([ref_image_encoder, zero_padding], dim=0)
y = list(torch.unbind(ref_image_encoder, dim=0)) # 第一个batch维度变成list y = ref_image_encoder # 第一个batch维度变成list
# embeddings # embeddings
x = weights.patch_embedding.apply(x.unsqueeze(0)) x = weights.patch_embedding.apply(x.unsqueeze(0))
...@@ -93,29 +68,10 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -93,29 +68,10 @@ class WanAudioPreInfer(WanPreInfer):
x = x.flatten(2).transpose(1, 2).contiguous() x = x.flatten(2).transpose(1, 2).contiguous()
seq_lens = torch.tensor(x.size(1), dtype=torch.long).cuda().unsqueeze(0) seq_lens = torch.tensor(x.size(1), dtype=torch.long).cuda().unsqueeze(0)
y = [weights.patch_embedding.apply(u.unsqueeze(0)) for u in y] y = weights.patch_embedding.apply(y.unsqueeze(0))
# y_grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in y]) y = y.flatten(2).transpose(1, 2).contiguous()
y = [u.flatten(2).transpose(1, 2).squeeze(0) for u in y]
ref_seq_lens = torch.tensor([u.size(0) for u in y], dtype=torch.long)
x = [torch.cat([a, b], dim=0) for a, b in zip(x, y)]
x = torch.stack(x, dim=0)
seq_len = x[0].size(0)
if self.config.model_cls == "wan2.2_audio": x = torch.cat([x, y], dim=1)
bt = t.size(0)
ref_seq_len = ref_seq_lens[0].item()
t = torch.cat(
[
t,
torch.zeros(
(1, ref_seq_len),
dtype=t.dtype,
device=t.device,
),
],
dim=1,
)
embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten()) embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten())
if self.sensitive_layer_dtype != self.infer_dtype: if self.sensitive_layer_dtype != self.infer_dtype:
...@@ -167,5 +123,5 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -167,5 +123,5 @@ class WanAudioPreInfer(WanPreInfer):
seq_lens=seq_lens, seq_lens=seq_lens,
freqs=self.freqs, freqs=self.freqs,
context=context, context=context,
adapter_output={"audio_encoder_output": inputs["audio_encoder_output"], "t_emb": t_emb}, adapter_output={"audio_encoder_output": inputs["audio_encoder_output"]},
) )
...@@ -32,7 +32,7 @@ class WanAudioTransformerInfer(WanOffloadTransformerInfer): ...@@ -32,7 +32,7 @@ class WanAudioTransformerInfer(WanOffloadTransformerInfer):
grid_sizes=pre_infer_out.grid_sizes, grid_sizes=pre_infer_out.grid_sizes,
ca_block=self.audio_adapter.ca[self.block_idx], ca_block=self.audio_adapter.ca[self.block_idx],
audio_encoder_output=pre_infer_out.adapter_output["audio_encoder_output"], audio_encoder_output=pre_infer_out.adapter_output["audio_encoder_output"],
t_emb=pre_infer_out.adapter_output["t_emb"], t_emb=self.scheduler.audio_adapter_t_emb,
weight=1.0, weight=1.0,
seq_p_group=self.seq_p_group, seq_p_group=self.seq_p_group,
) )
......
...@@ -34,15 +34,7 @@ class WanPreInfer: ...@@ -34,15 +34,7 @@ class WanPreInfer:
def infer(self, weights, inputs, kv_start=0, kv_end=0): def infer(self, weights, inputs, kv_start=0, kv_end=0):
x = self.scheduler.latents x = self.scheduler.latents
t = self.scheduler.timestep_input
if self.scheduler.flag_df:
t = self.scheduler.df_timesteps[self.scheduler.step_index].unsqueeze(0)
assert t.dim() == 2 # df推理模型timestep是二维
else:
timestep = self.scheduler.timesteps[self.scheduler.step_index]
t = torch.stack([timestep])
if self.config["model_cls"] == "wan2.2" and self.config["task"] == "i2v":
t = (self.scheduler.mask[0][:, ::2, ::2] * t).flatten()
if self.scheduler.infer_condition: if self.scheduler.infer_condition:
context = inputs["text_encoder_output"]["context"] context = inputs["text_encoder_output"]["context"]
...@@ -91,15 +83,6 @@ class WanPreInfer: ...@@ -91,15 +83,6 @@ class WanPreInfer:
embed0 = weights.time_projection_1.apply(embed0).unflatten(1, (6, self.dim)) embed0 = weights.time_projection_1.apply(embed0).unflatten(1, (6, self.dim))
if self.scheduler.flag_df:
b, f = t.shape
assert b == len(x) # batch_size == 1
embed = embed.view(b, f, 1, 1, self.dim)
embed0 = embed0.view(b, f, 1, 1, 6, self.dim)
embed = embed.repeat(1, 1, grid_sizes[0][1], grid_sizes[0][2], 1).flatten(1, 3)
embed0 = embed0.repeat(1, 1, grid_sizes[0][1], grid_sizes[0][2], 1, 1).flatten(1, 3)
embed0 = embed0.transpose(1, 2).contiguous()
# text embeddings # text embeddings
stacked = torch.stack([torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context]) stacked = torch.stack([torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context])
if self.sensitive_layer_dtype != self.infer_dtype: if self.sensitive_layer_dtype != self.infer_dtype:
......
...@@ -246,6 +246,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -246,6 +246,7 @@ class WanAudioRunner(WanRunner): # type:ignore
def init_scheduler(self): def init_scheduler(self):
"""Initialize consistency model scheduler""" """Initialize consistency model scheduler"""
scheduler = ConsistencyModelScheduler(self.config) scheduler = ConsistencyModelScheduler(self.config)
scheduler.set_audio_adapter(self.audio_adapter)
self.model.set_scheduler(scheduler) self.model.set_scheduler(scheduler)
def read_audio_input(self): def read_audio_input(self):
...@@ -292,12 +293,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -292,12 +293,7 @@ class WanAudioRunner(WanRunner): # type:ignore
def run_vae_encoder(self, img): def run_vae_encoder(self, img):
img = rearrange(img, "1 C H W -> 1 C 1 H W") img = rearrange(img, "1 C H W -> 1 C 1 H W")
vae_encoder_out = self.vae_encoder.encode(img.to(torch.float)) vae_encoder_out = self.vae_encoder.encode(img.to(torch.float))[0]
if self.config.model_cls == "wan2.2_audio":
vae_encoder_out = vae_encoder_out.unsqueeze(0).to(GET_DTYPE())
else:
if isinstance(vae_encoder_out, list):
vae_encoder_out = torch.stack(vae_encoder_out, dim=0).to(GET_DTYPE())
return vae_encoder_out return vae_encoder_out
@ProfilingContext("Run Encoders") @ProfilingContext("Run Encoders")
...@@ -351,7 +347,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -351,7 +347,7 @@ class WanAudioRunner(WanRunner): # type:ignore
frames_n = (nframe - 1) * 4 + 1 frames_n = (nframe - 1) * 4 + 1
prev_mask = torch.ones((1, frames_n, height, width), device=device, dtype=dtype) prev_mask = torch.ones((1, frames_n, height, width), device=device, dtype=dtype)
prev_mask[:, prev_frame_len:] = 0 prev_mask[:, prev_frame_len:] = 0
prev_mask = self._wan_mask_rearrange(prev_mask).unsqueeze(0) prev_mask = self._wan_mask_rearrange(prev_mask)
if prev_latents.shape[-2:] != (height, width): if prev_latents.shape[-2:] != (height, width):
logger.warning(f"Size mismatch: prev_latents {prev_latents.shape} vs scheduler latents (H={height}, W={width}). Config tgt_h={self.config.tgt_h}, tgt_w={self.config.tgt_w}") logger.warning(f"Size mismatch: prev_latents {prev_latents.shape} vs scheduler latents (H={height}, W={width}). Config tgt_h={self.config.tgt_h}, tgt_w={self.config.tgt_w}")
......
...@@ -12,10 +12,12 @@ class ConsistencyModelScheduler(WanScheduler): ...@@ -12,10 +12,12 @@ class ConsistencyModelScheduler(WanScheduler):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
def set_audio_adapter(self, audio_adapter):
self.audio_adapter = audio_adapter
def step_pre(self, step_index): def step_pre(self, step_index):
self.step_index = step_index super().step_pre(step_index)
if GET_DTYPE() == GET_SENSITIVE_DTYPE(): self.audio_adapter_t_emb = self.audio_adapter.time_embedding(self.timestep_input).unflatten(1, (3, -1))
self.latents = self.latents.to(GET_DTYPE())
def prepare(self, image_encoder_output=None): def prepare(self, image_encoder_output=None):
self.prepare_latents(self.config.target_shape, dtype=torch.float32) self.prepare_latents(self.config.target_shape, dtype=torch.float32)
......
...@@ -320,6 +320,12 @@ class WanScheduler(BaseScheduler): ...@@ -320,6 +320,12 @@ class WanScheduler(BaseScheduler):
x_t = x_t.to(x.dtype) x_t = x_t.to(x.dtype)
return x_t return x_t
def step_pre(self, step_index):
super().step_pre(step_index)
self.timestep_input = torch.stack([self.timesteps[self.step_index]])
if self.config["model_cls"] == "wan2.2" and self.config["task"] == "i2v":
self.timestep_input = (self.mask[0][:, ::2, ::2] * self.timestep_input).flatten()
def step_post(self): def step_post(self):
model_output = self.noise_pred.to(torch.float32) model_output = self.noise_pred.to(torch.float32)
timestep = self.timesteps[self.step_index] timestep = self.timesteps[self.step_index]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment