Unverified Commit cd777631 authored by LiangLiu's avatar LiangLiu Committed by GitHub
Browse files

Fix x264 recorder (#361)

parent 0379ae88
...@@ -11,7 +11,7 @@ from loguru import logger ...@@ -11,7 +11,7 @@ from loguru import logger
from scipy.signal import resample from scipy.signal import resample
class VAX64Recorder: class X264VARecorder:
def __init__( def __init__(
self, self,
whip_shared_path: str, whip_shared_path: str,
...@@ -19,7 +19,7 @@ class VAX64Recorder: ...@@ -19,7 +19,7 @@ class VAX64Recorder:
fps: float = 16.0, fps: float = 16.0,
sample_rate: int = 16000, sample_rate: int = 16000,
): ):
assert livestream_url.startswith("http"), "VAX64Recorder only support whip http livestream" assert livestream_url.startswith("http"), "X264VARecorder only support whip http livestream"
self.livestream_url = livestream_url self.livestream_url = livestream_url
self.fps = fps self.fps = fps
self.sample_rate = sample_rate self.sample_rate = sample_rate
...@@ -57,11 +57,11 @@ class VAX64Recorder: ...@@ -57,11 +57,11 @@ class VAX64Recorder:
t0 = time.time() t0 = time.time()
cur_audio = audios[i * audio_chunk : (i + 1) * audio_chunk].flatten() cur_audio = audios[i * audio_chunk : (i + 1) * audio_chunk].flatten()
audio_ptr = cur_audio.ctypes.data_as(ctypes.POINTER(ctypes.c_int16)) audio_ptr = cur_audio.ctypes.data_as(ctypes.POINTER(ctypes.c_int16))
self.whip_shared_lib.pushRawAudioFrame(self.whip_shared_handle, audio_ptr, audio_samples) self.whip_shared_lib.pushWhipRawAudioFrame(self.whip_shared_handle, audio_ptr, audio_samples)
cur_video = images[i].flatten() cur_video = images[i].flatten()
video_ptr = cur_video.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8)) video_ptr = cur_video.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8))
self.whip_shared_lib.pushRawVideoFrame(self.whip_shared_handle, video_ptr, self.width, self.height) self.whip_shared_lib.pushWhipRawVideoFrame(self.whip_shared_handle, video_ptr, self.width, self.height)
time.sleep(max(0, packet_secs - (time.time() - t0))) time.sleep(max(0, packet_secs - (time.time() - t0)))
fail_time = 0 fail_time = 0
...@@ -76,25 +76,25 @@ class VAX64Recorder: ...@@ -76,25 +76,25 @@ class VAX64Recorder:
finally: finally:
logger.info("Audio push worker thread stopped") logger.info("Audio push worker thread stopped")
def start_libx264_whip_shared_api(self): def start_libx264_whip_shared_api(self, width: int, height: int):
self.whip_shared_lib = ctypes.CDLL(self.whip_shared_path) self.whip_shared_lib = ctypes.CDLL(self.whip_shared_path)
# define function argtypes and restype # define function argtypes and restype
self.whip_shared_lib.initStream.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.c_int, ctypes.c_int] self.whip_shared_lib.initWhipStream.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int]
self.whip_shared_lib.initStream.restype = ctypes.c_void_p self.whip_shared_lib.initWhipStream.restype = ctypes.c_void_p
self.whip_shared_lib.pushRawAudioFrame.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_int16), ctypes.c_int] self.whip_shared_lib.pushWhipRawAudioFrame.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_int16), ctypes.c_int]
self.whip_shared_lib.pushRawVideoFrame.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_uint8), ctypes.c_int, ctypes.c_int] self.whip_shared_lib.pushWhipRawVideoFrame.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_uint8), ctypes.c_int, ctypes.c_int]
self.whip_shared_lib.destroyStream.argtypes = [ctypes.c_void_p] self.whip_shared_lib.destroyWhipStream.argtypes = [ctypes.c_void_p]
whip_url = ctypes.c_char_p(self.livestream_url.encode("utf-8")) whip_url = ctypes.c_char_p(self.livestream_url.encode("utf-8"))
self.whip_shared_handle = ctypes.c_void_p(self.whip_shared_lib.initStream(whip_url, 1, 1, 0)) self.whip_shared_handle = ctypes.c_void_p(self.whip_shared_lib.initWhipStream(whip_url, 1, 1, 0, width, height))
logger.info(f"WHIP shared API initialized with handle: {self.whip_shared_handle}") logger.info(f"WHIP shared API initialized with handle: {self.whip_shared_handle}")
def convert_data(self, audios, images): def convert_data(self, audios, images):
# Convert audio data to 16-bit integer format # Convert audio data to 16-bit integer format
audio_datas = np.clip(np.round(audios * 32767), -32768, 32767).astype(np.int16) audio_datas = torch.clamp(torch.round(audios * 32767), -32768, 32767).to(torch.int16).cpu().numpy().reshape(-1)
# Convert to numpy and scale to [0, 255], convert RGB to BGR for OpenCV/FFmpeg # Convert to numpy and scale to [0, 255], convert RGB to BGR for OpenCV/FFmpeg
image_datas = (images * 255).clamp(0, 255).to(torch.uint8).cpu().numpy() image_datas = (images * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
...@@ -112,12 +112,12 @@ class VAX64Recorder: ...@@ -112,12 +112,12 @@ class VAX64Recorder:
return return
self.width = width self.width = width
self.height = height self.height = height
self.start_libx264_whip_shared_api() self.start_libx264_whip_shared_api(width, height)
self.worker_thread = threading.Thread(target=self.worker) self.worker_thread = threading.Thread(target=self.worker)
self.worker_thread.start() self.worker_thread.start()
# Publish ComfyUI Image tensor and audio tensor to livestream # Publish ComfyUI Image tensor and audio tensor to livestream
def pub_livestream(self, images: torch.Tensor, audios: np.ndarray): def pub_livestream(self, images: torch.Tensor, audios: torch.Tensor):
N, height, width, C = images.shape N, height, width, C = images.shape
M = audios.reshape(-1).shape[0] M = audios.reshape(-1).shape[0]
assert C == 3, "Input must be [N, H, W, C] with C=3" assert C == 3, "Input must be [N, H, W, C] with C=3"
...@@ -154,7 +154,7 @@ class VAX64Recorder: ...@@ -154,7 +154,7 @@ class VAX64Recorder:
# Destroy WHIP shared API # Destroy WHIP shared API
if self.whip_shared_lib and self.whip_shared_handle: if self.whip_shared_lib and self.whip_shared_handle:
self.whip_shared_lib.destroyStream(self.whip_shared_handle) self.whip_shared_lib.destroyWhipStream(self.whip_shared_handle)
self.whip_shared_handle = None self.whip_shared_handle = None
self.whip_shared_lib = None self.whip_shared_lib = None
logger.warning("WHIP shared API destroyed") logger.warning("WHIP shared API destroyed")
...@@ -192,18 +192,18 @@ def create_simple_video(frames=10, height=480, width=640): ...@@ -192,18 +192,18 @@ def create_simple_video(frames=10, height=480, width=640):
if __name__ == "__main__": if __name__ == "__main__":
sample_rate = 16000 sample_rate = 16000
fps = 16 fps = 16
width = 352 width = 452
height = 352 height = 352
recorder = VAX64Recorder( recorder = X264VARecorder(
whip_shared_path="/data/nvme0/liuliang1/lightx2v/test_deploy/test_whip_so/src/libagora_go_whip.so", whip_shared_path="/data/nvme0/liuliang1/lightx2v/test_deploy/test_whip_so/0.1.1/go_whxp.so",
livestream_url="https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whip/?app=subscribe&stream=ll1&eip=10.120.114.82:8000", livestream_url="https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whip/?app=subscribe&stream=ll2&eip=10.120.114.82:8000",
fps=fps, fps=fps,
sample_rate=sample_rate, sample_rate=sample_rate,
) )
recorder.start(width, height) recorder.start(width, height)
time.sleep(5) # time.sleep(5)
audio_path = "/data/nvme0/liuliang1/lightx2v/test_deploy/media_test/mangzhong.wav" audio_path = "/data/nvme0/liuliang1/lightx2v/test_deploy/media_test/mangzhong.wav"
audio_array, ori_sr = ta.load(audio_path) audio_array, ori_sr = ta.load(audio_path)
audio_array = ta.functional.resample(audio_array.mean(0), orig_freq=ori_sr, new_freq=16000) audio_array = ta.functional.resample(audio_array.mean(0), orig_freq=ori_sr, new_freq=16000)
...@@ -219,7 +219,7 @@ if __name__ == "__main__": ...@@ -219,7 +219,7 @@ if __name__ == "__main__":
cur_audio_array = np.zeros(int(interval * sample_rate), dtype=np.float32) cur_audio_array = np.zeros(int(interval * sample_rate), dtype=np.float32)
num_frames = int(interval * fps) num_frames = int(interval * fps)
images = create_simple_video(num_frames, height, width) images = create_simple_video(num_frames, height, width)
recorder.pub_livestream(images, cur_audio_array) recorder.pub_livestream(images, torch.tensor(cur_audio_array, dtype=torch.float32))
i += interval i += interval
time.sleep(interval - (time.time() - t0)) time.sleep(interval - (time.time() - t0))
...@@ -230,7 +230,7 @@ if __name__ == "__main__": ...@@ -230,7 +230,7 @@ if __name__ == "__main__":
t0 = time.time() t0 = time.time()
start = int(i * sample_rate) start = int(i * sample_rate)
end = int((i + interval) * sample_rate) end = int((i + interval) * sample_rate)
cur_audio_array = audio_array[start:end] cur_audio_array = torch.tensor(audio_array[start:end], dtype=torch.float32)
num_frames = int(interval * fps) num_frames = int(interval * fps)
images = create_simple_video(num_frames, height, width) images = create_simple_video(num_frames, height, width)
logger.info(f"{i} / {secs} s") logger.info(f"{i} / {secs} s")
......
...@@ -19,7 +19,7 @@ from torchvision.transforms.functional import resize ...@@ -19,7 +19,7 @@ from torchvision.transforms.functional import resize
from lightx2v.deploy.common.va_reader import VAReader from lightx2v.deploy.common.va_reader import VAReader
from lightx2v.deploy.common.va_recorder import VARecorder from lightx2v.deploy.common.va_recorder import VARecorder
from lightx2v.deploy.common.va_x64_recorder import VAX64Recorder from lightx2v.deploy.common.va_recorder_x264 import X264VARecorder
from lightx2v.models.input_encoders.hf.seko_audio.audio_adapter import AudioAdapter from lightx2v.models.input_encoders.hf.seko_audio.audio_adapter import AudioAdapter
from lightx2v.models.input_encoders.hf.seko_audio.audio_encoder import SekoAudioEncoderModel from lightx2v.models.input_encoders.hf.seko_audio.audio_encoder import SekoAudioEncoderModel
from lightx2v.models.networks.wan.audio_model import WanAudioModel from lightx2v.models.networks.wan.audio_model import WanAudioModel
...@@ -682,7 +682,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -682,7 +682,7 @@ class WanAudioRunner(WanRunner): # type:ignore
whip_shared_path = os.getenv("WHIP_SHARED_LIB", None) whip_shared_path = os.getenv("WHIP_SHARED_LIB", None)
if whip_shared_path and output_video_path.startswith("http"): if whip_shared_path and output_video_path.startswith("http"):
self.va_recorder = VAX64Recorder( self.va_recorder = X264VARecorder(
whip_shared_path=whip_shared_path, whip_shared_path=whip_shared_path,
livestream_url=output_video_path, livestream_url=output_video_path,
fps=record_fps, fps=record_fps,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment