Commit fe13f4db authored by helloyongyang's avatar helloyongyang
Browse files

feat(server): Support stopping the running task

parent ea8da6fb
...@@ -10,6 +10,9 @@ import json ...@@ -10,6 +10,9 @@ import json
from typing import Optional from typing import Optional
from datetime import datetime from datetime import datetime
import threading import threading
import ctypes
import gc
import torch
from lightx2v.utils.profiler import ProfilingContext from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.set_config import set_config from lightx2v.utils.set_config import set_config
...@@ -47,6 +50,7 @@ def signal_handler(sig, frame): ...@@ -47,6 +50,7 @@ def signal_handler(sig, frame):
# ========================= # =========================
runner = None runner = None
thread = None
app = FastAPI() app = FastAPI()
...@@ -98,6 +102,15 @@ class ServiceStatus: ...@@ -98,6 +102,15 @@ class ServiceStatus:
cls._result_store[message.task_id] = {"success": False, "message": message, "start_time": cls._current_task["start_time"], "error": error} cls._result_store[message.task_id] = {"success": False, "message": message, "start_time": cls._current_task["start_time"], "error": error}
cls._current_task = None cls._current_task = None
@classmethod
def clean_stopped_task(cls):
with cls._lock:
if cls._current_task:
message = cls._current_task["message"]
error = "Task stopped by user"
cls._result_store[message.task_id] = {"success": False, "message": message, "start_time": cls._current_task["start_time"], "error": error}
cls._current_task = None
@classmethod @classmethod
def get_status_task_id(cls, task_id: str): def get_status_task_id(cls, task_id: str):
with cls._lock: with cls._lock:
...@@ -137,7 +150,9 @@ async def v1_local_video_generate(message: Message): ...@@ -137,7 +150,9 @@ async def v1_local_video_generate(message: Message):
try: try:
task_id = ServiceStatus.start_task(message) task_id = ServiceStatus.start_task(message)
# Use background threads to perform long-running tasks # Use background threads to perform long-running tasks
threading.Thread(target=local_video_generate, args=(message,), daemon=True).start() global thread
thread = threading.Thread(target=local_video_generate, args=(message,), daemon=True)
thread.start()
return {"task_id": task_id, "task_status": "processing", "save_video_path": message.save_video_path} return {"task_id": task_id, "task_status": "processing", "save_video_path": message.save_video_path}
except RuntimeError as e: except RuntimeError as e:
return {"error": str(e)} return {"error": str(e)}
...@@ -158,9 +173,34 @@ async def get_task_status(message: TaskStatusMessage): ...@@ -158,9 +173,34 @@ async def get_task_status(message: TaskStatusMessage):
return ServiceStatus.get_status_task_id(message.task_id) return ServiceStatus.get_status_task_id(message.task_id)
# TODO: Implement delete task. Stop the specified task and clean many things. def _async_raise(tid, exctype):
# @app.delete("/v1/local/video/generate/task_status") """Force thread tid to raise exception exctype"""
# async def delete_task(message: TaskStatusMessage): res = ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(tid), ctypes.py_object(exctype))
if res == 0:
raise ValueError("Invalid thread ID")
elif res > 1:
ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(tid), 0)
raise SystemError("PyThreadState_SetAsyncExc failed")
@app.get("/v1/local/video/generate/stop_running_task")
async def stop_running_task():
global thread
if thread and thread.is_alive():
try:
_async_raise(thread.ident, SystemExit)
thread.join()
# Clean up the thread reference
thread = None
ServiceStatus.clean_stopped_task()
gc.collect()
torch.cuda.empty_cache()
return {"stop_status": "success", "reason": "Task stopped successfully."}
except Exception as e:
return {"stop_status": "error", "reason": str(e)}
else:
return {"stop_status": "do_nothing", "reason": "No running task found."}
# ========================= # =========================
......
import requests
from loguru import logger
response = requests.get("http://localhost:8000/v1/local/video/generate/stop_running_task")
logger.info(response.json())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment