Commit b099ff96 authored by gaclove's avatar gaclove
Browse files

refactor: improve server configuration and distributed utilities

- Updated `ServerConfig` to raise a RuntimeError when no free port is found, providing clearer guidance for configuration adjustments.
- Introduced chunked broadcasting and receiving methods in `DistributedManager` to handle large byte data more efficiently.
- Refactored `broadcast_task_data` and `receive_task_data` methods to utilize the new chunking methods for improved readability and performance.
- Enhanced error logging in `image_utils.py` by replacing print statements with logger warnings.
- Cleaned up the `main.py` file by removing unused signal handling code.
parent bab78b8e
......@@ -13,7 +13,7 @@ graph TB
subgraph "Client Layer"
Client[HTTP Client]
end
subgraph "API Layer"
FastAPI[FastAPI Application]
ApiServer[ApiServer]
......@@ -21,71 +21,71 @@ graph TB
Router2[Files Router<br/>/v1/files]
Router3[Service Router<br/>/v1/service]
end
subgraph "Service Layer"
TaskManager[TaskManager<br/>Thread-safe Task Queue]
FileService[FileService<br/>File I/O & Downloads]
VideoService[VideoGenerationService]
end
subgraph "Processing Layer"
Thread[Processing Thread<br/>Sequential Task Loop]
end
subgraph "Distributed Inference Layer"
DistService[DistributedInferenceService]
SharedData[(Shared Data<br/>mp.Manager.dict)]
TaskEvent[Task Event<br/>mp.Manager.Event]
ResultEvent[Result Event<br/>mp.Manager.Event]
subgraph "Worker Processes"
W0[Worker 0<br/>Master/Rank 0]
W1[Worker 1<br/>Rank 1]
WN[Worker N<br/>Rank N]
end
end
subgraph "Resource Management"
GPUManager[GPUManager<br/>GPU Detection & Allocation]
DistManager[DistributedManager<br/>PyTorch Distributed]
Config[ServerConfig<br/>Configuration]
end
Client -->|HTTP Request| FastAPI
FastAPI --> ApiServer
ApiServer --> Router1
ApiServer --> Router2
ApiServer --> Router3
Router1 -->|Create/Manage Tasks| TaskManager
Router1 -->|Process Tasks| Thread
Router2 -->|File Operations| FileService
Router3 -->|Service Status| TaskManager
Thread -->|Get Pending Tasks| TaskManager
Thread -->|Generate Video| VideoService
VideoService -->|Download Images| FileService
VideoService -->|Submit Task| DistService
DistService -->|Update| SharedData
DistService -->|Signal| TaskEvent
TaskEvent -->|Notify| W0
W0 -->|Broadcast| W1
W0 -->|Broadcast| WN
W0 -->|Update Result| SharedData
W0 -->|Signal| ResultEvent
ResultEvent -->|Notify| DistService
W0 -.->|Uses| GPUManager
W1 -.->|Uses| GPUManager
WN -.->|Uses| GPUManager
W0 -.->|Setup| DistManager
W1 -.->|Setup| DistManager
WN -.->|Setup| DistManager
DistService -.->|Reads| Config
ApiServer -.->|Reads| Config
```
......@@ -119,23 +119,23 @@ sequenceDiagram
participant DIS as Distributed<br/>Inference Service
participant W0 as Worker 0<br/>(Master)
participant W1 as Worker 1..N
C->>API: POST /v1/tasks<br/>(Create Task)
API->>TM: create_task()
TM->>TM: Generate task_id
TM->>TM: Add to queue<br/>(status: PENDING)
API->>PT: ensure_processing_thread()
API-->>C: TaskResponse<br/>(task_id, status: pending)
Note over PT: Processing Loop
PT->>TM: get_next_pending_task()
TM-->>PT: task_id
PT->>TM: acquire_processing_lock()
PT->>TM: start_task()<br/>(status: PROCESSING)
PT->>VS: generate_video_with_stop_event()
alt Image is URL
VS->>FS: download_image()
FS->>FS: HTTP download<br/>with retry
......@@ -147,39 +147,39 @@ sequenceDiagram
VS->>FS: validate_file()
FS-->>VS: image_path
end
VS->>DIS: submit_task(task_data)
DIS->>DIS: shared_data["current_task"] = task_data
DIS->>DIS: task_event.set()
Note over W0,W1: Distributed Processing
W0->>W0: task_event.wait()
W0->>W0: Get task from shared_data
W0->>W1: broadcast_task_data()
par Parallel Inference
W0->>W0: run_pipeline()
and
W1->>W1: run_pipeline()
end
W0->>W0: barrier() for sync
W0->>W0: shared_data["result"] = result
W0->>DIS: result_event.set()
DIS->>DIS: result_event.wait()
DIS->>VS: return result
VS-->>PT: TaskResponse
PT->>TM: complete_task()<br/>(status: COMPLETED)
PT->>TM: release_processing_lock()
Note over C: Client Polling
C->>API: GET /v1/tasks/{task_id}/status
API->>TM: get_task_status()
TM-->>API: status info
API-->>C: Task Status
C->>API: GET /v1/tasks/{task_id}/result
API->>TM: get_task_status()
API->>FS: stream_file_response()
......
......@@ -62,9 +62,10 @@ class ServerConfig:
except OSError:
continue
import random
return str(random.randint(20000, 29999))
raise RuntimeError(
f"No free port found for master in range {self.master_port_range[0]}-{self.master_port_range[1] - 1} "
f"on address {self.master_addr}. Please adjust 'master_port_range' or free an occupied port."
)
def validate(self) -> bool:
valid = True
......
......@@ -16,6 +16,8 @@ class DistributedManager:
self.world_size = 1
self.device = "cpu"
CHUNK_SIZE = 1024 * 1024
def init_process_group(self, rank: int, world_size: int, master_addr: str, master_port: str) -> bool:
try:
os.environ["RANK"] = str(rank)
......@@ -61,6 +63,39 @@ class DistributedManager:
def is_rank_zero(self) -> bool:
return self.rank == 0
def _broadcast_byte_chunks(self, data_bytes: bytes, device: torch.device) -> None:
total_length = len(data_bytes)
num_full_chunks = total_length // self.CHUNK_SIZE
remaining = total_length % self.CHUNK_SIZE
for i in range(num_full_chunks):
start_idx = i * self.CHUNK_SIZE
end_idx = start_idx + self.CHUNK_SIZE
chunk = data_bytes[start_idx:end_idx]
task_tensor = torch.tensor(list(chunk), dtype=torch.uint8).to(device)
dist.broadcast(task_tensor, src=0)
if remaining:
chunk = data_bytes[-remaining:]
task_tensor = torch.tensor(list(chunk), dtype=torch.uint8).to(device)
dist.broadcast(task_tensor, src=0)
def _receive_byte_chunks(self, total_length: int, device: torch.device) -> bytes:
if total_length <= 0:
return b""
received = bytearray()
remaining = total_length
while remaining > 0:
chunk_length = min(self.CHUNK_SIZE, remaining)
task_tensor = torch.empty(chunk_length, dtype=torch.uint8).to(device)
dist.broadcast(task_tensor, src=0)
received.extend(task_tensor.cpu().numpy())
remaining -= chunk_length
return bytes(received)
def broadcast_task_data(self, task_data: Optional[Any] = None) -> Optional[Any]:
if not self.is_initialized:
return None
......@@ -88,19 +123,7 @@ class DistributedManager:
task_length = torch.tensor([len(task_bytes)], dtype=torch.int32).to(broadcast_device)
dist.broadcast(task_length, src=0)
chunk_size = 1024 * 1024
if len(task_bytes) > chunk_size:
num_chunks = (len(task_bytes) + chunk_size - 1) // chunk_size
for i in range(num_chunks):
start_idx = i * chunk_size
end_idx = min((i + 1) * chunk_size, len(task_bytes))
chunk = task_bytes[start_idx:end_idx]
task_tensor = torch.tensor(list(chunk), dtype=torch.uint8).to(broadcast_device)
dist.broadcast(task_tensor, src=0)
else:
task_tensor = torch.tensor(list(task_bytes), dtype=torch.uint8).to(broadcast_device)
dist.broadcast(task_tensor, src=0)
self._broadcast_byte_chunks(task_bytes, broadcast_device)
return task_data
else:
......@@ -113,25 +136,11 @@ class DistributedManager:
return None
else:
task_length = torch.tensor([0], dtype=torch.int32).to(broadcast_device)
dist.broadcast(task_length, src=0)
dist.broadcast(task_length, src=0)
total_length = int(task_length.item())
chunk_size = 1024 * 1024
if total_length > chunk_size:
task_bytes = bytearray()
num_chunks = (total_length + chunk_size - 1) // chunk_size
for i in range(num_chunks):
chunk_length = min(chunk_size, total_length - len(task_bytes))
task_tensor = torch.empty(chunk_length, dtype=torch.uint8).to(broadcast_device)
dist.broadcast(task_tensor, src=0)
task_bytes.extend(task_tensor.cpu().numpy())
task_bytes = bytes(task_bytes)
else:
task_tensor = torch.empty(total_length, dtype=torch.uint8).to(broadcast_device)
dist.broadcast(task_tensor, src=0)
task_bytes = bytes(task_tensor.cpu().numpy())
task_bytes = self._receive_byte_chunks(total_length, broadcast_device)
task_data = pickle.loads(task_bytes)
return task_data
......
......@@ -5,6 +5,8 @@ import uuid
from pathlib import Path
from typing import Optional, Tuple
from loguru import logger
def is_base64_image(data: str) -> bool:
"""Check if a string is a base64-encoded image"""
......@@ -24,7 +26,7 @@ def is_base64_image(data: str) -> bool:
if decoded[8:12] == b"WEBP":
return True
except Exception as e:
print(f"Error checking base64 image: {e}")
logger.warning(f"Error checking base64 image: {e}")
return False
return False
......@@ -45,7 +47,7 @@ def extract_base64_data(data: str) -> Tuple[str, Optional[str]]:
return data, None
def save_base64_image(base64_data: str, output_dir: str = "/tmp/flux_kontext_uploads") -> str:
def save_base64_image(base64_data: str, output_dir: str) -> str:
"""
Save a base64-encoded image to disk and return the file path
"""
......
import asyncio
import signal
import sys
from pathlib import Path
from typing import Optional
import uvicorn
from loguru import logger
......@@ -12,87 +9,6 @@ from .config import server_config
from .service import DistributedInferenceService
class ServerManager:
def __init__(self):
self.api_server: Optional[ApiServer] = None
self.inference_service: Optional[DistributedInferenceService] = None
self.shutdown_event = asyncio.Event()
async def startup(self, args):
logger.info("Starting LightX2V server...")
if hasattr(args, "host") and args.host:
server_config.host = args.host
if hasattr(args, "port") and args.port:
server_config.port = args.port
if not server_config.validate():
raise RuntimeError("Invalid server configuration")
self.inference_service = DistributedInferenceService()
if not self.inference_service.start_distributed_inference(args):
raise RuntimeError("Failed to start distributed inference service")
cache_dir = Path(server_config.cache_dir)
cache_dir.mkdir(parents=True, exist_ok=True)
self.api_server = ApiServer(max_queue_size=server_config.max_queue_size)
self.api_server.initialize_services(cache_dir, self.inference_service)
logger.info("Server startup completed successfully")
async def shutdown(self):
logger.info("Starting server shutdown...")
if self.api_server:
await self.api_server.cleanup()
logger.info("API server cleaned up")
if self.inference_service:
self.inference_service.stop_distributed_inference()
logger.info("Inference service stopped")
logger.info("Server shutdown completed")
def handle_signal(self, sig, frame):
logger.info(f"Received signal {sig}, initiating graceful shutdown...")
asyncio.create_task(self.shutdown())
self.shutdown_event.set()
async def run_server(self, args):
try:
await self.startup(args)
assert self.api_server is not None
app = self.api_server.get_app()
signal.signal(signal.SIGINT, self.handle_signal)
signal.signal(signal.SIGTERM, self.handle_signal)
logger.info(f"Starting server on {server_config.host}:{server_config.port}")
config = uvicorn.Config(
app=app,
host=server_config.host,
port=server_config.port,
log_level="info",
)
server = uvicorn.Server(config)
server_task = asyncio.create_task(server.serve())
await self.shutdown_event.wait()
server.should_exit = True
await server_task
except Exception as e:
logger.error(f"Server error: {e}")
raise
finally:
await self.shutdown()
def run_server(args):
inference_service = None
try:
......
......@@ -176,7 +176,8 @@ def _distributed_inference_worker(rank, world_size, master_addr, master_port, ar
logger.info(f"Process {rank}/{world_size - 1} distributed inference service initialization completed")
while True:
task_event.wait(timeout=1.0)
if not task_event.wait(timeout=1.0):
continue
if rank == 0:
if shared_data.get("stop", False):
......
......@@ -2,6 +2,7 @@ import base64
import os
import threading
import time
from typing import Any
import requests
from loguru import logger
......@@ -15,8 +16,8 @@ def image_to_base64(image_path):
return base64.b64encode(image_data).decode("utf-8")
def process_image_path(image_path):
"""处理image_path:如果是本地路径则转换为base64,如果是HTTP链接则保持不变"""
def process_image_path(image_path) -> Any | str:
"""Process image_path: convert to base64 if local path, keep unchanged if HTTP link"""
if not image_path:
return image_path
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment