Commit e2778d0d authored by litzh's avatar litzh
Browse files

Initial commit

parents
Pipeline #3370 canceled with stages
# -*- coding: utf-8 -*-
"""
Audio Source Separation Module
Separates different voice tracks in audio, supports multi-person audio separation
"""
import base64
import io
import os
import tempfile
import traceback
from collections import defaultdict
from typing import Dict, Optional, Union
import torch
import torch.serialization
import torchaudio
from loguru import logger
# Import pyannote.audio for speaker diarization
from pyannote.audio import Audio, Pipeline
# Fix for PyTorch 2.6 compatibility: allow pyannote.audio classes in torch.load
# PyTorch 2.6 changed torch.load default to weights_only=True for security
try:
# Add safe globals for pyannote.audio classes
# This allows torch.load to work with pyannote.audio model files
from pyannote.audio.core.task import Specifications
torch.serialization.add_safe_globals([Specifications])
except (ImportError, AttributeError) as e:
# If pyannote.audio is not installed or class doesn't exist, log warning
# The actual error will be handled when Pipeline.from_pretrained is called
logger.debug(f"Could not add pyannote.audio safe globals (may need to use weights_only=False): {e}")
_origin_torch_load = torch.load
def our_torch_load(checkpoint_file, *args, **kwargs):
kwargs["weights_only"] = False
return _origin_torch_load(checkpoint_file, *args, **kwargs)
class AudioSeparator:
"""
Audio separator for separating different voice tracks in audio using pyannote.audio
Supports multi-person conversation separation, maintains duration (other speakers' tracks are empty)
"""
def __init__(
self,
model_path: str = None,
device: str = None,
sample_rate: int = 16000,
):
"""
Initialize audio separator
Args:
model_path: Model path (if using custom model), default uses pyannote/speaker-diarization-community-1
device: Device ('cpu', 'cuda', etc.), None for auto selection
sample_rate: Target sample rate, default 16000
"""
self.sample_rate = sample_rate
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
self._init_pyannote(model_path)
def _init_pyannote(self, model_path: str = None):
"""Initialize pyannote.audio pipeline"""
try:
huggingface_token = os.getenv("HUGGINGFACE_TOKEN") or os.getenv("HF_TOKEN")
model_name = model_path or "pyannote/speaker-diarization-community-1"
# Fix for PyTorch 2.6: use safe_globals context manager to allow pyannote.audio classes
# PyTorch 2.6 changed torch.load default to weights_only=True
try:
from pyannote.audio.core.task import Specifications
safe_globals_context = torch.serialization.safe_globals([Specifications])
except (ImportError, AttributeError):
# If Specifications class is not available, use empty context
safe_globals_context = torch.serialization.safe_globals([])
try:
torch.load = our_torch_load
# Try loading with token if available
if huggingface_token:
self.pipeline = Pipeline.from_pretrained(model_name, token=huggingface_token)
else:
# Try without token (may work for public models)
self.pipeline = Pipeline.from_pretrained(model_name)
except Exception as e:
if "gated" in str(e).lower() or "token" in str(e).lower():
raise RuntimeError(f"Model requires authentication. Set HUGGINGFACE_TOKEN or HF_TOKEN environment variable: {e}")
# If safe_globals didn't work, try with weights_only=False as fallback
if "weights_only" in str(e).lower() or "Unsupported global" in str(e):
logger.warning(f"PyTorch 2.6 compatibility issue detected, attempting fallback: {e}")
# Note: We can't directly control weights_only in Pipeline.from_pretrained,
# but the safe_globals should have worked. If not, the error will be raised.
raise RuntimeError(f"Failed to load pyannote model: {e}")
finally:
torch.load = _origin_torch_load
# Move pipeline to specified device
if self.device:
self.pipeline.to(torch.device(self.device))
# Initialize Audio helper for waveform loading
self.pyannote_audio = Audio()
logger.info("Initialized pyannote.audio speaker diarization pipeline")
except Exception as e:
logger.error(f"Failed to initialize pyannote: {e}")
raise RuntimeError(f"Failed to initialize pyannote.audio pipeline: {e}")
def separate_speakers(
self,
audio_path: Union[str, bytes],
num_speakers: Optional[int] = None,
min_speakers: int = 1,
max_speakers: int = 5,
) -> Dict:
"""
Separate different speakers in audio
Args:
audio_path: Audio file path or bytes data
num_speakers: Specified number of speakers, None for auto detection
min_speakers: Minimum number of speakers
max_speakers: Maximum number of speakers
Returns:
Dict containing:
- speakers: List of speaker audio segments, each containing:
- speaker_id: Speaker ID (0, 1, 2, ...)
- audio: torch.Tensor audio data [channels, samples]
- segments: List of (start_time, end_time) tuples
- sample_rate: Sample rate
"""
try:
# Load audio
if isinstance(audio_path, bytes):
# 尝试从字节数据推断音频格式
# 检查是否是 WAV 格式(RIFF 头)
is_wav = audio_path[:4] == b"RIFF" and audio_path[8:12] == b"WAVE"
# 检查是否是 MP3 格式(ID3 或 MPEG 头)
is_mp3 = audio_path[:3] == b"ID3" or audio_path[:2] == b"\xff\xfb" or audio_path[:2] == b"\xff\xf3"
# 根据格式选择后缀
if is_wav:
suffix = ".wav"
elif is_mp3:
suffix = ".mp3"
else:
# 默认尝试 WAV,如果失败会抛出错误
suffix = ".wav"
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp_file:
tmp_file.write(audio_path)
tmp_audio_path = tmp_file.name
try:
result = self._separate_speakers_internal(tmp_audio_path, num_speakers, min_speakers, max_speakers)
finally:
# 确保临时文件被删除
try:
os.unlink(tmp_audio_path)
except Exception as e:
logger.warning(f"Failed to delete temp file {tmp_audio_path}: {e}")
return result
else:
return self._separate_speakers_internal(audio_path, num_speakers, min_speakers, max_speakers)
except Exception as e:
logger.error(f"Speaker separation failed: {traceback.format_exc()}")
raise RuntimeError(f"Audio separation error: {e}")
def _separate_speakers_internal(
self,
audio_path: str,
num_speakers: Optional[int] = None,
min_speakers: int = 1,
max_speakers: int = 5,
) -> Dict:
"""Internal method: execute speaker separation"""
# Load audio
waveform, original_sr = torchaudio.load(audio_path)
if original_sr != self.sample_rate:
resampler = torchaudio.transforms.Resample(original_sr, self.sample_rate)
waveform = resampler(waveform)
# Convert to mono if stereo
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
# Ensure waveform is float32 and normalized (pyannote expects this format)
if waveform.dtype != torch.float32:
waveform = waveform.float()
# Ensure waveform is in range [-1, 1] (normalize if needed)
if waveform.abs().max() > 1.0:
waveform = waveform / waveform.abs().max()
if self.pipeline is None:
raise RuntimeError("Pyannote pipeline not initialized")
return self._separate_with_pyannote(audio_path, waveform, num_speakers, min_speakers, max_speakers)
def _separate_with_pyannote(
self,
audio_path: str,
waveform: torch.Tensor,
num_speakers: Optional[int],
min_speakers: int,
max_speakers: int,
) -> Dict:
"""Use pyannote.audio for speaker diarization"""
try:
# Use waveform dict to avoid AudioDecoder dependency issues
# Pipeline can accept either file path or waveform dict
# Using waveform dict is more reliable when torchcodec is not properly installed
audio_input = {
"waveform": waveform,
"sample_rate": self.sample_rate,
}
# Run speaker diarization
output = self.pipeline(
audio_input,
min_speakers=min_speakers if num_speakers is None else num_speakers,
max_speakers=max_speakers if num_speakers is None else num_speakers,
)
# Extract audio segments for each speaker
speakers_dict = defaultdict(list)
for turn, speaker in output.speaker_diarization:
print(f"Speaker: {speaker}, Start time: {turn.start}, End time: {turn.end}")
start_time = turn.start
end_time = turn.end
start_sample = int(start_time * self.sample_rate)
end_sample = int(end_time * self.sample_rate)
# Extract audio segment for this time period
segment_audio = waveform[:, start_sample:end_sample]
speakers_dict[speaker].append((start_time, end_time, segment_audio))
# Generate complete audio for each speaker (other speakers' segments are empty)
speakers = []
audio_duration = waveform.shape[1] / self.sample_rate
num_samples = waveform.shape[1]
for speaker_id, segments in speakers_dict.items():
# Create zero-filled audio
speaker_audio = torch.zeros_like(waveform)
# Fill in this speaker's segments
for start_time, end_time, segment_audio in segments:
start_sample = int(start_time * self.sample_rate)
end_sample = int(end_time * self.sample_rate)
# Ensure no out-of-bounds
end_sample = min(end_sample, num_samples)
segment_len = end_sample - start_sample
if segment_len > 0 and segment_audio.shape[1] > 0:
actual_len = min(segment_len, segment_audio.shape[1])
speaker_audio[:, start_sample : start_sample + actual_len] = segment_audio[:, :actual_len]
speakers.append(
{
"speaker_id": speaker_id,
"audio": speaker_audio,
"segments": [(s[0], s[1]) for s in segments],
"sample_rate": self.sample_rate,
}
)
logger.info(f"Separated audio into {len(speakers)} speakers using pyannote")
return {"speakers": speakers, "method": "pyannote"}
except Exception as e:
logger.error(f"Pyannote separation failed: {e}")
raise RuntimeError(f"Audio separation failed: {e}")
def save_speaker_audio(self, speaker_audio: torch.Tensor, output_path: str, sample_rate: int = None):
"""
Save speaker audio to file
Args:
speaker_audio: Audio tensor [channels, samples]
output_path: Output path
sample_rate: Sample rate, if None uses self.sample_rate
"""
sr = sample_rate if sample_rate else self.sample_rate
torchaudio.save(output_path, speaker_audio, sr)
logger.info(f"Saved speaker audio to {output_path}")
def speaker_audio_to_base64(self, speaker_audio: torch.Tensor, sample_rate: int = None, format: str = "wav") -> str:
"""
Convert speaker audio tensor to base64 encoded string without saving to file
Args:
speaker_audio: Audio tensor [channels, samples]
sample_rate: Sample rate, if None uses self.sample_rate
format: Audio format (default: "wav")
Returns:
Base64 encoded audio string
"""
sr = sample_rate if sample_rate else self.sample_rate
# Use BytesIO to save audio to memory
buffer = io.BytesIO()
torchaudio.save(buffer, speaker_audio, sr, format=format)
# Get the audio bytes
audio_bytes = buffer.getvalue()
# Encode to base64
audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
logger.debug(f"Converted speaker audio to base64, size: {len(audio_bytes)} bytes")
return audio_base64
def separate_and_save(
self,
audio_path: Union[str, bytes],
output_dir: str,
num_speakers: Optional[int] = None,
min_speakers: int = 1,
max_speakers: int = 5,
) -> Dict:
"""
Separate audio and save to files
Args:
audio_path: Input audio path or bytes data
output_dir: Output directory
num_speakers: Specified number of speakers
min_speakers: Minimum number of speakers
max_speakers: Maximum number of speakers
Returns:
Separation result dictionary, containing output file paths
"""
os.makedirs(output_dir, exist_ok=True)
result = self.separate_speakers(audio_path, num_speakers, min_speakers, max_speakers)
output_paths = []
for speaker in result["speakers"]:
speaker_id = speaker["speaker_id"]
output_path = os.path.join(output_dir, f"{speaker_id}.wav")
self.save_speaker_audio(speaker["audio"], output_path, speaker["sample_rate"])
output_paths.append(output_path)
speaker["output_path"] = output_path
result["output_paths"] = output_paths
return result
def separate_audio_tracks(
audio_path: str,
output_dir: str = None,
num_speakers: int = None,
model_path: str = None,
) -> Dict:
"""
Convenience function: separate different audio tracks
Args:
audio_path: Audio file path
output_dir: Output directory, if None does not save files
num_speakers: Number of speakers
model_path: Model path (optional)
Returns:
Separation result dictionary
"""
separator = AudioSeparator(model_path=model_path)
if output_dir:
return separator.separate_and_save(audio_path, output_dir, num_speakers=num_speakers)
else:
return separator.separate_speakers(audio_path, num_speakers=num_speakers)
if __name__ == "__main__":
# Test code
import sys
if len(sys.argv) < 2:
print("Usage: python audio_separator.py <audio_path> [output_dir] [num_speakers]")
sys.exit(1)
audio_path = sys.argv[1]
output_dir = sys.argv[2] if len(sys.argv) > 2 else "./separated_audio"
num_speakers = int(sys.argv[3]) if len(sys.argv) > 3 else None
separator = AudioSeparator()
result = separator.separate_and_save(audio_path, output_dir, num_speakers=num_speakers)
print(f"Separated audio into {len(result['speakers'])} speakers:")
for speaker in result["speakers"]:
print(f" Speaker {speaker['speaker_id']}: {len(speaker['segments'])} segments")
if "output_path" in speaker:
print(f" Saved to: {speaker['output_path']}")
import io
import os
import traceback
from typing import Dict, List, Union
import numpy as np
import torch
from PIL import Image, ImageDraw
from loguru import logger
from ultralytics import YOLO
# Try to import transformers for Grounding DINO
try:
from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor
TRANSFORMERS_AVAILABLE = True
except ImportError:
TRANSFORMERS_AVAILABLE = False
logger.warning("transformers not available, Grounding DINO method will not work")
class FaceDetector:
"""
Face detection using multiple methods
Supports three detection methods:
1. YOLO World (method='yolo'):
- Open-vocabulary detection
- Supports various face types: human, animal, anime, sketch
- More flexible but slower
- Can detect custom classes via text description
2. Grounding DINO (method='grounding'):
- Open-vocabulary object detection
- Supports various face types via text prompts
- Requires transformers library
- Good balance between accuracy and flexibility
"""
def __init__(
self,
method: str = "yolo",
model_path: str = None,
conf_threshold: float = None,
device: str = None,
custom_classes: List[str] = None,
cascade_path: str = None,
):
"""
Initialize face detector
Args:
method: Detection method. Options:
- "yolo": Use YOLO World (supports various face types)
- "grounding": Use Grounding DINO (requires transformers)
Default: "yolo"
model_path: YOLO World model path (only used when method="yolo")
If None, uses default YOLO World model
conf_threshold: Confidence threshold (only used when method="yolo")
If None, uses adaptive threshold based on classes
device: Device for YOLO ('cpu', 'cuda', '0', '1', etc.), None for auto
custom_classes: List of custom class names for YOLO World. Default: ["face"]
Examples: ["face"], ["animal face"], ["human face", "animal face"]
"""
self.method = method.lower()
self.device = device
if self.method == "yolo":
# Initialize YOLO World detector
# Set custom classes (default to "face")
if custom_classes is None:
custom_classes = ["human face", "animal face", "anime face", "sketch face"]
self.custom_classes = custom_classes
# Adaptive confidence threshold based on class specificity
if conf_threshold is None:
if len(custom_classes) > 1:
# Multiple classes: use lower threshold to catch all detections
conf_threshold = 0.1
elif len(custom_classes) == 1:
class_name = custom_classes[0].lower()
if "face" in class_name and class_name.strip() == "face":
# Generic "face" class: needs higher threshold but not too high
conf_threshold = 0.15
else:
# Specific class like "animal face": can use moderate threshold
conf_threshold = 0.15
else:
conf_threshold = 0.25
self.conf_threshold = conf_threshold
if model_path is None:
# Use YOLO World model for open-vocabulary detection
logger.info("Loading YOLO World model for face detection")
try:
# Try to load YOLO World small model first (lighter and faster)
self.model = YOLO("yolov8s-world.pt")
except Exception as e:
logger.warning(f"Failed to load yolov8s-world.pt, trying yolov8m-world.pt: {e}")
try:
self.model = YOLO("yolov8m-world.pt")
except Exception as e2:
logger.warning(f"Failed to load yolov8m-world.pt, trying yolov8l-world.pt: {e2}")
self.model = YOLO("yolov8l-world.pt")
# Set custom classes for YOLO World
# YOLO World can detect any object described in natural language
self.model.set_classes(self.custom_classes)
else:
logger.info(f"Loading YOLO World model from {model_path}")
self.model = YOLO(model_path)
logger.info(f"Face detector initialized with YOLO World, custom classes: {self.custom_classes}, confidence threshold: {self.conf_threshold}")
self.face_cascade = None
elif self.method == "grounding":
# Initialize Grounding DINO detector
if not TRANSFORMERS_AVAILABLE:
raise ImportError("transformers library is required for Grounding DINO. Install it with: pip install transformers torch")
# Set up proxy for HuggingFace model download
# Check if proxy is already set, if not try to use common proxy settings
if not os.environ.get("HTTP_PROXY") and not os.environ.get("http_proxy"):
# Try to use HTTPS_PROXY for HTTP requests as well if available
https_proxy = os.environ.get("HTTPS_PROXY") or os.environ.get("https_proxy")
if https_proxy:
os.environ["HTTP_PROXY"] = https_proxy
os.environ["http_proxy"] = https_proxy
logger.info(f"Using proxy from HTTPS_PROXY: {https_proxy}")
# Log proxy settings
http_proxy = os.environ.get("HTTP_PROXY") or os.environ.get("http_proxy")
https_proxy = os.environ.get("HTTPS_PROXY") or os.environ.get("https_proxy")
if http_proxy or https_proxy:
logger.info(f"Using proxy - HTTP: {http_proxy}, HTTPS: {https_proxy}")
# Set custom classes (default to "face")
if custom_classes is None:
custom_classes = ["human face", "animal face", "anime face", "sketch face"]
self.custom_classes = custom_classes
# Adaptive confidence threshold
if conf_threshold is None:
if len(custom_classes) > 1:
conf_threshold = 0.1
else:
conf_threshold = 0.3 # Grounding DINO typically needs higher threshold
self.conf_threshold = conf_threshold
# Load Grounding DINO model
model_id = "IDEA-Research/grounding-dino-base" # or "grounding-dino-tiny" for faster inference
if model_path is not None:
model_id = model_path
logger.info(f"Loading Grounding DINO model: {model_id}")
try:
# Grounding DINO requires trust_remote_code=True
self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
self.model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id, trust_remote_code=True)
if device:
self.model = self.model.to(device)
logger.info(f"Face detector initialized with Grounding DINO, custom classes: {self.custom_classes}, confidence threshold: {self.conf_threshold}")
except Exception as e:
error_msg = str(e)
if "connection" in error_msg.lower() or "proxy" in error_msg.lower() or "network" in error_msg.lower():
logger.error(f"Failed to download model. Please check your network connection and proxy settings.")
logger.error(f"Current proxy settings - HTTP_PROXY: {http_proxy}, HTTPS_PROXY: {https_proxy}")
logger.error("You can set proxy with: export http_proxy=... && export https_proxy=...")
raise
self.face_cascade = None
else:
raise ValueError(f"Unknown method: {method}. Must be 'yolo', or 'grounding'")
def detect_faces(
self,
image: Union[str, Image.Image, bytes, np.ndarray],
return_image: bool = False,
) -> Dict:
"""
Detect faces in image
Args:
image: Input image, can be path, PIL Image, bytes or numpy array
return_image: Whether to return annotated image with detection boxes
Returns:
Dict containing:
- faces: List of face detection results, each containing:
- bbox: [x1, y1, x2, y2] bounding box coordinates (absolute pixel coordinates)
- confidence: Confidence score (0.0-1.0)
- class_id: Class ID
- class_name: Class name
- face_type: Type of face detected
- image (optional): PIL Image with detection boxes drawn (if return_image=True)
"""
try:
if self.method == "grounding":
return self._detect_faces_grounding(image, return_image)
elif self.method == "yolo":
return self._detect_faces_yolo(image, return_image)
except Exception as e:
logger.error(f"Face detection failed: {traceback.format_exc()}")
raise RuntimeError(f"Face detection error: {e}")
def _detect_faces_yolo(
self,
image: Union[str, Image.Image, bytes, np.ndarray],
return_image: bool = False,
) -> Dict:
"""Detect faces using YOLO World"""
# Load image
if isinstance(image, str):
img = Image.open(image).convert("RGB")
elif isinstance(image, bytes):
img = Image.open(io.BytesIO(image)).convert("RGB")
elif isinstance(image, np.ndarray):
img = Image.fromarray(image).convert("RGB")
elif isinstance(image, Image.Image):
img = image.convert("RGB")
else:
raise ValueError(f"Unsupported image type: {type(image)}")
# Use YOLO World for open-vocabulary detection
# YOLO World detects objects based on the custom classes set via set_classes()
results = self.model.predict(
source=img,
conf=self.conf_threshold,
device=self.device,
verbose=False,
)
faces = []
annotated_img = img.copy() if return_image else None
if len(results) > 0:
result = results[0]
boxes = result.boxes
if boxes is not None and len(boxes) > 0:
for i in range(len(boxes)):
# Get bounding box coordinates (xyxy format)
bbox = boxes.xyxy[i].cpu().numpy().tolist()
confidence = float(boxes.conf[i].cpu().numpy())
class_id = int(boxes.cls[i].cpu().numpy())
# Get class name from custom classes list
# YOLO World returns class_id that corresponds to index in custom_classes
if class_id < len(self.custom_classes):
class_name = self.custom_classes[class_id]
else:
class_name = result.names.get(class_id, "unknown")
# Determine face type based on class name
# For "face" class, it can detect all types of faces
if class_name.lower() == "face":
face_type = "face" # Generic face type (can be human, animal, anime, etc.)
elif any(keyword in class_name.lower() for keyword in ["human", "person"]):
face_type = "human"
elif any(keyword in class_name.lower() for keyword in ["animal", "cat", "dog", "bird", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe"]):
face_type = "animal"
elif any(keyword in class_name.lower() for keyword in ["anime", "cartoon", "manga"]):
face_type = "anime"
elif any(keyword in class_name.lower() for keyword in ["sketch", "line", "drawing"]):
face_type = "sketch"
else:
logger.debug(f"Dropped unused detected result: {class_name}")
face_type = None
face_info = {
"bbox": bbox, # [x1, y1, x2, y2] - absolute pixel coordinates
"confidence": confidence,
"class_id": class_id,
"class_name": class_name,
"face_type": face_type,
}
if face_type is not None:
faces.append(face_info)
# Draw annotations on image if needed
if return_image and annotated_img is not None:
draw = ImageDraw.Draw(annotated_img)
x1, y1, x2, y2 = bbox
# Draw bounding box
draw.rectangle(
[x1, y1, x2, y2],
outline="red",
width=2,
)
# Draw label
label = f"{class_name} {confidence:.2f}"
draw.text((x1, y1 - 15), label, fill="red")
result_dict = {"faces": faces}
if return_image and annotated_img is not None:
result_dict["image"] = annotated_img
logger.info(f"Detected {len(faces)} faces using YOLO World")
return result_dict
def _calculate_iou(self, bbox1: List[float], bbox2: List[float]) -> float:
"""
Calculate Intersection over Union (IoU) between two bounding boxes
Args:
bbox1: [x1, y1, x2, y2] format
bbox2: [x1, y1, x2, y2] format
Returns:
IoU value between 0 and 1
"""
x1_1, y1_1, x2_1, y2_1 = bbox1
x1_2, y1_2, x2_2, y2_2 = bbox2
# Calculate intersection area
inter_x1 = max(x1_1, x1_2)
inter_y1 = max(y1_1, y1_2)
inter_x2 = min(x2_1, x2_2)
inter_y2 = min(y2_1, y2_2)
if inter_x2 <= inter_x1 or inter_y2 <= inter_y1:
return 0.0
inter_area = (inter_x2 - inter_x1) * (inter_y2 - inter_y1)
# Calculate union area
area1 = (x2_1 - x1_1) * (y2_1 - y1_1)
area2 = (x2_2 - x1_2) * (y2_2 - y1_2)
union_area = area1 + area2 - inter_area
if union_area == 0:
return 0.0
return inter_area / union_area
def _calculate_bbox_area(self, bbox: List[float]) -> float:
"""Calculate the area of a bounding box"""
x1, y1, x2, y2 = bbox
return (x2 - x1) * (y2 - y1)
def _calculate_containment(self, bbox_small: List[float], bbox_large: List[float]) -> float:
"""
Calculate how much of bbox_small is contained in bbox_large
Returns the ratio of intersection area to bbox_small area
"""
x1_s, y1_s, x2_s, y2_s = bbox_small
x1_l, y1_l, x2_l, y2_l = bbox_large
# Calculate intersection
inter_x1 = max(x1_s, x1_l)
inter_y1 = max(y1_s, y1_l)
inter_x2 = min(x2_s, x2_l)
inter_y2 = min(y2_s, y2_l)
if inter_x2 <= inter_x1 or inter_y2 <= inter_y1:
return 0.0
inter_area = (inter_x2 - inter_x1) * (inter_y2 - inter_y1)
small_area = (x2_s - x1_s) * (y2_s - y1_s)
if small_area == 0:
return 0.0
return inter_area / small_area
def _apply_nms(self, faces: List[Dict], iou_threshold: float = 0.4, containment_threshold: float = 0.6) -> List[Dict]:
"""
Apply Non-Maximum Suppression (NMS) to remove duplicate detections.
When detections overlap, keeps the one with larger area (preferring whole objects over parts).
Args:
faces: List of face detection dictionaries
iou_threshold: IoU threshold for considering detections as duplicates
containment_threshold: If a smaller box is contained in a larger box by this ratio, suppress it
Returns:
Filtered list of faces with duplicates removed
"""
if len(faces) == 0:
return faces
# Sort by area (largest first), then by confidence as tie-breaker
# This ensures we keep the larger detection (whole object) over smaller ones (parts)
for face in faces:
face["_area"] = self._calculate_bbox_area(face["bbox"])
sorted_faces = sorted(faces, key=lambda x: (x["_area"], x["confidence"]), reverse=True)
keep = []
suppressed = set()
for i, face in enumerate(sorted_faces):
if i in suppressed:
continue
keep.append(face)
bbox_i = face["bbox"]
area_i = face["_area"]
# Suppress overlapping detections (prefer larger area)
for j in range(i + 1, len(sorted_faces)):
if j in suppressed:
continue
bbox_j = sorted_faces[j]["bbox"]
area_j = sorted_faces[j]["_area"]
# Check IoU overlap
iou = self._calculate_iou(bbox_i, bbox_j)
if iou > iou_threshold:
# If overlapping, suppress the smaller one
suppressed.add(j)
continue
# Check if smaller box is mostly contained in larger box
# (e.g., face is contained in whole animal body)
# Since we sorted by area, area_i >= area_j for j > i
if area_j < area_i:
containment = self._calculate_containment(bbox_j, bbox_i)
if containment > containment_threshold:
suppressed.add(j)
# Clean up temporary area field
for face in keep:
face.pop("_area", None)
logger.info(f"NMS filtered {len(faces)} detections to {len(keep)} (IoU threshold: {iou_threshold}, containment threshold: {containment_threshold}, prefer larger area)")
return keep
def _detect_faces_grounding(
self,
image: Union[str, Image.Image, bytes, np.ndarray],
return_image: bool = False,
) -> Dict:
"""Detect faces using Grounding DINO"""
# Load image
if isinstance(image, str):
img = Image.open(image).convert("RGB")
elif isinstance(image, bytes):
img = Image.open(io.BytesIO(image)).convert("RGB")
elif isinstance(image, np.ndarray):
img = Image.fromarray(image).convert("RGB")
elif isinstance(image, Image.Image):
img = image.convert("RGB")
else:
raise ValueError(f"Unsupported image type: {type(image)}")
# Prepare text prompt - join custom classes with ". " separator
text_prompt = ". ".join(self.custom_classes)
if not text_prompt.endswith("."):
text_prompt += "."
# Process image and text
inputs = self.processor(images=img, text=text_prompt, return_tensors="pt")
if self.device:
inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
# Run inference
with torch.no_grad():
outputs = self.model(**inputs)
# Post-process results
# Note: Grounding DINO uses 'threshold' instead of 'box_threshold'
results = self.processor.post_process_grounded_object_detection(
outputs,
input_ids=inputs["input_ids"],
threshold=self.conf_threshold,
text_threshold=self.conf_threshold,
target_sizes=[img.size[::-1]], # [height, width]
)
faces = []
annotated_img = img.copy() if return_image else None
if len(results) > 0:
result = results[0]
# Get detections
# Use text_labels instead of labels to avoid FutureWarning
boxes = result.get("boxes", [])
text_labels = result.get("text_labels", [])
# Fallback to labels if text_labels not available
if not text_labels:
labels = result.get("labels", [])
# Convert label IDs to class names if needed
text_labels = [self.custom_classes[label] if isinstance(label, int) and label < len(self.custom_classes) else str(label) for label in labels]
scores = result.get("scores", [])
for i, (box, label, score) in enumerate(zip(boxes, text_labels, scores)):
# Grounding DINO returns boxes as [x1, y1, x2, y2]
if isinstance(box, torch.Tensor):
bbox = box.tolist()
else:
bbox = list(box)
# Ensure it's [x1, y1, x2, y2] format
if len(bbox) == 4:
bbox = [float(bbox[0]), float(bbox[1]), float(bbox[2]), float(bbox[3])]
else:
# If it's in center format, convert
x_center, y_center, width, height = bbox
x1 = x_center - width / 2
y1 = y_center - height / 2
x2 = x_center + width / 2
y2 = y_center + height / 2
bbox = [float(x1), float(y1), float(x2), float(y2)]
# Get class name from label
# Grounding DINO may return multiple class names concatenated
class_name_raw = label if isinstance(label, str) else self.custom_classes[0]
# If class_name contains multiple classes, try to extract the most specific one
# Priority: specific classes (animal, anime, sketch) > human > generic face
class_name = class_name_raw
if isinstance(class_name_raw, str) and len(self.custom_classes) > 1:
class_name_lower = class_name_raw.lower()
# Check for specific classes first
if any(keyword in class_name_lower for keyword in ["animal"]):
for c in self.custom_classes:
if "animal" in c.lower():
class_name = c
break
elif any(keyword in class_name_lower for keyword in ["anime", "cartoon"]):
for c in self.custom_classes:
if any(k in c.lower() for k in ["anime", "cartoon"]):
class_name = c
break
elif any(keyword in class_name_lower for keyword in ["sketch", "line", "drawing"]):
for c in self.custom_classes:
if any(k in c.lower() for k in ["sketch", "line", "drawing"]):
class_name = c
break
elif any(keyword in class_name_lower for keyword in ["human", "person"]):
for c in self.custom_classes:
if any(k in c.lower() for k in ["human", "person"]):
class_name = c
break
# Determine face type based on class name
if class_name.lower() == "face":
face_type = "face"
elif any(keyword in class_name.lower() for keyword in ["human", "person"]):
face_type = "human"
elif any(keyword in class_name.lower() for keyword in ["animal", "cat", "dog", "bird"]):
face_type = "animal"
elif any(keyword in class_name.lower() for keyword in ["anime", "cartoon", "manga"]):
face_type = "anime"
elif any(keyword in class_name.lower() for keyword in ["sketch", "line", "drawing"]):
face_type = "sketch"
else:
face_type = class_name.lower()
face_info = {
"bbox": bbox,
"confidence": float(score),
"class_id": i,
"class_name": class_name,
"face_type": face_type,
}
faces.append(face_info)
# Draw annotations if needed
if return_image and annotated_img is not None:
draw = ImageDraw.Draw(annotated_img)
x1, y1, x2, y2 = bbox
draw.rectangle([x1, y1, x2, y2], outline="red", width=2)
label = f"{class_name} {score:.2f}"
draw.text((x1, y1 - 15), label, fill="red")
# Apply NMS to remove duplicate detections (only when multiple classes are specified)
if len(self.custom_classes) > 1:
faces = self._apply_nms(faces, iou_threshold=0.4, containment_threshold=0.6)
# Re-draw annotations after NMS if needed
if return_image and annotated_img is not None and len(faces) > 0:
annotated_img = img.copy()
draw = ImageDraw.Draw(annotated_img)
for face in faces:
x1, y1, x2, y2 = face["bbox"]
draw.rectangle([x1, y1, x2, y2], outline="red", width=2)
label = f"{face['class_name']} {face['confidence']:.2f}"
draw.text((x1, y1 - 15), label, fill="red")
result_dict = {"faces": faces}
if return_image and annotated_img is not None:
result_dict["image"] = annotated_img
logger.info(f"Detected {len(faces)} faces using Grounding DINO (after NMS)")
return result_dict
def detect_faces_from_bytes(self, image_bytes: bytes, **kwargs) -> Dict:
"""
Detect faces from byte data
Args:
image_bytes: Image byte data
**kwargs: Additional parameters passed to detect_faces
Returns:
Detection result dictionary
"""
return self.detect_faces(image_bytes, **kwargs)
def extract_face_regions(self, image: Union[str, Image.Image, bytes], expand_ratio: float = 0.1) -> List[Image.Image]:
"""
Extract detected face regions
Args:
image: Input image
expand_ratio: Bounding box expansion ratio to include more context
Returns:
List of extracted face region images
"""
result = self.detect_faces(image)
faces = result["faces"]
# Load original image
if isinstance(image, str):
img = Image.open(image).convert("RGB")
elif isinstance(image, bytes):
img = Image.open(io.BytesIO(image)).convert("RGB")
elif isinstance(image, Image.Image):
img = image.convert("RGB")
else:
raise ValueError(f"Unsupported image type: {type(image)}")
face_regions = []
img_width, img_height = img.size
for face in faces:
x1, y1, x2, y2 = face["bbox"]
# Expand bounding box
width = x2 - x1
height = y2 - y1
expand_x = width * expand_ratio
expand_y = height * expand_ratio
x1 = max(0, int(x1 - expand_x))
y1 = max(0, int(y1 - expand_y))
x2 = min(img_width, int(x2 + expand_x))
y2 = min(img_height, int(y2 + expand_y))
# Crop region
face_region = img.crop((x1, y1, x2, y2))
face_regions.append(face_region)
return face_regions
def count_faces(self, image: Union[str, Image.Image, bytes]) -> int:
"""
Count number of faces in image
Args:
image: Input image
Returns:
Number of detected faces
"""
result = self.detect_faces(image, return_image=False)
return len(result["faces"])
def detect_faces_in_image(
image_path: str,
method: str = "grounding",
model_path: str = None,
conf_threshold: float = None,
return_image: bool = False,
custom_classes: List[str] = None,
) -> Dict:
"""
Convenience function: detect faces in image
Args:
image_path: Image path
method: Detection method ("yolo", or "grounding"), default "yolo"
model_path: YOLO World model path (only for method="yolo")
conf_threshold: Confidence threshold (None for adaptive, only for method="yolo")
return_image: Whether to return annotated image
custom_classes: List of custom class names for YOLO (default: ["face"])
Returns:
Detection result dictionary containing:
- faces: List of face detection results with bbox coordinates [x1, y1, x2, y2]
Each face contains: bbox, confidence, class_id, class_name, face_type
- image (optional): Annotated image with detection boxes
Examples:
# Detect faces using YOLO World with default "face" class
result = detect_faces_in_image("image.jpg", method="yolo")
# Detect with YOLO World and custom classes
result = detect_faces_in_image("image.jpg", method="yolo",
custom_classes=["human face", "animal face"])
# Detect with Grounding DINO
result = detect_faces_in_image("image.jpg", method="grounding",
custom_classes=["animal face"])
"""
detector = FaceDetector(
method=method,
model_path=model_path,
conf_threshold=conf_threshold,
custom_classes=custom_classes,
)
return detector.detect_faces(image_path, return_image=return_image)
if __name__ == "__main__":
# Test code
import sys
if len(sys.argv) < 2:
print("Usage: python face_detector.py <image_path>")
sys.exit(1)
image_path = sys.argv[1]
detector = FaceDetector()
result = detector.detect_faces(image_path, return_image=True)
print(f"Detected {len(result['faces'])} faces:")
for i, face in enumerate(result["faces"]):
print(f" Face {i + 1}: {face}")
output_path = "detected_faces.png"
result["image"].save(output_path)
print(f"Annotated image saved to: {output_path}")
import json
import sys
from loguru import logger
class Pipeline:
def __init__(self, pipeline_json_file):
self.pipeline_json_file = pipeline_json_file
x = json.load(open(pipeline_json_file))
self.data = x["data"]
self.meta = x["meta"]
self.inputs = {}
self.outputs = {}
self.temps = {}
self.model_lists = []
self.types = {}
self.queues = set()
self.model_name_inner_to_outer = self.meta.get("model_name_inner_to_outer", {})
self.model_name_outer_to_inner = self.meta.get("model_name_outer_to_inner", {})
self.tidy_pipeline()
def init_dict(self, base, task, model_cls):
if task not in base:
base[task] = {}
if model_cls not in base[task]:
base[task][model_cls] = {}
# tidy each task item eg, ['t2v', 'wan2.1', 'multi_stage']
def tidy_task(self, task, model_cls, stage, v3):
out2worker = {}
out2num = {}
cur_inps = set()
cur_temps = set()
cur_types = {}
for worker_name, worker_item in v3.items():
prevs = []
for inp in worker_item["inputs"]:
cur_types[inp] = self.get_type(inp)
if inp in out2worker:
prevs.append(out2worker[inp])
out2num[inp] -= 1
if out2num[inp] <= 0:
cur_temps.add(inp)
else:
cur_inps.add(inp)
worker_item["previous"] = prevs
for out in worker_item["outputs"]:
cur_types[out] = self.get_type(out)
out2worker[out] = worker_name
if out not in out2num:
out2num[out] = 0
out2num[out] += 1
if "queue" not in worker_item:
worker_item["queue"] = "-".join([task, model_cls, stage, worker_name])
self.queues.add(worker_item["queue"])
cur_outs = [out for out, num in out2num.items() if num > 0]
self.inputs[task][model_cls][stage] = list(cur_inps)
self.outputs[task][model_cls][stage] = cur_outs
self.temps[task][model_cls][stage] = list(cur_temps)
self.types[task][model_cls][stage] = cur_types
# tidy previous dependence workers and queue name
def tidy_pipeline(self):
for task, v1 in self.data.items():
for model_cls, v2 in v1.items():
for stage, v3 in v2.items():
self.init_dict(self.inputs, task, model_cls)
self.init_dict(self.outputs, task, model_cls)
self.init_dict(self.temps, task, model_cls)
self.init_dict(self.types, task, model_cls)
self.tidy_task(task, model_cls, stage, v3)
self.model_lists.append({"task": task, "model_cls": model_cls, "stage": stage})
logger.info(f"pipelines: {json.dumps(self.data, indent=4)}")
logger.info(f"inputs: {self.inputs}")
logger.info(f"outputs: {self.outputs}")
logger.info(f"temps: {self.temps}")
logger.info(f"types: {self.types}")
logger.info(f"model_lists: {self.model_lists}")
logger.info(f"queues: {self.queues}")
def get_item_by_keys(self, keys):
item = self.data
for k in keys:
if k not in item:
raise Exception(f"{keys} are not in {self.pipeline_json_file}!")
item = item[k]
return item
# eg. keys: ['t2v', 'wan2.1', 'multi_stage', 'text_encoder']
def get_worker(self, keys):
return self.get_item_by_keys(keys)
# eg. keys: ['t2v', 'wan2.1', 'multi_stage']
def get_workers(self, keys):
return self.get_item_by_keys(keys)
# eg. keys: ['t2v', 'wan2.1', 'multi_stage']
def get_inputs(self, keys):
item = self.inputs
for k in keys:
if k not in item:
raise Exception(f"{keys} are not in inputs!")
item = item[k]
return item
# eg. keys: ['t2v', 'wan2.1', 'multi_stage']
def get_outputs(self, keys):
item = self.outputs
for k in keys:
if k not in item:
raise Exception(f"{keys} are not in outputs!")
item = item[k]
return item
# eg. keys: ['t2v', 'wan2.1', 'multi_stage']
def get_temps(self, keys):
item = self.temps
for k in keys:
if k not in item:
raise Exception(f"{keys} are not in temps!")
item = item[k]
return item
# eg. keys: ['t2v', 'wan2.1', 'multi_stage']
def get_types(self, keys):
item = self.types
for k in keys:
if k not in item:
raise Exception(f"{keys} are not in types!")
item = item[k]
return item
def check_item_by_keys(self, keys):
item = self.data
for k in keys:
if k not in item:
return False
item = item[k]
return True
def get_model_lists(self):
return self.model_lists
def get_type(self, name):
return self.meta["special_types"].get(name, "OBJECT")
def get_monitor_config(self):
return self.meta["monitor"]
def get_queues(self):
return self.queues
def inner_model_name(self, name):
return self.model_name_outer_to_inner.get(name, name)
def outer_model_name(self, name):
return self.model_name_inner_to_outer.get(name, name)
if __name__ == "__main__":
pipeline = Pipeline(sys.argv[1])
print(pipeline.get_workers(["t2v", "wan2.1", "multi_stage"]))
print(pipeline.get_worker(["i2v", "wan2.1", "multi_stage", "dit"]))
# -*- coding: utf-8 -*-
import asyncio
import io
import json
import os
import struct
import uuid
from dataclasses import dataclass
from enum import IntEnum
from typing import Callable, List, Optional
import websockets
from loguru import logger
from pydub import AudioSegment
# Protocol definitions (from podcasts_protocols)
class MsgType(IntEnum):
"""Message type enumeration"""
Invalid = 0
FullClientRequest = 0b1
AudioOnlyClient = 0b10
FullServerResponse = 0b1001
AudioOnlyServer = 0b1011
FrontEndResultServer = 0b1100
Error = 0b1111
ServerACK = AudioOnlyServer
class MsgTypeFlagBits(IntEnum):
"""Message type flag bits"""
NoSeq = 0
PositiveSeq = 0b1
LastNoSeq = 0b10
NegativeSeq = 0b11
WithEvent = 0b100
class VersionBits(IntEnum):
"""Version bits"""
Version1 = 1
class HeaderSizeBits(IntEnum):
"""Header size bits"""
HeaderSize4 = 1
HeaderSize8 = 2
HeaderSize12 = 3
HeaderSize16 = 4
class SerializationBits(IntEnum):
"""Serialization method bits"""
Raw = 0
JSON = 0b1
Thrift = 0b11
Custom = 0b1111
class CompressionBits(IntEnum):
"""Compression method bits"""
None_ = 0
Gzip = 0b1
Custom = 0b1111
class EventType(IntEnum):
"""Event type enumeration"""
None_ = 0
StartConnection = 1
StartTask = 1
FinishConnection = 2
FinishTask = 2
ConnectionStarted = 50
TaskStarted = 50
ConnectionFailed = 51
TaskFailed = 51
ConnectionFinished = 52
TaskFinished = 52
StartSession = 100
CancelSession = 101
FinishSession = 102
SessionStarted = 150
SessionCanceled = 151
SessionFinished = 152
SessionFailed = 153
UsageResponse = 154
ChargeData = 154
TaskRequest = 200
UpdateConfig = 201
AudioMuted = 250
SayHello = 300
TTSSentenceStart = 350
TTSSentenceEnd = 351
TTSResponse = 352
TTSEnded = 359
PodcastRoundStart = 360
PodcastRoundResponse = 361
PodcastRoundEnd = 362
PodcastEnd = 363
@dataclass
class Message:
"""Message object"""
version: VersionBits = VersionBits.Version1
header_size: HeaderSizeBits = HeaderSizeBits.HeaderSize4
type: MsgType = MsgType.Invalid
flag: MsgTypeFlagBits = MsgTypeFlagBits.NoSeq
serialization: SerializationBits = SerializationBits.JSON
compression: CompressionBits = CompressionBits.None_
event: EventType = EventType.None_
session_id: str = ""
connect_id: str = ""
sequence: int = 0
error_code: int = 0
payload: bytes = b""
@classmethod
def from_bytes(cls, data: bytes) -> "Message":
"""Create message object from bytes"""
if len(data) < 3:
raise ValueError(f"Data too short: expected at least 3 bytes, got {len(data)}")
type_and_flag = data[1]
msg_type = MsgType(type_and_flag >> 4)
flag = MsgTypeFlagBits(type_and_flag & 0b00001111)
msg = cls(type=msg_type, flag=flag)
msg.unmarshal(data)
return msg
def marshal(self) -> bytes:
"""Serialize message to bytes"""
buffer = io.BytesIO()
header = [
(self.version << 4) | self.header_size,
(self.type << 4) | self.flag,
(self.serialization << 4) | self.compression,
]
header_size = 4 * self.header_size
if padding := header_size - len(header):
header.extend([0] * padding)
buffer.write(bytes(header))
writers = self._get_writers()
for writer in writers:
writer(buffer)
return buffer.getvalue()
def unmarshal(self, data: bytes) -> None:
"""Deserialize message from bytes"""
buffer = io.BytesIO(data)
version_and_header_size = buffer.read(1)[0]
self.version = VersionBits(version_and_header_size >> 4)
self.header_size = HeaderSizeBits(version_and_header_size & 0b00001111)
buffer.read(1)
serialization_compression = buffer.read(1)[0]
self.serialization = SerializationBits(serialization_compression >> 4)
self.compression = CompressionBits(serialization_compression & 0b00001111)
header_size = 4 * self.header_size
read_size = 3
if padding_size := header_size - read_size:
buffer.read(padding_size)
readers = self._get_readers()
for reader in readers:
reader(buffer)
remaining = buffer.read()
if remaining:
raise ValueError(f"Unexpected data after message: {remaining}")
def _get_writers(self) -> List[Callable[[io.BytesIO], None]]:
"""Get list of writer functions"""
writers = []
if self.flag == MsgTypeFlagBits.WithEvent:
writers.extend([self._write_event, self._write_session_id])
if self.type in [MsgType.FullClientRequest, MsgType.FullServerResponse, MsgType.FrontEndResultServer, MsgType.AudioOnlyClient, MsgType.AudioOnlyServer]:
if self.flag in [MsgTypeFlagBits.PositiveSeq, MsgTypeFlagBits.NegativeSeq]:
writers.append(self._write_sequence)
elif self.type == MsgType.Error:
writers.append(self._write_error_code)
else:
raise ValueError(f"Unsupported message type: {self.type}")
writers.append(self._write_payload)
return writers
def _get_readers(self) -> List[Callable[[io.BytesIO], None]]:
"""Get list of reader functions"""
readers = []
if self.type in [MsgType.FullClientRequest, MsgType.FullServerResponse, MsgType.FrontEndResultServer, MsgType.AudioOnlyClient, MsgType.AudioOnlyServer]:
if self.flag in [MsgTypeFlagBits.PositiveSeq, MsgTypeFlagBits.NegativeSeq]:
readers.append(self._read_sequence)
elif self.type == MsgType.Error:
readers.append(self._read_error_code)
if self.flag == MsgTypeFlagBits.WithEvent:
readers.extend([self._read_event, self._read_session_id, self._read_connect_id])
readers.append(self._read_payload)
return readers
def _write_event(self, buffer: io.BytesIO) -> None:
buffer.write(struct.pack(">i", self.event))
def _write_session_id(self, buffer: io.BytesIO) -> None:
if self.event in [EventType.StartConnection, EventType.FinishConnection, EventType.ConnectionStarted, EventType.ConnectionFailed]:
return
session_id_bytes = self.session_id.encode("utf-8")
size = len(session_id_bytes)
if size > 0xFFFFFFFF:
raise ValueError(f"Session ID size ({size}) exceeds max(uint32)")
buffer.write(struct.pack(">I", size))
if size > 0:
buffer.write(session_id_bytes)
def _write_sequence(self, buffer: io.BytesIO) -> None:
buffer.write(struct.pack(">i", self.sequence))
def _write_error_code(self, buffer: io.BytesIO) -> None:
buffer.write(struct.pack(">I", self.error_code))
def _write_payload(self, buffer: io.BytesIO) -> None:
size = len(self.payload)
if size > 0xFFFFFFFF:
raise ValueError(f"Payload size ({size}) exceeds max(uint32)")
buffer.write(struct.pack(">I", size))
buffer.write(self.payload)
def _read_event(self, buffer: io.BytesIO) -> None:
event_bytes = buffer.read(4)
if event_bytes:
self.event = EventType(struct.unpack(">i", event_bytes)[0])
def _read_session_id(self, buffer: io.BytesIO) -> None:
if self.event in [EventType.StartConnection, EventType.FinishConnection, EventType.ConnectionStarted, EventType.ConnectionFailed, EventType.ConnectionFinished]:
return
size_bytes = buffer.read(4)
if size_bytes:
size = struct.unpack(">I", size_bytes)[0]
if size > 0:
session_id_bytes = buffer.read(size)
if len(session_id_bytes) == size:
self.session_id = session_id_bytes.decode("utf-8")
def _read_connect_id(self, buffer: io.BytesIO) -> None:
if self.event in [EventType.ConnectionStarted, EventType.ConnectionFailed, EventType.ConnectionFinished]:
size_bytes = buffer.read(4)
if size_bytes:
size = struct.unpack(">I", size_bytes)[0]
if size > 0:
self.connect_id = buffer.read(size).decode("utf-8")
def _read_sequence(self, buffer: io.BytesIO) -> None:
sequence_bytes = buffer.read(4)
if sequence_bytes:
self.sequence = struct.unpack(">i", sequence_bytes)[0]
def _read_error_code(self, buffer: io.BytesIO) -> None:
error_code_bytes = buffer.read(4)
if error_code_bytes:
self.error_code = struct.unpack(">I", error_code_bytes)[0]
def _read_payload(self, buffer: io.BytesIO) -> None:
size_bytes = buffer.read(4)
if size_bytes:
size = struct.unpack(">I", size_bytes)[0]
if size > 0:
self.payload = buffer.read(size)
async def receive_message(websocket: websockets.WebSocketClientProtocol) -> Message:
"""Receive message from websocket"""
try:
data = await websocket.recv()
if isinstance(data, str):
raise ValueError(f"Unexpected text message: {data}")
elif isinstance(data, bytes):
msg = Message.from_bytes(data)
# logger.debug(f"Received: {msg}")
return msg
else:
raise ValueError(f"Unexpected message type: {type(data)}")
except Exception as e:
logger.error(f"Failed to receive message: {e}")
raise
async def wait_for_event(websocket: websockets.WebSocketClientProtocol, msg_type: MsgType, event_type: EventType) -> Message:
"""Wait for specific event"""
while True:
msg = await receive_message(websocket)
if msg.type != msg_type or msg.event != event_type:
raise ValueError(f"Unexpected message: {msg}")
if msg.type == msg_type and msg.event == event_type:
return msg
async def start_connection(websocket: websockets.WebSocketClientProtocol) -> None:
"""Start connection"""
msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent)
msg.event = EventType.StartConnection
msg.payload = b"{}"
logger.debug(f"Sending: {msg}")
await websocket.send(msg.marshal())
async def finish_connection(websocket: websockets.WebSocketClientProtocol) -> None:
"""Finish connection"""
msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent)
msg.event = EventType.FinishConnection
msg.payload = b"{}"
logger.debug(f"Sending: {msg}")
await websocket.send(msg.marshal())
async def start_session(websocket: websockets.WebSocketClientProtocol, payload: bytes, session_id: str) -> None:
"""Start session"""
msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent)
msg.event = EventType.StartSession
msg.session_id = session_id
msg.payload = payload
logger.debug(f"Sending: {msg}")
await websocket.send(msg.marshal())
async def finish_session(websocket: websockets.WebSocketClientProtocol, session_id: str) -> None:
"""Finish session"""
msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent)
msg.event = EventType.FinishSession
msg.session_id = session_id
msg.payload = b"{}"
logger.debug(f"Sending: {msg}")
await websocket.send(msg.marshal())
class PodcastRoundPostProcessor:
def __init__(self, session_id, data_manager):
self.session_id = session_id
self.data_manager = data_manager
self.temp_merged_audio_name = "merged_audio.mp3"
self.output_merged_audio_name = f"{session_id}-merged_audio.mp3"
self.subtitle_timestamps = [] # 记录字幕时间戳
self.current_audio_duration = 0.0 # 当前音频时长
self.merged_audio = None # 用于存储合并的音频对象
self.merged_audio_bytes = None
async def init(self):
if self.data_manager:
await self.data_manager.create_podcast_temp_session_dir(self.session_id)
async def postprocess_round(self, current_round, voice, audio, podcast_texts):
text = ""
if podcast_texts:
text = podcast_texts[-1].get("text", "")
logger.debug(f"Processing round: {current_round}, voice: {voice}, text: {text}, audio: {len(audio)} bytes")
new_segment = AudioSegment.from_mp3(io.BytesIO(bytes(audio)))
round_duration = len(new_segment) / 1000.0
if self.merged_audio is None:
self.merged_audio = new_segment
else:
self.merged_audio = self.merged_audio + new_segment
# 保存合并后的音频到临时文件(用于前端实时访问)
merged_io = io.BytesIO()
self.merged_audio.export(merged_io, format="mp3")
self.merged_audio_bytes = merged_io.getvalue()
if self.data_manager:
await self.data_manager.save_podcast_temp_session_file(self.session_id, self.temp_merged_audio_name, self.merged_audio_bytes)
merged_file_size = len(self.merged_audio_bytes)
# 记录字幕时间戳
self.subtitle_timestamps.append(
{
"start": self.current_audio_duration,
"end": self.current_audio_duration + round_duration,
"text": text,
"speaker": voice,
}
)
self.current_audio_duration += round_duration
logger.debug(f"Merged audio updated: {merged_file_size} bytes, duration: {self.current_audio_duration:.2f}s")
return {
"url": f"/api/v1/podcast/audio?session_id={self.session_id}&filename={self.temp_merged_audio_name}",
"size": merged_file_size,
"duration": self.current_audio_duration,
"round": current_round,
"text": text,
"speaker": voice,
}
async def postprocess_final(self):
if self.data_manager:
await self.data_manager.save_podcast_output_file(self.output_merged_audio_name, self.merged_audio_bytes)
return {
"subtitles": self.subtitle_timestamps,
"audio_name": self.output_merged_audio_name,
}
async def cleanup(self):
if self.data_manager:
await self.data_manager.clear_podcast_temp_session_dir(self.session_id)
self.data_manager = None
class VolcEnginePodcastClient:
"""
VolcEngine Podcast客户端
支持多种播客类型:
- action=0: 文本转播客
- action=3: NLP文本转播客
- action=4: 提示词生成播客
"""
def __init__(self):
self.endpoint = "wss://openspeech.bytedance.com/api/v3/sami/podcasttts"
self.appid = os.getenv("VOLCENGINE_PODCAST_APPID")
self.access_token = os.getenv("VOLCENGINE_PODCAST_ACCESS_TOKEN")
self.app_key = "aGjiRDfUWi"
self.proxy = os.getenv("HTTPS_PROXY", None)
if self.proxy:
logger.info(f"volcengine podcast use proxy: {self.proxy}")
async def podcast_request(
self,
session_id: str,
data_manager=None,
text: str = "",
input_url: str = "",
prompt_text: str = "",
nlp_texts: str = "",
action: int = 0,
resource_id: str = "volc.service_type.10050",
encoding: str = "mp3",
input_id: str = "test_podcast",
speaker_info: str = '{"random_order":false}',
use_head_music: bool = False,
use_tail_music: bool = False,
only_nlp_text: bool = False,
return_audio_url: bool = False,
skip_round_audio_save: bool = False,
on_round_complete: Optional[Callable] = None,
):
"""
执行播客请求
Args:
text: 输入文本 (action=0时使用)
input_url: Web URL或文件URL (action=0时使用)
prompt_text: 提示词文本 (action=4时必须)
nlp_texts: NLP文本 (action=3时必须)
action: 播客类型 (0/3/4)
resource_id: 音频资源ID
encoding: 音频格式 (mp3/wav)
input_id: 唯一输入标识
speaker_info: 播客说话人信息
use_head_music: 是否使用开头音乐
use_tail_music: 是否使用结尾音乐
only_nlp_text: 是否只返回播客文本
return_audio_url: 是否返回音频URL
skip_round_audio_save: 是否跳过单轮音频保存
output_dir: 输出目录
on_round_complete: 轮次完成回调函数
"""
if not self.appid or not self.access_token:
logger.error("APP ID or Access Key is required")
return None, None
headers = {
"X-Api-App-Id": self.appid,
"X-Api-App-Key": self.app_key,
"X-Api-Access-Key": self.access_token,
"X-Api-Resource-Id": resource_id,
"X-Api-Connect-Id": str(uuid.uuid4()),
}
is_podcast_round_end = True
audio_received = False
last_round_id = -1
task_id = ""
websocket = None
retry_num = 5
audio = bytearray()
voice = ""
current_round = 0
podcast_texts = []
post_processor = PodcastRoundPostProcessor(session_id, data_manager)
await post_processor.init()
try:
while retry_num > 0:
# 建立WebSocket连接
websocket = await websockets.connect(self.endpoint, additional_headers=headers)
logger.debug(f"WebSocket connected: {websocket.response.headers}")
# 构建请求参数
if input_url:
req_params = {
"input_id": input_id,
"nlp_texts": json.loads(nlp_texts) if nlp_texts else None,
"prompt_text": prompt_text,
"action": action,
"use_head_music": use_head_music,
"use_tail_music": use_tail_music,
"input_info": {
"input_url": input_url,
"return_audio_url": return_audio_url,
"only_nlp_text": only_nlp_text,
},
"speaker_info": json.loads(speaker_info) if speaker_info else None,
"audio_config": {"format": encoding, "sample_rate": 24000, "speech_rate": 0},
}
else:
req_params = {
"input_id": input_id,
"input_text": text,
"nlp_texts": json.loads(nlp_texts) if nlp_texts else None,
"prompt_text": prompt_text,
"action": action,
"use_head_music": use_head_music,
"use_tail_music": use_tail_music,
"input_info": {
"input_url": input_url,
"return_audio_url": return_audio_url,
"only_nlp_text": only_nlp_text,
},
"speaker_info": json.loads(speaker_info) if speaker_info else None,
"audio_config": {"format": encoding, "sample_rate": 24000, "speech_rate": 0},
}
logger.debug(f"Request params: {json.dumps(req_params, indent=2, ensure_ascii=False)}")
if not is_podcast_round_end:
req_params["retry_info"] = {"retry_task_id": task_id, "last_finished_round_id": last_round_id}
# Start connection
await start_connection(websocket)
await wait_for_event(websocket, MsgType.FullServerResponse, EventType.ConnectionStarted)
session_id = str(uuid.uuid4())
if not task_id:
task_id = session_id
# Start session
await start_session(websocket, json.dumps(req_params).encode(), session_id)
await wait_for_event(websocket, MsgType.FullServerResponse, EventType.SessionStarted)
# Finish session
await finish_session(websocket, session_id)
while True:
msg = await receive_message(websocket)
# 音频数据块
if msg.type == MsgType.AudioOnlyServer and msg.event == EventType.PodcastRoundResponse:
if not audio_received and audio:
audio_received = True
audio.extend(msg.payload)
# 错误信息
elif msg.type == MsgType.Error:
raise RuntimeError(f"Server error: {msg.payload.decode()}")
elif msg.type == MsgType.FullServerResponse:
# 播客 round 开始
if msg.event == EventType.PodcastRoundStart:
data = json.loads(msg.payload.decode())
if data.get("text"):
filtered_payload = {"text": data.get("text"), "speaker": data.get("speaker")}
podcast_texts.append(filtered_payload)
voice = data.get("speaker")
current_round = data.get("round_id")
if current_round == -1:
voice = "head_music"
if current_round == 9999:
voice = "tail_music"
is_podcast_round_end = False
logger.debug(f"New round started: {data}")
# 播客 round 结束
if msg.event == EventType.PodcastRoundEnd:
data = json.loads(msg.payload.decode())
logger.debug(f"Podcast round end: {data}")
if data.get("is_error"):
break
is_podcast_round_end = True
last_round_id = current_round
if audio:
round_info = await post_processor.postprocess_round(current_round, voice, audio, podcast_texts)
if on_round_complete:
await on_round_complete(round_info)
audio.clear()
# 播客结束
if msg.event == EventType.PodcastEnd:
data = json.loads(msg.payload.decode())
logger.info(f"Podcast end: {data}")
# 会话结束
if msg.event == EventType.SessionFinished:
break
if not audio_received and not only_nlp_text:
raise RuntimeError("No audio data received")
# 保持连接
await finish_connection(websocket)
await wait_for_event(websocket, MsgType.FullServerResponse, EventType.ConnectionFinished)
# 播客结束, 保存最终音频文件
if is_podcast_round_end:
podcast_info = await post_processor.postprocess_final()
return podcast_info
else:
logger.error(f"Current podcast not finished, resuming from round {last_round_id}")
retry_num -= 1
await asyncio.sleep(1)
if websocket:
await websocket.close()
finally:
await post_processor.cleanup()
if websocket:
await websocket.close()
return None
async def test(args):
"""
Podcast测试函数
Args:
args: dict, 包含所有podcast参数
"""
client = VolcEnginePodcastClient()
# 设置默认参数
params = {
"text": "",
"input_url": "https://zhuanlan.zhihu.com/p/607822576",
"prompt_text": "",
"nlp_texts": "",
"action": 0,
"resource_id": "volc.service_type.10050",
"encoding": "mp3",
"input_id": "test_podcast",
"speaker_info": '{"random_order":false}',
"use_head_music": False,
"use_tail_music": False,
"only_nlp_text": False,
"return_audio_url": True,
"skip_round_audio_save": False,
"output_dir": "output",
}
# 覆盖默认参数
if args:
params.update(args)
await client.podcast_request(**params)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--text", default="", help="Input text Use when action in [0]")
parser.add_argument("--input_url", default="", help="Web url or file url Use when action in [0]")
parser.add_argument("--prompt_text", default="", help="Input Prompt Text must not empty when action in [4]")
parser.add_argument("--nlp_texts", default="", help="Input NLP Texts must not empty when action in [3]")
parser.add_argument("--resource_id", default="volc.service_type.10050", help="Audio Resource ID")
parser.add_argument("--encoding", default="mp3", choices=["mp3", "wav"], help="Audio format")
parser.add_argument("--input_id", default="test_podcast", help="Unique input identifier")
parser.add_argument("--speaker_info", default='{"random_order":false}', help="Podcast Speaker Info")
parser.add_argument("--use_head_music", default=False, action="store_true", help="Enable head music")
parser.add_argument("--use_tail_music", default=False, action="store_true", help="Enable tail music")
parser.add_argument("--only_nlp_text", default=False, action="store_true", help="Enable only podcast text when action in [0, 4]")
parser.add_argument("--return_audio_url", default=False, action="store_true", help="Enable return audio url that can download")
parser.add_argument("--action", default=0, type=int, choices=[0, 3, 4], help="different podcast type")
parser.add_argument("--skip_round_audio_save", default=False, action="store_true", help="skip round audio save")
parser.add_argument("--output_dir", default="output", help="Output directory")
args = parser.parse_args()
kwargs = {k: v for k, v in vars(args).items() if v is not None and not (isinstance(v, bool) and not v)}
asyncio.run(test(kwargs))
# -*- coding: utf-8 -*-
import asyncio
import os
import struct
import subprocess
import sys
import tempfile
import time
import uuid
from typing import Optional, Tuple
import aiohttp
import numpy as np
import soundfile as sf
from aiohttp import ClientWebSocketResponse
# Protobuf imports
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
from loguru import logger
# ============================================================================
# Generated protocol buffer code (from tts.proto)
# ============================================================================
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\ttts.proto\x12\x03tts"\x8a\x01\n\rSubtitleEntry\x12\x15\n\rstart_time_ms\x18\x01 \x01(\r\x12\x13\n\x0b\x65nd_time_ms\x18\x02 \x01(\r\x12\x0f\n\x07speaker\x18\x03 \x01(\t\x12\r\n\x05style\x18\x04 \x01(\t\x12\x1f\n\x08language\x18\x05 \x01(\x0e\x32\r.tts.Language\x12\x0c\n\x04text\x18\x06 \x01(\t"\x88\x01\n\nAudioChunk\x12\x12\n\naudio_data\x18\x01 \x01(\x0c\x12\x17\n\x0f\x61udio_chunk_seq\x18\x02 \x01(\x05\x12\x15\n\ris_last_chunk\x18\x03 \x01(\x08\x12\x0c\n\x04text\x18\x04 \x01(\t\x12\x14\n\x0c\x61udio_format\x18\x05 \x01(\t\x12\x12\n\ndisable_ns\x18\x06 \x01(\x08"\x84\x04\n\nTtsRequest\x12-\n\x0cmessage_type\x18\x01 \x01(\x0e\x32\x17.tts.RequestMessageType\x12\x0e\n\x06\x61pp_id\x18\x02 \x01(\t\x12\x15\n\rapp_signature\x18\x03 \x01(\t\x12\x0c\n\x04text\x18\x04 \x01(\t\x12\x16\n\x0etext_chunk_seq\x18\x05 \x01(\x05\x12\x1a\n\x12is_last_text_chunk\x18\x06 \x01(\x08\x12 \n\ttext_type\x18\x07 \x01(\x0e\x32\r.tts.TextType\x12\x0f\n\x07speaker\x18\x08 \x01(\t\x12\x1f\n\x08language\x18\t \x01(\x0e\x32\r.tts.Language\x12\r\n\x05style\x18\n \x01(\t\x12\r\n\x05speed\x18\x0b \x01(\x02\x12\x0e\n\x06volume\x18\x0c \x01(\x02\x12\r\n\x05pitch\x18\r \x01(\x02\x12\x15\n\rstream_output\x18\x0e \x01(\x08\x12\x19\n\x11\x61udio_sample_rate\x18\x0f \x01(\x05\x12*\n\x0e\x61udio_encoding\x18\x10 \x01(\x0e\x32\x12.tts.AudioEncoding\x12\x18\n\x10output_subtitles\x18\x11 \x01(\x08\x12\x12\n\nsession_id\x18\x12 \x01(\t\x12%\n\x0cupload_audio\x18\x13 \x01(\x0b\x32\x0f.tts.AudioChunk\x12\x1a\n\x12pronunciation_dict\x18\x14 \x03(\t"\xe7\x02\n\x0bTtsResponse\x12$\n\x0bstatus_code\x18\x01 \x01(\x0e\x32\x0f.tts.StatusCode\x12\x14\n\x0c\x65rror_detail\x18\x02 \x01(\t\x12\x14\n\x0ctime_cost_ms\x18\x03 \x01(\r\x12*\n\x0e\x61udio_encoding\x18\x04 \x01(\x0e\x32\x12.tts.AudioEncoding\x12\x17\n\x0f\x61udio_chunk_seq\x18\x05 \x01(\x05\x12\x12\n\naudio_data\x18\x06 \x01(\x0c\x12\x1b\n\x13is_last_audio_chunk\x18\x07 \x01(\x08\x12\x12\n\nsession_id\x18\x08 \x01(\t\x12%\n\tsubtitles\x18\t \x03(\x0b\x32\x12.tts.SubtitleEntry\x12\x0f\n\x07speaker\x18\n \x01(\t\x12\x1a\n\x12request_char_count\x18\x0b \x01(\r\x12(\n\rerror_subcode\x18\x0c \x01(\x0e\x32\x11.tts.ErrorSubCode*\xa9\x01\n\x12RequestMessageType\x12\x1c\n\x18\x43LIENT_SYNTHESIS_REQUEST\x10\x00\x12\x19\n\x15\x43LIENT_FINISH_REQUEST\x10\x01\x12\x1d\n\x19\x43LIENT_UPLOAD_CLONE_AUDIO\x10\x02\x12\x1c\n\x18\x43LIENT_QUERY_CLONE_AUDIO\x10\x03\x12\x1d\n\x19\x43LIENT_DELETE_CLONE_AUDIO\x10\x04*\x1f\n\x08TextType\x12\t\n\x05PLAIN\x10\x00\x12\x08\n\x04SSML\x10\x01*A\n\x08Language\x12\t\n\x05ZH_CN\x10\x00\x12\t\n\x05\x45N_US\x10\x01\x12\x11\n\rZH_CN_SICHUAN\x10\x02\x12\x0c\n\x08ZH_CN_HK\x10\x03**\n\rAudioEncoding\x12\x07\n\x03PCM\x10\x00\x12\x07\n\x03WAV\x10\x01\x12\x07\n\x03MP3\x10\x02*\xa7\x01\n\nStatusCode\x12\x0b\n\x07SUCCESS\x10\x00\x12\t\n\x05\x45RROR\x10\x01\x12\x0b\n\x07TIMEOUT\x10\x02\x12\x13\n\x0fINVALID_REQUEST\x10\x03\x12\x12\n\x0eINTERNAL_ERROR\x10\x04\x12\x18\n\x14UPLOAD_AUDIO_SUCCESS\x10\x05\x12\x17\n\x13QUERY_AUDIO_SUCCESS\x10\x06\x12\x18\n\x14\x44\x45LETE_AUDIO_SUCCESS\x10\x07*\xe1\x02\n\x0c\x45rrorSubCode\x12\x0c\n\x08\x45RR_NONE\x10\x00\x12\x16\n\x12\x45RR_BASE_FILE_READ\x10\x65\x12\x17\n\x13\x45RR_BASE_FILE_WRITE\x10\x66\x12\x1c\n\x18\x45RR_BASE_INVALID_SEQ_NUM\x10g\x12\x1e\n\x1a\x45RR_BASE_SPEAKER_NOT_FOUND\x10h\x12\x14\n\x0f\x45RR_AC_INTERNAL\x10\xc9\x01\x12\x16\n\x11\x45RR_AC_LONG_AUDIO\x10\xca\x01\x12\x15\n\x10\x45RR_AC_LONG_TEXT\x10\xcb\x01\x12\x1f\n\x1a\x45RR_AC_AUDIO_TEXT_MISMATCH\x10\xcc\x01\x12 \n\x1b\x45RR_AC_UNAUTHORIZED_SPEAKER\x10\xcd\x01\x12\x1b\n\x16\x45RR_AC_INVALID_SPEAKER\x10\xce\x01\x12\x17\n\x12\x45RR_AC_SHORT_AUDIO\x10\xcf\x01\x12\x16\n\x11\x45RR_AC_SHORT_TEXT\x10\xd0\x01\x62\x06proto3'
)
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "tts_pb2", _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._options = None
_globals["_REQUESTMESSAGETYPE"]._serialized_start = 1180
_globals["_REQUESTMESSAGETYPE"]._serialized_end = 1349
_globals["_TEXTTYPE"]._serialized_start = 1351
_globals["_TEXTTYPE"]._serialized_end = 1382
_globals["_LANGUAGE"]._serialized_start = 1384
_globals["_LANGUAGE"]._serialized_end = 1449
_globals["_AUDIOENCODING"]._serialized_start = 1451
_globals["_AUDIOENCODING"]._serialized_end = 1493
_globals["_STATUSCODE"]._serialized_start = 1496
_globals["_STATUSCODE"]._serialized_end = 1663
_globals["_ERRORSUBCODE"]._serialized_start = 1666
_globals["_ERRORSUBCODE"]._serialized_end = 2019
_globals["_SUBTITLEENTRY"]._serialized_start = 19
_globals["_SUBTITLEENTRY"]._serialized_end = 157
_globals["_AUDIOCHUNK"]._serialized_start = 160
_globals["_AUDIOCHUNK"]._serialized_end = 296
_globals["_TTSREQUEST"]._serialized_start = 299
_globals["_TTSREQUEST"]._serialized_end = 815
_globals["_TTSRESPONSE"]._serialized_start = 818
_globals["_TTSRESPONSE"]._serialized_end = 1177
# Import protobuf classes for easier access
# These are created by the protobuf builder above and added to _globals
# ============================================================================
# Get protobuf classes from _globals (they are created by the builder)
SubtitleEntry = _globals.get("SubtitleEntry")
AudioChunk = _globals.get("AudioChunk")
TtsRequest = _globals.get("TtsRequest")
TtsResponse = _globals.get("TtsResponse")
RequestMessageType = _globals.get("RequestMessageType")
TextType = _globals.get("TextType")
Language = _globals.get("Language")
AudioEncoding = _globals.get("AudioEncoding")
StatusCode = _globals.get("StatusCode")
ErrorSubCode = _globals.get("ErrorSubCode")
# Verify that all required classes are available
if not all([SubtitleEntry, AudioChunk, TtsRequest, TtsResponse, RequestMessageType, TextType, Language, AudioEncoding, StatusCode, ErrorSubCode]):
raise RuntimeError("Failed to load protobuf classes. Please check protobuf installation.")
# ============================================================================
# Configuration parameters
RECEIVE_TIMEOUT = 30 # Receive timeout (seconds)
# Language mapping
lang_id2str_mapping = {
Language.ZH_CN: "ZH_CN",
Language.ZH_CN_SICHUAN: "ZH_CN_SICHUAN",
Language.ZH_CN_HK: "ZH_CN_HK",
Language.EN_US: "EN_US",
}
lang_str2id_mapping = {v: k for k, v in lang_id2str_mapping.items()}
# Audio encoding mapping
codec_id2str_mapping = {
AudioEncoding.PCM: "pcm",
AudioEncoding.WAV: "wav",
AudioEncoding.MP3: "mp3",
}
codec_str2id_mapping = {v: k for k, v in codec_id2str_mapping.items()}
def parse_response(protocol_type: int, data: bytes) -> TtsResponse:
try:
response = TtsResponse()
response.ParseFromString(data)
return response
except Exception as e:
raise ValueError(f"Failed to parse response: {str(e)}")
def create_synthesis_request(
message_type,
text: str,
text_chunk_seq: int = 0,
is_last_text_chunk: bool = False,
app_id: str = "",
app_signature: str = "",
text_type: TextType = TextType.PLAIN,
speaker: str = "default",
language: Language = Language.ZH_CN,
style: str = "",
speed: float = 1,
volume: float = 0,
pitch: float = 0,
stream_output: bool = True,
audio_sample_rate: int = 24000,
audio_encoding: AudioEncoding = AudioEncoding.PCM,
output_subtitles: bool = False,
session_id: str = "",
upload_data: Optional[AudioChunk] = None,
) -> TtsRequest:
request = TtsRequest()
request.message_type = message_type
request.app_id = app_id
request.text = text
request.text_chunk_seq = text_chunk_seq
request.is_last_text_chunk = is_last_text_chunk
request.text_type = text_type
request.speaker = speaker
request.language = language
request.style = style
request.speed = speed
request.volume = volume
request.pitch = pitch
request.stream_output = stream_output
request.audio_sample_rate = audio_sample_rate
request.audio_encoding = audio_encoding
request.output_subtitles = output_subtitles
request.session_id = session_id
if upload_data is not None:
request.upload_audio.CopyFrom(upload_data)
return request
def serialize_request(request: TtsRequest) -> bytes:
request_bytes = request.SerializeToString()
request_length = struct.pack("!I", len(request_bytes))
full_request = b"\x01" + request_length + request_bytes
return full_request
async def receive_full_message(websocket: ClientWebSocketResponse) -> Tuple[int, bytes]:
try:
# Receive data
message = await asyncio.wait_for(websocket.receive_bytes(), timeout=RECEIVE_TIMEOUT)
if len(message) < 5:
raise ValueError("Invalid response: too short")
protocol_type = message[0]
if protocol_type != 0x01:
raise ValueError("Unsupported protocol type")
protocol_length = struct.unpack("!I", message[1:5])[0]
data = message[5:]
if len(data) != protocol_length:
logger.info(f"Length error {protocol_length}, got {len(data)}")
# If data is incomplete, continue receiving
while len(data) < protocol_length:
try:
chunk = await asyncio.wait_for(websocket.receive_bytes(), timeout=RECEIVE_TIMEOUT)
if not chunk:
raise ValueError("Got disconnected or empty data")
data += chunk
logger.info(f"Received additional {len(chunk)} bytes, total {len(data)}/{protocol_length}")
except asyncio.TimeoutError:
raise ValueError(f"Timeout while receiving message. Got {len(data)}/{protocol_length} bytes")
return protocol_type, data
except asyncio.TimeoutError:
raise ValueError(f"Response timed out after {RECEIVE_TIMEOUT} seconds")
except aiohttp.WSServerHandshakeError as e:
# WebSocket handshake error, may contain error information
error_msg = f"WebSocket handshake error: {str(e)}"
if hasattr(e, "message") and e.message:
error_msg = e.message
raise ValueError(error_msg)
except Exception as e:
error_str = str(e)
# Check if it's a WebSocket close message error
if "1009" in error_str:
raise ValueError("Audio file too large or format not supported. Please use WAV/MP3 audio file (max size limit).")
elif "1000" in error_str or "WSMsgType" in error_str:
# WebSocket close message, try to extract error information
if "1009" in error_str:
raise ValueError("Message too large. Audio file may be too big or in unsupported format.")
else:
raise ValueError(f"WebSocket connection closed: {error_str}")
raise ValueError(f"Error receiving data: {str(e)}")
class SenseTimeTTSClient:
"""
SenseTime TTS Client
Parameter ranges:
- speed: 0.5~2.0 (1.0 is normal speed)
- volume: -12~12 dB (0 is normal volume)
- pitch: -24~24 halftone (0 is normal pitch)
"""
def __init__(self, url=None, app_id=None, apikey=None):
self.url = url or os.getenv("SENSETIME_TTS_URL")
self.app_id = app_id or os.getenv("SENSETIME_APP_ID")
self.apikey = apikey or os.getenv("SENSETIME_APIKEY")
if not self.apikey:
raise ValueError("SENSETIME_APIKEY is not set")
if not self.app_id:
raise ValueError("SENSETIME_APP_ID is not set")
if not self.url:
raise ValueError("SENSETIME_TTS_URL is not set")
async def _receive_loop(self, websocket, session_id, params, result_dict):
"""Continuously receive server responses in a loop"""
is_running = True
data = b""
seq = -1
subtitles = []
first_latency = None
try:
while is_running:
try:
ptype, data_bytes = await receive_full_message(websocket)
response = parse_response(ptype, data_bytes)
if response.status_code == StatusCode.SUCCESS:
chunk_seq = response.audio_chunk_seq
is_last_chunk = response.is_last_audio_chunk
stream = params.get("stream_output", True)
# Check sequence number
valid = chunk_seq == seq + 1
seq = chunk_seq
if not valid:
logger.warning(f"Session {session_id} Invalid seq")
is_running = False
break
if chunk_seq == 0:
start_time = result_dict.get("start_time")
if start_time is not None:
first_latency = (time.time() - start_time) * 1000
logger.info(f"Session {session_id} stream({int(stream)}) Got first package, cost(ms): {first_latency:.3f}")
if response.audio_data:
data += response.audio_data
logger.info(f"Audio seq:{chunk_seq},is_last:{is_last_chunk} data length: {len(response.audio_data)} bytes")
if response.subtitles:
for subtitle in response.subtitles:
start_time_ms = subtitle.start_time_ms
end_time_ms = subtitle.end_time_ms
fmt_sub = f" {subtitle.text} ({start_time_ms}-{end_time_ms}ms)"
subtitles.append(fmt_sub)
if response.is_last_audio_chunk:
start_time = result_dict.get("start_time")
whole_cost = time.time() - start_time if start_time else 0
if len(data) > 0:
sample_rate = params.get("sample_rate", 24000)
duration = len(data) / 2 / sample_rate
rtf = whole_cost / duration if duration > 0 else 0
if len(subtitles) > 0:
joint_sub = "\t".join(subtitles)
logger.info(f"Session {session_id} subtile:{joint_sub}")
out_info = f"spk {params.get('speaker', 'default')} "
out_info += f"stream {int(stream)} "
if first_latency is not None:
out_info += f"latency {first_latency:.3f} ms "
out_info += f"cost {whole_cost:.3f} secs "
if params.get("audio_format") == "pcm":
out_info += f"duration {duration:.3f} secs "
out_info += f"RTF {rtf:.3f}"
logger.info(f"Session {session_id} done, {out_info}")
result_dict["audio_data"] = data
result_dict["subtitles"] = subtitles
result_dict["success"] = True
is_running = False
elif response.status_code == StatusCode.INTERNAL_ERROR:
error_msg = response.error_detail if response.error_detail else "Internal error"
logger.error(f"INTERNAL_ERROR in response: {error_msg}")
result_dict["error"] = error_msg
result_dict["success"] = False
is_running = False
break
elif response.status_code == StatusCode.ERROR:
error_msg = response.error_detail if response.error_detail else "Unknown error"
logger.error(f"ERROR in response: {error_msg}")
result_dict["error"] = error_msg
result_dict["success"] = False
is_running = False
break
elif response.status_code == StatusCode.UPLOAD_AUDIO_SUCCESS:
if response.speaker == "":
logger.error("ERROR: Got none speaker for UPLOAD_AUDIO_SUCCESS")
result_dict["error"] = "Got none speaker for UPLOAD_AUDIO_SUCCESS"
else:
logger.info(f"OK, Got speaker id {response.speaker} session id {response.session_id}")
result_dict["speaker"] = response.speaker
result_dict["session_id"] = response.session_id
result_dict["success"] = True
is_running = False
break
elif response.status_code == StatusCode.QUERY_AUDIO_SUCCESS:
logger.info(f"Query speaker {response.speaker} successful")
result_dict["speaker"] = response.speaker
result_dict["success"] = True
is_running = False
break
elif response.status_code == StatusCode.DELETE_AUDIO_SUCCESS:
logger.info(f"Delete speaker {response.speaker} successful")
result_dict["success"] = True
is_running = False
break
else:
# Handle other error status codes, return error details directly
error_msg = response.error_detail if response.error_detail else "Unknown error"
logger.error(f"Error in response: {error_msg}")
result_dict["error"] = error_msg
result_dict["success"] = False
is_running = False
break
except asyncio.CancelledError:
logger.info("Receive loop cancelled")
is_running = False
break
except Exception as e:
logger.error(f"Error in receive loop: {e}")
result_dict["error"] = str(e)
break
except Exception as e:
logger.error(f"Receive loop terminated: {e}")
result_dict["error"] = str(e)
logger.info("Exit receive loop.")
async def tts_request(
self,
text,
speaker="M20",
style="正常",
speed=1.0,
volume=0,
pitch=0,
language="ZH_CN",
output="tts_output.wav",
sample_rate=24000,
audio_format="wav",
stream_output=True,
output_subtitles=False,
):
"""
Execute TTS request
Args:
text: Text to convert
speaker: Speaker, common values include "M20", "F12", "zhili", "nvguo59", or ID returned by audioclone
style: Speaker style, common values include "正常" (normal), "高兴" (happy), "愤怒" (angry), etc.
speed: Speech rate (0.5~2.0, 1.0 is normal speed)
volume: Volume (-12~12 dB, 0 is normal volume)
pitch: Pitch (-24~24 halftone, 0 is normal pitch)
language: Language, options: "ZH_CN", "ZH_CN_SICHUAN", "ZH_CN_HK", "EN_US"
output: Output file path
sample_rate: Sample rate, options: 8000, 16000, 24000, 32000, 48000
audio_format: Audio format, options: "pcm", "wav", "mp3"
stream_output: Whether to stream output
output_subtitles: Whether to output subtitles
"""
# Validate parameter ranges
if not (0.5 <= speed <= 2.0):
logger.warning(f"speed {speed} is out of valid range [0.5, 2.0], using default value 1.0")
speed = 1.0
if not (-12 <= volume <= 12):
logger.warning(f"volume {volume} is out of valid range [-12, 12], using default value 0")
volume = 0
if not (-24 <= pitch <= 24):
logger.warning(f"pitch {pitch} is out of valid range [-24, 24], using default value 0")
pitch = 0
if language not in lang_str2id_mapping:
logger.warning(f"language {language} is invalid, using default value ZH_CN")
language = "ZH_CN"
if audio_format not in codec_str2id_mapping:
logger.warning(f"audio_format {audio_format} is invalid, using default value pcm")
audio_format = "pcm"
logger.info(f"Connecting to {self.url}...")
headers = {"apikey": self.apikey} if self.url.startswith("wss:") else None
result_dict = {"success": False, "audio_data": None, "subtitles": [], "error": None}
try:
async with aiohttp.ClientSession(headers=headers) as session:
async with session.ws_connect(self.url) as websocket:
logger.info("WebSocket connection established")
session_id = str(uuid.uuid4())
params = {
"speaker": speaker,
"style": style,
"speed": speed,
"volume": volume,
"pitch": pitch,
"language": language,
"sample_rate": sample_rate,
"audio_format": audio_format,
"stream_output": stream_output,
"output_subtitles": output_subtitles,
}
# Set start time (before sending request)
start_time = time.time()
result_dict["start_time"] = start_time
# Start receive loop
receive_task = asyncio.create_task(self._receive_loop(websocket, session_id, params, result_dict))
# Simulate streaming: send character by character
for i, chunk in enumerate(text):
if not receive_task.done():
is_last = i == len(text) - 1
request = create_synthesis_request(
message_type=RequestMessageType.CLIENT_SYNTHESIS_REQUEST,
app_id=self.app_id,
text=chunk,
text_chunk_seq=i,
is_last_text_chunk=is_last,
session_id=session_id,
speaker=speaker,
style=style,
speed=speed,
output_subtitles=output_subtitles,
audio_sample_rate=sample_rate,
language=lang_str2id_mapping[language],
volume=volume,
audio_encoding=codec_str2id_mapping[audio_format],
stream_output=stream_output,
pitch=pitch,
)
full_request = serialize_request(request)
await websocket.send_bytes(full_request)
# Wait for receive task to complete
await receive_task
if result_dict["success"] and result_dict["audio_data"]:
audio_data = result_dict["audio_data"]
# Save audio file
if audio_format == "pcm":
if not output.endswith(".wav"):
output += ".wav"
audio_np = np.frombuffer(audio_data, dtype=np.int16)
sf.write(output, audio_np, samplerate=sample_rate, subtype="PCM_16")
else:
if not output.endswith(f".{audio_format}"):
output += f".{audio_format}"
with open(output, "wb") as fp:
fp.write(audio_data)
logger.info(f"audio saved to {output}, audio size: {len(audio_data) / 1024:.2f} KB")
os.chmod(output, 0o644)
return True
else:
error_msg = result_dict.get("error", "Unknown error")
logger.warning(f"SenseTimeTTSClient tts request failed: {error_msg}")
return False
except Exception as e:
logger.warning(f"SenseTimeTTSClient tts request failed: {e}")
return False
async def upload_audio_clone(
self,
audio_path,
audio_text,
disable_ns=False,
):
"""
Upload audio for voice cloning
Args:
audio_path: Audio file path
audio_text: Text corresponding to the audio
disable_ns: Whether to disable audio noise reduction processing
Returns:
tuple: (success: bool, result: str)
- success: True indicates success, False indicates failure
- result: Returns speaker_id on success, error message string on failure
"""
logger.info(f"Connecting to {self.url}...")
headers = {"apikey": self.apikey} if self.url.startswith("wss:") else None
result_dict = {"success": False, "speaker": None, "session_id": None, "error": None}
try:
async with aiohttp.ClientSession(headers=headers) as session:
async with session.ws_connect(self.url) as websocket:
logger.info("WebSocket connection established")
session_id = str(uuid.uuid4())
# Start receive loop
receive_task = asyncio.create_task(self._receive_loop(websocket, session_id, {}, result_dict))
# Read and send audio
# Check file format, if it's a video file (e.g., MP4), extract audio first
tmp_audio_path = None
original_audio_path = audio_path
try:
file_ext = os.path.splitext(audio_path)[1].lower()
if file_ext in [".mp4", ".mov", ".avi", ".mkv", ".flv"]:
# Video file, need to extract audio first
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_audio:
tmp_audio_path = tmp_audio.name
try:
# Use ffmpeg to extract audio
cmd = ["ffmpeg", "-i", audio_path, "-vn", "-acodec", "pcm_s16le", "-ar", "16000", "-ac", "1", "-y", tmp_audio_path]
proc = await asyncio.create_subprocess_exec(*cmd, stderr=asyncio.subprocess.PIPE)
try:
_, stderr = await asyncio.wait_for(proc.communicate(), timeout=60)
except asyncio.TimeoutError:
proc.kill()
await proc.wait()
raise ValueError("Audio extraction timeout. Video file may be too large.")
if proc.returncode != 0:
raise ValueError(f"Failed to extract audio from video: {stderr.decode(errors='ignore')}")
logger.info(f"Extracted audio from video file to {tmp_audio_path}")
audio_path = tmp_audio_path
except subprocess.TimeoutError:
raise ValueError("Audio extraction timeout. Video file may be too large.")
except FileNotFoundError:
raise ValueError("ffmpeg not found. Please install ffmpeg to process video files.")
except Exception as e:
raise ValueError(f"Failed to extract audio: {str(e)}")
with open(audio_path, "rb") as fp:
audio_bytes = fp.read()
# Check file size (recommended not to exceed 10MB)
if len(audio_bytes) > 10 * 1024 * 1024:
logger.warning(f"Audio file size ({len(audio_bytes) / 1024 / 1024:.2f} MB) may be too large")
audio_chunk = AudioChunk()
audio_chunk.audio_data = audio_bytes
audio_chunk.audio_chunk_seq = 0
audio_chunk.is_last_chunk = 1
audio_chunk.text = audio_text
audio_chunk.disable_ns = disable_ns
finally:
# Clean up temporary files
if tmp_audio_path and os.path.exists(tmp_audio_path):
try:
os.unlink(tmp_audio_path)
except Exception:
pass
request = create_synthesis_request(
message_type=RequestMessageType.CLIENT_UPLOAD_CLONE_AUDIO,
app_id=self.app_id,
text="",
session_id=session_id,
upload_data=audio_chunk,
)
full_request = serialize_request(request)
await websocket.send_bytes(full_request)
logger.info(f"Sent audio chunk for cloning")
# Wait for receive task to complete
await receive_task
if result_dict["success"]:
speaker_id = result_dict.get("speaker")
logger.info(f"SenseTimeTTSClient upload audio clone successful, speaker: {speaker_id}")
return True, speaker_id
else:
# Return error message string directly
error_msg = result_dict.get("error", "Unknown error")
logger.warning(f"SenseTimeTTSClient upload audio clone failed: {error_msg}")
return False, error_msg
except Exception as e:
error_msg = str(e)
logger.warning(f"SenseTimeTTSClient upload audio clone failed: {error_msg}")
return False, error_msg
async def query_speaker(self, speaker):
"""
Query if the specified speaker exists
Args:
speaker: speaker ID
"""
logger.info(f"Connecting to {self.url}...")
headers = {"apikey": self.apikey} if self.url.startswith("wss:") else None
result_dict = {"success": False, "speaker": None, "error": None}
try:
async with aiohttp.ClientSession(headers=headers) as session:
async with session.ws_connect(self.url) as websocket:
logger.info("WebSocket connection established")
session_id = str(uuid.uuid4())
# Start receive loop
receive_task = asyncio.create_task(self._receive_loop(websocket, session_id, {}, result_dict))
# Send query request
request = create_synthesis_request(
message_type=RequestMessageType.CLIENT_QUERY_CLONE_AUDIO,
app_id=self.app_id,
text="",
session_id=session_id,
speaker=speaker,
)
full_request = serialize_request(request)
await websocket.send_bytes(full_request)
logger.info(f"Sent query for speaker {speaker}")
# Wait for receive task to complete
await receive_task
if result_dict["success"]:
logger.info(f"SenseTimeTTSClient query speaker successful")
return True
else:
error_msg = result_dict.get("error", "Unknown error")
logger.warning(f"SenseTimeTTSClient query speaker failed: {error_msg}")
return False
except Exception as e:
logger.warning(f"SenseTimeTTSClient query speaker failed: {e}")
return False
async def delete_speaker(self, speaker):
"""
Delete the specified speaker
Args:
speaker: speaker ID
"""
logger.info(f"Connecting to {self.url}...")
headers = {"apikey": self.apikey} if self.url.startswith("wss:") else None
result_dict = {"success": False, "error": None}
try:
async with aiohttp.ClientSession(headers=headers) as session:
async with session.ws_connect(self.url) as websocket:
logger.info("WebSocket connection established")
session_id = str(uuid.uuid4())
# Start receive loop
receive_task = asyncio.create_task(self._receive_loop(websocket, session_id, {}, result_dict))
# Send delete request
request = create_synthesis_request(
message_type=RequestMessageType.CLIENT_DELETE_CLONE_AUDIO,
app_id=self.app_id,
text="",
session_id=session_id,
speaker=speaker,
)
full_request = serialize_request(request)
await websocket.send_bytes(full_request)
logger.info(f"Sent delete request for speaker {speaker}")
# Wait for receive task to complete
await receive_task
if result_dict["success"]:
logger.info(f"SenseTimeTTSClient delete speaker successful")
return True
else:
error_msg = result_dict.get("error", "Unknown error")
logger.warning(f"SenseTimeTTSClient delete speaker failed: {error_msg}")
return False
except Exception as e:
logger.warning(f"SenseTimeTTSClient delete speaker failed: {e}")
return False
async def test(args):
"""
TTS test function
Args:
args: list, e.g. [text, speaker, style, speed, volume, pitch, language, output, sample_rate, audio_format, stream_output, output_subtitles]
Provide as many as needed, from left to right.
Parameter ranges:
- speed: 0.5~2.0 (1.0 is normal speed)
- volume: -12~12 dB (0 is normal volume)
- pitch: -24~24 halftone (0 is normal pitch)
"""
client = SenseTimeTTSClient()
# Set default parameters
params = {
"text": "今天天气真不错,阳光明媚,微风轻拂,让人心情愉悦。",
"speaker": "M20",
"style": "正常",
"speed": 1.0,
"volume": 0,
"pitch": 0,
"language": "ZH_CN",
"output": "tts_output.wav",
"sample_rate": 24000,
"audio_format": "pcm",
"stream_output": True,
"output_subtitles": False,
}
keys = list(params.keys())
# Override default parameters
for i, arg in enumerate(args):
if i < len(keys):
# Type conversion
if keys[i] in ["sample_rate"]:
params[keys[i]] = int(arg)
elif keys[i] in ["stream_output", "output_subtitles"]:
# Support multiple boolean inputs
params[keys[i]] = str(arg).lower() in ("1", "true", "yes", "on")
elif keys[i] in ["speed", "volume", "pitch"]:
params[keys[i]] = float(arg)
else:
params[keys[i]] = arg
await client.tts_request(
params["text"],
params["speaker"],
params["style"],
params["speed"],
params["volume"],
params["pitch"],
params["language"],
params["output"],
params["sample_rate"],
params["audio_format"],
params["stream_output"],
params["output_subtitles"],
)
async def test_audio_clone(args):
"""
Voice cloning test function
Args:
args: list, e.g. [audio_path, audio_text, disable_ns]
Provide as many as needed, from left to right.
Parameters:
- audio_path: Audio file path (required)
- audio_text: Text corresponding to the audio (required)
- disable_ns: Whether to disable audio noise reduction processing, default False (optional, supports "1", "true", "yes", "on" for True)
"""
client = SenseTimeTTSClient()
# Set default parameters
params = {
"audio_path": "",
"audio_text": "",
"disable_ns": False,
}
keys = list(params.keys())
# Override default parameters
for i, arg in enumerate(args):
if i < len(keys):
# Type conversion
if keys[i] == "disable_ns":
# Support multiple boolean inputs
params[keys[i]] = str(arg).lower() in ("1", "true", "yes", "on")
else:
params[keys[i]] = arg
# Validate required parameters
if not params["audio_path"]:
logger.error("audio_path is required for audio clone test")
return
if not params["audio_text"]:
logger.error("audio_text is required for audio clone test")
return
# Check if file exists
if not os.path.exists(params["audio_path"]):
logger.error(f"Audio file not found: {params['audio_path']}")
return
success, result = await client.upload_audio_clone(
params["audio_path"],
params["audio_text"],
params["disable_ns"],
)
if success:
logger.info(f"Audio clone successful! Speaker ID: {result}")
else:
logger.warning(f"Audio clone failed: {result}")
if __name__ == "__main__":
# Support two test modes: regular TTS test and voice cloning test
if len(sys.argv) > 1 and sys.argv[1] == "clone":
# Voice cloning test mode: python sensetime_tts.py clone [audio_path] [audio_text] [disable_ns]
asyncio.run(test_audio_clone(sys.argv[2:]))
else:
# Regular TTS test mode: python sensetime_tts.py [text] [speaker] ...
asyncio.run(test(sys.argv[1:]))
import asyncio
import base64
import io
import os
import subprocess
import tempfile
import time
import traceback
from datetime import datetime
import httpx
import torchaudio
from PIL import Image
from loguru import logger
FMT = "%Y-%m-%d %H:%M:%S"
def current_time():
return datetime.now().timestamp()
def time2str(t):
d = datetime.fromtimestamp(t)
return d.strftime(FMT)
def str2time(s):
d = datetime.strptime(s, FMT)
return d.timestamp()
def try_catch(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception:
logger.error(f"Error in {func.__name__}:")
traceback.print_exc()
return None
return wrapper
def class_try_catch(func):
def wrapper(self, *args, **kwargs):
try:
return func(self, *args, **kwargs)
except Exception:
logger.error(f"Error in {self.__class__.__name__}.{func.__name__}:")
traceback.print_exc()
return None
return wrapper
def class_try_catch_async(func):
async def wrapper(self, *args, **kwargs):
try:
return await func(self, *args, **kwargs)
except Exception:
logger.error(f"Error in {self.__class__.__name__}.{func.__name__}:")
traceback.print_exc()
return None
return wrapper
def data_name(x, task_id):
if x == "input_image" or x.startswith("input_image/"):
x = x + ".png"
elif x == "input_video":
x = x + ".mp4"
elif x == "input_last_frame":
x = x + ".png"
elif x == "output_video":
x = x + ".mp4"
elif x == "output_image":
x = x + ".png"
return f"{task_id}-{x}"
async def fetch_resource(url, timeout):
logger.info(f"Begin to download resource from url: {url}")
t0 = time.time()
async with httpx.AsyncClient() as client:
async with client.stream("GET", url, timeout=timeout) as response:
response.raise_for_status()
ans_bytes = []
async for chunk in response.aiter_bytes(chunk_size=1024 * 1024):
ans_bytes.append(chunk)
if len(ans_bytes) > 128:
raise Exception(f"url {url} recv data is too big")
content = b"".join(ans_bytes)
logger.info(f"Download url {url} resource cost time: {time.time() - t0} seconds")
return content
# check, resize, read rotate meta info
def format_image_data(data, max_size=1280):
image = Image.open(io.BytesIO(data)).convert("RGB")
exif = image.getexif()
changed = False
w, h = image.size
assert w > 0 and h > 0, "image is empty"
logger.info(f"load image: {w}x{h}, exif: {exif}")
if w > max_size or h > max_size:
ratio = max_size / max(w, h)
w = int(w * ratio)
h = int(h * ratio)
image = image.resize((w, h))
logger.info(f"resize image to: {image.size}")
changed = True
orientation_key = 274
if orientation_key and orientation_key in exif:
orientation = exif[orientation_key]
if orientation == 2:
image = image.transpose(Image.FLIP_LEFT_RIGHT)
elif orientation == 3:
image = image.rotate(180, expand=True)
elif orientation == 4:
image = image.transpose(Image.FLIP_TOP_BOTTOM)
elif orientation == 5:
image = image.transpose(Image.FLIP_LEFT_RIGHT).rotate(90, expand=True)
elif orientation == 6:
image = image.rotate(270, expand=True)
elif orientation == 7:
image = image.transpose(Image.FLIP_LEFT_RIGHT).rotate(270, expand=True)
elif orientation == 8:
image = image.rotate(90, expand=True)
# reset orientation to 1
if orientation != 1:
logger.info(f"reset orientation from {orientation} to 1")
exif[orientation_key] = 1
changed = True
if not changed:
return data
output = io.BytesIO()
image.save(output, format=image.format or "JPEG", exif=exif.tobytes())
return output.getvalue()
def media_to_audio(data, max_duration=None, sample_rate=44100, channels=2, output_format="wav"):
with tempfile.NamedTemporaryFile() as fin:
fin.write(data)
fin.flush()
ds = ["-t", str(max_duration)] if max_duration is not None else []
fmts = ["mp3", "libmp3lame"] if output_format == "mp3" else ["wav", "pcm_s16le"]
cmd = ["ffmpeg", "-i", fin.name, *ds, "-f", fmts[0], "-acodec", fmts[1], "-ar", str(sample_rate), "-ac", str(channels), "pipe:1"]
logger.info(f"media_to_audio cmd: {cmd}")
p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
assert p.returncode == 0, f"media to {output_format} failed: {p.stderr.decode()}"
return p.stdout
def format_audio_data(data, max_duration=None):
if len(data) < 4:
raise ValueError("Audio file too short")
data = media_to_audio(data, max_duration)
waveform, sample_rate = torchaudio.load(io.BytesIO(data), num_frames=10)
logger.info(f"load audio: {waveform.size()}, {sample_rate}")
assert waveform.numel() > 0, "audio is empty"
assert sample_rate > 0, "audio sample rate is not valid"
return data
async def preload_data(inp, inp_type, typ, val):
try:
if typ == "url":
timeout = int(os.getenv("REQUEST_TIMEOUT", "5"))
data = await fetch_resource(val, timeout=timeout)
elif typ == "base64":
# Check if this is multiple base64 images (for i2i tasks)
# Frontend now sends a list of base64 strings: ["base64string1", "base64string2", ...]
if isinstance(val, list):
data = {}
for idx, encoded in enumerate(val):
if encoded.startswith("data:image"):
_, encoded = encoded.split(",", 1)
decoded = await asyncio.to_thread(base64.b64decode, encoded)
data[f"{inp}_{idx + 1}"] = decoded
else:
data = await asyncio.to_thread(base64.b64decode, val)
# For multi-person audio directory, val should be a dict with file structure
elif typ == "directory":
data = {}
for fname, b64_data in val.items():
data[fname] = await asyncio.to_thread(base64.b64decode, b64_data)
return {"type": "directory", "data": data}
elif typ == "stream":
# no bytes data need to be saved by data_manager
data = None
else:
raise ValueError(f"cannot read {inp}[{inp_type}] which type is {typ}!")
# check if valid image bytes
if inp_type == "IMAGE":
if isinstance(data, dict):
for key, value in data.items():
data[key] = await asyncio.to_thread(format_image_data, value)
return {"type": "directory", "data": data}
else:
data = await asyncio.to_thread(format_image_data, data)
elif inp_type == "AUDIO":
if typ != "stream" and typ != "directory":
data = await asyncio.to_thread(format_audio_data, data)
elif inp_type == "VIDEO":
# Video data doesn't need special formatting, just validate it's not empty
if len(data) == 0:
raise ValueError("Video file is empty")
logger.info(f"load video: {len(data)} bytes")
else:
raise Exception(f"cannot parse inp_type={inp_type} data")
return data
except Exception as e:
raise ValueError(f"Failed to read {inp}, type={typ}, val={val[:100]}: {e}!")
async def load_inputs(params, raw_inputs, types):
inputs_data = {}
for inp in raw_inputs:
item = params.pop(inp)
bytes_data = await preload_data(inp, types[inp], item["type"], item["data"])
# Handle multi-person audio directory, multiple images (for i2i tasks)
if bytes_data is not None and isinstance(bytes_data, dict) and bytes_data.get("type") == "directory":
fs = []
for fname, fdata in bytes_data["data"].items():
inputs_data[f"{inp}/{fname}"] = fdata
fs.append(f"{inp}/{fname}")
if "extra_inputs" not in params:
params["extra_inputs"] = {}
params["extra_inputs"][inp] = fs
elif bytes_data is not None:
inputs_data[inp] = bytes_data
else:
params[inp] = item
return inputs_data
def check_params(params, raw_inputs, raw_outputs, types):
stream_audio = os.getenv("STREAM_AUDIO", "0") == "1"
stream_video = os.getenv("STREAM_VIDEO", "0") == "1"
for x in raw_inputs + raw_outputs:
if x in params and "type" in params[x]:
if params[x]["type"] == "stream":
if types[x] == "AUDIO":
assert stream_audio, "stream audio is not supported, please set env STREAM_AUDIO=1"
elif types[x] == "VIDEO":
assert stream_video, "stream video is not supported, please set env STREAM_VIDEO=1"
elif params[x]["type"] == "directory":
# Multi-person audio directory is only supported for AUDIO type
assert types[x] == "AUDIO", f"directory type is only supported for AUDIO input, got {types[x]}"
if __name__ == "__main__":
# https://github.com/recurser/exif-orientation-examples
exif_dir = "/data/nvme0/liuliang1/exif-orientation-examples"
out_dir = "/data/nvme0/liuliang1/exif-orientation-examples/outs"
os.makedirs(out_dir, exist_ok=True)
for base_name in ["Landscape", "Portrait"]:
for i in range(9):
fin_name = os.path.join(exif_dir, f"{base_name}_{i}.jpg")
fout_name = os.path.join(out_dir, f"{base_name}_{i}_formatted.jpg")
logger.info(f"format image: {fin_name} -> {fout_name}")
with open(fin_name, "rb") as f:
data = f.read()
data = format_image_data(data)
with open(fout_name, "wb") as f:
f.write(data)
import math
import os
import torch
import torch.distributed as dist
from loguru import logger
from lightx2v.models.runners.vsr.vsr_wrapper import compute_scaled_and_target_dims
from lightx2v_platform.base.global_var import AI_DEVICE
class NextControl:
def __init__(self, action: str, data: any = None):
# action: blank_to_voice, data: prev_video tensor
# action: wait, data: None
# action: fetch, data: None
# action: switch_image, data: image_path
# action: perform_action, data: action prompt
self.action = action
self.data = data
class VAController:
def __init__(self, model_runner):
self.reader = None
self.recorder = None
self.rank = 0
self.world_size = 1
if dist.is_initialized():
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.target_reader_rank = int(os.getenv("READER_RANK", "0")) % self.world_size
self.target_recorder_rank = int(os.getenv("RECORDER_RANK", "0")) % self.world_size
self.init_base(model_runner.config, model_runner.input_info, model_runner.vfi_model is not None, model_runner.vsr_model is not None)
self.init_recorder()
self.init_reader(model_runner)
def init_base(self, config, input_info, has_vfi_model, has_vsr_model):
if "stream_config" in input_info.__dataclass_fields__:
self.stream_config = input_info.stream_config
logger.info(f"VAController init base with stream config: {self.stream_config}")
self.audio_path = input_info.audio_path
self.output_video_path = input_info.save_result_path
if isinstance(self.output_video_path, dict):
self.output_video_path = self.output_video_path["data"]
self.audio_sr = config.get("audio_sr", 16000)
self.target_fps = config.get("target_fps", 16)
self.max_num_frames = config.get("target_video_length", 81)
self.prev_frame_length = config.get("prev_frame_length", 5)
self.record_fps = config.get("target_fps", 16)
if "video_frame_interpolation" in config and has_vfi_model:
self.record_fps = config["video_frame_interpolation"]["target_fps"]
self.record_fps = config.get("record_fps", self.record_fps)
self.tgt_h = input_info.target_shape[0]
self.tgt_w = input_info.target_shape[1]
self.record_h, self.record_w = self.tgt_h, self.tgt_w
if "video_super_resolution" in config and has_vsr_model:
_, _, self.record_w, self.record_h = compute_scaled_and_target_dims(
self.record_w,
self.record_h,
scale=config["video_super_resolution"]["scale"],
multiple=128,
)
# how many frames to publish stream as a batch
self.slice_frame = config.get("slice_frame", self.prev_frame_length)
# estimate the max infer seconds, for immediate switch with local omni
slice_interval = self.slice_frame / self.record_fps
est_max_infer_secs = config.get("est_max_infer_secs", 0.6)
est_max_switch_image_secs = config.get("est_max_switch_image_secs", 0)
est_max_switch_action_secs = config.get("est_max_switch_action_secs", 0)
self.est_infer_end_idx = math.ceil(est_max_infer_secs / slice_interval)
self.est_switch_image_end_idx = math.ceil(est_max_switch_image_secs / slice_interval)
self.est_switch_action_end_idx = math.ceil(est_max_switch_action_secs / slice_interval)
max_end_idx = max(self.est_infer_end_idx, self.est_switch_image_end_idx, self.est_switch_action_end_idx)
self.min_stay_queue_num = max_end_idx * 2 + 1
def init_recorder(self):
if not self.output_video_path or self.rank != self.target_recorder_rank:
return
logger.info(f"Rank {self.rank} init recorder with: {self.output_video_path}")
whip_shared_path = os.getenv("WHIP_SHARED_LIB", None)
if whip_shared_path and self.output_video_path.startswith("http"):
from lightx2v.deploy.common.va_recorder_x264 import X264VARecorder
self.recorder = X264VARecorder(
whip_shared_path=whip_shared_path,
livestream_url=self.output_video_path,
fps=self.record_fps,
sample_rate=self.audio_sr,
slice_frame=self.slice_frame,
prev_frame=self.prev_frame_length,
)
else:
from lightx2v.deploy.common.va_recorder import VARecorder
self.recorder = VARecorder(
livestream_url=self.output_video_path,
fps=self.record_fps,
sample_rate=self.audio_sr,
slice_frame=self.slice_frame,
prev_frame=self.prev_frame_length,
stream_config=self.stream_config,
)
def init_reader(self, model_runner=None):
if not isinstance(self.audio_path, dict):
return
assert self.audio_path["type"] == "stream", f"unexcept audio_path: {self.audio_path}"
segment_duration = self.max_num_frames / self.target_fps
prev_duration = self.prev_frame_length / self.target_fps
omni_work_dir = os.getenv("OMNI_WORK_DIR", None)
if omni_work_dir:
from lightx2v.deploy.common.va_reader_omni import OmniVAReader
self.reader = OmniVAReader(
rank=self.rank,
world_size=self.world_size,
stream_url=self.audio_path["data"],
sample_rate=self.audio_sr,
segment_duration=segment_duration,
prev_duration=prev_duration,
target_rank=self.target_reader_rank,
model_runner=model_runner,
huoshan_tts_voice_type=self.audio_path.get("huoshan_tts_voice_type", None),
stream_config=self.stream_config,
va_recorder=self.recorder,
)
else:
from lightx2v.deploy.common.va_reader import VAReader
self.reader = VAReader(
rank=self.rank,
world_size=self.world_size,
stream_url=self.audio_path["data"],
sample_rate=self.audio_sr,
segment_duration=segment_duration,
prev_duration=prev_duration,
target_rank=self.target_reader_rank,
)
def start(self):
self.reader.start()
if self.rank == self.target_recorder_rank:
assert self.recorder is not None, f"recorder is required for stream audio input for rank {self.rank}"
self.recorder.start(self.record_w, self.record_h)
if self.world_size > 1:
dist.barrier()
def next_control(self):
from lightx2v.deploy.common.va_reader_omni import OmniVAReader
if isinstance(self.reader, OmniVAReader):
action_control = self.omni_reader_action_control()
if action_control is not None:
return action_control
image_control = self.omni_reader_image_control()
if image_control is not None:
return image_control
return self.omni_reader_next_control()
return NextControl(action="fetch")
def before_control(self):
from lightx2v.deploy.common.va_reader_omni import OmniVAReader
if isinstance(self.reader, OmniVAReader):
self.len_tensor = torch.tensor([0], dtype=torch.int32, device=AI_DEVICE)
self.flag_tensor = torch.tensor([0], dtype=torch.int32, device=AI_DEVICE)
self.prev_tensor = torch.zeros((1, 3, self.prev_frame_length, self.tgt_h, self.tgt_w), dtype=torch.float, device=AI_DEVICE)
def omni_reader_next_control(self):
immediate_switch = self.reader.get_immediate_switch()
if immediate_switch == 1:
# truncate the stream buffer to keep the max infer time length
# and broadcast the prev video tensor to all ranks
if self.rank == self.target_recorder_rank:
logger.warning(f"runner recv immediate switch, truncate stream buffer")
video_tensor = self.recorder.truncate_stream_buffer(self.est_infer_end_idx)
if video_tensor is not None:
self.flag_tensor.fill_(1)
self.prev_tensor.copy_(video_tensor)
else:
self.flag_tensor.fill_(0)
dist.broadcast(self.flag_tensor, src=self.target_recorder_rank)
if self.flag_tensor.item() == 1:
dist.broadcast(self.prev_tensor, src=self.target_recorder_rank)
return NextControl(action="blank_to_voice", data=self.prev_tensor)
else:
# get the length of stream buffer, broadcast to all ranks
if self.rank == self.target_recorder_rank:
stream_buffer_length = self.recorder.get_buffer_stream_size()
self.len_tensor.copy_(stream_buffer_length)
dist.broadcast(self.len_tensor, src=self.target_recorder_rank)
buffer_length = self.len_tensor.item()
# stream buffer is enough, skip infer
if buffer_length >= self.min_stay_queue_num:
return NextControl(action="wait")
return NextControl(action="fetch")
def omni_reader_image_control(self):
image_switch = self.reader.get_image_switch()
if not isinstance(image_switch, str) or len(image_switch) == 0:
return None
if not os.path.exists(image_switch):
logger.warning(f"Switch image path {image_switch} does not exist")
return None
# truncate the stream buffer to keep the max infer time length
if self.rank == self.target_recorder_rank:
logger.warning(f"runner recv image switch, truncate stream buffer")
self.recorder.truncate_stream_buffer(self.est_switch_image_end_idx)
return NextControl(action="switch_image", data=image_switch)
def omni_reader_action_control(self):
action_switch = self.reader.get_action_switch()
if not isinstance(action_switch, str) or len(action_switch) == 0:
return None
# truncate the stream buffer to keep the max infer time length
if self.rank == self.target_recorder_rank:
logger.warning(f"runner recv action switch, truncate stream buffer")
self.recorder.truncate_stream_buffer(self.est_switch_action_end_idx)
return NextControl(action="perform_action", data=action_switch)
def pub_livestream(self, images: torch.Tensor, audios: torch.Tensor, gen_video: torch.Tensor, valid_duration=1e9):
if self.recorder.realtime:
self.recorder.buffer_stream(images, audios, gen_video, valid_duration=valid_duration)
else:
self.recorder.pub_livestream(images, audios)
def clear(self):
self.len_tensor = None
self.flag_tensor = None
self.prev_tensor = None
if self.reader is not None:
try:
self.reader.stop()
except Exception as e:
logger.warning(f"Error stopping reader: {e}")
self.reader = None
if self.recorder is not None:
try:
self.recorder.stop()
except Exception as e:
logger.warning(f"Error stopping recorder: {e}")
self.recorder = None
def __del__(self):
self.clear()
import os
import queue
import signal
import subprocess
import threading
import time
import traceback
import numpy as np
import torch
import torch.distributed as dist
from loguru import logger
class VAReader:
def __init__(
self,
rank: int,
world_size: int,
stream_url: str,
segment_duration: float = 5.0,
sample_rate: int = 16000,
audio_channels: int = 1,
buffer_size: int = 1,
prev_duration: float = 0.3125,
target_rank: int = 0,
):
self.rank = rank
self.world_size = world_size
self.stream_url = stream_url
self.segment_duration = segment_duration
self.sample_rate = sample_rate
self.audio_channels = audio_channels
self.prev_duration = prev_duration
# int16 = 2 bytes
self.chunk_size = int(self.segment_duration * self.sample_rate) * 2
self.prev_size = int(self.prev_duration * self.sample_rate) * 2
self.prev_chunk = None
self.buffer_size = buffer_size
self.audio_queue = queue.Queue(maxsize=self.buffer_size)
self.audio_thread = None
self.ffmpeg_process = None
self.bytes_buffer = bytearray()
self.target_rank = target_rank % self.world_size
self.flag_tensor = torch.tensor([0], dtype=torch.int32).to(device="cuda")
self.audio_tensor = torch.zeros(self.chunk_size, dtype=torch.uint8, device="cuda")
logger.info(f"VAReader initialized for stream: {stream_url} target_rank: {self.target_rank}")
logger.info(f"Audio duration per chunk: {segment_duration}s, sample rate: {sample_rate}Hz")
def start(self):
if self.rank == self.target_rank:
if self.stream_url.startswith("rtmp://"):
self.start_ffmpeg_process_rtmp()
elif self.stream_url.startswith("http"):
self.start_ffmpeg_process_whep()
else:
raise Exception(f"Unsupported stream URL: {self.stream_url}")
self.audio_thread = threading.Thread(target=self.audio_worker, daemon=True)
self.audio_thread.start()
logger.info(f"VAReader {self.rank}/{self.world_size} started successfully")
else:
logger.info(f"VAReader {self.rank}/{self.world_size} wait only")
if self.world_size > 1:
logger.info(f"VAReader {self.rank}/{self.world_size} wait barrier")
dist.barrier()
logger.info(f"VAReader {self.rank}/{self.world_size} end barrier")
def start_ffmpeg_process_rtmp(self):
"""Start ffmpeg process read audio from stream"""
ffmpeg_cmd = [
"ffmpeg",
"-i",
self.stream_url,
"-vn",
# "-acodec",
# "pcm_s16le",
"-ar",
str(self.sample_rate),
"-ac",
str(self.audio_channels),
"-f",
"s16le",
"-",
]
try:
self.ffmpeg_process = subprocess.Popen(ffmpeg_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=0)
logger.info(f"FFmpeg audio pull process started with PID: {self.ffmpeg_process.pid}")
logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}")
except Exception as e:
logger.error(f"Failed to start FFmpeg process: {e}")
raise
def start_ffmpeg_process_whep(self):
"""Start gstream process read audio from stream"""
ffmpeg_cmd = [
"gst-launch-1.0",
"-q",
"whepsrc",
f"whep-endpoint={self.stream_url}",
"video-caps=none",
"!rtpopusdepay",
"!opusdec",
"plc=false",
"!audioconvert",
"!audioresample",
f"!audio/x-raw,format=S16LE,channels={self.audio_channels},rate={self.sample_rate}",
"!fdsink",
"fd=1",
]
try:
self.ffmpeg_process = subprocess.Popen(
ffmpeg_cmd,
stdout=subprocess.PIPE,
# stderr=subprocess.PIPE,
bufsize=0,
)
logger.info(f"FFmpeg audio pull process started with PID: {self.ffmpeg_process.pid}")
logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}")
except Exception as e:
logger.error(f"Failed to start FFmpeg process: {e}")
raise
def audio_worker(self):
logger.info("Audio pull worker thread started")
try:
while True:
if not self.ffmpeg_process or self.ffmpeg_process.poll() is not None:
logger.warning("FFmpeg process exited, audio worker thread stopped")
break
self.fetch_audio_data()
time.sleep(0.01)
except: # noqa
logger.error(f"Audio pull worker error: {traceback.format_exc()}")
finally:
logger.warning("Audio pull worker thread stopped")
def fetch_audio_data(self):
"""Fetch audio data from ffmpeg process"""
try:
audio_bytes = self.ffmpeg_process.stdout.read(self.chunk_size)
if not audio_bytes:
return
self.bytes_buffer.extend(audio_bytes)
# logger.info(f"Fetch audio data: {len(audio_bytes)} bytes, bytes_buffer: {len(self.bytes_buffer)} bytes")
if len(self.bytes_buffer) >= self.chunk_size:
audio_data = self.bytes_buffer[: self.chunk_size]
self.bytes_buffer = self.bytes_buffer[self.chunk_size :]
# first chunk, read original 81 frames
# for other chunks, read 81 - 5 = 76 frames, concat with previous 5 frames
if self.prev_chunk is None:
logger.info(f"change chunk_size: from {self.chunk_size} to {self.chunk_size - self.prev_size}")
self.chunk_size -= self.prev_size
else:
audio_data = self.prev_chunk + audio_data
self.prev_chunk = audio_data[-self.prev_size :]
try:
self.audio_queue.put_nowait(audio_data)
except queue.Full:
logger.warning(f"Audio queue full:{self.audio_queue.qsize()}, discarded oldest chunk")
self.audio_queue.get_nowait()
self.audio_queue.put_nowait(audio_data)
logger.info(f"Put audio data: {len(audio_data)} bytes, audio_queue: {self.audio_queue.qsize()}, chunk_size:{self.chunk_size}")
except: # noqa
logger.error(f"Fetch audio data error: {traceback.format_exc()}")
def braodcast_audio_data(self, audio_data):
if self.rank == self.target_rank:
if audio_data is None:
self.flag_tensor.fill_(0)
else:
self.flag_tensor.fill_(1)
self.audio_tensor.copy_(torch.frombuffer(bytearray(audio_data), dtype=torch.uint8))
logger.info(f"rank {self.rank} send audio_tensor: {self.audio_tensor.shape}")
dist.broadcast(self.flag_tensor, src=self.target_rank)
if self.flag_tensor.item() == 0:
return None
dist.broadcast(self.audio_tensor, src=self.target_rank)
if self.rank != self.target_rank:
logger.info(f"rank {self.rank} recv audio_tensor: {self.audio_tensor.shape}")
audio_data = self.audio_tensor.cpu().numpy().tobytes()
return audio_data
def bytes_to_ndarray(self, audio_data):
if audio_data is None:
return None
audio_data = np.frombuffer(audio_data, dtype=np.int16)
audio_data = audio_data.astype(np.float32) / 32768.0
logger.info(f"Got segment audio rank={self.rank}: {audio_data.shape} {audio_data.dtype} {audio_data.min()} {audio_data.max()}")
return audio_data
def get_audio_segment(self, timeout: float = 1.0, fetch_duration: float = None, prev_duration: float = None):
if fetch_duration is not None and self.segment_duration != fetch_duration:
logger.warning(f"ignore fetch_duration, {fetch_duration} != {self.segment_duration}")
if prev_duration is not None and self.prev_duration != prev_duration:
raise ValueError(f"prev_duration {prev_duration} != {self.prev_duration}")
audio_data = None
if self.rank == self.target_rank:
try:
audio_data = self.audio_queue.get(timeout=timeout)
except: # noqa
logger.warning(f"Failed to get audio segment: {traceback.format_exc()}")
if self.world_size > 1:
audio_data = self.braodcast_audio_data(audio_data)
audio_data = self.bytes_to_ndarray(audio_data)
return audio_data, self.segment_duration
def stop(self):
# Stop ffmpeg process
if self.ffmpeg_process:
self.ffmpeg_process.send_signal(signal.SIGINT)
try:
self.ffmpeg_process.wait(timeout=5)
except subprocess.TimeoutExpired:
self.ffmpeg_process.kill()
logger.warning("FFmpeg reader process stopped")
# Wait for threads to finish
if self.audio_thread and self.audio_thread.is_alive():
self.audio_thread.join(timeout=5)
if self.audio_thread.is_alive():
logger.error("Audio pull thread did not stop gracefully")
while self.audio_queue and self.audio_queue.qsize() > 0:
self.audio_queue.get_nowait()
self.audio_queue = None
logger.warning("Audio pull queue cleaned")
def __del__(self):
self.stop()
if __name__ == "__main__":
WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))
RANK = int(os.environ.get("RANK", 0))
if WORLD_SIZE > 1:
dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank())
logger.info(f"Distributed initialized: rank={RANK}, world_size={WORLD_SIZE}")
reader = VAReader(
RANK,
WORLD_SIZE,
# "rtmp://localhost/live/test_audio",
"https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whep/?app=live&stream=ll_test_audio&eip=10.120.114.76:8000",
segment_duration=1.0,
sample_rate=16000,
audio_channels=1,
prev_duration=1 / 16,
)
reader.start()
fail_count = 0
max_fail_count = 2
try:
while True:
audio_data = reader.get_audio_segment(timeout=2)
if audio_data is not None:
# logger.info(f"Got audio chunk, shape: {audio_data.shape}, range: [{audio_data.min()}, {audio_data.max()}]")
fail_count = 0
else:
fail_count += 1
if fail_count > max_fail_count:
logger.warning("Failed to get audio chunk, stop reader")
reader.stop()
break
time.sleep(0.95)
finally:
reader.stop()
import datetime
import json
import os
import random
import subprocess
import threading
import time
import traceback
from collections import deque
from copy import deepcopy
import jsonschema
import numpy as np
import torch
import torch.distributed as dist
import zmq
from loguru import logger
try:
from bson import BSON
except ImportError:
BSON = None
logger.warning("BSON is not installed")
from scipy.signal import resample
class AudioInfo:
def __init__(self, info: dict):
self.sample_count = info["sample_count"]
self.sample_rate = info["sample_rate"]
self.channel_count = info["channel_count"]
self.sample_fmt = info["sample_fmt"]
self.pts = info["pts"]
def is_spec_equal(self, other: "AudioInfo") -> bool:
return self.sample_fmt == other.sample_fmt and self.sample_rate == other.sample_rate and self.channel_count == other.channel_count
def duration(self) -> datetime.timedelta:
return datetime.timedelta(seconds=self.sample_count / self.sample_rate)
def __str__(self):
return "AudioInfo(sample_count={}, sample_rate={}, channel_count={}, sample_fmt={}, pts={})".format(self.sample_count, self.sample_rate, self.channel_count, self.sample_fmt, self.pts)
class ByteBuffer:
def __init__(self):
self.buffer = deque()
self.current_size = 0
# is the audio belonging to current turn finished
self.audio_finished = False
def add(self, byte_data: bytes):
self.buffer.append(byte_data)
self.current_size += len(byte_data)
def get(self, size=1024):
data = bytearray()
while size > 0 and len(self.buffer) > 0:
chunk = self.buffer.popleft()
if len(chunk) <= size:
# 如果当前数据小于size,则将当前数据全部添加到data中
data.extend(chunk)
self.current_size -= len(chunk)
size -= len(chunk)
else:
# 如果当前数据大于size,则将当前数据的一部分添加到data中,剩余部分留在缓冲区
data.extend(chunk[:size])
self.buffer.appendleft(chunk[size:]) # 剩余部分留在缓冲区
self.current_size -= size
size = 0
return bytes(data)
def mark_finished(self):
self.audio_finished = True
def has_more_voice(self):
return not self.audio_finished
def __len__(self):
return self.current_size
class ChatAdapter:
def __init__(
self,
omni_work_dir: str,
whep_url: str,
session_id: str,
account: str,
config_files: list[str],
config_schema_path: str,
seg_duration: float,
model_runner,
huoshan_tts_voice_type,
stream_config: dict,
):
assert os.path.exists(omni_work_dir), f"OMNI work directory {omni_work_dir} does not exist"
self.omni_work_dir = omni_work_dir
self.stream_config = stream_config
self.context = zmq.Context()
self.w2f_socket = self.context.socket(zmq.PULL)
self.w2f_url = ChatAdapter.select_and_bind(self.w2f_socket)
self.f2w_socket = self.context.socket(zmq.PUSH)
self.f2w_url = ChatAdapter.select_and_bind(self.f2w_socket)
self.recv_thread = None
self.audio_buffer = ByteBuffer()
self.audio_info = None
self.chat_server_cmd = [
os.path.join(self.omni_work_dir, "bin", "seko-chatter"),
"--session-id",
session_id,
"--account",
account,
"--whep-server-url",
whep_url,
"--w2f-endpoint",
self.w2f_url,
"--f2w-endpoint",
self.f2w_url,
"--config-files",
*config_files,
]
override_config = {}
if huoshan_tts_voice_type is not None:
logger.info(f"Use Huoshan TTS voice type: {huoshan_tts_voice_type}")
override_config["TTS"] = {
"default_voice_info": {
"voice_type": huoshan_tts_voice_type,
"provider": "huoshan_stream_tts",
}
}
system_prompt = stream_config.get("system_prompt", "")
if system_prompt:
override_config["model"] = {"system_prompt": system_prompt}
logger.info(f"Omni use custom system prompt: {system_prompt}")
with open(config_schema_path, "r") as f:
schema = json.load(f)
jsonschema.validate(instance=override_config, schema=schema)
if override_config is not None:
self.chat_server_cmd.extend(["--override-config", json.dumps(override_config)])
self.chatter_proc = None
self.seg_duration = seg_duration
self.reset_prev = False
self.status = "blank"
self.immediate_switch = 0
self.image_switch = ""
self.action_switch = ""
self.model_runner = model_runner
def launch_chat_server(self):
env = {
"RUST_LOG": "info,duplex_server=debug,backend_5o=debug",
"LD_LIBRARY_PATH": os.environ.get("LD_LIBRARY_PATH", "") + ":" + os.path.join(self.omni_work_dir, "lib/"),
"PATH": os.environ["PATH"] + ":" + os.path.join(self.omni_work_dir, "bin/"),
}
self.chatter_proc = subprocess.Popen(self.chat_server_cmd, env=env, cwd=self.omni_work_dir)
@staticmethod
def select_and_bind(socket: zmq.Socket) -> str:
# randomly select a port between 1024 and 6553
retry_count = 20
err = None
while retry_count > 0:
try:
port = random.randint(1024, 65535)
# port = 5555
url = f"tcp://localhost:{port}"
socket.bind(url)
return url
except zmq.error.ZMQError as e:
retry_count -= 1
err = e
raise err
# immediate switch to status, discard prev_bytes, set immediate_switch to 1
def immediate_switch_to(self, status):
logger.warning(f"VA reader immediate switch to {status}")
self.reset_prev = True
self.status = status
self.immediate_switch = 1
# only no action switch can be paused immediately
if self.model_runner is not None and self.model_runner.can_pause:
self.model_runner.pause_signal = True
logger.warning(f"Model runner pause signal set to True")
def set_image_switch(self, image_path):
logger.warning(f"Setting image switch: {image_path}")
self.image_switch = image_path
# only blank status and no action switch can be paused immediately
if self.model_runner is not None and self.model_runner.can_pause:
self.model_runner.pause_signal = True
logger.warning(f"Model runner set pause signal for image switch & blank status")
def set_action_switch(self, prompt):
logger.warning(f"Setting action switch: {prompt}")
self.action_switch = prompt
# only blank status can be paused immediately
if self.model_runner is not None and self.model_runner.can_pause:
self.model_runner.pause_signal = True
logger.warning(f"Model runner set pause signal for action switch & blank status")
def recv_loop(self):
while True:
try:
message = self.w2f_socket.recv()
except Exception:
logger.error(f"Error receiving message: {traceback.format_exc()}")
break
try:
message = BSON.decode(message)
msg_type = message["type"]
logger.debug("Received message type: {}".format(msg_type))
if msg_type == "AgentAudio":
audio = message["audio"]
if audio["type"] != "Pcm":
logger.error("Unsupported audio type: {}".format(audio["type"]))
continue
pcm_data = audio["data"]
audio_info = AudioInfo(audio["info"])
logger.debug("Received audio with duration: {}".format(audio_info.duration()))
if self.audio_info is None:
self.audio_info = audio_info
else:
# check if the audio info is the same
if not self.audio_info.is_spec_equal(audio_info):
raise ValueError("Audio info mismatch")
self.audio_buffer.add(pcm_data)
# if status is blank and has voice, set immediate switch to 1
if self.status == "blank" and self.has_voice(self.seg_duration):
self.immediate_switch_to("voice")
elif msg_type == "AgentStartPlay":
logger.debug("Received AgentStartPlay, create new audio buffer")
self.audio_buffer = ByteBuffer()
elif msg_type == "AgentEndPlay":
logger.debug("Received AgentEndPlay, mark audio finished")
self.audio_buffer.mark_finished()
elif msg_type == "ClearAgentAudio":
logger.warning("Received ClearAgentAudio, clear audio buffer")
self.audio_buffer = None
self.audio_info = None
if self.status == "voice":
self.status = "blank"
# self.immediate_switch_to("blank")
except Exception as e:
logger.error("Error decoding message: {}, continue".format(e))
continue
logger.warning("recv loop interrupted")
def start(self):
self.launch_chat_server()
self.recv_thread = threading.Thread(target=self.recv_loop)
self.recv_thread.start()
def has_voice(self, duration) -> bool:
if self.audio_info is None or self.audio_buffer.current_size == 0:
return False
bytes_count = round(duration * self.audio_info.sample_rate) * self.audio_info.channel_count * 2 # S16LE assumed
# if not has enough bytes and maybe has more voice, return False
if self.audio_buffer.current_size < bytes_count and self.audio_buffer.has_more_voice():
logger.warning(f"Not enough bytes and maybe has more voice, content_size: {self.audio_buffer.current_size}, bytes_count: {bytes_count}")
return False
return bytes_count
def get_audio(self, fetch_duration) -> (bytes, AudioInfo):
bytes_count = self.has_voice(fetch_duration)
if bytes_count is False or self.audio_info is None:
return None
pcm_data = self.audio_buffer.get(bytes_count)
# the actual sample count fetched
sample_count = len(pcm_data) // (self.audio_info.channel_count * 2)
logger.debug("Fetched {} bytes audio".format(sample_count))
logger.debug("After fetch, there are {} bytes left".format(self.audio_buffer.current_size))
audio_info = deepcopy(self.audio_info)
audio_info.sample_count = sample_count
return (pcm_data, audio_info)
def stop(self):
self.model_runner = None
if self.chatter_proc is not None:
self.chatter_proc.terminate()
self.chatter_proc.wait()
self.chatter_proc = None
self.w2f_socket.close()
self.f2w_socket.close()
def __del__(self):
self.stop()
class OmniVAReader:
def __init__(
self,
rank: int,
world_size: int,
stream_url: str,
segment_duration: float = 5.0625,
sample_rate: int = 16000,
audio_channels: int = 1,
buffer_size: int = 1,
prev_duration: float = 0.3125,
target_rank: int = 0,
model_runner=None,
huoshan_tts_voice_type=None,
stream_config: dict = {},
**kwargs,
):
self.rank = rank
self.world_size = world_size
self.stream_url = stream_url
self.segment_duration = segment_duration
self.sample_rate = sample_rate
self.audio_channels = audio_channels
self.prev_duration = prev_duration
self.all_seg_sample_count = int(self.segment_duration * self.sample_rate)
self.prev_seg_sample_count = int(self.prev_duration * self.sample_rate)
self.prev_seg_chunk = None
self.target_rank = target_rank % self.world_size
self.flag_tensor = torch.tensor([0], dtype=torch.int32).to(device="cuda")
self.valid_duration_tensor = torch.tensor([0], dtype=torch.float32).to(device="cuda")
self.immediate_switch_tensor = torch.tensor([0], dtype=torch.int32).to(device="cuda")
chunk_size = int(self.segment_duration * self.sample_rate) * 2
self.audio_tensor = torch.zeros(chunk_size, dtype=torch.uint8, device="cuda")
self.chat_adapter = None
self.model_runner = model_runner
self.huoshan_tts_voice_type = huoshan_tts_voice_type
self.stream_config = stream_config
assert self.audio_channels == 1, "Only mono audio is supported for OmniVAReader"
logger.info(f"VAReader initialized for stream: {stream_url} target_rank: {self.target_rank}")
logger.info(f"Audio duration per chunk: {segment_duration}s, sample rate: {sample_rate}Hz")
def init_omni_env(self):
self.omni_work_dir = os.getenv("OMNI_WORK_DIR", "/path/of/seko_chatter/")
self.session_id = os.getenv("OMNI_SESSION_ID", "")
self.account = os.getenv("OMNI_ACCOUNT", "")
self.config_files = os.getenv("OMNI_CONFIG_FILES", "").split(",")
self.config_schema_path = os.getenv("OMNI_CONFIG_SCHEMA_PATH", None)
assert os.path.exists(self.omni_work_dir), f"OMNI work directory {self.omni_work_dir} does not exist"
assert self.session_id and self.account, "OMNI_SESSION_ID and OMNI_ACCOUNT are required"
logger.info(
f"OMNI work directory: {self.omni_work_dir}, session_id: {self.session_id}, account: {self.account}, config_files: {self.config_files}, config_schema_path: {self.config_schema_path}"
)
def start(self):
if self.rank == self.target_rank:
self.init_omni_env()
assert self.stream_url.startswith("http"), "Only HTTP stream is supported for OmniVAReader"
self.chat_adapter = ChatAdapter(
omni_work_dir=self.omni_work_dir,
whep_url=self.stream_url,
session_id=self.session_id,
account=self.account,
config_files=self.config_files,
config_schema_path=self.config_schema_path,
seg_duration=self.segment_duration,
model_runner=self.model_runner,
huoshan_tts_voice_type=self.huoshan_tts_voice_type,
stream_config=self.stream_config,
)
self.chat_adapter.start()
logger.info(f"OmniVAReader {self.rank}/{self.world_size} started successfully")
else:
logger.info(f"OmniVAReader {self.rank}/{self.world_size} wait only")
if self.world_size > 1:
logger.info(f"OmniVAReader {self.rank}/{self.world_size} wait barrier")
dist.barrier()
logger.info(f"OmniVAReader {self.rank}/{self.world_size} end barrier")
def braodcast_audio_data(self, audio_data):
if self.rank == self.target_rank:
if audio_data is None:
self.flag_tensor.fill_(0)
else:
self.flag_tensor.fill_(1)
self.audio_tensor.copy_(torch.frombuffer(bytearray(audio_data), dtype=torch.uint8))
# logger.info(f"rank {self.rank} send audio_tensor: {self.audio_tensor.shape}")
dist.broadcast(self.flag_tensor, src=self.target_rank)
if self.flag_tensor.item() == 0:
return None
dist.broadcast(self.audio_tensor, src=self.target_rank)
if self.rank != self.target_rank:
# logger.info(f"rank {self.rank} recv audio_tensor: {self.audio_tensor.shape}")
audio_data = self.audio_tensor.cpu().numpy().tobytes()
return audio_data
def braodcast_valid_duration(self, valid_duration):
if self.rank == self.target_rank:
self.valid_duration_tensor.fill_(valid_duration)
dist.broadcast(self.valid_duration_tensor, src=self.target_rank)
return self.valid_duration_tensor.item()
def bytes_to_ndarray(self, audio_data):
if audio_data is None:
return None
audio_data = np.frombuffer(audio_data, dtype=np.int16)
audio_data = audio_data.astype(np.float32) / 32768.0
# logger.info(f"Got segment audio rank={self.rank}: {audio_data.shape} {audio_data.dtype} {audio_data.min()} {audio_data.max()}")
return audio_data
def convert_pcm_s16le_to_mono_resampled(self, audio_data, audio_info):
audio = np.frombuffer(audio_data, dtype=np.int16)
sample_count = audio_info.sample_count
assert len(audio) == sample_count * audio_info.channel_count, f"audio length {len(audio)} != sample_count * channel_count {sample_count * audio_info.channel_count}"
# convert to mono
if audio_info.channel_count > 1:
audio = audio.reshape(-1, audio_info.channel_count).mean(axis=1)
# logger.info(f"audio: {audio.shape} {audio.dtype} {audio.min()} {audio.max()}")
if audio_info.sample_rate != self.sample_rate:
sample_count = int(len(audio) * self.sample_rate / audio_info.sample_rate)
audio = resample(audio, sample_count).astype(np.int16)
# logger.info(f"resampled audio: {audio.shape} {audio.dtype} {audio.min()} {audio.max()} {sample_count}")
logger.warning(f"valid audio: {audio.shape} {audio.dtype} {audio.min()} {audio.max()} {sample_count}")
return audio, sample_count
def prepare_audio_data(self, chat_audio_result):
sample_count = 0
audio = np.array([], dtype=np.int16)
# convert chat audio result to mono and target sample rate
if chat_audio_result is not None:
audio_data, audio_info = chat_audio_result
audio, sample_count = self.convert_pcm_s16le_to_mono_resampled(audio_data, audio_info)
valid_duration = sample_count / self.sample_rate
# if is not the first segment, concat with previous segment
if self.prev_seg_chunk is not None:
audio = np.concatenate([self.prev_seg_chunk, audio])
sample_count = len(audio)
assert sample_count <= self.all_seg_sample_count, f"audio length {sample_count} > all_seg_sample_count {self.all_seg_sample_count}"
# pad 0 to the audio to make it the same length as all_seg_sample_count
if sample_count < self.all_seg_sample_count:
pad_count = self.all_seg_sample_count - sample_count
# logger.info(f"pad {pad_count} samples to audio")
audio = np.pad(audio, (0, pad_count), mode="constant", constant_values=0)
sample_count = len(audio)
# update prev seg chunk
self.prev_seg_chunk = audio[-self.prev_seg_sample_count :]
# logger.info(f"audio: {audio.shape} {audio.dtype} {audio.min()} {audio.max()} {sample_count}, prev seg chunk: {self.prev_seg_chunk.shape}")
return audio.tobytes(), valid_duration
def get_fetch_duration(self):
fetch_duration = self.segment_duration
# after immediate switch, reset prev seg chunk
if self.chat_adapter is not None and self.chat_adapter.reset_prev:
self.prev_seg_chunk = None
self.chat_adapter.reset_prev = False
logger.warning(f"Reset prev seg chunk")
# first segment, fetch segment_duration, else fetch segment_duration - prev_duration
if self.prev_seg_chunk is not None:
fetch_duration -= self.prev_duration
return fetch_duration
def change_segment_duration(self, segment_duration):
if segment_duration is None or self.segment_duration == segment_duration:
return
if self.rank == self.target_rank:
logger.warning(f"segment duration changed: {self.segment_duration} -> {segment_duration}")
self.segment_duration = segment_duration
self.all_seg_sample_count = int(self.segment_duration * self.sample_rate)
chunk_size = int(self.segment_duration * self.sample_rate) * 2
self.audio_tensor = torch.zeros(chunk_size, dtype=torch.uint8, device="cuda")
if self.chat_adapter is not None:
self.chat_adapter.seg_duration = segment_duration
def get_audio_segment(self, fetch_duration: float = None, prev_duration: float = None):
audio_data = None
valid_duration = 0
if prev_duration is not None and self.prev_duration != prev_duration:
raise ValueError(f"prev_duration {prev_duration} != {self.prev_duration}")
self.change_segment_duration(fetch_duration)
if self.rank == self.target_rank:
try:
fetch_duration = self.get_fetch_duration()
# logger.info(f"Get segment, fetch_duration: {fetch_duration}")
if self.chat_adapter.status == "voice":
audio_result = self.chat_adapter.get_audio(fetch_duration)
audio_data, valid_duration = self.prepare_audio_data(audio_result)
# think all voice segments inferred, naturally switch to blank
if audio_result is None:
logger.info(f"Think all voice segments inferred, naturally switch to blank")
self.chat_adapter.status = "blank"
else:
audio_data, valid_duration = self.prepare_audio_data(None)
except Exception as e:
logger.warning(f"Failed to get voice segment: {e}")
return None, 0
if self.world_size > 1:
audio_data = self.braodcast_audio_data(audio_data)
valid_duration = self.braodcast_valid_duration(valid_duration)
audio_data = self.bytes_to_ndarray(audio_data)
return audio_data, valid_duration
def get_immediate_switch(self):
if self.rank == self.target_rank:
if self.chat_adapter is not None and self.chat_adapter.immediate_switch == 1:
self.immediate_switch_tensor.fill_(1)
# reset immediate switch
self.chat_adapter.immediate_switch = 0
else:
self.immediate_switch_tensor.fill_(0)
if self.world_size > 1:
dist.broadcast(self.immediate_switch_tensor, src=self.target_rank)
return self.immediate_switch_tensor.item()
def get_image_switch(self):
data = "" if self.chat_adapter is None else self.chat_adapter.image_switch
image_switch = self.broadcast_data(data)
# reset image switch
if self.chat_adapter is not None:
self.chat_adapter.image_switch = ""
return image_switch
def get_action_switch(self):
data = "" if self.chat_adapter is None else self.chat_adapter.action_switch
action_switch = self.broadcast_data(data)
# reset action switch
if self.chat_adapter is not None:
self.chat_adapter.action_switch = ""
return action_switch
def broadcast_data(self, data):
if self.world_size <= 1:
return data
if self.rank == self.target_rank:
val = json.dumps(data, ensure_ascii=False).encode("utf-8")
T = torch.frombuffer(bytearray(val), dtype=torch.uint8).to(device="cuda")
S = torch.tensor([T.shape[0]], dtype=torch.int32).to(device="cuda")
else:
S = torch.zeros(1, dtype=torch.int32, device="cuda")
dist.broadcast(S, src=self.target_rank)
if self.rank != self.target_rank:
T = torch.zeros(S.item(), dtype=torch.uint8, device="cuda")
dist.broadcast(T, src=self.target_rank)
if self.rank != self.target_rank:
val = T.cpu().numpy().tobytes()
data = json.loads(val.decode("utf-8"))
return data
def stop(self):
self.model_runner = None
if self.chat_adapter is not None:
self.chat_adapter.stop()
self.chat_adapter = None
logger.warning("OmniVAReader stopped")
def __del__(self):
self.stop()
if __name__ == "__main__":
WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))
RANK = int(os.environ.get("RANK", 0))
if WORLD_SIZE > 1:
dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank())
logger.info(f"Distributed initialized: rank={RANK}, world_size={WORLD_SIZE}")
reader = OmniVAReader(
RANK,
WORLD_SIZE,
"https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whep/?app=publish&stream=test_stream_ll&eip=10.120.114.82:8000",
segment_duration=17 / 16,
sample_rate=16000,
audio_channels=1,
prev_duration=1 / 16,
)
reader.start()
fail_count = 0
max_fail_count = 100000000
try:
while True:
audio_data = reader.get_audio_segment(timeout=1)
if audio_data is not None:
logger.info(f"Got audio chunk, shape: {audio_data.shape}, range: [{audio_data.min()}, {audio_data.max()}]")
fail_count = 0
else:
fail_count += 1
if fail_count > max_fail_count:
logger.warning("Failed to get audio chunk, stop reader")
reader.stop()
break
time.sleep(0.95)
finally:
reader.stop()
import math
import os
import queue
import socket
import subprocess
import threading
import time
import traceback
import numpy as np
import torch
import torchaudio as ta
from loguru import logger
def pseudo_random(a, b):
x = str(time.time()).split(".")[1]
y = int(float("0." + x) * 1000000)
return a + (y % (b - a + 1))
class VARecorder:
def __init__(
self,
livestream_url: str,
fps: float = 16.0,
sample_rate: int = 16000,
slice_frame: int = 1,
prev_frame: int = 1,
stream_config: dict = {},
):
self.livestream_url = livestream_url
self.stream_config = stream_config
self.fps = fps
self.sample_rate = sample_rate
self.audio_port = pseudo_random(32000, 40000)
self.video_port = self.audio_port + 1
self.ffmpeg_log_level = os.getenv("FFMPEG_LOG_LEVEL", "error")
logger.info(f"VARecorder audio port: {self.audio_port}, video port: {self.video_port}, ffmpeg_log_level: {self.ffmpeg_log_level}")
self.width = None
self.height = None
self.stoppable_t = None
self.realtime = False
if self.livestream_url.startswith("rtmp://") or self.livestream_url.startswith("http"):
self.realtime = True
# ffmpeg process for mix video and audio data and push to livestream
self.ffmpeg_process = None
# TCP connection objects
self.audio_socket = None
self.video_socket = None
self.audio_conn = None
self.video_conn = None
self.audio_thread = None
self.video_thread = None
# queue for send data to ffmpeg process
self.audio_queue = queue.Queue()
self.video_queue = queue.Queue()
# buffer for stream data
self.audio_samples_per_frame = round(self.sample_rate / self.fps)
self.stream_buffer = []
self.stream_buffer_lock = threading.Lock()
self.stop_schedule = False
self.schedule_thread = None
self.slice_frame = slice_frame
self.prev_frame = prev_frame
assert self.slice_frame >= self.prev_frame, "Slice frame must be greater than previous frame"
def init_sockets(self):
# TCP socket for send and recv video and audio data
self.video_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.video_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.video_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
self.video_socket.bind(("127.0.0.1", self.video_port))
self.video_socket.listen(1)
self.audio_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.audio_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.audio_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
self.audio_socket.bind(("127.0.0.1", self.audio_port))
self.audio_socket.listen(1)
def audio_worker(self):
try:
logger.info("Waiting for ffmpeg to connect to audio socket...")
self.audio_conn, _ = self.audio_socket.accept()
logger.info(f"Audio connection established from {self.audio_conn.getpeername()}")
fail_time, max_fail_time = 0, 10
while True:
try:
if self.audio_queue is None:
break
data = self.audio_queue.get()
if data is None:
logger.info("Audio thread received stop signal")
break
# Convert audio data to 16-bit integer format
audios = torch.clamp(torch.round(data * 32767), -32768, 32767).to(torch.int16)
try:
self.audio_conn.send(audios[None].cpu().numpy().tobytes())
except (BrokenPipeError, OSError, ConnectionResetError) as e:
logger.info(f"Audio connection closed, stopping worker: {type(e).__name__}")
return
fail_time = 0
except (BrokenPipeError, OSError, ConnectionResetError):
logger.info("Audio connection closed during queue processing")
break
except Exception:
logger.error(f"Send audio data error: {traceback.format_exc()}")
fail_time += 1
if fail_time > max_fail_time:
logger.error(f"Audio push worker thread failed {fail_time} times, stopping...")
break
except Exception:
logger.error(f"Audio push worker thread error: {traceback.format_exc()}")
finally:
logger.info("Audio push worker thread stopped")
def video_worker(self):
try:
logger.info("Waiting for ffmpeg to connect to video socket...")
self.video_conn, _ = self.video_socket.accept()
logger.info(f"Video connection established from {self.video_conn.getpeername()}")
fail_time, max_fail_time = 0, 10
packet_secs = 1.0 / self.fps
while True:
try:
if self.video_queue is None:
break
data = self.video_queue.get()
if data is None:
logger.info("Video thread received stop signal")
break
# Convert to numpy and scale to [0, 255], convert RGB to BGR for OpenCV/FFmpeg
for i in range(data.shape[0]):
t0 = time.time()
frame = (data[i] * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
try:
self.video_conn.send(frame.tobytes())
except (BrokenPipeError, OSError, ConnectionResetError) as e:
logger.info(f"Video connection closed, stopping worker: {type(e).__name__}")
return
if self.realtime and i < data.shape[0] - 1:
time.sleep(max(0, packet_secs - (time.time() - t0)))
fail_time = 0
except (BrokenPipeError, OSError, ConnectionResetError):
logger.info("Video connection closed during queue processing")
break
except Exception:
logger.error(f"Send video data error: {traceback.format_exc()}")
fail_time += 1
if fail_time > max_fail_time:
logger.error(f"Video push worker thread failed {fail_time} times, stopping...")
break
except Exception:
logger.error(f"Video push worker thread error: {traceback.format_exc()}")
finally:
logger.info("Video push worker thread stopped")
def start_ffmpeg_process_local(self):
"""Start ffmpeg process that connects to our TCP sockets"""
ffmpeg_cmd = [
"ffmpeg",
"-fflags",
"nobuffer",
"-analyzeduration",
"0",
"-probesize",
"32",
"-flush_packets",
"1",
"-f",
"s16le",
"-ar",
str(self.sample_rate),
"-ac",
"1",
"-i",
f"tcp://127.0.0.1:{self.audio_port}",
"-f",
"rawvideo",
"-pix_fmt",
"rgb24",
"-color_range",
"pc",
"-colorspace",
"rgb",
"-color_primaries",
"bt709",
"-color_trc",
"iec61966-2-1",
"-r",
str(self.fps),
"-s",
f"{self.width}x{self.height}",
"-i",
f"tcp://127.0.0.1:{self.video_port}",
"-ar",
"44100",
"-b:v",
"4M",
"-c:v",
"libx264",
"-preset",
"ultrafast",
"-tune",
"zerolatency",
"-g",
f"{self.fps}",
"-pix_fmt",
"yuv420p",
"-f",
"mp4",
self.livestream_url,
"-y",
"-loglevel",
self.ffmpeg_log_level,
]
try:
self.ffmpeg_process = subprocess.Popen(ffmpeg_cmd)
logger.info(f"FFmpeg streaming started with PID: {self.ffmpeg_process.pid}")
logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}")
except Exception as e:
logger.error(f"Failed to start FFmpeg: {e}")
def start_ffmpeg_process_rtmp(self):
"""Start ffmpeg process that connects to our TCP sockets"""
ffmpeg_cmd = [
"ffmpeg",
"-re",
"-f",
"s16le",
"-ar",
str(self.sample_rate),
"-ac",
"1",
"-i",
f"tcp://127.0.0.1:{self.audio_port}",
"-f",
"rawvideo",
"-re",
"-pix_fmt",
"rgb24",
"-r",
str(self.fps),
"-s",
f"{self.width}x{self.height}",
"-i",
f"tcp://127.0.0.1:{self.video_port}",
"-ar",
"44100",
"-b:v",
"2M",
"-c:v",
"libx264",
"-preset",
"ultrafast",
"-tune",
"zerolatency",
"-g",
f"{self.fps}",
"-pix_fmt",
"yuv420p",
"-f",
"flv",
self.livestream_url,
"-y",
"-loglevel",
self.ffmpeg_log_level,
]
try:
self.ffmpeg_process = subprocess.Popen(ffmpeg_cmd)
logger.info(f"FFmpeg streaming started with PID: {self.ffmpeg_process.pid}")
logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}")
except Exception as e:
logger.error(f"Failed to start FFmpeg: {e}")
def start_ffmpeg_process_whip(self):
"""Start ffmpeg process that connects to our TCP sockets"""
ffmpeg_cmd = [
"ffmpeg",
"-re",
"-fflags",
"nobuffer",
"-analyzeduration",
"0",
"-probesize",
"32",
"-flush_packets",
"1",
"-f",
"s16le",
"-ar",
str(self.sample_rate),
"-ac",
"1",
"-ch_layout",
"mono",
"-i",
f"tcp://127.0.0.1:{self.audio_port}",
"-f",
"rawvideo",
"-re",
"-pix_fmt",
"rgb24",
"-r",
str(self.fps),
"-s",
f"{self.width}x{self.height}",
"-i",
f"tcp://127.0.0.1:{self.video_port}",
"-ar",
"48000",
"-c:a",
"libopus",
"-ac",
"2",
"-b:v",
"2M",
"-c:v",
"libx264",
"-preset",
"ultrafast",
"-tune",
"zerolatency",
"-g",
f"{self.fps}",
"-pix_fmt",
"yuv420p",
"-threads",
"1",
"-bf",
"0",
"-f",
"whip",
self.livestream_url,
"-y",
"-loglevel",
self.ffmpeg_log_level,
]
try:
self.ffmpeg_process = subprocess.Popen(ffmpeg_cmd)
logger.info(f"FFmpeg streaming started with PID: {self.ffmpeg_process.pid}")
logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}")
except Exception as e:
logger.error(f"Failed to start FFmpeg: {e}")
def start(self, width: int, height: int):
self.set_video_size(width, height)
duration = 1.0
frames = int(self.fps * duration)
samples = int(self.sample_rate * (frames / self.fps))
tensor = torch.zeros((frames, height, width, 3), dtype=torch.float16)
self.pub_livestream(tensor, torch.zeros(samples, dtype=torch.float16))
time.sleep(duration)
def config_video_padding(self):
pass
def padding_video_frames(self, frames: torch.Tensor):
return frames
def try_init_sockets(self, max_try=10):
for i in range(max_try):
try:
self.init_sockets()
return True
except OSError:
self.audio_port = pseudo_random(32000, 40000)
self.video_port = self.audio_port + 1
logger.warning(f"Failed to initialize sockets {i + 1}/{max_try}: {traceback.format_exc()}")
logger.warning(f"change port to {self.audio_port} and {self.video_port}, retry ...")
def set_video_size(self, width: int, height: int):
if self.width is not None and self.height is not None:
assert self.width == width and self.height == height, "Video size already set"
return
self.width = width
self.height = height
self.config_video_padding()
self.try_init_sockets()
if self.livestream_url.startswith("rtmp://"):
self.start_ffmpeg_process_rtmp()
elif self.livestream_url.startswith("http"):
self.start_ffmpeg_process_whip()
else:
self.start_ffmpeg_process_local()
self.audio_thread = threading.Thread(target=self.audio_worker)
self.video_thread = threading.Thread(target=self.video_worker)
self.audio_thread.start()
self.video_thread.start()
if self.realtime:
self.schedule_thread = threading.Thread(target=self.schedule_stream_buffer)
self.schedule_thread.start()
# Publish ComfyUI Image tensor and audio tensor to livestream
def pub_livestream(self, images: torch.Tensor, audios: torch.Tensor):
N, height, width, C = images.shape
M = audios.reshape(-1).shape[0]
assert C == 3, "Input must be [N, H, W, C] with C=3"
logger.info(f"Publishing video [{N}x{width}x{height}], audio: [{M}]")
audio_frames = round(M * self.fps / self.sample_rate)
if audio_frames != N:
logger.warning(f"Video and audio frames mismatch, {N} vs {audio_frames}")
self.set_video_size(width, height)
self.audio_queue.put(audios)
self.video_queue.put(self.padding_video_frames(images))
logger.info(f"Published {N} frames and {M} audio samples")
self.stoppable_t = time.time() + M / self.sample_rate + 3
def buffer_stream(self, images: torch.Tensor, audios: torch.Tensor, gen_video: torch.Tensor, valid_duration=1e9):
N, height, width, C = images.shape
M = audios.reshape(-1).shape[0]
assert N % self.slice_frame == 0, "Video frames must be divisible by slice_frame"
assert C == 3, "Input must be [N, H, W, C] with C=3"
audio_frames = round(M * self.fps / self.sample_rate)
if audio_frames != N:
logger.warning(f"Video and audio frames mismatch, {N} vs {audio_frames}")
self.set_video_size(width, height)
valid_frames = math.ceil(valid_duration * self.fps)
# logger.info(f"Buffer stream images {images.shape} {audios.shape} {gen_video.shape}")
rets = []
for i in range(0, N, self.slice_frame):
end_frame = i + self.slice_frame
can_truncate = valid_frames < end_frame
img = self.padding_video_frames(images[i:end_frame])
aud = audios[i * self.audio_samples_per_frame : end_frame * self.audio_samples_per_frame]
gen = gen_video[:, :, (end_frame - self.prev_frame) : end_frame]
rets.append([img, aud, gen, can_truncate])
with self.stream_buffer_lock:
origin_size = len(self.stream_buffer)
self.stream_buffer.extend(rets)
logger.info(f"Buffered {origin_size} + {len(rets)} = {len(self.stream_buffer)} stream segments, valid_frames: {valid_frames}")
def get_buffer_stream_size(self):
return len(self.stream_buffer)
def truncate_stream_buffer(self, size: int, check_can_truncate: bool = True):
with self.stream_buffer_lock:
# find the first frame that cannot not be truncated
idx = len(self.stream_buffer) - 1
while check_can_truncate and idx >= size and idx >= 0:
if not self.stream_buffer[idx][3]:
logger.warning(f"can not truncate frame: {idx}, trucecate size: {size} -> {idx + 1}")
size = idx + 1
break
idx -= 1
self.stream_buffer = self.stream_buffer[:size]
logger.info(f"Truncated stream buffer to {len(self.stream_buffer)} segments")
if len(self.stream_buffer) > 0:
# after truncate, set the last segment can not be truncated
self.stream_buffer[-1][3] = False
return self.stream_buffer[-1][2] # return the last video tensor
else:
return None
def schedule_stream_buffer(self):
schedule_interval = self.slice_frame / self.fps
logger.info(f"Schedule stream buffer with interval: {schedule_interval} seconds")
t = None
fail_time = 0
while True:
try:
if self.stop_schedule:
break
img, aud, gen = None, None, None
with self.stream_buffer_lock:
if len(self.stream_buffer) > 0:
img, aud, gen, _ = self.stream_buffer.pop(0)
if t is not None:
wait_secs = schedule_interval - (time.time() - t)
if wait_secs > 0:
time.sleep(wait_secs)
t = time.time()
if img is not None and aud is not None:
fail_time = 0
self.audio_queue.put(aud)
self.video_queue.put(img)
# logger.info(f"Scheduled {img.shape[0]} frames and {aud.shape[0]} audio samples to publish")
del gen
self.stoppable_t = time.time() + aud.shape[0] / self.sample_rate + 3
else:
fail_time += 1
if fail_time % 10 == 0:
logger.warning(f"No stream buffer to schedule: {fail_time} times")
except Exception:
logger.error(f"Schedule stream buffer error: {traceback.format_exc()}")
break
logger.info("Schedule stream buffer thread stopped")
def stop(self, wait=True):
if wait and self.stoppable_t:
t = self.stoppable_t - time.time()
if t > 0:
logger.warning(f"Waiting for {t} seconds to stop ...")
time.sleep(t)
self.stoppable_t = None
if self.schedule_thread:
self.stop_schedule = True
self.schedule_thread.join(timeout=5)
if self.schedule_thread and self.schedule_thread.is_alive():
logger.error(f"Schedule thread did not stop after 5s")
# Send stop signals to queues
if self.audio_queue:
self.audio_queue.put(None)
if self.video_queue:
self.video_queue.put(None)
# Wait for threads to finish processing queued data (increased timeout)
queue_timeout = 30 # Increased from 5s to 30s to allow sufficient time for large video frames
if self.audio_thread and self.audio_thread.is_alive():
self.audio_thread.join(timeout=queue_timeout)
if self.audio_thread.is_alive():
logger.error(f"Audio push thread did not stop after {queue_timeout}s")
if self.video_thread and self.video_thread.is_alive():
self.video_thread.join(timeout=queue_timeout)
if self.video_thread.is_alive():
logger.error(f"Video push thread did not stop after {queue_timeout}s")
# Shutdown connections to signal EOF to FFmpeg
# shutdown(SHUT_WR) will wait for send buffer to flush, no explicit sleep needed
if self.audio_conn:
try:
self.audio_conn.getpeername()
self.audio_conn.shutdown(socket.SHUT_WR)
logger.info("Audio connection shutdown initiated")
except OSError:
# Connection already closed, skip shutdown
pass
if self.video_conn:
try:
self.video_conn.getpeername()
self.video_conn.shutdown(socket.SHUT_WR)
logger.info("Video connection shutdown initiated")
except OSError:
# Connection already closed, skip shutdown
pass
if self.ffmpeg_process:
is_local_file = not self.livestream_url.startswith(("rtmp://", "http"))
# Local MP4 files need time to write moov atom and finalize the container
timeout_seconds = 30 if is_local_file else 10
logger.info(f"Waiting for FFmpeg to finalize file (timeout={timeout_seconds}s, local_file={is_local_file})")
logger.info(f"FFmpeg output: {self.livestream_url}")
try:
returncode = self.ffmpeg_process.wait(timeout=timeout_seconds)
if returncode == 0:
logger.info(f"FFmpeg process exited successfully (exit code: {returncode})")
else:
logger.warning(f"FFmpeg process exited with non-zero code: {returncode}")
except subprocess.TimeoutExpired:
logger.warning(f"FFmpeg process did not exit within {timeout_seconds}s, sending SIGTERM...")
try:
self.ffmpeg_process.terminate() # SIGTERM
returncode = self.ffmpeg_process.wait(timeout=5)
logger.warning(f"FFmpeg process terminated with SIGTERM (exit code: {returncode})")
except subprocess.TimeoutExpired:
logger.error("FFmpeg process still running after SIGTERM, killing with SIGKILL...")
self.ffmpeg_process.kill()
self.ffmpeg_process.wait() # Wait for kill to complete
logger.error("FFmpeg process killed with SIGKILL")
finally:
self.ffmpeg_process = None
if self.audio_conn:
try:
self.audio_conn.close()
except Exception as e:
logger.debug(f"Error closing audio connection: {e}")
finally:
self.audio_conn = None
if self.video_conn:
try:
self.video_conn.close()
except Exception as e:
logger.debug(f"Error closing video connection: {e}")
finally:
self.video_conn = None
if self.audio_socket:
try:
self.audio_socket.close()
except Exception as e:
logger.debug(f"Error closing audio socket: {e}")
finally:
self.audio_socket = None
if self.video_socket:
try:
self.video_socket.close()
except Exception as e:
logger.debug(f"Error closing video socket: {e}")
finally:
self.video_socket = None
if self.audio_queue:
while self.audio_queue.qsize() > 0:
try:
self.audio_queue.get_nowait()
except: # noqa
break
if self.video_queue:
while self.video_queue.qsize() > 0:
try:
self.video_queue.get_nowait()
except: # noqa
break
self.audio_queue = None
self.video_queue = None
logger.info("VARecorder stopped and resources cleaned up")
def __del__(self):
self.stop(wait=False)
def create_simple_video(frames=10, height=480, width=640):
video_data = []
for i in range(frames):
frame = np.zeros((height, width, 3), dtype=np.float32)
stripe_height = height // 8
colors = [
[1.0, 0.0, 0.0], # 红色
[0.0, 1.0, 0.0], # 绿色
[0.0, 0.0, 1.0], # 蓝色
[1.0, 1.0, 0.0], # 黄色
[1.0, 0.0, 1.0], # 洋红
[0.0, 1.0, 1.0], # 青色
[1.0, 1.0, 1.0], # 白色
[0.5, 0.5, 0.5], # 灰色
]
for j, color in enumerate(colors):
start_y = j * stripe_height
end_y = min((j + 1) * stripe_height, height)
frame[start_y:end_y, :] = color
offset = int((i / frames) * width)
frame = np.roll(frame, offset, axis=1)
frame = torch.tensor(frame, dtype=torch.float32)
video_data.append(frame)
return torch.stack(video_data, dim=0)
if __name__ == "__main__":
sample_rate = 16000
fps = 16
width = 640
height = 480
recorder = VARecorder(
# livestream_url="rtmp://localhost/live/test",
# livestream_url="https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whip/?app=live&stream=ll_test_video&eip=127.0.0.1:8000",
livestream_url="/path/to/output_video.mp4",
fps=fps,
sample_rate=sample_rate,
)
audio_path = "/path/to/test_b_2min.wav"
audio_array, ori_sr = ta.load(audio_path)
audio_array = ta.functional.resample(audio_array.mean(0), orig_freq=ori_sr, new_freq=16000)
audio_array = audio_array.reshape(-1)
secs = audio_array.shape[0] // sample_rate
interval = 1
for i in range(0, secs, interval):
logger.info(f"{i} / {secs} s")
start = i * sample_rate
end = (i + interval) * sample_rate
cur_audio_array = audio_array[start:end]
logger.info(f"audio: {cur_audio_array.shape} {cur_audio_array.dtype} {cur_audio_array.min()} {cur_audio_array.max()}")
num_frames = int(interval * fps)
images = create_simple_video(num_frames, height, width)
logger.info(f"images: {images.shape} {images.dtype} {images.min()} {images.max()}")
recorder.pub_livestream(images, cur_audio_array)
time.sleep(interval)
recorder.stop()
import ctypes
import queue
import threading
import time
import traceback
import numpy as np
import torch
import torchaudio as ta
from loguru import logger
from scipy.signal import resample
class X264VARecorder:
def __init__(
self,
whip_shared_path: str,
livestream_url: str,
fps: float = 16.0,
sample_rate: int = 16000,
slice_frame: int = 1,
prev_frame: int = 1,
):
assert livestream_url.startswith("http"), "X264VARecorder only support whip http livestream"
self.livestream_url = livestream_url
self.fps = fps
self.sample_rate = sample_rate
self.width = None
self.height = None
self.stoppable_t = None
# only enable whip shared api for whip http livestream
self.whip_shared_path = whip_shared_path
self.whip_shared_lib = None
self.whip_shared_handle = None
assert livestream_url.startswith("http"), "X264VARecorder only support whip http livestream"
self.realtime = True
# queue for send data to whip shared api
self.queue = queue.Queue()
self.worker_thread = None
# buffer for stream data
self.target_sample_rate = 48000
self.target_samples_per_frame = round(self.target_sample_rate / self.fps)
self.target_chunks_per_frame = self.target_samples_per_frame * 2
self.stream_buffer = []
self.stream_buffer_lock = threading.Lock()
self.stop_schedule = False
self.schedule_thread = None
self.slice_frame = slice_frame
self.prev_frame = prev_frame
assert self.slice_frame >= self.prev_frame, "Slice frame must be greater than previous frame"
def worker(self):
try:
fail_time, max_fail_time = 0, 10
packet_secs = 1.0 / self.fps
while True:
try:
if self.queue is None:
break
data = self.queue.get()
if data is None:
logger.info("Worker thread received stop signal")
break
audios, images = data
for i in range(images.shape[0]):
t0 = time.time()
cur_audio = audios[i * self.target_chunks_per_frame : (i + 1) * self.target_chunks_per_frame].flatten()
audio_ptr = cur_audio.ctypes.data_as(ctypes.POINTER(ctypes.c_int16))
self.whip_shared_lib.pushWhipRawAudioFrame(self.whip_shared_handle, audio_ptr, self.target_samples_per_frame)
cur_video = images[i].flatten()
video_ptr = cur_video.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8))
self.whip_shared_lib.pushWhipRawVideoFrame(self.whip_shared_handle, video_ptr, self.width, self.height)
if self.realtime and i < images.shape[0] - 1:
time.sleep(max(0, packet_secs - (time.time() - t0)))
fail_time = 0
except: # noqa
logger.error(f"Send audio data error: {traceback.format_exc()}")
fail_time += 1
if fail_time > max_fail_time:
logger.error(f"Audio push worker thread failed {fail_time} times, stopping...")
break
except: # noqa
logger.error(f"Audio push worker thread error: {traceback.format_exc()}")
finally:
logger.info("Audio push worker thread stopped")
def start_libx264_whip_shared_api(self, width: int, height: int):
self.whip_shared_lib = ctypes.CDLL(self.whip_shared_path)
# define function argtypes and restype
self.whip_shared_lib.initWhipStream.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int]
self.whip_shared_lib.initWhipStream.restype = ctypes.c_void_p
self.whip_shared_lib.pushWhipRawAudioFrame.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_int16), ctypes.c_int]
self.whip_shared_lib.pushWhipRawVideoFrame.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_uint8), ctypes.c_int, ctypes.c_int]
self.whip_shared_lib.destroyWhipStream.argtypes = [ctypes.c_void_p]
whip_url = ctypes.c_char_p(self.livestream_url.encode("utf-8"))
self.whip_shared_handle = ctypes.c_void_p(self.whip_shared_lib.initWhipStream(whip_url, 1, 1, 0, width, height))
logger.info(f"WHIP shared API initialized with handle: {self.whip_shared_handle}")
def convert_data(self, audios, images):
# Convert audio data to 16-bit integer format
audio_datas = torch.clamp(torch.round(audios * 32767), -32768, 32767).to(torch.int16).cpu().numpy().reshape(-1)
# Convert to numpy and scale to [0, 255], convert RGB to BGR for OpenCV/FFmpeg
image_datas = (images * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
logger.info(f"image_datas: {image_datas.shape} {image_datas.dtype} {image_datas.min()} {image_datas.max()}")
reample_audios = resample(audio_datas, int(len(audio_datas) * 48000 / self.sample_rate))
stereo_audios = np.stack([reample_audios, reample_audios], axis=-1).astype(np.int16).reshape(-1)
return stereo_audios, image_datas
def start(self, width: int, height: int):
self.set_video_size(width, height)
def set_video_size(self, width: int, height: int):
if self.width is not None and self.height is not None:
assert self.width == width and self.height == height, "Video size already set"
return
self.width = width
self.height = height
self.start_libx264_whip_shared_api(width, height)
self.worker_thread = threading.Thread(target=self.worker)
self.worker_thread.start()
if self.realtime:
self.schedule_thread = threading.Thread(target=self.schedule_stream_buffer)
self.schedule_thread.start()
def buffer_stream(self, images: torch.Tensor, audios: torch.Tensor, gen_video: torch.Tensor, valid_duration=1e9):
N, height, width, C = images.shape
M = audios.reshape(-1).shape[0]
assert N % self.slice_frame == 0, "Video frames must be divisible by slice_frame"
assert C == 3, "Input must be [N, H, W, C] with C=3"
audio_frames = round(M * self.fps / self.sample_rate)
if audio_frames != N:
logger.warning(f"Video and audio frames mismatch, {N} vs {audio_frames}")
self.set_video_size(width, height)
audio_datas, image_datas = self.convert_data(audios, images)
# logger.info(f"Buffer stream images {images.shape} {audios.shape} {gen_video.shape}")
rets = []
for i in range(0, N, self.slice_frame):
end_frame = i + self.slice_frame
img = image_datas[i:end_frame]
aud = audio_datas[i * self.target_chunks_per_frame : end_frame * self.target_chunks_per_frame]
gen = gen_video[:, :, (end_frame - self.prev_frame) : end_frame]
rets.append((img, aud, gen))
with self.stream_buffer_lock:
origin_size = len(self.stream_buffer)
self.stream_buffer.extend(rets)
logger.info(f"Buffered {origin_size} + {len(rets)} = {len(self.stream_buffer)} stream segments")
def get_buffer_stream_size(self):
return len(self.stream_buffer)
def truncate_stream_buffer(self, size: int):
with self.stream_buffer_lock:
self.stream_buffer = self.stream_buffer[:size]
logger.info(f"Truncated stream buffer to {len(self.stream_buffer)} segments")
if len(self.stream_buffer) > 0:
return self.stream_buffer[-1][2] # return the last video tensor
else:
return None
def schedule_stream_buffer(self):
schedule_interval = self.slice_frame / self.fps
logger.info(f"Schedule stream buffer with interval: {schedule_interval} seconds")
t = None
while True:
try:
if self.stop_schedule:
break
img, aud, gen = None, None, None
with self.stream_buffer_lock:
if len(self.stream_buffer) > 0:
img, aud, gen = self.stream_buffer.pop(0)
if t is not None:
wait_secs = schedule_interval - (time.time() - t)
if wait_secs > 0:
time.sleep(wait_secs)
t = time.time()
if img is not None and aud is not None:
self.queue.put((aud, img))
# logger.info(f"Scheduled {img.shape[0]} frames and {aud.shape[0]} audio samples to publish")
del gen
self.stoppable_t = time.time() + img.shape[0] / self.fps + 3
else:
logger.warning(f"No stream buffer to schedule")
except Exception:
logger.error(f"Schedule stream buffer error: {traceback.format_exc()}")
break
logger.info("Schedule stream buffer thread stopped")
def stop(self, wait=True):
if wait and self.stoppable_t:
t = self.stoppable_t - time.time()
if t > 0:
logger.warning(f"Waiting for {t} seconds to stop ...")
time.sleep(t)
self.stoppable_t = None
if self.schedule_thread:
self.stop_schedule = True
self.schedule_thread.join(timeout=5)
if self.schedule_thread and self.schedule_thread.is_alive():
logger.error(f"Schedule thread did not stop after 5s")
# Send stop signals to queues
if self.queue:
self.queue.put(None)
# Wait for threads to finish
if self.worker_thread and self.worker_thread.is_alive():
self.worker_thread.join(timeout=5)
if self.worker_thread.is_alive():
logger.warning("Worker thread did not stop gracefully")
# Destroy WHIP shared API
if self.whip_shared_lib and self.whip_shared_handle:
self.whip_shared_lib.destroyWhipStream(self.whip_shared_handle)
self.whip_shared_handle = None
self.whip_shared_lib = None
logger.warning("WHIP shared API destroyed")
def __del__(self):
self.stop()
def create_simple_video(frames=10, height=480, width=640):
video_data = []
for i in range(frames):
frame = np.zeros((height, width, 3), dtype=np.float32)
stripe_height = height // 8
colors = [
[1.0, 0.0, 0.0], # 红色
[0.0, 1.0, 0.0], # 绿色
[0.0, 0.0, 1.0], # 蓝色
[1.0, 1.0, 0.0], # 黄色
[1.0, 0.0, 1.0], # 洋红
[0.0, 1.0, 1.0], # 青色
[1.0, 1.0, 1.0], # 白色
[0.5, 0.5, 0.5], # 灰色
]
for j, color in enumerate(colors):
start_y = j * stripe_height
end_y = min((j + 1) * stripe_height, height)
frame[start_y:end_y, :] = color
offset = int((i / frames) * width)
frame = np.roll(frame, offset, axis=1)
frame = torch.tensor(frame, dtype=torch.float32)
video_data.append(frame)
return torch.stack(video_data, dim=0)
if __name__ == "__main__":
sample_rate = 16000
fps = 16
width = 452
height = 352
recorder = X264VARecorder(
whip_shared_path="/data/nvme0/liuliang1/lightx2v/test_deploy/test_whip_so/0.1.1/go_whxp.so",
livestream_url="https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whip/?app=subscribe&stream=ll2&eip=10.120.114.82:8000",
fps=fps,
sample_rate=sample_rate,
)
recorder.start(width, height)
# time.sleep(5)
audio_path = "/data/nvme0/liuliang1/lightx2v/test_deploy/media_test/mangzhong.wav"
audio_array, ori_sr = ta.load(audio_path)
audio_array = ta.functional.resample(audio_array.mean(0), orig_freq=ori_sr, new_freq=16000)
audio_array = audio_array.numpy().reshape(-1)
secs = audio_array.shape[0] // sample_rate
interval = 1
space = 10
i = 0
while i < space:
t0 = time.time()
logger.info(f"space {i} / {space} s")
cur_audio_array = np.zeros(int(interval * sample_rate), dtype=np.float32)
num_frames = int(interval * fps)
images = create_simple_video(num_frames, height, width)
recorder.buffer_stream(images, torch.tensor(cur_audio_array, dtype=torch.float32), images)
i += interval
time.sleep(interval - (time.time() - t0))
started = True
i = 0
while i < secs:
t0 = time.time()
start = int(i * sample_rate)
end = int((i + interval) * sample_rate)
cur_audio_array = torch.tensor(audio_array[start:end], dtype=torch.float32)
num_frames = int(interval * fps)
images = create_simple_video(num_frames, height, width)
logger.info(f"{i} / {secs} s")
if started:
logger.warning(f"start pub_livestream !!!!!!!!!!!!!!!!!!!!!!!")
started = False
recorder.buffer_stream(images, cur_audio_array, images)
i += interval
time.sleep(interval - (time.time() - t0))
recorder.stop()
import os
import queue
import socket
import subprocess
import threading
import time
import traceback
import numpy as np
import torch
from loguru import logger
def pseudo_random(a, b):
x = str(time.time()).split(".")[1]
y = int(float("0." + x) * 1000000)
return a + (y % (b - a + 1))
class VideoRecorder:
def __init__(
self,
livestream_url: str,
fps: float = 16.0,
):
self.livestream_url = livestream_url
self.fps = fps
self.video_port = pseudo_random(32000, 40000)
self.ffmpeg_log_level = os.getenv("FFMPEG_LOG_LEVEL", "error")
logger.info(f"VideoRecorder video port: {self.video_port}, ffmpeg_log_level: {self.ffmpeg_log_level}")
self.width = None
self.height = None
self.stoppable_t = None
self.realtime = True
# ffmpeg process for video data and push to livestream
self.ffmpeg_process = None
# TCP connection objects
self.video_socket = None
self.video_conn = None
self.video_thread = None
# queue for send data to ffmpeg process
self.video_queue = queue.Queue()
def init_sockets(self):
# TCP socket for send and recv video data
self.video_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.video_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.video_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
self.video_socket.bind(("127.0.0.1", self.video_port))
self.video_socket.listen(1)
def video_worker(self):
try:
logger.info("Waiting for ffmpeg to connect to video socket...")
self.video_conn, _ = self.video_socket.accept()
logger.info(f"Video connection established from {self.video_conn.getpeername()}")
fail_time, max_fail_time = 0, 10
packet_secs = 1.0 / self.fps
while True:
try:
if self.video_queue is None:
break
data = self.video_queue.get()
if data is None:
logger.info("Video thread received stop signal")
break
# Convert to numpy and scale to [0, 255], convert RGB to BGR for OpenCV/FFmpeg
for i in range(data.shape[0]):
t0 = time.time()
frame = (data[i] * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
try:
self.video_conn.send(frame.tobytes())
except (BrokenPipeError, OSError, ConnectionResetError) as e:
logger.info(f"Video connection closed, stopping worker: {type(e).__name__}")
return
if self.realtime:
time.sleep(max(0, packet_secs - (time.time() - t0)))
fail_time = 0
except (BrokenPipeError, OSError, ConnectionResetError):
logger.info("Video connection closed during queue processing")
break
except Exception:
logger.error(f"Send video data error: {traceback.format_exc()}")
fail_time += 1
if fail_time > max_fail_time:
logger.error(f"Video push worker thread failed {fail_time} times, stopping...")
break
except Exception:
logger.error(f"Video push worker thread error: {traceback.format_exc()}")
finally:
logger.info("Video push worker thread stopped")
def start_ffmpeg_process_local(self):
"""Start ffmpeg process that connects to our TCP sockets"""
ffmpeg_cmd = [
"ffmpeg",
"-fflags",
"nobuffer",
"-analyzeduration",
"0",
"-probesize",
"32",
"-flush_packets",
"1",
"-f",
"rawvideo",
"-pix_fmt",
"rgb24",
"-color_range",
"pc",
"-colorspace",
"rgb",
"-color_primaries",
"bt709",
"-color_trc",
"iec61966-2-1",
"-r",
str(self.fps),
"-s",
f"{self.width}x{self.height}",
"-i",
f"tcp://127.0.0.1:{self.video_port}",
"-b:v",
"4M",
"-c:v",
"libx264",
"-preset",
"ultrafast",
"-tune",
"zerolatency",
"-g",
f"{self.fps}",
"-pix_fmt",
"yuv420p",
"-f",
"mp4",
self.livestream_url,
"-y",
"-loglevel",
self.ffmpeg_log_level,
]
try:
self.ffmpeg_process = subprocess.Popen(ffmpeg_cmd)
logger.info(f"FFmpeg streaming started with PID: {self.ffmpeg_process.pid}")
logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}")
except Exception as e:
logger.error(f"Failed to start FFmpeg: {e}")
def start_ffmpeg_process_rtmp(self):
"""Start ffmpeg process that connects to our TCP sockets"""
ffmpeg_cmd = [
"ffmpeg",
"-f",
"rawvideo",
"-re",
"-pix_fmt",
"rgb24",
"-r",
str(self.fps),
"-s",
f"{self.width}x{self.height}",
"-i",
f"tcp://127.0.0.1:{self.video_port}",
"-b:v",
"2M",
"-c:v",
"libx264",
"-preset",
"ultrafast",
"-tune",
"zerolatency",
"-g",
f"{self.fps}",
"-pix_fmt",
"yuv420p",
"-f",
"flv",
self.livestream_url,
"-y",
"-loglevel",
self.ffmpeg_log_level,
]
try:
self.ffmpeg_process = subprocess.Popen(ffmpeg_cmd)
logger.info(f"FFmpeg streaming started with PID: {self.ffmpeg_process.pid}")
logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}")
except Exception as e:
logger.error(f"Failed to start FFmpeg: {e}")
def start_ffmpeg_process_whip(self):
"""Start ffmpeg process that connects to our TCP sockets"""
ffmpeg_cmd = [
"ffmpeg",
"-re",
"-fflags",
"nobuffer",
"-analyzeduration",
"0",
"-probesize",
"32",
"-flush_packets",
"1",
"-f",
"rawvideo",
"-re",
"-pix_fmt",
"rgb24",
"-r",
str(self.fps),
"-s",
f"{self.width}x{self.height}",
"-i",
f"tcp://127.0.0.1:{self.video_port}",
"-b:v",
"2M",
"-c:v",
"libx264",
"-preset",
"ultrafast",
"-tune",
"zerolatency",
"-g",
f"{self.fps}",
"-pix_fmt",
"yuv420p",
"-threads",
"1",
"-bf",
"0",
"-f",
"whip",
self.livestream_url,
"-y",
"-loglevel",
self.ffmpeg_log_level,
]
try:
self.ffmpeg_process = subprocess.Popen(ffmpeg_cmd)
logger.info(f"FFmpeg streaming started with PID: {self.ffmpeg_process.pid}")
logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}")
except Exception as e:
logger.error(f"Failed to start FFmpeg: {e}")
def start(self, width: int, height: int):
self.set_video_size(width, height)
duration = 1.0
self.pub_video(torch.zeros((int(self.fps * duration), height, width, 3), dtype=torch.float16))
time.sleep(duration)
def set_video_size(self, width: int, height: int):
if self.width is not None and self.height is not None:
assert self.width == width and self.height == height, "Video size already set"
return
self.width = width
self.height = height
self.init_sockets()
if self.livestream_url.startswith("rtmp://"):
self.start_ffmpeg_process_rtmp()
elif self.livestream_url.startswith("http"):
self.start_ffmpeg_process_whip()
else:
self.start_ffmpeg_process_local()
self.realtime = False
self.video_thread = threading.Thread(target=self.video_worker)
self.video_thread.start()
# Publish ComfyUI Image tensor to livestream
def pub_video(self, images: torch.Tensor):
N, height, width, C = images.shape
assert C == 3, "Input must be [N, H, W, C] with C=3"
logger.info(f"Publishing video [{N}x{width}x{height}]")
self.set_video_size(width, height)
self.video_queue.put(images)
logger.info(f"Published {N} frames")
self.stoppable_t = time.time() + N / self.fps + 3
def stop(self, wait=True):
if wait and self.stoppable_t:
t = self.stoppable_t - time.time()
if t > 0:
logger.warning(f"Waiting for {t} seconds to stop ...")
time.sleep(t)
self.stoppable_t = None
# Send stop signals to queues
if self.video_queue:
self.video_queue.put(None)
# Wait for threads to finish processing queued data (increased timeout)
queue_timeout = 30 # Increased from 5s to 30s to allow sufficient time for large video frames
if self.video_thread and self.video_thread.is_alive():
self.video_thread.join(timeout=queue_timeout)
if self.video_thread.is_alive():
logger.error(f"Video push thread did not stop after {queue_timeout}s")
# Shutdown connections to signal EOF to FFmpeg
# shutdown(SHUT_WR) will wait for send buffer to flush, no explicit sleep needed
if self.video_conn:
try:
self.video_conn.getpeername()
self.video_conn.shutdown(socket.SHUT_WR)
logger.info("Video connection shutdown initiated")
except OSError:
# Connection already closed, skip shutdown
pass
if self.ffmpeg_process:
is_local_file = not self.livestream_url.startswith(("rtmp://", "http"))
# Local MP4 files need time to write moov atom and finalize the container
timeout_seconds = 30 if is_local_file else 10
logger.info(f"Waiting for FFmpeg to finalize file (timeout={timeout_seconds}s, local_file={is_local_file})")
logger.info(f"FFmpeg output: {self.livestream_url}")
try:
returncode = self.ffmpeg_process.wait(timeout=timeout_seconds)
if returncode == 0:
logger.info(f"FFmpeg process exited successfully (exit code: {returncode})")
else:
logger.warning(f"FFmpeg process exited with non-zero code: {returncode}")
except subprocess.TimeoutExpired:
logger.warning(f"FFmpeg process did not exit within {timeout_seconds}s, sending SIGTERM...")
try:
self.ffmpeg_process.terminate() # SIGTERM
returncode = self.ffmpeg_process.wait(timeout=5)
logger.warning(f"FFmpeg process terminated with SIGTERM (exit code: {returncode})")
except subprocess.TimeoutExpired:
logger.error("FFmpeg process still running after SIGTERM, killing with SIGKILL...")
self.ffmpeg_process.kill()
self.ffmpeg_process.wait() # Wait for kill to complete
logger.error("FFmpeg process killed with SIGKILL")
finally:
self.ffmpeg_process = None
if self.video_conn:
try:
self.video_conn.close()
except Exception as e:
logger.debug(f"Error closing video connection: {e}")
finally:
self.video_conn = None
if self.video_socket:
try:
self.video_socket.close()
except Exception as e:
logger.debug(f"Error closing video socket: {e}")
finally:
self.video_socket = None
if self.video_queue:
while self.video_queue.qsize() > 0:
try:
self.video_queue.get_nowait()
except: # noqa
break
self.video_queue = None
logger.info("VideoRecorder stopped and resources cleaned up")
def __del__(self):
self.stop(wait=False)
def create_simple_video(frames=10, height=480, width=640):
video_data = []
for i in range(frames):
frame = np.zeros((height, width, 3), dtype=np.float32)
stripe_height = height // 8
colors = [
[1.0, 0.0, 0.0], # 红色
[0.0, 1.0, 0.0], # 绿色
[0.0, 0.0, 1.0], # 蓝色
[1.0, 1.0, 0.0], # 黄色
[1.0, 0.0, 1.0], # 洋红
[0.0, 1.0, 1.0], # 青色
[1.0, 1.0, 1.0], # 白色
[0.5, 0.5, 0.5], # 灰色
]
for j, color in enumerate(colors):
start_y = j * stripe_height
end_y = min((j + 1) * stripe_height, height)
frame[start_y:end_y, :] = color
offset = int((i / frames) * width)
frame = np.roll(frame, offset, axis=1)
frame = torch.tensor(frame, dtype=torch.float32)
video_data.append(frame)
return torch.stack(video_data, dim=0)
if __name__ == "__main__":
fps = 16
width = 640
height = 480
recorder = VideoRecorder(
# livestream_url="rtmp://localhost/live/test",
# livestream_url="https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whip/?app=live&stream=ll_test_video&eip=127.0.0.1:8000",
livestream_url="/path/to/output_video.mp4",
fps=fps,
)
secs = 10 # 10秒视频
interval = 1
for i in range(0, secs, interval):
logger.info(f"{i} / {secs} s")
num_frames = int(interval * fps)
images = create_simple_video(num_frames, height, width)
logger.info(f"images: {images.shape} {images.dtype} {images.min()} {images.max()}")
recorder.pub_video(images)
time.sleep(interval)
recorder.stop()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment