Commit ca696d83 authored by helloyongyang's avatar helloyongyang
Browse files

Fix enhancer bugs

parent dbfa688b
...@@ -68,8 +68,8 @@ async def v1_local_video_generate(message: Message): ...@@ -68,8 +68,8 @@ async def v1_local_video_generate(message: Message):
logger.info(f"message: {message}") logger.info(f"message: {message}")
await asyncio.to_thread(runner.run_pipeline) await asyncio.to_thread(runner.run_pipeline)
response = {"response": "finished", "save_video_path": message.save_video_path} response = {"response": "finished", "save_video_path": message.save_video_path}
if runner.has_prompt_enhancer and message.use_prompt_enhancer: if message.use_prompt_enhancer:
response["enhanced_prompt"] = runner.config["prompt"] response["prompt_enhanced"] = runner.config["prompt_enhanced"]
return response return response
...@@ -80,11 +80,12 @@ async def v1_local_video_generate(message: Message): ...@@ -80,11 +80,12 @@ async def v1_local_video_generate(message: Message):
if __name__ == "__main__": if __name__ == "__main__":
signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGINT, signal_handler)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causvid"], default="hunyuan") parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causvid", "wan2.1_skyreels_v2_df"], default="hunyuan")
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v") parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True) parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True) parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--prompt_enhancer", default=None) parser.add_argument("--prompt_enhancer", default=None)
parser.add_argument("--port", type=int, default=8000) parser.add_argument("--port", type=int, default=8000)
args = parser.parse_args() args = parser.parse_args()
logger.info(f"args: {args}") logger.info(f"args: {args}")
......
...@@ -39,9 +39,9 @@ if __name__ == "__main__": ...@@ -39,9 +39,9 @@ if __name__ == "__main__":
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v") parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True) parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True) parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--enable_cfg", type=bool, default=False)
parser.add_argument("--prompt", type=str, required=True)
parser.add_argument("--prompt_enhancer", type=str, default=None) parser.add_argument("--prompt_enhancer", type=str, default=None)
parser.add_argument("--prompt", type=str, required=True)
parser.add_argument("--negative_prompt", type=str, default="") parser.add_argument("--negative_prompt", type=str, default="")
parser.add_argument("--image_path", type=str, default="", help="The path to input image file or path for image-to-video (i2v) task") parser.add_argument("--image_path", type=str, default="", help="The path to input image file or path for image-to-video (i2v) task")
parser.add_argument("--save_video_path", type=str, default="./output_lightx2v.mp4", help="The path to save video path/file") parser.add_argument("--save_video_path", type=str, default="./output_lightx2v.mp4", help="The path to save video path/file")
......
...@@ -11,29 +11,22 @@ from loguru import logger ...@@ -11,29 +11,22 @@ from loguru import logger
class DefaultRunner: class DefaultRunner:
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
self.config["user_prompt"] = self.config["prompt"] if self.config.prompt_enhancer is not None and self.config.task == "t2v":
self.has_prompt_enhancer = self.config.prompt_enhancer is not None and self.config.task == "t2v"
self.config["use_prompt_enhancer"] = self.has_prompt_enhancer
if self.has_prompt_enhancer:
self.load_prompt_enhancer() self.load_prompt_enhancer()
self.model, self.text_encoders, self.vae_model, self.image_encoder = self.load_model() self.model, self.text_encoders, self.vae_model, self.image_encoder = self.load_model()
@ProfilingContext("Load prompt enhancer") @ProfilingContext("Load prompt enhancer")
def load_prompt_enhancer(self): def load_prompt_enhancer(self):
gpu_count = torch.cuda.device_count() gpu_count = torch.cuda.device_count()
if gpu_count == 1: if gpu_count == 1:
logger.info("Only one GPU, use prompt enhancer cpu offload") logger.info("Only one GPU, use prompt enhancer cpu offload")
raise NotImplementedError("prompt enhancer cpu offload is not supported.") raise NotImplementedError("prompt enhancer cpu offload is not supported.")
self.prompt_enhancer = PromptEnhancer(model_name=self.config.prompt_enhancer, device_map="cuda:1") self.prompt_enhancer = PromptEnhancer(model_name=self.config.prompt_enhancer, device_map="cuda:1")
self.config["use_prompt_enhancer"] = True # Set use_prompt_enhancer to True now. (Default is False)
def set_inputs(self, inputs): def set_inputs(self, inputs):
self.config["user_prompt"] = inputs.get("prompt", "")
self.config["prompt"] = inputs.get("prompt", "") self.config["prompt"] = inputs.get("prompt", "")
self.config["use_prompt_enhancer"] = inputs.get("use_prompt_enhancer", False) self.config["use_prompt_enhancer"] = inputs.get("use_prompt_enhancer", False) # Reset use_prompt_enhancer from clinet side.
self.config["negative_prompt"] = inputs.get("negative_prompt", "") self.config["negative_prompt"] = inputs.get("negative_prompt", "")
self.config["image_path"] = inputs.get("image_path", "") self.config["image_path"] = inputs.get("image_path", "")
self.config["save_video_path"] = inputs.get("save_video_path", "") self.config["save_video_path"] = inputs.get("save_video_path", "")
...@@ -44,7 +37,8 @@ class DefaultRunner: ...@@ -44,7 +37,8 @@ class DefaultRunner:
with ProfilingContext("Run Img Encoder"): with ProfilingContext("Run Img Encoder"):
image_encoder_output = self.run_image_encoder(self.config, self.image_encoder, self.vae_model) image_encoder_output = self.run_image_encoder(self.config, self.image_encoder, self.vae_model)
with ProfilingContext("Run Text Encoder"): with ProfilingContext("Run Text Encoder"):
text_encoder_output = self.run_text_encoder(self.config["prompt"], self.text_encoders, self.config, image_encoder_output) prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
text_encoder_output = self.run_text_encoder(prompt, self.text_encoders, self.config, image_encoder_output)
self.set_target_shape() self.set_target_shape()
self.inputs = {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output} self.inputs = {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output}
...@@ -93,8 +87,8 @@ class DefaultRunner: ...@@ -93,8 +87,8 @@ class DefaultRunner:
save_videos_grid(images, self.config.save_video_path, fps=24) save_videos_grid(images, self.config.save_video_path, fps=24)
def run_pipeline(self): def run_pipeline(self):
if self.has_prompt_enhancer and self.config["use_prompt_enhancer"]: if self.config["use_prompt_enhancer"]:
self.config["prompt"] = self.prompt_enhancer(self.config["user_prompt"]) self.config["prompt_enhanced"] = self.prompt_enhancer(self.config["prompt"])
self.init_scheduler() self.init_scheduler()
self.run_input_encoder() self.run_input_encoder()
self.model.scheduler.prepare(self.inputs["image_encoder_output"]) self.model.scheduler.prepare(self.inputs["image_encoder_output"])
......
import argparse import argparse
import torch
from loguru import logger from loguru import logger
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext
...@@ -38,6 +39,7 @@ class PromptEnhancer: ...@@ -38,6 +39,7 @@ class PromptEnhancer:
self.model = self.model.to(device) self.model = self.model.to(device)
@ProfilingContext("Run prompt enhancer") @ProfilingContext("Run prompt enhancer")
@torch.no_grad()
def __call__(self, prompt): def __call__(self, prompt):
prompt = prompt.strip() prompt = prompt.strip()
prompt = sys_prompt.format(prompt) prompt = sys_prompt.format(prompt)
...@@ -46,11 +48,20 @@ class PromptEnhancer: ...@@ -46,11 +48,20 @@ class PromptEnhancer:
model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device) model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
generated_ids = self.model.generate( generated_ids = self.model.generate(
**model_inputs, **model_inputs,
max_new_tokens=2048, max_new_tokens=8192,
) )
generated_ids = [output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)] output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist()
rewritten_prompt = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
logger.info(f"Enhanced prompt: {rewritten_prompt}") think_id = self.tokenizer.encode("</think>")
if len(think_id) == 1:
index = len(output_ids) - output_ids[::-1].index(think_id[0])
else:
index = 0
thinking_content = self.tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n")
logger.info(f"[Enhanced] thinking content: {thinking_content}")
rewritten_prompt = self.tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
logger.info(f"[Enhanced] rewritten prompt: {rewritten_prompt}")
return rewritten_prompt return rewritten_prompt
......
...@@ -19,6 +19,7 @@ def get_default_config(): ...@@ -19,6 +19,7 @@ def get_default_config():
"lora_path": None, "lora_path": None,
"strength_model": 1.0, "strength_model": 1.0,
"mm_config": {}, "mm_config": {},
"use_prompt_enhancer": False,
} }
return default_config return default_config
......
...@@ -8,7 +8,7 @@ message = { ...@@ -8,7 +8,7 @@ message = {
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
"negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", "negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "", "image_path": "",
"save_video_path": "./output_lightx2v_wan_t2v_ap4.mp4", # It is best to set it to an absolute path. "save_video_path": "./output_lightx2v_wan_t2v.mp4", # It is best to set it to an absolute path.
} }
logger.info(f"message: {message}") logger.info(f"message: {message}")
......
...@@ -6,11 +6,10 @@ url = "http://localhost:8000/v1/local/video/generate" ...@@ -6,11 +6,10 @@ url = "http://localhost:8000/v1/local/video/generate"
message = { message = {
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
"use_prompt_enhancer": True,
"negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", "negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "", "image_path": "",
"num_fragments": 1, "save_video_path": "./output_lightx2v_wan_t2v_enhanced.mp4", # It is best to set it to an absolute path.
"save_video_path": "./output_lightx2v_wan_t2v_ap4.mp4", # It is best to set it to an absolute path. "use_prompt_enhancer": True,
} }
logger.info(f"message: {message}") logger.info(f"message: {message}")
......
...@@ -6,7 +6,7 @@ model_path= ...@@ -6,7 +6,7 @@ model_path=
# check section # check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=2 cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: ${cuda_devices}, change at shell script or set env variable." echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices} export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi fi
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment