Unverified Commit 87625ec2 authored by Yang Yong (雍洋)'s avatar Yang Yong (雍洋) Committed by GitHub
Browse files

Update Sparse Attention (#454)

parent 8ac762da
......@@ -4,6 +4,10 @@
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "nbhd_attn",
"nbhd_attn_setting": {
"coefficient": [1.0, 0.5, 0.25, 0.25],
"min_width": 2.0
},
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"sample_guide_scale": 5,
......
{
"infer_steps": 40,
"target_video_length": 81,
"target_height": 720,
"target_width": 1280,
"self_attn_1_type": "nbhd_attn",
"nbhd_attn_setting": {
"coefficient": [1.0, 0.5, 0.25, 0.25],
"min_width": 2.0
},
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"sample_guide_scale": 5,
"sample_shift": 3,
"enable_cfg": true,
"cpu_offload": false
}
......@@ -3,7 +3,7 @@
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "flash_attn3",
"self_attn_1_type": "radial_attn",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"sample_guide_scale": 5,
......
{
"infer_steps": 4,
"target_fps": 16,
"video_duration": 360,
"audio_sr": 16000,
"target_video_length": 81,
"resize_mode": "adaptive",
"self_attn_1_type": "nbhd_attn",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"sample_guide_scale": 1.0,
"sample_shift": 5,
"enable_cfg": false,
"cpu_offload": false,
"use_31_block": false
}
......@@ -11,7 +11,7 @@ from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate
def generate_nbhd_mask(a, block_num, attnmap_frame_num, device="cpu"):
def generate_nbhd_mask(a, block_num, attnmap_frame_num, coefficient=[1.0, 0.5, 0.056], min_width=1.0, device="cpu"):
"""
a : block num per frame
block_num : block num per col/row
......@@ -20,28 +20,28 @@ def generate_nbhd_mask(a, block_num, attnmap_frame_num, device="cpu"):
i_indices = torch.arange(block_num, device=device).unsqueeze(1) # [block_num, 1]
j_indices = torch.arange(block_num, device=device).unsqueeze(0) # [1, block_num]
# 1. attention sink frame: j <= a
assert len(coefficient) <= attnmap_frame_num, f"coefficient length {len(coefficient)} should <= attnmap_frame_num {attnmap_frame_num}"
width_list = [max(min_width, coefficient[i] * a) for i in range(len(coefficient))] + [min_width] * (attnmap_frame_num - len(coefficient))
logger.info(f"nbhd_attn width_list: {width_list}, len={len(width_list)}")
# attention sink frame: j <= a
mask_sink = j_indices <= a
# 2. self-attention within the frame
n = i_indices // a
mask_self = (j_indices >= n * a) & (j_indices < (n + 1) * a)
mask_sparse = torch.zeros((block_num, block_num), dtype=torch.bool, device=device)
for interval in range(0, attnmap_frame_num):
n = i_indices // a
mask_sparse_base_1 = (j_indices >= (n + interval) * a) & (j_indices <= (n + interval + 1) * a)
n = j_indices // a
mask_sparse_base_2 = (i_indices >= (n + interval) * a) & (i_indices <= (n + interval + 1) * a)
# 3. cross-frame attention
mask_cross = torch.zeros((block_num, block_num), dtype=torch.bool, device=device)
for n in range(1, attnmap_frame_num):
if n == 1:
width = 1 / 2 * a
elif n >= 2:
width = 1 / 8 * a
width = width_list[interval]
mask_1 = (i_indices - j_indices + (n * a + width) >= 0) & (i_indices - j_indices + (n * a - width) < 0)
mask_2 = (i_indices - j_indices - (n * a - width) > 0) & (i_indices - j_indices - (n * a + width) <= 0)
mask_1 = mask_sparse_base_1 & (i_indices - j_indices + (interval * a + width) >= 0) & (i_indices - j_indices + (interval * a - width) <= 0)
mask_2 = mask_sparse_base_2 & (i_indices - j_indices - (interval * a - width) >= 0) & (i_indices - j_indices - (interval * a + width) <= 0)
mask_cross = mask_cross | mask_1 | mask_2
mask_sparse = mask_sparse | mask_1 | mask_2
# 合并所有mask
mask = mask_sink | mask_self | mask_cross
mask = mask_sink | mask_sparse
return mask
......@@ -71,6 +71,8 @@ class NbhdAttnWeight(AttnWeightTemplate):
q_ranges = None
k_ranges = None
attn_type_map = None
coefficient = [1.0, 0.5, 0.056]
min_width = 1.0
def __init__(self):
self.config = {}
......@@ -80,8 +82,8 @@ class NbhdAttnWeight(AttnWeightTemplate):
if seqlen == cls.seqlen:
return
block_num = (seqlen + cls.block_size - 1) // cls.block_size
block_num_per_frame = (seqlen // cls.attnmap_frame_num + cls.block_size - 1) // cls.block_size
mask = generate_nbhd_mask(block_num_per_frame, block_num, cls.attnmap_frame_num, device="cpu")
block_num_per_frame = seqlen / cls.attnmap_frame_num / cls.block_size
mask = generate_nbhd_mask(block_num_per_frame, block_num, cls.attnmap_frame_num, coefficient=cls.coefficient, min_width=cls.min_width, device="cpu")
q_ranges, k_ranges = generate_qk_ranges(mask, cls.block_size, seqlen)
attn_type_map = torch.zeros(len(q_ranges), dtype=torch.int32, device="cuda")
q_ranges = q_ranges.to(torch.int32).to("cuda")
......
import torch
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate
from loguru import logger
try:
import flashinfer
from packaging import version
flashinfer_version = version.parse(flashinfer.__version__)
has_o_dtype = flashinfer_version >= version.parse("0.2.6.post1")
from magi_attention.functional import flex_flash_attn_func as magi_ffa_func
except ImportError:
flashinfer = None
###
### Code from radial-attention
### https://github.com/mit-han-lab/ç/blob/main/radial_attn/attn_mask.py#L150
###
@ATTN_WEIGHT_REGISTER("radial_attn")
class RadialAttnWeight(AttnWeightTemplate):
def __init__(self):
self.config = {}
magi_ffa_func = None
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
mask_map=None,
sparsity_type="radial",
block_size=128,
decay_factor=1,
model_cls="wan",
):
assert len(q.shape) == 3
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
x = radial_attn(
q,
k,
v,
mask_map=mask_map,
sparsity_type=sparsity_type,
block_size=block_size,
model_cls=model_cls[:3], # Use first 3 characters to match "wan", "wan2", etc.
decay_factor=decay_factor,
)
x = x.view(max_seqlen_q, -1)
return x
def radial_attn(
query, key, value, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, mask_map=None, sparsity_type="radial", block_size=128, decay_factor=1, model_cls="wan"
):
orig_seqlen, num_head, hidden_dim = query.shape
query = pad_qkv(query, block_size=block_size)
key = pad_qkv(key, block_size=block_size)
value = pad_qkv(value, block_size=block_size)
mask = mask_map.queryLogMask(query, sparsity_type, block_size=block_size, decay_factor=decay_factor, model_type=model_cls) if mask_map else None
seqlen = query.shape[0]
workspace_buffer = torch.empty(128 * 1024 * 1024, device=query.device, dtype=torch.uint8)
bsr_wrapper = flashinfer.BlockSparseAttentionWrapper(
workspace_buffer,
backend="fa2",
)
indptr = get_indptr_from_mask(mask, query)
indices = get_indices_from_mask(mask, query)
kwargs = dict(
indptr=indptr,
indices=indices,
M=seqlen,
N=seqlen,
R=block_size,
C=block_size,
num_qo_heads=num_head,
num_kv_heads=num_head,
head_dim=hidden_dim,
q_data_type=query.dtype,
kv_data_type=key.dtype,
use_fp16_qk_reduction=True,
)
if has_o_dtype:
kwargs["o_data_type"] = query.dtype
bsr_wrapper.plan(**kwargs)
o = bsr_wrapper.run(query, key, value)
return o[:orig_seqlen, :, :]
def get_indptr_from_mask(mask, query):
# query shows the device of the indptr
# indptr (torch.Tensor) - the block index pointer of the block-sparse matrix on row dimension,
# shape `(MB + 1,)`, where `MB` is the number of blocks in the row dimension.
# The first element is always 0, and the last element is the number of blocks in the row dimension.
# The rest of the elements are the number of blocks in each row.
# the mask is already a block sparse mask
indptr = torch.zeros(mask.shape[0] + 1, device=query.device, dtype=torch.int32)
indptr[0] = 0
row_counts = mask.sum(dim=1).flatten() # Ensure 1D output [num_blocks_row]
indptr[1:] = torch.cumsum(row_counts, dim=0)
return indptr
def get_indices_from_mask(mask, query):
# indices (torch.Tensor) - the block indices of the block-sparse matrix on column dimension,
# shape `(nnz,),` where `nnz` is the number of non-zero blocks.
# The elements in `indices` array should be less than `NB`: the number of blocks in the column dimension.
nonzero_indices = torch.nonzero(mask)
indices = nonzero_indices[:, 1].to(dtype=torch.int32, device=query.device)
return indices
from .template import AttnWeightTemplate
def shrinkMaskStrict(mask, block_size=128):
......@@ -136,39 +26,6 @@ def shrinkMaskStrict(mask, block_size=128):
return block_mask
def pad_qkv(input_tensor, block_size=128):
"""
Pad the input tensor to be a multiple of the block size.
input shape: (seqlen, num_heads, hidden_dim)
"""
seqlen, num_heads, hidden_dim = input_tensor.shape
# Calculate the necessary padding
padding_length = (block_size - (seqlen % block_size)) % block_size
# Create a padded tensor with zeros
padded_tensor = torch.zeros((seqlen + padding_length, num_heads, hidden_dim), device=input_tensor.device, dtype=input_tensor.dtype)
# Copy the original tensor into the padded tensor
padded_tensor[:seqlen, :, :] = input_tensor
return padded_tensor
def get_diagonal_split_mask(i, j, token_per_frame, sparse_type, query):
assert sparse_type in ["radial"]
dist = abs(i - j)
group = dist.bit_length()
threshold = 128 # hardcoded threshold for now, which is equal to block-size
decay_length = 2 ** token_per_frame.bit_length() / 2**group
if decay_length >= threshold:
return torch.ones((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool)
split_factor = int(threshold / decay_length)
modular = dist % split_factor
if modular == 0:
return torch.ones((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool)
else:
return torch.zeros((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool)
def get_window_width(i, j, token_per_frame, sparse_type, num_frame, decay_factor=1, block_size=128, model_type=None):
assert sparse_type in ["radial"]
dist = abs(i - j)
......@@ -191,28 +48,45 @@ def get_window_width(i, j, token_per_frame, sparse_type, num_frame, decay_factor
return threshold
def gen_log_mask_shrinked(query, s, video_token_num, num_frame, block_size=128, sparse_type="log", decay_factor=0.5, model_type=None):
def get_diagonal_split_mask(i, j, token_per_frame, sparse_type, device):
assert sparse_type in ["radial"]
dist = abs(i - j)
group = dist.bit_length()
threshold = 128 # hardcoded threshold for now, which is equal to block-size
decay_length = 2 ** token_per_frame.bit_length() / 2**group
if decay_length >= threshold:
return torch.ones((token_per_frame, token_per_frame), device=device, dtype=torch.bool)
split_factor = int(threshold / decay_length)
modular = dist % split_factor
if modular == 0:
return torch.ones((token_per_frame, token_per_frame), device=device, dtype=torch.bool)
else:
return torch.zeros((token_per_frame, token_per_frame), device=device, dtype=torch.bool)
def gen_log_mask_shrinked(device, s, video_token_num, num_frame, block_size=128, sparse_type="log", decay_factor=0.5, model_type=None):
"""
A more memory friendly version, we generate the attention mask of each frame pair at a time,
shrinks it, and stores it into the final result
"""
final_log_mask = torch.zeros((s // block_size, s // block_size), device=query.device, dtype=torch.bool)
final_log_mask = torch.zeros(((s + block_size - 1) // block_size, (s + block_size - 1) // block_size), device=device, dtype=torch.bool)
token_per_frame = video_token_num // num_frame
video_text_border = video_token_num // block_size
col_indices = torch.arange(0, token_per_frame, device=query.device).view(1, -1)
row_indices = torch.arange(0, token_per_frame, device=query.device).view(-1, 1)
col_indices = torch.arange(0, token_per_frame, device=device).view(1, -1)
row_indices = torch.arange(0, token_per_frame, device=device).view(-1, 1)
final_log_mask[video_text_border:] = True
final_log_mask[:, video_text_border:] = True
for i in range(num_frame):
for j in range(num_frame):
local_mask = torch.zeros((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool)
if j == 0: # this is attention sink
local_mask = torch.ones((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool)
local_mask = torch.zeros((token_per_frame, token_per_frame), device=device, dtype=torch.bool)
if j == 0 and model_type == "wan": # this is attention sink
local_mask = torch.ones((token_per_frame, token_per_frame), device=device, dtype=torch.bool)
else:
window_width = get_window_width(i, j, token_per_frame, sparse_type, num_frame, decay_factor=decay_factor, block_size=block_size, model_type=model_type)
local_mask = torch.abs(col_indices - row_indices) <= window_width
split_mask = get_diagonal_split_mask(i, j, token_per_frame, sparse_type, query)
split_mask = get_diagonal_split_mask(i, j, token_per_frame, sparse_type, device)
local_mask = torch.logical_and(local_mask, split_mask)
remainder_row = (i * token_per_frame) % block_size
......@@ -220,7 +94,7 @@ def gen_log_mask_shrinked(query, s, video_token_num, num_frame, block_size=128,
# get the padded size
all_length_row = remainder_row + ((token_per_frame - 1) // block_size + 1) * block_size
all_length_col = remainder_col + ((token_per_frame - 1) // block_size + 1) * block_size
padded_local_mask = torch.zeros((all_length_row, all_length_col), device=query.device, dtype=torch.bool)
padded_local_mask = torch.zeros((all_length_row, all_length_col), device=device, dtype=torch.bool)
padded_local_mask[remainder_row : remainder_row + token_per_frame, remainder_col : remainder_col + token_per_frame] = local_mask
# shrink the mask
block_mask = shrinkMaskStrict(padded_local_mask, block_size=block_size)
......@@ -230,22 +104,82 @@ def gen_log_mask_shrinked(query, s, video_token_num, num_frame, block_size=128,
block_row_end = block_row_start + block_mask.shape[0]
block_col_end = block_col_start + block_mask.shape[1]
final_log_mask[block_row_start:block_row_end, block_col_start:block_col_end] = torch.logical_or(final_log_mask[block_row_start:block_row_end, block_col_start:block_col_end], block_mask)
print(f"mask sparsity: {1 - final_log_mask.sum() / final_log_mask.numel()}")
return final_log_mask
class MaskMap:
def __init__(self, video_token_num=79200, num_frame=22):
self.video_token_num = video_token_num
self.num_frame = num_frame
self.log_mask = None
def queryLogMask(self, query, sparse_type, block_size=128, decay_factor=0.5, model_type=None):
log_mask = torch.ones((query.shape[0] // block_size, query.shape[0] // block_size), device=query.device, dtype=torch.bool)
if self.log_mask is None:
self.log_mask = gen_log_mask_shrinked(
query, query.shape[0], self.video_token_num, self.num_frame, sparse_type=sparse_type, decay_factor=decay_factor, model_type=model_type, block_size=block_size
)
block_bound = self.video_token_num // block_size
log_mask[:block_bound, :block_bound] = self.log_mask[:block_bound, :block_bound]
return log_mask
def generate_qk_ranges(mask, block_size, seqlen):
indices = torch.nonzero(mask, as_tuple=False) # shape: [N, 2]
i_indices = indices[:, 0] # [N]
j_indices = indices[:, 1] # [N]
q_start = i_indices * block_size # [N]
q_end = torch.clamp((i_indices + 1) * block_size, max=seqlen) # [N]
k_start = j_indices * block_size # [N]
k_end = torch.clamp((j_indices + 1) * block_size, max=seqlen) # [N]
q_ranges = torch.stack([q_start, q_end], dim=1) # [N, 2]
k_ranges = torch.stack([k_start, k_end], dim=1) # [N, 2]
return q_ranges, k_ranges
@ATTN_WEIGHT_REGISTER("radial_attn")
class RadialAttnWeight(AttnWeightTemplate):
block_size = 128
seqlen = None
attnmap_frame_num = None
q_ranges = None
k_ranges = None
attn_type_map = None
def __init__(self):
self.config = {}
@classmethod
def prepare_mask(cls, seqlen):
if seqlen == cls.seqlen:
return
mask = gen_log_mask_shrinked(
device="cuda", s=seqlen, video_token_num=seqlen, num_frame=cls.attnmap_frame_num, block_size=cls.block_size, sparse_type="radial", decay_factor=0.2, model_type="wan"
)
q_ranges, k_ranges = generate_qk_ranges(mask, cls.block_size, seqlen)
attn_type_map = torch.zeros(len(q_ranges), dtype=torch.int32, device="cuda")
q_ranges = q_ranges.to(torch.int32).to("cuda")
k_ranges = k_ranges.to(torch.int32).to("cuda")
cls.seqlen = seqlen
cls.q_ranges = q_ranges
cls.k_ranges = k_ranges
cls.attn_type_map = attn_type_map
logger.info(f"NbhdAttnWeight Update: seqlen={seqlen}")
sparsity = 1 - mask.sum().item() / mask.numel()
logger.info(f"Attention sparsity: {sparsity}")
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
model_cls=None,
):
"""
q: [seqlen, head_num, head_dim]
k: [seqlen, head_num, head_dim]
v: [seqlen, head_num, head_dim]
"""
self.prepare_mask(seqlen=q.shape[0])
out = magi_ffa_func(
q,
k,
v,
q_ranges=self.q_ranges,
k_ranges=self.k_ranges,
attn_type_map=self.attn_type_map,
auto_range_merge=True,
)[0]
return out.reshape(out.shape[0], -1)
......@@ -192,8 +192,15 @@ class WanSelfAttention(WeightModule):
context_length=self.config.get("svg_context_length", 0),
sparsity=self.config.get("svg_sparsity", 0.25),
)
if self.config["self_attn_1_type"] in ["svg_attn", "nbhd_attn"]:
if self.config["self_attn_1_type"] in ["svg_attn", "radial_attn", "nbhd_attn"]:
attention_weights_cls.attnmap_frame_num = self.config["attnmap_frame_num"]
# nbhd_attn setting
if self.config["self_attn_1_type"] == "nbhd_attn":
if "nbhd_attn_setting" in self.config:
if "coefficient" in self.config["nbhd_attn_setting"]:
attention_weights_cls.coefficient = self.config["nbhd_attn_setting"]["coefficient"]
if "min_width" in self.config["nbhd_attn_setting"]:
attention_weights_cls.min_width = self.config["nbhd_attn_setting"]["min_width"]
self.add_module("self_attn_1", attention_weights_cls())
if self.config["seq_parallel"]:
......
......@@ -72,6 +72,8 @@ def set_config(args):
config["target_video_length"] = config["target_video_length"] // config["vae_stride"][0] * config["vae_stride"][0] + 1
config["attnmap_frame_num"] = ((config["target_video_length"] - 1) // config["vae_stride"][0] + 1) // config["patch_size"][0]
if config["model_cls"] == "seko_talk":
config["attnmap_frame_num"] += 1
return config
......
#!/bin/bash
lightx2v_path=/path/to/Lightx2v
model_path=/path/to/SekoTalk-Distill
export CUDA_VISIBLE_DEVICES=0
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export SENSITIVE_LAYER_DTYPE=None
python -m lightx2v.infer \
--model_cls seko_talk \
--task s2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/seko_talk/seko_talk_22_nbhd_attn.json \
--prompt "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze." \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--image_path ${lightx2v_path}/assets/inputs/audio/seko_input.png \
--audio_path ${lightx2v_path}/assets/inputs/audio/seko_input.mp3 \
--save_result_path ${lightx2v_path}/save_results/output_lightx2v_seko_talk.mp4
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
export CUDA_VISIBLE_DEVICES=0
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
python -m lightx2v.infer \
--model_cls wan2.1 \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/attentions/wan_i2v_nbhd_480p.json \
--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \
--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \
--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v_nbhd_attn_480p.mp4
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
export CUDA_VISIBLE_DEVICES=0
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
python -m lightx2v.infer \
--model_cls wan2.1 \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/attentions/wan_i2v_nbhd_720p.json \
--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \
--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \
--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v_nbhd_attn_720p.mp4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment