"docs/api/index.mdx" did not exist on "f5e3939220e9cd3d7a636708bc9df031ebfd4854"
Unverified Commit 6a658f42 authored by Watebear's avatar Watebear Committed by GitHub
Browse files

[Feat] support self-forcing wan2.1 dmd (#342)

support self-forcing wan2.1 dmd
parent 411dd37a
{
"infer_steps": 50,
"target_video_length": 81,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "flash_attn2",
"cross_attn_1_type": "flash_attn2",
"cross_attn_2_type": "flash_attn2",
"seed": 0,
"sample_guide_scale": 6,
"sample_shift": 8,
"enable_cfg": false,
"cpu_offload": false,
"sf_config": {
"sf_type": "dmd",
"local_attn_size": -1,
"shift": 5.0,
"num_frame_per_block": 3,
"num_transformer_blocks": 30,
"frame_seq_length": 1560,
"num_output_frames": 21,
"num_inference_steps": 1000,
"denoising_step_list": [1000.0000, 937.5000, 833.3333, 625.0000]
}
}
......@@ -136,3 +136,15 @@ class RMSWeightFP32(RMSWeight):
hidden_states = hidden_states.to(input_dtype)
return hidden_states
@RMS_WEIGHT_REGISTER("self_forcing")
class RMSWeightSF(RMSWeight):
def __init__(self, weight_name, lazy_load=False, lazy_load_file=None, eps=1e-6):
super().__init__(weight_name, lazy_load, lazy_load_file, eps)
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
def apply(self, x):
return self._norm(x.float()).type_as(x) * self.weight
......@@ -13,6 +13,7 @@ from lightx2v.models.runners.wan.wan_audio_runner import Wan22AudioRunner, WanAu
from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner # noqa: F401
from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner # noqa: F401
from lightx2v.models.runners.wan.wan_runner import Wan22MoeRunner, WanRunner # noqa: F401
from lightx2v.models.runners.wan.wan_sf_runner import WanSFRunner # noqa: F401
from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner # noqa: F401
from lightx2v.models.runners.wan.wan_vace_runner import WanVaceRunner # noqa: F401
from lightx2v.utils.envs import *
......@@ -43,6 +44,7 @@ def main():
"wan2.1_causvid",
"wan2.1_skyreels_v2_df",
"wan2.1_vace",
"wan2.1_sf",
"cogvideox",
"seko_talk",
"wan2.2_moe",
......@@ -58,6 +60,7 @@ def main():
parser.add_argument("--task", type=str, choices=["t2v", "i2v", "t2i", "i2i", "flf2v", "vace", "animate"], default="t2v")
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--sf_model_path", type=str, required=False)
parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--use_prompt_enhancer", action="store_true")
......
import torch
from lightx2v.models.networks.wan.infer.module_io import GridOutput, WanPreInferModuleOutput
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.utils.envs import *
def sinusoidal_embedding_1d(dim, position):
# preprocess
assert dim % 2 == 0
half = dim // 2
position = position.type(torch.float64)
# calculation
sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
return x
def rope_params(max_seq_len, dim, theta=10000):
assert dim % 2 == 0
freqs = torch.outer(torch.arange(max_seq_len), 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)))
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
class WanSFPreInfer(WanPreInfer):
def __init__(self, config):
super().__init__(config)
d = config["dim"] // config["num_heads"]
self.freqs = torch.cat(
[
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
],
dim=1,
).cuda()
def time_embedding(self, weights, embed):
embed = weights.time_embedding_0.apply(embed)
embed = torch.nn.functional.silu(embed)
embed = weights.time_embedding_2.apply(embed)
return embed
def time_projection(self, weights, embed):
embed0 = torch.nn.functional.silu(embed)
embed0 = weights.time_projection_1.apply(embed0).unflatten(1, (6, self.dim))
return embed0
@torch.no_grad()
def infer(self, weights, inputs, kv_start=0, kv_end=0):
x = self.scheduler.latents_input
t = self.scheduler.timestep_input
if self.scheduler.infer_condition:
context = inputs["text_encoder_output"]["context"]
else:
context = inputs["text_encoder_output"]["context_null"]
# embeddings
x = weights.patch_embedding.apply(x.unsqueeze(0))
grid_sizes_t, grid_sizes_h, grid_sizes_w = x.shape[2:]
x = x.flatten(2).transpose(1, 2).contiguous()
seq_lens = torch.tensor(x.size(1), dtype=torch.int32, device=x.device).unsqueeze(0)
embed_tmp = sinusoidal_embedding_1d(self.freq_dim, t.flatten()).type_as(x)
embed = self.time_embedding(weights, embed_tmp)
embed0 = self.time_projection(weights, embed)
# text embeddings
if self.sensitive_layer_dtype != self.infer_dtype: # False
out = weights.text_embedding_0.apply(context.squeeze(0).to(self.sensitive_layer_dtype))
else:
out = weights.text_embedding_0.apply(context.squeeze(0))
out = torch.nn.functional.gelu(out, approximate="tanh")
context = weights.text_embedding_2.apply(out)
if self.clean_cuda_cache:
del out
torch.cuda.empty_cache()
if self.clean_cuda_cache:
if self.config.get("use_image_encoder", True):
del context_clip
torch.cuda.empty_cache()
grid_sizes = GridOutput(tensor=torch.tensor([[grid_sizes_t, grid_sizes_h, grid_sizes_w]], dtype=torch.int32, device=x.device), tuple=(grid_sizes_t, grid_sizes_h, grid_sizes_w))
return WanPreInferModuleOutput(
embed=embed,
grid_sizes=grid_sizes,
x=x.squeeze(0),
embed0=embed0.squeeze(0),
seq_lens=seq_lens,
freqs=self.freqs,
context=context,
)
import torch
from lightx2v.models.networks.wan.infer.transformer_infer import WanTransformerInfer
def causal_rope_apply(x, grid_sizes, freqs, start_frame=0):
n, c = x.size(2), x.size(3) // 2
# split freqs
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
# loop over samples
output = []
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
seq_len = f * h * w
# precompute multipliers
x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(seq_len, n, -1, 2))
freqs_i = torch.cat(
[freqs[0][start_frame : start_frame + f].view(f, 1, 1, -1).expand(f, h, w, -1), freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)],
dim=-1,
).reshape(seq_len, 1, -1)
# apply rotary embedding
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
x_i = torch.cat([x_i, x[i, seq_len:]])
# append to collection
output.append(x_i)
return torch.stack(output).type_as(x)
class WanSFTransformerInfer(WanTransformerInfer):
def __init__(self, config):
super().__init__(config)
if self.config.get("cpu_offload", False):
self.device = torch.device("cpu")
else:
self.device = torch.device("cuda")
self.dtype = torch.bfloat16
sf_config = self.config.sf_config
self.local_attn_size = sf_config.local_attn_size
self.max_attention_size = 32760 if self.local_attn_size == -1 else self.local_attn_size * 1560
self.num_frame_per_block = sf_config.num_frame_per_block
self.num_transformer_blocks = sf_config.num_transformer_blocks
self.frame_seq_length = sf_config.frame_seq_length
self._initialize_kv_cache(self.device, self.dtype)
self._initialize_crossattn_cache(self.device, self.dtype)
self.infer_func = self.infer_with_kvcache
def _initialize_kv_cache(self, dtype, device):
"""
Initialize a Per-GPU KV cache for the Wan model.
"""
kv_cache1 = []
if self.local_attn_size != -1:
# Use the local attention size to compute the KV cache size
kv_cache_size = self.local_attn_size * self.frame_seq_length
else:
# Use the default KV cache size
kv_cache_size = 32760
for _ in range(self.num_transformer_blocks):
kv_cache1.append(
{
"k": torch.zeros((kv_cache_size, 12, 128)).to(dtype).to(device),
"v": torch.zeros((kv_cache_size, 12, 128)).to(dtype).to(device),
"global_end_index": torch.tensor([0], dtype=torch.long).to(device),
"local_end_index": torch.tensor([0], dtype=torch.long).to(device),
}
)
self.kv_cache1_default = kv_cache1 # always store the clean cache
def _initialize_crossattn_cache(self, dtype, device):
"""
Initialize a Per-GPU cross-attention cache for the Wan model.
"""
crossattn_cache = []
for _ in range(self.num_transformer_blocks):
crossattn_cache.append({"k": torch.zeros((512, 12, 128)).to(dtype).to(device), "v": torch.zeros((512, 12, 128)).to(dtype).to(device), "is_init": False})
self.crossattn_cache_default = crossattn_cache
def infer_with_kvcache(self, blocks, x, pre_infer_out):
self.kv_cache1 = self.kv_cache1_default
self.crossattn_cache = self.crossattn_cache_default
for block_idx in range(len(blocks)):
self.block_idx = block_idx
x = self.infer_block_witch_kvcache(blocks[block_idx], x, pre_infer_out)
return x
def infer_self_attn_with_kvcache(self, phase, grid_sizes, x, seq_lens, freqs, shift_msa, scale_msa):
if hasattr(phase, "smooth_norm1_weight"):
norm1_weight = (1 + scale_msa.squeeze()) * phase.smooth_norm1_weight.tensor
norm1_bias = shift_msa.squeeze() * phase.smooth_norm1_bias.tensor
else:
norm1_weight = 1 + scale_msa.squeeze()
norm1_bias = shift_msa.squeeze()
norm1_out = phase.norm1.apply(x)
if self.sensitive_layer_dtype != self.infer_dtype:
norm1_out = norm1_out.to(self.sensitive_layer_dtype)
norm1_out.mul_(norm1_weight[0:1, :]).add_(norm1_bias[0:1, :])
if self.sensitive_layer_dtype != self.infer_dtype: # False
norm1_out = norm1_out.to(self.infer_dtype)
s, n, d = *norm1_out.shape[:1], self.num_heads, self.head_dim
q0 = phase.self_attn_q.apply(norm1_out)
k0 = phase.self_attn_k.apply(norm1_out)
q = phase.self_attn_norm_q.apply(q0).view(s, n, d)
k = phase.self_attn_norm_k.apply(k0).view(s, n, d)
v = phase.self_attn_v.apply(norm1_out).view(s, n, d)
seg_index = self.scheduler.seg_index
current_start_frame = seg_index * self.num_frame_per_block
q = causal_rope_apply(q.unsqueeze(0), grid_sizes, freqs, start_frame=current_start_frame).type_as(v)[0]
k = causal_rope_apply(k.unsqueeze(0), grid_sizes, freqs, start_frame=current_start_frame).type_as(v)[0]
# Assign new keys/values directly up to current_end
seg_seq_len = self.frame_seq_length * self.num_frame_per_block
local_start_index = seg_index * seg_seq_len
local_end_index = (seg_index + 1) * seg_seq_len
self.kv_cache1[self.block_idx]["k"][local_start_index:local_end_index] = k
self.kv_cache1[self.block_idx]["v"][local_start_index:local_end_index] = v
attn_k = self.kv_cache1[self.block_idx]["k"][max(0, local_end_index - self.max_attention_size) : local_end_index]
attn_v = self.kv_cache1[self.block_idx]["v"][max(0, local_end_index - self.max_attention_size) : local_end_index]
k_lens = torch.empty_like(seq_lens).fill_(attn_k.size(0))
cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(q, k_lens=k_lens)
if self.clean_cuda_cache:
del freqs_i, norm1_out, norm1_weight, norm1_bias
torch.cuda.empty_cache()
if self.config["seq_parallel"]:
attn_out = phase.self_attn_1_parallel.apply(
q=q,
k=attn_k,
v=attn_v,
img_qkv_len=q.shape[0],
cu_seqlens_qkv=cu_seqlens_q,
attention_module=phase.self_attn_1,
seq_p_group=self.seq_p_group,
model_cls=self.config["model_cls"],
)
else:
attn_out = phase.self_attn_1.apply(
q=q,
k=attn_k,
v=attn_v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_k,
max_seqlen_q=q.size(0),
max_seqlen_kv=attn_k.size(0),
model_cls=self.config["model_cls"],
)
y = phase.self_attn_o.apply(attn_out)
if self.clean_cuda_cache:
del q, k, v, attn_out
torch.cuda.empty_cache()
return y
def infer_cross_attn_with_kvcache(self, phase, x, context, y_out, gate_msa):
num_frames = gate_msa.shape[0]
frame_seqlen = x.shape[0] // gate_msa.shape[0]
seg_index = self.scheduler.seg_index
x.add_((y_out.unflatten(dim=0, sizes=(num_frames, frame_seqlen)) * gate_msa).flatten(0, 1))
norm3_out = phase.norm3.apply(x)
if self.task in ["i2v", "flf2v"] and self.config.get("use_image_encoder", True):
context_img = context[:257]
context = context[257:]
else:
context_img = None
if self.sensitive_layer_dtype != self.infer_dtype:
context = context.to(self.infer_dtype)
if self.task in ["i2v", "flf2v"] and self.config.get("use_image_encoder", True):
context_img = context_img.to(self.infer_dtype)
n, d = self.num_heads, self.head_dim
q = phase.cross_attn_norm_q.apply(phase.cross_attn_q.apply(norm3_out)).view(-1, n, d)
if seg_index == 0:
k = phase.cross_attn_norm_k.apply(phase.cross_attn_k.apply(context)).view(-1, n, d)
v = phase.cross_attn_v.apply(context).view(-1, n, d)
self.crossattn_cache[self.block_idx]["k"] = k
self.crossattn_cache[self.block_idx]["v"] = v
else:
k = self.crossattn_cache[self.block_idx]["k"]
v = self.crossattn_cache[self.block_idx]["v"]
cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(
q,
k_lens=torch.tensor([k.size(0)], dtype=torch.int32, device=k.device),
)
attn_out = phase.cross_attn_1.apply(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_k,
max_seqlen_q=q.size(0),
max_seqlen_kv=k.size(0),
model_cls=self.config["model_cls"],
)
if self.task in ["i2v", "flf2v"] and self.config.get("use_image_encoder", True) and context_img is not None:
k_img = phase.cross_attn_norm_k_img.apply(phase.cross_attn_k_img.apply(context_img)).view(-1, n, d)
v_img = phase.cross_attn_v_img.apply(context_img).view(-1, n, d)
cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(
q,
k_lens=torch.tensor([k_img.size(0)], dtype=torch.int32, device=k.device),
)
img_attn_out = phase.cross_attn_2.apply(
q=q,
k=k_img,
v=v_img,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_k,
max_seqlen_q=q.size(0),
max_seqlen_kv=k_img.size(0),
model_cls=self.config["model_cls"],
)
attn_out.add_(img_attn_out)
if self.clean_cuda_cache:
del k_img, v_img, img_attn_out
torch.cuda.empty_cache()
attn_out = phase.cross_attn_o.apply(attn_out)
if self.clean_cuda_cache:
del q, k, v, norm3_out, context, context_img
torch.cuda.empty_cache()
return x, attn_out
def infer_ffn(self, phase, x, attn_out, c_shift_msa, c_scale_msa):
x.add_(attn_out)
if self.clean_cuda_cache:
del attn_out
torch.cuda.empty_cache()
num_frames = c_shift_msa.shape[0]
frame_seqlen = x.shape[0] // c_shift_msa.shape[0]
norm2_weight = 1 + c_scale_msa
norm2_bias = c_shift_msa
norm2_out = phase.norm2.apply(x)
norm2_out = norm2_out.unflatten(dim=0, sizes=(num_frames, frame_seqlen))
norm2_out.mul_(norm2_weight).add_(norm2_bias)
norm2_out = norm2_out.flatten(0, 1)
y = phase.ffn_0.apply(norm2_out)
if self.clean_cuda_cache:
del norm2_out, x, norm2_weight, norm2_bias
torch.cuda.empty_cache()
y = torch.nn.functional.gelu(y, approximate="tanh")
if self.clean_cuda_cache:
torch.cuda.empty_cache()
y = phase.ffn_2.apply(y)
return y
def post_process(self, x, y, c_gate_msa, pre_infer_out=None):
num_frames = c_gate_msa.shape[0]
frame_seqlen = x.shape[0] // c_gate_msa.shape[0]
y = y.unflatten(dim=0, sizes=(num_frames, frame_seqlen))
x = x.unflatten(dim=0, sizes=(num_frames, frame_seqlen))
x.add_(y * c_gate_msa)
x = x.flatten(0, 1)
if self.clean_cuda_cache:
del y, c_gate_msa
torch.cuda.empty_cache()
return x
def infer_block_witch_kvcache(self, block, x, pre_infer_out):
if hasattr(block.compute_phases[0], "before_proj"):
x = block.compute_phases[0].before_proj.apply(x) + pre_infer_out.x
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.pre_process(
block.compute_phases[0].modulation,
pre_infer_out.embed0,
)
y_out = self.infer_self_attn_with_kvcache(
block.compute_phases[0],
pre_infer_out.grid_sizes.tensor,
x,
pre_infer_out.seq_lens,
pre_infer_out.freqs,
shift_msa,
scale_msa,
)
x, attn_out = self.infer_cross_attn_with_kvcache(
block.compute_phases[1],
x,
pre_infer_out.context,
y_out,
gate_msa,
)
y = self.infer_ffn(block.compute_phases[2], x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y, c_gate_msa, pre_infer_out)
if hasattr(block.compute_phases[2], "after_proj"):
pre_infer_out.adapter_output["hints"].append(block.compute_phases[2].after_proj.apply(x))
if self.has_post_adapter:
x = self.infer_post_adapter(block.compute_phases[3], x, pre_infer_out)
return x
def infer_non_blocks(self, weights, x, e):
num_frames = e.shape[0]
frame_seqlen = x.shape[0] // e.shape[0]
x = weights.norm.apply(x)
x = x.unflatten(dim=0, sizes=(num_frames, frame_seqlen))
t = self.scheduler.timestep_input
e = e.unflatten(dim=0, sizes=t.shape).unsqueeze(2)
modulation = weights.head_modulation.tensor
e = (modulation.unsqueeze(1) + e).chunk(2, dim=2)
x.mul_(1 + e[1][0]).add_(e[0][0])
x = x.flatten(0, 1)
x = weights.head.apply(x)
if self.clean_cuda_cache:
del e
torch.cuda.empty_cache()
return x
import os
import torch
from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer
from lightx2v.models.networks.wan.infer.self_forcing.pre_infer import WanSFPreInfer
from lightx2v.models.networks.wan.infer.self_forcing.transformer_infer import WanSFTransformerInfer
from lightx2v.models.networks.wan.model import WanModel
class WanSFModel(WanModel):
def __init__(self, model_path, config, device):
super().__init__(model_path, config, device)
self.to_cuda()
def _load_ckpt(self, unified_dtype, sensitive_layer):
sf_confg = self.config.sf_config
file_path = os.path.join(self.config.sf_model_path, f"checkpoints/self_forcing_{sf_confg.sf_type}.pt")
_weight_dict = torch.load(file_path)["generator_ema"]
weight_dict = {}
for k, v in _weight_dict.items():
name = k[6:]
weight = v.to(torch.bfloat16)
weight_dict.update({name: weight})
del _weight_dict
return weight_dict
def _init_infer_class(self):
self.pre_infer_class = WanSFPreInfer
self.post_infer_class = WanPostInfer
self.transformer_infer_class = WanSFTransformerInfer
@torch.no_grad()
def infer(self, inputs):
if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == 0:
self.to_cuda()
elif self.offload_granularity != "model":
self.pre_weight.to_cuda()
self.transformer_weights.non_block_weights_to_cuda()
current_start_frame = self.scheduler.seg_index * self.scheduler.num_frame_per_block
current_end_frame = (self.scheduler.seg_index + 1) * self.scheduler.num_frame_per_block
noise_pred = self._infer_cond_uncond(inputs, infer_condition=True)
self.scheduler.noise_pred[:, current_start_frame:current_end_frame] = noise_pred
if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1:
self.to_cpu()
elif self.offload_granularity != "model":
self.pre_weight.to_cpu()
self.transformer_weights.non_block_weights_to_cpu()
......@@ -113,6 +113,11 @@ class WanSelfAttention(WeightModule):
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
if self.config.get("sf_config", "False"):
self.attn_rms_type = "self_forcing"
else:
self.attn_rms_type = "sgl-kernel"
self.add_module(
"modulation",
TENSOR_REGISTER["Default"](
......@@ -136,6 +141,7 @@ class WanSelfAttention(WeightModule):
self.lazy_load_file,
),
)
self.add_module(
"self_attn_k",
MM_WEIGHT_REGISTER[self.mm_type](
......@@ -165,7 +171,7 @@ class WanSelfAttention(WeightModule):
)
self.add_module(
"self_attn_norm_q",
RMS_WEIGHT_REGISTER["sgl-kernel"](
RMS_WEIGHT_REGISTER[self.attn_rms_type](
f"{block_prefix}.{self.block_index}.self_attn.norm_q.weight",
self.lazy_load,
self.lazy_load_file,
......@@ -173,7 +179,7 @@ class WanSelfAttention(WeightModule):
)
self.add_module(
"self_attn_norm_k",
RMS_WEIGHT_REGISTER["sgl-kernel"](
RMS_WEIGHT_REGISTER[self.attn_rms_type](
f"{block_prefix}.{self.block_index}.self_attn.norm_k.weight",
self.lazy_load,
self.lazy_load_file,
......@@ -222,6 +228,11 @@ class WanCrossAttention(WeightModule):
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
if self.config.get("sf_config", "False"):
self.attn_rms_type = "self_forcing"
else:
self.attn_rms_type = "sgl-kernel"
self.add_module(
"norm3",
LN_WEIGHT_REGISTER["Default"](
......@@ -269,7 +280,7 @@ class WanCrossAttention(WeightModule):
)
self.add_module(
"cross_attn_norm_q",
RMS_WEIGHT_REGISTER["sgl-kernel"](
RMS_WEIGHT_REGISTER[self.attn_rms_type](
f"{block_prefix}.{self.block_index}.cross_attn.norm_q.weight",
self.lazy_load,
self.lazy_load_file,
......@@ -277,7 +288,7 @@ class WanCrossAttention(WeightModule):
)
self.add_module(
"cross_attn_norm_k",
RMS_WEIGHT_REGISTER["sgl-kernel"](
RMS_WEIGHT_REGISTER[self.attn_rms_type](
f"{block_prefix}.{self.block_index}.cross_attn.norm_k.weight",
self.lazy_load,
self.lazy_load_file,
......@@ -285,7 +296,7 @@ class WanCrossAttention(WeightModule):
)
self.add_module("cross_attn_1", ATTN_WEIGHT_REGISTER[self.config["cross_attn_1_type"]]())
if self.config.task in ["i2v", "flf2v", "animate"] and self.config.get("use_image_encoder", True):
if self.config.task in ["i2v", "flf2v"] and self.config.get("use_image_encoder", True):
self.add_module(
"cross_attn_k_img",
MM_WEIGHT_REGISTER[self.mm_type](
......@@ -306,7 +317,7 @@ class WanCrossAttention(WeightModule):
)
self.add_module(
"cross_attn_norm_k_img",
RMS_WEIGHT_REGISTER["sgl-kernel"](
RMS_WEIGHT_REGISTER[self.attn_rms_type](
f"{block_prefix}.{self.block_index}.cross_attn.norm_k_img.weight",
self.lazy_load,
self.lazy_load_file,
......
import gc
import torch
from loguru import logger
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.networks.wan.sf_model import WanSFModel
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.schedulers.wan.self_forcing.scheduler import WanSFScheduler
from lightx2v.models.video_encoders.hf.wan.vae_sf import WanSFVAE
from lightx2v.utils.envs import *
from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
torch.manual_seed(42)
@RUNNER_REGISTER("wan2.1_sf")
class WanSFRunner(WanRunner):
def __init__(self, config):
super().__init__(config)
self.vae_cls = WanSFVAE
def load_transformer(self):
model = WanSFModel(
self.config,
self.config,
self.init_device,
)
if self.config.get("lora_configs") and self.config.lora_configs:
assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
lora_wrapper = WanLoraWrapper(model)
for lora_config in self.config.lora_configs:
lora_path = lora_config["path"]
strength = lora_config.get("strength", 1.0)
lora_name = lora_wrapper.load_lora(lora_path)
lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
return model
def init_scheduler(self):
self.scheduler = WanSFScheduler(self.config)
def set_target_shape(self):
self.num_output_frames = 21
self.config.target_shape = [16, self.num_output_frames, 60, 104]
def get_video_segment_num(self):
self.video_segment_num = self.scheduler.num_blocks
@ProfilingContext4DebugL1("Run VAE Decoder")
def run_vae_decoder(self, latents):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.vae_decoder = self.load_vae_decoder()
images = self.vae_decoder.decode(latents.to(GET_DTYPE()), use_cache=True)
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.vae_decoder
torch.cuda.empty_cache()
gc.collect()
return images
@ProfilingContext4DebugL2("Run DiT")
def run_main(self, total_steps=None):
self.init_run()
if self.config.get("compile", False):
self.model.select_graph_for_compile()
total_blocks = self.scheduler.num_blocks
gen_videos = []
for seg_index in range(self.video_segment_num):
logger.info(f"==> segment_index: {seg_index + 1} / {total_blocks}")
total_steps = len(self.scheduler.denoising_step_list)
for step_index in range(total_steps):
logger.info(f"==> step_index: {step_index + 1} / {total_steps}")
with ProfilingContext4DebugL1("step_pre"):
self.model.scheduler.step_pre(seg_index=seg_index, step_index=step_index, is_rerun=False)
with ProfilingContext4DebugL1("🚀 infer_main"):
self.model.infer(self.inputs)
with ProfilingContext4DebugL1("step_post"):
self.model.scheduler.step_post()
latents = self.model.scheduler.stream_output
gen_videos.append(self.run_vae_decoder(latents))
# rerun with timestep zero to update KV cache using clean context
with ProfilingContext4DebugL1("step_pre_in_rerun"):
self.model.scheduler.step_pre(seg_index=seg_index, step_index=step_index, is_rerun=True)
with ProfilingContext4DebugL1("🚀 infer_main_in_rerun"):
self.model.infer(self.inputs)
self.gen_video = torch.cat(gen_videos, dim=0)
self.end_run()
import torch
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.utils.envs import *
class WanSFScheduler(WanScheduler):
def __init__(self, config):
super().__init__(config)
self.device = torch.device("cuda")
self.dtype = torch.bfloat16
self.num_frame_per_block = self.config.sf_config.num_frame_per_block
self.num_output_frames = self.config.sf_config.num_output_frames
self.num_blocks = self.num_output_frames // self.num_frame_per_block
self.denoising_step_list = self.config.sf_config.denoising_step_list
self.all_num_frames = [self.num_frame_per_block] * self.num_blocks
self.num_input_frames = 0
self.denoising_strength = 1.0
self.sigma_max = 1.0
self.sigma_min = 0
self.sf_shift = self.config.sf_config.shift
self.inverse_timesteps = False
self.extra_one_step = True
self.reverse_sigmas = False
self.num_inference_steps = self.config.sf_config.num_inference_steps
self.context_noise = 0
def prepare(self, image_encoder_output=None):
self.latents = torch.randn(self.config.target_shape, device=self.device, dtype=self.dtype)
timesteps = []
for frame_block_idx, current_num_frames in enumerate(self.all_num_frames):
frame_steps = []
for step_index, current_timestep in enumerate(self.denoising_step_list):
timestep = torch.ones([self.num_frame_per_block], device=self.device, dtype=torch.int64) * current_timestep
frame_steps.append(timestep)
timesteps.append(frame_steps)
self.timesteps = timesteps
self.noise_pred = torch.zeros(self.config.target_shape, device=self.device, dtype=self.dtype)
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * self.denoising_strength
if self.extra_one_step:
self.sigmas_sf = torch.linspace(sigma_start, self.sigma_min, self.num_inference_steps + 1)[:-1]
else:
self.sigmas_sf = torch.linspace(sigma_start, self.sigma_min, self.num_inference_steps)
if self.inverse_timesteps:
self.sigmas_sf = torch.flip(self.sigmas_sf, dims=[0])
self.sigmas_sf = self.sf_shift * self.sigmas_sf / (1 + (self.sf_shift - 1) * self.sigmas_sf)
if self.reverse_sigmas:
self.sigmas_sf = 1 - self.sigmas_sf
self.sigmas_sf = self.sigmas_sf.to(self.device)
self.timesteps_sf = self.sigmas_sf * self.num_train_timesteps
self.timesteps_sf = self.timesteps_sf.to(self.device)
self.stream_output = None
def step_pre(self, seg_index, step_index, is_rerun=False):
self.step_index = step_index
self.seg_index = seg_index
if not GET_DTYPE() == GET_SENSITIVE_DTYPE():
self.latents = self.latents.to(GET_DTYPE())
if not is_rerun:
self.timestep_input = torch.stack([self.timesteps[self.seg_index][self.step_index]])
self.latents_input = self.latents[:, self.seg_index * self.num_frame_per_block : min((self.seg_index + 1) * self.num_frame_per_block, self.num_output_frames)]
else:
# rerun with timestep zero to update KV cache using clean context
self.timestep_input = torch.ones_like(torch.stack([self.timesteps[self.seg_index][self.step_index]])) * self.context_noise
self.latents_input = self.latents[:, self.seg_index * self.num_frame_per_block : min((self.seg_index + 1) * self.num_frame_per_block, self.num_output_frames)]
def step_post(self):
# convert model outputs
current_start_frame = self.seg_index * self.num_frame_per_block
current_end_frame = (self.seg_index + 1) * self.num_frame_per_block
flow_pred = self.noise_pred[:, current_start_frame:current_end_frame].transpose(0, 1)
xt = self.latents_input.transpose(0, 1)
timestep = self.timestep_input.squeeze(0)
original_dtype = flow_pred.dtype
flow_pred, xt, sigmas, timesteps = map(lambda x: x.double().to(flow_pred.device), [flow_pred, xt, self.sigmas_sf, self.timesteps_sf])
timestep_id = torch.argmin((timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1)
sigma_t = sigmas[timestep_id].reshape(-1, 1, 1, 1)
x0_pred = xt - sigma_t * flow_pred
x0_pred = x0_pred.to(original_dtype)
# add noise
if self.step_index < len(self.denoising_step_list) - 1:
timestep_next = self.timesteps[self.seg_index][self.step_index + 1] * torch.ones(self.num_frame_per_block, device=self.device, dtype=torch.long)
timestep_id_next = torch.argmin((self.timesteps_sf.unsqueeze(0) - timestep_next.unsqueeze(1)).abs(), dim=1)
sigma_next = self.sigmas_sf[timestep_id_next].reshape(-1, 1, 1, 1)
noise_next = torch.randn_like(x0_pred)
sample_next = (1 - sigma_next) * x0_pred + sigma_next * noise_next
sample_next = sample_next.type_as(noise_next)
self.latents[:, self.seg_index * self.num_frame_per_block : min((self.seg_index + 1) * self.num_frame_per_block, self.num_output_frames)] = sample_next.transpose(0, 1)
else:
self.latents[:, self.seg_index * self.num_frame_per_block : min((self.seg_index + 1) * self.num_frame_per_block, self.num_output_frames)] = x0_pred.transpose(0, 1)
self.stream_output = x0_pred.transpose(0, 1)
......@@ -546,6 +546,7 @@ class WanVAE_(nn.Module):
self.temperal_upsample,
dropout,
)
self.clear_cache()
def forward(self, x):
mu, log_var = self.encode(x)
......@@ -739,6 +740,23 @@ class WanVAE_(nn.Module):
self.clear_cache()
return out
def cached_decode(self, z, scale):
# z: [b,c,t,h,w]
if isinstance(scale[0], torch.Tensor):
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
else:
z = z / scale[1] + scale[0]
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
else:
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2)
return out
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
......
import torch
from lightx2v.models.video_encoders.hf.wan.vae import _video_vae
class WanSFVAE:
def __init__(
self,
z_dim=16,
vae_pth="cache/vae_step_411000.pth",
dtype=torch.float,
device="cuda",
parallel=False,
use_tiling=False,
cpu_offload=False,
use_2d_split=True,
load_from_rank0=False,
):
self.dtype = dtype
self.device = device
self.parallel = parallel
self.use_tiling = use_tiling
self.cpu_offload = cpu_offload
self.use_2d_split = use_2d_split
mean = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921]
std = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160]
self.mean = torch.tensor(mean, dtype=torch.float32)
self.std = torch.tensor(std, dtype=torch.float32)
# init model
self.model = _video_vae(pretrained_path=vae_pth, z_dim=z_dim, cpu_offload=cpu_offload, dtype=dtype, load_from_rank0=load_from_rank0).eval().requires_grad_(False).to(device).to(dtype)
def to_cpu(self):
self.model.encoder = self.model.encoder.to("cpu")
self.model.decoder = self.model.decoder.to("cpu")
self.model = self.model.to("cpu")
self.mean = self.mean.cpu()
self.inv_std = self.inv_std.cpu()
self.scale = [self.mean, self.inv_std]
def to_cuda(self):
self.model.encoder = self.model.encoder.to("cuda")
self.model.decoder = self.model.decoder.to("cuda")
self.model = self.model.to("cuda")
self.mean = self.mean.cuda()
self.inv_std = self.inv_std.cuda()
self.scale = [self.mean, self.inv_std]
def decode(self, latent: torch.Tensor, use_cache: bool = False) -> torch.Tensor:
# from [batch_size, num_frames, num_channels, height, width]
# to [batch_size, num_channels, num_frames, height, width]
latent = latent.transpose(0, 1).unsqueeze(0)
zs = latent.permute(0, 2, 1, 3, 4)
if use_cache:
assert latent.shape[0] == 1, "Batch size must be 1 when using cache"
device, dtype = latent.device, latent.dtype
scale = [self.mean.to(device=device, dtype=dtype), 1.0 / self.std.to(device=device, dtype=dtype)]
if use_cache:
decode_function = self.model.cached_decode
else:
decode_function = self.model.decode
output = []
for u in zs:
output.append(decode_function(u.unsqueeze(0), scale).float().clamp_(-1, 1).squeeze(0))
output = torch.stack(output, dim=0)
# from [batch_size, num_channels, num_frames, height, width]
# to [batch_size, num_frames, num_channels, height, width]
output = output.permute(0, 2, 1, 3, 4).squeeze(0)
return output
#!/bin/bash
# set path and first
lightx2v_path=
model_path= # path to Wan2.1-T2V-1.3B
sf_model_path= # path to gdhe17/Self-Forcing
export CUDA_VISIBLE_DEVICES=0
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
python -m lightx2v.infer \
--model_cls wan2.1_sf \
--task t2v \
--model_path $model_path \
--sf_model_path $sf_model_path \
--config_json ${lightx2v_path}/configs/self_forcing/wan_t2v_sf.json \
--prompt 'A stylish woman strolls down a bustling Tokyo street, the warm glow of neon lights and animated city signs casting vibrant reflections. She wears a sleek black leather jacket paired with a flowing red dress and black boots, her black purse slung over her shoulder. Sunglasses perched on her nose and a bold red lipstick add to her confident, casual demeanor. The street is damp and reflective, creating a mirror-like effect that enhances the colorful lights and shadows. Pedestrians move about, adding to the lively atmosphere. The scene is captured in a dynamic medium shot with the woman walking slightly to one side, highlighting her graceful strides.' \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v_sf.mp4
#!/bin/bash
# set path and first
lightx2v_path=
model_path= # path to Wan2.1-T2V-1.3B
sf_model_path= # path to gdhe17/Self-Forcing
export CUDA_VISIBLE_DEVICES=0
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
python -m lightx2v.infer \
--model_cls wan2.1_sf \
--task t2v \
--model_path $model_path \
--sf_model_path $sf_model_path \
--config_json ${lightx2v_path}/configs/self_forcing/wan_t2v_sf.json \
--prompt 'A stylish woman strolls down a bustling Tokyo street, the warm glow of neon lights and animated city signs casting vibrant reflections. She wears a sleek black leather jacket paired with a flowing red dress and black boots, her black purse slung over her shoulder. Sunglasses perched on her nose and a bold red lipstick add to her confident, casual demeanor. The street is damp and reflective, creating a mirror-like effect that enhances the colorful lights and shadows. Pedestrians move about, adding to the lively atmosphere. The scene is captured in a dynamic medium shot with the woman walking slightly to one side, highlighting her graceful strides.' \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v_sf.mp4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment