"vscode:/vscode.git/clone" did not exist on "3a2631ba0ffb07a8b1ea53636224cf0cd8d26949"
Unverified Commit d8e48058 authored by dg845's avatar dg845 Committed by GitHub
Browse files

[WIP]Add Wan2.2 Animate Pipeline (Continuation of #12442 by tolgacangoz) (#12526)





---------
Co-authored-by: default avatarTolga Cangöz <mtcangoz@gmail.com>
Co-authored-by: default avatarTolga Cangöz <46008593+tolgacangoz@users.noreply.github.com>
parent 44c31016
......@@ -387,6 +387,8 @@
title: Transformer2DModel
- local: api/models/transformer_temporal
title: TransformerTemporalModel
- local: api/models/wan_animate_transformer_3d
title: WanAnimateTransformer3DModel
- local: api/models/wan_transformer_3d
title: WanTransformer3DModel
title: Transformers
......
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# WanAnimateTransformer3DModel
A Diffusion Transformer model for 3D video-like data was introduced in [Wan Animate](https://github.com/Wan-Video/Wan2.2) by the Alibaba Wan Team.
The model can be loaded with the following code snippet.
```python
from diffusers import WanAnimateTransformer3DModel
transformer = WanAnimateTransformer3DModel.from_pretrained("Wan-AI/Wan2.2-Animate-14B-720P-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
```
## WanAnimateTransformer3DModel
[[autodoc]] WanAnimateTransformer3DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
......@@ -40,6 +40,7 @@ The following Wan models are supported in Diffusers:
- [Wan 2.2 T2V 14B](https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B-Diffusers)
- [Wan 2.2 I2V 14B](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers)
- [Wan 2.2 TI2V 5B](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B-Diffusers)
- [Wan 2.2 Animate 14B](https://huggingface.co/Wan-AI/Wan2.2-Animate-14B-Diffusers)
> [!TIP]
> Click on the Wan models in the right sidebar for more examples of video generation.
......@@ -95,15 +96,15 @@ pipeline = WanPipeline.from_pretrained(
pipeline.to("cuda")
prompt = """
The camera rushes from far to near in a low-angle shot,
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
The camera rushes from far to near in a low-angle shot,
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
"""
negative_prompt = """
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
"""
......@@ -150,15 +151,15 @@ pipeline.transformer = torch.compile(
)
prompt = """
The camera rushes from far to near in a low-angle shot,
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
The camera rushes from far to near in a low-angle shot,
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
"""
negative_prompt = """
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
"""
......@@ -249,6 +250,220 @@ The code snippets available in [this](https://github.com/huggingface/diffusers/p
The general rule of thumb to keep in mind when preparing inputs for the VACE pipeline is that the input images, or frames of a video that you want to use for conditioning, should have a corresponding mask that is black in color. The black mask signifies that the model will not generate new content for that area, and only use those parts for conditioning the generation process. For parts/frames that should be generated by the model, the mask should be white in color.
</hfoption>
</hfoptions>
### Wan-Animate: Unified Character Animation and Replacement with Holistic Replication
[Wan-Animate](https://huggingface.co/papers/2509.14055) by the Wan Team.
*We introduce Wan-Animate, a unified framework for character animation and replacement. Given a character image and a reference video, Wan-Animate can animate the character by precisely replicating the expressions and movements of the character in the video to generate high-fidelity character videos. Alternatively, it can integrate the animated character into the reference video to replace the original character, replicating the scene's lighting and color tone to achieve seamless environmental integration. Wan-Animate is built upon the Wan model. To adapt it for character animation tasks, we employ a modified input paradigm to differentiate between reference conditions and regions for generation. This design unifies multiple tasks into a common symbolic representation. We use spatially-aligned skeleton signals to replicate body motion and implicit facial features extracted from source images to reenact expressions, enabling the generation of character videos with high controllability and expressiveness. Furthermore, to enhance environmental integration during character replacement, we develop an auxiliary Relighting LoRA. This module preserves the character's appearance consistency while applying the appropriate environmental lighting and color tone. Experimental results demonstrate that Wan-Animate achieves state-of-the-art performance. We are committed to open-sourcing the model weights and its source code.*
The project page: https://humanaigc.github.io/wan-animate
This model was mostly contributed by [M. Tolga Cangöz](https://github.com/tolgacangoz).
#### Usage
The Wan-Animate pipeline supports two modes of operation:
1. **Animation Mode** (default): Animates a character image based on motion and expression from reference videos
2. **Replacement Mode**: Replaces a character in a background video with a new character while preserving the scene
##### Prerequisites
Before using the pipeline, you need to preprocess your reference video to extract:
- **Pose video**: Contains skeletal keypoints representing body motion
- **Face video**: Contains facial feature representations for expression control
For replacement mode, you additionally need:
- **Background video**: The original video containing the scene
- **Mask video**: A mask indicating where to generate content (white) vs. preserve original (black)
> [!NOTE]
> The preprocessing tools are available in the original Wan-Animate repository. Integration of these preprocessing steps into Diffusers is planned for a future release.
The example below demonstrates how to use the Wan-Animate pipeline:
<hfoptions id="Animate usage">
<hfoption id="Animation mode">
```python
import numpy as np
import torch
from diffusers import AutoencoderKLWan, WanAnimatePipeline
from diffusers.utils import export_to_video, load_image, load_video
from transformers import CLIPVisionModel
model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers"
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanAnimatePipeline.from_pretrained(
model_id, vae=vae, torch_dtype=torch.bfloat16
)
pipe.to("cuda")
# Load character image and preprocessed videos
image = load_image("path/to/character.jpg")
pose_video = load_video("path/to/pose_video.mp4") # Preprocessed skeletal keypoints
face_video = load_video("path/to/face_video.mp4") # Preprocessed facial features
# Resize image to match VAE constraints
def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
return image, height, width
image, height, width = aspect_ratio_resize(image, pipe)
prompt = "A person dancing energetically in a studio with dynamic lighting and professional camera work"
negative_prompt = "blurry, low quality, distorted, deformed, static, poorly drawn"
# Generate animated video
output = pipe(
image=image,
pose_video=pose_video,
face_video=face_video,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=81,
guidance_scale=5.0,
mode="animation", # Animation mode (default)
).frames[0]
export_to_video(output, "animated_character.mp4", fps=16)
```
</hfoption>
<hfoption id="Replacement mode">
```python
import numpy as np
import torch
from diffusers import AutoencoderKLWan, WanAnimatePipeline
from diffusers.utils import export_to_video, load_image, load_video
from transformers import CLIPVisionModel
model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers"
image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float16)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanAnimatePipeline.from_pretrained(
model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
)
pipe.to("cuda")
# Load all required inputs for replacement mode
image = load_image("path/to/new_character.jpg")
pose_video = load_video("path/to/pose_video.mp4") # Preprocessed skeletal keypoints
face_video = load_video("path/to/face_video.mp4") # Preprocessed facial features
background_video = load_video("path/to/background_video.mp4") # Original scene
mask_video = load_video("path/to/mask_video.mp4") # Black: preserve, White: generate
# Resize image to match video dimensions
def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
return image, height, width
image, height, width = aspect_ratio_resize(image, pipe)
prompt = "A person seamlessly integrated into the scene with consistent lighting and environment"
negative_prompt = "blurry, low quality, inconsistent lighting, floating, disconnected from scene"
# Replace character in background video
output = pipe(
image=image,
pose_video=pose_video,
face_video=face_video,
background_video=background_video,
mask_video=mask_video,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=81,
guidance_scale=5.0,
mode="replacement", # Replacement mode
).frames[0]
export_to_video(output, "character_replaced.mp4", fps=16)
```
</hfoption>
<hfoption id="Advanced options">
```python
import numpy as np
import torch
from diffusers import AutoencoderKLWan, WanAnimatePipeline
from diffusers.utils import export_to_video, load_image, load_video
from transformers import CLIPVisionModel
model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers"
image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float16)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanAnimatePipeline.from_pretrained(
model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
)
pipe.to("cuda")
image = load_image("path/to/character.jpg")
pose_video = load_video("path/to/pose_video.mp4")
face_video = load_video("path/to/face_video.mp4")
def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
return image, height, width
image, height, width = aspect_ratio_resize(image, pipe)
prompt = "A person dancing energetically in a studio"
negative_prompt = "blurry, low quality"
# Advanced: Use temporal guidance and custom callback
def callback_fn(pipe, step_index, timestep, callback_kwargs):
# You can modify latents or other tensors here
print(f"Step {step_index}, Timestep {timestep}")
return callback_kwargs
output = pipe(
image=image,
pose_video=pose_video,
face_video=face_video,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=81,
num_inference_steps=50,
guidance_scale=5.0,
num_frames_for_temporal_guidance=5, # Use 5 frames for temporal guidance (1 or 5 recommended)
callback_on_step_end=callback_fn,
callback_on_step_end_tensor_inputs=["latents"],
).frames[0]
export_to_video(output, "animated_advanced.mp4", fps=16)
```
</hfoption>
</hfoptions>
#### Key Parameters
- **mode**: Choose between `"animation"` (default) or `"replacement"`
- **num_frames_for_temporal_guidance**: Number of frames for temporal guidance (1 or 5 recommended). Using 5 provides better temporal consistency but requires more memory
- **guidance_scale**: Controls how closely the output follows the text prompt. Higher values (5-7) produce results more aligned with the prompt
- **num_frames**: Total number of frames to generate. Should be divisible by `vae_scale_factor_temporal` (default: 4)
## Notes
- Wan2.1 supports LoRAs with [`~loaders.WanLoraLoaderMixin.load_lora_weights`].
......@@ -281,10 +496,10 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip
# use "steamboat willie style" to trigger the LoRA
prompt = """
steamboat willie style, golden era animation, The camera rushes from far to near in a low-angle shot,
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
steamboat willie style, golden era animation, The camera rushes from far to near in a low-angle shot,
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
"""
......@@ -359,6 +574,12 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip
- all
- __call__
## WanAnimatePipeline
[[autodoc]] WanAnimatePipeline
- all
- __call__
## WanPipelineOutput
[[autodoc]] pipelines.wan.pipeline_output.WanPipelineOutput
\ No newline at end of file
[[autodoc]] pipelines.wan.pipeline_output.WanPipelineOutput
......@@ -6,11 +6,20 @@ import torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download, snapshot_download
from safetensors.torch import load_file
from transformers import AutoProcessor, AutoTokenizer, CLIPVisionModelWithProjection, UMT5EncoderModel
from transformers import (
AutoProcessor,
AutoTokenizer,
CLIPImageProcessor,
CLIPVisionModel,
CLIPVisionModelWithProjection,
UMT5EncoderModel,
)
from diffusers import (
AutoencoderKLWan,
UniPCMultistepScheduler,
WanAnimatePipeline,
WanAnimateTransformer3DModel,
WanImageToVideoPipeline,
WanPipeline,
WanTransformer3DModel,
......@@ -105,8 +114,203 @@ VACE_TRANSFORMER_KEYS_RENAME_DICT = {
"after_proj": "proj_out",
}
ANIMATE_TRANSFORMER_KEYS_RENAME_DICT = {
"time_embedding.0": "condition_embedder.time_embedder.linear_1",
"time_embedding.2": "condition_embedder.time_embedder.linear_2",
"text_embedding.0": "condition_embedder.text_embedder.linear_1",
"text_embedding.2": "condition_embedder.text_embedder.linear_2",
"time_projection.1": "condition_embedder.time_proj",
"head.modulation": "scale_shift_table",
"head.head": "proj_out",
"modulation": "scale_shift_table",
"ffn.0": "ffn.net.0.proj",
"ffn.2": "ffn.net.2",
# Hack to swap the layer names
# The original model calls the norms in following order: norm1, norm3, norm2
# We convert it to: norm1, norm2, norm3
"norm2": "norm__placeholder",
"norm3": "norm2",
"norm__placeholder": "norm3",
"img_emb.proj.0": "condition_embedder.image_embedder.norm1",
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
# Add attention component mappings
"self_attn.q": "attn1.to_q",
"self_attn.k": "attn1.to_k",
"self_attn.v": "attn1.to_v",
"self_attn.o": "attn1.to_out.0",
"self_attn.norm_q": "attn1.norm_q",
"self_attn.norm_k": "attn1.norm_k",
"cross_attn.q": "attn2.to_q",
"cross_attn.k": "attn2.to_k",
"cross_attn.v": "attn2.to_v",
"cross_attn.o": "attn2.to_out.0",
"cross_attn.norm_q": "attn2.norm_q",
"cross_attn.norm_k": "attn2.norm_k",
"cross_attn.k_img": "attn2.to_k_img",
"cross_attn.v_img": "attn2.to_v_img",
"cross_attn.norm_k_img": "attn2.norm_k_img",
# After cross_attn -> attn2 rename, we need to rename the img keys
"attn2.to_k_img": "attn2.add_k_proj",
"attn2.to_v_img": "attn2.add_v_proj",
"attn2.norm_k_img": "attn2.norm_added_k",
# Wan Animate-specific mappings (motion encoder, face encoder, face adapter)
# Motion encoder mappings
# The name mapping is complicated for the convolutional part so we handle that in its own function
"motion_encoder.enc.fc": "motion_encoder.motion_network",
"motion_encoder.dec.direction.weight": "motion_encoder.motion_synthesis_weight",
# Face encoder mappings - CausalConv1d has a .conv submodule that we need to flatten
"face_encoder.conv1_local.conv": "face_encoder.conv1_local",
"face_encoder.conv2.conv": "face_encoder.conv2",
"face_encoder.conv3.conv": "face_encoder.conv3",
# Face adapter mappings are handled in a separate function
}
# TODO: Verify this and simplify if possible.
def convert_animate_motion_encoder_weights(key: str, state_dict: Dict[str, Any], final_conv_idx: int = 8) -> None:
"""
Convert all motion encoder weights for Animate model.
In the original model:
- All Linear layers in fc use EqualLinear
- All Conv2d layers in convs use EqualConv2d (except blur_conv which is initialized separately)
- Blur kernels are stored as buffers in Sequential modules
- ConvLayer is nn.Sequential with indices: [Blur (optional), EqualConv2d, FusedLeakyReLU (optional)]
Conversion strategy:
1. Drop .kernel buffers (blur kernels)
2. Rename sequential indices to named components (e.g., 0 -> conv2d, 1 -> bias_leaky_relu)
"""
# Skip if not a weight, bias, or kernel
if ".weight" not in key and ".bias" not in key and ".kernel" not in key:
return
# Handle Blur kernel buffers from original implementation.
# After renaming, these appear under: motion_encoder.res_blocks.*.conv{2,skip}.blur_kernel
# Diffusers constructs blur kernels as a non-persistent buffer so we must drop these keys
if ".kernel" in key and "motion_encoder" in key:
# Remove unexpected blur kernel buffers to avoid strict load errors
state_dict.pop(key, None)
return
# Rename Sequential indices to named components in ConvLayer and ResBlock
if ".enc.net_app.convs." in key and (".weight" in key or ".bias" in key):
parts = key.split(".")
# Find the sequential index (digit) after convs or after conv1/conv2/skip
# Examples:
# - enc.net_app.convs.0.0.weight -> conv_in.weight (initial conv layer weight)
# - enc.net_app.convs.0.1.bias -> conv_in.act_fn.bias (initial conv layer bias)
# - enc.net_app.convs.{n:1-7}.conv1.0.weight -> res_blocks.{(n-1):0-6}.conv1.weight (conv1 weight)
# - e.g. enc.net_app.convs.1.conv1.0.weight -> res_blocks.0.conv1.weight
# - enc.net_app.convs.{n:1-7}.conv1.1.bias -> res_blocks.{(n-1):0-6}.conv1.act_fn.bias (conv1 bias)
# - e.g. enc.net_app.convs.1.conv1.1.bias -> res_blocks.0.conv1.act_fn.bias
# - enc.net_app.convs.{n:1-7}.conv2.1.weight -> res_blocks.{(n-1):0-6}.conv2.weight (conv2 weight)
# - enc.net_app.convs.1.conv2.2.bias -> res_blocks.0.conv2.act_fn.bias (conv2 bias)
# - enc.net_app.convs.{n:1-7}.skip.1.weight -> res_blocks.{(n-1):0-6}.conv_skip.weight (skip conv weight)
# - enc.net_app.convs.8 -> conv_out (final conv layer)
convs_idx = parts.index("convs") if "convs" in parts else -1
if convs_idx >= 0 and len(parts) - convs_idx >= 2:
bias = False
# The nn.Sequential index will always follow convs
sequential_idx = int(parts[convs_idx + 1])
if sequential_idx == 0:
if key.endswith(".weight"):
new_key = "motion_encoder.conv_in.weight"
elif key.endswith(".bias"):
new_key = "motion_encoder.conv_in.act_fn.bias"
bias = True
elif sequential_idx == final_conv_idx:
if key.endswith(".weight"):
new_key = "motion_encoder.conv_out.weight"
else:
# Intermediate .convs. layers, which get mapped to .res_blocks.
prefix = "motion_encoder.res_blocks."
layer_name = parts[convs_idx + 2]
if layer_name == "skip":
layer_name = "conv_skip"
if key.endswith(".weight"):
param_name = "weight"
elif key.endswith(".bias"):
param_name = "act_fn.bias"
bias = True
suffix_parts = [str(sequential_idx - 1), layer_name, param_name]
suffix = ".".join(suffix_parts)
new_key = prefix + suffix
param = state_dict.pop(key)
if bias:
param = param.squeeze()
state_dict[new_key] = param
return
return
return
def convert_animate_face_adapter_weights(key: str, state_dict: Dict[str, Any]) -> None:
"""
Convert face adapter weights for the Animate model.
The original model uses a fused KV projection but the diffusers models uses separate K and V projections.
"""
# Skip if not a weight or bias
if ".weight" not in key and ".bias" not in key:
return
prefix = "face_adapter."
if ".fuser_blocks." in key:
parts = key.split(".")
module_list_idx = parts.index("fuser_blocks") if "fuser_blocks" in parts else -1
if module_list_idx >= 0 and (len(parts) - 1) - module_list_idx == 3:
block_idx = parts[module_list_idx + 1]
layer_name = parts[module_list_idx + 2]
param_name = parts[module_list_idx + 3]
if layer_name == "linear1_kv":
layer_name_k = "to_k"
layer_name_v = "to_v"
suffix_k = ".".join([block_idx, layer_name_k, param_name])
suffix_v = ".".join([block_idx, layer_name_v, param_name])
new_key_k = prefix + suffix_k
new_key_v = prefix + suffix_v
kv_proj = state_dict.pop(key)
k_proj, v_proj = torch.chunk(kv_proj, 2, dim=0)
state_dict[new_key_k] = k_proj
state_dict[new_key_v] = v_proj
return
else:
if layer_name == "q_norm":
new_layer_name = "norm_q"
elif layer_name == "k_norm":
new_layer_name = "norm_k"
elif layer_name == "linear1_q":
new_layer_name = "to_q"
elif layer_name == "linear2":
new_layer_name = "to_out"
suffix_parts = [block_idx, new_layer_name, param_name]
suffix = ".".join(suffix_parts)
new_key = prefix + suffix
state_dict[new_key] = state_dict.pop(key)
return
return
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = {}
ANIMATE_TRANSFORMER_SPECIAL_KEYS_REMAP = {
"motion_encoder": convert_animate_motion_encoder_weights,
"face_adapter": convert_animate_face_adapter_weights,
}
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
......@@ -364,6 +568,37 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
}
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
elif model_type == "Wan2.2-Animate-14B":
config = {
"model_id": "Wan-AI/Wan2.2-Animate-14B",
"diffusers_config": {
"image_dim": 1280,
"added_kv_proj_dim": 5120,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_channels": 36,
"num_attention_heads": 40,
"num_layers": 40,
"out_channels": 16,
"patch_size": (1, 2, 2),
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
"rope_max_seq_len": 1024,
"pos_embed_seq_len": None,
"motion_encoder_size": 512, # Start of Wan Animate-specific configs
"motion_style_dim": 512,
"motion_dim": 20,
"motion_encoder_dim": 512,
"face_encoder_hidden_dim": 1024,
"face_encoder_num_heads": 4,
"inject_face_latents_blocks": 5,
},
}
RENAME_DICT = ANIMATE_TRANSFORMER_KEYS_RENAME_DICT
SPECIAL_KEYS_REMAP = ANIMATE_TRANSFORMER_SPECIAL_KEYS_REMAP
return config, RENAME_DICT, SPECIAL_KEYS_REMAP
......@@ -380,10 +615,12 @@ def convert_transformer(model_type: str, stage: str = None):
original_state_dict = load_sharded_safetensors(model_dir)
with init_empty_weights():
if "VACE" not in model_type:
transformer = WanTransformer3DModel.from_config(diffusers_config)
else:
if "Animate" in model_type:
transformer = WanAnimateTransformer3DModel.from_config(diffusers_config)
elif "VACE" in model_type:
transformer = WanVACETransformer3DModel.from_config(diffusers_config)
else:
transformer = WanTransformer3DModel.from_config(diffusers_config)
for key in list(original_state_dict.keys()):
new_key = key[:]
......@@ -397,7 +634,12 @@ def convert_transformer(model_type: str, stage: str = None):
continue
handler_fn_inplace(key, original_state_dict)
# Load state dict into the meta model, which will materialize the tensors
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
# Move to CPU to ensure all tensors are materialized
transformer = transformer.to("cpu")
return transformer
......@@ -926,7 +1168,7 @@ DTYPE_MAPPING = {
if __name__ == "__main__":
args = get_args()
if "Wan2.2" in args.model_type and "TI2V" not in args.model_type:
if "Wan2.2" in args.model_type and "TI2V" not in args.model_type and "Animate" not in args.model_type:
transformer = convert_transformer(args.model_type, stage="high_noise_model")
transformer_2 = convert_transformer(args.model_type, stage="low_noise_model")
else:
......@@ -942,7 +1184,7 @@ if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
if "FLF2V" in args.model_type:
flow_shift = 16.0
elif "TI2V" in args.model_type:
elif "TI2V" in args.model_type or "Animate" in args.model_type:
flow_shift = 5.0
else:
flow_shift = 3.0
......@@ -954,6 +1196,8 @@ if __name__ == "__main__":
if args.dtype != "none":
dtype = DTYPE_MAPPING[args.dtype]
transformer.to(dtype)
if transformer_2 is not None:
transformer_2.to(dtype)
if "Wan2.2" and "I2V" in args.model_type and "TI2V" not in args.model_type:
pipe = WanImageToVideoPipeline(
......@@ -1016,6 +1260,21 @@ if __name__ == "__main__":
vae=vae,
scheduler=scheduler,
)
elif "Animate" in args.model_type:
image_encoder = CLIPVisionModel.from_pretrained(
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16
)
image_processor = CLIPImageProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
pipe = WanAnimatePipeline(
transformer=transformer,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
scheduler=scheduler,
image_encoder=image_encoder,
image_processor=image_processor,
)
else:
pipe = WanPipeline(
transformer=transformer,
......
......@@ -268,6 +268,7 @@ else:
"UNetSpatioTemporalConditionModel",
"UVit2DModel",
"VQModel",
"WanAnimateTransformer3DModel",
"WanTransformer3DModel",
"WanVACETransformer3DModel",
"attention_backend",
......@@ -636,6 +637,7 @@ else:
"VisualClozeGenerationPipeline",
"VisualClozePipeline",
"VQDiffusionPipeline",
"WanAnimatePipeline",
"WanImageToVideoPipeline",
"WanPipeline",
"WanVACEPipeline",
......@@ -977,6 +979,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
UNetSpatioTemporalConditionModel,
UVit2DModel,
VQModel,
WanAnimateTransformer3DModel,
WanTransformer3DModel,
WanVACETransformer3DModel,
attention_backend,
......@@ -1315,6 +1318,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
VisualClozeGenerationPipeline,
VisualClozePipeline,
VQDiffusionPipeline,
WanAnimatePipeline,
WanImageToVideoPipeline,
WanPipeline,
WanVACEPipeline,
......
......@@ -409,7 +409,7 @@ class VaeImageProcessor(ConfigMixin):
src_w = width if ratio < src_ratio else image.width * height // image.height
src_h = height if ratio >= src_ratio else image.height * width // image.width
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION[self.config.resample])
res = Image.new("RGB", (width, height))
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
......@@ -460,7 +460,7 @@ class VaeImageProcessor(ConfigMixin):
src_w = width if ratio > src_ratio else image.width * height // image.height
src_h = height if ratio <= src_ratio else image.height * width // image.width
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION[self.config.resample])
res = Image.new("RGB", (width, height))
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
return res
......
......@@ -108,6 +108,7 @@ if is_torch_available():
_import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
_import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"]
_import_structure["transformers.transformer_wan_animate"] = ["WanAnimateTransformer3DModel"]
_import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"]
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
......@@ -214,6 +215,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
T5FilmDecoder,
Transformer2DModel,
TransformerTemporalModel,
WanAnimateTransformer3DModel,
WanTransformer3DModel,
WanVACETransformer3DModel,
)
......
......@@ -42,4 +42,5 @@ if is_torch_available():
from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel
from .transformer_temporal import TransformerTemporalModel
from .transformer_wan import WanTransformer3DModel
from .transformer_wan_animate import WanAnimateTransformer3DModel
from .transformer_wan_vace import WanVACETransformer3DModel
......@@ -188,6 +188,11 @@ class WanRotaryPosEmbed(nn.Module):
h_dim = w_dim = 2 * (attention_head_dim // 6)
t_dim = attention_head_dim - h_dim - w_dim
self.t_dim = t_dim
self.h_dim = h_dim
self.w_dim = w_dim
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
freqs_cos = []
......@@ -213,11 +218,7 @@ class WanRotaryPosEmbed(nn.Module):
p_t, p_h, p_w = self.patch_size
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
split_sizes = [
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
self.attention_head_dim // 3,
self.attention_head_dim // 3,
]
split_sizes = [self.t_dim, self.h_dim, self.w_dim]
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
......
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import FP32LayerNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
WAN_ANIMATE_MOTION_ENCODER_CHANNEL_SIZES = {
"4": 512,
"8": 512,
"16": 512,
"32": 512,
"64": 256,
"128": 128,
"256": 64,
"512": 32,
"1024": 16,
}
# Copied from diffusers.models.transformers.transformer_wan._get_qkv_projections
def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor):
# encoder_hidden_states is only passed for cross-attention
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
if attn.fused_projections:
if attn.cross_attention_dim_head is None:
# In self-attention layers, we can fuse the entire QKV projection into a single linear
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
else:
# In cross-attention layers, we can only fuse the KV projections into a single linear
query = attn.to_q(hidden_states)
key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1)
else:
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
return query, key, value
# Copied from diffusers.models.transformers.transformer_wan._get_added_kv_projections
def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: torch.Tensor):
if attn.fused_projections:
key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1)
else:
key_img = attn.add_k_proj(encoder_hidden_states_img)
value_img = attn.add_v_proj(encoder_hidden_states_img)
return key_img, value_img
class FusedLeakyReLU(nn.Module):
"""
Fused LeakyRelu with scale factor and channel-wise bias.
"""
def __init__(self, negative_slope: float = 0.2, scale: float = 2**0.5, bias_channels: Optional[int] = None):
super().__init__()
self.negative_slope = negative_slope
self.scale = scale
self.channels = bias_channels
if self.channels is not None:
self.bias = nn.Parameter(
torch.zeros(
self.channels,
)
)
else:
self.bias = None
def forward(self, x: torch.Tensor, channel_dim: int = 1) -> torch.Tensor:
if self.bias is not None:
# Expand self.bias to have all singleton dims except at self.channel_dim
expanded_shape = [1] * x.ndim
expanded_shape[channel_dim] = self.bias.shape[0]
bias = self.bias.reshape(*expanded_shape)
x = x + bias
return F.leaky_relu(x, self.negative_slope) * self.scale
class MotionConv2d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
bias: bool = True,
blur_kernel: Optional[Tuple[int, ...]] = None,
blur_upsample_factor: int = 1,
use_activation: bool = True,
):
super().__init__()
self.use_activation = use_activation
self.in_channels = in_channels
# Handle blurring (applying a FIR filter with the given kernel) if available
self.blur = False
if blur_kernel is not None:
p = (len(blur_kernel) - stride) + (kernel_size - 1)
self.blur_padding = ((p + 1) // 2, p // 2)
kernel = torch.tensor(blur_kernel)
# Convert kernel to 2D if necessary
if kernel.ndim == 1:
kernel = kernel[None, :] * kernel[:, None]
# Normalize kernel
kernel = kernel / kernel.sum()
if blur_upsample_factor > 1:
kernel = kernel * (blur_upsample_factor**2)
self.register_buffer("blur_kernel", kernel, persistent=False)
self.blur = True
# Main Conv2d parameters (with scale factor)
self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
self.stride = stride
self.padding = padding
# If using an activation function, the bias will be fused into the activation
if bias and not self.use_activation:
self.bias = nn.Parameter(torch.zeros(out_channels))
else:
self.bias = None
if self.use_activation:
self.act_fn = FusedLeakyReLU(bias_channels=out_channels)
else:
self.act_fn = None
def forward(self, x: torch.Tensor, channel_dim: int = 1) -> torch.Tensor:
# Apply blur if using
if self.blur:
# NOTE: the original implementation uses a 2D upfirdn operation with the upsampling and downsampling rates
# set to 1, which should be equivalent to a 2D convolution
expanded_kernel = self.blur_kernel[None, None, :, :].expand(self.in_channels, 1, -1, -1)
x = F.conv2d(x, expanded_kernel, padding=self.blur_padding, groups=self.in_channels)
# Main Conv2D with scaling
x = F.conv2d(x, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding)
# Activation with fused bias, if using
if self.use_activation:
x = self.act_fn(x, channel_dim=channel_dim)
return x
def __repr__(self):
return (
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
f" kernel_size={self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
)
class MotionLinear(nn.Module):
def __init__(
self,
in_dim: int,
out_dim: int,
bias: bool = True,
use_activation: bool = False,
):
super().__init__()
self.use_activation = use_activation
# Linear weight with scale factor
self.weight = nn.Parameter(torch.randn(out_dim, in_dim))
self.scale = 1 / math.sqrt(in_dim)
# If an activation is present, the bias will be fused to it
if bias and not self.use_activation:
self.bias = nn.Parameter(torch.zeros(out_dim))
else:
self.bias = None
if self.use_activation:
self.act_fn = FusedLeakyReLU(bias_channels=out_dim)
else:
self.act_fn = None
def forward(self, input: torch.Tensor, channel_dim: int = 1) -> torch.Tensor:
out = F.linear(input, self.weight * self.scale, bias=self.bias)
if self.use_activation:
out = self.act_fn(out, channel_dim=channel_dim)
return out
def __repr__(self):
return (
f"{self.__class__.__name__}(in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]},"
f" bias={self.bias is not None})"
)
class MotionEncoderResBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
kernel_size_skip: int = 1,
blur_kernel: Tuple[int, ...] = (1, 3, 3, 1),
downsample_factor: int = 2,
):
super().__init__()
self.downsample_factor = downsample_factor
# 3 x 3 Conv + fused leaky ReLU
self.conv1 = MotionConv2d(
in_channels,
in_channels,
kernel_size,
stride=1,
padding=kernel_size // 2,
use_activation=True,
)
# 3 x 3 Conv that downsamples 2x + fused leaky ReLU
self.conv2 = MotionConv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=self.downsample_factor,
padding=0,
blur_kernel=blur_kernel,
use_activation=True,
)
# 1 x 1 Conv that downsamples 2x in skip connection
self.conv_skip = MotionConv2d(
in_channels,
out_channels,
kernel_size=kernel_size_skip,
stride=self.downsample_factor,
padding=0,
bias=False,
blur_kernel=blur_kernel,
use_activation=False,
)
def forward(self, x: torch.Tensor, channel_dim: int = 1) -> torch.Tensor:
x_out = self.conv1(x, channel_dim)
x_out = self.conv2(x_out, channel_dim)
x_skip = self.conv_skip(x, channel_dim)
x_out = (x_out + x_skip) / math.sqrt(2)
return x_out
class WanAnimateMotionEncoder(nn.Module):
def __init__(
self,
size: int = 512,
style_dim: int = 512,
motion_dim: int = 20,
out_dim: int = 512,
motion_blocks: int = 5,
channels: Optional[Dict[str, int]] = None,
):
super().__init__()
self.size = size
# Appearance encoder: conv layers
if channels is None:
channels = WAN_ANIMATE_MOTION_ENCODER_CHANNEL_SIZES
self.conv_in = MotionConv2d(3, channels[str(size)], 1, use_activation=True)
self.res_blocks = nn.ModuleList()
in_channels = channels[str(size)]
log_size = int(math.log(size, 2))
for i in range(log_size, 2, -1):
out_channels = channels[str(2 ** (i - 1))]
self.res_blocks.append(MotionEncoderResBlock(in_channels, out_channels))
in_channels = out_channels
self.conv_out = MotionConv2d(in_channels, style_dim, 4, padding=0, bias=False, use_activation=False)
# Motion encoder: linear layers
# NOTE: there are no activations in between the linear layers here, which is weird but I believe matches the
# original code.
linears = [MotionLinear(style_dim, style_dim) for _ in range(motion_blocks - 1)]
linears.append(MotionLinear(style_dim, motion_dim))
self.motion_network = nn.ModuleList(linears)
self.motion_synthesis_weight = nn.Parameter(torch.randn(out_dim, motion_dim))
def forward(self, face_image: torch.Tensor, channel_dim: int = 1) -> torch.Tensor:
if (face_image.shape[-2] != self.size) or (face_image.shape[-1] != self.size):
raise ValueError(
f"Face pixel values has resolution ({face_image.shape[-1]}, {face_image.shape[-2]}) but is expected"
f" to have resolution ({self.size}, {self.size})"
)
# Appearance encoding through convs
face_image = self.conv_in(face_image, channel_dim)
for block in self.res_blocks:
face_image = block(face_image, channel_dim)
face_image = self.conv_out(face_image, channel_dim)
motion_feat = face_image.squeeze(-1).squeeze(-1)
# Motion feature extraction
for linear_layer in self.motion_network:
motion_feat = linear_layer(motion_feat, channel_dim=channel_dim)
# Motion synthesis via Linear Motion Decomposition
weight = self.motion_synthesis_weight + 1e-8
# Upcast the QR orthogonalization operation to FP32
original_motion_dtype = motion_feat.dtype
motion_feat = motion_feat.to(torch.float32)
weight = weight.to(torch.float32)
Q = torch.linalg.qr(weight)[0].to(device=motion_feat.device)
motion_feat_diag = torch.diag_embed(motion_feat) # Alpha, diagonal matrix
motion_decomposition = torch.matmul(motion_feat_diag, Q.T)
motion_vec = torch.sum(motion_decomposition, dim=1)
motion_vec = motion_vec.to(dtype=original_motion_dtype)
return motion_vec
class WanAnimateFaceEncoder(nn.Module):
def __init__(
self,
in_dim: int,
out_dim: int,
hidden_dim: int = 1024,
num_heads: int = 4,
kernel_size: int = 3,
eps: float = 1e-6,
pad_mode: str = "replicate",
):
super().__init__()
self.num_heads = num_heads
self.time_causal_padding = (kernel_size - 1, 0)
self.pad_mode = pad_mode
self.act = nn.SiLU()
self.conv1_local = nn.Conv1d(in_dim, hidden_dim * num_heads, kernel_size=kernel_size, stride=1)
self.conv2 = nn.Conv1d(hidden_dim, hidden_dim, kernel_size, stride=2)
self.conv3 = nn.Conv1d(hidden_dim, hidden_dim, kernel_size, stride=2)
self.norm1 = nn.LayerNorm(hidden_dim, eps, elementwise_affine=False)
self.norm2 = nn.LayerNorm(hidden_dim, eps, elementwise_affine=False)
self.norm3 = nn.LayerNorm(hidden_dim, eps, elementwise_affine=False)
self.out_proj = nn.Linear(hidden_dim, out_dim)
self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, out_dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
batch_size = x.shape[0]
# Reshape to channels-first to apply causal Conv1d over frame dim
x = x.permute(0, 2, 1)
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
x = self.conv1_local(x) # [B, C, T_padded] --> [B, N * C, T]
x = x.unflatten(1, (self.num_heads, -1)).flatten(0, 1) # [B, N * C, T] --> [B * N, C, T]
# Reshape back to channels-last to apply LayerNorm over channel dim
x = x.permute(0, 2, 1)
x = self.norm1(x)
x = self.act(x)
x = x.permute(0, 2, 1)
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
x = self.conv2(x)
x = x.permute(0, 2, 1)
x = self.norm2(x)
x = self.act(x)
x = x.permute(0, 2, 1)
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
x = self.conv3(x)
x = x.permute(0, 2, 1)
x = self.norm3(x)
x = self.act(x)
x = self.out_proj(x)
x = x.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3) # [B * N, T, C_out] --> [B, T, N, C_out]
padding = self.padding_tokens.repeat(batch_size, x.shape[1], 1, 1).to(device=x.device)
x = torch.cat([x, padding], dim=-2) # [B, T, N, C_out] --> [B, T, N + 1, C_out]
return x
class WanAnimateFaceBlockAttnProcessor:
_attention_backend = None
_parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
f"{self.__class__.__name__} requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or"
f" higher."
)
def __call__(
self,
attn: "WanAnimateFaceBlockCrossAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# encoder_hidden_states corresponds to the motion vec
# attention_mask corresponds to the motion mask (if any)
hidden_states = attn.pre_norm_q(hidden_states)
encoder_hidden_states = attn.pre_norm_kv(encoder_hidden_states)
# B --> batch_size, T --> reduced inference segment len, N --> face_encoder_num_heads + 1, C --> attn.dim
B, T, N, C = encoder_hidden_states.shape
query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
query = query.unflatten(2, (attn.heads, -1)) # [B, S, H * D] --> [B, S, H, D]
key = key.view(B, T, N, attn.heads, -1) # [B, T, N, H * D_kv] --> [B, T, N, H, D_kv]
value = value.view(B, T, N, attn.heads, -1)
query = attn.norm_q(query)
key = attn.norm_k(key)
# NOTE: the below line (which follows the official code) means that in practice, the number of frames T in
# encoder_hidden_states (the motion vector after applying the face encoder) must evenly divide the
# post-patchify sequence length S of the transformer hidden_states. Is it possible to remove this dependency?
query = query.unflatten(1, (T, -1)).flatten(0, 1) # [B, S, H, D] --> [B * T, S / T, H, D]
key = key.flatten(0, 1) # [B, T, N, H, D_kv] --> [B * T, N, H, D_kv]
value = value.flatten(0, 1)
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.type_as(query)
hidden_states = hidden_states.unflatten(0, (B, T)).flatten(1, 2)
hidden_states = attn.to_out(hidden_states)
if attention_mask is not None:
# NOTE: attention_mask is assumed to be a multiplicative mask
attention_mask = attention_mask.flatten(start_dim=1)
hidden_states = hidden_states * attention_mask
return hidden_states
class WanAnimateFaceBlockCrossAttention(nn.Module, AttentionModuleMixin):
"""
Temporally-aligned cross attention with the face motion signal in the Wan Animate Face Blocks.
"""
_default_processor_cls = WanAnimateFaceBlockAttnProcessor
_available_processors = [WanAnimateFaceBlockAttnProcessor]
def __init__(
self,
dim: int,
heads: int = 8,
dim_head: int = 64,
eps: float = 1e-6,
cross_attention_dim_head: Optional[int] = None,
processor=None,
):
super().__init__()
self.inner_dim = dim_head * heads
self.heads = heads
self.cross_attention_head_dim = cross_attention_dim_head
self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
# 1. Pre-Attention Norms for the hidden_states (video latents) and encoder_hidden_states (motion vector).
# NOTE: this is not used in "vanilla" WanAttention
self.pre_norm_q = nn.LayerNorm(dim, eps, elementwise_affine=False)
self.pre_norm_kv = nn.LayerNorm(dim, eps, elementwise_affine=False)
# 2. QKV and Output Projections
self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
self.to_out = torch.nn.Linear(self.inner_dim, dim, bias=True)
# 3. QK Norm
# NOTE: this is applied after the reshape, so only over dim_head rather than dim_head * heads
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=True)
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=True)
# 4. Set attention processor
if processor is None:
processor = self._default_processor_cls()
self.set_processor(processor)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask)
# Copied from diffusers.models.transformers.transformer_wan.WanAttnProcessor
class WanAttnProcessor:
_attention_backend = None
_parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"WanAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
)
def __call__(
self,
attn: "WanAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
encoder_hidden_states_img = None
if attn.add_k_proj is not None:
# 512 is the context length of the text encoder, hardcoded for now
image_context_length = encoder_hidden_states.shape[1] - 512
encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
query = attn.norm_q(query)
key = attn.norm_k(key)
query = query.unflatten(2, (attn.heads, -1))
key = key.unflatten(2, (attn.heads, -1))
value = value.unflatten(2, (attn.heads, -1))
if rotary_emb is not None:
def apply_rotary_emb(
hidden_states: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor,
):
x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
cos = freqs_cos[..., 0::2]
sin = freqs_sin[..., 1::2]
out = torch.empty_like(hidden_states)
out[..., 0::2] = x1 * cos - x2 * sin
out[..., 1::2] = x1 * sin + x2 * cos
return out.type_as(hidden_states)
query = apply_rotary_emb(query, *rotary_emb)
key = apply_rotary_emb(key, *rotary_emb)
# I2V task
hidden_states_img = None
if encoder_hidden_states_img is not None:
key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img)
key_img = attn.norm_added_k(key_img)
key_img = key_img.unflatten(2, (attn.heads, -1))
value_img = value_img.unflatten(2, (attn.heads, -1))
hidden_states_img = dispatch_attention_fn(
query,
key_img,
value_img,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states_img = hidden_states_img.flatten(2, 3)
hidden_states_img = hidden_states_img.type_as(query)
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.type_as(query)
if hidden_states_img is not None:
hidden_states = hidden_states + hidden_states_img
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
# Copied from diffusers.models.transformers.transformer_wan.WanAttention
class WanAttention(torch.nn.Module, AttentionModuleMixin):
_default_processor_cls = WanAttnProcessor
_available_processors = [WanAttnProcessor]
def __init__(
self,
dim: int,
heads: int = 8,
dim_head: int = 64,
eps: float = 1e-5,
dropout: float = 0.0,
added_kv_proj_dim: Optional[int] = None,
cross_attention_dim_head: Optional[int] = None,
processor=None,
is_cross_attention=None,
):
super().__init__()
self.inner_dim = dim_head * heads
self.heads = heads
self.added_kv_proj_dim = added_kv_proj_dim
self.cross_attention_dim_head = cross_attention_dim_head
self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
self.to_out = torch.nn.ModuleList(
[
torch.nn.Linear(self.inner_dim, dim, bias=True),
torch.nn.Dropout(dropout),
]
)
self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
self.add_k_proj = self.add_v_proj = None
if added_kv_proj_dim is not None:
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
self.is_cross_attention = cross_attention_dim_head is not None
self.set_processor(processor)
def fuse_projections(self):
if getattr(self, "fused_projections", False):
return
if self.cross_attention_dim_head is None:
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
out_features, in_features = concatenated_weights.shape
with torch.device("meta"):
self.to_qkv = nn.Linear(in_features, out_features, bias=True)
self.to_qkv.load_state_dict(
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
)
else:
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
out_features, in_features = concatenated_weights.shape
with torch.device("meta"):
self.to_kv = nn.Linear(in_features, out_features, bias=True)
self.to_kv.load_state_dict(
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
)
if self.added_kv_proj_dim is not None:
concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data])
concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data])
out_features, in_features = concatenated_weights.shape
with torch.device("meta"):
self.to_added_kv = nn.Linear(in_features, out_features, bias=True)
self.to_added_kv.load_state_dict(
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
)
self.fused_projections = True
@torch.no_grad()
def unfuse_projections(self):
if not getattr(self, "fused_projections", False):
return
if hasattr(self, "to_qkv"):
delattr(self, "to_qkv")
if hasattr(self, "to_kv"):
delattr(self, "to_kv")
if hasattr(self, "to_added_kv"):
delattr(self, "to_added_kv")
self.fused_projections = False
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> torch.Tensor:
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs)
# Copied from diffusers.models.transformers.transformer_wan.WanImageEmbedding
class WanImageEmbedding(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
super().__init__()
self.norm1 = FP32LayerNorm(in_features)
self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
self.norm2 = FP32LayerNorm(out_features)
if pos_embed_seq_len is not None:
self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
else:
self.pos_embed = None
def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
if self.pos_embed is not None:
batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim)
encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed
hidden_states = self.norm1(encoder_hidden_states_image)
hidden_states = self.ff(hidden_states)
hidden_states = self.norm2(hidden_states)
return hidden_states
# Copied from diffusers.models.transformers.transformer_wan.WanTimeTextImageEmbedding
class WanTimeTextImageEmbedding(nn.Module):
def __init__(
self,
dim: int,
time_freq_dim: int,
time_proj_dim: int,
text_embed_dim: int,
image_embed_dim: Optional[int] = None,
pos_embed_seq_len: Optional[int] = None,
):
super().__init__()
self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
self.act_fn = nn.SiLU()
self.time_proj = nn.Linear(dim, time_proj_dim)
self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
self.image_embedder = None
if image_embed_dim is not None:
self.image_embedder = WanImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len)
def forward(
self,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states_image: Optional[torch.Tensor] = None,
timestep_seq_len: Optional[int] = None,
):
timestep = self.timesteps_proj(timestep)
if timestep_seq_len is not None:
timestep = timestep.unflatten(0, (-1, timestep_seq_len))
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
timestep = timestep.to(time_embedder_dtype)
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
timestep_proj = self.time_proj(self.act_fn(temb))
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
if encoder_hidden_states_image is not None:
encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
# Copied from diffusers.models.transformers.transformer_wan.WanRotaryPosEmbed
class WanRotaryPosEmbed(nn.Module):
def __init__(
self,
attention_head_dim: int,
patch_size: Tuple[int, int, int],
max_seq_len: int,
theta: float = 10000.0,
):
super().__init__()
self.attention_head_dim = attention_head_dim
self.patch_size = patch_size
self.max_seq_len = max_seq_len
h_dim = w_dim = 2 * (attention_head_dim // 6)
t_dim = attention_head_dim - h_dim - w_dim
self.t_dim = t_dim
self.h_dim = h_dim
self.w_dim = w_dim
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
freqs_cos = []
freqs_sin = []
for dim in [t_dim, h_dim, w_dim]:
freq_cos, freq_sin = get_1d_rotary_pos_embed(
dim,
max_seq_len,
theta,
use_real=True,
repeat_interleave_real=True,
freqs_dtype=freqs_dtype,
)
freqs_cos.append(freq_cos)
freqs_sin.append(freq_sin)
self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.patch_size
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
split_sizes = [self.t_dim, self.h_dim, self.w_dim]
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
return freqs_cos, freqs_sin
# Copied from diffusers.models.transformers.transformer_wan.WanTransformerBlock
class WanTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
ffn_dim: int,
num_heads: int,
qk_norm: str = "rms_norm_across_heads",
cross_attn_norm: bool = False,
eps: float = 1e-6,
added_kv_proj_dim: Optional[int] = None,
):
super().__init__()
# 1. Self-attention
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
self.attn1 = WanAttention(
dim=dim,
heads=num_heads,
dim_head=dim // num_heads,
eps=eps,
cross_attention_dim_head=None,
processor=WanAttnProcessor(),
)
# 2. Cross-attention
self.attn2 = WanAttention(
dim=dim,
heads=num_heads,
dim_head=dim // num_heads,
eps=eps,
added_kv_proj_dim=added_kv_proj_dim,
cross_attention_dim_head=dim // num_heads,
processor=WanAttnProcessor(),
)
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
# 3. Feed-forward
self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
rotary_emb: torch.Tensor,
) -> torch.Tensor:
if temb.ndim == 4:
# temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
self.scale_shift_table.unsqueeze(0) + temb.float()
).chunk(6, dim=2)
# batch_size, seq_len, 1, inner_dim
shift_msa = shift_msa.squeeze(2)
scale_msa = scale_msa.squeeze(2)
gate_msa = gate_msa.squeeze(2)
c_shift_msa = c_shift_msa.squeeze(2)
c_scale_msa = c_scale_msa.squeeze(2)
c_gate_msa = c_gate_msa.squeeze(2)
else:
# temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
self.scale_shift_table + temb.float()
).chunk(6, dim=1)
# 1. Self-attention
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb)
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
# 2. Cross-attention
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None)
hidden_states = hidden_states + attn_output
# 3. Feed-forward
norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
hidden_states
)
ff_output = self.ffn(norm_hidden_states)
hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
return hidden_states
class WanAnimateTransformer3DModel(
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
):
r"""
A Transformer model for video-like data used in the WanAnimate model.
Args:
patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
num_attention_heads (`int`, defaults to `40`):
Fixed length for text embeddings.
attention_head_dim (`int`, defaults to `128`):
The number of channels in each head.
in_channels (`int`, defaults to `16`):
The number of channels in the input.
out_channels (`int`, defaults to `16`):
The number of channels in the output.
text_dim (`int`, defaults to `512`):
Input dimension for text embeddings.
freq_dim (`int`, defaults to `256`):
Dimension for sinusoidal time embeddings.
ffn_dim (`int`, defaults to `13824`):
Intermediate dimension in feed-forward network.
num_layers (`int`, defaults to `40`):
The number of layers of transformer blocks to use.
window_size (`Tuple[int]`, defaults to `(-1, -1)`):
Window size for local attention (-1 indicates global attention).
cross_attn_norm (`bool`, defaults to `True`):
Enable cross-attention normalization.
qk_norm (`bool`, defaults to `True`):
Enable query/key normalization.
eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
image_dim (`int`, *optional*, defaults to `1280`):
The number of channels to use for the image embedding. If `None`, no projection is used.
added_kv_proj_dim (`int`, *optional*, defaults to `5120`):
The number of channels to use for the added key and value projections. If `None`, no projection is used.
"""
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
_no_split_modules = ["WanTransformerBlock", "MotionEncoderResBlock"]
_keep_in_fp32_modules = [
"time_embedder",
"scale_shift_table",
"norm1",
"norm2",
"norm3",
"motion_synthesis_weight",
]
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
_repeated_blocks = ["WanTransformerBlock"]
@register_to_config
def __init__(
self,
patch_size: Tuple[int] = (1, 2, 2),
num_attention_heads: int = 40,
attention_head_dim: int = 128,
in_channels: Optional[int] = 36,
latent_channels: Optional[int] = 16,
out_channels: Optional[int] = 16,
text_dim: int = 4096,
freq_dim: int = 256,
ffn_dim: int = 13824,
num_layers: int = 40,
cross_attn_norm: bool = True,
qk_norm: Optional[str] = "rms_norm_across_heads",
eps: float = 1e-6,
image_dim: Optional[int] = 1280,
added_kv_proj_dim: Optional[int] = None,
rope_max_seq_len: int = 1024,
pos_embed_seq_len: Optional[int] = None,
motion_encoder_channel_sizes: Optional[Dict[str, int]] = None, # Start of Wan Animate-specific args
motion_encoder_size: int = 512,
motion_style_dim: int = 512,
motion_dim: int = 20,
motion_encoder_dim: int = 512,
face_encoder_hidden_dim: int = 1024,
face_encoder_num_heads: int = 4,
inject_face_latents_blocks: int = 5,
motion_encoder_batch_size: int = 8,
) -> None:
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
# Allow either only in_channels or only latent_channels to be set for convenience
if in_channels is None and latent_channels is not None:
in_channels = 2 * latent_channels + 4
elif in_channels is not None and latent_channels is None:
latent_channels = (in_channels - 4) // 2
elif in_channels is not None and latent_channels is not None:
# TODO: should this always be true?
assert in_channels == 2 * latent_channels + 4, "in_channels should be 2 * latent_channels + 4"
else:
raise ValueError("At least one of `in_channels` and `latent_channels` must be supplied.")
out_channels = out_channels or latent_channels
# 1. Patch & position embedding
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
self.pose_patch_embedding = nn.Conv3d(latent_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
# 2. Condition embeddings
self.condition_embedder = WanTimeTextImageEmbedding(
dim=inner_dim,
time_freq_dim=freq_dim,
time_proj_dim=inner_dim * 6,
text_embed_dim=text_dim,
image_embed_dim=image_dim,
pos_embed_seq_len=pos_embed_seq_len,
)
# Motion encoder
self.motion_encoder = WanAnimateMotionEncoder(
size=motion_encoder_size,
style_dim=motion_style_dim,
motion_dim=motion_dim,
out_dim=motion_encoder_dim,
channels=motion_encoder_channel_sizes,
)
# Face encoder
self.face_encoder = WanAnimateFaceEncoder(
in_dim=motion_encoder_dim,
out_dim=inner_dim,
hidden_dim=face_encoder_hidden_dim,
num_heads=face_encoder_num_heads,
)
# 3. Transformer blocks
self.blocks = nn.ModuleList(
[
WanTransformerBlock(
dim=inner_dim,
ffn_dim=ffn_dim,
num_heads=num_attention_heads,
qk_norm=qk_norm,
cross_attn_norm=cross_attn_norm,
eps=eps,
added_kv_proj_dim=added_kv_proj_dim,
)
for _ in range(num_layers)
]
)
self.face_adapter = nn.ModuleList(
[
WanAnimateFaceBlockCrossAttention(
dim=inner_dim,
heads=num_attention_heads,
dim_head=inner_dim // num_attention_heads,
eps=eps,
cross_attention_dim_head=inner_dim // num_attention_heads,
processor=WanAnimateFaceBlockAttnProcessor(),
)
for _ in range(num_layers // inject_face_latents_blocks)
]
)
# 4. Output norm & projection
self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states_image: Optional[torch.Tensor] = None,
pose_hidden_states: Optional[torch.Tensor] = None,
face_pixel_values: Optional[torch.Tensor] = None,
motion_encode_batch_size: Optional[int] = None,
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Forward pass of Wan2.2-Animate transformer model.
Args:
hidden_states (`torch.Tensor` of shape `(B, 2C + 4, T + 1, H, W)`):
Input noisy video latents of shape `(B, 2C + 4, T + 1, H, W)`, where B is the batch size, C is the
number of latent channels (16 for Wan VAE), T is the number of latent frames in an inference segment, H
is the latent height, and W is the latent width.
timestep: (`torch.LongTensor`):
The current timestep in the denoising loop.
encoder_hidden_states (`torch.Tensor`):
Text embeddings from the text encoder (umT5 for Wan Animate).
encoder_hidden_states_image (`torch.Tensor`):
CLIP visual features of the reference (character) image.
pose_hidden_states (`torch.Tensor` of shape `(B, C, T, H, W)`):
Pose video latents. TODO: description
face_pixel_values (`torch.Tensor` of shape `(B, C', S, H', W')`):
Face video in pixel space (not latent space). Typically C' = 3 and H' and W' are the height/width of
the face video in pixels. Here S is the inference segment length, usually set to 77.
motion_encode_batch_size (`int`, *optional*):
The batch size for batched encoding of the face video via the motion encoder. Will default to
`self.config.motion_encoder_batch_size` if not set.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return the output as a dict or tuple.
"""
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
# Check that shapes match up
if pose_hidden_states is not None and pose_hidden_states.shape[2] + 1 != hidden_states.shape[2]:
raise ValueError(
f"pose_hidden_states frame dim (dim 2) is {pose_hidden_states.shape[2]} but must be one less than the"
f" hidden_states's corresponding frame dim: {hidden_states.shape[2]}"
)
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.config.patch_size
post_patch_num_frames = num_frames // p_t
post_patch_height = height // p_h
post_patch_width = width // p_w
# 1. Rotary position embedding
rotary_emb = self.rope(hidden_states)
# 2. Patch embedding
hidden_states = self.patch_embedding(hidden_states)
pose_hidden_states = self.pose_patch_embedding(pose_hidden_states)
# Add pose embeddings to hidden states
hidden_states[:, :, 1:] = hidden_states[:, :, 1:] + pose_hidden_states
# Calling contiguous() here is important so that we don't recompile when performing regional compilation
hidden_states = hidden_states.flatten(2).transpose(1, 2).contiguous()
# 3. Condition embeddings (time, text, image)
# Wan Animate is based on Wan 2.1 and thus uses Wan 2.1's timestep logic
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=None
)
# batch_size, 6, inner_dim
timestep_proj = timestep_proj.unflatten(1, (6, -1))
if encoder_hidden_states_image is not None:
encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
# 4. Get motion features from the face video
# Motion vector computation from face pixel values
batch_size, channels, num_face_frames, height, width = face_pixel_values.shape
# Rearrange from (B, C, T, H, W) to (B*T, C, H, W)
face_pixel_values = face_pixel_values.permute(0, 2, 1, 3, 4).reshape(-1, channels, height, width)
# Extract motion features using motion encoder
# Perform batched motion encoder inference to allow trading off inference speed for memory usage
motion_encode_batch_size = motion_encode_batch_size or self.config.motion_encoder_batch_size
face_batches = torch.split(face_pixel_values, motion_encode_batch_size)
motion_vec_batches = []
for face_batch in face_batches:
motion_vec_batch = self.motion_encoder(face_batch)
motion_vec_batches.append(motion_vec_batch)
motion_vec = torch.cat(motion_vec_batches)
motion_vec = motion_vec.view(batch_size, num_face_frames, -1)
# Now get face features from the motion vector
motion_vec = self.face_encoder(motion_vec)
# Add padding at the beginning (prepend zeros)
pad_face = torch.zeros_like(motion_vec[:, :1])
motion_vec = torch.cat([pad_face, motion_vec], dim=1)
# 5. Transformer blocks with face adapter integration
for block_idx, block in enumerate(self.blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
)
else:
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
# Face adapter integration: apply after every 5th block (0, 5, 10, 15, ...)
if block_idx % self.config.inject_face_latents_blocks == 0:
face_adapter_block_idx = block_idx // self.config.inject_face_latents_blocks
face_adapter_output = self.face_adapter[face_adapter_block_idx](hidden_states, motion_vec)
# In case the face adapter and main transformer blocks are on different devices, which can happen when
# using model parallelism
face_adapter_output = face_adapter_output.to(device=hidden_states.device)
hidden_states = face_adapter_output + hidden_states
# 6. Output norm, projection & unpatchify
# batch_size, inner_dim
shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
hidden_states_original_dtype = hidden_states.dtype
hidden_states = self.norm_out(hidden_states.float())
# Move the shift and scale tensors to the same device as hidden_states.
# When using multi-GPU inference via accelerate these will be on the
# first device rather than the last device, which hidden_states ends up
# on.
shift = shift.to(hidden_states.device)
scale = scale.to(hidden_states.device)
hidden_states = (hidden_states * (1 + scale) + shift).to(dtype=hidden_states_original_dtype)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
)
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
......@@ -385,7 +385,13 @@ else:
"WuerstchenDecoderPipeline",
"WuerstchenPriorPipeline",
]
_import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline", "WanVideoToVideoPipeline", "WanVACEPipeline"]
_import_structure["wan"] = [
"WanPipeline",
"WanImageToVideoPipeline",
"WanVideoToVideoPipeline",
"WanVACEPipeline",
"WanAnimatePipeline",
]
_import_structure["kandinsky5"] = ["Kandinsky5T2VPipeline"]
_import_structure["skyreels_v2"] = [
"SkyReelsV2DiffusionForcingPipeline",
......@@ -803,7 +809,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
UniDiffuserTextDecoder,
)
from .visualcloze import VisualClozeGenerationPipeline, VisualClozePipeline
from .wan import WanImageToVideoPipeline, WanPipeline, WanVACEPipeline, WanVideoToVideoPipeline
from .wan import (
WanAnimatePipeline,
WanImageToVideoPipeline,
WanPipeline,
WanVACEPipeline,
WanVideoToVideoPipeline,
)
from .wuerstchen import (
WuerstchenCombinedPipeline,
WuerstchenDecoderPipeline,
......
......@@ -23,6 +23,7 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_wan"] = ["WanPipeline"]
_import_structure["pipeline_wan_animate"] = ["WanAnimatePipeline"]
_import_structure["pipeline_wan_i2v"] = ["WanImageToVideoPipeline"]
_import_structure["pipeline_wan_vace"] = ["WanVACEPipeline"]
_import_structure["pipeline_wan_video2video"] = ["WanVideoToVideoPipeline"]
......@@ -35,10 +36,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_wan import WanPipeline
from .pipeline_wan_animate import WanAnimatePipeline
from .pipeline_wan_i2v import WanImageToVideoPipeline
from .pipeline_wan_vace import WanVACEPipeline
from .pipeline_wan_video2video import WanVideoToVideoPipeline
else:
import sys
......
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import numpy as np
import PIL.Image
import torch
from ...configuration_utils import register_to_config
from ...image_processor import VaeImageProcessor
from ...utils import PIL_INTERPOLATION
class WanAnimateImageProcessor(VaeImageProcessor):
r"""
Image processor to preprocess the reference (character) image for the Wan Animate model.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
`height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
vae_scale_factor (`int`, *optional*, defaults to `8`):
VAE (spatial) scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of
this factor.
vae_latent_channels (`int`, *optional*, defaults to `16`):
VAE latent channels.
spatial_patch_size (`Tuple[int, int]`, *optional*, defaults to `(2, 2)`):
The spatial patch size used by the diffusion transformer. For Wan models, this is typically (2, 2).
resample (`str`, *optional*, defaults to `lanczos`):
Resampling filter to use when resizing the image.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image to [-1,1].
do_binarize (`bool`, *optional*, defaults to `False`):
Whether to binarize the image to 0/1.
do_convert_rgb (`bool`, *optional*, defaults to be `False`):
Whether to convert the images to RGB format.
do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
Whether to convert the images to grayscale format.
fill_color (`str` or `float` or `Tuple[float, ...]`, *optional*, defaults to `None`):
An optional fill color when `resize_mode` is set to `"fill"`. This will fill the empty space with that
color instead of filling with data from the image. Any valid `color` argument to `PIL.Image.new` is valid;
if `None`, will default to filling with data from `image`.
"""
@register_to_config
def __init__(
self,
do_resize: bool = True,
vae_scale_factor: int = 8,
vae_latent_channels: int = 16,
spatial_patch_size: Tuple[int, int] = (2, 2),
resample: str = "lanczos",
reducing_gap: int = None,
do_normalize: bool = True,
do_binarize: bool = False,
do_convert_rgb: bool = False,
do_convert_grayscale: bool = False,
fill_color: Optional[Union[str, float, Tuple[float, ...]]] = 0,
):
super().__init__()
if do_convert_rgb and do_convert_grayscale:
raise ValueError(
"`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`,"
" if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.",
" if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`",
)
def _resize_and_fill(
self,
image: PIL.Image.Image,
width: int,
height: int,
) -> PIL.Image.Image:
r"""
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
the image within the dimensions, filling empty with data from image.
Args:
image (`PIL.Image.Image`):
The image to resize and fill.
width (`int`):
The width to resize the image to.
height (`int`):
The height to resize the image to.
Returns:
`PIL.Image.Image`:
The resized and filled image.
"""
ratio = width / height
src_ratio = image.width / image.height
fill_with_image_data = self.config.fill_color is None
fill_color = self.config.fill_color or 0
src_w = width if ratio < src_ratio else image.width * height // image.height
src_h = height if ratio >= src_ratio else image.height * width // image.width
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION[self.config.resample])
res = PIL.Image.new("RGB", (width, height), color=fill_color)
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
if fill_with_image_data:
if ratio < src_ratio:
fill_height = height // 2 - src_h // 2
if fill_height > 0:
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
res.paste(
resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)),
box=(0, fill_height + src_h),
)
elif ratio > src_ratio:
fill_width = width // 2 - src_w // 2
if fill_width > 0:
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
res.paste(
resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)),
box=(fill_width + src_w, 0),
)
return res
def get_default_height_width(
self,
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
height: Optional[int] = None,
width: Optional[int] = None,
) -> Tuple[int, int]:
r"""
Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`.
Args:
image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
The image input, which can be a PIL image, NumPy array, or PyTorch tensor. If it is a NumPy array, it
should have shape `[batch, height, width]` or `[batch, height, width, channels]`. If it is a PyTorch
tensor, it should have shape `[batch, channels, height, width]`.
height (`Optional[int]`, *optional*, defaults to `None`):
The height of the preprocessed image. If `None`, the height of the `image` input will be used.
width (`Optional[int]`, *optional*, defaults to `None`):
The width of the preprocessed image. If `None`, the width of the `image` input will be used.
Returns:
`Tuple[int, int]`:
A tuple containing the height and width, both resized to the nearest integer multiple of
`vae_scale_factor * spatial_patch_size`.
"""
if height is None:
if isinstance(image, PIL.Image.Image):
height = image.height
elif isinstance(image, torch.Tensor):
height = image.shape[2]
else:
height = image.shape[1]
if width is None:
if isinstance(image, PIL.Image.Image):
width = image.width
elif isinstance(image, torch.Tensor):
width = image.shape[3]
else:
width = image.shape[2]
max_area = width * height
aspect_ratio = height / width
mod_value_h = self.config.vae_scale_factor * self.config.spatial_patch_size[0]
mod_value_w = self.config.vae_scale_factor * self.config.spatial_patch_size[1]
# Try to preserve the aspect ratio
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value_h * mod_value_h
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value_w * mod_value_w
return height, width
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import html
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import PIL
import regex as re
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput
from ...loaders import WanLoraLoaderMixin
from ...models import AutoencoderKLWan, WanAnimateTransformer3DModel
from ...schedulers import UniPCMultistepScheduler
from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
from .image_processor import WanAnimateImageProcessor
from .pipeline_output import WanPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_ftfy_available():
import ftfy
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> import numpy as np
>>> from diffusers import WanAnimatePipeline
>>> from diffusers.utils import export_to_video, load_image, load_video
>>> model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers"
>>> pipe = WanAnimatePipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
>>> # Optionally upcast the Wan VAE to FP32
>>> pipe.vae.to(torch.float32)
>>> pipe.to("cuda")
>>> # Load the reference character image
>>> image = load_image(
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
... )
>>> # Load pose and face videos (preprocessed from reference video)
>>> # Note: Videos should be preprocessed to extract pose keypoints and face features
>>> # Refer to the Wan-Animate preprocessing documentation for details
>>> pose_video = load_video("path/to/pose_video.mp4")
>>> face_video = load_video("path/to/face_video.mp4")
>>> # CFG is generally not used for Wan Animate
>>> prompt = (
... "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
... "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
... )
>>> # Animation mode: Animate the character with the motion from pose/face videos
>>> output = pipe(
... image=image,
... pose_video=pose_video,
... face_video=face_video,
... prompt=prompt,
... height=height,
... width=width,
... segment_frame_length=77, # Frame length of each inference segment
... guidance_scale=1.0,
... num_inference_steps=20,
... mode="animate",
... ).frames[0]
>>> export_to_video(output, "output_animation.mp4", fps=30)
>>> # Replacement mode: Replace a character in the background video
>>> # Requires additional background_video and mask_video inputs
>>> background_video = load_video("path/to/background_video.mp4")
>>> mask_video = load_video("path/to/mask_video.mp4") # Black areas preserved, white areas generated
>>> output = pipe(
... image=image,
... pose_video=pose_video,
... face_video=face_video,
... background_video=background_video,
... mask_video=mask_video,
... prompt=prompt,
... height=height,
... width=width,
... segment_frame_length=77, # Frame length of each inference segment
... guidance_scale=1.0,
... num_inference_steps=20,
... mode="replace",
... ).frames[0]
>>> export_to_video(output, "output_replacement.mp4", fps=30)
```
"""
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
def prompt_clean(text):
text = whitespace_clean(basic_clean(text))
return text
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class WanAnimatePipeline(DiffusionPipeline, WanLoraLoaderMixin):
r"""
Pipeline for unified character animation and replacement using Wan-Animate.
WanAnimatePipeline takes a character image, pose video, and face video as input, and generates a video in two
modes:
1. **Animation mode**: The model generates a video of the character image that mimics the human motion in the input
pose and face videos. The character is animated based on the provided motion controls, creating a new animated
video of the character.
2. **Replacement mode**: The model replaces a character in a background video with the provided character image,
using the pose and face videos for motion control. This mode requires additional `background_video` and
`mask_video` inputs. The mask video should have black regions where the original content should be preserved and
white regions where the new character should be generated.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
The pipeline also inherits the following loading methods:
- [`~loaders.WanLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
Args:
tokenizer ([`T5Tokenizer`]):
Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
text_encoder ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
image_encoder ([`CLIPVisionModel`]):
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModel), specifically
the
[clip-vit-huge-patch14](https://github.com/mlfoundations/open_clip/blob/main/docs/PRETRAINED.md#vit-h14-xlm-roberta-large)
variant.
transformer ([`WanAnimateTransformer3DModel`]):
Conditional Transformer to denoise the input latents.
scheduler ([`UniPCMultistepScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLWan`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
image_processor ([`CLIPImageProcessor`]):
Image processor for preprocessing images before encoding.
"""
model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
tokenizer: AutoTokenizer,
text_encoder: UMT5EncoderModel,
vae: AutoencoderKLWan,
scheduler: UniPCMultistepScheduler,
image_processor: CLIPImageProcessor,
image_encoder: CLIPVisionModel,
transformer: WanAnimateTransformer3DModel,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
image_encoder=image_encoder,
transformer=transformer,
scheduler=scheduler,
image_processor=image_processor,
)
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
self.video_processor_for_mask = VideoProcessor(
vae_scale_factor=self.vae_scale_factor_spatial, do_normalize=False, do_convert_grayscale=True
)
# In case self.transformer is None (e.g. for some pipeline tests)
spatial_patch_size = self.transformer.config.patch_size[-2:] if self.transformer is not None else (2, 2)
self.vae_image_processor = WanAnimateImageProcessor(
vae_scale_factor=self.vae_scale_factor_spatial,
spatial_patch_size=spatial_patch_size,
resample="bilinear",
fill_color=0,
)
self.image_processor = image_processor
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_videos_per_prompt: int = 1,
max_sequence_length: int = 512,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt = [prompt_clean(u) for u in prompt]
batch_size = len(prompt)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_attention_mask=True,
return_tensors="pt",
)
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
seq_lens = mask.gt(0).sum(dim=1).long()
prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
return prompt_embeds
# Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.encode_image
def encode_image(
self,
image: PipelineImageInput,
device: Optional[torch.device] = None,
):
device = device or self._execution_device
image = self.image_processor(images=image, return_tensors="pt").to(device)
image_embeds = self.image_encoder(**image, output_hidden_states=True)
return image_embeds.hidden_states[-2]
# Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
do_classifier_free_guidance: bool = True,
num_videos_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
Whether to use classifier free guidance or not.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
device: (`torch.device`, *optional*):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds = self._get_t5_prompt_embeds(
prompt=prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds = self._get_t5_prompt_embeds(
prompt=negative_prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
return prompt_embeds, negative_prompt_embeds
def check_inputs(
self,
prompt,
negative_prompt,
image,
pose_video,
face_video,
background_video,
mask_video,
height,
width,
prompt_embeds=None,
negative_prompt_embeds=None,
image_embeds=None,
callback_on_step_end_tensor_inputs=None,
mode=None,
prev_segment_conditioning_frames=None,
):
if image is not None and image_embeds is not None:
raise ValueError(
f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to"
" only forward one of the two."
)
if image is None and image_embeds is None:
raise ValueError(
"Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined."
)
if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}")
if pose_video is None:
raise ValueError("Provide `pose_video`. Cannot leave `pose_video` undefined.")
if face_video is None:
raise ValueError("Provide `face_video`. Cannot leave `face_video` undefined.")
if not isinstance(pose_video, list) or not isinstance(face_video, list):
raise ValueError("`pose_video` and `face_video` must be lists of PIL images.")
if len(pose_video) == 0 or len(face_video) == 0:
raise ValueError("`pose_video` and `face_video` must contain at least one frame.")
if mode == "replace" and (background_video is None or mask_video is None):
raise ValueError(
"Provide `background_video` and `mask_video`. Cannot leave both `background_video` and `mask_video`"
" undefined when mode is `replace`."
)
if mode == "replace" and (not isinstance(background_video, list) or not isinstance(mask_video, list)):
raise ValueError("`background_video` and `mask_video` must be lists of PIL images when mode is `replace`.")
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found"
f" {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
elif negative_prompt is not None and (
not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
):
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
if mode is not None and (not isinstance(mode, str) or mode not in ("animate", "replace")):
raise ValueError(
f"`mode` has to be of type `str` and in ('animate', 'replace') but its type is {type(mode)} and value is {mode}"
)
if prev_segment_conditioning_frames is not None and (
not isinstance(prev_segment_conditioning_frames, int) or prev_segment_conditioning_frames not in (1, 5)
):
raise ValueError(
f"`prev_segment_conditioning_frames` has to be of type `int` and 1 or 5 but its type is"
f" {type(prev_segment_conditioning_frames)} and value is {prev_segment_conditioning_frames}"
)
def get_i2v_mask(
self,
batch_size: int,
latent_t: int,
latent_h: int,
latent_w: int,
mask_len: int = 1,
mask_pixel_values: Optional[torch.Tensor] = None,
dtype: Optional[torch.dtype] = None,
device: Union[str, torch.device] = "cuda",
) -> torch.Tensor:
# mask_pixel_values shape (if supplied): [B, C = 1, T, latent_h, latent_w]
if mask_pixel_values is None:
mask_lat_size = torch.zeros(
batch_size, 1, (latent_t - 1) * 4 + 1, latent_h, latent_w, dtype=dtype, device=device
)
else:
mask_lat_size = mask_pixel_values.clone().to(device=device, dtype=dtype)
mask_lat_size[:, :, :mask_len] = 1
first_frame_mask = mask_lat_size[:, :, 0:1]
# Repeat first frame mask self.vae_scale_factor_temporal (= 4) times in the frame dimension
first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal)
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:]], dim=2)
mask_lat_size = mask_lat_size.view(
batch_size, -1, self.vae_scale_factor_temporal, latent_h, latent_w
).transpose(1, 2) # [B, C = 1, 4 * T_lat, H_lat, W_lat] --> [B, C = 4, T_lat, H_lat, W_lat]
return mask_lat_size
def prepare_reference_image_latents(
self,
image: torch.Tensor,
batch_size: int = 1,
sample_mode: int = "argmax",
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
) -> torch.Tensor:
# image shape: (B, C, H, W) or (B, C, T, H, W)
dtype = dtype or self.vae.dtype
if image.ndim == 4:
# Add a singleton frame dimension after the channels dimension
image = image.unsqueeze(2)
_, _, _, height, width = image.shape
latent_height = height // self.vae_scale_factor_spatial
latent_width = width // self.vae_scale_factor_spatial
# Encode image to latents using VAE
image = image.to(device=device, dtype=dtype)
if isinstance(generator, list):
# Like in prepare_latents, assume len(generator) == batch_size
ref_image_latents = [
retrieve_latents(self.vae.encode(image), generator=g, sample_mode=sample_mode) for g in generator
]
ref_image_latents = torch.cat(ref_image_latents)
else:
ref_image_latents = retrieve_latents(self.vae.encode(image), generator, sample_mode)
# Standardize latents in preparation for Wan VAE encode
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(ref_image_latents.device, ref_image_latents.dtype)
)
latents_recip_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
ref_image_latents.device, ref_image_latents.dtype
)
ref_image_latents = (ref_image_latents - latents_mean) * latents_recip_std
# Handle the case where we supply one image and one generator, but batch_size > 1 (e.g. generating multiple
# videos per prompt)
if ref_image_latents.shape[0] == 1 and batch_size > 1:
ref_image_latents = ref_image_latents.expand(batch_size, -1, -1, -1, -1)
# Prepare I2V mask in latent space and prepend to the reference image latents along channel dim
reference_image_mask = self.get_i2v_mask(batch_size, 1, latent_height, latent_width, 1, None, dtype, device)
reference_image_latents = torch.cat([reference_image_mask, ref_image_latents], dim=1)
return reference_image_latents
def prepare_prev_segment_cond_latents(
self,
prev_segment_cond_video: Optional[torch.Tensor] = None,
background_video: Optional[torch.Tensor] = None,
mask_video: Optional[torch.Tensor] = None,
batch_size: int = 1,
segment_frame_length: int = 77,
start_frame: int = 0,
height: int = 720,
width: int = 1280,
prev_segment_cond_frames: int = 1,
task: str = "animate",
interpolation_mode: str = "bicubic",
sample_mode: str = "argmax",
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
) -> torch.Tensor:
# prev_segment_cond_video shape: (B, C, T, H, W) in pixel space if supplied
# background_video shape: (B, C, T, H, W) (same as prev_segment_cond_video shape)
# mask_video shape: (B, 1, T, H, W) (same as prev_segment_cond_video, but with only 1 channel)
dtype = dtype or self.vae.dtype
if prev_segment_cond_video is None:
if task == "replace":
prev_segment_cond_video = background_video[:, :, :prev_segment_cond_frames].to(dtype)
else:
cond_frames_shape = (batch_size, 3, prev_segment_cond_frames, height, width) # In pixel space
prev_segment_cond_video = torch.zeros(cond_frames_shape, dtype=dtype, device=device)
data_batch_size, channels, _, segment_height, segment_width = prev_segment_cond_video.shape
num_latent_frames = (segment_frame_length - 1) // self.vae_scale_factor_temporal + 1
latent_height = height // self.vae_scale_factor_spatial
latent_width = width // self.vae_scale_factor_spatial
if segment_height != height or segment_width != width:
print(
f"Interpolating prev segment cond video from ({segment_width}, {segment_height}) to ({width}, {height})"
)
# Perform a 4D (spatial) rather than a 5D (spatiotemporal) reshape, following the original code
prev_segment_cond_video = prev_segment_cond_video.transpose(1, 2).flatten(0, 1) # [B * T, C, H, W]
prev_segment_cond_video = F.interpolate(
prev_segment_cond_video, size=(height, width), mode=interpolation_mode
)
prev_segment_cond_video = prev_segment_cond_video.unflatten(0, (batch_size, -1)).transpose(1, 2)
# Fill the remaining part of the cond video segment with zeros (if animating) or the background video (if
# replacing).
if task == "replace":
remaining_segment = background_video[:, :, prev_segment_cond_frames:].to(dtype)
else:
remaining_segment_frames = segment_frame_length - prev_segment_cond_frames
remaining_segment = torch.zeros(
batch_size, channels, remaining_segment_frames, height, width, dtype=dtype, device=device
)
# Prepend the conditioning frames from the previous segment to the remaining segment video in the frame dim
prev_segment_cond_video = prev_segment_cond_video.to(dtype=dtype)
full_segment_cond_video = torch.cat([prev_segment_cond_video, remaining_segment], dim=2)
if isinstance(generator, list):
if data_batch_size == len(generator):
prev_segment_cond_latents = [
retrieve_latents(self.vae.encode(full_segment_cond_video[i].unsqueeze(0)), g, sample_mode)
for i, g in enumerate(generator)
]
elif data_batch_size == 1:
# Like prepare_latents, assume len(generator) == batch_size
prev_segment_cond_latents = [
retrieve_latents(self.vae.encode(full_segment_cond_video), g, sample_mode) for g in generator
]
else:
raise ValueError(
f"The batch size of the prev segment video should be either {len(generator)} or 1 but is"
f" {data_batch_size}"
)
prev_segment_cond_latents = torch.cat(prev_segment_cond_latents)
else:
prev_segment_cond_latents = retrieve_latents(
self.vae.encode(full_segment_cond_video), generator, sample_mode
)
# Standardize latents in preparation for Wan VAE encode
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(prev_segment_cond_latents.device, prev_segment_cond_latents.dtype)
)
latents_recip_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
prev_segment_cond_latents.device, prev_segment_cond_latents.dtype
)
prev_segment_cond_latents = (prev_segment_cond_latents - latents_mean) * latents_recip_std
# Prepare I2V mask
if task == "replace":
mask_video = 1 - mask_video
mask_video = mask_video.permute(0, 2, 1, 3, 4)
mask_video = mask_video.flatten(0, 1)
mask_video = F.interpolate(mask_video, size=(latent_height, latent_width), mode="nearest")
mask_pixel_values = mask_video.unflatten(0, (batch_size, -1))
mask_pixel_values = mask_pixel_values.permute(0, 2, 1, 3, 4) # output shape: [B, C = 1, T, H_lat, W_lat]
else:
mask_pixel_values = None
prev_segment_cond_mask = self.get_i2v_mask(
batch_size,
num_latent_frames,
latent_height,
latent_width,
mask_len=prev_segment_cond_frames if start_frame > 0 else 0,
mask_pixel_values=mask_pixel_values,
dtype=dtype,
device=device,
)
# Prepend cond I2V mask to prev segment cond latents along channel dimension
prev_segment_cond_latents = torch.cat([prev_segment_cond_mask, prev_segment_cond_latents], dim=1)
return prev_segment_cond_latents
def prepare_pose_latents(
self,
pose_video: torch.Tensor,
batch_size: int = 1,
sample_mode: int = "argmax",
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
) -> torch.Tensor:
# pose_video shape: (B, C, T, H, W)
pose_video = pose_video.to(device=device, dtype=dtype if dtype is not None else self.vae.dtype)
if isinstance(generator, list):
pose_latents = [
retrieve_latents(self.vae.encode(pose_video), generator=g, sample_mode=sample_mode) for g in generator
]
pose_latents = torch.cat(pose_latents)
else:
pose_latents = retrieve_latents(self.vae.encode(pose_video), generator, sample_mode)
# Standardize latents in preparation for Wan VAE encode
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(pose_latents.device, pose_latents.dtype)
)
latents_recip_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
pose_latents.device, pose_latents.dtype
)
pose_latents = (pose_latents - latents_mean) * latents_recip_std
if pose_latents.shape[0] == 1 and batch_size > 1:
pose_latents = pose_latents.expand(batch_size, -1, -1, -1, -1)
return pose_latents
def prepare_latents(
self,
batch_size: int,
num_channels_latents: int = 16,
height: int = 720,
width: int = 1280,
num_frames: int = 77,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
latent_height = height // self.vae_scale_factor_spatial
latent_width = width // self.vae_scale_factor_spatial
shape = (batch_size, num_channels_latents, num_latent_frames + 1, latent_height, latent_width)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device=device, dtype=dtype)
return latents
def pad_video_frames(self, frames: List[Any], num_target_frames: int) -> List[Any]:
"""
Pads an array-like video `frames` to `num_target_frames` using a "reflect"-like strategy. The frame dimension
is assumed to be the first dimension. In the 1D case, we can visualize this strategy as follows:
pad_video_frames([1, 2, 3, 4, 5], 10) -> [1, 2, 3, 4, 5, 4, 3, 2, 1, 2]
"""
idx = 0
flip = False
target_frames = []
while len(target_frames) < num_target_frames:
target_frames.append(deepcopy(frames[idx]))
if flip:
idx -= 1
else:
idx += 1
if idx == 0 or idx == len(frames) - 1:
flip = not flip
return target_frames
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def num_timesteps(self):
return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@property
def attention_kwargs(self):
return self._attention_kwargs
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
image: PipelineImageInput,
pose_video: List[PIL.Image.Image],
face_video: List[PIL.Image.Image],
background_video: Optional[List[PIL.Image.Image]] = None,
mask_video: Optional[List[PIL.Image.Image]] = None,
prompt: Union[str, List[str]] = None,
negative_prompt: Union[str, List[str]] = None,
height: int = 720,
width: int = 1280,
segment_frame_length: int = 77,
num_inference_steps: int = 20,
mode: str = "animate",
prev_segment_conditioning_frames: int = 1,
motion_encode_batch_size: Optional[int] = None,
guidance_scale: float = 1.0,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
image_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "np",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
):
r"""
The call function to the pipeline for generation.
Args:
image (`PipelineImageInput`):
The input character image to condition the generation on. Must be an image, a list of images or a
`torch.Tensor`.
pose_video (`List[PIL.Image.Image]`):
The input pose video to condition the generation on. Must be a list of PIL images.
face_video (`List[PIL.Image.Image]`):
The input face video to condition the generation on. Must be a list of PIL images.
background_video (`List[PIL.Image.Image]`, *optional*):
When mode is `"replace"`, the input background video to condition the generation on. Must be a list of
PIL images.
mask_video (`List[PIL.Image.Image]`, *optional*):
When mode is `"replace"`, the input mask video to condition the generation on. Must be a list of PIL
images.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
mode (`str`, defaults to `"animation"`):
The mode of the generation. Choose between `"animate"` and `"replace"`.
prev_segment_conditioning_frames (`int`, defaults to `1`):
The number of frames from the previous video segment to be used for temporal guidance. Recommended to
be 1 or 5. In general, should be 4N + 1, where N is a non-negative integer.
motion_encode_batch_size (`int`, *optional*):
The batch size for batched encoding of the face video via the motion encoder. This allows trading off
inference speed for lower memory usage by setting a smaller batch size. Will default to
`self.transformer.config.motion_encoder_batch_size` if not set.
height (`int`, defaults to `720`):
The height of the generated video.
width (`int`, defaults to `1280`):
The width of the generated video.
segment_frame_length (`int`, defaults to `77`):
The number of frames in each generated video segment. The total frames of video generated will be equal
to the number of frames in `pose_video`; we will generate the video in segments until we have hit this
length. In general, should be 4N + 1, where N is a non-negative integer.
num_inference_steps (`int`, defaults to `20`):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, defaults to `1.0`):
Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
the text `prompt`, usually at the expense of lower image quality. By default, CFG is not used in Wan
Animate inference.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `negative_prompt` input argument.
image_embeds (`torch.Tensor`, *optional*):
Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided,
image embeddings are generated from the `image` input argument.
output_type (`str`, *optional*, defaults to `"np"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int`, defaults to `512`):
The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
truncated. If the prompt is shorter, it will be padded to this length.
Examples:
Returns:
[`~WanPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where
the first element is a list with the generated images and the second element is a list of `bool`s
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
negative_prompt,
image,
pose_video,
face_video,
background_video,
mask_video,
height,
width,
prompt_embeds,
negative_prompt_embeds,
image_embeds,
callback_on_step_end_tensor_inputs,
mode,
prev_segment_conditioning_frames,
)
if segment_frame_length % self.vae_scale_factor_temporal != 1:
logger.warning(
f"`segment_frame_length - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the"
f" nearest number."
)
segment_frame_length = (
segment_frame_length // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
)
segment_frame_length = max(segment_frame_length, 1)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
device = self._execution_device
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# As we generate in segments of `segment_frame_length`, set the target frame length to be the least multiple
# of the effective segment length greater than or equal to the length of `pose_video`.
cond_video_frames = len(pose_video)
effective_segment_length = segment_frame_length - prev_segment_conditioning_frames
last_segment_frames = (cond_video_frames - prev_segment_conditioning_frames) % effective_segment_length
if last_segment_frames == 0:
num_padding_frames = 0
else:
num_padding_frames = effective_segment_length - last_segment_frames
num_target_frames = cond_video_frames + num_padding_frames
num_segments = num_target_frames // effective_segment_length
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
max_sequence_length=max_sequence_length,
device=device,
)
transformer_dtype = self.transformer.dtype
prompt_embeds = prompt_embeds.to(transformer_dtype)
if negative_prompt_embeds is not None:
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
# 4. Preprocess and encode the reference (character) image
image_height, image_width = self.video_processor.get_default_height_width(image)
if image_height != height or image_width != width:
logger.warning(f"Reshaping reference image from ({image_width}, {image_height}) to ({width}, {height})")
image_pixels = self.vae_image_processor.preprocess(image, height=height, width=width, resize_mode="fill").to(
device, dtype=torch.float32
)
# Get CLIP features from the reference image
if image_embeds is None:
image_embeds = self.encode_image(image, device)
image_embeds = image_embeds.repeat(batch_size * num_videos_per_prompt, 1, 1)
image_embeds = image_embeds.to(transformer_dtype)
# 5. Encode conditioning videos (pose, face)
pose_video = self.pad_video_frames(pose_video, num_target_frames)
face_video = self.pad_video_frames(face_video, num_target_frames)
# TODO: also support np.ndarray input (e.g. from decord like the original implementation?)
pose_video_width, pose_video_height = pose_video[0].size
if pose_video_height != height or pose_video_width != width:
logger.warning(
f"Reshaping pose video from ({pose_video_width}, {pose_video_height}) to ({width}, {height})"
)
pose_video = self.video_processor.preprocess_video(pose_video, height=height, width=width).to(
device, dtype=torch.float32
)
face_video_width, face_video_height = face_video[0].size
expected_face_size = self.transformer.config.motion_encoder_size
if face_video_width != expected_face_size or face_video_height != expected_face_size:
logger.warning(
f"Reshaping face video from ({face_video_width}, {face_video_height}) to ({expected_face_size},"
f" {expected_face_size})"
)
face_video = self.video_processor.preprocess_video(
face_video, height=expected_face_size, width=expected_face_size
).to(device, dtype=torch.float32)
if mode == "replace":
background_video = self.pad_video_frames(background_video, num_target_frames)
mask_video = self.pad_video_frames(mask_video, num_target_frames)
background_video = self.video_processor.preprocess_video(background_video, height=height, width=width).to(
device, dtype=torch.float32
)
mask_video = self.video_processor_for_mask.preprocess_video(mask_video, height=height, width=width).to(
device, dtype=torch.float32
)
# 6. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 7. Prepare latent variables which stay constant for all inference segments
num_channels_latents = self.vae.config.z_dim
# Get VAE-encoded latents of the reference (character) image
reference_image_latents = self.prepare_reference_image_latents(
image_pixels, batch_size * num_videos_per_prompt, generator=generator, device=device
)
# 8. Loop over video inference segments
start = 0
end = segment_frame_length # Data space frames, not latent frames
all_out_frames = []
out_frames = None
for _ in range(num_segments):
assert start + prev_segment_conditioning_frames < cond_video_frames
# Sample noisy latents from prior for the current inference segment
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents=num_channels_latents,
height=height,
width=width,
num_frames=segment_frame_length,
dtype=torch.float32,
device=device,
generator=generator,
latents=latents if start == 0 else None, # Only use pre-calculated latents for first segment
)
pose_video_segment = pose_video[:, :, start:end]
face_video_segment = face_video[:, :, start:end]
face_video_segment = face_video_segment.expand(batch_size * num_videos_per_prompt, -1, -1, -1, -1)
face_video_segment = face_video_segment.to(dtype=transformer_dtype)
if start > 0:
prev_segment_cond_video = out_frames[:, :, -prev_segment_conditioning_frames:].clone().detach()
else:
prev_segment_cond_video = None
if mode == "replace":
background_video_segment = background_video[:, :, start:end]
mask_video_segment = mask_video[:, :, start:end]
background_video_segment = background_video_segment.expand(
batch_size * num_videos_per_prompt, -1, -1, -1, -1
)
mask_video_segment = mask_video_segment.expand(batch_size * num_videos_per_prompt, -1, -1, -1, -1)
else:
background_video_segment = None
mask_video_segment = None
pose_latents = self.prepare_pose_latents(
pose_video_segment, batch_size * num_videos_per_prompt, generator=generator, device=device
)
pose_latents = pose_latents.to(dtype=transformer_dtype)
prev_segment_cond_latents = self.prepare_prev_segment_cond_latents(
prev_segment_cond_video,
background_video=background_video_segment,
mask_video=mask_video_segment,
batch_size=batch_size * num_videos_per_prompt,
segment_frame_length=segment_frame_length,
start_frame=start,
height=height,
width=width,
prev_segment_cond_frames=prev_segment_conditioning_frames,
task=mode,
generator=generator,
device=device,
)
# Concatenate the reference latents in the frame dimension
reference_latents = torch.cat([reference_image_latents, prev_segment_cond_latents], dim=2)
# 8.1 Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
# Concatenate the reference image + prev segment conditioning in the channel dim
latent_model_input = torch.cat([latents, reference_latents], dim=1).to(transformer_dtype)
timestep = t.expand(latents.shape[0])
with self.transformer.cache_context("cond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
encoder_hidden_states_image=image_embeds,
pose_hidden_states=pose_latents,
face_pixel_values=face_video_segment,
motion_encode_batch_size=motion_encode_batch_size,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
# Blank out face for unconditional guidance (set all pixels to -1)
face_pixel_values_uncond = face_video_segment * 0 - 1
with self.transformer.cache_context("uncond"):
noise_uncond = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
encoder_hidden_states_image=image_embeds,
pose_hidden_states=pose_latents,
face_pixel_values=face_pixel_values_uncond,
motion_encode_batch_size=motion_encode_batch_size,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
latents = latents.to(self.vae.dtype)
# Destandardize latents in preparation for Wan VAE decoding
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_recip_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(
1, self.vae.config.z_dim, 1, 1, 1
).to(latents.device, latents.dtype)
latents = latents / latents_recip_std + latents_mean
# Skip the first latent frame (used for conditioning)
out_frames = self.vae.decode(latents[:, :, 1:], return_dict=False)[0]
if start > 0:
out_frames = out_frames[:, :, prev_segment_conditioning_frames:]
all_out_frames.append(out_frames)
start += effective_segment_length
end += effective_segment_length
# Reset scheduler timesteps / state for next denoising loop
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
self._current_timestep = None
assert start + prev_segment_conditioning_frames >= cond_video_frames
if not output_type == "latent":
video = torch.cat(all_out_frames, dim=2)[:, :, :cond_video_frames]
video = self.video_processor.postprocess_video(video, output_type=output_type)
else:
video = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return WanPipelineOutput(frames=video)
......@@ -1623,6 +1623,21 @@ class VQModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class WanAnimateTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class WanTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
......
......@@ -3512,6 +3512,21 @@ class VQDiffusionPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class WanAnimatePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class WanImageToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
......
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import WanAnimateTransformer3DModel
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
class WanAnimateTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = WanAnimateTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def dummy_input(self):
batch_size = 1
num_channels = 4
num_frames = 20 # To make the shapes work out; for complicated reasons we want 21 to divide num_frames + 1
height = 16
width = 16
text_encoder_embedding_dim = 16
sequence_length = 12
clip_seq_len = 12
clip_dim = 16
inference_segment_length = 77 # The inference segment length in the full Wan2.2-Animate-14B model
face_height = 16 # Should be square and match `motion_encoder_size` below
face_width = 16
hidden_states = torch.randn((batch_size, 2 * num_channels + 4, num_frames + 1, height, width)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
clip_ref_features = torch.randn((batch_size, clip_seq_len, clip_dim)).to(torch_device)
pose_latents = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
face_pixel_values = torch.randn((batch_size, 3, inference_segment_length, face_height, face_width)).to(
torch_device
)
return {
"hidden_states": hidden_states,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
"encoder_hidden_states_image": clip_ref_features,
"pose_hidden_states": pose_latents,
"face_pixel_values": face_pixel_values,
}
@property
def input_shape(self):
return (12, 1, 16, 16)
@property
def output_shape(self):
return (4, 1, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
# Use custom channel sizes since the default Wan Animate channel sizes will cause the motion encoder to
# contain the vast majority of the parameters in the test model
channel_sizes = {"4": 16, "8": 16, "16": 16}
init_dict = {
"patch_size": (1, 2, 2),
"num_attention_heads": 2,
"attention_head_dim": 12,
"in_channels": 12, # 2 * C + 4 = 2 * 4 + 4 = 12
"latent_channels": 4,
"out_channels": 4,
"text_dim": 16,
"freq_dim": 256,
"ffn_dim": 32,
"num_layers": 2,
"cross_attn_norm": True,
"qk_norm": "rms_norm_across_heads",
"image_dim": 16,
"rope_max_seq_len": 32,
"motion_encoder_channel_sizes": channel_sizes, # Start of Wan Animate-specific config
"motion_encoder_size": 16, # Ensures that there will be 2 motion encoder resblocks
"motion_style_dim": 8,
"motion_dim": 4,
"motion_encoder_dim": 16,
"face_encoder_hidden_dim": 16,
"face_encoder_num_heads": 2,
"inject_face_latents_blocks": 2,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"WanAnimateTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
# Override test_output because the transformer output is expected to have less channels than the main transformer
# input.
def test_output(self):
expected_output_shape = (1, 4, 21, 16, 16)
super().test_output(expected_output_shape=expected_output_shape)
class WanAnimateTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = WanAnimateTransformer3DModel
def prepare_init_args_and_inputs_for_common(self):
return WanAnimateTransformer3DTests().prepare_init_args_and_inputs_for_common()
# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import numpy as np
import torch
from PIL import Image
from transformers import (
AutoTokenizer,
CLIPImageProcessor,
CLIPVisionConfig,
CLIPVisionModelWithProjection,
T5EncoderModel,
)
from diffusers import (
AutoencoderKLWan,
FlowMatchEulerDiscreteScheduler,
WanAnimatePipeline,
WanAnimateTransformer3DModel,
)
from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_torch_accelerator,
slow,
torch_device,
)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
class WanAnimatePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = WanAnimatePipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
test_xformers_attention = False
supports_dduf = False
def get_dummy_components(self):
torch.manual_seed(0)
vae = AutoencoderKLWan(
base_dim=3,
z_dim=16,
dim_mult=[1, 1, 1, 1],
num_res_blocks=1,
temperal_downsample=[False, True, True],
)
torch.manual_seed(0)
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
channel_sizes = {"4": 16, "8": 16, "16": 16}
transformer = WanAnimateTransformer3DModel(
patch_size=(1, 2, 2),
num_attention_heads=2,
attention_head_dim=12,
in_channels=36,
latent_channels=16,
out_channels=16,
text_dim=32,
freq_dim=256,
ffn_dim=32,
num_layers=2,
cross_attn_norm=True,
qk_norm="rms_norm_across_heads",
image_dim=4,
rope_max_seq_len=32,
motion_encoder_channel_sizes=channel_sizes,
motion_encoder_size=16,
motion_style_dim=8,
motion_dim=4,
motion_encoder_dim=16,
face_encoder_hidden_dim=16,
face_encoder_num_heads=2,
inject_face_latents_blocks=2,
)
torch.manual_seed(0)
image_encoder_config = CLIPVisionConfig(
hidden_size=4,
projection_dim=4,
num_hidden_layers=2,
num_attention_heads=2,
image_size=4,
intermediate_size=16,
patch_size=1,
)
image_encoder = CLIPVisionModelWithProjection(image_encoder_config)
torch.manual_seed(0)
image_processor = CLIPImageProcessor(crop_size=4, size=4)
components = {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"image_encoder": image_encoder,
"image_processor": image_processor,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
num_frames = 17
height = 16
width = 16
face_height = 16
face_width = 16
image = Image.new("RGB", (height, width))
pose_video = [Image.new("RGB", (height, width))] * num_frames
face_video = [Image.new("RGB", (face_height, face_width))] * num_frames
inputs = {
"image": image,
"pose_video": pose_video,
"face_video": face_video,
"prompt": "dance monkey",
"negative_prompt": "negative",
"height": height,
"width": width,
"segment_frame_length": 77, # TODO: can we set this to num_frames?
"num_inference_steps": 2,
"mode": "animate",
"prev_segment_conditioning_frames": 1,
"generator": generator,
"guidance_scale": 1.0,
"output_type": "pt",
"max_sequence_length": 16,
}
return inputs
def test_inference(self):
"""Test basic inference in animation mode."""
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames[0]
self.assertEqual(video.shape, (17, 3, 16, 16))
expected_video = torch.randn(17, 3, 16, 16)
max_diff = np.abs(video - expected_video).max()
self.assertLessEqual(max_diff, 1e10)
def test_inference_replacement(self):
"""Test the pipeline in replacement mode with background and mask videos."""
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["mode"] = "replace"
num_frames = 17
height = 16
width = 16
inputs["background_video"] = [Image.new("RGB", (height, width))] * num_frames
inputs["mask_video"] = [Image.new("L", (height, width))] * num_frames
video = pipe(**inputs).frames[0]
self.assertEqual(video.shape, (17, 3, 16, 16))
@unittest.skip("Test not supported")
def test_attention_slicing_forward_pass(self):
pass
@unittest.skip(
"Setting the Wan Animate latents to zero at the last denoising step does not guarantee that the output will be"
" zero. I believe this is because the latents are further processed in the outer loop where we loop over"
" inference segments."
)
def test_callback_inputs(self):
pass
@slow
@require_torch_accelerator
class WanAnimatePipelineIntegrationTests(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger."
def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
@unittest.skip("TODO: test needs to be implemented")
def test_wan_animate(self):
pass
......@@ -16,6 +16,7 @@ from diffusers import (
HiDreamImageTransformer2DModel,
SD3Transformer2DModel,
StableDiffusion3Pipeline,
WanAnimateTransformer3DModel,
WanTransformer3DModel,
WanVACETransformer3DModel,
)
......@@ -721,6 +722,33 @@ class WanVACEGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
}
class WanAnimateGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
ckpt_path = "https://huggingface.co/QuantStack/Wan2.2-Animate-14B-GGUF/blob/main/Wan2.2-Animate-14B-Q3_K_S.gguf"
torch_dtype = torch.bfloat16
model_cls = WanAnimateTransformer3DModel
expected_memory_use_in_gb = 9
def get_dummy_inputs(self):
return {
"hidden_states": torch.randn((1, 16, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
torch_device, self.torch_dtype
),
"encoder_hidden_states": torch.randn(
(1, 512, 4096),
generator=torch.Generator("cpu").manual_seed(0),
).to(torch_device, self.torch_dtype),
"control_hidden_states": torch.randn(
(1, 96, 2, 64, 64),
generator=torch.Generator("cpu").manual_seed(0),
).to(torch_device, self.torch_dtype),
"control_hidden_states_scale": torch.randn(
(8,),
generator=torch.Generator("cpu").manual_seed(0),
).to(torch_device, self.torch_dtype),
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
}
@require_torch_version_greater("2.7.1")
class GGUFCompileTests(QuantCompileTests, unittest.TestCase):
torch_dtype = torch.bfloat16
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment