Commit 3c778aee authored by PengGao's avatar PengGao Committed by GitHub
Browse files

Gp/dev (#310)


Co-authored-by: default avatarYang Yong(雍洋) <yongyang1030@163.com>
parent 32fd1c52
import argparse
import concurrent.futures
import os
import socket
import subprocess
import time
from dataclasses import dataclass
from typing import Optional
import requests
from loguru import logger
@dataclass
class ServerConfig:
port: int
gpu_id: int
model_cls: str
task: str
model_path: str
config_json: str
def get_node_ip() -> str:
"""Get the IP address of the current node"""
try:
# Create a UDP socket
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# Connect to an external address (no actual connection needed)
s.connect(("8.8.8.8", 80))
# Get local IP
ip = s.getsockname()[0]
s.close()
return ip
except Exception as e:
logger.error(f"Failed to get IP address: {e}")
return "localhost"
def is_port_in_use(port: int) -> bool:
"""Check if a port is in use"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(("localhost", port)) == 0
def find_available_port(start_port: int) -> Optional[int]:
"""Find an available port starting from start_port"""
port = start_port
while port < start_port + 1000: # Try up to 1000 ports
if not is_port_in_use(port):
return port
port += 1
return None
def start_server(config: ServerConfig) -> Optional[tuple[subprocess.Popen, str]]:
"""Start a single server instance"""
try:
# Set GPU
env = os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = str(config.gpu_id)
# Start server
process = subprocess.Popen(
[
"python",
"-m",
"lightx2v.api_server",
"--model_cls",
config.model_cls,
"--task",
config.task,
"--model_path",
config.model_path,
"--config_json",
config.config_json,
"--port",
str(config.port),
],
env=env,
)
# Wait for server to start, up to 600 seconds
node_ip = get_node_ip()
service_url = f"http://{node_ip}:{config.port}/v1/service/status"
# Check once per second, up to 600 times
for _ in range(600):
try:
response = requests.get(service_url, timeout=1)
if response.status_code == 200:
return process, f"http://{node_ip}:{config.port}"
except (requests.RequestException, ConnectionError) as e:
pass
time.sleep(1)
# If timeout, terminate the process
logger.error(f"Server startup timeout: port={config.port}, gpu={config.gpu_id}")
process.terminate()
return None
except Exception as e:
logger.error(f"Failed to start server: {e}")
return None
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--num_gpus", type=int, required=True, help="Number of GPUs to use")
parser.add_argument("--start_port", type=int, required=True, help="Starting port number")
parser.add_argument("--model_cls", type=str, required=True, help="Model class")
parser.add_argument("--task", type=str, required=True, help="Task type")
parser.add_argument("--model_path", type=str, required=True, help="Model path")
parser.add_argument("--config_json", type=str, required=True, help="Config file path")
args = parser.parse_args()
# Prepare configurations for all servers on this node
server_configs = []
current_port = args.start_port
# Create configs for each GPU on this node
for gpu in range(args.num_gpus):
port = find_available_port(current_port)
if port is None:
logger.error(f"Cannot find available port starting from {current_port}")
continue
config = ServerConfig(port=port, gpu_id=gpu, model_cls=args.model_cls, task=args.task, model_path=args.model_path, config_json=args.config_json)
server_configs.append(config)
current_port = port + 1
# Start all servers in parallel
processes = []
urls = []
with concurrent.futures.ThreadPoolExecutor(max_workers=len(server_configs)) as executor:
future_to_config = {executor.submit(start_server, config): config for config in server_configs}
for future in concurrent.futures.as_completed(future_to_config):
config = future_to_config[future]
try:
result = future.result()
if result:
process, url = result
processes.append(process)
urls.append(url)
logger.info(f"Server started successfully: {url} (GPU: {config.gpu_id})")
else:
logger.error(f"Failed to start server: port={config.port}, gpu={config.gpu_id}")
except Exception as e:
logger.error(f"Error occurred while starting server: {e}")
# Print all server URLs
logger.info("\nAll server URLs:")
for url in urls:
logger.info(url)
# Print node information
node_ip = get_node_ip()
logger.info(f"\nCurrent node IP: {node_ip}")
logger.info(f"Number of servers started: {len(urls)}")
try:
# Wait for all processes
for process in processes:
process.wait()
except KeyboardInterrupt:
logger.info("Received interrupt signal, shutting down all servers...")
for process in processes:
process.terminate()
if __name__ == "__main__":
main()
...@@ -9,85 +9,91 @@ The LightX2V server is a distributed video generation service built with FastAPI ...@@ -9,85 +9,91 @@ The LightX2V server is a distributed video generation service built with FastAPI
### System Architecture ### System Architecture
```mermaid ```mermaid
graph TB flowchart TB
subgraph "Client Layer" Client[Client] -->|Send API Request| Router[FastAPI Router]
Client[HTTP Client]
end
subgraph "API Layer" subgraph API Layer
FastAPI[FastAPI Application] Router --> TaskRoutes[Task APIs]
ApiServer[ApiServer] Router --> FileRoutes[File APIs]
Router1[Tasks Router<br/>/v1/tasks] Router --> ServiceRoutes[Service Status APIs]
Router2[Files Router<br/>/v1/files]
Router3[Service Router<br/>/v1/service]
end
subgraph "Service Layer" TaskRoutes --> CreateTask["POST /v1/tasks/ - Create Task"]
TaskManager[TaskManager<br/>Thread-safe Task Queue] TaskRoutes --> CreateTaskForm["POST /v1/tasks/form - Form Create"]
FileService[FileService<br/>File I/O & Downloads] TaskRoutes --> ListTasks["GET /v1/tasks/ - List Tasks"]
VideoService[VideoGenerationService] TaskRoutes --> GetTaskStatus["GET /v1/tasks/id/status - Get Status"]
end TaskRoutes --> GetTaskResult["GET /v1/tasks/id/result - Get Result"]
TaskRoutes --> StopTask["DELETE /v1/tasks/id - Stop Task"]
subgraph "Processing Layer" FileRoutes --> DownloadFile["GET /v1/files/download/path - Download File"]
Thread[Processing Thread<br/>Sequential Task Loop]
end
subgraph "Distributed Inference Layer" ServiceRoutes --> GetServiceStatus["GET /v1/service/status - Service Status"]
DistService[DistributedInferenceService] ServiceRoutes --> GetServiceMetadata["GET /v1/service/metadata - Metadata"]
SharedData[(Shared Data<br/>mp.Manager.dict)]
TaskEvent[Task Event<br/>mp.Manager.Event]
ResultEvent[Result Event<br/>mp.Manager.Event]
subgraph "Worker Processes"
W0[Worker 0<br/>Master/Rank 0]
W1[Worker 1<br/>Rank 1]
WN[Worker N<br/>Rank N]
end
end end
subgraph "Resource Management" subgraph Task Management
GPUManager[GPUManager<br/>GPU Detection & Allocation] TaskManager[Task Manager]
DistManager[DistributedManager<br/>PyTorch Distributed] TaskQueue[Task Queue]
Config[ServerConfig<br/>Configuration] TaskStatus[Task Status]
TaskResult[Task Result]
CreateTask --> TaskManager
CreateTaskForm --> TaskManager
TaskManager --> TaskQueue
TaskManager --> TaskStatus
TaskManager --> TaskResult
end end
Client -->|HTTP Request| FastAPI subgraph File Service
FastAPI --> ApiServer FileService[File Service]
ApiServer --> Router1 DownloadImage[Download Image]
ApiServer --> Router2 DownloadAudio[Download Audio]
ApiServer --> Router3 SaveFile[Save File]
GetOutputPath[Get Output Path]
Router1 -->|Create/Manage Tasks| TaskManager
Router1 -->|Process Tasks| Thread FileService --> DownloadImage
Router2 -->|File Operations| FileService FileService --> DownloadAudio
Router3 -->|Service Status| TaskManager FileService --> SaveFile
FileService --> GetOutputPath
Thread -->|Get Pending Tasks| TaskManager end
Thread -->|Generate Video| VideoService
VideoService -->|Download Images| FileService subgraph Processing Thread
VideoService -->|Submit Task| DistService ProcessingThread[Processing Thread]
NextTask[Get Next Task]
ProcessTask[Process Single Task]
DistService -->|Update| SharedData ProcessingThread --> NextTask
DistService -->|Signal| TaskEvent ProcessingThread --> ProcessTask
TaskEvent -->|Notify| W0 end
W0 -->|Broadcast| W1
W0 -->|Broadcast| WN
W0 -->|Update Result| SharedData subgraph Video Generation Service
W0 -->|Signal| ResultEvent VideoService[Video Service]
ResultEvent -->|Notify| DistService GenerateVideo[Generate Video]
W0 -.->|Uses| GPUManager VideoService --> GenerateVideo
W1 -.->|Uses| GPUManager end
WN -.->|Uses| GPUManager
W0 -.->|Setup| DistManager subgraph Distributed Inference Service
W1 -.->|Setup| DistManager InferenceService[Distributed Inference Service]
WN -.->|Setup| DistManager SubmitTask[Submit Task]
Worker[Inference Worker Node]
ProcessRequest[Process Request]
RunPipeline[Run Inference Pipeline]
InferenceService --> SubmitTask
SubmitTask --> Worker
Worker --> ProcessRequest
ProcessRequest --> RunPipeline
end
DistService -.->|Reads| Config %% ====== Connect Modules ======
ApiServer -.->|Reads| Config TaskQueue --> ProcessingThread
ProcessTask --> VideoService
GenerateVideo --> InferenceService
GetTaskResult --> FileService
DownloadFile --> FileService
VideoService --> FileService
InferenceService --> TaskManager
TaskManager --> TaskStatus
``` ```
## Task Processing Flow ## Task Processing Flow
...@@ -100,9 +106,9 @@ sequenceDiagram ...@@ -100,9 +106,9 @@ sequenceDiagram
participant PT as Processing Thread participant PT as Processing Thread
participant VS as VideoService participant VS as VideoService
participant FS as FileService participant FS as FileService
participant DIS as Distributed<br/>Inference Service participant DIS as DistributedInferenceService
participant W0 as Worker 0<br/>(Master) participant TIW0 as TorchrunInferenceWorker<br/>(Rank 0)
participant W1 as Worker 1..N participant TIW1 as TorchrunInferenceWorker<br/>(Rank 1..N)
C->>API: POST /v1/tasks<br/>(Create Task) C->>API: POST /v1/tasks<br/>(Create Task)
API->>TM: create_task() API->>TM: create_task()
...@@ -127,32 +133,54 @@ sequenceDiagram ...@@ -127,32 +133,54 @@ sequenceDiagram
else Image is Base64 else Image is Base64
VS->>FS: save_base64_image() VS->>FS: save_base64_image()
FS-->>VS: image_path FS-->>VS: image_path
else Image is Upload else Image is local path
VS->>FS: validate_file() VS->>VS: use existing path
FS-->>VS: image_path end
alt Audio is URL
VS->>FS: download_audio()
FS->>FS: HTTP download<br/>with retry
FS-->>VS: audio_path
else Audio is Base64
VS->>FS: save_base64_audio()
FS-->>VS: audio_path
else Audio is local path
VS->>VS: use existing path
end end
VS->>DIS: submit_task(task_data) VS->>DIS: submit_task_async(task_data)
DIS->>DIS: shared_data["current_task"] = task_data DIS->>TIW0: process_request(task_data)
DIS->>DIS: task_event.set()
Note over W0,W1: Distributed Processing Note over TIW0,TIW1: Torchrun-based Distributed Processing
W0->>W0: task_event.wait() TIW0->>TIW0: Check if processing
W0->>W0: Get task from shared_data TIW0->>TIW0: Set processing = True
W0->>W1: broadcast_task_data()
par Parallel Inference alt Multi-GPU Mode (world_size > 1)
W0->>W0: run_pipeline() TIW0->>TIW1: broadcast_task_data()<br/>(via DistributedManager)
Note over TIW1: worker_loop() listens for broadcasts
TIW1->>TIW1: Receive task_data
end
par Parallel Inference across all ranks
TIW0->>TIW0: runner.set_inputs(task_data)
TIW0->>TIW0: runner.run_pipeline()
and and
W1->>W1: run_pipeline() Note over TIW1: If world_size > 1
TIW1->>TIW1: runner.set_inputs(task_data)
TIW1->>TIW1: runner.run_pipeline()
end
Note over TIW0,TIW1: Synchronization
alt Multi-GPU Mode
TIW0->>TIW1: barrier() for sync
TIW1->>TIW0: barrier() response
end end
W0->>W0: barrier() for sync TIW0->>TIW0: Set processing = False
W0->>W0: shared_data["result"] = result TIW0->>DIS: Return result (only rank 0)
W0->>DIS: result_event.set() TIW1->>TIW1: Return None (non-rank 0)
DIS->>DIS: result_event.wait() DIS-->>VS: TaskResponse
DIS->>VS: return result
VS-->>PT: TaskResponse VS-->>PT: TaskResponse
PT->>TM: complete_task()<br/>(status: COMPLETED) PT->>TM: complete_task()<br/>(status: COMPLETED)
......
#!/usr/bin/env python
import argparse import argparse
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent)) from .main import run_server
from lightx2v.server.main import run_server
def main(): def main():
parser = argparse.ArgumentParser(description="Run LightX2V inference server") parser = argparse.ArgumentParser(description="LightX2V Server")
# Model arguments
parser.add_argument("--model_path", type=str, required=True, help="Path to model") parser.add_argument("--model_path", type=str, required=True, help="Path to model")
parser.add_argument("--model_cls", type=str, required=True, help="Model class name") parser.add_argument("--model_cls", type=str, required=True, help="Model class name")
parser.add_argument("--config_json", type=str, help="Path to model config JSON file")
parser.add_argument("--task", type=str, default="i2v", help="Task type (i2v, etc.)")
parser.add_argument("--nproc_per_node", type=int, default=1, help="Number of processes per node (GPUs to use)")
# Server arguments
parser.add_argument("--host", type=str, default="0.0.0.0", help="Server host")
parser.add_argument("--port", type=int, default=8000, help="Server port") parser.add_argument("--port", type=int, default=8000, help="Server port")
parser.add_argument("--host", type=str, default="127.0.0.1", help="Server host")
args = parser.parse_args() # Parse any additional arguments that might be passed
args, unknown = parser.parse_known_args()
# Add any unknown arguments as attributes to args
# This allows flexibility for model-specific arguments
for i in range(0, len(unknown), 2):
if unknown[i].startswith("--"):
key = unknown[i][2:]
if i + 1 < len(unknown) and not unknown[i + 1].startswith("--"):
value = unknown[i + 1]
setattr(args, key, value)
# Run the server
run_server(args) run_server(args)
......
...@@ -314,7 +314,6 @@ class ApiServer: ...@@ -314,7 +314,6 @@ class ApiServer:
return False return False
def _ensure_processing_thread_running(self): def _ensure_processing_thread_running(self):
"""Ensure the processing thread is running."""
if self.processing_thread is None or not self.processing_thread.is_alive(): if self.processing_thread is None or not self.processing_thread.is_alive():
self.stop_processing.clear() self.stop_processing.clear()
self.processing_thread = threading.Thread(target=self._task_processing_loop, daemon=True) self.processing_thread = threading.Thread(target=self._task_processing_loop, daemon=True)
...@@ -322,9 +321,11 @@ class ApiServer: ...@@ -322,9 +321,11 @@ class ApiServer:
logger.info("Started task processing thread") logger.info("Started task processing thread")
def _task_processing_loop(self): def _task_processing_loop(self):
"""Main loop that processes tasks from the queue one by one."""
logger.info("Task processing loop started") logger.info("Task processing loop started")
asyncio.set_event_loop(asyncio.new_event_loop())
loop = asyncio.get_event_loop()
while not self.stop_processing.is_set(): while not self.stop_processing.is_set():
task_id = task_manager.get_next_pending_task() task_id = task_manager.get_next_pending_task()
...@@ -335,12 +336,12 @@ class ApiServer: ...@@ -335,12 +336,12 @@ class ApiServer:
task_info = task_manager.get_task(task_id) task_info = task_manager.get_task(task_id)
if task_info and task_info.status == TaskStatus.PENDING: if task_info and task_info.status == TaskStatus.PENDING:
logger.info(f"Processing task {task_id}") logger.info(f"Processing task {task_id}")
self._process_single_task(task_info) loop.run_until_complete(self._process_single_task(task_info))
loop.close()
logger.info("Task processing loop stopped") logger.info("Task processing loop stopped")
def _process_single_task(self, task_info: Any): async def _process_single_task(self, task_info: Any):
"""Process a single task."""
assert self.video_service is not None, "Video service is not initialized" assert self.video_service is not None, "Video service is not initialized"
task_id = task_info.task_id task_id = task_info.task_id
...@@ -360,7 +361,7 @@ class ApiServer: ...@@ -360,7 +361,7 @@ class ApiServer:
task_manager.fail_task(task_id, "Task cancelled") task_manager.fail_task(task_id, "Task cancelled")
return return
result = asyncio.run(self.video_service.generate_video_with_stop_event(message, task_info.stop_event)) result = await self.video_service.generate_video_with_stop_event(message, task_info.stop_event)
if result: if result:
task_manager.complete_task(task_id, result.save_video_path) task_manager.complete_task(task_id, result.save_video_path)
......
...@@ -11,9 +11,6 @@ class ServerConfig: ...@@ -11,9 +11,6 @@ class ServerConfig:
port: int = 8000 port: int = 8000
max_queue_size: int = 10 max_queue_size: int = 10
master_addr: str = "127.0.0.1"
master_port_range: tuple = (29500, 29600)
task_timeout: int = 300 task_timeout: int = 300
task_history_limit: int = 1000 task_history_limit: int = 1000
...@@ -42,31 +39,13 @@ class ServerConfig: ...@@ -42,31 +39,13 @@ class ServerConfig:
except ValueError: except ValueError:
logger.warning(f"Invalid max queue size: {env_queue_size}") logger.warning(f"Invalid max queue size: {env_queue_size}")
if env_master_addr := os.environ.get("MASTER_ADDR"): # MASTER_ADDR is now managed by torchrun, no need to set manually
config.master_addr = env_master_addr
if env_cache_dir := os.environ.get("LIGHTX2V_CACHE_DIR"): if env_cache_dir := os.environ.get("LIGHTX2V_CACHE_DIR"):
config.cache_dir = env_cache_dir config.cache_dir = env_cache_dir
return config return config
def find_free_master_port(self) -> str:
import socket
for port in range(self.master_port_range[0], self.master_port_range[1]):
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind((self.master_addr, port))
logger.info(f"Found free port for master: {port}")
return str(port)
except OSError:
continue
raise RuntimeError(
f"No free port found for master in range {self.master_port_range[0]}-{self.master_port_range[1] - 1} "
f"on address {self.master_addr}. Please adjust 'master_port_range' or free an occupied port."
)
def validate(self) -> bool: def validate(self) -> bool:
valid = True valid = True
......
...@@ -6,8 +6,6 @@ import torch ...@@ -6,8 +6,6 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from loguru import logger from loguru import logger
from .gpu_manager import gpu_manager
class DistributedManager: class DistributedManager:
def __init__(self): def __init__(self):
...@@ -18,29 +16,35 @@ class DistributedManager: ...@@ -18,29 +16,35 @@ class DistributedManager:
CHUNK_SIZE = 1024 * 1024 CHUNK_SIZE = 1024 * 1024
def init_process_group(self, rank: int, world_size: int, master_addr: str, master_port: str) -> bool: def init_process_group(self) -> bool:
"""Initialize process group using torchrun environment variables"""
try: try:
os.environ["RANK"] = str(rank) # torchrun sets these environment variables automatically
os.environ["WORLD_SIZE"] = str(world_size) self.rank = int(os.environ.get("LOCAL_RANK", 0))
os.environ["MASTER_ADDR"] = master_addr self.world_size = int(os.environ.get("WORLD_SIZE", 1))
os.environ["MASTER_PORT"] = master_port
if self.world_size > 1:
backend = "nccl" if torch.cuda.is_available() else "gloo" # torchrun handles backend, init_method, rank, and world_size
# We just need to call init_process_group without parameters
dist.init_process_group(backend=backend, init_method=f"tcp://{master_addr}:{master_port}", rank=rank, world_size=world_size) backend = "nccl" if torch.cuda.is_available() else "gloo"
logger.info(f"Setup backend: {backend}") dist.init_process_group(backend=backend, init_method="env://")
logger.info(f"Setup backend: {backend}")
self.device = gpu_manager.set_device_for_rank(rank, world_size)
# Set CUDA device for this rank
if torch.cuda.is_available():
torch.cuda.set_device(self.rank)
self.device = f"cuda:{self.rank}"
else:
self.device = "cpu"
else:
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.is_initialized = True self.is_initialized = True
self.rank = rank logger.info(f"Rank {self.rank}/{self.world_size - 1} distributed environment initialized successfully")
self.world_size = world_size
logger.info(f"Rank {rank}/{world_size - 1} distributed environment initialized successfully")
return True return True
except Exception as e: except Exception as e:
logger.error(f"Rank {rank} distributed environment initialization failed: {str(e)}") logger.error(f"Rank {self.rank} distributed environment initialization failed: {str(e)}")
return False return False
def cleanup(self): def cleanup(self):
...@@ -143,30 +147,3 @@ class DistributedManager: ...@@ -143,30 +147,3 @@ class DistributedManager:
task_bytes = self._receive_byte_chunks(total_length, broadcast_device) task_bytes = self._receive_byte_chunks(total_length, broadcast_device)
task_data = pickle.loads(task_bytes) task_data = pickle.loads(task_bytes)
return task_data return task_data
class DistributedWorker:
def __init__(self, rank: int, world_size: int, master_addr: str, master_port: str):
self.rank = rank
self.world_size = world_size
self.master_addr = master_addr
self.master_port = master_port
self.dist_manager = DistributedManager()
def init(self) -> bool:
return self.dist_manager.init_process_group(self.rank, self.world_size, self.master_addr, self.master_port)
def cleanup(self):
self.dist_manager.cleanup()
def sync_and_report(self, task_id: str, status: str, result_queue, **kwargs):
self.dist_manager.barrier()
if self.dist_manager.is_rank_zero():
result = {"task_id": task_id, "status": status, **kwargs}
result_queue.put(result)
logger.info(f"Task {task_id} {status}")
def create_distributed_worker(rank: int, world_size: int, master_addr: str, master_port: str) -> DistributedWorker:
return DistributedWorker(rank, world_size, master_addr, master_port)
import os
from typing import List, Optional, Tuple
import torch
from loguru import logger
class GPUManager:
def __init__(self):
self.available_gpus = self._detect_gpus()
self.gpu_count = len(self.available_gpus)
def _detect_gpus(self) -> List[int]:
if not torch.cuda.is_available():
logger.warning("No CUDA devices available, will use CPU")
return []
gpu_count = torch.cuda.device_count()
cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", "")
if cuda_visible:
try:
visible_devices = [int(d.strip()) for d in cuda_visible.split(",")]
logger.info(f"CUDA_VISIBLE_DEVICES set to: {visible_devices}")
return list(range(len(visible_devices)))
except ValueError:
logger.warning(f"Invalid CUDA_VISIBLE_DEVICES: {cuda_visible}, using all devices")
available_gpus = list(range(gpu_count))
logger.info(f"Detected {gpu_count} GPU devices: {available_gpus}")
return available_gpus
def get_device_for_rank(self, rank: int, world_size: int) -> str:
if not self.available_gpus:
logger.info(f"Rank {rank}: Using CPU (no GPUs available)")
return "cpu"
if self.gpu_count == 1:
device = f"cuda:{self.available_gpus[0]}"
logger.info(f"Rank {rank}: Using single GPU {device}")
return device
if self.gpu_count >= world_size:
gpu_id = self.available_gpus[rank % self.gpu_count]
device = f"cuda:{gpu_id}"
logger.info(f"Rank {rank}: Assigned to dedicated GPU {device}")
return device
else:
gpu_id = self.available_gpus[rank % self.gpu_count]
device = f"cuda:{gpu_id}"
logger.info(f"Rank {rank}: Sharing GPU {device} (world_size={world_size} > gpu_count={self.gpu_count})")
return device
def set_device_for_rank(self, rank: int, world_size: int) -> str:
device = self.get_device_for_rank(rank, world_size)
if device.startswith("cuda:"):
gpu_id = int(device.split(":")[1])
torch.cuda.set_device(gpu_id)
logger.info(f"Rank {rank}: CUDA device set to {gpu_id}")
return device
def get_memory_info(self, device: Optional[str] = None) -> Tuple[int, int]:
if not torch.cuda.is_available():
return (0, 0)
if device and device.startswith("cuda:"):
gpu_id = int(device.split(":")[1])
else:
gpu_id = torch.cuda.current_device()
try:
used = torch.cuda.memory_allocated(gpu_id)
total = torch.cuda.get_device_properties(gpu_id).total_memory
return (used, total)
except Exception as e:
logger.error(f"Failed to get memory info for device {gpu_id}: {e}")
return (0, 0)
def clear_cache(self, device: Optional[str] = None):
if not torch.cuda.is_available():
return
if device and device.startswith("cuda:"):
gpu_id = int(device.split(":")[1])
with torch.cuda.device(gpu_id):
torch.cuda.empty_cache()
torch.cuda.synchronize()
else:
torch.cuda.empty_cache()
torch.cuda.synchronize()
logger.info(f"GPU cache cleared for device: {device or 'current'}")
@staticmethod
def get_optimal_world_size(requested_world_size: int) -> int:
if not torch.cuda.is_available():
logger.warning("No GPUs available, using single process")
return 1
gpu_count = torch.cuda.device_count()
if requested_world_size <= 0:
optimal_size = gpu_count
logger.info(f"Auto-detected world_size: {optimal_size} (based on {gpu_count} GPUs)")
elif requested_world_size > gpu_count:
logger.warning(f"Requested world_size ({requested_world_size}) exceeds GPU count ({gpu_count}). Processes will share GPUs.")
optimal_size = requested_world_size
else:
optimal_size = requested_world_size
return optimal_size
gpu_manager = GPUManager()
import os
import sys import sys
from pathlib import Path from pathlib import Path
...@@ -10,9 +11,14 @@ from .service import DistributedInferenceService ...@@ -10,9 +11,14 @@ from .service import DistributedInferenceService
def run_server(args): def run_server(args):
"""Run server with torchrun support"""
inference_service = None inference_service = None
try: try:
logger.info("Starting LightX2V server...") # Get rank from environment (set by torchrun)
rank = int(os.environ.get("LOCAL_RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
logger.info(f"Starting LightX2V server (Rank {rank}/{world_size})...")
if hasattr(args, "host") and args.host: if hasattr(args, "host") and args.host:
server_config.host = args.host server_config.host = args.host
...@@ -22,28 +28,37 @@ def run_server(args): ...@@ -22,28 +28,37 @@ def run_server(args):
if not server_config.validate(): if not server_config.validate():
raise RuntimeError("Invalid server configuration") raise RuntimeError("Invalid server configuration")
# Initialize inference service
inference_service = DistributedInferenceService() inference_service = DistributedInferenceService()
if not inference_service.start_distributed_inference(args): if not inference_service.start_distributed_inference(args):
raise RuntimeError("Failed to start distributed inference service") raise RuntimeError("Failed to start distributed inference service")
logger.info("Inference service started successfully") logger.info(f"Rank {rank}: Inference service started successfully")
if rank == 0:
# Only rank 0 runs the FastAPI server
cache_dir = Path(server_config.cache_dir)
cache_dir.mkdir(parents=True, exist_ok=True)
cache_dir = Path(server_config.cache_dir) api_server = ApiServer(max_queue_size=server_config.max_queue_size)
cache_dir.mkdir(parents=True, exist_ok=True) api_server.initialize_services(cache_dir, inference_service)
api_server = ApiServer(max_queue_size=server_config.max_queue_size) app = api_server.get_app()
api_server.initialize_services(cache_dir, inference_service)
app = api_server.get_app() logger.info(f"Starting FastAPI server on {server_config.host}:{server_config.port}")
uvicorn.run(app, host=server_config.host, port=server_config.port, log_level="info")
else:
# Non-rank-0 processes run the worker loop
logger.info(f"Rank {rank}: Starting worker loop")
import asyncio
logger.info(f"Starting server on {server_config.host}:{server_config.port}") asyncio.run(inference_service.run_worker_loop())
uvicorn.run(app, host=server_config.host, port=server_config.port, log_level="info")
except KeyboardInterrupt: except KeyboardInterrupt:
logger.info("Server interrupted by user") logger.info(f"Server rank {rank} interrupted by user")
if inference_service: if inference_service:
inference_service.stop_distributed_inference() inference_service.stop_distributed_inference()
except Exception as e: except Exception as e:
logger.error(f"Server failed: {e}") logger.error(f"Server rank {rank} failed: {e}")
if inference_service: if inference_service:
inference_service.stop_distributed_inference() inference_service.stop_distributed_inference()
sys.exit(1) sys.exit(1)
This diff is collapsed.
#!/bin/bash #!/bin/bash
# set path and first # set path and first
lightx2v_path= lightx2v_path=/mnt/afs/users/lijiaqi2/deploy-comfyui-ljq-custom_nodes/ComfyUI-Lightx2vWrapper/lightx2v
model_path= model_path=/mnt/afs/users/lijiaqi2/wan_model/Wan2.1-R2V0909-Audio-14B-720P-fp8
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export CUDA_VISIBLE_DEVICES=0,1,2,3
# set environment variables # set environment variables
source ${lightx2v_path}/scripts/base/base.sh source ${lightx2v_path}/scripts/base/base.sh
# Start multiple servers # Start multiple servers
python -m lightx2v.api_multi_servers \ torchrun --nproc_per_node 4 -m lightx2v.server \
--num_gpus $num_gpus \ --model_cls seko_talk \
--start_port 8000 \
--model_cls wan2.1_distill \
--task i2v \ --task i2v \
--model_path $model_path \ --model_path $model_path \
--config_json ${lightx2v_path}/configs/distill/wan_i2v_distill_4step_cfg.json --config_json ${lightx2v_path}/configs/seko_talk/xxx_dist.json \
--port 8000
#!/bin/bash #!/bin/bash
# set path and first # set path and first
lightx2v_path= lightx2v_path=/path/to/Lightx2v
model_path= model_path=/path/to/Wan2.1-R2V0909-Audio-14B-720P-fp8
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
...@@ -11,12 +11,11 @@ source ${lightx2v_path}/scripts/base/base.sh ...@@ -11,12 +11,11 @@ source ${lightx2v_path}/scripts/base/base.sh
# Start API server with distributed inference service # Start API server with distributed inference service
python -m lightx2v.api_server \ python -m lightx2v.server \
--model_cls wan2.1_distill \ --model_cls seko_talk \
--task i2v \ --task i2v \
--model_path $model_path \ --model_path $model_path \
--config_json ${lightx2v_path}/configs/distill/wan_i2v_distill_4step_cfg.json \ --config_json ${lightx2v_path}/configs/seko_talk/seko_talk_05_offload_fp8_4090.json \
--port 8000 \ --port 8000
--nproc_per_node 1
echo "Service stopped" echo "Service stopped"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment