Commit b453c16e authored by Zhuguanyu Wu's avatar Zhuguanyu Wu Committed by GitHub
Browse files

support autoregressive inference for wan2.1-t2v (#23)



* [feature]: add support for Causvid inference

* [feature]: support long video generation for autoregressive wan2.1-t2v

* [feature]: support long video generation for autoregressive wan2.1-t2v

* Update wan_t2v_causal.json

* Update run_wan_t2v_causal.sh

* update readme for Wan2.1-T2V-CausVid

---------
Co-authored-by: default avatarYang Yong(雍洋) <yongyang1030@163.com>
parent 7c3da5c0
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
[Wan2.1-I2V](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P) [Wan2.1-I2V](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P)
[Wan2.1-T2V-CausVid](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-CausVid)
## Fast Start Up With Conda ## Fast Start Up With Conda
```shell ```shell
......
{
"infer_steps": 9,
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"attention_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 6,
"sample_shift": 8,
"mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl",
"weight_auto_quant": true
},
"num_fragments": 3,
"num_frames": 21,
"num_frame_per_block": 3,
"num_blocks": 7,
"frame_seq_length": 1560,
"denoising_step_list": [999, 934, 862, 756, 603, 410, 250, 140, 74],
"cpu_offload": true
}
...@@ -11,6 +11,7 @@ from lightx2v.utils.registry_factory import RUNNER_REGISTER ...@@ -11,6 +11,7 @@ from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner
from lightx2v.models.runners.wan.wan_runner import WanRunner from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.runners.wan.wan_causal_runner import WanCausalRunner
from lightx2v.models.runners.graph_runner import GraphRunner from lightx2v.models.runners.graph_runner import GraphRunner
from lightx2v.common.ops import * from lightx2v.common.ops import *
...@@ -18,7 +19,7 @@ from lightx2v.common.ops import * ...@@ -18,7 +19,7 @@ from lightx2v.common.ops import *
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan"], default="hunyuan") parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causal"], default="hunyuan")
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v") parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True) parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--prompt", type=str, required=True) parser.add_argument("--prompt", type=str, required=True)
......
...@@ -29,7 +29,7 @@ def sage_attn2(q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None ...@@ -29,7 +29,7 @@ def sage_attn2(q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None
) )
x = torch.cat((x1, x2), dim=1) x = torch.cat((x1, x2), dim=1)
x = x.view(max_seqlen_q, -1) x = x.view(max_seqlen_q, -1)
elif model_cls == "wan2.1": elif model_cls in ["wan2.1", "wan2.1_causal"]:
x = sageattn( x = sageattn(
q.unsqueeze(0), q.unsqueeze(0),
k.unsqueeze(0), k.unsqueeze(0),
......
import os
import torch
import time
import glob
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
from lightx2v.models.networks.wan.weights.transformer_weights import (
WanTransformerWeights,
)
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer
from lightx2v.models.networks.wan.infer.causal.transformer_infer import (
WanTransformerInferCausal,
)
from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import WanTransformerInferTeaCaching
from safetensors import safe_open
import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap
class WanCausalModel(WanModel):
pre_weight_class = WanPreWeights
post_weight_class = WanPostWeights
transformer_weight_class = WanTransformerWeights
def __init__(self, model_path, config, device):
super().__init__(model_path, config, device)
def _init_infer_class(self):
self.pre_infer_class = WanPreInfer
self.post_infer_class = WanPostInfer
self.transformer_infer_class = WanTransformerInferCausal
def _load_ckpt(self):
use_bfloat16 = self.config.get("use_bfloat16", True)
weight_dict = torch.load(os.path.join(self.model_path, "causal_model.pt"), map_location="cpu", weights_only=True)
dtype = torch.bfloat16 if use_bfloat16 else None
for key, value in weight_dict.items():
weight_dict[key] = value.to(device=self.device, dtype=dtype)
return weight_dict
@torch.no_grad()
def infer(self, inputs, kv_start, kv_end):
if self.config["cpu_offload"]:
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=True)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out, kv_start, kv_end)
self.scheduler.noise_pred = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
if self.config["cpu_offload"]:
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
import torch
import math
from ..utils import compute_freqs, compute_freqs_causal, compute_freqs_dist, apply_rotary_emb
from lightx2v.attentions import attention
from lightx2v.common.offload.manager import WeightStreamManager
from lightx2v.utils.envs import *
from ..transformer_infer import WanTransformerInfer
class WanTransformerInferCausal(WanTransformerInfer):
def __init__(self, config):
super().__init__(config)
self.num_frames = config["num_frames"]
self.num_frame_per_block = config["num_frame_per_block"]
self.frame_seq_length = config["frame_seq_length"]
self.text_len = config["text_len"]
self.kv_size = self.num_frames * self.frame_seq_length
self.kv_cache = None
self.crossattn_cache = None
def _init_kv_cache(self, dtype, device):
kv_cache = []
for _ in range(self.blocks_num):
kv_cache.append(
{
"k": torch.zeros([self.kv_size, self.num_heads, self.head_dim], dtype=dtype, device=device),
"v": torch.zeros([self.kv_size, self.num_heads, self.head_dim], dtype=dtype, device=device),
}
)
self.kv_cache = kv_cache
def _init_crossattn_cache(self, dtype, device):
crossattn_cache = []
for _ in range(self.blocks_num):
crossattn_cache.append(
{
"k": torch.zeros([self.text_len, self.num_heads, self.head_dim], dtype=dtype, device=device),
"v": torch.zeros([self.text_len, self.num_heads, self.head_dim], dtype=dtype, device=device),
"is_init": False,
}
)
self.crossattn_cache = crossattn_cache
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, kv_start, kv_end):
return self.infer_func(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, kv_start, kv_end)
def _infer_with_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, kv_start, kv_end):
for block_idx in range(self.blocks_num):
if block_idx == 0:
self.weights_stream_mgr.active_weights[0] = weights.blocks_weights[0]
self.weights_stream_mgr.active_weights[0].to_cuda()
with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
x = self.infer_block(self.weights_stream_mgr.active_weights[0], grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx, kv_start, kv_end)
if block_idx < self.blocks_num - 1:
self.weights_stream_mgr.prefetch_weights(block_idx + 1, weights.blocks_weights)
self.weights_stream_mgr.swap_weights()
return x
def _infer_without_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, kv_start, kv_end):
for block_idx in range(self.blocks_num):
x = self.infer_block(weights.blocks_weights[block_idx], grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx, kv_start, kv_end)
return x
def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx, kv_start, kv_end):
embed0 = (weights.modulation + embed0).chunk(6, dim=1)
norm1_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6)
norm1_out = (norm1_out * (1 + embed0[1]) + embed0[0]).squeeze(0)
s, n, d = *norm1_out.shape[:1], self.num_heads, self.head_dim
q = weights.self_attn_norm_q.apply(weights.self_attn_q.apply(norm1_out)).view(s, n, d)
k = weights.self_attn_norm_k.apply(weights.self_attn_k.apply(norm1_out)).view(s, n, d)
v = weights.self_attn_v.apply(norm1_out).view(s, n, d)
if not self.parallel_attention:
freqs_i = compute_freqs_causal(q.size(2) // 2, grid_sizes, freqs, start_frame=kv_start // math.prod(grid_sizes[0][1:]).item())
else:
# TODO: Implement parallel attention for causal inference
raise NotImplementedError("Parallel attention is not implemented for causal inference")
q = apply_rotary_emb(q, freqs_i)
k = apply_rotary_emb(k, freqs_i)
self.kv_cache[block_idx]["k"][kv_start:kv_end] = k
self.kv_cache[block_idx]["v"][kv_start:kv_end] = v
cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(q=q, k=self.kv_cache[block_idx]["k"][:kv_end], k_lens=torch.tensor([kv_end], dtype=torch.int32, device=k.device))
if not self.parallel_attention:
attn_out = attention(
attention_type=self.attention_type,
q=q,
k=self.kv_cache[block_idx]["k"][:kv_end],
v=self.kv_cache[block_idx]["v"][:kv_end],
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_k,
max_seqlen_q=lq,
max_seqlen_kv=lk,
model_cls=self.config["model_cls"],
)
else:
# TODO: Implement parallel attention for causal inference
raise NotImplementedError("Parallel attention is not implemented for causal inference")
y = weights.self_attn_o.apply(attn_out)
x = x + y * embed0[2].squeeze(0)
norm3_out = weights.norm3.apply(x)
# TODO: Implement I2V inference for causal model
if self.task == "i2v":
raise NotImplementedError("I2V inference for causal model is not implemented")
n, d = self.num_heads, self.head_dim
q = weights.cross_attn_norm_q.apply(weights.cross_attn_q.apply(norm3_out)).view(-1, n, d)
if not self.crossattn_cache[block_idx]["is_init"]:
k = weights.cross_attn_norm_k.apply(weights.cross_attn_k.apply(context)).view(-1, n, d)
v = weights.cross_attn_v.apply(context).view(-1, n, d)
self.crossattn_cache[block_idx]["k"] = k
self.crossattn_cache[block_idx]["v"] = v
self.crossattn_cache[block_idx]["is_init"] = True
else:
k = self.crossattn_cache[block_idx]["k"]
v = self.crossattn_cache[block_idx]["v"]
cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(q, k, k_lens=torch.tensor([k.size(0)], dtype=torch.int32, device=k.device))
attn_out = attention(
attention_type=self.attention_type, q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_k, max_seqlen_q=lq, max_seqlen_kv=lk, model_cls=self.config["model_cls"]
)
# TODO: Implement I2V inference for causal model
if self.task == "i2v":
raise NotImplementedError("I2V inference for causal model is not implemented")
attn_out = weights.cross_attn_o.apply(attn_out)
x = x + attn_out
norm2_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6)
y = weights.ffn_0.apply(norm2_out * (1 + embed0[4].squeeze(0)) + embed0[3].squeeze(0))
y = torch.nn.functional.gelu(y, approximate="tanh")
y = weights.ffn_2.apply(y)
x = x + y * embed0[5].squeeze(0)
return x
...@@ -20,6 +20,22 @@ def compute_freqs(c, grid_sizes, freqs): ...@@ -20,6 +20,22 @@ def compute_freqs(c, grid_sizes, freqs):
return freqs_i return freqs_i
def compute_freqs_causal(c, grid_sizes, freqs, start_frame=0):
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist()
seq_len = f * h * w
freqs_i = torch.cat(
[
freqs[0][start_frame : start_frame + f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
],
dim=-1,
).reshape(seq_len, 1, -1)
return freqs_i
def pad_freqs(original_tensor, target_len): def pad_freqs(original_tensor, target_len):
seq_len, s1, s2 = original_tensor.shape seq_len, s1, s2 = original_tensor.shape
pad_size = target_len - seq_len pad_size = target_len - seq_len
......
...@@ -61,7 +61,7 @@ class DefaultRunner: ...@@ -61,7 +61,7 @@ class DefaultRunner:
@ProfilingContext("Save video") @ProfilingContext("Save video")
def save_video(self, images): def save_video(self, images):
if not self.config.parallel_attn_type or (self.config.parallel_attn_type and dist.get_rank() == 0): if not self.config.parallel_attn_type or (self.config.parallel_attn_type and dist.get_rank() == 0):
if self.config.model_cls == "wan2.1": if self.config.model_cls in ["wan2.1", "wan2.1_causal"]:
cache_video(tensor=images, save_file=self.config.save_video_path, fps=16, nrow=1, normalize=True, value_range=(-1, 1)) cache_video(tensor=images, save_file=self.config.save_video_path, fps=16, nrow=1, normalize=True, value_range=(-1, 1))
else: else:
save_videos_grid(images, self.config.save_video_path, fps=24) save_videos_grid(images, self.config.save_video_path, fps=24)
......
import os
import numpy as np
import torch
import torchvision.transforms.functional as TF
from PIL import Image
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.models.schedulers.wan.causal.scheduler import WanCausalScheduler
from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext
from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
from lightx2v.models.networks.wan.causal_model import WanCausalModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
import torch.distributed as dist
@RUNNER_REGISTER("wan2.1_causal")
class WanCausalRunner(WanRunner):
def __init__(self, config):
super().__init__(config)
self.denoising_step_list = self.model.config.denoising_step_list
self.num_frame_per_block = self.model.config.num_frame_per_block
self.num_frames = self.model.config.num_frames
self.frame_seq_length = self.model.config.frame_seq_length
self.infer_blocks = self.model.config.num_blocks
self.num_fragments = self.model.config.num_fragments
@ProfilingContext("Load models")
def load_model(self):
if self.config["parallel_attn_type"]:
cur_rank = dist.get_rank()
torch.cuda.set_device(cur_rank)
image_encoder = None
if self.config.cpu_offload:
init_device = torch.device("cpu")
else:
init_device = torch.device("cuda")
text_encoder = T5EncoderModel(
text_len=self.config["text_len"],
dtype=torch.bfloat16,
device=init_device,
checkpoint_path=os.path.join(self.config.model_path, "models_t5_umt5-xxl-enc-bf16.pth"),
tokenizer_path=os.path.join(self.config.model_path, "google/umt5-xxl"),
shard_fn=None,
)
text_encoders = [text_encoder]
model = WanCausalModel(self.config.model_path, self.config, init_device)
if self.config.lora_path:
lora_wrapper = WanLoraWrapper(model)
lora_name = lora_wrapper.load_lora(self.config.lora_path)
lora_wrapper.apply_lora(lora_name, self.config.strength_model)
print(f"Loaded LoRA: {lora_name}")
vae_model = WanVAE(vae_pth=os.path.join(self.config.model_path, "Wan2.1_VAE.pth"), device=init_device, parallel=self.config.parallel_vae)
if self.config.task == "i2v":
image_encoder = CLIPModel(
dtype=torch.float16,
device=init_device,
checkpoint_path=os.path.join(self.config.model_path, "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
tokenizer_path=os.path.join(self.config.model_path, "xlm-roberta-large"),
)
return model, text_encoders, vae_model, image_encoder
def init_scheduler(self):
scheduler = WanCausalScheduler(self.config)
self.model.set_scheduler(scheduler)
def set_target_shape(self):
if self.config.task == "i2v":
self.config.target_shape = (16, 3, self.config.lat_h, self.config.lat_w)
elif self.config.task == "t2v":
self.config.target_shape = (
16,
self.config.num_frame_per_block,
int(self.config.target_height) // self.config.vae_stride[1],
int(self.config.target_width) // self.config.vae_stride[2],
)
def run(self):
self.model.transformer_infer._init_kv_cache(dtype=torch.bfloat16, device="cuda")
self.model.transformer_infer._init_crossattn_cache(dtype=torch.bfloat16, device="cuda")
output_latents = torch.zeros(
(self.model.config.target_shape[0], self.num_frames + (self.num_fragments - 1) * (self.num_frames - self.num_frame_per_block), *self.model.config.target_shape[2:]),
device="cuda",
dtype=torch.bfloat16,
)
start_block_idx = 0
for fragment_idx in range(self.num_fragments):
print(f"=======> fragment_idx: {fragment_idx + 1} / {self.num_fragments}")
kv_start = 0
kv_end = kv_start + self.num_frame_per_block * self.frame_seq_length
if fragment_idx > 0:
print("recompute the kv_cache ...")
with ProfilingContext4Debug("step_pre"):
self.model.scheduler.latents = self.model.scheduler.last_sample
self.model.scheduler.step_pre(step_index=self.model.scheduler.infer_steps - 1)
with ProfilingContext4Debug("infer"):
self.model.infer(self.inputs, kv_start, kv_end)
kv_start += self.num_frame_per_block * self.frame_seq_length
kv_end += self.num_frame_per_block * self.frame_seq_length
infer_blocks = self.infer_blocks - (fragment_idx > 0)
for block_idx in range(infer_blocks):
print(f"=======> block_idx: {block_idx + 1} / {infer_blocks}")
print(f"=======> kv_start: {kv_start}, kv_end: {kv_end}")
self.model.scheduler.reset()
for step_index in range(self.model.scheduler.infer_steps):
print(f"==> step_index: {step_index + 1} / {self.model.scheduler.infer_steps}")
with ProfilingContext4Debug("step_pre"):
self.model.scheduler.step_pre(step_index=step_index)
with ProfilingContext4Debug("infer"):
self.model.infer(self.inputs, kv_start, kv_end)
with ProfilingContext4Debug("step_post"):
self.model.scheduler.step_post()
kv_start += self.num_frame_per_block * self.frame_seq_length
kv_end += self.num_frame_per_block * self.frame_seq_length
output_latents[:, start_block_idx * self.num_frame_per_block : (start_block_idx + 1) * self.num_frame_per_block] = self.model.scheduler.latents
start_block_idx += 1
return output_latents, self.model.scheduler.generator
import math
import numpy as np
import torch
from typing import List, Optional, Tuple, Union
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
class WanCausalScheduler(WanScheduler):
def __init__(self, config):
super().__init__(config)
self.denoising_step_list = config.denoising_step_list
def prepare(self, image_encoder_output):
self.generator = torch.Generator(device=self.device)
self.generator.manual_seed(self.config.seed)
self.prepare_latents(self.config.target_shape, dtype=torch.float32)
if self.config.task in ["t2v"]:
self.seq_len = math.ceil((self.config.target_shape[2] * self.config.target_shape[3]) / (self.config.patch_size[1] * self.config.patch_size[2]) * self.config.target_shape[1])
elif self.config.task in ["i2v"]:
self.seq_len = ((self.config.target_video_length - 1) // self.config.vae_stride[0] + 1) * self.config.lat_h * self.config.lat_w // (self.config.patch_size[1] * self.config.patch_size[2])
alphas = np.linspace(1, 1 / self.num_train_timesteps, self.num_train_timesteps)[::-1].copy()
sigmas = 1.0 - alphas
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
self.sigmas = sigmas
self.timesteps = sigmas * self.num_train_timesteps
self.model_outputs = [None] * self.solver_order
self.timestep_list = [None] * self.solver_order
self.last_sample = None
self.sigmas = self.sigmas.to("cpu")
self.sigma_min = self.sigmas[-1].item()
self.sigma_max = self.sigmas[0].item()
self.set_denoising_timesteps(device=self.device)
def set_denoising_timesteps(self, device: Union[str, torch.device] = None):
self.timesteps = torch.tensor(self.denoising_step_list, device=device, dtype=torch.int64)
self.sigmas = torch.cat([self.timesteps / self.num_train_timesteps, torch.tensor([0.0], device=device)])
self.sigmas = self.sigmas.to("cpu")
self.infer_steps = len(self.timesteps)
self.model_outputs = [
None,
] * self.solver_order
self.lower_order_nums = 0
self.last_sample = None
self._begin_index = None
def reset(self):
self.model_outputs = [None] * self.solver_order
self.timestep_list = [None] * self.solver_order
self.last_sample = None
self.noise_pred = None
self.this_order = None
self.lower_order_nums = 0
self.prepare_latents(self.config.target_shape, dtype=torch.float32)
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python -m lightx2v \
--model_cls wan2.1_causal \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan_t2v_causal.json \
--prompt "Two anthropomorphic cats fight intensely on a spotlighted stage; the left cat wearing blue boxing gear with matching gloves, the right cat in bright red boxing attire and gloves." \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--save_video_path ./output_lightx2v_wan_t2v_causal.mp4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment