main.py 2.04 KB
Newer Older
PengGao's avatar
PengGao committed
1
import os
gaclove's avatar
gaclove committed
2
3
4
5
6
7
8
9
import sys
from pathlib import Path

import uvicorn
from loguru import logger

from .api import ApiServer
from .config import server_config
PengGao's avatar
PengGao committed
10
from .services import DistributedInferenceService
gaclove's avatar
gaclove committed
11
12
13
14
15


def run_server(args):
    inference_service = None
    try:
PengGao's avatar
PengGao committed
16
17
18
19
        rank = int(os.environ.get("LOCAL_RANK", 0))
        world_size = int(os.environ.get("WORLD_SIZE", 1))

        logger.info(f"Starting LightX2V server (Rank {rank}/{world_size})...")
gaclove's avatar
gaclove committed
20
21
22
23
24
25
26
27
28
29
30
31

        if hasattr(args, "host") and args.host:
            server_config.host = args.host
        if hasattr(args, "port") and args.port:
            server_config.port = args.port

        if not server_config.validate():
            raise RuntimeError("Invalid server configuration")

        inference_service = DistributedInferenceService()
        if not inference_service.start_distributed_inference(args):
            raise RuntimeError("Failed to start distributed inference service")
PengGao's avatar
PengGao committed
32
33
34
35
36
        logger.info(f"Rank {rank}: Inference service started successfully")

        if rank == 0:
            cache_dir = Path(server_config.cache_dir)
            cache_dir.mkdir(parents=True, exist_ok=True)
gaclove's avatar
gaclove committed
37

PengGao's avatar
PengGao committed
38
39
            api_server = ApiServer(max_queue_size=server_config.max_queue_size)
            api_server.initialize_services(cache_dir, inference_service)
gaclove's avatar
gaclove committed
40

PengGao's avatar
PengGao committed
41
            app = api_server.get_app()
gaclove's avatar
gaclove committed
42

PengGao's avatar
PengGao committed
43
44
45
46
47
            logger.info(f"Starting FastAPI server on {server_config.host}:{server_config.port}")
            uvicorn.run(app, host=server_config.host, port=server_config.port, log_level="info")
        else:
            logger.info(f"Rank {rank}: Starting worker loop")
            import asyncio
gaclove's avatar
gaclove committed
48

PengGao's avatar
PengGao committed
49
            asyncio.run(inference_service.run_worker_loop())
gaclove's avatar
gaclove committed
50
51

    except KeyboardInterrupt:
PengGao's avatar
PengGao committed
52
        logger.info(f"Server rank {rank} interrupted by user")
gaclove's avatar
gaclove committed
53
54
55
        if inference_service:
            inference_service.stop_distributed_inference()
    except Exception as e:
PengGao's avatar
PengGao committed
56
        logger.error(f"Server rank {rank} failed: {e}")
gaclove's avatar
gaclove committed
57
58
59
        if inference_service:
            inference_service.stop_distributed_inference()
        sys.exit(1)