Unverified Commit bc2828b0 authored by LiangLiu's avatar LiangLiu Committed by GitHub
Browse files

Dev/omni (#577)

Tidy VAReader & OmniVAReader
Tidy VARecorder & X264VARecorder
VARecorder with stream, use buffer stream
Tidy env WORKER_RANK, READER_RANK, RECORDER_RANK
Support voice type choose
parent f3b4ba24
import math
import os
import torch
import torch.distributed as dist
from loguru import logger
from lightx2v.deploy.common.va_reader import VAReader
from lightx2v.deploy.common.va_reader_omni import OmniVAReader
from lightx2v.deploy.common.va_recorder import VARecorder
from lightx2v.deploy.common.va_recorder_x264 import X264VARecorder
from lightx2v.models.runners.vsr.vsr_wrapper import compute_scaled_and_target_dims
from lightx2v_platform.base.global_var import AI_DEVICE
class NextControl:
def __init__(self, action: str, data: any = None):
# action: switch, data: prev_video tensor
# action: wait, data: None
# action: fetch, data: None
self.action = action
self.data = data
class VAController:
def __init__(self, model_runner):
self.reader = None
self.recorder = None
self.rank = 0
self.world_size = 1
if dist.is_initialized():
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.target_reader_rank = int(os.getenv("READER_RANK", "0")) % self.world_size
self.target_recorder_rank = int(os.getenv("RECORDER_RANK", "0")) % self.world_size
self.init_base(model_runner.config, model_runner.input_info, model_runner.vfi_model is not None, model_runner.vsr_model is not None)
self.init_recorder()
self.init_reader(model_runner)
def init_base(self, config, input_info, has_vfi_model, has_vsr_model):
self.audio_path = input_info.audio_path
self.output_video_path = input_info.save_result_path
if isinstance(self.output_video_path, dict):
self.output_video_path = self.output_video_path["data"]
self.audio_sr = config.get("audio_sr", 16000)
self.target_fps = config.get("target_fps", 16)
self.max_num_frames = config.get("target_video_length", 81)
self.prev_frame_length = config.get("prev_frame_length", 5)
self.record_fps = config.get("target_fps", 16)
if "video_frame_interpolation" in config and has_vfi_model:
self.record_fps = config["video_frame_interpolation"]["target_fps"]
self.record_fps = config.get("record_fps", self.record_fps)
self.tgt_h = input_info.target_shape[0]
self.tgt_w = input_info.target_shape[1]
self.record_h, self.record_w = self.tgt_h, self.tgt_w
if "video_super_resolution" in config and has_vsr_model:
_, _, self.record_w, self.record_h = compute_scaled_and_target_dims(
self.record_w,
self.record_h,
scale=config["video_super_resolution"]["scale"],
multiple=128,
)
# how many frames to publish stream as a batch
self.slice_frame = config.get("slice_frame", 1)
# estimate the max infer seconds, for immediate switch with local omni
slice_interval = self.slice_frame / self.record_fps
est_max_infer_secs = config.get("est_max_infer_secs", 0.6)
self.est_infer_end_idx = math.ceil(est_max_infer_secs / slice_interval)
self.min_stay_queue_num = self.est_infer_end_idx * 2 + 1
def init_recorder(self):
if not self.output_video_path or self.rank != self.target_recorder_rank:
return
logger.info(f"Rank {self.rank} init recorder with: {self.output_video_path}")
whip_shared_path = os.getenv("WHIP_SHARED_LIB", None)
if whip_shared_path and self.output_video_path.startswith("http"):
self.recorder = X264VARecorder(
whip_shared_path=whip_shared_path,
livestream_url=self.output_video_path,
fps=self.record_fps,
sample_rate=self.audio_sr,
slice_frame=self.slice_frame,
prev_frame=self.prev_frame_length,
)
else:
self.recorder = VARecorder(
livestream_url=self.output_video_path,
fps=self.record_fps,
sample_rate=self.audio_sr,
slice_frame=self.slice_frame,
prev_frame=self.prev_frame_length,
)
def init_reader(self, model_runner=None):
if not isinstance(self.audio_path, dict):
return
assert self.audio_path["type"] == "stream", f"unexcept audio_path: {self.audio_path}"
segment_duration = self.max_num_frames / self.target_fps
prev_duration = self.prev_frame_length / self.target_fps
omni_work_dir = os.getenv("OMNI_WORK_DIR", None)
if omni_work_dir:
self.reader = OmniVAReader(
rank=self.rank,
world_size=self.world_size,
stream_url=self.audio_path["data"],
sample_rate=self.audio_sr,
segment_duration=segment_duration,
prev_duration=prev_duration,
target_rank=self.target_reader_rank,
model_runner=model_runner,
huoshan_tts_voice_type=self.audio_path.get("huoshan_tts_voice_type", None),
)
else:
self.reader = VAReader(
rank=self.rank,
world_size=self.world_size,
stream_url=self.audio_path["data"],
sample_rate=self.audio_sr,
segment_duration=segment_duration,
prev_duration=prev_duration,
target_rank=self.target_reader_rank,
)
def start(self):
self.reader.start()
if self.rank == self.target_recorder_rank:
assert self.recorder is not None, f"recorder is required for stream audio input for rank {self.rank}"
self.recorder.start(self.record_w, self.record_h)
if self.world_size > 1:
dist.barrier()
def next_control(self):
if isinstance(self.reader, OmniVAReader):
return self.omni_reader_next_control()
return NextControl(action="fetch")
def before_control(self):
if isinstance(self.reader, OmniVAReader):
self.len_tensor = torch.tensor([0], dtype=torch.int32, device=AI_DEVICE)
self.flag_tensor = torch.tensor([0], dtype=torch.int32, device=AI_DEVICE)
self.prev_tensor = torch.zeros((1, 3, self.prev_frame_length, self.tgt_h, self.tgt_w), dtype=torch.float, device=AI_DEVICE)
def omni_reader_next_control(self):
immediate_switch = self.reader.get_immediate_switch()
if immediate_switch == 1:
# truncate the stream buffer to keep the max infer time length
# and broadcast the prev video tensor to all ranks
if self.rank == self.target_recorder_rank:
logger.warning(f"runner recv immediate switch, truncate stream buffer")
video_tensor = self.recorder.truncate_stream_buffer(self.est_infer_end_idx)
if video_tensor is not None:
self.flag_tensor.fill_(1)
self.prev_tensor.copy_(video_tensor)
else:
self.flag_tensor.fill_(0)
dist.broadcast(self.flag_tensor, src=self.target_recorder_rank)
if self.flag_tensor.item() == 1:
dist.broadcast(self.prev_tensor, src=self.target_recorder_rank)
return NextControl(action="switch", data=self.prev_tensor)
else:
# get the length of stream buffer, broadcast to all ranks
if self.rank == self.target_recorder_rank:
stream_buffer_length = self.recorder.get_buffer_stream_size()
self.len_tensor.copy_(stream_buffer_length)
dist.broadcast(self.len_tensor, src=self.target_recorder_rank)
buffer_length = self.len_tensor.item()
# stream buffer is enough, skip infer
if buffer_length >= self.min_stay_queue_num:
return NextControl(action="wait")
return NextControl(action="fetch")
def pub_livestream(self, images: torch.Tensor, audios: torch.Tensor, gen_video: torch.Tensor):
if self.recorder.realtime:
self.recorder.buffer_stream(images, audios, gen_video)
else:
self.recorder.pub_livestream(images, audios)
def clear(self):
self.len_tensor = None
self.flag_tensor = None
self.prev_tensor = None
if self.reader is not None:
self.reader.stop()
self.reader = None
if self.recorder is not None:
self.recorder.stop()
self.recorder = None
def __del__(self):
self.clear()
This diff is collapsed.
......@@ -24,6 +24,8 @@ class VARecorder:
livestream_url: str,
fps: float = 16.0,
sample_rate: int = 16000,
slice_frame: int = 1,
prev_frame: int = 1,
):
self.livestream_url = livestream_url
self.fps = fps
......@@ -36,7 +38,9 @@ class VARecorder:
self.width = None
self.height = None
self.stoppable_t = None
self.realtime = True
self.realtime = False
if self.livestream_url.startswith("rtmp://") or self.livestream_url.startswith("http"):
self.realtime = True
# ffmpeg process for mix video and audio data and push to livestream
self.ffmpeg_process = None
......@@ -53,6 +57,16 @@ class VARecorder:
self.audio_queue = queue.Queue()
self.video_queue = queue.Queue()
# buffer for stream data
self.audio_samples_per_frame = round(self.sample_rate / self.fps)
self.stream_buffer = []
self.stream_buffer_lock = threading.Lock()
self.stop_schedule = False
self.schedule_thread = None
self.slice_frame = slice_frame
self.prev_frame = prev_frame
assert self.slice_frame >= self.prev_frame, "Slice frame must be greater than previous frame"
def init_sockets(self):
# TCP socket for send and recv video and audio data
self.video_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
......@@ -128,7 +142,7 @@ class VARecorder:
except (BrokenPipeError, OSError, ConnectionResetError) as e:
logger.info(f"Video connection closed, stopping worker: {type(e).__name__}")
return
if self.realtime:
if self.realtime and i < data.shape[0] - 1:
time.sleep(max(0, packet_secs - (time.time() - t0)))
fail_time = 0
......@@ -337,7 +351,9 @@ class VARecorder:
def start(self, width: int, height: int):
self.set_video_size(width, height)
duration = 1.0
self.pub_livestream(torch.zeros((int(self.fps * duration), height, width, 3), dtype=torch.float16), torch.zeros(int(self.sample_rate * duration), dtype=torch.float16))
frames = int(self.fps * duration)
samples = int(self.sample_rate * (frames / self.fps))
self.pub_livestream(torch.zeros((frames, height, width, 3), dtype=torch.float16), torch.zeros(samples, dtype=torch.float16))
time.sleep(duration)
def set_video_size(self, width: int, height: int):
......@@ -353,11 +369,13 @@ class VARecorder:
self.start_ffmpeg_process_whip()
else:
self.start_ffmpeg_process_local()
self.realtime = False
self.audio_thread = threading.Thread(target=self.audio_worker)
self.video_thread = threading.Thread(target=self.video_worker)
self.audio_thread.start()
self.video_thread.start()
if self.realtime:
self.schedule_thread = threading.Thread(target=self.schedule_stream_buffer)
self.schedule_thread.start()
# Publish ComfyUI Image tensor and audio tensor to livestream
def pub_livestream(self, images: torch.Tensor, audios: torch.Tensor):
......@@ -377,6 +395,75 @@ class VARecorder:
self.stoppable_t = time.time() + M / self.sample_rate + 3
def buffer_stream(self, images: torch.Tensor, audios: torch.Tensor, gen_video: torch.Tensor):
N, height, width, C = images.shape
M = audios.reshape(-1).shape[0]
assert N % self.slice_frame == 0, "Video frames must be divisible by slice_frame"
assert C == 3, "Input must be [N, H, W, C] with C=3"
audio_frames = round(M * self.fps / self.sample_rate)
if audio_frames != N:
logger.warning(f"Video and audio frames mismatch, {N} vs {audio_frames}")
self.set_video_size(width, height)
# logger.info(f"Buffer stream images {images.shape} {audios.shape} {gen_video.shape}")
rets = []
for i in range(0, N, self.slice_frame):
end_frame = i + self.slice_frame
img = images[i:end_frame]
aud = audios[i * self.audio_samples_per_frame : end_frame * self.audio_samples_per_frame]
gen = gen_video[:, :, (end_frame - self.prev_frame) : end_frame]
rets.append((img, aud, gen))
with self.stream_buffer_lock:
origin_size = len(self.stream_buffer)
self.stream_buffer.extend(rets)
logger.info(f"Buffered {origin_size} + {len(rets)} = {len(self.stream_buffer)} stream segments")
def get_buffer_stream_size(self):
return len(self.stream_buffer)
def truncate_stream_buffer(self, size: int):
with self.stream_buffer_lock:
self.stream_buffer = self.stream_buffer[:size]
logger.info(f"Truncated stream buffer to {len(self.stream_buffer)} segments")
if len(self.stream_buffer) > 0:
return self.stream_buffer[-1][2] # return the last video tensor
else:
return None
def schedule_stream_buffer(self):
schedule_interval = self.slice_frame / self.fps
logger.info(f"Schedule stream buffer with interval: {schedule_interval} seconds")
t = None
while True:
try:
if self.stop_schedule:
break
img, aud, gen = None, None, None
with self.stream_buffer_lock:
if len(self.stream_buffer) > 0:
img, aud, gen = self.stream_buffer.pop(0)
if t is not None:
wait_secs = schedule_interval - (time.time() - t)
if wait_secs > 0:
time.sleep(wait_secs)
t = time.time()
if img is not None and aud is not None:
self.audio_queue.put(aud)
self.video_queue.put(img)
# logger.info(f"Scheduled {img.shape[0]} frames and {aud.shape[0]} audio samples to publish")
del gen
self.stoppable_t = time.time() + aud.shape[0] / self.sample_rate + 3
else:
logger.warning(f"No stream buffer to schedule")
except Exception:
logger.error(f"Schedule stream buffer error: {traceback.format_exc()}")
break
logger.info("Schedule stream buffer thread stopped")
def stop(self, wait=True):
if wait and self.stoppable_t:
t = self.stoppable_t - time.time()
......@@ -385,6 +472,12 @@ class VARecorder:
time.sleep(t)
self.stoppable_t = None
if self.schedule_thread:
self.stop_schedule = True
self.schedule_thread.join(timeout=5)
if self.schedule_thread and self.schedule_thread.is_alive():
logger.error(f"Schedule thread did not stop after 5s")
# Send stop signals to queues
if self.audio_queue:
self.audio_queue.put(None)
......
......@@ -18,6 +18,8 @@ class X264VARecorder:
livestream_url: str,
fps: float = 16.0,
sample_rate: int = 16000,
slice_frame: int = 1,
prev_frame: int = 1,
):
assert livestream_url.startswith("http"), "X264VARecorder only support whip http livestream"
self.livestream_url = livestream_url
......@@ -33,16 +35,29 @@ class X264VARecorder:
self.whip_shared_lib = None
self.whip_shared_handle = None
assert livestream_url.startswith("http"), "X264VARecorder only support whip http livestream"
self.realtime = True
# queue for send data to whip shared api
self.queue = queue.Queue()
self.worker_thread = None
# buffer for stream data
self.target_sample_rate = 48000
self.target_samples_per_frame = round(self.target_sample_rate / self.fps)
self.target_chunks_per_frame = self.target_samples_per_frame * 2
self.stream_buffer = []
self.stream_buffer_lock = threading.Lock()
self.stop_schedule = False
self.schedule_thread = None
self.slice_frame = slice_frame
self.prev_frame = prev_frame
assert self.slice_frame >= self.prev_frame, "Slice frame must be greater than previous frame"
def worker(self):
try:
fail_time, max_fail_time = 0, 10
packet_secs = 1.0 / self.fps
audio_chunk = round(48000 * 2 / self.fps)
audio_samples = round(48000 / self.fps)
while True:
try:
if self.queue is None:
......@@ -55,14 +70,16 @@ class X264VARecorder:
for i in range(images.shape[0]):
t0 = time.time()
cur_audio = audios[i * audio_chunk : (i + 1) * audio_chunk].flatten()
cur_audio = audios[i * self.target_chunks_per_frame : (i + 1) * self.target_chunks_per_frame].flatten()
audio_ptr = cur_audio.ctypes.data_as(ctypes.POINTER(ctypes.c_int16))
self.whip_shared_lib.pushWhipRawAudioFrame(self.whip_shared_handle, audio_ptr, audio_samples)
self.whip_shared_lib.pushWhipRawAudioFrame(self.whip_shared_handle, audio_ptr, self.target_samples_per_frame)
cur_video = images[i].flatten()
video_ptr = cur_video.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8))
self.whip_shared_lib.pushWhipRawVideoFrame(self.whip_shared_handle, video_ptr, self.width, self.height)
time.sleep(max(0, packet_secs - (time.time() - t0)))
if self.realtime and i < images.shape[0] - 1:
time.sleep(max(0, packet_secs - (time.time() - t0)))
fail_time = 0
except: # noqa
......@@ -115,24 +132,78 @@ class X264VARecorder:
self.start_libx264_whip_shared_api(width, height)
self.worker_thread = threading.Thread(target=self.worker)
self.worker_thread.start()
if self.realtime:
self.schedule_thread = threading.Thread(target=self.schedule_stream_buffer)
self.schedule_thread.start()
# Publish ComfyUI Image tensor and audio tensor to livestream
def pub_livestream(self, images: torch.Tensor, audios: torch.Tensor):
def buffer_stream(self, images: torch.Tensor, audios: torch.Tensor, gen_video: torch.Tensor):
N, height, width, C = images.shape
M = audios.reshape(-1).shape[0]
assert N % self.slice_frame == 0, "Video frames must be divisible by slice_frame"
assert C == 3, "Input must be [N, H, W, C] with C=3"
logger.info(f"Publishing video [{N}x{width}x{height}], audio: [{M}]")
audio_frames = round(M * self.fps / self.sample_rate)
if audio_frames != N:
logger.warning(f"Video and audio frames mismatch, {N} vs {audio_frames}")
self.set_video_size(width, height)
audio_datas, image_datas = self.convert_data(audios, images)
self.queue.put((audio_datas, image_datas))
logger.info(f"Published {N} frames and {M} audio samples")
self.stoppable_t = time.time() + M / self.sample_rate + 3
# logger.info(f"Buffer stream images {images.shape} {audios.shape} {gen_video.shape}")
rets = []
for i in range(0, N, self.slice_frame):
end_frame = i + self.slice_frame
img = image_datas[i:end_frame]
aud = audio_datas[i * self.target_chunks_per_frame : end_frame * self.target_chunks_per_frame]
gen = gen_video[:, :, (end_frame - self.prev_frame) : end_frame]
rets.append((img, aud, gen))
with self.stream_buffer_lock:
origin_size = len(self.stream_buffer)
self.stream_buffer.extend(rets)
logger.info(f"Buffered {origin_size} + {len(rets)} = {len(self.stream_buffer)} stream segments")
def get_buffer_stream_size(self):
return len(self.stream_buffer)
def truncate_stream_buffer(self, size: int):
with self.stream_buffer_lock:
self.stream_buffer = self.stream_buffer[:size]
logger.info(f"Truncated stream buffer to {len(self.stream_buffer)} segments")
if len(self.stream_buffer) > 0:
return self.stream_buffer[-1][2] # return the last video tensor
else:
return None
def schedule_stream_buffer(self):
schedule_interval = self.slice_frame / self.fps
logger.info(f"Schedule stream buffer with interval: {schedule_interval} seconds")
t = None
while True:
try:
if self.stop_schedule:
break
img, aud, gen = None, None, None
with self.stream_buffer_lock:
if len(self.stream_buffer) > 0:
img, aud, gen = self.stream_buffer.pop(0)
if t is not None:
wait_secs = schedule_interval - (time.time() - t)
if wait_secs > 0:
time.sleep(wait_secs)
t = time.time()
if img is not None and aud is not None:
self.queue.put((aud, img))
# logger.info(f"Scheduled {img.shape[0]} frames and {aud.shape[0]} audio samples to publish")
del gen
self.stoppable_t = time.time() + img.shape[0] / self.fps + 3
else:
logger.warning(f"No stream buffer to schedule")
except Exception:
logger.error(f"Schedule stream buffer error: {traceback.format_exc()}")
break
logger.info("Schedule stream buffer thread stopped")
def stop(self, wait=True):
if wait and self.stoppable_t:
......@@ -142,6 +213,12 @@ class X264VARecorder:
time.sleep(t)
self.stoppable_t = None
if self.schedule_thread:
self.stop_schedule = True
self.schedule_thread.join(timeout=5)
if self.schedule_thread and self.schedule_thread.is_alive():
logger.error(f"Schedule thread did not stop after 5s")
# Send stop signals to queues
if self.queue:
self.queue.put(None)
......@@ -219,7 +296,7 @@ if __name__ == "__main__":
cur_audio_array = np.zeros(int(interval * sample_rate), dtype=np.float32)
num_frames = int(interval * fps)
images = create_simple_video(num_frames, height, width)
recorder.pub_livestream(images, torch.tensor(cur_audio_array, dtype=torch.float32))
recorder.buffer_stream(images, torch.tensor(cur_audio_array, dtype=torch.float32), images)
i += interval
time.sleep(interval - (time.time() - t0))
......@@ -237,7 +314,7 @@ if __name__ == "__main__":
if started:
logger.warning(f"start pub_livestream !!!!!!!!!!!!!!!!!!!!!!!")
started = False
recorder.pub_livestream(images, cur_audio_array)
recorder.buffer_stream(images, cur_audio_array, images)
i += interval
time.sleep(interval - (time.time() - t0))
......
import argparse
import asyncio
import base64
import copy
import json
import mimetypes
import os
......@@ -1030,13 +1031,17 @@ async def api_v1_share_get(share_id: str):
@app.get("/api/v1/voices/list")
async def api_v1_voices_list():
async def api_v1_voices_list(request: Request):
try:
version = request.query_params.get("version", "all")
if volcengine_tts_client is None:
return error_response("Volcengine TTS client not loaded", 500)
voices = volcengine_tts_client.get_voice_list()
if voices is None:
return error_response("No voice list found", 404)
if version != "all":
voices = copy.deepcopy(voices)
voices["voices"] = [v for v in voices["voices"] if v["version"] == version]
return voices
except Exception as e:
traceback.print_exc()
......
......@@ -34,7 +34,7 @@ HEADERS = {"Authorization": f"Bearer {WORKER_SECRET_KEY}", "Content-Type": "appl
STOPPED = False
WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))
RANK = int(os.environ.get("RANK", 0))
TARGET_RANK = WORLD_SIZE - 1
TARGET_RANK = int(os.getenv("WORKER_RANK", "0")) % WORLD_SIZE
async def ping_life(server_url, worker_identity, keys):
......@@ -251,14 +251,17 @@ async def main(args):
logger.warning("Main loop cancelled, do not shut down")
finally:
try:
if ping_task:
ping_task.cancel()
await sync_subtask()
except Exception:
logger.warning(f"Sync subtask failed: {traceback.format_exc()}")
if RANK == TARGET_RANK and sub["task_id"] in RUNNING_SUBTASKS:
try:
await report_task(status=status, **sub)
except: # noqa
except Exception:
logger.warning(f"Report failed: {traceback.format_exc()}")
if ping_task:
ping_task.cancel()
await sync_subtask()
async def shutdown(loop):
......
......@@ -40,8 +40,8 @@ class BaseWorker:
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
set_parallel_config(config)
# same as va_recorder rank and worker main ping rank
self.out_video_rank = self.world_size - 1
# same as va_recorder rank
self.out_video_rank = int(os.getenv("RECORDER_RANK", "0")) % self.world_size
torch.set_grad_enabled(False)
self.runner = RUNNER_REGISTER[config["model_cls"]](config)
self.input_info = set_input_info(args)
......
import os
from abc import ABC
import torch
......@@ -139,19 +140,30 @@ class BaseRunner(ABC):
if dist.is_initialized():
rank = dist.get_rank()
world_size = dist.get_world_size()
signal_rank = world_size - 1
stop_rank = int(os.getenv("WORKER_RANK", "0")) % world_size # same as worker hub target_rank
pause_rank = int(os.getenv("READER_RANK", "0")) % world_size # same as va_reader target_rank
stopped = 0
if rank == signal_rank and hasattr(self, "stop_signal") and self.stop_signal:
stopped, paused = 0, 0
if rank == stop_rank and hasattr(self, "stop_signal") and self.stop_signal:
stopped = 1
if rank == pause_rank and hasattr(self, "pause_signal") and self.pause_signal:
paused = 1
if world_size > 1:
if rank == signal_rank:
t = torch.tensor([stopped], dtype=torch.int32).to(device=AI_DEVICE)
if rank == stop_rank:
t1 = torch.tensor([stopped], dtype=torch.int32).to(device=AI_DEVICE)
else:
t = torch.zeros(1, dtype=torch.int32, device=AI_DEVICE)
dist.broadcast(t, src=signal_rank)
stopped = t.item()
t1 = torch.zeros(1, dtype=torch.int32, device=AI_DEVICE)
if rank == pause_rank:
t2 = torch.tensor([paused], dtype=torch.int32).to(device=AI_DEVICE)
else:
t2 = torch.zeros(1, dtype=torch.int32, device=AI_DEVICE)
dist.broadcast(t1, src=stop_rank)
dist.broadcast(t2, src=pause_rank)
stopped = t1.item()
paused = t2.item()
if stopped == 1:
raise Exception(f"find rank: {rank} stop_signal, stop running, it's an expected behavior")
if paused == 1:
raise Exception(f"find rank: {rank} pause_signal, pause running, it's an expected behavior")
......@@ -18,9 +18,7 @@ from loguru import logger
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import resize
from lightx2v.deploy.common.va_reader import VAReader
from lightx2v.deploy.common.va_recorder import VARecorder
from lightx2v.deploy.common.va_recorder_x264 import X264VARecorder
from lightx2v.deploy.common.va_controller import VAController
from lightx2v.models.input_encoders.hf.seko_audio.audio_adapter import AudioAdapter
from lightx2v.models.input_encoders.hf.seko_audio.audio_encoder import SekoAudioEncoderModel
from lightx2v.models.networks.wan.audio_model import WanAudioModel
......@@ -694,15 +692,15 @@ class WanAudioRunner(WanRunner): # type:ignore
)
if "video_super_resolution" in self.config and self.vsr_model is not None:
logger.info(f"Applying video super resolution with scale {self.config['video_super_resolution']['scale']}")
# logger.info(f"Applying video super resolution with scale {self.config['video_super_resolution']['scale']}")
video_seg = self.vsr_model.super_resolve_frames(
video_seg,
seed=self.config["video_super_resolution"]["seed"],
scale=self.config["video_super_resolution"]["scale"],
)
if self.va_recorder:
self.va_recorder.pub_livestream(video_seg, audio_seg)
if self.va_controller.recorder is not None:
self.va_controller.pub_livestream(video_seg, audio_seg, self.gen_video[:, :, :useful_length])
elif self.input_info.return_result_tensor:
self.gen_video_final[self.segment.start_frame : self.segment.end_frame].copy_(video_seg)
self.cut_audio_final[self.segment.start_frame * self._audio_processor.audio_frame_rate : self.segment.end_frame * self._audio_processor.audio_frame_rate].copy_(audio_seg)
......@@ -721,85 +719,35 @@ class WanAudioRunner(WanRunner): # type:ignore
world_size = dist.get_world_size()
return rank, world_size
def init_va_recorder(self):
output_video_path = self.input_info.save_result_path
self.va_recorder = None
if isinstance(output_video_path, dict):
output_video_path = output_video_path["data"]
logger.info(f"init va_recorder with output_video_path: {output_video_path}")
rank, world_size = self.get_rank_and_world_size()
if output_video_path and rank == world_size - 1:
record_fps = self.config.get("target_fps", 16)
audio_sr = self.config.get("audio_sr", 16000)
if "video_frame_interpolation" in self.config and self.vfi_model is not None:
record_fps = self.config["video_frame_interpolation"]["target_fps"]
whip_shared_path = os.getenv("WHIP_SHARED_LIB", None)
if whip_shared_path and output_video_path.startswith("http"):
self.va_recorder = X264VARecorder(
whip_shared_path=whip_shared_path,
livestream_url=output_video_path,
fps=record_fps,
sample_rate=audio_sr,
)
else:
self.va_recorder = VARecorder(
livestream_url=output_video_path,
fps=record_fps,
sample_rate=audio_sr,
)
def init_va_reader(self):
audio_path = self.input_info.audio_path
self.va_reader = None
if isinstance(audio_path, dict):
assert audio_path["type"] == "stream", f"unexcept audio_path: {audio_path}"
rank, world_size = self.get_rank_and_world_size()
target_fps = self.config.get("target_fps", 16)
max_num_frames = self.config.get("target_video_length", 81)
audio_sr = self.config.get("audio_sr", 16000)
prev_frames = self.config.get("prev_frame_length", 5)
self.va_reader = VAReader(
rank=rank,
world_size=world_size,
stream_url=audio_path["data"],
sample_rate=audio_sr,
segment_duration=max_num_frames / target_fps,
prev_duration=prev_frames / target_fps,
target_rank=1,
)
def run_main(self):
try:
self.init_va_recorder()
self.init_va_reader()
logger.info(f"init va_recorder: {self.va_recorder} and va_reader: {self.va_reader}")
self.va_controller = VAController(self)
logger.info(f"init va_recorder: {self.va_controller.recorder} and va_reader: {self.va_controller.reader}")
if self.va_reader is None:
# fixed audio segments inputs
if self.va_controller.reader is None:
return super().run_main()
self.va_reader.start()
rank, world_size = self.get_rank_and_world_size()
if rank == world_size - 1:
assert self.va_recorder is not None, "va_recorder is required for stream audio input for rank 2"
self.va_recorder.start(self.input_info.target_shape[1], self.input_info.target_shape[0])
if world_size > 1:
dist.barrier()
self.va_controller.start()
self.init_run()
if self.config.get("compile", False):
if self.config.get("compile", False) and hasattr(self.model, "comple"):
self.model.select_graph_for_compile(self.input_info)
self.video_segment_num = "unlimited"
fetch_timeout = self.va_reader.segment_duration + 1
# steam audio input, video segment num is unlimited
self.video_segment_num = 1000000
segment_idx = 0
fail_count = 0
max_fail_count = 10
fail_count, max_fail_count = 0, 10
self.va_controller.before_control()
while True:
with ProfilingContext4DebugL1(f"stream segment get audio segment {segment_idx}"):
self.check_stop()
audio_array = self.va_reader.get_audio_segment(timeout=fetch_timeout)
control = self.va_controller.next_control()
if control.action == "immediate":
self.prev_video = control.data
elif control.action == "wait":
time.sleep(0.01)
continue
audio_array = self.va_controller.reader.get_audio_segment()
if audio_array is None:
fail_count += 1
logger.warning(f"Failed to get audio chunk {fail_count} times")
......@@ -808,22 +756,27 @@ class WanAudioRunner(WanRunner): # type:ignore
continue
with ProfilingContext4DebugL1(f"stream segment end2end {segment_idx}"):
fail_count = 0
self.init_run_segment(segment_idx, audio_array)
latents = self.run_segment(segment_idx)
self.gen_video = self.run_vae_decoder(latents)
self.end_run_segment(segment_idx)
segment_idx += 1
try:
# reset pause signal
self.pause_signal = False
self.init_run_segment(segment_idx, audio_array)
self.check_stop()
latents = self.run_segment(segment_idx)
self.check_stop()
self.gen_video = self.run_vae_decoder(latents)
self.check_stop()
self.end_run_segment(segment_idx)
segment_idx += 1
fail_count = 0
except Exception as e:
if "pause_signal, pause running" in str(e):
logger.warning(f"model infer audio pause: {e}, should continue")
else:
raise
finally:
if hasattr(self.model, "inputs"):
self.end_run()
if self.va_reader:
self.va_reader.stop()
self.va_reader = None
if self.va_recorder:
self.va_recorder.stop()
self.va_recorder = None
self.va_controller.clear()
@ProfilingContext4DebugL1("Process after vae decoder")
def process_images_after_vae_decoder(self):
......
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["DTYPE"] = "BF16"
os.environ["SENSITIVE_LAYER_DTYPE"] = "None"
os.environ["PROFILING_DEBUG_LEVEL"] = "2"
# please do not set envs in this file, it will be imported by the __init__.py file
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# os.environ["DTYPE"] = "BF16"
# os.environ["SENSITIVE_LAYER_DTYPE"] = "None"
# os.environ["PROFILING_DEBUG_LEVEL"] = "2"
import json
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment