Unverified Commit dea872a2 authored by PengGao's avatar PengGao Committed by GitHub
Browse files

Api image (#515)

parent 1892a3db
......@@ -22,9 +22,8 @@ class WanAudioModel(WanModel):
def __init__(self, model_path, config, device):
self.config = config
self.run_device = self.config.get("run_device", "cuda")
self._load_adapter_ckpt()
super().__init__(model_path, config, device)
self._load_adapter_ckpt() # depend on run_device
def _load_adapter_ckpt(self):
if self.config.get("adapter_model_path", None) is None:
......
......@@ -2,7 +2,50 @@
## Overview
The LightX2V server is a distributed video generation service built with FastAPI that processes image-to-video tasks using a multi-process architecture with GPU support. It implements a sophisticated task queue system with distributed inference capabilities for high-throughput video generation workloads.
The LightX2V server is a distributed video/image generation service built with FastAPI that processes image-to-video and text-to-image tasks using a multi-process architecture with GPU support. It implements a sophisticated task queue system with distributed inference capabilities for high-throughput generation workloads.
## Directory Structure
```
server/
├── __init__.py
├── __main__.py # Entry point
├── main.py # Server startup
├── config.py # Configuration
├── task_manager.py # Task management
├── schema.py # Data models (VideoTaskRequest, ImageTaskRequest)
├── api/
│ ├── __init__.py
│ ├── router.py # Main router aggregation
│ ├── deps.py # Dependency injection container
│ ├── server.py # ApiServer class
│ ├── files.py # /v1/files/*
│ ├── service_routes.py # /v1/service/*
│ └── tasks/
│ ├── __init__.py
│ ├── common.py # Common task operations
│ ├── video.py # POST /v1/tasks/video
│ └── image.py # POST /v1/tasks/image
├── services/
│ ├── __init__.py
│ ├── file_service.py # File service (unified download)
│ ├── distributed_utils.py # Distributed manager
│ ├── inference/
│ │ ├── __init__.py
│ │ ├── worker.py # TorchrunInferenceWorker
│ │ └── service.py # DistributedInferenceService
│ └── generation/
│ ├── __init__.py
│ ├── base.py # Base generation service
│ ├── video.py # VideoGenerationService
│ └── image.py # ImageGenerationService
├── media/
│ ├── __init__.py
│ ├── base.py # MediaHandler base class
│ ├── image.py # ImageHandler
│ └── audio.py # AudioHandler
└── metrics/ # Prometheus metrics
```
## Architecture
......@@ -17,14 +60,16 @@ flowchart TB
Router --> FileRoutes[File APIs]
Router --> ServiceRoutes[Service Status APIs]
TaskRoutes --> CreateTask["POST /v1/tasks/ - Create Task"]
TaskRoutes --> CreateTaskForm["POST /v1/tasks/form - Form Create"]
TaskRoutes --> CreateVideoTask["POST /v1/tasks/video - Create Video Task"]
TaskRoutes --> CreateImageTask["POST /v1/tasks/image - Create Image Task"]
TaskRoutes --> CreateVideoTaskForm["POST /v1/tasks/video/form - Form Create Video"]
TaskRoutes --> CreateImageTaskForm["POST /v1/tasks/image/form - Form Create Image"]
TaskRoutes --> ListTasks["GET /v1/tasks/ - List Tasks"]
TaskRoutes --> GetTaskStatus["GET /v1/tasks/id/status - Get Status"]
TaskRoutes --> GetTaskResult["GET /v1/tasks/id/result - Get Result"]
TaskRoutes --> StopTask["DELETE /v1/tasks/id - Stop Task"]
TaskRoutes --> GetTaskStatus["GET /v1/tasks/{id}/status - Get Status"]
TaskRoutes --> GetTaskResult["GET /v1/tasks/{id}/result - Get Result"]
TaskRoutes --> StopTask["DELETE /v1/tasks/{id} - Stop Task"]
FileRoutes --> DownloadFile["GET /v1/files/download/path - Download File"]
FileRoutes --> DownloadFile["GET /v1/files/download/{path} - Download File"]
ServiceRoutes --> GetServiceStatus["GET /v1/service/status - Service Status"]
ServiceRoutes --> GetServiceMetadata["GET /v1/service/metadata - Metadata"]
......@@ -36,8 +81,8 @@ flowchart TB
TaskStatus[Task Status]
TaskResult[Task Result]
CreateTask --> TaskManager
CreateTaskForm --> TaskManager
CreateVideoTask --> TaskManager
CreateImageTask --> TaskManager
TaskManager --> TaskQueue
TaskManager --> TaskStatus
TaskManager --> TaskResult
......@@ -45,17 +90,24 @@ flowchart TB
subgraph File Service
FileService[File Service]
DownloadImage[Download Image]
DownloadAudio[Download Audio]
DownloadMedia[Download Media]
SaveFile[Save File]
GetOutputPath[Get Output Path]
FileService --> DownloadImage
FileService --> DownloadAudio
FileService --> DownloadMedia
FileService --> SaveFile
FileService --> GetOutputPath
end
subgraph Media Handlers
MediaHandler[MediaHandler Base]
ImageHandler[ImageHandler]
AudioHandler[AudioHandler]
MediaHandler --> ImageHandler
MediaHandler --> AudioHandler
end
subgraph Processing Thread
ProcessingThread[Processing Thread]
NextTask[Get Next Task]
......@@ -65,17 +117,19 @@ flowchart TB
ProcessingThread --> ProcessTask
end
subgraph Video Generation Service
VideoService[Video Service]
GenerateVideo[Generate Video]
subgraph Generation Services
VideoService[VideoGenerationService]
ImageService[ImageGenerationService]
BaseService[BaseGenerationService]
VideoService --> GenerateVideo
BaseService --> VideoService
BaseService --> ImageService
end
subgraph Distributed Inference Service
InferenceService[Distributed Inference Service]
InferenceService[DistributedInferenceService]
SubmitTask[Submit Task]
Worker[Inference Worker Node]
Worker[TorchrunInferenceWorker]
ProcessRequest[Process Request]
RunPipeline[Run Inference Pipeline]
......@@ -85,15 +139,16 @@ flowchart TB
ProcessRequest --> RunPipeline
end
%% ====== Connect Modules ======
TaskQueue --> ProcessingThread
ProcessTask --> VideoService
GenerateVideo --> InferenceService
ProcessTask --> ImageService
VideoService --> InferenceService
ImageService --> InferenceService
GetTaskResult --> FileService
DownloadFile --> FileService
VideoService --> FileService
InferenceService --> TaskManager
TaskManager --> TaskStatus
ImageService --> FileService
FileService --> MediaHandler
```
## Task Processing Flow
......@@ -104,13 +159,13 @@ sequenceDiagram
participant API as API Server
participant TM as TaskManager
participant PT as Processing Thread
participant VS as VideoService
participant GS as GenerationService<br/>(Video/Image)
participant FS as FileService
participant DIS as DistributedInferenceService
participant TIW0 as TorchrunInferenceWorker<br/>(Rank 0)
participant TIW1 as TorchrunInferenceWorker<br/>(Rank 1..N)
C->>API: POST /v1/tasks<br/>(Create Task)
C->>API: POST /v1/tasks/video<br/>or /v1/tasks/image
API->>TM: create_task()
TM->>TM: Generate task_id
TM->>TM: Add to queue<br/>(status: PENDING)
......@@ -124,31 +179,30 @@ sequenceDiagram
PT->>TM: acquire_processing_lock()
PT->>TM: start_task()<br/>(status: PROCESSING)
PT->>VS: generate_video_with_stop_event()
PT->>PT: Select service by task type
PT->>GS: generate_with_stop_event()
alt Image is URL
VS->>FS: download_image()
GS->>FS: download_media(url, "image")
FS->>FS: HTTP download<br/>with retry
FS-->>VS: image_path
FS-->>GS: image_path
else Image is Base64
VS->>FS: save_base64_image()
FS-->>VS: image_path
GS->>GS: save_base64_image()
GS-->>GS: image_path
else Image is local path
VS->>VS: use existing path
GS->>GS: use existing path
end
alt Audio is URL
VS->>FS: download_audio()
alt Audio is URL (Video only)
GS->>FS: download_media(url, "audio")
FS->>FS: HTTP download<br/>with retry
FS-->>VS: audio_path
FS-->>GS: audio_path
else Audio is Base64
VS->>FS: save_base64_audio()
FS-->>VS: audio_path
else Audio is local path
VS->>VS: use existing path
GS->>GS: save_base64_audio()
GS-->>GS: audio_path
end
VS->>DIS: submit_task_async(task_data)
GS->>DIS: submit_task_async(task_data)
DIS->>TIW0: process_request(task_data)
Note over TIW0,TIW1: Torchrun-based Distributed Processing
......@@ -180,8 +234,8 @@ sequenceDiagram
TIW0->>DIS: Return result (only rank 0)
TIW1->>TIW1: Return None (non-rank 0)
DIS-->>VS: TaskResponse
VS-->>PT: TaskResponse
DIS-->>GS: TaskResponse
GS-->>PT: TaskResponse
PT->>TM: complete_task()<br/>(status: COMPLETED)
PT->>TM: release_processing_lock()
......@@ -195,8 +249,8 @@ sequenceDiagram
C->>API: GET /v1/tasks/{task_id}/result
API->>TM: get_task_status()
API->>FS: stream_file_response()
FS-->>API: Video Stream
API-->>C: Video File
FS-->>API: Video/Image Stream
API-->>C: Output File
```
## Task States
......@@ -214,6 +268,70 @@ stateDiagram-v2
CANCELLED --> [*]
```
## API Endpoints
### Task APIs
| Endpoint | Method | Description |
|----------|--------|-------------|
| `/v1/tasks/video` | POST | Create video generation task |
| `/v1/tasks/video/form` | POST | Create video task with form data |
| `/v1/tasks/image` | POST | Create image generation task |
| `/v1/tasks/image/form` | POST | Create image task with form data |
| `/v1/tasks` | GET | List all tasks |
| `/v1/tasks/queue/status` | GET | Get queue status |
| `/v1/tasks/{task_id}/status` | GET | Get task status |
| `/v1/tasks/{task_id}/result` | GET | Get task result (stream) |
| `/v1/tasks/{task_id}` | DELETE | Cancel task |
| `/v1/tasks/all/running` | DELETE | Cancel all running tasks |
### File APIs
| Endpoint | Method | Description |
|----------|--------|-------------|
| `/v1/files/download/{path}` | GET | Download output file |
### Service APIs
| Endpoint | Method | Description |
|----------|--------|-------------|
| `/v1/service/status` | GET | Get service status |
| `/v1/service/metadata` | GET | Get service metadata |
## Request Models
### VideoTaskRequest
```python
class VideoTaskRequest(BaseTaskRequest):
num_fragments: int = 1
target_video_length: int = 81
audio_path: str = ""
video_duration: int = 5
talk_objects: Optional[list[TalkObject]] = None
```
### ImageTaskRequest
```python
class ImageTaskRequest(BaseTaskRequest):
aspect_ratio: str = "16:9"
```
### BaseTaskRequest (Common Fields)
```python
class BaseTaskRequest(BaseModel):
task_id: str # auto-generated
prompt: str = ""
use_prompt_enhancer: bool = False
negative_prompt: str = ""
image_path: str = "" # URL, base64, or local path
save_result_path: str = ""
infer_steps: int = 5
seed: int # auto-generated
```
## Configuration
### Environment Variables
......@@ -223,7 +341,8 @@ see `lightx2v/server/config.py`
### Command Line Arguments
```bash
python -m lightx2v.server.main \
# Single GPU
python -m lightx2v.server \
--model_path /path/to/model \
--model_cls wan2.1_distill \
--task i2v \
......@@ -233,14 +352,14 @@ python -m lightx2v.server.main \
```
```bash
python -m lightx2v.server.main \
# Multi-GPU with torchrun
torchrun --nproc_per_node=2 -m lightx2v.server \
--model_path /path/to/model \
--model_cls wan2.1_distill \
--task i2v \
--host 0.0.0.0 \
--port 8000 \
--config_json /path/to/xxx_dist_config.json \
--nproc_per_node 2
--config_json /path/to/xxx_dist_config.json
```
## Key Features
......@@ -269,6 +388,14 @@ python -m lightx2v.server.main \
- **Streaming responses** for large video files
- **Cache management** with automatic cleanup
- **File validation** and format detection
- **Unified media handling** via MediaHandler pattern
### 4. Separate Video/Image Endpoints
- **Dedicated endpoints** for video and image generation
- **Type-specific request models** (VideoTaskRequest, ImageTaskRequest)
- **Automatic service routing** based on task type
- **Backward compatible** with legacy `/v1/tasks` endpoint
## Performance Considerations
......
......@@ -6,19 +6,14 @@ from .main import run_server
def main():
parser = argparse.ArgumentParser(description="LightX2V Server")
# Model arguments
parser.add_argument("--model_path", type=str, required=True, help="Path to model")
parser.add_argument("--model_cls", type=str, required=True, help="Model class name")
# Server arguments
parser.add_argument("--host", type=str, default="0.0.0.0", help="Server host")
parser.add_argument("--port", type=int, default=8000, help="Server port")
# Parse any additional arguments that might be passed
args, unknown = parser.parse_known_args()
# Add any unknown arguments as attributes to args
# This allows flexibility for model-specific arguments
for i in range(0, len(unknown), 2):
if unknown[i].startswith("--"):
key = unknown[i][2:]
......@@ -26,7 +21,6 @@ def main():
value = unknown[i + 1]
setattr(args, key, value)
# Run the server
run_server(args)
......
import asyncio
import gc
import threading
import time
import uuid
from pathlib import Path
from typing import Any, Optional
from urllib.parse import urlparse
import httpx
import torch
from fastapi import APIRouter, FastAPI, File, Form, HTTPException, UploadFile
from fastapi.responses import StreamingResponse
from loguru import logger
from .schema import (
StopTaskResponse,
TaskRequest,
TaskResponse,
)
from .service import DistributedInferenceService, FileService, VideoGenerationService
from .task_manager import TaskStatus, task_manager
class ApiServer:
def __init__(self, max_queue_size: int = 10, app: Optional[FastAPI] = None):
self.app = app or FastAPI(title="LightX2V API", version="1.0.0")
self.file_service = None
self.inference_service = None
self.video_service = None
self.max_queue_size = max_queue_size
self.processing_thread = None
self.stop_processing = threading.Event()
self.tasks_router = APIRouter(prefix="/v1/tasks", tags=["tasks"])
self.files_router = APIRouter(prefix="/v1/files", tags=["files"])
self.service_router = APIRouter(prefix="/v1/service", tags=["service"])
self._setup_routes()
def _setup_routes(self):
@self.app.get("/")
def redirect_to_docs():
return RedirectResponse(url="/docs")
self._setup_task_routes()
self._setup_file_routes()
self._setup_service_routes()
self.app.include_router(self.tasks_router)
self.app.include_router(self.files_router)
self.app.include_router(self.service_router)
def _write_file_sync(self, file_path: Path, content: bytes) -> None:
with open(file_path, "wb") as buffer:
buffer.write(content)
def _stream_file_response(self, file_path: Path, filename: str | None = None) -> StreamingResponse:
assert self.file_service is not None, "File service is not initialized"
try:
resolved_path = file_path.resolve()
if not str(resolved_path).startswith(str(self.file_service.output_video_dir.resolve())):
raise HTTPException(status_code=403, detail="Access to this file is not allowed")
if not resolved_path.exists() or not resolved_path.is_file():
raise HTTPException(status_code=404, detail=f"File not found: {file_path}")
file_size = resolved_path.stat().st_size
actual_filename = filename or resolved_path.name
# Set appropriate MIME type
mime_type = "application/octet-stream"
if actual_filename.lower().endswith((".mp4", ".avi", ".mov", ".mkv")):
mime_type = "video/mp4"
elif actual_filename.lower().endswith((".jpg", ".jpeg", ".png", ".gif")):
mime_type = "image/jpeg"
headers = {
"Content-Disposition": f'attachment; filename="{actual_filename}"',
"Content-Length": str(file_size),
"Accept-Ranges": "bytes",
}
def file_stream_generator(file_path: str, chunk_size: int = 1024 * 1024):
with open(file_path, "rb") as file:
while chunk := file.read(chunk_size):
yield chunk
return StreamingResponse(
file_stream_generator(str(resolved_path)),
media_type=mime_type,
headers=headers,
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error occurred while processing file stream response: {e}")
raise HTTPException(status_code=500, detail="File transfer failed")
def _setup_task_routes(self):
@self.tasks_router.post("/", response_model=TaskResponse)
async def create_task(message: TaskRequest):
"""Create video generation task"""
try:
if hasattr(message, "image_path") and message.image_path and message.image_path.startswith("http"):
if not await self._validate_image_url(message.image_path):
raise HTTPException(status_code=400, detail=f"Image URL is not accessible: {message.image_path}")
task_id = task_manager.create_task(message)
message.task_id = task_id
self._ensure_processing_thread_running()
return TaskResponse(
task_id=task_id,
task_status="pending",
save_result_path=message.save_result_path,
)
except RuntimeError as e:
raise HTTPException(status_code=503, detail=str(e))
except Exception as e:
logger.error(f"Failed to create task: {e}")
raise HTTPException(status_code=500, detail=str(e))
@self.tasks_router.post("/form", response_model=TaskResponse)
async def create_task_form(
image_file: UploadFile = File(...),
prompt: str = Form(default=""),
save_result_path: str = Form(default=""),
use_prompt_enhancer: bool = Form(default=False),
negative_prompt: str = Form(default=""),
num_fragments: int = Form(default=1),
infer_steps: int = Form(default=5),
target_video_length: int = Form(default=81),
seed: int = Form(default=42),
audio_file: UploadFile = File(None),
video_duration: int = Form(default=5),
):
assert self.file_service is not None, "File service is not initialized"
async def save_file_async(file: UploadFile, target_dir: Path) -> str:
if not file or not file.filename:
return ""
file_extension = Path(file.filename).suffix
unique_filename = f"{uuid.uuid4()}{file_extension}"
file_path = target_dir / unique_filename
content = await file.read()
await asyncio.to_thread(self._write_file_sync, file_path, content)
return str(file_path)
image_path = ""
if image_file and image_file.filename:
image_path = await save_file_async(image_file, self.file_service.input_image_dir)
audio_path = ""
if audio_file and audio_file.filename:
audio_path = await save_file_async(audio_file, self.file_service.input_audio_dir)
message = TaskRequest(
prompt=prompt,
use_prompt_enhancer=use_prompt_enhancer,
negative_prompt=negative_prompt,
image_path=image_path,
num_fragments=num_fragments,
save_result_path=save_result_path,
infer_steps=infer_steps,
target_video_length=target_video_length,
seed=seed,
audio_path=audio_path,
video_duration=video_duration,
)
try:
task_id = task_manager.create_task(message)
message.task_id = task_id
self._ensure_processing_thread_running()
return TaskResponse(
task_id=task_id,
task_status="pending",
save_result_path=message.save_result_path,
)
except RuntimeError as e:
raise HTTPException(status_code=503, detail=str(e))
except Exception as e:
logger.error(f"Failed to create form task: {e}")
raise HTTPException(status_code=500, detail=str(e))
@self.tasks_router.get("/", response_model=dict)
async def list_tasks():
return task_manager.get_all_tasks()
@self.tasks_router.get("/queue/status", response_model=dict)
async def get_queue_status():
service_status = task_manager.get_service_status()
return {
"is_processing": task_manager.is_processing(),
"current_task": service_status.get("current_task"),
"pending_count": task_manager.get_pending_task_count(),
"active_count": task_manager.get_active_task_count(),
"queue_size": self.max_queue_size,
"queue_available": self.max_queue_size - task_manager.get_active_task_count(),
}
@self.tasks_router.get("/{task_id}/status")
async def get_task_status(task_id: str):
status = task_manager.get_task_status(task_id)
if not status:
raise HTTPException(status_code=404, detail="Task not found")
return status
@self.tasks_router.get("/{task_id}/result")
async def get_task_result(task_id: str):
assert self.video_service is not None, "Video service is not initialized"
assert self.file_service is not None, "File service is not initialized"
try:
task_status = task_manager.get_task_status(task_id)
if not task_status:
raise HTTPException(status_code=404, detail="Task not found")
if task_status.get("status") != TaskStatus.COMPLETED.value:
raise HTTPException(status_code=404, detail="Task not completed")
save_result_path = task_status.get("save_result_path")
if not save_result_path:
raise HTTPException(status_code=404, detail="Task result file does not exist")
full_path = Path(save_result_path)
if not full_path.is_absolute():
full_path = self.file_service.output_video_dir / save_result_path
return self._stream_file_response(full_path)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error occurred while getting task result: {e}")
raise HTTPException(status_code=500, detail="Failed to get task result")
@self.tasks_router.delete("/{task_id}", response_model=StopTaskResponse)
async def stop_task(task_id: str):
try:
if task_manager.cancel_task(task_id):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info(f"Task {task_id} stopped successfully.")
return StopTaskResponse(stop_status="success", reason="Task stopped successfully.")
else:
return StopTaskResponse(stop_status="do_nothing", reason="Task not found or already completed.")
except Exception as e:
logger.error(f"Error occurred while stopping task {task_id}: {str(e)}")
return StopTaskResponse(stop_status="error", reason=str(e))
@self.tasks_router.delete("/all/running", response_model=StopTaskResponse)
async def stop_all_running_tasks():
try:
task_manager.cancel_all_tasks()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info("All tasks stopped successfully.")
return StopTaskResponse(stop_status="success", reason="All tasks stopped successfully.")
except Exception as e:
logger.error(f"Error occurred while stopping all tasks: {str(e)}")
return StopTaskResponse(stop_status="error", reason=str(e))
def _setup_file_routes(self):
@self.files_router.get("/download/{file_path:path}")
async def download_file(file_path: str):
assert self.file_service is not None, "File service is not initialized"
try:
full_path = self.file_service.output_video_dir / file_path
return self._stream_file_response(full_path)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error occurred while processing file download request: {e}")
raise HTTPException(status_code=500, detail="File download failed")
def _setup_service_routes(self):
@self.service_router.get("/status", response_model=dict)
async def get_service_status():
return task_manager.get_service_status()
@self.service_router.get("/metadata", response_model=dict)
async def get_service_metadata():
assert self.inference_service is not None, "Inference service is not initialized"
return self.inference_service.server_metadata()
async def _validate_image_url(self, image_url: str) -> bool:
if not image_url or not image_url.startswith("http"):
return True
try:
parsed_url = urlparse(image_url)
if not parsed_url.scheme or not parsed_url.netloc:
return False
timeout = httpx.Timeout(connect=5.0, read=5.0, write=5.0, pool=5.0)
async with httpx.AsyncClient(verify=False, timeout=timeout) as client:
response = await client.head(image_url, follow_redirects=True)
return response.status_code < 400
except Exception as e:
logger.warning(f"URL validation failed for {image_url}: {str(e)}")
return False
def _ensure_processing_thread_running(self):
if self.processing_thread is None or not self.processing_thread.is_alive():
self.stop_processing.clear()
self.processing_thread = threading.Thread(target=self._task_processing_loop, daemon=True)
self.processing_thread.start()
logger.info("Started task processing thread")
def _task_processing_loop(self):
logger.info("Task processing loop started")
asyncio.set_event_loop(asyncio.new_event_loop())
loop = asyncio.get_event_loop()
while not self.stop_processing.is_set():
task_id = task_manager.get_next_pending_task()
if task_id is None:
time.sleep(1)
continue
task_info = task_manager.get_task(task_id)
if task_info and task_info.status == TaskStatus.PENDING:
logger.info(f"Processing task {task_id}")
loop.run_until_complete(self._process_single_task(task_info))
loop.close()
logger.info("Task processing loop stopped")
async def _process_single_task(self, task_info: Any):
assert self.video_service is not None, "Video service is not initialized"
task_id = task_info.task_id
message = task_info.message
lock_acquired = task_manager.acquire_processing_lock(task_id, timeout=1)
if not lock_acquired:
logger.error(f"Task {task_id} failed to acquire processing lock")
task_manager.fail_task(task_id, "Failed to acquire processing lock")
return
try:
task_manager.start_task(task_id)
if task_info.stop_event.is_set():
logger.info(f"Task {task_id} cancelled before processing")
task_manager.fail_task(task_id, "Task cancelled")
return
result = await self.video_service.generate_video_with_stop_event(message, task_info.stop_event)
if result:
task_manager.complete_task(task_id, result.save_result_path)
logger.info(f"Task {task_id} completed successfully")
else:
if task_info.stop_event.is_set():
task_manager.fail_task(task_id, "Task cancelled during processing")
logger.info(f"Task {task_id} cancelled during processing")
else:
task_manager.fail_task(task_id, "Generation failed")
logger.error(f"Task {task_id} generation failed")
except Exception as e:
logger.exception(f"Task {task_id} processing failed: {str(e)}")
task_manager.fail_task(task_id, str(e))
finally:
if lock_acquired:
task_manager.release_processing_lock(task_id)
def initialize_services(self, cache_dir: Path, inference_service: DistributedInferenceService):
self.file_service = FileService(cache_dir)
self.inference_service = inference_service
self.video_service = VideoGenerationService(self.file_service, inference_service)
async def cleanup(self):
self.stop_processing.set()
if self.processing_thread and self.processing_thread.is_alive():
self.processing_thread.join(timeout=5)
if self.file_service:
await self.file_service.cleanup()
def get_app(self) -> FastAPI:
return self.app
from .router import create_api_router
from .server import ApiServer
__all__ = [
"create_api_router",
"ApiServer",
]
from pathlib import Path
from typing import Optional
from urllib.parse import urlparse
import httpx
from loguru import logger
from ..services import DistributedInferenceService, FileService, ImageGenerationService, VideoGenerationService
class ServiceContainer:
_instance: Optional["ServiceContainer"] = None
def __init__(self):
self.file_service: Optional[FileService] = None
self.inference_service: Optional[DistributedInferenceService] = None
self.video_service: Optional[VideoGenerationService] = None
self.image_service: Optional[ImageGenerationService] = None
self.max_queue_size: int = 10
@classmethod
def get_instance(cls) -> "ServiceContainer":
if cls._instance is None:
cls._instance = cls()
return cls._instance
def initialize(self, cache_dir: Path, inference_service: DistributedInferenceService, max_queue_size: int = 10):
self.file_service = FileService(cache_dir)
self.inference_service = inference_service
self.video_service = VideoGenerationService(self.file_service, inference_service)
self.image_service = ImageGenerationService(self.file_service, inference_service)
self.max_queue_size = max_queue_size
def get_services() -> ServiceContainer:
return ServiceContainer.get_instance()
async def validate_url_async(url: str) -> bool:
if not url or not url.startswith("http"):
return True
try:
parsed_url = urlparse(url)
if not parsed_url.scheme or not parsed_url.netloc:
return False
timeout = httpx.Timeout(connect=5.0, read=5.0, write=5.0, pool=5.0)
async with httpx.AsyncClient(verify=False, timeout=timeout) as client:
response = await client.head(url, follow_redirects=True)
return response.status_code < 400
except Exception as e:
logger.warning(f"URL validation failed for {url}: {str(e)}")
return False
from pathlib import Path
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from loguru import logger
from .deps import get_services
router = APIRouter()
def _stream_file_response(file_path: Path, filename: str | None = None) -> StreamingResponse:
services = get_services()
assert services.file_service is not None, "File service is not initialized"
try:
resolved_path = file_path.resolve()
if not str(resolved_path).startswith(str(services.file_service.output_video_dir.resolve())):
raise HTTPException(status_code=403, detail="Access to this file is not allowed")
if not resolved_path.exists() or not resolved_path.is_file():
raise HTTPException(status_code=404, detail=f"File not found: {file_path}")
file_size = resolved_path.stat().st_size
actual_filename = filename or resolved_path.name
mime_type = "application/octet-stream"
if actual_filename.lower().endswith((".mp4", ".avi", ".mov", ".mkv")):
mime_type = "video/mp4"
elif actual_filename.lower().endswith((".jpg", ".jpeg", ".png", ".gif")):
mime_type = "image/jpeg"
headers = {
"Content-Disposition": f'attachment; filename="{actual_filename}"',
"Content-Length": str(file_size),
"Accept-Ranges": "bytes",
}
def file_stream_generator(file_path: str, chunk_size: int = 1024 * 1024):
with open(file_path, "rb") as file:
while chunk := file.read(chunk_size):
yield chunk
return StreamingResponse(
file_stream_generator(str(resolved_path)),
media_type=mime_type,
headers=headers,
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error occurred while processing file stream response: {e}")
raise HTTPException(status_code=500, detail="File transfer failed")
@router.get("/download/{file_path:path}")
async def download_file(file_path: str):
services = get_services()
assert services.file_service is not None, "File service is not initialized"
try:
full_path = services.file_service.output_video_dir / file_path
return _stream_file_response(full_path)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error occurred while processing file download request: {e}")
raise HTTPException(status_code=500, detail="File download failed")
from fastapi import APIRouter
from .files import router as files_router
from .service_routes import router as service_router
from .tasks import common_router, image_router, video_router
def create_api_router() -> APIRouter:
api_router = APIRouter()
tasks_router = APIRouter(prefix="/v1/tasks", tags=["tasks"])
tasks_router.include_router(common_router)
tasks_router.include_router(video_router, prefix="/video", tags=["video"])
tasks_router.include_router(image_router, prefix="/image", tags=["image"])
# backward compatibility : POST /v1/tasks default to video task
from .tasks.video import create_video_task
tasks_router.post("/", response_model_exclude_unset=True, deprecated=True)(create_video_task)
api_router.include_router(tasks_router)
api_router.include_router(files_router, prefix="/v1/files", tags=["files"])
api_router.include_router(service_router, prefix="/v1/service", tags=["service"])
return api_router
import asyncio
import threading
import time
from pathlib import Path
from typing import Any, Optional
from fastapi import FastAPI
from loguru import logger
from starlette.responses import RedirectResponse
from ..services import DistributedInferenceService
from ..task_manager import TaskStatus, task_manager
from .deps import ServiceContainer, get_services
from .router import create_api_router
class ApiServer:
def __init__(self, max_queue_size: int = 10, app: Optional[FastAPI] = None):
self.app = app or FastAPI(title="LightX2V API", version="1.0.0")
self.max_queue_size = max_queue_size
self.processing_thread = None
self.stop_processing = threading.Event()
self._setup_routes()
def _setup_routes(self):
@self.app.get("/")
def redirect_to_docs():
return RedirectResponse(url="/docs")
api_router = create_api_router()
self.app.include_router(api_router)
def _ensure_processing_thread_running(self):
if self.processing_thread is None or not self.processing_thread.is_alive():
self.stop_processing.clear()
self.processing_thread = threading.Thread(target=self._task_processing_loop, daemon=True)
self.processing_thread.start()
logger.info("Started task processing thread")
def _task_processing_loop(self):
logger.info("Task processing loop started")
asyncio.set_event_loop(asyncio.new_event_loop())
loop = asyncio.get_event_loop()
while not self.stop_processing.is_set():
task_id = task_manager.get_next_pending_task()
if task_id is None:
time.sleep(1)
continue
task_info = task_manager.get_task(task_id)
if task_info and task_info.status == TaskStatus.PENDING:
logger.info(f"Processing task {task_id}")
loop.run_until_complete(self._process_single_task(task_info))
loop.close()
logger.info("Task processing loop stopped")
async def _process_single_task(self, task_info: Any):
services = get_services()
task_id = task_info.task_id
message = task_info.message
lock_acquired = task_manager.acquire_processing_lock(task_id, timeout=1)
if not lock_acquired:
logger.error(f"Task {task_id} failed to acquire processing lock")
task_manager.fail_task(task_id, "Failed to acquire processing lock")
return
try:
task_manager.start_task(task_id)
if task_info.stop_event.is_set():
logger.info(f"Task {task_id} cancelled before processing")
task_manager.fail_task(task_id, "Task cancelled")
return
from ..schema import ImageTaskRequest
if isinstance(message, ImageTaskRequest):
generation_service = services.image_service
else:
generation_service = services.video_service
result = await generation_service.generate_with_stop_event(message, task_info.stop_event)
if result:
task_manager.complete_task(task_id, result.save_result_path)
logger.info(f"Task {task_id} completed successfully")
else:
if task_info.stop_event.is_set():
task_manager.fail_task(task_id, "Task cancelled during processing")
logger.info(f"Task {task_id} cancelled during processing")
else:
task_manager.fail_task(task_id, "Generation failed")
logger.error(f"Task {task_id} generation failed")
except Exception as e:
logger.exception(f"Task {task_id} processing failed: {str(e)}")
task_manager.fail_task(task_id, str(e))
finally:
if lock_acquired:
task_manager.release_processing_lock(task_id)
def initialize_services(self, cache_dir: Path, inference_service: DistributedInferenceService):
container = ServiceContainer.get_instance()
container.initialize(cache_dir, inference_service, self.max_queue_size)
self._ensure_processing_thread_running()
async def cleanup(self):
self.stop_processing.set()
if self.processing_thread and self.processing_thread.is_alive():
self.processing_thread.join(timeout=5)
services = get_services()
if services.file_service:
await services.file_service.cleanup()
def get_app(self) -> FastAPI:
return self.app
from fastapi import APIRouter
from ..task_manager import task_manager
from .deps import get_services
router = APIRouter()
@router.get("/status")
async def get_service_status():
return task_manager.get_service_status()
@router.get("/metadata")
async def get_service_metadata():
services = get_services()
assert services.inference_service is not None, "Inference service is not initialized"
return services.inference_service.server_metadata()
from .common import router as common_router
from .image import router as image_router
from .video import router as video_router
__all__ = [
"common_router",
"video_router",
"image_router",
]
import gc
from pathlib import Path
import torch
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from loguru import logger
from ...schema import StopTaskResponse
from ...task_manager import TaskStatus, task_manager
from ..deps import get_services
router = APIRouter()
def _stream_file_response(file_path: Path, filename: str | None = None) -> StreamingResponse:
services = get_services()
assert services.file_service is not None, "File service is not initialized"
try:
resolved_path = file_path.resolve()
if not str(resolved_path).startswith(str(services.file_service.output_video_dir.resolve())):
raise HTTPException(status_code=403, detail="Access to this file is not allowed")
if not resolved_path.exists() or not resolved_path.is_file():
raise HTTPException(status_code=404, detail=f"File not found: {file_path}")
file_size = resolved_path.stat().st_size
actual_filename = filename or resolved_path.name
mime_type = "application/octet-stream"
if actual_filename.lower().endswith((".mp4", ".avi", ".mov", ".mkv")):
mime_type = "video/mp4"
elif actual_filename.lower().endswith((".jpg", ".jpeg", ".png", ".gif")):
mime_type = "image/jpeg"
headers = {
"Content-Disposition": f'attachment; filename="{actual_filename}"',
"Content-Length": str(file_size),
"Accept-Ranges": "bytes",
}
def file_stream_generator(file_path: str, chunk_size: int = 1024 * 1024):
with open(file_path, "rb") as file:
while chunk := file.read(chunk_size):
yield chunk
return StreamingResponse(
file_stream_generator(str(resolved_path)),
media_type=mime_type,
headers=headers,
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error occurred while processing file stream response: {e}")
raise HTTPException(status_code=500, detail="File transfer failed")
@router.get("/")
async def list_tasks():
return task_manager.get_all_tasks()
@router.get("/queue/status")
async def get_queue_status():
services = get_services()
service_status = task_manager.get_service_status()
return {
"is_processing": task_manager.is_processing(),
"current_task": service_status.get("current_task"),
"pending_count": task_manager.get_pending_task_count(),
"active_count": task_manager.get_active_task_count(),
"queue_size": services.max_queue_size,
"queue_available": services.max_queue_size - task_manager.get_active_task_count(),
}
@router.get("/{task_id}/status")
async def get_task_status(task_id: str):
status = task_manager.get_task_status(task_id)
if not status:
raise HTTPException(status_code=404, detail="Task not found")
return status
@router.get("/{task_id}/result")
async def get_task_result(task_id: str):
services = get_services()
assert services.video_service is not None, "Video service is not initialized"
assert services.file_service is not None, "File service is not initialized"
try:
task_status = task_manager.get_task_status(task_id)
if not task_status:
raise HTTPException(status_code=404, detail="Task not found")
if task_status.get("status") != TaskStatus.COMPLETED.value:
raise HTTPException(status_code=404, detail="Task not completed")
save_result_path = task_status.get("save_result_path")
if not save_result_path:
raise HTTPException(status_code=404, detail="Task result file does not exist")
full_path = Path(save_result_path)
if not full_path.is_absolute():
full_path = services.file_service.output_video_dir / save_result_path
return _stream_file_response(full_path)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error occurred while getting task result: {e}")
raise HTTPException(status_code=500, detail="Failed to get task result")
@router.delete("/{task_id}", response_model=StopTaskResponse)
async def stop_task(task_id: str):
try:
if task_manager.cancel_task(task_id):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info(f"Task {task_id} stopped successfully.")
return StopTaskResponse(stop_status="success", reason="Task stopped successfully.")
else:
return StopTaskResponse(stop_status="do_nothing", reason="Task not found or already completed.")
except Exception as e:
logger.error(f"Error occurred while stopping task {task_id}: {str(e)}")
return StopTaskResponse(stop_status="error", reason=str(e))
@router.delete("/all/running", response_model=StopTaskResponse)
async def stop_all_running_tasks():
try:
task_manager.cancel_all_tasks()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info("All tasks stopped successfully.")
return StopTaskResponse(stop_status="success", reason="All tasks stopped successfully.")
except Exception as e:
logger.error(f"Error occurred while stopping all tasks: {str(e)}")
return StopTaskResponse(stop_status="error", reason=str(e))
import asyncio
import uuid
from pathlib import Path
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
from loguru import logger
from ...schema import ImageTaskRequest, TaskResponse
from ...task_manager import task_manager
from ..deps import get_services, validate_url_async
router = APIRouter()
def _write_file_sync(file_path: Path, content: bytes) -> None:
with open(file_path, "wb") as buffer:
buffer.write(content)
@router.post("/", response_model=TaskResponse)
async def create_image_task(message: ImageTaskRequest):
try:
if hasattr(message, "image_path") and message.image_path and message.image_path.startswith("http"):
if not await validate_url_async(message.image_path):
raise HTTPException(status_code=400, detail=f"Image URL is not accessible: {message.image_path}")
task_id = task_manager.create_task(message)
message.task_id = task_id
return TaskResponse(
task_id=task_id,
task_status="pending",
save_result_path=message.save_result_path,
)
except RuntimeError as e:
raise HTTPException(status_code=503, detail=str(e))
except Exception as e:
logger.error(f"Failed to create image task: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/form", response_model=TaskResponse)
async def create_image_task_form(
image_file: UploadFile = File(None),
prompt: str = Form(default=""),
save_result_path: str = Form(default=""),
use_prompt_enhancer: bool = Form(default=False),
negative_prompt: str = Form(default=""),
infer_steps: int = Form(default=5),
seed: int = Form(default=42),
aspect_ratio: str = Form(default="16:9"),
):
services = get_services()
assert services.file_service is not None, "File service is not initialized"
async def save_file_async(file: UploadFile, target_dir: Path) -> str:
if not file or not file.filename:
return ""
file_extension = Path(file.filename).suffix
unique_filename = f"{uuid.uuid4()}{file_extension}"
file_path = target_dir / unique_filename
content = await file.read()
await asyncio.to_thread(_write_file_sync, file_path, content)
return str(file_path)
image_path = ""
if image_file and image_file.filename:
image_path = await save_file_async(image_file, services.file_service.input_image_dir)
message = ImageTaskRequest(
prompt=prompt,
use_prompt_enhancer=use_prompt_enhancer,
negative_prompt=negative_prompt,
image_path=image_path,
save_result_path=save_result_path,
infer_steps=infer_steps,
seed=seed,
aspect_ratio=aspect_ratio,
)
try:
task_id = task_manager.create_task(message)
message.task_id = task_id
return TaskResponse(
task_id=task_id,
task_status="pending",
save_result_path=message.save_result_path,
)
except RuntimeError as e:
raise HTTPException(status_code=503, detail=str(e))
except Exception as e:
logger.error(f"Failed to create image form task: {e}")
raise HTTPException(status_code=500, detail=str(e))
import asyncio
import uuid
from pathlib import Path
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
from loguru import logger
from ...schema import TaskResponse, VideoTaskRequest
from ...task_manager import task_manager
from ..deps import get_services, validate_url_async
router = APIRouter()
def _write_file_sync(file_path: Path, content: bytes) -> None:
with open(file_path, "wb") as buffer:
buffer.write(content)
@router.post("/", response_model=TaskResponse)
async def create_video_task(message: VideoTaskRequest):
try:
if hasattr(message, "image_path") and message.image_path and message.image_path.startswith("http"):
if not await validate_url_async(message.image_path):
raise HTTPException(status_code=400, detail=f"Image URL is not accessible: {message.image_path}")
task_id = task_manager.create_task(message)
message.task_id = task_id
return TaskResponse(
task_id=task_id,
task_status="pending",
save_result_path=message.save_result_path,
)
except RuntimeError as e:
raise HTTPException(status_code=503, detail=str(e))
except Exception as e:
logger.error(f"Failed to create video task: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/form", response_model=TaskResponse)
async def create_video_task_form(
image_file: UploadFile = File(...),
prompt: str = Form(default=""),
save_result_path: str = Form(default=""),
use_prompt_enhancer: bool = Form(default=False),
negative_prompt: str = Form(default=""),
num_fragments: int = Form(default=1),
infer_steps: int = Form(default=5),
target_video_length: int = Form(default=81),
seed: int = Form(default=42),
audio_file: UploadFile = File(None),
video_duration: int = Form(default=5),
):
services = get_services()
assert services.file_service is not None, "File service is not initialized"
async def save_file_async(file: UploadFile, target_dir: Path) -> str:
if not file or not file.filename:
return ""
file_extension = Path(file.filename).suffix
unique_filename = f"{uuid.uuid4()}{file_extension}"
file_path = target_dir / unique_filename
content = await file.read()
await asyncio.to_thread(_write_file_sync, file_path, content)
return str(file_path)
image_path = ""
if image_file and image_file.filename:
image_path = await save_file_async(image_file, services.file_service.input_image_dir)
audio_path = ""
if audio_file and audio_file.filename:
audio_path = await save_file_async(audio_file, services.file_service.input_audio_dir)
message = VideoTaskRequest(
prompt=prompt,
use_prompt_enhancer=use_prompt_enhancer,
negative_prompt=negative_prompt,
image_path=image_path,
num_fragments=num_fragments,
save_result_path=save_result_path,
infer_steps=infer_steps,
target_video_length=target_video_length,
seed=seed,
audio_path=audio_path,
video_duration=video_duration,
)
try:
task_id = task_manager.create_task(message)
message.task_id = task_id
return TaskResponse(
task_id=task_id,
task_status="pending",
save_result_path=message.save_result_path,
)
except RuntimeError as e:
raise HTTPException(status_code=503, detail=str(e))
except Exception as e:
logger.error(f"Failed to create video form task: {e}")
raise HTTPException(status_code=500, detail=str(e))
import base64
import os
import re
import uuid
from pathlib import Path
from typing import Optional, Tuple
from loguru import logger
def is_base64_audio(data: str) -> bool:
"""Check if a string is a base64-encoded audio"""
if data.startswith("data:audio/"):
return True
try:
if len(data) % 4 == 0:
base64.b64decode(data, validate=True)
decoded = base64.b64decode(data[:100])
if decoded.startswith(b"ID3"):
return True
if decoded.startswith(b"\xff\xfb") or decoded.startswith(b"\xff\xf3") or decoded.startswith(b"\xff\xf2"):
return True
if decoded.startswith(b"OggS"):
return True
if decoded.startswith(b"RIFF") and b"WAVE" in decoded[:12]:
return True
if decoded.startswith(b"fLaC"):
return True
if decoded[:4] in [b"ftyp", b"\x00\x00\x00\x20", b"\x00\x00\x00\x18"]:
return True
except Exception as e:
logger.warning(f"Error checking base64 audio: {e}")
return False
return False
def extract_base64_data(data: str) -> Tuple[str, Optional[str]]:
"""
Extract base64 data and format from a data URL or plain base64 string
Returns: (base64_data, format)
"""
if data.startswith("data:"):
match = re.match(r"data:audio/(\w+);base64,(.+)", data)
if match:
format_type = match.group(1)
base64_data = match.group(2)
return base64_data, format_type
return data, None
def save_base64_audio(base64_data: str, output_dir: str) -> str:
"""
Save a base64-encoded audio to disk and return the file path
"""
Path(output_dir).mkdir(parents=True, exist_ok=True)
data, format_type = extract_base64_data(base64_data)
file_id = str(uuid.uuid4())
try:
audio_data = base64.b64decode(data)
except Exception as e:
raise ValueError(f"Invalid base64 data: {e}")
if format_type:
ext = format_type
else:
if audio_data.startswith(b"ID3") or audio_data.startswith(b"\xff\xfb") or audio_data.startswith(b"\xff\xf3") or audio_data.startswith(b"\xff\xf2"):
ext = "mp3"
elif audio_data.startswith(b"OggS"):
ext = "ogg"
elif audio_data.startswith(b"RIFF") and b"WAVE" in audio_data[:12]:
ext = "wav"
elif audio_data.startswith(b"fLaC"):
ext = "flac"
elif audio_data[:4] in [b"ftyp", b"\x00\x00\x00\x20", b"\x00\x00\x00\x18"]:
ext = "m4a"
else:
ext = "mp3"
file_path = os.path.join(output_dir, f"{file_id}.{ext}")
with open(file_path, "wb") as f:
f.write(audio_data)
return file_path
import base64
import os
import re
import uuid
from pathlib import Path
from typing import Optional, Tuple
from loguru import logger
def is_base64_image(data: str) -> bool:
"""Check if a string is a base64-encoded image"""
if data.startswith("data:image/"):
return True
try:
if len(data) % 4 == 0:
base64.b64decode(data, validate=True)
decoded = base64.b64decode(data[:100])
if decoded.startswith(b"\x89PNG\r\n\x1a\n"):
return True
if decoded.startswith(b"\xff\xd8\xff"):
return True
if decoded.startswith(b"GIF87a") or decoded.startswith(b"GIF89a"):
return True
if decoded[8:12] == b"WEBP":
return True
except Exception as e:
logger.warning(f"Error checking base64 image: {e}")
return False
return False
def extract_base64_data(data: str) -> Tuple[str, Optional[str]]:
"""
Extract base64 data and format from a data URL or plain base64 string
Returns: (base64_data, format)
"""
if data.startswith("data:"):
match = re.match(r"data:image/(\w+);base64,(.+)", data)
if match:
format_type = match.group(1)
base64_data = match.group(2)
return base64_data, format_type
return data, None
def save_base64_image(base64_data: str, output_dir: str) -> str:
"""
Save a base64-encoded image to disk and return the file path
"""
Path(output_dir).mkdir(parents=True, exist_ok=True)
data, format_type = extract_base64_data(base64_data)
file_id = str(uuid.uuid4())
try:
image_data = base64.b64decode(data)
except Exception as e:
raise ValueError(f"Invalid base64 data: {e}")
if format_type:
ext = format_type
else:
if image_data.startswith(b"\x89PNG\r\n\x1a\n"):
ext = "png"
elif image_data.startswith(b"\xff\xd8\xff"):
ext = "jpg"
elif image_data.startswith(b"GIF87a") or image_data.startswith(b"GIF89a"):
ext = "gif"
elif len(image_data) > 12 and image_data[8:12] == b"WEBP":
ext = "webp"
else:
ext = "png"
file_path = os.path.join(output_dir, f"{file_id}.{ext}")
with open(file_path, "wb") as f:
f.write(image_data)
return file_path
......@@ -7,14 +7,12 @@ from loguru import logger
from .api import ApiServer
from .config import server_config
from .service import DistributedInferenceService
from .services import DistributedInferenceService
def run_server(args):
"""Run server with torchrun support"""
inference_service = None
try:
# Get rank from environment (set by torchrun)
rank = int(os.environ.get("LOCAL_RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
......@@ -28,14 +26,12 @@ def run_server(args):
if not server_config.validate():
raise RuntimeError("Invalid server configuration")
# Initialize inference service
inference_service = DistributedInferenceService()
if not inference_service.start_distributed_inference(args):
raise RuntimeError("Failed to start distributed inference service")
logger.info(f"Rank {rank}: Inference service started successfully")
if rank == 0:
# Only rank 0 runs the FastAPI server
cache_dir = Path(server_config.cache_dir)
cache_dir.mkdir(parents=True, exist_ok=True)
......@@ -47,7 +43,6 @@ def run_server(args):
logger.info(f"Starting FastAPI server on {server_config.host}:{server_config.port}")
uvicorn.run(app, host=server_config.host, port=server_config.port, log_level="info")
else:
# Non-rank-0 processes run the worker loop
logger.info(f"Rank {rank}: Starting worker loop")
import asyncio
......
from .audio import AudioHandler, is_base64_audio, save_base64_audio
from .base import MediaHandler
from .image import ImageHandler, is_base64_image, save_base64_image
__all__ = [
"MediaHandler",
"ImageHandler",
"AudioHandler",
"is_base64_image",
"save_base64_image",
"is_base64_audio",
"save_base64_audio",
]
from typing import Dict
from .base import MediaHandler
class AudioHandler(MediaHandler):
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def get_media_signatures(self) -> Dict[bytes, str]:
return {
b"ID3": "mp3",
b"\xff\xfb": "mp3",
b"\xff\xf3": "mp3",
b"\xff\xf2": "mp3",
b"OggS": "ogg",
b"fLaC": "flac",
}
def get_data_url_prefix(self) -> str:
return "data:audio/"
def get_data_url_pattern(self) -> str:
return r"data:audio/(\w+);base64,(.+)"
def get_default_extension(self) -> str:
return "mp3"
def is_base64(self, data: str) -> bool:
if data.startswith(self.get_data_url_prefix()):
return True
try:
import base64
if len(data) % 4 == 0:
base64.b64decode(data, validate=True)
decoded = base64.b64decode(data[:100])
for signature in self.get_media_signatures().keys():
if decoded.startswith(signature):
return True
if decoded.startswith(b"RIFF") and b"WAVE" in decoded[:12]:
return True
if decoded[:4] in [b"ftyp", b"\x00\x00\x00\x20", b"\x00\x00\x00\x18"]:
return True
except Exception:
return False
return False
def detect_extension(self, data: bytes) -> str:
for signature, ext in self.get_media_signatures().items():
if data.startswith(signature):
return ext
if data.startswith(b"RIFF") and b"WAVE" in data[:12]:
return "wav"
if data[:4] in [b"ftyp", b"\x00\x00\x00\x20", b"\x00\x00\x00\x18"]:
return "m4a"
return self.get_default_extension()
_handler = AudioHandler()
def is_base64_audio(data: str) -> bool:
return _handler.is_base64(data)
def save_base64_audio(base64_data: str, output_dir: str) -> str:
return _handler.save_base64(base64_data, output_dir)
import base64
import os
import re
import uuid
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, Optional, Tuple
from loguru import logger
class MediaHandler(ABC):
@abstractmethod
def get_media_signatures(self) -> Dict[bytes, str]:
"""Return the binary signatures of this media type and their corresponding file extensions."""
pass
@abstractmethod
def get_data_url_prefix(self) -> str:
"""Return the data URL prefix, e.g. 'data:image/' or 'data:audio/'."""
pass
@abstractmethod
def get_data_url_pattern(self) -> str:
"""Return the regex pattern for data URL."""
pass
@abstractmethod
def get_default_extension(self) -> str:
"""Return the default extension for this media type."""
pass
def is_base64(self, data: str) -> bool:
if data.startswith(self.get_data_url_prefix()):
return True
try:
if len(data) % 4 == 0:
base64.b64decode(data, validate=True)
decoded = base64.b64decode(data[:100])
for signature in self.get_media_signatures().keys():
if decoded.startswith(signature):
return True
except Exception as e:
logger.warning(f"Error checking base64 {self.__class__.__name__}: {e}")
return False
return False
def extract_base64_data(self, data: str) -> Tuple[str, Optional[str]]:
if data.startswith("data:"):
match = re.match(self.get_data_url_pattern(), data)
if match:
format_type = match.group(1)
base64_data = match.group(2)
return base64_data, format_type
return data, None
def detect_extension(self, data: bytes) -> str:
for signature, ext in self.get_media_signatures().items():
if data.startswith(signature):
return ext
return self.get_default_extension()
def save_base64(self, base64_data: str, output_dir: str) -> str:
Path(output_dir).mkdir(parents=True, exist_ok=True)
data, format_type = self.extract_base64_data(base64_data)
file_id = str(uuid.uuid4())
try:
media_data = base64.b64decode(data)
except Exception as e:
raise ValueError(f"Invalid base64 data: {e}")
if format_type:
ext = format_type
else:
ext = self.detect_extension(media_data)
file_path = os.path.join(output_dir, f"{file_id}.{ext}")
with open(file_path, "wb") as f:
f.write(media_data)
return file_path
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment