Commit c98d486d authored by helloyongyang's avatar helloyongyang
Browse files

Support Ctrl-C when server is running.

parent 704bf91e
import signal
import sys
import psutil
import argparse
from fastapi import FastAPI
from fastapi import FastAPI, Request
from pydantic import BaseModel
import uvicorn
import json
import asyncio
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.set_config import set_config
from lightx2v.infer import init_runner
# =========================
# Signal & Process Control
# =========================
def kill_all_related_processes():
"""Kill the current process and all its child processes"""
current_process = psutil.Process()
children = current_process.children(recursive=True)
for child in children:
try:
child.kill()
except Exception as e:
print(f"Failed to kill child process {child.pid}: {e}")
try:
current_process.kill()
except Exception as e:
print(f"Failed to kill main process: {e}")
def signal_handler(sig, frame):
print("\nReceived Ctrl+C, shutting down all related processes...")
kill_all_related_processes()
sys.exit(0)
# =========================
# FastAPI Related Code
# =========================
runner = None
app = FastAPI()
class Message(BaseModel):
prompt: str
negative_prompt: str = ""
......@@ -19,13 +58,20 @@ class Message(BaseModel):
return getattr(self, key, default)
async def main(message):
@app.post("/v1/local/video/generate")
async def v1_local_video_generate(message: Message, request: Request):
global runner
runner.set_inputs(message)
runner.run_pipeline()
return {"response": "finished"}
await asyncio.to_thread(runner.run_pipeline)
return {"response": "finished", "save_video_path": message.save_video_path}
# =========================
# Main Entry
# =========================
if __name__ == "__main__":
signal.signal(signal.SIGINT, signal_handler)
parser = argparse.ArgumentParser()
parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causal"], default="hunyuan")
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
......@@ -40,11 +86,4 @@ if __name__ == "__main__":
print(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
runner = init_runner(config)
app = FastAPI()
@app.post("/v1/local/video/generate")
async def generate_video(message: Message):
response = await main(message)
return response
uvicorn.run(app, host="0.0.0.0", port=config.port)
uvicorn.run(app, host="0.0.0.0", port=config.port, reload=False, workers=1)
......@@ -10,6 +10,8 @@ message = {
"save_video_path": "./output_lightx2v_wan_t2v_ap4.mp4", # It is best to set it to an absolute path.
}
print(f"message: {message}")
response = requests.post(url, json=message)
print(f"response: {response.json()}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment