api_server.py 7.33 KB
Newer Older
1
2
3
import signal
import sys
import psutil
helloyongyang's avatar
helloyongyang committed
4
import argparse
5
from fastapi import FastAPI
helloyongyang's avatar
helloyongyang committed
6
from pydantic import BaseModel
root's avatar
root committed
7
from loguru import logger
helloyongyang's avatar
helloyongyang committed
8
9
import uvicorn
import json
10
11
12
from typing import Optional
from datetime import datetime
import threading
13
14
15
import ctypes
import gc
import torch
helloyongyang's avatar
helloyongyang committed
16
17
18
19
20
21

from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.set_config import set_config
from lightx2v.infer import init_runner


22
23
24
25
26
27
28
29
30
31
32
33
34
# =========================
# Signal & Process Control
# =========================


def kill_all_related_processes():
    """Kill the current process and all its child processes"""
    current_process = psutil.Process()
    children = current_process.children(recursive=True)
    for child in children:
        try:
            child.kill()
        except Exception as e:
root's avatar
root committed
35
            logger.info(f"Failed to kill child process {child.pid}: {e}")
36
37
38
    try:
        current_process.kill()
    except Exception as e:
root's avatar
root committed
39
        logger.info(f"Failed to kill main process: {e}")
40
41
42


def signal_handler(sig, frame):
root's avatar
root committed
43
    logger.info("\nReceived Ctrl+C, shutting down all related processes...")
44
45
46
47
48
49
50
51
52
    kill_all_related_processes()
    sys.exit(0)


# =========================
# FastAPI Related Code
# =========================

runner = None
53
thread = None
54
55
56
57

app = FastAPI()


helloyongyang's avatar
helloyongyang committed
58
class Message(BaseModel):
59
60
61
    task_id: str
    task_id_must_unique: bool = False

helloyongyang's avatar
helloyongyang committed
62
    prompt: str
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
63
    use_prompt_enhancer: bool = False
helloyongyang's avatar
helloyongyang committed
64
65
    negative_prompt: str = ""
    image_path: str = ""
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
66
    num_fragments: int = 1
helloyongyang's avatar
helloyongyang committed
67
68
69
70
71
72
    save_video_path: str

    def get(self, key, default=None):
        return getattr(self, key, default)


73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
class TaskStatusMessage(BaseModel):
    task_id: str


class ServiceStatus:
    _lock = threading.Lock()
    _current_task = None
    _result_store = {}

    @classmethod
    def start_task(cls, message: Message):
        with cls._lock:
            if cls._current_task is not None:
                raise RuntimeError("Service busy")
            if message.task_id_must_unique and message.task_id in cls._result_store:
                raise RuntimeError(f"Task ID {message.task_id} already exists")
            cls._current_task = {"message": message, "start_time": datetime.now()}
            return message.task_id

    @classmethod
    def complete_task(cls, message: Message):
        with cls._lock:
            cls._result_store[message.task_id] = {"success": True, "message": message, "start_time": cls._current_task["start_time"], "completion_time": datetime.now()}
            cls._current_task = None

    @classmethod
    def record_failed_task(cls, message: Message, error: Optional[str] = None):
        """Record a failed task with an error message."""
        with cls._lock:
            cls._result_store[message.task_id] = {"success": False, "message": message, "start_time": cls._current_task["start_time"], "error": error}
            cls._current_task = None

105
106
107
108
109
110
111
112
113
    @classmethod
    def clean_stopped_task(cls):
        with cls._lock:
            if cls._current_task:
                message = cls._current_task["message"]
                error = "Task stopped by user"
                cls._result_store[message.task_id] = {"success": False, "message": message, "start_time": cls._current_task["start_time"], "error": error}
                cls._current_task = None

114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
    @classmethod
    def get_status_task_id(cls, task_id: str):
        with cls._lock:
            if cls._current_task and cls._current_task["message"].task_id == task_id:
                return {"task_status": "processing"}
            if task_id in cls._result_store:
                return {"task_status": "completed", **cls._result_store[task_id]}
            return {"task_status": "not_found"}

    @classmethod
    def get_status_service(cls):
        with cls._lock:
            if cls._current_task:
                return {"service_status": "busy", "task_id": cls._current_task["message"].task_id}
            return {"service_status": "idle"}

    @classmethod
    def get_all_tasks(cls):
        with cls._lock:
            return cls._result_store


def local_video_generate(message: Message):
    try:
        global runner
        runner.set_inputs(message)
        logger.info(f"message: {message}")
        runner.run_pipeline()
        ServiceStatus.complete_task(message)
    except Exception as e:
        logger.error(f"task_id {message.task_id} failed: {str(e)}")
        ServiceStatus.record_failed_task(message, error=str(e))


148
@app.post("/v1/local/video/generate")
helloyongyang's avatar
helloyongyang committed
149
async def v1_local_video_generate(message: Message):
150
151
152
    try:
        task_id = ServiceStatus.start_task(message)
        # Use background threads to perform long-running tasks
153
154
155
        global thread
        thread = threading.Thread(target=local_video_generate, args=(message,), daemon=True)
        thread.start()
156
        return {"task_id": task_id, "task_status": "processing", "save_video_path": message.save_video_path}
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
    except RuntimeError as e:
        return {"error": str(e)}


@app.get("/v1/local/video/generate/service_status")
async def get_service_status():
    return ServiceStatus.get_status_service()


@app.get("/v1/local/video/generate/get_all_tasks")
async def get_all_tasks():
    return ServiceStatus.get_all_tasks()


@app.post("/v1/local/video/generate/task_status")
async def get_task_status(message: TaskStatusMessage):
    return ServiceStatus.get_status_task_id(message.task_id)


176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
def _async_raise(tid, exctype):
    """Force thread tid to raise exception exctype"""
    res = ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(tid), ctypes.py_object(exctype))
    if res == 0:
        raise ValueError("Invalid thread ID")
    elif res > 1:
        ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(tid), 0)
        raise SystemError("PyThreadState_SetAsyncExc failed")


@app.get("/v1/local/video/generate/stop_running_task")
async def stop_running_task():
    global thread
    if thread and thread.is_alive():
        try:
            _async_raise(thread.ident, SystemExit)
            thread.join()

            # Clean up the thread reference
            thread = None
            ServiceStatus.clean_stopped_task()
            gc.collect()
            torch.cuda.empty_cache()
            return {"stop_status": "success", "reason": "Task stopped successfully."}
        except Exception as e:
            return {"stop_status": "error", "reason": str(e)}
    else:
        return {"stop_status": "do_nothing", "reason": "No running task found."}
helloyongyang's avatar
helloyongyang committed
204
205


206
207
208
209
# =========================
# Main Entry
# =========================

helloyongyang's avatar
helloyongyang committed
210
if __name__ == "__main__":
211
    signal.signal(signal.SIGINT, signal_handler)
helloyongyang's avatar
helloyongyang committed
212
    parser = argparse.ArgumentParser()
helloyongyang's avatar
helloyongyang committed
213
    parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causvid", "wan2.1_skyreels_v2_df"], default="hunyuan")
helloyongyang's avatar
helloyongyang committed
214
215
216
    parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
    parser.add_argument("--model_path", type=str, required=True)
    parser.add_argument("--config_json", type=str, required=True)
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
217
    parser.add_argument("--prompt_enhancer", default=None)
helloyongyang's avatar
helloyongyang committed
218

helloyongyang's avatar
helloyongyang committed
219
220
    parser.add_argument("--port", type=int, default=8000)
    args = parser.parse_args()
root's avatar
root committed
221
    logger.info(f"args: {args}")
helloyongyang's avatar
helloyongyang committed
222
223
224

    with ProfilingContext("Init Server Cost"):
        config = set_config(args)
root's avatar
root committed
225
        logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
helloyongyang's avatar
helloyongyang committed
226
227
        runner = init_runner(config)

228
    uvicorn.run(app, host="0.0.0.0", port=config.port, reload=False, workers=1)