api_server.py 3.07 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import argparse
PengGao's avatar
PengGao committed
2
3
4
5
import sys
import signal
import atexit
from pathlib import Path
root's avatar
root committed
6
from loguru import logger
helloyongyang's avatar
helloyongyang committed
7
8
import uvicorn

PengGao's avatar
PengGao committed
9
10
11
from lightx2v.server.api import ApiServer
from lightx2v.server.service import DistributedInferenceService
from lightx2v.server.utils import ProcessManager
12
13


PengGao's avatar
PengGao committed
14
15
def create_signal_handler(inference_service: DistributedInferenceService):
    """Create unified signal handler function"""
16

PengGao's avatar
PengGao committed
17
18
    def signal_handler(signum, frame):
        logger.info(f"Received signal {signum}, gracefully shutting down...")
19
        try:
PengGao's avatar
PengGao committed
20
21
22
23
            if inference_service.is_running:
                inference_service.stop_distributed_inference()
        except Exception as e:
            logger.error(f"Error occurred while shutting down distributed inference service: {str(e)}")
24
        finally:
PengGao's avatar
PengGao committed
25
            sys.exit(0)
26

PengGao's avatar
PengGao committed
27
    return signal_handler
28
29


PengGao's avatar
PengGao committed
30
31
32
33
34
35
36
37
38
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_cls",
        type=str,
        required=True,
        choices=[
            "wan2.1",
            "hunyuan",
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
39
            "wan2.1_distill",
PengGao's avatar
PengGao committed
40
41
42
43
44
45
46
47
48
            "wan2.1_causvid",
            "wan2.1_skyreels_v2_df",
            "wan2.1_audio",
        ],
        default="hunyuan",
    )
    parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
    parser.add_argument("--model_path", type=str, required=True)
    parser.add_argument("--config_json", type=str, required=True)
49

PengGao's avatar
PengGao committed
50
51
52
53
54
    parser.add_argument("--split", action="store_true")
    parser.add_argument("--lora_path", type=str, required=False, default=None)
    parser.add_argument("--port", type=int, default=8000)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--nproc_per_node", type=int, default=1, help="Number of processes per node for distributed inference")
55

PengGao's avatar
PengGao committed
56
57
    args = parser.parse_args()
    logger.info(f"args: {args}")
58

Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
59
    cache_dir = Path(__file__).parent.parent / "server_cache"
PengGao's avatar
PengGao committed
60
    inference_service = DistributedInferenceService()
61

PengGao's avatar
PengGao committed
62
63
    api_server = ApiServer()
    api_server.initialize_services(cache_dir, inference_service)
64

PengGao's avatar
PengGao committed
65
66
67
    signal_handler = create_signal_handler(inference_service)
    signal.signal(signal.SIGTERM, signal_handler)
    signal.signal(signal.SIGINT, signal_handler)
68

PengGao's avatar
PengGao committed
69
70
71
72
73
    logger.info("Starting distributed inference service...")
    success = inference_service.start_distributed_inference(args)
    if not success:
        logger.error("Failed to start distributed inference service, exiting program")
        sys.exit(1)
74

PengGao's avatar
PengGao committed
75
    atexit.register(inference_service.stop_distributed_inference)
helloyongyang's avatar
helloyongyang committed
76

PengGao's avatar
PengGao committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
    try:
        logger.info(f"Starting FastAPI server on port: {args.port}")
        uvicorn.run(
            api_server.get_app(),
            host="0.0.0.0",
            port=args.port,
            reload=False,
            workers=1,
        )
    except KeyboardInterrupt:
        logger.info("Received KeyboardInterrupt, shutting down service...")
    except Exception as e:
        logger.error(f"Error occurred while running FastAPI server: {str(e)}")
    finally:
        inference_service.stop_distributed_inference()
helloyongyang's avatar
helloyongyang committed
92

93

helloyongyang's avatar
helloyongyang committed
94
if __name__ == "__main__":
PengGao's avatar
PengGao committed
95
    main()