"README_ORIGIN.md" did not exist on "823b62ed9478049c5e977101f373eaf4e60ac98c"
Unverified Commit f7cdbcb5 authored by LiangLiu's avatar LiangLiu Committed by GitHub
Browse files

multi-person & animate & podcast (#554)



- 服务化功能新增(前端+后端):
1、seko-talk 模型支持多人输入
2、支持播客合成与管理
3、支持wan2.2 animate 模型

- 后端接口新增:
1、 基于火山的播客websocket合成接口,支持边合成边听
2、播客的查询管理接口
3、基于 yolo 的多人人脸检测接口
4、音频多人切分接口

- 推理代码侵入式修改
1、将 animate 相关的 输入文件路径(mask/image/pose等)从固定写死的config中移除到可变的input_info中
2、animate的预处理相关代码包装成接口供服务化使用

@xinyiqin

---------
Co-authored-by: default avatarqinxinyi <qxy118045534@163.com>
parent 61dd69ca
......@@ -94,7 +94,7 @@
}
},
"s2v": {
"seko_talk": {
"SekoTalk": {
"single_stage": {
"pipeline": {
"inputs": ["input_image", "input_audio"],
......@@ -125,12 +125,24 @@
}
}
}
},
"animate": {
"wan2.2_animate": {
"single_stage": {
"pipeline": {
"inputs": ["input_image","input_video"],
"outputs": ["output_video"]
}
}
}
}
},
"meta": {
"special_types": {
"input_image": "IMAGE",
"input_audio": "AUDIO",
"input_video": "VIDEO",
"latents": "TENSOR",
"output_video": "VIDEO"
},
......
......@@ -12,9 +12,6 @@
"sample_guide_scale": 5.0,
"enable_cfg": false,
"cpu_offload": false,
"src_pose_path": "../save_results/animate/process_results/src_pose.mp4",
"src_face_path": "../save_results/animate/process_results/src_face.mp4",
"src_ref_images": "../save_results/animate/process_results/src_ref.png",
"refert_num": 1,
"replace_flag": false,
"fps": 30
......
......@@ -13,9 +13,6 @@
"enable_cfg": false,
"cpu_offload": true,
"offload_granularity": "phase",
"src_pose_path": "../save_results/animate/process_results/src_pose.mp4",
"src_face_path": "../save_results/animate/process_results/src_face.mp4",
"src_ref_images": "../save_results/animate/process_results/src_ref.png",
"refert_num": 1,
"replace_flag": false,
"fps": 30,
......
......@@ -12,9 +12,6 @@
"sample_guide_scale": 1.0,
"enable_cfg": false,
"cpu_offload": false,
"src_pose_path": "../save_results/animate/process_results/src_pose.mp4",
"src_face_path": "../save_results/animate/process_results/src_face.mp4",
"src_ref_images": "../save_results/animate/process_results/src_ref.png",
"refert_num": 1,
"replace_flag": false,
"fps": 30,
......
......@@ -12,11 +12,6 @@
"sample_guide_scale": 5.0,
"enable_cfg": false,
"cpu_offload": false,
"src_pose_path": "../save_results/replace/process_results/src_pose.mp4",
"src_face_path": "../save_results/replace/process_results/src_face.mp4",
"src_ref_images": "../save_results/replace/process_results/src_ref.png",
"src_bg_path": "../save_results/replace/process_results/src_bg.mp4",
"src_mask_path": "../save_results/replace/process_results/src_mask.mp4",
"refert_num": 1,
"fps": 30,
"replace_flag": true
......
......@@ -13,11 +13,6 @@
"enable_cfg": false,
"cpu_offload": true,
"offload_granularity": "phase",
"src_pose_path": "../save_results/replace/process_results/src_pose.mp4",
"src_face_path": "../save_results/replace/process_results/src_face.mp4",
"src_ref_images": "../save_results/replace/process_results/src_ref.png",
"src_bg_path": "../save_results/replace/process_results/src_bg.mp4",
"src_mask_path": "../save_results/replace/process_results/src_mask.mp4",
"refert_num": 1,
"fps": 30,
"replace_flag": true,
......
# -*- coding: utf-8 -*-
"""
Audio Source Separation Module
Separates different voice tracks in audio, supports multi-person audio separation
"""
import base64
import io
import os
import tempfile
import traceback
from collections import defaultdict
from typing import Dict, Optional, Union
import torch
import torchaudio
from loguru import logger
# Import pyannote.audio for speaker diarization
from pyannote.audio import Audio, Pipeline
class AudioSeparator:
"""
Audio separator for separating different voice tracks in audio using pyannote.audio
Supports multi-person conversation separation, maintains duration (other speakers' tracks are empty)
"""
def __init__(
self,
model_path: str = None,
device: str = None,
sample_rate: int = 16000,
):
"""
Initialize audio separator
Args:
model_path: Model path (if using custom model), default uses pyannote/speaker-diarization-community-1
device: Device ('cpu', 'cuda', etc.), None for auto selection
sample_rate: Target sample rate, default 16000
"""
self.sample_rate = sample_rate
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
self._init_pyannote(model_path)
def _init_pyannote(self, model_path: str = None):
"""Initialize pyannote.audio pipeline"""
try:
huggingface_token = os.getenv("HUGGINGFACE_TOKEN") or os.getenv("HF_TOKEN")
model_name = model_path or "pyannote/speaker-diarization-community-1"
try:
# Try loading with token if available
if huggingface_token:
self.pipeline = Pipeline.from_pretrained(model_name, token=huggingface_token)
else:
# Try without token (may work for public models)
self.pipeline = Pipeline.from_pretrained(model_name)
except Exception as e:
if "gated" in str(e).lower() or "token" in str(e).lower():
raise RuntimeError(f"Model requires authentication. Set HUGGINGFACE_TOKEN or HF_TOKEN environment variable: {e}")
raise RuntimeError(f"Failed to load pyannote model: {e}")
# Move pipeline to specified device
if self.device:
self.pipeline.to(torch.device(self.device))
# Initialize Audio helper for waveform loading
self.pyannote_audio = Audio()
logger.info("Initialized pyannote.audio speaker diarization pipeline")
except Exception as e:
logger.error(f"Failed to initialize pyannote: {e}")
raise RuntimeError(f"Failed to initialize pyannote.audio pipeline: {e}")
def separate_speakers(
self,
audio_path: Union[str, bytes],
num_speakers: Optional[int] = None,
min_speakers: int = 1,
max_speakers: int = 5,
) -> Dict:
"""
Separate different speakers in audio
Args:
audio_path: Audio file path or bytes data
num_speakers: Specified number of speakers, None for auto detection
min_speakers: Minimum number of speakers
max_speakers: Maximum number of speakers
Returns:
Dict containing:
- speakers: List of speaker audio segments, each containing:
- speaker_id: Speaker ID (0, 1, 2, ...)
- audio: torch.Tensor audio data [channels, samples]
- segments: List of (start_time, end_time) tuples
- sample_rate: Sample rate
"""
try:
# Load audio
if isinstance(audio_path, bytes):
# 尝试从字节数据推断音频格式
# 检查是否是 WAV 格式(RIFF 头)
is_wav = audio_path[:4] == b"RIFF" and audio_path[8:12] == b"WAVE"
# 检查是否是 MP3 格式(ID3 或 MPEG 头)
is_mp3 = audio_path[:3] == b"ID3" or audio_path[:2] == b"\xff\xfb" or audio_path[:2] == b"\xff\xf3"
# 根据格式选择后缀
if is_wav:
suffix = ".wav"
elif is_mp3:
suffix = ".mp3"
else:
# 默认尝试 WAV,如果失败会抛出错误
suffix = ".wav"
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp_file:
tmp_file.write(audio_path)
tmp_audio_path = tmp_file.name
try:
result = self._separate_speakers_internal(tmp_audio_path, num_speakers, min_speakers, max_speakers)
finally:
# 确保临时文件被删除
try:
os.unlink(tmp_audio_path)
except Exception as e:
logger.warning(f"Failed to delete temp file {tmp_audio_path}: {e}")
return result
else:
return self._separate_speakers_internal(audio_path, num_speakers, min_speakers, max_speakers)
except Exception as e:
logger.error(f"Speaker separation failed: {traceback.format_exc()}")
raise RuntimeError(f"Audio separation error: {e}")
def _separate_speakers_internal(
self,
audio_path: str,
num_speakers: Optional[int] = None,
min_speakers: int = 1,
max_speakers: int = 5,
) -> Dict:
"""Internal method: execute speaker separation"""
# Load audio
waveform, original_sr = torchaudio.load(audio_path)
if original_sr != self.sample_rate:
resampler = torchaudio.transforms.Resample(original_sr, self.sample_rate)
waveform = resampler(waveform)
# Convert to mono if stereo
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
# Ensure waveform is float32 and normalized (pyannote expects this format)
if waveform.dtype != torch.float32:
waveform = waveform.float()
# Ensure waveform is in range [-1, 1] (normalize if needed)
if waveform.abs().max() > 1.0:
waveform = waveform / waveform.abs().max()
if self.pipeline is None:
raise RuntimeError("Pyannote pipeline not initialized")
return self._separate_with_pyannote(audio_path, waveform, num_speakers, min_speakers, max_speakers)
def _separate_with_pyannote(
self,
audio_path: str,
waveform: torch.Tensor,
num_speakers: Optional[int],
min_speakers: int,
max_speakers: int,
) -> Dict:
"""Use pyannote.audio for speaker diarization"""
try:
# Use waveform dict to avoid AudioDecoder dependency issues
# Pipeline can accept either file path or waveform dict
# Using waveform dict is more reliable when torchcodec is not properly installed
audio_input = {
"waveform": waveform,
"sample_rate": self.sample_rate,
}
# Run speaker diarization
output = self.pipeline(
audio_input,
min_speakers=min_speakers if num_speakers is None else num_speakers,
max_speakers=max_speakers if num_speakers is None else num_speakers,
)
# Extract audio segments for each speaker
speakers_dict = defaultdict(list)
for turn, speaker in output.speaker_diarization:
print(f"Speaker: {speaker}, Start time: {turn.start}, End time: {turn.end}")
start_time = turn.start
end_time = turn.end
start_sample = int(start_time * self.sample_rate)
end_sample = int(end_time * self.sample_rate)
# Extract audio segment for this time period
segment_audio = waveform[:, start_sample:end_sample]
speakers_dict[speaker].append((start_time, end_time, segment_audio))
# Generate complete audio for each speaker (other speakers' segments are empty)
speakers = []
audio_duration = waveform.shape[1] / self.sample_rate
num_samples = waveform.shape[1]
for speaker_id, segments in speakers_dict.items():
# Create zero-filled audio
speaker_audio = torch.zeros_like(waveform)
# Fill in this speaker's segments
for start_time, end_time, segment_audio in segments:
start_sample = int(start_time * self.sample_rate)
end_sample = int(end_time * self.sample_rate)
# Ensure no out-of-bounds
end_sample = min(end_sample, num_samples)
segment_len = end_sample - start_sample
if segment_len > 0 and segment_audio.shape[1] > 0:
actual_len = min(segment_len, segment_audio.shape[1])
speaker_audio[:, start_sample : start_sample + actual_len] = segment_audio[:, :actual_len]
speakers.append(
{
"speaker_id": speaker_id,
"audio": speaker_audio,
"segments": [(s[0], s[1]) for s in segments],
"sample_rate": self.sample_rate,
}
)
logger.info(f"Separated audio into {len(speakers)} speakers using pyannote")
return {"speakers": speakers, "method": "pyannote"}
except Exception as e:
logger.error(f"Pyannote separation failed: {e}")
raise RuntimeError(f"Audio separation failed: {e}")
def save_speaker_audio(self, speaker_audio: torch.Tensor, output_path: str, sample_rate: int = None):
"""
Save speaker audio to file
Args:
speaker_audio: Audio tensor [channels, samples]
output_path: Output path
sample_rate: Sample rate, if None uses self.sample_rate
"""
sr = sample_rate if sample_rate else self.sample_rate
torchaudio.save(output_path, speaker_audio, sr)
logger.info(f"Saved speaker audio to {output_path}")
def speaker_audio_to_base64(self, speaker_audio: torch.Tensor, sample_rate: int = None, format: str = "wav") -> str:
"""
Convert speaker audio tensor to base64 encoded string without saving to file
Args:
speaker_audio: Audio tensor [channels, samples]
sample_rate: Sample rate, if None uses self.sample_rate
format: Audio format (default: "wav")
Returns:
Base64 encoded audio string
"""
sr = sample_rate if sample_rate else self.sample_rate
# Use BytesIO to save audio to memory
buffer = io.BytesIO()
torchaudio.save(buffer, speaker_audio, sr, format=format)
# Get the audio bytes
audio_bytes = buffer.getvalue()
# Encode to base64
audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
logger.debug(f"Converted speaker audio to base64, size: {len(audio_bytes)} bytes")
return audio_base64
def separate_and_save(
self,
audio_path: Union[str, bytes],
output_dir: str,
num_speakers: Optional[int] = None,
min_speakers: int = 1,
max_speakers: int = 5,
) -> Dict:
"""
Separate audio and save to files
Args:
audio_path: Input audio path or bytes data
output_dir: Output directory
num_speakers: Specified number of speakers
min_speakers: Minimum number of speakers
max_speakers: Maximum number of speakers
Returns:
Separation result dictionary, containing output file paths
"""
os.makedirs(output_dir, exist_ok=True)
result = self.separate_speakers(audio_path, num_speakers, min_speakers, max_speakers)
output_paths = []
for speaker in result["speakers"]:
speaker_id = speaker["speaker_id"]
output_path = os.path.join(output_dir, f"{speaker_id}.wav")
self.save_speaker_audio(speaker["audio"], output_path, speaker["sample_rate"])
output_paths.append(output_path)
speaker["output_path"] = output_path
result["output_paths"] = output_paths
return result
def separate_audio_tracks(
audio_path: str,
output_dir: str = None,
num_speakers: int = None,
model_path: str = None,
) -> Dict:
"""
Convenience function: separate different audio tracks
Args:
audio_path: Audio file path
output_dir: Output directory, if None does not save files
num_speakers: Number of speakers
model_path: Model path (optional)
Returns:
Separation result dictionary
"""
separator = AudioSeparator(model_path=model_path)
if output_dir:
return separator.separate_and_save(audio_path, output_dir, num_speakers=num_speakers)
else:
return separator.separate_speakers(audio_path, num_speakers=num_speakers)
if __name__ == "__main__":
# Test code
import sys
if len(sys.argv) < 2:
print("Usage: python audio_separator.py <audio_path> [output_dir] [num_speakers]")
sys.exit(1)
audio_path = sys.argv[1]
output_dir = sys.argv[2] if len(sys.argv) > 2 else "./separated_audio"
num_speakers = int(sys.argv[3]) if len(sys.argv) > 3 else None
separator = AudioSeparator()
result = separator.separate_and_save(audio_path, output_dir, num_speakers=num_speakers)
print(f"Separated audio into {len(result['speakers'])} speakers:")
for speaker in result["speakers"]:
print(f" Speaker {speaker['speaker_id']}: {len(speaker['segments'])} segments")
if "output_path" in speaker:
print(f" Saved to: {speaker['output_path']}")
# -*- coding: utf-8 -*-
"""
Face Detection Module using YOLO
Supports detecting faces in images, including human faces, animal faces, anime faces, sketches, etc.
"""
import io
import traceback
from typing import Dict, List, Union
import numpy as np
from PIL import Image, ImageDraw
from loguru import logger
from ultralytics import YOLO
class FaceDetector:
"""
Face detection using YOLO models
Supports detecting: human faces, animal faces, anime faces, sketch faces, etc.
"""
def __init__(self, model_path: str = None, conf_threshold: float = 0.25, device: str = None):
"""
Initialize face detector
Args:
model_path: YOLO model path, if None uses default pretrained model
conf_threshold: Confidence threshold, default 0.25
device: Device ('cpu', 'cuda', '0', '1', etc.), None for auto selection
"""
self.conf_threshold = conf_threshold
self.device = device
if model_path is None:
# Use YOLO11 pretrained model, can detect COCO dataset classes (including person)
# Or use dedicated face detection model
logger.info("Loading default YOLO11n model for face detection")
try:
self.model = YOLO("yolo11n.pt") # Lightweight model
except Exception as e:
logger.warning(f"Failed to load default model, trying yolov8n: {e}")
self.model = YOLO("yolov8n.pt")
else:
logger.info(f"Loading YOLO model from {model_path}")
self.model = YOLO(model_path)
# Person class ID in COCO dataset is 0
# YOLO can detect person, for more precise face detection, recommend using dedicated face detection models
# Such as YOLOv8-face or RetinaFace, can be specified via model_path parameter
# First use YOLO to detect person region, then can further detect faces within
self.target_classes = {
"person": 0, # Face (by detecting person class)
# Can be extended to detect animal faces (cat, dog, etc.) and other classes
}
def detect_faces(
self,
image: Union[str, Image.Image, bytes, np.ndarray],
return_image: bool = False,
) -> Dict:
"""
Detect faces in image
Args:
image: Input image, can be path, PIL Image, bytes or numpy array
return_image: Whether to return annotated image with detection boxes
return_boxes: Whether to return detection box information
Returns:
Dict containing:
- faces: List of face detection results, each containing:
- bbox: [x1, y1, x2, y2] bounding box coordinates (absolute pixel coordinates)
- confidence: Confidence score (0.0-1.0)
- class_id: Class ID
- class_name: Class name
- image (optional): PIL Image with detection boxes drawn (if return_image=True)
"""
try:
# Load image
if isinstance(image, str):
img = Image.open(image).convert("RGB")
elif isinstance(image, bytes):
img = Image.open(io.BytesIO(image)).convert("RGB")
elif isinstance(image, np.ndarray):
img = Image.fromarray(image).convert("RGB")
elif isinstance(image, Image.Image):
img = image.convert("RGB")
else:
raise ValueError(f"Unsupported image type: {type(image)}")
# Use YOLO for detection
# Note: YOLO by default detects person, we focus on person detection
# For more precise face detection, can train or use dedicated face detection models
results = self.model.predict(
source=img,
conf=self.conf_threshold,
device=self.device,
verbose=False,
)
faces = []
annotated_img = img.copy() if return_image else None
if len(results) > 0:
result = results[0]
boxes = result.boxes
if boxes is not None and len(boxes) > 0:
for i in range(len(boxes)):
# Get bounding box coordinates (xyxy format)
bbox = boxes.xyxy[i].cpu().numpy().tolist()
confidence = float(boxes.conf[i].cpu().numpy())
class_id = int(boxes.cls[i].cpu().numpy())
# Get class name
class_name = result.names.get(class_id, "unknown")
# Process target classes (person, etc.)
# For person, the entire body box contains face region
# For more precise face detection, can:
# 1. Use dedicated face detection models (RetinaFace, YOLOv8-face)
# 2. Further use face detection model within current person box
# 3. Use specifically trained multi-class detection models (faces, animal faces, anime faces, etc.)
if class_id in self.target_classes.values():
face_info = {
"bbox": bbox, # [x1, y1, x2, y2] - absolute pixel coordinates
"confidence": confidence,
"class_id": class_id,
"class_name": class_name,
}
faces.append(face_info)
# Draw annotations on image if needed
if return_image and annotated_img is not None:
draw = ImageDraw.Draw(annotated_img)
x1, y1, x2, y2 = bbox
# Draw bounding box
draw.rectangle(
[x1, y1, x2, y2],
outline="red",
width=2,
)
# Draw label
label = f"{class_name} {confidence:.2f}"
draw.text((x1, y1 - 15), label, fill="red")
result_dict = {"faces": faces}
if return_image and annotated_img is not None:
result_dict["image"] = annotated_img
logger.info(f"Detected {len(faces)} faces in image")
return result_dict
except Exception as e:
logger.error(f"Face detection failed: {traceback.format_exc()}")
raise RuntimeError(f"Face detection error: {e}")
def detect_faces_from_bytes(self, image_bytes: bytes, **kwargs) -> Dict:
"""
Detect faces from byte data
Args:
image_bytes: Image byte data
**kwargs: Additional parameters passed to detect_faces
Returns:
Detection result dictionary
"""
return self.detect_faces(image_bytes, **kwargs)
def extract_face_regions(self, image: Union[str, Image.Image, bytes], expand_ratio: float = 0.1) -> List[Image.Image]:
"""
Extract detected face regions
Args:
image: Input image
expand_ratio: Bounding box expansion ratio to include more context
Returns:
List of extracted face region images
"""
result = self.detect_faces(image)
faces = result["faces"]
# Load original image
if isinstance(image, str):
img = Image.open(image).convert("RGB")
elif isinstance(image, bytes):
img = Image.open(io.BytesIO(image)).convert("RGB")
elif isinstance(image, Image.Image):
img = image.convert("RGB")
else:
raise ValueError(f"Unsupported image type: {type(image)}")
face_regions = []
img_width, img_height = img.size
for face in faces:
x1, y1, x2, y2 = face["bbox"]
# Expand bounding box
width = x2 - x1
height = y2 - y1
expand_x = width * expand_ratio
expand_y = height * expand_ratio
x1 = max(0, int(x1 - expand_x))
y1 = max(0, int(y1 - expand_y))
x2 = min(img_width, int(x2 + expand_x))
y2 = min(img_height, int(y2 + expand_y))
# Crop region
face_region = img.crop((x1, y1, x2, y2))
face_regions.append(face_region)
return face_regions
def count_faces(self, image: Union[str, Image.Image, bytes]) -> int:
"""
Count number of faces in image
Args:
image: Input image
Returns:
Number of detected faces
"""
result = self.detect_faces(image, return_image=False)
return len(result["faces"])
def detect_faces_in_image(
image_path: str,
model_path: str = None,
conf_threshold: float = 0.25,
return_image: bool = False,
) -> Dict:
"""
Convenience function: detect faces in image
Args:
image_path: Image path
model_path: YOLO model path
conf_threshold: Confidence threshold
return_image: Whether to return annotated image
Returns:
Detection result dictionary containing:
- faces: List of face detection results with bbox coordinates [x1, y1, x2, y2]
- image (optional): Annotated image with detection boxes
"""
detector = FaceDetector(model_path=model_path, conf_threshold=conf_threshold)
return detector.detect_faces(image_path, return_image=return_image)
if __name__ == "__main__":
# Test code
import sys
if len(sys.argv) < 2:
print("Usage: python face_detector.py <image_path>")
sys.exit(1)
image_path = sys.argv[1]
detector = FaceDetector()
result = detector.detect_faces(image_path, return_image=True)
print(f"Detected {len(result['faces'])} faces:")
for i, face in enumerate(result["faces"]):
print(f" Face {i + 1}: {face}")
output_path = "detected_faces.png"
result["image"].save(output_path)
print(f"Annotated image saved to: {output_path}")
This diff is collapsed.
......@@ -69,6 +69,8 @@ def class_try_catch_async(func):
def data_name(x, task_id):
if x == "input_image":
x = x + ".png"
elif x == "input_video":
x = x + ".mp4"
elif x == "output_video":
x = x + ".mp4"
return f"{task_id}-{x}"
......@@ -165,7 +167,14 @@ async def preload_data(inp, inp_type, typ, val):
timeout = int(os.getenv("REQUEST_TIMEOUT", "5"))
data = await fetch_resource(val, timeout=timeout)
elif typ == "base64":
data = base64.b64decode(val)
# Decode base64 in background thread to avoid blocking event loop
data = await asyncio.to_thread(base64.b64decode, val)
# For multi-person audio directory, val should be a dict with file structure
elif typ == "directory":
data = {}
for fname, b64_data in val.items():
data[fname] = await asyncio.to_thread(base64.b64decode, b64_data)
return {"type": "directory", "data": data}
elif typ == "stream":
# no bytes data need to be saved by data_manager
data = None
......@@ -176,8 +185,13 @@ async def preload_data(inp, inp_type, typ, val):
if inp_type == "IMAGE":
data = await asyncio.to_thread(format_image_data, data)
elif inp_type == "AUDIO":
if typ != "stream":
if typ != "stream" and typ != "directory":
data = await asyncio.to_thread(format_audio_data, data)
elif inp_type == "VIDEO":
# Video data doesn't need special formatting, just validate it's not empty
if len(data) == 0:
raise ValueError("Video file is empty")
logger.info(f"load video: {len(data)} bytes")
else:
raise Exception(f"cannot parse inp_type={inp_type} data")
return data
......@@ -191,7 +205,15 @@ async def load_inputs(params, raw_inputs, types):
for inp in raw_inputs:
item = params.pop(inp)
bytes_data = await preload_data(inp, types[inp], item["type"], item["data"])
if bytes_data is not None:
# Handle multi-person audio directory
if bytes_data is not None and isinstance(bytes_data, dict) and bytes_data.get("type") == "directory":
fs = []
for fname, fdata in bytes_data["data"].items():
inputs_data[f"{inp}/{fname}"] = fdata
fs.append(f"{inp}/{fname}")
params["extra_inputs"] = {inp: fs}
elif bytes_data is not None:
inputs_data[inp] = bytes_data
else:
params[inp] = item
......@@ -202,11 +224,15 @@ def check_params(params, raw_inputs, raw_outputs, types):
stream_audio = os.getenv("STREAM_AUDIO", "0") == "1"
stream_video = os.getenv("STREAM_VIDEO", "0") == "1"
for x in raw_inputs + raw_outputs:
if x in params and "type" in params[x] and params[x]["type"] == "stream":
if x in params and "type" in params[x]:
if params[x]["type"] == "stream":
if types[x] == "AUDIO":
assert stream_audio, "stream audio is not supported, please set env STREAM_AUDIO=1"
elif types[x] == "VIDEO":
assert stream_video, "stream video is not supported, please set env STREAM_VIDEO=1"
elif params[x]["type"] == "directory":
# Multi-person audio directory is only supported for AUDIO type
assert types[x] == "AUDIO", f"directory type is only supported for AUDIO input, got {types[x]}"
if __name__ == "__main__":
......
......@@ -22,8 +22,8 @@ class VolcEngineTTSClient:
def __init__(self, voices_list_file=None):
self.url = "https://openspeech.bytedance.com/api/v3/tts/unidirectional"
self.appid = os.getenv("VOLCENGINE_APPID")
self.access_token = os.getenv("VOLCENGINE_ACCESS_TOKEN")
self.appid = os.getenv("VOLCENGINE_TTS_APPID")
self.access_token = os.getenv("VOLCENGINE_TTS_ACCESS_TOKEN")
self.proxy = os.getenv("HTTPS_PROXY", None)
if self.proxy:
logger.info(f"volcengine tts use proxy: {self.proxy}")
......
......@@ -14,6 +14,8 @@ class BaseDataManager:
self.template_audios_dir = None
self.template_videos_dir = None
self.template_tasks_dir = None
self.podcast_temp_session_dir = None
self.podcast_output_dir = None
async def init(self):
pass
......@@ -188,7 +190,8 @@ class BaseDataManager:
template_dir = self.get_template_dir(template_type)
if template_dir is None:
return None
return await self.save_bytes(bytes_data, None, abs_path=os.path.join(template_dir, filename))
abs_path = os.path.join(template_dir, filename)
return await self.save_bytes(bytes_data, None, abs_path=abs_path)
@class_try_catch_async
async def presign_template_url(self, template_type, filename):
......@@ -197,6 +200,46 @@ class BaseDataManager:
return None
return await self.presign_url(None, abs_path=os.path.join(template_dir, filename))
@class_try_catch_async
async def list_podcast_temp_session_files(self, session_id):
session_dir = os.path.join(self.podcast_temp_session_dir, session_id)
return await self.list_files(base_dir=session_dir)
@class_try_catch_async
async def save_podcast_temp_session_file(self, session_id, filename, bytes_data):
fpath = os.path.join(self.podcast_temp_session_dir, session_id, filename)
await self.save_bytes(bytes_data, None, abs_path=fpath)
@class_try_catch_async
async def load_podcast_temp_session_file(self, session_id, filename):
fpath = os.path.join(self.podcast_temp_session_dir, session_id, filename)
return await self.load_bytes(None, abs_path=fpath)
@class_try_catch_async
async def delete_podcast_temp_session_file(self, session_id, filename):
fpath = os.path.join(self.podcast_temp_session_dir, session_id, filename)
return await self.delete_bytes(None, abs_path=fpath)
@class_try_catch_async
async def save_podcast_output_file(self, filename, bytes_data):
fpath = os.path.join(self.podcast_output_dir, filename)
await self.save_bytes(bytes_data, None, abs_path=fpath)
@class_try_catch_async
async def load_podcast_output_file(self, filename):
fpath = os.path.join(self.podcast_output_dir, filename)
return await self.load_bytes(None, abs_path=fpath)
@class_try_catch_async
async def delete_podcast_output_file(self, filename):
fpath = os.path.join(self.podcast_output_dir, filename)
return await self.delete_bytes(None, abs_path=fpath)
@class_try_catch_async
async def presign_podcast_output_url(self, filename):
fpath = os.path.join(self.podcast_output_dir, filename)
return await self.presign_url(None, abs_path=fpath)
# Import data manager implementations
from .local_data_manager import LocalDataManager # noqa
......
import asyncio
import os
import shutil
from loguru import logger
......@@ -24,6 +25,12 @@ class LocalDataManager(BaseDataManager):
assert os.path.exists(self.template_videos_dir), f"{self.template_videos_dir} not exists!"
assert os.path.exists(self.template_tasks_dir), f"{self.template_tasks_dir} not exists!"
# podcast temp session dir and output dir
self.podcast_temp_session_dir = os.path.join(self.local_dir, "podcast_temp_session")
self.podcast_output_dir = os.path.join(self.local_dir, "podcast_output")
os.makedirs(self.podcast_temp_session_dir, exist_ok=True)
os.makedirs(self.podcast_output_dir, exist_ok=True)
@class_try_catch_async
async def save_bytes(self, bytes_data, filename, abs_path=None):
out_path = self.fmt_path(self.local_dir, filename, abs_path)
......@@ -54,6 +61,20 @@ class LocalDataManager(BaseDataManager):
prefix = base_dir if base_dir else self.local_dir
return os.listdir(prefix)
@class_try_catch_async
async def create_podcast_temp_session_dir(self, session_id):
dir_path = os.path.join(self.podcast_temp_session_dir, session_id)
os.makedirs(dir_path, exist_ok=True)
return dir_path
@class_try_catch_async
async def clear_podcast_temp_session_dir(self, session_id):
session_dir = os.path.join(self.podcast_temp_session_dir, session_id)
if os.path.isdir(session_dir):
shutil.rmtree(session_dir)
logger.info(f"cleared podcast temp session dir {session_dir}")
return True
async def test():
import torch
......
......@@ -38,6 +38,10 @@ class S3DataManager(BaseDataManager):
self.template_videos_dir = os.path.join(template_dir, "videos")
self.template_tasks_dir = os.path.join(template_dir, "tasks")
# podcast temp session dir and output dir
self.podcast_temp_session_dir = os.path.join(self.base_path, "podcast_temp_session")
self.podcast_output_dir = os.path.join(self.base_path, "podcast_output")
async def init_presign_client(self):
# init tos client for volces.com
if "volces.com" in self.endpoint_url:
......@@ -128,12 +132,42 @@ class S3DataManager(BaseDataManager):
@class_try_catch_async
async def list_files(self, base_dir=None):
prefix = base_dir if base_dir else self.base_path
response = await self.s3_client.list_objects_v2(Bucket=self.bucket_name, Prefix=prefix)
if base_dir:
prefix = self.fmt_path(self.base_path, None, abs_path=base_dir)
else:
prefix = self.base_path
prefix = prefix + "/" if not prefix.endswith("/") else prefix
# Handle pagination for S3 list_objects_v2
files = []
continuation_token = None
page = 1
while True:
list_kwargs = {"Bucket": self.bucket_name, "Prefix": prefix, "MaxKeys": 1000}
if continuation_token:
list_kwargs["ContinuationToken"] = continuation_token
response = await self.s3_client.list_objects_v2(**list_kwargs)
if "Contents" in response:
page_files = []
for obj in response["Contents"]:
files.append(obj["Key"].replace(prefix + "/", ""))
# Remove the prefix from the key to get just the filename
key = obj["Key"]
if key.startswith(prefix):
filename = key[len(prefix) :]
if filename: # Skip empty filenames (the directory itself)
page_files.append(filename)
files.extend(page_files)
else:
logger.warning(f"[S3DataManager.list_files] Page {page}: No files found in this page.")
# Check if there are more pages
if response.get("IsTruncated", False):
continuation_token = response.get("NextContinuationToken")
page += 1
else:
break
return files
@class_try_catch_async
......@@ -149,6 +183,18 @@ class S3DataManager(BaseDataManager):
else:
return None
@class_try_catch_async
async def create_podcast_temp_session_dir(self, session_id):
pass
@class_try_catch_async
async def clear_podcast_temp_session_dir(self, session_id):
session_dir = os.path.join(self.podcast_temp_session_dir, session_id)
fs = await self.list_files(base_dir=session_dir)
logger.info(f"clear podcast temp session dir {session_dir} with files: {fs}")
for f in fs:
await self.delete_bytes(f, abs_path=os.path.join(session_dir, f))
async def test():
import torch
......
This diff is collapsed.
......@@ -186,7 +186,7 @@ class AuthManager:
try:
payload = jwt.decode(token, secret_key, algorithms=[self.jwt_algorithm])
token_type = payload.get("token_type")
if token_type != expected_type:
if token_type and token_type != expected_type:
raise HTTPException(status_code=401, detail="Token type mismatch")
return payload
except jwt.ExpiredSignatureError:
......
......@@ -46,12 +46,19 @@
<link rel="dns-prefetch" href="https://cdnjs.cloudflare.com">
<link rel="preload" href="/src/style.css" as="style">
<link rel="preload" href="/src/main.js" as="script" type="module">
<link href="https://cdn.bootcdn.net/ajax/libs/font-awesome/6.4.0/css/all.min.css" rel="stylesheet">
<link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css" rel="stylesheet" media="print" onload="this.media='all'">
<link href="https://cdn.bootcdn.net/ajax/libs/font-awesome/7.0.1/css/all.min.css" rel="stylesheet">
<link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/7.0.1/css/all.min.css" rel="stylesheet" media="print" onload="this.media='all'">
<link rel='stylesheet' href='https://cdn-uicons.flaticon.com/3.0.0/uicons-solid-rounded/css/uicons-solid-rounded.css'>
<link rel='stylesheet' href='https://cdn-uicons.flaticon.com/3.0.0/uicons-bold-rounded/css/uicons-bold-rounded.css'>
<link rel='stylesheet' href='https://cdn-uicons.flaticon.com/3.0.0/uicons-bold-straight/css/uicons-bold-straight.css'>
<link rel='stylesheet' href='https://cdn-uicons.flaticon.com/3.0.0/uicons-solid-rounded/css/uicons-solid-rounded.css'>
<link rel='stylesheet' href='https://cdn-uicons.flaticon.com/3.0.0/uicons-regular-rounded/css/uicons-regular-rounded.css'>
<link rel='stylesheet' href='https://cdn-uicons.flaticon.com/3.0.0/uicons-thin-rounded/css/uicons-thin-rounded.css'>
<link rel='stylesheet' href='https://cdn-uicons.flaticon.com/3.0.0/uicons-solid-straight/css/uicons-solid-straight.css'>
<link rel='stylesheet' href='https://cdn-uicons.flaticon.com/3.0.0/uicons-solid-chubby/css/uicons-solid-chubby.css'>
<link href="https://cdn.jsdelivr.net/npm/remixicon@3.5.0/fonts/remixicon.css" rel="stylesheet">
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bootstrap-icons@1.11.0/font/bootstrap-icons.css">
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/@tabler/icons-webfont@latest/tabler-icons.min.css">
<link href="/src/style.css" rel="stylesheet">
<style>
.seo-shell {
......
<script setup>
import { onMounted, onUnmounted, ref } from 'vue'
import router from './router'
import { init, handleLoginCallback, handleClickOutside, validateToken } from './utils/other'
import { initLanguage } from './utils/i18n'
......@@ -76,7 +77,7 @@ onMounted(async () => {
localStorage.removeItem('currentUser')
isLoggedIn.value = false
console.log('Token已过期')
showAlert('请重新登录', 'warning', {
showAlert(t('pleaseRelogin'), 'warning', {
label: t('login'),
onClick: login
})
......@@ -87,7 +88,7 @@ onMounted(async () => {
}
} catch (error) {
console.error('初始化失败', error)
showAlert('初始化失败,请刷新页面重试', 'danger')
showAlert(t('initFailedPleaseRefresh'), 'danger')
isLoggedIn.value = false
} finally {
loginLoading.value = false
......
......@@ -220,7 +220,7 @@ onMounted(() => {
<!-- 视频预览 -->
<video v-if="item?.outputs?.output_video"
:src="getTemplateFileUrl(item.outputs.output_video,'videos')"
:poster="getTemplateFileUrl(item.inputs.input_image,'images')"
:poster="item?.inputs?.input_image ? getTemplateFileUrl(item.inputs.input_image,'images') : undefined"
class="w-full h-auto object-contain group-hover:scale-[1.02] transition-transform duration-200"
preload="auto" playsinline webkit-playsinline
@mouseenter="playVideo($event)" @mouseleave="pauseVideo($event)"
......@@ -228,7 +228,7 @@ onMounted(() => {
@ended="onVideoEnded($event)"
@error="onVideoError($event)"></video>
<!-- 图片缩略图 -->
<img v-else
<img v-else-if="item?.inputs?.input_image"
:src="getTemplateFileUrl(item.inputs.input_image,'images')"
:alt="item.params?.prompt || '模板图片'"
class="w-full h-auto object-contain group-hover:scale-[1.02] transition-transform duration-200"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment