Commit 3c778aee authored by PengGao's avatar PengGao Committed by GitHub
Browse files

Gp/dev (#310)


Co-authored-by: default avatarYang Yong(雍洋) <yongyang1030@163.com>
parent 32fd1c52
import argparse
import concurrent.futures
import os
import socket
import subprocess
import time
from dataclasses import dataclass
from typing import Optional
import requests
from loguru import logger
@dataclass
class ServerConfig:
port: int
gpu_id: int
model_cls: str
task: str
model_path: str
config_json: str
def get_node_ip() -> str:
"""Get the IP address of the current node"""
try:
# Create a UDP socket
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# Connect to an external address (no actual connection needed)
s.connect(("8.8.8.8", 80))
# Get local IP
ip = s.getsockname()[0]
s.close()
return ip
except Exception as e:
logger.error(f"Failed to get IP address: {e}")
return "localhost"
def is_port_in_use(port: int) -> bool:
"""Check if a port is in use"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(("localhost", port)) == 0
def find_available_port(start_port: int) -> Optional[int]:
"""Find an available port starting from start_port"""
port = start_port
while port < start_port + 1000: # Try up to 1000 ports
if not is_port_in_use(port):
return port
port += 1
return None
def start_server(config: ServerConfig) -> Optional[tuple[subprocess.Popen, str]]:
"""Start a single server instance"""
try:
# Set GPU
env = os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = str(config.gpu_id)
# Start server
process = subprocess.Popen(
[
"python",
"-m",
"lightx2v.api_server",
"--model_cls",
config.model_cls,
"--task",
config.task,
"--model_path",
config.model_path,
"--config_json",
config.config_json,
"--port",
str(config.port),
],
env=env,
)
# Wait for server to start, up to 600 seconds
node_ip = get_node_ip()
service_url = f"http://{node_ip}:{config.port}/v1/service/status"
# Check once per second, up to 600 times
for _ in range(600):
try:
response = requests.get(service_url, timeout=1)
if response.status_code == 200:
return process, f"http://{node_ip}:{config.port}"
except (requests.RequestException, ConnectionError) as e:
pass
time.sleep(1)
# If timeout, terminate the process
logger.error(f"Server startup timeout: port={config.port}, gpu={config.gpu_id}")
process.terminate()
return None
except Exception as e:
logger.error(f"Failed to start server: {e}")
return None
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--num_gpus", type=int, required=True, help="Number of GPUs to use")
parser.add_argument("--start_port", type=int, required=True, help="Starting port number")
parser.add_argument("--model_cls", type=str, required=True, help="Model class")
parser.add_argument("--task", type=str, required=True, help="Task type")
parser.add_argument("--model_path", type=str, required=True, help="Model path")
parser.add_argument("--config_json", type=str, required=True, help="Config file path")
args = parser.parse_args()
# Prepare configurations for all servers on this node
server_configs = []
current_port = args.start_port
# Create configs for each GPU on this node
for gpu in range(args.num_gpus):
port = find_available_port(current_port)
if port is None:
logger.error(f"Cannot find available port starting from {current_port}")
continue
config = ServerConfig(port=port, gpu_id=gpu, model_cls=args.model_cls, task=args.task, model_path=args.model_path, config_json=args.config_json)
server_configs.append(config)
current_port = port + 1
# Start all servers in parallel
processes = []
urls = []
with concurrent.futures.ThreadPoolExecutor(max_workers=len(server_configs)) as executor:
future_to_config = {executor.submit(start_server, config): config for config in server_configs}
for future in concurrent.futures.as_completed(future_to_config):
config = future_to_config[future]
try:
result = future.result()
if result:
process, url = result
processes.append(process)
urls.append(url)
logger.info(f"Server started successfully: {url} (GPU: {config.gpu_id})")
else:
logger.error(f"Failed to start server: port={config.port}, gpu={config.gpu_id}")
except Exception as e:
logger.error(f"Error occurred while starting server: {e}")
# Print all server URLs
logger.info("\nAll server URLs:")
for url in urls:
logger.info(url)
# Print node information
node_ip = get_node_ip()
logger.info(f"\nCurrent node IP: {node_ip}")
logger.info(f"Number of servers started: {len(urls)}")
try:
# Wait for all processes
for process in processes:
process.wait()
except KeyboardInterrupt:
logger.info("Received interrupt signal, shutting down all servers...")
for process in processes:
process.terminate()
if __name__ == "__main__":
main()
...@@ -9,85 +9,91 @@ The LightX2V server is a distributed video generation service built with FastAPI ...@@ -9,85 +9,91 @@ The LightX2V server is a distributed video generation service built with FastAPI
### System Architecture ### System Architecture
```mermaid ```mermaid
graph TB flowchart TB
subgraph "Client Layer" Client[Client] -->|Send API Request| Router[FastAPI Router]
Client[HTTP Client]
end
subgraph "API Layer" subgraph API Layer
FastAPI[FastAPI Application] Router --> TaskRoutes[Task APIs]
ApiServer[ApiServer] Router --> FileRoutes[File APIs]
Router1[Tasks Router<br/>/v1/tasks] Router --> ServiceRoutes[Service Status APIs]
Router2[Files Router<br/>/v1/files]
Router3[Service Router<br/>/v1/service]
end
subgraph "Service Layer" TaskRoutes --> CreateTask["POST /v1/tasks/ - Create Task"]
TaskManager[TaskManager<br/>Thread-safe Task Queue] TaskRoutes --> CreateTaskForm["POST /v1/tasks/form - Form Create"]
FileService[FileService<br/>File I/O & Downloads] TaskRoutes --> ListTasks["GET /v1/tasks/ - List Tasks"]
VideoService[VideoGenerationService] TaskRoutes --> GetTaskStatus["GET /v1/tasks/id/status - Get Status"]
end TaskRoutes --> GetTaskResult["GET /v1/tasks/id/result - Get Result"]
TaskRoutes --> StopTask["DELETE /v1/tasks/id - Stop Task"]
subgraph "Processing Layer"
Thread[Processing Thread<br/>Sequential Task Loop]
end
subgraph "Distributed Inference Layer" FileRoutes --> DownloadFile["GET /v1/files/download/path - Download File"]
DistService[DistributedInferenceService]
SharedData[(Shared Data<br/>mp.Manager.dict)]
TaskEvent[Task Event<br/>mp.Manager.Event]
ResultEvent[Result Event<br/>mp.Manager.Event]
subgraph "Worker Processes" ServiceRoutes --> GetServiceStatus["GET /v1/service/status - Service Status"]
W0[Worker 0<br/>Master/Rank 0] ServiceRoutes --> GetServiceMetadata["GET /v1/service/metadata - Metadata"]
W1[Worker 1<br/>Rank 1]
WN[Worker N<br/>Rank N]
end
end end
subgraph "Resource Management" subgraph Task Management
GPUManager[GPUManager<br/>GPU Detection & Allocation] TaskManager[Task Manager]
DistManager[DistributedManager<br/>PyTorch Distributed] TaskQueue[Task Queue]
Config[ServerConfig<br/>Configuration] TaskStatus[Task Status]
TaskResult[Task Result]
CreateTask --> TaskManager
CreateTaskForm --> TaskManager
TaskManager --> TaskQueue
TaskManager --> TaskStatus
TaskManager --> TaskResult
end end
Client -->|HTTP Request| FastAPI subgraph File Service
FastAPI --> ApiServer FileService[File Service]
ApiServer --> Router1 DownloadImage[Download Image]
ApiServer --> Router2 DownloadAudio[Download Audio]
ApiServer --> Router3 SaveFile[Save File]
GetOutputPath[Get Output Path]
Router1 -->|Create/Manage Tasks| TaskManager
Router1 -->|Process Tasks| Thread FileService --> DownloadImage
Router2 -->|File Operations| FileService FileService --> DownloadAudio
Router3 -->|Service Status| TaskManager FileService --> SaveFile
FileService --> GetOutputPath
Thread -->|Get Pending Tasks| TaskManager end
Thread -->|Generate Video| VideoService
VideoService -->|Download Images| FileService subgraph Processing Thread
VideoService -->|Submit Task| DistService ProcessingThread[Processing Thread]
NextTask[Get Next Task]
ProcessTask[Process Single Task]
DistService -->|Update| SharedData ProcessingThread --> NextTask
DistService -->|Signal| TaskEvent ProcessingThread --> ProcessTask
TaskEvent -->|Notify| W0 end
W0 -->|Broadcast| W1
W0 -->|Broadcast| WN
W0 -->|Update Result| SharedData subgraph Video Generation Service
W0 -->|Signal| ResultEvent VideoService[Video Service]
ResultEvent -->|Notify| DistService GenerateVideo[Generate Video]
W0 -.->|Uses| GPUManager VideoService --> GenerateVideo
W1 -.->|Uses| GPUManager end
WN -.->|Uses| GPUManager
W0 -.->|Setup| DistManager subgraph Distributed Inference Service
W1 -.->|Setup| DistManager InferenceService[Distributed Inference Service]
WN -.->|Setup| DistManager SubmitTask[Submit Task]
Worker[Inference Worker Node]
ProcessRequest[Process Request]
RunPipeline[Run Inference Pipeline]
InferenceService --> SubmitTask
SubmitTask --> Worker
Worker --> ProcessRequest
ProcessRequest --> RunPipeline
end
DistService -.->|Reads| Config %% ====== Connect Modules ======
ApiServer -.->|Reads| Config TaskQueue --> ProcessingThread
ProcessTask --> VideoService
GenerateVideo --> InferenceService
GetTaskResult --> FileService
DownloadFile --> FileService
VideoService --> FileService
InferenceService --> TaskManager
TaskManager --> TaskStatus
``` ```
## Task Processing Flow ## Task Processing Flow
...@@ -100,9 +106,9 @@ sequenceDiagram ...@@ -100,9 +106,9 @@ sequenceDiagram
participant PT as Processing Thread participant PT as Processing Thread
participant VS as VideoService participant VS as VideoService
participant FS as FileService participant FS as FileService
participant DIS as Distributed<br/>Inference Service participant DIS as DistributedInferenceService
participant W0 as Worker 0<br/>(Master) participant TIW0 as TorchrunInferenceWorker<br/>(Rank 0)
participant W1 as Worker 1..N participant TIW1 as TorchrunInferenceWorker<br/>(Rank 1..N)
C->>API: POST /v1/tasks<br/>(Create Task) C->>API: POST /v1/tasks<br/>(Create Task)
API->>TM: create_task() API->>TM: create_task()
...@@ -127,32 +133,54 @@ sequenceDiagram ...@@ -127,32 +133,54 @@ sequenceDiagram
else Image is Base64 else Image is Base64
VS->>FS: save_base64_image() VS->>FS: save_base64_image()
FS-->>VS: image_path FS-->>VS: image_path
else Image is Upload else Image is local path
VS->>FS: validate_file() VS->>VS: use existing path
FS-->>VS: image_path end
alt Audio is URL
VS->>FS: download_audio()
FS->>FS: HTTP download<br/>with retry
FS-->>VS: audio_path
else Audio is Base64
VS->>FS: save_base64_audio()
FS-->>VS: audio_path
else Audio is local path
VS->>VS: use existing path
end end
VS->>DIS: submit_task(task_data) VS->>DIS: submit_task_async(task_data)
DIS->>DIS: shared_data["current_task"] = task_data DIS->>TIW0: process_request(task_data)
DIS->>DIS: task_event.set()
Note over TIW0,TIW1: Torchrun-based Distributed Processing
TIW0->>TIW0: Check if processing
TIW0->>TIW0: Set processing = True
Note over W0,W1: Distributed Processing alt Multi-GPU Mode (world_size > 1)
W0->>W0: task_event.wait() TIW0->>TIW1: broadcast_task_data()<br/>(via DistributedManager)
W0->>W0: Get task from shared_data Note over TIW1: worker_loop() listens for broadcasts
W0->>W1: broadcast_task_data() TIW1->>TIW1: Receive task_data
end
par Parallel Inference par Parallel Inference across all ranks
W0->>W0: run_pipeline() TIW0->>TIW0: runner.set_inputs(task_data)
TIW0->>TIW0: runner.run_pipeline()
and and
W1->>W1: run_pipeline() Note over TIW1: If world_size > 1
TIW1->>TIW1: runner.set_inputs(task_data)
TIW1->>TIW1: runner.run_pipeline()
end
Note over TIW0,TIW1: Synchronization
alt Multi-GPU Mode
TIW0->>TIW1: barrier() for sync
TIW1->>TIW0: barrier() response
end end
W0->>W0: barrier() for sync TIW0->>TIW0: Set processing = False
W0->>W0: shared_data["result"] = result TIW0->>DIS: Return result (only rank 0)
W0->>DIS: result_event.set() TIW1->>TIW1: Return None (non-rank 0)
DIS->>DIS: result_event.wait() DIS-->>VS: TaskResponse
DIS->>VS: return result
VS-->>PT: TaskResponse VS-->>PT: TaskResponse
PT->>TM: complete_task()<br/>(status: COMPLETED) PT->>TM: complete_task()<br/>(status: COMPLETED)
......
#!/usr/bin/env python
import argparse import argparse
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent)) from .main import run_server
from lightx2v.server.main import run_server
def main(): def main():
parser = argparse.ArgumentParser(description="Run LightX2V inference server") parser = argparse.ArgumentParser(description="LightX2V Server")
# Model arguments
parser.add_argument("--model_path", type=str, required=True, help="Path to model") parser.add_argument("--model_path", type=str, required=True, help="Path to model")
parser.add_argument("--model_cls", type=str, required=True, help="Model class name") parser.add_argument("--model_cls", type=str, required=True, help="Model class name")
parser.add_argument("--config_json", type=str, help="Path to model config JSON file")
parser.add_argument("--task", type=str, default="i2v", help="Task type (i2v, etc.)")
parser.add_argument("--nproc_per_node", type=int, default=1, help="Number of processes per node (GPUs to use)")
# Server arguments
parser.add_argument("--host", type=str, default="0.0.0.0", help="Server host")
parser.add_argument("--port", type=int, default=8000, help="Server port") parser.add_argument("--port", type=int, default=8000, help="Server port")
parser.add_argument("--host", type=str, default="127.0.0.1", help="Server host")
args = parser.parse_args() # Parse any additional arguments that might be passed
args, unknown = parser.parse_known_args()
# Add any unknown arguments as attributes to args
# This allows flexibility for model-specific arguments
for i in range(0, len(unknown), 2):
if unknown[i].startswith("--"):
key = unknown[i][2:]
if i + 1 < len(unknown) and not unknown[i + 1].startswith("--"):
value = unknown[i + 1]
setattr(args, key, value)
# Run the server
run_server(args) run_server(args)
......
...@@ -314,7 +314,6 @@ class ApiServer: ...@@ -314,7 +314,6 @@ class ApiServer:
return False return False
def _ensure_processing_thread_running(self): def _ensure_processing_thread_running(self):
"""Ensure the processing thread is running."""
if self.processing_thread is None or not self.processing_thread.is_alive(): if self.processing_thread is None or not self.processing_thread.is_alive():
self.stop_processing.clear() self.stop_processing.clear()
self.processing_thread = threading.Thread(target=self._task_processing_loop, daemon=True) self.processing_thread = threading.Thread(target=self._task_processing_loop, daemon=True)
...@@ -322,9 +321,11 @@ class ApiServer: ...@@ -322,9 +321,11 @@ class ApiServer:
logger.info("Started task processing thread") logger.info("Started task processing thread")
def _task_processing_loop(self): def _task_processing_loop(self):
"""Main loop that processes tasks from the queue one by one."""
logger.info("Task processing loop started") logger.info("Task processing loop started")
asyncio.set_event_loop(asyncio.new_event_loop())
loop = asyncio.get_event_loop()
while not self.stop_processing.is_set(): while not self.stop_processing.is_set():
task_id = task_manager.get_next_pending_task() task_id = task_manager.get_next_pending_task()
...@@ -335,12 +336,12 @@ class ApiServer: ...@@ -335,12 +336,12 @@ class ApiServer:
task_info = task_manager.get_task(task_id) task_info = task_manager.get_task(task_id)
if task_info and task_info.status == TaskStatus.PENDING: if task_info and task_info.status == TaskStatus.PENDING:
logger.info(f"Processing task {task_id}") logger.info(f"Processing task {task_id}")
self._process_single_task(task_info) loop.run_until_complete(self._process_single_task(task_info))
loop.close()
logger.info("Task processing loop stopped") logger.info("Task processing loop stopped")
def _process_single_task(self, task_info: Any): async def _process_single_task(self, task_info: Any):
"""Process a single task."""
assert self.video_service is not None, "Video service is not initialized" assert self.video_service is not None, "Video service is not initialized"
task_id = task_info.task_id task_id = task_info.task_id
...@@ -360,7 +361,7 @@ class ApiServer: ...@@ -360,7 +361,7 @@ class ApiServer:
task_manager.fail_task(task_id, "Task cancelled") task_manager.fail_task(task_id, "Task cancelled")
return return
result = asyncio.run(self.video_service.generate_video_with_stop_event(message, task_info.stop_event)) result = await self.video_service.generate_video_with_stop_event(message, task_info.stop_event)
if result: if result:
task_manager.complete_task(task_id, result.save_video_path) task_manager.complete_task(task_id, result.save_video_path)
......
...@@ -11,9 +11,6 @@ class ServerConfig: ...@@ -11,9 +11,6 @@ class ServerConfig:
port: int = 8000 port: int = 8000
max_queue_size: int = 10 max_queue_size: int = 10
master_addr: str = "127.0.0.1"
master_port_range: tuple = (29500, 29600)
task_timeout: int = 300 task_timeout: int = 300
task_history_limit: int = 1000 task_history_limit: int = 1000
...@@ -42,31 +39,13 @@ class ServerConfig: ...@@ -42,31 +39,13 @@ class ServerConfig:
except ValueError: except ValueError:
logger.warning(f"Invalid max queue size: {env_queue_size}") logger.warning(f"Invalid max queue size: {env_queue_size}")
if env_master_addr := os.environ.get("MASTER_ADDR"): # MASTER_ADDR is now managed by torchrun, no need to set manually
config.master_addr = env_master_addr
if env_cache_dir := os.environ.get("LIGHTX2V_CACHE_DIR"): if env_cache_dir := os.environ.get("LIGHTX2V_CACHE_DIR"):
config.cache_dir = env_cache_dir config.cache_dir = env_cache_dir
return config return config
def find_free_master_port(self) -> str:
import socket
for port in range(self.master_port_range[0], self.master_port_range[1]):
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind((self.master_addr, port))
logger.info(f"Found free port for master: {port}")
return str(port)
except OSError:
continue
raise RuntimeError(
f"No free port found for master in range {self.master_port_range[0]}-{self.master_port_range[1] - 1} "
f"on address {self.master_addr}. Please adjust 'master_port_range' or free an occupied port."
)
def validate(self) -> bool: def validate(self) -> bool:
valid = True valid = True
......
...@@ -6,8 +6,6 @@ import torch ...@@ -6,8 +6,6 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from loguru import logger from loguru import logger
from .gpu_manager import gpu_manager
class DistributedManager: class DistributedManager:
def __init__(self): def __init__(self):
...@@ -18,29 +16,35 @@ class DistributedManager: ...@@ -18,29 +16,35 @@ class DistributedManager:
CHUNK_SIZE = 1024 * 1024 CHUNK_SIZE = 1024 * 1024
def init_process_group(self, rank: int, world_size: int, master_addr: str, master_port: str) -> bool: def init_process_group(self) -> bool:
"""Initialize process group using torchrun environment variables"""
try: try:
os.environ["RANK"] = str(rank) # torchrun sets these environment variables automatically
os.environ["WORLD_SIZE"] = str(world_size) self.rank = int(os.environ.get("LOCAL_RANK", 0))
os.environ["MASTER_ADDR"] = master_addr self.world_size = int(os.environ.get("WORLD_SIZE", 1))
os.environ["MASTER_PORT"] = master_port
if self.world_size > 1:
# torchrun handles backend, init_method, rank, and world_size
# We just need to call init_process_group without parameters
backend = "nccl" if torch.cuda.is_available() else "gloo" backend = "nccl" if torch.cuda.is_available() else "gloo"
dist.init_process_group(backend=backend, init_method="env://")
dist.init_process_group(backend=backend, init_method=f"tcp://{master_addr}:{master_port}", rank=rank, world_size=world_size)
logger.info(f"Setup backend: {backend}") logger.info(f"Setup backend: {backend}")
self.device = gpu_manager.set_device_for_rank(rank, world_size) # Set CUDA device for this rank
if torch.cuda.is_available():
torch.cuda.set_device(self.rank)
self.device = f"cuda:{self.rank}"
else:
self.device = "cpu"
else:
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.is_initialized = True self.is_initialized = True
self.rank = rank logger.info(f"Rank {self.rank}/{self.world_size - 1} distributed environment initialized successfully")
self.world_size = world_size
logger.info(f"Rank {rank}/{world_size - 1} distributed environment initialized successfully")
return True return True
except Exception as e: except Exception as e:
logger.error(f"Rank {rank} distributed environment initialization failed: {str(e)}") logger.error(f"Rank {self.rank} distributed environment initialization failed: {str(e)}")
return False return False
def cleanup(self): def cleanup(self):
...@@ -143,30 +147,3 @@ class DistributedManager: ...@@ -143,30 +147,3 @@ class DistributedManager:
task_bytes = self._receive_byte_chunks(total_length, broadcast_device) task_bytes = self._receive_byte_chunks(total_length, broadcast_device)
task_data = pickle.loads(task_bytes) task_data = pickle.loads(task_bytes)
return task_data return task_data
class DistributedWorker:
def __init__(self, rank: int, world_size: int, master_addr: str, master_port: str):
self.rank = rank
self.world_size = world_size
self.master_addr = master_addr
self.master_port = master_port
self.dist_manager = DistributedManager()
def init(self) -> bool:
return self.dist_manager.init_process_group(self.rank, self.world_size, self.master_addr, self.master_port)
def cleanup(self):
self.dist_manager.cleanup()
def sync_and_report(self, task_id: str, status: str, result_queue, **kwargs):
self.dist_manager.barrier()
if self.dist_manager.is_rank_zero():
result = {"task_id": task_id, "status": status, **kwargs}
result_queue.put(result)
logger.info(f"Task {task_id} {status}")
def create_distributed_worker(rank: int, world_size: int, master_addr: str, master_port: str) -> DistributedWorker:
return DistributedWorker(rank, world_size, master_addr, master_port)
import os
from typing import List, Optional, Tuple
import torch
from loguru import logger
class GPUManager:
def __init__(self):
self.available_gpus = self._detect_gpus()
self.gpu_count = len(self.available_gpus)
def _detect_gpus(self) -> List[int]:
if not torch.cuda.is_available():
logger.warning("No CUDA devices available, will use CPU")
return []
gpu_count = torch.cuda.device_count()
cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", "")
if cuda_visible:
try:
visible_devices = [int(d.strip()) for d in cuda_visible.split(",")]
logger.info(f"CUDA_VISIBLE_DEVICES set to: {visible_devices}")
return list(range(len(visible_devices)))
except ValueError:
logger.warning(f"Invalid CUDA_VISIBLE_DEVICES: {cuda_visible}, using all devices")
available_gpus = list(range(gpu_count))
logger.info(f"Detected {gpu_count} GPU devices: {available_gpus}")
return available_gpus
def get_device_for_rank(self, rank: int, world_size: int) -> str:
if not self.available_gpus:
logger.info(f"Rank {rank}: Using CPU (no GPUs available)")
return "cpu"
if self.gpu_count == 1:
device = f"cuda:{self.available_gpus[0]}"
logger.info(f"Rank {rank}: Using single GPU {device}")
return device
if self.gpu_count >= world_size:
gpu_id = self.available_gpus[rank % self.gpu_count]
device = f"cuda:{gpu_id}"
logger.info(f"Rank {rank}: Assigned to dedicated GPU {device}")
return device
else:
gpu_id = self.available_gpus[rank % self.gpu_count]
device = f"cuda:{gpu_id}"
logger.info(f"Rank {rank}: Sharing GPU {device} (world_size={world_size} > gpu_count={self.gpu_count})")
return device
def set_device_for_rank(self, rank: int, world_size: int) -> str:
device = self.get_device_for_rank(rank, world_size)
if device.startswith("cuda:"):
gpu_id = int(device.split(":")[1])
torch.cuda.set_device(gpu_id)
logger.info(f"Rank {rank}: CUDA device set to {gpu_id}")
return device
def get_memory_info(self, device: Optional[str] = None) -> Tuple[int, int]:
if not torch.cuda.is_available():
return (0, 0)
if device and device.startswith("cuda:"):
gpu_id = int(device.split(":")[1])
else:
gpu_id = torch.cuda.current_device()
try:
used = torch.cuda.memory_allocated(gpu_id)
total = torch.cuda.get_device_properties(gpu_id).total_memory
return (used, total)
except Exception as e:
logger.error(f"Failed to get memory info for device {gpu_id}: {e}")
return (0, 0)
def clear_cache(self, device: Optional[str] = None):
if not torch.cuda.is_available():
return
if device and device.startswith("cuda:"):
gpu_id = int(device.split(":")[1])
with torch.cuda.device(gpu_id):
torch.cuda.empty_cache()
torch.cuda.synchronize()
else:
torch.cuda.empty_cache()
torch.cuda.synchronize()
logger.info(f"GPU cache cleared for device: {device or 'current'}")
@staticmethod
def get_optimal_world_size(requested_world_size: int) -> int:
if not torch.cuda.is_available():
logger.warning("No GPUs available, using single process")
return 1
gpu_count = torch.cuda.device_count()
if requested_world_size <= 0:
optimal_size = gpu_count
logger.info(f"Auto-detected world_size: {optimal_size} (based on {gpu_count} GPUs)")
elif requested_world_size > gpu_count:
logger.warning(f"Requested world_size ({requested_world_size}) exceeds GPU count ({gpu_count}). Processes will share GPUs.")
optimal_size = requested_world_size
else:
optimal_size = requested_world_size
return optimal_size
gpu_manager = GPUManager()
import os
import sys import sys
from pathlib import Path from pathlib import Path
...@@ -10,9 +11,14 @@ from .service import DistributedInferenceService ...@@ -10,9 +11,14 @@ from .service import DistributedInferenceService
def run_server(args): def run_server(args):
"""Run server with torchrun support"""
inference_service = None inference_service = None
try: try:
logger.info("Starting LightX2V server...") # Get rank from environment (set by torchrun)
rank = int(os.environ.get("LOCAL_RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
logger.info(f"Starting LightX2V server (Rank {rank}/{world_size})...")
if hasattr(args, "host") and args.host: if hasattr(args, "host") and args.host:
server_config.host = args.host server_config.host = args.host
...@@ -22,11 +28,14 @@ def run_server(args): ...@@ -22,11 +28,14 @@ def run_server(args):
if not server_config.validate(): if not server_config.validate():
raise RuntimeError("Invalid server configuration") raise RuntimeError("Invalid server configuration")
# Initialize inference service
inference_service = DistributedInferenceService() inference_service = DistributedInferenceService()
if not inference_service.start_distributed_inference(args): if not inference_service.start_distributed_inference(args):
raise RuntimeError("Failed to start distributed inference service") raise RuntimeError("Failed to start distributed inference service")
logger.info("Inference service started successfully") logger.info(f"Rank {rank}: Inference service started successfully")
if rank == 0:
# Only rank 0 runs the FastAPI server
cache_dir = Path(server_config.cache_dir) cache_dir = Path(server_config.cache_dir)
cache_dir.mkdir(parents=True, exist_ok=True) cache_dir.mkdir(parents=True, exist_ok=True)
...@@ -35,15 +44,21 @@ def run_server(args): ...@@ -35,15 +44,21 @@ def run_server(args):
app = api_server.get_app() app = api_server.get_app()
logger.info(f"Starting server on {server_config.host}:{server_config.port}") logger.info(f"Starting FastAPI server on {server_config.host}:{server_config.port}")
uvicorn.run(app, host=server_config.host, port=server_config.port, log_level="info") uvicorn.run(app, host=server_config.host, port=server_config.port, log_level="info")
else:
# Non-rank-0 processes run the worker loop
logger.info(f"Rank {rank}: Starting worker loop")
import asyncio
asyncio.run(inference_service.run_worker_loop())
except KeyboardInterrupt: except KeyboardInterrupt:
logger.info("Server interrupted by user") logger.info(f"Server rank {rank} interrupted by user")
if inference_service: if inference_service:
inference_service.stop_distributed_inference() inference_service.stop_distributed_inference()
except Exception as e: except Exception as e:
logger.error(f"Server failed: {e}") logger.error(f"Server rank {rank} failed: {e}")
if inference_service: if inference_service:
inference_service.stop_distributed_inference() inference_service.stop_distributed_inference()
sys.exit(1) sys.exit(1)
import asyncio import asyncio
import threading import json
import time import os
import uuid import uuid
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Any, Dict, Optional
from urllib.parse import urlparse from urllib.parse import urlparse
import httpx import httpx
import torch.multiprocessing as mp import torch
from loguru import logger from loguru import logger
from ..infer import init_runner from ..infer import init_runner
from ..utils.set_config import set_config from ..utils.set_config import set_config
from .audio_utils import is_base64_audio, save_base64_audio from .audio_utils import is_base64_audio, save_base64_audio
from .config import server_config from .distributed_utils import DistributedManager
from .distributed_utils import create_distributed_worker
from .image_utils import is_base64_image, save_base64_image from .image_utils import is_base64_image, save_base64_image
from .schema import TaskRequest, TaskResponse from .schema import TaskRequest, TaskResponse
mp.set_start_method("spawn", force=True)
class FileService: class FileService:
def __init__(self, cache_dir: Path): def __init__(self, cache_dir: Path):
...@@ -196,280 +193,191 @@ class FileService: ...@@ -196,280 +193,191 @@ class FileService:
self._http_client = None self._http_client = None
def _distributed_inference_worker(rank, world_size, master_addr, master_port, args, shared_data, task_event, result_event): class TorchrunInferenceWorker:
task_data = None """Worker class for torchrun-based distributed inference"""
worker = None
def __init__(self):
self.rank = int(os.environ.get("LOCAL_RANK", 0))
self.world_size = int(os.environ.get("WORLD_SIZE", 1))
self.runner = None
self.dist_manager = DistributedManager()
self.request_queue = asyncio.Queue() if self.rank == 0 else None
self.processing = False # Track if currently processing a request
def init(self, args) -> bool:
"""Initialize the worker with model and distributed setup"""
try: try:
logger.info(f"Process {rank}/{world_size - 1} initializing distributed inference service...") # Initialize distributed process group using torchrun env vars
if self.world_size > 1:
worker = create_distributed_worker(rank, world_size, master_addr, master_port) if not self.dist_manager.init_process_group():
if not worker.init(): raise RuntimeError("Failed to initialize distributed process group")
raise RuntimeError(f"Rank {rank} distributed environment initialization failed") else:
# Single GPU mode
self.dist_manager.rank = 0
self.dist_manager.world_size = 1
self.dist_manager.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.dist_manager.is_initialized = False
# Initialize model
config = set_config(args) config = set_config(args)
logger.info(f"Rank {rank} config: {config}") if self.rank == 0:
logger.info(f"Config:\n {json.dumps(config, ensure_ascii=False, indent=4)}")
runner = init_runner(config)
logger.info(f"Process {rank}/{world_size - 1} distributed inference service initialization completed")
while True: self.runner = init_runner(config)
if not task_event.wait(timeout=1.0): logger.info(f"Rank {self.rank}/{self.world_size - 1} initialization completed")
continue
if rank == 0: return True
if shared_data.get("stop", False):
logger.info(f"Process {rank} received stop signal, exiting inference service")
worker.dist_manager.broadcast_task_data(None)
break
task_data = shared_data.get("current_task") except Exception as e:
if task_data: logger.error(f"Rank {self.rank} initialization failed: {str(e)}")
worker.dist_manager.broadcast_task_data(task_data) return False
shared_data["current_task"] = None
try:
task_event.clear()
except Exception:
pass
else:
continue
else:
task_data = worker.dist_manager.broadcast_task_data()
if task_data is None:
logger.info(f"Process {rank} received stop signal, exiting inference service")
break
if task_data is not None: async def process_request(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
logger.info(f"Process {rank} received inference task: {task_data['task_id']}") """Process a single inference request
Note: We keep the inference synchronous to maintain NCCL/CUDA context integrity.
The async wrapper allows FastAPI to handle other requests while this runs.
"""
try: try:
runner.set_inputs(task_data) # type: ignore # Only rank 0 broadcasts task data (worker processes already received it in worker_loop)
runner.run_pipeline() if self.world_size > 1 and self.rank == 0:
task_data = self.dist_manager.broadcast_task_data(task_data)
worker.dist_manager.barrier()
# Run inference directly - torchrun handles the parallelization
if rank == 0: # Using asyncio.to_thread would be risky with NCCL operations
# Only rank 0 updates the result # Instead, we rely on FastAPI's async handling and queue management
shared_data["result"] = { self.runner.set_inputs(task_data)
self.runner.run_pipeline()
# Small yield to allow other async operations if needed
await asyncio.sleep(0)
# Synchronize all ranks
if self.world_size > 1:
self.dist_manager.barrier()
# Only rank 0 returns the result
if self.rank == 0:
return {
"task_id": task_data["task_id"], "task_id": task_data["task_id"],
"status": "success", "status": "success",
"save_video_path": task_data.get("video_path", task_data["save_video_path"]), # Return original path for API "save_video_path": task_data.get("video_path", task_data["save_video_path"]),
"message": "Inference completed", "message": "Inference completed",
} }
result_event.set() else:
logger.info(f"Task {task_data['task_id']} success") return None
except Exception as e: except Exception as e:
logger.exception(f"Process {rank} error occurred while processing task: {str(e)}") logger.error(f"Rank {self.rank} inference failed: {str(e)}")
if self.world_size > 1:
worker.dist_manager.barrier() self.dist_manager.barrier()
if rank == 0: if self.rank == 0:
# Only rank 0 updates the result return {
shared_data["result"] = {
"task_id": task_data.get("task_id", "unknown"), "task_id": task_data.get("task_id", "unknown"),
"status": "failed", "status": "failed",
"error": str(e), "error": str(e),
"message": f"Inference failed: {str(e)}", "message": f"Inference failed: {str(e)}",
} }
result_event.set() else:
logger.info(f"Task {task_data.get('task_id', 'unknown')} failed") return None
except KeyboardInterrupt: async def worker_loop(self):
logger.info(f"Process {rank} received KeyboardInterrupt, gracefully exiting") """Non-rank-0 workers: Listen for broadcast tasks"""
except Exception as e: while True:
logger.exception(f"Distributed inference service process {rank} startup failed: {str(e)}")
if rank == 0:
shared_data["result"] = {
"task_id": "startup",
"status": "startup_failed",
"error": str(e),
"message": f"Inference service startup failed: {str(e)}",
}
result_event.set()
finally:
try: try:
if worker: task_data = self.dist_manager.broadcast_task_data()
worker.cleanup() if task_data is None:
logger.info(f"Rank {self.rank} received stop signal")
break
await self.process_request(task_data)
except Exception as e: except Exception as e:
logger.debug(f"Error cleaning up worker for rank {rank}: {e}") logger.error(f"Rank {self.rank} worker loop error: {str(e)}")
continue
def cleanup(self):
self.dist_manager.cleanup()
class DistributedInferenceService: class DistributedInferenceService:
def __init__(self): def __init__(self):
self.manager = None self.worker = None
self.shared_data = None
self.task_event = None
self.result_event = None
self.processes = []
self.is_running = False self.is_running = False
self.args = None
def start_distributed_inference(self, args) -> bool: def start_distributed_inference(self, args) -> bool:
if hasattr(args, "lora_path") and args.lora_path:
args.lora_configs = [{"path": args.lora_path, "strength": getattr(args, "lora_strength", 1.0)}]
delattr(args, "lora_path")
if hasattr(args, "lora_strength"):
delattr(args, "lora_strength")
self.args = args self.args = args
if self.is_running: if self.is_running:
logger.warning("Distributed inference service is already running") logger.warning("Distributed inference service is already running")
return True return True
nproc_per_node = args.nproc_per_node
if nproc_per_node <= 0:
logger.error("nproc_per_node must be greater than 0")
return False
try: try:
master_addr = server_config.master_addr self.worker = TorchrunInferenceWorker()
master_port = server_config.find_free_master_port()
logger.info(f"Distributed inference service Master Addr: {master_addr}, Master Port: {master_port}") if not self.worker.init(args):
raise RuntimeError("Worker initialization failed")
# Create shared data structures
self.manager = mp.Manager()
self.shared_data = self.manager.dict()
self.task_event = self.manager.Event()
self.result_event = self.manager.Event()
# Initialize shared data
self.shared_data["current_task"] = None
self.shared_data["result"] = None
self.shared_data["stop"] = False
for rank in range(nproc_per_node):
p = mp.Process(
target=_distributed_inference_worker,
args=(
rank,
nproc_per_node,
master_addr,
master_port,
args,
self.shared_data,
self.task_event,
self.result_event,
),
daemon=False, # Changed to False for proper cleanup
)
p.start()
self.processes.append(p)
self.is_running = True self.is_running = True
logger.info(f"Distributed inference service started successfully with {nproc_per_node} processes") logger.info(f"Rank {self.worker.rank} inference service started successfully")
return True return True
except Exception as e: except Exception as e:
logger.exception(f"Error occurred while starting distributed inference service: {str(e)}") logger.error(f"Error starting inference service: {str(e)}")
self.stop_distributed_inference() self.stop_distributed_inference()
return False return False
def stop_distributed_inference(self): def stop_distributed_inference(self):
assert self.task_event, "Task event is not initialized"
assert self.result_event, "Result event is not initialized"
if not self.is_running: if not self.is_running:
return return
try: try:
logger.info(f"Stopping {len(self.processes)} distributed inference service processes...") if self.worker:
self.worker.cleanup()
if self.shared_data is not None: logger.info("Inference service stopped")
self.shared_data["stop"] = True
self.task_event.set()
for p in self.processes:
try:
p.join(timeout=10)
if p.is_alive():
logger.warning(f"Process {p.pid} did not end within the specified time, forcing termination...")
p.terminate()
p.join(timeout=5)
except Exception as e:
logger.warning(f"Error terminating process {p.pid}: {e}")
logger.info("All distributed inference service processes have stopped")
except Exception as e: except Exception as e:
logger.error(f"Error occurred while stopping distributed inference service: {str(e)}") logger.error(f"Error stopping inference service: {str(e)}")
finally: finally:
# Clean up resources self.worker = None
self.processes = []
self.manager = None
self.shared_data = None
self.task_event = None
self.result_event = None
self.is_running = False self.is_running = False
def submit_task(self, task_data: dict) -> bool: async def submit_task_async(self, task_data: dict) -> Optional[dict]:
assert self.task_event, "Task event is not initialized" if not self.is_running or not self.worker:
assert self.result_event, "Result event is not initialized" logger.error("Inference service is not started")
if not self.is_running or not self.shared_data:
logger.error("Distributed inference service is not started")
return False
try:
self.result_event.clear()
self.shared_data["result"] = None
self.shared_data["current_task"] = task_data
self.task_event.set() # Signal workers
return True
except Exception as e:
logger.error(f"Failed to submit task: {str(e)}")
return False
def wait_for_result(self, task_id: str, timeout: Optional[int] = None) -> Optional[dict]:
assert self.task_event, "Task event is not initialized"
assert self.result_event, "Result event is not initialized"
if timeout is None:
timeout = server_config.task_timeout
if not self.is_running or not self.shared_data:
return None
if self.result_event.wait(timeout=timeout):
result = self.shared_data.get("result")
if result and result.get("task_id") == task_id:
self.shared_data["current_task"] = None
self.task_event.clear()
return result
return None
def wait_for_result_with_stop(self, task_id: str, stop_event: threading.Event, timeout: Optional[int] = None) -> Optional[dict]:
if timeout is None:
timeout = server_config.task_timeout
if not self.is_running or not self.shared_data:
return None return None
assert self.task_event, "Task event is not initialized" if self.worker.rank != 0:
assert self.result_event, "Result event is not initialized"
start_time = time.time()
while time.time() - start_time < timeout:
if stop_event.is_set():
logger.info(f"Task {task_id} stop event triggered during wait")
self.shared_data["current_task"] = None
self.task_event.clear()
return None return None
if self.result_event.wait(timeout=0.5): try:
result = self.shared_data.get("result") if self.worker.processing:
if result and result.get("task_id") == task_id: # If we want to support queueing, we can add the task to queue
self.shared_data["current_task"] = None # For now, we'll process sequentially
self.task_event.clear() logger.info(f"Waiting for previous task to complete before processing task {task_data.get('task_id')}")
self.worker.processing = True
result = await self.worker.process_request(task_data)
self.worker.processing = False
return result return result
except Exception as e:
return None self.worker.processing = False
logger.error(f"Failed to process task: {str(e)}")
return {
"task_id": task_data.get("task_id", "unknown"),
"status": "failed",
"error": str(e),
"message": f"Task processing failed: {str(e)}",
}
def server_metadata(self): def server_metadata(self):
assert hasattr(self, "args"), "Distributed inference service has not been started. Call start_distributed_inference() first." assert hasattr(self, "args"), "Distributed inference service has not been started. Call start_distributed_inference() first."
return {"nproc_per_node": self.args.nproc_per_node, "model_cls": self.args.model_cls, "model_path": self.args.model_path} return {"nproc_per_node": self.worker.world_size, "model_cls": self.args.model_cls, "model_path": self.args.model_path}
async def run_worker_loop(self):
"""Run the worker loop for non-rank-0 processes"""
if self.worker and self.worker.rank != 0:
await self.worker.worker_loop()
class VideoGenerationService: class VideoGenerationService:
...@@ -478,6 +386,7 @@ class VideoGenerationService: ...@@ -478,6 +386,7 @@ class VideoGenerationService:
self.inference_service = inference_service self.inference_service = inference_service
async def generate_video_with_stop_event(self, message: TaskRequest, stop_event) -> Optional[TaskResponse]: async def generate_video_with_stop_event(self, message: TaskRequest, stop_event) -> Optional[TaskResponse]:
"""Generate video using torchrun-based inference"""
try: try:
task_data = {field: getattr(message, field) for field in message.model_fields_set if field != "task_id"} task_data = {field: getattr(message, field) for field in message.model_fields_set if field != "task_id"}
task_data["task_id"] = message.task_id task_data["task_id"] = message.task_id
...@@ -496,6 +405,8 @@ class VideoGenerationService: ...@@ -496,6 +405,8 @@ class VideoGenerationService:
else: else:
task_data["image_path"] = message.image_path task_data["image_path"] = message.image_path
logger.info(f"Task {message.task_id} image path: {task_data['image_path']}")
if "audio_path" in message.model_fields_set and message.audio_path: if "audio_path" in message.model_fields_set and message.audio_path:
if message.audio_path.startswith("http"): if message.audio_path.startswith("http"):
audio_path = await self.file_service.download_audio(message.audio_path) audio_path = await self.file_service.download_audio(message.audio_path)
...@@ -506,20 +417,19 @@ class VideoGenerationService: ...@@ -506,20 +417,19 @@ class VideoGenerationService:
else: else:
task_data["audio_path"] = message.audio_path task_data["audio_path"] = message.audio_path
logger.info(f"Task {message.task_id} audio path: {task_data['audio_path']}")
actual_save_path = self.file_service.get_output_path(message.save_video_path) actual_save_path = self.file_service.get_output_path(message.save_video_path)
task_data["save_video_path"] = str(actual_save_path) task_data["save_video_path"] = str(actual_save_path)
task_data["video_path"] = message.save_video_path task_data["video_path"] = message.save_video_path
if not self.inference_service.submit_task(task_data): result = await self.inference_service.submit_task_async(task_data)
raise RuntimeError("Distributed inference service is not started")
result = self.inference_service.wait_for_result_with_stop(message.task_id, stop_event, timeout=300)
if result is None: if result is None:
if stop_event.is_set(): if stop_event.is_set():
logger.info(f"Task {message.task_id} cancelled during processing") logger.info(f"Task {message.task_id} cancelled during processing")
return None return None
raise RuntimeError("Task processing timeout") raise RuntimeError("Task processing failed")
if result.get("status") == "success": if result.get("status") == "success":
return TaskResponse( return TaskResponse(
......
#!/bin/bash #!/bin/bash
# set path and first # set path and first
lightx2v_path= lightx2v_path=/mnt/afs/users/lijiaqi2/deploy-comfyui-ljq-custom_nodes/ComfyUI-Lightx2vWrapper/lightx2v
model_path= model_path=/mnt/afs/users/lijiaqi2/wan_model/Wan2.1-R2V0909-Audio-14B-720P-fp8
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export CUDA_VISIBLE_DEVICES=0,1,2,3
# set environment variables # set environment variables
source ${lightx2v_path}/scripts/base/base.sh source ${lightx2v_path}/scripts/base/base.sh
# Start multiple servers # Start multiple servers
python -m lightx2v.api_multi_servers \ torchrun --nproc_per_node 4 -m lightx2v.server \
--num_gpus $num_gpus \ --model_cls seko_talk \
--start_port 8000 \
--model_cls wan2.1_distill \
--task i2v \ --task i2v \
--model_path $model_path \ --model_path $model_path \
--config_json ${lightx2v_path}/configs/distill/wan_i2v_distill_4step_cfg.json --config_json ${lightx2v_path}/configs/seko_talk/xxx_dist.json \
--port 8000
#!/bin/bash #!/bin/bash
# set path and first # set path and first
lightx2v_path= lightx2v_path=/path/to/Lightx2v
model_path= model_path=/path/to/Wan2.1-R2V0909-Audio-14B-720P-fp8
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
...@@ -11,12 +11,11 @@ source ${lightx2v_path}/scripts/base/base.sh ...@@ -11,12 +11,11 @@ source ${lightx2v_path}/scripts/base/base.sh
# Start API server with distributed inference service # Start API server with distributed inference service
python -m lightx2v.api_server \ python -m lightx2v.server \
--model_cls wan2.1_distill \ --model_cls seko_talk \
--task i2v \ --task i2v \
--model_path $model_path \ --model_path $model_path \
--config_json ${lightx2v_path}/configs/distill/wan_i2v_distill_4step_cfg.json \ --config_json ${lightx2v_path}/configs/seko_talk/seko_talk_05_offload_fp8_4090.json \
--port 8000 \ --port 8000
--nproc_per_node 1
echo "Service stopped" echo "Service stopped"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment