Commit fe13f4db authored by helloyongyang's avatar helloyongyang
Browse files

feat(server): Support stopping the running task

parent ea8da6fb
......@@ -10,6 +10,9 @@ import json
from typing import Optional
from datetime import datetime
import threading
import ctypes
import gc
import torch
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.set_config import set_config
......@@ -47,6 +50,7 @@ def signal_handler(sig, frame):
# =========================
runner = None
thread = None
app = FastAPI()
......@@ -98,6 +102,15 @@ class ServiceStatus:
cls._result_store[message.task_id] = {"success": False, "message": message, "start_time": cls._current_task["start_time"], "error": error}
cls._current_task = None
@classmethod
def clean_stopped_task(cls):
with cls._lock:
if cls._current_task:
message = cls._current_task["message"]
error = "Task stopped by user"
cls._result_store[message.task_id] = {"success": False, "message": message, "start_time": cls._current_task["start_time"], "error": error}
cls._current_task = None
@classmethod
def get_status_task_id(cls, task_id: str):
with cls._lock:
......@@ -137,7 +150,9 @@ async def v1_local_video_generate(message: Message):
try:
task_id = ServiceStatus.start_task(message)
# Use background threads to perform long-running tasks
threading.Thread(target=local_video_generate, args=(message,), daemon=True).start()
global thread
thread = threading.Thread(target=local_video_generate, args=(message,), daemon=True)
thread.start()
return {"task_id": task_id, "task_status": "processing", "save_video_path": message.save_video_path}
except RuntimeError as e:
return {"error": str(e)}
......@@ -158,9 +173,34 @@ async def get_task_status(message: TaskStatusMessage):
return ServiceStatus.get_status_task_id(message.task_id)
# TODO: Implement delete task. Stop the specified task and clean many things.
# @app.delete("/v1/local/video/generate/task_status")
# async def delete_task(message: TaskStatusMessage):
def _async_raise(tid, exctype):
"""Force thread tid to raise exception exctype"""
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(tid), ctypes.py_object(exctype))
if res == 0:
raise ValueError("Invalid thread ID")
elif res > 1:
ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(tid), 0)
raise SystemError("PyThreadState_SetAsyncExc failed")
@app.get("/v1/local/video/generate/stop_running_task")
async def stop_running_task():
global thread
if thread and thread.is_alive():
try:
_async_raise(thread.ident, SystemExit)
thread.join()
# Clean up the thread reference
thread = None
ServiceStatus.clean_stopped_task()
gc.collect()
torch.cuda.empty_cache()
return {"stop_status": "success", "reason": "Task stopped successfully."}
except Exception as e:
return {"stop_status": "error", "reason": str(e)}
else:
return {"stop_status": "do_nothing", "reason": "No running task found."}
# =========================
......
import requests
from loguru import logger
response = requests.get("http://localhost:8000/v1/local/video/generate/stop_running_task")
logger.info(response.json())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment