Commit af02604e authored by helloyongyang's avatar helloyongyang
Browse files

feat(server): Support async server

parent 5b56dc56
......@@ -2,12 +2,14 @@ import signal
import sys
import psutil
import argparse
from fastapi import FastAPI, Request
from fastapi import FastAPI
from pydantic import BaseModel
from loguru import logger
import uvicorn
import json
import asyncio
from typing import Optional
from datetime import datetime
import threading
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.set_config import set_config
......@@ -50,6 +52,9 @@ app = FastAPI()
class Message(BaseModel):
task_id: str
task_id_must_unique: bool = False
prompt: str
use_prompt_enhancer: bool = False
negative_prompt: str = ""
......@@ -61,16 +66,101 @@ class Message(BaseModel):
return getattr(self, key, default)
class TaskStatusMessage(BaseModel):
task_id: str
class ServiceStatus:
_lock = threading.Lock()
_current_task = None
_result_store = {}
@classmethod
def start_task(cls, message: Message):
with cls._lock:
if cls._current_task is not None:
raise RuntimeError("Service busy")
if message.task_id_must_unique and message.task_id in cls._result_store:
raise RuntimeError(f"Task ID {message.task_id} already exists")
cls._current_task = {"message": message, "start_time": datetime.now()}
return message.task_id
@classmethod
def complete_task(cls, message: Message):
with cls._lock:
cls._result_store[message.task_id] = {"success": True, "message": message, "start_time": cls._current_task["start_time"], "completion_time": datetime.now()}
cls._current_task = None
@classmethod
def record_failed_task(cls, message: Message, error: Optional[str] = None):
"""Record a failed task with an error message."""
with cls._lock:
cls._result_store[message.task_id] = {"success": False, "message": message, "start_time": cls._current_task["start_time"], "error": error}
cls._current_task = None
@classmethod
def get_status_task_id(cls, task_id: str):
with cls._lock:
if cls._current_task and cls._current_task["message"].task_id == task_id:
return {"task_status": "processing"}
if task_id in cls._result_store:
return {"task_status": "completed", **cls._result_store[task_id]}
return {"task_status": "not_found"}
@classmethod
def get_status_service(cls):
with cls._lock:
if cls._current_task:
return {"service_status": "busy", "task_id": cls._current_task["message"].task_id}
return {"service_status": "idle"}
@classmethod
def get_all_tasks(cls):
with cls._lock:
return cls._result_store
def local_video_generate(message: Message):
try:
global runner
runner.set_inputs(message)
logger.info(f"message: {message}")
runner.run_pipeline()
ServiceStatus.complete_task(message)
except Exception as e:
logger.error(f"task_id {message.task_id} failed: {str(e)}")
ServiceStatus.record_failed_task(message, error=str(e))
@app.post("/v1/local/video/generate")
async def v1_local_video_generate(message: Message):
global runner
runner.set_inputs(message)
logger.info(f"message: {message}")
await asyncio.to_thread(runner.run_pipeline)
response = {"response": "finished", "save_video_path": message.save_video_path}
if message.use_prompt_enhancer:
response["prompt_enhanced"] = runner.config["prompt_enhanced"]
return response
try:
task_id = ServiceStatus.start_task(message)
# Use background threads to perform long-running tasks
threading.Thread(target=local_video_generate, args=(message,), daemon=True).start()
return {"task_id": task_id, "task_status": "processing"}
except RuntimeError as e:
return {"error": str(e)}
@app.get("/v1/local/video/generate/service_status")
async def get_service_status():
return ServiceStatus.get_status_service()
@app.get("/v1/local/video/generate/get_all_tasks")
async def get_all_tasks():
return ServiceStatus.get_all_tasks()
@app.post("/v1/local/video/generate/task_status")
async def get_task_status(message: TaskStatusMessage):
return ServiceStatus.get_status_task_id(message.task_id)
# TODO: Implement delete task. Stop the specified task and clean many things.
# @app.delete("/v1/local/video/generate/task_status")
# async def delete_task(message: TaskStatusMessage):
# =========================
......
import random
import string
import time
from datetime import datetime
def generate_task_id():
"""
Generate a random task ID in the format XXXX-XXXX-XXXX-XXXX-XXXX.
Features:
1. Does not modify the global random state.
2. Each X is an uppercase letter or digit (0-9).
3. Combines time factors to ensure high randomness.
"""
# Save the current random state (does not affect external randomness)
original_state = random.getstate()
try:
# Define character set (uppercase letters + digits)
characters = string.ascii_uppercase + string.digits
# Create an independent random instance
local_random = random.Random(time.perf_counter_ns())
# Generate 5 groups of 4-character random strings
groups = []
for _ in range(5):
# Mix new time factor for each group
time_mix = int(datetime.now().timestamp())
local_random.seed(time_mix + local_random.getstate()[1][0] + time.perf_counter_ns())
groups.append("".join(local_random.choices(characters, k=4)))
return "-".join(groups)
finally:
# Restore the original random state
random.setstate(original_state)
if __name__ == "__main__":
# Set global random seed
random.seed(42)
# Test that external randomness is not affected
print("External random number 1:", random.random()) # Always the same
print("Task ID 1:", generate_task_id()) # Different each time
print("External random number 1:", random.random()) # Always the same
print("Task ID 1:", generate_task_id()) # Different each time
import requests
from loguru import logger
response = requests.get("http://localhost:8000/v1/local/video/generate/service_status")
logger.info(response.json())
response = requests.get("http://localhost:8000/v1/local/video/generate/get_all_tasks")
logger.info(response.json())
response = requests.post("http://localhost:8000/v1/local/video/generate/task_status", json={"task_id": "test_task_001"})
logger.info(response.json())
import requests
from loguru import logger
import random
import string
import time
from datetime import datetime
url = "http://localhost:8000/v1/local/video/generate"
# same as lightx2v/utils/generate_task_id.py
# from lightx2v.utils.generate_task_id import generate_task_id
def generate_task_id():
"""
Generate a random task ID in the format XXXX-XXXX-XXXX-XXXX-XXXX.
Features:
1. Does not modify the global random state.
2. Each X is an uppercase letter or digit (0-9).
3. Combines time factors to ensure high randomness.
For example: N1PQ-PRM5-N1BN-Z3S1-BGBJ
"""
# Save the current random state (does not affect external randomness)
original_state = random.getstate()
message = {
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
"negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "",
"save_video_path": "./output_lightx2v_wan_t2v.mp4", # It is best to set it to an absolute path.
}
try:
# Define character set (uppercase letters + digits)
characters = string.ascii_uppercase + string.digits
logger.info(f"message: {message}")
# Create an independent random instance
local_random = random.Random(time.perf_counter_ns())
response = requests.post(url, json=message)
# Generate 5 groups of 4-character random strings
groups = []
for _ in range(5):
# Mix new time factor for each group
time_mix = int(datetime.now().timestamp())
local_random.seed(time_mix + local_random.getstate()[1][0] + time.perf_counter_ns())
logger.info(f"response: {response.json()}")
groups.append("".join(local_random.choices(characters, k=4)))
return "-".join(groups)
finally:
# Restore the original random state
random.setstate(original_state)
if __name__ == "__main__":
url = "http://localhost:8000/v1/local/video/generate"
message = {
"task_id": generate_task_id(), # task_id also can be string you like, such as "test_task_001"
"task_id_must_unique": True, # If True, the task_id must be unique, otherwise, it will raise an error. Default is False.
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
"negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "",
"save_video_path": "./output_lightx2v_wan_t2v_t02.mp4", # It is best to set it to an absolute path.
}
logger.info(f"message: {message}")
response = requests.post(url, json=message)
logger.info(f"response: {response.json()}")
......@@ -5,6 +5,7 @@ from loguru import logger
url = "http://localhost:8000/v1/local/video/generate"
message = {
"task_id": "test_task_001",
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
"negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment