Unverified Commit 682037cd authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

[Feat] Add wan2.2 animate model (#339)

parent e251e4dc
{
"infer_steps": 20,
"target_video_length": 77,
"text_len": 512,
"target_height": 720,
"target_width": 1280,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"adapter_attn_type": "flash_attn3",
"seed": 42,
"sample_shift": 5.0,
"sample_guide_scale": 5.0,
"enable_cfg": false,
"cpu_offload": false,
"src_pose_path": "/path/to/animate/process_results/src_pose.mp4",
"src_face_path": "/path/to/animate/process_results/src_face.mp4",
"src_ref_images": "/path/to/animate/process_results/src_ref.png",
"refert_num": 1,
"replace_flag": false,
"fps": 30
}
{
"infer_steps": 20,
"target_video_length": 77,
"text_len": 512,
"target_height": 720,
"target_width": 1280,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"adapter_attn_type": "flash_attn3",
"seed": 42,
"sample_shift": 5.0,
"sample_guide_scale": 5.0,
"enable_cfg": false,
"cpu_offload": false,
"src_pose_path": "/path/to/replace/process_results/src_pose.mp4",
"src_face_path": "/path/to/replace/process_results/src_face.mp4",
"src_ref_images": "/path/to/replace/process_results/src_ref.png",
"src_bg_path": "/path/to/replace/process_results/src_bg.mp4",
"src_mask_path": "/path/to/replace/process_results/src_mask.mp4",
"refert_num": 1,
"fps": 30,
"replace_flag": true
}
...@@ -3,6 +3,9 @@ class WeightModule: ...@@ -3,6 +3,9 @@ class WeightModule:
self._modules = {} self._modules = {}
self._parameters = {} self._parameters = {}
def is_empty(self):
return len(self._modules) == 0 and len(self._parameters) == 0
def add_module(self, name, module): def add_module(self, name, module):
self._modules[name] = module self._modules[name] = module
setattr(self, name, module) setattr(self, name, module)
......
...@@ -62,13 +62,24 @@ class FlashAttn3Weight(AttnWeightTemplate): ...@@ -62,13 +62,24 @@ class FlashAttn3Weight(AttnWeightTemplate):
max_seqlen_kv=None, max_seqlen_kv=None,
model_cls=None, model_cls=None,
): ):
x = flash_attn_varlen_func_v3( if len(q.shape) == 3:
q, x = flash_attn_varlen_func_v3(
k, q,
v, k,
cu_seqlens_q, v,
cu_seqlens_kv, cu_seqlens_q,
max_seqlen_q, cu_seqlens_kv,
max_seqlen_kv, max_seqlen_q,
).reshape(max_seqlen_q, -1) max_seqlen_kv,
).reshape(max_seqlen_q, -1)
elif len(q.shape) == 4:
x = flash_attn_varlen_func_v3(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
).reshape(q.shape[0] * max_seqlen_q, -1)
return x return x
...@@ -51,14 +51,23 @@ class SageAttn2Weight(AttnWeightTemplate): ...@@ -51,14 +51,23 @@ class SageAttn2Weight(AttnWeightTemplate):
) )
x = torch.cat((x1, x2), dim=1) x = torch.cat((x1, x2), dim=1)
x = x.view(max_seqlen_q, -1) x = x.view(max_seqlen_q, -1)
elif model_cls in ["wan2.1", "wan2.1_distill", "wan2.1_causvid", "wan2.1_df", "seko_talk", "wan2.2", "wan2.1_vace", "wan2.2_moe", "wan2.2_moe_distill", "qwen_image"]: elif model_cls in ["wan2.1", "wan2.1_distill", "wan2.1_causvid", "wan2.1_df", "seko_talk", "wan2.2", "wan2.1_vace", "wan2.2_moe", "wan2.2_animate", "wan2.2_moe_distill", "qwen_image"]:
x = sageattn( if len(q.shape) == 3:
q.unsqueeze(0), x = sageattn(
k.unsqueeze(0), q.unsqueeze(0),
v.unsqueeze(0), k.unsqueeze(0),
tensor_layout="NHD", v.unsqueeze(0),
) tensor_layout="NHD",
x = x.view(max_seqlen_q, -1) )
x = x.view(max_seqlen_q, -1)
elif len(q.shape) == 4:
x = sageattn(
q,
k,
v,
tensor_layout="NHD",
)
x = x.view(q.shape[0] * max_seqlen_q, -1)
else: else:
raise NotImplementedError(f"Model class '{model_cls}' is not implemented in this attention implementation") raise NotImplementedError(f"Model class '{model_cls}' is not implemented in this attention implementation")
return x return x
...@@ -8,6 +8,7 @@ from lightx2v.common.ops import * ...@@ -8,6 +8,7 @@ from lightx2v.common.ops import *
from lightx2v.models.runners.cogvideox.cogvidex_runner import CogvideoxRunner # noqa: F401 from lightx2v.models.runners.cogvideox.cogvidex_runner import CogvideoxRunner # noqa: F401
from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner # noqa: F401 from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner # noqa: F401
from lightx2v.models.runners.qwen_image.qwen_image_runner import QwenImageRunner # noqa: F401 from lightx2v.models.runners.qwen_image.qwen_image_runner import QwenImageRunner # noqa: F401
from lightx2v.models.runners.wan.wan_animate_runner import WanAnimateRunner # noqa: F401
from lightx2v.models.runners.wan.wan_audio_runner import Wan22AudioRunner, WanAudioRunner # noqa: F401 from lightx2v.models.runners.wan.wan_audio_runner import Wan22AudioRunner, WanAudioRunner # noqa: F401
from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner # noqa: F401 from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner # noqa: F401
from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner # noqa: F401 from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner # noqa: F401
...@@ -50,11 +51,12 @@ def main(): ...@@ -50,11 +51,12 @@ def main():
"wan2.2_audio", "wan2.2_audio",
"wan2.2_moe_distill", "wan2.2_moe_distill",
"qwen_image", "qwen_image",
"wan2.2_animate",
], ],
default="wan2.1", default="wan2.1",
) )
parser.add_argument("--task", type=str, choices=["t2v", "i2v", "t2i", "i2i", "flf2v", "vace"], default="t2v") parser.add_argument("--task", type=str, choices=["t2v", "i2v", "t2i", "i2i", "flf2v", "vace", "animate"], default="t2v")
parser.add_argument("--model_path", type=str, required=True) parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True) parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--use_prompt_enhancer", action="store_true") parser.add_argument("--use_prompt_enhancer", action="store_true")
......
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import math
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn
try:
from flash_attn import flash_attn_func, flash_attn_qkvpacked_func # noqa: F401
except ImportError:
flash_attn_func = None
MEMORY_LAYOUT = {
"flash": (
lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
lambda x: x,
),
"torch": (
lambda x: x.transpose(1, 2),
lambda x: x.transpose(1, 2),
),
"vanilla": (
lambda x: x.transpose(1, 2),
lambda x: x.transpose(1, 2),
),
}
def attention(
q,
k,
v,
mode="flash",
drop_rate=0,
attn_mask=None,
causal=False,
max_seqlen_q=None,
batch_size=1,
):
"""
Perform QKV self attention.
Args:
q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
k (torch.Tensor): Key tensor with shape [b, s1, a, d]
v (torch.Tensor): Value tensor with shape [b, s1, a, d]
mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
drop_rate (float): Dropout rate in attention map. (default: 0)
attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
(default: None)
causal (bool): Whether to use causal attention. (default: False)
cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
used to index into q.
cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
used to index into kv.
max_seqlen_q (int): The maximum sequence length in the batch of q.
max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
Returns:
torch.Tensor: Output tensor after self attention with shape [b, s, ad]
"""
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
if mode == "torch":
if attn_mask is not None and attn_mask.dtype != torch.bool:
attn_mask = attn_mask.to(q.dtype)
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
elif mode == "flash":
x = flash_attn_func(
q,
k,
v,
)
x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d]
elif mode == "vanilla":
scale_factor = 1 / math.sqrt(q.size(-1))
b, a, s, _ = q.shape
s1 = k.size(2)
attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
if causal:
# Only applied to self attention
assert attn_mask is None, "Causal mask and attn_mask cannot be used together"
temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(q.dtype)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
attn = (q @ k.transpose(-2, -1)) * scale_factor
attn += attn_bias
attn = attn.softmax(dim=-1)
attn = torch.dropout(attn, p=drop_rate, train=True)
x = attn @ v
else:
raise NotImplementedError(f"Unsupported attention mode: {mode}")
x = post_attn_layout(x)
b, s, a, d = x.shape
out = x.reshape(b, s, -1)
return out
class CausalConv1d(nn.Module):
def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", **kwargs):
super().__init__()
self.pad_mode = pad_mode
padding = (kernel_size - 1, 0) # T
self.time_causal_padding = padding
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
return self.conv(x)
class FaceEncoder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.num_heads = num_heads
self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1)
self.norm1 = nn.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.act = nn.SiLU()
self.conv2 = CausalConv1d(1024, 1024, 3, stride=2)
self.conv3 = CausalConv1d(1024, 1024, 3, stride=2)
self.out_proj = nn.Linear(1024, hidden_dim)
self.norm1 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.norm2 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.norm3 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))
def forward(self, x):
x = rearrange(x, "b t c -> b c t")
b, c, t = x.shape
x = self.conv1_local(x)
x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads)
x = self.norm1(x)
x = self.act(x)
x = rearrange(x, "b t c -> b c t")
x = self.conv2(x)
x = rearrange(x, "b c t -> b t c")
x = self.norm2(x)
x = self.act(x)
x = rearrange(x, "b t c -> b c t")
x = self.conv3(x)
x = rearrange(x, "b c t -> b t c")
x = self.norm3(x)
x = self.act(x)
x = self.out_proj(x)
x = rearrange(x, "(b n) t c -> b t n c", b=b)
padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1)
x = torch.cat([x, padding], dim=-2)
x_local = x.clone()
return x_local
# Modified from ``https://github.com/wyhsirius/LIA``
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
def custom_qr(input_tensor):
original_dtype = input_tensor.dtype
if original_dtype == torch.bfloat16:
q, r = torch.linalg.qr(input_tensor.to(torch.float32))
return q.to(original_dtype), r.to(original_dtype)
return torch.linalg.qr(input_tensor)
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5):
return F.leaky_relu(input + bias, negative_slope) * scale
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
_, minor, in_h, in_w = input.shape
kernel_h, kernel_w = kernel.shape
out = input.view(-1, minor, in_h, 1, in_w, 1)
out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])
out = out.view(-1, minor, in_h * up_y, in_w * up_x)
out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
out = out[
:,
:,
max(-pad_y0, 0) : out.shape[2] - max(-pad_y1, 0),
max(-pad_x0, 0) : out.shape[3] - max(-pad_x1, 0),
]
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
out = F.conv2d(out, w)
out = out.reshape(
-1,
minor,
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
)
return out[:, :, ::down_y, ::down_x]
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
def make_kernel(k):
k = torch.tensor(k, dtype=torch.float32)
if k.ndim == 1:
k = k[None, :] * k[:, None]
k /= k.sum()
return k
class FusedLeakyReLU(nn.Module):
def __init__(self, channel, negative_slope=0.2, scale=2**0.5):
super().__init__()
self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1))
self.negative_slope = negative_slope
self.scale = scale
def forward(self, input):
out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
return out
class Blur(nn.Module):
def __init__(self, kernel, pad, upsample_factor=1):
super().__init__()
kernel = make_kernel(kernel)
if upsample_factor > 1:
kernel = kernel * (upsample_factor**2)
self.register_buffer("kernel", kernel)
self.pad = pad
def forward(self, input):
return upfirdn2d(input, self.kernel, pad=self.pad)
class ScaledLeakyReLU(nn.Module):
def __init__(self, negative_slope=0.2):
super().__init__()
self.negative_slope = negative_slope
def forward(self, input):
return F.leaky_relu(input, negative_slope=self.negative_slope)
class EqualConv2d(nn.Module):
def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size))
self.scale = 1 / math.sqrt(in_channel * kernel_size**2)
self.stride = stride
self.padding = padding
if bias:
self.bias = nn.Parameter(torch.zeros(out_channel))
else:
self.bias = None
def forward(self, input):
return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding)
def __repr__(self):
return f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]}, {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
class EqualLinear(nn.Module):
def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
if bias:
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
else:
self.bias = None
self.activation = activation
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
self.lr_mul = lr_mul
def forward(self, input):
if self.activation:
out = F.linear(input, self.weight * self.scale)
out = fused_leaky_relu(out, self.bias * self.lr_mul)
else:
out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul)
return out
def __repr__(self):
return f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
class ConvLayer(nn.Sequential):
def __init__(
self,
in_channel,
out_channel,
kernel_size,
downsample=False,
blur_kernel=[1, 3, 3, 1],
bias=True,
activate=True,
):
layers = []
if downsample:
factor = 2
p = (len(blur_kernel) - factor) + (kernel_size - 1)
pad0 = (p + 1) // 2
pad1 = p // 2
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
stride = 2
self.padding = 0
else:
stride = 1
self.padding = kernel_size // 2
layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride, bias=bias and not activate))
if activate:
if bias:
layers.append(FusedLeakyReLU(out_channel))
else:
layers.append(ScaledLeakyReLU(0.2))
super().__init__(*layers)
class ResBlock(nn.Module):
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
super().__init__()
self.conv1 = ConvLayer(in_channel, in_channel, 3)
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False)
def forward(self, input):
out = self.conv1(input)
out = self.conv2(out)
skip = self.skip(input)
out = (out + skip) / math.sqrt(2)
return out
class EncoderApp(nn.Module):
def __init__(self, size, w_dim=512):
super(EncoderApp, self).__init__()
channels = {4: 512, 8: 512, 16: 512, 32: 512, 64: 256, 128: 128, 256: 64, 512: 32, 1024: 16}
self.w_dim = w_dim
log_size = int(math.log(size, 2))
self.convs = nn.ModuleList()
self.convs.append(ConvLayer(3, channels[size], 1))
in_channel = channels[size]
for i in range(log_size, 2, -1):
out_channel = channels[2 ** (i - 1)]
self.convs.append(ResBlock(in_channel, out_channel))
in_channel = out_channel
self.convs.append(EqualConv2d(in_channel, self.w_dim, 4, padding=0, bias=False))
def forward(self, x):
res = []
h = x
for conv in self.convs:
h = conv(h)
res.append(h)
return res[-1].squeeze(-1).squeeze(-1), res[::-1][2:]
class Encoder(nn.Module):
def __init__(self, size, dim=512, dim_motion=20):
super(Encoder, self).__init__()
# appearance netmork
self.net_app = EncoderApp(size, dim)
# motion network
fc = [EqualLinear(dim, dim)]
for i in range(3):
fc.append(EqualLinear(dim, dim))
fc.append(EqualLinear(dim, dim_motion))
self.fc = nn.Sequential(*fc)
def enc_app(self, x):
h_source = self.net_app(x)
return h_source
def enc_motion(self, x):
h, _ = self.net_app(x)
h_motion = self.fc(h)
return h_motion
class Direction(nn.Module):
def __init__(self, motion_dim):
super(Direction, self).__init__()
self.weight = nn.Parameter(torch.randn(512, motion_dim))
def forward(self, input):
weight = self.weight + 1e-8
Q, R = custom_qr(weight)
if input is None:
return Q
else:
input_diag = torch.diag_embed(input) # alpha, diagonal matrix
out = torch.matmul(input_diag, Q.T)
out = torch.sum(out, dim=1)
return out
class Synthesis(nn.Module):
def __init__(self, motion_dim):
super(Synthesis, self).__init__()
self.direction = Direction(motion_dim)
class Generator(nn.Module):
def __init__(self, size, style_dim=512, motion_dim=20):
super().__init__()
self.enc = Encoder(size, style_dim, motion_dim)
self.dec = Synthesis(motion_dim)
def get_motion(self, img):
# motion_feat = self.enc.enc_motion(img)
motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True)
with torch.amp.autocast("cuda", dtype=torch.float32):
motion = self.dec.direction(motion_feat)
return motion
from lightx2v.models.networks.wan.infer.animate.pre_infer import WanAnimatePreInfer
from lightx2v.models.networks.wan.infer.animate.transformer_infer import WanAnimateTransformerInfer
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.weights.animate.transformer_weights import WanAnimateTransformerWeights
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
class WanAnimateModel(WanModel):
pre_weight_class = WanPreWeights
transformer_weight_class = WanAnimateTransformerWeights
def __init__(self, model_path, config, device):
self.remove_keys = ["face_encoder", "motion_encoder"]
super().__init__(model_path, config, device)
def _init_infer_class(self):
super()._init_infer_class()
self.pre_infer_class = WanAnimatePreInfer
self.transformer_infer_class = WanAnimateTransformerInfer
def set_animate_encoders(self, motion_encoder, face_encoder):
self.pre_infer.set_animate_encoders(motion_encoder, face_encoder)
import math
import torch
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.utils.envs import *
class WanAnimatePreInfer(WanPreInfer):
def __init__(self, config):
super().__init__(config)
self.encode_bs = 8
def set_animate_encoders(self, motion_encoder, face_encoder):
self.motion_encoder = motion_encoder
self.face_encoder = face_encoder
@torch.no_grad()
def after_patch_embedding(self, weights, x, pose_latents, face_pixel_values):
pose_latents = weights.pose_patch_embedding.apply(pose_latents)
x[:, :, 1:].add_(pose_latents)
face_pixel_values_tmp = []
for i in range(math.ceil(face_pixel_values.shape[0] / self.encode_bs)):
face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i * self.encode_bs : (i + 1) * self.encode_bs]))
motion_vec = torch.cat(face_pixel_values_tmp)
motion_vec = self.face_encoder(motion_vec.unsqueeze(0).to(GET_DTYPE())).squeeze(0)
pad_face = torch.zeros(1, motion_vec.shape[1], motion_vec.shape[2], dtype=motion_vec.dtype, device="cuda")
motion_vec = torch.cat([pad_face, motion_vec], dim=0)
return x, motion_vec
import torch
from einops import rearrange
from lightx2v.models.networks.wan.infer.offload.transformer_infer import WanOffloadTransformerInfer
class WanAnimateTransformerInfer(WanOffloadTransformerInfer):
def __init__(self, config):
super().__init__(config)
self.has_post_adapter = True
self.phases_num = 4
@torch.no_grad()
def infer_post_adapter(self, phase, x, pre_infer_out):
if phase.is_empty():
return x
T = pre_infer_out.motion_vec.shape[0]
x_motion = phase.pre_norm_motion.apply(pre_infer_out.motion_vec)
x_feat = phase.pre_norm_feat.apply(x)
kv = phase.linear1_kv.apply(x_motion.view(-1, x_motion.shape[-1]))
kv = kv.view(T, -1, kv.shape[-1])
q = phase.linear1_q.apply(x_feat)
k, v = rearrange(kv, "L N (K H D) -> K L N H D", K=2, H=self.config.num_heads)
q = rearrange(q, "S (H D) -> S H D", H=self.config.num_heads)
q = phase.q_norm.apply(q).view(T, q.shape[0] // T, q.shape[1], q.shape[2])
k = phase.k_norm.apply(k)
attn = phase.adapter_attn.apply(
q=q,
k=k,
v=v,
max_seqlen_q=q.shape[1],
model_cls=self.config["model_cls"],
)
output = phase.linear2.apply(attn)
x = x.add_(output)
return x
...@@ -19,4 +19,5 @@ class WanPreInferModuleOutput: ...@@ -19,4 +19,5 @@ class WanPreInferModuleOutput:
seq_lens: torch.Tensor seq_lens: torch.Tensor
freqs: torch.Tensor freqs: torch.Tensor
context: torch.Tensor context: torch.Tensor
motion_vec: torch.Tensor
adapter_output: Dict[str, Any] = field(default_factory=dict) adapter_output: Dict[str, Any] = field(default_factory=dict)
...@@ -41,7 +41,7 @@ class WanPreInfer: ...@@ -41,7 +41,7 @@ class WanPreInfer:
else: else:
context = inputs["text_encoder_output"]["context_null"] context = inputs["text_encoder_output"]["context_null"]
if self.task in ["i2v", "flf2v"]: if self.task in ["i2v", "flf2v", "animate"]:
if self.config.get("use_image_encoder", True): if self.config.get("use_image_encoder", True):
clip_fea = inputs["image_encoder_output"]["clip_encoder_out"] clip_fea = inputs["image_encoder_output"]["clip_encoder_out"]
...@@ -61,6 +61,12 @@ class WanPreInfer: ...@@ -61,6 +61,12 @@ class WanPreInfer:
# embeddings # embeddings
x = weights.patch_embedding.apply(x.unsqueeze(0)) x = weights.patch_embedding.apply(x.unsqueeze(0))
if hasattr(self, "after_patch_embedding"):
x, motion_vec = self.after_patch_embedding(weights, x, inputs["image_encoder_output"]["pose_latents"], inputs["image_encoder_output"]["face_pixel_values"])
else:
motion_vec = None
grid_sizes_t, grid_sizes_h, grid_sizes_w = x.shape[2:] grid_sizes_t, grid_sizes_h, grid_sizes_w = x.shape[2:]
x = x.flatten(2).transpose(1, 2).contiguous() x = x.flatten(2).transpose(1, 2).contiguous()
seq_lens = torch.tensor(x.size(1), dtype=torch.int32, device=x.device).unsqueeze(0) seq_lens = torch.tensor(x.size(1), dtype=torch.int32, device=x.device).unsqueeze(0)
...@@ -94,7 +100,7 @@ class WanPreInfer: ...@@ -94,7 +100,7 @@ class WanPreInfer:
del out del out
torch.cuda.empty_cache() torch.cuda.empty_cache()
if self.task in ["i2v", "flf2v"] and self.config.get("use_image_encoder", True): if self.task in ["i2v", "flf2v", "animate"] and self.config.get("use_image_encoder", True):
if self.task == "flf2v": if self.task == "flf2v":
_, n, d = clip_fea.shape _, n, d = clip_fea.shape
clip_fea = clip_fea.view(2 * n, d) clip_fea = clip_fea.view(2 * n, d)
...@@ -125,4 +131,5 @@ class WanPreInfer: ...@@ -125,4 +131,5 @@ class WanPreInfer:
seq_lens=seq_lens, seq_lens=seq_lens,
freqs=self.freqs, freqs=self.freqs,
context=context, context=context,
motion_vec=motion_vec,
) )
...@@ -203,7 +203,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -203,7 +203,7 @@ class WanTransformerInfer(BaseTransformerInfer):
x.add_(y_out * gate_msa.squeeze()) x.add_(y_out * gate_msa.squeeze())
norm3_out = phase.norm3.apply(x) norm3_out = phase.norm3.apply(x)
if self.task in ["i2v", "flf2v"] and self.config.get("use_image_encoder", True): if self.task in ["i2v", "flf2v", "animate"] and self.config.get("use_image_encoder", True):
context_img = context[:257] context_img = context[:257]
context = context[257:] context = context[257:]
else: else:
...@@ -211,7 +211,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -211,7 +211,7 @@ class WanTransformerInfer(BaseTransformerInfer):
if self.sensitive_layer_dtype != self.infer_dtype: if self.sensitive_layer_dtype != self.infer_dtype:
context = context.to(self.infer_dtype) context = context.to(self.infer_dtype)
if self.task in ["i2v", "flf2v"] and self.config.get("use_image_encoder", True): if self.task in ["i2v", "flf2v", "animate"] and self.config.get("use_image_encoder", True):
context_img = context_img.to(self.infer_dtype) context_img = context_img.to(self.infer_dtype)
n, d = self.num_heads, self.head_dim n, d = self.num_heads, self.head_dim
...@@ -234,7 +234,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -234,7 +234,7 @@ class WanTransformerInfer(BaseTransformerInfer):
model_cls=self.config["model_cls"], model_cls=self.config["model_cls"],
) )
if self.task in ["i2v", "flf2v"] and self.config.get("use_image_encoder", True) and context_img is not None: if self.task in ["i2v", "flf2v", "animate"] and self.config.get("use_image_encoder", True) and context_img is not None:
k_img = phase.cross_attn_norm_k_img.apply(phase.cross_attn_k_img.apply(context_img)).view(-1, n, d) k_img = phase.cross_attn_norm_k_img.apply(phase.cross_attn_k_img.apply(context_img)).view(-1, n, d)
v_img = phase.cross_attn_v_img.apply(context_img).view(-1, n, d) v_img = phase.cross_attn_v_img.apply(context_img).view(-1, n, d)
......
...@@ -137,12 +137,19 @@ class WanModel(CompiledMethodsMixin): ...@@ -137,12 +137,19 @@ class WanModel(CompiledMethodsMixin):
return False return False
def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer): def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer):
remove_keys = self.remove_keys if hasattr(self, "remove_keys") else []
if self.device.type == "cuda" and dist.is_initialized(): if self.device.type == "cuda" and dist.is_initialized():
device = torch.device("cuda:{}".format(dist.get_rank())) device = torch.device("cuda:{}".format(dist.get_rank()))
else: else:
device = self.device device = self.device
with safe_open(file_path, framework="pt", device=str(device)) as f: with safe_open(file_path, framework="pt", device=str(device)) as f:
return {key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE())) for key in f.keys()} return {
key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE()))
for key in f.keys()
if not any(remove_key in key for remove_key in remove_keys)
}
def _load_ckpt(self, unified_dtype, sensitive_layer): def _load_ckpt(self, unified_dtype, sensitive_layer):
safetensors_path = find_hf_model_path(self.config, self.model_path, "dit_original_ckpt", subdir="original") safetensors_path = find_hf_model_path(self.config, self.model_path, "dit_original_ckpt", subdir="original")
...@@ -158,6 +165,7 @@ class WanModel(CompiledMethodsMixin): ...@@ -158,6 +165,7 @@ class WanModel(CompiledMethodsMixin):
return weight_dict return weight_dict
def _load_quant_ckpt(self, unified_dtype, sensitive_layer): def _load_quant_ckpt(self, unified_dtype, sensitive_layer):
remove_keys = self.remove_keys if hasattr(self, "remove_keys") else []
ckpt_path = self.dit_quantized_ckpt ckpt_path = self.dit_quantized_ckpt
index_files = [f for f in os.listdir(ckpt_path) if f.endswith(".index.json")] index_files = [f for f in os.listdir(ckpt_path) if f.endswith(".index.json")]
if not index_files: if not index_files:
...@@ -175,6 +183,9 @@ class WanModel(CompiledMethodsMixin): ...@@ -175,6 +183,9 @@ class WanModel(CompiledMethodsMixin):
with safe_open(safetensor_path, framework="pt") as f: with safe_open(safetensor_path, framework="pt") as f:
logger.info(f"Loading weights from {safetensor_path}") logger.info(f"Loading weights from {safetensor_path}")
for k in f.keys(): for k in f.keys():
if any(remove_key in k for remove_key in remove_keys):
continue
if f.get_tensor(k).dtype in [ if f.get_tensor(k).dtype in [
torch.float16, torch.float16,
torch.bfloat16, torch.bfloat16,
......
import os
from safetensors import safe_open
from lightx2v.common.modules.weight_module import WeightModule
from lightx2v.models.networks.wan.weights.transformer_weights import (
WanTransformerWeights,
)
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER, LN_WEIGHT_REGISTER, MM_WEIGHT_REGISTER, RMS_WEIGHT_REGISTER
class WanAnimateTransformerWeights(WanTransformerWeights):
def __init__(self, config):
super().__init__(config)
self.adapter_blocks_num = self.blocks_num // 5
for i in range(self.blocks_num):
if i % 5 == 0:
self.blocks[i].compute_phases.append(WanAnimateFuserBlock(self.config, i // 5, "face_adapter.fuser_blocks", self.mm_type))
else:
self.blocks[i].compute_phases.append(WeightModule())
class WanAnimateFuserBlock(WeightModule):
def __init__(self, config, block_index, block_prefix, mm_type):
super().__init__()
self.config = config
lazy_load = config.get("lazy_load", False)
if lazy_load:
lazy_load_path = os.path.join(config.dit_quantized_ckpt, f"{block_prefix[:-1]}_{block_index}.safetensors")
lazy_load_file = safe_open(lazy_load_path, framework="pt", device="cpu")
else:
lazy_load_file = None
self.add_module(
"linear1_kv",
MM_WEIGHT_REGISTER[mm_type](f"{block_prefix}.{block_index}.linear1_kv.weight", f"{block_prefix}.{block_index}.linear1_kv.bias", lazy_load, lazy_load_file),
)
self.add_module(
"linear1_q",
MM_WEIGHT_REGISTER[mm_type](f"{block_prefix}.{block_index}.linear1_q.weight", f"{block_prefix}.{block_index}.linear1_q.bias", lazy_load, lazy_load_file),
)
self.add_module(
"linear2",
MM_WEIGHT_REGISTER[mm_type](f"{block_prefix}.{block_index}.linear2.weight", f"{block_prefix}.{block_index}.linear2.bias", lazy_load, lazy_load_file),
)
self.add_module(
"q_norm",
RMS_WEIGHT_REGISTER["sgl-kernel"](
f"{block_prefix}.{block_index}.q_norm.weight",
lazy_load,
lazy_load_file,
),
)
self.add_module(
"k_norm",
RMS_WEIGHT_REGISTER["sgl-kernel"](
f"{block_prefix}.{block_index}.k_norm.weight",
lazy_load,
lazy_load_file,
),
)
self.add_module(
"pre_norm_feat",
LN_WEIGHT_REGISTER["Default"](),
)
self.add_module(
"pre_norm_motion",
LN_WEIGHT_REGISTER["Default"](),
)
self.add_module("adapter_attn", ATTN_WEIGHT_REGISTER[config["adapter_attn_type"]]())
...@@ -40,7 +40,7 @@ class WanPreWeights(WeightModule): ...@@ -40,7 +40,7 @@ class WanPreWeights(WeightModule):
MM_WEIGHT_REGISTER["Default"]("time_projection.1.weight", "time_projection.1.bias"), MM_WEIGHT_REGISTER["Default"]("time_projection.1.weight", "time_projection.1.bias"),
) )
if config.task in ["i2v", "flf2v"] and config.get("use_image_encoder", True): if config.task in ["i2v", "flf2v", "animate"] and config.get("use_image_encoder", True):
self.add_module( self.add_module(
"proj_0", "proj_0",
LN_WEIGHT_REGISTER["Default"]("img_emb.proj.0.weight", "img_emb.proj.0.bias"), LN_WEIGHT_REGISTER["Default"]("img_emb.proj.0.weight", "img_emb.proj.0.bias"),
...@@ -73,3 +73,8 @@ class WanPreWeights(WeightModule): ...@@ -73,3 +73,8 @@ class WanPreWeights(WeightModule):
"emb_pos", "emb_pos",
TENSOR_REGISTER["Default"](f"img_emb.emb_pos"), TENSOR_REGISTER["Default"](f"img_emb.emb_pos"),
) )
if config.task == "animate":
self.add_module(
"pose_patch_embedding",
CONV3D_WEIGHT_REGISTER["Default"]("pose_patch_embedding.weight", "pose_patch_embedding.bias", stride=self.patch_size),
)
...@@ -285,7 +285,7 @@ class WanCrossAttention(WeightModule): ...@@ -285,7 +285,7 @@ class WanCrossAttention(WeightModule):
) )
self.add_module("cross_attn_1", ATTN_WEIGHT_REGISTER[self.config["cross_attn_1_type"]]()) self.add_module("cross_attn_1", ATTN_WEIGHT_REGISTER[self.config["cross_attn_1_type"]]())
if self.config.task in ["i2v", "flf2v"] and self.config.get("use_image_encoder", True): if self.config.task in ["i2v", "flf2v", "animate"] and self.config.get("use_image_encoder", True):
self.add_module( self.add_module(
"cross_attn_k_img", "cross_attn_k_img",
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
......
...@@ -145,7 +145,7 @@ class BaseRunner(ABC): ...@@ -145,7 +145,7 @@ class BaseRunner(ABC):
def run_segment(self, total_steps=None): def run_segment(self, total_steps=None):
pass pass
def end_run_segment(self): def end_run_segment(self, segment_idx=None):
pass pass
def end_run(self): def end_run(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment