Commit 3c778aee authored by PengGao's avatar PengGao Committed by GitHub
Browse files

Gp/dev (#310)


Co-authored-by: default avatarYang Yong(雍洋) <yongyang1030@163.com>
parent 32fd1c52
import argparse
import concurrent.futures
import os
import socket
import subprocess
import time
from dataclasses import dataclass
from typing import Optional
import requests
from loguru import logger
@dataclass
class ServerConfig:
port: int
gpu_id: int
model_cls: str
task: str
model_path: str
config_json: str
def get_node_ip() -> str:
"""Get the IP address of the current node"""
try:
# Create a UDP socket
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# Connect to an external address (no actual connection needed)
s.connect(("8.8.8.8", 80))
# Get local IP
ip = s.getsockname()[0]
s.close()
return ip
except Exception as e:
logger.error(f"Failed to get IP address: {e}")
return "localhost"
def is_port_in_use(port: int) -> bool:
"""Check if a port is in use"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(("localhost", port)) == 0
def find_available_port(start_port: int) -> Optional[int]:
"""Find an available port starting from start_port"""
port = start_port
while port < start_port + 1000: # Try up to 1000 ports
if not is_port_in_use(port):
return port
port += 1
return None
def start_server(config: ServerConfig) -> Optional[tuple[subprocess.Popen, str]]:
"""Start a single server instance"""
try:
# Set GPU
env = os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = str(config.gpu_id)
# Start server
process = subprocess.Popen(
[
"python",
"-m",
"lightx2v.api_server",
"--model_cls",
config.model_cls,
"--task",
config.task,
"--model_path",
config.model_path,
"--config_json",
config.config_json,
"--port",
str(config.port),
],
env=env,
)
# Wait for server to start, up to 600 seconds
node_ip = get_node_ip()
service_url = f"http://{node_ip}:{config.port}/v1/service/status"
# Check once per second, up to 600 times
for _ in range(600):
try:
response = requests.get(service_url, timeout=1)
if response.status_code == 200:
return process, f"http://{node_ip}:{config.port}"
except (requests.RequestException, ConnectionError) as e:
pass
time.sleep(1)
# If timeout, terminate the process
logger.error(f"Server startup timeout: port={config.port}, gpu={config.gpu_id}")
process.terminate()
return None
except Exception as e:
logger.error(f"Failed to start server: {e}")
return None
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--num_gpus", type=int, required=True, help="Number of GPUs to use")
parser.add_argument("--start_port", type=int, required=True, help="Starting port number")
parser.add_argument("--model_cls", type=str, required=True, help="Model class")
parser.add_argument("--task", type=str, required=True, help="Task type")
parser.add_argument("--model_path", type=str, required=True, help="Model path")
parser.add_argument("--config_json", type=str, required=True, help="Config file path")
args = parser.parse_args()
# Prepare configurations for all servers on this node
server_configs = []
current_port = args.start_port
# Create configs for each GPU on this node
for gpu in range(args.num_gpus):
port = find_available_port(current_port)
if port is None:
logger.error(f"Cannot find available port starting from {current_port}")
continue
config = ServerConfig(port=port, gpu_id=gpu, model_cls=args.model_cls, task=args.task, model_path=args.model_path, config_json=args.config_json)
server_configs.append(config)
current_port = port + 1
# Start all servers in parallel
processes = []
urls = []
with concurrent.futures.ThreadPoolExecutor(max_workers=len(server_configs)) as executor:
future_to_config = {executor.submit(start_server, config): config for config in server_configs}
for future in concurrent.futures.as_completed(future_to_config):
config = future_to_config[future]
try:
result = future.result()
if result:
process, url = result
processes.append(process)
urls.append(url)
logger.info(f"Server started successfully: {url} (GPU: {config.gpu_id})")
else:
logger.error(f"Failed to start server: port={config.port}, gpu={config.gpu_id}")
except Exception as e:
logger.error(f"Error occurred while starting server: {e}")
# Print all server URLs
logger.info("\nAll server URLs:")
for url in urls:
logger.info(url)
# Print node information
node_ip = get_node_ip()
logger.info(f"\nCurrent node IP: {node_ip}")
logger.info(f"Number of servers started: {len(urls)}")
try:
# Wait for all processes
for process in processes:
process.wait()
except KeyboardInterrupt:
logger.info("Received interrupt signal, shutting down all servers...")
for process in processes:
process.terminate()
if __name__ == "__main__":
main()
......@@ -9,85 +9,91 @@ The LightX2V server is a distributed video generation service built with FastAPI
### System Architecture
```mermaid
graph TB
subgraph "Client Layer"
Client[HTTP Client]
end
flowchart TB
Client[Client] -->|Send API Request| Router[FastAPI Router]
subgraph "API Layer"
FastAPI[FastAPI Application]
ApiServer[ApiServer]
Router1[Tasks Router<br/>/v1/tasks]
Router2[Files Router<br/>/v1/files]
Router3[Service Router<br/>/v1/service]
end
subgraph API Layer
Router --> TaskRoutes[Task APIs]
Router --> FileRoutes[File APIs]
Router --> ServiceRoutes[Service Status APIs]
subgraph "Service Layer"
TaskManager[TaskManager<br/>Thread-safe Task Queue]
FileService[FileService<br/>File I/O & Downloads]
VideoService[VideoGenerationService]
end
TaskRoutes --> CreateTask["POST /v1/tasks/ - Create Task"]
TaskRoutes --> CreateTaskForm["POST /v1/tasks/form - Form Create"]
TaskRoutes --> ListTasks["GET /v1/tasks/ - List Tasks"]
TaskRoutes --> GetTaskStatus["GET /v1/tasks/id/status - Get Status"]
TaskRoutes --> GetTaskResult["GET /v1/tasks/id/result - Get Result"]
TaskRoutes --> StopTask["DELETE /v1/tasks/id - Stop Task"]
subgraph "Processing Layer"
Thread[Processing Thread<br/>Sequential Task Loop]
end
FileRoutes --> DownloadFile["GET /v1/files/download/path - Download File"]
subgraph "Distributed Inference Layer"
DistService[DistributedInferenceService]
SharedData[(Shared Data<br/>mp.Manager.dict)]
TaskEvent[Task Event<br/>mp.Manager.Event]
ResultEvent[Result Event<br/>mp.Manager.Event]
subgraph "Worker Processes"
W0[Worker 0<br/>Master/Rank 0]
W1[Worker 1<br/>Rank 1]
WN[Worker N<br/>Rank N]
end
ServiceRoutes --> GetServiceStatus["GET /v1/service/status - Service Status"]
ServiceRoutes --> GetServiceMetadata["GET /v1/service/metadata - Metadata"]
end
subgraph "Resource Management"
GPUManager[GPUManager<br/>GPU Detection & Allocation]
DistManager[DistributedManager<br/>PyTorch Distributed]
Config[ServerConfig<br/>Configuration]
subgraph Task Management
TaskManager[Task Manager]
TaskQueue[Task Queue]
TaskStatus[Task Status]
TaskResult[Task Result]
CreateTask --> TaskManager
CreateTaskForm --> TaskManager
TaskManager --> TaskQueue
TaskManager --> TaskStatus
TaskManager --> TaskResult
end
Client -->|HTTP Request| FastAPI
FastAPI --> ApiServer
ApiServer --> Router1
ApiServer --> Router2
ApiServer --> Router3
Router1 -->|Create/Manage Tasks| TaskManager
Router1 -->|Process Tasks| Thread
Router2 -->|File Operations| FileService
Router3 -->|Service Status| TaskManager
Thread -->|Get Pending Tasks| TaskManager
Thread -->|Generate Video| VideoService
subgraph File Service
FileService[File Service]
DownloadImage[Download Image]
DownloadAudio[Download Audio]
SaveFile[Save File]
GetOutputPath[Get Output Path]
FileService --> DownloadImage
FileService --> DownloadAudio
FileService --> SaveFile
FileService --> GetOutputPath
end
VideoService -->|Download Images| FileService
VideoService -->|Submit Task| DistService
subgraph Processing Thread
ProcessingThread[Processing Thread]
NextTask[Get Next Task]
ProcessTask[Process Single Task]
DistService -->|Update| SharedData
DistService -->|Signal| TaskEvent
TaskEvent -->|Notify| W0
W0 -->|Broadcast| W1
W0 -->|Broadcast| WN
ProcessingThread --> NextTask
ProcessingThread --> ProcessTask
end
W0 -->|Update Result| SharedData
W0 -->|Signal| ResultEvent
ResultEvent -->|Notify| DistService
subgraph Video Generation Service
VideoService[Video Service]
GenerateVideo[Generate Video]
W0 -.->|Uses| GPUManager
W1 -.->|Uses| GPUManager
WN -.->|Uses| GPUManager
VideoService --> GenerateVideo
end
W0 -.->|Setup| DistManager
W1 -.->|Setup| DistManager
WN -.->|Setup| DistManager
subgraph Distributed Inference Service
InferenceService[Distributed Inference Service]
SubmitTask[Submit Task]
Worker[Inference Worker Node]
ProcessRequest[Process Request]
RunPipeline[Run Inference Pipeline]
InferenceService --> SubmitTask
SubmitTask --> Worker
Worker --> ProcessRequest
ProcessRequest --> RunPipeline
end
DistService -.->|Reads| Config
ApiServer -.->|Reads| Config
%% ====== Connect Modules ======
TaskQueue --> ProcessingThread
ProcessTask --> VideoService
GenerateVideo --> InferenceService
GetTaskResult --> FileService
DownloadFile --> FileService
VideoService --> FileService
InferenceService --> TaskManager
TaskManager --> TaskStatus
```
## Task Processing Flow
......@@ -100,9 +106,9 @@ sequenceDiagram
participant PT as Processing Thread
participant VS as VideoService
participant FS as FileService
participant DIS as Distributed<br/>Inference Service
participant W0 as Worker 0<br/>(Master)
participant W1 as Worker 1..N
participant DIS as DistributedInferenceService
participant TIW0 as TorchrunInferenceWorker<br/>(Rank 0)
participant TIW1 as TorchrunInferenceWorker<br/>(Rank 1..N)
C->>API: POST /v1/tasks<br/>(Create Task)
API->>TM: create_task()
......@@ -127,32 +133,54 @@ sequenceDiagram
else Image is Base64
VS->>FS: save_base64_image()
FS-->>VS: image_path
else Image is Upload
VS->>FS: validate_file()
FS-->>VS: image_path
else Image is local path
VS->>VS: use existing path
end
alt Audio is URL
VS->>FS: download_audio()
FS->>FS: HTTP download<br/>with retry
FS-->>VS: audio_path
else Audio is Base64
VS->>FS: save_base64_audio()
FS-->>VS: audio_path
else Audio is local path
VS->>VS: use existing path
end
VS->>DIS: submit_task(task_data)
DIS->>DIS: shared_data["current_task"] = task_data
DIS->>DIS: task_event.set()
VS->>DIS: submit_task_async(task_data)
DIS->>TIW0: process_request(task_data)
Note over W0,W1: Distributed Processing
W0->>W0: task_event.wait()
W0->>W0: Get task from shared_data
W0->>W1: broadcast_task_data()
Note over TIW0,TIW1: Torchrun-based Distributed Processing
TIW0->>TIW0: Check if processing
TIW0->>TIW0: Set processing = True
par Parallel Inference
W0->>W0: run_pipeline()
alt Multi-GPU Mode (world_size > 1)
TIW0->>TIW1: broadcast_task_data()<br/>(via DistributedManager)
Note over TIW1: worker_loop() listens for broadcasts
TIW1->>TIW1: Receive task_data
end
par Parallel Inference across all ranks
TIW0->>TIW0: runner.set_inputs(task_data)
TIW0->>TIW0: runner.run_pipeline()
and
W1->>W1: run_pipeline()
Note over TIW1: If world_size > 1
TIW1->>TIW1: runner.set_inputs(task_data)
TIW1->>TIW1: runner.run_pipeline()
end
Note over TIW0,TIW1: Synchronization
alt Multi-GPU Mode
TIW0->>TIW1: barrier() for sync
TIW1->>TIW0: barrier() response
end
W0->>W0: barrier() for sync
W0->>W0: shared_data["result"] = result
W0->>DIS: result_event.set()
TIW0->>TIW0: Set processing = False
TIW0->>DIS: Return result (only rank 0)
TIW1->>TIW1: Return None (non-rank 0)
DIS->>DIS: result_event.wait()
DIS->>VS: return result
DIS-->>VS: TaskResponse
VS-->>PT: TaskResponse
PT->>TM: complete_task()<br/>(status: COMPLETED)
......
#!/usr/bin/env python
import argparse
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from lightx2v.server.main import run_server
from .main import run_server
def main():
parser = argparse.ArgumentParser(description="Run LightX2V inference server")
parser = argparse.ArgumentParser(description="LightX2V Server")
# Model arguments
parser.add_argument("--model_path", type=str, required=True, help="Path to model")
parser.add_argument("--model_cls", type=str, required=True, help="Model class name")
parser.add_argument("--config_json", type=str, help="Path to model config JSON file")
parser.add_argument("--task", type=str, default="i2v", help="Task type (i2v, etc.)")
parser.add_argument("--nproc_per_node", type=int, default=1, help="Number of processes per node (GPUs to use)")
# Server arguments
parser.add_argument("--host", type=str, default="0.0.0.0", help="Server host")
parser.add_argument("--port", type=int, default=8000, help="Server port")
parser.add_argument("--host", type=str, default="127.0.0.1", help="Server host")
args = parser.parse_args()
# Parse any additional arguments that might be passed
args, unknown = parser.parse_known_args()
# Add any unknown arguments as attributes to args
# This allows flexibility for model-specific arguments
for i in range(0, len(unknown), 2):
if unknown[i].startswith("--"):
key = unknown[i][2:]
if i + 1 < len(unknown) and not unknown[i + 1].startswith("--"):
value = unknown[i + 1]
setattr(args, key, value)
# Run the server
run_server(args)
......
......@@ -314,7 +314,6 @@ class ApiServer:
return False
def _ensure_processing_thread_running(self):
"""Ensure the processing thread is running."""
if self.processing_thread is None or not self.processing_thread.is_alive():
self.stop_processing.clear()
self.processing_thread = threading.Thread(target=self._task_processing_loop, daemon=True)
......@@ -322,9 +321,11 @@ class ApiServer:
logger.info("Started task processing thread")
def _task_processing_loop(self):
"""Main loop that processes tasks from the queue one by one."""
logger.info("Task processing loop started")
asyncio.set_event_loop(asyncio.new_event_loop())
loop = asyncio.get_event_loop()
while not self.stop_processing.is_set():
task_id = task_manager.get_next_pending_task()
......@@ -335,12 +336,12 @@ class ApiServer:
task_info = task_manager.get_task(task_id)
if task_info and task_info.status == TaskStatus.PENDING:
logger.info(f"Processing task {task_id}")
self._process_single_task(task_info)
loop.run_until_complete(self._process_single_task(task_info))
loop.close()
logger.info("Task processing loop stopped")
def _process_single_task(self, task_info: Any):
"""Process a single task."""
async def _process_single_task(self, task_info: Any):
assert self.video_service is not None, "Video service is not initialized"
task_id = task_info.task_id
......@@ -360,7 +361,7 @@ class ApiServer:
task_manager.fail_task(task_id, "Task cancelled")
return
result = asyncio.run(self.video_service.generate_video_with_stop_event(message, task_info.stop_event))
result = await self.video_service.generate_video_with_stop_event(message, task_info.stop_event)
if result:
task_manager.complete_task(task_id, result.save_video_path)
......
......@@ -11,9 +11,6 @@ class ServerConfig:
port: int = 8000
max_queue_size: int = 10
master_addr: str = "127.0.0.1"
master_port_range: tuple = (29500, 29600)
task_timeout: int = 300
task_history_limit: int = 1000
......@@ -42,31 +39,13 @@ class ServerConfig:
except ValueError:
logger.warning(f"Invalid max queue size: {env_queue_size}")
if env_master_addr := os.environ.get("MASTER_ADDR"):
config.master_addr = env_master_addr
# MASTER_ADDR is now managed by torchrun, no need to set manually
if env_cache_dir := os.environ.get("LIGHTX2V_CACHE_DIR"):
config.cache_dir = env_cache_dir
return config
def find_free_master_port(self) -> str:
import socket
for port in range(self.master_port_range[0], self.master_port_range[1]):
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind((self.master_addr, port))
logger.info(f"Found free port for master: {port}")
return str(port)
except OSError:
continue
raise RuntimeError(
f"No free port found for master in range {self.master_port_range[0]}-{self.master_port_range[1] - 1} "
f"on address {self.master_addr}. Please adjust 'master_port_range' or free an occupied port."
)
def validate(self) -> bool:
valid = True
......
......@@ -6,8 +6,6 @@ import torch
import torch.distributed as dist
from loguru import logger
from .gpu_manager import gpu_manager
class DistributedManager:
def __init__(self):
......@@ -18,29 +16,35 @@ class DistributedManager:
CHUNK_SIZE = 1024 * 1024
def init_process_group(self, rank: int, world_size: int, master_addr: str, master_port: str) -> bool:
def init_process_group(self) -> bool:
"""Initialize process group using torchrun environment variables"""
try:
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = master_port
backend = "nccl" if torch.cuda.is_available() else "gloo"
dist.init_process_group(backend=backend, init_method=f"tcp://{master_addr}:{master_port}", rank=rank, world_size=world_size)
logger.info(f"Setup backend: {backend}")
self.device = gpu_manager.set_device_for_rank(rank, world_size)
# torchrun sets these environment variables automatically
self.rank = int(os.environ.get("LOCAL_RANK", 0))
self.world_size = int(os.environ.get("WORLD_SIZE", 1))
if self.world_size > 1:
# torchrun handles backend, init_method, rank, and world_size
# We just need to call init_process_group without parameters
backend = "nccl" if torch.cuda.is_available() else "gloo"
dist.init_process_group(backend=backend, init_method="env://")
logger.info(f"Setup backend: {backend}")
# Set CUDA device for this rank
if torch.cuda.is_available():
torch.cuda.set_device(self.rank)
self.device = f"cuda:{self.rank}"
else:
self.device = "cpu"
else:
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.is_initialized = True
self.rank = rank
self.world_size = world_size
logger.info(f"Rank {rank}/{world_size - 1} distributed environment initialized successfully")
logger.info(f"Rank {self.rank}/{self.world_size - 1} distributed environment initialized successfully")
return True
except Exception as e:
logger.error(f"Rank {rank} distributed environment initialization failed: {str(e)}")
logger.error(f"Rank {self.rank} distributed environment initialization failed: {str(e)}")
return False
def cleanup(self):
......@@ -143,30 +147,3 @@ class DistributedManager:
task_bytes = self._receive_byte_chunks(total_length, broadcast_device)
task_data = pickle.loads(task_bytes)
return task_data
class DistributedWorker:
def __init__(self, rank: int, world_size: int, master_addr: str, master_port: str):
self.rank = rank
self.world_size = world_size
self.master_addr = master_addr
self.master_port = master_port
self.dist_manager = DistributedManager()
def init(self) -> bool:
return self.dist_manager.init_process_group(self.rank, self.world_size, self.master_addr, self.master_port)
def cleanup(self):
self.dist_manager.cleanup()
def sync_and_report(self, task_id: str, status: str, result_queue, **kwargs):
self.dist_manager.barrier()
if self.dist_manager.is_rank_zero():
result = {"task_id": task_id, "status": status, **kwargs}
result_queue.put(result)
logger.info(f"Task {task_id} {status}")
def create_distributed_worker(rank: int, world_size: int, master_addr: str, master_port: str) -> DistributedWorker:
return DistributedWorker(rank, world_size, master_addr, master_port)
import os
from typing import List, Optional, Tuple
import torch
from loguru import logger
class GPUManager:
def __init__(self):
self.available_gpus = self._detect_gpus()
self.gpu_count = len(self.available_gpus)
def _detect_gpus(self) -> List[int]:
if not torch.cuda.is_available():
logger.warning("No CUDA devices available, will use CPU")
return []
gpu_count = torch.cuda.device_count()
cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", "")
if cuda_visible:
try:
visible_devices = [int(d.strip()) for d in cuda_visible.split(",")]
logger.info(f"CUDA_VISIBLE_DEVICES set to: {visible_devices}")
return list(range(len(visible_devices)))
except ValueError:
logger.warning(f"Invalid CUDA_VISIBLE_DEVICES: {cuda_visible}, using all devices")
available_gpus = list(range(gpu_count))
logger.info(f"Detected {gpu_count} GPU devices: {available_gpus}")
return available_gpus
def get_device_for_rank(self, rank: int, world_size: int) -> str:
if not self.available_gpus:
logger.info(f"Rank {rank}: Using CPU (no GPUs available)")
return "cpu"
if self.gpu_count == 1:
device = f"cuda:{self.available_gpus[0]}"
logger.info(f"Rank {rank}: Using single GPU {device}")
return device
if self.gpu_count >= world_size:
gpu_id = self.available_gpus[rank % self.gpu_count]
device = f"cuda:{gpu_id}"
logger.info(f"Rank {rank}: Assigned to dedicated GPU {device}")
return device
else:
gpu_id = self.available_gpus[rank % self.gpu_count]
device = f"cuda:{gpu_id}"
logger.info(f"Rank {rank}: Sharing GPU {device} (world_size={world_size} > gpu_count={self.gpu_count})")
return device
def set_device_for_rank(self, rank: int, world_size: int) -> str:
device = self.get_device_for_rank(rank, world_size)
if device.startswith("cuda:"):
gpu_id = int(device.split(":")[1])
torch.cuda.set_device(gpu_id)
logger.info(f"Rank {rank}: CUDA device set to {gpu_id}")
return device
def get_memory_info(self, device: Optional[str] = None) -> Tuple[int, int]:
if not torch.cuda.is_available():
return (0, 0)
if device and device.startswith("cuda:"):
gpu_id = int(device.split(":")[1])
else:
gpu_id = torch.cuda.current_device()
try:
used = torch.cuda.memory_allocated(gpu_id)
total = torch.cuda.get_device_properties(gpu_id).total_memory
return (used, total)
except Exception as e:
logger.error(f"Failed to get memory info for device {gpu_id}: {e}")
return (0, 0)
def clear_cache(self, device: Optional[str] = None):
if not torch.cuda.is_available():
return
if device and device.startswith("cuda:"):
gpu_id = int(device.split(":")[1])
with torch.cuda.device(gpu_id):
torch.cuda.empty_cache()
torch.cuda.synchronize()
else:
torch.cuda.empty_cache()
torch.cuda.synchronize()
logger.info(f"GPU cache cleared for device: {device or 'current'}")
@staticmethod
def get_optimal_world_size(requested_world_size: int) -> int:
if not torch.cuda.is_available():
logger.warning("No GPUs available, using single process")
return 1
gpu_count = torch.cuda.device_count()
if requested_world_size <= 0:
optimal_size = gpu_count
logger.info(f"Auto-detected world_size: {optimal_size} (based on {gpu_count} GPUs)")
elif requested_world_size > gpu_count:
logger.warning(f"Requested world_size ({requested_world_size}) exceeds GPU count ({gpu_count}). Processes will share GPUs.")
optimal_size = requested_world_size
else:
optimal_size = requested_world_size
return optimal_size
gpu_manager = GPUManager()
import os
import sys
from pathlib import Path
......@@ -10,9 +11,14 @@ from .service import DistributedInferenceService
def run_server(args):
"""Run server with torchrun support"""
inference_service = None
try:
logger.info("Starting LightX2V server...")
# Get rank from environment (set by torchrun)
rank = int(os.environ.get("LOCAL_RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
logger.info(f"Starting LightX2V server (Rank {rank}/{world_size})...")
if hasattr(args, "host") and args.host:
server_config.host = args.host
......@@ -22,28 +28,37 @@ def run_server(args):
if not server_config.validate():
raise RuntimeError("Invalid server configuration")
# Initialize inference service
inference_service = DistributedInferenceService()
if not inference_service.start_distributed_inference(args):
raise RuntimeError("Failed to start distributed inference service")
logger.info("Inference service started successfully")
logger.info(f"Rank {rank}: Inference service started successfully")
if rank == 0:
# Only rank 0 runs the FastAPI server
cache_dir = Path(server_config.cache_dir)
cache_dir.mkdir(parents=True, exist_ok=True)
cache_dir = Path(server_config.cache_dir)
cache_dir.mkdir(parents=True, exist_ok=True)
api_server = ApiServer(max_queue_size=server_config.max_queue_size)
api_server.initialize_services(cache_dir, inference_service)
api_server = ApiServer(max_queue_size=server_config.max_queue_size)
api_server.initialize_services(cache_dir, inference_service)
app = api_server.get_app()
app = api_server.get_app()
logger.info(f"Starting FastAPI server on {server_config.host}:{server_config.port}")
uvicorn.run(app, host=server_config.host, port=server_config.port, log_level="info")
else:
# Non-rank-0 processes run the worker loop
logger.info(f"Rank {rank}: Starting worker loop")
import asyncio
logger.info(f"Starting server on {server_config.host}:{server_config.port}")
uvicorn.run(app, host=server_config.host, port=server_config.port, log_level="info")
asyncio.run(inference_service.run_worker_loop())
except KeyboardInterrupt:
logger.info("Server interrupted by user")
logger.info(f"Server rank {rank} interrupted by user")
if inference_service:
inference_service.stop_distributed_inference()
except Exception as e:
logger.error(f"Server failed: {e}")
logger.error(f"Server rank {rank} failed: {e}")
if inference_service:
inference_service.stop_distributed_inference()
sys.exit(1)
This diff is collapsed.
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
lightx2v_path=/mnt/afs/users/lijiaqi2/deploy-comfyui-ljq-custom_nodes/ComfyUI-Lightx2vWrapper/lightx2v
model_path=/mnt/afs/users/lijiaqi2/wan_model/Wan2.1-R2V0909-Audio-14B-720P-fp8
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export CUDA_VISIBLE_DEVICES=0,1,2,3
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
# Start multiple servers
python -m lightx2v.api_multi_servers \
--num_gpus $num_gpus \
--start_port 8000 \
--model_cls wan2.1_distill \
torchrun --nproc_per_node 4 -m lightx2v.server \
--model_cls seko_talk \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/distill/wan_i2v_distill_4step_cfg.json
--config_json ${lightx2v_path}/configs/seko_talk/xxx_dist.json \
--port 8000
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
lightx2v_path=/path/to/Lightx2v
model_path=/path/to/Wan2.1-R2V0909-Audio-14B-720P-fp8
export CUDA_VISIBLE_DEVICES=0
......@@ -11,12 +11,11 @@ source ${lightx2v_path}/scripts/base/base.sh
# Start API server with distributed inference service
python -m lightx2v.api_server \
--model_cls wan2.1_distill \
python -m lightx2v.server \
--model_cls seko_talk \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/distill/wan_i2v_distill_4step_cfg.json \
--port 8000 \
--nproc_per_node 1
--config_json ${lightx2v_path}/configs/seko_talk/seko_talk_05_offload_fp8_4090.json \
--port 8000
echo "Service stopped"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment