api_server.py 2.84 KB
Newer Older
1
2
3
import signal
import sys
import psutil
helloyongyang's avatar
helloyongyang committed
4
import argparse
5
from fastapi import FastAPI, Request
helloyongyang's avatar
helloyongyang committed
6
from pydantic import BaseModel
root's avatar
root committed
7
from loguru import logger
helloyongyang's avatar
helloyongyang committed
8
9
import uvicorn
import json
10
import asyncio
helloyongyang's avatar
helloyongyang committed
11
12
13
14
15
16

from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.set_config import set_config
from lightx2v.infer import init_runner


17
18
19
20
21
22
23
24
25
26
27
28
29
# =========================
# Signal & Process Control
# =========================


def kill_all_related_processes():
    """Kill the current process and all its child processes"""
    current_process = psutil.Process()
    children = current_process.children(recursive=True)
    for child in children:
        try:
            child.kill()
        except Exception as e:
root's avatar
root committed
30
            logger.info(f"Failed to kill child process {child.pid}: {e}")
31
32
33
    try:
        current_process.kill()
    except Exception as e:
root's avatar
root committed
34
        logger.info(f"Failed to kill main process: {e}")
35
36
37


def signal_handler(sig, frame):
root's avatar
root committed
38
    logger.info("\nReceived Ctrl+C, shutting down all related processes...")
39
40
41
42
43
44
45
46
47
48
49
50
51
    kill_all_related_processes()
    sys.exit(0)


# =========================
# FastAPI Related Code
# =========================

runner = None

app = FastAPI()


helloyongyang's avatar
helloyongyang committed
52
53
class Message(BaseModel):
    prompt: str
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
54
    use_prompt_enhancer: bool = False
helloyongyang's avatar
helloyongyang committed
55
56
    negative_prompt: str = ""
    image_path: str = ""
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
57
    num_fragments: int = 1
helloyongyang's avatar
helloyongyang committed
58
59
60
61
62
63
    save_video_path: str

    def get(self, key, default=None):
        return getattr(self, key, default)


64
@app.post("/v1/local/video/generate")
helloyongyang's avatar
helloyongyang committed
65
async def v1_local_video_generate(message: Message):
66
    global runner
helloyongyang's avatar
helloyongyang committed
67
    runner.set_inputs(message)
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
68
    logger.info(f"message: {message}")
69
    await asyncio.to_thread(runner.run_pipeline)
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
70
71
72
73
    response = {"response": "finished", "save_video_path": message.save_video_path}
    if runner.has_prompt_enhancer and message.use_prompt_enhancer:
        response["enhanced_prompt"] = runner.config["prompt"]
    return response
helloyongyang's avatar
helloyongyang committed
74
75


76
77
78
79
# =========================
# Main Entry
# =========================

helloyongyang's avatar
helloyongyang committed
80
if __name__ == "__main__":
81
    signal.signal(signal.SIGINT, signal_handler)
helloyongyang's avatar
helloyongyang committed
82
    parser = argparse.ArgumentParser()
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
83
    parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causvid"], default="hunyuan")
helloyongyang's avatar
helloyongyang committed
84
85
86
    parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
    parser.add_argument("--model_path", type=str, required=True)
    parser.add_argument("--config_json", type=str, required=True)
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
87
    parser.add_argument("--prompt_enhancer", default=None)
helloyongyang's avatar
helloyongyang committed
88
89
    parser.add_argument("--port", type=int, default=8000)
    args = parser.parse_args()
root's avatar
root committed
90
    logger.info(f"args: {args}")
helloyongyang's avatar
helloyongyang committed
91
92
93

    with ProfilingContext("Init Server Cost"):
        config = set_config(args)
root's avatar
root committed
94
        logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
helloyongyang's avatar
helloyongyang committed
95
96
        runner = init_runner(config)

97
    uvicorn.run(app, host="0.0.0.0", port=config.port, reload=False, workers=1)