Unverified Commit 73a9d585 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Wan VACE (#11582)

* initial support

* make fix-copies

* fix no split modules

* add conversion script

* refactor

* add pipeline test

* refactor

* fix bug with mask

* fix for reference images

* remove print

* update docs

* update slices

* update

* update

* update example
parent 16c955c5
...@@ -22,17 +22,30 @@ ...@@ -22,17 +22,30 @@
# Wan2.1 # Wan2.1
[Wan2.1](https://files.alicdn.com/tpsservice/5c9de1c74de03972b7aa657e5a54756b.pdf) is a series of large diffusion transformer available in two versions, a high-performance 14B parameter model and a more accessible 1.3B version. Trained on billions of images and videos, it supports tasks like text-to-video (T2V) and image-to-video (I2V) while enabling features such as camera control and stylistic diversity. The Wan-VAE features better image data compression and a feature cache mechanism that encodes and decodes a video in chunks. To maintain continuity, features from previous chunks are cached and reused for processing subsequent chunks. This improves inference efficiency by reducing memory usage. Wan2.1 also uses a multilingual text encoder and the diffusion transformer models space and time relationships and text conditions with each time step to capture more complex video dynamics. [Wan-2.1](https://huggingface.co/papers/2503.20314) by the Wan Team.
*This report presents Wan, a comprehensive and open suite of video foundation models designed to push the boundaries of video generation. Built upon the mainstream diffusion transformer paradigm, Wan achieves significant advancements in generative capabilities through a series of innovations, including our novel VAE, scalable pre-training strategies, large-scale data curation, and automated evaluation metrics. These contributions collectively enhance the model's performance and versatility. Specifically, Wan is characterized by four key features: Leading Performance: The 14B model of Wan, trained on a vast dataset comprising billions of images and videos, demonstrates the scaling laws of video generation with respect to both data and model size. It consistently outperforms the existing open-source models as well as state-of-the-art commercial solutions across multiple internal and external benchmarks, demonstrating a clear and significant performance superiority. Comprehensiveness: Wan offers two capable models, i.e., 1.3B and 14B parameters, for efficiency and effectiveness respectively. It also covers multiple downstream applications, including image-to-video, instruction-guided video editing, and personal video generation, encompassing up to eight tasks. Consumer-Grade Efficiency: The 1.3B model demonstrates exceptional resource efficiency, requiring only 8.19 GB VRAM, making it compatible with a wide range of consumer-grade GPUs. Openness: We open-source the entire series of Wan, including source code and all models, with the goal of fostering the growth of the video generation community. This openness seeks to significantly expand the creative possibilities of video production in the industry and provide academia with high-quality video foundation models. All the code and models are available at [this https URL](https://github.com/Wan-Video/Wan2.1).*
You can find all the original Wan2.1 checkpoints under the [Wan-AI](https://huggingface.co/Wan-AI) organization. You can find all the original Wan2.1 checkpoints under the [Wan-AI](https://huggingface.co/Wan-AI) organization.
The following Wan models are supported in Diffusers:
- [Wan 2.1 T2V 1.3B](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B-Diffusers)
- [Wan 2.1 T2V 14B](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B-Diffusers)
- [Wan 2.1 I2V 14B - 480P](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P-Diffusers)
- [Wan 2.1 I2V 14B - 720P](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P-Diffusers)
- [Wan 2.1 FLF2V 14B - 720P](https://huggingface.co/Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers)
- [Wan 2.1 VACE 1.3B](https://huggingface.co/Wan-AI/Wan2.1-VACE-1.3B-diffusers)
- [Wan 2.1 VACE 14B](https://huggingface.co/Wan-AI/Wan2.1-VACE-14B-diffusers)
> [!TIP] > [!TIP]
> Click on the Wan2.1 models in the right sidebar for more examples of video generation. > Click on the Wan2.1 models in the right sidebar for more examples of video generation.
### Text-to-Video Generation
The example below demonstrates how to generate a video from text optimized for memory or inference speed. The example below demonstrates how to generate a video from text optimized for memory or inference speed.
<hfoptions id="usage"> <hfoptions id="T2V usage">
<hfoption id="memory"> <hfoption id="T2V memory">
Refer to the [Reduce memory usage](../../optimization/memory) guide for more details about the various memory saving techniques. Refer to the [Reduce memory usage](../../optimization/memory) guide for more details about the various memory saving techniques.
...@@ -100,7 +113,7 @@ export_to_video(output, "output.mp4", fps=16) ...@@ -100,7 +113,7 @@ export_to_video(output, "output.mp4", fps=16)
``` ```
</hfoption> </hfoption>
<hfoption id="inference speed"> <hfoption id="T2V inference speed">
[Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster. [Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster.
...@@ -157,6 +170,81 @@ export_to_video(output, "output.mp4", fps=16) ...@@ -157,6 +170,81 @@ export_to_video(output, "output.mp4", fps=16)
</hfoption> </hfoption>
</hfoptions> </hfoptions>
### First-Last-Frame-to-Video Generation
The example below demonstrates how to use the image-to-video pipeline to generate a video using a text description, a starting frame, and an ending frame.
<hfoptions id="FLF2V usage">
<hfoption id="usage">
```python
import numpy as np
import torch
import torchvision.transforms.functional as TF
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
from diffusers.utils import export_to_video, load_image
from transformers import CLIPVisionModel
model_id = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanImageToVideoPipeline.from_pretrained(
model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
)
pipe.to("cuda")
first_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png")
last_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png")
def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
return image, height, width
def center_crop_resize(image, height, width):
# Calculate resize ratio to match first frame dimensions
resize_ratio = max(width / image.width, height / image.height)
# Resize the image
width = round(image.width * resize_ratio)
height = round(image.height * resize_ratio)
size = [width, height]
image = TF.center_crop(image, size)
return image, height, width
first_frame, height, width = aspect_ratio_resize(first_frame, pipe)
if last_frame.size != first_frame.size:
last_frame, _, _ = center_crop_resize(last_frame, height, width)
prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
output = pipe(
image=first_frame, last_image=last_frame, prompt=prompt, height=height, width=width, guidance_scale=5.5
).frames[0]
export_to_video(output, "output.mp4", fps=16)
```
</hfoption>
</hfoptions>
### Any-to-Video Controllable Generation
Wan VACE supports various generation techniques which achieve controllable video generation. Some of the capabilities include:
- Control to Video (Depth, Pose, Sketch, Flow, Grayscale, Scribble, Layout, Boundary Box, etc.). Recommended library for preprocessing videos to obtain control videos: [huggingface/controlnet_aux]()
- Image/Video to Video (first frame, last frame, starting clip, ending clip, random clips)
- Inpainting and Outpainting
- Subject to Video (faces, object, characters, etc.)
- Composition to Video (reference anything, animate anything, swap anything, expand anything, move anything, etc.)
The code snippets available in [this](https://github.com/huggingface/diffusers/pull/11582) pull request demonstrate some examples of how videos can be generated with controllability signals.
The general rule of thumb to keep in mind when preparing inputs for the VACE pipeline is that the input images, or frames of a video that you want to use for conditioning, should have a corresponding mask that is black in color. The black mask signifies that the model will not generate new content for that area, and only use those parts for conditioning the generation process. For parts/frames that should be generated by the model, the mask should be white in color.
## Notes ## Notes
- Wan2.1 supports LoRAs with [`~loaders.WanLoraLoaderMixin.load_lora_weights`]. - Wan2.1 supports LoRAs with [`~loaders.WanLoraLoaderMixin.load_lora_weights`].
...@@ -251,6 +339,18 @@ export_to_video(output, "output.mp4", fps=16) ...@@ -251,6 +339,18 @@ export_to_video(output, "output.mp4", fps=16)
- all - all
- __call__ - __call__
## WanVACEPipeline
[[autodoc]] WanVACEPipeline
- all
- __call__
## WanVideoToVideoPipeline
[[autodoc]] WanVideoToVideoPipeline
- all
- __call__
## WanPipelineOutput ## WanPipelineOutput
[[autodoc]] pipelines.wan.pipeline_output.WanPipelineOutput [[autodoc]] pipelines.wan.pipeline_output.WanPipelineOutput
\ No newline at end of file
import argparse import argparse
import pathlib import pathlib
from typing import Any, Dict from typing import Any, Dict, Tuple
import torch import torch
from accelerate import init_empty_weights from accelerate import init_empty_weights
...@@ -14,6 +14,8 @@ from diffusers import ( ...@@ -14,6 +14,8 @@ from diffusers import (
WanImageToVideoPipeline, WanImageToVideoPipeline,
WanPipeline, WanPipeline,
WanTransformer3DModel, WanTransformer3DModel,
WanVACEPipeline,
WanVACETransformer3DModel,
) )
...@@ -59,7 +61,52 @@ TRANSFORMER_KEYS_RENAME_DICT = { ...@@ -59,7 +61,52 @@ TRANSFORMER_KEYS_RENAME_DICT = {
"attn2.norm_k_img": "attn2.norm_added_k", "attn2.norm_k_img": "attn2.norm_added_k",
} }
VACE_TRANSFORMER_KEYS_RENAME_DICT = {
"time_embedding.0": "condition_embedder.time_embedder.linear_1",
"time_embedding.2": "condition_embedder.time_embedder.linear_2",
"text_embedding.0": "condition_embedder.text_embedder.linear_1",
"text_embedding.2": "condition_embedder.text_embedder.linear_2",
"time_projection.1": "condition_embedder.time_proj",
"head.modulation": "scale_shift_table",
"head.head": "proj_out",
"modulation": "scale_shift_table",
"ffn.0": "ffn.net.0.proj",
"ffn.2": "ffn.net.2",
# Hack to swap the layer names
# The original model calls the norms in following order: norm1, norm3, norm2
# We convert it to: norm1, norm2, norm3
"norm2": "norm__placeholder",
"norm3": "norm2",
"norm__placeholder": "norm3",
# # For the I2V model
# "img_emb.proj.0": "condition_embedder.image_embedder.norm1",
# "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
# "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
# "img_emb.proj.4": "condition_embedder.image_embedder.norm2",
# # for the FLF2V model
# "img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed",
# Add attention component mappings
"self_attn.q": "attn1.to_q",
"self_attn.k": "attn1.to_k",
"self_attn.v": "attn1.to_v",
"self_attn.o": "attn1.to_out.0",
"self_attn.norm_q": "attn1.norm_q",
"self_attn.norm_k": "attn1.norm_k",
"cross_attn.q": "attn2.to_q",
"cross_attn.k": "attn2.to_k",
"cross_attn.v": "attn2.to_v",
"cross_attn.o": "attn2.to_out.0",
"cross_attn.norm_q": "attn2.norm_q",
"cross_attn.norm_k": "attn2.norm_k",
"attn2.to_k_img": "attn2.add_k_proj",
"attn2.to_v_img": "attn2.add_v_proj",
"attn2.norm_k_img": "attn2.norm_added_k",
"before_proj": "proj_in",
"after_proj": "proj_out",
}
TRANSFORMER_SPECIAL_KEYS_REMAP = {} TRANSFORMER_SPECIAL_KEYS_REMAP = {}
VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = {}
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
...@@ -74,7 +121,7 @@ def load_sharded_safetensors(dir: pathlib.Path): ...@@ -74,7 +121,7 @@ def load_sharded_safetensors(dir: pathlib.Path):
return state_dict return state_dict
def get_transformer_config(model_type: str) -> Dict[str, Any]: def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
if model_type == "Wan-T2V-1.3B": if model_type == "Wan-T2V-1.3B":
config = { config = {
"model_id": "StevenZhang/Wan2.1-T2V-1.3B-Diff", "model_id": "StevenZhang/Wan2.1-T2V-1.3B-Diff",
...@@ -94,6 +141,8 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]: ...@@ -94,6 +141,8 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
"text_dim": 4096, "text_dim": 4096,
}, },
} }
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
elif model_type == "Wan-T2V-14B": elif model_type == "Wan-T2V-14B":
config = { config = {
"model_id": "StevenZhang/Wan2.1-T2V-14B-Diff", "model_id": "StevenZhang/Wan2.1-T2V-14B-Diff",
...@@ -113,6 +162,8 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]: ...@@ -113,6 +162,8 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
"text_dim": 4096, "text_dim": 4096,
}, },
} }
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
elif model_type == "Wan-I2V-14B-480p": elif model_type == "Wan-I2V-14B-480p":
config = { config = {
"model_id": "StevenZhang/Wan2.1-I2V-14B-480P-Diff", "model_id": "StevenZhang/Wan2.1-I2V-14B-480P-Diff",
...@@ -133,6 +184,8 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]: ...@@ -133,6 +184,8 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
"text_dim": 4096, "text_dim": 4096,
}, },
} }
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
elif model_type == "Wan-I2V-14B-720p": elif model_type == "Wan-I2V-14B-720p":
config = { config = {
"model_id": "StevenZhang/Wan2.1-I2V-14B-720P-Diff", "model_id": "StevenZhang/Wan2.1-I2V-14B-720P-Diff",
...@@ -153,6 +206,8 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]: ...@@ -153,6 +206,8 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
"text_dim": 4096, "text_dim": 4096,
}, },
} }
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
elif model_type == "Wan-FLF2V-14B-720P": elif model_type == "Wan-FLF2V-14B-720P":
config = { config = {
"model_id": "ypyp/Wan2.1-FLF2V-14B-720P", # This is just a placeholder "model_id": "ypyp/Wan2.1-FLF2V-14B-720P", # This is just a placeholder
...@@ -175,11 +230,60 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]: ...@@ -175,11 +230,60 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
"pos_embed_seq_len": 257 * 2, "pos_embed_seq_len": 257 * 2,
}, },
} }
return config RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
elif model_type == "Wan-VACE-1.3B":
config = {
"model_id": "Wan-AI/Wan2.1-VACE-1.3B",
"diffusers_config": {
"added_kv_proj_dim": None,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 8960,
"freq_dim": 256,
"in_channels": 16,
"num_attention_heads": 12,
"num_layers": 30,
"out_channels": 16,
"patch_size": [1, 2, 2],
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
"vace_layers": [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28],
"vace_in_channels": 96,
},
}
RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
elif model_type == "Wan-VACE-14B":
config = {
"model_id": "Wan-AI/Wan2.1-VACE-14B",
"diffusers_config": {
"added_kv_proj_dim": None,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_channels": 16,
"num_attention_heads": 40,
"num_layers": 40,
"out_channels": 16,
"patch_size": [1, 2, 2],
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
"vace_layers": [0, 5, 10, 15, 20, 25, 30, 35],
"vace_in_channels": 96,
},
}
RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
return config, RENAME_DICT, SPECIAL_KEYS_REMAP
def convert_transformer(model_type: str): def convert_transformer(model_type: str):
config = get_transformer_config(model_type) config, RENAME_DICT, SPECIAL_KEYS_REMAP = get_transformer_config(model_type)
diffusers_config = config["diffusers_config"] diffusers_config = config["diffusers_config"]
model_id = config["model_id"] model_id = config["model_id"]
model_dir = pathlib.Path(snapshot_download(model_id, repo_type="model")) model_dir = pathlib.Path(snapshot_download(model_id, repo_type="model"))
...@@ -187,16 +291,19 @@ def convert_transformer(model_type: str): ...@@ -187,16 +291,19 @@ def convert_transformer(model_type: str):
original_state_dict = load_sharded_safetensors(model_dir) original_state_dict = load_sharded_safetensors(model_dir)
with init_empty_weights(): with init_empty_weights():
if "VACE" not in model_type:
transformer = WanTransformer3DModel.from_config(diffusers_config) transformer = WanTransformer3DModel.from_config(diffusers_config)
else:
transformer = WanVACETransformer3DModel.from_config(diffusers_config)
for key in list(original_state_dict.keys()): for key in list(original_state_dict.keys()):
new_key = key[:] new_key = key[:]
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): for replace_key, rename_key in RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key) new_key = new_key.replace(replace_key, rename_key)
update_state_dict_(original_state_dict, key, new_key) update_state_dict_(original_state_dict, key, new_key)
for key in list(original_state_dict.keys()): for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): for special_key, handler_fn_inplace in SPECIAL_KEYS_REMAP.items():
if special_key not in key: if special_key not in key:
continue continue
handler_fn_inplace(key, original_state_dict) handler_fn_inplace(key, original_state_dict)
...@@ -412,7 +519,7 @@ def get_args(): ...@@ -412,7 +519,7 @@ def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_type", type=str, default=None) parser.add_argument("--model_type", type=str, default=None)
parser.add_argument("--output_path", type=str, required=True) parser.add_argument("--output_path", type=str, required=True)
parser.add_argument("--dtype", default="fp32") parser.add_argument("--dtype", default="fp32", choices=["fp32", "fp16", "bf16", "none"])
return parser.parse_args() return parser.parse_args()
...@@ -426,18 +533,20 @@ DTYPE_MAPPING = { ...@@ -426,18 +533,20 @@ DTYPE_MAPPING = {
if __name__ == "__main__": if __name__ == "__main__":
args = get_args() args = get_args()
transformer = None transformer = convert_transformer(args.model_type)
dtype = DTYPE_MAPPING[args.dtype]
transformer = convert_transformer(args.model_type).to(dtype=dtype)
vae = convert_vae() vae = convert_vae()
text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl") text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl") tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
flow_shift = 16.0 if "FLF2V" in args.model_type else 3.0 flow_shift = 16.0 if "FLF2V" in args.model_type else 3.0
scheduler = UniPCMultistepScheduler( scheduler = UniPCMultistepScheduler(
prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift
) )
# If user has specified "none", we keep the original dtypes of the state dict without any conversion
if args.dtype != "none":
dtype = DTYPE_MAPPING[args.dtype]
transformer.to(dtype)
if "I2V" in args.model_type or "FLF2V" in args.model_type: if "I2V" in args.model_type or "FLF2V" in args.model_type:
image_encoder = CLIPVisionModelWithProjection.from_pretrained( image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16 "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16
...@@ -452,6 +561,14 @@ if __name__ == "__main__": ...@@ -452,6 +561,14 @@ if __name__ == "__main__":
image_encoder=image_encoder, image_encoder=image_encoder,
image_processor=image_processor, image_processor=image_processor,
) )
elif "VACE" in args.model_type:
pipe = WanVACEPipeline(
transformer=transformer,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
scheduler=scheduler,
)
else: else:
pipe = WanPipeline( pipe = WanPipeline(
transformer=transformer, transformer=transformer,
......
...@@ -215,6 +215,7 @@ else: ...@@ -215,6 +215,7 @@ else:
"UVit2DModel", "UVit2DModel",
"VQModel", "VQModel",
"WanTransformer3DModel", "WanTransformer3DModel",
"WanVACETransformer3DModel",
] ]
) )
_import_structure["optimization"] = [ _import_structure["optimization"] = [
...@@ -527,6 +528,7 @@ else: ...@@ -527,6 +528,7 @@ else:
"VQDiffusionPipeline", "VQDiffusionPipeline",
"WanImageToVideoPipeline", "WanImageToVideoPipeline",
"WanPipeline", "WanPipeline",
"WanVACEPipeline",
"WanVideoToVideoPipeline", "WanVideoToVideoPipeline",
"WuerstchenCombinedPipeline", "WuerstchenCombinedPipeline",
"WuerstchenDecoderPipeline", "WuerstchenDecoderPipeline",
...@@ -821,6 +823,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -821,6 +823,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
UVit2DModel, UVit2DModel,
VQModel, VQModel,
WanTransformer3DModel, WanTransformer3DModel,
WanVACETransformer3DModel,
) )
from .optimization import ( from .optimization import (
get_constant_schedule, get_constant_schedule,
...@@ -1112,6 +1115,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -1112,6 +1115,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
VQDiffusionPipeline, VQDiffusionPipeline,
WanImageToVideoPipeline, WanImageToVideoPipeline,
WanPipeline, WanPipeline,
WanVACEPipeline,
WanVideoToVideoPipeline, WanVideoToVideoPipeline,
WuerstchenCombinedPipeline, WuerstchenCombinedPipeline,
WuerstchenDecoderPipeline, WuerstchenDecoderPipeline,
......
...@@ -58,6 +58,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = { ...@@ -58,6 +58,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
"CogView4Transformer2DModel": lambda model_cls, weights: weights, "CogView4Transformer2DModel": lambda model_cls, weights: weights,
"HiDreamImageTransformer2DModel": lambda model_cls, weights: weights, "HiDreamImageTransformer2DModel": lambda model_cls, weights: weights,
"HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights, "HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights,
"WanVACETransformer3DModel": lambda model_cls, weights: weights,
} }
......
...@@ -89,6 +89,7 @@ if is_torch_available(): ...@@ -89,6 +89,7 @@ if is_torch_available():
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"] _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
_import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"] _import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"]
_import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"]
_import_structure["unets.unet_1d"] = ["UNet1DModel"] _import_structure["unets.unet_1d"] = ["UNet1DModel"]
_import_structure["unets.unet_2d"] = ["UNet2DModel"] _import_structure["unets.unet_2d"] = ["UNet2DModel"]
_import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"] _import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"]
...@@ -178,6 +179,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -178,6 +179,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Transformer2DModel, Transformer2DModel,
TransformerTemporalModel, TransformerTemporalModel,
WanTransformer3DModel, WanTransformer3DModel,
WanVACETransformer3DModel,
) )
from .unets import ( from .unets import (
I2VGenXLUNet, I2VGenXLUNet,
......
...@@ -32,3 +32,4 @@ if is_torch_available(): ...@@ -32,3 +32,4 @@ if is_torch_available():
from .transformer_sd3 import SD3Transformer2DModel from .transformer_sd3 import SD3Transformer2DModel
from .transformer_temporal import TransformerTemporalModel from .transformer_temporal import TransformerTemporalModel
from .transformer_wan import WanTransformer3DModel from .transformer_wan import WanTransformer3DModel
from .transformer_wan_vace import WanVACETransformer3DModel
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention import FeedForward
from ..attention_processor import Attention
from ..cache_utils import CacheMixin
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import FP32LayerNorm
from .transformer_wan import WanAttnProcessor2_0, WanRotaryPosEmbed, WanTimeTextImageEmbedding, WanTransformerBlock
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class WanVACETransformerBlock(nn.Module):
def __init__(
self,
dim: int,
ffn_dim: int,
num_heads: int,
qk_norm: str = "rms_norm_across_heads",
cross_attn_norm: bool = False,
eps: float = 1e-6,
added_kv_proj_dim: Optional[int] = None,
apply_input_projection: bool = False,
apply_output_projection: bool = False,
):
super().__init__()
# 1. Input projection
self.proj_in = None
if apply_input_projection:
self.proj_in = nn.Linear(dim, dim)
# 2. Self-attention
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
self.attn1 = Attention(
query_dim=dim,
heads=num_heads,
kv_heads=num_heads,
dim_head=dim // num_heads,
qk_norm=qk_norm,
eps=eps,
bias=True,
cross_attention_dim=None,
out_bias=True,
processor=WanAttnProcessor2_0(),
)
# 3. Cross-attention
self.attn2 = Attention(
query_dim=dim,
heads=num_heads,
kv_heads=num_heads,
dim_head=dim // num_heads,
qk_norm=qk_norm,
eps=eps,
bias=True,
cross_attention_dim=None,
out_bias=True,
added_kv_proj_dim=added_kv_proj_dim,
added_proj_bias=True,
processor=WanAttnProcessor2_0(),
)
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
# 4. Feed-forward
self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
# 5. Output projection
self.proj_out = None
if apply_output_projection:
self.proj_out = nn.Linear(dim, dim)
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
control_hidden_states: torch.Tensor,
temb: torch.Tensor,
rotary_emb: torch.Tensor,
) -> torch.Tensor:
if self.proj_in is not None:
control_hidden_states = self.proj_in(control_hidden_states)
control_hidden_states = control_hidden_states + hidden_states
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
self.scale_shift_table + temb.float()
).chunk(6, dim=1)
# 1. Self-attention
norm_hidden_states = (self.norm1(control_hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(
control_hidden_states
)
attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
control_hidden_states = (control_hidden_states.float() + attn_output * gate_msa).type_as(control_hidden_states)
# 2. Cross-attention
norm_hidden_states = self.norm2(control_hidden_states.float()).type_as(control_hidden_states)
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
control_hidden_states = control_hidden_states + attn_output
# 3. Feed-forward
norm_hidden_states = (self.norm3(control_hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
control_hidden_states
)
ff_output = self.ffn(norm_hidden_states)
control_hidden_states = (control_hidden_states.float() + ff_output.float() * c_gate_msa).type_as(
control_hidden_states
)
conditioning_states = None
if self.proj_out is not None:
conditioning_states = self.proj_out(control_hidden_states)
return conditioning_states, control_hidden_states
class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
r"""
A Transformer model for video-like data used in the Wan model.
Args:
patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
num_attention_heads (`int`, defaults to `40`):
Fixed length for text embeddings.
attention_head_dim (`int`, defaults to `128`):
The number of channels in each head.
in_channels (`int`, defaults to `16`):
The number of channels in the input.
out_channels (`int`, defaults to `16`):
The number of channels in the output.
text_dim (`int`, defaults to `512`):
Input dimension for text embeddings.
freq_dim (`int`, defaults to `256`):
Dimension for sinusoidal time embeddings.
ffn_dim (`int`, defaults to `13824`):
Intermediate dimension in feed-forward network.
num_layers (`int`, defaults to `40`):
The number of layers of transformer blocks to use.
window_size (`Tuple[int]`, defaults to `(-1, -1)`):
Window size for local attention (-1 indicates global attention).
cross_attn_norm (`bool`, defaults to `True`):
Enable cross-attention normalization.
qk_norm (`bool`, defaults to `True`):
Enable query/key normalization.
eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
add_img_emb (`bool`, defaults to `False`):
Whether to use img_emb.
added_kv_proj_dim (`int`, *optional*, defaults to `None`):
The number of channels to use for the added key and value projections. If `None`, no projection is used.
"""
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["patch_embedding", "vace_patch_embedding", "condition_embedder", "norm"]
_no_split_modules = ["WanTransformerBlock", "WanVACETransformerBlock"]
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
@register_to_config
def __init__(
self,
patch_size: Tuple[int] = (1, 2, 2),
num_attention_heads: int = 40,
attention_head_dim: int = 128,
in_channels: int = 16,
out_channels: int = 16,
text_dim: int = 4096,
freq_dim: int = 256,
ffn_dim: int = 13824,
num_layers: int = 40,
cross_attn_norm: bool = True,
qk_norm: Optional[str] = "rms_norm_across_heads",
eps: float = 1e-6,
image_dim: Optional[int] = None,
added_kv_proj_dim: Optional[int] = None,
rope_max_seq_len: int = 1024,
pos_embed_seq_len: Optional[int] = None,
vace_layers: List[int] = [0, 5, 10, 15, 20, 25, 30, 35],
vace_in_channels: int = 96,
) -> None:
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
out_channels = out_channels or in_channels
if max(vace_layers) >= num_layers:
raise ValueError(f"VACE layers {vace_layers} exceed the number of transformer layers {num_layers}.")
if 0 not in vace_layers:
raise ValueError("VACE layers must include layer 0.")
# 1. Patch & position embedding
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
self.vace_patch_embedding = nn.Conv3d(vace_in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
# 2. Condition embeddings
# image_embedding_dim=1280 for I2V model
self.condition_embedder = WanTimeTextImageEmbedding(
dim=inner_dim,
time_freq_dim=freq_dim,
time_proj_dim=inner_dim * 6,
text_embed_dim=text_dim,
image_embed_dim=image_dim,
pos_embed_seq_len=pos_embed_seq_len,
)
# 3. Transformer blocks
self.blocks = nn.ModuleList(
[
WanTransformerBlock(
inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
)
for _ in range(num_layers)
]
)
self.vace_blocks = nn.ModuleList(
[
WanVACETransformerBlock(
inner_dim,
ffn_dim,
num_attention_heads,
qk_norm,
cross_attn_norm,
eps,
added_kv_proj_dim,
apply_input_projection=i == 0, # Layer 0 always has input projection and is in vace_layers
apply_output_projection=True,
)
for i in range(len(vace_layers))
]
)
# 4. Output norm & projection
self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states_image: Optional[torch.Tensor] = None,
control_hidden_states: torch.Tensor = None,
control_hidden_states_scale: torch.Tensor = None,
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.config.patch_size
post_patch_num_frames = num_frames // p_t
post_patch_height = height // p_h
post_patch_width = width // p_w
if control_hidden_states_scale is None:
control_hidden_states_scale = control_hidden_states.new_ones(len(self.config.vace_layers))
control_hidden_states_scale = torch.unbind(control_hidden_states_scale)
if len(control_hidden_states_scale) != len(self.config.vace_layers):
raise ValueError(
f"Length of `control_hidden_states_scale` {len(control_hidden_states_scale)} should be "
f"equal to {len(self.config.vace_layers)}."
)
# 1. Rotary position embedding
rotary_emb = self.rope(hidden_states)
# 2. Patch embedding
hidden_states = self.patch_embedding(hidden_states)
hidden_states = hidden_states.flatten(2).transpose(1, 2)
control_hidden_states = self.vace_patch_embedding(control_hidden_states)
control_hidden_states = control_hidden_states.flatten(2).transpose(1, 2)
control_hidden_states_padding = control_hidden_states.new_zeros(
batch_size, hidden_states.size(1) - control_hidden_states.size(1), control_hidden_states.size(2)
)
control_hidden_states = torch.cat([control_hidden_states, control_hidden_states_padding], dim=1)
# 3. Time embedding
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
timestep, encoder_hidden_states, encoder_hidden_states_image
)
timestep_proj = timestep_proj.unflatten(1, (6, -1))
# 4. Image embedding
if encoder_hidden_states_image is not None:
encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
# 5. Transformer blocks
if torch.is_grad_enabled() and self.gradient_checkpointing:
# Prepare VACE hints
control_hidden_states_list = []
for i, block in enumerate(self.vace_blocks):
conditioning_states, control_hidden_states = self._gradient_checkpointing_func(
block, hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb
)
control_hidden_states_list.append((conditioning_states, control_hidden_states_scale[i]))
control_hidden_states_list = control_hidden_states_list[::-1]
for i, block in enumerate(self.blocks):
hidden_states = self._gradient_checkpointing_func(
block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
)
if i in self.config.vace_layers:
control_hint, scale = control_hidden_states_list.pop()
hidden_states = hidden_states + control_hint * scale
else:
# Prepare VACE hints
control_hidden_states_list = []
for i, block in enumerate(self.vace_blocks):
conditioning_states, control_hidden_states = block(
hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb
)
control_hidden_states_list.append((conditioning_states, control_hidden_states_scale[i]))
control_hidden_states_list = control_hidden_states_list[::-1]
for i, block in enumerate(self.blocks):
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
if i in self.config.vace_layers:
control_hint, scale = control_hidden_states_list.pop()
hidden_states = hidden_states + control_hint * scale
# 6. Output norm, projection & unpatchify
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
# Move the shift and scale tensors to the same device as hidden_states.
# When using multi-GPU inference via accelerate these will be on the
# first device rather than the last device, which hidden_states ends up
# on.
shift = shift.to(hidden_states.device)
scale = scale.to(hidden_states.device)
hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
)
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
...@@ -371,7 +371,7 @@ else: ...@@ -371,7 +371,7 @@ else:
"WuerstchenDecoderPipeline", "WuerstchenDecoderPipeline",
"WuerstchenPriorPipeline", "WuerstchenPriorPipeline",
] ]
_import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline", "WanVideoToVideoPipeline"] _import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline", "WanVideoToVideoPipeline", "WanVACEPipeline"]
try: try:
if not is_onnx_available(): if not is_onnx_available():
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
...@@ -739,7 +739,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -739,7 +739,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
UniDiffuserTextDecoder, UniDiffuserTextDecoder,
) )
from .visualcloze import VisualClozeGenerationPipeline, VisualClozePipeline from .visualcloze import VisualClozeGenerationPipeline, VisualClozePipeline
from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline from .wan import WanImageToVideoPipeline, WanPipeline, WanVACEPipeline, WanVideoToVideoPipeline
from .wuerstchen import ( from .wuerstchen import (
WuerstchenCombinedPipeline, WuerstchenCombinedPipeline,
WuerstchenDecoderPipeline, WuerstchenDecoderPipeline,
......
...@@ -24,6 +24,7 @@ except OptionalDependencyNotAvailable: ...@@ -24,6 +24,7 @@ except OptionalDependencyNotAvailable:
else: else:
_import_structure["pipeline_wan"] = ["WanPipeline"] _import_structure["pipeline_wan"] = ["WanPipeline"]
_import_structure["pipeline_wan_i2v"] = ["WanImageToVideoPipeline"] _import_structure["pipeline_wan_i2v"] = ["WanImageToVideoPipeline"]
_import_structure["pipeline_wan_vace"] = ["WanVACEPipeline"]
_import_structure["pipeline_wan_video2video"] = ["WanVideoToVideoPipeline"] _import_structure["pipeline_wan_video2video"] = ["WanVideoToVideoPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try: try:
...@@ -35,6 +36,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -35,6 +36,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else: else:
from .pipeline_wan import WanPipeline from .pipeline_wan import WanPipeline
from .pipeline_wan_i2v import WanImageToVideoPipeline from .pipeline_wan_i2v import WanImageToVideoPipeline
from .pipeline_wan_vace import WanVACEPipeline
from .pipeline_wan_video2video import WanVideoToVideoPipeline from .pipeline_wan_video2video import WanVideoToVideoPipeline
else: else:
......
This diff is collapsed.
...@@ -1150,6 +1150,21 @@ class WanTransformer3DModel(metaclass=DummyObject): ...@@ -1150,6 +1150,21 @@ class WanTransformer3DModel(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class WanVACETransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
def get_constant_schedule(*args, **kwargs): def get_constant_schedule(*args, **kwargs):
requires_backends(get_constant_schedule, ["torch"]) requires_backends(get_constant_schedule, ["torch"])
......
...@@ -2897,6 +2897,21 @@ class WanPipeline(metaclass=DummyObject): ...@@ -2897,6 +2897,21 @@ class WanPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class WanVACEPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class WanVideoToVideoPipeline(metaclass=DummyObject): class WanVideoToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
......
# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import torch
from PIL import Image
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanVACEPipeline, WanVACETransformer3DModel
from diffusers.utils.testing_utils import enable_full_determinism
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
class WanVACEPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = WanVACEPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
test_xformers_attention = False
supports_dduf = False
def get_dummy_components(self):
torch.manual_seed(0)
vae = AutoencoderKLWan(
base_dim=3,
z_dim=16,
dim_mult=[1, 1, 1, 1],
num_res_blocks=1,
temperal_downsample=[False, True, True],
)
torch.manual_seed(0)
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
transformer = WanVACETransformer3DModel(
patch_size=(1, 2, 2),
num_attention_heads=2,
attention_head_dim=12,
in_channels=16,
out_channels=16,
text_dim=32,
freq_dim=256,
ffn_dim=32,
num_layers=3,
cross_attn_norm=True,
qk_norm="rms_norm_across_heads",
rope_max_seq_len=32,
vace_layers=[0, 2],
vace_in_channels=96,
)
components = {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
num_frames = 17
height = 16
width = 16
video = [Image.new("RGB", (height, width))] * num_frames
mask = [Image.new("L", (height, width), 0)] * num_frames
inputs = {
"video": video,
"mask": mask,
"prompt": "dance monkey",
"negative_prompt": "negative",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
"height": 16,
"width": 16,
"num_frames": num_frames,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames[0]
self.assertEqual(video.shape, (17, 3, 16, 16))
# fmt: off
expected_slice = [0.4523, 0.45198, 0.44872, 0.45326, 0.45211, 0.45258, 0.45344, 0.453, 0.52431, 0.52572, 0.50701, 0.5118, 0.53717, 0.53093, 0.50557, 0.51402]
# fmt: on
video_slice = video.flatten()
video_slice = torch.cat([video_slice[:8], video_slice[-8:]])
video_slice = [round(x, 5) for x in video_slice.tolist()]
self.assertTrue(np.allclose(video_slice, expected_slice, atol=1e-3))
def test_inference_with_single_reference_image(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["reference_images"] = Image.new("RGB", (16, 16))
video = pipe(**inputs).frames[0]
self.assertEqual(video.shape, (17, 3, 16, 16))
# fmt: off
expected_slice = [0.45247, 0.45214, 0.44874, 0.45314, 0.45171, 0.45299, 0.45428, 0.45317, 0.51378, 0.52658, 0.53361, 0.52303, 0.46204, 0.50435, 0.52555, 0.51342]
# fmt: on
video_slice = video.flatten()
video_slice = torch.cat([video_slice[:8], video_slice[-8:]])
video_slice = [round(x, 5) for x in video_slice.tolist()]
self.assertTrue(np.allclose(video_slice, expected_slice, atol=1e-3))
def test_inference_with_multiple_reference_image(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["reference_images"] = [[Image.new("RGB", (16, 16))] * 2]
video = pipe(**inputs).frames[0]
self.assertEqual(video.shape, (17, 3, 16, 16))
# fmt: off
expected_slice = [0.45321, 0.45221, 0.44818, 0.45375, 0.45268, 0.4519, 0.45271, 0.45253, 0.51244, 0.52223, 0.51253, 0.51321, 0.50743, 0.51177, 0.51626, 0.50983]
# fmt: on
video_slice = video.flatten()
video_slice = torch.cat([video_slice[:8], video_slice[-8:]])
video_slice = [round(x, 5) for x in video_slice.tolist()]
self.assertTrue(np.allclose(video_slice, expected_slice, atol=1e-3))
@unittest.skip("Test not supported")
def test_attention_slicing_forward_pass(self):
pass
@unittest.skip("Errors out because passing multiple prompts at once is not yet supported by this pipeline.")
def test_encode_prompt_works_in_isolation(self):
pass
@unittest.skip("Batching is not yet supported with this pipeline")
def test_inference_batch_consistent(self):
pass
@unittest.skip("Batching is not yet supported with this pipeline")
def test_inference_batch_single_identical(self):
return super().test_inference_batch_single_identical()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment