api_server.py 3.03 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import argparse
PengGao's avatar
PengGao committed
2
3
4
5
import sys
import signal
import atexit
from pathlib import Path
root's avatar
root committed
6
from loguru import logger
helloyongyang's avatar
helloyongyang committed
7
8
import uvicorn

PengGao's avatar
PengGao committed
9
10
11
from lightx2v.server.api import ApiServer
from lightx2v.server.service import DistributedInferenceService
from lightx2v.server.utils import ProcessManager
12
13


PengGao's avatar
PengGao committed
14
15
def create_signal_handler(inference_service: DistributedInferenceService):
    """Create unified signal handler function"""
16

PengGao's avatar
PengGao committed
17
18
    def signal_handler(signum, frame):
        logger.info(f"Received signal {signum}, gracefully shutting down...")
19
        try:
PengGao's avatar
PengGao committed
20
21
22
23
            if inference_service.is_running:
                inference_service.stop_distributed_inference()
        except Exception as e:
            logger.error(f"Error occurred while shutting down distributed inference service: {str(e)}")
24
        finally:
PengGao's avatar
PengGao committed
25
            sys.exit(0)
26

PengGao's avatar
PengGao committed
27
    return signal_handler
28
29


PengGao's avatar
PengGao committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_cls",
        type=str,
        required=True,
        choices=[
            "wan2.1",
            "hunyuan",
            "wan2.1_causvid",
            "wan2.1_skyreels_v2_df",
            "wan2.1_audio",
        ],
        default="hunyuan",
    )
    parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
    parser.add_argument("--model_path", type=str, required=True)
    parser.add_argument("--config_json", type=str, required=True)
48

PengGao's avatar
PengGao committed
49
50
51
52
53
    parser.add_argument("--split", action="store_true")
    parser.add_argument("--lora_path", type=str, required=False, default=None)
    parser.add_argument("--port", type=int, default=8000)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--nproc_per_node", type=int, default=1, help="Number of processes per node for distributed inference")
54

PengGao's avatar
PengGao committed
55
56
    args = parser.parse_args()
    logger.info(f"args: {args}")
57

PengGao's avatar
PengGao committed
58
59
    cache_dir = Path(__file__).parent.parent / ".cache"
    inference_service = DistributedInferenceService()
60

PengGao's avatar
PengGao committed
61
62
    api_server = ApiServer()
    api_server.initialize_services(cache_dir, inference_service)
63

PengGao's avatar
PengGao committed
64
65
66
    signal_handler = create_signal_handler(inference_service)
    signal.signal(signal.SIGTERM, signal_handler)
    signal.signal(signal.SIGINT, signal_handler)
67

PengGao's avatar
PengGao committed
68
69
70
71
72
    logger.info("Starting distributed inference service...")
    success = inference_service.start_distributed_inference(args)
    if not success:
        logger.error("Failed to start distributed inference service, exiting program")
        sys.exit(1)
73

PengGao's avatar
PengGao committed
74
    atexit.register(inference_service.stop_distributed_inference)
helloyongyang's avatar
helloyongyang committed
75

PengGao's avatar
PengGao committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
    try:
        logger.info(f"Starting FastAPI server on port: {args.port}")
        uvicorn.run(
            api_server.get_app(),
            host="0.0.0.0",
            port=args.port,
            reload=False,
            workers=1,
        )
    except KeyboardInterrupt:
        logger.info("Received KeyboardInterrupt, shutting down service...")
    except Exception as e:
        logger.error(f"Error occurred while running FastAPI server: {str(e)}")
    finally:
        inference_service.stop_distributed_inference()
helloyongyang's avatar
helloyongyang committed
91

92

helloyongyang's avatar
helloyongyang committed
93
if __name__ == "__main__":
PengGao's avatar
PengGao committed
94
    main()