Unverified Commit bc2828b0 authored by LiangLiu's avatar LiangLiu Committed by GitHub
Browse files

Dev/omni (#577)

Tidy VAReader & OmniVAReader
Tidy VARecorder & X264VARecorder
VARecorder with stream, use buffer stream
Tidy env WORKER_RANK, READER_RANK, RECORDER_RANK
Support voice type choose
parent f3b4ba24
import math
import os
import torch
import torch.distributed as dist
from loguru import logger
from lightx2v.deploy.common.va_reader import VAReader
from lightx2v.deploy.common.va_reader_omni import OmniVAReader
from lightx2v.deploy.common.va_recorder import VARecorder
from lightx2v.deploy.common.va_recorder_x264 import X264VARecorder
from lightx2v.models.runners.vsr.vsr_wrapper import compute_scaled_and_target_dims
from lightx2v_platform.base.global_var import AI_DEVICE
class NextControl:
def __init__(self, action: str, data: any = None):
# action: switch, data: prev_video tensor
# action: wait, data: None
# action: fetch, data: None
self.action = action
self.data = data
class VAController:
def __init__(self, model_runner):
self.reader = None
self.recorder = None
self.rank = 0
self.world_size = 1
if dist.is_initialized():
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.target_reader_rank = int(os.getenv("READER_RANK", "0")) % self.world_size
self.target_recorder_rank = int(os.getenv("RECORDER_RANK", "0")) % self.world_size
self.init_base(model_runner.config, model_runner.input_info, model_runner.vfi_model is not None, model_runner.vsr_model is not None)
self.init_recorder()
self.init_reader(model_runner)
def init_base(self, config, input_info, has_vfi_model, has_vsr_model):
self.audio_path = input_info.audio_path
self.output_video_path = input_info.save_result_path
if isinstance(self.output_video_path, dict):
self.output_video_path = self.output_video_path["data"]
self.audio_sr = config.get("audio_sr", 16000)
self.target_fps = config.get("target_fps", 16)
self.max_num_frames = config.get("target_video_length", 81)
self.prev_frame_length = config.get("prev_frame_length", 5)
self.record_fps = config.get("target_fps", 16)
if "video_frame_interpolation" in config and has_vfi_model:
self.record_fps = config["video_frame_interpolation"]["target_fps"]
self.record_fps = config.get("record_fps", self.record_fps)
self.tgt_h = input_info.target_shape[0]
self.tgt_w = input_info.target_shape[1]
self.record_h, self.record_w = self.tgt_h, self.tgt_w
if "video_super_resolution" in config and has_vsr_model:
_, _, self.record_w, self.record_h = compute_scaled_and_target_dims(
self.record_w,
self.record_h,
scale=config["video_super_resolution"]["scale"],
multiple=128,
)
# how many frames to publish stream as a batch
self.slice_frame = config.get("slice_frame", 1)
# estimate the max infer seconds, for immediate switch with local omni
slice_interval = self.slice_frame / self.record_fps
est_max_infer_secs = config.get("est_max_infer_secs", 0.6)
self.est_infer_end_idx = math.ceil(est_max_infer_secs / slice_interval)
self.min_stay_queue_num = self.est_infer_end_idx * 2 + 1
def init_recorder(self):
if not self.output_video_path or self.rank != self.target_recorder_rank:
return
logger.info(f"Rank {self.rank} init recorder with: {self.output_video_path}")
whip_shared_path = os.getenv("WHIP_SHARED_LIB", None)
if whip_shared_path and self.output_video_path.startswith("http"):
self.recorder = X264VARecorder(
whip_shared_path=whip_shared_path,
livestream_url=self.output_video_path,
fps=self.record_fps,
sample_rate=self.audio_sr,
slice_frame=self.slice_frame,
prev_frame=self.prev_frame_length,
)
else:
self.recorder = VARecorder(
livestream_url=self.output_video_path,
fps=self.record_fps,
sample_rate=self.audio_sr,
slice_frame=self.slice_frame,
prev_frame=self.prev_frame_length,
)
def init_reader(self, model_runner=None):
if not isinstance(self.audio_path, dict):
return
assert self.audio_path["type"] == "stream", f"unexcept audio_path: {self.audio_path}"
segment_duration = self.max_num_frames / self.target_fps
prev_duration = self.prev_frame_length / self.target_fps
omni_work_dir = os.getenv("OMNI_WORK_DIR", None)
if omni_work_dir:
self.reader = OmniVAReader(
rank=self.rank,
world_size=self.world_size,
stream_url=self.audio_path["data"],
sample_rate=self.audio_sr,
segment_duration=segment_duration,
prev_duration=prev_duration,
target_rank=self.target_reader_rank,
model_runner=model_runner,
huoshan_tts_voice_type=self.audio_path.get("huoshan_tts_voice_type", None),
)
else:
self.reader = VAReader(
rank=self.rank,
world_size=self.world_size,
stream_url=self.audio_path["data"],
sample_rate=self.audio_sr,
segment_duration=segment_duration,
prev_duration=prev_duration,
target_rank=self.target_reader_rank,
)
def start(self):
self.reader.start()
if self.rank == self.target_recorder_rank:
assert self.recorder is not None, f"recorder is required for stream audio input for rank {self.rank}"
self.recorder.start(self.record_w, self.record_h)
if self.world_size > 1:
dist.barrier()
def next_control(self):
if isinstance(self.reader, OmniVAReader):
return self.omni_reader_next_control()
return NextControl(action="fetch")
def before_control(self):
if isinstance(self.reader, OmniVAReader):
self.len_tensor = torch.tensor([0], dtype=torch.int32, device=AI_DEVICE)
self.flag_tensor = torch.tensor([0], dtype=torch.int32, device=AI_DEVICE)
self.prev_tensor = torch.zeros((1, 3, self.prev_frame_length, self.tgt_h, self.tgt_w), dtype=torch.float, device=AI_DEVICE)
def omni_reader_next_control(self):
immediate_switch = self.reader.get_immediate_switch()
if immediate_switch == 1:
# truncate the stream buffer to keep the max infer time length
# and broadcast the prev video tensor to all ranks
if self.rank == self.target_recorder_rank:
logger.warning(f"runner recv immediate switch, truncate stream buffer")
video_tensor = self.recorder.truncate_stream_buffer(self.est_infer_end_idx)
if video_tensor is not None:
self.flag_tensor.fill_(1)
self.prev_tensor.copy_(video_tensor)
else:
self.flag_tensor.fill_(0)
dist.broadcast(self.flag_tensor, src=self.target_recorder_rank)
if self.flag_tensor.item() == 1:
dist.broadcast(self.prev_tensor, src=self.target_recorder_rank)
return NextControl(action="switch", data=self.prev_tensor)
else:
# get the length of stream buffer, broadcast to all ranks
if self.rank == self.target_recorder_rank:
stream_buffer_length = self.recorder.get_buffer_stream_size()
self.len_tensor.copy_(stream_buffer_length)
dist.broadcast(self.len_tensor, src=self.target_recorder_rank)
buffer_length = self.len_tensor.item()
# stream buffer is enough, skip infer
if buffer_length >= self.min_stay_queue_num:
return NextControl(action="wait")
return NextControl(action="fetch")
def pub_livestream(self, images: torch.Tensor, audios: torch.Tensor, gen_video: torch.Tensor):
if self.recorder.realtime:
self.recorder.buffer_stream(images, audios, gen_video)
else:
self.recorder.pub_livestream(images, audios)
def clear(self):
self.len_tensor = None
self.flag_tensor = None
self.prev_tensor = None
if self.reader is not None:
self.reader.stop()
self.reader = None
if self.recorder is not None:
self.recorder.stop()
self.recorder = None
def __del__(self):
self.clear()
This diff is collapsed.
...@@ -24,6 +24,8 @@ class VARecorder: ...@@ -24,6 +24,8 @@ class VARecorder:
livestream_url: str, livestream_url: str,
fps: float = 16.0, fps: float = 16.0,
sample_rate: int = 16000, sample_rate: int = 16000,
slice_frame: int = 1,
prev_frame: int = 1,
): ):
self.livestream_url = livestream_url self.livestream_url = livestream_url
self.fps = fps self.fps = fps
...@@ -36,7 +38,9 @@ class VARecorder: ...@@ -36,7 +38,9 @@ class VARecorder:
self.width = None self.width = None
self.height = None self.height = None
self.stoppable_t = None self.stoppable_t = None
self.realtime = True self.realtime = False
if self.livestream_url.startswith("rtmp://") or self.livestream_url.startswith("http"):
self.realtime = True
# ffmpeg process for mix video and audio data and push to livestream # ffmpeg process for mix video and audio data and push to livestream
self.ffmpeg_process = None self.ffmpeg_process = None
...@@ -53,6 +57,16 @@ class VARecorder: ...@@ -53,6 +57,16 @@ class VARecorder:
self.audio_queue = queue.Queue() self.audio_queue = queue.Queue()
self.video_queue = queue.Queue() self.video_queue = queue.Queue()
# buffer for stream data
self.audio_samples_per_frame = round(self.sample_rate / self.fps)
self.stream_buffer = []
self.stream_buffer_lock = threading.Lock()
self.stop_schedule = False
self.schedule_thread = None
self.slice_frame = slice_frame
self.prev_frame = prev_frame
assert self.slice_frame >= self.prev_frame, "Slice frame must be greater than previous frame"
def init_sockets(self): def init_sockets(self):
# TCP socket for send and recv video and audio data # TCP socket for send and recv video and audio data
self.video_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.video_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
...@@ -128,7 +142,7 @@ class VARecorder: ...@@ -128,7 +142,7 @@ class VARecorder:
except (BrokenPipeError, OSError, ConnectionResetError) as e: except (BrokenPipeError, OSError, ConnectionResetError) as e:
logger.info(f"Video connection closed, stopping worker: {type(e).__name__}") logger.info(f"Video connection closed, stopping worker: {type(e).__name__}")
return return
if self.realtime: if self.realtime and i < data.shape[0] - 1:
time.sleep(max(0, packet_secs - (time.time() - t0))) time.sleep(max(0, packet_secs - (time.time() - t0)))
fail_time = 0 fail_time = 0
...@@ -337,7 +351,9 @@ class VARecorder: ...@@ -337,7 +351,9 @@ class VARecorder:
def start(self, width: int, height: int): def start(self, width: int, height: int):
self.set_video_size(width, height) self.set_video_size(width, height)
duration = 1.0 duration = 1.0
self.pub_livestream(torch.zeros((int(self.fps * duration), height, width, 3), dtype=torch.float16), torch.zeros(int(self.sample_rate * duration), dtype=torch.float16)) frames = int(self.fps * duration)
samples = int(self.sample_rate * (frames / self.fps))
self.pub_livestream(torch.zeros((frames, height, width, 3), dtype=torch.float16), torch.zeros(samples, dtype=torch.float16))
time.sleep(duration) time.sleep(duration)
def set_video_size(self, width: int, height: int): def set_video_size(self, width: int, height: int):
...@@ -353,11 +369,13 @@ class VARecorder: ...@@ -353,11 +369,13 @@ class VARecorder:
self.start_ffmpeg_process_whip() self.start_ffmpeg_process_whip()
else: else:
self.start_ffmpeg_process_local() self.start_ffmpeg_process_local()
self.realtime = False
self.audio_thread = threading.Thread(target=self.audio_worker) self.audio_thread = threading.Thread(target=self.audio_worker)
self.video_thread = threading.Thread(target=self.video_worker) self.video_thread = threading.Thread(target=self.video_worker)
self.audio_thread.start() self.audio_thread.start()
self.video_thread.start() self.video_thread.start()
if self.realtime:
self.schedule_thread = threading.Thread(target=self.schedule_stream_buffer)
self.schedule_thread.start()
# Publish ComfyUI Image tensor and audio tensor to livestream # Publish ComfyUI Image tensor and audio tensor to livestream
def pub_livestream(self, images: torch.Tensor, audios: torch.Tensor): def pub_livestream(self, images: torch.Tensor, audios: torch.Tensor):
...@@ -377,6 +395,75 @@ class VARecorder: ...@@ -377,6 +395,75 @@ class VARecorder:
self.stoppable_t = time.time() + M / self.sample_rate + 3 self.stoppable_t = time.time() + M / self.sample_rate + 3
def buffer_stream(self, images: torch.Tensor, audios: torch.Tensor, gen_video: torch.Tensor):
N, height, width, C = images.shape
M = audios.reshape(-1).shape[0]
assert N % self.slice_frame == 0, "Video frames must be divisible by slice_frame"
assert C == 3, "Input must be [N, H, W, C] with C=3"
audio_frames = round(M * self.fps / self.sample_rate)
if audio_frames != N:
logger.warning(f"Video and audio frames mismatch, {N} vs {audio_frames}")
self.set_video_size(width, height)
# logger.info(f"Buffer stream images {images.shape} {audios.shape} {gen_video.shape}")
rets = []
for i in range(0, N, self.slice_frame):
end_frame = i + self.slice_frame
img = images[i:end_frame]
aud = audios[i * self.audio_samples_per_frame : end_frame * self.audio_samples_per_frame]
gen = gen_video[:, :, (end_frame - self.prev_frame) : end_frame]
rets.append((img, aud, gen))
with self.stream_buffer_lock:
origin_size = len(self.stream_buffer)
self.stream_buffer.extend(rets)
logger.info(f"Buffered {origin_size} + {len(rets)} = {len(self.stream_buffer)} stream segments")
def get_buffer_stream_size(self):
return len(self.stream_buffer)
def truncate_stream_buffer(self, size: int):
with self.stream_buffer_lock:
self.stream_buffer = self.stream_buffer[:size]
logger.info(f"Truncated stream buffer to {len(self.stream_buffer)} segments")
if len(self.stream_buffer) > 0:
return self.stream_buffer[-1][2] # return the last video tensor
else:
return None
def schedule_stream_buffer(self):
schedule_interval = self.slice_frame / self.fps
logger.info(f"Schedule stream buffer with interval: {schedule_interval} seconds")
t = None
while True:
try:
if self.stop_schedule:
break
img, aud, gen = None, None, None
with self.stream_buffer_lock:
if len(self.stream_buffer) > 0:
img, aud, gen = self.stream_buffer.pop(0)
if t is not None:
wait_secs = schedule_interval - (time.time() - t)
if wait_secs > 0:
time.sleep(wait_secs)
t = time.time()
if img is not None and aud is not None:
self.audio_queue.put(aud)
self.video_queue.put(img)
# logger.info(f"Scheduled {img.shape[0]} frames and {aud.shape[0]} audio samples to publish")
del gen
self.stoppable_t = time.time() + aud.shape[0] / self.sample_rate + 3
else:
logger.warning(f"No stream buffer to schedule")
except Exception:
logger.error(f"Schedule stream buffer error: {traceback.format_exc()}")
break
logger.info("Schedule stream buffer thread stopped")
def stop(self, wait=True): def stop(self, wait=True):
if wait and self.stoppable_t: if wait and self.stoppable_t:
t = self.stoppable_t - time.time() t = self.stoppable_t - time.time()
...@@ -385,6 +472,12 @@ class VARecorder: ...@@ -385,6 +472,12 @@ class VARecorder:
time.sleep(t) time.sleep(t)
self.stoppable_t = None self.stoppable_t = None
if self.schedule_thread:
self.stop_schedule = True
self.schedule_thread.join(timeout=5)
if self.schedule_thread and self.schedule_thread.is_alive():
logger.error(f"Schedule thread did not stop after 5s")
# Send stop signals to queues # Send stop signals to queues
if self.audio_queue: if self.audio_queue:
self.audio_queue.put(None) self.audio_queue.put(None)
......
...@@ -18,6 +18,8 @@ class X264VARecorder: ...@@ -18,6 +18,8 @@ class X264VARecorder:
livestream_url: str, livestream_url: str,
fps: float = 16.0, fps: float = 16.0,
sample_rate: int = 16000, sample_rate: int = 16000,
slice_frame: int = 1,
prev_frame: int = 1,
): ):
assert livestream_url.startswith("http"), "X264VARecorder only support whip http livestream" assert livestream_url.startswith("http"), "X264VARecorder only support whip http livestream"
self.livestream_url = livestream_url self.livestream_url = livestream_url
...@@ -33,16 +35,29 @@ class X264VARecorder: ...@@ -33,16 +35,29 @@ class X264VARecorder:
self.whip_shared_lib = None self.whip_shared_lib = None
self.whip_shared_handle = None self.whip_shared_handle = None
assert livestream_url.startswith("http"), "X264VARecorder only support whip http livestream"
self.realtime = True
# queue for send data to whip shared api # queue for send data to whip shared api
self.queue = queue.Queue() self.queue = queue.Queue()
self.worker_thread = None self.worker_thread = None
# buffer for stream data
self.target_sample_rate = 48000
self.target_samples_per_frame = round(self.target_sample_rate / self.fps)
self.target_chunks_per_frame = self.target_samples_per_frame * 2
self.stream_buffer = []
self.stream_buffer_lock = threading.Lock()
self.stop_schedule = False
self.schedule_thread = None
self.slice_frame = slice_frame
self.prev_frame = prev_frame
assert self.slice_frame >= self.prev_frame, "Slice frame must be greater than previous frame"
def worker(self): def worker(self):
try: try:
fail_time, max_fail_time = 0, 10 fail_time, max_fail_time = 0, 10
packet_secs = 1.0 / self.fps packet_secs = 1.0 / self.fps
audio_chunk = round(48000 * 2 / self.fps)
audio_samples = round(48000 / self.fps)
while True: while True:
try: try:
if self.queue is None: if self.queue is None:
...@@ -55,14 +70,16 @@ class X264VARecorder: ...@@ -55,14 +70,16 @@ class X264VARecorder:
for i in range(images.shape[0]): for i in range(images.shape[0]):
t0 = time.time() t0 = time.time()
cur_audio = audios[i * audio_chunk : (i + 1) * audio_chunk].flatten() cur_audio = audios[i * self.target_chunks_per_frame : (i + 1) * self.target_chunks_per_frame].flatten()
audio_ptr = cur_audio.ctypes.data_as(ctypes.POINTER(ctypes.c_int16)) audio_ptr = cur_audio.ctypes.data_as(ctypes.POINTER(ctypes.c_int16))
self.whip_shared_lib.pushWhipRawAudioFrame(self.whip_shared_handle, audio_ptr, audio_samples) self.whip_shared_lib.pushWhipRawAudioFrame(self.whip_shared_handle, audio_ptr, self.target_samples_per_frame)
cur_video = images[i].flatten() cur_video = images[i].flatten()
video_ptr = cur_video.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8)) video_ptr = cur_video.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8))
self.whip_shared_lib.pushWhipRawVideoFrame(self.whip_shared_handle, video_ptr, self.width, self.height) self.whip_shared_lib.pushWhipRawVideoFrame(self.whip_shared_handle, video_ptr, self.width, self.height)
time.sleep(max(0, packet_secs - (time.time() - t0)))
if self.realtime and i < images.shape[0] - 1:
time.sleep(max(0, packet_secs - (time.time() - t0)))
fail_time = 0 fail_time = 0
except: # noqa except: # noqa
...@@ -115,24 +132,78 @@ class X264VARecorder: ...@@ -115,24 +132,78 @@ class X264VARecorder:
self.start_libx264_whip_shared_api(width, height) self.start_libx264_whip_shared_api(width, height)
self.worker_thread = threading.Thread(target=self.worker) self.worker_thread = threading.Thread(target=self.worker)
self.worker_thread.start() self.worker_thread.start()
if self.realtime:
self.schedule_thread = threading.Thread(target=self.schedule_stream_buffer)
self.schedule_thread.start()
# Publish ComfyUI Image tensor and audio tensor to livestream def buffer_stream(self, images: torch.Tensor, audios: torch.Tensor, gen_video: torch.Tensor):
def pub_livestream(self, images: torch.Tensor, audios: torch.Tensor):
N, height, width, C = images.shape N, height, width, C = images.shape
M = audios.reshape(-1).shape[0] M = audios.reshape(-1).shape[0]
assert N % self.slice_frame == 0, "Video frames must be divisible by slice_frame"
assert C == 3, "Input must be [N, H, W, C] with C=3" assert C == 3, "Input must be [N, H, W, C] with C=3"
logger.info(f"Publishing video [{N}x{width}x{height}], audio: [{M}]")
audio_frames = round(M * self.fps / self.sample_rate) audio_frames = round(M * self.fps / self.sample_rate)
if audio_frames != N: if audio_frames != N:
logger.warning(f"Video and audio frames mismatch, {N} vs {audio_frames}") logger.warning(f"Video and audio frames mismatch, {N} vs {audio_frames}")
self.set_video_size(width, height) self.set_video_size(width, height)
audio_datas, image_datas = self.convert_data(audios, images) audio_datas, image_datas = self.convert_data(audios, images)
self.queue.put((audio_datas, image_datas))
logger.info(f"Published {N} frames and {M} audio samples") # logger.info(f"Buffer stream images {images.shape} {audios.shape} {gen_video.shape}")
self.stoppable_t = time.time() + M / self.sample_rate + 3 rets = []
for i in range(0, N, self.slice_frame):
end_frame = i + self.slice_frame
img = image_datas[i:end_frame]
aud = audio_datas[i * self.target_chunks_per_frame : end_frame * self.target_chunks_per_frame]
gen = gen_video[:, :, (end_frame - self.prev_frame) : end_frame]
rets.append((img, aud, gen))
with self.stream_buffer_lock:
origin_size = len(self.stream_buffer)
self.stream_buffer.extend(rets)
logger.info(f"Buffered {origin_size} + {len(rets)} = {len(self.stream_buffer)} stream segments")
def get_buffer_stream_size(self):
return len(self.stream_buffer)
def truncate_stream_buffer(self, size: int):
with self.stream_buffer_lock:
self.stream_buffer = self.stream_buffer[:size]
logger.info(f"Truncated stream buffer to {len(self.stream_buffer)} segments")
if len(self.stream_buffer) > 0:
return self.stream_buffer[-1][2] # return the last video tensor
else:
return None
def schedule_stream_buffer(self):
schedule_interval = self.slice_frame / self.fps
logger.info(f"Schedule stream buffer with interval: {schedule_interval} seconds")
t = None
while True:
try:
if self.stop_schedule:
break
img, aud, gen = None, None, None
with self.stream_buffer_lock:
if len(self.stream_buffer) > 0:
img, aud, gen = self.stream_buffer.pop(0)
if t is not None:
wait_secs = schedule_interval - (time.time() - t)
if wait_secs > 0:
time.sleep(wait_secs)
t = time.time()
if img is not None and aud is not None:
self.queue.put((aud, img))
# logger.info(f"Scheduled {img.shape[0]} frames and {aud.shape[0]} audio samples to publish")
del gen
self.stoppable_t = time.time() + img.shape[0] / self.fps + 3
else:
logger.warning(f"No stream buffer to schedule")
except Exception:
logger.error(f"Schedule stream buffer error: {traceback.format_exc()}")
break
logger.info("Schedule stream buffer thread stopped")
def stop(self, wait=True): def stop(self, wait=True):
if wait and self.stoppable_t: if wait and self.stoppable_t:
...@@ -142,6 +213,12 @@ class X264VARecorder: ...@@ -142,6 +213,12 @@ class X264VARecorder:
time.sleep(t) time.sleep(t)
self.stoppable_t = None self.stoppable_t = None
if self.schedule_thread:
self.stop_schedule = True
self.schedule_thread.join(timeout=5)
if self.schedule_thread and self.schedule_thread.is_alive():
logger.error(f"Schedule thread did not stop after 5s")
# Send stop signals to queues # Send stop signals to queues
if self.queue: if self.queue:
self.queue.put(None) self.queue.put(None)
...@@ -219,7 +296,7 @@ if __name__ == "__main__": ...@@ -219,7 +296,7 @@ if __name__ == "__main__":
cur_audio_array = np.zeros(int(interval * sample_rate), dtype=np.float32) cur_audio_array = np.zeros(int(interval * sample_rate), dtype=np.float32)
num_frames = int(interval * fps) num_frames = int(interval * fps)
images = create_simple_video(num_frames, height, width) images = create_simple_video(num_frames, height, width)
recorder.pub_livestream(images, torch.tensor(cur_audio_array, dtype=torch.float32)) recorder.buffer_stream(images, torch.tensor(cur_audio_array, dtype=torch.float32), images)
i += interval i += interval
time.sleep(interval - (time.time() - t0)) time.sleep(interval - (time.time() - t0))
...@@ -237,7 +314,7 @@ if __name__ == "__main__": ...@@ -237,7 +314,7 @@ if __name__ == "__main__":
if started: if started:
logger.warning(f"start pub_livestream !!!!!!!!!!!!!!!!!!!!!!!") logger.warning(f"start pub_livestream !!!!!!!!!!!!!!!!!!!!!!!")
started = False started = False
recorder.pub_livestream(images, cur_audio_array) recorder.buffer_stream(images, cur_audio_array, images)
i += interval i += interval
time.sleep(interval - (time.time() - t0)) time.sleep(interval - (time.time() - t0))
......
import argparse import argparse
import asyncio import asyncio
import base64 import base64
import copy
import json import json
import mimetypes import mimetypes
import os import os
...@@ -1030,13 +1031,17 @@ async def api_v1_share_get(share_id: str): ...@@ -1030,13 +1031,17 @@ async def api_v1_share_get(share_id: str):
@app.get("/api/v1/voices/list") @app.get("/api/v1/voices/list")
async def api_v1_voices_list(): async def api_v1_voices_list(request: Request):
try: try:
version = request.query_params.get("version", "all")
if volcengine_tts_client is None: if volcengine_tts_client is None:
return error_response("Volcengine TTS client not loaded", 500) return error_response("Volcengine TTS client not loaded", 500)
voices = volcengine_tts_client.get_voice_list() voices = volcengine_tts_client.get_voice_list()
if voices is None: if voices is None:
return error_response("No voice list found", 404) return error_response("No voice list found", 404)
if version != "all":
voices = copy.deepcopy(voices)
voices["voices"] = [v for v in voices["voices"] if v["version"] == version]
return voices return voices
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
......
...@@ -34,7 +34,7 @@ HEADERS = {"Authorization": f"Bearer {WORKER_SECRET_KEY}", "Content-Type": "appl ...@@ -34,7 +34,7 @@ HEADERS = {"Authorization": f"Bearer {WORKER_SECRET_KEY}", "Content-Type": "appl
STOPPED = False STOPPED = False
WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1)) WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))
RANK = int(os.environ.get("RANK", 0)) RANK = int(os.environ.get("RANK", 0))
TARGET_RANK = WORLD_SIZE - 1 TARGET_RANK = int(os.getenv("WORKER_RANK", "0")) % WORLD_SIZE
async def ping_life(server_url, worker_identity, keys): async def ping_life(server_url, worker_identity, keys):
...@@ -251,14 +251,17 @@ async def main(args): ...@@ -251,14 +251,17 @@ async def main(args):
logger.warning("Main loop cancelled, do not shut down") logger.warning("Main loop cancelled, do not shut down")
finally: finally:
try:
if ping_task:
ping_task.cancel()
await sync_subtask()
except Exception:
logger.warning(f"Sync subtask failed: {traceback.format_exc()}")
if RANK == TARGET_RANK and sub["task_id"] in RUNNING_SUBTASKS: if RANK == TARGET_RANK and sub["task_id"] in RUNNING_SUBTASKS:
try: try:
await report_task(status=status, **sub) await report_task(status=status, **sub)
except: # noqa except Exception:
logger.warning(f"Report failed: {traceback.format_exc()}") logger.warning(f"Report failed: {traceback.format_exc()}")
if ping_task:
ping_task.cancel()
await sync_subtask()
async def shutdown(loop): async def shutdown(loop):
......
...@@ -40,8 +40,8 @@ class BaseWorker: ...@@ -40,8 +40,8 @@ class BaseWorker:
self.rank = dist.get_rank() self.rank = dist.get_rank()
self.world_size = dist.get_world_size() self.world_size = dist.get_world_size()
set_parallel_config(config) set_parallel_config(config)
# same as va_recorder rank and worker main ping rank # same as va_recorder rank
self.out_video_rank = self.world_size - 1 self.out_video_rank = int(os.getenv("RECORDER_RANK", "0")) % self.world_size
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
self.runner = RUNNER_REGISTER[config["model_cls"]](config) self.runner = RUNNER_REGISTER[config["model_cls"]](config)
self.input_info = set_input_info(args) self.input_info = set_input_info(args)
......
import os
from abc import ABC from abc import ABC
import torch import torch
...@@ -139,19 +140,30 @@ class BaseRunner(ABC): ...@@ -139,19 +140,30 @@ class BaseRunner(ABC):
if dist.is_initialized(): if dist.is_initialized():
rank = dist.get_rank() rank = dist.get_rank()
world_size = dist.get_world_size() world_size = dist.get_world_size()
signal_rank = world_size - 1 stop_rank = int(os.getenv("WORKER_RANK", "0")) % world_size # same as worker hub target_rank
pause_rank = int(os.getenv("READER_RANK", "0")) % world_size # same as va_reader target_rank
stopped = 0 stopped, paused = 0, 0
if rank == signal_rank and hasattr(self, "stop_signal") and self.stop_signal: if rank == stop_rank and hasattr(self, "stop_signal") and self.stop_signal:
stopped = 1 stopped = 1
if rank == pause_rank and hasattr(self, "pause_signal") and self.pause_signal:
paused = 1
if world_size > 1: if world_size > 1:
if rank == signal_rank: if rank == stop_rank:
t = torch.tensor([stopped], dtype=torch.int32).to(device=AI_DEVICE) t1 = torch.tensor([stopped], dtype=torch.int32).to(device=AI_DEVICE)
else: else:
t = torch.zeros(1, dtype=torch.int32, device=AI_DEVICE) t1 = torch.zeros(1, dtype=torch.int32, device=AI_DEVICE)
dist.broadcast(t, src=signal_rank) if rank == pause_rank:
stopped = t.item() t2 = torch.tensor([paused], dtype=torch.int32).to(device=AI_DEVICE)
else:
t2 = torch.zeros(1, dtype=torch.int32, device=AI_DEVICE)
dist.broadcast(t1, src=stop_rank)
dist.broadcast(t2, src=pause_rank)
stopped = t1.item()
paused = t2.item()
if stopped == 1: if stopped == 1:
raise Exception(f"find rank: {rank} stop_signal, stop running, it's an expected behavior") raise Exception(f"find rank: {rank} stop_signal, stop running, it's an expected behavior")
if paused == 1:
raise Exception(f"find rank: {rank} pause_signal, pause running, it's an expected behavior")
...@@ -18,9 +18,7 @@ from loguru import logger ...@@ -18,9 +18,7 @@ from loguru import logger
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import resize from torchvision.transforms.functional import resize
from lightx2v.deploy.common.va_reader import VAReader from lightx2v.deploy.common.va_controller import VAController
from lightx2v.deploy.common.va_recorder import VARecorder
from lightx2v.deploy.common.va_recorder_x264 import X264VARecorder
from lightx2v.models.input_encoders.hf.seko_audio.audio_adapter import AudioAdapter from lightx2v.models.input_encoders.hf.seko_audio.audio_adapter import AudioAdapter
from lightx2v.models.input_encoders.hf.seko_audio.audio_encoder import SekoAudioEncoderModel from lightx2v.models.input_encoders.hf.seko_audio.audio_encoder import SekoAudioEncoderModel
from lightx2v.models.networks.wan.audio_model import WanAudioModel from lightx2v.models.networks.wan.audio_model import WanAudioModel
...@@ -694,15 +692,15 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -694,15 +692,15 @@ class WanAudioRunner(WanRunner): # type:ignore
) )
if "video_super_resolution" in self.config and self.vsr_model is not None: if "video_super_resolution" in self.config and self.vsr_model is not None:
logger.info(f"Applying video super resolution with scale {self.config['video_super_resolution']['scale']}") # logger.info(f"Applying video super resolution with scale {self.config['video_super_resolution']['scale']}")
video_seg = self.vsr_model.super_resolve_frames( video_seg = self.vsr_model.super_resolve_frames(
video_seg, video_seg,
seed=self.config["video_super_resolution"]["seed"], seed=self.config["video_super_resolution"]["seed"],
scale=self.config["video_super_resolution"]["scale"], scale=self.config["video_super_resolution"]["scale"],
) )
if self.va_recorder: if self.va_controller.recorder is not None:
self.va_recorder.pub_livestream(video_seg, audio_seg) self.va_controller.pub_livestream(video_seg, audio_seg, self.gen_video[:, :, :useful_length])
elif self.input_info.return_result_tensor: elif self.input_info.return_result_tensor:
self.gen_video_final[self.segment.start_frame : self.segment.end_frame].copy_(video_seg) self.gen_video_final[self.segment.start_frame : self.segment.end_frame].copy_(video_seg)
self.cut_audio_final[self.segment.start_frame * self._audio_processor.audio_frame_rate : self.segment.end_frame * self._audio_processor.audio_frame_rate].copy_(audio_seg) self.cut_audio_final[self.segment.start_frame * self._audio_processor.audio_frame_rate : self.segment.end_frame * self._audio_processor.audio_frame_rate].copy_(audio_seg)
...@@ -721,85 +719,35 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -721,85 +719,35 @@ class WanAudioRunner(WanRunner): # type:ignore
world_size = dist.get_world_size() world_size = dist.get_world_size()
return rank, world_size return rank, world_size
def init_va_recorder(self):
output_video_path = self.input_info.save_result_path
self.va_recorder = None
if isinstance(output_video_path, dict):
output_video_path = output_video_path["data"]
logger.info(f"init va_recorder with output_video_path: {output_video_path}")
rank, world_size = self.get_rank_and_world_size()
if output_video_path and rank == world_size - 1:
record_fps = self.config.get("target_fps", 16)
audio_sr = self.config.get("audio_sr", 16000)
if "video_frame_interpolation" in self.config and self.vfi_model is not None:
record_fps = self.config["video_frame_interpolation"]["target_fps"]
whip_shared_path = os.getenv("WHIP_SHARED_LIB", None)
if whip_shared_path and output_video_path.startswith("http"):
self.va_recorder = X264VARecorder(
whip_shared_path=whip_shared_path,
livestream_url=output_video_path,
fps=record_fps,
sample_rate=audio_sr,
)
else:
self.va_recorder = VARecorder(
livestream_url=output_video_path,
fps=record_fps,
sample_rate=audio_sr,
)
def init_va_reader(self):
audio_path = self.input_info.audio_path
self.va_reader = None
if isinstance(audio_path, dict):
assert audio_path["type"] == "stream", f"unexcept audio_path: {audio_path}"
rank, world_size = self.get_rank_and_world_size()
target_fps = self.config.get("target_fps", 16)
max_num_frames = self.config.get("target_video_length", 81)
audio_sr = self.config.get("audio_sr", 16000)
prev_frames = self.config.get("prev_frame_length", 5)
self.va_reader = VAReader(
rank=rank,
world_size=world_size,
stream_url=audio_path["data"],
sample_rate=audio_sr,
segment_duration=max_num_frames / target_fps,
prev_duration=prev_frames / target_fps,
target_rank=1,
)
def run_main(self): def run_main(self):
try: try:
self.init_va_recorder() self.va_controller = VAController(self)
self.init_va_reader() logger.info(f"init va_recorder: {self.va_controller.recorder} and va_reader: {self.va_controller.reader}")
logger.info(f"init va_recorder: {self.va_recorder} and va_reader: {self.va_reader}")
if self.va_reader is None: # fixed audio segments inputs
if self.va_controller.reader is None:
return super().run_main() return super().run_main()
self.va_reader.start() self.va_controller.start()
rank, world_size = self.get_rank_and_world_size()
if rank == world_size - 1:
assert self.va_recorder is not None, "va_recorder is required for stream audio input for rank 2"
self.va_recorder.start(self.input_info.target_shape[1], self.input_info.target_shape[0])
if world_size > 1:
dist.barrier()
self.init_run() self.init_run()
if self.config.get("compile", False): if self.config.get("compile", False) and hasattr(self.model, "comple"):
self.model.select_graph_for_compile(self.input_info) self.model.select_graph_for_compile(self.input_info)
self.video_segment_num = "unlimited" # steam audio input, video segment num is unlimited
self.video_segment_num = 1000000
fetch_timeout = self.va_reader.segment_duration + 1
segment_idx = 0 segment_idx = 0
fail_count = 0 fail_count, max_fail_count = 0, 10
max_fail_count = 10 self.va_controller.before_control()
while True: while True:
with ProfilingContext4DebugL1(f"stream segment get audio segment {segment_idx}"): with ProfilingContext4DebugL1(f"stream segment get audio segment {segment_idx}"):
self.check_stop() control = self.va_controller.next_control()
audio_array = self.va_reader.get_audio_segment(timeout=fetch_timeout) if control.action == "immediate":
self.prev_video = control.data
elif control.action == "wait":
time.sleep(0.01)
continue
audio_array = self.va_controller.reader.get_audio_segment()
if audio_array is None: if audio_array is None:
fail_count += 1 fail_count += 1
logger.warning(f"Failed to get audio chunk {fail_count} times") logger.warning(f"Failed to get audio chunk {fail_count} times")
...@@ -808,22 +756,27 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -808,22 +756,27 @@ class WanAudioRunner(WanRunner): # type:ignore
continue continue
with ProfilingContext4DebugL1(f"stream segment end2end {segment_idx}"): with ProfilingContext4DebugL1(f"stream segment end2end {segment_idx}"):
fail_count = 0 try:
self.init_run_segment(segment_idx, audio_array) # reset pause signal
latents = self.run_segment(segment_idx) self.pause_signal = False
self.gen_video = self.run_vae_decoder(latents) self.init_run_segment(segment_idx, audio_array)
self.end_run_segment(segment_idx) self.check_stop()
segment_idx += 1 latents = self.run_segment(segment_idx)
self.check_stop()
self.gen_video = self.run_vae_decoder(latents)
self.check_stop()
self.end_run_segment(segment_idx)
segment_idx += 1
fail_count = 0
except Exception as e:
if "pause_signal, pause running" in str(e):
logger.warning(f"model infer audio pause: {e}, should continue")
else:
raise
finally: finally:
if hasattr(self.model, "inputs"): if hasattr(self.model, "inputs"):
self.end_run() self.end_run()
if self.va_reader: self.va_controller.clear()
self.va_reader.stop()
self.va_reader = None
if self.va_recorder:
self.va_recorder.stop()
self.va_recorder = None
@ProfilingContext4DebugL1("Process after vae decoder") @ProfilingContext4DebugL1("Process after vae decoder")
def process_images_after_vae_decoder(self): def process_images_after_vae_decoder(self):
......
import os # please do not set envs in this file, it will be imported by the __init__.py file
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["TOKENIZERS_PARALLELISM"] = "false" # os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" # os.environ["DTYPE"] = "BF16"
os.environ["DTYPE"] = "BF16" # os.environ["SENSITIVE_LAYER_DTYPE"] = "None"
os.environ["SENSITIVE_LAYER_DTYPE"] = "None" # os.environ["PROFILING_DEBUG_LEVEL"] = "2"
os.environ["PROFILING_DEBUG_LEVEL"] = "2"
import json import json
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment