Commit 1b0743a5 authored by sandy's avatar sandy Committed by GitHub
Browse files

Support sekotalk multiperson (#321)


Co-authored-by: default avatarPengGao <peng.gaoc@gmail.com>
Co-authored-by: default avatarCopilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: default avatarYang Yong (雍洋) <yongyang1030@163.com>
parent 6b7a3cad
{
"talk_objects": [
{
"audio": "p1.mp3",
"mask": "p1_mask.png"
},
{
"audio": "p2.mp3",
"mask": "p2_mask.png"
}
]
}
{
"talk_objects": [
{
"audio": "p1.mp3",
"mask": "p1_mask.png"
},
{
"audio": "p2.mp3",
"mask": "p2_mask.png"
}
]
}
{
"talk_objects": [
{
"audio": "p1.mp3",
"mask": "p1_mask.png"
}
]
}
{
"infer_steps": 4,
"target_fps": 16,
"video_duration": 360,
"audio_sr": 16000,
"target_video_length": 81,
"resize_mode": "adaptive",
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 1.0,
"sample_shift": 5,
"enable_cfg": false,
"cpu_offload": false,
"use_31_block": false
}
{
"infer_steps": 4,
"target_fps": 16,
"video_duration": 5,
"audio_sr": 16000,
"target_video_length": 81,
"resize_mode": "adaptive",
"self_attn_1_type": "sage_attn2",
"cross_attn_1_type": "sage_attn2",
"cross_attn_2_type": "sage_attn2",
"seed": 42,
"sample_guide_scale": 1.0,
"sample_shift": 5,
"enable_cfg": false,
"cpu_offload": false,
"use_31_block": false,
"mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl"
},
"adapter_quantized": true,
"adapter_quant_scheme": "fp8",
"t5_quantized": true,
"t5_quant_scheme": "fp8"
}
{
"infer_steps": 4,
"target_fps": 16,
"video_duration": 360,
"audio_sr": 16000,
"target_video_length": 81,
"resize_mode": "adaptive",
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 1.0,
"sample_shift": 5,
"enable_cfg": false,
"cpu_offload": false,
"use_31_block": false,
"parallel": {
"seq_p_size": 4,
"seq_p_attn_type": "ulysses"
},
"person_mask_path": "assets/inputs/audio/multi_person"
}
{
"infer_steps": 4,
"target_fps": 16,
"video_duration": 360,
"audio_sr": 16000,
"target_video_length": 81,
"resize_mode": "adaptive",
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 1.0,
"sample_shift": 5,
"enable_cfg": false,
"cpu_offload": false,
"use_31_block": false,
"compile": true,
"compile_shapes": [[480, 832], [720, 1280]],
"compile_max_audios": 3,
"person_mask_path": "assets/inputs/audio/multi_person"
}
...@@ -72,7 +72,7 @@ class VAReader: ...@@ -72,7 +72,7 @@ class VAReader:
def start_ffmpeg_process_rtmp(self): def start_ffmpeg_process_rtmp(self):
"""Start ffmpeg process read audio from stream""" """Start ffmpeg process read audio from stream"""
ffmpeg_cmd = [ ffmpeg_cmd = [
"/opt/conda/bin/ffmpeg", "ffmpeg",
"-i", "-i",
self.stream_url, self.stream_url,
"-vn", "-vn",
...@@ -97,6 +97,7 @@ class VAReader: ...@@ -97,6 +97,7 @@ class VAReader:
def start_ffmpeg_process_whep(self): def start_ffmpeg_process_whep(self):
"""Start gstream process read audio from stream""" """Start gstream process read audio from stream"""
ffmpeg_cmd = [ ffmpeg_cmd = [
"ffmpeg",
"gst-launch-1.0", "gst-launch-1.0",
"-q", "-q",
"whepsrc", "whepsrc",
......
...@@ -73,8 +73,8 @@ class VARecorder: ...@@ -73,8 +73,8 @@ class VARecorder:
logger.info("Audio thread received stop signal") logger.info("Audio thread received stop signal")
break break
# Convert audio data to 16-bit integer format # Convert audio data to 16-bit integer format
audios = np.clip(np.round(data * 32767), -32768, 32767).astype(np.int16) audios = torch.clamp(torch.round(data * 32767), -32768, 32767).to(torch.int16)
self.audio_conn.send(audios.tobytes()) self.audio_conn.send(audios[None].cpu().numpy().tobytes())
fail_time = 0 fail_time = 0
except: # noqa except: # noqa
logger.error(f"Send audio data error: {traceback.format_exc()}") logger.error(f"Send audio data error: {traceback.format_exc()}")
...@@ -119,8 +119,7 @@ class VARecorder: ...@@ -119,8 +119,7 @@ class VARecorder:
def start_ffmpeg_process_local(self): def start_ffmpeg_process_local(self):
"""Start ffmpeg process that connects to our TCP sockets""" """Start ffmpeg process that connects to our TCP sockets"""
ffmpeg_cmd = [ ffmpeg_cmd = [
"/opt/conda/bin/ffmpeg", "ffmpeg",
"-re",
"-f", "-f",
"s16le", "s16le",
"-ar", "-ar",
...@@ -131,7 +130,6 @@ class VARecorder: ...@@ -131,7 +130,6 @@ class VARecorder:
f"tcp://127.0.0.1:{self.audio_port}", f"tcp://127.0.0.1:{self.audio_port}",
"-f", "-f",
"rawvideo", "rawvideo",
"-re",
"-pix_fmt", "-pix_fmt",
"rgb24", "rgb24",
"-r", "-r",
...@@ -171,7 +169,7 @@ class VARecorder: ...@@ -171,7 +169,7 @@ class VARecorder:
def start_ffmpeg_process_rtmp(self): def start_ffmpeg_process_rtmp(self):
"""Start ffmpeg process that connects to our TCP sockets""" """Start ffmpeg process that connects to our TCP sockets"""
ffmpeg_cmd = [ ffmpeg_cmd = [
"/opt/conda/bin/ffmpeg", "ffmpeg",
"-re", "-re",
"-f", "-f",
"s16le", "s16le",
...@@ -223,7 +221,7 @@ class VARecorder: ...@@ -223,7 +221,7 @@ class VARecorder:
def start_ffmpeg_process_whip(self): def start_ffmpeg_process_whip(self):
"""Start ffmpeg process that connects to our TCP sockets""" """Start ffmpeg process that connects to our TCP sockets"""
ffmpeg_cmd = [ ffmpeg_cmd = [
"/opt/conda/bin/ffmpeg", "ffmpeg",
"-re", "-re",
"-f", "-f",
"s16le", "s16le",
...@@ -299,7 +297,7 @@ class VARecorder: ...@@ -299,7 +297,7 @@ class VARecorder:
self.video_thread.start() self.video_thread.start()
# Publish ComfyUI Image tensor and audio tensor to livestream # Publish ComfyUI Image tensor and audio tensor to livestream
def pub_livestream(self, images: torch.Tensor, audios: np.ndarray): def pub_livestream(self, images: torch.Tensor, audios: torch.Tensor):
N, height, width, C = images.shape N, height, width, C = images.shape
M = audios.reshape(-1).shape[0] M = audios.reshape(-1).shape[0]
assert C == 3, "Input must be [N, H, W, C] with C=3" assert C == 3, "Input must be [N, H, W, C] with C=3"
...@@ -414,7 +412,7 @@ if __name__ == "__main__": ...@@ -414,7 +412,7 @@ if __name__ == "__main__":
audio_path = "/path/to/test_b_2min.wav" audio_path = "/path/to/test_b_2min.wav"
audio_array, ori_sr = ta.load(audio_path) audio_array, ori_sr = ta.load(audio_path)
audio_array = ta.functional.resample(audio_array.mean(0), orig_freq=ori_sr, new_freq=16000) audio_array = ta.functional.resample(audio_array.mean(0), orig_freq=ori_sr, new_freq=16000)
audio_array = audio_array.numpy().reshape(-1) audio_array = audio_array.reshape(-1)
secs = audio_array.shape[0] // sample_rate secs = audio_array.shape[0] // sample_rate
interval = 1 interval = 1
......
...@@ -64,8 +64,9 @@ def main(): ...@@ -64,8 +64,9 @@ def main():
parser.add_argument("--image_path", type=str, default="", help="The path to input image file for image-to-video (i2v) task") parser.add_argument("--image_path", type=str, default="", help="The path to input image file for image-to-video (i2v) task")
parser.add_argument("--last_frame_path", type=str, default="", help="The path to last frame file for first-last-frame-to-video (flf2v) task") parser.add_argument("--last_frame_path", type=str, default="", help="The path to last frame file for first-last-frame-to-video (flf2v) task")
parser.add_argument("--audio_path", type=str, default="", help="The path to input audio file for audio-to-video (a2v) task") parser.add_argument("--audio_path", type=str, default="", help="The path to input audio file or directory for audio-to-video (s2v) task")
# [Warning] For vace task, need refactor.
parser.add_argument( parser.add_argument(
"--src_ref_images", "--src_ref_images",
type=str, type=str,
......
...@@ -68,7 +68,7 @@ def get_qk_lens_audio_range( ...@@ -68,7 +68,7 @@ def get_qk_lens_audio_range(
return q_lens, k_lens, max_seqlen_q, max_seqlen_k, t0, t1 return q_lens, k_lens, max_seqlen_q, max_seqlen_k, t0, t1
def calculate_n_query_tokens(hidden_states, sp_rank, sp_size, n_tokens_per_rank, n_tokens): def calculate_n_query_tokens(hidden_states, person_mask_latens, sp_rank, sp_size, n_tokens_per_rank, n_tokens):
tail_length = n_tokens_per_rank * sp_size - n_tokens tail_length = n_tokens_per_rank * sp_size - n_tokens
n_unused_ranks = tail_length // n_tokens_per_rank n_unused_ranks = tail_length // n_tokens_per_rank
...@@ -83,12 +83,20 @@ def calculate_n_query_tokens(hidden_states, sp_rank, sp_size, n_tokens_per_rank, ...@@ -83,12 +83,20 @@ def calculate_n_query_tokens(hidden_states, sp_rank, sp_size, n_tokens_per_rank,
if n_query_tokens > 0: if n_query_tokens > 0:
hidden_states_aligned = hidden_states[:n_query_tokens] hidden_states_aligned = hidden_states[:n_query_tokens]
hidden_states_tail = hidden_states[n_query_tokens:] hidden_states_tail = hidden_states[n_query_tokens:]
if person_mask_latens is not None:
person_mask_aligned = person_mask_latens[:, :n_query_tokens]
else:
person_mask_aligned = None
else: else:
# for ranks that should be excluded from cross-attn, fake cross-attn will be applied so that FSDP works. # for ranks that should be excluded from cross-attn, fake cross-attn will be applied so that FSDP works.
hidden_states_aligned = hidden_states[:1] hidden_states_aligned = hidden_states[:1]
hidden_states_tail = hidden_states[1:] hidden_states_tail = hidden_states[1:]
if person_mask_latens is not None:
person_mask_aligned = person_mask_latens[:, :1]
else:
person_mask_aligned = None
return n_query_tokens, hidden_states_aligned, hidden_states_tail return n_query_tokens, hidden_states_aligned, hidden_states_tail, person_mask_aligned
''' '''
......
...@@ -51,12 +51,12 @@ class WanAudioModel(WanModel): ...@@ -51,12 +51,12 @@ class WanAudioModel(WanModel):
self.post_infer_class = WanAudioPostInfer self.post_infer_class = WanAudioPostInfer
self.transformer_infer_class = WanAudioTransformerInfer self.transformer_infer_class = WanAudioTransformerInfer
def get_graph_name(self, shape): def get_graph_name(self, shape, audio_num):
return f"graph_{shape[0]}x{shape[1]}" return f"graph_{shape[0]}x{shape[1]}_{audio_num}audio"
def start_compile(self, shape): def start_compile(self, shape, audio_num):
graph_name = self.get_graph_name(shape) graph_name = self.get_graph_name(shape, audio_num)
logger.info(f"[Compile] Compile shape: {shape}, graph_name: {graph_name}") logger.info(f"[Compile] Compile shape: {shape}, audio_num:{audio_num}, graph_name: {graph_name}")
target_video_length = self.config.get("target_video_length", 81) target_video_length = self.config.get("target_video_length", 81)
latents_length = (target_video_length - 1) // 16 * 4 + 1 latents_length = (target_video_length - 1) // 16 * 4 + 1
...@@ -72,7 +72,8 @@ class WanAudioModel(WanModel): ...@@ -72,7 +72,8 @@ class WanAudioModel(WanModel):
new_inputs["image_encoder_output"]["clip_encoder_out"] = torch.randn(257, 1280, dtype=torch.bfloat16).cuda() new_inputs["image_encoder_output"]["clip_encoder_out"] = torch.randn(257, 1280, dtype=torch.bfloat16).cuda()
new_inputs["image_encoder_output"]["vae_encoder_out"] = torch.randn(16, 1, latents_h, latents_w, dtype=torch.bfloat16).cuda() new_inputs["image_encoder_output"]["vae_encoder_out"] = torch.randn(16, 1, latents_h, latents_w, dtype=torch.bfloat16).cuda()
new_inputs["audio_encoder_output"] = torch.randn(1, latents_length, 128, 1024, dtype=torch.bfloat16).cuda() new_inputs["audio_encoder_output"] = torch.randn(audio_num, latents_length, 128, 1024, dtype=torch.bfloat16).cuda()
new_inputs["person_mask_latens"] = torch.zeros(audio_num, 1, (latents_h // 2), (latents_w // 2), dtype=torch.int8).cuda()
new_inputs["previmg_encoder_output"] = {} new_inputs["previmg_encoder_output"] = {}
new_inputs["previmg_encoder_output"]["prev_latents"] = torch.randn(16, latents_length, latents_h, latents_w, dtype=torch.bfloat16).cuda() new_inputs["previmg_encoder_output"]["prev_latents"] = torch.randn(16, latents_length, latents_h, latents_w, dtype=torch.bfloat16).cuda()
...@@ -95,8 +96,10 @@ class WanAudioModel(WanModel): ...@@ -95,8 +96,10 @@ class WanAudioModel(WanModel):
self.pre_weight.to_cuda() self.pre_weight.to_cuda()
self.transformer_weights.non_block_weights_to_cuda() self.transformer_weights.non_block_weights_to_cuda()
for shape in compile_shapes: max_audio_num_num = self.config.get("compile_max_audios", 1)
self.start_compile(shape) for audio_num in range(1, max_audio_num_num + 1):
for shape in compile_shapes:
self.start_compile(shape, audio_num)
if self.cpu_offload: if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1 and "wan2.2_moe" not in self.config.model_cls: if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1 and "wan2.2_moe" not in self.config.model_cls:
...@@ -113,6 +116,34 @@ class WanAudioModel(WanModel): ...@@ -113,6 +116,34 @@ class WanAudioModel(WanModel):
assert shape in [[480, 832], [544, 960], [720, 1280], [832, 480], [960, 544], [1280, 720], [480, 480], [576, 576], [704, 704], [960, 960]] assert shape in [[480, 832], [544, 960], [720, 1280], [832, 480], [960, 544], [1280, 720], [480, 480], [576, 576], [704, 704], [960, 960]]
def select_graph_for_compile(self): def select_graph_for_compile(self):
logger.info(f"tgt_h, tgt_w : {self.config.get('tgt_h')}, {self.config.get('tgt_w')}") logger.info(f"tgt_h, tgt_w : {self.config.get('tgt_h')}, {self.config.get('tgt_w')}, audio_num: {self.config.get('audio_num')}")
self.select_graph("_infer_cond_uncond", f"graph_{self.config.get('tgt_h')}x{self.config.get('tgt_w')}") self.select_graph("_infer_cond_uncond", f"graph_{self.config.get('tgt_h')}x{self.config.get('tgt_w')}_{self.config.get('audio_num')}audio")
logger.info(f"[Compile] Compile status: {self.get_compile_status()}") logger.info(f"[Compile] Compile status: {self.get_compile_status()}")
@torch.no_grad()
def _seq_parallel_pre_process(self, pre_infer_out):
x = pre_infer_out.x
person_mask_latens = pre_infer_out.adapter_output["person_mask_latens"]
world_size = dist.get_world_size(self.seq_p_group)
cur_rank = dist.get_rank(self.seq_p_group)
padding_size = (world_size - (x.shape[0] % world_size)) % world_size
if padding_size > 0:
x = F.pad(x, (0, 0, 0, padding_size))
if person_mask_latens is not None:
person_mask_latens = F.pad(person_mask_latens, (0, padding_size))
pre_infer_out.x = torch.chunk(x, world_size, dim=0)[cur_rank]
if person_mask_latens is not None:
pre_infer_out.adapter_output["person_mask_latens"] = torch.chunk(person_mask_latens, world_size, dim=1)[cur_rank]
if self.config["model_cls"] in ["wan2.2", "wan2.2_audio"] and self.config["task"] == "i2v":
embed, embed0 = pre_infer_out.embed, pre_infer_out.embed0
padding_size = (world_size - (embed.shape[0] % world_size)) % world_size
if padding_size > 0:
embed = F.pad(embed, (0, 0, 0, padding_size))
embed0 = F.pad(embed0, (0, 0, 0, 0, 0, padding_size))
pre_infer_out.embed = torch.chunk(embed, world_size, dim=0)[cur_rank]
pre_infer_out.embed0 = torch.chunk(embed0, world_size, dim=0)[cur_rank]
return pre_infer_out
...@@ -74,6 +74,11 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -74,6 +74,11 @@ class WanAudioPreInfer(WanPreInfer):
self.freqs[grid_sizes_t:, : self.rope_t_dim] = 0 self.freqs[grid_sizes_t:, : self.rope_t_dim] = 0
grid_sizes_t += 1 grid_sizes_t += 1
person_mask_latens = inputs["person_mask_latens"]
if person_mask_latens is not None:
person_mask_latens = person_mask_latens.expand(-1, grid_sizes_t, -1, -1)
person_mask_latens = person_mask_latens.reshape(person_mask_latens.shape[0], -1)
embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten()) embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten())
if self.sensitive_layer_dtype != self.infer_dtype: if self.sensitive_layer_dtype != self.infer_dtype:
embed = weights.time_embedding_0.apply(embed.to(self.sensitive_layer_dtype)) embed = weights.time_embedding_0.apply(embed.to(self.sensitive_layer_dtype))
...@@ -123,5 +128,5 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -123,5 +128,5 @@ class WanAudioPreInfer(WanPreInfer):
seq_lens=seq_lens, seq_lens=seq_lens,
freqs=self.freqs, freqs=self.freqs,
context=context, context=context,
adapter_output={"audio_encoder_output": inputs["audio_encoder_output"]}, adapter_output={"audio_encoder_output": inputs["audio_encoder_output"], "person_mask_latens": person_mask_latens},
) )
...@@ -23,7 +23,7 @@ class WanAudioTransformerInfer(WanOffloadTransformerInfer): ...@@ -23,7 +23,7 @@ class WanAudioTransformerInfer(WanOffloadTransformerInfer):
def infer_post_adapter(self, phase, x, pre_infer_out): def infer_post_adapter(self, phase, x, pre_infer_out):
grid_sizes = pre_infer_out.grid_sizes.tensor grid_sizes = pre_infer_out.grid_sizes.tensor
audio_encoder_output = pre_infer_out.adapter_output["audio_encoder_output"] audio_encoder_output = pre_infer_out.adapter_output["audio_encoder_output"]
person_mask_latens = pre_infer_out.adapter_output["person_mask_latens"]
total_tokens = grid_sizes[0].prod() total_tokens = grid_sizes[0].prod()
pre_frame_tokens = grid_sizes[0][1:].prod() pre_frame_tokens = grid_sizes[0][1:].prod()
n_tokens = total_tokens - pre_frame_tokens # 去掉ref image的token数 n_tokens = total_tokens - pre_frame_tokens # 去掉ref image的token数
...@@ -39,19 +39,30 @@ class WanAudioTransformerInfer(WanOffloadTransformerInfer): ...@@ -39,19 +39,30 @@ class WanAudioTransformerInfer(WanOffloadTransformerInfer):
sp_size = 1 sp_size = 1
sp_rank = 0 sp_rank = 0
n_query_tokens, hidden_states_aligned, hidden_states_tail = calculate_n_query_tokens(x, sp_rank, sp_size, n_tokens_per_rank, n_tokens) n_query_tokens, hidden_states_aligned, hidden_states_tail, person_mask_aligned = calculate_n_query_tokens(x, person_mask_latens, sp_rank, sp_size, n_tokens_per_rank, n_tokens)
q_lens, k_lens, max_seqlen_q, max_seqlen_k, t0, t1 = get_qk_lens_audio_range( q_lens, k_lens, max_seqlen_q, max_seqlen_k, t0, t1 = get_qk_lens_audio_range(
n_tokens_per_rank=n_tokens_per_rank, n_query_tokens=n_query_tokens, n_tokens_per_frame=pre_frame_tokens, sp_rank=sp_rank, num_tokens_x4=128 n_tokens_per_rank=n_tokens_per_rank, n_query_tokens=n_query_tokens, n_tokens_per_frame=pre_frame_tokens, sp_rank=sp_rank, num_tokens_x4=128
) )
audio_encoder_output = audio_encoder_output[:, t0:t1].reshape(-1, audio_encoder_output.size(-1)) total_residual = None
residual = self.perceiver_attention_ca(phase, audio_encoder_output, hidden_states_aligned, self.scheduler.audio_adapter_t_emb, q_lens, k_lens, max_seqlen_q, max_seqlen_k) for i in range(audio_encoder_output.shape[0]):
audio_encoder = audio_encoder_output[i]
audio_encoder = audio_encoder[t0:t1].reshape(-1, audio_encoder.size(-1))
residual = self.perceiver_attention_ca(phase, audio_encoder, hidden_states_aligned, self.scheduler.audio_adapter_t_emb, q_lens, k_lens, max_seqlen_q, max_seqlen_k)
residual = residual.to(ori_dtype) # audio做了CrossAttention之后以Residual的方式注入
if n_query_tokens == 0:
residual = residual * 0.0
if person_mask_aligned is not None:
residual = residual * person_mask_aligned[i].unsqueeze(-1)
if total_residual is None:
total_residual = residual
else:
total_residual += residual
residual = residual.to(ori_dtype) # audio做了CrossAttention之后以Residual的方式注入 x = torch.cat([hidden_states_aligned + total_residual, hidden_states_tail], dim=0)
if n_query_tokens == 0:
residual = residual * 0.0
x = torch.cat([hidden_states_aligned + residual, hidden_states_tail], dim=0)
return x return x
@torch.no_grad() @torch.no_grad()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment