Unverified Commit bc2828b0 authored by LiangLiu's avatar LiangLiu Committed by GitHub
Browse files

Dev/omni (#577)

Tidy VAReader & OmniVAReader
Tidy VARecorder & X264VARecorder
VARecorder with stream, use buffer stream
Tidy env WORKER_RANK, READER_RANK, RECORDER_RANK
Support voice type choose
parent f3b4ba24
import math
import os
import torch
import torch.distributed as dist
from loguru import logger
from lightx2v.deploy.common.va_reader import VAReader
from lightx2v.deploy.common.va_reader_omni import OmniVAReader
from lightx2v.deploy.common.va_recorder import VARecorder
from lightx2v.deploy.common.va_recorder_x264 import X264VARecorder
from lightx2v.models.runners.vsr.vsr_wrapper import compute_scaled_and_target_dims
from lightx2v_platform.base.global_var import AI_DEVICE
class NextControl:
def __init__(self, action: str, data: any = None):
# action: switch, data: prev_video tensor
# action: wait, data: None
# action: fetch, data: None
self.action = action
self.data = data
class VAController:
def __init__(self, model_runner):
self.reader = None
self.recorder = None
self.rank = 0
self.world_size = 1
if dist.is_initialized():
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.target_reader_rank = int(os.getenv("READER_RANK", "0")) % self.world_size
self.target_recorder_rank = int(os.getenv("RECORDER_RANK", "0")) % self.world_size
self.init_base(model_runner.config, model_runner.input_info, model_runner.vfi_model is not None, model_runner.vsr_model is not None)
self.init_recorder()
self.init_reader(model_runner)
def init_base(self, config, input_info, has_vfi_model, has_vsr_model):
self.audio_path = input_info.audio_path
self.output_video_path = input_info.save_result_path
if isinstance(self.output_video_path, dict):
self.output_video_path = self.output_video_path["data"]
self.audio_sr = config.get("audio_sr", 16000)
self.target_fps = config.get("target_fps", 16)
self.max_num_frames = config.get("target_video_length", 81)
self.prev_frame_length = config.get("prev_frame_length", 5)
self.record_fps = config.get("target_fps", 16)
if "video_frame_interpolation" in config and has_vfi_model:
self.record_fps = config["video_frame_interpolation"]["target_fps"]
self.record_fps = config.get("record_fps", self.record_fps)
self.tgt_h = input_info.target_shape[0]
self.tgt_w = input_info.target_shape[1]
self.record_h, self.record_w = self.tgt_h, self.tgt_w
if "video_super_resolution" in config and has_vsr_model:
_, _, self.record_w, self.record_h = compute_scaled_and_target_dims(
self.record_w,
self.record_h,
scale=config["video_super_resolution"]["scale"],
multiple=128,
)
# how many frames to publish stream as a batch
self.slice_frame = config.get("slice_frame", 1)
# estimate the max infer seconds, for immediate switch with local omni
slice_interval = self.slice_frame / self.record_fps
est_max_infer_secs = config.get("est_max_infer_secs", 0.6)
self.est_infer_end_idx = math.ceil(est_max_infer_secs / slice_interval)
self.min_stay_queue_num = self.est_infer_end_idx * 2 + 1
def init_recorder(self):
if not self.output_video_path or self.rank != self.target_recorder_rank:
return
logger.info(f"Rank {self.rank} init recorder with: {self.output_video_path}")
whip_shared_path = os.getenv("WHIP_SHARED_LIB", None)
if whip_shared_path and self.output_video_path.startswith("http"):
self.recorder = X264VARecorder(
whip_shared_path=whip_shared_path,
livestream_url=self.output_video_path,
fps=self.record_fps,
sample_rate=self.audio_sr,
slice_frame=self.slice_frame,
prev_frame=self.prev_frame_length,
)
else:
self.recorder = VARecorder(
livestream_url=self.output_video_path,
fps=self.record_fps,
sample_rate=self.audio_sr,
slice_frame=self.slice_frame,
prev_frame=self.prev_frame_length,
)
def init_reader(self, model_runner=None):
if not isinstance(self.audio_path, dict):
return
assert self.audio_path["type"] == "stream", f"unexcept audio_path: {self.audio_path}"
segment_duration = self.max_num_frames / self.target_fps
prev_duration = self.prev_frame_length / self.target_fps
omni_work_dir = os.getenv("OMNI_WORK_DIR", None)
if omni_work_dir:
self.reader = OmniVAReader(
rank=self.rank,
world_size=self.world_size,
stream_url=self.audio_path["data"],
sample_rate=self.audio_sr,
segment_duration=segment_duration,
prev_duration=prev_duration,
target_rank=self.target_reader_rank,
model_runner=model_runner,
huoshan_tts_voice_type=self.audio_path.get("huoshan_tts_voice_type", None),
)
else:
self.reader = VAReader(
rank=self.rank,
world_size=self.world_size,
stream_url=self.audio_path["data"],
sample_rate=self.audio_sr,
segment_duration=segment_duration,
prev_duration=prev_duration,
target_rank=self.target_reader_rank,
)
def start(self):
self.reader.start()
if self.rank == self.target_recorder_rank:
assert self.recorder is not None, f"recorder is required for stream audio input for rank {self.rank}"
self.recorder.start(self.record_w, self.record_h)
if self.world_size > 1:
dist.barrier()
def next_control(self):
if isinstance(self.reader, OmniVAReader):
return self.omni_reader_next_control()
return NextControl(action="fetch")
def before_control(self):
if isinstance(self.reader, OmniVAReader):
self.len_tensor = torch.tensor([0], dtype=torch.int32, device=AI_DEVICE)
self.flag_tensor = torch.tensor([0], dtype=torch.int32, device=AI_DEVICE)
self.prev_tensor = torch.zeros((1, 3, self.prev_frame_length, self.tgt_h, self.tgt_w), dtype=torch.float, device=AI_DEVICE)
def omni_reader_next_control(self):
immediate_switch = self.reader.get_immediate_switch()
if immediate_switch == 1:
# truncate the stream buffer to keep the max infer time length
# and broadcast the prev video tensor to all ranks
if self.rank == self.target_recorder_rank:
logger.warning(f"runner recv immediate switch, truncate stream buffer")
video_tensor = self.recorder.truncate_stream_buffer(self.est_infer_end_idx)
if video_tensor is not None:
self.flag_tensor.fill_(1)
self.prev_tensor.copy_(video_tensor)
else:
self.flag_tensor.fill_(0)
dist.broadcast(self.flag_tensor, src=self.target_recorder_rank)
if self.flag_tensor.item() == 1:
dist.broadcast(self.prev_tensor, src=self.target_recorder_rank)
return NextControl(action="switch", data=self.prev_tensor)
else:
# get the length of stream buffer, broadcast to all ranks
if self.rank == self.target_recorder_rank:
stream_buffer_length = self.recorder.get_buffer_stream_size()
self.len_tensor.copy_(stream_buffer_length)
dist.broadcast(self.len_tensor, src=self.target_recorder_rank)
buffer_length = self.len_tensor.item()
# stream buffer is enough, skip infer
if buffer_length >= self.min_stay_queue_num:
return NextControl(action="wait")
return NextControl(action="fetch")
def pub_livestream(self, images: torch.Tensor, audios: torch.Tensor, gen_video: torch.Tensor):
if self.recorder.realtime:
self.recorder.buffer_stream(images, audios, gen_video)
else:
self.recorder.pub_livestream(images, audios)
def clear(self):
self.len_tensor = None
self.flag_tensor = None
self.prev_tensor = None
if self.reader is not None:
self.reader.stop()
self.reader = None
if self.recorder is not None:
self.recorder.stop()
self.recorder = None
def __del__(self):
self.clear()
import datetime
import json
import os
import random
import subprocess
import threading
import time
import traceback
from collections import deque
from copy import deepcopy
import jsonschema
import numpy as np
import torch
import torch.distributed as dist
import zmq
from bson import BSON
from loguru import logger
from scipy.signal import resample
class AudioInfo:
def __init__(self, info: dict):
self.sample_count = info["sample_count"]
self.sample_rate = info["sample_rate"]
self.channel_count = info["channel_count"]
self.sample_fmt = info["sample_fmt"]
self.pts = info["pts"]
def is_spec_equal(self, other: "AudioInfo") -> bool:
return self.sample_fmt == other.sample_fmt and self.sample_rate == other.sample_rate and self.channel_count == other.channel_count
def duration(self) -> datetime.timedelta:
return datetime.timedelta(seconds=self.sample_count / self.sample_rate)
def __str__(self):
return "AudioInfo(sample_count={}, sample_rate={}, channel_count={}, sample_fmt={}, pts={})".format(self.sample_count, self.sample_rate, self.channel_count, self.sample_fmt, self.pts)
class ByteBuffer:
def __init__(self):
self.buffer = deque()
self.current_size = 0
# is the audio belonging to current turn finished
self.audio_finished = False
def add(self, byte_data: bytes):
self.buffer.append(byte_data)
self.current_size += len(byte_data)
def get(self, size=1024):
data = bytearray()
while size > 0 and len(self.buffer) > 0:
chunk = self.buffer.popleft()
if len(chunk) <= size:
# 如果当前数据小于size,则将当前数据全部添加到data中
data.extend(chunk)
self.current_size -= len(chunk)
size -= len(chunk)
else:
# 如果当前数据大于size,则将当前数据的一部分添加到data中,剩余部分留在缓冲区
data.extend(chunk[:size])
self.buffer.appendleft(chunk[size:]) # 剩余部分留在缓冲区
self.current_size -= size
size = 0
return bytes(data)
def mark_finished(self):
self.audio_finished = True
def has_more_voice(self):
return not self.audio_finished
def __len__(self):
return self.current_size
class ChatAdapter:
def __init__(
self,
omni_work_dir: str,
whep_url: str,
session_id: str,
account: str,
config_files: list[str],
config_schema_path: str,
seg_duration: float,
model_runner,
huoshan_tts_voice_type,
):
assert os.path.exists(omni_work_dir), f"OMNI work directory {omni_work_dir} does not exist"
self.omni_work_dir = omni_work_dir
self.context = zmq.Context()
self.w2f_socket = self.context.socket(zmq.PULL)
self.w2f_url = ChatAdapter.select_and_bind(self.w2f_socket)
self.f2w_socket = self.context.socket(zmq.PUSH)
self.f2w_url = ChatAdapter.select_and_bind(self.f2w_socket)
self.recv_thread = None
self.audio_buffer = ByteBuffer()
self.audio_info = None
self.chat_server_cmd = [
os.path.join(self.omni_work_dir, "bin", "seko-chatter"),
"--session-id",
session_id,
"--account",
account,
"--whep-server-url",
whep_url,
"--w2f-endpoint",
self.w2f_url,
"--f2w-endpoint",
self.f2w_url,
"--config-files",
*config_files,
]
override_config = {}
if huoshan_tts_voice_type is not None:
logger.info(f"Use Huoshan TTS voice type: {huoshan_tts_voice_type}")
override_config["TTS"] = {
"default_voice_info": {
"voice_type": huoshan_tts_voice_type,
"provider": "huoshan_stream_tts",
}
}
with open(config_schema_path, "r") as f:
schema = json.load(f)
jsonschema.validate(instance=override_config, schema=schema)
if override_config is not None:
self.chat_server_cmd.extend(["--override-config", json.dumps(override_config)])
self.chatter_proc = None
self.seg_duration = seg_duration
self.reset_prev = False
self.status = "blank"
self.immediate_switch = 0
self.model_runner = model_runner
def launch_chat_server(self):
env = {
"RUST_LOG": "info,duplex_server=debug,backend_5o=debug",
"LD_LIBRARY_PATH": os.environ.get("LD_LIBRARY_PATH", "") + ":" + os.path.join(self.omni_work_dir, "lib/"),
"PATH": os.environ["PATH"] + ":" + os.path.join(self.omni_work_dir, "bin/"),
}
self.chatter_proc = subprocess.Popen(self.chat_server_cmd, env=env, cwd=self.omni_work_dir)
@staticmethod
def select_and_bind(socket: zmq.Socket) -> str:
# randomly select a port between 1024 and 6553
retry_count = 20
err = None
while retry_count > 0:
try:
port = random.randint(1024, 65535)
# port = 5555
url = f"tcp://localhost:{port}"
socket.bind(url)
return url
except zmq.error.ZMQError as e:
retry_count -= 1
err = e
raise err
# immediate switch to status, discard prev_bytes, set immediate_switch to 1
def immediate_switch_to(self, status):
logger.warning(f"VA reader immediate switch to {status}")
self.reset_prev = True
self.status = status
self.immediate_switch = 1
if self.model_runner is not None:
self.model_runner.pause_signal = True
logger.warning(f"Model runner pause signal set to True")
def recv_loop(self):
while True:
try:
message = self.w2f_socket.recv()
except Exception:
logger.error(f"Error receiving message: {traceback.format_exc()}")
break
try:
message = BSON.decode(message)
msg_type = message["type"]
logger.debug("Received message type: {}".format(msg_type))
if msg_type == "AgentAudio":
audio = message["audio"]
if audio["type"] != "Pcm":
logger.error("Unsupported audio type: {}".format(audio["type"]))
continue
pcm_data = audio["data"]
audio_info = AudioInfo(audio["info"])
logger.debug("Received audio with duration: {}".format(audio_info.duration()))
if self.audio_info is None:
self.audio_info = audio_info
else:
# check if the audio info is the same
if not self.audio_info.is_spec_equal(audio_info):
raise ValueError("Audio info mismatch")
self.audio_buffer.add(pcm_data)
# if status is blank and has voice, set immediate switch to 1
if self.status == "blank" and self.has_voice(self.seg_duration):
self.immediate_switch_to("voice")
elif msg_type == "AgentStartPlay":
logger.debug("Received AgentStartPlay, create new audio buffer")
self.audio_buffer = ByteBuffer()
elif msg_type == "AgentEndPlay":
logger.debug("Received AgentEndPlay, mark audio finished")
self.audio_buffer.mark_finished()
elif msg_type == "ClearAgentAudio":
logger.warning("Received ClearAgentAudio, clear audio buffer")
self.audio_buffer = None
self.audio_info = None
if self.status == "voice":
self.status = "blank"
# self.immediate_switch_to("blank")
except Exception as e:
logger.error("Error decoding message: {}, continue".format(e))
continue
logger.warning("recv loop interrupted")
def start(self):
self.launch_chat_server()
self.recv_thread = threading.Thread(target=self.recv_loop)
self.recv_thread.start()
def has_voice(self, duration) -> bool:
if self.audio_info is None or self.audio_buffer.current_size == 0:
return False
bytes_count = round(duration * self.audio_info.sample_rate) * self.audio_info.channel_count * 2 # S16LE assumed
# if not has enough bytes and maybe has more voice, return False
if self.audio_buffer.current_size < bytes_count and self.audio_buffer.has_more_voice():
logger.warning(f"Not enough bytes and maybe has more voice, content_size: {self.audio_buffer.current_size}, bytes_count: {bytes_count}")
return False
return bytes_count
def get_audio(self, fetch_duration) -> (bytes, AudioInfo):
bytes_count = self.has_voice(fetch_duration)
if bytes_count is False:
return None
pcm_data = self.audio_buffer.get(bytes_count)
# the actual sample count fetched
sample_count = len(pcm_data) // (self.audio_info.channel_count * 2)
logger.debug("Fetched {} bytes audio".format(sample_count))
logger.debug("After fetch, there are {} bytes left".format(self.audio_buffer.current_size))
audio_info = deepcopy(self.audio_info)
audio_info.sample_count = sample_count
return (pcm_data, audio_info)
def stop(self):
self.model_runner = None
if self.chatter_proc is not None:
self.chatter_proc.terminate()
self.chatter_proc.wait()
self.chatter_proc = None
self.w2f_socket.close()
self.f2w_socket.close()
def __del__(self):
self.stop()
class OmniVAReader:
def __init__(
self,
rank: int,
world_size: int,
stream_url: str,
segment_duration: float = 5.0625,
sample_rate: int = 16000,
audio_channels: int = 1,
buffer_size: int = 1,
prev_duration: float = 0.3125,
target_rank: int = 0,
model_runner=None,
huoshan_tts_voice_type=None,
):
self.rank = rank
self.world_size = world_size
self.stream_url = stream_url
self.segment_duration = segment_duration
self.sample_rate = sample_rate
self.audio_channels = audio_channels
self.prev_duration = prev_duration
self.all_seg_sample_count = int(self.segment_duration * self.sample_rate)
self.prev_seg_sample_count = int(self.prev_duration * self.sample_rate)
self.prev_seg_chunk = None
self.target_rank = target_rank % self.world_size
self.flag_tensor = torch.tensor([0], dtype=torch.int32).to(device="cuda")
self.immediate_switch_tensor = torch.tensor([0], dtype=torch.int32).to(device="cuda")
chunk_size = int(self.segment_duration * self.sample_rate) * 2
self.audio_tensor = torch.zeros(chunk_size, dtype=torch.uint8, device="cuda")
self.chat_adapter = None
self.model_runner = model_runner
self.huoshan_tts_voice_type = huoshan_tts_voice_type
assert self.audio_channels == 1, "Only mono audio is supported for OmniVAReader"
logger.info(f"VAReader initialized for stream: {stream_url} target_rank: {self.target_rank}")
logger.info(f"Audio duration per chunk: {segment_duration}s, sample rate: {sample_rate}Hz")
def init_omni_env(self):
self.omni_work_dir = os.getenv("OMNI_WORK_DIR", "/path/of/seko_chatter/")
self.session_id = os.getenv("OMNI_SESSION_ID", "")
self.account = os.getenv("OMNI_ACCOUNT", "")
self.config_files = os.getenv("OMNI_CONFIG_FILES", "").split(",")
self.config_schema_path = os.getenv("OMNI_CONFIG_SCHEMA_PATH", None)
assert os.path.exists(self.omni_work_dir), f"OMNI work directory {self.omni_work_dir} does not exist"
assert self.session_id and self.account, "OMNI_SESSION_ID and OMNI_ACCOUNT are required"
logger.info(
f"OMNI work directory: {self.omni_work_dir}, session_id: {self.session_id}, account: {self.account}, config_files: {self.config_files}, config_schema_path: {self.config_schema_path}"
)
def start(self):
if self.rank == self.target_rank:
self.init_omni_env()
assert self.stream_url.startswith("http"), "Only HTTP stream is supported for OmniVAReader"
self.chat_adapter = ChatAdapter(
omni_work_dir=self.omni_work_dir,
whep_url=self.stream_url,
session_id=self.session_id,
account=self.account,
config_files=self.config_files,
config_schema_path=self.config_schema_path,
seg_duration=self.segment_duration,
model_runner=self.model_runner,
huoshan_tts_voice_type=self.huoshan_tts_voice_type,
)
self.chat_adapter.start()
logger.info(f"OmniVAReader {self.rank}/{self.world_size} started successfully")
else:
logger.info(f"OmniVAReader {self.rank}/{self.world_size} wait only")
if self.world_size > 1:
logger.info(f"OmniVAReader {self.rank}/{self.world_size} wait barrier")
dist.barrier()
logger.info(f"OmniVAReader {self.rank}/{self.world_size} end barrier")
def braodcast_audio_data(self, audio_data):
if self.rank == self.target_rank:
if audio_data is None:
self.flag_tensor.fill_(0)
else:
self.flag_tensor.fill_(1)
self.audio_tensor.copy_(torch.frombuffer(bytearray(audio_data), dtype=torch.uint8))
# logger.info(f"rank {self.rank} send audio_tensor: {self.audio_tensor.shape}")
dist.broadcast(self.flag_tensor, src=self.target_rank)
if self.flag_tensor.item() == 0:
return None
dist.broadcast(self.audio_tensor, src=self.target_rank)
if self.rank != self.target_rank:
# logger.info(f"rank {self.rank} recv audio_tensor: {self.audio_tensor.shape}")
audio_data = self.audio_tensor.cpu().numpy().tobytes()
return audio_data
def bytes_to_ndarray(self, audio_data):
if audio_data is None:
return None
audio_data = np.frombuffer(audio_data, dtype=np.int16)
audio_data = audio_data.astype(np.float32) / 32768.0
# logger.info(f"Got segment audio rank={self.rank}: {audio_data.shape} {audio_data.dtype} {audio_data.min()} {audio_data.max()}")
return audio_data
def convert_pcm_s16le_to_mono_resampled(self, audio_data, audio_info):
audio = np.frombuffer(audio_data, dtype=np.int16)
sample_count = audio_info.sample_count
assert len(audio) == sample_count * audio_info.channel_count, f"audio length {len(audio)} != sample_count * channel_count {sample_count * audio_info.channel_count}"
# convert to mono
if audio_info.channel_count > 1:
audio = audio.reshape(-1, audio_info.channel_count).mean(axis=1)
# logger.info(f"audio: {audio.shape} {audio.dtype} {audio.min()} {audio.max()}")
if audio_info.sample_rate != self.sample_rate:
sample_count = int(len(audio) * self.sample_rate / audio_info.sample_rate)
audio = resample(audio, sample_count).astype(np.int16)
# logger.info(f"resampled audio: {audio.shape} {audio.dtype} {audio.min()} {audio.max()} {sample_count}")
logger.warning(f"valid audio: {audio.shape} {audio.dtype} {audio.min()} {audio.max()} {sample_count}")
return audio, sample_count
def prepare_audio_data(self, chat_audio_result):
sample_count = 0
audio = np.array([], dtype=np.int16)
# convert chat audio result to mono and target sample rate
if chat_audio_result is not None:
audio_data, audio_info = chat_audio_result
audio, sample_count = self.convert_pcm_s16le_to_mono_resampled(audio_data, audio_info)
# if is not the first segment, concat with previous segment
if self.prev_seg_chunk is not None:
audio = np.concatenate([self.prev_seg_chunk, audio])
sample_count = len(audio)
assert sample_count <= self.all_seg_sample_count, f"audio length {sample_count} > all_seg_sample_count {self.all_seg_sample_count}"
# pad 0 to the audio to make it the same length as all_seg_sample_count
if sample_count < self.all_seg_sample_count:
pad_count = self.all_seg_sample_count - sample_count
# logger.info(f"pad {pad_count} samples to audio")
audio = np.pad(audio, (0, pad_count), mode="constant", constant_values=0)
sample_count = len(audio)
# update prev seg chunk
self.prev_seg_chunk = audio[-self.prev_seg_sample_count :]
# logger.info(f"audio: {audio.shape} {audio.dtype} {audio.min()} {audio.max()} {sample_count}, prev seg chunk: {self.prev_seg_chunk.shape}")
return audio.tobytes()
def get_fetch_duration(self):
fetch_duration = self.segment_duration
# after immediate switch, reset prev seg chunk
if self.chat_adapter.reset_prev:
self.prev_seg_chunk = None
self.chat_adapter.reset_prev = False
logger.warning(f"Reset prev seg chunk")
# first segment, fetch segment_duration, else fetch segment_duration - prev_duration
if self.prev_seg_chunk is not None:
fetch_duration -= self.prev_duration
return fetch_duration
def get_audio_segment(self):
audio_data = None
if self.rank == self.target_rank:
try:
fetch_duration = self.get_fetch_duration()
# logger.info(f"Get segment, fetch_duration: {fetch_duration}")
if self.chat_adapter.status == "voice":
audio_result = self.chat_adapter.get_audio(fetch_duration)
audio_data = self.prepare_audio_data(audio_result)
# think all voice segments inferred, naturally switch to blank
if audio_result is None:
logger.info(f"Think all voice segments inferred, naturally switch to blank")
self.chat_adapter.status = "blank"
else:
audio_data = self.prepare_audio_data(None)
except Exception as e:
logger.warning(f"Failed to get voice segment: {e}")
return None
if self.world_size > 1:
audio_data = self.braodcast_audio_data(audio_data)
audio_data = self.bytes_to_ndarray(audio_data)
return audio_data
def get_immediate_switch(self):
if self.rank == self.target_rank:
if self.chat_adapter.immediate_switch == 1:
self.immediate_switch_tensor.fill_(1)
# reset immediate switch
self.chat_adapter.immediate_switch = 0
else:
self.immediate_switch_tensor.fill_(0)
dist.broadcast(self.immediate_switch_tensor, src=self.target_rank)
immediate_switch = self.immediate_switch_tensor.item()
return immediate_switch
def stop(self):
self.model_runner = None
if self.chat_adapter is not None:
self.chat_adapter.stop()
self.chat_adapter = None
logger.warning("OmniVAReader stopped")
def __del__(self):
self.stop()
if __name__ == "__main__":
WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))
RANK = int(os.environ.get("RANK", 0))
if WORLD_SIZE > 1:
dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank())
logger.info(f"Distributed initialized: rank={RANK}, world_size={WORLD_SIZE}")
reader = OmniVAReader(
RANK,
WORLD_SIZE,
"https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whep/?app=publish&stream=test_stream_ll&eip=10.120.114.82:8000",
segment_duration=17 / 16,
sample_rate=16000,
audio_channels=1,
prev_duration=1 / 16,
)
reader.start()
fail_count = 0
max_fail_count = 100000000
try:
while True:
audio_data = reader.get_audio_segment(timeout=1)
if audio_data is not None:
logger.info(f"Got audio chunk, shape: {audio_data.shape}, range: [{audio_data.min()}, {audio_data.max()}]")
fail_count = 0
else:
fail_count += 1
if fail_count > max_fail_count:
logger.warning("Failed to get audio chunk, stop reader")
reader.stop()
break
time.sleep(0.95)
finally:
reader.stop()
......@@ -24,6 +24,8 @@ class VARecorder:
livestream_url: str,
fps: float = 16.0,
sample_rate: int = 16000,
slice_frame: int = 1,
prev_frame: int = 1,
):
self.livestream_url = livestream_url
self.fps = fps
......@@ -36,7 +38,9 @@ class VARecorder:
self.width = None
self.height = None
self.stoppable_t = None
self.realtime = True
self.realtime = False
if self.livestream_url.startswith("rtmp://") or self.livestream_url.startswith("http"):
self.realtime = True
# ffmpeg process for mix video and audio data and push to livestream
self.ffmpeg_process = None
......@@ -53,6 +57,16 @@ class VARecorder:
self.audio_queue = queue.Queue()
self.video_queue = queue.Queue()
# buffer for stream data
self.audio_samples_per_frame = round(self.sample_rate / self.fps)
self.stream_buffer = []
self.stream_buffer_lock = threading.Lock()
self.stop_schedule = False
self.schedule_thread = None
self.slice_frame = slice_frame
self.prev_frame = prev_frame
assert self.slice_frame >= self.prev_frame, "Slice frame must be greater than previous frame"
def init_sockets(self):
# TCP socket for send and recv video and audio data
self.video_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
......@@ -128,7 +142,7 @@ class VARecorder:
except (BrokenPipeError, OSError, ConnectionResetError) as e:
logger.info(f"Video connection closed, stopping worker: {type(e).__name__}")
return
if self.realtime:
if self.realtime and i < data.shape[0] - 1:
time.sleep(max(0, packet_secs - (time.time() - t0)))
fail_time = 0
......@@ -337,7 +351,9 @@ class VARecorder:
def start(self, width: int, height: int):
self.set_video_size(width, height)
duration = 1.0
self.pub_livestream(torch.zeros((int(self.fps * duration), height, width, 3), dtype=torch.float16), torch.zeros(int(self.sample_rate * duration), dtype=torch.float16))
frames = int(self.fps * duration)
samples = int(self.sample_rate * (frames / self.fps))
self.pub_livestream(torch.zeros((frames, height, width, 3), dtype=torch.float16), torch.zeros(samples, dtype=torch.float16))
time.sleep(duration)
def set_video_size(self, width: int, height: int):
......@@ -353,11 +369,13 @@ class VARecorder:
self.start_ffmpeg_process_whip()
else:
self.start_ffmpeg_process_local()
self.realtime = False
self.audio_thread = threading.Thread(target=self.audio_worker)
self.video_thread = threading.Thread(target=self.video_worker)
self.audio_thread.start()
self.video_thread.start()
if self.realtime:
self.schedule_thread = threading.Thread(target=self.schedule_stream_buffer)
self.schedule_thread.start()
# Publish ComfyUI Image tensor and audio tensor to livestream
def pub_livestream(self, images: torch.Tensor, audios: torch.Tensor):
......@@ -377,6 +395,75 @@ class VARecorder:
self.stoppable_t = time.time() + M / self.sample_rate + 3
def buffer_stream(self, images: torch.Tensor, audios: torch.Tensor, gen_video: torch.Tensor):
N, height, width, C = images.shape
M = audios.reshape(-1).shape[0]
assert N % self.slice_frame == 0, "Video frames must be divisible by slice_frame"
assert C == 3, "Input must be [N, H, W, C] with C=3"
audio_frames = round(M * self.fps / self.sample_rate)
if audio_frames != N:
logger.warning(f"Video and audio frames mismatch, {N} vs {audio_frames}")
self.set_video_size(width, height)
# logger.info(f"Buffer stream images {images.shape} {audios.shape} {gen_video.shape}")
rets = []
for i in range(0, N, self.slice_frame):
end_frame = i + self.slice_frame
img = images[i:end_frame]
aud = audios[i * self.audio_samples_per_frame : end_frame * self.audio_samples_per_frame]
gen = gen_video[:, :, (end_frame - self.prev_frame) : end_frame]
rets.append((img, aud, gen))
with self.stream_buffer_lock:
origin_size = len(self.stream_buffer)
self.stream_buffer.extend(rets)
logger.info(f"Buffered {origin_size} + {len(rets)} = {len(self.stream_buffer)} stream segments")
def get_buffer_stream_size(self):
return len(self.stream_buffer)
def truncate_stream_buffer(self, size: int):
with self.stream_buffer_lock:
self.stream_buffer = self.stream_buffer[:size]
logger.info(f"Truncated stream buffer to {len(self.stream_buffer)} segments")
if len(self.stream_buffer) > 0:
return self.stream_buffer[-1][2] # return the last video tensor
else:
return None
def schedule_stream_buffer(self):
schedule_interval = self.slice_frame / self.fps
logger.info(f"Schedule stream buffer with interval: {schedule_interval} seconds")
t = None
while True:
try:
if self.stop_schedule:
break
img, aud, gen = None, None, None
with self.stream_buffer_lock:
if len(self.stream_buffer) > 0:
img, aud, gen = self.stream_buffer.pop(0)
if t is not None:
wait_secs = schedule_interval - (time.time() - t)
if wait_secs > 0:
time.sleep(wait_secs)
t = time.time()
if img is not None and aud is not None:
self.audio_queue.put(aud)
self.video_queue.put(img)
# logger.info(f"Scheduled {img.shape[0]} frames and {aud.shape[0]} audio samples to publish")
del gen
self.stoppable_t = time.time() + aud.shape[0] / self.sample_rate + 3
else:
logger.warning(f"No stream buffer to schedule")
except Exception:
logger.error(f"Schedule stream buffer error: {traceback.format_exc()}")
break
logger.info("Schedule stream buffer thread stopped")
def stop(self, wait=True):
if wait and self.stoppable_t:
t = self.stoppable_t - time.time()
......@@ -385,6 +472,12 @@ class VARecorder:
time.sleep(t)
self.stoppable_t = None
if self.schedule_thread:
self.stop_schedule = True
self.schedule_thread.join(timeout=5)
if self.schedule_thread and self.schedule_thread.is_alive():
logger.error(f"Schedule thread did not stop after 5s")
# Send stop signals to queues
if self.audio_queue:
self.audio_queue.put(None)
......
......@@ -18,6 +18,8 @@ class X264VARecorder:
livestream_url: str,
fps: float = 16.0,
sample_rate: int = 16000,
slice_frame: int = 1,
prev_frame: int = 1,
):
assert livestream_url.startswith("http"), "X264VARecorder only support whip http livestream"
self.livestream_url = livestream_url
......@@ -33,16 +35,29 @@ class X264VARecorder:
self.whip_shared_lib = None
self.whip_shared_handle = None
assert livestream_url.startswith("http"), "X264VARecorder only support whip http livestream"
self.realtime = True
# queue for send data to whip shared api
self.queue = queue.Queue()
self.worker_thread = None
# buffer for stream data
self.target_sample_rate = 48000
self.target_samples_per_frame = round(self.target_sample_rate / self.fps)
self.target_chunks_per_frame = self.target_samples_per_frame * 2
self.stream_buffer = []
self.stream_buffer_lock = threading.Lock()
self.stop_schedule = False
self.schedule_thread = None
self.slice_frame = slice_frame
self.prev_frame = prev_frame
assert self.slice_frame >= self.prev_frame, "Slice frame must be greater than previous frame"
def worker(self):
try:
fail_time, max_fail_time = 0, 10
packet_secs = 1.0 / self.fps
audio_chunk = round(48000 * 2 / self.fps)
audio_samples = round(48000 / self.fps)
while True:
try:
if self.queue is None:
......@@ -55,14 +70,16 @@ class X264VARecorder:
for i in range(images.shape[0]):
t0 = time.time()
cur_audio = audios[i * audio_chunk : (i + 1) * audio_chunk].flatten()
cur_audio = audios[i * self.target_chunks_per_frame : (i + 1) * self.target_chunks_per_frame].flatten()
audio_ptr = cur_audio.ctypes.data_as(ctypes.POINTER(ctypes.c_int16))
self.whip_shared_lib.pushWhipRawAudioFrame(self.whip_shared_handle, audio_ptr, audio_samples)
self.whip_shared_lib.pushWhipRawAudioFrame(self.whip_shared_handle, audio_ptr, self.target_samples_per_frame)
cur_video = images[i].flatten()
video_ptr = cur_video.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8))
self.whip_shared_lib.pushWhipRawVideoFrame(self.whip_shared_handle, video_ptr, self.width, self.height)
time.sleep(max(0, packet_secs - (time.time() - t0)))
if self.realtime and i < images.shape[0] - 1:
time.sleep(max(0, packet_secs - (time.time() - t0)))
fail_time = 0
except: # noqa
......@@ -115,24 +132,78 @@ class X264VARecorder:
self.start_libx264_whip_shared_api(width, height)
self.worker_thread = threading.Thread(target=self.worker)
self.worker_thread.start()
if self.realtime:
self.schedule_thread = threading.Thread(target=self.schedule_stream_buffer)
self.schedule_thread.start()
# Publish ComfyUI Image tensor and audio tensor to livestream
def pub_livestream(self, images: torch.Tensor, audios: torch.Tensor):
def buffer_stream(self, images: torch.Tensor, audios: torch.Tensor, gen_video: torch.Tensor):
N, height, width, C = images.shape
M = audios.reshape(-1).shape[0]
assert N % self.slice_frame == 0, "Video frames must be divisible by slice_frame"
assert C == 3, "Input must be [N, H, W, C] with C=3"
logger.info(f"Publishing video [{N}x{width}x{height}], audio: [{M}]")
audio_frames = round(M * self.fps / self.sample_rate)
if audio_frames != N:
logger.warning(f"Video and audio frames mismatch, {N} vs {audio_frames}")
self.set_video_size(width, height)
audio_datas, image_datas = self.convert_data(audios, images)
self.queue.put((audio_datas, image_datas))
logger.info(f"Published {N} frames and {M} audio samples")
self.stoppable_t = time.time() + M / self.sample_rate + 3
# logger.info(f"Buffer stream images {images.shape} {audios.shape} {gen_video.shape}")
rets = []
for i in range(0, N, self.slice_frame):
end_frame = i + self.slice_frame
img = image_datas[i:end_frame]
aud = audio_datas[i * self.target_chunks_per_frame : end_frame * self.target_chunks_per_frame]
gen = gen_video[:, :, (end_frame - self.prev_frame) : end_frame]
rets.append((img, aud, gen))
with self.stream_buffer_lock:
origin_size = len(self.stream_buffer)
self.stream_buffer.extend(rets)
logger.info(f"Buffered {origin_size} + {len(rets)} = {len(self.stream_buffer)} stream segments")
def get_buffer_stream_size(self):
return len(self.stream_buffer)
def truncate_stream_buffer(self, size: int):
with self.stream_buffer_lock:
self.stream_buffer = self.stream_buffer[:size]
logger.info(f"Truncated stream buffer to {len(self.stream_buffer)} segments")
if len(self.stream_buffer) > 0:
return self.stream_buffer[-1][2] # return the last video tensor
else:
return None
def schedule_stream_buffer(self):
schedule_interval = self.slice_frame / self.fps
logger.info(f"Schedule stream buffer with interval: {schedule_interval} seconds")
t = None
while True:
try:
if self.stop_schedule:
break
img, aud, gen = None, None, None
with self.stream_buffer_lock:
if len(self.stream_buffer) > 0:
img, aud, gen = self.stream_buffer.pop(0)
if t is not None:
wait_secs = schedule_interval - (time.time() - t)
if wait_secs > 0:
time.sleep(wait_secs)
t = time.time()
if img is not None and aud is not None:
self.queue.put((aud, img))
# logger.info(f"Scheduled {img.shape[0]} frames and {aud.shape[0]} audio samples to publish")
del gen
self.stoppable_t = time.time() + img.shape[0] / self.fps + 3
else:
logger.warning(f"No stream buffer to schedule")
except Exception:
logger.error(f"Schedule stream buffer error: {traceback.format_exc()}")
break
logger.info("Schedule stream buffer thread stopped")
def stop(self, wait=True):
if wait and self.stoppable_t:
......@@ -142,6 +213,12 @@ class X264VARecorder:
time.sleep(t)
self.stoppable_t = None
if self.schedule_thread:
self.stop_schedule = True
self.schedule_thread.join(timeout=5)
if self.schedule_thread and self.schedule_thread.is_alive():
logger.error(f"Schedule thread did not stop after 5s")
# Send stop signals to queues
if self.queue:
self.queue.put(None)
......@@ -219,7 +296,7 @@ if __name__ == "__main__":
cur_audio_array = np.zeros(int(interval * sample_rate), dtype=np.float32)
num_frames = int(interval * fps)
images = create_simple_video(num_frames, height, width)
recorder.pub_livestream(images, torch.tensor(cur_audio_array, dtype=torch.float32))
recorder.buffer_stream(images, torch.tensor(cur_audio_array, dtype=torch.float32), images)
i += interval
time.sleep(interval - (time.time() - t0))
......@@ -237,7 +314,7 @@ if __name__ == "__main__":
if started:
logger.warning(f"start pub_livestream !!!!!!!!!!!!!!!!!!!!!!!")
started = False
recorder.pub_livestream(images, cur_audio_array)
recorder.buffer_stream(images, cur_audio_array, images)
i += interval
time.sleep(interval - (time.time() - t0))
......
import argparse
import asyncio
import base64
import copy
import json
import mimetypes
import os
......@@ -1030,13 +1031,17 @@ async def api_v1_share_get(share_id: str):
@app.get("/api/v1/voices/list")
async def api_v1_voices_list():
async def api_v1_voices_list(request: Request):
try:
version = request.query_params.get("version", "all")
if volcengine_tts_client is None:
return error_response("Volcengine TTS client not loaded", 500)
voices = volcengine_tts_client.get_voice_list()
if voices is None:
return error_response("No voice list found", 404)
if version != "all":
voices = copy.deepcopy(voices)
voices["voices"] = [v for v in voices["voices"] if v["version"] == version]
return voices
except Exception as e:
traceback.print_exc()
......
......@@ -34,7 +34,7 @@ HEADERS = {"Authorization": f"Bearer {WORKER_SECRET_KEY}", "Content-Type": "appl
STOPPED = False
WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))
RANK = int(os.environ.get("RANK", 0))
TARGET_RANK = WORLD_SIZE - 1
TARGET_RANK = int(os.getenv("WORKER_RANK", "0")) % WORLD_SIZE
async def ping_life(server_url, worker_identity, keys):
......@@ -251,14 +251,17 @@ async def main(args):
logger.warning("Main loop cancelled, do not shut down")
finally:
try:
if ping_task:
ping_task.cancel()
await sync_subtask()
except Exception:
logger.warning(f"Sync subtask failed: {traceback.format_exc()}")
if RANK == TARGET_RANK and sub["task_id"] in RUNNING_SUBTASKS:
try:
await report_task(status=status, **sub)
except: # noqa
except Exception:
logger.warning(f"Report failed: {traceback.format_exc()}")
if ping_task:
ping_task.cancel()
await sync_subtask()
async def shutdown(loop):
......
......@@ -40,8 +40,8 @@ class BaseWorker:
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
set_parallel_config(config)
# same as va_recorder rank and worker main ping rank
self.out_video_rank = self.world_size - 1
# same as va_recorder rank
self.out_video_rank = int(os.getenv("RECORDER_RANK", "0")) % self.world_size
torch.set_grad_enabled(False)
self.runner = RUNNER_REGISTER[config["model_cls"]](config)
self.input_info = set_input_info(args)
......
import os
from abc import ABC
import torch
......@@ -139,19 +140,30 @@ class BaseRunner(ABC):
if dist.is_initialized():
rank = dist.get_rank()
world_size = dist.get_world_size()
signal_rank = world_size - 1
stop_rank = int(os.getenv("WORKER_RANK", "0")) % world_size # same as worker hub target_rank
pause_rank = int(os.getenv("READER_RANK", "0")) % world_size # same as va_reader target_rank
stopped = 0
if rank == signal_rank and hasattr(self, "stop_signal") and self.stop_signal:
stopped, paused = 0, 0
if rank == stop_rank and hasattr(self, "stop_signal") and self.stop_signal:
stopped = 1
if rank == pause_rank and hasattr(self, "pause_signal") and self.pause_signal:
paused = 1
if world_size > 1:
if rank == signal_rank:
t = torch.tensor([stopped], dtype=torch.int32).to(device=AI_DEVICE)
if rank == stop_rank:
t1 = torch.tensor([stopped], dtype=torch.int32).to(device=AI_DEVICE)
else:
t = torch.zeros(1, dtype=torch.int32, device=AI_DEVICE)
dist.broadcast(t, src=signal_rank)
stopped = t.item()
t1 = torch.zeros(1, dtype=torch.int32, device=AI_DEVICE)
if rank == pause_rank:
t2 = torch.tensor([paused], dtype=torch.int32).to(device=AI_DEVICE)
else:
t2 = torch.zeros(1, dtype=torch.int32, device=AI_DEVICE)
dist.broadcast(t1, src=stop_rank)
dist.broadcast(t2, src=pause_rank)
stopped = t1.item()
paused = t2.item()
if stopped == 1:
raise Exception(f"find rank: {rank} stop_signal, stop running, it's an expected behavior")
if paused == 1:
raise Exception(f"find rank: {rank} pause_signal, pause running, it's an expected behavior")
......@@ -18,9 +18,7 @@ from loguru import logger
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import resize
from lightx2v.deploy.common.va_reader import VAReader
from lightx2v.deploy.common.va_recorder import VARecorder
from lightx2v.deploy.common.va_recorder_x264 import X264VARecorder
from lightx2v.deploy.common.va_controller import VAController
from lightx2v.models.input_encoders.hf.seko_audio.audio_adapter import AudioAdapter
from lightx2v.models.input_encoders.hf.seko_audio.audio_encoder import SekoAudioEncoderModel
from lightx2v.models.networks.wan.audio_model import WanAudioModel
......@@ -694,15 +692,15 @@ class WanAudioRunner(WanRunner): # type:ignore
)
if "video_super_resolution" in self.config and self.vsr_model is not None:
logger.info(f"Applying video super resolution with scale {self.config['video_super_resolution']['scale']}")
# logger.info(f"Applying video super resolution with scale {self.config['video_super_resolution']['scale']}")
video_seg = self.vsr_model.super_resolve_frames(
video_seg,
seed=self.config["video_super_resolution"]["seed"],
scale=self.config["video_super_resolution"]["scale"],
)
if self.va_recorder:
self.va_recorder.pub_livestream(video_seg, audio_seg)
if self.va_controller.recorder is not None:
self.va_controller.pub_livestream(video_seg, audio_seg, self.gen_video[:, :, :useful_length])
elif self.input_info.return_result_tensor:
self.gen_video_final[self.segment.start_frame : self.segment.end_frame].copy_(video_seg)
self.cut_audio_final[self.segment.start_frame * self._audio_processor.audio_frame_rate : self.segment.end_frame * self._audio_processor.audio_frame_rate].copy_(audio_seg)
......@@ -721,85 +719,35 @@ class WanAudioRunner(WanRunner): # type:ignore
world_size = dist.get_world_size()
return rank, world_size
def init_va_recorder(self):
output_video_path = self.input_info.save_result_path
self.va_recorder = None
if isinstance(output_video_path, dict):
output_video_path = output_video_path["data"]
logger.info(f"init va_recorder with output_video_path: {output_video_path}")
rank, world_size = self.get_rank_and_world_size()
if output_video_path and rank == world_size - 1:
record_fps = self.config.get("target_fps", 16)
audio_sr = self.config.get("audio_sr", 16000)
if "video_frame_interpolation" in self.config and self.vfi_model is not None:
record_fps = self.config["video_frame_interpolation"]["target_fps"]
whip_shared_path = os.getenv("WHIP_SHARED_LIB", None)
if whip_shared_path and output_video_path.startswith("http"):
self.va_recorder = X264VARecorder(
whip_shared_path=whip_shared_path,
livestream_url=output_video_path,
fps=record_fps,
sample_rate=audio_sr,
)
else:
self.va_recorder = VARecorder(
livestream_url=output_video_path,
fps=record_fps,
sample_rate=audio_sr,
)
def init_va_reader(self):
audio_path = self.input_info.audio_path
self.va_reader = None
if isinstance(audio_path, dict):
assert audio_path["type"] == "stream", f"unexcept audio_path: {audio_path}"
rank, world_size = self.get_rank_and_world_size()
target_fps = self.config.get("target_fps", 16)
max_num_frames = self.config.get("target_video_length", 81)
audio_sr = self.config.get("audio_sr", 16000)
prev_frames = self.config.get("prev_frame_length", 5)
self.va_reader = VAReader(
rank=rank,
world_size=world_size,
stream_url=audio_path["data"],
sample_rate=audio_sr,
segment_duration=max_num_frames / target_fps,
prev_duration=prev_frames / target_fps,
target_rank=1,
)
def run_main(self):
try:
self.init_va_recorder()
self.init_va_reader()
logger.info(f"init va_recorder: {self.va_recorder} and va_reader: {self.va_reader}")
self.va_controller = VAController(self)
logger.info(f"init va_recorder: {self.va_controller.recorder} and va_reader: {self.va_controller.reader}")
if self.va_reader is None:
# fixed audio segments inputs
if self.va_controller.reader is None:
return super().run_main()
self.va_reader.start()
rank, world_size = self.get_rank_and_world_size()
if rank == world_size - 1:
assert self.va_recorder is not None, "va_recorder is required for stream audio input for rank 2"
self.va_recorder.start(self.input_info.target_shape[1], self.input_info.target_shape[0])
if world_size > 1:
dist.barrier()
self.va_controller.start()
self.init_run()
if self.config.get("compile", False):
if self.config.get("compile", False) and hasattr(self.model, "comple"):
self.model.select_graph_for_compile(self.input_info)
self.video_segment_num = "unlimited"
fetch_timeout = self.va_reader.segment_duration + 1
# steam audio input, video segment num is unlimited
self.video_segment_num = 1000000
segment_idx = 0
fail_count = 0
max_fail_count = 10
fail_count, max_fail_count = 0, 10
self.va_controller.before_control()
while True:
with ProfilingContext4DebugL1(f"stream segment get audio segment {segment_idx}"):
self.check_stop()
audio_array = self.va_reader.get_audio_segment(timeout=fetch_timeout)
control = self.va_controller.next_control()
if control.action == "immediate":
self.prev_video = control.data
elif control.action == "wait":
time.sleep(0.01)
continue
audio_array = self.va_controller.reader.get_audio_segment()
if audio_array is None:
fail_count += 1
logger.warning(f"Failed to get audio chunk {fail_count} times")
......@@ -808,22 +756,27 @@ class WanAudioRunner(WanRunner): # type:ignore
continue
with ProfilingContext4DebugL1(f"stream segment end2end {segment_idx}"):
fail_count = 0
self.init_run_segment(segment_idx, audio_array)
latents = self.run_segment(segment_idx)
self.gen_video = self.run_vae_decoder(latents)
self.end_run_segment(segment_idx)
segment_idx += 1
try:
# reset pause signal
self.pause_signal = False
self.init_run_segment(segment_idx, audio_array)
self.check_stop()
latents = self.run_segment(segment_idx)
self.check_stop()
self.gen_video = self.run_vae_decoder(latents)
self.check_stop()
self.end_run_segment(segment_idx)
segment_idx += 1
fail_count = 0
except Exception as e:
if "pause_signal, pause running" in str(e):
logger.warning(f"model infer audio pause: {e}, should continue")
else:
raise
finally:
if hasattr(self.model, "inputs"):
self.end_run()
if self.va_reader:
self.va_reader.stop()
self.va_reader = None
if self.va_recorder:
self.va_recorder.stop()
self.va_recorder = None
self.va_controller.clear()
@ProfilingContext4DebugL1("Process after vae decoder")
def process_images_after_vae_decoder(self):
......
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["DTYPE"] = "BF16"
os.environ["SENSITIVE_LAYER_DTYPE"] = "None"
os.environ["PROFILING_DEBUG_LEVEL"] = "2"
# please do not set envs in this file, it will be imported by the __init__.py file
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# os.environ["DTYPE"] = "BF16"
# os.environ["SENSITIVE_LAYER_DTYPE"] = "None"
# os.environ["PROFILING_DEBUG_LEVEL"] = "2"
import json
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment