Unverified Commit df1d7b01 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

[WIP] Add Wan Video2Video (#11053)

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update
parent 5a6edac0
......@@ -133,6 +133,46 @@ output = pipe(
export_to_video(output, "wan-i2v.mp4", fps=16)
```
### Video to Video Generation
```python
import torch
from diffusers.utils import load_video, export_to_video
from diffusers import AutoencoderKLWan, WanVideoToVideoPipeline, UniPCMultistepScheduler
# Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers
model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
vae = AutoencoderKLWan.from_pretrained(
model_id, subfolder="vae", torch_dtype=torch.float32
)
pipe = WanVideoToVideoPipeline.from_pretrained(
model_id, vae=vae, torch_dtype=torch.bfloat16
)
flow_shift = 3.0 # 5.0 for 720P, 3.0 for 480P
pipe.scheduler = UniPCMultistepScheduler.from_config(
pipe.scheduler.config, flow_shift=flow_shift
)
# change to pipe.to("cuda") if you have sufficient VRAM
pipe.enable_model_cpu_offload()
prompt = "A robot standing on a mountain top. The sun is setting in the background"
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
video = load_video(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4"
)
output = pipe(
video=video,
prompt=prompt,
negative_prompt=negative_prompt,
height=480,
width=512,
guidance_scale=7.0,
strength=0.7,
).frames[0]
export_to_video(output, "wan-v2v.mp4", fps=16)
```
## Memory Optimizations for Wan 2.1
Base inference with the large 14B Wan 2.1 models can take up to 35GB of VRAM when generating videos at 720p resolution. We'll outline a few memory optimizations we can apply to reduce the VRAM required to run the model.
......@@ -323,7 +363,7 @@ import numpy as np
from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanImageToVideoPipeline
from diffusers.hooks.group_offloading import apply_group_offloading
from diffusers.utils import export_to_video, load_image
from transformers import UMT5EncoderModel, CLIPVisionMode
from transformers import UMT5EncoderModel, CLIPVisionModel
model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
image_encoder = CLIPVisionModel.from_pretrained(
......@@ -356,7 +396,7 @@ prompt = (
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
"the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
)
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
num_frames = 33
output = pipe(
......@@ -372,7 +412,7 @@ output = pipe(
export_to_video(output, "wan-i2v.mp4", fps=16)
```
### Using a Custom Scheduler
## Using a Custom Scheduler
Wan can be used with many different schedulers, each with their own benefits regarding speed and generation quality. By default, Wan uses the `UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0)` scheduler. You can use a different scheduler as follows:
......@@ -403,7 +443,7 @@ transformer = WanTransformer3DModel.from_single_file(ckpt_path, torch_dtype=torc
pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", transformer=transformer)
```
## Recommendations for Inference:
## Recommendations for Inference
- Keep `AutencoderKLWan` in `torch.float32` for better decoding quality.
- `num_frames` should satisfy the following constraint: `(num_frames - 1) % 4 == 0`
- For smaller resolution videos, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution videos, try higher values (between `7.0` and `12.0`). The default value is `3.0` for Wan.
......
......@@ -509,6 +509,7 @@ else:
"VQDiffusionPipeline",
"WanImageToVideoPipeline",
"WanPipeline",
"WanVideoToVideoPipeline",
"WuerstchenCombinedPipeline",
"WuerstchenDecoderPipeline",
"WuerstchenPriorPipeline",
......@@ -1062,6 +1063,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
VQDiffusionPipeline,
WanImageToVideoPipeline,
WanPipeline,
WanVideoToVideoPipeline,
WuerstchenCombinedPipeline,
WuerstchenDecoderPipeline,
WuerstchenPriorPipeline,
......
......@@ -356,7 +356,7 @@ else:
"WuerstchenDecoderPipeline",
"WuerstchenPriorPipeline",
]
_import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline"]
_import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline", "WanVideoToVideoPipeline"]
try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
......@@ -709,7 +709,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
UniDiffuserPipeline,
UniDiffuserTextDecoder,
)
from .wan import WanImageToVideoPipeline, WanPipeline
from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline
from .wuerstchen import (
WuerstchenCombinedPipeline,
WuerstchenDecoderPipeline,
......
......@@ -24,7 +24,7 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["pipeline_wan"] = ["WanPipeline"]
_import_structure["pipeline_wan_i2v"] = ["WanImageToVideoPipeline"]
_import_structure["pipeline_wan_video2video"] = ["WanVideoToVideoPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
......@@ -35,6 +35,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .pipeline_wan import WanPipeline
from .pipeline_wan_i2v import WanImageToVideoPipeline
from .pipeline_wan_video2video import WanVideoToVideoPipeline
else:
import sys
......
This diff is collapsed.
......@@ -2762,6 +2762,21 @@ class WanPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class WanVideoToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class WuerstchenCombinedPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
......
# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import torch
from PIL import Image
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLWan, UniPCMultistepScheduler, WanTransformer3DModel, WanVideoToVideoPipeline
from diffusers.utils.testing_utils import (
enable_full_determinism,
)
from ..pipeline_params import TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
PipelineTesterMixin,
)
enable_full_determinism()
class WanVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = WanVideoToVideoPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = frozenset(["video", "prompt", "negative_prompt"])
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
test_xformers_attention = False
supports_dduf = False
def get_dummy_components(self):
torch.manual_seed(0)
vae = AutoencoderKLWan(
base_dim=3,
z_dim=16,
dim_mult=[1, 1, 1, 1],
num_res_blocks=1,
temperal_downsample=[False, True, True],
)
torch.manual_seed(0)
scheduler = UniPCMultistepScheduler(flow_shift=3.0)
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
transformer = WanTransformer3DModel(
patch_size=(1, 2, 2),
num_attention_heads=2,
attention_head_dim=12,
in_channels=16,
out_channels=16,
text_dim=32,
freq_dim=256,
ffn_dim=32,
num_layers=2,
cross_attn_norm=True,
qk_norm="rms_norm_across_heads",
rope_max_seq_len=32,
)
components = {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
video = [Image.new("RGB", (16, 16))] * 17
inputs = {
"video": video,
"prompt": "dance monkey",
"negative_prompt": "negative", # TODO
"generator": generator,
"num_inference_steps": 4,
"guidance_scale": 6.0,
"height": 16,
"width": 16,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
self.assertEqual(generated_video.shape, (17, 3, 16, 16))
expected_video = torch.randn(17, 3, 16, 16)
max_diff = np.abs(generated_video - expected_video).max()
self.assertLessEqual(max_diff, 1e10)
@unittest.skip("Test not supported")
def test_attention_slicing_forward_pass(self):
pass
@unittest.skip(
"WanVideoToVideoPipeline has to run in mixed precision. Casting the entire pipeline will result in errors"
)
def test_float16_inference(self):
pass
@unittest.skip(
"WanVideoToVideoPipeline has to run in mixed precision. Save/Load the entire pipeline in FP16 will result in errors"
)
def test_save_load_float16(self):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment