Commit 1336a33d authored by zzg_666's avatar zzg_666
Browse files

wan2.2

parents
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import logging
import torch
import torch.cuda.amp as amp
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
__all__ = [
'Wan2_1_VAE',
]
CACHE_T = 2
class CausalConv3d(nn.Conv3d):
"""
Causal 3d convolusion.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._padding = (self.padding[2], self.padding[2], self.padding[1],
self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
def forward(self, x, cache_x=None):
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
x = F.pad(x, padding)
return super().forward(x)
class RMS_norm(nn.Module):
def __init__(self, dim, channel_first=True, images=True, bias=False):
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
def forward(self, x):
return F.normalize(
x, dim=(1 if self.channel_first else
-1)) * self.scale * self.gamma + self.bias
class Upsample(nn.Upsample):
def forward(self, x):
"""
Fix bfloat16 support for nearest neighbor interpolation.
"""
return super().forward(x.float()).type_as(x)
class Resample(nn.Module):
def __init__(self, dim, mode):
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
'downsample3d')
super().__init__()
self.dim = dim
self.mode = mode
# layers
if mode == 'upsample2d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
nn.Conv2d(dim, dim // 2, 3, padding=1))
elif mode == 'upsample3d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
nn.Conv2d(dim, dim // 2, 3, padding=1))
self.time_conv = CausalConv3d(
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
elif mode == 'downsample2d':
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
elif mode == 'downsample3d':
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
self.time_conv = CausalConv3d(
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
else:
self.resample = nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
b, c, t, h, w = x.size()
if self.mode == 'upsample3d':
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = 'Rep'
feat_idx[0] += 1
else:
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] != 'Rep':
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] == 'Rep':
cache_x = torch.cat([
torch.zeros_like(cache_x).to(cache_x.device),
cache_x
],
dim=2)
if feat_cache[idx] == 'Rep':
x = self.time_conv(x)
else:
x = self.time_conv(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
x = x.reshape(b, 2, c, t, h, w)
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
3)
x = x.reshape(b, c, t * 2, h, w)
t = x.shape[2]
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = self.resample(x)
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
if self.mode == 'downsample3d':
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = x.clone()
feat_idx[0] += 1
else:
cache_x = x[:, :, -1:, :, :].clone()
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
# # cache last frame of last two chunk
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.time_conv(
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x
feat_idx[0] += 1
return x
def init_weight(self, conv):
conv_weight = conv.weight
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
one_matrix = torch.eye(c1, c2)
init_matrix = one_matrix
nn.init.zeros_(conv_weight)
#conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def init_weight2(self, conv):
conv_weight = conv.weight.data
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
init_matrix = torch.eye(c1 // 2, c2)
#init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
class ResidualBlock(nn.Module):
def __init__(self, in_dim, out_dim, dropout=0.0):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
# layers
self.residual = nn.Sequential(
RMS_norm(in_dim, images=False), nn.SiLU(),
CausalConv3d(in_dim, out_dim, 3, padding=1),
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
CausalConv3d(out_dim, out_dim, 3, padding=1))
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
if in_dim != out_dim else nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
h = self.shortcut(x)
for layer in self.residual:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x + h
class AttentionBlock(nn.Module):
"""
Causal self-attention with a single head.
"""
def __init__(self, dim):
super().__init__()
self.dim = dim
# layers
self.norm = RMS_norm(dim)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
self.proj = nn.Conv2d(dim, dim, 1)
# zero out the last layer params
nn.init.zeros_(self.proj.weight)
def forward(self, x):
identity = x
b, c, t, h, w = x.size()
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = self.norm(x)
# compute query, key, value
q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3,
-1).permute(0, 1, 3,
2).contiguous().chunk(
3, dim=-1)
# apply attention
x = F.scaled_dot_product_attention(
q,
k,
v,
)
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
# output
x = self.proj(x)
x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
return x + identity
class Encoder3d(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
# dimensions
dims = [dim * u for u in [1] + dim_mult]
scale = 1.0
# init block
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
# downsample blocks
downsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
for _ in range(num_res_blocks):
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
downsamples.append(AttentionBlock(out_dim))
in_dim = out_dim
# downsample block
if i != len(dim_mult) - 1:
mode = 'downsample3d' if temperal_downsample[
i] else 'downsample2d'
downsamples.append(Resample(out_dim, mode=mode))
scale /= 2.0
self.downsamples = nn.Sequential(*downsamples)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
ResidualBlock(out_dim, out_dim, dropout))
# output blocks
self.head = nn.Sequential(
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, z_dim, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0]):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
## downsamples
for layer in self.downsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## middle
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## head
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
class Decoder3d(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[False, True, True],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_upsample = temperal_upsample
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
scale = 1.0 / 2**(len(dim_mult) - 2)
# init block
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
ResidualBlock(dims[0], dims[0], dropout))
# upsample blocks
upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
if i == 1 or i == 2 or i == 3:
in_dim = in_dim // 2
for _ in range(num_res_blocks + 1):
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
upsamples.append(AttentionBlock(out_dim))
in_dim = out_dim
# upsample block
if i != len(dim_mult) - 1:
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
upsamples.append(Resample(out_dim, mode=mode))
scale *= 2.0
self.upsamples = nn.Sequential(*upsamples)
# output blocks
self.head = nn.Sequential(
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, 3, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0]):
## conv1
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
## middle
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## upsamples
for layer in self.upsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## head
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
def count_conv3d(model):
count = 0
for m in model.modules():
if isinstance(m, CausalConv3d):
count += 1
return count
class WanVAE_(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1]
# modules
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
attn_scales, self.temperal_downsample, dropout)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
attn_scales, self.temperal_upsample, dropout)
def forward(self, x):
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
x_recon = self.decode(z)
return x_recon, mu, log_var
def encode(self, x, scale):
self.clear_cache()
## cache
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
## 对encode输入的x,按时间拆分为1、4、4、4....
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
out = self.encoder(
x[:, :, :1, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
else:
out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
if isinstance(scale[0], torch.Tensor):
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
1, self.z_dim, 1, 1, 1)
else:
mu = (mu - scale[0]) * scale[1]
self.clear_cache()
return mu
def decode(self, z, scale):
self.clear_cache()
# z: [b,c,t,h,w]
if isinstance(scale[0], torch.Tensor):
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
1, self.z_dim, 1, 1, 1)
else:
z = z / scale[1] + scale[0]
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
else:
out_ = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2)
self.clear_cache()
return out
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return eps * std + mu
def sample(self, imgs, deterministic=False):
mu, log_var = self.encode(imgs)
if deterministic:
return mu
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
return mu + std * torch.randn_like(std)
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
#cache encode
self._enc_conv_num = count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num
def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs):
"""
Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
"""
# params
cfg = dict(
dim=96,
z_dim=z_dim,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[False, True, True],
dropout=0.0)
cfg.update(**kwargs)
# init model
with torch.device('meta'):
model = WanVAE_(**cfg)
# load checkpoint
logging.info(f'loading {pretrained_path}')
model.load_state_dict(
torch.load(pretrained_path, map_location=device), assign=True)
return model
class Wan2_1_VAE:
def __init__(self,
z_dim=16,
vae_pth='cache/vae_step_411000.pth',
dtype=torch.float,
device="cuda"):
self.dtype = dtype
self.device = device
mean = [
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
]
std = [
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
]
self.mean = torch.tensor(mean, dtype=dtype, device=device)
self.std = torch.tensor(std, dtype=dtype, device=device)
self.scale = [self.mean, 1.0 / self.std]
# init model
self.model = _video_vae(
pretrained_path=vae_pth,
z_dim=z_dim,
).eval().requires_grad_(False).to(device)
def encode(self, videos):
"""
videos: A list of videos each with shape [C, T, H, W].
"""
with amp.autocast(dtype=self.dtype):
return [
self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0)
for u in videos
]
def decode(self, zs):
with amp.autocast(dtype=self.dtype):
return [
self.model.decode(u.unsqueeze(0),
self.scale).float().clamp_(-1, 1).squeeze(0)
for u in zs
]
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import logging
import torch
import torch.cuda.amp as amp
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
__all__ = [
"Wan2_2_VAE",
]
CACHE_T = 2
class CausalConv3d(nn.Conv3d):
"""
Causal 3d convolusion.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._padding = (
self.padding[2],
self.padding[2],
self.padding[1],
self.padding[1],
2 * self.padding[0],
0,
)
self.padding = (0, 0, 0)
def forward(self, x, cache_x=None):
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
x = F.pad(x, padding)
return super().forward(x)
class RMS_norm(nn.Module):
def __init__(self, dim, channel_first=True, images=True, bias=False):
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
return (F.normalize(x, dim=(1 if self.channel_first else -1)) *
self.scale * self.gamma + self.bias)
class Upsample(nn.Upsample):
def forward(self, x):
"""
Fix bfloat16 support for nearest neighbor interpolation.
"""
return super().forward(x.float()).type_as(x)
class Resample(nn.Module):
def __init__(self, dim, mode):
assert mode in (
"none",
"upsample2d",
"upsample3d",
"downsample2d",
"downsample3d",
)
super().__init__()
self.dim = dim
self.mode = mode
# layers
if mode == "upsample2d":
self.resample = nn.Sequential(
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
nn.Conv2d(dim, dim, 3, padding=1),
)
elif mode == "upsample3d":
self.resample = nn.Sequential(
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
nn.Conv2d(dim, dim, 3, padding=1),
# nn.Conv2d(dim, dim//2, 3, padding=1)
)
self.time_conv = CausalConv3d(
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
elif mode == "downsample2d":
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
elif mode == "downsample3d":
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
self.time_conv = CausalConv3d(
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
else:
self.resample = nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
b, c, t, h, w = x.size()
if self.mode == "upsample3d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = "Rep"
feat_idx[0] += 1
else:
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and
feat_cache[idx] != "Rep"):
# cache last frame of last two chunk
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device),
cache_x,
],
dim=2,
)
if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and
feat_cache[idx] == "Rep"):
cache_x = torch.cat(
[
torch.zeros_like(cache_x).to(cache_x.device),
cache_x
],
dim=2,
)
if feat_cache[idx] == "Rep":
x = self.time_conv(x)
else:
x = self.time_conv(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
x = x.reshape(b, 2, c, t, h, w)
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
3)
x = x.reshape(b, c, t * 2, h, w)
t = x.shape[2]
x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.resample(x)
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
if self.mode == "downsample3d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = x.clone()
feat_idx[0] += 1
else:
cache_x = x[:, :, -1:, :, :].clone()
x = self.time_conv(
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x
feat_idx[0] += 1
return x
def init_weight(self, conv):
conv_weight = conv.weight.detach().clone()
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
one_matrix = torch.eye(c1, c2)
init_matrix = one_matrix
nn.init.zeros_(conv_weight)
conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5
conv.weight = nn.Parameter(conv_weight)
nn.init.zeros_(conv.bias.data)
def init_weight2(self, conv):
conv_weight = conv.weight.data.detach().clone()
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
init_matrix = torch.eye(c1 // 2, c2)
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
conv.weight = nn.Parameter(conv_weight)
nn.init.zeros_(conv.bias.data)
class ResidualBlock(nn.Module):
def __init__(self, in_dim, out_dim, dropout=0.0):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
# layers
self.residual = nn.Sequential(
RMS_norm(in_dim, images=False),
nn.SiLU(),
CausalConv3d(in_dim, out_dim, 3, padding=1),
RMS_norm(out_dim, images=False),
nn.SiLU(),
nn.Dropout(dropout),
CausalConv3d(out_dim, out_dim, 3, padding=1),
)
self.shortcut = (
CausalConv3d(in_dim, out_dim, 1)
if in_dim != out_dim else nn.Identity())
def forward(self, x, feat_cache=None, feat_idx=[0]):
h = self.shortcut(x)
for layer in self.residual:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device),
cache_x,
],
dim=2,
)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x + h
class AttentionBlock(nn.Module):
"""
Causal self-attention with a single head.
"""
def __init__(self, dim):
super().__init__()
self.dim = dim
# layers
self.norm = RMS_norm(dim)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
self.proj = nn.Conv2d(dim, dim, 1)
# zero out the last layer params
nn.init.zeros_(self.proj.weight)
def forward(self, x):
identity = x
b, c, t, h, w = x.size()
x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.norm(x)
# compute query, key, value
q, k, v = (
self.to_qkv(x).reshape(b * t, 1, c * 3,
-1).permute(0, 1, 3,
2).contiguous().chunk(3, dim=-1))
# apply attention
x = F.scaled_dot_product_attention(
q,
k,
v,
)
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
# output
x = self.proj(x)
x = rearrange(x, "(b t) c h w-> b c t h w", t=t)
return x + identity
def patchify(x, patch_size):
if patch_size == 1:
return x
if x.dim() == 4:
x = rearrange(
x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
elif x.dim() == 5:
x = rearrange(
x,
"b c f (h q) (w r) -> b (c r q) f h w",
q=patch_size,
r=patch_size,
)
else:
raise ValueError(f"Invalid input shape: {x.shape}")
return x
def unpatchify(x, patch_size):
if patch_size == 1:
return x
if x.dim() == 4:
x = rearrange(
x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
elif x.dim() == 5:
x = rearrange(
x,
"b (c r q) f h w -> b c f (h q) (w r)",
q=patch_size,
r=patch_size,
)
return x
class AvgDown3D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
factor_t,
factor_s=1,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.factor_t = factor_t
self.factor_s = factor_s
self.factor = self.factor_t * self.factor_s * self.factor_s
assert in_channels * self.factor % out_channels == 0
self.group_size = in_channels * self.factor // out_channels
def forward(self, x: torch.Tensor) -> torch.Tensor:
pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
pad = (0, 0, 0, 0, pad_t, 0)
x = F.pad(x, pad)
B, C, T, H, W = x.shape
x = x.view(
B,
C,
T // self.factor_t,
self.factor_t,
H // self.factor_s,
self.factor_s,
W // self.factor_s,
self.factor_s,
)
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
x = x.view(
B,
C * self.factor,
T // self.factor_t,
H // self.factor_s,
W // self.factor_s,
)
x = x.view(
B,
self.out_channels,
self.group_size,
T // self.factor_t,
H // self.factor_s,
W // self.factor_s,
)
x = x.mean(dim=2)
return x
class DupUp3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
factor_t,
factor_s=1,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.factor_t = factor_t
self.factor_s = factor_s
self.factor = self.factor_t * self.factor_s * self.factor_s
assert out_channels * self.factor % in_channels == 0
self.repeats = out_channels * self.factor // in_channels
def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
x = x.repeat_interleave(self.repeats, dim=1)
x = x.view(
x.size(0),
self.out_channels,
self.factor_t,
self.factor_s,
self.factor_s,
x.size(2),
x.size(3),
x.size(4),
)
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
x = x.view(
x.size(0),
self.out_channels,
x.size(2) * self.factor_t,
x.size(4) * self.factor_s,
x.size(6) * self.factor_s,
)
if first_chunk:
x = x[:, :, self.factor_t - 1:, :, :]
return x
class Down_ResidualBlock(nn.Module):
def __init__(self,
in_dim,
out_dim,
dropout,
mult,
temperal_downsample=False,
down_flag=False):
super().__init__()
# Shortcut path with downsample
self.avg_shortcut = AvgDown3D(
in_dim,
out_dim,
factor_t=2 if temperal_downsample else 1,
factor_s=2 if down_flag else 1,
)
# Main path with residual blocks and downsample
downsamples = []
for _ in range(mult):
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
in_dim = out_dim
# Add the final downsample block
if down_flag:
mode = "downsample3d" if temperal_downsample else "downsample2d"
downsamples.append(Resample(out_dim, mode=mode))
self.downsamples = nn.Sequential(*downsamples)
def forward(self, x, feat_cache=None, feat_idx=[0]):
x_copy = x.clone()
for module in self.downsamples:
x = module(x, feat_cache, feat_idx)
return x + self.avg_shortcut(x_copy)
class Up_ResidualBlock(nn.Module):
def __init__(self,
in_dim,
out_dim,
dropout,
mult,
temperal_upsample=False,
up_flag=False):
super().__init__()
# Shortcut path with upsample
if up_flag:
self.avg_shortcut = DupUp3D(
in_dim,
out_dim,
factor_t=2 if temperal_upsample else 1,
factor_s=2 if up_flag else 1,
)
else:
self.avg_shortcut = None
# Main path with residual blocks and upsample
upsamples = []
for _ in range(mult):
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
in_dim = out_dim
# Add the final upsample block
if up_flag:
mode = "upsample3d" if temperal_upsample else "upsample2d"
upsamples.append(Resample(out_dim, mode=mode))
self.upsamples = nn.Sequential(*upsamples)
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
x_main = x.clone()
for module in self.upsamples:
x_main = module(x_main, feat_cache, feat_idx)
if self.avg_shortcut is not None:
x_shortcut = self.avg_shortcut(x, first_chunk)
return x_main + x_shortcut
else:
return x_main
class Encoder3d(nn.Module):
def __init__(
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0,
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
# dimensions
dims = [dim * u for u in [1] + dim_mult]
scale = 1.0
# init block
self.conv1 = CausalConv3d(12, dims[0], 3, padding=1)
# downsample blocks
downsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
t_down_flag = (
temperal_downsample[i]
if i < len(temperal_downsample) else False)
downsamples.append(
Down_ResidualBlock(
in_dim=in_dim,
out_dim=out_dim,
dropout=dropout,
mult=num_res_blocks,
temperal_downsample=t_down_flag,
down_flag=i != len(dim_mult) - 1,
))
scale /= 2.0
self.downsamples = nn.Sequential(*downsamples)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(out_dim, out_dim, dropout),
AttentionBlock(out_dim),
ResidualBlock(out_dim, out_dim, dropout),
)
# # output blocks
self.head = nn.Sequential(
RMS_norm(out_dim, images=False),
nn.SiLU(),
CausalConv3d(out_dim, z_dim, 3, padding=1),
)
def forward(self, x, feat_cache=None, feat_idx=[0]):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device),
cache_x,
],
dim=2,
)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
## downsamples
for layer in self.downsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## middle
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## head
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device),
cache_x,
],
dim=2,
)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
class Decoder3d(nn.Module):
def __init__(
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[False, True, True],
dropout=0.0,
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_upsample = temperal_upsample
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
scale = 1.0 / 2**(len(dim_mult) - 2)
# init block
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(dims[0], dims[0], dropout),
AttentionBlock(dims[0]),
ResidualBlock(dims[0], dims[0], dropout),
)
# upsample blocks
upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
t_up_flag = temperal_upsample[i] if i < len(
temperal_upsample) else False
upsamples.append(
Up_ResidualBlock(
in_dim=in_dim,
out_dim=out_dim,
dropout=dropout,
mult=num_res_blocks + 1,
temperal_upsample=t_up_flag,
up_flag=i != len(dim_mult) - 1,
))
self.upsamples = nn.Sequential(*upsamples)
# output blocks
self.head = nn.Sequential(
RMS_norm(out_dim, images=False),
nn.SiLU(),
CausalConv3d(out_dim, 12, 3, padding=1),
)
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device),
cache_x,
],
dim=2,
)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## upsamples
for layer in self.upsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx, first_chunk)
else:
x = layer(x)
## head
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device),
cache_x,
],
dim=2,
)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
def count_conv3d(model):
count = 0
for m in model.modules():
if isinstance(m, CausalConv3d):
count += 1
return count
class WanVAE_(nn.Module):
def __init__(
self,
dim=160,
dec_dim=256,
z_dim=16,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0,
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1]
# modules
self.encoder = Encoder3d(
dim,
z_dim * 2,
dim_mult,
num_res_blocks,
attn_scales,
self.temperal_downsample,
dropout,
)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(
dec_dim,
z_dim,
dim_mult,
num_res_blocks,
attn_scales,
self.temperal_upsample,
dropout,
)
def forward(self, x, scale=[0, 1]):
mu = self.encode(x, scale)
x_recon = self.decode(mu, scale)
return x_recon, mu
def encode(self, x, scale):
self.clear_cache()
x = patchify(x, patch_size=2)
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
out = self.encoder(
x[:, :, :1, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx,
)
else:
out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx,
)
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
if isinstance(scale[0], torch.Tensor):
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
1, self.z_dim, 1, 1, 1)
else:
mu = (mu - scale[0]) * scale[1]
self.clear_cache()
return mu
def decode(self, z, scale):
self.clear_cache()
if isinstance(scale[0], torch.Tensor):
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
1, self.z_dim, 1, 1, 1)
else:
z = z / scale[1] + scale[0]
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx,
first_chunk=True,
)
else:
out_ = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx,
)
out = torch.cat([out, out_], 2)
out = unpatchify(out, patch_size=2)
self.clear_cache()
return out
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return eps * std + mu
def sample(self, imgs, deterministic=False):
mu, log_var = self.encode(imgs)
if deterministic:
return mu
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
return mu + std * torch.randn_like(std)
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
# cache encode
self._enc_conv_num = count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num
def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", **kwargs):
# params
cfg = dict(
dim=dim,
z_dim=z_dim,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, True],
dropout=0.0,
)
cfg.update(**kwargs)
# init model
with torch.device("meta"):
model = WanVAE_(**cfg)
# load checkpoint
logging.info(f"loading {pretrained_path}")
model.load_state_dict(
torch.load(pretrained_path, map_location=device), assign=True)
return model
class Wan2_2_VAE:
def __init__(
self,
z_dim=48,
c_dim=160,
vae_pth=None,
dim_mult=[1, 2, 4, 4],
temperal_downsample=[False, True, True],
dtype=torch.float,
device="cuda",
):
self.dtype = dtype
self.device = device
mean = torch.tensor(
[
-0.2289,
-0.0052,
-0.1323,
-0.2339,
-0.2799,
0.0174,
0.1838,
0.1557,
-0.1382,
0.0542,
0.2813,
0.0891,
0.1570,
-0.0098,
0.0375,
-0.1825,
-0.2246,
-0.1207,
-0.0698,
0.5109,
0.2665,
-0.2108,
-0.2158,
0.2502,
-0.2055,
-0.0322,
0.1109,
0.1567,
-0.0729,
0.0899,
-0.2799,
-0.1230,
-0.0313,
-0.1649,
0.0117,
0.0723,
-0.2839,
-0.2083,
-0.0520,
0.3748,
0.0152,
0.1957,
0.1433,
-0.2944,
0.3573,
-0.0548,
-0.1681,
-0.0667,
],
dtype=dtype,
device=device,
)
std = torch.tensor(
[
0.4765,
1.0364,
0.4514,
1.1677,
0.5313,
0.4990,
0.4818,
0.5013,
0.8158,
1.0344,
0.5894,
1.0901,
0.6885,
0.6165,
0.8454,
0.4978,
0.5759,
0.3523,
0.7135,
0.6804,
0.5833,
1.4146,
0.8986,
0.5659,
0.7069,
0.5338,
0.4889,
0.4917,
0.4069,
0.4999,
0.6866,
0.4093,
0.5709,
0.6065,
0.6415,
0.4944,
0.5726,
1.2042,
0.5458,
1.6887,
0.3971,
1.0600,
0.3943,
0.5537,
0.5444,
0.4089,
0.7468,
0.7744,
],
dtype=dtype,
device=device,
)
self.scale = [mean, 1.0 / std]
# init model
self.model = (
_video_vae(
pretrained_path=vae_pth,
z_dim=z_dim,
dim=c_dim,
dim_mult=dim_mult,
temperal_downsample=temperal_downsample,
).eval().requires_grad_(False).to(device))
def encode(self, videos):
try:
if not isinstance(videos, list):
raise TypeError("videos should be a list")
with amp.autocast(dtype=self.dtype):
return [
self.model.encode(u.unsqueeze(0),
self.scale).float().squeeze(0)
for u in videos
]
except TypeError as e:
logging.info(e)
return None
def decode(self, zs):
try:
if not isinstance(zs, list):
raise TypeError("zs should be a list")
with amp.autocast(dtype=self.dtype):
return [
self.model.decode(u.unsqueeze(0),
self.scale).float().clamp_(-1,
1).squeeze(0)
for u in zs
]
except TypeError as e:
logging.info(e)
return None
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import gc
import logging
import math
import os
import random
import sys
import types
from contextlib import contextmanager
from copy import deepcopy
from functools import partial
import numpy as np
import torch
import torch.cuda.amp as amp
import torch.distributed as dist
import torchvision.transforms.functional as TF
from decord import VideoReader
from PIL import Image
from safetensors import safe_open
from torchvision import transforms
from tqdm import tqdm
from .distributed.fsdp import shard_model
from .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward
from .distributed.util import get_world_size
from .modules.s2v.audio_encoder import AudioEncoder
from .modules.s2v.model_s2v import WanModel_S2V, sp_attn_forward_s2v
from .modules.t5 import T5EncoderModel
from .modules.vae2_1 import Wan2_1_VAE
from .utils.fm_solvers import (
FlowDPMSolverMultistepScheduler,
get_sampling_sigmas,
retrieve_timesteps,
)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
def load_safetensors(path):
tensors = {}
with safe_open(path, framework="pt", device="cpu") as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
return tensors
class WanS2V:
def __init__(
self,
config,
checkpoint_dir,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_sp=False,
t5_cpu=False,
init_on_cpu=True,
convert_model_dtype=False,
):
r"""
Initializes the image-to-video generation model components.
Args:
config (EasyDict):
Object containing model parameters initialized from config.py
checkpoint_dir (`str`):
Path to directory containing model checkpoints
device_id (`int`, *optional*, defaults to 0):
Id of target GPU device
rank (`int`, *optional*, defaults to 0):
Process rank for distributed training
t5_fsdp (`bool`, *optional*, defaults to False):
Enable FSDP sharding for T5 model
dit_fsdp (`bool`, *optional*, defaults to False):
Enable FSDP sharding for DiT model
use_sp (`bool`, *optional*, defaults to False):
Enable distribution strategy of sequence parallel.
t5_cpu (`bool`, *optional*, defaults to False):
Whether to place T5 model on CPU. Only works without t5_fsdp.
init_on_cpu (`bool`, *optional*, defaults to True):
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
convert_model_dtype (`bool`, *optional*, defaults to False):
Convert DiT model parameters dtype to 'config.param_dtype'.
Only works without FSDP.
"""
self.device = torch.device(f"cuda:{device_id}")
self.config = config
self.rank = rank
self.t5_cpu = t5_cpu
self.init_on_cpu = init_on_cpu
self.num_train_timesteps = config.num_train_timesteps
self.param_dtype = config.param_dtype
if t5_fsdp or dit_fsdp or use_sp:
self.init_on_cpu = False
shard_fn = partial(shard_model, device_id=device_id)
self.text_encoder = T5EncoderModel(
text_len=config.text_len,
dtype=config.t5_dtype,
device=torch.device('cpu'),
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
shard_fn=shard_fn if t5_fsdp else None,
)
self.vae = Wan2_1_VAE(
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
device=self.device)
logging.info(f"Creating WanModel from {checkpoint_dir}")
if not dit_fsdp:
self.noise_model = WanModel_S2V.from_pretrained(
checkpoint_dir,
torch_dtype=self.param_dtype,
device_map=self.device)
else:
self.noise_model = WanModel_S2V.from_pretrained(
checkpoint_dir, torch_dtype=self.param_dtype)
self.noise_model = self._configure_model(
model=self.noise_model,
use_sp=use_sp,
dit_fsdp=dit_fsdp,
shard_fn=shard_fn,
convert_model_dtype=convert_model_dtype)
self.audio_encoder = AudioEncoder(
model_id=os.path.join(checkpoint_dir,
"wav2vec2-large-xlsr-53-english"))
if use_sp:
self.sp_size = get_world_size()
else:
self.sp_size = 1
self.sample_neg_prompt = config.sample_neg_prompt
self.motion_frames = config.transformer.motion_frames
self.drop_first_motion = config.drop_first_motion
self.fps = config.sample_fps
self.audio_sample_m = 0
def _configure_model(self, model, use_sp, dit_fsdp, shard_fn,
convert_model_dtype):
"""
Configures a model object. This includes setting evaluation modes,
applying distributed parallel strategy, and handling device placement.
Args:
model (torch.nn.Module):
The model instance to configure.
use_sp (`bool`):
Enable distribution strategy of sequence parallel.
dit_fsdp (`bool`):
Enable FSDP sharding for DiT model.
shard_fn (callable):
The function to apply FSDP sharding.
convert_model_dtype (`bool`):
Convert DiT model parameters dtype to 'config.param_dtype'.
Only works without FSDP.
Returns:
torch.nn.Module:
The configured model.
"""
model.eval().requires_grad_(False)
if use_sp:
for block in model.blocks:
block.self_attn.forward = types.MethodType(
sp_attn_forward_s2v, block.self_attn)
model.use_context_parallel = True
if dist.is_initialized():
dist.barrier()
if dit_fsdp:
model = shard_fn(model)
else:
if convert_model_dtype:
model.to(self.param_dtype)
if not self.init_on_cpu:
model.to(self.device)
return model
def get_size_less_than_area(self,
height,
width,
target_area=1024 * 704,
divisor=64):
if height * width <= target_area:
# If the original image area is already less than or equal to the target,
# no resizing is needed—just padding. Still need to ensure that the padded area doesn't exceed the target.
max_upper_area = target_area
min_scale = 0.1
max_scale = 1.0
else:
# Resize to fit within the target area and then pad to multiples of `divisor`
max_upper_area = target_area # Maximum allowed total pixel count after padding
d = divisor - 1
b = d * (height + width)
a = height * width
c = d**2 - max_upper_area
# Calculate scale boundaries using quadratic equation
min_scale = (-b + math.sqrt(b**2 - 2 * a * c)) / (
2 * a) # Scale when maximum padding is applied
max_scale = math.sqrt(max_upper_area /
(height * width)) # Scale without any padding
# We want to choose the largest possible scale such that the final padded area does not exceed max_upper_area
# Use binary search-like iteration to find this scale
find_it = False
for i in range(100):
scale = max_scale - (max_scale - min_scale) * i / 100
new_height, new_width = int(height * scale), int(width * scale)
# Pad to make dimensions divisible by 64
pad_height = (64 - new_height % 64) % 64
pad_width = (64 - new_width % 64) % 64
pad_top = pad_height // 2
pad_bottom = pad_height - pad_top
pad_left = pad_width // 2
pad_right = pad_width - pad_left
padded_height, padded_width = new_height + pad_height, new_width + pad_width
if padded_height * padded_width <= max_upper_area:
find_it = True
break
if find_it:
return padded_height, padded_width
else:
# Fallback: calculate target dimensions based on aspect ratio and divisor alignment
aspect_ratio = width / height
target_width = int(
(target_area * aspect_ratio)**0.5 // divisor * divisor)
target_height = int(
(target_area / aspect_ratio)**0.5 // divisor * divisor)
# Ensure the result is not larger than the original resolution
if target_width >= width or target_height >= height:
target_width = int(width // divisor * divisor)
target_height = int(height // divisor * divisor)
return target_height, target_width
def prepare_default_cond_input(self,
map_shape=[3, 12, 64, 64],
motion_frames=5,
lat_motion_frames=2,
enable_mano=False,
enable_kp=False,
enable_pose=False):
default_value = [1.0, -1.0, -1.0]
cond_enable = [enable_mano, enable_kp, enable_pose]
cond = []
for d, c in zip(default_value, cond_enable):
if c:
map_value = torch.ones(
map_shape, dtype=self.param_dtype, device=self.device) * d
cond_lat = torch.cat([
map_value[:, :, 0:1].repeat(1, 1, motion_frames, 1, 1),
map_value
],
dim=2)
cond_lat = torch.stack(
self.vae.encode(cond_lat.to(
self.param_dtype)))[:, :, lat_motion_frames:].to(
self.param_dtype)
cond.append(cond_lat)
if len(cond) >= 1:
cond = torch.cat(cond, dim=1)
else:
cond = None
return cond
def encode_audio(self, audio_path, infer_frames):
z = self.audio_encoder.extract_audio_feat(
audio_path, return_all_layers=True)
audio_embed_bucket, num_repeat = self.audio_encoder.get_audio_embed_bucket_fps(
z, fps=self.fps, batch_frames=infer_frames, m=self.audio_sample_m)
audio_embed_bucket = audio_embed_bucket.to(self.device,
self.param_dtype)
audio_embed_bucket = audio_embed_bucket.unsqueeze(0)
if len(audio_embed_bucket.shape) == 3:
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1)
elif len(audio_embed_bucket.shape) == 4:
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1)
return audio_embed_bucket, num_repeat
def read_last_n_frames(self,
video_path,
n_frames,
target_fps=16,
reverse=False):
"""
Read the last `n_frames` from a video at the specified frame rate.
Parameters:
video_path (str): Path to the video file.
n_frames (int): Number of frames to read.
target_fps (int, optional): Target sampling frame rate. Defaults to 16.
reverse (bool, optional): Whether to read frames in reverse order.
If True, reads the first `n_frames` instead of the last ones.
Returns:
np.ndarray: A NumPy array of shape [n_frames, H, W, 3], representing the sampled video frames.
"""
vr = VideoReader(video_path)
original_fps = vr.get_avg_fps()
total_frames = len(vr)
interval = max(1, round(original_fps / target_fps))
required_span = (n_frames - 1) * interval
start_frame = max(0, total_frames - required_span -
1) if not reverse else 0
sampled_indices = []
for i in range(n_frames):
indice = start_frame + i * interval
if indice >= total_frames:
break
else:
sampled_indices.append(indice)
return vr.get_batch(sampled_indices).asnumpy()
def load_pose_cond(self, pose_video, num_repeat, infer_frames, size):
HEIGHT, WIDTH = size
if not pose_video is None:
pose_seq = self.read_last_n_frames(
pose_video,
n_frames=infer_frames * num_repeat,
target_fps=self.fps,
reverse=True)
resize_opreat = transforms.Resize(min(HEIGHT, WIDTH))
crop_opreat = transforms.CenterCrop((HEIGHT, WIDTH))
tensor_trans = transforms.ToTensor()
cond_tensor = torch.from_numpy(pose_seq)
cond_tensor = cond_tensor.permute(0, 3, 1, 2) / 255.0 * 2 - 1.0
cond_tensor = crop_opreat(resize_opreat(cond_tensor)).permute(
1, 0, 2, 3).unsqueeze(0)
padding_frame_num = num_repeat * infer_frames - cond_tensor.shape[2]
cond_tensor = torch.cat([
cond_tensor,
- torch.ones([1, 3, padding_frame_num, HEIGHT, WIDTH])
],
dim=2)
cond_tensors = torch.chunk(cond_tensor, num_repeat, dim=2)
else:
cond_tensors = [-torch.ones([1, 3, infer_frames, HEIGHT, WIDTH])]
COND = []
for r in range(len(cond_tensors)):
cond = cond_tensors[r]
cond = torch.cat([cond[:, :, 0:1].repeat(1, 1, 1, 1, 1), cond],
dim=2)
cond_lat = torch.stack(
self.vae.encode(
cond.to(dtype=self.param_dtype,
device=self.device)))[:, :,
1:].cpu() # for mem save
COND.append(cond_lat)
return COND
def get_gen_size(self, size, max_area, ref_image_path, pre_video_path):
if not size is None:
HEIGHT, WIDTH = size
else:
if pre_video_path:
ref_image = self.read_last_n_frames(
pre_video_path, n_frames=1)[0]
else:
ref_image = np.array(Image.open(ref_image_path).convert('RGB'))
HEIGHT, WIDTH = ref_image.shape[:2]
HEIGHT, WIDTH = self.get_size_less_than_area(
HEIGHT, WIDTH, target_area=max_area)
return (HEIGHT, WIDTH)
def generate(
self,
input_prompt,
ref_image_path,
audio_path,
enable_tts,
tts_prompt_audio,
tts_prompt_text,
tts_text,
num_repeat=1,
pose_video=None,
max_area=720 * 1280,
infer_frames=80,
shift=5.0,
sample_solver='unipc',
sampling_steps=40,
guide_scale=5.0,
n_prompt="",
seed=-1,
offload_model=True,
init_first_frame=False,
):
r"""
Generates video frames from input image and text prompt using diffusion process.
Args:
input_prompt (`str`):
Text prompt for content generation.
ref_image_path ('str'):
Input image path
audio_path ('str'):
Audio for video driven
num_repeat ('int'):
Number of clips to generate; will be automatically adjusted based on the audio length
pose_video ('str'):
If provided, uses a sequence of poses to drive the generated video
max_area (`int`, *optional*, defaults to 720*1280):
Maximum pixel area for latent space calculation. Controls video resolution scaling
infer_frames (`int`, *optional*, defaults to 80):
How many frames to generate per clips. The number should be 4n
shift (`float`, *optional*, defaults to 5.0):
Noise schedule shift parameter. Affects temporal dynamics
[NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.
sample_solver (`str`, *optional*, defaults to 'unipc'):
Solver used to sample the video.
sampling_steps (`int`, *optional*, defaults to 40):
Number of diffusion sampling steps. Higher values improve quality but slow generation
guide_scale (`float` or tuple[`float`], *optional*, defaults 5.0):
Classifier-free guidance scale. Controls prompt adherence vs. creativity.
If tuple, the first guide_scale will be used for low noise model and
the second guide_scale will be used for high noise model.
n_prompt (`str`, *optional*, defaults to ""):
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
seed (`int`, *optional*, defaults to -1):
Random seed for noise generation. If -1, use random seed
offload_model (`bool`, *optional*, defaults to True):
If True, offloads models to CPU during generation to save VRAM
init_first_frame (`bool`, *optional*, defaults to False):
Whether to use the reference image as the first frame (i.e., standard image-to-video generation)
Returns:
torch.Tensor:
Generated video frames tensor. Dimensions: (C, N H, W) where:
- C: Color channels (3 for RGB)
- N: Number of frames (81)
- H: Frame height (from max_area)
- W: Frame width from max_area)
"""
# preprocess
size = self.get_gen_size(
size=None,
max_area=max_area,
ref_image_path=ref_image_path,
pre_video_path=None)
HEIGHT, WIDTH = size
channel = 3
resize_opreat = transforms.Resize(min(HEIGHT, WIDTH))
crop_opreat = transforms.CenterCrop((HEIGHT, WIDTH))
tensor_trans = transforms.ToTensor()
ref_image = None
motion_latents = None
if ref_image is None:
ref_image = np.array(Image.open(ref_image_path).convert('RGB'))
if motion_latents is None:
motion_latents = torch.zeros(
[1, channel, self.motion_frames, HEIGHT, WIDTH],
dtype=self.param_dtype,
device=self.device)
# extract audio emb
if enable_tts is True:
audio_path = self.tts(tts_prompt_audio, tts_prompt_text, tts_text)
audio_emb, nr = self.encode_audio(audio_path, infer_frames=infer_frames)
if num_repeat is None or num_repeat > nr:
num_repeat = nr
lat_motion_frames = (self.motion_frames + 3) // 4
model_pic = crop_opreat(resize_opreat(Image.fromarray(ref_image)))
ref_pixel_values = tensor_trans(model_pic)
ref_pixel_values = ref_pixel_values.unsqueeze(1).unsqueeze(
0) * 2 - 1.0 # b c 1 h w
ref_pixel_values = ref_pixel_values.to(
dtype=self.vae.dtype, device=self.vae.device)
ref_latents = torch.stack(self.vae.encode(ref_pixel_values))
# encode the motion latents
videos_last_frames = motion_latents.detach()
drop_first_motion = self.drop_first_motion
if init_first_frame:
drop_first_motion = False
motion_latents[:, :, -6:] = ref_pixel_values
motion_latents = torch.stack(self.vae.encode(motion_latents))
# get pose cond input if need
COND = self.load_pose_cond(
pose_video=pose_video,
num_repeat=num_repeat,
infer_frames=infer_frames,
size=size)
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
if n_prompt == "":
n_prompt = self.sample_neg_prompt
# preprocess
if not self.t5_cpu:
self.text_encoder.model.to(self.device)
context = self.text_encoder([input_prompt], self.device)
context_null = self.text_encoder([n_prompt], self.device)
if offload_model:
self.text_encoder.model.cpu()
else:
context = self.text_encoder([input_prompt], torch.device('cpu'))
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
context = [t.to(self.device) for t in context]
context_null = [t.to(self.device) for t in context_null]
out = []
# evaluation mode
with (
torch.amp.autocast('cuda', dtype=self.param_dtype),
torch.no_grad(),
):
for r in range(num_repeat):
seed_g = torch.Generator(device=self.device)
seed_g.manual_seed(seed + r)
lat_target_frames = (infer_frames + 3 + self.motion_frames
) // 4 - lat_motion_frames
target_shape = [lat_target_frames, HEIGHT // 8, WIDTH // 8]
noise = [
torch.randn(
16,
target_shape[0],
target_shape[1],
target_shape[2],
dtype=self.param_dtype,
device=self.device,
generator=seed_g)
]
max_seq_len = np.prod(target_shape) // 4
if sample_solver == 'unipc':
sample_scheduler = FlowUniPCMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sample_scheduler.set_timesteps(
sampling_steps, device=self.device, shift=shift)
timesteps = sample_scheduler.timesteps
elif sample_solver == 'dpm++':
sample_scheduler = FlowDPMSolverMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
timesteps, _ = retrieve_timesteps(
sample_scheduler,
device=self.device,
sigmas=sampling_sigmas)
else:
raise NotImplementedError("Unsupported solver.")
latents = deepcopy(noise)
with torch.no_grad():
left_idx = r * infer_frames
right_idx = r * infer_frames + infer_frames
cond_latents = COND[r] if pose_video else COND[0] * 0
cond_latents = cond_latents.to(
dtype=self.param_dtype, device=self.device)
audio_input = audio_emb[..., left_idx:right_idx]
input_motion_latents = motion_latents.clone()
arg_c = {
'context': context[0:1],
'seq_len': max_seq_len,
'cond_states': cond_latents,
"motion_latents": input_motion_latents,
'ref_latents': ref_latents,
"audio_input": audio_input,
"motion_frames": [self.motion_frames, lat_motion_frames],
"drop_motion_frames": drop_first_motion and r == 0,
}
if guide_scale > 1:
arg_null = {
'context': context_null[0:1],
'seq_len': max_seq_len,
'cond_states': cond_latents,
"motion_latents": input_motion_latents,
'ref_latents': ref_latents,
"audio_input": 0.0 * audio_input,
"motion_frames": [
self.motion_frames, lat_motion_frames
],
"drop_motion_frames": drop_first_motion and r == 0,
}
if offload_model or self.init_on_cpu:
self.noise_model.to(self.device)
torch.cuda.empty_cache()
for i, t in enumerate(tqdm(timesteps)):
latent_model_input = latents[0:1]
timestep = [t]
timestep = torch.stack(timestep).to(self.device)
noise_pred_cond = self.noise_model(
latent_model_input, t=timestep, **arg_c)
if guide_scale > 1:
noise_pred_uncond = self.noise_model(
latent_model_input, t=timestep, **arg_null)
noise_pred = [
u + guide_scale * (c - u)
for c, u in zip(noise_pred_cond, noise_pred_uncond)
]
else:
noise_pred = noise_pred_cond
temp_x0 = sample_scheduler.step(
noise_pred[0].unsqueeze(0),
t,
latents[0].unsqueeze(0),
return_dict=False,
generator=seed_g)[0]
latents[0] = temp_x0.squeeze(0)
if offload_model:
self.noise_model.cpu()
torch.cuda.synchronize()
torch.cuda.empty_cache()
latents = torch.stack(latents)
if not (drop_first_motion and r == 0):
decode_latents = torch.cat([motion_latents, latents], dim=2)
else:
decode_latents = torch.cat([ref_latents, latents], dim=2)
image = torch.stack(self.vae.decode(decode_latents))
image = image[:, :, -(infer_frames):]
if (drop_first_motion and r == 0):
image = image[:, :, 3:]
overlap_frames_num = min(self.motion_frames, image.shape[2])
videos_last_frames = torch.cat([
videos_last_frames[:, :, overlap_frames_num:],
image[:, :, -overlap_frames_num:]
],
dim=2)
videos_last_frames = videos_last_frames.to(
dtype=motion_latents.dtype, device=motion_latents.device)
motion_latents = torch.stack(
self.vae.encode(videos_last_frames))
out.append(image.cpu())
videos = torch.cat(out, dim=2)
del noise, latents
del sample_scheduler
if offload_model:
gc.collect()
torch.cuda.synchronize()
if dist.is_initialized():
dist.barrier()
return videos[0] if self.rank == 0 else None
def tts(self, tts_prompt_audio, tts_prompt_text, tts_text):
if not hasattr(self, 'cosyvoice'):
self.load_tts()
speech_list = []
from cosyvoice.utils.file_utils import load_wav
import torchaudio
prompt_speech_16k = load_wav(tts_prompt_audio, 16000)
if tts_prompt_text is not None:
for i in self.cosyvoice.inference_zero_shot(tts_text, tts_prompt_text, prompt_speech_16k):
speech_list.append(i['tts_speech'])
else:
for i in self.cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k):
speech_list.append(i['tts_speech'])
torchaudio.save('tts.wav', torch.concat(speech_list, dim=1), self.cosyvoice.sample_rate)
return 'tts.wav'
def load_tts(self):
if not os.path.exists('CosyVoice'):
from wan.utils.utils import download_cosyvoice_repo
download_cosyvoice_repo('CosyVoice')
if not os.path.exists('CosyVoice2-0.5B'):
from wan.utils.utils import download_cosyvoice_model
download_cosyvoice_model('CosyVoice2-0.5B', 'CosyVoice2-0.5B')
sys.path.append('CosyVoice')
sys.path.append('CosyVoice/third_party/Matcha-TTS')
from cosyvoice.cli.cosyvoice import CosyVoice2
self.cosyvoice = CosyVoice2('CosyVoice2-0.5B')
\ No newline at end of file
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import gc
import logging
import math
import os
import random
import sys
import types
from contextlib import contextmanager
from functools import partial
import torch
import torch.cuda.amp as amp
import torch.distributed as dist
from tqdm import tqdm
from .distributed.fsdp import shard_model
from .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward
from .distributed.util import get_world_size
from .modules.model import WanModel
from .modules.t5 import T5EncoderModel
from .modules.vae2_1 import Wan2_1_VAE
from .utils.fm_solvers import (
FlowDPMSolverMultistepScheduler,
get_sampling_sigmas,
retrieve_timesteps,
)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
class WanT2V:
def __init__(
self,
config,
checkpoint_dir,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_sp=False,
t5_cpu=False,
init_on_cpu=True,
convert_model_dtype=False,
):
r"""
Initializes the Wan text-to-video generation model components.
Args:
config (EasyDict):
Object containing model parameters initialized from config.py
checkpoint_dir (`str`):
Path to directory containing model checkpoints
device_id (`int`, *optional*, defaults to 0):
Id of target GPU device
rank (`int`, *optional*, defaults to 0):
Process rank for distributed training
t5_fsdp (`bool`, *optional*, defaults to False):
Enable FSDP sharding for T5 model
dit_fsdp (`bool`, *optional*, defaults to False):
Enable FSDP sharding for DiT model
use_sp (`bool`, *optional*, defaults to False):
Enable distribution strategy of sequence parallel.
t5_cpu (`bool`, *optional*, defaults to False):
Whether to place T5 model on CPU. Only works without t5_fsdp.
init_on_cpu (`bool`, *optional*, defaults to True):
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
convert_model_dtype (`bool`, *optional*, defaults to False):
Convert DiT model parameters dtype to 'config.param_dtype'.
Only works without FSDP.
"""
self.device = torch.device(f"cuda:{device_id}")
self.config = config
self.rank = rank
self.t5_cpu = t5_cpu
self.init_on_cpu = init_on_cpu
self.num_train_timesteps = config.num_train_timesteps
self.boundary = config.boundary
self.param_dtype = config.param_dtype
if t5_fsdp or dit_fsdp or use_sp:
self.init_on_cpu = False
shard_fn = partial(shard_model, device_id=device_id)
self.text_encoder = T5EncoderModel(
text_len=config.text_len,
dtype=config.t5_dtype,
device=torch.device('cpu'),
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
shard_fn=shard_fn if t5_fsdp else None)
self.vae_stride = config.vae_stride
self.patch_size = config.patch_size
self.vae = Wan2_1_VAE(
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
device=self.device)
logging.info(f"Creating WanModel from {checkpoint_dir}")
self.low_noise_model = WanModel.from_pretrained(
checkpoint_dir, subfolder=config.low_noise_checkpoint)
self.low_noise_model = self._configure_model(
model=self.low_noise_model,
use_sp=use_sp,
dit_fsdp=dit_fsdp,
shard_fn=shard_fn,
convert_model_dtype=convert_model_dtype)
self.high_noise_model = WanModel.from_pretrained(
checkpoint_dir, subfolder=config.high_noise_checkpoint)
self.high_noise_model = self._configure_model(
model=self.high_noise_model,
use_sp=use_sp,
dit_fsdp=dit_fsdp,
shard_fn=shard_fn,
convert_model_dtype=convert_model_dtype)
if use_sp:
self.sp_size = get_world_size()
else:
self.sp_size = 1
self.sample_neg_prompt = config.sample_neg_prompt
def _configure_model(self, model, use_sp, dit_fsdp, shard_fn,
convert_model_dtype):
"""
Configures a model object. This includes setting evaluation modes,
applying distributed parallel strategy, and handling device placement.
Args:
model (torch.nn.Module):
The model instance to configure.
use_sp (`bool`):
Enable distribution strategy of sequence parallel.
dit_fsdp (`bool`):
Enable FSDP sharding for DiT model.
shard_fn (callable):
The function to apply FSDP sharding.
convert_model_dtype (`bool`):
Convert DiT model parameters dtype to 'config.param_dtype'.
Only works without FSDP.
Returns:
torch.nn.Module:
The configured model.
"""
model.eval().requires_grad_(False)
if use_sp:
for block in model.blocks:
block.self_attn.forward = types.MethodType(
sp_attn_forward, block.self_attn)
model.forward = types.MethodType(sp_dit_forward, model)
if dist.is_initialized():
dist.barrier()
if dit_fsdp:
model = shard_fn(model)
else:
if convert_model_dtype:
model.to(self.param_dtype)
if not self.init_on_cpu:
model.to(self.device)
return model
def _prepare_model_for_timestep(self, t, boundary, offload_model):
r"""
Prepares and returns the required model for the current timestep.
Args:
t (torch.Tensor):
current timestep.
boundary (`int`):
The timestep threshold. If `t` is at or above this value,
the `high_noise_model` is considered as the required model.
offload_model (`bool`):
A flag intended to control the offloading behavior.
Returns:
torch.nn.Module:
The active model on the target device for the current timestep.
"""
if t.item() >= boundary:
required_model_name = 'high_noise_model'
offload_model_name = 'low_noise_model'
else:
required_model_name = 'low_noise_model'
offload_model_name = 'high_noise_model'
if offload_model or self.init_on_cpu:
if next(getattr(
self,
offload_model_name).parameters()).device.type == 'cuda':
getattr(self, offload_model_name).to('cpu')
if next(getattr(
self,
required_model_name).parameters()).device.type == 'cpu':
getattr(self, required_model_name).to(self.device)
return getattr(self, required_model_name)
def generate(self,
input_prompt,
size=(1280, 720),
frame_num=81,
shift=5.0,
sample_solver='unipc',
sampling_steps=50,
guide_scale=5.0,
n_prompt="",
seed=-1,
offload_model=True):
r"""
Generates video frames from text prompt using diffusion process.
Args:
input_prompt (`str`):
Text prompt for content generation
size (`tuple[int]`, *optional*, defaults to (1280,720)):
Controls video resolution, (width,height).
frame_num (`int`, *optional*, defaults to 81):
How many frames to sample from a video. The number should be 4n+1
shift (`float`, *optional*, defaults to 5.0):
Noise schedule shift parameter. Affects temporal dynamics
sample_solver (`str`, *optional*, defaults to 'unipc'):
Solver used to sample the video.
sampling_steps (`int`, *optional*, defaults to 50):
Number of diffusion sampling steps. Higher values improve quality but slow generation
guide_scale (`float` or tuple[`float`], *optional*, defaults 5.0):
Classifier-free guidance scale. Controls prompt adherence vs. creativity.
If tuple, the first guide_scale will be used for low noise model and
the second guide_scale will be used for high noise model.
n_prompt (`str`, *optional*, defaults to ""):
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
seed (`int`, *optional*, defaults to -1):
Random seed for noise generation. If -1, use random seed.
offload_model (`bool`, *optional*, defaults to True):
If True, offloads models to CPU during generation to save VRAM
Returns:
torch.Tensor:
Generated video frames tensor. Dimensions: (C, N H, W) where:
- C: Color channels (3 for RGB)
- N: Number of frames (81)
- H: Frame height (from size)
- W: Frame width from size)
"""
# preprocess
guide_scale = (guide_scale, guide_scale) if isinstance(
guide_scale, float) else guide_scale
F = frame_num
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
size[1] // self.vae_stride[1],
size[0] // self.vae_stride[2])
seq_len = math.ceil((target_shape[2] * target_shape[3]) /
(self.patch_size[1] * self.patch_size[2]) *
target_shape[1] / self.sp_size) * self.sp_size
if n_prompt == "":
n_prompt = self.sample_neg_prompt
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
seed_g = torch.Generator(device=self.device)
seed_g.manual_seed(seed)
if not self.t5_cpu:
self.text_encoder.model.to(self.device)
context = self.text_encoder([input_prompt], self.device)
context_null = self.text_encoder([n_prompt], self.device)
if offload_model:
self.text_encoder.model.cpu()
else:
context = self.text_encoder([input_prompt], torch.device('cpu'))
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
context = [t.to(self.device) for t in context]
context_null = [t.to(self.device) for t in context_null]
noise = [
torch.randn(
target_shape[0],
target_shape[1],
target_shape[2],
target_shape[3],
dtype=torch.float32,
device=self.device,
generator=seed_g)
]
@contextmanager
def noop_no_sync():
yield
no_sync_low_noise = getattr(self.low_noise_model, 'no_sync',
noop_no_sync)
no_sync_high_noise = getattr(self.high_noise_model, 'no_sync',
noop_no_sync)
# evaluation mode
with (
torch.amp.autocast('cuda', dtype=self.param_dtype),
torch.no_grad(),
no_sync_low_noise(),
no_sync_high_noise(),
):
boundary = self.boundary * self.num_train_timesteps
if sample_solver == 'unipc':
sample_scheduler = FlowUniPCMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sample_scheduler.set_timesteps(
sampling_steps, device=self.device, shift=shift)
timesteps = sample_scheduler.timesteps
elif sample_solver == 'dpm++':
sample_scheduler = FlowDPMSolverMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
timesteps, _ = retrieve_timesteps(
sample_scheduler,
device=self.device,
sigmas=sampling_sigmas)
else:
raise NotImplementedError("Unsupported solver.")
# sample videos
latents = noise
arg_c = {'context': context, 'seq_len': seq_len}
arg_null = {'context': context_null, 'seq_len': seq_len}
for _, t in enumerate(tqdm(timesteps)):
latent_model_input = latents
timestep = [t]
timestep = torch.stack(timestep)
model = self._prepare_model_for_timestep(
t, boundary, offload_model)
sample_guide_scale = guide_scale[1] if t.item(
) >= boundary else guide_scale[0]
noise_pred_cond = model(
latent_model_input, t=timestep, **arg_c)[0]
noise_pred_uncond = model(
latent_model_input, t=timestep, **arg_null)[0]
noise_pred = noise_pred_uncond + sample_guide_scale * (
noise_pred_cond - noise_pred_uncond)
temp_x0 = sample_scheduler.step(
noise_pred.unsqueeze(0),
t,
latents[0].unsqueeze(0),
return_dict=False,
generator=seed_g)[0]
latents = [temp_x0.squeeze(0)]
x0 = latents
if offload_model:
self.low_noise_model.cpu()
self.high_noise_model.cpu()
torch.cuda.empty_cache()
if self.rank == 0:
videos = self.vae.decode(x0)
del noise, latents
del sample_scheduler
if offload_model:
gc.collect()
torch.cuda.synchronize()
if dist.is_initialized():
dist.barrier()
return videos[0] if self.rank == 0 else None
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import gc
import logging
import math
import os
import random
import sys
import types
from contextlib import contextmanager
from functools import partial
import torch
import torch.cuda.amp as amp
import torch.distributed as dist
import torchvision.transforms.functional as TF
from PIL import Image
from tqdm import tqdm
from .distributed.fsdp import shard_model
from .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward
from .distributed.util import get_world_size
from .modules.model import WanModel
from .modules.t5 import T5EncoderModel
from .modules.vae2_2 import Wan2_2_VAE
from .utils.fm_solvers import (
FlowDPMSolverMultistepScheduler,
get_sampling_sigmas,
retrieve_timesteps,
)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from .utils.utils import best_output_size, masks_like
class WanTI2V:
def __init__(
self,
config,
checkpoint_dir,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_sp=False,
t5_cpu=False,
init_on_cpu=True,
convert_model_dtype=False,
):
r"""
Initializes the Wan text-to-video generation model components.
Args:
config (EasyDict):
Object containing model parameters initialized from config.py
checkpoint_dir (`str`):
Path to directory containing model checkpoints
device_id (`int`, *optional*, defaults to 0):
Id of target GPU device
rank (`int`, *optional*, defaults to 0):
Process rank for distributed training
t5_fsdp (`bool`, *optional*, defaults to False):
Enable FSDP sharding for T5 model
dit_fsdp (`bool`, *optional*, defaults to False):
Enable FSDP sharding for DiT model
use_sp (`bool`, *optional*, defaults to False):
Enable distribution strategy of sequence parallel.
t5_cpu (`bool`, *optional*, defaults to False):
Whether to place T5 model on CPU. Only works without t5_fsdp.
init_on_cpu (`bool`, *optional*, defaults to True):
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
convert_model_dtype (`bool`, *optional*, defaults to False):
Convert DiT model parameters dtype to 'config.param_dtype'.
Only works without FSDP.
"""
self.device = torch.device(f"cuda:{device_id}")
self.config = config
self.rank = rank
self.t5_cpu = t5_cpu
self.init_on_cpu = init_on_cpu
self.num_train_timesteps = config.num_train_timesteps
self.param_dtype = config.param_dtype
if t5_fsdp or dit_fsdp or use_sp:
self.init_on_cpu = False
shard_fn = partial(shard_model, device_id=device_id)
self.text_encoder = T5EncoderModel(
text_len=config.text_len,
dtype=config.t5_dtype,
device=torch.device('cpu'),
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
shard_fn=shard_fn if t5_fsdp else None)
self.vae_stride = config.vae_stride
self.patch_size = config.patch_size
self.vae = Wan2_2_VAE(
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
device=self.device)
logging.info(f"Creating WanModel from {checkpoint_dir}")
self.model = WanModel.from_pretrained(checkpoint_dir)
self.model = self._configure_model(
model=self.model,
use_sp=use_sp,
dit_fsdp=dit_fsdp,
shard_fn=shard_fn,
convert_model_dtype=convert_model_dtype)
if use_sp:
self.sp_size = get_world_size()
else:
self.sp_size = 1
self.sample_neg_prompt = config.sample_neg_prompt
def _configure_model(self, model, use_sp, dit_fsdp, shard_fn,
convert_model_dtype):
"""
Configures a model object. This includes setting evaluation modes,
applying distributed parallel strategy, and handling device placement.
Args:
model (torch.nn.Module):
The model instance to configure.
use_sp (`bool`):
Enable distribution strategy of sequence parallel.
dit_fsdp (`bool`):
Enable FSDP sharding for DiT model.
shard_fn (callable):
The function to apply FSDP sharding.
convert_model_dtype (`bool`):
Convert DiT model parameters dtype to 'config.param_dtype'.
Only works without FSDP.
Returns:
torch.nn.Module:
The configured model.
"""
model.eval().requires_grad_(False)
if use_sp:
for block in model.blocks:
block.self_attn.forward = types.MethodType(
sp_attn_forward, block.self_attn)
model.forward = types.MethodType(sp_dit_forward, model)
if dist.is_initialized():
dist.barrier()
if dit_fsdp:
model = shard_fn(model)
else:
if convert_model_dtype:
model.to(self.param_dtype)
if not self.init_on_cpu:
model.to(self.device)
return model
def generate(self,
input_prompt,
img=None,
size=(1280, 704),
max_area=704 * 1280,
frame_num=81,
shift=5.0,
sample_solver='unipc',
sampling_steps=50,
guide_scale=5.0,
n_prompt="",
seed=-1,
offload_model=True):
r"""
Generates video frames from text prompt using diffusion process.
Args:
input_prompt (`str`):
Text prompt for content generation
img (PIL.Image.Image):
Input image tensor. Shape: [3, H, W]
size (`tuple[int]`, *optional*, defaults to (1280,704)):
Controls video resolution, (width,height).
max_area (`int`, *optional*, defaults to 704*1280):
Maximum pixel area for latent space calculation. Controls video resolution scaling
frame_num (`int`, *optional*, defaults to 81):
How many frames to sample from a video. The number should be 4n+1
shift (`float`, *optional*, defaults to 5.0):
Noise schedule shift parameter. Affects temporal dynamics
sample_solver (`str`, *optional*, defaults to 'unipc'):
Solver used to sample the video.
sampling_steps (`int`, *optional*, defaults to 50):
Number of diffusion sampling steps. Higher values improve quality but slow generation
guide_scale (`float`, *optional*, defaults 5.0):
Classifier-free guidance scale. Controls prompt adherence vs. creativity.
n_prompt (`str`, *optional*, defaults to ""):
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
seed (`int`, *optional*, defaults to -1):
Random seed for noise generation. If -1, use random seed.
offload_model (`bool`, *optional*, defaults to True):
If True, offloads models to CPU during generation to save VRAM
Returns:
torch.Tensor:
Generated video frames tensor. Dimensions: (C, N H, W) where:
- C: Color channels (3 for RGB)
- N: Number of frames (81)
- H: Frame height (from size)
- W: Frame width from size)
"""
# i2v
if img is not None:
return self.i2v(
input_prompt=input_prompt,
img=img,
max_area=max_area,
frame_num=frame_num,
shift=shift,
sample_solver=sample_solver,
sampling_steps=sampling_steps,
guide_scale=guide_scale,
n_prompt=n_prompt,
seed=seed,
offload_model=offload_model)
# t2v
return self.t2v(
input_prompt=input_prompt,
size=size,
frame_num=frame_num,
shift=shift,
sample_solver=sample_solver,
sampling_steps=sampling_steps,
guide_scale=guide_scale,
n_prompt=n_prompt,
seed=seed,
offload_model=offload_model)
def t2v(self,
input_prompt,
size=(1280, 704),
frame_num=121,
shift=5.0,
sample_solver='unipc',
sampling_steps=50,
guide_scale=5.0,
n_prompt="",
seed=-1,
offload_model=True):
r"""
Generates video frames from text prompt using diffusion process.
Args:
input_prompt (`str`):
Text prompt for content generation
size (`tuple[int]`, *optional*, defaults to (1280,704)):
Controls video resolution, (width,height).
frame_num (`int`, *optional*, defaults to 121):
How many frames to sample from a video. The number should be 4n+1
shift (`float`, *optional*, defaults to 5.0):
Noise schedule shift parameter. Affects temporal dynamics
sample_solver (`str`, *optional*, defaults to 'unipc'):
Solver used to sample the video.
sampling_steps (`int`, *optional*, defaults to 50):
Number of diffusion sampling steps. Higher values improve quality but slow generation
guide_scale (`float`, *optional*, defaults 5.0):
Classifier-free guidance scale. Controls prompt adherence vs. creativity.
n_prompt (`str`, *optional*, defaults to ""):
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
seed (`int`, *optional*, defaults to -1):
Random seed for noise generation. If -1, use random seed.
offload_model (`bool`, *optional*, defaults to True):
If True, offloads models to CPU during generation to save VRAM
Returns:
torch.Tensor:
Generated video frames tensor. Dimensions: (C, N H, W) where:
- C: Color channels (3 for RGB)
- N: Number of frames (81)
- H: Frame height (from size)
- W: Frame width from size)
"""
# preprocess
F = frame_num
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
size[1] // self.vae_stride[1],
size[0] // self.vae_stride[2])
seq_len = math.ceil((target_shape[2] * target_shape[3]) /
(self.patch_size[1] * self.patch_size[2]) *
target_shape[1] / self.sp_size) * self.sp_size
if n_prompt == "":
n_prompt = self.sample_neg_prompt
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
seed_g = torch.Generator(device=self.device)
seed_g.manual_seed(seed)
if not self.t5_cpu:
self.text_encoder.model.to(self.device)
context = self.text_encoder([input_prompt], self.device)
context_null = self.text_encoder([n_prompt], self.device)
if offload_model:
self.text_encoder.model.cpu()
else:
context = self.text_encoder([input_prompt], torch.device('cpu'))
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
context = [t.to(self.device) for t in context]
context_null = [t.to(self.device) for t in context_null]
noise = [
torch.randn(
target_shape[0],
target_shape[1],
target_shape[2],
target_shape[3],
dtype=torch.float32,
device=self.device,
generator=seed_g)
]
@contextmanager
def noop_no_sync():
yield
no_sync = getattr(self.model, 'no_sync', noop_no_sync)
# evaluation mode
with (
torch.amp.autocast('cuda', dtype=self.param_dtype),
torch.no_grad(),
no_sync(),
):
if sample_solver == 'unipc':
sample_scheduler = FlowUniPCMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sample_scheduler.set_timesteps(
sampling_steps, device=self.device, shift=shift)
timesteps = sample_scheduler.timesteps
elif sample_solver == 'dpm++':
sample_scheduler = FlowDPMSolverMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
timesteps, _ = retrieve_timesteps(
sample_scheduler,
device=self.device,
sigmas=sampling_sigmas)
else:
raise NotImplementedError("Unsupported solver.")
# sample videos
latents = noise
mask1, mask2 = masks_like(noise, zero=False)
arg_c = {'context': context, 'seq_len': seq_len}
arg_null = {'context': context_null, 'seq_len': seq_len}
if offload_model or self.init_on_cpu:
self.model.to(self.device)
torch.cuda.empty_cache()
for _, t in enumerate(tqdm(timesteps)):
latent_model_input = latents
timestep = [t]
timestep = torch.stack(timestep)
temp_ts = (mask2[0][0][:, ::2, ::2] * timestep).flatten()
temp_ts = torch.cat([
temp_ts,
temp_ts.new_ones(seq_len - temp_ts.size(0)) * timestep
])
timestep = temp_ts.unsqueeze(0)
noise_pred_cond = self.model(
latent_model_input, t=timestep, **arg_c)[0]
noise_pred_uncond = self.model(
latent_model_input, t=timestep, **arg_null)[0]
noise_pred = noise_pred_uncond + guide_scale * (
noise_pred_cond - noise_pred_uncond)
temp_x0 = sample_scheduler.step(
noise_pred.unsqueeze(0),
t,
latents[0].unsqueeze(0),
return_dict=False,
generator=seed_g)[0]
latents = [temp_x0.squeeze(0)]
x0 = latents
if offload_model:
self.model.cpu()
torch.cuda.synchronize()
torch.cuda.empty_cache()
if self.rank == 0:
videos = self.vae.decode(x0)
del noise, latents
del sample_scheduler
if offload_model:
gc.collect()
torch.cuda.synchronize()
if dist.is_initialized():
dist.barrier()
return videos[0] if self.rank == 0 else None
def i2v(self,
input_prompt,
img,
max_area=704 * 1280,
frame_num=121,
shift=5.0,
sample_solver='unipc',
sampling_steps=40,
guide_scale=5.0,
n_prompt="",
seed=-1,
offload_model=True):
r"""
Generates video frames from input image and text prompt using diffusion process.
Args:
input_prompt (`str`):
Text prompt for content generation.
img (PIL.Image.Image):
Input image tensor. Shape: [3, H, W]
max_area (`int`, *optional*, defaults to 704*1280):
Maximum pixel area for latent space calculation. Controls video resolution scaling
frame_num (`int`, *optional*, defaults to 121):
How many frames to sample from a video. The number should be 4n+1
shift (`float`, *optional*, defaults to 5.0):
Noise schedule shift parameter. Affects temporal dynamics
[NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.
sample_solver (`str`, *optional*, defaults to 'unipc'):
Solver used to sample the video.
sampling_steps (`int`, *optional*, defaults to 40):
Number of diffusion sampling steps. Higher values improve quality but slow generation
guide_scale (`float`, *optional*, defaults 5.0):
Classifier-free guidance scale. Controls prompt adherence vs. creativity.
n_prompt (`str`, *optional*, defaults to ""):
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
seed (`int`, *optional*, defaults to -1):
Random seed for noise generation. If -1, use random seed
offload_model (`bool`, *optional*, defaults to True):
If True, offloads models to CPU during generation to save VRAM
Returns:
torch.Tensor:
Generated video frames tensor. Dimensions: (C, N H, W) where:
- C: Color channels (3 for RGB)
- N: Number of frames (121)
- H: Frame height (from max_area)
- W: Frame width (from max_area)
"""
# preprocess
ih, iw = img.height, img.width
dh, dw = self.patch_size[1] * self.vae_stride[1], self.patch_size[
2] * self.vae_stride[2]
ow, oh = best_output_size(iw, ih, dw, dh, max_area)
scale = max(ow / iw, oh / ih)
img = img.resize((round(iw * scale), round(ih * scale)), Image.LANCZOS)
# center-crop
x1 = (img.width - ow) // 2
y1 = (img.height - oh) // 2
img = img.crop((x1, y1, x1 + ow, y1 + oh))
assert img.width == ow and img.height == oh
# to tensor
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device).unsqueeze(1)
F = frame_num
seq_len = ((F - 1) // self.vae_stride[0] + 1) * (
oh // self.vae_stride[1]) * (ow // self.vae_stride[2]) // (
self.patch_size[1] * self.patch_size[2])
seq_len = int(math.ceil(seq_len / self.sp_size)) * self.sp_size
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
seed_g = torch.Generator(device=self.device)
seed_g.manual_seed(seed)
noise = torch.randn(
self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
oh // self.vae_stride[1],
ow // self.vae_stride[2],
dtype=torch.float32,
generator=seed_g,
device=self.device)
if n_prompt == "":
n_prompt = self.sample_neg_prompt
# preprocess
if not self.t5_cpu:
self.text_encoder.model.to(self.device)
context = self.text_encoder([input_prompt], self.device)
context_null = self.text_encoder([n_prompt], self.device)
if offload_model:
self.text_encoder.model.cpu()
else:
context = self.text_encoder([input_prompt], torch.device('cpu'))
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
context = [t.to(self.device) for t in context]
context_null = [t.to(self.device) for t in context_null]
z = self.vae.encode([img])
@contextmanager
def noop_no_sync():
yield
no_sync = getattr(self.model, 'no_sync', noop_no_sync)
# evaluation mode
with (
torch.amp.autocast('cuda', dtype=self.param_dtype),
torch.no_grad(),
no_sync(),
):
if sample_solver == 'unipc':
sample_scheduler = FlowUniPCMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sample_scheduler.set_timesteps(
sampling_steps, device=self.device, shift=shift)
timesteps = sample_scheduler.timesteps
elif sample_solver == 'dpm++':
sample_scheduler = FlowDPMSolverMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
timesteps, _ = retrieve_timesteps(
sample_scheduler,
device=self.device,
sigmas=sampling_sigmas)
else:
raise NotImplementedError("Unsupported solver.")
# sample videos
latent = noise
mask1, mask2 = masks_like([noise], zero=True)
latent = (1. - mask2[0]) * z[0] + mask2[0] * latent
arg_c = {
'context': [context[0]],
'seq_len': seq_len,
}
arg_null = {
'context': context_null,
'seq_len': seq_len,
}
if offload_model or self.init_on_cpu:
self.model.to(self.device)
torch.cuda.empty_cache()
for _, t in enumerate(tqdm(timesteps)):
latent_model_input = [latent.to(self.device)]
timestep = [t]
timestep = torch.stack(timestep).to(self.device)
temp_ts = (mask2[0][0][:, ::2, ::2] * timestep).flatten()
temp_ts = torch.cat([
temp_ts,
temp_ts.new_ones(seq_len - temp_ts.size(0)) * timestep
])
timestep = temp_ts.unsqueeze(0)
noise_pred_cond = self.model(
latent_model_input, t=timestep, **arg_c)[0]
if offload_model:
torch.cuda.empty_cache()
noise_pred_uncond = self.model(
latent_model_input, t=timestep, **arg_null)[0]
if offload_model:
torch.cuda.empty_cache()
noise_pred = noise_pred_uncond + guide_scale * (
noise_pred_cond - noise_pred_uncond)
temp_x0 = sample_scheduler.step(
noise_pred.unsqueeze(0),
t,
latent.unsqueeze(0),
return_dict=False,
generator=seed_g)[0]
latent = temp_x0.squeeze(0)
latent = (1. - mask2[0]) * z[0] + mask2[0] * latent
x0 = [latent]
del latent_model_input, timestep
if offload_model:
self.model.cpu()
torch.cuda.synchronize()
torch.cuda.empty_cache()
if self.rank == 0:
videos = self.vae.decode(x0)
del noise, latent, x0
del sample_scheduler
if offload_model:
gc.collect()
torch.cuda.synchronize()
if dist.is_initialized():
dist.barrier()
return videos[0] if self.rank == 0 else None
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from .fm_solvers import (
FlowDPMSolverMultistepScheduler,
get_sampling_sigmas,
retrieve_timesteps,
)
from .fm_solvers_unipc import FlowUniPCMultistepScheduler
__all__ = [
'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps',
'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler'
]
# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
# Convert dpm solver for flow matching
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import inspect
import math
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import (
KarrasDiffusionSchedulers,
SchedulerMixin,
SchedulerOutput,
)
from diffusers.utils import deprecate, is_scipy_available
from diffusers.utils.torch_utils import randn_tensor
if is_scipy_available():
pass
def get_sampling_sigmas(sampling_steps, shift):
sigma = np.linspace(1, 0, sampling_steps + 1)[:sampling_steps]
sigma = (shift * sigma / (1 + (shift - 1) * sigma))
return sigma
def retrieve_timesteps(
scheduler,
num_inference_steps=None,
device=None,
timesteps=None,
sigmas=None,
**kwargs,
):
if timesteps is not None and sigmas is not None:
raise ValueError(
"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
)
if timesteps is not None:
accepts_timesteps = "timesteps" in set(
inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(
inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class FlowDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
`FlowDPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs.
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model. This determines the resolution of the diffusion process.
solver_order (`int`, defaults to 2):
The DPMSolver order which can be `1`, `2`, or `3`. It is recommended to use `solver_order=2` for guided
sampling, and `solver_order=3` for unconditional sampling. This affects the number of model outputs stored
and used in multistep updates.
prediction_type (`str`, defaults to "flow_prediction"):
Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts
the flow of the diffusion process.
shift (`float`, *optional*, defaults to 1.0):
A factor used to adjust the sigmas in the noise schedule. It modifies the step sizes during the sampling
process.
use_dynamic_shifting (`bool`, defaults to `False`):
Whether to apply dynamic shifting to the timesteps based on image resolution. If `True`, the shifting is
applied on the fly.
thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method. This method adjusts the predicted sample to prevent
saturation and improve photorealism.
dynamic_thresholding_ratio (`float`, defaults to 0.995):
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
sample_max_value (`float`, defaults to 1.0):
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
`algorithm_type="dpmsolver++"`.
algorithm_type (`str`, defaults to `dpmsolver++`):
Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The
`dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
paper, and the `dpmsolver++` type implements the algorithms in the
[DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
`sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
solver_type (`str`, defaults to `midpoint`):
Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
lower_order_final (`bool`, defaults to `True`):
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
euler_at_final (`bool`, defaults to `False`):
Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail
richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
steps, but sometimes may result in blurring.
final_sigmas_type (`str`, *optional*, defaults to "zero"):
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
lambda_min_clipped (`float`, defaults to `-inf`):
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
cosine (`squaredcos_cap_v2`) noise schedule.
variance_type (`str`, *optional*):
Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
contains the predicted Gaussian variance.
"""
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
solver_order: int = 2,
prediction_type: str = "flow_prediction",
shift: Optional[float] = 1.0,
use_dynamic_shifting=False,
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint",
lower_order_final: bool = True,
euler_at_final: bool = False,
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None,
invert_sigmas: bool = False,
):
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0",
deprecation_message)
# settings for DPM-Solver
if algorithm_type not in [
"dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"
]:
if algorithm_type == "deis":
self.register_to_config(algorithm_type="dpmsolver++")
else:
raise NotImplementedError(
f"{algorithm_type} is not implemented for {self.__class__}")
if solver_type not in ["midpoint", "heun"]:
if solver_type in ["logrho", "bh1", "bh2"]:
self.register_to_config(solver_type="midpoint")
else:
raise NotImplementedError(
f"{solver_type} is not implemented for {self.__class__}")
if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"
] and final_sigmas_type == "zero":
raise ValueError(
f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead."
)
# setable values
self.num_inference_steps = None
alphas = np.linspace(1, 1 / num_train_timesteps,
num_train_timesteps)[::-1].copy()
sigmas = 1.0 - alphas
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
if not use_dynamic_shifting:
# when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
sigmas = shift * sigmas / (1 +
(shift - 1) * sigmas) # pyright: ignore
self.sigmas = sigmas
self.timesteps = sigmas * num_train_timesteps
self.model_outputs = [None] * solver_order
self.lower_order_nums = 0
self._step_index = None
self._begin_index = None
# self.sigmas = self.sigmas.to(
# "cpu") # to avoid too much CPU/GPU communication
self.sigma_min = self.sigmas[-1].item()
self.sigma_max = self.sigmas[0].item()
@property
def step_index(self):
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
# Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps
def set_timesteps(
self,
num_inference_steps: Union[int, None] = None,
device: Union[str, torch.device] = None,
sigmas: Optional[List[float]] = None,
mu: Optional[Union[float, None]] = None,
shift: Optional[Union[float, None]] = None,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
Total number of the spacing of the time steps.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
if self.config.use_dynamic_shifting and mu is None:
raise ValueError(
" you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`"
)
if sigmas is None:
sigmas = np.linspace(self.sigma_max, self.sigma_min,
num_inference_steps +
1).copy()[:-1] # pyright: ignore
if self.config.use_dynamic_shifting:
sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore
else:
if shift is None:
shift = self.config.shift
sigmas = shift * sigmas / (1 +
(shift - 1) * sigmas) # pyright: ignore
if self.config.final_sigmas_type == "sigma_min":
sigma_last = ((1 - self.alphas_cumprod[0]) /
self.alphas_cumprod[0])**0.5
elif self.config.final_sigmas_type == "zero":
sigma_last = 0
else:
raise ValueError(
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
)
timesteps = sigmas * self.config.num_train_timesteps
sigmas = np.concatenate([sigmas, [sigma_last]
]).astype(np.float32) # pyright: ignore
self.sigmas = torch.from_numpy(sigmas)
self.timesteps = torch.from_numpy(timesteps).to(
device=device, dtype=torch.int64)
self.num_inference_steps = len(timesteps)
self.model_outputs = [
None,
] * self.config.solver_order
self.lower_order_nums = 0
self._step_index = None
self._begin_index = None
# self.sigmas = self.sigmas.to(
# "cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://arxiv.org/abs/2205.11487
"""
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
if dtype not in (torch.float32, torch.float64):
sample = sample.float(
) # upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
s = torch.quantile(
abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
s = torch.clamp(
s, min=1, max=self.config.sample_max_value
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
s = s.unsqueeze(
1) # (batch_size, 1) because clamp will broadcast along dim=0
sample = torch.clamp(
sample, -s, s
) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
sample = sample.reshape(batch_size, channels, *remaining_dims)
sample = sample.to(dtype)
return sample
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma):
return sigma * self.config.num_train_timesteps
def _sigma_to_alpha_sigma_t(self, sigma):
return 1 - sigma, sigma
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma)
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output
def convert_model_output(
self,
model_output: torch.Tensor,
*args,
sample: torch.Tensor = None,
**kwargs,
) -> torch.Tensor:
"""
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
integral of the data prediction model.
<Tip>
The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
prediction and data prediction models.
</Tip>
Args:
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
`torch.Tensor`:
The converted model output.
"""
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
if sample is None:
if len(args) > 1:
sample = args[1]
else:
raise ValueError(
"missing `sample` as a required keyward argument")
if timestep is not None:
deprecate(
"timesteps",
"1.0.0",
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
# DPM-Solver++ needs to solve an integral of the data prediction model.
if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
if self.config.prediction_type == "flow_prediction":
sigma_t = self.sigmas[self.step_index]
x0_pred = sample - sigma_t * model_output
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
" `v_prediction`, or `flow_prediction` for the FlowDPMSolverMultistepScheduler."
)
if self.config.thresholding:
x0_pred = self._threshold_sample(x0_pred)
return x0_pred
# DPM-Solver needs to solve an integral of the noise prediction model.
elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
if self.config.prediction_type == "flow_prediction":
sigma_t = self.sigmas[self.step_index]
epsilon = sample - (1 - sigma_t) * model_output
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
" `v_prediction` or `flow_prediction` for the FlowDPMSolverMultistepScheduler."
)
if self.config.thresholding:
sigma_t = self.sigmas[self.step_index]
x0_pred = sample - sigma_t * model_output
x0_pred = self._threshold_sample(x0_pred)
epsilon = model_output + x0_pred
return epsilon
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update
def dpm_solver_first_order_update(
self,
model_output: torch.Tensor,
*args,
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
One step for the first-order DPMSolver (equivalent to DDIM).
Args:
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
`torch.Tensor`:
The sample tensor at the previous timestep.
"""
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
prev_timestep = args[1] if len(args) > 1 else kwargs.pop(
"prev_timestep", None)
if sample is None:
if len(args) > 2:
sample = args[2]
else:
raise ValueError(
" missing `sample` as a required keyward argument")
if timestep is not None:
deprecate(
"timesteps",
"1.0.0",
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[
self.step_index] # pyright: ignore
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
h = lambda_t - lambda_s
if self.config.algorithm_type == "dpmsolver++":
x_t = (sigma_t /
sigma_s) * sample - (alpha_t *
(torch.exp(-h) - 1.0)) * model_output
elif self.config.algorithm_type == "dpmsolver":
x_t = (alpha_t /
alpha_s) * sample - (sigma_t *
(torch.exp(h) - 1.0)) * model_output
elif self.config.algorithm_type == "sde-dpmsolver++":
assert noise is not None
x_t = ((sigma_t / sigma_s * torch.exp(-h)) * sample +
(alpha_t * (1 - torch.exp(-2.0 * h))) * model_output +
sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise)
elif self.config.algorithm_type == "sde-dpmsolver":
assert noise is not None
x_t = ((alpha_t / alpha_s) * sample - 2.0 *
(sigma_t * (torch.exp(h) - 1.0)) * model_output +
sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise)
return x_t # pyright: ignore
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update
def multistep_dpm_solver_second_order_update(
self,
model_output_list: List[torch.Tensor],
*args,
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
One step for the second-order multistep DPMSolver.
Args:
model_output_list (`List[torch.Tensor]`):
The direct outputs from learned diffusion model at current and latter timesteps.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
`torch.Tensor`:
The sample tensor at the previous timestep.
"""
timestep_list = args[0] if len(args) > 0 else kwargs.pop(
"timestep_list", None)
prev_timestep = args[1] if len(args) > 1 else kwargs.pop(
"prev_timestep", None)
if sample is None:
if len(args) > 2:
sample = args[2]
else:
raise ValueError(
" missing `sample` as a required keyward argument")
if timestep_list is not None:
deprecate(
"timestep_list",
"1.0.0",
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma_t, sigma_s0, sigma_s1 = (
self.sigmas[self.step_index + 1], # pyright: ignore
self.sigmas[self.step_index],
self.sigmas[self.step_index - 1], # pyright: ignore
)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
m0, m1 = model_output_list[-1], model_output_list[-2]
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
r0 = h_0 / h
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
if self.config.algorithm_type == "dpmsolver++":
# See https://arxiv.org/abs/2211.01095 for detailed derivations
if self.config.solver_type == "midpoint":
x_t = ((sigma_t / sigma_s0) * sample -
(alpha_t * (torch.exp(-h) - 1.0)) * D0 - 0.5 *
(alpha_t * (torch.exp(-h) - 1.0)) * D1)
elif self.config.solver_type == "heun":
x_t = ((sigma_t / sigma_s0) * sample -
(alpha_t * (torch.exp(-h) - 1.0)) * D0 +
(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1)
elif self.config.algorithm_type == "dpmsolver":
# See https://arxiv.org/abs/2206.00927 for detailed derivations
if self.config.solver_type == "midpoint":
x_t = ((alpha_t / alpha_s0) * sample -
(sigma_t * (torch.exp(h) - 1.0)) * D0 - 0.5 *
(sigma_t * (torch.exp(h) - 1.0)) * D1)
elif self.config.solver_type == "heun":
x_t = ((alpha_t / alpha_s0) * sample -
(sigma_t * (torch.exp(h) - 1.0)) * D0 -
(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1)
elif self.config.algorithm_type == "sde-dpmsolver++":
assert noise is not None
if self.config.solver_type == "midpoint":
x_t = ((sigma_t / sigma_s0 * torch.exp(-h)) * sample +
(alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + 0.5 *
(alpha_t * (1 - torch.exp(-2.0 * h))) * D1 +
sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise)
elif self.config.solver_type == "heun":
x_t = ((sigma_t / sigma_s0 * torch.exp(-h)) * sample +
(alpha_t * (1 - torch.exp(-2.0 * h))) * D0 +
(alpha_t * ((1.0 - torch.exp(-2.0 * h)) /
(-2.0 * h) + 1.0)) * D1 +
sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise)
elif self.config.algorithm_type == "sde-dpmsolver":
assert noise is not None
if self.config.solver_type == "midpoint":
x_t = ((alpha_t / alpha_s0) * sample - 2.0 *
(sigma_t * (torch.exp(h) - 1.0)) * D0 -
(sigma_t * (torch.exp(h) - 1.0)) * D1 +
sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise)
elif self.config.solver_type == "heun":
x_t = ((alpha_t / alpha_s0) * sample - 2.0 *
(sigma_t * (torch.exp(h) - 1.0)) * D0 - 2.0 *
(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 +
sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise)
return x_t # pyright: ignore
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update
def multistep_dpm_solver_third_order_update(
self,
model_output_list: List[torch.Tensor],
*args,
sample: torch.Tensor = None,
**kwargs,
) -> torch.Tensor:
"""
One step for the third-order multistep DPMSolver.
Args:
model_output_list (`List[torch.Tensor]`):
The direct outputs from learned diffusion model at current and latter timesteps.
sample (`torch.Tensor`):
A current instance of a sample created by diffusion process.
Returns:
`torch.Tensor`:
The sample tensor at the previous timestep.
"""
timestep_list = args[0] if len(args) > 0 else kwargs.pop(
"timestep_list", None)
prev_timestep = args[1] if len(args) > 1 else kwargs.pop(
"prev_timestep", None)
if sample is None:
if len(args) > 2:
sample = args[2]
else:
raise ValueError(
" missing`sample` as a required keyward argument")
if timestep_list is not None:
deprecate(
"timestep_list",
"1.0.0",
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
self.sigmas[self.step_index + 1], # pyright: ignore
self.sigmas[self.step_index],
self.sigmas[self.step_index - 1], # pyright: ignore
self.sigmas[self.step_index - 2], # pyright: ignore
)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)
m0, m1, m2 = model_output_list[-1], model_output_list[
-2], model_output_list[-3]
h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
r0, r1 = h_0 / h, h_1 / h
D0 = m0
D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
if self.config.algorithm_type == "dpmsolver++":
# See https://arxiv.org/abs/2206.00927 for detailed derivations
x_t = ((sigma_t / sigma_s0) * sample -
(alpha_t * (torch.exp(-h) - 1.0)) * D0 +
(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 -
(alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2)
elif self.config.algorithm_type == "dpmsolver":
# See https://arxiv.org/abs/2206.00927 for detailed derivations
x_t = ((alpha_t / alpha_s0) * sample - (sigma_t *
(torch.exp(h) - 1.0)) * D0 -
(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 -
(sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2)
return x_t # pyright: ignore
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = (schedule_timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos = 1 if len(indices) > 1 else 0
return indices[pos].item()
def _init_step_index(self, timestep):
"""
Initialize the step_index counter for the scheduler.
"""
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
# Modified from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.step
def step(
self,
model_output: torch.Tensor,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
generator=None,
variance_noise: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
the multistep DPMSolver.
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
variance_noise (`torch.Tensor`):
Alternative to generating noise with `generator` by directly providing the noise for the variance
itself. Useful for methods such as [`LEdits++`].
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
if self.step_index is None:
self._init_step_index(timestep)
# Improve numerical stability for small number of steps
lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
self.config.euler_at_final or
(self.config.lower_order_final and len(self.timesteps) < 15) or
self.config.final_sigmas_type == "zero")
lower_order_second = ((self.step_index == len(self.timesteps) - 2) and
self.config.lower_order_final and
len(self.timesteps) < 15)
model_output = self.convert_model_output(model_output, sample=sample)
for i in range(self.config.solver_order - 1):
self.model_outputs[i] = self.model_outputs[i + 1]
self.model_outputs[-1] = model_output
# Upcast to avoid precision issues when computing prev_sample
sample = sample.to(torch.float32)
if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"
] and variance_noise is None:
noise = randn_tensor(
model_output.shape,
generator=generator,
device=model_output.device,
dtype=torch.float32)
elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
noise = variance_noise.to(
device=model_output.device,
dtype=torch.float32) # pyright: ignore
else:
noise = None
if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
prev_sample = self.dpm_solver_first_order_update(
model_output, sample=sample, noise=noise)
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
prev_sample = self.multistep_dpm_solver_second_order_update(
self.model_outputs, sample=sample, noise=noise)
else:
prev_sample = self.multistep_dpm_solver_third_order_update(
self.model_outputs, sample=sample)
if self.lower_order_nums < self.config.solver_order:
self.lower_order_nums += 1
# Cast sample back to expected dtype
prev_sample = prev_sample.to(model_output.dtype)
# upon completion increase step index by one
self._step_index += 1 # pyright: ignore
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input
def scale_model_input(self, sample: torch.Tensor, *args,
**kwargs) -> torch.Tensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.Tensor`):
The input sample.
Returns:
`torch.Tensor`:
A scaled input sample.
"""
return sample
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(
device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(
timesteps):
# mps does not support float64
schedule_timesteps = self.timesteps.to(
original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(
original_samples.device, dtype=torch.float32)
else:
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
# begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [
self.index_for_timestep(t, schedule_timesteps)
for t in timesteps
]
elif self.step_index is not None:
# add_noise is called after first denoising step (for inpainting)
step_indices = [self.step_index] * timesteps.shape[0]
else:
# add noise is called before first denoising step to create initial latent(img2img)
step_indices = [self.begin_index] * timesteps.shape[0]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
noisy_samples = alpha_t * original_samples + sigma_t * noise
return noisy_samples
def __len__(self):
return self.config.num_train_timesteps
# Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py
# Convert unipc for flow matching
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import math
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import (
KarrasDiffusionSchedulers,
SchedulerMixin,
SchedulerOutput,
)
from diffusers.utils import deprecate, is_scipy_available
if is_scipy_available():
import scipy.stats
class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
`UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models.
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
solver_order (`int`, default `2`):
The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1`
due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for
unconditional sampling.
prediction_type (`str`, defaults to "flow_prediction"):
Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts
the flow of the diffusion process.
thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
as Stable Diffusion.
dynamic_thresholding_ratio (`float`, defaults to 0.995):
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
sample_max_value (`float`, defaults to 1.0):
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.
predict_x0 (`bool`, defaults to `True`):
Whether to use the updating algorithm on the predicted x0.
solver_type (`str`, default `bh2`):
Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2`
otherwise.
lower_order_final (`bool`, default `True`):
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
disable_corrector (`list`, default `[]`):
Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)`
and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is
usually disabled during the first few steps.
solver_p (`SchedulerMixin`, default `None`):
Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`.
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}.
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
steps_offset (`int`, defaults to 0):
An offset added to the inference steps, as required by some model families.
final_sigmas_type (`str`, defaults to `"zero"`):
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
"""
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
solver_order: int = 2,
prediction_type: str = "flow_prediction",
shift: Optional[float] = 1.0,
use_dynamic_shifting=False,
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
predict_x0: bool = True,
solver_type: str = "bh2",
lower_order_final: bool = True,
disable_corrector: List[int] = [],
solver_p: SchedulerMixin = None,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
):
if solver_type not in ["bh1", "bh2"]:
if solver_type in ["midpoint", "heun", "logrho"]:
self.register_to_config(solver_type="bh2")
else:
raise NotImplementedError(
f"{solver_type} is not implemented for {self.__class__}")
self.predict_x0 = predict_x0
# setable values
self.num_inference_steps = None
alphas = np.linspace(1, 1 / num_train_timesteps,
num_train_timesteps)[::-1].copy()
sigmas = 1.0 - alphas
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
if not use_dynamic_shifting:
# when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
sigmas = shift * sigmas / (1 +
(shift - 1) * sigmas) # pyright: ignore
self.sigmas = sigmas
self.timesteps = sigmas * num_train_timesteps
self.model_outputs = [None] * solver_order
self.timestep_list = [None] * solver_order
self.lower_order_nums = 0
self.disable_corrector = disable_corrector
self.solver_p = solver_p
self.last_sample = None
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to(
"cpu") # to avoid too much CPU/GPU communication
self.sigma_min = self.sigmas[-1].item()
self.sigma_max = self.sigmas[0].item()
@property
def step_index(self):
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
# Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps
def set_timesteps(
self,
num_inference_steps: Union[int, None] = None,
device: Union[str, torch.device] = None,
sigmas: Optional[List[float]] = None,
mu: Optional[Union[float, None]] = None,
shift: Optional[Union[float, None]] = None,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
Total number of the spacing of the time steps.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
if self.config.use_dynamic_shifting and mu is None:
raise ValueError(
" you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`"
)
if sigmas is None:
sigmas = np.linspace(self.sigma_max, self.sigma_min,
num_inference_steps +
1).copy()[:-1] # pyright: ignore
if self.config.use_dynamic_shifting:
sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore
else:
if shift is None:
shift = self.config.shift
sigmas = shift * sigmas / (1 +
(shift - 1) * sigmas) # pyright: ignore
if self.config.final_sigmas_type == "sigma_min":
sigma_last = ((1 - self.alphas_cumprod[0]) /
self.alphas_cumprod[0])**0.5
elif self.config.final_sigmas_type == "zero":
sigma_last = 0
else:
raise ValueError(
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
)
timesteps = sigmas * self.config.num_train_timesteps
sigmas = np.concatenate([sigmas, [sigma_last]
]).astype(np.float32) # pyright: ignore
self.sigmas = torch.from_numpy(sigmas)
self.timesteps = torch.from_numpy(timesteps).to(
device=device, dtype=torch.int64)
self.num_inference_steps = len(timesteps)
self.model_outputs = [
None,
] * self.config.solver_order
self.lower_order_nums = 0
self.last_sample = None
if self.solver_p:
self.solver_p.set_timesteps(self.num_inference_steps, device=device)
# add an index counter for schedulers that allow duplicated timesteps
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to(
"cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://arxiv.org/abs/2205.11487
"""
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
if dtype not in (torch.float32, torch.float64):
sample = sample.float(
) # upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
s = torch.quantile(
abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
s = torch.clamp(
s, min=1, max=self.config.sample_max_value
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
s = s.unsqueeze(
1) # (batch_size, 1) because clamp will broadcast along dim=0
sample = torch.clamp(
sample, -s, s
) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
sample = sample.reshape(batch_size, channels, *remaining_dims)
sample = sample.to(dtype)
return sample
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma):
return sigma * self.config.num_train_timesteps
def _sigma_to_alpha_sigma_t(self, sigma):
return 1 - sigma, sigma
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma)
def convert_model_output(
self,
model_output: torch.Tensor,
*args,
sample: torch.Tensor = None,
**kwargs,
) -> torch.Tensor:
r"""
Convert the model output to the corresponding type the UniPC algorithm needs.
Args:
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
`torch.Tensor`:
The converted model output.
"""
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
if sample is None:
if len(args) > 1:
sample = args[1]
else:
raise ValueError(
"missing `sample` as a required keyward argument")
if timestep is not None:
deprecate(
"timesteps",
"1.0.0",
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
if self.predict_x0:
if self.config.prediction_type == "flow_prediction":
sigma_t = self.sigmas[self.step_index]
x0_pred = sample - sigma_t * model_output
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
" `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
)
if self.config.thresholding:
x0_pred = self._threshold_sample(x0_pred)
return x0_pred
else:
if self.config.prediction_type == "flow_prediction":
sigma_t = self.sigmas[self.step_index]
epsilon = sample - (1 - sigma_t) * model_output
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
" `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
)
if self.config.thresholding:
sigma_t = self.sigmas[self.step_index]
x0_pred = sample - sigma_t * model_output
x0_pred = self._threshold_sample(x0_pred)
epsilon = model_output + x0_pred
return epsilon
def multistep_uni_p_bh_update(
self,
model_output: torch.Tensor,
*args,
sample: torch.Tensor = None,
order: int = None, # pyright: ignore
**kwargs,
) -> torch.Tensor:
"""
One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
Args:
model_output (`torch.Tensor`):
The direct output from the learned diffusion model at the current timestep.
prev_timestep (`int`):
The previous discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
order (`int`):
The order of UniP at this timestep (corresponds to the *p* in UniPC-p).
Returns:
`torch.Tensor`:
The sample tensor at the previous timestep.
"""
prev_timestep = args[0] if len(args) > 0 else kwargs.pop(
"prev_timestep", None)
if sample is None:
if len(args) > 1:
sample = args[1]
else:
raise ValueError(
" missing `sample` as a required keyward argument")
if order is None:
if len(args) > 2:
order = args[2]
else:
raise ValueError(
" missing `order` as a required keyward argument")
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
model_output_list = self.model_outputs
s0 = self.timestep_list[-1]
m0 = model_output_list[-1]
x = sample
if self.solver_p:
x_t = self.solver_p.step(model_output, s0, x).prev_sample
return x_t
sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[
self.step_index] # pyright: ignore
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
h = lambda_t - lambda_s0
device = sample.device
rks = []
D1s = []
for i in range(1, order):
si = self.step_index - i # pyright: ignore
mi = model_output_list[-(i + 1)]
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
rk = (lambda_si - lambda_s0) / h
rks.append(rk)
D1s.append((mi - m0) / rk) # pyright: ignore
rks.append(1.0)
rks = torch.tensor(rks, device=device)
R = []
b = []
hh = -h if self.predict_x0 else h
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
h_phi_k = h_phi_1 / hh - 1
factorial_i = 1
if self.config.solver_type == "bh1":
B_h = hh
elif self.config.solver_type == "bh2":
B_h = torch.expm1(hh)
else:
raise NotImplementedError()
for i in range(1, order + 1):
R.append(torch.pow(rks, i - 1))
b.append(h_phi_k * factorial_i / B_h)
factorial_i *= i + 1
h_phi_k = h_phi_k / hh - 1 / factorial_i
R = torch.stack(R)
b = torch.tensor(b, device=device)
if len(D1s) > 0:
D1s = torch.stack(D1s, dim=1) # (B, K)
# for order 2, we use a simplified version
if order == 2:
rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
else:
rhos_p = torch.linalg.solve(R[:-1, :-1],
b[:-1]).to(device).to(x.dtype)
else:
D1s = None
if self.predict_x0:
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
if D1s is not None:
pred_res = torch.einsum("k,bkc...->bc...", rhos_p,
D1s) # pyright: ignore
else:
pred_res = 0
x_t = x_t_ - alpha_t * B_h * pred_res
else:
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
if D1s is not None:
pred_res = torch.einsum("k,bkc...->bc...", rhos_p,
D1s) # pyright: ignore
else:
pred_res = 0
x_t = x_t_ - sigma_t * B_h * pred_res
x_t = x_t.to(x.dtype)
return x_t
def multistep_uni_c_bh_update(
self,
this_model_output: torch.Tensor,
*args,
last_sample: torch.Tensor = None,
this_sample: torch.Tensor = None,
order: int = None, # pyright: ignore
**kwargs,
) -> torch.Tensor:
"""
One step for the UniC (B(h) version).
Args:
this_model_output (`torch.Tensor`):
The model outputs at `x_t`.
this_timestep (`int`):
The current timestep `t`.
last_sample (`torch.Tensor`):
The generated sample before the last predictor `x_{t-1}`.
this_sample (`torch.Tensor`):
The generated sample after the last predictor `x_{t}`.
order (`int`):
The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`.
Returns:
`torch.Tensor`:
The corrected sample tensor at the current timestep.
"""
this_timestep = args[0] if len(args) > 0 else kwargs.pop(
"this_timestep", None)
if last_sample is None:
if len(args) > 1:
last_sample = args[1]
else:
raise ValueError(
" missing`last_sample` as a required keyward argument")
if this_sample is None:
if len(args) > 2:
this_sample = args[2]
else:
raise ValueError(
" missing`this_sample` as a required keyward argument")
if order is None:
if len(args) > 3:
order = args[3]
else:
raise ValueError(
" missing`order` as a required keyward argument")
if this_timestep is not None:
deprecate(
"this_timestep",
"1.0.0",
"Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
model_output_list = self.model_outputs
m0 = model_output_list[-1]
x = last_sample
x_t = this_sample
model_t = this_model_output
sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[
self.step_index - 1] # pyright: ignore
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
h = lambda_t - lambda_s0
device = this_sample.device
rks = []
D1s = []
for i in range(1, order):
si = self.step_index - (i + 1) # pyright: ignore
mi = model_output_list[-(i + 1)]
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
rk = (lambda_si - lambda_s0) / h
rks.append(rk)
D1s.append((mi - m0) / rk) # pyright: ignore
rks.append(1.0)
rks = torch.tensor(rks, device=device)
R = []
b = []
hh = -h if self.predict_x0 else h
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
h_phi_k = h_phi_1 / hh - 1
factorial_i = 1
if self.config.solver_type == "bh1":
B_h = hh
elif self.config.solver_type == "bh2":
B_h = torch.expm1(hh)
else:
raise NotImplementedError()
for i in range(1, order + 1):
R.append(torch.pow(rks, i - 1))
b.append(h_phi_k * factorial_i / B_h)
factorial_i *= i + 1
h_phi_k = h_phi_k / hh - 1 / factorial_i
R = torch.stack(R)
b = torch.tensor(b, device=device)
if len(D1s) > 0:
D1s = torch.stack(D1s, dim=1)
else:
D1s = None
# for order 1, we use a simplified version
if order == 1:
rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
else:
rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)
if self.predict_x0:
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
if D1s is not None:
corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
else:
corr_res = 0
D1_t = model_t - m0
x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
else:
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
if D1s is not None:
corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
else:
corr_res = 0
D1_t = model_t - m0
x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
x_t = x_t.to(x.dtype)
return x_t
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = (schedule_timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos = 1 if len(indices) > 1 else 0
return indices[pos].item()
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
def _init_step_index(self, timestep):
"""
Initialize the step_index counter for the scheduler.
"""
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def step(self,
model_output: torch.Tensor,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
return_dict: bool = True,
generator=None) -> Union[SchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
the multistep UniPC.
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
if self.step_index is None:
self._init_step_index(timestep)
use_corrector = (
self.step_index > 0 and
self.step_index - 1 not in self.disable_corrector and
self.last_sample is not None # pyright: ignore
)
model_output_convert = self.convert_model_output(
model_output, sample=sample)
if use_corrector:
sample = self.multistep_uni_c_bh_update(
this_model_output=model_output_convert,
last_sample=self.last_sample,
this_sample=sample,
order=self.this_order,
)
for i in range(self.config.solver_order - 1):
self.model_outputs[i] = self.model_outputs[i + 1]
self.timestep_list[i] = self.timestep_list[i + 1]
self.model_outputs[-1] = model_output_convert
self.timestep_list[-1] = timestep # pyright: ignore
if self.config.lower_order_final:
this_order = min(self.config.solver_order,
len(self.timesteps) -
self.step_index) # pyright: ignore
else:
this_order = self.config.solver_order
self.this_order = min(this_order,
self.lower_order_nums + 1) # warmup for multistep
assert self.this_order > 0
self.last_sample = sample
prev_sample = self.multistep_uni_p_bh_update(
model_output=model_output, # pass the original non-converted model output, in case solver-p is used
sample=sample,
order=self.this_order,
)
if self.lower_order_nums < self.config.solver_order:
self.lower_order_nums += 1
# upon completion increase step index by one
self._step_index += 1 # pyright: ignore
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def scale_model_input(self, sample: torch.Tensor, *args,
**kwargs) -> torch.Tensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.Tensor`):
The input sample.
Returns:
`torch.Tensor`:
A scaled input sample.
"""
return sample
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(
device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(
timesteps):
# mps does not support float64
schedule_timesteps = self.timesteps.to(
original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(
original_samples.device, dtype=torch.float32)
else:
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
# begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [
self.index_for_timestep(t, schedule_timesteps)
for t in timesteps
]
elif self.step_index is not None:
# add_noise is called after first denoising step (for inpainting)
step_indices = [self.step_index] * timesteps.shape[0]
else:
# add noise is called before first denoising step to create initial latent(img2img)
step_indices = [self.begin_index] * timesteps.shape[0]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
noisy_samples = alpha_t * original_samples + sigma_t * noise
return noisy_samples
def __len__(self):
return self.config.num_train_timesteps
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import json
import logging
import math
import os
import random
import sys
import tempfile
from dataclasses import dataclass
from http import HTTPStatus
from typing import Optional, Union
import dashscope
import torch
from PIL import Image
try:
from flash_attn import flash_attn_varlen_func
FLASH_VER = 2
except ModuleNotFoundError:
flash_attn_varlen_func = None # in compatible with CPU machines
FLASH_VER = None
from .system_prompt import *
DEFAULT_SYS_PROMPTS = {
"t2v-A14B": {
"zh": T2V_A14B_ZH_SYS_PROMPT,
"en": T2V_A14B_EN_SYS_PROMPT,
},
"i2v-A14B": {
"zh": I2V_A14B_ZH_SYS_PROMPT,
"en": I2V_A14B_EN_SYS_PROMPT,
"empty": {
"zh": I2V_A14B_EMPTY_ZH_SYS_PROMPT,
"en": I2V_A14B_EMPTY_EN_SYS_PROMPT,
}
},
"ti2v-5B": {
"t2v": {
"zh": T2V_A14B_ZH_SYS_PROMPT,
"en": T2V_A14B_EN_SYS_PROMPT,
},
"i2v": {
"zh": I2V_A14B_ZH_SYS_PROMPT,
"en": I2V_A14B_EN_SYS_PROMPT,
}
},
}
@dataclass
class PromptOutput(object):
status: bool
prompt: str
seed: int
system_prompt: str
message: str
def add_custom_field(self, key: str, value) -> None:
self.__setattr__(key, value)
class PromptExpander:
def __init__(self, model_name, task, is_vl=False, device=0, **kwargs):
self.model_name = model_name
self.task = task
self.is_vl = is_vl
self.device = device
def extend_with_img(self,
prompt,
system_prompt,
image=None,
seed=-1,
*args,
**kwargs):
pass
def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
pass
def decide_system_prompt(self, tar_lang="zh", prompt=None):
assert self.task is not None
if "ti2v" in self.task:
if self.is_vl:
return DEFAULT_SYS_PROMPTS[self.task]["i2v"][tar_lang]
else:
return DEFAULT_SYS_PROMPTS[self.task]["t2v"][tar_lang]
if "i2v" in self.task and len(prompt) == 0:
return DEFAULT_SYS_PROMPTS[self.task]["empty"][tar_lang]
return DEFAULT_SYS_PROMPTS[self.task][tar_lang]
def __call__(self,
prompt,
system_prompt=None,
tar_lang="zh",
image=None,
seed=-1,
*args,
**kwargs):
if system_prompt is None:
system_prompt = self.decide_system_prompt(
tar_lang=tar_lang, prompt=prompt)
if seed < 0:
seed = random.randint(0, sys.maxsize)
if image is not None and self.is_vl:
return self.extend_with_img(
prompt, system_prompt, image=image, seed=seed, *args, **kwargs)
elif not self.is_vl:
return self.extend(prompt, system_prompt, seed, *args, **kwargs)
else:
raise NotImplementedError
class DashScopePromptExpander(PromptExpander):
def __init__(self,
api_key=None,
model_name=None,
task=None,
max_image_size=512 * 512,
retry_times=4,
is_vl=False,
**kwargs):
'''
Args:
api_key: The API key for Dash Scope authentication and access to related services.
model_name: Model name, 'qwen-plus' for extending prompts, 'qwen-vl-max' for extending prompt-images.
task: Task name. This is required to determine the default system prompt.
max_image_size: The maximum size of the image; unit unspecified (e.g., pixels, KB). Please specify the unit based on actual usage.
retry_times: Number of retry attempts in case of request failure.
is_vl: A flag indicating whether the task involves visual-language processing.
**kwargs: Additional keyword arguments that can be passed to the function or method.
'''
if model_name is None:
model_name = 'qwen-plus' if not is_vl else 'qwen-vl-max'
super().__init__(model_name, task, is_vl, **kwargs)
if api_key is not None:
dashscope.api_key = api_key
elif 'DASH_API_KEY' in os.environ and os.environ[
'DASH_API_KEY'] is not None:
dashscope.api_key = os.environ['DASH_API_KEY']
else:
raise ValueError("DASH_API_KEY is not set")
if 'DASH_API_URL' in os.environ and os.environ[
'DASH_API_URL'] is not None:
dashscope.base_http_api_url = os.environ['DASH_API_URL']
else:
dashscope.base_http_api_url = 'https://dashscope.aliyuncs.com/api/v1'
self.api_key = api_key
self.max_image_size = max_image_size
self.model = model_name
self.retry_times = retry_times
def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
messages = [{
'role': 'system',
'content': system_prompt
}, {
'role': 'user',
'content': prompt
}]
exception = None
for _ in range(self.retry_times):
try:
response = dashscope.Generation.call(
self.model,
messages=messages,
seed=seed,
result_format='message', # set the result to be "message" format.
)
assert response.status_code == HTTPStatus.OK, response
expanded_prompt = response['output']['choices'][0]['message'][
'content']
return PromptOutput(
status=True,
prompt=expanded_prompt,
seed=seed,
system_prompt=system_prompt,
message=json.dumps(response, ensure_ascii=False))
except Exception as e:
exception = e
return PromptOutput(
status=False,
prompt=prompt,
seed=seed,
system_prompt=system_prompt,
message=str(exception))
def extend_with_img(self,
prompt,
system_prompt,
image: Union[Image.Image, str] = None,
seed=-1,
*args,
**kwargs):
if isinstance(image, str):
image = Image.open(image).convert('RGB')
w = image.width
h = image.height
area = min(w * h, self.max_image_size)
aspect_ratio = h / w
resized_h = round(math.sqrt(area * aspect_ratio))
resized_w = round(math.sqrt(area / aspect_ratio))
image = image.resize((resized_w, resized_h))
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f:
image.save(f.name)
fname = f.name
image_path = f"file://{f.name}"
prompt = f"{prompt}"
messages = [
{
'role': 'system',
'content': [{
"text": system_prompt
}]
},
{
'role': 'user',
'content': [{
"text": prompt
}, {
"image": image_path
}]
},
]
response = None
result_prompt = prompt
exception = None
status = False
for _ in range(self.retry_times):
try:
response = dashscope.MultiModalConversation.call(
self.model,
messages=messages,
seed=seed,
result_format='message', # set the result to be "message" format.
)
assert response.status_code == HTTPStatus.OK, response
result_prompt = response['output']['choices'][0]['message'][
'content'][0]['text'].replace('\n', '\\n')
status = True
break
except Exception as e:
exception = e
result_prompt = result_prompt.replace('\n', '\\n')
os.remove(fname)
return PromptOutput(
status=status,
prompt=result_prompt,
seed=seed,
system_prompt=system_prompt,
message=str(exception) if not status else json.dumps(
response, ensure_ascii=False))
class QwenPromptExpander(PromptExpander):
model_dict = {
"QwenVL2.5_3B": "Qwen/Qwen2.5-VL-3B-Instruct",
"QwenVL2.5_7B": "Qwen/Qwen2.5-VL-7B-Instruct",
"Qwen2.5_3B": "Qwen/Qwen2.5-3B-Instruct",
"Qwen2.5_7B": "Qwen/Qwen2.5-7B-Instruct",
"Qwen2.5_14B": "Qwen/Qwen2.5-14B-Instruct",
}
def __init__(self,
model_name=None,
task=None,
device=0,
is_vl=False,
**kwargs):
'''
Args:
model_name: Use predefined model names such as 'QwenVL2.5_7B' and 'Qwen2.5_14B',
which are specific versions of the Qwen model. Alternatively, you can use the
local path to a downloaded model or the model name from Hugging Face."
Detailed Breakdown:
Predefined Model Names:
* 'QwenVL2.5_7B' and 'Qwen2.5_14B' are specific versions of the Qwen model.
Local Path:
* You can provide the path to a model that you have downloaded locally.
Hugging Face Model Name:
* You can also specify the model name from Hugging Face's model hub.
task: Task name. This is required to determine the default system prompt.
is_vl: A flag indicating whether the task involves visual-language processing.
**kwargs: Additional keyword arguments that can be passed to the function or method.
'''
if model_name is None:
model_name = 'Qwen2.5_14B' if not is_vl else 'QwenVL2.5_7B'
super().__init__(model_name, task, is_vl, device, **kwargs)
if (not os.path.exists(self.model_name)) and (self.model_name
in self.model_dict):
self.model_name = self.model_dict[self.model_name]
if self.is_vl:
# default: Load the model on the available device(s)
from transformers import (
AutoProcessor,
AutoTokenizer,
Qwen2_5_VLForConditionalGeneration,
)
try:
from .qwen_vl_utils import process_vision_info
except:
from qwen_vl_utils import process_vision_info
self.process_vision_info = process_vision_info
min_pixels = 256 * 28 * 28
max_pixels = 1280 * 28 * 28
self.processor = AutoProcessor.from_pretrained(
self.model_name,
min_pixels=min_pixels,
max_pixels=max_pixels,
use_fast=True)
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
self.model_name,
torch_dtype=torch.bfloat16 if FLASH_VER == 2 else
torch.float16 if "AWQ" in self.model_name else "auto",
attn_implementation="flash_attention_2"
if FLASH_VER == 2 else None,
device_map="cpu")
else:
from transformers import AutoModelForCausalLM, AutoTokenizer
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.float16
if "AWQ" in self.model_name else "auto",
attn_implementation="flash_attention_2"
if FLASH_VER == 2 else None,
device_map="cpu")
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
self.model = self.model.to(self.device)
messages = [{
"role": "system",
"content": system_prompt
}, {
"role": "user",
"content": prompt
}]
text = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True)
model_inputs = self.tokenizer([text],
return_tensors="pt").to(self.model.device)
generated_ids = self.model.generate(**model_inputs, max_new_tokens=512)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(
model_inputs.input_ids, generated_ids)
]
expanded_prompt = self.tokenizer.batch_decode(
generated_ids, skip_special_tokens=True)[0]
self.model = self.model.to("cpu")
return PromptOutput(
status=True,
prompt=expanded_prompt,
seed=seed,
system_prompt=system_prompt,
message=json.dumps({"content": expanded_prompt},
ensure_ascii=False))
def extend_with_img(self,
prompt,
system_prompt,
image: Union[Image.Image, str] = None,
seed=-1,
*args,
**kwargs):
self.model = self.model.to(self.device)
messages = [{
'role': 'system',
'content': [{
"type": "text",
"text": system_prompt
}]
}, {
"role":
"user",
"content": [
{
"type": "image",
"image": image,
},
{
"type": "text",
"text": prompt
},
],
}]
# Preparation for inference
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = self.process_vision_info(messages)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(self.device)
# Inference: Generation of the output
generated_ids = self.model.generate(**inputs, max_new_tokens=512)
generated_ids_trimmed = [
out_ids[len(in_ids):]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
expanded_prompt = self.processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False)[0]
self.model = self.model.to("cpu")
return PromptOutput(
status=True,
prompt=expanded_prompt,
seed=seed,
system_prompt=system_prompt,
message=json.dumps({"content": expanded_prompt},
ensure_ascii=False))
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s] %(levelname)s: %(message)s",
handlers=[logging.StreamHandler(stream=sys.stdout)])
seed = 100
prompt = "夏日海滩度假风格,一只戴着墨镜的白色猫咪坐在冲浪板上。猫咪毛发蓬松,表情悠闲,直视镜头。背景是模糊的海滩景色,海水清澈,远处有绿色的山丘和蓝天白云。猫咪的姿态自然放松,仿佛在享受海风和阳光。近景特写,强调猫咪的细节和海滩的清新氛围。"
en_prompt = "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
image = "./examples/i2v_input.JPG"
def test(method,
prompt,
model_name,
task,
image=None,
en_prompt=None,
seed=None):
prompt_expander = method(
model_name=model_name, task=task, is_vl=image is not None)
result = prompt_expander(prompt, image=image, tar_lang="zh")
logging.info(f"zh prompt -> zh: {result.prompt}")
result = prompt_expander(prompt, image=image, tar_lang="en")
logging.info(f"zh prompt -> en: {result.prompt}")
if en_prompt is not None:
result = prompt_expander(en_prompt, image=image, tar_lang="zh")
logging.info(f"en prompt -> zh: {result.prompt}")
result = prompt_expander(en_prompt, image=image, tar_lang="en")
logging.info(f"en prompt -> en: {result.prompt}")
ds_model_name = None
ds_vl_model_name = None
qwen_model_name = None
qwen_vl_model_name = None
for task in ["t2v-A14B", "i2v-A14B", "ti2v-5B"]:
# test prompt extend
if "t2v" in task or "ti2v" in task:
# test dashscope api
logging.info(f"-" * 40)
logging.info(f"Testing {task} dashscope prompt extend")
test(
DashScopePromptExpander,
prompt,
ds_model_name,
task,
image=None,
en_prompt=en_prompt,
seed=seed)
# test qwen api
logging.info(f"-" * 40)
logging.info(f"Testing {task} qwen prompt extend")
test(
QwenPromptExpander,
prompt,
qwen_model_name,
task,
image=None,
en_prompt=en_prompt,
seed=seed)
# test prompt-image extend
if "i2v" in task:
# test dashscope api
logging.info(f"-" * 40)
logging.info(f"Testing {task} dashscope vl prompt extend")
test(
DashScopePromptExpander,
prompt,
ds_vl_model_name,
task,
image=image,
en_prompt=en_prompt,
seed=seed)
# test qwen api
logging.info(f"-" * 40)
logging.info(f"Testing {task} qwen vl prompt extend")
test(
QwenPromptExpander,
prompt,
qwen_vl_model_name,
task,
image=image,
en_prompt=en_prompt,
seed=seed)
# test empty prompt extend
if "i2v-A14B" in task:
# test dashscope api
logging.info(f"-" * 40)
logging.info(f"Testing {task} dashscope vl empty prompt extend")
test(
DashScopePromptExpander,
"",
ds_vl_model_name,
task,
image=image,
en_prompt=None,
seed=seed)
# test qwen api
logging.info(f"-" * 40)
logging.info(f"Testing {task} qwen vl empty prompt extend")
test(
QwenPromptExpander,
"",
qwen_vl_model_name,
task,
image=image,
en_prompt=None,
seed=seed)
# Copied from https://github.com/kq-chen/qwen-vl-utils
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from __future__ import annotations
import base64
import logging
import math
import os
import sys
import time
import warnings
from functools import lru_cache
from io import BytesIO
import requests
import torch
import torchvision
from packaging import version
from PIL import Image
from torchvision import io, transforms
from torchvision.transforms import InterpolationMode
logger = logging.getLogger(__name__)
IMAGE_FACTOR = 28
MIN_PIXELS = 4 * 28 * 28
MAX_PIXELS = 16384 * 28 * 28
MAX_RATIO = 200
VIDEO_MIN_PIXELS = 128 * 28 * 28
VIDEO_MAX_PIXELS = 768 * 28 * 28
VIDEO_TOTAL_PIXELS = 24576 * 28 * 28
FRAME_FACTOR = 2
FPS = 2.0
FPS_MIN_FRAMES = 4
FPS_MAX_FRAMES = 768
def round_by_factor(number: int, factor: int) -> int:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return round(number / factor) * factor
def ceil_by_factor(number: int, factor: int) -> int:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return math.ceil(number / factor) * factor
def floor_by_factor(number: int, factor: int) -> int:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor
def smart_resize(height: int,
width: int,
factor: int = IMAGE_FACTOR,
min_pixels: int = MIN_PIXELS,
max_pixels: int = MAX_PIXELS) -> tuple[int, int]:
"""
Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if max(height, width) / min(height, width) > MAX_RATIO:
raise ValueError(
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
)
h_bar = max(factor, round_by_factor(height, factor))
w_bar = max(factor, round_by_factor(width, factor))
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = floor_by_factor(height / beta, factor)
w_bar = floor_by_factor(width / beta, factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = ceil_by_factor(height * beta, factor)
w_bar = ceil_by_factor(width * beta, factor)
return h_bar, w_bar
def fetch_image(ele: dict[str, str | Image.Image],
size_factor: int = IMAGE_FACTOR) -> Image.Image:
if "image" in ele:
image = ele["image"]
else:
image = ele["image_url"]
image_obj = None
if isinstance(image, Image.Image):
image_obj = image
elif image.startswith("http://") or image.startswith("https://"):
image_obj = Image.open(requests.get(image, stream=True).raw)
elif image.startswith("file://"):
image_obj = Image.open(image[7:])
elif image.startswith("data:image"):
if "base64," in image:
_, base64_data = image.split("base64,", 1)
data = base64.b64decode(base64_data)
image_obj = Image.open(BytesIO(data))
else:
image_obj = Image.open(image)
if image_obj is None:
raise ValueError(
f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}"
)
image = image_obj.convert("RGB")
## resize
if "resized_height" in ele and "resized_width" in ele:
resized_height, resized_width = smart_resize(
ele["resized_height"],
ele["resized_width"],
factor=size_factor,
)
else:
width, height = image.size
min_pixels = ele.get("min_pixels", MIN_PIXELS)
max_pixels = ele.get("max_pixels", MAX_PIXELS)
resized_height, resized_width = smart_resize(
height,
width,
factor=size_factor,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
image = image.resize((resized_width, resized_height))
return image
def smart_nframes(
ele: dict,
total_frames: int,
video_fps: int | float,
) -> int:
"""calculate the number of frames for video used for model inputs.
Args:
ele (dict): a dict contains the configuration of video.
support either `fps` or `nframes`:
- nframes: the number of frames to extract for model inputs.
- fps: the fps to extract frames for model inputs.
- min_frames: the minimum number of frames of the video, only used when fps is provided.
- max_frames: the maximum number of frames of the video, only used when fps is provided.
total_frames (int): the original total number of frames of the video.
video_fps (int | float): the original fps of the video.
Raises:
ValueError: nframes should in interval [FRAME_FACTOR, total_frames].
Returns:
int: the number of frames for video used for model inputs.
"""
assert not ("fps" in ele and
"nframes" in ele), "Only accept either `fps` or `nframes`"
if "nframes" in ele:
nframes = round_by_factor(ele["nframes"], FRAME_FACTOR)
else:
fps = ele.get("fps", FPS)
min_frames = ceil_by_factor(
ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR)
max_frames = floor_by_factor(
ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)),
FRAME_FACTOR)
nframes = total_frames / video_fps * fps
nframes = min(max(nframes, min_frames), max_frames)
nframes = round_by_factor(nframes, FRAME_FACTOR)
if not (FRAME_FACTOR <= nframes and nframes <= total_frames):
raise ValueError(
f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}."
)
return nframes
def _read_video_torchvision(ele: dict,) -> torch.Tensor:
"""read video using torchvision.io.read_video
Args:
ele (dict): a dict contains the configuration of video.
support keys:
- video: the path of video. support "file://", "http://", "https://" and local path.
- video_start: the start time of video.
- video_end: the end time of video.
Returns:
torch.Tensor: the video tensor with shape (T, C, H, W).
"""
video_path = ele["video"]
if version.parse(torchvision.__version__) < version.parse("0.19.0"):
if "http://" in video_path or "https://" in video_path:
warnings.warn(
"torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0."
)
if "file://" in video_path:
video_path = video_path[7:]
st = time.time()
video, audio, info = io.read_video(
video_path,
start_pts=ele.get("video_start", 0.0),
end_pts=ele.get("video_end", None),
pts_unit="sec",
output_format="TCHW",
)
total_frames, video_fps = video.size(0), info["video_fps"]
logger.info(
f"torchvision: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s"
)
nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
idx = torch.linspace(0, total_frames - 1, nframes).round().long()
video = video[idx]
return video
def is_decord_available() -> bool:
import importlib.util
return importlib.util.find_spec("decord") is not None
def _read_video_decord(ele: dict,) -> torch.Tensor:
"""read video using decord.VideoReader
Args:
ele (dict): a dict contains the configuration of video.
support keys:
- video: the path of video. support "file://", "http://", "https://" and local path.
- video_start: the start time of video.
- video_end: the end time of video.
Returns:
torch.Tensor: the video tensor with shape (T, C, H, W).
"""
import decord
video_path = ele["video"]
st = time.time()
vr = decord.VideoReader(video_path)
# TODO: support start_pts and end_pts
if 'video_start' in ele or 'video_end' in ele:
raise NotImplementedError(
"not support start_pts and end_pts in decord for now.")
total_frames, video_fps = len(vr), vr.get_avg_fps()
logger.info(
f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s"
)
nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
video = vr.get_batch(idx).asnumpy()
video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format
return video
VIDEO_READER_BACKENDS = {
"decord": _read_video_decord,
"torchvision": _read_video_torchvision,
}
FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None)
@lru_cache(maxsize=1)
def get_video_reader_backend() -> str:
if FORCE_QWENVL_VIDEO_READER is not None:
video_reader_backend = FORCE_QWENVL_VIDEO_READER
elif is_decord_available():
video_reader_backend = "decord"
else:
video_reader_backend = "torchvision"
logger.info(
f"qwen-vl-utils using {video_reader_backend} to read video.",
file=sys.stderr)
return video_reader_backend
def fetch_video(
ele: dict,
image_factor: int = IMAGE_FACTOR) -> torch.Tensor | list[Image.Image]:
if isinstance(ele["video"], str):
video_reader_backend = get_video_reader_backend()
video = VIDEO_READER_BACKENDS[video_reader_backend](ele)
nframes, _, height, width = video.shape
min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS)
max_pixels = max(
min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR),
int(min_pixels * 1.05))
max_pixels = ele.get("max_pixels", max_pixels)
if "resized_height" in ele and "resized_width" in ele:
resized_height, resized_width = smart_resize(
ele["resized_height"],
ele["resized_width"],
factor=image_factor,
)
else:
resized_height, resized_width = smart_resize(
height,
width,
factor=image_factor,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
video = transforms.functional.resize(
video,
[resized_height, resized_width],
interpolation=InterpolationMode.BICUBIC,
antialias=True,
).float()
return video
else:
assert isinstance(ele["video"], (list, tuple))
process_info = ele.copy()
process_info.pop("type", None)
process_info.pop("video", None)
images = [
fetch_image({
"image": video_element,
**process_info
},
size_factor=image_factor)
for video_element in ele["video"]
]
nframes = ceil_by_factor(len(images), FRAME_FACTOR)
if len(images) < nframes:
images.extend([images[-1]] * (nframes - len(images)))
return images
def extract_vision_info(
conversations: list[dict] | list[list[dict]]) -> list[dict]:
vision_infos = []
if isinstance(conversations[0], dict):
conversations = [conversations]
for conversation in conversations:
for message in conversation:
if isinstance(message["content"], list):
for ele in message["content"]:
if ("image" in ele or "image_url" in ele or
"video" in ele or
ele["type"] in ("image", "image_url", "video")):
vision_infos.append(ele)
return vision_infos
def process_vision_info(
conversations: list[dict] | list[list[dict]],
) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] |
None]:
vision_infos = extract_vision_info(conversations)
## Read images or videos
image_inputs = []
video_inputs = []
for vision_info in vision_infos:
if "image" in vision_info or "image_url" in vision_info:
image_inputs.append(fetch_image(vision_info))
elif "video" in vision_info:
video_inputs.append(fetch_video(vision_info))
else:
raise ValueError("image, image_url or video should in content.")
if len(image_inputs) == 0:
image_inputs = None
if len(video_inputs) == 0:
video_inputs = None
return image_inputs, video_inputs
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
T2V_A14B_ZH_SYS_PROMPT = \
''' 你是一位电影导演,旨在为用户输入的原始prompt添加电影元素,改写为优质Prompt,使其完整、具有表现力。
任务要求:
1. 对于用户输入的prompt,在不改变prompt的原意(如主体、动作)前提下,从下列电影美学设定中选择部分合适的时间、光源、光线强度、光线角度、对比度、饱和度、色调、拍摄角度、镜头大小、构图的电影设定细节,将这些内容添加到prompt中,让画面变得更美,注意,可以任选,不必每项都有
时间:["白天", "夜晚", "黎明", "日出"], 可以不选, 如果prompt没有特别说明则选白天 !
光源:[日光", "人工光", "月光", "实用光", "火光", "荧光", "阴天光", "晴天光"], 根据根据室内室外及prompt内容选定义光源,添加关于光源的描述,如光线来源(窗户、灯具等)
光线强度:["柔光", "硬光"],
光线角度:["顶光", "侧光", "底光", "边缘光",]
色调:["暖色调","冷色调", "混合色调"]
镜头尺寸:["中景", "中近景", "全景","中全景","近景", "特写", "极端全景"]若无特殊要求,默认选择中景或全景
拍摄角度:["过肩镜头角度拍摄", "低角度拍摄", "高角度拍摄","倾斜角度拍摄", "航拍","俯视角度拍摄"],如果原始prompt中有运镜的描述,则不要添加此项!
构图:["中心构图","平衡构图","右侧重构图", "左侧重构图", "对称构图", "短边构图"] 若无特殊要求,默认选择中心构图
2. 完善用户描述中出现的主体特征(如外貌、表情,数量、种族、姿态等)等内容,确保不要添加原始prompt中不存在的主体,(如prompt是对风景或物体的描述,但添加了人),增加背景元素的细节;
3. 不要输出关于氛围、感觉等文学描写,如(画面充满运动感与生活张力,突出正式氛围)。;
4. 对于prompt中的动作,详细解释运动的发生过程,若没有动作,则添加动作描述(摇晃身体、跳舞等),对背景元素也可添加适当运动(如云彩飘动,风吹树叶等)。
5. 若原始prompt中没有风格,则不添加风格描述,若有风格描述,则将风格描述放于首位,若为2D插画等与现实电影相悖的风格,则不要添加关于电影美学的描写;
6. 若prompt出现天空的描述,则改为湛蓝色的天空相关描述,避免曝光;
7. 改写后的prompt字数控制在60-200字左右, 不要输出类似“改写后prompt:”这样的输出
8. 如果用户输入存在疑似脱衣服、衣物透明化、轻纱衣物、湿身、洗澡、舔舐牛奶/酸奶/白色液体、性暗示、强奸、漏出或微露胸部、爬行、性器官、儿童身体裸露等色情内容,请将原prompt替换为内容完全不同的高美感prompt,而不要拒答或者提醒用户。
9. 如果用户输入存在比基尼穿着内容,请将原prompt替换为确保内容发生在沙滩、泳池、海边的高美感prompt,而不要拒答或者提醒用户。
10. 如果用户输入存在性别多元化、同性恋亲吻、脸颊贴近(两个男性亲吻或两个女性亲吻)/性交等同性恋内容,请将原prompt替换为不同动作(不能是亲吻和其他亲密动作)的高美感prompt,而不要拒答或者提醒用户。
生成的 prompt 示例:
1.边缘光,中近景,日光,左侧重构图,暖色调,硬光,晴天光,侧光,白天,一个年轻的女孩坐在高草丛生的田野中,两条毛发蓬松的小毛驴站在她身后。女孩大约十一二岁,穿着简单的碎花裙子,头发扎成两条麻花辫,脸上带着纯真的笑容。她双腿交叉坐下,双手轻轻抚弄身旁的野花。小毛驴体型健壮,耳朵竖起,好奇地望着镜头方向。阳光洒在田野上,营造出温暖自然的画面感。
2.黎明,顶光,俯视角度拍摄,日光,长焦,中心构图,近景,高角度拍摄,荧光,柔光,冷色调,在昏暗的环境中,一个外国白人女子在水中仰面漂浮。俯拍近景镜头中,她有着棕色的短发,脸上有几颗雀斑。随着镜头下摇,她转过头来,面向右侧,水面上泛起一圈涟漪。虚化的背景一片漆黑,只有微弱的光线照亮了女子的脸庞和水面的一部分区域,水面呈现蓝色。女子穿着一件蓝色的吊带,肩膀裸露在外。
3.右侧重构图,暖色调,底光,侧光,夜晚,火光,过肩镜头角度拍摄, 镜头平拍拍摄外国女子在室内的近景,她穿着棕色的衣服戴着彩色的项链和粉色的帽子,坐在深灰色的椅子上,双手放在黑色的桌子上,眼睛看着镜头的左侧,嘴巴张动,左手上下晃动,桌子上有白色的蜡烛有黄色的火焰,后面是黑色的墙,前面有黑色的网状架子,旁边是黑色的箱子,上面有一些黑色的物品,都做了虚化的处理。
4. 二次元厚涂动漫插画,一个猫耳兽耳白人少女手持文件夹摇晃,神情略带不满。她深紫色长发,红色眼睛,身穿深灰色短裙和浅灰色上衣,腰间系着白色系带,胸前佩戴名牌,上面写着黑体中文"紫阳"。淡黄色调室内背景,隐约可见一些家具轮廓。少女头顶有一个粉色光圈。线条流畅的日系赛璐璐风格。近景半身略俯视视角。
'''
T2V_A14B_EN_SYS_PROMPT = \
'''你是一位电影导演,旨在为用户输入的原始prompt添加电影元素,改写为优质(英文)Prompt,使其完整、具有表现力注意,输出必须是英文!
任务要求:
1. 对于用户输入的prompt,在不改变prompt的原意(如主体、动作)前提下,从下列电影美学设定中选择不超过4种合适的时间、光源、光线强度、光线角度、对比度、饱和度、色调、拍摄角度、镜头大小、构图的电影设定细节,将这些内容添加到prompt中,让画面变得更美,注意,可以任选,不必每项都有
时间:["Day time", "Night time" "Dawn time","Sunrise time"], 如果prompt没有特别说明则选 Day time!!!
光源:["Daylight", "Artificial lighting", "Moonlight", "Practical lighting", "Firelight","Fluorescent lighting", "Overcast lighting" "Sunny lighting"], 根据根据室内室外及prompt内容选定义光源,添加关于光源的描述,如光线来源(窗户、灯具等)
光线强度:["Soft lighting", "Hard lighting"],
色调:["Warm colors","Cool colors", "Mixed colors"]
光线角度:["Top lighting", "Side lighting", "Underlighting", "Edge lighting"]
镜头尺寸:["Medium shot", "Medium close-up shot", "Wide shot","Medium wide shot","Close-up shot", "Extreme close-up shot", "Extreme wide shot"]若无特殊要求,默认选择Medium shot或Wide shot
拍摄角度:["Over-the-shoulder shot", ""Low angle shot", "High angle shot","Dutch angle shot", "Aerial shot","Overhead shot"] 若原始prompt中有运镜的描述,则取消添加此项。
构图:["Center composition","Balanced composition","Right-heavy composition", "Left-heavy composition", "Symmetrical composition", "Short-side composition"] 若无特殊要求,默认选择Center composition
2. 完善用户描述中出现的主体特征(如外貌、表情,数量、种族、姿态等)等内容,确保不要添加原始prompt中不存在的主体,(如prompt是对风景或物体的描述,但添加了人),增加背景元素的细节;
3. 不要输出关于氛围、感觉等文学描写,如(画面充满运动感与生活张力,突出正式氛围)。;
4. 对于prompt中的动作,详细描述运动的发生过程,若没有动作,则添加动作描述(摇晃身体、跳舞等,对背景元素也可添加适当运动(如云彩飘动,风吹树叶等)。
5. 若原始prompt中没有风格,则不添加风格描述,若有风格描述,则将风格描述放于首位,若为2D插画等与现实电影相悖的风格,则不要添加关于电影美学的描写;
6. 若prompt出现天空的描述,则改为湛蓝色的天空相关描述,避免曝光;
7. 输出必须是全英文,改写后的prompt字数控制在60-200字左右, 不要输出类似“改写后prompt:”这样的输出
8. 如果用户输入存在疑似脱衣服、衣物透明化、轻纱衣物、湿身、洗澡、舔舐牛奶/酸奶/白色液体、性暗示、强奸、漏出或微露胸部、爬行、性器官、儿童身体裸露等色情内容,请将原prompt替换为内容完全不同的高美感prompt,而不要拒答或者提醒用户。
9. 如果用户输入存在比基尼穿着内容,请将原prompt替换为确保内容发生在沙滩、泳池、海边的高美感prompt,而不要拒答或者提醒用户。
10. 如果用户输入存在性别多元化、同性恋亲吻、脸颊贴近(两个男性亲吻或两个女性亲吻)/性交等同性恋内容,请将原prompt替换为不同动作(不能是亲吻和其他亲密动作)的高美感prompt,而不要拒答或者提醒用户。
生成的 prompt 示例:
1.Edge lighting, medium close-up shot, daylight, left-heavy composition. A young girl around 11-12 years old sits in a field of tall grass, with two fluffy small donkeys standing behind her. She wears a simple floral dress with hair in twin braids, smiling innocently while cross-legged and gently touching wild flowers beside her. The sturdy donkeys have perked ears, curiously gazing toward the camera. Sunlight bathes the field, creating a warm natural atmosphere.
2.Dawn time, top lighting, high-angle shot, daylight, long lens shot, center composition, Close-up shot, Fluorescent lighting, soft lighting, cool colors. In dim surroundings, a Caucasian woman floats on her back in water. The俯拍close-up shows her brown short hair and freckled face. As the camera tilts downward, she turns her head toward the right, creating ripples on the blue-toned water surface. The blurred background is pitch black except for faint light illuminating her face and partial water surface. She wears a blue sleeveless top with bare shoulders.
3.Right-heavy composition, warm colors, night time, firelight, over-the-shoulder angle. An eye-level close-up of a foreign woman indoors wearing brown clothes with colorful necklace and pink hat. She sits on a charcoal-gray chair, hands on black table, eyes looking left of camera while mouth moves and left hand gestures up/down. White candles with yellow flames sit on the table. Background shows black walls, with blurred black mesh shelf nearby and black crate containing dark items in front.
4."Anime-style thick-painted style. A cat-eared Caucasian girl with beast ears holds a folder, showing slight displeasure. Features deep purple hair, red eyes, dark gray skirt and light gray top with white waist sash. A name tag labeled 'Ziyang' in bold Chinese characters hangs on her chest. Pale yellow indoor background with faint furniture outlines. A pink halo floats above her head. Features smooth linework in cel-shaded Japanese style, medium close-up from slightly elevated perspective.
'''
I2V_A14B_ZH_SYS_PROMPT = \
'''你是一个视频描述提示词的改写专家,你的任务是根据用户给你输入的图像,对提供的视频描述提示词进行改写,你要强调潜在的动态内容。具体要求如下
用户输入的语言可能含有多样化的描述,如markdown文档格式、指令格式,长度过长或者过短,你需要根据图片的内容和用户的输入的提示词,尽可能提取用户输入的提示词和图片关联信息。
你改写的视频描述结果要尽可能保留提供给你的视频描述提示词中动态部分,保留主体的动作。
你要根据图像,强调并简化视频描述提示词中的图像主体,如果用户只提供了动作,你要根据图像内容合理补充,如“跳舞”补充称“一个女孩在跳舞”
如果用户输入的提示词过长,你需要提炼潜在的动作过程
如果用户输入的提示词过短,综合用户输入的提示词以及画面内容,合理的增加潜在的运动信息
你要根据图像,保留并强调视频描述提示词中关于运镜手段的描述,如“镜头上摇”,“镜头从左到右”,“镜头从右到左”等等,你要保留,如“镜头拍摄两个男人打斗,他们先是躺在地上,随后镜头向上移动,拍摄他们站起来,接着镜头向左移动,左边男人拿着一个蓝色的东西,右边男人上前抢夺,两人激烈地来回争抢。”。
你需要给出对视频描述的动态内容,不要添加对于静态场景的描述,如果用户输入的描述已经在画面中出现,则移除这些描述
改写后的prompt字数控制在100字以下
无论用户输入那种语言,你都需要输出中文
改写后 prompt 示例:
1. 镜头后拉,拍摄两个外国男人,走在楼梯上,镜头左侧的男人右手搀扶着镜头右侧的男人。
2. 一只黑色的小松鼠专注地吃着东西,偶尔抬头看看四周。
3. 男子说着话,表情从微笑逐渐转变为闭眼,然后睁开眼睛,最后是闭眼微笑,他的手势活跃,在说话时做出一系列的手势。
4. 一个人正在用尺子和笔进行测量的特写,右手用一支黑色水性笔在纸上画出一条直线。
5. 一辆车模型在木板上形式,车辆从画面的右侧向左侧移动,经过一片草地和一些木制结构。
6. 镜头左移后前推,拍摄一个人坐在防波堤上。
7. 男子说着话,他的表情和手势随着对话内容的变化而变化,但整体场景保持不变。
8. 镜头左移后前推,拍摄一个人坐在防波堤上。
9. 带着珍珠项链的女子看向画面右侧并说着话。
请直接输出改写后的文本,不要进行多余的回复。'''
I2V_A14B_EN_SYS_PROMPT = \
'''You are an expert in rewriting video description prompts. Your task is to rewrite the provided video description prompts based on the images given by users, emphasizing potential dynamic content. Specific requirements are as follows:
The user's input language may include diverse descriptions, such as markdown format, instruction format, or be too long or too short. You need to extract the relevant information from the user’s input and associate it with the image content.
Your rewritten video description should retain the dynamic parts of the provided prompts, focusing on the main subject's actions. Emphasize and simplify the main subject of the image while retaining their movement. If the user only provides an action (e.g., "dancing"), supplement it reasonably based on the image content (e.g., "a girl is dancing").
If the user’s input prompt is too long, refine it to capture the essential action process. If the input is too short, add reasonable motion-related details based on the image content.
Retain and emphasize descriptions of camera movements, such as "the camera pans up," "the camera moves from left to right," or "the camera moves from right to left." For example: "The camera captures two men fighting. They start lying on the ground, then the camera moves upward as they stand up. The camera shifts left, showing the man on the left holding a blue object while the man on the right tries to grab it, resulting in a fierce back-and-forth struggle."
Focus on dynamic content in the video description and avoid adding static scene descriptions. If the user’s input already describes elements visible in the image, remove those static descriptions.
Limit the rewritten prompt to 100 words or less. Regardless of the input language, your output must be in English.
Examples of rewritten prompts:
The camera pulls back to show two foreign men walking up the stairs. The man on the left supports the man on the right with his right hand.
A black squirrel focuses on eating, occasionally looking around.
A man talks, his expression shifting from smiling to closing his eyes, reopening them, and finally smiling with closed eyes. His gestures are lively, making various hand motions while speaking.
A close-up of someone measuring with a ruler and pen, drawing a straight line on paper with a black marker in their right hand.
A model car moves on a wooden board, traveling from right to left across grass and wooden structures.
The camera moves left, then pushes forward to capture a person sitting on a breakwater.
A man speaks, his expressions and gestures changing with the conversation, while the overall scene remains constant.
The camera moves left, then pushes forward to capture a person sitting on a breakwater.
A woman wearing a pearl necklace looks to the right and speaks.
Output only the rewritten text without additional responses.'''
I2V_A14B_EMPTY_ZH_SYS_PROMPT = \
'''你是一个视频描述提示词的撰写专家,你的任务是根据用户给你输入的图像,发挥合理的想象,让这张图动起来,你要强调潜在的动态内容。具体要求如下
你需要根据图片的内容想象出运动的主体
你输出的结果应强调图片中的动态部分,保留主体的动作。
你需要给出对视频描述的动态内容,不要有过多的对于静态场景的描述
输出的prompt字数控制在100字以下
你需要输出中文
prompt 示例:
1. 镜头后拉,拍摄两个外国男人,走在楼梯上,镜头左侧的男人右手搀扶着镜头右侧的男人。
2. 一只黑色的小松鼠专注地吃着东西,偶尔抬头看看四周。
3. 男子说着话,表情从微笑逐渐转变为闭眼,然后睁开眼睛,最后是闭眼微笑,他的手势活跃,在说话时做出一系列的手势。
4. 一个人正在用尺子和笔进行测量的特写,右手用一支黑色水性笔在纸上画出一条直线。
5. 一辆车模型在木板上形式,车辆从画面的右侧向左侧移动,经过一片草地和一些木制结构。
6. 镜头左移后前推,拍摄一个人坐在防波堤上。
7. 男子说着话,他的表情和手势随着对话内容的变化而变化,但整体场景保持不变。
8. 镜头左移后前推,拍摄一个人坐在防波堤上。
9. 带着珍珠项链的女子看向画面右侧并说着话。
请直接输出文本,不要进行多余的回复。'''
I2V_A14B_EMPTY_EN_SYS_PROMPT = \
'''You are an expert in writing video description prompts. Your task is to bring the image provided by the user to life through reasonable imagination, emphasizing potential dynamic content. Specific requirements are as follows:
You need to imagine the moving subject based on the content of the image.
Your output should emphasize the dynamic parts of the image and retain the main subject’s actions.
Focus only on describing dynamic content; avoid excessive descriptions of static scenes.
Limit the output prompt to 100 words or less.
The output must be in English.
Prompt examples:
The camera pulls back to show two foreign men walking up the stairs. The man on the left supports the man on the right with his right hand.
A black squirrel focuses on eating, occasionally looking around.
A man talks, his expression shifting from smiling to closing his eyes, reopening them, and finally smiling with closed eyes. His gestures are lively, making various hand motions while speaking.
A close-up of someone measuring with a ruler and pen, drawing a straight line on paper with a black marker in their right hand.
A model car moves on a wooden board, traveling from right to left across grass and wooden structures.
The camera moves left, then pushes forward to capture a person sitting on a breakwater.
A man speaks, his expressions and gestures changing with the conversation, while the overall scene remains constant.
The camera moves left, then pushes forward to capture a person sitting on a breakwater.
A woman wearing a pearl necklace looks to the right and speaks.
Output only the text without additional responses.'''
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import binascii
import logging
import os
import os.path as osp
import shutil
import subprocess
import imageio
import torch
import torchvision
__all__ = ['save_video', 'save_image', 'str2bool']
def rand_name(length=8, suffix=''):
name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
if suffix:
if not suffix.startswith('.'):
suffix = '.' + suffix
name += suffix
return name
def merge_video_audio(video_path: str, audio_path: str):
"""
Merge the video and audio into a new video, with the duration set to the shorter of the two,
and overwrite the original video file.
Parameters:
video_path (str): Path to the original video file
audio_path (str): Path to the audio file
"""
# set logging
logging.basicConfig(level=logging.INFO)
# check
if not os.path.exists(video_path):
raise FileNotFoundError(f"video file {video_path} does not exist")
if not os.path.exists(audio_path):
raise FileNotFoundError(f"audio file {audio_path} does not exist")
base, ext = os.path.splitext(video_path)
temp_output = f"{base}_temp{ext}"
try:
# create ffmpeg command
command = [
'ffmpeg',
'-y', # overwrite
'-i',
video_path,
'-i',
audio_path,
'-c:v',
'copy', # copy video stream
'-c:a',
'aac', # use AAC audio encoder
'-b:a',
'192k', # set audio bitrate (optional)
'-map',
'0:v:0', # select the first video stream
'-map',
'1:a:0', # select the first audio stream
'-shortest', # choose the shortest duration
temp_output
]
# execute the command
logging.info("Start merging video and audio...")
result = subprocess.run(
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
# check result
if result.returncode != 0:
error_msg = f"FFmpeg execute failed: {result.stderr}"
logging.error(error_msg)
raise RuntimeError(error_msg)
shutil.move(temp_output, video_path)
logging.info(f"Merge completed, saved to {video_path}")
except Exception as e:
if os.path.exists(temp_output):
os.remove(temp_output)
logging.error(f"merge_video_audio failed with error: {e}")
def save_video(tensor,
save_file=None,
fps=30,
suffix='.mp4',
nrow=8,
normalize=True,
value_range=(-1, 1)):
# cache file
cache_file = osp.join('/tmp', rand_name(
suffix=suffix)) if save_file is None else save_file
# save to cache
try:
# preprocess
tensor = tensor.clamp(min(value_range), max(value_range))
tensor = torch.stack([
torchvision.utils.make_grid(
u, nrow=nrow, normalize=normalize, value_range=value_range)
for u in tensor.unbind(2)
],
dim=1).permute(1, 2, 3, 0)
tensor = (tensor * 255).type(torch.uint8).cpu()
# write video
writer = imageio.get_writer(
cache_file, fps=fps, codec='libx264', quality=8)
for frame in tensor.numpy():
writer.append_data(frame)
writer.close()
except Exception as e:
logging.info(f'save_video failed, error: {e}')
def save_image(tensor, save_file, nrow=8, normalize=True, value_range=(-1, 1)):
# cache file
suffix = osp.splitext(save_file)[1]
if suffix.lower() not in [
'.jpg', '.jpeg', '.png', '.tiff', '.gif', '.webp'
]:
suffix = '.png'
# save to cache
try:
tensor = tensor.clamp(min(value_range), max(value_range))
torchvision.utils.save_image(
tensor,
save_file,
nrow=nrow,
normalize=normalize,
value_range=value_range)
return save_file
except Exception as e:
logging.info(f'save_image failed, error: {e}')
def str2bool(v):
"""
Convert a string to a boolean.
Supported true values: 'yes', 'true', 't', 'y', '1'
Supported false values: 'no', 'false', 'f', 'n', '0'
Args:
v (str): String to convert.
Returns:
bool: Converted boolean value.
Raises:
argparse.ArgumentTypeError: If the value cannot be converted to boolean.
"""
if isinstance(v, bool):
return v
v_lower = v.lower()
if v_lower in ('yes', 'true', 't', 'y', '1'):
return True
elif v_lower in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected (True/False)')
def masks_like(tensor, zero=False, generator=None, p=0.2):
assert isinstance(tensor, list)
out1 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor]
out2 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor]
if zero:
if generator is not None:
for u, v in zip(out1, out2):
random_num = torch.rand(
1, generator=generator, device=generator.device).item()
if random_num < p:
u[:, 0] = torch.normal(
mean=-3.5,
std=0.5,
size=(1,),
device=u.device,
generator=generator).expand_as(u[:, 0]).exp()
v[:, 0] = torch.zeros_like(v[:, 0])
else:
u[:, 0] = u[:, 0]
v[:, 0] = v[:, 0]
else:
for u, v in zip(out1, out2):
u[:, 0] = torch.zeros_like(u[:, 0])
v[:, 0] = torch.zeros_like(v[:, 0])
return out1, out2
def best_output_size(w, h, dw, dh, expected_area):
# float output size
ratio = w / h
ow = (expected_area * ratio)**0.5
oh = expected_area / ow
# process width first
ow1 = int(ow // dw * dw)
oh1 = int(expected_area / ow1 // dh * dh)
assert ow1 % dw == 0 and oh1 % dh == 0 and ow1 * oh1 <= expected_area
ratio1 = ow1 / oh1
# process height first
oh2 = int(oh // dh * dh)
ow2 = int(expected_area / oh2 // dw * dw)
assert oh2 % dh == 0 and ow2 % dw == 0 and ow2 * oh2 <= expected_area
ratio2 = ow2 / oh2
# compare ratios
if max(ratio / ratio1, ratio1 / ratio) < max(ratio / ratio2,
ratio2 / ratio):
return ow1, oh1
else:
return ow2, oh2
def download_cosyvoice_repo(repo_path):
try:
import git
except ImportError:
raise ImportError('failed to import git, please run pip install GitPython')
repo = git.Repo.clone_from('https://github.com/FunAudioLLM/CosyVoice.git', repo_path, multi_options=['--recursive'], branch='main')
def download_cosyvoice_model(model_name, model_path):
from modelscope import snapshot_download
snapshot_download('iic/{}'.format(model_name), local_dir=model_path)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment