"...source/git@developer.sourcefind.cn:Wenxuan/LightX2V.git" did not exist on "dfc3b85ed20b048486db697a703cc542835802a1"
Commit f05a99da authored by gaclove's avatar gaclove
Browse files

feat: add audio utilities and download functionality

parent a3bc0044
import base64
import os
import re
import uuid
from pathlib import Path
from typing import Optional, Tuple
from loguru import logger
def is_base64_audio(data: str) -> bool:
"""Check if a string is a base64-encoded audio"""
if data.startswith("data:audio/"):
return True
try:
if len(data) % 4 == 0:
base64.b64decode(data, validate=True)
decoded = base64.b64decode(data[:100])
if decoded.startswith(b"ID3"):
return True
if decoded.startswith(b"\xff\xfb") or decoded.startswith(b"\xff\xf3") or decoded.startswith(b"\xff\xf2"):
return True
if decoded.startswith(b"OggS"):
return True
if decoded.startswith(b"RIFF") and b"WAVE" in decoded[:12]:
return True
if decoded.startswith(b"fLaC"):
return True
if decoded[:4] in [b"ftyp", b"\x00\x00\x00\x20", b"\x00\x00\x00\x18"]:
return True
except Exception as e:
logger.warning(f"Error checking base64 audio: {e}")
return False
return False
def extract_base64_data(data: str) -> Tuple[str, Optional[str]]:
"""
Extract base64 data and format from a data URL or plain base64 string
Returns: (base64_data, format)
"""
if data.startswith("data:"):
match = re.match(r"data:audio/(\w+);base64,(.+)", data)
if match:
format_type = match.group(1)
base64_data = match.group(2)
return base64_data, format_type
return data, None
def save_base64_audio(base64_data: str, output_dir: str) -> str:
"""
Save a base64-encoded audio to disk and return the file path
"""
Path(output_dir).mkdir(parents=True, exist_ok=True)
data, format_type = extract_base64_data(base64_data)
file_id = str(uuid.uuid4())
try:
audio_data = base64.b64decode(data)
except Exception as e:
raise ValueError(f"Invalid base64 data: {e}")
if format_type:
ext = format_type
else:
if audio_data.startswith(b"ID3") or audio_data.startswith(b"\xff\xfb") or audio_data.startswith(b"\xff\xf3") or audio_data.startswith(b"\xff\xf2"):
ext = "mp3"
elif audio_data.startswith(b"OggS"):
ext = "ogg"
elif audio_data.startswith(b"RIFF") and b"WAVE" in audio_data[:12]:
ext = "wav"
elif audio_data.startswith(b"fLaC"):
ext = "flac"
elif audio_data[:4] in [b"ftyp", b"\x00\x00\x00\x20", b"\x00\x00\x00\x18"]:
ext = "m4a"
else:
ext = "mp3"
file_path = os.path.join(output_dir, f"{file_id}.{ext}")
with open(file_path, "wb") as f:
f.write(audio_data)
return file_path
...@@ -12,6 +12,7 @@ from loguru import logger ...@@ -12,6 +12,7 @@ from loguru import logger
from ..infer import init_runner from ..infer import init_runner
from ..utils.set_config import set_config from ..utils.set_config import set_config
from .audio_utils import is_base64_audio, save_base64_audio
from .config import server_config from .config import server_config
from .distributed_utils import create_distributed_worker from .distributed_utils import create_distributed_worker
from .image_utils import is_base64_image, save_base64_image from .image_utils import is_base64_image, save_base64_image
...@@ -134,6 +135,43 @@ class FileService: ...@@ -134,6 +135,43 @@ class FileService:
logger.error(f"Unexpected error downloading image from {image_url}: {str(e)}") logger.error(f"Unexpected error downloading image from {image_url}: {str(e)}")
raise ValueError(f"Failed to download image from {image_url}: {str(e)}") raise ValueError(f"Failed to download image from {image_url}: {str(e)}")
async def download_audio(self, audio_url: str) -> Path:
"""Download audio with retry logic and proper error handling."""
try:
parsed_url = urlparse(audio_url)
if not parsed_url.scheme or not parsed_url.netloc:
raise ValueError(f"Invalid URL format: {audio_url}")
response = await self._download_with_retry(audio_url)
audio_name = Path(parsed_url.path).name
if not audio_name:
audio_name = f"{uuid.uuid4()}.mp3"
audio_path = self.input_audio_dir / audio_name
audio_path.parent.mkdir(parents=True, exist_ok=True)
with open(audio_path, "wb") as f:
f.write(response.content)
logger.info(f"Successfully downloaded audio from {audio_url} to {audio_path}")
return audio_path
except httpx.ConnectError as e:
logger.error(f"Connection error downloading audio from {audio_url}: {str(e)}")
raise ValueError(f"Failed to connect to {audio_url}: {str(e)}")
except httpx.TimeoutException as e:
logger.error(f"Timeout downloading audio from {audio_url}: {str(e)}")
raise ValueError(f"Download timeout for {audio_url}: {str(e)}")
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error downloading audio from {audio_url}: {str(e)}")
raise ValueError(f"HTTP error for {audio_url}: {str(e)}")
except ValueError as e:
raise
except Exception as e:
logger.error(f"Unexpected error downloading audio from {audio_url}: {str(e)}")
raise ValueError(f"Failed to download audio from {audio_url}: {str(e)}")
def save_uploaded_file(self, file_content: bytes, filename: str) -> Path: def save_uploaded_file(self, file_content: bytes, filename: str) -> Path:
file_extension = Path(filename).suffix file_extension = Path(filename).suffix
unique_filename = f"{uuid.uuid4()}{file_extension}" unique_filename = f"{uuid.uuid4()}{file_extension}"
...@@ -453,6 +491,16 @@ class VideoGenerationService: ...@@ -453,6 +491,16 @@ class VideoGenerationService:
else: else:
task_data["image_path"] = message.image_path task_data["image_path"] = message.image_path
if "audio_path" in message.model_fields_set and message.audio_path:
if message.audio_path.startswith("http"):
audio_path = await self.file_service.download_audio(message.audio_path)
task_data["audio_path"] = str(audio_path)
elif is_base64_audio(message.audio_path):
audio_path = save_base64_audio(message.audio_path, str(self.file_service.input_audio_dir))
task_data["audio_path"] = str(audio_path)
else:
task_data["audio_path"] = message.audio_path
actual_save_path = self.file_service.get_output_path(message.save_video_path) actual_save_path = self.file_service.get_output_path(message.save_video_path)
task_data["save_video_path"] = str(actual_save_path) task_data["save_video_path"] = str(actual_save_path)
task_data["video_path"] = message.save_video_path task_data["video_path"] = message.save_video_path
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment