Commit 1336a33d authored by zzg_666's avatar zzg_666
Browse files

wan2.2

parents
# Copyright (c) 2025. Your modifications here.
# A wrapper for sam2 functions
from collections import OrderedDict
import torch
from tqdm import tqdm
from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base
from sam2.sam2_video_predictor import SAM2VideoPredictor as _SAM2VideoPredictor
from sam2.utils.misc import concat_points, fill_holes_in_mask_scores
from sam_utils import load_video_frames_v2, load_video_frames
class SAM2VideoPredictor(_SAM2VideoPredictor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@torch.inference_mode()
def init_state(
self,
video_path,
offload_video_to_cpu=False,
offload_state_to_cpu=False,
async_loading_frames=False,
frame_names=None
):
"""Initialize a inference state."""
images, video_height, video_width = load_video_frames(
video_path=video_path,
image_size=self.image_size,
offload_video_to_cpu=offload_video_to_cpu,
async_loading_frames=async_loading_frames,
frame_names=frame_names
)
inference_state = {}
inference_state["images"] = images
inference_state["num_frames"] = len(images)
# whether to offload the video frames to CPU memory
# turning on this option saves the GPU memory with only a very small overhead
inference_state["offload_video_to_cpu"] = offload_video_to_cpu
# whether to offload the inference state to CPU memory
# turning on this option saves the GPU memory at the cost of a lower tracking fps
# (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object
# and from 24 to 21 when tracking two objects)
inference_state["offload_state_to_cpu"] = offload_state_to_cpu
# the original video height and width, used for resizing final output scores
inference_state["video_height"] = video_height
inference_state["video_width"] = video_width
inference_state["device"] = torch.device("cuda")
if offload_state_to_cpu:
inference_state["storage_device"] = torch.device("cpu")
else:
inference_state["storage_device"] = torch.device("cuda")
# inputs on each frame
inference_state["point_inputs_per_obj"] = {}
inference_state["mask_inputs_per_obj"] = {}
# visual features on a small number of recently visited frames for quick interactions
inference_state["cached_features"] = {}
# values that don't change across frames (so we only need to hold one copy of them)
inference_state["constants"] = {}
# mapping between client-side object id and model-side object index
inference_state["obj_id_to_idx"] = OrderedDict()
inference_state["obj_idx_to_id"] = OrderedDict()
inference_state["obj_ids"] = []
# A storage to hold the model's tracking results and states on each frame
inference_state["output_dict"] = {
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
}
# Slice (view) of each object tracking results, sharing the same memory with "output_dict"
inference_state["output_dict_per_obj"] = {}
# A temporary storage to hold new outputs when user interact with a frame
# to add clicks or mask (it's merged into "output_dict" before propagation starts)
inference_state["temp_output_dict_per_obj"] = {}
# Frames that already holds consolidated outputs from click or mask inputs
# (we directly use their consolidated outputs during tracking)
inference_state["consolidated_frame_inds"] = {
"cond_frame_outputs": set(), # set containing frame indices
"non_cond_frame_outputs": set(), # set containing frame indices
}
# metadata for each tracking frame (e.g. which direction it's tracked)
inference_state["tracking_has_started"] = False
inference_state["frames_already_tracked"] = {}
# Warm up the visual backbone and cache the image feature on frame 0
self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
return inference_state
@torch.inference_mode()
def init_state_v2(
self,
frames,
offload_video_to_cpu=False,
offload_state_to_cpu=False,
async_loading_frames=False,
frame_names=None
):
"""Initialize a inference state."""
images, video_height, video_width = load_video_frames_v2(
frames=frames,
image_size=self.image_size,
offload_video_to_cpu=offload_video_to_cpu,
async_loading_frames=async_loading_frames,
frame_names=frame_names
)
inference_state = {}
inference_state["images"] = images
inference_state["num_frames"] = len(images)
# whether to offload the video frames to CPU memory
# turning on this option saves the GPU memory with only a very small overhead
inference_state["offload_video_to_cpu"] = offload_video_to_cpu
# whether to offload the inference state to CPU memory
# turning on this option saves the GPU memory at the cost of a lower tracking fps
# (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object
# and from 24 to 21 when tracking two objects)
inference_state["offload_state_to_cpu"] = offload_state_to_cpu
# the original video height and width, used for resizing final output scores
inference_state["video_height"] = video_height
inference_state["video_width"] = video_width
inference_state["device"] = torch.device("cuda")
if offload_state_to_cpu:
inference_state["storage_device"] = torch.device("cpu")
else:
inference_state["storage_device"] = torch.device("cuda")
# inputs on each frame
inference_state["point_inputs_per_obj"] = {}
inference_state["mask_inputs_per_obj"] = {}
# visual features on a small number of recently visited frames for quick interactions
inference_state["cached_features"] = {}
# values that don't change across frames (so we only need to hold one copy of them)
inference_state["constants"] = {}
# mapping between client-side object id and model-side object index
inference_state["obj_id_to_idx"] = OrderedDict()
inference_state["obj_idx_to_id"] = OrderedDict()
inference_state["obj_ids"] = []
# A storage to hold the model's tracking results and states on each frame
inference_state["output_dict"] = {
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
}
# Slice (view) of each object tracking results, sharing the same memory with "output_dict"
inference_state["output_dict_per_obj"] = {}
# A temporary storage to hold new outputs when user interact with a frame
# to add clicks or mask (it's merged into "output_dict" before propagation starts)
inference_state["temp_output_dict_per_obj"] = {}
# Frames that already holds consolidated outputs from click or mask inputs
# (we directly use their consolidated outputs during tracking)
inference_state["consolidated_frame_inds"] = {
"cond_frame_outputs": set(), # set containing frame indices
"non_cond_frame_outputs": set(), # set containing frame indices
}
# metadata for each tracking frame (e.g. which direction it's tracked)
inference_state["tracking_has_started"] = False
inference_state["frames_already_tracked"] = {}
# Warm up the visual backbone and cache the image feature on frame 0
self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
return inference_state
\ No newline at end of file
# Modified from transformers.models.xlm_roberta.modeling_xlm_roberta
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = ['XLMRoberta', 'xlm_roberta_large']
class SelfAttention(nn.Module):
def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.eps = eps
# layers
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask):
"""
x: [B, L, C].
"""
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
# compute query, key, value
q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
# compute attention
p = self.dropout.p if self.training else 0.0
x = F.scaled_dot_product_attention(q, k, v, mask, p)
x = x.permute(0, 2, 1, 3).reshape(b, s, c)
# output
x = self.o(x)
x = self.dropout(x)
return x
class AttentionBlock(nn.Module):
def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.post_norm = post_norm
self.eps = eps
# layers
self.attn = SelfAttention(dim, num_heads, dropout, eps)
self.norm1 = nn.LayerNorm(dim, eps=eps)
self.ffn = nn.Sequential(
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim),
nn.Dropout(dropout))
self.norm2 = nn.LayerNorm(dim, eps=eps)
def forward(self, x, mask):
if self.post_norm:
x = self.norm1(x + self.attn(x, mask))
x = self.norm2(x + self.ffn(x))
else:
x = x + self.attn(self.norm1(x), mask)
x = x + self.ffn(self.norm2(x))
return x
class XLMRoberta(nn.Module):
"""
XLMRobertaModel with no pooler and no LM head.
"""
def __init__(self,
vocab_size=250002,
max_seq_len=514,
type_size=1,
pad_id=1,
dim=1024,
num_heads=16,
num_layers=24,
post_norm=True,
dropout=0.1,
eps=1e-5):
super().__init__()
self.vocab_size = vocab_size
self.max_seq_len = max_seq_len
self.type_size = type_size
self.pad_id = pad_id
self.dim = dim
self.num_heads = num_heads
self.num_layers = num_layers
self.post_norm = post_norm
self.eps = eps
# embeddings
self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
self.type_embedding = nn.Embedding(type_size, dim)
self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
self.dropout = nn.Dropout(dropout)
# blocks
self.blocks = nn.ModuleList([
AttentionBlock(dim, num_heads, post_norm, dropout, eps)
for _ in range(num_layers)
])
# norm layer
self.norm = nn.LayerNorm(dim, eps=eps)
def forward(self, ids):
"""
ids: [B, L] of torch.LongTensor.
"""
b, s = ids.shape
mask = ids.ne(self.pad_id).long()
# embeddings
x = self.token_embedding(ids) + \
self.type_embedding(torch.zeros_like(ids)) + \
self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
if self.post_norm:
x = self.norm(x)
x = self.dropout(x)
# blocks
mask = torch.where(
mask.view(b, 1, 1, s).gt(0), 0.0,
torch.finfo(x.dtype).min)
for block in self.blocks:
x = block(x, mask)
# output
if not self.post_norm:
x = self.norm(x)
return x
def xlm_roberta_large(pretrained=False,
return_tokenizer=False,
device='cpu',
**kwargs):
"""
XLMRobertaLarge adapted from Huggingface.
"""
# params
cfg = dict(
vocab_size=250002,
max_seq_len=514,
type_size=1,
pad_id=1,
dim=1024,
num_heads=16,
num_layers=24,
post_norm=True,
dropout=0.1,
eps=1e-5)
cfg.update(**kwargs)
# init a model on device
with torch.device(device):
model = XLMRoberta(**cfg)
return model
\ No newline at end of file
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_3_AVAILABLE = False
try:
import flash_attn
FLASH_ATTN_2_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_2_AVAILABLE = False
import warnings
__all__ = [
'flash_attention',
'attention',
]
def flash_attention(
q,
k,
v,
q_lens=None,
k_lens=None,
dropout_p=0.,
softmax_scale=None,
q_scale=None,
causal=False,
window_size=(-1, -1),
deterministic=False,
dtype=torch.bfloat16,
version=None,
):
"""
q: [B, Lq, Nq, C1].
k: [B, Lk, Nk, C1].
v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
q_lens: [B].
k_lens: [B].
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
causal: bool. Whether to apply causal attention mask.
window_size: (left right). If not (-1, -1), apply sliding window local attention.
deterministic: bool. If True, slightly slower and uses more memory.
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
"""
half_dtypes = (torch.float16, torch.bfloat16)
assert dtype in half_dtypes
assert q.device.type == 'cuda' and q.size(-1) <= 256
# params
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
def half(x):
return x if x.dtype in half_dtypes else x.to(dtype)
# preprocess query
if q_lens is None:
q = half(q.flatten(0, 1))
q_lens = torch.tensor(
[lq] * b, dtype=torch.int32).to(
device=q.device, non_blocking=True)
else:
q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
# preprocess key, value
if k_lens is None:
k = half(k.flatten(0, 1))
v = half(v.flatten(0, 1))
k_lens = torch.tensor(
[lk] * b, dtype=torch.int32).to(
device=k.device, non_blocking=True)
else:
k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
q = q.to(v.dtype)
k = k.to(v.dtype)
if q_scale is not None:
q = q * q_scale
if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
warnings.warn(
'Flash attention 3 is not available, use flash attention 2 instead.'
)
# apply attention
if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
# Note: dropout_p, window_size are not supported in FA3 now.
x = flash_attn_interface.flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
0, dtype=torch.int32).to(q.device, non_blocking=True),
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
0, dtype=torch.int32).to(q.device, non_blocking=True),
seqused_q=None,
seqused_k=None,
max_seqlen_q=lq,
max_seqlen_k=lk,
softmax_scale=softmax_scale,
causal=causal,
deterministic=deterministic)[0].unflatten(0, (b, lq))
else:
assert FLASH_ATTN_2_AVAILABLE
x = flash_attn.flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
0, dtype=torch.int32).to(q.device, non_blocking=True),
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
0, dtype=torch.int32).to(q.device, non_blocking=True),
max_seqlen_q=lq,
max_seqlen_k=lk,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
deterministic=deterministic).unflatten(0, (b, lq))
# output
return x.type(out_dtype)
def attention(
q,
k,
v,
q_lens=None,
k_lens=None,
dropout_p=0.,
softmax_scale=None,
q_scale=None,
causal=False,
window_size=(-1, -1),
deterministic=False,
dtype=torch.bfloat16,
fa_version=None,
):
if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
return flash_attention(
q=q,
k=k,
v=v,
q_lens=q_lens,
k_lens=k_lens,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
q_scale=q_scale,
causal=causal,
window_size=window_size,
deterministic=deterministic,
dtype=dtype,
version=fa_version,
)
else:
if q_lens is not None or k_lens is not None:
warnings.warn(
'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
)
attn_mask = None
q = q.transpose(1, 2).to(dtype)
k = k.transpose(1, 2).to(dtype)
v = v.transpose(1, 2).to(dtype)
out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
out = out.transpose(1, 2).contiguous()
return out
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import math
import torch
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from .attention import flash_attention
__all__ = ['WanModel']
def sinusoidal_embedding_1d(dim, position):
# preprocess
assert dim % 2 == 0
half = dim // 2
position = position.type(torch.float64)
# calculation
sinusoid = torch.outer(
position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
return x
@torch.amp.autocast('cuda', enabled=False)
def rope_params(max_seq_len, dim, theta=10000):
assert dim % 2 == 0
freqs = torch.outer(
torch.arange(max_seq_len),
1.0 / torch.pow(theta,
torch.arange(0, dim, 2).to(torch.float64).div(dim)))
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
@torch.amp.autocast('cuda', enabled=False)
def rope_apply(x, grid_sizes, freqs):
n, c = x.size(2), x.size(3) // 2
# split freqs
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
# loop over samples
output = []
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
seq_len = f * h * w
# precompute multipliers
x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(
seq_len, n, -1, 2))
freqs_i = torch.cat([
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
],
dim=-1).reshape(seq_len, 1, -1)
# apply rotary embedding
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
x_i = torch.cat([x_i, x[i, seq_len:]])
# append to collection
output.append(x_i)
return torch.stack(output).float()
class WanRMSNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
r"""
Args:
x(Tensor): Shape [B, L, C]
"""
return self._norm(x.float()).type_as(x) * self.weight
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
class WanLayerNorm(nn.LayerNorm):
def __init__(self, dim, eps=1e-6, elementwise_affine=False):
super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
def forward(self, x):
r"""
Args:
x(Tensor): Shape [B, L, C]
"""
return super().forward(x.float()).type_as(x)
class WanSelfAttention(nn.Module):
def __init__(self,
dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
eps=1e-6):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.window_size = window_size
self.qk_norm = qk_norm
self.eps = eps
# layers
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
def forward(self, x, seq_lens, grid_sizes, freqs):
r"""
Args:
x(Tensor): Shape [B, L, num_heads, C / num_heads]
seq_lens(Tensor): Shape [B]
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
# query, key, value function
def qkv_fn(x):
q = self.norm_q(self.q(x)).view(b, s, n, d)
k = self.norm_k(self.k(x)).view(b, s, n, d)
v = self.v(x).view(b, s, n, d)
return q, k, v
q, k, v = qkv_fn(x)
x = flash_attention(
q=rope_apply(q, grid_sizes, freqs),
k=rope_apply(k, grid_sizes, freqs),
v=v,
k_lens=seq_lens,
window_size=self.window_size)
# output
x = x.flatten(2)
x = self.o(x)
return x
class WanCrossAttention(WanSelfAttention):
def forward(self, x, context, context_lens):
r"""
Args:
x(Tensor): Shape [B, L1, C]
context(Tensor): Shape [B, L2, C]
context_lens(Tensor): Shape [B]
"""
b, n, d = x.size(0), self.num_heads, self.head_dim
# compute query, key, value
q = self.norm_q(self.q(x)).view(b, -1, n, d)
k = self.norm_k(self.k(context)).view(b, -1, n, d)
v = self.v(context).view(b, -1, n, d)
# compute attention
x = flash_attention(q, k, v, k_lens=context_lens)
# output
x = x.flatten(2)
x = self.o(x)
return x
class WanAttentionBlock(nn.Module):
def __init__(self,
dim,
ffn_dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=False,
eps=1e-6):
super().__init__()
self.dim = dim
self.ffn_dim = ffn_dim
self.num_heads = num_heads
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
# layers
self.norm1 = WanLayerNorm(dim, eps)
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
eps)
self.norm3 = WanLayerNorm(
dim, eps,
elementwise_affine=True) if cross_attn_norm else nn.Identity()
self.cross_attn = WanCrossAttention(dim, num_heads, (-1, -1), qk_norm,
eps)
self.norm2 = WanLayerNorm(dim, eps)
self.ffn = nn.Sequential(
nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
nn.Linear(ffn_dim, dim))
# modulation
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
def forward(
self,
x,
e,
seq_lens,
grid_sizes,
freqs,
context,
context_lens,
):
r"""
Args:
x(Tensor): Shape [B, L, C]
e(Tensor): Shape [B, L1, 6, C]
seq_lens(Tensor): Shape [B], length of each sequence in batch
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
assert e.dtype == torch.float32
with torch.amp.autocast('cuda', dtype=torch.float32):
e = (self.modulation.unsqueeze(0) + e).chunk(6, dim=2)
assert e[0].dtype == torch.float32
# self-attention
y = self.self_attn(
self.norm1(x).float() * (1 + e[1].squeeze(2)) + e[0].squeeze(2),
seq_lens, grid_sizes, freqs)
with torch.amp.autocast('cuda', dtype=torch.float32):
x = x + y * e[2].squeeze(2)
# cross-attention & ffn function
def cross_attn_ffn(x, context, context_lens, e):
x = x + self.cross_attn(self.norm3(x), context, context_lens)
y = self.ffn(
self.norm2(x).float() * (1 + e[4].squeeze(2)) + e[3].squeeze(2))
with torch.amp.autocast('cuda', dtype=torch.float32):
x = x + y * e[5].squeeze(2)
return x
x = cross_attn_ffn(x, context, context_lens, e)
return x
class Head(nn.Module):
def __init__(self, dim, out_dim, patch_size, eps=1e-6):
super().__init__()
self.dim = dim
self.out_dim = out_dim
self.patch_size = patch_size
self.eps = eps
# layers
out_dim = math.prod(patch_size) * out_dim
self.norm = WanLayerNorm(dim, eps)
self.head = nn.Linear(dim, out_dim)
# modulation
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
def forward(self, x, e):
r"""
Args:
x(Tensor): Shape [B, L1, C]
e(Tensor): Shape [B, L1, C]
"""
assert e.dtype == torch.float32
with torch.amp.autocast('cuda', dtype=torch.float32):
e = (self.modulation.unsqueeze(0) + e.unsqueeze(2)).chunk(2, dim=2)
x = (
self.head(
self.norm(x) * (1 + e[1].squeeze(2)) + e[0].squeeze(2)))
return x
class WanModel(ModelMixin, ConfigMixin):
r"""
Wan diffusion backbone supporting both text-to-video and image-to-video.
"""
ignore_for_config = [
'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
]
_no_split_modules = ['WanAttentionBlock']
@register_to_config
def __init__(self,
model_type='t2v',
patch_size=(1, 2, 2),
text_len=512,
in_dim=16,
dim=2048,
ffn_dim=8192,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=16,
num_layers=32,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6):
r"""
Initialize the diffusion model backbone.
Args:
model_type (`str`, *optional*, defaults to 't2v'):
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
text_len (`int`, *optional*, defaults to 512):
Fixed length for text embeddings
in_dim (`int`, *optional*, defaults to 16):
Input video channels (C_in)
dim (`int`, *optional*, defaults to 2048):
Hidden dimension of the transformer
ffn_dim (`int`, *optional*, defaults to 8192):
Intermediate dimension in feed-forward network
freq_dim (`int`, *optional*, defaults to 256):
Dimension for sinusoidal time embeddings
text_dim (`int`, *optional*, defaults to 4096):
Input dimension for text embeddings
out_dim (`int`, *optional*, defaults to 16):
Output video channels (C_out)
num_heads (`int`, *optional*, defaults to 16):
Number of attention heads
num_layers (`int`, *optional*, defaults to 32):
Number of transformer blocks
window_size (`tuple`, *optional*, defaults to (-1, -1)):
Window size for local attention (-1 indicates global attention)
qk_norm (`bool`, *optional*, defaults to True):
Enable query/key normalization
cross_attn_norm (`bool`, *optional*, defaults to False):
Enable cross-attention normalization
eps (`float`, *optional*, defaults to 1e-6):
Epsilon value for normalization layers
"""
super().__init__()
assert model_type in ['t2v', 'i2v', 'ti2v', 's2v']
self.model_type = model_type
self.patch_size = patch_size
self.text_len = text_len
self.in_dim = in_dim
self.dim = dim
self.ffn_dim = ffn_dim
self.freq_dim = freq_dim
self.text_dim = text_dim
self.out_dim = out_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
# embeddings
self.patch_embedding = nn.Conv3d(
in_dim, dim, kernel_size=patch_size, stride=patch_size)
self.text_embedding = nn.Sequential(
nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
nn.Linear(dim, dim))
self.time_embedding = nn.Sequential(
nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
# blocks
self.blocks = nn.ModuleList([
WanAttentionBlock(dim, ffn_dim, num_heads, window_size, qk_norm,
cross_attn_norm, eps) for _ in range(num_layers)
])
# head
self.head = Head(dim, out_dim, patch_size, eps)
# buffers (don't use register_buffer otherwise dtype will be changed in to())
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
d = dim // num_heads
self.freqs = torch.cat([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6))
],
dim=1)
# initialize weights
self.init_weights()
def forward(
self,
x,
t,
context,
seq_len,
y=None,
):
r"""
Forward pass through the diffusion model
Args:
x (List[Tensor]):
List of input video tensors, each with shape [C_in, F, H, W]
t (Tensor):
Diffusion timesteps tensor of shape [B]
context (List[Tensor]):
List of text embeddings each with shape [L, C]
seq_len (`int`):
Maximum sequence length for positional encoding
y (List[Tensor], *optional*):
Conditional video inputs for image-to-video mode, same shape as x
Returns:
List[Tensor]:
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
"""
if self.model_type == 'i2v':
assert y is not None
# params
device = self.patch_embedding.weight.device
if self.freqs.device != device:
self.freqs = self.freqs.to(device)
if y is not None:
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
# embeddings
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
grid_sizes = torch.stack(
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
x = [u.flatten(2).transpose(1, 2) for u in x]
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
assert seq_lens.max() <= seq_len
x = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
dim=1) for u in x
])
# time embeddings
if t.dim() == 1:
t = t.expand(t.size(0), seq_len)
with torch.amp.autocast('cuda', dtype=torch.float32):
bt = t.size(0)
t = t.flatten()
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim,
t).unflatten(0, (bt, seq_len)).float())
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
assert e.dtype == torch.float32 and e0.dtype == torch.float32
# context
context_lens = None
context = self.text_embedding(
torch.stack([
torch.cat(
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
for u in context
]))
# arguments
kwargs = dict(
e=e0,
seq_lens=seq_lens,
grid_sizes=grid_sizes,
freqs=self.freqs,
context=context,
context_lens=context_lens)
for block in self.blocks:
x = block(x, **kwargs)
# head
x = self.head(x, e)
# unpatchify
x = self.unpatchify(x, grid_sizes)
return [u.float() for u in x]
def unpatchify(self, x, grid_sizes):
r"""
Reconstruct video tensors from patch embeddings.
Args:
x (List[Tensor]):
List of patchified features, each with shape [L, C_out * prod(patch_size)]
grid_sizes (Tensor):
Original spatial-temporal grid dimensions before patching,
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
Returns:
List[Tensor]:
Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
"""
c = self.out_dim
out = []
for u, v in zip(x, grid_sizes.tolist()):
u = u[:math.prod(v)].view(*v, *self.patch_size, c)
u = torch.einsum('fhwpqrc->cfphqwr', u)
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
out.append(u)
return out
def init_weights(self):
r"""
Initialize model parameters using Xavier initialization.
"""
# basic init
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
# init embeddings
nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
for m in self.text_embedding.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=.02)
for m in self.time_embedding.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=.02)
# init output layer
nn.init.zeros_(self.head.head.weight)
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from .audio_encoder import AudioEncoder
from .model_s2v import WanModel_S2V
__all__ = ['WanModel_S2V', 'AudioEncoder']
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import math
import librosa
import numpy as np
import torch
import torch.nn.functional as F
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
def get_sample_indices(original_fps,
total_frames,
target_fps,
num_sample,
fixed_start=None):
required_duration = num_sample / target_fps
required_origin_frames = int(np.ceil(required_duration * original_fps))
if required_duration > total_frames / original_fps:
raise ValueError("required_duration must be less than video length")
if not fixed_start is None and fixed_start >= 0:
start_frame = fixed_start
else:
max_start = total_frames - required_origin_frames
if max_start < 0:
raise ValueError("video length is too short")
start_frame = np.random.randint(0, max_start + 1)
start_time = start_frame / original_fps
end_time = start_time + required_duration
time_points = np.linspace(start_time, end_time, num_sample, endpoint=False)
frame_indices = np.round(np.array(time_points) * original_fps).astype(int)
frame_indices = np.clip(frame_indices, 0, total_frames - 1)
return frame_indices
def linear_interpolation(features, input_fps, output_fps, output_len=None):
"""
features: shape=[1, T, 512]
input_fps: fps for audio, f_a
output_fps: fps for video, f_m
output_len: video length
"""
features = features.transpose(1, 2) # [1, 512, T]
seq_len = features.shape[2] / float(input_fps) # T/f_a
if output_len is None:
output_len = int(seq_len * output_fps) # f_m*T/f_a
output_features = F.interpolate(
features, size=output_len, align_corners=True,
mode='linear') # [1, 512, output_len]
return output_features.transpose(1, 2) # [1, output_len, 512]
class AudioEncoder():
def __init__(self, device='cpu', model_id="facebook/wav2vec2-base-960h"):
# load pretrained model
self.processor = Wav2Vec2Processor.from_pretrained(model_id)
self.model = Wav2Vec2ForCTC.from_pretrained(model_id)
self.model = self.model.to(device)
self.video_rate = 30
def extract_audio_feat(self,
audio_path,
return_all_layers=False,
dtype=torch.float32):
audio_input, sample_rate = librosa.load(audio_path, sr=16000)
input_values = self.processor(
audio_input, sampling_rate=sample_rate,
return_tensors="pt").input_values
# INFERENCE
# retrieve logits & take argmax
res = self.model(
input_values.to(self.model.device), output_hidden_states=True)
if return_all_layers:
feat = torch.cat(res.hidden_states)
else:
feat = res.hidden_states[-1]
feat = linear_interpolation(
feat, input_fps=50, output_fps=self.video_rate)
z = feat.to(dtype) # Encoding for the motion
return z
def get_audio_embed_bucket(self,
audio_embed,
stride=2,
batch_frames=12,
m=2):
num_layers, audio_frame_num, audio_dim = audio_embed.shape
if num_layers > 1:
return_all_layers = True
else:
return_all_layers = False
min_batch_num = int(audio_frame_num / (batch_frames * stride)) + 1
bucket_num = min_batch_num * batch_frames
batch_idx = [stride * i for i in range(bucket_num)]
batch_audio_eb = []
for bi in batch_idx:
if bi < audio_frame_num:
audio_sample_stride = 2
chosen_idx = list(
range(bi - m * audio_sample_stride,
bi + (m + 1) * audio_sample_stride,
audio_sample_stride))
chosen_idx = [0 if c < 0 else c for c in chosen_idx]
chosen_idx = [
audio_frame_num - 1 if c >= audio_frame_num else c
for c in chosen_idx
]
if return_all_layers:
frame_audio_embed = audio_embed[:, chosen_idx].flatten(
start_dim=-2, end_dim=-1)
else:
frame_audio_embed = audio_embed[0][chosen_idx].flatten()
else:
frame_audio_embed = \
torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \
else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device)
batch_audio_eb.append(frame_audio_embed)
batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb],
dim=0)
return batch_audio_eb, min_batch_num
def get_audio_embed_bucket_fps(self,
audio_embed,
fps=16,
batch_frames=81,
m=0):
num_layers, audio_frame_num, audio_dim = audio_embed.shape
if num_layers > 1:
return_all_layers = True
else:
return_all_layers = False
scale = self.video_rate / fps
min_batch_num = int(audio_frame_num / (batch_frames * scale)) + 1
bucket_num = min_batch_num * batch_frames
padd_audio_num = math.ceil(min_batch_num * batch_frames / fps *
self.video_rate) - audio_frame_num
batch_idx = get_sample_indices(
original_fps=self.video_rate,
total_frames=audio_frame_num + padd_audio_num,
target_fps=fps,
num_sample=bucket_num,
fixed_start=0)
batch_audio_eb = []
audio_sample_stride = int(self.video_rate / fps)
for bi in batch_idx:
if bi < audio_frame_num:
chosen_idx = list(
range(bi - m * audio_sample_stride,
bi + (m + 1) * audio_sample_stride,
audio_sample_stride))
chosen_idx = [0 if c < 0 else c for c in chosen_idx]
chosen_idx = [
audio_frame_num - 1 if c >= audio_frame_num else c
for c in chosen_idx
]
if return_all_layers:
frame_audio_embed = audio_embed[:, chosen_idx].flatten(
start_dim=-2, end_dim=-1)
else:
frame_audio_embed = audio_embed[0][chosen_idx].flatten()
else:
frame_audio_embed = \
torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \
else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device)
batch_audio_eb.append(frame_audio_embed)
batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb],
dim=0)
return batch_audio_eb, min_batch_num
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import math
from typing import Tuple, Union
import torch
import torch.cuda.amp as amp
import torch.nn as nn
from diffusers.models.attention import AdaLayerNorm
from ..model import WanAttentionBlock, WanCrossAttention
from .auxi_blocks import MotionEncoder_tc
class CausalAudioEncoder(nn.Module):
def __init__(self,
dim=5120,
num_layers=25,
out_dim=2048,
video_rate=8,
num_token=4,
need_global=False):
super().__init__()
self.encoder = MotionEncoder_tc(
in_dim=dim,
hidden_dim=out_dim,
num_heads=num_token,
need_global=need_global)
weight = torch.ones((1, num_layers, 1, 1)) * 0.01
self.weights = torch.nn.Parameter(weight)
self.act = torch.nn.SiLU()
def forward(self, features):
with amp.autocast(dtype=torch.float32):
# features B * num_layers * dim * video_length
weights = self.act(self.weights)
weights_sum = weights.sum(dim=1, keepdims=True)
weighted_feat = ((features * weights) / weights_sum).sum(
dim=1) # b dim f
weighted_feat = weighted_feat.permute(0, 2, 1) # b f dim
res = self.encoder(weighted_feat) # b f n dim
return res # b f n dim
class AudioCrossAttention(WanCrossAttention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
class AudioInjector_WAN(nn.Module):
def __init__(self,
all_modules,
all_modules_names,
dim=2048,
num_heads=32,
inject_layer=[0, 27],
root_net=None,
enable_adain=False,
adain_dim=2048,
need_adain_ont=False):
super().__init__()
num_injector_layers = len(inject_layer)
self.injected_block_id = {}
audio_injector_id = 0
for mod_name, mod in zip(all_modules_names, all_modules):
if isinstance(mod, WanAttentionBlock):
for inject_id in inject_layer:
if f'transformer_blocks.{inject_id}' in mod_name:
self.injected_block_id[inject_id] = audio_injector_id
audio_injector_id += 1
self.injector = nn.ModuleList([
AudioCrossAttention(
dim=dim,
num_heads=num_heads,
qk_norm=True,
) for _ in range(audio_injector_id)
])
self.injector_pre_norm_feat = nn.ModuleList([
nn.LayerNorm(
dim,
elementwise_affine=False,
eps=1e-6,
) for _ in range(audio_injector_id)
])
self.injector_pre_norm_vec = nn.ModuleList([
nn.LayerNorm(
dim,
elementwise_affine=False,
eps=1e-6,
) for _ in range(audio_injector_id)
])
if enable_adain:
self.injector_adain_layers = nn.ModuleList([
AdaLayerNorm(
output_dim=dim * 2, embedding_dim=adain_dim, chunk_dim=1)
for _ in range(audio_injector_id)
])
if need_adain_ont:
self.injector_adain_output_layers = nn.ModuleList(
[nn.Linear(dim, dim) for _ in range(audio_injector_id)])
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import importlib.metadata
import math
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models import ModelMixin
from diffusers.utils import is_torch_version, logging
from einops import rearrange
try:
from flash_attn import flash_attn_func, flash_attn_qkvpacked_func
except ImportError:
flash_attn_func = None
MEMORY_LAYOUT = {
"flash": (
lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
lambda x: x,
),
"torch": (
lambda x: x.transpose(1, 2),
lambda x: x.transpose(1, 2),
),
"vanilla": (
lambda x: x.transpose(1, 2),
lambda x: x.transpose(1, 2),
),
}
def attention(
q,
k,
v,
mode="flash",
drop_rate=0,
attn_mask=None,
causal=False,
max_seqlen_q=None,
batch_size=1,
):
"""
Perform QKV self attention.
Args:
q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
k (torch.Tensor): Key tensor with shape [b, s1, a, d]
v (torch.Tensor): Value tensor with shape [b, s1, a, d]
mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
drop_rate (float): Dropout rate in attention map. (default: 0)
attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
(default: None)
causal (bool): Whether to use causal attention. (default: False)
cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
used to index into q.
cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
used to index into kv.
max_seqlen_q (int): The maximum sequence length in the batch of q.
max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
Returns:
torch.Tensor: Output tensor after self attention with shape [b, s, ad]
"""
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
if mode == "torch":
if attn_mask is not None and attn_mask.dtype != torch.bool:
attn_mask = attn_mask.to(q.dtype)
x = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
elif mode == "flash":
x = flash_attn_func(
q,
k,
v,
)
# x with shape [(bxs), a, d]
x = x.view(batch_size, max_seqlen_q, x.shape[-2],
x.shape[-1]) # reshape x to [b, s, a, d]
elif mode == "vanilla":
scale_factor = 1 / math.sqrt(q.size(-1))
b, a, s, _ = q.shape
s1 = k.size(2)
attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
if causal:
# Only applied to self attention
assert (
attn_mask
is None), "Causal mask and attn_mask cannot be used together"
temp_mask = torch.ones(
b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(q.dtype)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
# TODO: Maybe force q and k to be float32 to avoid numerical overflow
attn = (q @ k.transpose(-2, -1)) * scale_factor
attn += attn_bias
attn = attn.softmax(dim=-1)
attn = torch.dropout(attn, p=drop_rate, train=True)
x = attn @ v
else:
raise NotImplementedError(f"Unsupported attention mode: {mode}")
x = post_attn_layout(x)
b, s, a, d = x.shape
out = x.reshape(b, s, -1)
return out
class CausalConv1d(nn.Module):
def __init__(self,
chan_in,
chan_out,
kernel_size=3,
stride=1,
dilation=1,
pad_mode='replicate',
**kwargs):
super().__init__()
self.pad_mode = pad_mode
padding = (kernel_size - 1, 0) # T
self.time_causal_padding = padding
self.conv = nn.Conv1d(
chan_in,
chan_out,
kernel_size,
stride=stride,
dilation=dilation,
**kwargs)
def forward(self, x):
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
return self.conv(x)
class MotionEncoder_tc(nn.Module):
def __init__(self,
in_dim: int,
hidden_dim: int,
num_heads=int,
need_global=True,
dtype=None,
device=None):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.num_heads = num_heads
self.need_global = need_global
self.conv1_local = CausalConv1d(
in_dim, hidden_dim // 4 * num_heads, 3, stride=1)
if need_global:
self.conv1_global = CausalConv1d(
in_dim, hidden_dim // 4, 3, stride=1)
self.norm1 = nn.LayerNorm(
hidden_dim // 4,
elementwise_affine=False,
eps=1e-6,
**factory_kwargs)
self.act = nn.SiLU()
self.conv2 = CausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2)
self.conv3 = CausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2)
if need_global:
self.final_linear = nn.Linear(hidden_dim, hidden_dim,
**factory_kwargs)
self.norm1 = nn.LayerNorm(
hidden_dim // 4,
elementwise_affine=False,
eps=1e-6,
**factory_kwargs)
self.norm2 = nn.LayerNorm(
hidden_dim // 2,
elementwise_affine=False,
eps=1e-6,
**factory_kwargs)
self.norm3 = nn.LayerNorm(
hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))
def forward(self, x):
x = rearrange(x, 'b t c -> b c t')
x_ori = x.clone()
b, c, t = x.shape
x = self.conv1_local(x)
x = rearrange(x, 'b (n c) t -> (b n) t c', n=self.num_heads)
x = self.norm1(x)
x = self.act(x)
x = rearrange(x, 'b t c -> b c t')
x = self.conv2(x)
x = rearrange(x, 'b c t -> b t c')
x = self.norm2(x)
x = self.act(x)
x = rearrange(x, 'b t c -> b c t')
x = self.conv3(x)
x = rearrange(x, 'b c t -> b t c')
x = self.norm3(x)
x = self.act(x)
x = rearrange(x, '(b n) t c -> b t n c', b=b)
padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1)
x = torch.cat([x, padding], dim=-2)
x_local = x.clone()
if not self.need_global:
return x_local
x = self.conv1_global(x_ori)
x = rearrange(x, 'b c t -> b t c')
x = self.norm1(x)
x = self.act(x)
x = rearrange(x, 'b t c -> b c t')
x = self.conv2(x)
x = rearrange(x, 'b c t -> b t c')
x = self.norm2(x)
x = self.act(x)
x = rearrange(x, 'b t c -> b c t')
x = self.conv3(x)
x = rearrange(x, 'b c t -> b t c')
x = self.norm3(x)
x = self.act(x)
x = self.final_linear(x)
x = rearrange(x, '(b n) t c -> b t n c', b=b)
return x, x_local
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import math
import types
from copy import deepcopy
import numpy as np
import torch
import torch.cuda.amp as amp
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from einops import rearrange
from ...distributed.sequence_parallel import (
distributed_attention,
gather_forward,
get_rank,
get_world_size,
)
from ..model import (
Head,
WanAttentionBlock,
WanLayerNorm,
WanModel,
WanSelfAttention,
flash_attention,
rope_params,
sinusoidal_embedding_1d,
)
from .audio_utils import AudioInjector_WAN, CausalAudioEncoder
from .motioner import FramePackMotioner, MotionerTransformers
from .s2v_utils import rope_precompute
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def torch_dfs(model: nn.Module, parent_name='root'):
module_names, modules = [], []
current_name = parent_name if parent_name else 'root'
module_names.append(current_name)
modules.append(model)
for name, child in model.named_children():
if parent_name:
child_name = f'{parent_name}.{name}'
else:
child_name = name
child_modules, child_names = torch_dfs(child, child_name)
module_names += child_names
modules += child_modules
return modules, module_names
@amp.autocast(enabled=False)
def rope_apply(x, grid_sizes, freqs, start=None):
n, c = x.size(2), x.size(3) // 2
# loop over samples
output = []
for i, _ in enumerate(x):
s = x.size(1)
x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
s, n, -1, 2))
freqs_i = freqs[i, :s]
# apply rotary embedding
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
x_i = torch.cat([x_i, x[i, s:]])
# append to collection
output.append(x_i)
return torch.stack(output).float()
@amp.autocast(enabled=False)
def rope_apply_usp(x, grid_sizes, freqs):
s, n, c = x.size(1), x.size(2), x.size(3) // 2
# loop over samples
output = []
for i, _ in enumerate(x):
s = x.size(1)
# precompute multipliers
x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
s, n, -1, 2))
freqs_i = freqs[i]
freqs_i_rank = freqs_i
x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
x_i = torch.cat([x_i, x[i, s:]])
# append to collection
output.append(x_i)
return torch.stack(output).float()
def sp_attn_forward_s2v(self,
x,
seq_lens,
grid_sizes,
freqs,
dtype=torch.bfloat16):
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
half_dtypes = (torch.float16, torch.bfloat16)
def half(x):
return x if x.dtype in half_dtypes else x.to(dtype)
# query, key, value function
def qkv_fn(x):
q = self.norm_q(self.q(x)).view(b, s, n, d)
k = self.norm_k(self.k(x)).view(b, s, n, d)
v = self.v(x).view(b, s, n, d)
return q, k, v
q, k, v = qkv_fn(x)
q = rope_apply_usp(q, grid_sizes, freqs)
k = rope_apply_usp(k, grid_sizes, freqs)
x = distributed_attention(
half(q),
half(k),
half(v),
seq_lens,
window_size=self.window_size,
)
# output
x = x.flatten(2)
x = self.o(x)
return x
class Head_S2V(Head):
def forward(self, x, e):
"""
Args:
x(Tensor): Shape [B, L1, C]
e(Tensor): Shape [B, L1, C]
"""
assert e.dtype == torch.float32
with amp.autocast(dtype=torch.float32):
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
return x
class WanS2VSelfAttention(WanSelfAttention):
def forward(self, x, seq_lens, grid_sizes, freqs):
"""
Args:
x(Tensor): Shape [B, L, num_heads, C / num_heads]
seq_lens(Tensor): Shape [B]
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
# query, key, value function
def qkv_fn(x):
q = self.norm_q(self.q(x)).view(b, s, n, d)
k = self.norm_k(self.k(x)).view(b, s, n, d)
v = self.v(x).view(b, s, n, d)
return q, k, v
q, k, v = qkv_fn(x)
x = flash_attention(
q=rope_apply(q, grid_sizes, freqs),
k=rope_apply(k, grid_sizes, freqs),
v=v,
k_lens=seq_lens,
window_size=self.window_size)
# output
x = x.flatten(2)
x = self.o(x)
return x
class WanS2VAttentionBlock(WanAttentionBlock):
def __init__(self,
dim,
ffn_dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=False,
eps=1e-6):
super().__init__(dim, ffn_dim, num_heads, window_size, qk_norm,
cross_attn_norm, eps)
self.self_attn = WanS2VSelfAttention(dim, num_heads, window_size,
qk_norm, eps)
def forward(self, x, e, seq_lens, grid_sizes, freqs, context, context_lens):
assert e[0].dtype == torch.float32
seg_idx = e[1].item()
seg_idx = min(max(0, seg_idx), x.size(1))
seg_idx = [0, seg_idx, x.size(1)]
e = e[0]
modulation = self.modulation.unsqueeze(2)
with amp.autocast(dtype=torch.float32):
e = (modulation + e).chunk(6, dim=1)
assert e[0].dtype == torch.float32
e = [element.squeeze(1) for element in e]
norm_x = self.norm1(x).float()
parts = []
for i in range(2):
parts.append(norm_x[:, seg_idx[i]:seg_idx[i + 1]] *
(1 + e[1][:, i:i + 1]) + e[0][:, i:i + 1])
norm_x = torch.cat(parts, dim=1)
# self-attention
y = self.self_attn(norm_x, seq_lens, grid_sizes, freqs)
with amp.autocast(dtype=torch.float32):
z = []
for i in range(2):
z.append(y[:, seg_idx[i]:seg_idx[i + 1]] * e[2][:, i:i + 1])
y = torch.cat(z, dim=1)
x = x + y
# cross-attention & ffn function
def cross_attn_ffn(x, context, context_lens, e):
x = x + self.cross_attn(self.norm3(x), context, context_lens)
norm2_x = self.norm2(x).float()
parts = []
for i in range(2):
parts.append(norm2_x[:, seg_idx[i]:seg_idx[i + 1]] *
(1 + e[4][:, i:i + 1]) + e[3][:, i:i + 1])
norm2_x = torch.cat(parts, dim=1)
y = self.ffn(norm2_x)
with amp.autocast(dtype=torch.float32):
z = []
for i in range(2):
z.append(y[:, seg_idx[i]:seg_idx[i + 1]] * e[5][:, i:i + 1])
y = torch.cat(z, dim=1)
x = x + y
return x
x = cross_attn_ffn(x, context, context_lens, e)
return x
class WanModel_S2V(ModelMixin, ConfigMixin):
ignore_for_config = [
'args', 'kwargs', 'patch_size', 'cross_attn_norm', 'qk_norm',
'text_dim', 'window_size'
]
_no_split_modules = ['WanS2VAttentionBlock']
@register_to_config
def __init__(
self,
cond_dim=0,
audio_dim=5120,
num_audio_token=4,
enable_adain=False,
adain_mode="attn_norm",
audio_inject_layers=[0, 4, 8, 12, 16, 20, 24, 27],
zero_init=False,
zero_timestep=False,
enable_motioner=True,
add_last_motion=True,
enable_tsm=False,
trainable_token_pos_emb=False,
motion_token_num=1024,
enable_framepack=False, # Mutually exclusive with enable_motioner
framepack_drop_mode="drop",
model_type='s2v',
patch_size=(1, 2, 2),
text_len=512,
in_dim=16,
dim=2048,
ffn_dim=8192,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=16,
num_layers=32,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
*args,
**kwargs):
super().__init__()
assert model_type == 's2v'
self.model_type = model_type
self.patch_size = patch_size
self.text_len = text_len
self.in_dim = in_dim
self.dim = dim
self.ffn_dim = ffn_dim
self.freq_dim = freq_dim
self.text_dim = text_dim
self.out_dim = out_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
# embeddings
self.patch_embedding = nn.Conv3d(
in_dim, dim, kernel_size=patch_size, stride=patch_size)
self.text_embedding = nn.Sequential(
nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
nn.Linear(dim, dim))
self.time_embedding = nn.Sequential(
nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
# blocks
self.blocks = nn.ModuleList([
WanS2VAttentionBlock(dim, ffn_dim, num_heads, window_size, qk_norm,
cross_attn_norm, eps)
for _ in range(num_layers)
])
# head
self.head = Head_S2V(dim, out_dim, patch_size, eps)
# buffers (don't use register_buffer otherwise dtype will be changed in to())
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
d = dim // num_heads
self.freqs = torch.cat([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6))
],
dim=1)
# initialize weights
self.init_weights()
self.use_context_parallel = False # will modify in _configure_model func
if cond_dim > 0:
self.cond_encoder = nn.Conv3d(
cond_dim,
self.dim,
kernel_size=self.patch_size,
stride=self.patch_size)
self.enbale_adain = enable_adain
self.casual_audio_encoder = CausalAudioEncoder(
dim=audio_dim,
out_dim=self.dim,
num_token=num_audio_token,
need_global=enable_adain)
all_modules, all_modules_names = torch_dfs(
self.blocks, parent_name="root.transformer_blocks")
self.audio_injector = AudioInjector_WAN(
all_modules,
all_modules_names,
dim=self.dim,
num_heads=self.num_heads,
inject_layer=audio_inject_layers,
root_net=self,
enable_adain=enable_adain,
adain_dim=self.dim,
need_adain_ont=adain_mode != "attn_norm",
)
self.adain_mode = adain_mode
self.trainable_cond_mask = nn.Embedding(3, self.dim)
if zero_init:
self.zero_init_weights()
self.zero_timestep = zero_timestep # Whether to assign 0 value timestep to ref/motion
# init motioner
if enable_motioner and enable_framepack:
raise ValueError(
"enable_motioner and enable_framepack are mutually exclusive, please set one of them to False"
)
self.enable_motioner = enable_motioner
self.add_last_motion = add_last_motion
if enable_motioner:
motioner_dim = 2048
self.motioner = MotionerTransformers(
patch_size=(2, 4, 4),
dim=motioner_dim,
ffn_dim=motioner_dim,
freq_dim=256,
out_dim=16,
num_heads=16,
num_layers=13,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=False,
eps=1e-6,
motion_token_num=motion_token_num,
enable_tsm=enable_tsm,
motion_stride=4,
expand_ratio=2,
trainable_token_pos_emb=trainable_token_pos_emb,
)
self.zip_motion_out = torch.nn.Sequential(
WanLayerNorm(motioner_dim),
zero_module(nn.Linear(motioner_dim, self.dim)))
self.trainable_token_pos_emb = trainable_token_pos_emb
if trainable_token_pos_emb:
d = self.dim // self.num_heads
x = torch.zeros([1, motion_token_num, self.num_heads, d])
x[..., ::2] = 1
gride_sizes = [[
torch.tensor([0, 0, 0]).unsqueeze(0).repeat(1, 1),
torch.tensor([
1, self.motioner.motion_side_len,
self.motioner.motion_side_len
]).unsqueeze(0).repeat(1, 1),
torch.tensor([
1, self.motioner.motion_side_len,
self.motioner.motion_side_len
]).unsqueeze(0).repeat(1, 1),
]]
token_freqs = rope_apply(x, gride_sizes, self.freqs)
token_freqs = token_freqs[0, :,
0].reshape(motion_token_num, -1, 2)
token_freqs = token_freqs * 0.01
self.token_freqs = torch.nn.Parameter(token_freqs)
self.enable_framepack = enable_framepack
if enable_framepack:
self.frame_packer = FramePackMotioner(
inner_dim=self.dim,
num_heads=self.num_heads,
zip_frame_buckets=[1, 2, 16],
drop_mode=framepack_drop_mode)
def zero_init_weights(self):
with torch.no_grad():
self.trainable_cond_mask = zero_module(self.trainable_cond_mask)
if hasattr(self, "cond_encoder"):
self.cond_encoder = zero_module(self.cond_encoder)
for i in range(self.audio_injector.injector.__len__()):
self.audio_injector.injector[i].o = zero_module(
self.audio_injector.injector[i].o)
if self.enbale_adain:
self.audio_injector.injector_adain_layers[
i].linear = zero_module(
self.audio_injector.injector_adain_layers[i].linear)
def process_motion(self, motion_latents, drop_motion_frames=False):
if drop_motion_frames or motion_latents[0].shape[1] == 0:
return [], []
self.lat_motion_frames = motion_latents[0].shape[1]
mot = [self.patch_embedding(m.unsqueeze(0)) for m in motion_latents]
batch_size = len(mot)
mot_remb = []
flattern_mot = []
for bs in range(batch_size):
height, width = mot[bs].shape[3], mot[bs].shape[4]
flat_mot = mot[bs].flatten(2).transpose(1, 2).contiguous()
motion_grid_sizes = [[
torch.tensor([-self.lat_motion_frames, 0,
0]).unsqueeze(0).repeat(1, 1),
torch.tensor([0, height, width]).unsqueeze(0).repeat(1, 1),
torch.tensor([self.lat_motion_frames, height,
width]).unsqueeze(0).repeat(1, 1)
]]
motion_rope_emb = rope_precompute(
flat_mot.detach().view(1, flat_mot.shape[1], self.num_heads,
self.dim // self.num_heads),
motion_grid_sizes,
self.freqs,
start=None)
mot_remb.append(motion_rope_emb)
flattern_mot.append(flat_mot)
return flattern_mot, mot_remb
def process_motion_frame_pack(self,
motion_latents,
drop_motion_frames=False,
add_last_motion=2):
flattern_mot, mot_remb = self.frame_packer(motion_latents,
add_last_motion)
if drop_motion_frames:
return [m[:, :0] for m in flattern_mot
], [m[:, :0] for m in mot_remb]
else:
return flattern_mot, mot_remb
def process_motion_transformer_motioner(self,
motion_latents,
drop_motion_frames=False,
add_last_motion=True):
batch_size, height, width = len(
motion_latents), motion_latents[0].shape[2] // self.patch_size[
1], motion_latents[0].shape[3] // self.patch_size[2]
freqs = self.freqs
device = self.patch_embedding.weight.device
if freqs.device != device:
freqs = freqs.to(device)
if self.trainable_token_pos_emb:
with amp.autocast(dtype=torch.float64):
token_freqs = self.token_freqs.to(torch.float64)
token_freqs = token_freqs / token_freqs.norm(
dim=-1, keepdim=True)
freqs = [freqs, torch.view_as_complex(token_freqs)]
if not drop_motion_frames and add_last_motion:
last_motion_latent = [u[:, -1:] for u in motion_latents]
last_mot = [
self.patch_embedding(m.unsqueeze(0)) for m in last_motion_latent
]
last_mot = [m.flatten(2).transpose(1, 2) for m in last_mot]
last_mot = torch.cat(last_mot)
gride_sizes = [[
torch.tensor([-1, 0, 0]).unsqueeze(0).repeat(batch_size, 1),
torch.tensor([0, height,
width]).unsqueeze(0).repeat(batch_size, 1),
torch.tensor([1, height,
width]).unsqueeze(0).repeat(batch_size, 1)
]]
else:
last_mot = torch.zeros([batch_size, 0, self.dim],
device=motion_latents[0].device,
dtype=motion_latents[0].dtype)
gride_sizes = []
zip_motion = self.motioner(motion_latents)
zip_motion = self.zip_motion_out(zip_motion)
if drop_motion_frames:
zip_motion = zip_motion * 0.0
zip_motion_grid_sizes = [[
torch.tensor([-1, 0, 0]).unsqueeze(0).repeat(batch_size, 1),
torch.tensor([
0, self.motioner.motion_side_len, self.motioner.motion_side_len
]).unsqueeze(0).repeat(batch_size, 1),
torch.tensor(
[1 if not self.trainable_token_pos_emb else -1, height,
width]).unsqueeze(0).repeat(batch_size, 1),
]]
mot = torch.cat([last_mot, zip_motion], dim=1)
gride_sizes = gride_sizes + zip_motion_grid_sizes
motion_rope_emb = rope_precompute(
mot.detach().view(batch_size, mot.shape[1], self.num_heads,
self.dim // self.num_heads),
gride_sizes,
freqs,
start=None)
return [m.unsqueeze(0) for m in mot
], [r.unsqueeze(0) for r in motion_rope_emb]
def inject_motion(self,
x,
seq_lens,
rope_embs,
mask_input,
motion_latents,
drop_motion_frames=False,
add_last_motion=True):
# inject the motion frames token to the hidden states
if self.enable_motioner:
mot, mot_remb = self.process_motion_transformer_motioner(
motion_latents,
drop_motion_frames=drop_motion_frames,
add_last_motion=add_last_motion)
elif self.enable_framepack:
mot, mot_remb = self.process_motion_frame_pack(
motion_latents,
drop_motion_frames=drop_motion_frames,
add_last_motion=add_last_motion)
else:
mot, mot_remb = self.process_motion(
motion_latents, drop_motion_frames=drop_motion_frames)
if len(mot) > 0:
x = [torch.cat([u, m], dim=1) for u, m in zip(x, mot)]
seq_lens = seq_lens + torch.tensor([r.size(1) for r in mot],
dtype=torch.long)
rope_embs = [
torch.cat([u, m], dim=1) for u, m in zip(rope_embs, mot_remb)
]
mask_input = [
torch.cat([
m, 2 * torch.ones([1, u.shape[1] - m.shape[1]],
device=m.device,
dtype=m.dtype)
],
dim=1) for m, u in zip(mask_input, x)
]
return x, seq_lens, rope_embs, mask_input
def after_transformer_block(self, block_idx, hidden_states):
if block_idx in self.audio_injector.injected_block_id.keys():
audio_attn_id = self.audio_injector.injected_block_id[block_idx]
audio_emb = self.merged_audio_emb # b f n c
num_frames = audio_emb.shape[1]
if self.use_context_parallel:
hidden_states = gather_forward(hidden_states, dim=1)
input_hidden_states = hidden_states[:, :self.
original_seq_len].clone(
) # b (f h w) c
input_hidden_states = rearrange(
input_hidden_states, "b (t n) c -> (b t) n c", t=num_frames)
if self.enbale_adain and self.adain_mode == "attn_norm":
audio_emb_global = self.audio_emb_global
audio_emb_global = rearrange(audio_emb_global,
"b t n c -> (b t) n c")
adain_hidden_states = self.audio_injector.injector_adain_layers[
audio_attn_id](
input_hidden_states, temb=audio_emb_global[:, 0])
attn_hidden_states = adain_hidden_states
else:
attn_hidden_states = self.audio_injector.injector_pre_norm_feat[
audio_attn_id](
input_hidden_states)
audio_emb = rearrange(
audio_emb, "b t n c -> (b t) n c", t=num_frames)
attn_audio_emb = audio_emb
residual_out = self.audio_injector.injector[audio_attn_id](
x=attn_hidden_states,
context=attn_audio_emb,
context_lens=torch.ones(
attn_hidden_states.shape[0],
dtype=torch.long,
device=attn_hidden_states.device) * attn_audio_emb.shape[1])
residual_out = rearrange(
residual_out, "(b t) n c -> b (t n) c", t=num_frames)
hidden_states[:, :self.
original_seq_len] = hidden_states[:, :self.
original_seq_len] + residual_out
if self.use_context_parallel:
hidden_states = torch.chunk(
hidden_states, get_world_size(), dim=1)[get_rank()]
return hidden_states
def forward(
self,
x,
t,
context,
seq_len,
ref_latents,
motion_latents,
cond_states,
audio_input=None,
motion_frames=[17, 5],
add_last_motion=2,
drop_motion_frames=False,
*extra_args,
**extra_kwargs):
"""
x: A list of videos each with shape [C, T, H, W].
t: [B].
context: A list of text embeddings each with shape [L, C].
seq_len: A list of video token lens, no need for this model.
ref_latents A list of reference image for each video with shape [C, 1, H, W].
motion_latents A list of motion frames for each video with shape [C, T_m, H, W].
cond_states A list of condition frames (i.e. pose) each with shape [C, T, H, W].
audio_input The input audio embedding [B, num_wav2vec_layer, C_a, T_a].
motion_frames The number of motion frames and motion latents frames encoded by vae, i.e. [17, 5]
add_last_motion For the motioner, if add_last_motion > 0, it means that the most recent frame (i.e., the last frame) will be added.
For frame packing, the behavior depends on the value of add_last_motion:
add_last_motion = 0: Only the farthest part of the latent (i.e., clean_latents_4x) is included.
add_last_motion = 1: Both clean_latents_2x and clean_latents_4x are included.
add_last_motion = 2: All motion-related latents are used.
drop_motion_frames Bool, whether drop the motion frames info
"""
add_last_motion = self.add_last_motion * add_last_motion
audio_input = torch.cat([
audio_input[..., 0:1].repeat(1, 1, 1, motion_frames[0]), audio_input
],
dim=-1)
audio_emb_res = self.casual_audio_encoder(audio_input)
if self.enbale_adain:
audio_emb_global, audio_emb = audio_emb_res
self.audio_emb_global = audio_emb_global[:,
motion_frames[1]:].clone()
else:
audio_emb = audio_emb_res
self.merged_audio_emb = audio_emb[:, motion_frames[1]:, :]
device = self.patch_embedding.weight.device
# embeddings
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
# cond states
cond = [self.cond_encoder(c.unsqueeze(0)) for c in cond_states]
x = [x_ + pose for x_, pose in zip(x, cond)]
grid_sizes = torch.stack(
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
x = [u.flatten(2).transpose(1, 2) for u in x]
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
original_grid_sizes = deepcopy(grid_sizes)
grid_sizes = [[torch.zeros_like(grid_sizes), grid_sizes, grid_sizes]]
# ref and motion
self.lat_motion_frames = motion_latents[0].shape[1]
ref = [self.patch_embedding(r.unsqueeze(0)) for r in ref_latents]
batch_size = len(ref)
height, width = ref[0].shape[3], ref[0].shape[4]
ref_grid_sizes = [[
torch.tensor([30, 0, 0]).unsqueeze(0).repeat(batch_size,
1), # the start index
torch.tensor([31, height,
width]).unsqueeze(0).repeat(batch_size,
1), # the end index
torch.tensor([1, height, width]).unsqueeze(0).repeat(batch_size, 1),
] # the range
]
ref = [r.flatten(2).transpose(1, 2) for r in ref] # r: 1 c f h w
self.original_seq_len = seq_lens[0]
seq_lens = seq_lens + torch.tensor([r.size(1) for r in ref],
dtype=torch.long)
grid_sizes = grid_sizes + ref_grid_sizes
x = [torch.cat([u, r], dim=1) for u, r in zip(x, ref)]
# Initialize masks to indicate noisy latent, ref latent, and motion latent.
# However, at this point, only the first two (noisy and ref latents) are marked;
# the marking of motion latent will be implemented inside `inject_motion`.
mask_input = [
torch.zeros([1, u.shape[1]], dtype=torch.long, device=x[0].device)
for u in x
]
for i in range(len(mask_input)):
mask_input[i][:, self.original_seq_len:] = 1
# compute the rope embeddings for the input
x = torch.cat(x)
b, s, n, d = x.size(0), x.size(
1), self.num_heads, self.dim // self.num_heads
self.pre_compute_freqs = rope_precompute(
x.detach().view(b, s, n, d), grid_sizes, self.freqs, start=None)
x = [u.unsqueeze(0) for u in x]
self.pre_compute_freqs = [
u.unsqueeze(0) for u in self.pre_compute_freqs
]
x, seq_lens, self.pre_compute_freqs, mask_input = self.inject_motion(
x,
seq_lens,
self.pre_compute_freqs,
mask_input,
motion_latents,
drop_motion_frames=drop_motion_frames,
add_last_motion=add_last_motion)
x = torch.cat(x, dim=0)
self.pre_compute_freqs = torch.cat(self.pre_compute_freqs, dim=0)
mask_input = torch.cat(mask_input, dim=0)
x = x + self.trainable_cond_mask(mask_input).to(x.dtype)
# time embeddings
if self.zero_timestep:
t = torch.cat([t, torch.zeros([1], dtype=t.dtype, device=t.device)])
with amp.autocast(dtype=torch.float32):
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t).float())
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
assert e.dtype == torch.float32 and e0.dtype == torch.float32
if self.zero_timestep:
e = e[:-1]
zero_e0 = e0[-1:]
e0 = e0[:-1]
token_len = x.shape[1]
e0 = torch.cat([
e0.unsqueeze(2),
zero_e0.unsqueeze(2).repeat(e0.size(0), 1, 1, 1)
],
dim=2)
e0 = [e0, self.original_seq_len]
else:
e0 = e0.unsqueeze(2).repeat(1, 1, 2, 1)
e0 = [e0, 0]
# context
context_lens = None
context = self.text_embedding(
torch.stack([
torch.cat(
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
for u in context
]))
# grad ckpt args
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs, **kwargs):
if return_dict is not None:
return module(*inputs, **kwargs, return_dict=return_dict)
else:
return module(*inputs, **kwargs)
return custom_forward
if self.use_context_parallel:
# sharded tensors for long context attn
sp_rank = get_rank()
x = torch.chunk(x, get_world_size(), dim=1)
sq_size = [u.shape[1] for u in x]
sq_start_size = sum(sq_size[:sp_rank])
x = x[sp_rank]
# Confirm the application range of the time embedding in e0[0] for each sequence:
# - For tokens before seg_id: apply e0[0][:, :, 0]
# - For tokens after seg_id: apply e0[0][:, :, 1]
sp_size = x.shape[1]
seg_idx = e0[1] - sq_start_size
e0[1] = seg_idx
self.pre_compute_freqs = torch.chunk(
self.pre_compute_freqs, get_world_size(), dim=1)
self.pre_compute_freqs = self.pre_compute_freqs[sp_rank]
# arguments
kwargs = dict(
e=e0,
seq_lens=seq_lens,
grid_sizes=grid_sizes,
freqs=self.pre_compute_freqs,
context=context,
context_lens=context_lens)
for idx, block in enumerate(self.blocks):
x = block(x, **kwargs)
x = self.after_transformer_block(idx, x)
# Context Parallel
if self.use_context_parallel:
x = gather_forward(x.contiguous(), dim=1)
# unpatchify
x = x[:, :self.original_seq_len]
# head
x = self.head(x, e)
x = self.unpatchify(x, original_grid_sizes)
return [u.float() for u in x]
def unpatchify(self, x, grid_sizes):
"""
Reconstruct video tensors from patch embeddings.
Args:
x (List[Tensor]):
List of patchified features, each with shape [L, C_out * prod(patch_size)]
grid_sizes (Tensor):
Original spatial-temporal grid dimensions before patching,
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
Returns:
List[Tensor]:
Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
"""
c = self.out_dim
out = []
for u, v in zip(x, grid_sizes.tolist()):
u = u[:math.prod(v)].view(*v, *self.patch_size, c)
u = torch.einsum('fhwpqrc->cfphqwr', u)
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
out.append(u)
return out
def init_weights(self):
r"""
Initialize model parameters using Xavier initialization.
"""
# basic init
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
# init embeddings
nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
for m in self.text_embedding.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=.02)
for m in self.time_embedding.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=.02)
# init output layer
nn.init.zeros_(self.head.head.weight)
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import math
from typing import Any, Dict, List, Literal, Optional, Union
import numpy as np
import torch
import torch.cuda.amp as amp
import torch.nn as nn
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
from diffusers.utils import BaseOutput, is_torch_version
from einops import rearrange, repeat
from ..model import flash_attention
from .s2v_utils import rope_precompute
def sinusoidal_embedding_1d(dim, position):
# preprocess
assert dim % 2 == 0
half = dim // 2
position = position.type(torch.float64)
# calculation
sinusoid = torch.outer(
position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
return x
@amp.autocast(enabled=False)
def rope_params(max_seq_len, dim, theta=10000):
assert dim % 2 == 0
freqs = torch.outer(
torch.arange(max_seq_len),
1.0 / torch.pow(theta,
torch.arange(0, dim, 2).to(torch.float64).div(dim)))
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
@amp.autocast(enabled=False)
def rope_apply(x, grid_sizes, freqs, start=None):
n, c = x.size(2), x.size(3) // 2
# split freqs
if type(freqs) is list:
trainable_freqs = freqs[1]
freqs = freqs[0]
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
# loop over samples
output = []
output = x.clone()
seq_bucket = [0]
if not type(grid_sizes) is list:
grid_sizes = [grid_sizes]
for g in grid_sizes:
if not type(g) is list:
g = [torch.zeros_like(g), g]
batch_size = g[0].shape[0]
for i in range(batch_size):
if start is None:
f_o, h_o, w_o = g[0][i]
else:
f_o, h_o, w_o = start[i]
f, h, w = g[1][i]
t_f, t_h, t_w = g[2][i]
seq_f, seq_h, seq_w = f - f_o, h - h_o, w - w_o
seq_len = int(seq_f * seq_h * seq_w)
if seq_len > 0:
if t_f > 0:
factor_f, factor_h, factor_w = (t_f / seq_f).item(), (
t_h / seq_h).item(), (t_w / seq_w).item()
if f_o >= 0:
f_sam = np.linspace(f_o.item(), (t_f + f_o).item() - 1,
seq_f).astype(int).tolist()
else:
f_sam = np.linspace(-f_o.item(),
(-t_f - f_o).item() + 1,
seq_f).astype(int).tolist()
h_sam = np.linspace(h_o.item(), (t_h + h_o).item() - 1,
seq_h).astype(int).tolist()
w_sam = np.linspace(w_o.item(), (t_w + w_o).item() - 1,
seq_w).astype(int).tolist()
assert f_o * f >= 0 and h_o * h >= 0 and w_o * w >= 0
freqs_0 = freqs[0][f_sam] if f_o >= 0 else freqs[0][
f_sam].conj()
freqs_0 = freqs_0.view(seq_f, 1, 1, -1)
freqs_i = torch.cat([
freqs_0.expand(seq_f, seq_h, seq_w, -1),
freqs[1][h_sam].view(1, seq_h, 1, -1).expand(
seq_f, seq_h, seq_w, -1),
freqs[2][w_sam].view(1, 1, seq_w, -1).expand(
seq_f, seq_h, seq_w, -1),
],
dim=-1).reshape(seq_len, 1, -1)
elif t_f < 0:
freqs_i = trainable_freqs.unsqueeze(1)
# apply rotary embedding
# precompute multipliers
x_i = torch.view_as_complex(
x[i, seq_bucket[-1]:seq_bucket[-1] + seq_len].to(
torch.float64).reshape(seq_len, n, -1, 2))
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
output[i, seq_bucket[-1]:seq_bucket[-1] + seq_len] = x_i
seq_bucket.append(seq_bucket[-1] + seq_len)
return output.float()
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
return self._norm(x.float()).type_as(x) * self.weight
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
class LayerNorm(nn.LayerNorm):
def __init__(self, dim, eps=1e-6, elementwise_affine=False):
super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
def forward(self, x):
return super().forward(x.float()).type_as(x)
class SelfAttention(nn.Module):
def __init__(self,
dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
eps=1e-6):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.window_size = window_size
self.qk_norm = qk_norm
self.eps = eps
# layers
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
self.norm_k = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
def forward(self, x, seq_lens, grid_sizes, freqs):
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
# query, key, value function
def qkv_fn(x):
q = self.norm_q(self.q(x)).view(b, s, n, d)
k = self.norm_k(self.k(x)).view(b, s, n, d)
v = self.v(x).view(b, s, n, d)
return q, k, v
q, k, v = qkv_fn(x)
x = flash_attention(
q=rope_apply(q, grid_sizes, freqs),
k=rope_apply(k, grid_sizes, freqs),
v=v,
k_lens=seq_lens,
window_size=self.window_size)
# output
x = x.flatten(2)
x = self.o(x)
return x
class SwinSelfAttention(SelfAttention):
def forward(self, x, seq_lens, grid_sizes, freqs):
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
assert b == 1, 'Only support batch_size 1'
# query, key, value function
def qkv_fn(x):
q = self.norm_q(self.q(x)).view(b, s, n, d)
k = self.norm_k(self.k(x)).view(b, s, n, d)
v = self.v(x).view(b, s, n, d)
return q, k, v
q, k, v = qkv_fn(x)
q = rope_apply(q, grid_sizes, freqs)
k = rope_apply(k, grid_sizes, freqs)
T, H, W = grid_sizes[0].tolist()
q = rearrange(q, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W)
k = rearrange(k, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W)
v = rearrange(v, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W)
ref_q = q[-1:]
q = q[:-1]
ref_k = repeat(
k[-1:], "1 s n d -> t s n d", t=k.shape[0] - 1) # t hw n d
k = k[:-1]
k = torch.cat([k[:1], k, k[-1:]])
k = torch.cat([k[1:-1], k[2:], k[:-2], ref_k], dim=1) # (bt) (3hw) n d
ref_v = repeat(v[-1:], "1 s n d -> t s n d", t=v.shape[0] - 1)
v = v[:-1]
v = torch.cat([v[:1], v, v[-1:]])
v = torch.cat([v[1:-1], v[2:], v[:-2], ref_v], dim=1)
# q: b (t h w) n d
# k: b (t h w) n d
out = flash_attention(
q=q,
k=k,
v=v,
# k_lens=torch.tensor([k.shape[1]] * k.shape[0], device=x.device, dtype=torch.long),
window_size=self.window_size)
out = torch.cat([out, ref_v[:1]], axis=0)
out = rearrange(out, '(b t) (h w) n d -> b (t h w) n d', t=T, h=H, w=W)
x = out
# output
x = x.flatten(2)
x = self.o(x)
return x
#Fix the reference frame RoPE to 1,H,W.
#Set the current frame RoPE to 1.
#Set the previous frame RoPE to 0.
class CasualSelfAttention(SelfAttention):
def forward(self, x, seq_lens, grid_sizes, freqs):
shifting = 3
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
assert b == 1, 'Only support batch_size 1'
# query, key, value function
def qkv_fn(x):
q = self.norm_q(self.q(x)).view(b, s, n, d)
k = self.norm_k(self.k(x)).view(b, s, n, d)
v = self.v(x).view(b, s, n, d)
return q, k, v
q, k, v = qkv_fn(x)
T, H, W = grid_sizes[0].tolist()
q = rearrange(q, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W)
k = rearrange(k, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W)
v = rearrange(v, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W)
ref_q = q[-1:]
q = q[:-1]
grid_sizes = torch.tensor([[1, H, W]] * q.shape[0], dtype=torch.long)
start = [[shifting, 0, 0]] * q.shape[0]
q = rope_apply(q, grid_sizes, freqs, start=start)
ref_k = k[-1:]
grid_sizes = torch.tensor([[1, H, W]], dtype=torch.long)
# start = [[shifting, H, W]]
start = [[shifting + 10, 0, 0]]
ref_k = rope_apply(ref_k, grid_sizes, freqs, start)
ref_k = repeat(
ref_k, "1 s n d -> t s n d", t=k.shape[0] - 1) # t hw n d
k = k[:-1]
k = torch.cat([*([k[:1]] * shifting), k])
cat_k = []
for i in range(shifting):
cat_k.append(k[i:i - shifting])
cat_k.append(k[shifting:])
k = torch.cat(cat_k, dim=1) # (bt) (3hw) n d
grid_sizes = torch.tensor(
[[shifting + 1, H, W]] * q.shape[0], dtype=torch.long)
k = rope_apply(k, grid_sizes, freqs)
k = torch.cat([k, ref_k], dim=1)
ref_v = repeat(v[-1:], "1 s n d -> t s n d", t=q.shape[0]) # t hw n d
v = v[:-1]
v = torch.cat([*([v[:1]] * shifting), v])
cat_v = []
for i in range(shifting):
cat_v.append(v[i:i - shifting])
cat_v.append(v[shifting:])
v = torch.cat(cat_v, dim=1) # (bt) (3hw) n d
v = torch.cat([v, ref_v], dim=1)
# q: b (t h w) n d
# k: b (t h w) n d
outs = []
for i in range(q.shape[0]):
out = flash_attention(
q=q[i:i + 1],
k=k[i:i + 1],
v=v[i:i + 1],
window_size=self.window_size)
outs.append(out)
out = torch.cat(outs, dim=0)
out = torch.cat([out, ref_v[:1]], axis=0)
out = rearrange(out, '(b t) (h w) n d -> b (t h w) n d', t=T, h=H, w=W)
x = out
# output
x = x.flatten(2)
x = self.o(x)
return x
class MotionerAttentionBlock(nn.Module):
def __init__(self,
dim,
ffn_dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=False,
eps=1e-6,
self_attn_block="SelfAttention"):
super().__init__()
self.dim = dim
self.ffn_dim = ffn_dim
self.num_heads = num_heads
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
# layers
self.norm1 = LayerNorm(dim, eps)
if self_attn_block == "SelfAttention":
self.self_attn = SelfAttention(dim, num_heads, window_size, qk_norm,
eps)
elif self_attn_block == "SwinSelfAttention":
self.self_attn = SwinSelfAttention(dim, num_heads, window_size,
qk_norm, eps)
elif self_attn_block == "CasualSelfAttention":
self.self_attn = CasualSelfAttention(dim, num_heads, window_size,
qk_norm, eps)
self.norm2 = LayerNorm(dim, eps)
self.ffn = nn.Sequential(
nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
nn.Linear(ffn_dim, dim))
def forward(
self,
x,
seq_lens,
grid_sizes,
freqs,
):
# self-attention
y = self.self_attn(self.norm1(x).float(), seq_lens, grid_sizes, freqs)
x = x + y
y = self.ffn(self.norm2(x).float())
x = x + y
return x
class Head(nn.Module):
def __init__(self, dim, out_dim, patch_size, eps=1e-6):
super().__init__()
self.dim = dim
self.out_dim = out_dim
self.patch_size = patch_size
self.eps = eps
# layers
out_dim = math.prod(patch_size) * out_dim
self.norm = LayerNorm(dim, eps)
self.head = nn.Linear(dim, out_dim)
def forward(self, x):
x = self.head(self.norm(x))
return x
class MotionerTransformers(nn.Module, PeftAdapterMixin):
def __init__(
self,
patch_size=(1, 2, 2),
in_dim=16,
dim=2048,
ffn_dim=8192,
freq_dim=256,
out_dim=16,
num_heads=16,
num_layers=32,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=False,
eps=1e-6,
self_attn_block="SelfAttention",
motion_token_num=1024,
enable_tsm=False,
motion_stride=4,
expand_ratio=2,
trainable_token_pos_emb=False,
):
super().__init__()
self.patch_size = patch_size
self.in_dim = in_dim
self.dim = dim
self.ffn_dim = ffn_dim
self.freq_dim = freq_dim
self.out_dim = out_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
self.enable_tsm = enable_tsm
self.motion_stride = motion_stride
self.expand_ratio = expand_ratio
self.sample_c = self.patch_size[0]
# embeddings
self.patch_embedding = nn.Conv3d(
in_dim, dim, kernel_size=patch_size, stride=patch_size)
# blocks
self.blocks = nn.ModuleList([
MotionerAttentionBlock(
dim,
ffn_dim,
num_heads,
window_size,
qk_norm,
cross_attn_norm,
eps,
self_attn_block=self_attn_block) for _ in range(num_layers)
])
# buffers (don't use register_buffer otherwise dtype will be changed in to())
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
d = dim // num_heads
self.freqs = torch.cat([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6))
],
dim=1)
self.gradient_checkpointing = False
self.motion_side_len = int(math.sqrt(motion_token_num))
assert self.motion_side_len**2 == motion_token_num
self.token = nn.Parameter(
torch.zeros(1, motion_token_num, dim).contiguous())
self.trainable_token_pos_emb = trainable_token_pos_emb
if trainable_token_pos_emb:
x = torch.zeros([1, motion_token_num, num_heads, d])
x[..., ::2] = 1
gride_sizes = [[
torch.tensor([0, 0, 0]).unsqueeze(0).repeat(1, 1),
torch.tensor([1, self.motion_side_len,
self.motion_side_len]).unsqueeze(0).repeat(1, 1),
torch.tensor([1, self.motion_side_len,
self.motion_side_len]).unsqueeze(0).repeat(1, 1),
]]
token_freqs = rope_apply(x, gride_sizes, self.freqs)
token_freqs = token_freqs[0, :, 0].reshape(motion_token_num, -1, 2)
token_freqs = token_freqs * 0.01
self.token_freqs = torch.nn.Parameter(token_freqs)
def after_patch_embedding(self, x):
return x
def forward(
self,
x,
):
"""
x: A list of videos each with shape [C, T, H, W].
t: [B].
context: A list of text embeddings each with shape [L, C].
"""
# params
motion_frames = x[0].shape[1]
device = self.patch_embedding.weight.device
freqs = self.freqs
if freqs.device != device:
freqs = freqs.to(device)
if self.trainable_token_pos_emb:
with amp.autocast(dtype=torch.float64):
token_freqs = self.token_freqs.to(torch.float64)
token_freqs = token_freqs / token_freqs.norm(
dim=-1, keepdim=True)
freqs = [freqs, torch.view_as_complex(token_freqs)]
if self.enable_tsm:
sample_idx = [
sample_indices(
u.shape[1],
stride=self.motion_stride,
expand_ratio=self.expand_ratio,
c=self.sample_c) for u in x
]
x = [
torch.flip(torch.flip(u, [1])[:, idx], [1])
for idx, u in zip(sample_idx, x)
]
# embeddings
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
x = self.after_patch_embedding(x)
seq_f, seq_h, seq_w = x[0].shape[-3:]
batch_size = len(x)
if not self.enable_tsm:
grid_sizes = torch.stack(
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
grid_sizes = [[
torch.zeros_like(grid_sizes), grid_sizes, grid_sizes
]]
seq_f = 0
else:
grid_sizes = []
for idx in sample_idx[0][::-1][::self.sample_c]:
tsm_frame_grid_sizes = [[
torch.tensor([idx, 0,
0]).unsqueeze(0).repeat(batch_size, 1),
torch.tensor([idx + 1, seq_h,
seq_w]).unsqueeze(0).repeat(batch_size, 1),
torch.tensor([1, seq_h,
seq_w]).unsqueeze(0).repeat(batch_size, 1),
]]
grid_sizes += tsm_frame_grid_sizes
seq_f = sample_idx[0][-1] + 1
x = [u.flatten(2).transpose(1, 2) for u in x]
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
x = torch.cat([u for u in x])
batch_size = len(x)
token_grid_sizes = [[
torch.tensor([seq_f, 0, 0]).unsqueeze(0).repeat(batch_size, 1),
torch.tensor(
[seq_f + 1, self.motion_side_len,
self.motion_side_len]).unsqueeze(0).repeat(batch_size, 1),
torch.tensor(
[1 if not self.trainable_token_pos_emb else -1, seq_h,
seq_w]).unsqueeze(0).repeat(batch_size, 1),
] # 第三行代表rope emb的想要覆盖到的范围
]
grid_sizes = grid_sizes + token_grid_sizes
token_unpatch_grid_sizes = torch.stack([
torch.tensor([1, 32, 32], dtype=torch.long)
for b in range(batch_size)
])
token_len = self.token.shape[1]
token = self.token.clone().repeat(x.shape[0], 1, 1).contiguous()
seq_lens = seq_lens + torch.tensor([t.size(0) for t in token],
dtype=torch.long)
x = torch.cat([x, token], dim=1)
# arguments
kwargs = dict(
seq_lens=seq_lens,
grid_sizes=grid_sizes,
freqs=freqs,
)
# grad ckpt args
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs, **kwargs):
if return_dict is not None:
return module(*inputs, **kwargs, return_dict=return_dict)
else:
return module(*inputs, **kwargs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = ({
"use_reentrant": False
} if is_torch_version(">=", "1.11.0") else {})
for idx, block in enumerate(self.blocks):
if self.training and self.gradient_checkpointing:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x,
**kwargs,
**ckpt_kwargs,
)
else:
x = block(x, **kwargs)
# head
out = x[:, -token_len:]
return out
def unpatchify(self, x, grid_sizes):
c = self.out_dim
out = []
for u, v in zip(x, grid_sizes.tolist()):
u = u[:math.prod(v)].view(*v, *self.patch_size, c)
u = torch.einsum('fhwpqrc->cfphqwr', u)
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
out.append(u)
return out
def init_weights(self):
# basic init
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
# init embeddings
nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
class FramePackMotioner(nn.Module):
def __init__(
self,
inner_dim=1024,
num_heads=16, # Used to indicate the number of heads in the backbone network; unrelated to this module's design
zip_frame_buckets=[
1, 2, 16
], # Three numbers representing the number of frames sampled for patch operations from the nearest to the farthest frames
drop_mode="drop", # If not "drop", it will use "padd", meaning padding instead of deletion
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.proj = nn.Conv3d(
16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
self.proj_2x = nn.Conv3d(
16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
self.proj_4x = nn.Conv3d(
16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
self.zip_frame_buckets = torch.tensor(
zip_frame_buckets, dtype=torch.long)
self.inner_dim = inner_dim
self.num_heads = num_heads
assert (inner_dim %
num_heads) == 0 and (inner_dim // num_heads) % 2 == 0
d = inner_dim // num_heads
self.freqs = torch.cat([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6))
],
dim=1)
self.drop_mode = drop_mode
def forward(self, motion_latents, add_last_motion=2):
motion_frames = motion_latents[0].shape[1]
mot = []
mot_remb = []
for m in motion_latents:
lat_height, lat_width = m.shape[2], m.shape[3]
padd_lat = torch.zeros(16, self.zip_frame_buckets.sum(), lat_height,
lat_width).to(
device=m.device, dtype=m.dtype)
overlap_frame = min(padd_lat.shape[1], m.shape[1])
if overlap_frame > 0:
padd_lat[:, -overlap_frame:] = m[:, -overlap_frame:]
if add_last_motion < 2 and self.drop_mode != "drop":
zero_end_frame = self.zip_frame_buckets[:self.zip_frame_buckets.
__len__() -
add_last_motion -
1].sum()
padd_lat[:, -zero_end_frame:] = 0
padd_lat = padd_lat.unsqueeze(0)
clean_latents_4x, clean_latents_2x, clean_latents_post = padd_lat[:, :, -self.zip_frame_buckets.sum(
):, :, :].split(
list(self.zip_frame_buckets)[::-1], dim=2) # 16, 2 ,1
# patchfy
clean_latents_post = self.proj(clean_latents_post).flatten(
2).transpose(1, 2)
clean_latents_2x = self.proj_2x(clean_latents_2x).flatten(
2).transpose(1, 2)
clean_latents_4x = self.proj_4x(clean_latents_4x).flatten(
2).transpose(1, 2)
if add_last_motion < 2 and self.drop_mode == "drop":
clean_latents_post = clean_latents_post[:, :
0] if add_last_motion < 2 else clean_latents_post
clean_latents_2x = clean_latents_2x[:, :
0] if add_last_motion < 1 else clean_latents_2x
motion_lat = torch.cat(
[clean_latents_post, clean_latents_2x, clean_latents_4x], dim=1)
# rope
start_time_id = -(self.zip_frame_buckets[:1].sum())
end_time_id = start_time_id + self.zip_frame_buckets[0]
grid_sizes = [] if add_last_motion < 2 and self.drop_mode == "drop" else \
[
[torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1),
torch.tensor([end_time_id, lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1),
torch.tensor([self.zip_frame_buckets[0], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ]
]
start_time_id = -(self.zip_frame_buckets[:2].sum())
end_time_id = start_time_id + self.zip_frame_buckets[1] // 2
grid_sizes_2x = [] if add_last_motion < 1 and self.drop_mode == "drop" else \
[
[torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1),
torch.tensor([end_time_id, lat_height // 4, lat_width // 4]).unsqueeze(0).repeat(1, 1),
torch.tensor([self.zip_frame_buckets[1], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ]
]
start_time_id = -(self.zip_frame_buckets[:3].sum())
end_time_id = start_time_id + self.zip_frame_buckets[2] // 4
grid_sizes_4x = [[
torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1),
torch.tensor([end_time_id, lat_height // 8,
lat_width // 8]).unsqueeze(0).repeat(1, 1),
torch.tensor([
self.zip_frame_buckets[2], lat_height // 2, lat_width // 2
]).unsqueeze(0).repeat(1, 1),
]]
grid_sizes = grid_sizes + grid_sizes_2x + grid_sizes_4x
motion_rope_emb = rope_precompute(
motion_lat.detach().view(1, motion_lat.shape[1], self.num_heads,
self.inner_dim // self.num_heads),
grid_sizes,
self.freqs,
start=None)
mot.append(motion_lat)
mot_remb.append(motion_rope_emb)
return mot, mot_remb
def sample_indices(N, stride, expand_ratio, c):
indices = []
current_start = 0
while current_start < N:
bucket_width = int(stride * (expand_ratio**(len(indices) / stride)))
interval = int(bucket_width / stride * c)
current_end = min(N, current_start + bucket_width)
bucket_samples = []
for i in range(current_end - 1, current_start - 1, -interval):
for near in range(c):
bucket_samples.append(i - near)
indices += bucket_samples[::-1]
current_start += bucket_width
return indices
if __name__ == '__main__':
device = "cuda"
model = FramePackMotioner(inner_dim=1024)
batch_size = 2
num_frame, height, width = (28, 32, 32)
single_input = torch.ones([16, num_frame, height, width], device=device)
for i in range(num_frame):
single_input[:, num_frame - 1 - i] *= i
x = [single_input] * batch_size
model.forward(x)
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import numpy as np
import torch
def rope_precompute(x, grid_sizes, freqs, start=None):
b, s, n, c = x.size(0), x.size(1), x.size(2), x.size(3) // 2
# split freqs
if type(freqs) is list:
trainable_freqs = freqs[1]
freqs = freqs[0]
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
# loop over samples
output = torch.view_as_complex(x.detach().reshape(b, s, n, -1,
2).to(torch.float64))
seq_bucket = [0]
if not type(grid_sizes) is list:
grid_sizes = [grid_sizes]
for g in grid_sizes:
if not type(g) is list:
g = [torch.zeros_like(g), g]
batch_size = g[0].shape[0]
for i in range(batch_size):
if start is None:
f_o, h_o, w_o = g[0][i]
else:
f_o, h_o, w_o = start[i]
f, h, w = g[1][i]
t_f, t_h, t_w = g[2][i]
seq_f, seq_h, seq_w = f - f_o, h - h_o, w - w_o
seq_len = int(seq_f * seq_h * seq_w)
if seq_len > 0:
if t_f > 0:
factor_f, factor_h, factor_w = (t_f / seq_f).item(), (
t_h / seq_h).item(), (t_w / seq_w).item()
# Generate a list of seq_f integers starting from f_o and ending at math.ceil(factor_f * seq_f.item() + f_o.item())
if f_o >= 0:
f_sam = np.linspace(f_o.item(), (t_f + f_o).item() - 1,
seq_f).astype(int).tolist()
else:
f_sam = np.linspace(-f_o.item(),
(-t_f - f_o).item() + 1,
seq_f).astype(int).tolist()
h_sam = np.linspace(h_o.item(), (t_h + h_o).item() - 1,
seq_h).astype(int).tolist()
w_sam = np.linspace(w_o.item(), (t_w + w_o).item() - 1,
seq_w).astype(int).tolist()
assert f_o * f >= 0 and h_o * h >= 0 and w_o * w >= 0
freqs_0 = freqs[0][f_sam] if f_o >= 0 else freqs[0][
f_sam].conj()
freqs_0 = freqs_0.view(seq_f, 1, 1, -1)
freqs_i = torch.cat([
freqs_0.expand(seq_f, seq_h, seq_w, -1),
freqs[1][h_sam].view(1, seq_h, 1, -1).expand(
seq_f, seq_h, seq_w, -1),
freqs[2][w_sam].view(1, 1, seq_w, -1).expand(
seq_f, seq_h, seq_w, -1),
],
dim=-1).reshape(seq_len, 1, -1)
elif t_f < 0:
freqs_i = trainable_freqs.unsqueeze(1)
# apply rotary embedding
output[i, seq_bucket[-1]:seq_bucket[-1] + seq_len] = freqs_i
seq_bucket.append(seq_bucket[-1] + seq_len)
return output
# Modified from transformers.models.t5.modeling_t5
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import logging
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .tokenizers import HuggingfaceTokenizer
__all__ = [
'T5Model',
'T5Encoder',
'T5Decoder',
'T5EncoderModel',
]
def fp16_clamp(x):
if x.dtype == torch.float16 and torch.isinf(x).any():
clamp = torch.finfo(x.dtype).max - 1000
x = torch.clamp(x, min=-clamp, max=clamp)
return x
def init_weights(m):
if isinstance(m, T5LayerNorm):
nn.init.ones_(m.weight)
elif isinstance(m, T5Model):
nn.init.normal_(m.token_embedding.weight, std=1.0)
elif isinstance(m, T5FeedForward):
nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
elif isinstance(m, T5Attention):
nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
nn.init.normal_(m.k.weight, std=m.dim**-0.5)
nn.init.normal_(m.v.weight, std=m.dim**-0.5)
nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
elif isinstance(m, T5RelativeEmbedding):
nn.init.normal_(
m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
class GELU(nn.Module):
def forward(self, x):
return 0.5 * x * (1.0 + torch.tanh(
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
class T5LayerNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super(T5LayerNorm, self).__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
self.eps)
if self.weight.dtype in [torch.float16, torch.bfloat16]:
x = x.type_as(self.weight)
return self.weight * x
class T5Attention(nn.Module):
def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
assert dim_attn % num_heads == 0
super(T5Attention, self).__init__()
self.dim = dim
self.dim_attn = dim_attn
self.num_heads = num_heads
self.head_dim = dim_attn // num_heads
# layers
self.q = nn.Linear(dim, dim_attn, bias=False)
self.k = nn.Linear(dim, dim_attn, bias=False)
self.v = nn.Linear(dim, dim_attn, bias=False)
self.o = nn.Linear(dim_attn, dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x, context=None, mask=None, pos_bias=None):
"""
x: [B, L1, C].
context: [B, L2, C] or None.
mask: [B, L2] or [B, L1, L2] or None.
"""
# check inputs
context = x if context is None else context
b, n, c = x.size(0), self.num_heads, self.head_dim
# compute query, key, value
q = self.q(x).view(b, -1, n, c)
k = self.k(context).view(b, -1, n, c)
v = self.v(context).view(b, -1, n, c)
# attention bias
attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
if pos_bias is not None:
attn_bias += pos_bias
if mask is not None:
assert mask.ndim in [2, 3]
mask = mask.view(b, 1, 1,
-1) if mask.ndim == 2 else mask.unsqueeze(1)
attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
# compute attention (T5 does not use scaling)
attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
x = torch.einsum('bnij,bjnc->binc', attn, v)
# output
x = x.reshape(b, -1, n * c)
x = self.o(x)
x = self.dropout(x)
return x
class T5FeedForward(nn.Module):
def __init__(self, dim, dim_ffn, dropout=0.1):
super(T5FeedForward, self).__init__()
self.dim = dim
self.dim_ffn = dim_ffn
# layers
self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.fc1(x) * self.gate(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class T5SelfAttention(nn.Module):
def __init__(self,
dim,
dim_attn,
dim_ffn,
num_heads,
num_buckets,
shared_pos=True,
dropout=0.1):
super(T5SelfAttention, self).__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.num_buckets = num_buckets
self.shared_pos = shared_pos
# layers
self.norm1 = T5LayerNorm(dim)
self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
self.norm2 = T5LayerNorm(dim)
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
num_buckets, num_heads, bidirectional=True)
def forward(self, x, mask=None, pos_bias=None):
e = pos_bias if self.shared_pos else self.pos_embedding(
x.size(1), x.size(1))
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
x = fp16_clamp(x + self.ffn(self.norm2(x)))
return x
class T5CrossAttention(nn.Module):
def __init__(self,
dim,
dim_attn,
dim_ffn,
num_heads,
num_buckets,
shared_pos=True,
dropout=0.1):
super(T5CrossAttention, self).__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.num_buckets = num_buckets
self.shared_pos = shared_pos
# layers
self.norm1 = T5LayerNorm(dim)
self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
self.norm2 = T5LayerNorm(dim)
self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
self.norm3 = T5LayerNorm(dim)
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
num_buckets, num_heads, bidirectional=False)
def forward(self,
x,
mask=None,
encoder_states=None,
encoder_mask=None,
pos_bias=None):
e = pos_bias if self.shared_pos else self.pos_embedding(
x.size(1), x.size(1))
x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
x = fp16_clamp(x + self.cross_attn(
self.norm2(x), context=encoder_states, mask=encoder_mask))
x = fp16_clamp(x + self.ffn(self.norm3(x)))
return x
class T5RelativeEmbedding(nn.Module):
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
super(T5RelativeEmbedding, self).__init__()
self.num_buckets = num_buckets
self.num_heads = num_heads
self.bidirectional = bidirectional
self.max_dist = max_dist
# layers
self.embedding = nn.Embedding(num_buckets, num_heads)
def forward(self, lq, lk):
device = self.embedding.weight.device
# rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
# torch.arange(lq).unsqueeze(1).to(device)
rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
torch.arange(lq, device=device).unsqueeze(1)
rel_pos = self._relative_position_bucket(rel_pos)
rel_pos_embeds = self.embedding(rel_pos)
rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
0) # [1, N, Lq, Lk]
return rel_pos_embeds.contiguous()
def _relative_position_bucket(self, rel_pos):
# preprocess
if self.bidirectional:
num_buckets = self.num_buckets // 2
rel_buckets = (rel_pos > 0).long() * num_buckets
rel_pos = torch.abs(rel_pos)
else:
num_buckets = self.num_buckets
rel_buckets = 0
rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
# embeddings for small and large positions
max_exact = num_buckets // 2
rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
math.log(self.max_dist / max_exact) *
(num_buckets - max_exact)).long()
rel_pos_large = torch.min(
rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
return rel_buckets
class T5Encoder(nn.Module):
def __init__(self,
vocab,
dim,
dim_attn,
dim_ffn,
num_heads,
num_layers,
num_buckets,
shared_pos=True,
dropout=0.1):
super(T5Encoder, self).__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.num_layers = num_layers
self.num_buckets = num_buckets
self.shared_pos = shared_pos
# layers
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
else nn.Embedding(vocab, dim)
self.pos_embedding = T5RelativeEmbedding(
num_buckets, num_heads, bidirectional=True) if shared_pos else None
self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList([
T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
shared_pos, dropout) for _ in range(num_layers)
])
self.norm = T5LayerNorm(dim)
# initialize weights
self.apply(init_weights)
def forward(self, ids, mask=None):
x = self.token_embedding(ids)
x = self.dropout(x)
e = self.pos_embedding(x.size(1),
x.size(1)) if self.shared_pos else None
for block in self.blocks:
x = block(x, mask, pos_bias=e)
x = self.norm(x)
x = self.dropout(x)
return x
class T5Decoder(nn.Module):
def __init__(self,
vocab,
dim,
dim_attn,
dim_ffn,
num_heads,
num_layers,
num_buckets,
shared_pos=True,
dropout=0.1):
super(T5Decoder, self).__init__()
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.num_layers = num_layers
self.num_buckets = num_buckets
self.shared_pos = shared_pos
# layers
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
else nn.Embedding(vocab, dim)
self.pos_embedding = T5RelativeEmbedding(
num_buckets, num_heads, bidirectional=False) if shared_pos else None
self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList([
T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
shared_pos, dropout) for _ in range(num_layers)
])
self.norm = T5LayerNorm(dim)
# initialize weights
self.apply(init_weights)
def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None):
b, s = ids.size()
# causal mask
if mask is None:
mask = torch.tril(torch.ones(1, s, s).to(ids.device))
elif mask.ndim == 2:
mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1))
# layers
x = self.token_embedding(ids)
x = self.dropout(x)
e = self.pos_embedding(x.size(1),
x.size(1)) if self.shared_pos else None
for block in self.blocks:
x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
x = self.norm(x)
x = self.dropout(x)
return x
class T5Model(nn.Module):
def __init__(self,
vocab_size,
dim,
dim_attn,
dim_ffn,
num_heads,
encoder_layers,
decoder_layers,
num_buckets,
shared_pos=True,
dropout=0.1):
super(T5Model, self).__init__()
self.vocab_size = vocab_size
self.dim = dim
self.dim_attn = dim_attn
self.dim_ffn = dim_ffn
self.num_heads = num_heads
self.encoder_layers = encoder_layers
self.decoder_layers = decoder_layers
self.num_buckets = num_buckets
# layers
self.token_embedding = nn.Embedding(vocab_size, dim)
self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn,
num_heads, encoder_layers, num_buckets,
shared_pos, dropout)
self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn,
num_heads, decoder_layers, num_buckets,
shared_pos, dropout)
self.head = nn.Linear(dim, vocab_size, bias=False)
# initialize weights
self.apply(init_weights)
def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
x = self.encoder(encoder_ids, encoder_mask)
x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
x = self.head(x)
return x
def _t5(name,
encoder_only=False,
decoder_only=False,
return_tokenizer=False,
tokenizer_kwargs={},
dtype=torch.float32,
device='cpu',
**kwargs):
# sanity check
assert not (encoder_only and decoder_only)
# params
if encoder_only:
model_cls = T5Encoder
kwargs['vocab'] = kwargs.pop('vocab_size')
kwargs['num_layers'] = kwargs.pop('encoder_layers')
_ = kwargs.pop('decoder_layers')
elif decoder_only:
model_cls = T5Decoder
kwargs['vocab'] = kwargs.pop('vocab_size')
kwargs['num_layers'] = kwargs.pop('decoder_layers')
_ = kwargs.pop('encoder_layers')
else:
model_cls = T5Model
# init model
with torch.device(device):
model = model_cls(**kwargs)
# set device
model = model.to(dtype=dtype, device=device)
# init tokenizer
if return_tokenizer:
from .tokenizers import HuggingfaceTokenizer
tokenizer = HuggingfaceTokenizer(f'google/{name}', **tokenizer_kwargs)
return model, tokenizer
else:
return model
def umt5_xxl(**kwargs):
cfg = dict(
vocab_size=256384,
dim=4096,
dim_attn=4096,
dim_ffn=10240,
num_heads=64,
encoder_layers=24,
decoder_layers=24,
num_buckets=32,
shared_pos=False,
dropout=0.1)
cfg.update(**kwargs)
return _t5('umt5-xxl', **cfg)
class T5EncoderModel:
def __init__(
self,
text_len,
dtype=torch.bfloat16,
device=torch.cuda.current_device(),
checkpoint_path=None,
tokenizer_path=None,
shard_fn=None,
):
self.text_len = text_len
self.dtype = dtype
self.device = device
self.checkpoint_path = checkpoint_path
self.tokenizer_path = tokenizer_path
# init model
model = umt5_xxl(
encoder_only=True,
return_tokenizer=False,
dtype=dtype,
device=device).eval().requires_grad_(False)
logging.info(f'loading {checkpoint_path}')
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
self.model = model
if shard_fn is not None:
self.model = shard_fn(self.model, sync_module_states=False)
else:
self.model.to(self.device)
# init tokenizer
self.tokenizer = HuggingfaceTokenizer(
name=tokenizer_path, seq_len=text_len, clean='whitespace')
def __call__(self, texts, device):
ids, mask = self.tokenizer(
texts, return_mask=True, add_special_tokens=True)
ids = ids.to(device)
mask = mask.to(device)
seq_lens = mask.gt(0).sum(dim=1).long()
context = self.model(ids, mask)
return [u[:v] for u, v in zip(context, seq_lens)]
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import html
import string
import ftfy
import regex as re
from transformers import AutoTokenizer
__all__ = ['HuggingfaceTokenizer']
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
def canonicalize(text, keep_punctuation_exact_string=None):
text = text.replace('_', ' ')
if keep_punctuation_exact_string:
text = keep_punctuation_exact_string.join(
part.translate(str.maketrans('', '', string.punctuation))
for part in text.split(keep_punctuation_exact_string))
else:
text = text.translate(str.maketrans('', '', string.punctuation))
text = text.lower()
text = re.sub(r'\s+', ' ', text)
return text.strip()
class HuggingfaceTokenizer:
def __init__(self, name, seq_len=None, clean=None, **kwargs):
assert clean in (None, 'whitespace', 'lower', 'canonicalize')
self.name = name
self.seq_len = seq_len
self.clean = clean
# init tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
self.vocab_size = self.tokenizer.vocab_size
def __call__(self, sequence, **kwargs):
return_mask = kwargs.pop('return_mask', False)
# arguments
_kwargs = {'return_tensors': 'pt'}
if self.seq_len is not None:
_kwargs.update({
'padding': 'max_length',
'truncation': True,
'max_length': self.seq_len
})
_kwargs.update(**kwargs)
# tokenization
if isinstance(sequence, str):
sequence = [sequence]
if self.clean:
sequence = [self._clean(u) for u in sequence]
ids = self.tokenizer(sequence, **_kwargs)
# output
if return_mask:
return ids.input_ids, ids.attention_mask
else:
return ids.input_ids
def _clean(self, text):
if self.clean == 'whitespace':
text = whitespace_clean(basic_clean(text))
elif self.clean == 'lower':
text = whitespace_clean(basic_clean(text)).lower()
elif self.clean == 'canonicalize':
text = canonicalize(basic_clean(text))
return text
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment