Unverified Commit dea872a2 authored by PengGao's avatar PengGao Committed by GitHub
Browse files

Api image (#515)

parent 1892a3db
from typing import Dict
from .base import MediaHandler
class ImageHandler(MediaHandler):
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def get_media_signatures(self) -> Dict[bytes, str]:
return {
b"\x89PNG\r\n\x1a\n": "png",
b"\xff\xd8\xff": "jpg",
b"GIF87a": "gif",
b"GIF89a": "gif",
}
def get_data_url_prefix(self) -> str:
return "data:image/"
def get_data_url_pattern(self) -> str:
return r"data:image/(\w+);base64,(.+)"
def get_default_extension(self) -> str:
return "png"
def is_base64(self, data: str) -> bool:
if data.startswith(self.get_data_url_prefix()):
return True
try:
import base64
if len(data) % 4 == 0:
base64.b64decode(data, validate=True)
decoded = base64.b64decode(data[:100])
for signature in self.get_media_signatures().keys():
if decoded.startswith(signature):
return True
if len(decoded) > 12 and decoded[8:12] == b"WEBP":
return True
except Exception:
return False
return False
def detect_extension(self, data: bytes) -> str:
for signature, ext in self.get_media_signatures().items():
if data.startswith(signature):
return ext
if len(data) > 12 and data[8:12] == b"WEBP":
return "webp"
return self.get_default_extension()
_handler = ImageHandler()
def is_base64_image(data: str) -> bool:
return _handler.is_base64(data)
def save_base64_image(base64_data: str, output_dir: str) -> str:
return _handler.save_base64(base64_data, output_dir)
#!/usr/bin/env python
"""Example script to run the LightX2V server."""
import argparse
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from lightx2v.server.main import run_server
def main():
parser = argparse.ArgumentParser(description="Run LightX2V inference server")
parser.add_argument("--model_path", type=str, required=True, help="Path to model")
parser.add_argument("--model_cls", type=str, required=True, help="Model class name")
parser.add_argument("--config_json", type=str, help="Path to model config JSON file")
parser.add_argument("--task", type=str, default="i2v", help="Task type (i2v, etc.)")
parser.add_argument("--nproc_per_node", type=int, default=1, help="Number of processes per node (GPUs to use)")
parser.add_argument("--port", type=int, default=8000, help="Server port")
parser.add_argument("--host", type=str, default="127.0.0.1", help="Server host")
args = parser.parse_args()
run_server(args)
if __name__ == "__main__":
main()
import random
from typing import Optional
from pydantic import BaseModel, Field
......@@ -5,35 +6,55 @@ from pydantic import BaseModel, Field
from ..utils.generate_task_id import generate_task_id
def generate_random_seed() -> int:
return random.randint(0, 2**32 - 1)
class TalkObject(BaseModel):
audio: str = Field(..., description="Audio path")
mask: str = Field(..., description="Mask path")
class TaskRequest(BaseModel):
class BaseTaskRequest(BaseModel):
task_id: str = Field(default_factory=generate_task_id, description="Task ID (auto-generated)")
prompt: str = Field("", description="Generation prompt")
use_prompt_enhancer: bool = Field(False, description="Whether to use prompt enhancer")
negative_prompt: str = Field("", description="Negative prompt")
image_path: str = Field("", description="Base64 encoded image or URL")
num_fragments: int = Field(1, description="Number of fragments")
save_result_path: str = Field("", description="Save video path (optional, defaults to task_id.mp4)")
save_result_path: str = Field("", description="Save result path (optional, defaults to task_id, suffix auto-detected)")
infer_steps: int = Field(5, description="Inference steps")
target_video_length: int = Field(81, description="Target video length")
seed: int = Field(42, description="Random seed")
audio_path: str = Field("", description="Input audio path (Wan-Audio)")
video_duration: int = Field(5, description="Video duration (Wan-Audio)")
talk_objects: Optional[list[TalkObject]] = Field(None, description="Talk objects (Wan-Audio)")
seed: int = Field(default_factory=generate_random_seed, description="Random seed (auto-generated if not set)")
def __init__(self, **data):
super().__init__(**data)
if not self.save_result_path:
self.save_result_path = f"{self.task_id}.mp4"
self.save_result_path = f"{self.task_id}"
def get(self, key, default=None):
return getattr(self, key, default)
class VideoTaskRequest(BaseTaskRequest):
num_fragments: int = Field(1, description="Number of fragments")
target_video_length: int = Field(81, description="Target video length")
audio_path: str = Field("", description="Input audio path (Wan-Audio)")
video_duration: int = Field(5, description="Video duration (Wan-Audio)")
talk_objects: Optional[list[TalkObject]] = Field(None, description="Talk objects (Wan-Audio)")
class ImageTaskRequest(BaseTaskRequest):
aspect_ratio: str = Field("16:9", description="Output aspect ratio")
class TaskRequest(BaseTaskRequest):
num_fragments: int = Field(1, description="Number of fragments")
target_video_length: int = Field(81, description="Target video length (video only)")
audio_path: str = Field("", description="Input audio path (Wan-Audio)")
video_duration: int = Field(5, description="Video duration (Wan-Audio)")
talk_objects: Optional[list[TalkObject]] = Field(None, description="Talk objects (Wan-Audio)")
aspect_ratio: str = Field("16:9", description="Output aspect ratio (T2I only)")
class TaskStatusMessage(BaseModel):
task_id: str = Field(..., description="Task ID")
......
import asyncio
import json
import os
import uuid
from pathlib import Path
from typing import Any, Dict, Optional
from urllib.parse import urlparse
import httpx
import torch
from easydict import EasyDict
from loguru import logger
from ..infer import init_runner
from ..utils.input_info import set_input_info
from ..utils.set_config import set_config
from .audio_utils import is_base64_audio, save_base64_audio
from .distributed_utils import DistributedManager
from .image_utils import is_base64_image, save_base64_image
from .schema import TaskRequest, TaskResponse
class FileService:
def __init__(self, cache_dir: Path):
self.cache_dir = cache_dir
self.input_image_dir = cache_dir / "inputs" / "imgs"
self.input_audio_dir = cache_dir / "inputs" / "audios"
self.output_video_dir = cache_dir / "outputs"
self._http_client = None
self._client_lock = asyncio.Lock()
self.max_retries = 3
self.retry_delay = 1.0
self.max_retry_delay = 10.0
for directory in [
self.input_image_dir,
self.output_video_dir,
self.input_audio_dir,
]:
directory.mkdir(parents=True, exist_ok=True)
async def _get_http_client(self) -> httpx.AsyncClient:
"""Get or create a persistent HTTP client with connection pooling."""
async with self._client_lock:
if self._http_client is None or self._http_client.is_closed:
timeout = httpx.Timeout(
connect=10.0,
read=30.0,
write=10.0,
pool=5.0,
)
limits = httpx.Limits(max_keepalive_connections=5, max_connections=10, keepalive_expiry=30.0)
self._http_client = httpx.AsyncClient(verify=False, timeout=timeout, limits=limits, follow_redirects=True)
return self._http_client
async def _download_with_retry(self, url: str, max_retries: Optional[int] = None) -> httpx.Response:
"""Download with exponential backoff retry logic."""
if max_retries is None:
max_retries = self.max_retries
last_exception = None
retry_delay = self.retry_delay
for attempt in range(max_retries):
try:
client = await self._get_http_client()
response = await client.get(url)
if response.status_code == 200:
return response
elif response.status_code >= 500:
logger.warning(f"Server error {response.status_code} for {url}, attempt {attempt + 1}/{max_retries}")
last_exception = httpx.HTTPStatusError(f"Server returned {response.status_code}", request=response.request, response=response)
else:
raise httpx.HTTPStatusError(f"Client error {response.status_code}", request=response.request, response=response)
except (httpx.ConnectError, httpx.TimeoutException, httpx.NetworkError) as e:
logger.warning(f"Connection error for {url}, attempt {attempt + 1}/{max_retries}: {str(e)}")
last_exception = e
except httpx.HTTPStatusError as e:
if e.response and e.response.status_code < 500:
raise
last_exception = e
except Exception as e:
logger.error(f"Unexpected error downloading {url}: {str(e)}")
last_exception = e
if attempt < max_retries - 1:
await asyncio.sleep(retry_delay)
retry_delay = min(retry_delay * 2, self.max_retry_delay)
error_msg = f"All {max_retries} connection attempts failed for {url}"
if last_exception:
error_msg += f": {str(last_exception)}"
raise httpx.ConnectError(error_msg)
async def download_image(self, image_url: str) -> Path:
"""Download image with retry logic and proper error handling."""
try:
parsed_url = urlparse(image_url)
if not parsed_url.scheme or not parsed_url.netloc:
raise ValueError(f"Invalid URL format: {image_url}")
response = await self._download_with_retry(image_url)
image_name = Path(parsed_url.path).name
if not image_name:
image_name = f"{uuid.uuid4()}.jpg"
image_path = self.input_image_dir / image_name
image_path.parent.mkdir(parents=True, exist_ok=True)
with open(image_path, "wb") as f:
f.write(response.content)
logger.info(f"Successfully downloaded image from {image_url} to {image_path}")
return image_path
except httpx.ConnectError as e:
logger.error(f"Connection error downloading image from {image_url}: {str(e)}")
raise ValueError(f"Failed to connect to {image_url}: {str(e)}")
except httpx.TimeoutException as e:
logger.error(f"Timeout downloading image from {image_url}: {str(e)}")
raise ValueError(f"Download timeout for {image_url}: {str(e)}")
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error downloading image from {image_url}: {str(e)}")
raise ValueError(f"HTTP error for {image_url}: {str(e)}")
except ValueError as e:
raise
except Exception as e:
logger.error(f"Unexpected error downloading image from {image_url}: {str(e)}")
raise ValueError(f"Failed to download image from {image_url}: {str(e)}")
async def download_audio(self, audio_url: str) -> Path:
"""Download audio with retry logic and proper error handling."""
try:
parsed_url = urlparse(audio_url)
if not parsed_url.scheme or not parsed_url.netloc:
raise ValueError(f"Invalid URL format: {audio_url}")
response = await self._download_with_retry(audio_url)
audio_name = Path(parsed_url.path).name
if not audio_name:
audio_name = f"{uuid.uuid4()}.mp3"
audio_path = self.input_audio_dir / audio_name
audio_path.parent.mkdir(parents=True, exist_ok=True)
with open(audio_path, "wb") as f:
f.write(response.content)
logger.info(f"Successfully downloaded audio from {audio_url} to {audio_path}")
return audio_path
except httpx.ConnectError as e:
logger.error(f"Connection error downloading audio from {audio_url}: {str(e)}")
raise ValueError(f"Failed to connect to {audio_url}: {str(e)}")
except httpx.TimeoutException as e:
logger.error(f"Timeout downloading audio from {audio_url}: {str(e)}")
raise ValueError(f"Download timeout for {audio_url}: {str(e)}")
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error downloading audio from {audio_url}: {str(e)}")
raise ValueError(f"HTTP error for {audio_url}: {str(e)}")
except ValueError as e:
raise
except Exception as e:
logger.error(f"Unexpected error downloading audio from {audio_url}: {str(e)}")
raise ValueError(f"Failed to download audio from {audio_url}: {str(e)}")
def save_uploaded_file(self, file_content: bytes, filename: str) -> Path:
file_extension = Path(filename).suffix
unique_filename = f"{uuid.uuid4()}{file_extension}"
file_path = self.input_image_dir / unique_filename
with open(file_path, "wb") as f:
f.write(file_content)
return file_path
def get_output_path(self, save_result_path: str) -> Path:
video_path = Path(save_result_path)
if not video_path.is_absolute():
return self.output_video_dir / save_result_path
return video_path
async def cleanup(self):
"""Cleanup resources including HTTP client."""
async with self._client_lock:
if self._http_client and not self._http_client.is_closed:
await self._http_client.aclose()
self._http_client = None
class TorchrunInferenceWorker:
"""Worker class for torchrun-based distributed inference"""
def __init__(self):
self.rank = int(os.environ.get("LOCAL_RANK", 0))
self.world_size = int(os.environ.get("WORLD_SIZE", 1))
self.runner = None
self.dist_manager = DistributedManager()
self.processing = False # Track if currently processing a request
def init(self, args) -> bool:
"""Initialize the worker with model and distributed setup"""
try:
# Initialize distributed process group using torchrun env vars
if self.world_size > 1:
if not self.dist_manager.init_process_group():
raise RuntimeError("Failed to initialize distributed process group")
else:
# Single GPU mode
self.dist_manager.rank = 0
self.dist_manager.world_size = 1
self.dist_manager.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.dist_manager.is_initialized = False
# Initialize model
config = set_config(args)
if self.rank == 0:
logger.info(f"Config:\n {json.dumps(config, ensure_ascii=False, indent=4)}")
self.runner = init_runner(config)
logger.info(f"Rank {self.rank}/{self.world_size - 1} initialization completed")
return True
except Exception as e:
logger.error(f"Rank {self.rank} initialization failed: {str(e)}")
return False
async def process_request(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
"""Process a single inference request
Note: We keep the inference synchronous to maintain NCCL/CUDA context integrity.
The async wrapper allows FastAPI to handle other requests while this runs.
"""
try:
# Only rank 0 broadcasts task data (worker processes already received it in worker_loop)
if self.world_size > 1 and self.rank == 0:
task_data = self.dist_manager.broadcast_task_data(task_data)
# Run inference directly - torchrun handles the parallelization
# Using asyncio.to_thread would be risky with NCCL operations
# Instead, we rely on FastAPI's async handling and queue management
task_data["task"] = self.runner.config["task"]
task_data["return_result_tensor"] = False
task_data["negative_prompt"] = task_data.get("negative_prompt", "")
# must be convert
task_data = EasyDict(task_data)
input_info = set_input_info(task_data)
# update lock config
self.runner.set_config(task_data)
# print("input_info==>", input_info)
self.runner.run_pipeline(input_info)
# Small yield to allow other async operations if needed
await asyncio.sleep(0)
# Synchronize all ranks
if self.world_size > 1:
self.dist_manager.barrier()
# Only rank 0 returns the result
if self.rank == 0:
return {
"task_id": task_data["task_id"],
"status": "success",
"save_result_path": task_data.get("video_path", task_data["save_result_path"]),
"message": "Inference completed",
}
else:
return None
except Exception as e:
logger.exception(f"Rank {self.rank} inference failed: {str(e)}")
if self.world_size > 1:
self.dist_manager.barrier()
if self.rank == 0:
return {
"task_id": task_data.get("task_id", "unknown"),
"status": "failed",
"error": str(e),
"message": f"Inference failed: {str(e)}",
}
else:
return None
async def worker_loop(self):
"""Non-rank-0 workers: Listen for broadcast tasks"""
while True:
try:
task_data = self.dist_manager.broadcast_task_data()
if task_data is None:
logger.info(f"Rank {self.rank} received stop signal")
break
await self.process_request(task_data)
except Exception as e:
logger.error(f"Rank {self.rank} worker loop error: {str(e)}")
continue
def cleanup(self):
self.dist_manager.cleanup()
class DistributedInferenceService:
def __init__(self):
self.worker = None
self.is_running = False
self.args = None
def start_distributed_inference(self, args) -> bool:
self.args = args
if self.is_running:
logger.warning("Distributed inference service is already running")
return True
try:
self.worker = TorchrunInferenceWorker()
if not self.worker.init(args):
raise RuntimeError("Worker initialization failed")
self.is_running = True
logger.info(f"Rank {self.worker.rank} inference service started successfully")
return True
except Exception as e:
logger.error(f"Error starting inference service: {str(e)}")
self.stop_distributed_inference()
return False
def stop_distributed_inference(self):
if not self.is_running:
return
try:
if self.worker:
self.worker.cleanup()
logger.info("Inference service stopped")
except Exception as e:
logger.error(f"Error stopping inference service: {str(e)}")
finally:
self.worker = None
self.is_running = False
async def submit_task_async(self, task_data: dict) -> Optional[dict]:
if not self.is_running or not self.worker:
logger.error("Inference service is not started")
return None
if self.worker.rank != 0:
return None
try:
if self.worker.processing:
# If we want to support queueing, we can add the task to queue
# For now, we'll process sequentially
logger.info(f"Waiting for previous task to complete before processing task {task_data.get('task_id')}")
self.worker.processing = True
result = await self.worker.process_request(task_data)
self.worker.processing = False
return result
except Exception as e:
self.worker.processing = False
logger.error(f"Failed to process task: {str(e)}")
return {
"task_id": task_data.get("task_id", "unknown"),
"status": "failed",
"error": str(e),
"message": f"Task processing failed: {str(e)}",
}
def server_metadata(self):
assert hasattr(self, "args"), "Distributed inference service has not been started. Call start_distributed_inference() first."
return {"nproc_per_node": self.worker.world_size, "model_cls": self.args.model_cls, "model_path": self.args.model_path}
async def run_worker_loop(self):
"""Run the worker loop for non-rank-0 processes"""
if self.worker and self.worker.rank != 0:
await self.worker.worker_loop()
class VideoGenerationService:
def __init__(self, file_service: FileService, inference_service: DistributedInferenceService):
self.file_service = file_service
self.inference_service = inference_service
async def generate_video_with_stop_event(self, message: TaskRequest, stop_event) -> Optional[TaskResponse]:
"""Generate video using torchrun-based inference"""
try:
task_data = {field: getattr(message, field) for field in message.model_fields_set if field != "task_id"}
task_data["task_id"] = message.task_id
if stop_event.is_set():
logger.info(f"Task {message.task_id} cancelled before processing")
return None
if "image_path" in message.model_fields_set and message.image_path:
if message.image_path.startswith("http"):
image_path = await self.file_service.download_image(message.image_path)
task_data["image_path"] = str(image_path)
elif is_base64_image(message.image_path):
image_path = save_base64_image(message.image_path, str(self.file_service.input_image_dir))
task_data["image_path"] = str(image_path)
else:
task_data["image_path"] = message.image_path
logger.info(f"Task {message.task_id} image path: {task_data['image_path']}")
if "audio_path" in message.model_fields_set and message.audio_path:
if message.audio_path.startswith("http"):
audio_path = await self.file_service.download_audio(message.audio_path)
task_data["audio_path"] = str(audio_path)
elif is_base64_audio(message.audio_path):
audio_path = save_base64_audio(message.audio_path, str(self.file_service.input_audio_dir))
task_data["audio_path"] = str(audio_path)
else:
task_data["audio_path"] = message.audio_path
logger.info(f"Task {message.task_id} audio path: {task_data['audio_path']}")
if "talk_objects" in message.model_fields_set and message.talk_objects:
task_data["talk_objects"] = [{} for _ in range(len(message.talk_objects))]
for index, talk_object in enumerate(message.talk_objects):
if talk_object.audio.startswith("http"):
audio_path = await self.file_service.download_audio(talk_object.audio)
task_data["talk_objects"][index]["audio"] = str(audio_path)
elif is_base64_audio(talk_object.audio):
audio_path = save_base64_audio(talk_object.audio, str(self.file_service.input_audio_dir))
task_data["talk_objects"][index]["audio"] = str(audio_path)
else:
task_data["talk_objects"][index]["audio"] = talk_object.audio
if talk_object.mask.startswith("http"):
mask_path = await self.file_service.download_image(talk_object.mask)
task_data["talk_objects"][index]["mask"] = str(mask_path)
elif is_base64_image(talk_object.mask):
mask_path = save_base64_image(talk_object.mask, str(self.file_service.input_image_dir))
task_data["talk_objects"][index]["mask"] = str(mask_path)
else:
task_data["talk_objects"][index]["mask"] = talk_object.mask
# FIXME(xxx): 存储成一个config.json , 然后将这个config.json 的路径,赋值给task_data["audio_path"]
temp_path = self.file_service.cache_dir / uuid.uuid4().hex[:8]
temp_path.mkdir(parents=True, exist_ok=True)
task_data["audio_path"] = str(temp_path)
config_path = temp_path / "config.json"
with open(config_path, "w") as f:
json.dump({"talk_objects": task_data["talk_objects"]}, f)
actual_save_path = self.file_service.get_output_path(message.save_result_path)
task_data["save_result_path"] = str(actual_save_path)
task_data["video_path"] = message.save_result_path
result = await self.inference_service.submit_task_async(task_data)
if result is None:
if stop_event.is_set():
logger.info(f"Task {message.task_id} cancelled during processing")
return None
raise RuntimeError("Task processing failed")
if result.get("status") == "success":
return TaskResponse(
task_id=message.task_id,
task_status="completed",
save_result_path=message.save_result_path, # Return original path
)
else:
error_msg = result.get("error", "Inference failed")
raise RuntimeError(error_msg)
except Exception as e:
logger.exception(f"Task {message.task_id} processing failed: {str(e)}")
raise
from .file_service import FileService
from .generation import ImageGenerationService, VideoGenerationService
from .inference import DistributedInferenceService, TorchrunInferenceWorker
__all__ = [
"FileService",
"DistributedInferenceService",
"TorchrunInferenceWorker",
"VideoGenerationService",
"ImageGenerationService",
]
......@@ -17,20 +17,15 @@ class DistributedManager:
CHUNK_SIZE = 1024 * 1024
def init_process_group(self) -> bool:
"""Initialize process group using torchrun environment variables"""
try:
# torchrun sets these environment variables automatically
self.rank = int(os.environ.get("LOCAL_RANK", 0))
self.world_size = int(os.environ.get("WORLD_SIZE", 1))
if self.world_size > 1:
# torchrun handles backend, init_method, rank, and world_size
# We just need to call init_process_group without parameters
backend = "nccl" if torch.cuda.is_available() else "gloo"
dist.init_process_group(backend=backend, init_method="env://")
logger.info(f"Setup backend: {backend}")
# Set CUDA device for this rank
if torch.cuda.is_available():
torch.cuda.set_device(self.rank)
self.device = f"cuda:{self.rank}"
......
import asyncio
import uuid
from pathlib import Path
from typing import Optional
from urllib.parse import urlparse
import httpx
from loguru import logger
class FileService:
def __init__(self, cache_dir: Path):
self.cache_dir = cache_dir
self.input_image_dir = cache_dir / "inputs" / "imgs"
self.input_audio_dir = cache_dir / "inputs" / "audios"
self.output_video_dir = cache_dir / "outputs"
self._http_client = None
self._client_lock = asyncio.Lock()
self.max_retries = 3
self.retry_delay = 1.0
self.max_retry_delay = 10.0
for directory in [
self.input_image_dir,
self.output_video_dir,
self.input_audio_dir,
]:
directory.mkdir(parents=True, exist_ok=True)
async def _get_http_client(self) -> httpx.AsyncClient:
async with self._client_lock:
if self._http_client is None or self._http_client.is_closed:
timeout = httpx.Timeout(
connect=10.0,
read=30.0,
write=10.0,
pool=5.0,
)
limits = httpx.Limits(max_keepalive_connections=5, max_connections=10, keepalive_expiry=30.0)
self._http_client = httpx.AsyncClient(verify=False, timeout=timeout, limits=limits, follow_redirects=True)
return self._http_client
async def _download_with_retry(self, url: str, max_retries: Optional[int] = None) -> httpx.Response:
if max_retries is None:
max_retries = self.max_retries
last_exception = None
retry_delay = self.retry_delay
for attempt in range(max_retries):
try:
client = await self._get_http_client()
response = await client.get(url)
if response.status_code == 200:
return response
elif response.status_code >= 500:
logger.warning(f"Server error {response.status_code} for {url}, attempt {attempt + 1}/{max_retries}")
last_exception = httpx.HTTPStatusError(f"Server returned {response.status_code}", request=response.request, response=response)
else:
raise httpx.HTTPStatusError(f"Client error {response.status_code}", request=response.request, response=response)
except (httpx.ConnectError, httpx.TimeoutException, httpx.NetworkError) as e:
logger.warning(f"Connection error for {url}, attempt {attempt + 1}/{max_retries}: {str(e)}")
last_exception = e
except httpx.HTTPStatusError as e:
if e.response and e.response.status_code < 500:
raise
last_exception = e
except Exception as e:
logger.error(f"Unexpected error downloading {url}: {str(e)}")
last_exception = e
if attempt < max_retries - 1:
await asyncio.sleep(retry_delay)
retry_delay = min(retry_delay * 2, self.max_retry_delay)
error_msg = f"All {max_retries} connection attempts failed for {url}"
if last_exception:
error_msg += f": {str(last_exception)}"
raise httpx.ConnectError(error_msg)
async def download_media(self, url: str, media_type: str = "image") -> Path:
try:
parsed_url = urlparse(url)
if not parsed_url.scheme or not parsed_url.netloc:
raise ValueError(f"Invalid URL format: {url}")
response = await self._download_with_retry(url)
media_name = Path(parsed_url.path).name
if not media_name:
default_ext = "jpg" if media_type == "image" else "mp3"
media_name = f"{uuid.uuid4()}.{default_ext}"
if media_type == "image":
target_dir = self.input_image_dir
else:
target_dir = self.input_audio_dir
media_path = target_dir / media_name
media_path.parent.mkdir(parents=True, exist_ok=True)
with open(media_path, "wb") as f:
f.write(response.content)
logger.info(f"Successfully downloaded {media_type} from {url} to {media_path}")
return media_path
except httpx.ConnectError as e:
logger.error(f"Connection error downloading {media_type} from {url}: {str(e)}")
raise ValueError(f"Failed to connect to {url}: {str(e)}")
except httpx.TimeoutException as e:
logger.error(f"Timeout downloading {media_type} from {url}: {str(e)}")
raise ValueError(f"Download timeout for {url}: {str(e)}")
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error downloading {media_type} from {url}: {str(e)}")
raise ValueError(f"HTTP error for {url}: {str(e)}")
except ValueError:
raise
except Exception as e:
logger.error(f"Unexpected error downloading {media_type} from {url}: {str(e)}")
raise ValueError(f"Failed to download {media_type} from {url}: {str(e)}")
async def download_image(self, image_url: str) -> Path:
return await self.download_media(image_url, "image")
async def download_audio(self, audio_url: str) -> Path:
return await self.download_media(audio_url, "audio")
def save_uploaded_file(self, file_content: bytes, filename: str) -> Path:
file_extension = Path(filename).suffix
unique_filename = f"{uuid.uuid4()}{file_extension}"
file_path = self.input_image_dir / unique_filename
with open(file_path, "wb") as f:
f.write(file_content)
return file_path
def get_output_path(self, save_result_path: str) -> Path:
video_path = Path(save_result_path)
if not video_path.is_absolute():
return self.output_video_dir / save_result_path
return video_path
async def cleanup(self):
async with self._client_lock:
if self._http_client and not self._http_client.is_closed:
await self._http_client.aclose()
self._http_client = None
from .base import BaseGenerationService
from .image import ImageGenerationService
from .video import VideoGenerationService
__all__ = [
"BaseGenerationService",
"VideoGenerationService",
"ImageGenerationService",
]
import json
import uuid
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
from loguru import logger
from ...media import is_base64_audio, is_base64_image, save_base64_audio, save_base64_image
from ...schema import TaskResponse
from ..file_service import FileService
from ..inference import DistributedInferenceService
class BaseGenerationService(ABC):
def __init__(self, file_service: FileService, inference_service: DistributedInferenceService):
self.file_service = file_service
self.inference_service = inference_service
@abstractmethod
def get_output_extension(self) -> str:
pass
@abstractmethod
def get_task_type(self) -> str:
pass
def _is_target_task_type(self) -> bool:
if self.inference_service.worker and self.inference_service.worker.runner:
task_type = self.inference_service.worker.runner.config.get("task", "t2v")
return task_type in self.get_task_type().split(",")
return False
async def _process_image_path(self, image_path: str, task_data: Dict[str, Any]) -> None:
if not image_path:
return
if image_path.startswith("http"):
downloaded_path = await self.file_service.download_image(image_path)
task_data["image_path"] = str(downloaded_path)
elif is_base64_image(image_path):
saved_path = save_base64_image(image_path, str(self.file_service.input_image_dir))
task_data["image_path"] = str(saved_path)
else:
task_data["image_path"] = image_path
async def _process_audio_path(self, audio_path: str, task_data: Dict[str, Any]) -> None:
if not audio_path:
return
if audio_path.startswith("http"):
downloaded_path = await self.file_service.download_audio(audio_path)
task_data["audio_path"] = str(downloaded_path)
elif is_base64_audio(audio_path):
saved_path = save_base64_audio(audio_path, str(self.file_service.input_audio_dir))
task_data["audio_path"] = str(saved_path)
else:
task_data["audio_path"] = audio_path
async def _process_talk_objects(self, talk_objects: list, task_data: Dict[str, Any]) -> None:
if not talk_objects:
return
task_data["talk_objects"] = [{} for _ in range(len(talk_objects))]
for index, talk_object in enumerate(talk_objects):
if talk_object.audio.startswith("http"):
audio_path = await self.file_service.download_audio(talk_object.audio)
task_data["talk_objects"][index]["audio"] = str(audio_path)
elif is_base64_audio(talk_object.audio):
audio_path = save_base64_audio(talk_object.audio, str(self.file_service.input_audio_dir))
task_data["talk_objects"][index]["audio"] = str(audio_path)
else:
task_data["talk_objects"][index]["audio"] = talk_object.audio
if talk_object.mask.startswith("http"):
mask_path = await self.file_service.download_image(talk_object.mask)
task_data["talk_objects"][index]["mask"] = str(mask_path)
elif is_base64_image(talk_object.mask):
mask_path = save_base64_image(talk_object.mask, str(self.file_service.input_image_dir))
task_data["talk_objects"][index]["mask"] = str(mask_path)
else:
task_data["talk_objects"][index]["mask"] = talk_object.mask
temp_path = self.file_service.cache_dir / uuid.uuid4().hex[:8]
temp_path.mkdir(parents=True, exist_ok=True)
task_data["audio_path"] = str(temp_path)
config_path = temp_path / "config.json"
with open(config_path, "w") as f:
json.dump({"talk_objects": task_data["talk_objects"]}, f)
def _prepare_output_path(self, save_result_path: str, task_data: Dict[str, Any]) -> None:
actual_save_path = self.file_service.get_output_path(save_result_path)
if not actual_save_path.suffix:
actual_save_path = actual_save_path.with_suffix(self.get_output_extension())
task_data["save_result_path"] = str(actual_save_path)
task_data["video_path"] = actual_save_path.name
async def generate_with_stop_event(self, message: Any, stop_event) -> Optional[Any]:
try:
task_data = {field: getattr(message, field) for field in message.model_fields_set if field != "task_id"}
task_data["task_id"] = message.task_id
if stop_event.is_set():
logger.info(f"Task {message.task_id} cancelled before processing")
return None
if hasattr(message, "image_path") and message.image_path:
await self._process_image_path(message.image_path, task_data)
logger.info(f"Task {message.task_id} image path: {task_data.get('image_path')}")
if hasattr(message, "audio_path") and message.audio_path:
await self._process_audio_path(message.audio_path, task_data)
logger.info(f"Task {message.task_id} audio path: {task_data.get('audio_path')}")
if hasattr(message, "talk_objects") and message.talk_objects:
await self._process_talk_objects(message.talk_objects, task_data)
self._prepare_output_path(message.save_result_path, task_data)
task_data["seed"] = message.seed
result = await self.inference_service.submit_task_async(task_data)
if result is None:
if stop_event.is_set():
logger.info(f"Task {message.task_id} cancelled during processing")
return None
raise RuntimeError("Task processing failed")
if result.get("status") == "success":
actual_save_path = self.file_service.get_output_path(message.save_result_path)
if not actual_save_path.suffix:
actual_save_path = actual_save_path.with_suffix(self.get_output_extension())
return TaskResponse(
task_id=message.task_id,
task_status="completed",
save_result_path=actual_save_path.name,
)
else:
error_msg = result.get("error", "Inference failed")
raise RuntimeError(error_msg)
except Exception as e:
logger.exception(f"Task {message.task_id} processing failed: {str(e)}")
raise
from typing import Any, Optional
from loguru import logger
from ...schema import TaskResponse
from ..file_service import FileService
from ..inference import DistributedInferenceService
from .base import BaseGenerationService
class ImageGenerationService(BaseGenerationService):
def __init__(self, file_service: FileService, inference_service: DistributedInferenceService):
super().__init__(file_service, inference_service)
def get_output_extension(self) -> str:
return ".png"
def get_task_type(self) -> str:
return "t2i,i2i"
async def generate_with_stop_event(self, message: Any, stop_event) -> Optional[Any]:
try:
task_data = {field: getattr(message, field) for field in message.model_fields_set if field != "task_id"}
task_data["task_id"] = message.task_id
if hasattr(message, "aspect_ratio"):
task_data["aspect_ratio"] = message.aspect_ratio
if stop_event.is_set():
logger.info(f"Task {message.task_id} cancelled before processing")
return None
if hasattr(message, "image_path") and message.image_path:
await self._process_image_path(message.image_path, task_data)
logger.info(f"Task {message.task_id} image path: {task_data.get('image_path')}")
self._prepare_output_path(message.save_result_path, task_data)
task_data["seed"] = message.seed
result = await self.inference_service.submit_task_async(task_data)
if result is None:
if stop_event.is_set():
logger.info(f"Task {message.task_id} cancelled during processing")
return None
raise RuntimeError("Task processing failed")
if result.get("status") == "success":
actual_save_path = self.file_service.get_output_path(message.save_result_path)
if not actual_save_path.suffix:
actual_save_path = actual_save_path.with_suffix(self.get_output_extension())
return TaskResponse(
task_id=message.task_id,
task_status="completed",
save_result_path=actual_save_path.name,
)
else:
error_msg = result.get("error", "Inference failed")
raise RuntimeError(error_msg)
except Exception as e:
logger.exception(f"Task {message.task_id} processing failed: {str(e)}")
raise
async def generate_image_with_stop_event(self, message: Any, stop_event) -> Optional[Any]:
return await self.generate_with_stop_event(message, stop_event)
from typing import Any, Optional
from ..file_service import FileService
from ..inference import DistributedInferenceService
from .base import BaseGenerationService
class VideoGenerationService(BaseGenerationService):
def __init__(self, file_service: FileService, inference_service: DistributedInferenceService):
super().__init__(file_service, inference_service)
def get_output_extension(self) -> str:
return ".mp4"
def get_task_type(self) -> str:
return "t2v,i2v,s2v"
async def generate_with_stop_event(self, message: Any, stop_event) -> Optional[Any]:
return await super().generate_with_stop_event(message, stop_event)
async def generate_video_with_stop_event(self, message: Any, stop_event) -> Optional[Any]:
return await self.generate_with_stop_event(message, stop_event)
from .service import DistributedInferenceService
from .worker import TorchrunInferenceWorker
__all__ = [
"TorchrunInferenceWorker",
"DistributedInferenceService",
]
from typing import Optional
from loguru import logger
from .worker import TorchrunInferenceWorker
class DistributedInferenceService:
def __init__(self):
self.worker = None
self.is_running = False
self.args = None
def start_distributed_inference(self, args) -> bool:
self.args = args
if self.is_running:
logger.warning("Distributed inference service is already running")
return True
try:
self.worker = TorchrunInferenceWorker()
if not self.worker.init(args):
raise RuntimeError("Worker initialization failed")
self.is_running = True
logger.info(f"Rank {self.worker.rank} inference service started successfully")
return True
except Exception as e:
logger.error(f"Error starting inference service: {str(e)}")
self.stop_distributed_inference()
return False
def stop_distributed_inference(self):
if not self.is_running:
return
try:
if self.worker:
self.worker.cleanup()
logger.info("Inference service stopped")
except Exception as e:
logger.error(f"Error stopping inference service: {str(e)}")
finally:
self.worker = None
self.is_running = False
async def submit_task_async(self, task_data: dict) -> Optional[dict]:
if not self.is_running or not self.worker:
logger.error("Inference service is not started")
return None
if self.worker.rank != 0:
return None
try:
if self.worker.processing:
logger.info(f"Waiting for previous task to complete before processing task {task_data.get('task_id')}")
self.worker.processing = True
result = await self.worker.process_request(task_data)
self.worker.processing = False
return result
except Exception as e:
self.worker.processing = False
logger.error(f"Failed to process task: {str(e)}")
return {
"task_id": task_data.get("task_id", "unknown"),
"status": "failed",
"error": str(e),
"message": f"Task processing failed: {str(e)}",
}
def server_metadata(self):
assert hasattr(self, "args"), "Distributed inference service has not been started. Call start_distributed_inference() first."
return {"nproc_per_node": self.worker.world_size, "model_cls": self.args.model_cls, "model_path": self.args.model_path}
async def run_worker_loop(self):
if self.worker and self.worker.rank != 0:
await self.worker.worker_loop()
import asyncio
import json
import os
from typing import Any, Dict
import torch
from easydict import EasyDict
from loguru import logger
from lightx2v.infer import init_runner
from lightx2v.utils.input_info import set_input_info
from lightx2v.utils.set_config import set_config
from ..distributed_utils import DistributedManager
class TorchrunInferenceWorker:
def __init__(self):
self.rank = int(os.environ.get("LOCAL_RANK", 0))
self.world_size = int(os.environ.get("WORLD_SIZE", 1))
self.runner = None
self.dist_manager = DistributedManager()
self.processing = False
def init(self, args) -> bool:
try:
if self.world_size > 1:
if not self.dist_manager.init_process_group():
raise RuntimeError("Failed to initialize distributed process group")
else:
self.dist_manager.rank = 0
self.dist_manager.world_size = 1
self.dist_manager.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.dist_manager.is_initialized = False
config = set_config(args)
if self.rank == 0:
logger.info(f"Config:\n {json.dumps(config, ensure_ascii=False, indent=4)}")
self.runner = init_runner(config)
logger.info(f"Rank {self.rank}/{self.world_size - 1} initialization completed")
return True
except Exception as e:
logger.exception(f"Rank {self.rank} initialization failed: {str(e)}")
return False
async def process_request(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
try:
if self.world_size > 1 and self.rank == 0:
task_data = self.dist_manager.broadcast_task_data(task_data)
task_data["task"] = self.runner.config["task"]
task_data["return_result_tensor"] = False
task_data["negative_prompt"] = task_data.get("negative_prompt", "")
task_data = EasyDict(task_data)
input_info = set_input_info(task_data)
self.runner.set_config(task_data)
self.runner.run_pipeline(input_info)
await asyncio.sleep(0)
if self.world_size > 1:
self.dist_manager.barrier()
if self.rank == 0:
return {
"task_id": task_data["task_id"],
"status": "success",
"save_result_path": task_data.get("video_path", task_data["save_result_path"]),
"message": "Inference completed",
}
else:
return None
except Exception as e:
logger.exception(f"Rank {self.rank} inference failed: {str(e)}")
if self.world_size > 1:
self.dist_manager.barrier()
if self.rank == 0:
return {
"task_id": task_data.get("task_id", "unknown"),
"status": "failed",
"error": str(e),
"message": f"Inference failed: {str(e)}",
}
else:
return None
async def worker_loop(self):
while True:
try:
task_data = self.dist_manager.broadcast_task_data()
if task_data is None:
logger.info(f"Rank {self.rank} received stop signal")
break
await self.process_request(task_data)
except Exception as e:
logger.error(f"Rank {self.rank} worker loop error: {str(e)}")
continue
def cleanup(self):
self.dist_manager.cleanup()
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
export CUDA_VISIBLE_DEVICES=0
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
# Start API server with distributed inference service
python -m lightx2v.server \
--model_cls qwen_image \
--task i2i \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/qwen_image/qwen_image_i2i.json \
--port 8000
echo "Service stopped"
# {
# "prompt": "turn the style of the photo to vintage comic book",
# "image_path": "assets/inputs/imgs/snake.png",
# "infer_steps": 50
# }
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
export CUDA_VISIBLE_DEVICES=0
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
# Start API server with distributed inference service
python -m lightx2v.server \
--model_cls qwen_image \
--task t2i \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/qwen_image/qwen_image_t2i.json \
--port 8000
echo "Service stopped"
# {
# "prompt": "a beautiful sunset over the ocean",
# "aspect_ratio": "16:9",
# "infer_steps": 50
# }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment