Commit 3c778aee authored by PengGao's avatar PengGao Committed by GitHub
Browse files

Gp/dev (#310)


Co-authored-by: default avatarYang Yong(雍洋) <yongyang1030@163.com>
parent 32fd1c52
import argparse
import concurrent.futures
import os
import socket
import subprocess
import time
from dataclasses import dataclass
from typing import Optional
import requests
from loguru import logger
@dataclass
class ServerConfig:
port: int
gpu_id: int
model_cls: str
task: str
model_path: str
config_json: str
def get_node_ip() -> str:
"""Get the IP address of the current node"""
try:
# Create a UDP socket
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# Connect to an external address (no actual connection needed)
s.connect(("8.8.8.8", 80))
# Get local IP
ip = s.getsockname()[0]
s.close()
return ip
except Exception as e:
logger.error(f"Failed to get IP address: {e}")
return "localhost"
def is_port_in_use(port: int) -> bool:
"""Check if a port is in use"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(("localhost", port)) == 0
def find_available_port(start_port: int) -> Optional[int]:
"""Find an available port starting from start_port"""
port = start_port
while port < start_port + 1000: # Try up to 1000 ports
if not is_port_in_use(port):
return port
port += 1
return None
def start_server(config: ServerConfig) -> Optional[tuple[subprocess.Popen, str]]:
"""Start a single server instance"""
try:
# Set GPU
env = os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = str(config.gpu_id)
# Start server
process = subprocess.Popen(
[
"python",
"-m",
"lightx2v.api_server",
"--model_cls",
config.model_cls,
"--task",
config.task,
"--model_path",
config.model_path,
"--config_json",
config.config_json,
"--port",
str(config.port),
],
env=env,
)
# Wait for server to start, up to 600 seconds
node_ip = get_node_ip()
service_url = f"http://{node_ip}:{config.port}/v1/service/status"
# Check once per second, up to 600 times
for _ in range(600):
try:
response = requests.get(service_url, timeout=1)
if response.status_code == 200:
return process, f"http://{node_ip}:{config.port}"
except (requests.RequestException, ConnectionError) as e:
pass
time.sleep(1)
# If timeout, terminate the process
logger.error(f"Server startup timeout: port={config.port}, gpu={config.gpu_id}")
process.terminate()
return None
except Exception as e:
logger.error(f"Failed to start server: {e}")
return None
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--num_gpus", type=int, required=True, help="Number of GPUs to use")
parser.add_argument("--start_port", type=int, required=True, help="Starting port number")
parser.add_argument("--model_cls", type=str, required=True, help="Model class")
parser.add_argument("--task", type=str, required=True, help="Task type")
parser.add_argument("--model_path", type=str, required=True, help="Model path")
parser.add_argument("--config_json", type=str, required=True, help="Config file path")
args = parser.parse_args()
# Prepare configurations for all servers on this node
server_configs = []
current_port = args.start_port
# Create configs for each GPU on this node
for gpu in range(args.num_gpus):
port = find_available_port(current_port)
if port is None:
logger.error(f"Cannot find available port starting from {current_port}")
continue
config = ServerConfig(port=port, gpu_id=gpu, model_cls=args.model_cls, task=args.task, model_path=args.model_path, config_json=args.config_json)
server_configs.append(config)
current_port = port + 1
# Start all servers in parallel
processes = []
urls = []
with concurrent.futures.ThreadPoolExecutor(max_workers=len(server_configs)) as executor:
future_to_config = {executor.submit(start_server, config): config for config in server_configs}
for future in concurrent.futures.as_completed(future_to_config):
config = future_to_config[future]
try:
result = future.result()
if result:
process, url = result
processes.append(process)
urls.append(url)
logger.info(f"Server started successfully: {url} (GPU: {config.gpu_id})")
else:
logger.error(f"Failed to start server: port={config.port}, gpu={config.gpu_id}")
except Exception as e:
logger.error(f"Error occurred while starting server: {e}")
# Print all server URLs
logger.info("\nAll server URLs:")
for url in urls:
logger.info(url)
# Print node information
node_ip = get_node_ip()
logger.info(f"\nCurrent node IP: {node_ip}")
logger.info(f"Number of servers started: {len(urls)}")
try:
# Wait for all processes
for process in processes:
process.wait()
except KeyboardInterrupt:
logger.info("Received interrupt signal, shutting down all servers...")
for process in processes:
process.terminate()
if __name__ == "__main__":
main()
......@@ -9,85 +9,91 @@ The LightX2V server is a distributed video generation service built with FastAPI
### System Architecture
```mermaid
graph TB
subgraph "Client Layer"
Client[HTTP Client]
end
flowchart TB
Client[Client] -->|Send API Request| Router[FastAPI Router]
subgraph "API Layer"
FastAPI[FastAPI Application]
ApiServer[ApiServer]
Router1[Tasks Router<br/>/v1/tasks]
Router2[Files Router<br/>/v1/files]
Router3[Service Router<br/>/v1/service]
end
subgraph API Layer
Router --> TaskRoutes[Task APIs]
Router --> FileRoutes[File APIs]
Router --> ServiceRoutes[Service Status APIs]
subgraph "Service Layer"
TaskManager[TaskManager<br/>Thread-safe Task Queue]
FileService[FileService<br/>File I/O & Downloads]
VideoService[VideoGenerationService]
end
TaskRoutes --> CreateTask["POST /v1/tasks/ - Create Task"]
TaskRoutes --> CreateTaskForm["POST /v1/tasks/form - Form Create"]
TaskRoutes --> ListTasks["GET /v1/tasks/ - List Tasks"]
TaskRoutes --> GetTaskStatus["GET /v1/tasks/id/status - Get Status"]
TaskRoutes --> GetTaskResult["GET /v1/tasks/id/result - Get Result"]
TaskRoutes --> StopTask["DELETE /v1/tasks/id - Stop Task"]
subgraph "Processing Layer"
Thread[Processing Thread<br/>Sequential Task Loop]
end
FileRoutes --> DownloadFile["GET /v1/files/download/path - Download File"]
subgraph "Distributed Inference Layer"
DistService[DistributedInferenceService]
SharedData[(Shared Data<br/>mp.Manager.dict)]
TaskEvent[Task Event<br/>mp.Manager.Event]
ResultEvent[Result Event<br/>mp.Manager.Event]
subgraph "Worker Processes"
W0[Worker 0<br/>Master/Rank 0]
W1[Worker 1<br/>Rank 1]
WN[Worker N<br/>Rank N]
end
ServiceRoutes --> GetServiceStatus["GET /v1/service/status - Service Status"]
ServiceRoutes --> GetServiceMetadata["GET /v1/service/metadata - Metadata"]
end
subgraph "Resource Management"
GPUManager[GPUManager<br/>GPU Detection & Allocation]
DistManager[DistributedManager<br/>PyTorch Distributed]
Config[ServerConfig<br/>Configuration]
subgraph Task Management
TaskManager[Task Manager]
TaskQueue[Task Queue]
TaskStatus[Task Status]
TaskResult[Task Result]
CreateTask --> TaskManager
CreateTaskForm --> TaskManager
TaskManager --> TaskQueue
TaskManager --> TaskStatus
TaskManager --> TaskResult
end
Client -->|HTTP Request| FastAPI
FastAPI --> ApiServer
ApiServer --> Router1
ApiServer --> Router2
ApiServer --> Router3
Router1 -->|Create/Manage Tasks| TaskManager
Router1 -->|Process Tasks| Thread
Router2 -->|File Operations| FileService
Router3 -->|Service Status| TaskManager
Thread -->|Get Pending Tasks| TaskManager
Thread -->|Generate Video| VideoService
subgraph File Service
FileService[File Service]
DownloadImage[Download Image]
DownloadAudio[Download Audio]
SaveFile[Save File]
GetOutputPath[Get Output Path]
FileService --> DownloadImage
FileService --> DownloadAudio
FileService --> SaveFile
FileService --> GetOutputPath
end
VideoService -->|Download Images| FileService
VideoService -->|Submit Task| DistService
subgraph Processing Thread
ProcessingThread[Processing Thread]
NextTask[Get Next Task]
ProcessTask[Process Single Task]
DistService -->|Update| SharedData
DistService -->|Signal| TaskEvent
TaskEvent -->|Notify| W0
W0 -->|Broadcast| W1
W0 -->|Broadcast| WN
ProcessingThread --> NextTask
ProcessingThread --> ProcessTask
end
W0 -->|Update Result| SharedData
W0 -->|Signal| ResultEvent
ResultEvent -->|Notify| DistService
subgraph Video Generation Service
VideoService[Video Service]
GenerateVideo[Generate Video]
W0 -.->|Uses| GPUManager
W1 -.->|Uses| GPUManager
WN -.->|Uses| GPUManager
VideoService --> GenerateVideo
end
W0 -.->|Setup| DistManager
W1 -.->|Setup| DistManager
WN -.->|Setup| DistManager
subgraph Distributed Inference Service
InferenceService[Distributed Inference Service]
SubmitTask[Submit Task]
Worker[Inference Worker Node]
ProcessRequest[Process Request]
RunPipeline[Run Inference Pipeline]
InferenceService --> SubmitTask
SubmitTask --> Worker
Worker --> ProcessRequest
ProcessRequest --> RunPipeline
end
DistService -.->|Reads| Config
ApiServer -.->|Reads| Config
%% ====== Connect Modules ======
TaskQueue --> ProcessingThread
ProcessTask --> VideoService
GenerateVideo --> InferenceService
GetTaskResult --> FileService
DownloadFile --> FileService
VideoService --> FileService
InferenceService --> TaskManager
TaskManager --> TaskStatus
```
## Task Processing Flow
......@@ -100,9 +106,9 @@ sequenceDiagram
participant PT as Processing Thread
participant VS as VideoService
participant FS as FileService
participant DIS as Distributed<br/>Inference Service
participant W0 as Worker 0<br/>(Master)
participant W1 as Worker 1..N
participant DIS as DistributedInferenceService
participant TIW0 as TorchrunInferenceWorker<br/>(Rank 0)
participant TIW1 as TorchrunInferenceWorker<br/>(Rank 1..N)
C->>API: POST /v1/tasks<br/>(Create Task)
API->>TM: create_task()
......@@ -127,32 +133,54 @@ sequenceDiagram
else Image is Base64
VS->>FS: save_base64_image()
FS-->>VS: image_path
else Image is Upload
VS->>FS: validate_file()
FS-->>VS: image_path
else Image is local path
VS->>VS: use existing path
end
alt Audio is URL
VS->>FS: download_audio()
FS->>FS: HTTP download<br/>with retry
FS-->>VS: audio_path
else Audio is Base64
VS->>FS: save_base64_audio()
FS-->>VS: audio_path
else Audio is local path
VS->>VS: use existing path
end
VS->>DIS: submit_task(task_data)
DIS->>DIS: shared_data["current_task"] = task_data
DIS->>DIS: task_event.set()
VS->>DIS: submit_task_async(task_data)
DIS->>TIW0: process_request(task_data)
Note over W0,W1: Distributed Processing
W0->>W0: task_event.wait()
W0->>W0: Get task from shared_data
W0->>W1: broadcast_task_data()
Note over TIW0,TIW1: Torchrun-based Distributed Processing
TIW0->>TIW0: Check if processing
TIW0->>TIW0: Set processing = True
par Parallel Inference
W0->>W0: run_pipeline()
alt Multi-GPU Mode (world_size > 1)
TIW0->>TIW1: broadcast_task_data()<br/>(via DistributedManager)
Note over TIW1: worker_loop() listens for broadcasts
TIW1->>TIW1: Receive task_data
end
par Parallel Inference across all ranks
TIW0->>TIW0: runner.set_inputs(task_data)
TIW0->>TIW0: runner.run_pipeline()
and
W1->>W1: run_pipeline()
Note over TIW1: If world_size > 1
TIW1->>TIW1: runner.set_inputs(task_data)
TIW1->>TIW1: runner.run_pipeline()
end
Note over TIW0,TIW1: Synchronization
alt Multi-GPU Mode
TIW0->>TIW1: barrier() for sync
TIW1->>TIW0: barrier() response
end
W0->>W0: barrier() for sync
W0->>W0: shared_data["result"] = result
W0->>DIS: result_event.set()
TIW0->>TIW0: Set processing = False
TIW0->>DIS: Return result (only rank 0)
TIW1->>TIW1: Return None (non-rank 0)
DIS->>DIS: result_event.wait()
DIS->>VS: return result
DIS-->>VS: TaskResponse
VS-->>PT: TaskResponse
PT->>TM: complete_task()<br/>(status: COMPLETED)
......
#!/usr/bin/env python
import argparse
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from lightx2v.server.main import run_server
from .main import run_server
def main():
parser = argparse.ArgumentParser(description="Run LightX2V inference server")
parser = argparse.ArgumentParser(description="LightX2V Server")
# Model arguments
parser.add_argument("--model_path", type=str, required=True, help="Path to model")
parser.add_argument("--model_cls", type=str, required=True, help="Model class name")
parser.add_argument("--config_json", type=str, help="Path to model config JSON file")
parser.add_argument("--task", type=str, default="i2v", help="Task type (i2v, etc.)")
parser.add_argument("--nproc_per_node", type=int, default=1, help="Number of processes per node (GPUs to use)")
# Server arguments
parser.add_argument("--host", type=str, default="0.0.0.0", help="Server host")
parser.add_argument("--port", type=int, default=8000, help="Server port")
parser.add_argument("--host", type=str, default="127.0.0.1", help="Server host")
args = parser.parse_args()
# Parse any additional arguments that might be passed
args, unknown = parser.parse_known_args()
# Add any unknown arguments as attributes to args
# This allows flexibility for model-specific arguments
for i in range(0, len(unknown), 2):
if unknown[i].startswith("--"):
key = unknown[i][2:]
if i + 1 < len(unknown) and not unknown[i + 1].startswith("--"):
value = unknown[i + 1]
setattr(args, key, value)
# Run the server
run_server(args)
......
......@@ -314,7 +314,6 @@ class ApiServer:
return False
def _ensure_processing_thread_running(self):
"""Ensure the processing thread is running."""
if self.processing_thread is None or not self.processing_thread.is_alive():
self.stop_processing.clear()
self.processing_thread = threading.Thread(target=self._task_processing_loop, daemon=True)
......@@ -322,9 +321,11 @@ class ApiServer:
logger.info("Started task processing thread")
def _task_processing_loop(self):
"""Main loop that processes tasks from the queue one by one."""
logger.info("Task processing loop started")
asyncio.set_event_loop(asyncio.new_event_loop())
loop = asyncio.get_event_loop()
while not self.stop_processing.is_set():
task_id = task_manager.get_next_pending_task()
......@@ -335,12 +336,12 @@ class ApiServer:
task_info = task_manager.get_task(task_id)
if task_info and task_info.status == TaskStatus.PENDING:
logger.info(f"Processing task {task_id}")
self._process_single_task(task_info)
loop.run_until_complete(self._process_single_task(task_info))
loop.close()
logger.info("Task processing loop stopped")
def _process_single_task(self, task_info: Any):
"""Process a single task."""
async def _process_single_task(self, task_info: Any):
assert self.video_service is not None, "Video service is not initialized"
task_id = task_info.task_id
......@@ -360,7 +361,7 @@ class ApiServer:
task_manager.fail_task(task_id, "Task cancelled")
return
result = asyncio.run(self.video_service.generate_video_with_stop_event(message, task_info.stop_event))
result = await self.video_service.generate_video_with_stop_event(message, task_info.stop_event)
if result:
task_manager.complete_task(task_id, result.save_video_path)
......
......@@ -11,9 +11,6 @@ class ServerConfig:
port: int = 8000
max_queue_size: int = 10
master_addr: str = "127.0.0.1"
master_port_range: tuple = (29500, 29600)
task_timeout: int = 300
task_history_limit: int = 1000
......@@ -42,31 +39,13 @@ class ServerConfig:
except ValueError:
logger.warning(f"Invalid max queue size: {env_queue_size}")
if env_master_addr := os.environ.get("MASTER_ADDR"):
config.master_addr = env_master_addr
# MASTER_ADDR is now managed by torchrun, no need to set manually
if env_cache_dir := os.environ.get("LIGHTX2V_CACHE_DIR"):
config.cache_dir = env_cache_dir
return config
def find_free_master_port(self) -> str:
import socket
for port in range(self.master_port_range[0], self.master_port_range[1]):
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind((self.master_addr, port))
logger.info(f"Found free port for master: {port}")
return str(port)
except OSError:
continue
raise RuntimeError(
f"No free port found for master in range {self.master_port_range[0]}-{self.master_port_range[1] - 1} "
f"on address {self.master_addr}. Please adjust 'master_port_range' or free an occupied port."
)
def validate(self) -> bool:
valid = True
......
......@@ -6,8 +6,6 @@ import torch
import torch.distributed as dist
from loguru import logger
from .gpu_manager import gpu_manager
class DistributedManager:
def __init__(self):
......@@ -18,29 +16,35 @@ class DistributedManager:
CHUNK_SIZE = 1024 * 1024
def init_process_group(self, rank: int, world_size: int, master_addr: str, master_port: str) -> bool:
def init_process_group(self) -> bool:
"""Initialize process group using torchrun environment variables"""
try:
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = master_port
backend = "nccl" if torch.cuda.is_available() else "gloo"
dist.init_process_group(backend=backend, init_method=f"tcp://{master_addr}:{master_port}", rank=rank, world_size=world_size)
logger.info(f"Setup backend: {backend}")
self.device = gpu_manager.set_device_for_rank(rank, world_size)
# torchrun sets these environment variables automatically
self.rank = int(os.environ.get("LOCAL_RANK", 0))
self.world_size = int(os.environ.get("WORLD_SIZE", 1))
if self.world_size > 1:
# torchrun handles backend, init_method, rank, and world_size
# We just need to call init_process_group without parameters
backend = "nccl" if torch.cuda.is_available() else "gloo"
dist.init_process_group(backend=backend, init_method="env://")
logger.info(f"Setup backend: {backend}")
# Set CUDA device for this rank
if torch.cuda.is_available():
torch.cuda.set_device(self.rank)
self.device = f"cuda:{self.rank}"
else:
self.device = "cpu"
else:
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.is_initialized = True
self.rank = rank
self.world_size = world_size
logger.info(f"Rank {rank}/{world_size - 1} distributed environment initialized successfully")
logger.info(f"Rank {self.rank}/{self.world_size - 1} distributed environment initialized successfully")
return True
except Exception as e:
logger.error(f"Rank {rank} distributed environment initialization failed: {str(e)}")
logger.error(f"Rank {self.rank} distributed environment initialization failed: {str(e)}")
return False
def cleanup(self):
......@@ -143,30 +147,3 @@ class DistributedManager:
task_bytes = self._receive_byte_chunks(total_length, broadcast_device)
task_data = pickle.loads(task_bytes)
return task_data
class DistributedWorker:
def __init__(self, rank: int, world_size: int, master_addr: str, master_port: str):
self.rank = rank
self.world_size = world_size
self.master_addr = master_addr
self.master_port = master_port
self.dist_manager = DistributedManager()
def init(self) -> bool:
return self.dist_manager.init_process_group(self.rank, self.world_size, self.master_addr, self.master_port)
def cleanup(self):
self.dist_manager.cleanup()
def sync_and_report(self, task_id: str, status: str, result_queue, **kwargs):
self.dist_manager.barrier()
if self.dist_manager.is_rank_zero():
result = {"task_id": task_id, "status": status, **kwargs}
result_queue.put(result)
logger.info(f"Task {task_id} {status}")
def create_distributed_worker(rank: int, world_size: int, master_addr: str, master_port: str) -> DistributedWorker:
return DistributedWorker(rank, world_size, master_addr, master_port)
import os
from typing import List, Optional, Tuple
import torch
from loguru import logger
class GPUManager:
def __init__(self):
self.available_gpus = self._detect_gpus()
self.gpu_count = len(self.available_gpus)
def _detect_gpus(self) -> List[int]:
if not torch.cuda.is_available():
logger.warning("No CUDA devices available, will use CPU")
return []
gpu_count = torch.cuda.device_count()
cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", "")
if cuda_visible:
try:
visible_devices = [int(d.strip()) for d in cuda_visible.split(",")]
logger.info(f"CUDA_VISIBLE_DEVICES set to: {visible_devices}")
return list(range(len(visible_devices)))
except ValueError:
logger.warning(f"Invalid CUDA_VISIBLE_DEVICES: {cuda_visible}, using all devices")
available_gpus = list(range(gpu_count))
logger.info(f"Detected {gpu_count} GPU devices: {available_gpus}")
return available_gpus
def get_device_for_rank(self, rank: int, world_size: int) -> str:
if not self.available_gpus:
logger.info(f"Rank {rank}: Using CPU (no GPUs available)")
return "cpu"
if self.gpu_count == 1:
device = f"cuda:{self.available_gpus[0]}"
logger.info(f"Rank {rank}: Using single GPU {device}")
return device
if self.gpu_count >= world_size:
gpu_id = self.available_gpus[rank % self.gpu_count]
device = f"cuda:{gpu_id}"
logger.info(f"Rank {rank}: Assigned to dedicated GPU {device}")
return device
else:
gpu_id = self.available_gpus[rank % self.gpu_count]
device = f"cuda:{gpu_id}"
logger.info(f"Rank {rank}: Sharing GPU {device} (world_size={world_size} > gpu_count={self.gpu_count})")
return device
def set_device_for_rank(self, rank: int, world_size: int) -> str:
device = self.get_device_for_rank(rank, world_size)
if device.startswith("cuda:"):
gpu_id = int(device.split(":")[1])
torch.cuda.set_device(gpu_id)
logger.info(f"Rank {rank}: CUDA device set to {gpu_id}")
return device
def get_memory_info(self, device: Optional[str] = None) -> Tuple[int, int]:
if not torch.cuda.is_available():
return (0, 0)
if device and device.startswith("cuda:"):
gpu_id = int(device.split(":")[1])
else:
gpu_id = torch.cuda.current_device()
try:
used = torch.cuda.memory_allocated(gpu_id)
total = torch.cuda.get_device_properties(gpu_id).total_memory
return (used, total)
except Exception as e:
logger.error(f"Failed to get memory info for device {gpu_id}: {e}")
return (0, 0)
def clear_cache(self, device: Optional[str] = None):
if not torch.cuda.is_available():
return
if device and device.startswith("cuda:"):
gpu_id = int(device.split(":")[1])
with torch.cuda.device(gpu_id):
torch.cuda.empty_cache()
torch.cuda.synchronize()
else:
torch.cuda.empty_cache()
torch.cuda.synchronize()
logger.info(f"GPU cache cleared for device: {device or 'current'}")
@staticmethod
def get_optimal_world_size(requested_world_size: int) -> int:
if not torch.cuda.is_available():
logger.warning("No GPUs available, using single process")
return 1
gpu_count = torch.cuda.device_count()
if requested_world_size <= 0:
optimal_size = gpu_count
logger.info(f"Auto-detected world_size: {optimal_size} (based on {gpu_count} GPUs)")
elif requested_world_size > gpu_count:
logger.warning(f"Requested world_size ({requested_world_size}) exceeds GPU count ({gpu_count}). Processes will share GPUs.")
optimal_size = requested_world_size
else:
optimal_size = requested_world_size
return optimal_size
gpu_manager = GPUManager()
import os
import sys
from pathlib import Path
......@@ -10,9 +11,14 @@ from .service import DistributedInferenceService
def run_server(args):
"""Run server with torchrun support"""
inference_service = None
try:
logger.info("Starting LightX2V server...")
# Get rank from environment (set by torchrun)
rank = int(os.environ.get("LOCAL_RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
logger.info(f"Starting LightX2V server (Rank {rank}/{world_size})...")
if hasattr(args, "host") and args.host:
server_config.host = args.host
......@@ -22,28 +28,37 @@ def run_server(args):
if not server_config.validate():
raise RuntimeError("Invalid server configuration")
# Initialize inference service
inference_service = DistributedInferenceService()
if not inference_service.start_distributed_inference(args):
raise RuntimeError("Failed to start distributed inference service")
logger.info("Inference service started successfully")
logger.info(f"Rank {rank}: Inference service started successfully")
if rank == 0:
# Only rank 0 runs the FastAPI server
cache_dir = Path(server_config.cache_dir)
cache_dir.mkdir(parents=True, exist_ok=True)
cache_dir = Path(server_config.cache_dir)
cache_dir.mkdir(parents=True, exist_ok=True)
api_server = ApiServer(max_queue_size=server_config.max_queue_size)
api_server.initialize_services(cache_dir, inference_service)
api_server = ApiServer(max_queue_size=server_config.max_queue_size)
api_server.initialize_services(cache_dir, inference_service)
app = api_server.get_app()
app = api_server.get_app()
logger.info(f"Starting FastAPI server on {server_config.host}:{server_config.port}")
uvicorn.run(app, host=server_config.host, port=server_config.port, log_level="info")
else:
# Non-rank-0 processes run the worker loop
logger.info(f"Rank {rank}: Starting worker loop")
import asyncio
logger.info(f"Starting server on {server_config.host}:{server_config.port}")
uvicorn.run(app, host=server_config.host, port=server_config.port, log_level="info")
asyncio.run(inference_service.run_worker_loop())
except KeyboardInterrupt:
logger.info("Server interrupted by user")
logger.info(f"Server rank {rank} interrupted by user")
if inference_service:
inference_service.stop_distributed_inference()
except Exception as e:
logger.error(f"Server failed: {e}")
logger.error(f"Server rank {rank} failed: {e}")
if inference_service:
inference_service.stop_distributed_inference()
sys.exit(1)
import asyncio
import threading
import time
import json
import os
import uuid
from pathlib import Path
from typing import Optional
from typing import Any, Dict, Optional
from urllib.parse import urlparse
import httpx
import torch.multiprocessing as mp
import torch
from loguru import logger
from ..infer import init_runner
from ..utils.set_config import set_config
from .audio_utils import is_base64_audio, save_base64_audio
from .config import server_config
from .distributed_utils import create_distributed_worker
from .distributed_utils import DistributedManager
from .image_utils import is_base64_image, save_base64_image
from .schema import TaskRequest, TaskResponse
mp.set_start_method("spawn", force=True)
class FileService:
def __init__(self, cache_dir: Path):
......@@ -196,280 +193,191 @@ class FileService:
self._http_client = None
def _distributed_inference_worker(rank, world_size, master_addr, master_port, args, shared_data, task_event, result_event):
task_data = None
worker = None
class TorchrunInferenceWorker:
"""Worker class for torchrun-based distributed inference"""
try:
logger.info(f"Process {rank}/{world_size - 1} initializing distributed inference service...")
def __init__(self):
self.rank = int(os.environ.get("LOCAL_RANK", 0))
self.world_size = int(os.environ.get("WORLD_SIZE", 1))
self.runner = None
self.dist_manager = DistributedManager()
self.request_queue = asyncio.Queue() if self.rank == 0 else None
self.processing = False # Track if currently processing a request
def init(self, args) -> bool:
"""Initialize the worker with model and distributed setup"""
try:
# Initialize distributed process group using torchrun env vars
if self.world_size > 1:
if not self.dist_manager.init_process_group():
raise RuntimeError("Failed to initialize distributed process group")
else:
# Single GPU mode
self.dist_manager.rank = 0
self.dist_manager.world_size = 1
self.dist_manager.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.dist_manager.is_initialized = False
worker = create_distributed_worker(rank, world_size, master_addr, master_port)
if not worker.init():
raise RuntimeError(f"Rank {rank} distributed environment initialization failed")
# Initialize model
config = set_config(args)
if self.rank == 0:
logger.info(f"Config:\n {json.dumps(config, ensure_ascii=False, indent=4)}")
config = set_config(args)
logger.info(f"Rank {rank} config: {config}")
self.runner = init_runner(config)
logger.info(f"Rank {self.rank}/{self.world_size - 1} initialization completed")
runner = init_runner(config)
logger.info(f"Process {rank}/{world_size - 1} distributed inference service initialization completed")
return True
while True:
if not task_event.wait(timeout=1.0):
continue
except Exception as e:
logger.error(f"Rank {self.rank} initialization failed: {str(e)}")
return False
if rank == 0:
if shared_data.get("stop", False):
logger.info(f"Process {rank} received stop signal, exiting inference service")
worker.dist_manager.broadcast_task_data(None)
break
async def process_request(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
"""Process a single inference request
task_data = shared_data.get("current_task")
if task_data:
worker.dist_manager.broadcast_task_data(task_data)
shared_data["current_task"] = None
try:
task_event.clear()
except Exception:
pass
else:
continue
Note: We keep the inference synchronous to maintain NCCL/CUDA context integrity.
The async wrapper allows FastAPI to handle other requests while this runs.
"""
try:
# Only rank 0 broadcasts task data (worker processes already received it in worker_loop)
if self.world_size > 1 and self.rank == 0:
task_data = self.dist_manager.broadcast_task_data(task_data)
# Run inference directly - torchrun handles the parallelization
# Using asyncio.to_thread would be risky with NCCL operations
# Instead, we rely on FastAPI's async handling and queue management
self.runner.set_inputs(task_data)
self.runner.run_pipeline()
# Small yield to allow other async operations if needed
await asyncio.sleep(0)
# Synchronize all ranks
if self.world_size > 1:
self.dist_manager.barrier()
# Only rank 0 returns the result
if self.rank == 0:
return {
"task_id": task_data["task_id"],
"status": "success",
"save_video_path": task_data.get("video_path", task_data["save_video_path"]),
"message": "Inference completed",
}
else:
return None
except Exception as e:
logger.error(f"Rank {self.rank} inference failed: {str(e)}")
if self.world_size > 1:
self.dist_manager.barrier()
if self.rank == 0:
return {
"task_id": task_data.get("task_id", "unknown"),
"status": "failed",
"error": str(e),
"message": f"Inference failed: {str(e)}",
}
else:
task_data = worker.dist_manager.broadcast_task_data()
return None
async def worker_loop(self):
"""Non-rank-0 workers: Listen for broadcast tasks"""
while True:
try:
task_data = self.dist_manager.broadcast_task_data()
if task_data is None:
logger.info(f"Process {rank} received stop signal, exiting inference service")
logger.info(f"Rank {self.rank} received stop signal")
break
if task_data is not None:
logger.info(f"Process {rank} received inference task: {task_data['task_id']}")
try:
runner.set_inputs(task_data) # type: ignore
runner.run_pipeline()
worker.dist_manager.barrier()
if rank == 0:
# Only rank 0 updates the result
shared_data["result"] = {
"task_id": task_data["task_id"],
"status": "success",
"save_video_path": task_data.get("video_path", task_data["save_video_path"]), # Return original path for API
"message": "Inference completed",
}
result_event.set()
logger.info(f"Task {task_data['task_id']} success")
except Exception as e:
logger.exception(f"Process {rank} error occurred while processing task: {str(e)}")
worker.dist_manager.barrier()
if rank == 0:
# Only rank 0 updates the result
shared_data["result"] = {
"task_id": task_data.get("task_id", "unknown"),
"status": "failed",
"error": str(e),
"message": f"Inference failed: {str(e)}",
}
result_event.set()
logger.info(f"Task {task_data.get('task_id', 'unknown')} failed")
except KeyboardInterrupt:
logger.info(f"Process {rank} received KeyboardInterrupt, gracefully exiting")
except Exception as e:
logger.exception(f"Distributed inference service process {rank} startup failed: {str(e)}")
if rank == 0:
shared_data["result"] = {
"task_id": "startup",
"status": "startup_failed",
"error": str(e),
"message": f"Inference service startup failed: {str(e)}",
}
result_event.set()
finally:
try:
if worker:
worker.cleanup()
except Exception as e:
logger.debug(f"Error cleaning up worker for rank {rank}: {e}")
await self.process_request(task_data)
except Exception as e:
logger.error(f"Rank {self.rank} worker loop error: {str(e)}")
continue
def cleanup(self):
self.dist_manager.cleanup()
class DistributedInferenceService:
def __init__(self):
self.manager = None
self.shared_data = None
self.task_event = None
self.result_event = None
self.processes = []
self.worker = None
self.is_running = False
self.args = None
def start_distributed_inference(self, args) -> bool:
if hasattr(args, "lora_path") and args.lora_path:
args.lora_configs = [{"path": args.lora_path, "strength": getattr(args, "lora_strength", 1.0)}]
delattr(args, "lora_path")
if hasattr(args, "lora_strength"):
delattr(args, "lora_strength")
self.args = args
if self.is_running:
logger.warning("Distributed inference service is already running")
return True
nproc_per_node = args.nproc_per_node
if nproc_per_node <= 0:
logger.error("nproc_per_node must be greater than 0")
return False
try:
master_addr = server_config.master_addr
master_port = server_config.find_free_master_port()
logger.info(f"Distributed inference service Master Addr: {master_addr}, Master Port: {master_port}")
# Create shared data structures
self.manager = mp.Manager()
self.shared_data = self.manager.dict()
self.task_event = self.manager.Event()
self.result_event = self.manager.Event()
# Initialize shared data
self.shared_data["current_task"] = None
self.shared_data["result"] = None
self.shared_data["stop"] = False
for rank in range(nproc_per_node):
p = mp.Process(
target=_distributed_inference_worker,
args=(
rank,
nproc_per_node,
master_addr,
master_port,
args,
self.shared_data,
self.task_event,
self.result_event,
),
daemon=False, # Changed to False for proper cleanup
)
p.start()
self.processes.append(p)
self.worker = TorchrunInferenceWorker()
if not self.worker.init(args):
raise RuntimeError("Worker initialization failed")
self.is_running = True
logger.info(f"Distributed inference service started successfully with {nproc_per_node} processes")
logger.info(f"Rank {self.worker.rank} inference service started successfully")
return True
except Exception as e:
logger.exception(f"Error occurred while starting distributed inference service: {str(e)}")
logger.error(f"Error starting inference service: {str(e)}")
self.stop_distributed_inference()
return False
def stop_distributed_inference(self):
assert self.task_event, "Task event is not initialized"
assert self.result_event, "Result event is not initialized"
if not self.is_running:
return
try:
logger.info(f"Stopping {len(self.processes)} distributed inference service processes...")
if self.shared_data is not None:
self.shared_data["stop"] = True
self.task_event.set()
for p in self.processes:
try:
p.join(timeout=10)
if p.is_alive():
logger.warning(f"Process {p.pid} did not end within the specified time, forcing termination...")
p.terminate()
p.join(timeout=5)
except Exception as e:
logger.warning(f"Error terminating process {p.pid}: {e}")
logger.info("All distributed inference service processes have stopped")
if self.worker:
self.worker.cleanup()
logger.info("Inference service stopped")
except Exception as e:
logger.error(f"Error occurred while stopping distributed inference service: {str(e)}")
logger.error(f"Error stopping inference service: {str(e)}")
finally:
# Clean up resources
self.processes = []
self.manager = None
self.shared_data = None
self.task_event = None
self.result_event = None
self.worker = None
self.is_running = False
def submit_task(self, task_data: dict) -> bool:
assert self.task_event, "Task event is not initialized"
assert self.result_event, "Result event is not initialized"
if not self.is_running or not self.shared_data:
logger.error("Distributed inference service is not started")
return False
try:
self.result_event.clear()
self.shared_data["result"] = None
self.shared_data["current_task"] = task_data
self.task_event.set() # Signal workers
return True
except Exception as e:
logger.error(f"Failed to submit task: {str(e)}")
return False
def wait_for_result(self, task_id: str, timeout: Optional[int] = None) -> Optional[dict]:
assert self.task_event, "Task event is not initialized"
assert self.result_event, "Result event is not initialized"
if timeout is None:
timeout = server_config.task_timeout
if not self.is_running or not self.shared_data:
async def submit_task_async(self, task_data: dict) -> Optional[dict]:
if not self.is_running or not self.worker:
logger.error("Inference service is not started")
return None
if self.result_event.wait(timeout=timeout):
result = self.shared_data.get("result")
if result and result.get("task_id") == task_id:
self.shared_data["current_task"] = None
self.task_event.clear()
return result
return None
def wait_for_result_with_stop(self, task_id: str, stop_event: threading.Event, timeout: Optional[int] = None) -> Optional[dict]:
if timeout is None:
timeout = server_config.task_timeout
if not self.is_running or not self.shared_data:
if self.worker.rank != 0:
return None
assert self.task_event, "Task event is not initialized"
assert self.result_event, "Result event is not initialized"
start_time = time.time()
while time.time() - start_time < timeout:
if stop_event.is_set():
logger.info(f"Task {task_id} stop event triggered during wait")
self.shared_data["current_task"] = None
self.task_event.clear()
return None
if self.result_event.wait(timeout=0.5):
result = self.shared_data.get("result")
if result and result.get("task_id") == task_id:
self.shared_data["current_task"] = None
self.task_event.clear()
return result
return None
try:
if self.worker.processing:
# If we want to support queueing, we can add the task to queue
# For now, we'll process sequentially
logger.info(f"Waiting for previous task to complete before processing task {task_data.get('task_id')}")
self.worker.processing = True
result = await self.worker.process_request(task_data)
self.worker.processing = False
return result
except Exception as e:
self.worker.processing = False
logger.error(f"Failed to process task: {str(e)}")
return {
"task_id": task_data.get("task_id", "unknown"),
"status": "failed",
"error": str(e),
"message": f"Task processing failed: {str(e)}",
}
def server_metadata(self):
assert hasattr(self, "args"), "Distributed inference service has not been started. Call start_distributed_inference() first."
return {"nproc_per_node": self.args.nproc_per_node, "model_cls": self.args.model_cls, "model_path": self.args.model_path}
return {"nproc_per_node": self.worker.world_size, "model_cls": self.args.model_cls, "model_path": self.args.model_path}
async def run_worker_loop(self):
"""Run the worker loop for non-rank-0 processes"""
if self.worker and self.worker.rank != 0:
await self.worker.worker_loop()
class VideoGenerationService:
......@@ -478,6 +386,7 @@ class VideoGenerationService:
self.inference_service = inference_service
async def generate_video_with_stop_event(self, message: TaskRequest, stop_event) -> Optional[TaskResponse]:
"""Generate video using torchrun-based inference"""
try:
task_data = {field: getattr(message, field) for field in message.model_fields_set if field != "task_id"}
task_data["task_id"] = message.task_id
......@@ -496,6 +405,8 @@ class VideoGenerationService:
else:
task_data["image_path"] = message.image_path
logger.info(f"Task {message.task_id} image path: {task_data['image_path']}")
if "audio_path" in message.model_fields_set and message.audio_path:
if message.audio_path.startswith("http"):
audio_path = await self.file_service.download_audio(message.audio_path)
......@@ -506,20 +417,19 @@ class VideoGenerationService:
else:
task_data["audio_path"] = message.audio_path
logger.info(f"Task {message.task_id} audio path: {task_data['audio_path']}")
actual_save_path = self.file_service.get_output_path(message.save_video_path)
task_data["save_video_path"] = str(actual_save_path)
task_data["video_path"] = message.save_video_path
if not self.inference_service.submit_task(task_data):
raise RuntimeError("Distributed inference service is not started")
result = self.inference_service.wait_for_result_with_stop(message.task_id, stop_event, timeout=300)
result = await self.inference_service.submit_task_async(task_data)
if result is None:
if stop_event.is_set():
logger.info(f"Task {message.task_id} cancelled during processing")
return None
raise RuntimeError("Task processing timeout")
raise RuntimeError("Task processing failed")
if result.get("status") == "success":
return TaskResponse(
......
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
lightx2v_path=/mnt/afs/users/lijiaqi2/deploy-comfyui-ljq-custom_nodes/ComfyUI-Lightx2vWrapper/lightx2v
model_path=/mnt/afs/users/lijiaqi2/wan_model/Wan2.1-R2V0909-Audio-14B-720P-fp8
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export CUDA_VISIBLE_DEVICES=0,1,2,3
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
# Start multiple servers
python -m lightx2v.api_multi_servers \
--num_gpus $num_gpus \
--start_port 8000 \
--model_cls wan2.1_distill \
torchrun --nproc_per_node 4 -m lightx2v.server \
--model_cls seko_talk \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/distill/wan_i2v_distill_4step_cfg.json
--config_json ${lightx2v_path}/configs/seko_talk/xxx_dist.json \
--port 8000
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
lightx2v_path=/path/to/Lightx2v
model_path=/path/to/Wan2.1-R2V0909-Audio-14B-720P-fp8
export CUDA_VISIBLE_DEVICES=0
......@@ -11,12 +11,11 @@ source ${lightx2v_path}/scripts/base/base.sh
# Start API server with distributed inference service
python -m lightx2v.api_server \
--model_cls wan2.1_distill \
python -m lightx2v.server \
--model_cls seko_talk \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/distill/wan_i2v_distill_4step_cfg.json \
--port 8000 \
--nproc_per_node 1
--config_json ${lightx2v_path}/configs/seko_talk/seko_talk_05_offload_fp8_4090.json \
--port 8000
echo "Service stopped"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment