api_server.py 2.53 KB
Newer Older
1
2
3
import signal
import sys
import psutil
helloyongyang's avatar
helloyongyang committed
4
import argparse
5
from fastapi import FastAPI, Request
helloyongyang's avatar
helloyongyang committed
6
from pydantic import BaseModel
root's avatar
root committed
7
from loguru import logger
helloyongyang's avatar
helloyongyang committed
8
9
import uvicorn
import json
10
import asyncio
helloyongyang's avatar
helloyongyang committed
11
12
13
14
15
16

from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.set_config import set_config
from lightx2v.infer import init_runner


17
18
19
20
21
22
23
24
25
26
27
28
29
# =========================
# Signal & Process Control
# =========================


def kill_all_related_processes():
    """Kill the current process and all its child processes"""
    current_process = psutil.Process()
    children = current_process.children(recursive=True)
    for child in children:
        try:
            child.kill()
        except Exception as e:
root's avatar
root committed
30
            logger.info(f"Failed to kill child process {child.pid}: {e}")
31
32
33
    try:
        current_process.kill()
    except Exception as e:
root's avatar
root committed
34
        logger.info(f"Failed to kill main process: {e}")
35
36
37


def signal_handler(sig, frame):
root's avatar
root committed
38
    logger.info("\nReceived Ctrl+C, shutting down all related processes...")
39
40
41
42
43
44
45
46
47
48
49
50
51
    kill_all_related_processes()
    sys.exit(0)


# =========================
# FastAPI Related Code
# =========================

runner = None

app = FastAPI()


helloyongyang's avatar
helloyongyang committed
52
53
54
55
56
57
58
59
60
61
class Message(BaseModel):
    prompt: str
    negative_prompt: str = ""
    image_path: str = ""
    save_video_path: str

    def get(self, key, default=None):
        return getattr(self, key, default)


62
@app.post("/v1/local/video/generate")
helloyongyang's avatar
helloyongyang committed
63
async def v1_local_video_generate(message: Message):
64
    global runner
helloyongyang's avatar
helloyongyang committed
65
    runner.set_inputs(message)
66
67
    await asyncio.to_thread(runner.run_pipeline)
    return {"response": "finished", "save_video_path": message.save_video_path}
helloyongyang's avatar
helloyongyang committed
68
69


70
71
72
73
# =========================
# Main Entry
# =========================

helloyongyang's avatar
helloyongyang committed
74
if __name__ == "__main__":
75
    signal.signal(signal.SIGINT, signal_handler)
helloyongyang's avatar
helloyongyang committed
76
77
78
79
80
81
82
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causal"], default="hunyuan")
    parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
    parser.add_argument("--model_path", type=str, required=True)
    parser.add_argument("--config_json", type=str, required=True)
    parser.add_argument("--port", type=int, default=8000)
    args = parser.parse_args()
root's avatar
root committed
83
    logger.info(f"args: {args}")
helloyongyang's avatar
helloyongyang committed
84
85
86

    with ProfilingContext("Init Server Cost"):
        config = set_config(args)
root's avatar
root committed
87
        logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
helloyongyang's avatar
helloyongyang committed
88
89
        runner = init_runner(config)

90
    uvicorn.run(app, host="0.0.0.0", port=config.port, reload=False, workers=1)