api_server.py 5.92 KB
Newer Older
1
2
3
import signal
import sys
import psutil
helloyongyang's avatar
helloyongyang committed
4
import argparse
5
from fastapi import FastAPI
helloyongyang's avatar
helloyongyang committed
6
from pydantic import BaseModel
root's avatar
root committed
7
from loguru import logger
helloyongyang's avatar
helloyongyang committed
8
9
import uvicorn
import json
10
11
12
from typing import Optional
from datetime import datetime
import threading
helloyongyang's avatar
helloyongyang committed
13
14
15
16
17
18

from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.set_config import set_config
from lightx2v.infer import init_runner


19
20
21
22
23
24
25
26
27
28
29
30
31
# =========================
# Signal & Process Control
# =========================


def kill_all_related_processes():
    """Kill the current process and all its child processes"""
    current_process = psutil.Process()
    children = current_process.children(recursive=True)
    for child in children:
        try:
            child.kill()
        except Exception as e:
root's avatar
root committed
32
            logger.info(f"Failed to kill child process {child.pid}: {e}")
33
34
35
    try:
        current_process.kill()
    except Exception as e:
root's avatar
root committed
36
        logger.info(f"Failed to kill main process: {e}")
37
38
39


def signal_handler(sig, frame):
root's avatar
root committed
40
    logger.info("\nReceived Ctrl+C, shutting down all related processes...")
41
42
43
44
45
46
47
48
49
50
51
52
53
    kill_all_related_processes()
    sys.exit(0)


# =========================
# FastAPI Related Code
# =========================

runner = None

app = FastAPI()


helloyongyang's avatar
helloyongyang committed
54
class Message(BaseModel):
55
56
57
    task_id: str
    task_id_must_unique: bool = False

helloyongyang's avatar
helloyongyang committed
58
    prompt: str
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
59
    use_prompt_enhancer: bool = False
helloyongyang's avatar
helloyongyang committed
60
61
    negative_prompt: str = ""
    image_path: str = ""
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
62
    num_fragments: int = 1
helloyongyang's avatar
helloyongyang committed
63
64
65
66
67
68
    save_video_path: str

    def get(self, key, default=None):
        return getattr(self, key, default)


69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
class TaskStatusMessage(BaseModel):
    task_id: str


class ServiceStatus:
    _lock = threading.Lock()
    _current_task = None
    _result_store = {}

    @classmethod
    def start_task(cls, message: Message):
        with cls._lock:
            if cls._current_task is not None:
                raise RuntimeError("Service busy")
            if message.task_id_must_unique and message.task_id in cls._result_store:
                raise RuntimeError(f"Task ID {message.task_id} already exists")
            cls._current_task = {"message": message, "start_time": datetime.now()}
            return message.task_id

    @classmethod
    def complete_task(cls, message: Message):
        with cls._lock:
            cls._result_store[message.task_id] = {"success": True, "message": message, "start_time": cls._current_task["start_time"], "completion_time": datetime.now()}
            cls._current_task = None

    @classmethod
    def record_failed_task(cls, message: Message, error: Optional[str] = None):
        """Record a failed task with an error message."""
        with cls._lock:
            cls._result_store[message.task_id] = {"success": False, "message": message, "start_time": cls._current_task["start_time"], "error": error}
            cls._current_task = None

    @classmethod
    def get_status_task_id(cls, task_id: str):
        with cls._lock:
            if cls._current_task and cls._current_task["message"].task_id == task_id:
                return {"task_status": "processing"}
            if task_id in cls._result_store:
                return {"task_status": "completed", **cls._result_store[task_id]}
            return {"task_status": "not_found"}

    @classmethod
    def get_status_service(cls):
        with cls._lock:
            if cls._current_task:
                return {"service_status": "busy", "task_id": cls._current_task["message"].task_id}
            return {"service_status": "idle"}

    @classmethod
    def get_all_tasks(cls):
        with cls._lock:
            return cls._result_store


def local_video_generate(message: Message):
    try:
        global runner
        runner.set_inputs(message)
        logger.info(f"message: {message}")
        runner.run_pipeline()
        ServiceStatus.complete_task(message)
    except Exception as e:
        logger.error(f"task_id {message.task_id} failed: {str(e)}")
        ServiceStatus.record_failed_task(message, error=str(e))


135
@app.post("/v1/local/video/generate")
helloyongyang's avatar
helloyongyang committed
136
async def v1_local_video_generate(message: Message):
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
    try:
        task_id = ServiceStatus.start_task(message)
        # Use background threads to perform long-running tasks
        threading.Thread(target=local_video_generate, args=(message,), daemon=True).start()
        return {"task_id": task_id, "task_status": "processing"}
    except RuntimeError as e:
        return {"error": str(e)}


@app.get("/v1/local/video/generate/service_status")
async def get_service_status():
    return ServiceStatus.get_status_service()


@app.get("/v1/local/video/generate/get_all_tasks")
async def get_all_tasks():
    return ServiceStatus.get_all_tasks()


@app.post("/v1/local/video/generate/task_status")
async def get_task_status(message: TaskStatusMessage):
    return ServiceStatus.get_status_task_id(message.task_id)


# TODO: Implement delete task. Stop the specified task and clean many things.
# @app.delete("/v1/local/video/generate/task_status")
# async def delete_task(message: TaskStatusMessage):
helloyongyang's avatar
helloyongyang committed
164
165


166
167
168
169
# =========================
# Main Entry
# =========================

helloyongyang's avatar
helloyongyang committed
170
if __name__ == "__main__":
171
    signal.signal(signal.SIGINT, signal_handler)
helloyongyang's avatar
helloyongyang committed
172
    parser = argparse.ArgumentParser()
helloyongyang's avatar
helloyongyang committed
173
    parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causvid", "wan2.1_skyreels_v2_df"], default="hunyuan")
helloyongyang's avatar
helloyongyang committed
174
175
176
    parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
    parser.add_argument("--model_path", type=str, required=True)
    parser.add_argument("--config_json", type=str, required=True)
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
177
    parser.add_argument("--prompt_enhancer", default=None)
helloyongyang's avatar
helloyongyang committed
178

helloyongyang's avatar
helloyongyang committed
179
180
    parser.add_argument("--port", type=int, default=8000)
    args = parser.parse_args()
root's avatar
root committed
181
    logger.info(f"args: {args}")
helloyongyang's avatar
helloyongyang committed
182
183
184

    with ProfilingContext("Init Server Cost"):
        config = set_config(args)
root's avatar
root committed
185
        logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
helloyongyang's avatar
helloyongyang committed
186
187
        runner = init_runner(config)

188
    uvicorn.run(app, host="0.0.0.0", port=config.port, reload=False, workers=1)