Unverified Commit cd777631 authored by LiangLiu's avatar LiangLiu Committed by GitHub
Browse files

Fix x264 recorder (#361)

parent 0379ae88
......@@ -11,7 +11,7 @@ from loguru import logger
from scipy.signal import resample
class VAX64Recorder:
class X264VARecorder:
def __init__(
self,
whip_shared_path: str,
......@@ -19,7 +19,7 @@ class VAX64Recorder:
fps: float = 16.0,
sample_rate: int = 16000,
):
assert livestream_url.startswith("http"), "VAX64Recorder only support whip http livestream"
assert livestream_url.startswith("http"), "X264VARecorder only support whip http livestream"
self.livestream_url = livestream_url
self.fps = fps
self.sample_rate = sample_rate
......@@ -57,11 +57,11 @@ class VAX64Recorder:
t0 = time.time()
cur_audio = audios[i * audio_chunk : (i + 1) * audio_chunk].flatten()
audio_ptr = cur_audio.ctypes.data_as(ctypes.POINTER(ctypes.c_int16))
self.whip_shared_lib.pushRawAudioFrame(self.whip_shared_handle, audio_ptr, audio_samples)
self.whip_shared_lib.pushWhipRawAudioFrame(self.whip_shared_handle, audio_ptr, audio_samples)
cur_video = images[i].flatten()
video_ptr = cur_video.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8))
self.whip_shared_lib.pushRawVideoFrame(self.whip_shared_handle, video_ptr, self.width, self.height)
self.whip_shared_lib.pushWhipRawVideoFrame(self.whip_shared_handle, video_ptr, self.width, self.height)
time.sleep(max(0, packet_secs - (time.time() - t0)))
fail_time = 0
......@@ -76,25 +76,25 @@ class VAX64Recorder:
finally:
logger.info("Audio push worker thread stopped")
def start_libx264_whip_shared_api(self):
def start_libx264_whip_shared_api(self, width: int, height: int):
self.whip_shared_lib = ctypes.CDLL(self.whip_shared_path)
# define function argtypes and restype
self.whip_shared_lib.initStream.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.c_int, ctypes.c_int]
self.whip_shared_lib.initStream.restype = ctypes.c_void_p
self.whip_shared_lib.initWhipStream.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int]
self.whip_shared_lib.initWhipStream.restype = ctypes.c_void_p
self.whip_shared_lib.pushRawAudioFrame.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_int16), ctypes.c_int]
self.whip_shared_lib.pushRawVideoFrame.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_uint8), ctypes.c_int, ctypes.c_int]
self.whip_shared_lib.pushWhipRawAudioFrame.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_int16), ctypes.c_int]
self.whip_shared_lib.pushWhipRawVideoFrame.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_uint8), ctypes.c_int, ctypes.c_int]
self.whip_shared_lib.destroyStream.argtypes = [ctypes.c_void_p]
self.whip_shared_lib.destroyWhipStream.argtypes = [ctypes.c_void_p]
whip_url = ctypes.c_char_p(self.livestream_url.encode("utf-8"))
self.whip_shared_handle = ctypes.c_void_p(self.whip_shared_lib.initStream(whip_url, 1, 1, 0))
self.whip_shared_handle = ctypes.c_void_p(self.whip_shared_lib.initWhipStream(whip_url, 1, 1, 0, width, height))
logger.info(f"WHIP shared API initialized with handle: {self.whip_shared_handle}")
def convert_data(self, audios, images):
# Convert audio data to 16-bit integer format
audio_datas = np.clip(np.round(audios * 32767), -32768, 32767).astype(np.int16)
audio_datas = torch.clamp(torch.round(audios * 32767), -32768, 32767).to(torch.int16).cpu().numpy().reshape(-1)
# Convert to numpy and scale to [0, 255], convert RGB to BGR for OpenCV/FFmpeg
image_datas = (images * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
......@@ -112,12 +112,12 @@ class VAX64Recorder:
return
self.width = width
self.height = height
self.start_libx264_whip_shared_api()
self.start_libx264_whip_shared_api(width, height)
self.worker_thread = threading.Thread(target=self.worker)
self.worker_thread.start()
# Publish ComfyUI Image tensor and audio tensor to livestream
def pub_livestream(self, images: torch.Tensor, audios: np.ndarray):
def pub_livestream(self, images: torch.Tensor, audios: torch.Tensor):
N, height, width, C = images.shape
M = audios.reshape(-1).shape[0]
assert C == 3, "Input must be [N, H, W, C] with C=3"
......@@ -154,7 +154,7 @@ class VAX64Recorder:
# Destroy WHIP shared API
if self.whip_shared_lib and self.whip_shared_handle:
self.whip_shared_lib.destroyStream(self.whip_shared_handle)
self.whip_shared_lib.destroyWhipStream(self.whip_shared_handle)
self.whip_shared_handle = None
self.whip_shared_lib = None
logger.warning("WHIP shared API destroyed")
......@@ -192,18 +192,18 @@ def create_simple_video(frames=10, height=480, width=640):
if __name__ == "__main__":
sample_rate = 16000
fps = 16
width = 352
width = 452
height = 352
recorder = VAX64Recorder(
whip_shared_path="/data/nvme0/liuliang1/lightx2v/test_deploy/test_whip_so/src/libagora_go_whip.so",
livestream_url="https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whip/?app=subscribe&stream=ll1&eip=10.120.114.82:8000",
recorder = X264VARecorder(
whip_shared_path="/data/nvme0/liuliang1/lightx2v/test_deploy/test_whip_so/0.1.1/go_whxp.so",
livestream_url="https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whip/?app=subscribe&stream=ll2&eip=10.120.114.82:8000",
fps=fps,
sample_rate=sample_rate,
)
recorder.start(width, height)
time.sleep(5)
# time.sleep(5)
audio_path = "/data/nvme0/liuliang1/lightx2v/test_deploy/media_test/mangzhong.wav"
audio_array, ori_sr = ta.load(audio_path)
audio_array = ta.functional.resample(audio_array.mean(0), orig_freq=ori_sr, new_freq=16000)
......@@ -219,7 +219,7 @@ if __name__ == "__main__":
cur_audio_array = np.zeros(int(interval * sample_rate), dtype=np.float32)
num_frames = int(interval * fps)
images = create_simple_video(num_frames, height, width)
recorder.pub_livestream(images, cur_audio_array)
recorder.pub_livestream(images, torch.tensor(cur_audio_array, dtype=torch.float32))
i += interval
time.sleep(interval - (time.time() - t0))
......@@ -230,7 +230,7 @@ if __name__ == "__main__":
t0 = time.time()
start = int(i * sample_rate)
end = int((i + interval) * sample_rate)
cur_audio_array = audio_array[start:end]
cur_audio_array = torch.tensor(audio_array[start:end], dtype=torch.float32)
num_frames = int(interval * fps)
images = create_simple_video(num_frames, height, width)
logger.info(f"{i} / {secs} s")
......
......@@ -19,7 +19,7 @@ from torchvision.transforms.functional import resize
from lightx2v.deploy.common.va_reader import VAReader
from lightx2v.deploy.common.va_recorder import VARecorder
from lightx2v.deploy.common.va_x64_recorder import VAX64Recorder
from lightx2v.deploy.common.va_recorder_x264 import X264VARecorder
from lightx2v.models.input_encoders.hf.seko_audio.audio_adapter import AudioAdapter
from lightx2v.models.input_encoders.hf.seko_audio.audio_encoder import SekoAudioEncoderModel
from lightx2v.models.networks.wan.audio_model import WanAudioModel
......@@ -682,7 +682,7 @@ class WanAudioRunner(WanRunner): # type:ignore
whip_shared_path = os.getenv("WHIP_SHARED_LIB", None)
if whip_shared_path and output_video_path.startswith("http"):
self.va_recorder = VAX64Recorder(
self.va_recorder = X264VARecorder(
whip_shared_path=whip_shared_path,
livestream_url=output_video_path,
fps=record_fps,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment