Commit d725c154 authored by Zhuguanyu Wu's avatar Zhuguanyu Wu Committed by GitHub
Browse files

Add prompt enhancer (#29)

* [bugs fixed] fixed bugs for cpu offload.

* [rename] rename causal_model -> causvid_model

* [feature] add prompt enhancer

* [feature] add prompt enhancer

* [rename] rename causal_model -> causvid_model
parent ae96fdbf
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
"seed": 42, "seed": 42,
"sample_guide_scale": 6, "sample_guide_scale": 6,
"sample_shift": 8, "sample_shift": 8,
"enable_cfg": false,
"cpu_offload": true,
"mm_config": { "mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl", "mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl",
"weight_auto_quant": true "weight_auto_quant": true
...@@ -16,6 +18,5 @@ ...@@ -16,6 +18,5 @@
"num_frame_per_block": 3, "num_frame_per_block": 3,
"num_blocks": 7, "num_blocks": 7,
"frame_seq_length": 1560, "frame_seq_length": 1560,
"denoising_step_list": [999, 934, 862, 756, 603, 410, 250, 140, 74], "denoising_step_list": [999, 934, 862, 756, 603, 410, 250, 140, 74]
"cpu_offload": true
} }
...@@ -51,8 +51,10 @@ app = FastAPI() ...@@ -51,8 +51,10 @@ app = FastAPI()
class Message(BaseModel): class Message(BaseModel):
prompt: str prompt: str
use_prompt_enhancer: bool = False
negative_prompt: str = "" negative_prompt: str = ""
image_path: str = "" image_path: str = ""
num_fragments: int = 1
save_video_path: str save_video_path: str
def get(self, key, default=None): def get(self, key, default=None):
...@@ -63,8 +65,12 @@ class Message(BaseModel): ...@@ -63,8 +65,12 @@ class Message(BaseModel):
async def v1_local_video_generate(message: Message): async def v1_local_video_generate(message: Message):
global runner global runner
runner.set_inputs(message) runner.set_inputs(message)
logger.info(f"message: {message}")
await asyncio.to_thread(runner.run_pipeline) await asyncio.to_thread(runner.run_pipeline)
return {"response": "finished", "save_video_path": message.save_video_path} response = {"response": "finished", "save_video_path": message.save_video_path}
if runner.has_prompt_enhancer and message.use_prompt_enhancer:
response["enhanced_prompt"] = runner.config["prompt"]
return response
# ========================= # =========================
...@@ -74,10 +80,11 @@ async def v1_local_video_generate(message: Message): ...@@ -74,10 +80,11 @@ async def v1_local_video_generate(message: Message):
if __name__ == "__main__": if __name__ == "__main__":
signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGINT, signal_handler)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causal"], default="hunyuan") parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causvid"], default="hunyuan")
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v") parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True) parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True) parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--prompt_enhancer", default=None)
parser.add_argument("--port", type=int, default=8000) parser.add_argument("--port", type=int, default=8000)
args = parser.parse_args() args = parser.parse_args()
logger.info(f"args: {args}") logger.info(f"args: {args}")
......
...@@ -29,7 +29,7 @@ def sage_attn2(q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None ...@@ -29,7 +29,7 @@ def sage_attn2(q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None
) )
x = torch.cat((x1, x2), dim=1) x = torch.cat((x1, x2), dim=1)
x = x.view(max_seqlen_q, -1) x = x.view(max_seqlen_q, -1)
elif model_cls in ["wan2.1", "wan2.1_causal", "wan2.1_df"]: elif model_cls in ["wan2.1", "wan2.1_causvid", "wan2.1_df"]:
x = sageattn( x = sageattn(
q.unsqueeze(0), q.unsqueeze(0),
k.unsqueeze(0), k.unsqueeze(0),
......
...@@ -45,17 +45,24 @@ class WeightModule: ...@@ -45,17 +45,24 @@ class WeightModule:
def to_cpu(self): def to_cpu(self):
for name, param in self._parameters.items(): for name, param in self._parameters.items():
if param is not None and hasattr(param, "cpu"): if param is not None:
self._parameters[name] = param.cpu() if hasattr(param, "cpu"):
setattr(self, name, self._parameters[name]) self._parameters[name] = param.cpu()
setattr(self, name, self._parameters[name])
elif hasattr(param, "to_cpu"):
self._parameters[name].to_cpu()
setattr(self, name, self._parameters[name])
for module in self._modules.values(): for module in self._modules.values():
if module is not None and hasattr(module, "to_cpu"): if module is not None and hasattr(module, "to_cpu"):
module.to_cpu() module.to_cpu()
def to_cuda(self): def to_cuda(self):
for name, param in self._parameters.items(): for name, param in self._parameters.items():
if param is not None and hasattr(param, "cuda"): if param is not None:
self._parameters[name] = param.cuda() if hasattr(param, "cuda"):
self._parameters[name] = param.cuda()
elif hasattr(param, "to_cuda"):
self._parameters[name].to_cuda()
setattr(self, name, self._parameters[name]) setattr(self, name, self._parameters[name])
for module in self._modules.values(): for module in self._modules.values():
if module is not None and hasattr(module, "to_cuda"): if module is not None and hasattr(module, "to_cuda"):
...@@ -63,21 +70,28 @@ class WeightModule: ...@@ -63,21 +70,28 @@ class WeightModule:
def to_cpu_sync(self): def to_cpu_sync(self):
for name, param in self._parameters.items(): for name, param in self._parameters.items():
if param is not None and hasattr(param, "to"): if param is not None:
self._parameters[name] = param.to("cpu", non_blocking=True) if hasattr(param, "cpu"):
setattr(self, name, self._parameters[name]) self._parameters[name] = param.cpu(non_blocking=True)
setattr(self, name, self._parameters[name])
elif hasattr(param, "to_cpu"):
self._parameters[name].to_cpu(non_blocking=True)
setattr(self, name, self._parameters[name])
for module in self._modules.values(): for module in self._modules.values():
if module is not None and hasattr(module, "to_cpu_sync"): if module is not None and hasattr(module, "to_cpu"):
module.to_cpu_sync() module.to_cpu(non_blocking=True)
def to_cuda_sync(self): def to_cuda_sync(self):
for name, param in self._parameters.items(): for name, param in self._parameters.items():
if param is not None and hasattr(param, "cuda"): if param is not None:
self._parameters[name] = param.cuda(non_blocking=True) if hasattr(param, "cuda"):
self._parameters[name] = param.cuda(non_blocking=True)
elif hasattr(param, "to_cuda"):
self._parameters[name].to_cuda(non_blocking=True)
setattr(self, name, self._parameters[name]) setattr(self, name, self._parameters[name])
for module in self._modules.values(): for module in self._modules.values():
if module is not None and hasattr(module, "to_cuda_sync"): if module is not None and hasattr(module, "to_cuda"):
module.to_cuda_sync() module.to_cuda(non_blocking=True)
class WeightModuleList(WeightModule): class WeightModuleList(WeightModule):
......
...@@ -39,12 +39,12 @@ class Conv2dWeight(Conv2dWeightTemplate): ...@@ -39,12 +39,12 @@ class Conv2dWeight(Conv2dWeightTemplate):
input_tensor = torch.nn.functional.conv2d(input_tensor, weight=self.weight, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) input_tensor = torch.nn.functional.conv2d(input_tensor, weight=self.weight, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)
return input_tensor return input_tensor
def to_cpu(self): def to_cpu(self, non_blocking=False):
self.weight = self.weight.cpu() self.weight = self.weight.cpu(non_blocking=non_blocking)
if self.bias is not None: if self.bias is not None:
self.bias = self.bias.cpu() self.bias = self.bias.cpu(non_blocking=non_blocking)
def to_cuda(self): def to_cuda(self, non_blocking=False):
self.weight = self.weight.cuda() self.weight = self.weight.cuda(non_blocking=non_blocking)
if self.bias is not None: if self.bias is not None:
self.bias = self.bias.cuda() self.bias = self.bias.cuda(non_blocking=non_blocking)
...@@ -11,7 +11,7 @@ from lightx2v.utils.registry_factory import RUNNER_REGISTER ...@@ -11,7 +11,7 @@ from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner
from lightx2v.models.runners.wan.wan_runner import WanRunner from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.runners.wan.wan_causal_runner import WanCausalRunner from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner
from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner
from lightx2v.models.runners.graph_runner import GraphRunner from lightx2v.models.runners.graph_runner import GraphRunner
...@@ -35,12 +35,13 @@ def init_runner(config): ...@@ -35,12 +35,13 @@ def init_runner(config):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causal", "wan2.1_skyreels_v2_df"], default="hunyuan") parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causvid", "wan2.1_skyreels_v2_df"], default="hunyuan")
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v") parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True) parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True) parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--enable_cfg", type=bool, default=False) parser.add_argument("--enable_cfg", type=bool, default=False)
parser.add_argument("--prompt", type=str, required=True) parser.add_argument("--prompt", type=str, required=True)
parser.add_argument("--prompt_enhancer", type=str, default=None)
parser.add_argument("--negative_prompt", type=str, default="") parser.add_argument("--negative_prompt", type=str, default="")
parser.add_argument("--image_path", type=str, default="", help="The path to input image file or path for image-to-video (i2v) task") parser.add_argument("--image_path", type=str, default="", help="The path to input image file or path for image-to-video (i2v) task")
parser.add_argument("--save_video_path", type=str, default="./output_lightx2v.mp4", help="The path to save video path/file") parser.add_argument("--save_video_path", type=str, default="./output_lightx2v.mp4", help="The path to save video path/file")
......
...@@ -10,8 +10,8 @@ from lightx2v.models.networks.wan.weights.transformer_weights import ( ...@@ -10,8 +10,8 @@ from lightx2v.models.networks.wan.weights.transformer_weights import (
) )
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer
from lightx2v.models.networks.wan.infer.causal.transformer_infer import ( from lightx2v.models.networks.wan.infer.causvid.transformer_infer import (
WanTransformerInferCausal, WanTransformerInferCausVid,
) )
from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import WanTransformerInferTeaCaching from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import WanTransformerInferTeaCaching
from safetensors import safe_open from safetensors import safe_open
...@@ -19,7 +19,7 @@ import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap ...@@ -19,7 +19,7 @@ import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap
class WanCausalModel(WanModel): class WanCausVidModel(WanModel):
pre_weight_class = WanPreWeights pre_weight_class = WanPreWeights
post_weight_class = WanPostWeights post_weight_class = WanPostWeights
transformer_weight_class = WanTransformerWeights transformer_weight_class = WanTransformerWeights
...@@ -30,7 +30,7 @@ class WanCausalModel(WanModel): ...@@ -30,7 +30,7 @@ class WanCausalModel(WanModel):
def _init_infer_class(self): def _init_infer_class(self):
self.pre_infer_class = WanPreInfer self.pre_infer_class = WanPreInfer
self.post_infer_class = WanPostInfer self.post_infer_class = WanPostInfer
self.transformer_infer_class = WanTransformerInferCausal self.transformer_infer_class = WanTransformerInferCausVid
def _load_ckpt(self): def _load_ckpt(self):
use_bfloat16 = self.config.get("use_bfloat16", True) use_bfloat16 = self.config.get("use_bfloat16", True)
......
import torch import torch
import math import math
from ..utils import compute_freqs, compute_freqs_causal, compute_freqs_dist, apply_rotary_emb from ..utils import compute_freqs, compute_freqs_causvid, compute_freqs_dist, apply_rotary_emb
from lightx2v.attentions import attention from lightx2v.attentions import attention
from lightx2v.common.offload.manager import WeightStreamManager from lightx2v.common.offload.manager import WeightStreamManager
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from ..transformer_infer import WanTransformerInfer from ..transformer_infer import WanTransformerInfer
class WanTransformerInferCausal(WanTransformerInfer): class WanTransformerInferCausVid(WanTransformerInfer):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_frames = config["num_frames"] self.num_frames = config["num_frames"]
...@@ -52,25 +52,55 @@ class WanTransformerInferCausal(WanTransformerInfer): ...@@ -52,25 +52,55 @@ class WanTransformerInferCausal(WanTransformerInfer):
def _infer_with_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, kv_start, kv_end): def _infer_with_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, kv_start, kv_end):
for block_idx in range(self.blocks_num): for block_idx in range(self.blocks_num):
if block_idx == 0: if block_idx == 0:
self.weights_stream_mgr.active_weights[0] = weights.blocks_weights[0] self.weights_stream_mgr.active_weights[0] = weights.blocks[0]
self.weights_stream_mgr.active_weights[0].to_cuda() self.weights_stream_mgr.active_weights[0].to_cuda()
with torch.cuda.stream(self.weights_stream_mgr.compute_stream): with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
x = self.infer_block(self.weights_stream_mgr.active_weights[0], grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx, kv_start, kv_end) x = self.infer_block(
self.weights_stream_mgr.active_weights[0],
grid_sizes,
embed,
x,
embed0,
seq_lens,
freqs,
context,
block_idx,
kv_start,
kv_end,
)
if block_idx < self.blocks_num - 1: if block_idx < self.blocks_num - 1:
self.weights_stream_mgr.prefetch_weights(block_idx + 1, weights.blocks_weights) self.weights_stream_mgr.prefetch_weights(block_idx + 1, weights.blocks)
self.weights_stream_mgr.swap_weights() self.weights_stream_mgr.swap_weights()
return x return x
def _infer_without_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, kv_start, kv_end): def _infer_without_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, kv_start, kv_end):
for block_idx in range(self.blocks_num): for block_idx in range(self.blocks_num):
x = self.infer_block(weights.blocks_weights[block_idx], grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx, kv_start, kv_end) x = self.infer_block(
weights.blocks[block_idx],
grid_sizes,
embed,
x,
embed0,
seq_lens,
freqs,
context,
block_idx,
kv_start,
kv_end,
)
return x return x
def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx, kv_start, kv_end): def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx, kv_start, kv_end):
embed0 = (weights.modulation + embed0).chunk(6, dim=1) if embed0.dim() == 3:
modulation = weights.modulation.tensor.unsqueeze(2) # 1, 6, 1, dim
embed0 = embed0.unsqueeze(0) #
embed0 = (modulation + embed0).chunk(6, dim=1)
embed0 = [ei.squeeze(1) for ei in embed0]
elif embed0.dim() == 2:
embed0 = (weights.modulation.tensor + embed0).chunk(6, dim=1)
norm1_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6) norm1_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6)
norm1_out = (norm1_out * (1 + embed0[1]) + embed0[0]).squeeze(0) norm1_out = (norm1_out * (1 + embed0[1]) + embed0[0]).squeeze(0)
...@@ -81,10 +111,10 @@ class WanTransformerInferCausal(WanTransformerInfer): ...@@ -81,10 +111,10 @@ class WanTransformerInferCausal(WanTransformerInfer):
v = weights.self_attn_v.apply(norm1_out).view(s, n, d) v = weights.self_attn_v.apply(norm1_out).view(s, n, d)
if not self.parallel_attention: if not self.parallel_attention:
freqs_i = compute_freqs_causal(q.size(2) // 2, grid_sizes, freqs, start_frame=kv_start // math.prod(grid_sizes[0][1:]).item()) freqs_i = compute_freqs_causvid(q.size(2) // 2, grid_sizes, freqs, start_frame=kv_start // math.prod(grid_sizes[0][1:]).item())
else: else:
# TODO: Implement parallel attention for causal inference # TODO: Implement parallel attention for causvid inference
raise NotImplementedError("Parallel attention is not implemented for causal inference") raise NotImplementedError("Parallel attention is not implemented for causvid inference")
q = apply_rotary_emb(q, freqs_i) q = apply_rotary_emb(q, freqs_i)
k = apply_rotary_emb(k, freqs_i) k = apply_rotary_emb(k, freqs_i)
...@@ -107,8 +137,8 @@ class WanTransformerInferCausal(WanTransformerInfer): ...@@ -107,8 +137,8 @@ class WanTransformerInferCausal(WanTransformerInfer):
model_cls=self.config["model_cls"], model_cls=self.config["model_cls"],
) )
else: else:
# TODO: Implement parallel attention for causal inference # TODO: Implement parallel attention for causvid inference
raise NotImplementedError("Parallel attention is not implemented for causal inference") raise NotImplementedError("Parallel attention is not implemented for causvid inference")
y = weights.self_attn_o.apply(attn_out) y = weights.self_attn_o.apply(attn_out)
...@@ -116,9 +146,9 @@ class WanTransformerInferCausal(WanTransformerInfer): ...@@ -116,9 +146,9 @@ class WanTransformerInferCausal(WanTransformerInfer):
norm3_out = weights.norm3.apply(x) norm3_out = weights.norm3.apply(x)
# TODO: Implement I2V inference for causal model # TODO: Implement I2V inference for causvid model
if self.task == "i2v": if self.task == "i2v":
raise NotImplementedError("I2V inference for causal model is not implemented") raise NotImplementedError("I2V inference for causvid model is not implemented")
n, d = self.num_heads, self.head_dim n, d = self.num_heads, self.head_dim
q = weights.cross_attn_norm_q.apply(weights.cross_attn_q.apply(norm3_out)).view(-1, n, d) q = weights.cross_attn_norm_q.apply(weights.cross_attn_q.apply(norm3_out)).view(-1, n, d)
...@@ -138,9 +168,9 @@ class WanTransformerInferCausal(WanTransformerInfer): ...@@ -138,9 +168,9 @@ class WanTransformerInferCausal(WanTransformerInfer):
attention_type=self.attention_type, q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_k, max_seqlen_q=lq, max_seqlen_kv=lk, model_cls=self.config["model_cls"] attention_type=self.attention_type, q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_k, max_seqlen_q=lq, max_seqlen_kv=lk, model_cls=self.config["model_cls"]
) )
# TODO: Implement I2V inference for causal model # TODO: Implement I2V inference for causvid model
if self.task == "i2v": if self.task == "i2v":
raise NotImplementedError("I2V inference for causal model is not implemented") raise NotImplementedError("I2V inference for causvid model is not implemented")
attn_out = weights.cross_attn_o.apply(attn_out) attn_out = weights.cross_attn_o.apply(attn_out)
......
...@@ -20,7 +20,7 @@ def compute_freqs(c, grid_sizes, freqs): ...@@ -20,7 +20,7 @@ def compute_freqs(c, grid_sizes, freqs):
return freqs_i return freqs_i
def compute_freqs_causal(c, grid_sizes, freqs, start_frame=0): def compute_freqs_causvid(c, grid_sizes, freqs, start_frame=0):
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist() f, h, w = grid_sizes[0].tolist()
seq_len = f * h * w seq_len = f * h * w
......
...@@ -3,6 +3,7 @@ import torch ...@@ -3,6 +3,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext
from lightx2v.utils.utils import save_videos_grid, cache_video from lightx2v.utils.utils import save_videos_grid, cache_video
from lightx2v.utils.prompt_enhancer import PromptEnhancer
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from loguru import logger from loguru import logger
...@@ -10,10 +11,29 @@ from loguru import logger ...@@ -10,10 +11,29 @@ from loguru import logger
class DefaultRunner: class DefaultRunner:
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
self.config["user_prompt"] = self.config["prompt"]
self.has_prompt_enhancer = self.config.prompt_enhancer is not None and self.config.task == "t2v"
self.config["use_prompt_enhancer"] = self.has_prompt_enhancer
if self.has_prompt_enhancer:
self.load_prompt_enhancer()
self.model, self.text_encoders, self.vae_model, self.image_encoder = self.load_model() self.model, self.text_encoders, self.vae_model, self.image_encoder = self.load_model()
@ProfilingContext("Load prompt enhancer")
def load_prompt_enhancer(self):
gpu_count = torch.cuda.device_count()
if gpu_count == 1:
logger.info("Only one GPU, use prompt enhancer cpu offload")
raise NotImplementedError("prompt enhancer cpu offload is not supported.")
self.prompt_enhancer = PromptEnhancer(model_name=self.config.prompt_enhancer, device_map="cuda:1")
def set_inputs(self, inputs): def set_inputs(self, inputs):
self.config["user_prompt"] = inputs.get("prompt", "")
self.config["prompt"] = inputs.get("prompt", "") self.config["prompt"] = inputs.get("prompt", "")
self.config["use_prompt_enhancer"] = inputs.get("use_prompt_enhancer", False)
self.config["negative_prompt"] = inputs.get("negative_prompt", "") self.config["negative_prompt"] = inputs.get("negative_prompt", "")
self.config["image_path"] = inputs.get("image_path", "") self.config["image_path"] = inputs.get("image_path", "")
self.config["save_video_path"] = inputs.get("save_video_path", "") self.config["save_video_path"] = inputs.get("save_video_path", "")
...@@ -55,10 +75,9 @@ class DefaultRunner: ...@@ -55,10 +75,9 @@ class DefaultRunner:
self.model.scheduler.step_post() self.model.scheduler.step_post()
def end_run(self): def end_run(self):
if self.config.cpu_offload: self.model.scheduler.clear()
self.model.scheduler.clear() del self.inputs, self.model.scheduler
del self.inputs, self.model.scheduler, self.model, self.text_encoders torch.cuda.empty_cache()
torch.cuda.empty_cache()
@ProfilingContext("Run VAE") @ProfilingContext("Run VAE")
def run_vae(self, latents, generator): def run_vae(self, latents, generator):
...@@ -68,12 +87,14 @@ class DefaultRunner: ...@@ -68,12 +87,14 @@ class DefaultRunner:
@ProfilingContext("Save video") @ProfilingContext("Save video")
def save_video(self, images): def save_video(self, images):
if not self.config.parallel_attn_type or (self.config.parallel_attn_type and dist.get_rank() == 0): if not self.config.parallel_attn_type or (self.config.parallel_attn_type and dist.get_rank() == 0):
if self.config.model_cls in ["wan2.1", "wan2.1_causal", "wan2.1_skyreels_v2_df"]: if self.config.model_cls in ["wan2.1", "wan2.1_causvid", "wan2.1_skyreels_v2_df"]:
cache_video(tensor=images, save_file=self.config.save_video_path, fps=16, nrow=1, normalize=True, value_range=(-1, 1)) cache_video(tensor=images, save_file=self.config.save_video_path, fps=16, nrow=1, normalize=True, value_range=(-1, 1))
else: else:
save_videos_grid(images, self.config.save_video_path, fps=24) save_videos_grid(images, self.config.save_video_path, fps=24)
def run_pipeline(self): def run_pipeline(self):
if self.has_prompt_enhancer and self.config["use_prompt_enhancer"]:
self.config["prompt"] = self.prompt_enhancer(self.config["user_prompt"])
self.init_scheduler() self.init_scheduler()
self.run_input_encoder() self.run_input_encoder()
self.model.scheduler.prepare(self.inputs["image_encoder_output"]) self.model.scheduler.prepare(self.inputs["image_encoder_output"])
...@@ -81,3 +102,6 @@ class DefaultRunner: ...@@ -81,3 +102,6 @@ class DefaultRunner:
self.end_run() self.end_run()
images = self.run_vae(latents, generator) images = self.run_vae(latents, generator)
self.save_video(images) self.save_video(images)
del latents, generator, images
gc.collect()
torch.cuda.empty_cache()
import os import os
import gc
import numpy as np import numpy as np
import torch import torch
import torchvision.transforms.functional as TF import torchvision.transforms.functional as TF
...@@ -7,19 +8,19 @@ from lightx2v.utils.registry_factory import RUNNER_REGISTER ...@@ -7,19 +8,19 @@ from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.wan.wan_runner import WanRunner from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.runners.default_runner import DefaultRunner from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.wan.scheduler import WanScheduler from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.models.schedulers.wan.causal.scheduler import WanCausalScheduler from lightx2v.models.schedulers.wan.causvid.scheduler import WanCausVidScheduler
from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext
from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
from lightx2v.models.networks.wan.causal_model import WanCausalModel from lightx2v.models.networks.wan.causvid_model import WanCausVidModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
from loguru import logger from loguru import logger
import torch.distributed as dist import torch.distributed as dist
@RUNNER_REGISTER("wan2.1_causal") @RUNNER_REGISTER("wan2.1_causvid")
class WanCausalRunner(WanRunner): class WanCausVidRunner(WanRunner):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.denoising_step_list = self.model.config.denoising_step_list self.denoising_step_list = self.model.config.denoising_step_list
...@@ -49,7 +50,7 @@ class WanCausalRunner(WanRunner): ...@@ -49,7 +50,7 @@ class WanCausalRunner(WanRunner):
shard_fn=None, shard_fn=None,
) )
text_encoders = [text_encoder] text_encoders = [text_encoder]
model = WanCausalModel(self.config.model_path, self.config, init_device) model = WanCausVidModel(self.config.model_path, self.config, init_device)
if self.config.lora_path: if self.config.lora_path:
lora_wrapper = WanLoraWrapper(model) lora_wrapper = WanLoraWrapper(model)
...@@ -68,8 +69,13 @@ class WanCausalRunner(WanRunner): ...@@ -68,8 +69,13 @@ class WanCausalRunner(WanRunner):
return model, text_encoders, vae_model, image_encoder return model, text_encoders, vae_model, image_encoder
def set_inputs(self, inputs):
super().set_inputs(inputs)
self.config["num_fragments"] = inputs.get("num_fragments", 1)
self.num_fragments = self.config["num_fragments"]
def init_scheduler(self): def init_scheduler(self):
scheduler = WanCausalScheduler(self.config) scheduler = WanCausVidScheduler(self.config)
self.model.set_scheduler(scheduler) self.model.set_scheduler(scheduler)
def set_target_shape(self): def set_target_shape(self):
...@@ -96,7 +102,7 @@ class WanCausalRunner(WanRunner): ...@@ -96,7 +102,7 @@ class WanCausalRunner(WanRunner):
start_block_idx = 0 start_block_idx = 0
for fragment_idx in range(self.num_fragments): for fragment_idx in range(self.num_fragments):
logger.info(f"=======> fragment_idx: {fragment_idx + 1} / {self.num_fragments}") logger.info(f"========> fragment_idx: {fragment_idx + 1} / {self.num_fragments}")
kv_start = 0 kv_start = 0
kv_end = kv_start + self.num_frame_per_block * self.frame_seq_length kv_end = kv_start + self.num_frame_per_block * self.frame_seq_length
...@@ -116,8 +122,8 @@ class WanCausalRunner(WanRunner): ...@@ -116,8 +122,8 @@ class WanCausalRunner(WanRunner):
infer_blocks = self.infer_blocks - (fragment_idx > 0) infer_blocks = self.infer_blocks - (fragment_idx > 0)
for block_idx in range(infer_blocks): for block_idx in range(infer_blocks):
logger.info(f"=======> block_idx: {block_idx + 1} / {infer_blocks}") logger.info(f"=====> block_idx: {block_idx + 1} / {infer_blocks}")
logger.info(f"=======> kv_start: {kv_start}, kv_end: {kv_end}") logger.info(f"=====> kv_start: {kv_start}, kv_end: {kv_end}")
self.model.scheduler.reset() self.model.scheduler.reset()
for step_index in range(self.model.scheduler.infer_steps): for step_index in range(self.model.scheduler.infer_steps):
...@@ -139,3 +145,9 @@ class WanCausalRunner(WanRunner): ...@@ -139,3 +145,9 @@ class WanCausalRunner(WanRunner):
start_block_idx += 1 start_block_idx += 1
return output_latents, self.model.scheduler.generator return output_latents, self.model.scheduler.generator
def end_run(self):
self.model.scheduler.clear()
del self.inputs, self.model.scheduler, self.model.transformer_infer.kv_cache, self.model.transformer_infer.crossattn_cache
gc.collect()
torch.cuda.empty_cache()
...@@ -5,7 +5,7 @@ from typing import List, Optional, Tuple, Union ...@@ -5,7 +5,7 @@ from typing import List, Optional, Tuple, Union
from lightx2v.models.schedulers.wan.scheduler import WanScheduler from lightx2v.models.schedulers.wan.scheduler import WanScheduler
class WanCausalScheduler(WanScheduler): class WanCausVidScheduler(WanScheduler):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.denoising_step_list = config.denoising_step_list self.denoising_step_list = config.denoising_step_list
......
import argparse
from loguru import logger
from transformers import AutoModelForCausalLM, AutoTokenizer
from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext
sys_prompt = """
Transform the short prompt into a detailed video-generation caption using this structure:
​​Opening shot type​​ (long/medium/close-up/extreme close-up/full shot)
​​Primary subject(s)​​ with vivid attributes (colors, textures, actions, interactions)
​​Dynamic elements​​ (movement, transitions, or changes over time, e.g., 'gradually lowers,' 'begins to climb,' 'camera moves toward...')
​​Scene composition​​ (background, environment, spatial relationships)
​​Lighting/atmosphere​​ (natural/artificial, time of day, mood)
​​Camera motion​​ (zooms, pans, static/handheld shots) if applicable.
Pattern Summary from Examples:
[Shot Type] of [Subject+Action] + [Detailed Subject Description] + [Environmental Context] + [Lighting Conditions] + [Camera Movement]
​One case:
Short prompt: a person is playing football
Long prompt: Medium shot of a young athlete in a red jersey sprinting across a muddy field, dribbling a soccer ball with precise footwork. The player glances toward the goalpost, adjusts their stance, and kicks the ball forcefully into the net. Raindrops fall lightly, creating reflections under stadium floodlights. The camera follows the ball’s trajectory in a smooth pan.
Note: If the subject is stationary, incorporate camera movement to ensure the generated video remains dynamic.
​​Now expand this short prompt:​​ [{}]. Please only output the final long prompt in English.
"""
class PromptEnhancer:
def __init__(self, model_name="Qwen/Qwen2.5-32B-Instruct", device_map="cuda:0"):
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map=device_map,
)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
def to_device(self, device):
self.model = self.model.to(device)
@ProfilingContext("Run prompt enhancer")
def __call__(self, prompt):
prompt = prompt.strip()
prompt = sys_prompt.format(prompt)
messages = [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}]
text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
generated_ids = self.model.generate(
**model_inputs,
max_new_tokens=2048,
)
generated_ids = [output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)]
rewritten_prompt = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
logger.info(f"Enhanced prompt: {rewritten_prompt}")
return rewritten_prompt
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--prompt", type=str, default="In a still frame, a stop sign")
args = parser.parse_args()
prompt_enhancer = PromptEnhancer()
enhanced_prompt = prompt_enhancer(args.prompt)
logger.info(f"Original prompt: {args.prompt}")
logger.info(f"Enhanced prompt: {enhanced_prompt}")
...@@ -6,8 +6,10 @@ url = "http://localhost:8000/v1/local/video/generate" ...@@ -6,8 +6,10 @@ url = "http://localhost:8000/v1/local/video/generate"
message = { message = {
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
"use_prompt_enhancer": True,
"negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", "negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "", "image_path": "",
"num_fragments": 1,
"save_video_path": "./output_lightx2v_wan_t2v_ap4.mp4", # It is best to set it to an absolute path. "save_video_path": "./output_lightx2v_wan_t2v_ap4.mp4", # It is best to set it to an absolute path.
} }
......
...@@ -29,7 +29,7 @@ export ENABLE_PROFILING_DEBUG=true ...@@ -29,7 +29,7 @@ export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false export ENABLE_GRAPH_MODE=false
python -m lightx2v.infer \ python -m lightx2v.infer \
--model_cls wan2.1_causal \ --model_cls wan2.1_causvid \
--task t2v \ --task t2v \
--model_path $model_path \ --model_path $model_path \
--config_json ${lightx2v_path}/configs/wan_t2v_causal.json \ --config_json ${lightx2v_path}/configs/wan_t2v_causal.json \
......
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
prompt_enhancer_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0,1
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python -m lightx2v.infer \
--model_cls wan2.1 \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan_t2v.json \
--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
--prompt_enhancer ${prompt_enhancer_path} \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v.mp4
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
prompt_enhancer_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0,1
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python -m lightx2v.api_server \
--model_cls wan2.1_causvid \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan_t2v_causvid.json \
--prompt_enhancer ${prompt_enhancer_path} \
--port 8000
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment