Commit d725c154 authored by Zhuguanyu Wu's avatar Zhuguanyu Wu Committed by GitHub
Browse files

Add prompt enhancer (#29)

* [bugs fixed] fixed bugs for cpu offload.

* [rename] rename causal_model -> causvid_model

* [feature] add prompt enhancer

* [feature] add prompt enhancer

* [rename] rename causal_model -> causvid_model
parent ae96fdbf
......@@ -7,6 +7,8 @@
"seed": 42,
"sample_guide_scale": 6,
"sample_shift": 8,
"enable_cfg": false,
"cpu_offload": true,
"mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl",
"weight_auto_quant": true
......@@ -16,6 +18,5 @@
"num_frame_per_block": 3,
"num_blocks": 7,
"frame_seq_length": 1560,
"denoising_step_list": [999, 934, 862, 756, 603, 410, 250, 140, 74],
"cpu_offload": true
"denoising_step_list": [999, 934, 862, 756, 603, 410, 250, 140, 74]
}
......@@ -51,8 +51,10 @@ app = FastAPI()
class Message(BaseModel):
prompt: str
use_prompt_enhancer: bool = False
negative_prompt: str = ""
image_path: str = ""
num_fragments: int = 1
save_video_path: str
def get(self, key, default=None):
......@@ -63,8 +65,12 @@ class Message(BaseModel):
async def v1_local_video_generate(message: Message):
global runner
runner.set_inputs(message)
logger.info(f"message: {message}")
await asyncio.to_thread(runner.run_pipeline)
return {"response": "finished", "save_video_path": message.save_video_path}
response = {"response": "finished", "save_video_path": message.save_video_path}
if runner.has_prompt_enhancer and message.use_prompt_enhancer:
response["enhanced_prompt"] = runner.config["prompt"]
return response
# =========================
......@@ -74,10 +80,11 @@ async def v1_local_video_generate(message: Message):
if __name__ == "__main__":
signal.signal(signal.SIGINT, signal_handler)
parser = argparse.ArgumentParser()
parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causal"], default="hunyuan")
parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causvid"], default="hunyuan")
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--prompt_enhancer", default=None)
parser.add_argument("--port", type=int, default=8000)
args = parser.parse_args()
logger.info(f"args: {args}")
......
......@@ -29,7 +29,7 @@ def sage_attn2(q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None
)
x = torch.cat((x1, x2), dim=1)
x = x.view(max_seqlen_q, -1)
elif model_cls in ["wan2.1", "wan2.1_causal", "wan2.1_df"]:
elif model_cls in ["wan2.1", "wan2.1_causvid", "wan2.1_df"]:
x = sageattn(
q.unsqueeze(0),
k.unsqueeze(0),
......
......@@ -45,17 +45,24 @@ class WeightModule:
def to_cpu(self):
for name, param in self._parameters.items():
if param is not None and hasattr(param, "cpu"):
self._parameters[name] = param.cpu()
setattr(self, name, self._parameters[name])
if param is not None:
if hasattr(param, "cpu"):
self._parameters[name] = param.cpu()
setattr(self, name, self._parameters[name])
elif hasattr(param, "to_cpu"):
self._parameters[name].to_cpu()
setattr(self, name, self._parameters[name])
for module in self._modules.values():
if module is not None and hasattr(module, "to_cpu"):
module.to_cpu()
def to_cuda(self):
for name, param in self._parameters.items():
if param is not None and hasattr(param, "cuda"):
self._parameters[name] = param.cuda()
if param is not None:
if hasattr(param, "cuda"):
self._parameters[name] = param.cuda()
elif hasattr(param, "to_cuda"):
self._parameters[name].to_cuda()
setattr(self, name, self._parameters[name])
for module in self._modules.values():
if module is not None and hasattr(module, "to_cuda"):
......@@ -63,21 +70,28 @@ class WeightModule:
def to_cpu_sync(self):
for name, param in self._parameters.items():
if param is not None and hasattr(param, "to"):
self._parameters[name] = param.to("cpu", non_blocking=True)
setattr(self, name, self._parameters[name])
if param is not None:
if hasattr(param, "cpu"):
self._parameters[name] = param.cpu(non_blocking=True)
setattr(self, name, self._parameters[name])
elif hasattr(param, "to_cpu"):
self._parameters[name].to_cpu(non_blocking=True)
setattr(self, name, self._parameters[name])
for module in self._modules.values():
if module is not None and hasattr(module, "to_cpu_sync"):
module.to_cpu_sync()
if module is not None and hasattr(module, "to_cpu"):
module.to_cpu(non_blocking=True)
def to_cuda_sync(self):
for name, param in self._parameters.items():
if param is not None and hasattr(param, "cuda"):
self._parameters[name] = param.cuda(non_blocking=True)
if param is not None:
if hasattr(param, "cuda"):
self._parameters[name] = param.cuda(non_blocking=True)
elif hasattr(param, "to_cuda"):
self._parameters[name].to_cuda(non_blocking=True)
setattr(self, name, self._parameters[name])
for module in self._modules.values():
if module is not None and hasattr(module, "to_cuda_sync"):
module.to_cuda_sync()
if module is not None and hasattr(module, "to_cuda"):
module.to_cuda(non_blocking=True)
class WeightModuleList(WeightModule):
......
......@@ -39,12 +39,12 @@ class Conv2dWeight(Conv2dWeightTemplate):
input_tensor = torch.nn.functional.conv2d(input_tensor, weight=self.weight, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)
return input_tensor
def to_cpu(self):
self.weight = self.weight.cpu()
def to_cpu(self, non_blocking=False):
self.weight = self.weight.cpu(non_blocking=non_blocking)
if self.bias is not None:
self.bias = self.bias.cpu()
self.bias = self.bias.cpu(non_blocking=non_blocking)
def to_cuda(self):
self.weight = self.weight.cuda()
def to_cuda(self, non_blocking=False):
self.weight = self.weight.cuda(non_blocking=non_blocking)
if self.bias is not None:
self.bias = self.bias.cuda()
self.bias = self.bias.cuda(non_blocking=non_blocking)
......@@ -11,7 +11,7 @@ from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.runners.wan.wan_causal_runner import WanCausalRunner
from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner
from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner
from lightx2v.models.runners.graph_runner import GraphRunner
......@@ -35,12 +35,13 @@ def init_runner(config):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causal", "wan2.1_skyreels_v2_df"], default="hunyuan")
parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causvid", "wan2.1_skyreels_v2_df"], default="hunyuan")
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--enable_cfg", type=bool, default=False)
parser.add_argument("--prompt", type=str, required=True)
parser.add_argument("--prompt_enhancer", type=str, default=None)
parser.add_argument("--negative_prompt", type=str, default="")
parser.add_argument("--image_path", type=str, default="", help="The path to input image file or path for image-to-video (i2v) task")
parser.add_argument("--save_video_path", type=str, default="./output_lightx2v.mp4", help="The path to save video path/file")
......
......@@ -10,8 +10,8 @@ from lightx2v.models.networks.wan.weights.transformer_weights import (
)
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer
from lightx2v.models.networks.wan.infer.causal.transformer_infer import (
WanTransformerInferCausal,
from lightx2v.models.networks.wan.infer.causvid.transformer_infer import (
WanTransformerInferCausVid,
)
from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import WanTransformerInferTeaCaching
from safetensors import safe_open
......@@ -19,7 +19,7 @@ import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap
class WanCausalModel(WanModel):
class WanCausVidModel(WanModel):
pre_weight_class = WanPreWeights
post_weight_class = WanPostWeights
transformer_weight_class = WanTransformerWeights
......@@ -30,7 +30,7 @@ class WanCausalModel(WanModel):
def _init_infer_class(self):
self.pre_infer_class = WanPreInfer
self.post_infer_class = WanPostInfer
self.transformer_infer_class = WanTransformerInferCausal
self.transformer_infer_class = WanTransformerInferCausVid
def _load_ckpt(self):
use_bfloat16 = self.config.get("use_bfloat16", True)
......
import torch
import math
from ..utils import compute_freqs, compute_freqs_causal, compute_freqs_dist, apply_rotary_emb
from ..utils import compute_freqs, compute_freqs_causvid, compute_freqs_dist, apply_rotary_emb
from lightx2v.attentions import attention
from lightx2v.common.offload.manager import WeightStreamManager
from lightx2v.utils.envs import *
from ..transformer_infer import WanTransformerInfer
class WanTransformerInferCausal(WanTransformerInfer):
class WanTransformerInferCausVid(WanTransformerInfer):
def __init__(self, config):
super().__init__(config)
self.num_frames = config["num_frames"]
......@@ -52,25 +52,55 @@ class WanTransformerInferCausal(WanTransformerInfer):
def _infer_with_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, kv_start, kv_end):
for block_idx in range(self.blocks_num):
if block_idx == 0:
self.weights_stream_mgr.active_weights[0] = weights.blocks_weights[0]
self.weights_stream_mgr.active_weights[0] = weights.blocks[0]
self.weights_stream_mgr.active_weights[0].to_cuda()
with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
x = self.infer_block(self.weights_stream_mgr.active_weights[0], grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx, kv_start, kv_end)
x = self.infer_block(
self.weights_stream_mgr.active_weights[0],
grid_sizes,
embed,
x,
embed0,
seq_lens,
freqs,
context,
block_idx,
kv_start,
kv_end,
)
if block_idx < self.blocks_num - 1:
self.weights_stream_mgr.prefetch_weights(block_idx + 1, weights.blocks_weights)
self.weights_stream_mgr.prefetch_weights(block_idx + 1, weights.blocks)
self.weights_stream_mgr.swap_weights()
return x
def _infer_without_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, kv_start, kv_end):
for block_idx in range(self.blocks_num):
x = self.infer_block(weights.blocks_weights[block_idx], grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx, kv_start, kv_end)
x = self.infer_block(
weights.blocks[block_idx],
grid_sizes,
embed,
x,
embed0,
seq_lens,
freqs,
context,
block_idx,
kv_start,
kv_end,
)
return x
def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx, kv_start, kv_end):
embed0 = (weights.modulation + embed0).chunk(6, dim=1)
if embed0.dim() == 3:
modulation = weights.modulation.tensor.unsqueeze(2) # 1, 6, 1, dim
embed0 = embed0.unsqueeze(0) #
embed0 = (modulation + embed0).chunk(6, dim=1)
embed0 = [ei.squeeze(1) for ei in embed0]
elif embed0.dim() == 2:
embed0 = (weights.modulation.tensor + embed0).chunk(6, dim=1)
norm1_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6)
norm1_out = (norm1_out * (1 + embed0[1]) + embed0[0]).squeeze(0)
......@@ -81,10 +111,10 @@ class WanTransformerInferCausal(WanTransformerInfer):
v = weights.self_attn_v.apply(norm1_out).view(s, n, d)
if not self.parallel_attention:
freqs_i = compute_freqs_causal(q.size(2) // 2, grid_sizes, freqs, start_frame=kv_start // math.prod(grid_sizes[0][1:]).item())
freqs_i = compute_freqs_causvid(q.size(2) // 2, grid_sizes, freqs, start_frame=kv_start // math.prod(grid_sizes[0][1:]).item())
else:
# TODO: Implement parallel attention for causal inference
raise NotImplementedError("Parallel attention is not implemented for causal inference")
# TODO: Implement parallel attention for causvid inference
raise NotImplementedError("Parallel attention is not implemented for causvid inference")
q = apply_rotary_emb(q, freqs_i)
k = apply_rotary_emb(k, freqs_i)
......@@ -107,8 +137,8 @@ class WanTransformerInferCausal(WanTransformerInfer):
model_cls=self.config["model_cls"],
)
else:
# TODO: Implement parallel attention for causal inference
raise NotImplementedError("Parallel attention is not implemented for causal inference")
# TODO: Implement parallel attention for causvid inference
raise NotImplementedError("Parallel attention is not implemented for causvid inference")
y = weights.self_attn_o.apply(attn_out)
......@@ -116,9 +146,9 @@ class WanTransformerInferCausal(WanTransformerInfer):
norm3_out = weights.norm3.apply(x)
# TODO: Implement I2V inference for causal model
# TODO: Implement I2V inference for causvid model
if self.task == "i2v":
raise NotImplementedError("I2V inference for causal model is not implemented")
raise NotImplementedError("I2V inference for causvid model is not implemented")
n, d = self.num_heads, self.head_dim
q = weights.cross_attn_norm_q.apply(weights.cross_attn_q.apply(norm3_out)).view(-1, n, d)
......@@ -138,9 +168,9 @@ class WanTransformerInferCausal(WanTransformerInfer):
attention_type=self.attention_type, q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_k, max_seqlen_q=lq, max_seqlen_kv=lk, model_cls=self.config["model_cls"]
)
# TODO: Implement I2V inference for causal model
# TODO: Implement I2V inference for causvid model
if self.task == "i2v":
raise NotImplementedError("I2V inference for causal model is not implemented")
raise NotImplementedError("I2V inference for causvid model is not implemented")
attn_out = weights.cross_attn_o.apply(attn_out)
......
......@@ -20,7 +20,7 @@ def compute_freqs(c, grid_sizes, freqs):
return freqs_i
def compute_freqs_causal(c, grid_sizes, freqs, start_frame=0):
def compute_freqs_causvid(c, grid_sizes, freqs, start_frame=0):
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist()
seq_len = f * h * w
......
......@@ -3,6 +3,7 @@ import torch
import torch.distributed as dist
from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext
from lightx2v.utils.utils import save_videos_grid, cache_video
from lightx2v.utils.prompt_enhancer import PromptEnhancer
from lightx2v.utils.envs import *
from loguru import logger
......@@ -10,10 +11,29 @@ from loguru import logger
class DefaultRunner:
def __init__(self, config):
self.config = config
self.config["user_prompt"] = self.config["prompt"]
self.has_prompt_enhancer = self.config.prompt_enhancer is not None and self.config.task == "t2v"
self.config["use_prompt_enhancer"] = self.has_prompt_enhancer
if self.has_prompt_enhancer:
self.load_prompt_enhancer()
self.model, self.text_encoders, self.vae_model, self.image_encoder = self.load_model()
@ProfilingContext("Load prompt enhancer")
def load_prompt_enhancer(self):
gpu_count = torch.cuda.device_count()
if gpu_count == 1:
logger.info("Only one GPU, use prompt enhancer cpu offload")
raise NotImplementedError("prompt enhancer cpu offload is not supported.")
self.prompt_enhancer = PromptEnhancer(model_name=self.config.prompt_enhancer, device_map="cuda:1")
def set_inputs(self, inputs):
self.config["user_prompt"] = inputs.get("prompt", "")
self.config["prompt"] = inputs.get("prompt", "")
self.config["use_prompt_enhancer"] = inputs.get("use_prompt_enhancer", False)
self.config["negative_prompt"] = inputs.get("negative_prompt", "")
self.config["image_path"] = inputs.get("image_path", "")
self.config["save_video_path"] = inputs.get("save_video_path", "")
......@@ -55,10 +75,9 @@ class DefaultRunner:
self.model.scheduler.step_post()
def end_run(self):
if self.config.cpu_offload:
self.model.scheduler.clear()
del self.inputs, self.model.scheduler, self.model, self.text_encoders
torch.cuda.empty_cache()
self.model.scheduler.clear()
del self.inputs, self.model.scheduler
torch.cuda.empty_cache()
@ProfilingContext("Run VAE")
def run_vae(self, latents, generator):
......@@ -68,12 +87,14 @@ class DefaultRunner:
@ProfilingContext("Save video")
def save_video(self, images):
if not self.config.parallel_attn_type or (self.config.parallel_attn_type and dist.get_rank() == 0):
if self.config.model_cls in ["wan2.1", "wan2.1_causal", "wan2.1_skyreels_v2_df"]:
if self.config.model_cls in ["wan2.1", "wan2.1_causvid", "wan2.1_skyreels_v2_df"]:
cache_video(tensor=images, save_file=self.config.save_video_path, fps=16, nrow=1, normalize=True, value_range=(-1, 1))
else:
save_videos_grid(images, self.config.save_video_path, fps=24)
def run_pipeline(self):
if self.has_prompt_enhancer and self.config["use_prompt_enhancer"]:
self.config["prompt"] = self.prompt_enhancer(self.config["user_prompt"])
self.init_scheduler()
self.run_input_encoder()
self.model.scheduler.prepare(self.inputs["image_encoder_output"])
......@@ -81,3 +102,6 @@ class DefaultRunner:
self.end_run()
images = self.run_vae(latents, generator)
self.save_video(images)
del latents, generator, images
gc.collect()
torch.cuda.empty_cache()
import os
import gc
import numpy as np
import torch
import torchvision.transforms.functional as TF
......@@ -7,19 +8,19 @@ from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.models.schedulers.wan.causal.scheduler import WanCausalScheduler
from lightx2v.models.schedulers.wan.causvid.scheduler import WanCausVidScheduler
from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext
from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
from lightx2v.models.networks.wan.causal_model import WanCausalModel
from lightx2v.models.networks.wan.causvid_model import WanCausVidModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
from loguru import logger
import torch.distributed as dist
@RUNNER_REGISTER("wan2.1_causal")
class WanCausalRunner(WanRunner):
@RUNNER_REGISTER("wan2.1_causvid")
class WanCausVidRunner(WanRunner):
def __init__(self, config):
super().__init__(config)
self.denoising_step_list = self.model.config.denoising_step_list
......@@ -49,7 +50,7 @@ class WanCausalRunner(WanRunner):
shard_fn=None,
)
text_encoders = [text_encoder]
model = WanCausalModel(self.config.model_path, self.config, init_device)
model = WanCausVidModel(self.config.model_path, self.config, init_device)
if self.config.lora_path:
lora_wrapper = WanLoraWrapper(model)
......@@ -68,8 +69,13 @@ class WanCausalRunner(WanRunner):
return model, text_encoders, vae_model, image_encoder
def set_inputs(self, inputs):
super().set_inputs(inputs)
self.config["num_fragments"] = inputs.get("num_fragments", 1)
self.num_fragments = self.config["num_fragments"]
def init_scheduler(self):
scheduler = WanCausalScheduler(self.config)
scheduler = WanCausVidScheduler(self.config)
self.model.set_scheduler(scheduler)
def set_target_shape(self):
......@@ -96,7 +102,7 @@ class WanCausalRunner(WanRunner):
start_block_idx = 0
for fragment_idx in range(self.num_fragments):
logger.info(f"=======> fragment_idx: {fragment_idx + 1} / {self.num_fragments}")
logger.info(f"========> fragment_idx: {fragment_idx + 1} / {self.num_fragments}")
kv_start = 0
kv_end = kv_start + self.num_frame_per_block * self.frame_seq_length
......@@ -116,8 +122,8 @@ class WanCausalRunner(WanRunner):
infer_blocks = self.infer_blocks - (fragment_idx > 0)
for block_idx in range(infer_blocks):
logger.info(f"=======> block_idx: {block_idx + 1} / {infer_blocks}")
logger.info(f"=======> kv_start: {kv_start}, kv_end: {kv_end}")
logger.info(f"=====> block_idx: {block_idx + 1} / {infer_blocks}")
logger.info(f"=====> kv_start: {kv_start}, kv_end: {kv_end}")
self.model.scheduler.reset()
for step_index in range(self.model.scheduler.infer_steps):
......@@ -139,3 +145,9 @@ class WanCausalRunner(WanRunner):
start_block_idx += 1
return output_latents, self.model.scheduler.generator
def end_run(self):
self.model.scheduler.clear()
del self.inputs, self.model.scheduler, self.model.transformer_infer.kv_cache, self.model.transformer_infer.crossattn_cache
gc.collect()
torch.cuda.empty_cache()
......@@ -5,7 +5,7 @@ from typing import List, Optional, Tuple, Union
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
class WanCausalScheduler(WanScheduler):
class WanCausVidScheduler(WanScheduler):
def __init__(self, config):
super().__init__(config)
self.denoising_step_list = config.denoising_step_list
......
import argparse
from loguru import logger
from transformers import AutoModelForCausalLM, AutoTokenizer
from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext
sys_prompt = """
Transform the short prompt into a detailed video-generation caption using this structure:
​​Opening shot type​​ (long/medium/close-up/extreme close-up/full shot)
​​Primary subject(s)​​ with vivid attributes (colors, textures, actions, interactions)
​​Dynamic elements​​ (movement, transitions, or changes over time, e.g., 'gradually lowers,' 'begins to climb,' 'camera moves toward...')
​​Scene composition​​ (background, environment, spatial relationships)
​​Lighting/atmosphere​​ (natural/artificial, time of day, mood)
​​Camera motion​​ (zooms, pans, static/handheld shots) if applicable.
Pattern Summary from Examples:
[Shot Type] of [Subject+Action] + [Detailed Subject Description] + [Environmental Context] + [Lighting Conditions] + [Camera Movement]
​One case:
Short prompt: a person is playing football
Long prompt: Medium shot of a young athlete in a red jersey sprinting across a muddy field, dribbling a soccer ball with precise footwork. The player glances toward the goalpost, adjusts their stance, and kicks the ball forcefully into the net. Raindrops fall lightly, creating reflections under stadium floodlights. The camera follows the ball’s trajectory in a smooth pan.
Note: If the subject is stationary, incorporate camera movement to ensure the generated video remains dynamic.
​​Now expand this short prompt:​​ [{}]. Please only output the final long prompt in English.
"""
class PromptEnhancer:
def __init__(self, model_name="Qwen/Qwen2.5-32B-Instruct", device_map="cuda:0"):
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map=device_map,
)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
def to_device(self, device):
self.model = self.model.to(device)
@ProfilingContext("Run prompt enhancer")
def __call__(self, prompt):
prompt = prompt.strip()
prompt = sys_prompt.format(prompt)
messages = [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}]
text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
generated_ids = self.model.generate(
**model_inputs,
max_new_tokens=2048,
)
generated_ids = [output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)]
rewritten_prompt = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
logger.info(f"Enhanced prompt: {rewritten_prompt}")
return rewritten_prompt
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--prompt", type=str, default="In a still frame, a stop sign")
args = parser.parse_args()
prompt_enhancer = PromptEnhancer()
enhanced_prompt = prompt_enhancer(args.prompt)
logger.info(f"Original prompt: {args.prompt}")
logger.info(f"Enhanced prompt: {enhanced_prompt}")
......@@ -6,8 +6,10 @@ url = "http://localhost:8000/v1/local/video/generate"
message = {
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
"use_prompt_enhancer": True,
"negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "",
"num_fragments": 1,
"save_video_path": "./output_lightx2v_wan_t2v_ap4.mp4", # It is best to set it to an absolute path.
}
......
......@@ -29,7 +29,7 @@ export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python -m lightx2v.infer \
--model_cls wan2.1_causal \
--model_cls wan2.1_causvid \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan_t2v_causal.json \
......
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
prompt_enhancer_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0,1
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python -m lightx2v.infer \
--model_cls wan2.1 \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan_t2v.json \
--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
--prompt_enhancer ${prompt_enhancer_path} \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v.mp4
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
prompt_enhancer_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0,1
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python -m lightx2v.api_server \
--model_cls wan2.1_causvid \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan_t2v_causvid.json \
--prompt_enhancer ${prompt_enhancer_path} \
--port 8000
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment