Commit b099ff96 authored by gaclove's avatar gaclove
Browse files

refactor: improve server configuration and distributed utilities

- Updated `ServerConfig` to raise a RuntimeError when no free port is found, providing clearer guidance for configuration adjustments.
- Introduced chunked broadcasting and receiving methods in `DistributedManager` to handle large byte data more efficiently.
- Refactored `broadcast_task_data` and `receive_task_data` methods to utilize the new chunking methods for improved readability and performance.
- Enhanced error logging in `image_utils.py` by replacing print statements with logger warnings.
- Cleaned up the `main.py` file by removing unused signal handling code.
parent bab78b8e
......@@ -62,9 +62,10 @@ class ServerConfig:
except OSError:
continue
import random
return str(random.randint(20000, 29999))
raise RuntimeError(
f"No free port found for master in range {self.master_port_range[0]}-{self.master_port_range[1] - 1} "
f"on address {self.master_addr}. Please adjust 'master_port_range' or free an occupied port."
)
def validate(self) -> bool:
valid = True
......
......@@ -16,6 +16,8 @@ class DistributedManager:
self.world_size = 1
self.device = "cpu"
CHUNK_SIZE = 1024 * 1024
def init_process_group(self, rank: int, world_size: int, master_addr: str, master_port: str) -> bool:
try:
os.environ["RANK"] = str(rank)
......@@ -61,6 +63,39 @@ class DistributedManager:
def is_rank_zero(self) -> bool:
return self.rank == 0
def _broadcast_byte_chunks(self, data_bytes: bytes, device: torch.device) -> None:
total_length = len(data_bytes)
num_full_chunks = total_length // self.CHUNK_SIZE
remaining = total_length % self.CHUNK_SIZE
for i in range(num_full_chunks):
start_idx = i * self.CHUNK_SIZE
end_idx = start_idx + self.CHUNK_SIZE
chunk = data_bytes[start_idx:end_idx]
task_tensor = torch.tensor(list(chunk), dtype=torch.uint8).to(device)
dist.broadcast(task_tensor, src=0)
if remaining:
chunk = data_bytes[-remaining:]
task_tensor = torch.tensor(list(chunk), dtype=torch.uint8).to(device)
dist.broadcast(task_tensor, src=0)
def _receive_byte_chunks(self, total_length: int, device: torch.device) -> bytes:
if total_length <= 0:
return b""
received = bytearray()
remaining = total_length
while remaining > 0:
chunk_length = min(self.CHUNK_SIZE, remaining)
task_tensor = torch.empty(chunk_length, dtype=torch.uint8).to(device)
dist.broadcast(task_tensor, src=0)
received.extend(task_tensor.cpu().numpy())
remaining -= chunk_length
return bytes(received)
def broadcast_task_data(self, task_data: Optional[Any] = None) -> Optional[Any]:
if not self.is_initialized:
return None
......@@ -88,19 +123,7 @@ class DistributedManager:
task_length = torch.tensor([len(task_bytes)], dtype=torch.int32).to(broadcast_device)
dist.broadcast(task_length, src=0)
chunk_size = 1024 * 1024
if len(task_bytes) > chunk_size:
num_chunks = (len(task_bytes) + chunk_size - 1) // chunk_size
for i in range(num_chunks):
start_idx = i * chunk_size
end_idx = min((i + 1) * chunk_size, len(task_bytes))
chunk = task_bytes[start_idx:end_idx]
task_tensor = torch.tensor(list(chunk), dtype=torch.uint8).to(broadcast_device)
dist.broadcast(task_tensor, src=0)
else:
task_tensor = torch.tensor(list(task_bytes), dtype=torch.uint8).to(broadcast_device)
dist.broadcast(task_tensor, src=0)
self._broadcast_byte_chunks(task_bytes, broadcast_device)
return task_data
else:
......@@ -113,25 +136,11 @@ class DistributedManager:
return None
else:
task_length = torch.tensor([0], dtype=torch.int32).to(broadcast_device)
dist.broadcast(task_length, src=0)
dist.broadcast(task_length, src=0)
total_length = int(task_length.item())
chunk_size = 1024 * 1024
if total_length > chunk_size:
task_bytes = bytearray()
num_chunks = (total_length + chunk_size - 1) // chunk_size
for i in range(num_chunks):
chunk_length = min(chunk_size, total_length - len(task_bytes))
task_tensor = torch.empty(chunk_length, dtype=torch.uint8).to(broadcast_device)
dist.broadcast(task_tensor, src=0)
task_bytes.extend(task_tensor.cpu().numpy())
task_bytes = bytes(task_bytes)
else:
task_tensor = torch.empty(total_length, dtype=torch.uint8).to(broadcast_device)
dist.broadcast(task_tensor, src=0)
task_bytes = bytes(task_tensor.cpu().numpy())
task_bytes = self._receive_byte_chunks(total_length, broadcast_device)
task_data = pickle.loads(task_bytes)
return task_data
......
......@@ -5,6 +5,8 @@ import uuid
from pathlib import Path
from typing import Optional, Tuple
from loguru import logger
def is_base64_image(data: str) -> bool:
"""Check if a string is a base64-encoded image"""
......@@ -24,7 +26,7 @@ def is_base64_image(data: str) -> bool:
if decoded[8:12] == b"WEBP":
return True
except Exception as e:
print(f"Error checking base64 image: {e}")
logger.warning(f"Error checking base64 image: {e}")
return False
return False
......@@ -45,7 +47,7 @@ def extract_base64_data(data: str) -> Tuple[str, Optional[str]]:
return data, None
def save_base64_image(base64_data: str, output_dir: str = "/tmp/flux_kontext_uploads") -> str:
def save_base64_image(base64_data: str, output_dir: str) -> str:
"""
Save a base64-encoded image to disk and return the file path
"""
......
import asyncio
import signal
import sys
from pathlib import Path
from typing import Optional
import uvicorn
from loguru import logger
......@@ -12,87 +9,6 @@ from .config import server_config
from .service import DistributedInferenceService
class ServerManager:
def __init__(self):
self.api_server: Optional[ApiServer] = None
self.inference_service: Optional[DistributedInferenceService] = None
self.shutdown_event = asyncio.Event()
async def startup(self, args):
logger.info("Starting LightX2V server...")
if hasattr(args, "host") and args.host:
server_config.host = args.host
if hasattr(args, "port") and args.port:
server_config.port = args.port
if not server_config.validate():
raise RuntimeError("Invalid server configuration")
self.inference_service = DistributedInferenceService()
if not self.inference_service.start_distributed_inference(args):
raise RuntimeError("Failed to start distributed inference service")
cache_dir = Path(server_config.cache_dir)
cache_dir.mkdir(parents=True, exist_ok=True)
self.api_server = ApiServer(max_queue_size=server_config.max_queue_size)
self.api_server.initialize_services(cache_dir, self.inference_service)
logger.info("Server startup completed successfully")
async def shutdown(self):
logger.info("Starting server shutdown...")
if self.api_server:
await self.api_server.cleanup()
logger.info("API server cleaned up")
if self.inference_service:
self.inference_service.stop_distributed_inference()
logger.info("Inference service stopped")
logger.info("Server shutdown completed")
def handle_signal(self, sig, frame):
logger.info(f"Received signal {sig}, initiating graceful shutdown...")
asyncio.create_task(self.shutdown())
self.shutdown_event.set()
async def run_server(self, args):
try:
await self.startup(args)
assert self.api_server is not None
app = self.api_server.get_app()
signal.signal(signal.SIGINT, self.handle_signal)
signal.signal(signal.SIGTERM, self.handle_signal)
logger.info(f"Starting server on {server_config.host}:{server_config.port}")
config = uvicorn.Config(
app=app,
host=server_config.host,
port=server_config.port,
log_level="info",
)
server = uvicorn.Server(config)
server_task = asyncio.create_task(server.serve())
await self.shutdown_event.wait()
server.should_exit = True
await server_task
except Exception as e:
logger.error(f"Server error: {e}")
raise
finally:
await self.shutdown()
def run_server(args):
inference_service = None
try:
......
......@@ -176,7 +176,8 @@ def _distributed_inference_worker(rank, world_size, master_addr, master_port, ar
logger.info(f"Process {rank}/{world_size - 1} distributed inference service initialization completed")
while True:
task_event.wait(timeout=1.0)
if not task_event.wait(timeout=1.0):
continue
if rank == 0:
if shared_data.get("stop", False):
......
......@@ -2,6 +2,7 @@ import base64
import os
import threading
import time
from typing import Any
import requests
from loguru import logger
......@@ -15,8 +16,8 @@ def image_to_base64(image_path):
return base64.b64encode(image_data).decode("utf-8")
def process_image_path(image_path):
"""处理image_path:如果是本地路径则转换为base64,如果是HTTP链接则保持不变"""
def process_image_path(image_path) -> Any | str:
"""Process image_path: convert to base64 if local path, keep unchanged if HTTP link"""
if not image_path:
return image_path
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment