Commit 1336a33d authored by zzg_666's avatar zzg_666
Browse files

wan2.2

parents
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import numbers
from peft import LoraConfig
def get_loraconfig(transformer, rank=128, alpha=128, init_lora_weights="gaussian"):
target_modules = []
for name, module in transformer.named_modules():
if "blocks" in name and "face" not in name and "modulation" not in name and isinstance(module, torch.nn.Linear):
target_modules.append(name)
transformer_lora_config = LoraConfig(
r=rank,
lora_alpha=alpha,
init_lora_weights=init_lora_weights,
target_modules=target_modules,
)
return transformer_lora_config
class TensorList(object):
def __init__(self, tensors):
"""
tensors: a list of torch.Tensor objects. No need to have uniform shape.
"""
assert isinstance(tensors, (list, tuple))
assert all(isinstance(u, torch.Tensor) for u in tensors)
assert len(set([u.ndim for u in tensors])) == 1
assert len(set([u.dtype for u in tensors])) == 1
assert len(set([u.device for u in tensors])) == 1
self.tensors = tensors
def to(self, *args, **kwargs):
return TensorList([u.to(*args, **kwargs) for u in self.tensors])
def size(self, dim):
assert dim == 0, 'only support get the 0th size'
return len(self.tensors)
def pow(self, *args, **kwargs):
return TensorList([u.pow(*args, **kwargs) for u in self.tensors])
def squeeze(self, dim):
assert dim != 0
if dim > 0:
dim -= 1
return TensorList([u.squeeze(dim) for u in self.tensors])
def type(self, *args, **kwargs):
return TensorList([u.type(*args, **kwargs) for u in self.tensors])
def type_as(self, other):
assert isinstance(other, (torch.Tensor, TensorList))
if isinstance(other, torch.Tensor):
return TensorList([u.type_as(other) for u in self.tensors])
else:
return TensorList([u.type(other.dtype) for u in self.tensors])
@property
def dtype(self):
return self.tensors[0].dtype
@property
def device(self):
return self.tensors[0].device
@property
def ndim(self):
return 1 + self.tensors[0].ndim
def __getitem__(self, index):
return self.tensors[index]
def __len__(self):
return len(self.tensors)
def __add__(self, other):
return self._apply(other, lambda u, v: u + v)
def __radd__(self, other):
return self._apply(other, lambda u, v: v + u)
def __sub__(self, other):
return self._apply(other, lambda u, v: u - v)
def __rsub__(self, other):
return self._apply(other, lambda u, v: v - u)
def __mul__(self, other):
return self._apply(other, lambda u, v: u * v)
def __rmul__(self, other):
return self._apply(other, lambda u, v: v * u)
def __floordiv__(self, other):
return self._apply(other, lambda u, v: u // v)
def __truediv__(self, other):
return self._apply(other, lambda u, v: u / v)
def __rfloordiv__(self, other):
return self._apply(other, lambda u, v: v // u)
def __rtruediv__(self, other):
return self._apply(other, lambda u, v: v / u)
def __pow__(self, other):
return self._apply(other, lambda u, v: u ** v)
def __rpow__(self, other):
return self._apply(other, lambda u, v: v ** u)
def __neg__(self):
return TensorList([-u for u in self.tensors])
def __iter__(self):
for tensor in self.tensors:
yield tensor
def __repr__(self):
return 'TensorList: \n' + repr(self.tensors)
def _apply(self, other, op):
if isinstance(other, (list, tuple, TensorList)) or (
isinstance(other, torch.Tensor) and (
other.numel() > 1 or other.ndim > 1
)
):
assert len(other) == len(self.tensors)
return TensorList([op(u, v) for u, v in zip(self.tensors, other)])
elif isinstance(other, numbers.Number) or (
isinstance(other, torch.Tensor) and (
other.numel() == 1 and other.ndim <= 1
)
):
return TensorList([op(u, other) for u in self.tensors])
else:
raise TypeError(
f'unsupported operand for *: "TensorList" and "{type(other)}"'
)
\ No newline at end of file
# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import logging
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from ..attention import flash_attention
from ..tokenizers import HuggingfaceTokenizer
from .xlm_roberta import XLMRoberta
__all__ = [
'XLMRobertaCLIP',
'clip_xlm_roberta_vit_h_14',
'CLIPModel',
]
def pos_interpolate(pos, seq_len):
if pos.size(1) == seq_len:
return pos
else:
src_grid = int(math.sqrt(pos.size(1)))
tar_grid = int(math.sqrt(seq_len))
n = pos.size(1) - src_grid * src_grid
return torch.cat([
pos[:, :n],
F.interpolate(
pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
0, 3, 1, 2),
size=(tar_grid, tar_grid),
mode='bicubic',
align_corners=False).flatten(2).transpose(1, 2)
],
dim=1)
class QuickGELU(nn.Module):
def forward(self, x):
return x * torch.sigmoid(1.702 * x)
class LayerNorm(nn.LayerNorm):
def forward(self, x):
return super().forward(x.float()).type_as(x)
class SelfAttention(nn.Module):
def __init__(self,
dim,
num_heads,
causal=False,
attn_dropout=0.0,
proj_dropout=0.0):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.causal = causal
self.attn_dropout = attn_dropout
self.proj_dropout = proj_dropout
# layers
self.to_qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)
def forward(self, x):
"""
x: [B, L, C].
"""
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
# compute query, key, value
q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
# compute attention
p = self.attn_dropout if self.training else 0.0
x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
x = x.reshape(b, s, c)
# output
x = self.proj(x)
x = F.dropout(x, self.proj_dropout, self.training)
return x
class SwiGLU(nn.Module):
def __init__(self, dim, mid_dim):
super().__init__()
self.dim = dim
self.mid_dim = mid_dim
# layers
self.fc1 = nn.Linear(dim, mid_dim)
self.fc2 = nn.Linear(dim, mid_dim)
self.fc3 = nn.Linear(mid_dim, dim)
def forward(self, x):
x = F.silu(self.fc1(x)) * self.fc2(x)
x = self.fc3(x)
return x
class AttentionBlock(nn.Module):
def __init__(self,
dim,
mlp_ratio,
num_heads,
post_norm=False,
causal=False,
activation='quick_gelu',
attn_dropout=0.0,
proj_dropout=0.0,
norm_eps=1e-5):
assert activation in ['quick_gelu', 'gelu', 'swi_glu']
super().__init__()
self.dim = dim
self.mlp_ratio = mlp_ratio
self.num_heads = num_heads
self.post_norm = post_norm
self.causal = causal
self.norm_eps = norm_eps
# layers
self.norm1 = LayerNorm(dim, eps=norm_eps)
self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
proj_dropout)
self.norm2 = LayerNorm(dim, eps=norm_eps)
if activation == 'swi_glu':
self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
else:
self.mlp = nn.Sequential(
nn.Linear(dim, int(dim * mlp_ratio)),
QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
def forward(self, x):
if self.post_norm:
x = x + self.norm1(self.attn(x))
x = x + self.norm2(self.mlp(x))
else:
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class AttentionPool(nn.Module):
def __init__(self,
dim,
mlp_ratio,
num_heads,
activation='gelu',
proj_dropout=0.0,
norm_eps=1e-5):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.mlp_ratio = mlp_ratio
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.proj_dropout = proj_dropout
self.norm_eps = norm_eps
# layers
gain = 1.0 / math.sqrt(dim)
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
self.to_q = nn.Linear(dim, dim)
self.to_kv = nn.Linear(dim, dim * 2)
self.proj = nn.Linear(dim, dim)
self.norm = LayerNorm(dim, eps=norm_eps)
self.mlp = nn.Sequential(
nn.Linear(dim, int(dim * mlp_ratio)),
QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
def forward(self, x):
"""
x: [B, L, C].
"""
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
# compute query, key, value
q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
# compute attention
x = flash_attention(q, k, v, version=2)
x = x.reshape(b, 1, c)
# output
x = self.proj(x)
x = F.dropout(x, self.proj_dropout, self.training)
# mlp
x = x + self.mlp(self.norm(x))
return x[:, 0]
class VisionTransformer(nn.Module):
def __init__(self,
image_size=224,
patch_size=16,
dim=768,
mlp_ratio=4,
out_dim=512,
num_heads=12,
num_layers=12,
pool_type='token',
pre_norm=True,
post_norm=False,
activation='quick_gelu',
attn_dropout=0.0,
proj_dropout=0.0,
embedding_dropout=0.0,
norm_eps=1e-5):
if image_size % patch_size != 0:
print(
'[WARNING] image_size is not divisible by patch_size',
flush=True)
assert pool_type in ('token', 'token_fc', 'attn_pool')
out_dim = out_dim or dim
super().__init__()
self.image_size = image_size
self.patch_size = patch_size
self.num_patches = (image_size // patch_size)**2
self.dim = dim
self.mlp_ratio = mlp_ratio
self.out_dim = out_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.pool_type = pool_type
self.post_norm = post_norm
self.norm_eps = norm_eps
# embeddings
gain = 1.0 / math.sqrt(dim)
self.patch_embedding = nn.Conv2d(
3,
dim,
kernel_size=patch_size,
stride=patch_size,
bias=not pre_norm)
if pool_type in ('token', 'token_fc'):
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
self.pos_embedding = nn.Parameter(gain * torch.randn(
1, self.num_patches +
(1 if pool_type in ('token', 'token_fc') else 0), dim))
self.dropout = nn.Dropout(embedding_dropout)
# transformer
self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
self.transformer = nn.Sequential(*[
AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
activation, attn_dropout, proj_dropout, norm_eps)
for _ in range(num_layers)
])
self.post_norm = LayerNorm(dim, eps=norm_eps)
# head
if pool_type == 'token':
self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
elif pool_type == 'token_fc':
self.head = nn.Linear(dim, out_dim)
elif pool_type == 'attn_pool':
self.head = AttentionPool(dim, mlp_ratio, num_heads, activation,
proj_dropout, norm_eps)
def forward(self, x, interpolation=False, use_31_block=False):
b = x.size(0)
# embeddings
x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
if self.pool_type in ('token', 'token_fc'):
x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)
if interpolation:
e = pos_interpolate(self.pos_embedding, x.size(1))
else:
e = self.pos_embedding
x = self.dropout(x + e)
if self.pre_norm is not None:
x = self.pre_norm(x)
# transformer
if use_31_block:
x = self.transformer[:-1](x)
return x
else:
x = self.transformer(x)
return x
class XLMRobertaWithHead(XLMRoberta):
def __init__(self, **kwargs):
self.out_dim = kwargs.pop('out_dim')
super().__init__(**kwargs)
# head
mid_dim = (self.dim + self.out_dim) // 2
self.head = nn.Sequential(
nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),
nn.Linear(mid_dim, self.out_dim, bias=False))
def forward(self, ids):
# xlm-roberta
x = super().forward(ids)
# average pooling
mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
x = (x * mask).sum(dim=1) / mask.sum(dim=1)
# head
x = self.head(x)
return x
class XLMRobertaCLIP(nn.Module):
def __init__(self,
embed_dim=1024,
image_size=224,
patch_size=14,
vision_dim=1280,
vision_mlp_ratio=4,
vision_heads=16,
vision_layers=32,
vision_pool='token',
vision_pre_norm=True,
vision_post_norm=False,
activation='gelu',
vocab_size=250002,
max_text_len=514,
type_size=1,
pad_id=1,
text_dim=1024,
text_heads=16,
text_layers=24,
text_post_norm=True,
text_dropout=0.1,
attn_dropout=0.0,
proj_dropout=0.0,
embedding_dropout=0.0,
norm_eps=1e-5):
super().__init__()
self.embed_dim = embed_dim
self.image_size = image_size
self.patch_size = patch_size
self.vision_dim = vision_dim
self.vision_mlp_ratio = vision_mlp_ratio
self.vision_heads = vision_heads
self.vision_layers = vision_layers
self.vision_pre_norm = vision_pre_norm
self.vision_post_norm = vision_post_norm
self.activation = activation
self.vocab_size = vocab_size
self.max_text_len = max_text_len
self.type_size = type_size
self.pad_id = pad_id
self.text_dim = text_dim
self.text_heads = text_heads
self.text_layers = text_layers
self.text_post_norm = text_post_norm
self.norm_eps = norm_eps
# models
self.visual = VisionTransformer(
image_size=image_size,
patch_size=patch_size,
dim=vision_dim,
mlp_ratio=vision_mlp_ratio,
out_dim=embed_dim,
num_heads=vision_heads,
num_layers=vision_layers,
pool_type=vision_pool,
pre_norm=vision_pre_norm,
post_norm=vision_post_norm,
activation=activation,
attn_dropout=attn_dropout,
proj_dropout=proj_dropout,
embedding_dropout=embedding_dropout,
norm_eps=norm_eps)
self.textual = XLMRobertaWithHead(
vocab_size=vocab_size,
max_seq_len=max_text_len,
type_size=type_size,
pad_id=pad_id,
dim=text_dim,
out_dim=embed_dim,
num_heads=text_heads,
num_layers=text_layers,
post_norm=text_post_norm,
dropout=text_dropout)
self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
def forward(self, imgs, txt_ids):
"""
imgs: [B, 3, H, W] of torch.float32.
- mean: [0.48145466, 0.4578275, 0.40821073]
- std: [0.26862954, 0.26130258, 0.27577711]
txt_ids: [B, L] of torch.long.
Encoded by data.CLIPTokenizer.
"""
xi = self.visual(imgs)
xt = self.textual(txt_ids)
return xi, xt
def param_groups(self):
groups = [{
'params': [
p for n, p in self.named_parameters()
if 'norm' in n or n.endswith('bias')
],
'weight_decay': 0.0
}, {
'params': [
p for n, p in self.named_parameters()
if not ('norm' in n or n.endswith('bias'))
]
}]
return groups
def _clip(pretrained=False,
pretrained_name=None,
model_cls=XLMRobertaCLIP,
return_transforms=False,
return_tokenizer=False,
tokenizer_padding='eos',
dtype=torch.float32,
device='cpu',
**kwargs):
# init a model on device
with torch.device(device):
model = model_cls(**kwargs)
# set device
model = model.to(dtype=dtype, device=device)
output = (model,)
# init transforms
if return_transforms:
# mean and std
if 'siglip' in pretrained_name.lower():
mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
else:
mean = [0.48145466, 0.4578275, 0.40821073]
std = [0.26862954, 0.26130258, 0.27577711]
# transforms
transforms = T.Compose([
T.Resize((model.image_size, model.image_size),
interpolation=T.InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=mean, std=std)
])
output += (transforms,)
return output[0] if len(output) == 1 else output
def clip_xlm_roberta_vit_h_14(
pretrained=False,
pretrained_name='open-clip-xlm-roberta-large-vit-huge-14',
**kwargs):
cfg = dict(
embed_dim=1024,
image_size=224,
patch_size=14,
vision_dim=1280,
vision_mlp_ratio=4,
vision_heads=16,
vision_layers=32,
vision_pool='token',
activation='gelu',
vocab_size=250002,
max_text_len=514,
type_size=1,
pad_id=1,
text_dim=1024,
text_heads=16,
text_layers=24,
text_post_norm=True,
text_dropout=0.1,
attn_dropout=0.0,
proj_dropout=0.0,
embedding_dropout=0.0)
cfg.update(**kwargs)
return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
class CLIPModel:
def __init__(self, dtype, device, checkpoint_path, tokenizer_path):
self.dtype = dtype
self.device = device
self.checkpoint_path = checkpoint_path
self.tokenizer_path = tokenizer_path
# init model
self.model, self.transforms = clip_xlm_roberta_vit_h_14(
pretrained=False,
return_transforms=True,
return_tokenizer=False,
dtype=dtype,
device=device)
self.model = self.model.eval().requires_grad_(False)
logging.info(f'loading {checkpoint_path}')
self.model.load_state_dict(
torch.load(checkpoint_path, map_location='cpu'))
# init tokenizer
self.tokenizer = HuggingfaceTokenizer(
name=tokenizer_path,
seq_len=self.model.max_text_len - 2,
clean='whitespace')
def visual(self, videos):
# preprocess
size = (self.model.image_size,) * 2
videos = torch.cat([
F.interpolate(
u.transpose(0, 1),
size=size,
mode='bicubic',
align_corners=False) for u in videos
])
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
# forward
with torch.cuda.amp.autocast(dtype=self.dtype):
out = self.model.visual(videos, use_31_block=True)
return out
\ No newline at end of file
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from torch import nn
import torch
from typing import Tuple, Optional
from einops import rearrange
import torch.nn.functional as F
import math
from ...distributed.util import gather_forward, get_rank, get_world_size
try:
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
except ImportError:
flash_attn_func = None
MEMORY_LAYOUT = {
"flash": (
lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
lambda x: x,
),
"torch": (
lambda x: x.transpose(1, 2),
lambda x: x.transpose(1, 2),
),
"vanilla": (
lambda x: x.transpose(1, 2),
lambda x: x.transpose(1, 2),
),
}
def attention(
q,
k,
v,
mode="flash",
drop_rate=0,
attn_mask=None,
causal=False,
max_seqlen_q=None,
batch_size=1,
):
"""
Perform QKV self attention.
Args:
q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
k (torch.Tensor): Key tensor with shape [b, s1, a, d]
v (torch.Tensor): Value tensor with shape [b, s1, a, d]
mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
drop_rate (float): Dropout rate in attention map. (default: 0)
attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
(default: None)
causal (bool): Whether to use causal attention. (default: False)
cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
used to index into q.
cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
used to index into kv.
max_seqlen_q (int): The maximum sequence length in the batch of q.
max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
Returns:
torch.Tensor: Output tensor after self attention with shape [b, s, ad]
"""
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
if mode == "torch":
if attn_mask is not None and attn_mask.dtype != torch.bool:
attn_mask = attn_mask.to(q.dtype)
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
elif mode == "flash":
x = flash_attn_func(
q,
k,
v,
)
x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d]
elif mode == "vanilla":
scale_factor = 1 / math.sqrt(q.size(-1))
b, a, s, _ = q.shape
s1 = k.size(2)
attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
if causal:
# Only applied to self attention
assert attn_mask is None, "Causal mask and attn_mask cannot be used together"
temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(q.dtype)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
attn = (q @ k.transpose(-2, -1)) * scale_factor
attn += attn_bias
attn = attn.softmax(dim=-1)
attn = torch.dropout(attn, p=drop_rate, train=True)
x = attn @ v
else:
raise NotImplementedError(f"Unsupported attention mode: {mode}")
x = post_attn_layout(x)
b, s, a, d = x.shape
out = x.reshape(b, s, -1)
return out
class CausalConv1d(nn.Module):
def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", **kwargs):
super().__init__()
self.pad_mode = pad_mode
padding = (kernel_size - 1, 0) # T
self.time_causal_padding = padding
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
return self.conv(x)
class FaceEncoder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.num_heads = num_heads
self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1)
self.norm1 = nn.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.act = nn.SiLU()
self.conv2 = CausalConv1d(1024, 1024, 3, stride=2)
self.conv3 = CausalConv1d(1024, 1024, 3, stride=2)
self.out_proj = nn.Linear(1024, hidden_dim)
self.norm1 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.norm2 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.norm3 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))
def forward(self, x):
x = rearrange(x, "b t c -> b c t")
b, c, t = x.shape
x = self.conv1_local(x)
x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads)
x = self.norm1(x)
x = self.act(x)
x = rearrange(x, "b t c -> b c t")
x = self.conv2(x)
x = rearrange(x, "b c t -> b t c")
x = self.norm2(x)
x = self.act(x)
x = rearrange(x, "b t c -> b c t")
x = self.conv3(x)
x = rearrange(x, "b c t -> b t c")
x = self.norm3(x)
x = self.act(x)
x = self.out_proj(x)
x = rearrange(x, "(b n) t c -> b t n c", b=b)
padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1)
x = torch.cat([x, padding], dim=-2)
x_local = x.clone()
return x_local
class RMSNorm(nn.Module):
def __init__(
self,
dim: int,
elementwise_affine=True,
eps: float = 1e-6,
device=None,
dtype=None,
):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
output = self._norm(x.float()).type_as(x)
if hasattr(self, "weight"):
output = output * self.weight
return output
def get_norm_layer(norm_layer):
"""
Get the normalization layer.
Args:
norm_layer (str): The type of normalization layer.
Returns:
norm_layer (nn.Module): The normalization layer.
"""
if norm_layer == "layer":
return nn.LayerNorm
elif norm_layer == "rms":
return RMSNorm
else:
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
class FaceAdapter(nn.Module):
def __init__(
self,
hidden_dim: int,
heads_num: int,
qk_norm: bool = True,
qk_norm_type: str = "rms",
num_adapter_layers: int = 1,
dtype=None,
device=None,
):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.hidden_size = hidden_dim
self.heads_num = heads_num
self.fuser_blocks = nn.ModuleList(
[
FaceBlock(
self.hidden_size,
self.heads_num,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
**factory_kwargs,
)
for _ in range(num_adapter_layers)
]
)
def forward(
self,
x: torch.Tensor,
motion_embed: torch.Tensor,
idx: int,
freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None,
freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None,
) -> torch.Tensor:
return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k)
class FaceBlock(nn.Module):
def __init__(
self,
hidden_size: int,
heads_num: int,
qk_norm: bool = True,
qk_norm_type: str = "rms",
qk_scale: float = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.deterministic = False
self.hidden_size = hidden_size
self.heads_num = heads_num
head_dim = hidden_size // heads_num
self.scale = qk_scale or head_dim**-0.5
self.linear1_kv = nn.Linear(hidden_size, hidden_size * 2, **factory_kwargs)
self.linear1_q = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
self.linear2 = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
qk_norm_layer = get_norm_layer(qk_norm_type)
self.q_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
)
self.k_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
)
self.pre_norm_feat = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.pre_norm_motion = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
def forward(
self,
x: torch.Tensor,
motion_vec: torch.Tensor,
motion_mask: Optional[torch.Tensor] = None,
use_context_parallel=False,
) -> torch.Tensor:
B, T, N, C = motion_vec.shape
T_comp = T
x_motion = self.pre_norm_motion(motion_vec)
x_feat = self.pre_norm_feat(x)
kv = self.linear1_kv(x_motion)
q = self.linear1_q(x_feat)
k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.heads_num)
q = rearrange(q, "B S (H D) -> B S H D", H=self.heads_num)
# Apply QK-Norm if needed.
q = self.q_norm(q).to(v)
k = self.k_norm(k).to(v)
k = rearrange(k, "B L N H D -> (B L) N H D")
v = rearrange(v, "B L N H D -> (B L) N H D")
if use_context_parallel:
q = gather_forward(q, dim=1)
q = rearrange(q, "B (L S) H D -> (B L) S H D", L=T_comp)
# Compute attention.
attn = attention(
q,
k,
v,
max_seqlen_q=q.shape[1],
batch_size=q.shape[0],
)
attn = rearrange(attn, "(B L) S C -> B (L S) C", L=T_comp)
if use_context_parallel:
attn = torch.chunk(attn, get_world_size(), dim=1)[get_rank()]
output = self.linear2(attn)
if motion_mask is not None:
output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze(-1)
return output
\ No newline at end of file
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import math
import types
from copy import deepcopy
from einops import rearrange
from typing import List
import numpy as np
import torch
import torch.cuda.amp as amp
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from diffusers.loaders import PeftAdapterMixin
from ...distributed.sequence_parallel import (
distributed_attention,
gather_forward,
get_rank,
get_world_size,
)
from ..model import (
Head,
WanAttentionBlock,
WanLayerNorm,
WanRMSNorm,
WanModel,
WanSelfAttention,
flash_attention,
rope_params,
sinusoidal_embedding_1d,
rope_apply
)
from .face_blocks import FaceEncoder, FaceAdapter
from .motion_encoder import Generator
class HeadAnimate(Head):
def forward(self, x, e):
"""
Args:
x(Tensor): Shape [B, L1, C]
e(Tensor): Shape [B, L1, C]
"""
assert e.dtype == torch.float32
with amp.autocast(dtype=torch.float32):
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
return x
class WanAnimateSelfAttention(WanSelfAttention):
def forward(self, x, seq_lens, grid_sizes, freqs):
"""
Args:
x(Tensor): Shape [B, L, num_heads, C / num_heads]
seq_lens(Tensor): Shape [B]
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
# query, key, value function
def qkv_fn(x):
q = self.norm_q(self.q(x)).view(b, s, n, d)
k = self.norm_k(self.k(x)).view(b, s, n, d)
v = self.v(x).view(b, s, n, d)
return q, k, v
q, k, v = qkv_fn(x)
x = flash_attention(
q=rope_apply(q, grid_sizes, freqs),
k=rope_apply(k, grid_sizes, freqs),
v=v,
k_lens=seq_lens,
window_size=self.window_size)
# output
x = x.flatten(2)
x = self.o(x)
return x
class WanAnimateCrossAttention(WanSelfAttention):
def __init__(
self,
dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
eps=1e-6,
use_img_emb=True
):
super().__init__(
dim,
num_heads,
window_size,
qk_norm,
eps
)
self.use_img_emb = use_img_emb
if use_img_emb:
self.k_img = nn.Linear(dim, dim)
self.v_img = nn.Linear(dim, dim)
self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
def forward(self, x, context, context_lens):
"""
x: [B, L1, C].
context: [B, L2, C].
context_lens: [B].
"""
if self.use_img_emb:
context_img = context[:, :257]
context = context[:, 257:]
else:
context = context
b, n, d = x.size(0), self.num_heads, self.head_dim
# compute query, key, value
q = self.norm_q(self.q(x)).view(b, -1, n, d)
k = self.norm_k(self.k(context)).view(b, -1, n, d)
v = self.v(context).view(b, -1, n, d)
if self.use_img_emb:
k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
v_img = self.v_img(context_img).view(b, -1, n, d)
img_x = flash_attention(q, k_img, v_img, k_lens=None)
# compute attention
x = flash_attention(q, k, v, k_lens=context_lens)
# output
x = x.flatten(2)
if self.use_img_emb:
img_x = img_x.flatten(2)
x = x + img_x
x = self.o(x)
return x
class WanAnimateAttentionBlock(nn.Module):
def __init__(self,
dim,
ffn_dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
use_img_emb=True):
super().__init__()
self.dim = dim
self.ffn_dim = ffn_dim
self.num_heads = num_heads
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
# layers
self.norm1 = WanLayerNorm(dim, eps)
self.self_attn = WanAnimateSelfAttention(dim, num_heads, window_size, qk_norm, eps)
self.norm3 = WanLayerNorm(
dim, eps, elementwise_affine=True
) if cross_attn_norm else nn.Identity()
self.cross_attn = WanAnimateCrossAttention(dim, num_heads, (-1, -1), qk_norm, eps, use_img_emb=use_img_emb)
self.norm2 = WanLayerNorm(dim, eps)
self.ffn = nn.Sequential(
nn.Linear(dim, ffn_dim),
nn.GELU(approximate='tanh'),
nn.Linear(ffn_dim, dim)
)
# modulation
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim ** 0.5)
def forward(
self,
x,
e,
seq_lens,
grid_sizes,
freqs,
context,
context_lens,
):
"""
Args:
x(Tensor): Shape [B, L, C]
e(Tensor): Shape [B, L1, 6, C]
seq_lens(Tensor): Shape [B], length of each sequence in batch
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
assert e.dtype == torch.float32
with amp.autocast(dtype=torch.float32):
e = (self.modulation + e).chunk(6, dim=1)
assert e[0].dtype == torch.float32
# self-attention
y = self.self_attn(
self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes, freqs
)
with amp.autocast(dtype=torch.float32):
x = x + y * e[2]
# cross-attention & ffn function
def cross_attn_ffn(x, context, context_lens, e):
x = x + self.cross_attn(self.norm3(x), context, context_lens)
y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
with amp.autocast(dtype=torch.float32):
x = x + y * e[5]
return x
x = cross_attn_ffn(x, context, context_lens, e)
return x
class MLPProj(torch.nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.proj = torch.nn.Sequential(
torch.nn.LayerNorm(in_dim),
torch.nn.Linear(in_dim, in_dim),
torch.nn.GELU(),
torch.nn.Linear(in_dim, out_dim),
torch.nn.LayerNorm(out_dim),
)
def forward(self, image_embeds):
clip_extra_context_tokens = self.proj(image_embeds)
return clip_extra_context_tokens
class WanAnimateModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
_no_split_modules = ['WanAttentionBlock']
@register_to_config
def __init__(self,
patch_size=(1, 2, 2),
text_len=512,
in_dim=36,
dim=5120,
ffn_dim=13824,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=40,
num_layers=40,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
motion_encoder_dim=512,
use_context_parallel=False,
use_img_emb=True):
super().__init__()
self.patch_size = patch_size
self.text_len = text_len
self.in_dim = in_dim
self.dim = dim
self.ffn_dim = ffn_dim
self.freq_dim = freq_dim
self.text_dim = text_dim
self.out_dim = out_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
self.motion_encoder_dim = motion_encoder_dim
self.use_context_parallel = use_context_parallel
self.use_img_emb = use_img_emb
# embeddings
self.patch_embedding = nn.Conv3d(
in_dim, dim, kernel_size=patch_size, stride=patch_size)
self.pose_patch_embedding = nn.Conv3d(
16, dim, kernel_size=patch_size, stride=patch_size
)
self.text_embedding = nn.Sequential(
nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
nn.Linear(dim, dim))
self.time_embedding = nn.Sequential(
nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
# blocks
self.blocks = nn.ModuleList([
WanAnimateAttentionBlock(dim, ffn_dim, num_heads, window_size, qk_norm,
cross_attn_norm, eps, use_img_emb) for _ in range(num_layers)
])
# head
self.head = HeadAnimate(dim, out_dim, patch_size, eps)
# buffers (don't use register_buffer otherwise dtype will be changed in to())
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
d = dim // num_heads
self.freqs = torch.cat([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6))
], dim=1)
self.img_emb = MLPProj(1280, dim)
# initialize weights
self.init_weights()
self.motion_encoder = Generator(size=512, style_dim=512, motion_dim=20)
self.face_adapter = FaceAdapter(
heads_num=self.num_heads,
hidden_dim=self.dim,
num_adapter_layers=self.num_layers // 5,
)
self.face_encoder = FaceEncoder(
in_dim=motion_encoder_dim,
hidden_dim=self.dim,
num_heads=4,
)
def after_patch_embedding(self, x: List[torch.Tensor], pose_latents, face_pixel_values):
pose_latents = [self.pose_patch_embedding(u.unsqueeze(0)) for u in pose_latents]
for x_, pose_latents_ in zip(x, pose_latents):
x_[:, :, 1:] += pose_latents_
b,c,T,h,w = face_pixel_values.shape
face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w")
encode_bs = 8
face_pixel_values_tmp = []
for i in range(math.ceil(face_pixel_values.shape[0]/encode_bs)):
face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i*encode_bs:(i+1)*encode_bs]))
motion_vec = torch.cat(face_pixel_values_tmp)
motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T)
motion_vec = self.face_encoder(motion_vec)
B, L, H, C = motion_vec.shape
pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec)
motion_vec = torch.cat([pad_face, motion_vec], dim=1)
return x, motion_vec
def after_transformer_block(self, block_idx, x, motion_vec, motion_masks=None):
if block_idx % 5 == 0:
adapter_args = [x, motion_vec, motion_masks, self.use_context_parallel]
residual_out = self.face_adapter.fuser_blocks[block_idx // 5](*adapter_args)
x = residual_out + x
return x
def forward(
self,
x,
t,
clip_fea,
context,
seq_len,
y=None,
pose_latents=None,
face_pixel_values=None
):
# params
device = self.patch_embedding.weight.device
if self.freqs.device != device:
self.freqs = self.freqs.to(device)
if y is not None:
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
# embeddings
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
x, motion_vec = self.after_patch_embedding(x, pose_latents, face_pixel_values)
grid_sizes = torch.stack(
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
x = [u.flatten(2).transpose(1, 2) for u in x]
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
assert seq_lens.max() <= seq_len
x = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
dim=1) for u in x
])
# time embeddings
with amp.autocast(dtype=torch.float32):
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t).float()
)
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
assert e.dtype == torch.float32 and e0.dtype == torch.float32
# context
context_lens = None
context = self.text_embedding(
torch.stack([
torch.cat(
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
for u in context
]))
if self.use_img_emb:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
# arguments
kwargs = dict(
e=e0,
seq_lens=seq_lens,
grid_sizes=grid_sizes,
freqs=self.freqs,
context=context,
context_lens=context_lens)
if self.use_context_parallel:
x = torch.chunk(x, get_world_size(), dim=1)[get_rank()]
for idx, block in enumerate(self.blocks):
x = block(x, **kwargs)
x = self.after_transformer_block(idx, x, motion_vec)
# head
x = self.head(x, e)
if self.use_context_parallel:
x = gather_forward(x, dim=1)
# unpatchify
x = self.unpatchify(x, grid_sizes)
return [u.float() for u in x]
def unpatchify(self, x, grid_sizes):
r"""
Reconstruct video tensors from patch embeddings.
Args:
x (List[Tensor]):
List of patchified features, each with shape [L, C_out * prod(patch_size)]
grid_sizes (Tensor):
Original spatial-temporal grid dimensions before patching,
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
Returns:
List[Tensor]:
Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
"""
c = self.out_dim
out = []
for u, v in zip(x, grid_sizes.tolist()):
u = u[:math.prod(v)].view(*v, *self.patch_size, c)
u = torch.einsum('fhwpqrc->cfphqwr', u)
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
out.append(u)
return out
def init_weights(self):
r"""
Initialize model parameters using Xavier initialization.
"""
# basic init
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
# init embeddings
nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
for m in self.text_embedding.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=.02)
for m in self.time_embedding.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=.02)
# init output layer
nn.init.zeros_(self.head.head.weight)
This diff is collapsed.
# Wan-animate Preprocessing User Guider
## 1. Introductions
Wan-animate offers two generation modes: `animation` and `replacement`. While both modes extract the skeleton from the reference video, they each have a distinct preprocessing pipeline.
### 1.1 Animation Mode
In this mode, it is highly recommended to enable pose retargeting, especially if the body proportions of the reference and driving characters are dissimilar.
- A simplified version of pose retargeting pipeline is provided to help developers quickly implement this functionality.
- **NOTE:** Due to the potential complexity of input data, the results from this simplified retargeting version are NOT guaranteed to be perfect. It is strongly advised to verify the preprocessing results before proceeding.
- Community contributions to improve on this feature are welcome.
### 1.2 Replacement Mode
- Pose retargeting is DISABLED by default in this mode. This is a deliberate choice to account for potential spatial interactions between the character and the environment.
- **WARNING**: If there is a significant mismatch in body proportions between the reference and driving characters, artifacts or deformations may appear in the final output.
- A simplified version for extracting the character's mask is also provided.
- **WARNING:** This mask extraction process is designed for **single-person videos ONLY** and may produce incorrect results or fail in multi-person videos (incorrect pose tracking). For multi-person video, users are required to either develop their own solution or integrate a suitable open-source tool.
---
## 2. Preprocessing Instructions and Recommendations
### 2.1 Basic Usage
- The preprocessing process requires some additional models, including pose detection (mandatory), and mask extraction and image editing models (optional, as needed). Place them according to the following directory structure:
```
/path/to/your/ckpt_path/
├── det/
│ └── yolov10m.onnx
├── pose2d/
│ └── vitpose_h_wholebody.onnx
├── sam2/
│ └── sam2_hiera_large.pt
└── FLUX.1-Kontext-dev/
```
- `video_path`, `refer_path`, and `save_path` correspond to the paths for the input driving video, the character image, and the preprocessed results.
- When using `animation` mode, two videos, `src_face.mp4` and `src_pose.mp4`, will be generated in `save_path`. When using `replacement` mode, two additional videos, `src_bg.mp4` and `src_mask.mp4`, will also be generated.
- The `resolution_area` parameter determines the resolution for both preprocessing and the generation model. Its size is determined by pixel area.
- The `fps` parameter can specify the frame rate for video processing. A lower frame rate can improve generation efficiency, but may cause stuttering or choppiness.
---
### 2.2 Animation Mode
- We support three forms: not using pose retargeting, using basic pose retargeting, and using enhanced pose retargeting based on the `FLUX.1-Kontext-dev` image editing model. These are specified via the `retarget_flag` and `use_flux` parameters.
- Specifying `retarget_flag` to use basic pose retargeting requires ensuring that both the reference character and the character in the first frame of the driving video are in a front-facing, stretched pose.
- Other than that, we recommend using enhanced pose retargeting by specifying both `retarget_flag` and `use_flux`. **NOTE:** Due to the limited capabilities of `FLUX.1-Kontext-dev`, it is NOT guaranteed to produce the expected results (e.g., consistency is not maintained, the pose is incorrect, etc.). It is recommended to check the intermediate results as well as the finally generated pose video; both are stored in `save_path`. Of course, users can also use a better image editing model, or explore the prompts for Flux on their own.
---
### 2.3 Replacement Mode
- Specifying `replace_flag` to enable data preprocessing for this mode. The preprocessing will additionally process a mask for the character in the video, and its size and shape can be adjusted by specifying some parameters.
- `iterations` and `k` can make the mask larger, covering more area.
- `w_len` and `h_len` can adjust the mask's shape. Smaller values will make the outline coarser, while larger values will make it finer.
- A smaller, finer-contoured mask can allow for more of the original background to be preserved, but may potentially limit the character's generation area (considering potential appearance differences, this can lead to some shape leakage). A larger, coarser mask can allow the character generation to be more flexible and consistent, but because it includes more of the background, it might affect the background's consistency. We recommend users to adjust the relevant parameters based on their specific input data.
\ No newline at end of file
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from .process_pipepline import ProcessPipeline
from .video_predictor import SAM2VideoPredictor
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment