Commit b099ff96 authored by gaclove's avatar gaclove
Browse files

refactor: improve server configuration and distributed utilities

- Updated `ServerConfig` to raise a RuntimeError when no free port is found, providing clearer guidance for configuration adjustments.
- Introduced chunked broadcasting and receiving methods in `DistributedManager` to handle large byte data more efficiently.
- Refactored `broadcast_task_data` and `receive_task_data` methods to utilize the new chunking methods for improved readability and performance.
- Enhanced error logging in `image_utils.py` by replacing print statements with logger warnings.
- Cleaned up the `main.py` file by removing unused signal handling code.
parent bab78b8e
...@@ -13,7 +13,7 @@ graph TB ...@@ -13,7 +13,7 @@ graph TB
subgraph "Client Layer" subgraph "Client Layer"
Client[HTTP Client] Client[HTTP Client]
end end
subgraph "API Layer" subgraph "API Layer"
FastAPI[FastAPI Application] FastAPI[FastAPI Application]
ApiServer[ApiServer] ApiServer[ApiServer]
...@@ -21,71 +21,71 @@ graph TB ...@@ -21,71 +21,71 @@ graph TB
Router2[Files Router<br/>/v1/files] Router2[Files Router<br/>/v1/files]
Router3[Service Router<br/>/v1/service] Router3[Service Router<br/>/v1/service]
end end
subgraph "Service Layer" subgraph "Service Layer"
TaskManager[TaskManager<br/>Thread-safe Task Queue] TaskManager[TaskManager<br/>Thread-safe Task Queue]
FileService[FileService<br/>File I/O & Downloads] FileService[FileService<br/>File I/O & Downloads]
VideoService[VideoGenerationService] VideoService[VideoGenerationService]
end end
subgraph "Processing Layer" subgraph "Processing Layer"
Thread[Processing Thread<br/>Sequential Task Loop] Thread[Processing Thread<br/>Sequential Task Loop]
end end
subgraph "Distributed Inference Layer" subgraph "Distributed Inference Layer"
DistService[DistributedInferenceService] DistService[DistributedInferenceService]
SharedData[(Shared Data<br/>mp.Manager.dict)] SharedData[(Shared Data<br/>mp.Manager.dict)]
TaskEvent[Task Event<br/>mp.Manager.Event] TaskEvent[Task Event<br/>mp.Manager.Event]
ResultEvent[Result Event<br/>mp.Manager.Event] ResultEvent[Result Event<br/>mp.Manager.Event]
subgraph "Worker Processes" subgraph "Worker Processes"
W0[Worker 0<br/>Master/Rank 0] W0[Worker 0<br/>Master/Rank 0]
W1[Worker 1<br/>Rank 1] W1[Worker 1<br/>Rank 1]
WN[Worker N<br/>Rank N] WN[Worker N<br/>Rank N]
end end
end end
subgraph "Resource Management" subgraph "Resource Management"
GPUManager[GPUManager<br/>GPU Detection & Allocation] GPUManager[GPUManager<br/>GPU Detection & Allocation]
DistManager[DistributedManager<br/>PyTorch Distributed] DistManager[DistributedManager<br/>PyTorch Distributed]
Config[ServerConfig<br/>Configuration] Config[ServerConfig<br/>Configuration]
end end
Client -->|HTTP Request| FastAPI Client -->|HTTP Request| FastAPI
FastAPI --> ApiServer FastAPI --> ApiServer
ApiServer --> Router1 ApiServer --> Router1
ApiServer --> Router2 ApiServer --> Router2
ApiServer --> Router3 ApiServer --> Router3
Router1 -->|Create/Manage Tasks| TaskManager Router1 -->|Create/Manage Tasks| TaskManager
Router1 -->|Process Tasks| Thread Router1 -->|Process Tasks| Thread
Router2 -->|File Operations| FileService Router2 -->|File Operations| FileService
Router3 -->|Service Status| TaskManager Router3 -->|Service Status| TaskManager
Thread -->|Get Pending Tasks| TaskManager Thread -->|Get Pending Tasks| TaskManager
Thread -->|Generate Video| VideoService Thread -->|Generate Video| VideoService
VideoService -->|Download Images| FileService VideoService -->|Download Images| FileService
VideoService -->|Submit Task| DistService VideoService -->|Submit Task| DistService
DistService -->|Update| SharedData DistService -->|Update| SharedData
DistService -->|Signal| TaskEvent DistService -->|Signal| TaskEvent
TaskEvent -->|Notify| W0 TaskEvent -->|Notify| W0
W0 -->|Broadcast| W1 W0 -->|Broadcast| W1
W0 -->|Broadcast| WN W0 -->|Broadcast| WN
W0 -->|Update Result| SharedData W0 -->|Update Result| SharedData
W0 -->|Signal| ResultEvent W0 -->|Signal| ResultEvent
ResultEvent -->|Notify| DistService ResultEvent -->|Notify| DistService
W0 -.->|Uses| GPUManager W0 -.->|Uses| GPUManager
W1 -.->|Uses| GPUManager W1 -.->|Uses| GPUManager
WN -.->|Uses| GPUManager WN -.->|Uses| GPUManager
W0 -.->|Setup| DistManager W0 -.->|Setup| DistManager
W1 -.->|Setup| DistManager W1 -.->|Setup| DistManager
WN -.->|Setup| DistManager WN -.->|Setup| DistManager
DistService -.->|Reads| Config DistService -.->|Reads| Config
ApiServer -.->|Reads| Config ApiServer -.->|Reads| Config
``` ```
...@@ -119,23 +119,23 @@ sequenceDiagram ...@@ -119,23 +119,23 @@ sequenceDiagram
participant DIS as Distributed<br/>Inference Service participant DIS as Distributed<br/>Inference Service
participant W0 as Worker 0<br/>(Master) participant W0 as Worker 0<br/>(Master)
participant W1 as Worker 1..N participant W1 as Worker 1..N
C->>API: POST /v1/tasks<br/>(Create Task) C->>API: POST /v1/tasks<br/>(Create Task)
API->>TM: create_task() API->>TM: create_task()
TM->>TM: Generate task_id TM->>TM: Generate task_id
TM->>TM: Add to queue<br/>(status: PENDING) TM->>TM: Add to queue<br/>(status: PENDING)
API->>PT: ensure_processing_thread() API->>PT: ensure_processing_thread()
API-->>C: TaskResponse<br/>(task_id, status: pending) API-->>C: TaskResponse<br/>(task_id, status: pending)
Note over PT: Processing Loop Note over PT: Processing Loop
PT->>TM: get_next_pending_task() PT->>TM: get_next_pending_task()
TM-->>PT: task_id TM-->>PT: task_id
PT->>TM: acquire_processing_lock() PT->>TM: acquire_processing_lock()
PT->>TM: start_task()<br/>(status: PROCESSING) PT->>TM: start_task()<br/>(status: PROCESSING)
PT->>VS: generate_video_with_stop_event() PT->>VS: generate_video_with_stop_event()
alt Image is URL alt Image is URL
VS->>FS: download_image() VS->>FS: download_image()
FS->>FS: HTTP download<br/>with retry FS->>FS: HTTP download<br/>with retry
...@@ -147,39 +147,39 @@ sequenceDiagram ...@@ -147,39 +147,39 @@ sequenceDiagram
VS->>FS: validate_file() VS->>FS: validate_file()
FS-->>VS: image_path FS-->>VS: image_path
end end
VS->>DIS: submit_task(task_data) VS->>DIS: submit_task(task_data)
DIS->>DIS: shared_data["current_task"] = task_data DIS->>DIS: shared_data["current_task"] = task_data
DIS->>DIS: task_event.set() DIS->>DIS: task_event.set()
Note over W0,W1: Distributed Processing Note over W0,W1: Distributed Processing
W0->>W0: task_event.wait() W0->>W0: task_event.wait()
W0->>W0: Get task from shared_data W0->>W0: Get task from shared_data
W0->>W1: broadcast_task_data() W0->>W1: broadcast_task_data()
par Parallel Inference par Parallel Inference
W0->>W0: run_pipeline() W0->>W0: run_pipeline()
and and
W1->>W1: run_pipeline() W1->>W1: run_pipeline()
end end
W0->>W0: barrier() for sync W0->>W0: barrier() for sync
W0->>W0: shared_data["result"] = result W0->>W0: shared_data["result"] = result
W0->>DIS: result_event.set() W0->>DIS: result_event.set()
DIS->>DIS: result_event.wait() DIS->>DIS: result_event.wait()
DIS->>VS: return result DIS->>VS: return result
VS-->>PT: TaskResponse VS-->>PT: TaskResponse
PT->>TM: complete_task()<br/>(status: COMPLETED) PT->>TM: complete_task()<br/>(status: COMPLETED)
PT->>TM: release_processing_lock() PT->>TM: release_processing_lock()
Note over C: Client Polling Note over C: Client Polling
C->>API: GET /v1/tasks/{task_id}/status C->>API: GET /v1/tasks/{task_id}/status
API->>TM: get_task_status() API->>TM: get_task_status()
TM-->>API: status info TM-->>API: status info
API-->>C: Task Status API-->>C: Task Status
C->>API: GET /v1/tasks/{task_id}/result C->>API: GET /v1/tasks/{task_id}/result
API->>TM: get_task_status() API->>TM: get_task_status()
API->>FS: stream_file_response() API->>FS: stream_file_response()
......
...@@ -62,9 +62,10 @@ class ServerConfig: ...@@ -62,9 +62,10 @@ class ServerConfig:
except OSError: except OSError:
continue continue
import random raise RuntimeError(
f"No free port found for master in range {self.master_port_range[0]}-{self.master_port_range[1] - 1} "
return str(random.randint(20000, 29999)) f"on address {self.master_addr}. Please adjust 'master_port_range' or free an occupied port."
)
def validate(self) -> bool: def validate(self) -> bool:
valid = True valid = True
......
...@@ -16,6 +16,8 @@ class DistributedManager: ...@@ -16,6 +16,8 @@ class DistributedManager:
self.world_size = 1 self.world_size = 1
self.device = "cpu" self.device = "cpu"
CHUNK_SIZE = 1024 * 1024
def init_process_group(self, rank: int, world_size: int, master_addr: str, master_port: str) -> bool: def init_process_group(self, rank: int, world_size: int, master_addr: str, master_port: str) -> bool:
try: try:
os.environ["RANK"] = str(rank) os.environ["RANK"] = str(rank)
...@@ -61,6 +63,39 @@ class DistributedManager: ...@@ -61,6 +63,39 @@ class DistributedManager:
def is_rank_zero(self) -> bool: def is_rank_zero(self) -> bool:
return self.rank == 0 return self.rank == 0
def _broadcast_byte_chunks(self, data_bytes: bytes, device: torch.device) -> None:
total_length = len(data_bytes)
num_full_chunks = total_length // self.CHUNK_SIZE
remaining = total_length % self.CHUNK_SIZE
for i in range(num_full_chunks):
start_idx = i * self.CHUNK_SIZE
end_idx = start_idx + self.CHUNK_SIZE
chunk = data_bytes[start_idx:end_idx]
task_tensor = torch.tensor(list(chunk), dtype=torch.uint8).to(device)
dist.broadcast(task_tensor, src=0)
if remaining:
chunk = data_bytes[-remaining:]
task_tensor = torch.tensor(list(chunk), dtype=torch.uint8).to(device)
dist.broadcast(task_tensor, src=0)
def _receive_byte_chunks(self, total_length: int, device: torch.device) -> bytes:
if total_length <= 0:
return b""
received = bytearray()
remaining = total_length
while remaining > 0:
chunk_length = min(self.CHUNK_SIZE, remaining)
task_tensor = torch.empty(chunk_length, dtype=torch.uint8).to(device)
dist.broadcast(task_tensor, src=0)
received.extend(task_tensor.cpu().numpy())
remaining -= chunk_length
return bytes(received)
def broadcast_task_data(self, task_data: Optional[Any] = None) -> Optional[Any]: def broadcast_task_data(self, task_data: Optional[Any] = None) -> Optional[Any]:
if not self.is_initialized: if not self.is_initialized:
return None return None
...@@ -88,19 +123,7 @@ class DistributedManager: ...@@ -88,19 +123,7 @@ class DistributedManager:
task_length = torch.tensor([len(task_bytes)], dtype=torch.int32).to(broadcast_device) task_length = torch.tensor([len(task_bytes)], dtype=torch.int32).to(broadcast_device)
dist.broadcast(task_length, src=0) dist.broadcast(task_length, src=0)
self._broadcast_byte_chunks(task_bytes, broadcast_device)
chunk_size = 1024 * 1024
if len(task_bytes) > chunk_size:
num_chunks = (len(task_bytes) + chunk_size - 1) // chunk_size
for i in range(num_chunks):
start_idx = i * chunk_size
end_idx = min((i + 1) * chunk_size, len(task_bytes))
chunk = task_bytes[start_idx:end_idx]
task_tensor = torch.tensor(list(chunk), dtype=torch.uint8).to(broadcast_device)
dist.broadcast(task_tensor, src=0)
else:
task_tensor = torch.tensor(list(task_bytes), dtype=torch.uint8).to(broadcast_device)
dist.broadcast(task_tensor, src=0)
return task_data return task_data
else: else:
...@@ -113,25 +136,11 @@ class DistributedManager: ...@@ -113,25 +136,11 @@ class DistributedManager:
return None return None
else: else:
task_length = torch.tensor([0], dtype=torch.int32).to(broadcast_device) task_length = torch.tensor([0], dtype=torch.int32).to(broadcast_device)
dist.broadcast(task_length, src=0)
dist.broadcast(task_length, src=0)
total_length = int(task_length.item()) total_length = int(task_length.item())
chunk_size = 1024 * 1024
if total_length > chunk_size:
task_bytes = bytearray()
num_chunks = (total_length + chunk_size - 1) // chunk_size
for i in range(num_chunks):
chunk_length = min(chunk_size, total_length - len(task_bytes))
task_tensor = torch.empty(chunk_length, dtype=torch.uint8).to(broadcast_device)
dist.broadcast(task_tensor, src=0)
task_bytes.extend(task_tensor.cpu().numpy())
task_bytes = bytes(task_bytes)
else:
task_tensor = torch.empty(total_length, dtype=torch.uint8).to(broadcast_device)
dist.broadcast(task_tensor, src=0)
task_bytes = bytes(task_tensor.cpu().numpy())
task_bytes = self._receive_byte_chunks(total_length, broadcast_device)
task_data = pickle.loads(task_bytes) task_data = pickle.loads(task_bytes)
return task_data return task_data
......
...@@ -5,6 +5,8 @@ import uuid ...@@ -5,6 +5,8 @@ import uuid
from pathlib import Path from pathlib import Path
from typing import Optional, Tuple from typing import Optional, Tuple
from loguru import logger
def is_base64_image(data: str) -> bool: def is_base64_image(data: str) -> bool:
"""Check if a string is a base64-encoded image""" """Check if a string is a base64-encoded image"""
...@@ -24,7 +26,7 @@ def is_base64_image(data: str) -> bool: ...@@ -24,7 +26,7 @@ def is_base64_image(data: str) -> bool:
if decoded[8:12] == b"WEBP": if decoded[8:12] == b"WEBP":
return True return True
except Exception as e: except Exception as e:
print(f"Error checking base64 image: {e}") logger.warning(f"Error checking base64 image: {e}")
return False return False
return False return False
...@@ -45,7 +47,7 @@ def extract_base64_data(data: str) -> Tuple[str, Optional[str]]: ...@@ -45,7 +47,7 @@ def extract_base64_data(data: str) -> Tuple[str, Optional[str]]:
return data, None return data, None
def save_base64_image(base64_data: str, output_dir: str = "/tmp/flux_kontext_uploads") -> str: def save_base64_image(base64_data: str, output_dir: str) -> str:
""" """
Save a base64-encoded image to disk and return the file path Save a base64-encoded image to disk and return the file path
""" """
......
import asyncio
import signal
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Optional
import uvicorn import uvicorn
from loguru import logger from loguru import logger
...@@ -12,87 +9,6 @@ from .config import server_config ...@@ -12,87 +9,6 @@ from .config import server_config
from .service import DistributedInferenceService from .service import DistributedInferenceService
class ServerManager:
def __init__(self):
self.api_server: Optional[ApiServer] = None
self.inference_service: Optional[DistributedInferenceService] = None
self.shutdown_event = asyncio.Event()
async def startup(self, args):
logger.info("Starting LightX2V server...")
if hasattr(args, "host") and args.host:
server_config.host = args.host
if hasattr(args, "port") and args.port:
server_config.port = args.port
if not server_config.validate():
raise RuntimeError("Invalid server configuration")
self.inference_service = DistributedInferenceService()
if not self.inference_service.start_distributed_inference(args):
raise RuntimeError("Failed to start distributed inference service")
cache_dir = Path(server_config.cache_dir)
cache_dir.mkdir(parents=True, exist_ok=True)
self.api_server = ApiServer(max_queue_size=server_config.max_queue_size)
self.api_server.initialize_services(cache_dir, self.inference_service)
logger.info("Server startup completed successfully")
async def shutdown(self):
logger.info("Starting server shutdown...")
if self.api_server:
await self.api_server.cleanup()
logger.info("API server cleaned up")
if self.inference_service:
self.inference_service.stop_distributed_inference()
logger.info("Inference service stopped")
logger.info("Server shutdown completed")
def handle_signal(self, sig, frame):
logger.info(f"Received signal {sig}, initiating graceful shutdown...")
asyncio.create_task(self.shutdown())
self.shutdown_event.set()
async def run_server(self, args):
try:
await self.startup(args)
assert self.api_server is not None
app = self.api_server.get_app()
signal.signal(signal.SIGINT, self.handle_signal)
signal.signal(signal.SIGTERM, self.handle_signal)
logger.info(f"Starting server on {server_config.host}:{server_config.port}")
config = uvicorn.Config(
app=app,
host=server_config.host,
port=server_config.port,
log_level="info",
)
server = uvicorn.Server(config)
server_task = asyncio.create_task(server.serve())
await self.shutdown_event.wait()
server.should_exit = True
await server_task
except Exception as e:
logger.error(f"Server error: {e}")
raise
finally:
await self.shutdown()
def run_server(args): def run_server(args):
inference_service = None inference_service = None
try: try:
......
...@@ -176,7 +176,8 @@ def _distributed_inference_worker(rank, world_size, master_addr, master_port, ar ...@@ -176,7 +176,8 @@ def _distributed_inference_worker(rank, world_size, master_addr, master_port, ar
logger.info(f"Process {rank}/{world_size - 1} distributed inference service initialization completed") logger.info(f"Process {rank}/{world_size - 1} distributed inference service initialization completed")
while True: while True:
task_event.wait(timeout=1.0) if not task_event.wait(timeout=1.0):
continue
if rank == 0: if rank == 0:
if shared_data.get("stop", False): if shared_data.get("stop", False):
......
...@@ -2,6 +2,7 @@ import base64 ...@@ -2,6 +2,7 @@ import base64
import os import os
import threading import threading
import time import time
from typing import Any
import requests import requests
from loguru import logger from loguru import logger
...@@ -15,8 +16,8 @@ def image_to_base64(image_path): ...@@ -15,8 +16,8 @@ def image_to_base64(image_path):
return base64.b64encode(image_data).decode("utf-8") return base64.b64encode(image_data).decode("utf-8")
def process_image_path(image_path): def process_image_path(image_path) -> Any | str:
"""处理image_path:如果是本地路径则转换为base64,如果是HTTP链接则保持不变""" """Process image_path: convert to base64 if local path, keep unchanged if HTTP link"""
if not image_path: if not image_path:
return image_path return image_path
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment