Commit 1336a33d authored by zzg_666's avatar zzg_666
Browse files

wan2.2

parents
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import gc
from functools import partial
import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
from torch.distributed.utils import _free_storage
def shard_model(
model,
device_id,
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
process_group=None,
sharding_strategy=ShardingStrategy.FULL_SHARD,
sync_module_states=True,
use_lora=False
):
model = FSDP(
module=model,
process_group=process_group,
sharding_strategy=sharding_strategy,
auto_wrap_policy=partial(
lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
mixed_precision=MixedPrecision(
param_dtype=param_dtype,
reduce_dtype=reduce_dtype,
buffer_dtype=buffer_dtype),
device_id=device_id,
sync_module_states=sync_module_states,
use_orig_params=True if use_lora else False)
return model
def free_model(model):
for m in model.modules():
if isinstance(m, FSDP):
_free_storage(m._handle.flat_param.data)
del model
gc.collect()
torch.cuda.empty_cache()
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.cuda.amp as amp
from ..modules.model import sinusoidal_embedding_1d
from .ulysses import distributed_attention
from .util import gather_forward, get_rank, get_world_size
def pad_freqs(original_tensor, target_len):
seq_len, s1, s2 = original_tensor.shape
pad_size = target_len - seq_len
padding_tensor = torch.ones(
pad_size,
s1,
s2,
dtype=original_tensor.dtype,
device=original_tensor.device)
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
return padded_tensor
@torch.amp.autocast('cuda', enabled=False)
def rope_apply(x, grid_sizes, freqs):
"""
x: [B, L, N, C].
grid_sizes: [B, 3].
freqs: [M, C // 2].
"""
s, n, c = x.size(1), x.size(2), x.size(3) // 2
# split freqs
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
# loop over samples
output = []
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
seq_len = f * h * w
# precompute multipliers
x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
s, n, -1, 2))
freqs_i = torch.cat([
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
],
dim=-1).reshape(seq_len, 1, -1)
# apply rotary embedding
sp_size = get_world_size()
sp_rank = get_rank()
freqs_i = pad_freqs(freqs_i, s * sp_size)
s_per_rank = s
freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
s_per_rank), :, :]
x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
x_i = torch.cat([x_i, x[i, s:]])
# append to collection
output.append(x_i)
return torch.stack(output).float()
def sp_dit_forward(
self,
x,
t,
context,
seq_len,
y=None,
):
"""
x: A list of videos each with shape [C, T, H, W].
t: [B].
context: A list of text embeddings each with shape [L, C].
"""
if self.model_type == 'i2v':
assert y is not None
# params
device = self.patch_embedding.weight.device
if self.freqs.device != device:
self.freqs = self.freqs.to(device)
if y is not None:
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
# embeddings
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
grid_sizes = torch.stack(
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
x = [u.flatten(2).transpose(1, 2) for u in x]
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
assert seq_lens.max() <= seq_len
x = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
for u in x
])
# time embeddings
if t.dim() == 1:
t = t.expand(t.size(0), seq_len)
with torch.amp.autocast('cuda', dtype=torch.float32):
bt = t.size(0)
t = t.flatten()
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim,
t).unflatten(0, (bt, seq_len)).float())
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
assert e.dtype == torch.float32 and e0.dtype == torch.float32
# context
context_lens = None
context = self.text_embedding(
torch.stack([
torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
for u in context
]))
# Context Parallel
x = torch.chunk(x, get_world_size(), dim=1)[get_rank()]
e = torch.chunk(e, get_world_size(), dim=1)[get_rank()]
e0 = torch.chunk(e0, get_world_size(), dim=1)[get_rank()]
# arguments
kwargs = dict(
e=e0,
seq_lens=seq_lens,
grid_sizes=grid_sizes,
freqs=self.freqs,
context=context,
context_lens=context_lens)
for block in self.blocks:
x = block(x, **kwargs)
# head
x = self.head(x, e)
# Context Parallel
x = gather_forward(x, dim=1)
# unpatchify
x = self.unpatchify(x, grid_sizes)
return [u.float() for u in x]
def sp_attn_forward(self, x, seq_lens, grid_sizes, freqs, dtype=torch.bfloat16):
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
half_dtypes = (torch.float16, torch.bfloat16)
def half(x):
return x if x.dtype in half_dtypes else x.to(dtype)
# query, key, value function
def qkv_fn(x):
q = self.norm_q(self.q(x)).view(b, s, n, d)
k = self.norm_k(self.k(x)).view(b, s, n, d)
v = self.v(x).view(b, s, n, d)
return q, k, v
q, k, v = qkv_fn(x)
q = rope_apply(q, grid_sizes, freqs)
k = rope_apply(k, grid_sizes, freqs)
x = distributed_attention(
half(q),
half(k),
half(v),
seq_lens,
window_size=self.window_size,
)
# output
x = x.flatten(2)
x = self.o(x)
return x
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.distributed as dist
from ..modules.attention import flash_attention
from .util import all_to_all
def distributed_attention(
q,
k,
v,
seq_lens,
window_size=(-1, -1),
):
"""
Performs distributed attention based on DeepSpeed Ulysses attention mechanism.
please refer to https://arxiv.org/pdf/2309.14509
Args:
q: [B, Lq // p, Nq, C1].
k: [B, Lk // p, Nk, C1].
v: [B, Lk // p, Nk, C2]. Nq must be divisible by Nk.
seq_lens: [B], length of each sequence in batch
window_size: (left right). If not (-1, -1), apply sliding window local attention.
"""
if not dist.is_initialized():
raise ValueError("distributed group should be initialized.")
b = q.shape[0]
# gather q/k/v sequence
q = all_to_all(q, scatter_dim=2, gather_dim=1)
k = all_to_all(k, scatter_dim=2, gather_dim=1)
v = all_to_all(v, scatter_dim=2, gather_dim=1)
# apply attention
x = flash_attention(
q,
k,
v,
k_lens=seq_lens,
window_size=window_size,
)
# scatter q/k/v sequence
x = all_to_all(x, scatter_dim=1, gather_dim=2)
return x
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.distributed as dist
def init_distributed_group():
"""r initialize sequence parallel group.
"""
if not dist.is_initialized():
dist.init_process_group(backend='nccl')
def get_rank():
return dist.get_rank()
def get_world_size():
return dist.get_world_size()
def all_to_all(x, scatter_dim, gather_dim, group=None, **kwargs):
"""
`scatter` along one dimension and `gather` along another.
"""
world_size = get_world_size()
if world_size > 1:
inputs = [u.contiguous() for u in x.chunk(world_size, dim=scatter_dim)]
outputs = [torch.empty_like(u) for u in inputs]
dist.all_to_all(outputs, inputs, group=group, **kwargs)
x = torch.cat(outputs, dim=gather_dim).contiguous()
return x
def all_gather(tensor):
world_size = dist.get_world_size()
if world_size == 1:
return [tensor]
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
torch.distributed.all_gather(tensor_list, tensor)
return tensor_list
def gather_forward(input, dim):
# skip if world_size == 1
world_size = dist.get_world_size()
if world_size == 1:
return input
# gather sequence
output = all_gather(input)
return torch.cat(output, dim=dim).contiguous()
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import gc
import logging
import math
import os
import random
import sys
import types
from contextlib import contextmanager
from functools import partial
import numpy as np
import torch
import torch.cuda.amp as amp
import torch.distributed as dist
import torchvision.transforms.functional as TF
from tqdm import tqdm
from .distributed.fsdp import shard_model
from .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward
from .distributed.util import get_world_size
from .modules.model import WanModel
from .modules.t5 import T5EncoderModel
from .modules.vae2_1 import Wan2_1_VAE
from .utils.fm_solvers import (
FlowDPMSolverMultistepScheduler,
get_sampling_sigmas,
retrieve_timesteps,
)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
class WanI2V:
def __init__(
self,
config,
checkpoint_dir,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_sp=False,
t5_cpu=False,
init_on_cpu=True,
convert_model_dtype=False,
):
r"""
Initializes the image-to-video generation model components.
Args:
config (EasyDict):
Object containing model parameters initialized from config.py
checkpoint_dir (`str`):
Path to directory containing model checkpoints
device_id (`int`, *optional*, defaults to 0):
Id of target GPU device
rank (`int`, *optional*, defaults to 0):
Process rank for distributed training
t5_fsdp (`bool`, *optional*, defaults to False):
Enable FSDP sharding for T5 model
dit_fsdp (`bool`, *optional*, defaults to False):
Enable FSDP sharding for DiT model
use_sp (`bool`, *optional*, defaults to False):
Enable distribution strategy of sequence parallel.
t5_cpu (`bool`, *optional*, defaults to False):
Whether to place T5 model on CPU. Only works without t5_fsdp.
init_on_cpu (`bool`, *optional*, defaults to True):
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
convert_model_dtype (`bool`, *optional*, defaults to False):
Convert DiT model parameters dtype to 'config.param_dtype'.
Only works without FSDP.
"""
self.device = torch.device(f"cuda:{device_id}")
self.config = config
self.rank = rank
self.t5_cpu = t5_cpu
self.init_on_cpu = init_on_cpu
self.num_train_timesteps = config.num_train_timesteps
self.boundary = config.boundary
self.param_dtype = config.param_dtype
if t5_fsdp or dit_fsdp or use_sp:
self.init_on_cpu = False
shard_fn = partial(shard_model, device_id=device_id)
self.text_encoder = T5EncoderModel(
text_len=config.text_len,
dtype=config.t5_dtype,
device=torch.device('cpu'),
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
shard_fn=shard_fn if t5_fsdp else None,
)
self.vae_stride = config.vae_stride
self.patch_size = config.patch_size
self.vae = Wan2_1_VAE(
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
device=self.device)
logging.info(f"Creating WanModel from {checkpoint_dir}")
self.low_noise_model = WanModel.from_pretrained(
checkpoint_dir, subfolder=config.low_noise_checkpoint)
self.low_noise_model = self._configure_model(
model=self.low_noise_model,
use_sp=use_sp,
dit_fsdp=dit_fsdp,
shard_fn=shard_fn,
convert_model_dtype=convert_model_dtype)
self.high_noise_model = WanModel.from_pretrained(
checkpoint_dir, subfolder=config.high_noise_checkpoint)
self.high_noise_model = self._configure_model(
model=self.high_noise_model,
use_sp=use_sp,
dit_fsdp=dit_fsdp,
shard_fn=shard_fn,
convert_model_dtype=convert_model_dtype)
if use_sp:
self.sp_size = get_world_size()
else:
self.sp_size = 1
self.sample_neg_prompt = config.sample_neg_prompt
def _configure_model(self, model, use_sp, dit_fsdp, shard_fn,
convert_model_dtype):
"""
Configures a model object. This includes setting evaluation modes,
applying distributed parallel strategy, and handling device placement.
Args:
model (torch.nn.Module):
The model instance to configure.
use_sp (`bool`):
Enable distribution strategy of sequence parallel.
dit_fsdp (`bool`):
Enable FSDP sharding for DiT model.
shard_fn (callable):
The function to apply FSDP sharding.
convert_model_dtype (`bool`):
Convert DiT model parameters dtype to 'config.param_dtype'.
Only works without FSDP.
Returns:
torch.nn.Module:
The configured model.
"""
model.eval().requires_grad_(False)
if use_sp:
for block in model.blocks:
block.self_attn.forward = types.MethodType(
sp_attn_forward, block.self_attn)
model.forward = types.MethodType(sp_dit_forward, model)
if dist.is_initialized():
dist.barrier()
if dit_fsdp:
model = shard_fn(model)
else:
if convert_model_dtype:
model.to(self.param_dtype)
if not self.init_on_cpu:
model.to(self.device)
return model
def _prepare_model_for_timestep(self, t, boundary, offload_model):
r"""
Prepares and returns the required model for the current timestep.
Args:
t (torch.Tensor):
current timestep.
boundary (`int`):
The timestep threshold. If `t` is at or above this value,
the `high_noise_model` is considered as the required model.
offload_model (`bool`):
A flag intended to control the offloading behavior.
Returns:
torch.nn.Module:
The active model on the target device for the current timestep.
"""
if t.item() >= boundary:
required_model_name = 'high_noise_model'
offload_model_name = 'low_noise_model'
else:
required_model_name = 'low_noise_model'
offload_model_name = 'high_noise_model'
if offload_model or self.init_on_cpu:
if next(getattr(
self,
offload_model_name).parameters()).device.type == 'cuda':
getattr(self, offload_model_name).to('cpu')
if next(getattr(
self,
required_model_name).parameters()).device.type == 'cpu':
getattr(self, required_model_name).to(self.device)
return getattr(self, required_model_name)
def generate(self,
input_prompt,
img,
max_area=720 * 1280,
frame_num=81,
shift=5.0,
sample_solver='unipc',
sampling_steps=40,
guide_scale=5.0,
n_prompt="",
seed=-1,
offload_model=True):
r"""
Generates video frames from input image and text prompt using diffusion process.
Args:
input_prompt (`str`):
Text prompt for content generation.
img (PIL.Image.Image):
Input image tensor. Shape: [3, H, W]
max_area (`int`, *optional*, defaults to 720*1280):
Maximum pixel area for latent space calculation. Controls video resolution scaling
frame_num (`int`, *optional*, defaults to 81):
How many frames to sample from a video. The number should be 4n+1
shift (`float`, *optional*, defaults to 5.0):
Noise schedule shift parameter. Affects temporal dynamics
[NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.
sample_solver (`str`, *optional*, defaults to 'unipc'):
Solver used to sample the video.
sampling_steps (`int`, *optional*, defaults to 40):
Number of diffusion sampling steps. Higher values improve quality but slow generation
guide_scale (`float` or tuple[`float`], *optional*, defaults 5.0):
Classifier-free guidance scale. Controls prompt adherence vs. creativity.
If tuple, the first guide_scale will be used for low noise model and
the second guide_scale will be used for high noise model.
n_prompt (`str`, *optional*, defaults to ""):
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
seed (`int`, *optional*, defaults to -1):
Random seed for noise generation. If -1, use random seed
offload_model (`bool`, *optional*, defaults to True):
If True, offloads models to CPU during generation to save VRAM
Returns:
torch.Tensor:
Generated video frames tensor. Dimensions: (C, N H, W) where:
- C: Color channels (3 for RGB)
- N: Number of frames (81)
- H: Frame height (from max_area)
- W: Frame width from max_area)
"""
# preprocess
guide_scale = (guide_scale, guide_scale) if isinstance(
guide_scale, float) else guide_scale
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
F = frame_num
h, w = img.shape[1:]
aspect_ratio = h / w
lat_h = round(
np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] //
self.patch_size[1] * self.patch_size[1])
lat_w = round(
np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] //
self.patch_size[2] * self.patch_size[2])
h = lat_h * self.vae_stride[1]
w = lat_w * self.vae_stride[2]
max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // (
self.patch_size[1] * self.patch_size[2])
max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
seed_g = torch.Generator(device=self.device)
seed_g.manual_seed(seed)
noise = torch.randn(
16,
(F - 1) // self.vae_stride[0] + 1,
lat_h,
lat_w,
dtype=torch.float32,
generator=seed_g,
device=self.device)
msk = torch.ones(1, F, lat_h, lat_w, device=self.device)
msk[:, 1:] = 0
msk = torch.concat([
torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
],
dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
msk = msk.transpose(1, 2)[0]
if n_prompt == "":
n_prompt = self.sample_neg_prompt
# preprocess
if not self.t5_cpu:
self.text_encoder.model.to(self.device)
context = self.text_encoder([input_prompt], self.device)
context_null = self.text_encoder([n_prompt], self.device)
if offload_model:
self.text_encoder.model.cpu()
else:
context = self.text_encoder([input_prompt], torch.device('cpu'))
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
context = [t.to(self.device) for t in context]
context_null = [t.to(self.device) for t in context_null]
y = self.vae.encode([
torch.concat([
torch.nn.functional.interpolate(
img[None].cpu(), size=(h, w), mode='bicubic').transpose(
0, 1),
torch.zeros(3, F - 1, h, w)
],
dim=1).to(self.device)
])[0]
y = torch.concat([msk, y])
@contextmanager
def noop_no_sync():
yield
no_sync_low_noise = getattr(self.low_noise_model, 'no_sync',
noop_no_sync)
no_sync_high_noise = getattr(self.high_noise_model, 'no_sync',
noop_no_sync)
# evaluation mode
with (
torch.amp.autocast('cuda', dtype=self.param_dtype),
torch.no_grad(),
no_sync_low_noise(),
no_sync_high_noise(),
):
boundary = self.boundary * self.num_train_timesteps
if sample_solver == 'unipc':
sample_scheduler = FlowUniPCMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sample_scheduler.set_timesteps(
sampling_steps, device=self.device, shift=shift)
timesteps = sample_scheduler.timesteps
elif sample_solver == 'dpm++':
sample_scheduler = FlowDPMSolverMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
timesteps, _ = retrieve_timesteps(
sample_scheduler,
device=self.device,
sigmas=sampling_sigmas)
else:
raise NotImplementedError("Unsupported solver.")
# sample videos
latent = noise
arg_c = {
'context': [context[0]],
'seq_len': max_seq_len,
'y': [y],
}
arg_null = {
'context': context_null,
'seq_len': max_seq_len,
'y': [y],
}
if offload_model:
torch.cuda.empty_cache()
for _, t in enumerate(tqdm(timesteps)):
latent_model_input = [latent.to(self.device)]
timestep = [t]
timestep = torch.stack(timestep).to(self.device)
model = self._prepare_model_for_timestep(
t, boundary, offload_model)
sample_guide_scale = guide_scale[1] if t.item(
) >= boundary else guide_scale[0]
noise_pred_cond = model(
latent_model_input, t=timestep, **arg_c)[0]
if offload_model:
torch.cuda.empty_cache()
noise_pred_uncond = model(
latent_model_input, t=timestep, **arg_null)[0]
if offload_model:
torch.cuda.empty_cache()
noise_pred = noise_pred_uncond + sample_guide_scale * (
noise_pred_cond - noise_pred_uncond)
temp_x0 = sample_scheduler.step(
noise_pred.unsqueeze(0),
t,
latent.unsqueeze(0),
return_dict=False,
generator=seed_g)[0]
latent = temp_x0.squeeze(0)
x0 = [latent]
del latent_model_input, timestep
if offload_model:
self.low_noise_model.cpu()
self.high_noise_model.cpu()
torch.cuda.empty_cache()
if self.rank == 0:
videos = self.vae.decode(x0)
del noise, latent, x0
del sample_scheduler
if offload_model:
gc.collect()
torch.cuda.synchronize()
if dist.is_initialized():
dist.barrier()
return videos[0] if self.rank == 0 else None
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from .attention import flash_attention
from .model import WanModel
from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
from .tokenizers import HuggingfaceTokenizer
from .vae2_1 import Wan2_1_VAE
from .vae2_2 import Wan2_2_VAE
__all__ = [
'Wan2_1_VAE',
'Wan2_2_VAE',
'WanModel',
'T5Model',
'T5Encoder',
'T5Decoder',
'T5EncoderModel',
'HuggingfaceTokenizer',
'flash_attention',
]
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from .model_animate import WanAnimateModel
from .clip import CLIPModel
__all__ = ['WanAnimateModel', 'CLIPModel']
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment