api_server.py 3.21 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import argparse
PengGao's avatar
PengGao committed
2
import atexit
PengGao's avatar
PengGao committed
3
4
import signal
import sys
PengGao's avatar
PengGao committed
5
from pathlib import Path
PengGao's avatar
PengGao committed
6

helloyongyang's avatar
helloyongyang committed
7
import uvicorn
PengGao's avatar
PengGao committed
8
from loguru import logger
helloyongyang's avatar
helloyongyang committed
9

PengGao's avatar
PengGao committed
10
11
12
from lightx2v.server.api import ApiServer
from lightx2v.server.service import DistributedInferenceService
from lightx2v.server.utils import ProcessManager
13
14


PengGao's avatar
PengGao committed
15
16
def create_signal_handler(inference_service: DistributedInferenceService):
    """Create unified signal handler function"""
17

PengGao's avatar
PengGao committed
18
19
    def signal_handler(signum, frame):
        logger.info(f"Received signal {signum}, gracefully shutting down...")
20
        try:
PengGao's avatar
PengGao committed
21
22
23
24
            if inference_service.is_running:
                inference_service.stop_distributed_inference()
        except Exception as e:
            logger.error(f"Error occurred while shutting down distributed inference service: {str(e)}")
25
        finally:
PengGao's avatar
PengGao committed
26
            sys.exit(0)
27

PengGao's avatar
PengGao committed
28
    return signal_handler
29
30


PengGao's avatar
PengGao committed
31
32
33
34
35
36
37
38
39
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_cls",
        type=str,
        required=True,
        choices=[
            "wan2.1",
            "hunyuan",
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
40
            "wan2.1_distill",
PengGao's avatar
PengGao committed
41
42
43
            "wan2.1_causvid",
            "wan2.1_skyreels_v2_df",
            "wan2.1_audio",
helloyongyang's avatar
helloyongyang committed
44
            "wan2.2_moe",
PengGao's avatar
PengGao committed
45
        ],
helloyongyang's avatar
helloyongyang committed
46
        default="wan2.1",
PengGao's avatar
PengGao committed
47
48
49
50
    )
    parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
    parser.add_argument("--model_path", type=str, required=True)
    parser.add_argument("--config_json", type=str, required=True)
51

PengGao's avatar
PengGao committed
52
53
    parser.add_argument("--split", action="store_true")
    parser.add_argument("--lora_path", type=str, required=False, default=None)
54
    parser.add_argument("--lora_strength", type=float, default=1.0, help="The strength for the lora (default: 1.0)")
PengGao's avatar
PengGao committed
55
56
57
    parser.add_argument("--port", type=int, default=8000)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--nproc_per_node", type=int, default=1, help="Number of processes per node for distributed inference")
58

PengGao's avatar
PengGao committed
59
60
    args = parser.parse_args()
    logger.info(f"args: {args}")
61

Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
62
    cache_dir = Path(__file__).parent.parent / "server_cache"
PengGao's avatar
PengGao committed
63
    inference_service = DistributedInferenceService()
64

PengGao's avatar
PengGao committed
65
66
    api_server = ApiServer()
    api_server.initialize_services(cache_dir, inference_service)
67

PengGao's avatar
PengGao committed
68
69
70
    signal_handler = create_signal_handler(inference_service)
    signal.signal(signal.SIGTERM, signal_handler)
    signal.signal(signal.SIGINT, signal_handler)
71

PengGao's avatar
PengGao committed
72
73
74
75
76
    logger.info("Starting distributed inference service...")
    success = inference_service.start_distributed_inference(args)
    if not success:
        logger.error("Failed to start distributed inference service, exiting program")
        sys.exit(1)
77

PengGao's avatar
PengGao committed
78
    atexit.register(inference_service.stop_distributed_inference)
helloyongyang's avatar
helloyongyang committed
79

PengGao's avatar
PengGao committed
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
    try:
        logger.info(f"Starting FastAPI server on port: {args.port}")
        uvicorn.run(
            api_server.get_app(),
            host="0.0.0.0",
            port=args.port,
            reload=False,
            workers=1,
        )
    except KeyboardInterrupt:
        logger.info("Received KeyboardInterrupt, shutting down service...")
    except Exception as e:
        logger.error(f"Error occurred while running FastAPI server: {str(e)}")
    finally:
        inference_service.stop_distributed_inference()
helloyongyang's avatar
helloyongyang committed
95

96

helloyongyang's avatar
helloyongyang committed
97
if __name__ == "__main__":
PengGao's avatar
PengGao committed
98
    main()