main.py 2.27 KB
Newer Older
PengGao's avatar
PengGao committed
1
import os
gaclove's avatar
gaclove committed
2
3
4
5
6
7
8
9
10
11
12
13
import sys
from pathlib import Path

import uvicorn
from loguru import logger

from .api import ApiServer
from .config import server_config
from .service import DistributedInferenceService


def run_server(args):
PengGao's avatar
PengGao committed
14
    """Run server with torchrun support"""
gaclove's avatar
gaclove committed
15
16
    inference_service = None
    try:
PengGao's avatar
PengGao committed
17
18
19
20
21
        # Get rank from environment (set by torchrun)
        rank = int(os.environ.get("LOCAL_RANK", 0))
        world_size = int(os.environ.get("WORLD_SIZE", 1))

        logger.info(f"Starting LightX2V server (Rank {rank}/{world_size})...")
gaclove's avatar
gaclove committed
22
23
24
25
26
27
28
29
30

        if hasattr(args, "host") and args.host:
            server_config.host = args.host
        if hasattr(args, "port") and args.port:
            server_config.port = args.port

        if not server_config.validate():
            raise RuntimeError("Invalid server configuration")

PengGao's avatar
PengGao committed
31
        # Initialize inference service
gaclove's avatar
gaclove committed
32
33
34
        inference_service = DistributedInferenceService()
        if not inference_service.start_distributed_inference(args):
            raise RuntimeError("Failed to start distributed inference service")
PengGao's avatar
PengGao committed
35
36
37
38
39
40
        logger.info(f"Rank {rank}: Inference service started successfully")

        if rank == 0:
            # Only rank 0 runs the FastAPI server
            cache_dir = Path(server_config.cache_dir)
            cache_dir.mkdir(parents=True, exist_ok=True)
gaclove's avatar
gaclove committed
41

PengGao's avatar
PengGao committed
42
43
            api_server = ApiServer(max_queue_size=server_config.max_queue_size)
            api_server.initialize_services(cache_dir, inference_service)
gaclove's avatar
gaclove committed
44

PengGao's avatar
PengGao committed
45
            app = api_server.get_app()
gaclove's avatar
gaclove committed
46

PengGao's avatar
PengGao committed
47
48
49
50
51
52
            logger.info(f"Starting FastAPI server on {server_config.host}:{server_config.port}")
            uvicorn.run(app, host=server_config.host, port=server_config.port, log_level="info")
        else:
            # Non-rank-0 processes run the worker loop
            logger.info(f"Rank {rank}: Starting worker loop")
            import asyncio
gaclove's avatar
gaclove committed
53

PengGao's avatar
PengGao committed
54
            asyncio.run(inference_service.run_worker_loop())
gaclove's avatar
gaclove committed
55
56

    except KeyboardInterrupt:
PengGao's avatar
PengGao committed
57
        logger.info(f"Server rank {rank} interrupted by user")
gaclove's avatar
gaclove committed
58
59
60
        if inference_service:
            inference_service.stop_distributed_inference()
    except Exception as e:
PengGao's avatar
PengGao committed
61
        logger.error(f"Server rank {rank} failed: {e}")
gaclove's avatar
gaclove committed
62
63
64
        if inference_service:
            inference_service.stop_distributed_inference()
        sys.exit(1)