Commit 429dcc45 authored by Zhuguanyu Wu's avatar Zhuguanyu Wu Committed by GitHub
Browse files

support prompt enhancer with vllm (#53)

* support prompt enhancer server

* bugs fixed

* finished prompt enhancer service
parent 48426398
import argparse
from fastapi import FastAPI
from pydantic import BaseModel
from loguru import logger
import uvicorn
import json
from typing import Optional
from vllm import LLM, SamplingParams
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.service_utils import TaskStatusMessage, BaseServiceStatus, ProcessManager
# =========================
# FastAPI Related Code
# =========================
runner = None
app = FastAPI()
sys_prompt = """
Transform the short prompt into a detailed video-generation caption using this structure:
​​Opening shot type​​ (long/medium/close-up/extreme close-up/full shot)
​​Primary subject(s)​​ with vivid attributes (colors, textures, actions, interactions)
​​Dynamic elements​​ (movement, transitions, or changes over time, e.g., 'gradually lowers,' 'begins to climb,' 'camera moves toward...')
​​Scene composition​​ (background, environment, spatial relationships)
​​Lighting/atmosphere​​ (natural/artificial, time of day, mood)
​​Camera motion​​ (zooms, pans, static/handheld shots) if applicable.
Pattern Summary from Examples:
[Shot Type] of [Subject+Action] + [Detailed Subject Description] + [Environmental Context] + [Lighting Conditions] + [Camera Movement]
​One case:
Short prompt: a person is playing football
Long prompt: Medium shot of a young athlete in a red jersey sprinting across a muddy field, dribbling a soccer ball with precise footwork. The player glances toward the goalpost, adjusts their stance, and kicks the ball forcefully into the net. Raindrops fall lightly, creating reflections under stadium floodlights. The camera follows the ball’s trajectory in a smooth pan.
Note: If the subject is stationary, incorporate camera movement to ensure the generated video remains dynamic.
​​Now expand this short prompt:​​ [{}]. Please only output the final long prompt in English.
"""
class Message(BaseModel):
task_id: str
task_id_must_unique: bool = False
prompt: str
def get(self, key, default=None):
return getattr(self, key, default)
class PromptEnhancerServiceStatus(BaseServiceStatus):
pass
class PromptEnhancerRunner:
def __init__(self, model_path):
self.model_path = model_path
self.model = self.get_model()
self.sampling_params = SamplingParams(
temperature=0.7,
top_p=0.9,
max_tokens=8192,
)
def get_model(self):
model = LLM(model=self.model_path, trust_remote_code=True, dtype="bfloat16", gpu_memory_utilization=0.95, max_model_len=16384)
return model
def _run_prompt_enhancer(self, prompt):
prompt = prompt.strip()
prompt = sys_prompt.format(prompt)
messages = [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}]
outputs = self.model.chat(
messages=messages,
sampling_params=self.sampling_params,
)
enhanced_prompt = outputs[0].outputs[0].text
return enhanced_prompt.strip()
def run_prompt_enhancer(message: Message):
try:
global runner
enhanced_prompt = runner._run_prompt_enhancer(message.prompt)
assert enhanced_prompt is not None
PromptEnhancerServiceStatus.complete_task(message)
return enhanced_prompt
except Exception as e:
logger.error(f"task_id {message.task_id} failed: {str(e)}")
PromptEnhancerServiceStatus.record_failed_task(message, error=str(e))
@app.post("/v1/local/prompt_enhancer/generate")
def v1_local_prompt_enhancer_generate(message: Message):
try:
task_id = PromptEnhancerServiceStatus.start_task(message)
enhanced_prompt = run_prompt_enhancer(message)
return {"task_id": task_id, "task_status": "completed", "output": enhanced_prompt, "kwargs": None}
except RuntimeError as e:
return {"error": str(e)}
@app.get("/v1/local/prompt_enhancer/generate/service_status")
async def get_service_status():
return PromptEnhancerServiceStatus.get_status_service()
@app.get("/v1/local/prompt_enhancer/generate/get_all_tasks")
async def get_all_tasks():
return PromptEnhancerServiceStatus.get_all_tasks()
@app.post("/v1/local/prompt_enhancer/generate/task_status")
async def get_task_status(message: TaskStatusMessage):
return PromptEnhancerServiceStatus.get_status_task_id(message.task_id)
# =========================
# Main Entry
# =========================
if __name__ == "__main__":
ProcessManager.register_signal_handler()
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--port", type=int, default=9001)
args = parser.parse_args()
logger.info(f"args: {args}")
with ProfilingContext("Init Server Cost"):
runner = PromptEnhancerRunner(args.model_path)
uvicorn.run(app, host="0.0.0.0", port=args.port, reload=False, workers=1)
...@@ -35,6 +35,7 @@ class Message(BaseModel): ...@@ -35,6 +35,7 @@ class Message(BaseModel):
text: str text: str
img: Optional[bytes] = None img: Optional[bytes] = None
n_prompt: Optional[str] = None
def get(self, key, default=None): def get(self, key, default=None):
return getattr(self, key, default) return getattr(self, key, default)
...@@ -71,12 +72,11 @@ class TextEncoderRunner: ...@@ -71,12 +72,11 @@ class TextEncoderRunner:
raise ValueError(f"Unsupported model class: {self.config.model_cls}") raise ValueError(f"Unsupported model class: {self.config.model_cls}")
return text_encoders return text_encoders
def _run_text_encoder(self, text, img): def _run_text_encoder(self, text, img, n_prompt):
if "wan2.1" in self.config.model_cls: if "wan2.1" in self.config.model_cls:
text_encoder_output = {} text_encoder_output = {}
n_prompt = self.config.get("negative_prompt", "") context = self.text_encoders[0].infer([text])
context = self.text_encoders[0].infer([text], self.config) context_null = self.text_encoders[0].infer([n_prompt if n_prompt else ""])
context_null = self.text_encoders[0].infer([n_prompt if n_prompt else ""], self.config)
text_encoder_output["context"] = context text_encoder_output["context"] = context
text_encoder_output["context_null"] = context_null text_encoder_output["context_null"] = context_null
elif self.config.model_cls in ["hunyuan"]: elif self.config.model_cls in ["hunyuan"]:
...@@ -97,7 +97,7 @@ class TextEncoderRunner: ...@@ -97,7 +97,7 @@ class TextEncoderRunner:
def run_text_encoder(message: Message): def run_text_encoder(message: Message):
try: try:
global runner global runner
text_encoder_output = runner._run_text_encoder(message.text, message.img) text_encoder_output = runner._run_text_encoder(message.text, message.img, message.n_prompt)
TextEncoderServiceStatus.complete_task(message) TextEncoderServiceStatus.complete_task(message)
return text_encoder_output return text_encoder_output
except Exception as e: except Exception as e:
...@@ -105,7 +105,7 @@ def run_text_encoder(message: Message): ...@@ -105,7 +105,7 @@ def run_text_encoder(message: Message):
TextEncoderServiceStatus.record_failed_task(message, error=str(e)) TextEncoderServiceStatus.record_failed_task(message, error=str(e))
@app.post("/v1/local/text_encoder/generate") @app.post("/v1/local/text_encoders/generate")
def v1_local_text_encoder_generate(message: Message): def v1_local_text_encoder_generate(message: Message):
try: try:
task_id = TextEncoderServiceStatus.start_task(message) task_id = TextEncoderServiceStatus.start_task(message)
...@@ -117,17 +117,17 @@ def v1_local_text_encoder_generate(message: Message): ...@@ -117,17 +117,17 @@ def v1_local_text_encoder_generate(message: Message):
return {"error": str(e)} return {"error": str(e)}
@app.get("/v1/local/text_encoder/generate/service_status") @app.get("/v1/local/text_encoders/generate/service_status")
async def get_service_status(): async def get_service_status():
return TextEncoderServiceStatus.get_status_service() return TextEncoderServiceStatus.get_status_service()
@app.get("/v1/local/text_encoder/generate/get_all_tasks") @app.get("/v1/local/text_encoders/generate/get_all_tasks")
async def get_all_tasks(): async def get_all_tasks():
return TextEncoderServiceStatus.get_all_tasks() return TextEncoderServiceStatus.get_all_tasks()
@app.post("/v1/local/text_encoder/generate/task_status") @app.post("/v1/local/text_encoders/generate/task_status")
async def get_task_status(message: TaskStatusMessage): async def get_task_status(message: TaskStatusMessage):
return TextEncoderServiceStatus.get_status_task_id(message.task_id) return TextEncoderServiceStatus.get_status_task_id(message.task_id)
......
...@@ -187,6 +187,11 @@ def v1_local_vae_model_decoder_generate(message: Message): ...@@ -187,6 +187,11 @@ def v1_local_vae_model_decoder_generate(message: Message):
return {"error": str(e)} return {"error": str(e)}
@app.get("/v1/local/vae_model/generate/service_status")
async def get_service_status():
return VAEServiceStatus.get_status_service()
@app.get("/v1/local/vae_model/encoder/generate/service_status") @app.get("/v1/local/vae_model/encoder/generate/service_status")
async def get_service_status(): async def get_service_status():
return VAEServiceStatus.get_status_service() return VAEServiceStatus.get_status_service()
......
import asyncio import asyncio
import gc import gc
import aiohttp import aiohttp
import requests
from requests.exceptions import RequestException
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torchvision.transforms.functional as TF import torchvision.transforms.functional as TF
...@@ -16,28 +18,65 @@ from loguru import logger ...@@ -16,28 +18,65 @@ from loguru import logger
class DefaultRunner: class DefaultRunner:
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
# TODO: implement prompt enhancer
self.has_prompt_enhancer = False self.has_prompt_enhancer = False
# if self.config.prompt_enhancer is not None and self.config.task == "t2v": if self.config["task"] == "t2v" and self.config.get("sub_servers", {}).get("prompt_enhancer") is not None:
# self.config["use_prompt_enhancer"] = True self.has_prompt_enhancer = True
# self.has_prompt_enhancer = True if not self.check_sub_servers("prompt_enhancer"):
self.has_prompt_enhancer = False
logger.warning("No prompt enhancer server available, disable prompt enhancer.")
if self.config["mode"] == "split_server": if self.config["mode"] == "split_server":
self.model = self.load_transformer() self.model = self.load_transformer()
self.text_encoders, self.vae_model, self.image_encoder = None, None, None self.text_encoders, self.vae_model, self.image_encoder = None, None, None
self.tensor_transporter = TensorTransporter() self.tensor_transporter = TensorTransporter()
self.image_transporter = ImageTransporter() self.image_transporter = ImageTransporter()
if not self.check_sub_servers("text_encoders"):
raise ValueError("No text encoder server available")
if "wan2.1" in self.config["model_cls"] and not self.check_sub_servers("image_encoder"):
raise ValueError("No image encoder server available")
if not self.check_sub_servers("vae_model"):
raise ValueError("No vae model server available")
else: else:
self.model, self.text_encoders, self.vae_model, self.image_encoder = self.load_model() self.model, self.text_encoders, self.vae_model, self.image_encoder = self.load_model()
def check_sub_servers(self, task_type):
urls = self.config.get("sub_servers", {}).get(task_type, [])
available_servers = []
for url in urls:
try:
status_url = f"{url}/v1/local/{task_type}/generate/service_status"
response = requests.get(status_url, timeout=2)
if response.status_code == 200:
available_servers.append(url)
else:
logger.warning(f"Service {url} returned status code {response.status_code}")
except RequestException as e:
logger.warning(f"Failed to connect to {url}: {str(e)}")
continue
logger.info(f"{task_type} available servers: {available_servers}")
self.config["sub_servers"][task_type] = available_servers
return len(available_servers) > 0
def set_inputs(self, inputs): def set_inputs(self, inputs):
self.config["prompt"] = inputs.get("prompt", "") self.config["prompt"] = inputs.get("prompt", "")
if self.has_prompt_enhancer and self.config["mode"] != "infer": if self.has_prompt_enhancer:
self.config["use_prompt_enhancer"] = inputs.get("use_prompt_enhancer", False) # Reset use_prompt_enhancer from clinet side. self.config["use_prompt_enhancer"] = inputs.get("use_prompt_enhancer", False) # Reset use_prompt_enhancer from clinet side.
self.config["negative_prompt"] = inputs.get("negative_prompt", "") self.config["negative_prompt"] = inputs.get("negative_prompt", "")
self.config["image_path"] = inputs.get("image_path", "") self.config["image_path"] = inputs.get("image_path", "")
self.config["save_video_path"] = inputs.get("save_video_path", "") self.config["save_video_path"] = inputs.get("save_video_path", "")
async def post_encoders(self, prompt, img=None, i2v=False): def post_prompt_enhancer(self):
while True:
for url in self.config["sub_servers"]["prompt_enhancer"]:
response = requests.get(f"{url}/v1/local/prompt_enhancer/generate/service_status").json()
if response["service_status"] == "idle":
response = requests.post(f"{url}/v1/local/prompt_enhancer/generate", json={"task_id": generate_task_id(), "prompt": self.config["prompt"]})
self.config["prompt_enhanced"] = response.json()["output"]
logger.info(f"Enhanced prompt: {self.config['prompt_enhanced']}")
return
async def post_encoders(self, prompt, img=None, n_prompt=None, i2v=False):
tasks = [] tasks = []
img_byte = self.image_transporter.prepare_image(img) if img is not None else None img_byte = self.image_transporter.prepare_image(img) if img is not None else None
if i2v: if i2v:
...@@ -54,11 +93,16 @@ class DefaultRunner: ...@@ -54,11 +93,16 @@ class DefaultRunner:
) )
tasks.append( tasks.append(
asyncio.create_task( asyncio.create_task(
self.post_task(task_type="text_encoder", urls=self.config["sub_servers"]["text_encoders"], message={"task_id": generate_task_id(), "text": prompt, "img": img_byte}, device="cuda") self.post_task(
task_type="text_encoders",
urls=self.config["sub_servers"]["text_encoders"],
message={"task_id": generate_task_id(), "text": prompt, "img": img_byte, "n_prompt": n_prompt},
device="cuda",
)
) )
) )
results = await asyncio.gather(*tasks) results = await asyncio.gather(*tasks)
# clip_encoder, vae_encoder, text_encoder # clip_encoder, vae_encoder, text_encoders
if not i2v: if not i2v:
return None, None, results[0] return None, None, results[0]
if "wan2.1" in self.config["model_cls"]: if "wan2.1" in self.config["model_cls"]:
...@@ -69,11 +113,12 @@ class DefaultRunner: ...@@ -69,11 +113,12 @@ class DefaultRunner:
async def run_input_encoder(self): async def run_input_encoder(self):
image_encoder_output = None image_encoder_output = None
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"] prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
n_prompt = self.config.get("negative_prompt", "")
i2v = self.config["task"] == "i2v" i2v = self.config["task"] == "i2v"
img = Image.open(self.config["image_path"]).convert("RGB") if i2v else None img = Image.open(self.config["image_path"]).convert("RGB") if i2v else None
with ProfilingContext("Run Encoders"): with ProfilingContext("Run Encoders"):
if self.config["mode"] == "split_server": if self.config["mode"] == "split_server":
clip_encoder_out, vae_encode_out, text_encoder_output = await self.post_encoders(prompt, img, i2v) clip_encoder_out, vae_encode_out, text_encoder_output = await self.post_encoders(prompt, img, n_prompt, i2v)
if i2v: if i2v:
if self.config["model_cls"] in ["hunyuan"]: if self.config["model_cls"] in ["hunyuan"]:
image_encoder_output = {"img": img, "img_latents": vae_encode_out} image_encoder_output = {"img": img, "img_latents": vae_encode_out}
...@@ -157,7 +202,7 @@ class DefaultRunner: ...@@ -157,7 +202,7 @@ class DefaultRunner:
async def run_pipeline(self): async def run_pipeline(self):
if self.config["use_prompt_enhancer"]: if self.config["use_prompt_enhancer"]:
self.config["prompt_enhanced"] = self.prompt_enhancer(self.config["prompt"]) self.config["prompt_enhanced"] = self.post_prompt_enhancer()
self.init_scheduler() self.init_scheduler()
await self.run_input_encoder() await self.run_input_encoder()
self.model.scheduler.prepare(self.inputs["image_encoder_output"]) self.model.scheduler.prepare(self.inputs["image_encoder_output"])
......
...@@ -110,6 +110,8 @@ class TensorTransporter: ...@@ -110,6 +110,8 @@ class TensorTransporter:
def to_device(self, data, device): def to_device(self, data, device):
if isinstance(data, dict): if isinstance(data, dict):
return {key: self.to_device(value, device) for key, value in data.items()} return {key: self.to_device(value, device) for key, value in data.items()}
elif isinstance(data, list):
return [self.to_device(item, device) for item in data]
elif isinstance(data, torch.Tensor): elif isinstance(data, torch.Tensor):
return data.to(device) return data.to(device)
else: else:
......
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python -m lightx2v.common.apis.prompt_enhancer \
--model_path $model_path \
--port 9001
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment