Commit c98d486d authored by helloyongyang's avatar helloyongyang
Browse files

Support Ctrl-C when server is running.

parent 704bf91e
import signal
import sys
import psutil
import argparse import argparse
from fastapi import FastAPI from fastapi import FastAPI, Request
from pydantic import BaseModel from pydantic import BaseModel
import uvicorn import uvicorn
import json import json
import asyncio
from lightx2v.utils.profiler import ProfilingContext from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.set_config import set_config from lightx2v.utils.set_config import set_config
from lightx2v.infer import init_runner from lightx2v.infer import init_runner
# =========================
# Signal & Process Control
# =========================
def kill_all_related_processes():
"""Kill the current process and all its child processes"""
current_process = psutil.Process()
children = current_process.children(recursive=True)
for child in children:
try:
child.kill()
except Exception as e:
print(f"Failed to kill child process {child.pid}: {e}")
try:
current_process.kill()
except Exception as e:
print(f"Failed to kill main process: {e}")
def signal_handler(sig, frame):
print("\nReceived Ctrl+C, shutting down all related processes...")
kill_all_related_processes()
sys.exit(0)
# =========================
# FastAPI Related Code
# =========================
runner = None
app = FastAPI()
class Message(BaseModel): class Message(BaseModel):
prompt: str prompt: str
negative_prompt: str = "" negative_prompt: str = ""
...@@ -19,13 +58,20 @@ class Message(BaseModel): ...@@ -19,13 +58,20 @@ class Message(BaseModel):
return getattr(self, key, default) return getattr(self, key, default)
async def main(message): @app.post("/v1/local/video/generate")
async def v1_local_video_generate(message: Message, request: Request):
global runner
runner.set_inputs(message) runner.set_inputs(message)
runner.run_pipeline() await asyncio.to_thread(runner.run_pipeline)
return {"response": "finished"} return {"response": "finished", "save_video_path": message.save_video_path}
# =========================
# Main Entry
# =========================
if __name__ == "__main__": if __name__ == "__main__":
signal.signal(signal.SIGINT, signal_handler)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causal"], default="hunyuan") parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causal"], default="hunyuan")
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v") parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
...@@ -40,11 +86,4 @@ if __name__ == "__main__": ...@@ -40,11 +86,4 @@ if __name__ == "__main__":
print(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}") print(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
runner = init_runner(config) runner = init_runner(config)
app = FastAPI() uvicorn.run(app, host="0.0.0.0", port=config.port, reload=False, workers=1)
@app.post("/v1/local/video/generate")
async def generate_video(message: Message):
response = await main(message)
return response
uvicorn.run(app, host="0.0.0.0", port=config.port)
...@@ -10,6 +10,8 @@ message = { ...@@ -10,6 +10,8 @@ message = {
"save_video_path": "./output_lightx2v_wan_t2v_ap4.mp4", # It is best to set it to an absolute path. "save_video_path": "./output_lightx2v_wan_t2v_ap4.mp4", # It is best to set it to an absolute path.
} }
print(f"message: {message}")
response = requests.post(url, json=message) response = requests.post(url, json=message)
print(f"response: {response.json()}") print(f"response: {response.json()}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment