Commit bab78b8e authored by gaclove's avatar gaclove
Browse files

refactor: server api

- Introduced `image_utils.py` with functions to check, extract, and save base64-encoded images.
- Updated `TaskRequest` schema to accept base64 images or URLs in `image_path`.
- Modified `VideoGenerationService` to handle base64 images, saving them appropriately.
- Updated scripts to convert local image paths to base64 before sending requests.
parent fab43a07
#!/usr/bin/env python
import argparse
import atexit
import signal
import sys
from pathlib import Path
import uvicorn
from loguru import logger
sys.path.insert(0, str(Path(__file__).parent.parent))
from lightx2v.server.api import ApiServer
from lightx2v.server.service import DistributedInferenceService
def create_signal_handler(inference_service: DistributedInferenceService):
"""Create unified signal handler function"""
def signal_handler(signum, frame):
logger.info(f"Received signal {signum}, gracefully shutting down...")
try:
if inference_service.is_running:
inference_service.stop_distributed_inference()
except Exception as e:
logger.error(f"Error occurred while shutting down distributed inference service: {str(e)}")
finally:
sys.exit(0)
return signal_handler
from lightx2v.server.main import run_server
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_cls",
type=str,
required=True,
choices=[
"wan2.1",
"hunyuan",
"wan2.1_distill",
"wan2.1_causvid",
"wan2.1_skyreels_v2_df",
"wan2.1_audio",
"wan2.2_moe",
"wan2.2_moe_distill",
],
default="wan2.1",
)
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True)
parser = argparse.ArgumentParser(description="Run LightX2V inference server")
parser.add_argument("--split", action="store_true")
parser.add_argument("--lora_path", type=str, required=False, default=None)
parser.add_argument("--lora_strength", type=float, default=1.0, help="The strength for the lora (default: 1.0)")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--nproc_per_node", type=int, default=1, help="Number of processes per node for distributed inference")
parser.add_argument("--model_path", type=str, required=True, help="Path to model")
parser.add_argument("--model_cls", type=str, required=True, help="Model class name")
parser.add_argument("--config_json", type=str, help="Path to model config JSON file")
parser.add_argument("--task", type=str, default="i2v", help="Task type (i2v, etc.)")
args = parser.parse_args()
logger.info(f"args: {args}")
parser.add_argument("--nproc_per_node", type=int, default=1, help="Number of processes per node (GPUs to use)")
cache_dir = Path(__file__).parent.parent / "server_cache"
inference_service = DistributedInferenceService()
parser.add_argument("--port", type=int, default=8000, help="Server port")
parser.add_argument("--host", type=str, default="127.0.0.1", help="Server host")
api_server = ApiServer()
api_server.initialize_services(cache_dir, inference_service)
signal_handler = create_signal_handler(inference_service)
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
logger.info("Starting distributed inference service...")
success = inference_service.start_distributed_inference(args)
if not success:
logger.error("Failed to start distributed inference service, exiting program")
sys.exit(1)
atexit.register(inference_service.stop_distributed_inference)
args = parser.parse_args()
try:
logger.info(f"Starting FastAPI server on port: {args.port}")
uvicorn.run(
api_server.get_app(),
host="0.0.0.0",
port=args.port,
reload=False,
workers=1,
)
except KeyboardInterrupt:
logger.info("Received KeyboardInterrupt, shutting down service...")
except Exception as e:
logger.error(f"Error occurred while running FastAPI server: {str(e)}")
finally:
inference_service.stop_distributed_inference()
run_server(args)
if __name__ == "__main__":
......
# LightX2V Server
## Overview
The LightX2V server is a distributed video generation service built with FastAPI that processes image-to-video tasks using a multi-process architecture with GPU support. It implements a sophisticated task queue system with distributed inference capabilities for high-throughput video generation workloads.
## Architecture
### System Architecture
```mermaid
graph TB
subgraph "Client Layer"
Client[HTTP Client]
end
subgraph "API Layer"
FastAPI[FastAPI Application]
ApiServer[ApiServer]
Router1[Tasks Router<br/>/v1/tasks]
Router2[Files Router<br/>/v1/files]
Router3[Service Router<br/>/v1/service]
end
subgraph "Service Layer"
TaskManager[TaskManager<br/>Thread-safe Task Queue]
FileService[FileService<br/>File I/O & Downloads]
VideoService[VideoGenerationService]
end
subgraph "Processing Layer"
Thread[Processing Thread<br/>Sequential Task Loop]
end
subgraph "Distributed Inference Layer"
DistService[DistributedInferenceService]
SharedData[(Shared Data<br/>mp.Manager.dict)]
TaskEvent[Task Event<br/>mp.Manager.Event]
ResultEvent[Result Event<br/>mp.Manager.Event]
subgraph "Worker Processes"
W0[Worker 0<br/>Master/Rank 0]
W1[Worker 1<br/>Rank 1]
WN[Worker N<br/>Rank N]
end
end
subgraph "Resource Management"
GPUManager[GPUManager<br/>GPU Detection & Allocation]
DistManager[DistributedManager<br/>PyTorch Distributed]
Config[ServerConfig<br/>Configuration]
end
Client -->|HTTP Request| FastAPI
FastAPI --> ApiServer
ApiServer --> Router1
ApiServer --> Router2
ApiServer --> Router3
Router1 -->|Create/Manage Tasks| TaskManager
Router1 -->|Process Tasks| Thread
Router2 -->|File Operations| FileService
Router3 -->|Service Status| TaskManager
Thread -->|Get Pending Tasks| TaskManager
Thread -->|Generate Video| VideoService
VideoService -->|Download Images| FileService
VideoService -->|Submit Task| DistService
DistService -->|Update| SharedData
DistService -->|Signal| TaskEvent
TaskEvent -->|Notify| W0
W0 -->|Broadcast| W1
W0 -->|Broadcast| WN
W0 -->|Update Result| SharedData
W0 -->|Signal| ResultEvent
ResultEvent -->|Notify| DistService
W0 -.->|Uses| GPUManager
W1 -.->|Uses| GPUManager
WN -.->|Uses| GPUManager
W0 -.->|Setup| DistManager
W1 -.->|Setup| DistManager
WN -.->|Setup| DistManager
DistService -.->|Reads| Config
ApiServer -.->|Reads| Config
```
### Components
#### Core Components
| Component | File | Description |
|-----------|------|-------------|
| **ServerManager** | `main.py` | Orchestrates server lifecycle, startup/shutdown sequences |
| **ApiServer** | `api.py` | FastAPI application manager with route registration |
| **TaskManager** | `task_manager.py` | Thread-safe task queue and lifecycle management |
| **FileService** | `service.py` | File I/O, HTTP downloads with retry logic |
| **VideoGenerationService** | `service.py` | Video generation workflow orchestration |
| **DistributedInferenceService** | `service.py` | Multi-process inference management |
| **GPUManager** | `gpu_manager.py` | GPU detection, allocation, and memory management |
| **DistributedManager** | `distributed_utils.py` | PyTorch distributed communication setup |
| **ServerConfig** | `config.py` | Centralized configuration with environment variable support |
## Task Processing Flow
```mermaid
sequenceDiagram
participant C as Client
participant API as API Server
participant TM as TaskManager
participant PT as Processing Thread
participant VS as VideoService
participant FS as FileService
participant DIS as Distributed<br/>Inference Service
participant W0 as Worker 0<br/>(Master)
participant W1 as Worker 1..N
C->>API: POST /v1/tasks<br/>(Create Task)
API->>TM: create_task()
TM->>TM: Generate task_id
TM->>TM: Add to queue<br/>(status: PENDING)
API->>PT: ensure_processing_thread()
API-->>C: TaskResponse<br/>(task_id, status: pending)
Note over PT: Processing Loop
PT->>TM: get_next_pending_task()
TM-->>PT: task_id
PT->>TM: acquire_processing_lock()
PT->>TM: start_task()<br/>(status: PROCESSING)
PT->>VS: generate_video_with_stop_event()
alt Image is URL
VS->>FS: download_image()
FS->>FS: HTTP download<br/>with retry
FS-->>VS: image_path
else Image is Base64
VS->>FS: save_base64_image()
FS-->>VS: image_path
else Image is Upload
VS->>FS: validate_file()
FS-->>VS: image_path
end
VS->>DIS: submit_task(task_data)
DIS->>DIS: shared_data["current_task"] = task_data
DIS->>DIS: task_event.set()
Note over W0,W1: Distributed Processing
W0->>W0: task_event.wait()
W0->>W0: Get task from shared_data
W0->>W1: broadcast_task_data()
par Parallel Inference
W0->>W0: run_pipeline()
and
W1->>W1: run_pipeline()
end
W0->>W0: barrier() for sync
W0->>W0: shared_data["result"] = result
W0->>DIS: result_event.set()
DIS->>DIS: result_event.wait()
DIS->>VS: return result
VS-->>PT: TaskResponse
PT->>TM: complete_task()<br/>(status: COMPLETED)
PT->>TM: release_processing_lock()
Note over C: Client Polling
C->>API: GET /v1/tasks/{task_id}/status
API->>TM: get_task_status()
TM-->>API: status info
API-->>C: Task Status
C->>API: GET /v1/tasks/{task_id}/result
API->>TM: get_task_status()
API->>FS: stream_file_response()
FS-->>API: Video Stream
API-->>C: Video File
```
## Task States
```mermaid
stateDiagram-v2
[*] --> PENDING: create_task()
PENDING --> PROCESSING: start_task()
PROCESSING --> COMPLETED: complete_task()
PROCESSING --> FAILED: fail_task()
PENDING --> CANCELLED: cancel_task()
PROCESSING --> CANCELLED: cancel_task()
COMPLETED --> [*]
FAILED --> [*]
CANCELLED --> [*]
```
## API Endpoints
see `{base_url}/docs`
### Task Management
- POST `/v1/tasks/`
- POST `/v1/tasks/form`
- GET `/v1/tasks/{task_id}/status`
- GET `/v1/tasks/{task_id}/result`
- DELETE `/v1/tasks/{task_id}`
- DELETE `/v1/tasks/all/running`
- GET `/v1/tasks/`
- GET `/v1/tasks/queue/status`
## Configuration
### Environment Variables
| Variable | Description | Default |
|----------|-------------|---------|
| `LIGHTX2V_HOST` | Server host address | `0.0.0.0` |
| `LIGHTX2V_PORT` | Server port | `8000` |
| `LIGHTX2V_MAX_QUEUE_SIZE` | Maximum task queue size | `100` |
| `LIGHTX2V_CACHE_DIR` | File cache directory | `/tmp/lightx2v_cache` |
| `LIGHTX2V_TASK_TIMEOUT` | Task processing timeout (seconds) | `600` |
| `LIGHTX2V_HTTP_TIMEOUT` | HTTP download timeout (seconds) | `30` |
| `LIGHTX2V_HTTP_MAX_RETRIES` | HTTP download max retries | `3` |
| `LIGHTX2V_MAX_UPLOAD_SIZE` | Maximum upload file size (bytes) | `100MB` |
### Command Line Arguments
```bash
python -m lightx2v.server.main \
--model_path /path/to/model \
--model_cls wan2.1_distill \
--task i2v \
--host 0.0.0.0 \
--port 8000 \
--config_json /path/to/xxx_config.json
```
```bash
python -m lightx2v.server.main \
--model_path /path/to/model \
--model_cls wan2.1_distill \
--task i2v \
--host 0.0.0.0 \
--port 8000 \
--config_json /path/to/xxx_dist_config.json \
--nproc_per_node 2
```
## Key Features
### 1. Distributed Processing
- **Multi-process architecture** for GPU parallelization
- **Master-worker pattern** with rank 0 as coordinator
- **PyTorch distributed** backend (NCCL for GPU, Gloo for CPU)
- **Automatic GPU allocation** across processes
- **Task broadcasting** with chunked pickle serialization
### 2. Task Queue Management
- **Thread-safe** task queue with locks
- **Sequential processing** with single processing thread
- **Configurable queue limits** with overflow protection
- **Task prioritization** (FIFO)
- **Automatic cleanup** of old completed tasks
- **Cancellation support** for pending and running tasks
### 3. File Management
- **Multiple input formats**: URL, base64, file upload
- **HTTP downloads** with exponential backoff retry
- **Streaming responses** for large video files
- **Cache management** with automatic cleanup
- **File validation** and format detection
### 4. Resilient Architecture
- **Graceful shutdown** with signal handling
- **Process failure recovery** mechanisms
- **Connection pooling** for HTTP clients
- **Timeout protection** at multiple levels
- **Comprehensive error handling** throughout
### 5. Resource Management
- **GPU memory management** with cache clearing
- **Process lifecycle management**
- **Connection pooling** for efficiency
- **Memory-efficient** streaming for large files
- **Automatic resource cleanup** on shutdown
## Performance Considerations
1. **Single Task Processing**: Tasks are processed sequentially to manage GPU memory effectively
2. **Multi-GPU Support**: Distributes inference across available GPUs for parallelization
3. **Connection Pooling**: Reuses HTTP connections to reduce overhead
4. **Streaming Responses**: Large files are streamed to avoid memory issues
5. **Queue Management**: Automatic task cleanup prevents memory leaks
6. **Process Isolation**: Distributed workers run in separate processes for stability
## Usage Examples
### Client Usage
```python
import httpx
import base64
# Create a task with URL image
response = httpx.post(
"http://localhost:8000/v1/tasks/",
json={
"prompt": "A cat playing piano",
"image_path": "https://example.com/image.jpg",
"use_prompt_enhancer": True,
"seed": 42
}
)
task_id = response.json()["task_id"]
# Create a task with base64 image
with open("image.png", "rb") as f:
image_base64 = base64.b64encode(f.read()).decode()
response = httpx.post(
"http://localhost:8000/v1/tasks/",
json={
"prompt": "A dog dancing",
"image_path": f"data:image/png;base64,{image_base64}"
}
)
# Check task status
status = httpx.get(f"http://localhost:8000/v1/tasks/{task_id}/status")
print(status.json())
# Download result when completed
if status.json()["status"] == "completed":
video = httpx.get(f"http://localhost:8000/v1/tasks/{task_id}/result")
with open("output.mp4", "wb") as f:
f.write(video.content)
```
## Monitoring and Debugging
### Logging
The server uses `loguru` for structured logging. Logs include:
- Request/response details
- Task lifecycle events
- Worker process status
- Error traces with context
### Health Checks
- `/v1/service/status` - Overall service health
- `/v1/tasks/queue/status` - Queue status and processing state
- Process monitoring via system tools (htop, nvidia-smi)
### Common Issues
1. **GPU Out of Memory**: Reduce `nproc_per_node` or adjust model batch size
2. **Task Timeout**: Increase `LIGHTX2V_TASK_TIMEOUT` for longer videos
3. **Queue Full**: Increase `LIGHTX2V_MAX_QUEUE_SIZE` or add rate limiting
4. **Port Conflicts**: Change `LIGHTX2V_PORT` or `MASTER_PORT` range
## Security Considerations
1. **Input Validation**: All inputs validated with Pydantic schemas
2. **File Access**: Restricted to cache directory
3. **Resource Limits**: Configurable queue and file size limits
4. **Process Isolation**: Worker processes run with limited permissions
5. **HTTP Security**: Support for proxy and authentication headers
## License
See the main project LICENSE file for licensing information.
import asyncio
import gc
import threading
import time
import uuid
from pathlib import Path
from typing import Optional
from typing import Any, Optional
from urllib.parse import urlparse
import httpx
import torch
from fastapi import APIRouter, FastAPI, File, Form, HTTPException, UploadFile
from fastapi.responses import StreamingResponse
from loguru import logger
from .schema import (
ServiceStatusResponse,
StopTaskResponse,
TaskRequest,
TaskResponse,
)
from .service import DistributedInferenceService, FileService, VideoGenerationService
from .utils import ServiceStatus
from .task_manager import TaskStatus, task_manager
class ApiServer:
def __init__(self):
self.app = FastAPI(title="LightX2V API", version="1.0.0")
def __init__(self, max_queue_size: int = 10, app: Optional[FastAPI] = None):
self.app = app or FastAPI(title="LightX2V API", version="1.0.0")
self.file_service = None
self.inference_service = None
self.video_service = None
self.thread = None
self.stop_generation_event = threading.Event()
self.max_queue_size = max_queue_size
self.processing_thread = None
self.stop_processing = threading.Event()
# Create routers
self.tasks_router = APIRouter(prefix="/v1/tasks", tags=["tasks"])
self.files_router = APIRouter(prefix="/v1/files", tags=["files"])
self.service_router = APIRouter(prefix="/v1/service", tags=["service"])
......@@ -37,7 +40,6 @@ class ApiServer:
self._setup_routes()
def _setup_routes(self):
"""Setup routes"""
self._setup_task_routes()
self._setup_file_routes()
self._setup_service_routes()
......@@ -48,18 +50,15 @@ class ApiServer:
self.app.include_router(self.service_router)
def _write_file_sync(self, file_path: Path, content: bytes) -> None:
"""同步写入文件到指定路径"""
with open(file_path, "wb") as buffer:
buffer.write(content)
def _stream_file_response(self, file_path: Path, filename: str | None = None) -> StreamingResponse:
"""Common file streaming response method"""
assert self.file_service is not None, "File service is not initialized"
try:
resolved_path = file_path.resolve()
# Security check: ensure file is within allowed directory
if not str(resolved_path).startswith(str(self.file_service.output_video_dir.resolve())):
raise HTTPException(status_code=403, detail="Access to this file is not allowed")
......@@ -103,24 +102,25 @@ class ApiServer:
async def create_task(message: TaskRequest):
"""Create video generation task"""
try:
task_id = ServiceStatus.start_task(message)
# Use background thread to handle long-running tasks
self.stop_generation_event.clear()
self.thread = threading.Thread(
target=self._process_video_generation,
args=(message, self.stop_generation_event),
daemon=True,
)
self.thread.start()
if hasattr(message, "image_path") and message.image_path and message.image_path.startswith("http"):
if not await self._validate_image_url(message.image_path):
raise HTTPException(status_code=400, detail=f"Image URL is not accessible: {message.image_path}")
task_id = task_manager.create_task(message)
message.task_id = task_id
self._ensure_processing_thread_running()
return TaskResponse(
task_id=task_id,
task_status="processing",
task_status="pending",
save_video_path=message.save_video_path,
)
except RuntimeError as e:
raise HTTPException(status_code=400, detail=str(e))
raise HTTPException(status_code=503, detail=str(e))
except Exception as e:
logger.error(f"Failed to create task: {e}")
raise HTTPException(status_code=500, detail=str(e))
@self.tasks_router.post("/form", response_model=TaskResponse)
async def create_task_form(
......@@ -136,11 +136,9 @@ class ApiServer:
audio_file: Optional[UploadFile] = File(default=None),
video_duration: int = Form(default=5),
):
"""Create video generation task via form"""
assert self.file_service is not None, "File service is not initialized"
async def save_file_async(file: UploadFile, target_dir: Path) -> str:
"""异步保存文件到指定目录"""
if not file or not file.filename:
return ""
......@@ -177,44 +175,58 @@ class ApiServer:
)
try:
task_id = ServiceStatus.start_task(message)
self.stop_generation_event.clear()
self.thread = threading.Thread(
target=self._process_video_generation,
args=(message, self.stop_generation_event),
daemon=True,
)
self.thread.start()
task_id = task_manager.create_task(message)
message.task_id = task_id
self._ensure_processing_thread_running()
return TaskResponse(
task_id=task_id,
task_status="processing",
task_status="pending",
save_video_path=message.save_video_path,
)
except RuntimeError as e:
raise HTTPException(status_code=400, detail=str(e))
raise HTTPException(status_code=503, detail=str(e))
except Exception as e:
logger.error(f"Failed to create form task: {e}")
raise HTTPException(status_code=500, detail=str(e))
@self.tasks_router.get("/", response_model=dict)
async def list_tasks():
"""Get all task list"""
return ServiceStatus.get_all_tasks()
return task_manager.get_all_tasks()
@self.tasks_router.get("/queue/status", response_model=dict)
async def get_queue_status():
service_status = task_manager.get_service_status()
return {
"is_processing": task_manager.is_processing(),
"current_task": service_status.get("current_task"),
"pending_count": task_manager.get_pending_task_count(),
"active_count": task_manager.get_active_task_count(),
"queue_size": self.max_queue_size,
"queue_available": self.max_queue_size - task_manager.get_active_task_count(),
}
@self.tasks_router.get("/{task_id}/status")
async def get_task_status(task_id: str):
"""Get status of specified task"""
return ServiceStatus.get_status_task_id(task_id)
status = task_manager.get_task_status(task_id)
if not status:
raise HTTPException(status_code=404, detail="Task not found")
return status
@self.tasks_router.get("/{task_id}/result")
async def get_task_result(task_id: str):
"""Get result video file of specified task"""
assert self.video_service is not None, "Video service is not initialized"
assert self.file_service is not None, "File service is not initialized"
try:
task_status = ServiceStatus.get_status_task_id(task_id)
task_status = task_manager.get_task_status(task_id)
if not task_status or task_status.get("status") != "completed":
raise HTTPException(status_code=404, detail="Task not completed or does not exist")
if not task_status:
raise HTTPException(status_code=404, detail="Task not found")
if task_status.get("status") != TaskStatus.COMPLETED.value:
raise HTTPException(status_code=404, detail="Task not completed")
save_video_path = task_status.get("save_video_path")
if not save_video_path:
......@@ -232,38 +244,37 @@ class ApiServer:
logger.error(f"Error occurred while getting task result: {e}")
raise HTTPException(status_code=500, detail="Failed to get task result")
@self.tasks_router.delete("/running", response_model=StopTaskResponse)
async def stop_running_task():
"""Stop currently running task"""
if self.thread and self.thread.is_alive():
try:
logger.info("Sending stop signal to running task thread...")
self.stop_generation_event.set()
self.thread.join(timeout=5)
if self.thread.is_alive():
logger.warning("Task thread did not stop within the specified time, manual intervention may be required.")
return StopTaskResponse(
stop_status="warning",
reason="Task thread did not stop within the specified time, manual intervention may be required.",
)
else:
self.thread = None
ServiceStatus.clean_stopped_task()
gc.collect()
@self.tasks_router.delete("/{task_id}", response_model=StopTaskResponse)
async def stop_task(task_id: str):
try:
if task_manager.cancel_task(task_id):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info("Task stopped successfully.")
return StopTaskResponse(stop_status="success", reason="Task stopped successfully.")
except Exception as e:
logger.error(f"Error occurred while stopping task: {str(e)}")
return StopTaskResponse(stop_status="error", reason=str(e))
else:
return StopTaskResponse(stop_status="do_nothing", reason="No running task found.")
logger.info(f"Task {task_id} stopped successfully.")
return StopTaskResponse(stop_status="success", reason="Task stopped successfully.")
else:
return StopTaskResponse(stop_status="do_nothing", reason="Task not found or already completed.")
except Exception as e:
logger.error(f"Error occurred while stopping task {task_id}: {str(e)}")
return StopTaskResponse(stop_status="error", reason=str(e))
@self.tasks_router.delete("/all/running", response_model=StopTaskResponse)
async def stop_all_running_tasks():
try:
task_manager.cancel_all_tasks()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info("All tasks stopped successfully.")
return StopTaskResponse(stop_status="success", reason="All tasks stopped successfully.")
except Exception as e:
logger.error(f"Error occurred while stopping all tasks: {str(e)}")
return StopTaskResponse(stop_status="error", reason=str(e))
def _setup_file_routes(self):
@self.files_router.get("/download/{file_path:path}")
async def download_file(file_path: str):
"""Download file"""
assert self.file_service is not None, "File service is not initialized"
try:
......@@ -276,36 +287,111 @@ class ApiServer:
raise HTTPException(status_code=500, detail="File download failed")
def _setup_service_routes(self):
@self.service_router.get("/status", response_model=ServiceStatusResponse)
@self.service_router.get("/status", response_model=dict)
async def get_service_status():
"""Get service status"""
return ServiceStatus.get_status_service()
return task_manager.get_service_status()
@self.service_router.get("/metadata", response_model=dict)
async def get_service_metadata():
"""Get service metadata"""
assert self.inference_service is not None, "Inference service is not initialized"
return self.inference_service.server_metadata()
def _process_video_generation(self, message: TaskRequest, stop_event: threading.Event):
async def _validate_image_url(self, image_url: str) -> bool:
if not image_url or not image_url.startswith("http"):
return True
try:
parsed_url = urlparse(image_url)
if not parsed_url.scheme or not parsed_url.netloc:
return False
timeout = httpx.Timeout(connect=5.0, read=5.0)
async with httpx.AsyncClient(verify=False, timeout=timeout) as client:
response = await client.head(image_url, follow_redirects=True)
return response.status_code < 400
except Exception as e:
logger.warning(f"URL validation failed for {image_url}: {str(e)}")
return False
def _ensure_processing_thread_running(self):
"""Ensure the processing thread is running."""
if self.processing_thread is None or not self.processing_thread.is_alive():
self.stop_processing.clear()
self.processing_thread = threading.Thread(target=self._task_processing_loop, daemon=True)
self.processing_thread.start()
logger.info("Started task processing thread")
def _task_processing_loop(self):
"""Main loop that processes tasks from the queue one by one."""
logger.info("Task processing loop started")
while not self.stop_processing.is_set():
task_id = task_manager.get_next_pending_task()
if task_id is None:
time.sleep(1)
continue
task_info = task_manager.get_task(task_id)
if task_info and task_info.status == TaskStatus.PENDING:
logger.info(f"Processing task {task_id}")
self._process_single_task(task_info)
logger.info("Task processing loop stopped")
def _process_single_task(self, task_info: Any):
"""Process a single task."""
assert self.video_service is not None, "Video service is not initialized"
task_id = task_info.task_id
message = task_info.message
lock_acquired = task_manager.acquire_processing_lock(task_id, timeout=1)
if not lock_acquired:
logger.error(f"Task {task_id} failed to acquire processing lock")
task_manager.fail_task(task_id, "Failed to acquire processing lock")
return
try:
if stop_event.is_set():
logger.info(f"Task {message.task_id} received stop signal, terminating")
ServiceStatus.record_failed_task(message, error="Task stopped")
task_manager.start_task(task_id)
if task_info.stop_event.is_set():
logger.info(f"Task {task_id} cancelled before processing")
task_manager.fail_task(task_id, "Task cancelled")
return
# Use video generation service to process task
result = asyncio.run(self.video_service.generate_video(message))
result = asyncio.run(self.video_service.generate_video_with_stop_event(message, task_info.stop_event))
if result:
task_manager.complete_task(task_id, result.save_video_path)
logger.info(f"Task {task_id} completed successfully")
else:
if task_info.stop_event.is_set():
task_manager.fail_task(task_id, "Task cancelled during processing")
logger.info(f"Task {task_id} cancelled during processing")
else:
task_manager.fail_task(task_id, "Generation failed")
logger.error(f"Task {task_id} generation failed")
except Exception as e:
logger.error(f"Task {message.task_id} processing failed: {str(e)}")
ServiceStatus.record_failed_task(message, error=str(e))
logger.error(f"Task {task_id} processing failed: {str(e)}")
task_manager.fail_task(task_id, str(e))
finally:
if lock_acquired:
task_manager.release_processing_lock(task_id)
def initialize_services(self, cache_dir: Path, inference_service: DistributedInferenceService):
self.file_service = FileService(cache_dir)
self.inference_service = inference_service
self.video_service = VideoGenerationService(self.file_service, inference_service)
async def cleanup(self):
self.stop_processing.set()
if self.processing_thread and self.processing_thread.is_alive():
self.processing_thread.join(timeout=5)
if self.file_service:
await self.file_service.cleanup()
def get_app(self) -> FastAPI:
return self.app
import os
from dataclasses import dataclass
from pathlib import Path
from loguru import logger
@dataclass
class ServerConfig:
host: str = "0.0.0.0"
port: int = 8000
max_queue_size: int = 10
master_addr: str = "127.0.0.1"
master_port_range: tuple = (29500, 29600)
task_timeout: int = 300
task_history_limit: int = 1000
http_timeout: int = 30
http_max_retries: int = 3
cache_dir: str = str(Path(__file__).parent.parent / "server_cache")
max_upload_size: int = 500 * 1024 * 1024 # 500MB
@classmethod
def from_env(cls) -> "ServerConfig":
config = cls()
if env_host := os.environ.get("LIGHTX2V_HOST"):
config.host = env_host
if env_port := os.environ.get("LIGHTX2V_PORT"):
try:
config.port = int(env_port)
except ValueError:
logger.warning(f"Invalid port in environment: {env_port}")
if env_queue_size := os.environ.get("LIGHTX2V_MAX_QUEUE_SIZE"):
try:
config.max_queue_size = int(env_queue_size)
except ValueError:
logger.warning(f"Invalid max queue size: {env_queue_size}")
if env_master_addr := os.environ.get("MASTER_ADDR"):
config.master_addr = env_master_addr
if env_cache_dir := os.environ.get("LIGHTX2V_CACHE_DIR"):
config.cache_dir = env_cache_dir
return config
def find_free_master_port(self) -> str:
import socket
for port in range(self.master_port_range[0], self.master_port_range[1]):
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind((self.master_addr, port))
logger.info(f"Found free port for master: {port}")
return str(port)
except OSError:
continue
import random
return str(random.randint(20000, 29999))
def validate(self) -> bool:
valid = True
if self.max_queue_size <= 0:
logger.error("max_queue_size must be positive")
valid = False
if self.task_timeout <= 0:
logger.error("task_timeout must be positive")
valid = False
return valid
server_config = ServerConfig.from_env()
import os
import pickle
from typing import Any, Optional
import torch
import torch.distributed as dist
from loguru import logger
from .gpu_manager import gpu_manager
class DistributedManager:
def __init__(self):
self.is_initialized = False
self.rank = 0
self.world_size = 1
self.device = "cpu"
def init_process_group(self, rank: int, world_size: int, master_addr: str, master_port: str) -> bool:
try:
......@@ -18,10 +23,12 @@ class DistributedManager:
os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = master_port
dist.init_process_group(backend="nccl", init_method=f"tcp://{master_addr}:{master_port}", rank=rank, world_size=world_size)
backend = "nccl" if torch.cuda.is_available() else "gloo"
dist.init_process_group(backend=backend, init_method=f"tcp://{master_addr}:{master_port}", rank=rank, world_size=world_size)
logger.info(f"Setup backend: {backend}")
if torch.cuda.is_available(): # type: ignore
torch.cuda.set_device(rank)
self.device = gpu_manager.set_device_for_rank(rank, world_size)
self.is_initialized = True
self.rank = rank
......@@ -46,55 +53,86 @@ class DistributedManager:
def barrier(self):
if self.is_initialized:
dist.barrier()
if torch.cuda.is_available() and dist.get_backend() == "nccl":
dist.barrier(device_ids=[torch.cuda.current_device()])
else:
dist.barrier()
def is_rank_zero(self) -> bool:
return self.rank == 0
def broadcast_task_data(self, task_data=None): # type: ignore
def broadcast_task_data(self, task_data: Optional[Any] = None) -> Optional[Any]:
if not self.is_initialized:
return None
try:
backend = dist.get_backend() if dist.is_initialized() else "gloo"
except Exception:
backend = "gloo"
if backend == "gloo":
broadcast_device = torch.device("cpu")
else:
broadcast_device = torch.device(self.device if self.device != "cpu" else "cpu")
if self.is_rank_zero():
if task_data is None:
stop_signal = torch.tensor([1], dtype=torch.int32, device=f"cuda:{self.rank}")
stop_signal = torch.tensor([1], dtype=torch.int32).to(broadcast_device)
else:
stop_signal = torch.tensor([0], dtype=torch.int32, device=f"cuda:{self.rank}")
stop_signal = torch.tensor([0], dtype=torch.int32).to(broadcast_device)
dist.broadcast(stop_signal, src=0)
if task_data is not None:
import pickle
task_bytes = pickle.dumps(task_data)
task_length = torch.tensor([len(task_bytes)], dtype=torch.int32, device=f"cuda:{self.rank}")
task_length = torch.tensor([len(task_bytes)], dtype=torch.int32).to(broadcast_device)
dist.broadcast(task_length, src=0)
task_tensor = torch.tensor(list(task_bytes), dtype=torch.uint8, device=f"cuda:{self.rank}")
dist.broadcast(task_tensor, src=0)
chunk_size = 1024 * 1024
if len(task_bytes) > chunk_size:
num_chunks = (len(task_bytes) + chunk_size - 1) // chunk_size
for i in range(num_chunks):
start_idx = i * chunk_size
end_idx = min((i + 1) * chunk_size, len(task_bytes))
chunk = task_bytes[start_idx:end_idx]
task_tensor = torch.tensor(list(chunk), dtype=torch.uint8).to(broadcast_device)
dist.broadcast(task_tensor, src=0)
else:
task_tensor = torch.tensor(list(task_bytes), dtype=torch.uint8).to(broadcast_device)
dist.broadcast(task_tensor, src=0)
return task_data
else:
return None
else:
stop_signal = torch.tensor([0], dtype=torch.int32, device=f"cuda:{self.rank}")
stop_signal = torch.tensor([0], dtype=torch.int32).to(broadcast_device)
dist.broadcast(stop_signal, src=0)
if stop_signal.item() == 1:
return None
else:
task_length = torch.tensor([0], dtype=torch.int32, device=f"cuda:{self.rank}")
task_length = torch.tensor([0], dtype=torch.int32).to(broadcast_device)
dist.broadcast(task_length, src=0)
task_tensor = torch.empty(int(task_length.item()), dtype=torch.uint8, device=f"cuda:{self.rank}")
dist.broadcast(task_tensor, src=0)
total_length = int(task_length.item())
chunk_size = 1024 * 1024
if total_length > chunk_size:
task_bytes = bytearray()
num_chunks = (total_length + chunk_size - 1) // chunk_size
for i in range(num_chunks):
chunk_length = min(chunk_size, total_length - len(task_bytes))
task_tensor = torch.empty(chunk_length, dtype=torch.uint8).to(broadcast_device)
dist.broadcast(task_tensor, src=0)
task_bytes.extend(task_tensor.cpu().numpy())
task_bytes = bytes(task_bytes)
else:
task_tensor = torch.empty(total_length, dtype=torch.uint8).to(broadcast_device)
dist.broadcast(task_tensor, src=0)
task_bytes = bytes(task_tensor.cpu().numpy())
import pickle
task_bytes = bytes(task_tensor.cpu().numpy())
task_data = pickle.loads(task_bytes)
return task_data
......@@ -113,7 +151,6 @@ class DistributedWorker:
self.dist_manager.cleanup()
def sync_and_report(self, task_id: str, status: str, result_queue, **kwargs):
# Synchronize all processes
self.dist_manager.barrier()
if self.dist_manager.is_rank_zero():
......
import os
from typing import List, Optional, Tuple
import torch
from loguru import logger
class GPUManager:
def __init__(self):
self.available_gpus = self._detect_gpus()
self.gpu_count = len(self.available_gpus)
def _detect_gpus(self) -> List[int]:
if not torch.cuda.is_available():
logger.warning("No CUDA devices available, will use CPU")
return []
gpu_count = torch.cuda.device_count()
cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", "")
if cuda_visible:
try:
visible_devices = [int(d.strip()) for d in cuda_visible.split(",")]
logger.info(f"CUDA_VISIBLE_DEVICES set to: {visible_devices}")
return list(range(len(visible_devices)))
except ValueError:
logger.warning(f"Invalid CUDA_VISIBLE_DEVICES: {cuda_visible}, using all devices")
available_gpus = list(range(gpu_count))
logger.info(f"Detected {gpu_count} GPU devices: {available_gpus}")
return available_gpus
def get_device_for_rank(self, rank: int, world_size: int) -> str:
if not self.available_gpus:
logger.info(f"Rank {rank}: Using CPU (no GPUs available)")
return "cpu"
if self.gpu_count == 1:
device = f"cuda:{self.available_gpus[0]}"
logger.info(f"Rank {rank}: Using single GPU {device}")
return device
if self.gpu_count >= world_size:
gpu_id = self.available_gpus[rank % self.gpu_count]
device = f"cuda:{gpu_id}"
logger.info(f"Rank {rank}: Assigned to dedicated GPU {device}")
return device
else:
gpu_id = self.available_gpus[rank % self.gpu_count]
device = f"cuda:{gpu_id}"
logger.info(f"Rank {rank}: Sharing GPU {device} (world_size={world_size} > gpu_count={self.gpu_count})")
return device
def set_device_for_rank(self, rank: int, world_size: int) -> str:
device = self.get_device_for_rank(rank, world_size)
if device.startswith("cuda:"):
gpu_id = int(device.split(":")[1])
torch.cuda.set_device(gpu_id)
logger.info(f"Rank {rank}: CUDA device set to {gpu_id}")
return device
def get_memory_info(self, device: Optional[str] = None) -> Tuple[int, int]:
if not torch.cuda.is_available():
return (0, 0)
if device and device.startswith("cuda:"):
gpu_id = int(device.split(":")[1])
else:
gpu_id = torch.cuda.current_device()
try:
used = torch.cuda.memory_allocated(gpu_id)
total = torch.cuda.get_device_properties(gpu_id).total_memory
return (used, total)
except Exception as e:
logger.error(f"Failed to get memory info for device {gpu_id}: {e}")
return (0, 0)
def clear_cache(self, device: Optional[str] = None):
if not torch.cuda.is_available():
return
if device and device.startswith("cuda:"):
gpu_id = int(device.split(":")[1])
with torch.cuda.device(gpu_id):
torch.cuda.empty_cache()
torch.cuda.synchronize()
else:
torch.cuda.empty_cache()
torch.cuda.synchronize()
logger.info(f"GPU cache cleared for device: {device or 'current'}")
@staticmethod
def get_optimal_world_size(requested_world_size: int) -> int:
if not torch.cuda.is_available():
logger.warning("No GPUs available, using single process")
return 1
gpu_count = torch.cuda.device_count()
if requested_world_size <= 0:
optimal_size = gpu_count
logger.info(f"Auto-detected world_size: {optimal_size} (based on {gpu_count} GPUs)")
elif requested_world_size > gpu_count:
logger.warning(f"Requested world_size ({requested_world_size}) exceeds GPU count ({gpu_count}). Processes will share GPUs.")
optimal_size = requested_world_size
else:
optimal_size = requested_world_size
return optimal_size
gpu_manager = GPUManager()
import base64
import os
import re
import uuid
from pathlib import Path
from typing import Optional, Tuple
def is_base64_image(data: str) -> bool:
"""Check if a string is a base64-encoded image"""
if data.startswith("data:image/"):
return True
try:
if len(data) % 4 == 0:
base64.b64decode(data, validate=True)
decoded = base64.b64decode(data[:100])
if decoded.startswith(b"\x89PNG\r\n\x1a\n"):
return True
if decoded.startswith(b"\xff\xd8\xff"):
return True
if decoded.startswith(b"GIF87a") or decoded.startswith(b"GIF89a"):
return True
if decoded[8:12] == b"WEBP":
return True
except Exception as e:
print(f"Error checking base64 image: {e}")
return False
return False
def extract_base64_data(data: str) -> Tuple[str, Optional[str]]:
"""
Extract base64 data and format from a data URL or plain base64 string
Returns: (base64_data, format)
"""
if data.startswith("data:"):
match = re.match(r"data:image/(\w+);base64,(.+)", data)
if match:
format_type = match.group(1)
base64_data = match.group(2)
return base64_data, format_type
return data, None
def save_base64_image(base64_data: str, output_dir: str = "/tmp/flux_kontext_uploads") -> str:
"""
Save a base64-encoded image to disk and return the file path
"""
Path(output_dir).mkdir(parents=True, exist_ok=True)
data, format_type = extract_base64_data(base64_data)
file_id = str(uuid.uuid4())
try:
image_data = base64.b64decode(data)
except Exception as e:
raise ValueError(f"Invalid base64 data: {e}")
if format_type:
ext = format_type
else:
if image_data.startswith(b"\x89PNG\r\n\x1a\n"):
ext = "png"
elif image_data.startswith(b"\xff\xd8\xff"):
ext = "jpg"
elif image_data.startswith(b"GIF87a") or image_data.startswith(b"GIF89a"):
ext = "gif"
elif len(image_data) > 12 and image_data[8:12] == b"WEBP":
ext = "webp"
else:
ext = "png"
file_path = os.path.join(output_dir, f"{file_id}.{ext}")
with open(file_path, "wb") as f:
f.write(image_data)
return file_path
import asyncio
import signal
import sys
from pathlib import Path
from typing import Optional
import uvicorn
from loguru import logger
from .api import ApiServer
from .config import server_config
from .service import DistributedInferenceService
class ServerManager:
def __init__(self):
self.api_server: Optional[ApiServer] = None
self.inference_service: Optional[DistributedInferenceService] = None
self.shutdown_event = asyncio.Event()
async def startup(self, args):
logger.info("Starting LightX2V server...")
if hasattr(args, "host") and args.host:
server_config.host = args.host
if hasattr(args, "port") and args.port:
server_config.port = args.port
if not server_config.validate():
raise RuntimeError("Invalid server configuration")
self.inference_service = DistributedInferenceService()
if not self.inference_service.start_distributed_inference(args):
raise RuntimeError("Failed to start distributed inference service")
cache_dir = Path(server_config.cache_dir)
cache_dir.mkdir(parents=True, exist_ok=True)
self.api_server = ApiServer(max_queue_size=server_config.max_queue_size)
self.api_server.initialize_services(cache_dir, self.inference_service)
logger.info("Server startup completed successfully")
async def shutdown(self):
logger.info("Starting server shutdown...")
if self.api_server:
await self.api_server.cleanup()
logger.info("API server cleaned up")
if self.inference_service:
self.inference_service.stop_distributed_inference()
logger.info("Inference service stopped")
logger.info("Server shutdown completed")
def handle_signal(self, sig, frame):
logger.info(f"Received signal {sig}, initiating graceful shutdown...")
asyncio.create_task(self.shutdown())
self.shutdown_event.set()
async def run_server(self, args):
try:
await self.startup(args)
assert self.api_server is not None
app = self.api_server.get_app()
signal.signal(signal.SIGINT, self.handle_signal)
signal.signal(signal.SIGTERM, self.handle_signal)
logger.info(f"Starting server on {server_config.host}:{server_config.port}")
config = uvicorn.Config(
app=app,
host=server_config.host,
port=server_config.port,
log_level="info",
)
server = uvicorn.Server(config)
server_task = asyncio.create_task(server.serve())
await self.shutdown_event.wait()
server.should_exit = True
await server_task
except Exception as e:
logger.error(f"Server error: {e}")
raise
finally:
await self.shutdown()
def run_server(args):
inference_service = None
try:
logger.info("Starting LightX2V server...")
if hasattr(args, "host") and args.host:
server_config.host = args.host
if hasattr(args, "port") and args.port:
server_config.port = args.port
if not server_config.validate():
raise RuntimeError("Invalid server configuration")
inference_service = DistributedInferenceService()
if not inference_service.start_distributed_inference(args):
raise RuntimeError("Failed to start distributed inference service")
logger.info("Inference service started successfully")
cache_dir = Path(server_config.cache_dir)
cache_dir.mkdir(parents=True, exist_ok=True)
api_server = ApiServer(max_queue_size=server_config.max_queue_size)
api_server.initialize_services(cache_dir, inference_service)
app = api_server.get_app()
logger.info(f"Starting server on {server_config.host}:{server_config.port}")
uvicorn.run(app, host=server_config.host, port=server_config.port, log_level="info")
except KeyboardInterrupt:
logger.info("Server interrupted by user")
if inference_service:
inference_service.stop_distributed_inference()
except Exception as e:
logger.error(f"Server failed: {e}")
if inference_service:
inference_service.stop_distributed_inference()
sys.exit(1)
from datetime import datetime
from typing import Optional
from pydantic import BaseModel, Field
from ..utils.generate_task_id import generate_task_id
......@@ -11,7 +8,7 @@ class TaskRequest(BaseModel):
prompt: str = Field("", description="Generation prompt")
use_prompt_enhancer: bool = Field(False, description="Whether to use prompt enhancer")
negative_prompt: str = Field("", description="Negative prompt")
image_path: str = Field("", description="Input image path")
image_path: str = Field("", description="Base64 encoded image or URL")
num_fragments: int = Field(1, description="Number of fragments")
save_video_path: str = Field("", description="Save video path (optional, defaults to task_id.mp4)")
infer_steps: int = Field(5, description="Inference steps")
......@@ -22,7 +19,6 @@ class TaskRequest(BaseModel):
def __init__(self, **data):
super().__init__(**data)
# If save_video_path is empty, use task_id.mp4
if not self.save_video_path:
self.save_video_path = f"{self.task_id}.mp4"
......@@ -40,21 +36,6 @@ class TaskResponse(BaseModel):
save_video_path: str
class TaskResultResponse(BaseModel):
status: str
task_status: str
filename: Optional[str] = None
file_size: Optional[int] = None
download_url: Optional[str] = None
message: str
class ServiceStatusResponse(BaseModel):
service_status: str
task_id: Optional[str] = None
start_time: Optional[datetime] = None
class StopTaskResponse(BaseModel):
stop_status: str
reason: str
This diff is collapsed.
import threading
import uuid
from collections import OrderedDict
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any, Dict, Optional
from loguru import logger
class TaskStatus(Enum):
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
@dataclass
class TaskInfo:
task_id: str
status: TaskStatus
message: Any
start_time: datetime = field(default_factory=datetime.now)
end_time: Optional[datetime] = None
error: Optional[str] = None
save_video_path: Optional[str] = None
stop_event: threading.Event = field(default_factory=threading.Event)
thread: Optional[threading.Thread] = None
class TaskManager:
def __init__(self, max_queue_size: int = 100):
self.max_queue_size = max_queue_size
self._tasks: OrderedDict[str, TaskInfo] = OrderedDict()
self._lock = threading.RLock()
self._processing_lock = threading.Lock()
self._current_processing_task: Optional[str] = None
self.total_tasks = 0
self.completed_tasks = 0
self.failed_tasks = 0
def create_task(self, message: Any) -> str:
with self._lock:
if hasattr(message, "task_id") and message.task_id in self._tasks:
raise RuntimeError(f"Task ID {message.task_id} already exists")
active_tasks = sum(1 for t in self._tasks.values() if t.status in [TaskStatus.PENDING, TaskStatus.PROCESSING])
if active_tasks >= self.max_queue_size:
raise RuntimeError(f"Task queue is full (max {self.max_queue_size} tasks)")
task_id = getattr(message, "task_id", str(uuid.uuid4()))
task_info = TaskInfo(task_id=task_id, status=TaskStatus.PENDING, message=message, save_video_path=getattr(message, "save_video_path", None))
self._tasks[task_id] = task_info
self.total_tasks += 1
self._cleanup_old_tasks()
return task_id
def start_task(self, task_id: str) -> TaskInfo:
with self._lock:
if task_id not in self._tasks:
raise KeyError(f"Task {task_id} not found")
task = self._tasks[task_id]
task.status = TaskStatus.PROCESSING
task.start_time = datetime.now()
self._tasks.move_to_end(task_id)
return task
def complete_task(self, task_id: str, save_video_path: Optional[str] = None):
with self._lock:
if task_id not in self._tasks:
logger.warning(f"Task {task_id} not found for completion")
return
task = self._tasks[task_id]
task.status = TaskStatus.COMPLETED
task.end_time = datetime.now()
if save_video_path:
task.save_video_path = save_video_path
self.completed_tasks += 1
def fail_task(self, task_id: str, error: str):
with self._lock:
if task_id not in self._tasks:
logger.warning(f"Task {task_id} not found for failure")
return
task = self._tasks[task_id]
task.status = TaskStatus.FAILED
task.end_time = datetime.now()
task.error = error
self.failed_tasks += 1
def cancel_task(self, task_id: str) -> bool:
with self._lock:
if task_id not in self._tasks:
return False
task = self._tasks[task_id]
if task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED]:
return False
task.stop_event.set()
task.status = TaskStatus.CANCELLED
task.end_time = datetime.now()
task.error = "Task cancelled by user"
if task.thread and task.thread.is_alive():
task.thread.join(timeout=5)
return True
def cancel_all_tasks(self):
with self._lock:
for task_id, task in list(self._tasks.items()):
if task.status in [TaskStatus.PENDING, TaskStatus.PROCESSING]:
self.cancel_task(task_id)
def get_task(self, task_id: str) -> Optional[TaskInfo]:
with self._lock:
return self._tasks.get(task_id)
def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
task = self.get_task(task_id)
if not task:
return None
return {"task_id": task.task_id, "status": task.status.value, "start_time": task.start_time, "end_time": task.end_time, "error": task.error, "save_video_path": task.save_video_path}
def get_all_tasks(self):
with self._lock:
return {task_id: self.get_task_status(task_id) for task_id in self._tasks}
def get_active_task_count(self) -> int:
with self._lock:
return sum(1 for t in self._tasks.values() if t.status in [TaskStatus.PENDING, TaskStatus.PROCESSING])
def get_pending_task_count(self) -> int:
with self._lock:
return sum(1 for t in self._tasks.values() if t.status == TaskStatus.PENDING)
def is_processing(self) -> bool:
with self._lock:
return self._current_processing_task is not None
def acquire_processing_lock(self, task_id: str, timeout: Optional[float] = None) -> bool:
acquired = self._processing_lock.acquire(timeout=timeout if timeout else False)
if acquired:
with self._lock:
self._current_processing_task = task_id
logger.info(f"Task {task_id} acquired processing lock")
return acquired
def release_processing_lock(self, task_id: str):
with self._lock:
if self._current_processing_task == task_id:
self._current_processing_task = None
try:
self._processing_lock.release()
logger.info(f"Task {task_id} released processing lock")
except RuntimeError as e:
logger.warning(f"Task {task_id} tried to release lock but failed: {e}")
def get_next_pending_task(self) -> Optional[str]:
with self._lock:
for task_id, task in self._tasks.items():
if task.status == TaskStatus.PENDING:
return task_id
return None
def get_service_status(self) -> Dict[str, Any]:
with self._lock:
active_tasks = [task_id for task_id, task in self._tasks.items() if task.status == TaskStatus.PROCESSING]
pending_count = sum(1 for t in self._tasks.values() if t.status == TaskStatus.PENDING)
return {
"service_status": "busy" if self._current_processing_task else "idle",
"current_task": self._current_processing_task,
"active_tasks": active_tasks,
"pending_tasks": pending_count,
"queue_size": self.max_queue_size,
"total_tasks": self.total_tasks,
"completed_tasks": self.completed_tasks,
"failed_tasks": self.failed_tasks,
}
def _cleanup_old_tasks(self, keep_count: int = 1000):
if len(self._tasks) <= keep_count:
return
completed_tasks = [(task_id, task) for task_id, task in self._tasks.items() if task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]]
completed_tasks.sort(key=lambda x: x[1].end_time or x[1].start_time)
remove_count = len(self._tasks) - keep_count
for task_id, _ in completed_tasks[:remove_count]:
del self._tasks[task_id]
logger.debug(f"Cleaned up old task: {task_id}")
task_manager = TaskManager()
......@@ -2,9 +2,6 @@ import base64
import io
import signal
import sys
import threading
from datetime import datetime
from typing import Optional
import psutil
import torch
......@@ -43,81 +40,6 @@ class TaskStatusMessage(BaseModel):
task_id: str
class ServiceStatus:
_lock = threading.Lock()
_current_task = None
_result_store = {}
@classmethod
def start_task(cls, message):
with cls._lock:
if cls._current_task is not None:
raise RuntimeError("Service busy")
if message.task_id in cls._result_store:
raise RuntimeError(f"Task ID {message.task_id} already exists")
cls._current_task = {"message": message, "start_time": datetime.now()}
return message.task_id
@classmethod
def complete_task(cls, message):
with cls._lock:
if cls._current_task:
cls._result_store[message.task_id] = {
"success": True,
"message": message,
"start_time": cls._current_task["start_time"],
"completion_time": datetime.now(),
"save_video_path": message.save_video_path,
}
cls._current_task = None
@classmethod
def record_failed_task(cls, message, error: Optional[str] = None):
with cls._lock:
if cls._current_task:
cls._result_store[message.task_id] = {"success": False, "message": message, "start_time": cls._current_task["start_time"], "error": error, "save_video_path": message.save_video_path}
cls._current_task = None
@classmethod
def clean_stopped_task(cls):
with cls._lock:
if cls._current_task:
message = cls._current_task["message"]
error = "Task stopped by user"
cls._result_store[message.task_id] = {"success": False, "message": message, "start_time": cls._current_task["start_time"], "error": error, "save_video_path": message.save_video_path}
cls._current_task = None
@classmethod
def get_status_task_id(cls, task_id: str):
with cls._lock:
if cls._current_task and cls._current_task["message"].task_id == task_id:
return {"status": "processing", "task_id": task_id}
if task_id in cls._result_store:
result = cls._result_store[task_id]
return {
"status": "completed" if result["success"] else "failed",
"task_id": task_id,
"success": result["success"],
"start_time": result["start_time"],
"completion_time": result.get("completion_time"),
"error": result.get("error"),
"save_video_path": result.get("save_video_path"),
}
return {"status": "not_found", "task_id": task_id}
@classmethod
def get_status_service(cls):
with cls._lock:
if cls._current_task:
return {"service_status": "busy", "task_id": cls._current_task["message"].task_id, "start_time": cls._current_task["start_time"]}
return {"service_status": "idle"}
@classmethod
def get_all_tasks(cls):
with cls._lock:
return cls._result_store
class TensorTransporter:
def __init__(self):
self.buffer = io.BytesIO()
......
import base64
import requests
from loguru import logger
def image_to_base64(image_path):
"""Convert an image file to base64 string"""
with open(image_path, "rb") as f:
image_data = f.read()
return base64.b64encode(image_data).decode("utf-8")
if __name__ == "__main__":
url = "http://localhost:8000/v1/tasks/"
message = {
"prompt": "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
"negative_prompt": "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "assets/inputs/imgs/img_0.jpg", # 图片地址
"image_path": image_to_base64("assets/inputs/imgs/img_0.jpg"), # 图片地址
}
logger.info(f"message: {message}")
......
import base64
import os
import threading
import time
......@@ -6,10 +8,34 @@ from loguru import logger
from tqdm import tqdm
def image_to_base64(image_path):
"""Convert an image file to base64 string"""
with open(image_path, "rb") as f:
image_data = f.read()
return base64.b64encode(image_data).decode("utf-8")
def process_image_path(image_path):
"""处理image_path:如果是本地路径则转换为base64,如果是HTTP链接则保持不变"""
if not image_path:
return image_path
if image_path.startswith(("http://", "https://")):
return image_path
if os.path.exists(image_path):
return image_to_base64(image_path)
else:
logger.warning(f"Image path not found: {image_path}")
return image_path
def send_and_monitor_task(url, message, task_index, complete_bar, complete_lock):
"""Send task to server and monitor until completion"""
try:
# Step 1: Send task and get task_id
if "image_path" in message and message["image_path"]:
message["image_path"] = process_image_path(message["image_path"])
response = requests.post(f"{url}/v1/tasks/", json=message)
response_data = response.json()
task_id = response_data.get("task_id")
......@@ -38,7 +64,6 @@ def send_and_monitor_task(url, message, task_index, complete_bar, complete_lock)
complete_bar.update(1) # Still update progress even if failed
return False
else:
# Task still running, wait and check again
time.sleep(0.5)
except Exception as e:
......@@ -91,7 +116,8 @@ def process_tasks_async(messages, available_urls, show_progress=True):
logger.info(f"Sending {len(messages)} tasks to available servers...")
# Create completion progress bar
complete_bar = None
complete_lock = None
if show_progress:
complete_bar = tqdm(total=len(messages), desc="Completing tasks")
complete_lock = threading.Lock() # Thread-safe updates to completion bar
......@@ -101,7 +127,7 @@ def process_tasks_async(messages, available_urls, show_progress=True):
server_url = find_idle_server(available_urls)
# Create and start thread for sending and monitoring task
thread = threading.Thread(target=send_and_monitor_task, args=(server_url, message, idx, complete_bar if show_progress else None, complete_lock if show_progress else None))
thread = threading.Thread(target=send_and_monitor_task, args=(server_url, message, idx, complete_bar, complete_lock))
thread.daemon = False
thread.start()
active_threads.append(thread)
......@@ -114,7 +140,7 @@ def process_tasks_async(messages, available_urls, show_progress=True):
thread.join()
# Close completion bar
if show_progress:
if complete_bar:
complete_bar.close()
logger.info("All tasks processing completed!")
......
import base64
from loguru import logger
from post_multi_servers import get_available_urls, process_tasks_async
def image_to_base64(image_path):
"""Convert an image file to base64 string"""
with open(image_path, "rb") as f:
image_data = f.read()
return base64.b64encode(image_data).decode("utf-8")
if __name__ == "__main__":
urls = [f"http://localhost:{port}" for port in range(8000, 8008)]
img_prompts = {
......@@ -11,7 +21,7 @@ if __name__ == "__main__":
messages = []
for i, (image_path, prompt) in enumerate(img_prompts.items()):
messages.append({"prompt": prompt, "negative_prompt": negative_prompt, "image_path": image_path, "save_video_path": f"./output_lightx2v_wan_i2v_{i + 1}.mp4"})
messages.append({"prompt": prompt, "negative_prompt": negative_prompt, "image_path": image_to_base64(image_path), "save_video_path": f"./output_lightx2v_wan_i2v_{i + 1}.mp4"})
logger.info(f"urls: {urls}")
......
......@@ -9,6 +9,9 @@ export CUDA_VISIBLE_DEVICES=0
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
export ENABLE_GRAPH_MODE=false
export TORCH_CUDA_ARCH_LIST="9.0"
# Start API server with distributed inference service
python -m lightx2v.api_server \
--model_cls wan2.1_distill \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment