"vscode:/vscode.git/clone" did not exist on "7dcf63e69cb580bc90213a29936f987df632b595"
Commit ca696d83 authored by helloyongyang's avatar helloyongyang
Browse files

Fix enhancer bugs

parent dbfa688b
......@@ -68,8 +68,8 @@ async def v1_local_video_generate(message: Message):
logger.info(f"message: {message}")
await asyncio.to_thread(runner.run_pipeline)
response = {"response": "finished", "save_video_path": message.save_video_path}
if runner.has_prompt_enhancer and message.use_prompt_enhancer:
response["enhanced_prompt"] = runner.config["prompt"]
if message.use_prompt_enhancer:
response["prompt_enhanced"] = runner.config["prompt_enhanced"]
return response
......@@ -80,11 +80,12 @@ async def v1_local_video_generate(message: Message):
if __name__ == "__main__":
signal.signal(signal.SIGINT, signal_handler)
parser = argparse.ArgumentParser()
parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causvid"], default="hunyuan")
parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causvid", "wan2.1_skyreels_v2_df"], default="hunyuan")
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--prompt_enhancer", default=None)
parser.add_argument("--port", type=int, default=8000)
args = parser.parse_args()
logger.info(f"args: {args}")
......
......@@ -39,9 +39,9 @@ if __name__ == "__main__":
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--enable_cfg", type=bool, default=False)
parser.add_argument("--prompt", type=str, required=True)
parser.add_argument("--prompt_enhancer", type=str, default=None)
parser.add_argument("--prompt", type=str, required=True)
parser.add_argument("--negative_prompt", type=str, default="")
parser.add_argument("--image_path", type=str, default="", help="The path to input image file or path for image-to-video (i2v) task")
parser.add_argument("--save_video_path", type=str, default="./output_lightx2v.mp4", help="The path to save video path/file")
......
......@@ -11,29 +11,22 @@ from loguru import logger
class DefaultRunner:
def __init__(self, config):
self.config = config
self.config["user_prompt"] = self.config["prompt"]
self.has_prompt_enhancer = self.config.prompt_enhancer is not None and self.config.task == "t2v"
self.config["use_prompt_enhancer"] = self.has_prompt_enhancer
if self.has_prompt_enhancer:
if self.config.prompt_enhancer is not None and self.config.task == "t2v":
self.load_prompt_enhancer()
self.model, self.text_encoders, self.vae_model, self.image_encoder = self.load_model()
@ProfilingContext("Load prompt enhancer")
def load_prompt_enhancer(self):
gpu_count = torch.cuda.device_count()
if gpu_count == 1:
logger.info("Only one GPU, use prompt enhancer cpu offload")
raise NotImplementedError("prompt enhancer cpu offload is not supported.")
self.prompt_enhancer = PromptEnhancer(model_name=self.config.prompt_enhancer, device_map="cuda:1")
self.config["use_prompt_enhancer"] = True # Set use_prompt_enhancer to True now. (Default is False)
def set_inputs(self, inputs):
self.config["user_prompt"] = inputs.get("prompt", "")
self.config["prompt"] = inputs.get("prompt", "")
self.config["use_prompt_enhancer"] = inputs.get("use_prompt_enhancer", False)
self.config["use_prompt_enhancer"] = inputs.get("use_prompt_enhancer", False) # Reset use_prompt_enhancer from clinet side.
self.config["negative_prompt"] = inputs.get("negative_prompt", "")
self.config["image_path"] = inputs.get("image_path", "")
self.config["save_video_path"] = inputs.get("save_video_path", "")
......@@ -44,7 +37,8 @@ class DefaultRunner:
with ProfilingContext("Run Img Encoder"):
image_encoder_output = self.run_image_encoder(self.config, self.image_encoder, self.vae_model)
with ProfilingContext("Run Text Encoder"):
text_encoder_output = self.run_text_encoder(self.config["prompt"], self.text_encoders, self.config, image_encoder_output)
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
text_encoder_output = self.run_text_encoder(prompt, self.text_encoders, self.config, image_encoder_output)
self.set_target_shape()
self.inputs = {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output}
......@@ -93,8 +87,8 @@ class DefaultRunner:
save_videos_grid(images, self.config.save_video_path, fps=24)
def run_pipeline(self):
if self.has_prompt_enhancer and self.config["use_prompt_enhancer"]:
self.config["prompt"] = self.prompt_enhancer(self.config["user_prompt"])
if self.config["use_prompt_enhancer"]:
self.config["prompt_enhanced"] = self.prompt_enhancer(self.config["prompt"])
self.init_scheduler()
self.run_input_encoder()
self.model.scheduler.prepare(self.inputs["image_encoder_output"])
......
import argparse
import torch
from loguru import logger
from transformers import AutoModelForCausalLM, AutoTokenizer
from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext
......@@ -38,6 +39,7 @@ class PromptEnhancer:
self.model = self.model.to(device)
@ProfilingContext("Run prompt enhancer")
@torch.no_grad()
def __call__(self, prompt):
prompt = prompt.strip()
prompt = sys_prompt.format(prompt)
......@@ -46,11 +48,20 @@ class PromptEnhancer:
model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
generated_ids = self.model.generate(
**model_inputs,
max_new_tokens=2048,
max_new_tokens=8192,
)
generated_ids = [output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)]
rewritten_prompt = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
logger.info(f"Enhanced prompt: {rewritten_prompt}")
output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist()
think_id = self.tokenizer.encode("</think>")
if len(think_id) == 1:
index = len(output_ids) - output_ids[::-1].index(think_id[0])
else:
index = 0
thinking_content = self.tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n")
logger.info(f"[Enhanced] thinking content: {thinking_content}")
rewritten_prompt = self.tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
logger.info(f"[Enhanced] rewritten prompt: {rewritten_prompt}")
return rewritten_prompt
......
......@@ -19,6 +19,7 @@ def get_default_config():
"lora_path": None,
"strength_model": 1.0,
"mm_config": {},
"use_prompt_enhancer": False,
}
return default_config
......
......@@ -8,7 +8,7 @@ message = {
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
"negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "",
"save_video_path": "./output_lightx2v_wan_t2v_ap4.mp4", # It is best to set it to an absolute path.
"save_video_path": "./output_lightx2v_wan_t2v.mp4", # It is best to set it to an absolute path.
}
logger.info(f"message: {message}")
......
......@@ -6,11 +6,10 @@ url = "http://localhost:8000/v1/local/video/generate"
message = {
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
"use_prompt_enhancer": True,
"negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "",
"num_fragments": 1,
"save_video_path": "./output_lightx2v_wan_t2v_ap4.mp4", # It is best to set it to an absolute path.
"save_video_path": "./output_lightx2v_wan_t2v_enhanced.mp4", # It is best to set it to an absolute path.
"use_prompt_enhancer": True,
}
logger.info(f"message: {message}")
......
......@@ -6,7 +6,7 @@ model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=2
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment