import datetime import json import os import random import subprocess import threading import time import traceback from collections import deque from copy import deepcopy import jsonschema import numpy as np import torch import torch.distributed as dist import zmq from bson import BSON from loguru import logger from scipy.signal import resample class AudioInfo: def __init__(self, info: dict): self.sample_count = info["sample_count"] self.sample_rate = info["sample_rate"] self.channel_count = info["channel_count"] self.sample_fmt = info["sample_fmt"] self.pts = info["pts"] def is_spec_equal(self, other: "AudioInfo") -> bool: return self.sample_fmt == other.sample_fmt and self.sample_rate == other.sample_rate and self.channel_count == other.channel_count def duration(self) -> datetime.timedelta: return datetime.timedelta(seconds=self.sample_count / self.sample_rate) def __str__(self): return "AudioInfo(sample_count={}, sample_rate={}, channel_count={}, sample_fmt={}, pts={})".format(self.sample_count, self.sample_rate, self.channel_count, self.sample_fmt, self.pts) class ByteBuffer: def __init__(self): self.buffer = deque() self.current_size = 0 # is the audio belonging to current turn finished self.audio_finished = False def add(self, byte_data: bytes): self.buffer.append(byte_data) self.current_size += len(byte_data) def get(self, size=1024): data = bytearray() while size > 0 and len(self.buffer) > 0: chunk = self.buffer.popleft() if len(chunk) <= size: # 如果当前数据小于size,则将当前数据全部添加到data中 data.extend(chunk) self.current_size -= len(chunk) size -= len(chunk) else: # 如果当前数据大于size,则将当前数据的一部分添加到data中,剩余部分留在缓冲区 data.extend(chunk[:size]) self.buffer.appendleft(chunk[size:]) # 剩余部分留在缓冲区 self.current_size -= size size = 0 return bytes(data) def mark_finished(self): self.audio_finished = True def has_more_voice(self): return not self.audio_finished def __len__(self): return self.current_size class ChatAdapter: def __init__( self, omni_work_dir: str, whep_url: str, session_id: str, account: str, config_files: list[str], config_schema_path: str, seg_duration: float, model_runner, huoshan_tts_voice_type, ): assert os.path.exists(omni_work_dir), f"OMNI work directory {omni_work_dir} does not exist" self.omni_work_dir = omni_work_dir self.context = zmq.Context() self.w2f_socket = self.context.socket(zmq.PULL) self.w2f_url = ChatAdapter.select_and_bind(self.w2f_socket) self.f2w_socket = self.context.socket(zmq.PUSH) self.f2w_url = ChatAdapter.select_and_bind(self.f2w_socket) self.recv_thread = None self.audio_buffer = ByteBuffer() self.audio_info = None self.chat_server_cmd = [ os.path.join(self.omni_work_dir, "bin", "seko-chatter"), "--session-id", session_id, "--account", account, "--whep-server-url", whep_url, "--w2f-endpoint", self.w2f_url, "--f2w-endpoint", self.f2w_url, "--config-files", *config_files, ] override_config = {} if huoshan_tts_voice_type is not None: logger.info(f"Use Huoshan TTS voice type: {huoshan_tts_voice_type}") override_config["TTS"] = { "default_voice_info": { "voice_type": huoshan_tts_voice_type, "provider": "huoshan_stream_tts", } } with open(config_schema_path, "r") as f: schema = json.load(f) jsonschema.validate(instance=override_config, schema=schema) if override_config is not None: self.chat_server_cmd.extend(["--override-config", json.dumps(override_config)]) self.chatter_proc = None self.seg_duration = seg_duration self.reset_prev = False self.status = "blank" self.immediate_switch = 0 self.model_runner = model_runner def launch_chat_server(self): env = { "RUST_LOG": "info,duplex_server=debug,backend_5o=debug", "LD_LIBRARY_PATH": os.environ.get("LD_LIBRARY_PATH", "") + ":" + os.path.join(self.omni_work_dir, "lib/"), "PATH": os.environ["PATH"] + ":" + os.path.join(self.omni_work_dir, "bin/"), } self.chatter_proc = subprocess.Popen(self.chat_server_cmd, env=env, cwd=self.omni_work_dir) @staticmethod def select_and_bind(socket: zmq.Socket) -> str: # randomly select a port between 1024 and 6553 retry_count = 20 err = None while retry_count > 0: try: port = random.randint(1024, 65535) # port = 5555 url = f"tcp://localhost:{port}" socket.bind(url) return url except zmq.error.ZMQError as e: retry_count -= 1 err = e raise err # immediate switch to status, discard prev_bytes, set immediate_switch to 1 def immediate_switch_to(self, status): logger.warning(f"VA reader immediate switch to {status}") self.reset_prev = True self.status = status self.immediate_switch = 1 if self.model_runner is not None: self.model_runner.pause_signal = True logger.warning(f"Model runner pause signal set to True") def recv_loop(self): while True: try: message = self.w2f_socket.recv() except Exception: logger.error(f"Error receiving message: {traceback.format_exc()}") break try: message = BSON.decode(message) msg_type = message["type"] logger.debug("Received message type: {}".format(msg_type)) if msg_type == "AgentAudio": audio = message["audio"] if audio["type"] != "Pcm": logger.error("Unsupported audio type: {}".format(audio["type"])) continue pcm_data = audio["data"] audio_info = AudioInfo(audio["info"]) logger.debug("Received audio with duration: {}".format(audio_info.duration())) if self.audio_info is None: self.audio_info = audio_info else: # check if the audio info is the same if not self.audio_info.is_spec_equal(audio_info): raise ValueError("Audio info mismatch") self.audio_buffer.add(pcm_data) # if status is blank and has voice, set immediate switch to 1 if self.status == "blank" and self.has_voice(self.seg_duration): self.immediate_switch_to("voice") elif msg_type == "AgentStartPlay": logger.debug("Received AgentStartPlay, create new audio buffer") self.audio_buffer = ByteBuffer() elif msg_type == "AgentEndPlay": logger.debug("Received AgentEndPlay, mark audio finished") self.audio_buffer.mark_finished() elif msg_type == "ClearAgentAudio": logger.warning("Received ClearAgentAudio, clear audio buffer") self.audio_buffer = None self.audio_info = None if self.status == "voice": self.status = "blank" # self.immediate_switch_to("blank") except Exception as e: logger.error("Error decoding message: {}, continue".format(e)) continue logger.warning("recv loop interrupted") def start(self): self.launch_chat_server() self.recv_thread = threading.Thread(target=self.recv_loop) self.recv_thread.start() def has_voice(self, duration) -> bool: if self.audio_info is None or self.audio_buffer.current_size == 0: return False bytes_count = round(duration * self.audio_info.sample_rate) * self.audio_info.channel_count * 2 # S16LE assumed # if not has enough bytes and maybe has more voice, return False if self.audio_buffer.current_size < bytes_count and self.audio_buffer.has_more_voice(): logger.warning(f"Not enough bytes and maybe has more voice, content_size: {self.audio_buffer.current_size}, bytes_count: {bytes_count}") return False return bytes_count def get_audio(self, fetch_duration) -> (bytes, AudioInfo): bytes_count = self.has_voice(fetch_duration) if bytes_count is False: return None pcm_data = self.audio_buffer.get(bytes_count) # the actual sample count fetched sample_count = len(pcm_data) // (self.audio_info.channel_count * 2) logger.debug("Fetched {} bytes audio".format(sample_count)) logger.debug("After fetch, there are {} bytes left".format(self.audio_buffer.current_size)) audio_info = deepcopy(self.audio_info) audio_info.sample_count = sample_count return (pcm_data, audio_info) def stop(self): self.model_runner = None if self.chatter_proc is not None: self.chatter_proc.terminate() self.chatter_proc.wait() self.chatter_proc = None self.w2f_socket.close() self.f2w_socket.close() def __del__(self): self.stop() class OmniVAReader: def __init__( self, rank: int, world_size: int, stream_url: str, segment_duration: float = 5.0625, sample_rate: int = 16000, audio_channels: int = 1, buffer_size: int = 1, prev_duration: float = 0.3125, target_rank: int = 0, model_runner=None, huoshan_tts_voice_type=None, ): self.rank = rank self.world_size = world_size self.stream_url = stream_url self.segment_duration = segment_duration self.sample_rate = sample_rate self.audio_channels = audio_channels self.prev_duration = prev_duration self.all_seg_sample_count = int(self.segment_duration * self.sample_rate) self.prev_seg_sample_count = int(self.prev_duration * self.sample_rate) self.prev_seg_chunk = None self.target_rank = target_rank % self.world_size self.flag_tensor = torch.tensor([0], dtype=torch.int32).to(device="cuda") self.immediate_switch_tensor = torch.tensor([0], dtype=torch.int32).to(device="cuda") chunk_size = int(self.segment_duration * self.sample_rate) * 2 self.audio_tensor = torch.zeros(chunk_size, dtype=torch.uint8, device="cuda") self.chat_adapter = None self.model_runner = model_runner self.huoshan_tts_voice_type = huoshan_tts_voice_type assert self.audio_channels == 1, "Only mono audio is supported for OmniVAReader" logger.info(f"VAReader initialized for stream: {stream_url} target_rank: {self.target_rank}") logger.info(f"Audio duration per chunk: {segment_duration}s, sample rate: {sample_rate}Hz") def init_omni_env(self): self.omni_work_dir = os.getenv("OMNI_WORK_DIR", "/path/of/seko_chatter/") self.session_id = os.getenv("OMNI_SESSION_ID", "") self.account = os.getenv("OMNI_ACCOUNT", "") self.config_files = os.getenv("OMNI_CONFIG_FILES", "").split(",") self.config_schema_path = os.getenv("OMNI_CONFIG_SCHEMA_PATH", None) assert os.path.exists(self.omni_work_dir), f"OMNI work directory {self.omni_work_dir} does not exist" assert self.session_id and self.account, "OMNI_SESSION_ID and OMNI_ACCOUNT are required" logger.info( f"OMNI work directory: {self.omni_work_dir}, session_id: {self.session_id}, account: {self.account}, config_files: {self.config_files}, config_schema_path: {self.config_schema_path}" ) def start(self): if self.rank == self.target_rank: self.init_omni_env() assert self.stream_url.startswith("http"), "Only HTTP stream is supported for OmniVAReader" self.chat_adapter = ChatAdapter( omni_work_dir=self.omni_work_dir, whep_url=self.stream_url, session_id=self.session_id, account=self.account, config_files=self.config_files, config_schema_path=self.config_schema_path, seg_duration=self.segment_duration, model_runner=self.model_runner, huoshan_tts_voice_type=self.huoshan_tts_voice_type, ) self.chat_adapter.start() logger.info(f"OmniVAReader {self.rank}/{self.world_size} started successfully") else: logger.info(f"OmniVAReader {self.rank}/{self.world_size} wait only") if self.world_size > 1: logger.info(f"OmniVAReader {self.rank}/{self.world_size} wait barrier") dist.barrier() logger.info(f"OmniVAReader {self.rank}/{self.world_size} end barrier") def braodcast_audio_data(self, audio_data): if self.rank == self.target_rank: if audio_data is None: self.flag_tensor.fill_(0) else: self.flag_tensor.fill_(1) self.audio_tensor.copy_(torch.frombuffer(bytearray(audio_data), dtype=torch.uint8)) # logger.info(f"rank {self.rank} send audio_tensor: {self.audio_tensor.shape}") dist.broadcast(self.flag_tensor, src=self.target_rank) if self.flag_tensor.item() == 0: return None dist.broadcast(self.audio_tensor, src=self.target_rank) if self.rank != self.target_rank: # logger.info(f"rank {self.rank} recv audio_tensor: {self.audio_tensor.shape}") audio_data = self.audio_tensor.cpu().numpy().tobytes() return audio_data def bytes_to_ndarray(self, audio_data): if audio_data is None: return None audio_data = np.frombuffer(audio_data, dtype=np.int16) audio_data = audio_data.astype(np.float32) / 32768.0 # logger.info(f"Got segment audio rank={self.rank}: {audio_data.shape} {audio_data.dtype} {audio_data.min()} {audio_data.max()}") return audio_data def convert_pcm_s16le_to_mono_resampled(self, audio_data, audio_info): audio = np.frombuffer(audio_data, dtype=np.int16) sample_count = audio_info.sample_count assert len(audio) == sample_count * audio_info.channel_count, f"audio length {len(audio)} != sample_count * channel_count {sample_count * audio_info.channel_count}" # convert to mono if audio_info.channel_count > 1: audio = audio.reshape(-1, audio_info.channel_count).mean(axis=1) # logger.info(f"audio: {audio.shape} {audio.dtype} {audio.min()} {audio.max()}") if audio_info.sample_rate != self.sample_rate: sample_count = int(len(audio) * self.sample_rate / audio_info.sample_rate) audio = resample(audio, sample_count).astype(np.int16) # logger.info(f"resampled audio: {audio.shape} {audio.dtype} {audio.min()} {audio.max()} {sample_count}") logger.warning(f"valid audio: {audio.shape} {audio.dtype} {audio.min()} {audio.max()} {sample_count}") return audio, sample_count def prepare_audio_data(self, chat_audio_result): sample_count = 0 audio = np.array([], dtype=np.int16) # convert chat audio result to mono and target sample rate if chat_audio_result is not None: audio_data, audio_info = chat_audio_result audio, sample_count = self.convert_pcm_s16le_to_mono_resampled(audio_data, audio_info) # if is not the first segment, concat with previous segment if self.prev_seg_chunk is not None: audio = np.concatenate([self.prev_seg_chunk, audio]) sample_count = len(audio) assert sample_count <= self.all_seg_sample_count, f"audio length {sample_count} > all_seg_sample_count {self.all_seg_sample_count}" # pad 0 to the audio to make it the same length as all_seg_sample_count if sample_count < self.all_seg_sample_count: pad_count = self.all_seg_sample_count - sample_count # logger.info(f"pad {pad_count} samples to audio") audio = np.pad(audio, (0, pad_count), mode="constant", constant_values=0) sample_count = len(audio) # update prev seg chunk self.prev_seg_chunk = audio[-self.prev_seg_sample_count :] # logger.info(f"audio: {audio.shape} {audio.dtype} {audio.min()} {audio.max()} {sample_count}, prev seg chunk: {self.prev_seg_chunk.shape}") return audio.tobytes() def get_fetch_duration(self): fetch_duration = self.segment_duration # after immediate switch, reset prev seg chunk if self.chat_adapter.reset_prev: self.prev_seg_chunk = None self.chat_adapter.reset_prev = False logger.warning(f"Reset prev seg chunk") # first segment, fetch segment_duration, else fetch segment_duration - prev_duration if self.prev_seg_chunk is not None: fetch_duration -= self.prev_duration return fetch_duration def get_audio_segment(self): audio_data = None if self.rank == self.target_rank: try: fetch_duration = self.get_fetch_duration() # logger.info(f"Get segment, fetch_duration: {fetch_duration}") if self.chat_adapter.status == "voice": audio_result = self.chat_adapter.get_audio(fetch_duration) audio_data = self.prepare_audio_data(audio_result) # think all voice segments inferred, naturally switch to blank if audio_result is None: logger.info(f"Think all voice segments inferred, naturally switch to blank") self.chat_adapter.status = "blank" else: audio_data = self.prepare_audio_data(None) except Exception as e: logger.warning(f"Failed to get voice segment: {e}") return None if self.world_size > 1: audio_data = self.braodcast_audio_data(audio_data) audio_data = self.bytes_to_ndarray(audio_data) return audio_data def get_immediate_switch(self): if self.rank == self.target_rank: if self.chat_adapter.immediate_switch == 1: self.immediate_switch_tensor.fill_(1) # reset immediate switch self.chat_adapter.immediate_switch = 0 else: self.immediate_switch_tensor.fill_(0) dist.broadcast(self.immediate_switch_tensor, src=self.target_rank) immediate_switch = self.immediate_switch_tensor.item() return immediate_switch def stop(self): self.model_runner = None if self.chat_adapter is not None: self.chat_adapter.stop() self.chat_adapter = None logger.warning("OmniVAReader stopped") def __del__(self): self.stop() if __name__ == "__main__": WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1)) RANK = int(os.environ.get("RANK", 0)) if WORLD_SIZE > 1: dist.init_process_group(backend="nccl") torch.cuda.set_device(dist.get_rank()) logger.info(f"Distributed initialized: rank={RANK}, world_size={WORLD_SIZE}") reader = OmniVAReader( RANK, WORLD_SIZE, "https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whep/?app=publish&stream=test_stream_ll&eip=10.120.114.82:8000", segment_duration=17 / 16, sample_rate=16000, audio_channels=1, prev_duration=1 / 16, ) reader.start() fail_count = 0 max_fail_count = 100000000 try: while True: audio_data = reader.get_audio_segment(timeout=1) if audio_data is not None: logger.info(f"Got audio chunk, shape: {audio_data.shape}, range: [{audio_data.min()}, {audio_data.max()}]") fail_count = 0 else: fail_count += 1 if fail_count > max_fail_count: logger.warning("Failed to get audio chunk, stop reader") reader.stop() break time.sleep(0.95) finally: reader.stop()