Unverified Commit 2d8a41ca authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

[Alibaba Wan Team] continue on #10921 Wan2.1 (#10922)

* Add wanx pipeline, model and example

* wanx_merged_v1

* change WanX into Wan

* fix i2v fp32 oom error

Link: https://code.alibaba-inc.com/open_wanx2/diffusers/codereview/20607813



* support t2v load fp32 ckpt

* add example

* final merge v1

* Update autoencoder_kl_wan.py

* up

* update middle, test up_block

* up up

* one less nn.sequential

* up more

* up

* more

* [refactor] [wip] Wan transformer/pipeline (#10926)

* update

* update

* refactor rope

* refactor pipeline

* make fix-copies

* add transformer test

* update

* update

* make style

* update tests

* tests

* conversion script

* conversion script

* update

* docs

* remove unused code

* fix _toctree.yml

* update dtype

* fix test

* fix tests: scale

* up

* more

* Apply suggestions from code review

* Apply suggestions from code review

* style

* Update scripts/convert_wan_to_diffusers.py

* update docs

* fix

---------
Co-authored-by: default avatarYitong Huang <huangyitong.hyt@alibaba-inc.com>
Co-authored-by: default avatar亚森 <wangjiayu.wjy@alibaba-inc.com>
Co-authored-by: default avatarAryan <aryan@huggingface.co>
parent 7007feba
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from diffusers import AutoencoderKLWan
from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
enable_full_determinism()
class AutoencoderKLWanTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = AutoencoderKLWan
main_input_name = "sample"
base_precision = 1e-2
def get_autoencoder_kl_wan_config(self):
return {
"base_dim": 3,
"z_dim": 16,
"dim_mult": [1, 1, 1, 1],
"num_res_blocks": 1,
"temperal_downsample": [False, True, True],
}
@property
def dummy_input(self):
batch_size = 2
num_frames = 9
num_channels = 3
sizes = (16, 16)
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
return {"sample": image}
@property
def input_shape(self):
return (3, 9, 16, 16)
@property
def output_shape(self):
return (3, 9, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_wan_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@unittest.skip("Gradient checkpointing has not been implemented yet")
def test_gradient_checkpointing_is_applied(self):
pass
@unittest.skip("Test not supported")
def test_forward_with_norm_groups(self):
pass
@unittest.skip("RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
def test_layerwise_casting_inference(self):
pass
@unittest.skip("RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
def test_layerwise_casting_training(self):
pass
......@@ -739,8 +739,14 @@ class ModelTesterMixin:
model.save_pretrained(tmpdirname, safe_serialization=False)
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True, torch_dtype=dtype)
assert new_model.dtype == dtype
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype)
assert new_model.dtype == dtype
if (
hasattr(self.model_class, "_keep_in_fp32_modules")
and self.model_class._keep_in_fp32_modules is None
):
new_model = self.model_class.from_pretrained(
tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype
)
assert new_model.dtype == dtype
def test_determinism(self, expected_max_diff=1e-5):
if self.forward_requires_fresh_args:
......
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import WanTransformer3DModel
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class WanTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = WanTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def dummy_input(self):
batch_size = 1
num_channels = 4
num_frames = 2
height = 16
width = 16
text_encoder_embedding_dim = 16
sequence_length = 12
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep,
}
@property
def input_shape(self):
return (4, 1, 16, 16)
@property
def output_shape(self):
return (4, 1, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": (1, 2, 2),
"num_attention_heads": 2,
"attention_head_dim": 12,
"in_channels": 4,
"out_channels": 4,
"text_dim": 16,
"freq_dim": 256,
"ffn_dim": 32,
"num_layers": 2,
"cross_attn_norm": True,
"qk_norm": "rms_norm_across_heads",
"rope_max_seq_len": 32,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"WanTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
# Copyright 2024 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import numpy as np
import torch
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanPipeline, WanTransformer3DModel
from diffusers.utils.testing_utils import (
enable_full_determinism,
require_torch_accelerator,
slow,
)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
PipelineTesterMixin,
)
enable_full_determinism()
class WanPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = WanPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
test_xformers_attention = False
supports_dduf = False
def get_dummy_components(self):
torch.manual_seed(0)
vae = AutoencoderKLWan(
base_dim=3,
z_dim=16,
dim_mult=[1, 1, 1, 1],
num_res_blocks=1,
temperal_downsample=[False, True, True],
)
torch.manual_seed(0)
# TODO: impl FlowDPMSolverMultistepScheduler
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
transformer = WanTransformer3DModel(
patch_size=(1, 2, 2),
num_attention_heads=2,
attention_head_dim=12,
in_channels=16,
out_channels=16,
text_dim=32,
freq_dim=256,
ffn_dim=32,
num_layers=2,
cross_attn_norm=True,
qk_norm="rms_norm_across_heads",
rope_max_seq_len=32,
)
components = {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "dance monkey",
"negative_prompt": "negative", # TODO
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
"height": 16,
"width": 16,
"num_frames": 9,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
self.assertEqual(generated_video.shape, (9, 3, 16, 16))
expected_video = torch.randn(9, 3, 16, 16)
max_diff = np.abs(generated_video - expected_video).max()
self.assertLessEqual(max_diff, 1e10)
@unittest.skip("Test not supported")
def test_attention_slicing_forward_pass(self):
pass
@slow
@require_torch_accelerator
class WanPipelineIntegrationTests(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger."
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
@unittest.skip("TODO: test needs to be implemented")
def test_Wanx(self):
pass
# Copyright 2024 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import torch
from PIL import Image
from transformers import (
AutoTokenizer,
CLIPImageProcessor,
CLIPVisionConfig,
CLIPVisionModelWithProjection,
T5EncoderModel,
)
from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanImageToVideoPipeline, WanTransformer3DModel
from diffusers.utils.testing_utils import enable_full_determinism
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
class WanImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = WanImageToVideoPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "height", "width"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
test_xformers_attention = False
supports_dduf = False
def get_dummy_components(self):
torch.manual_seed(0)
vae = AutoencoderKLWan(
base_dim=3,
z_dim=16,
dim_mult=[1, 1, 1, 1],
num_res_blocks=1,
temperal_downsample=[False, True, True],
)
torch.manual_seed(0)
# TODO: impl FlowDPMSolverMultistepScheduler
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
transformer = WanTransformer3DModel(
patch_size=(1, 2, 2),
num_attention_heads=2,
attention_head_dim=12,
in_channels=36,
out_channels=16,
text_dim=32,
freq_dim=256,
ffn_dim=32,
num_layers=2,
cross_attn_norm=True,
qk_norm="rms_norm_across_heads",
rope_max_seq_len=32,
image_dim=4,
)
torch.manual_seed(0)
image_encoder_config = CLIPVisionConfig(
hidden_size=4,
projection_dim=4,
num_hidden_layers=2,
num_attention_heads=2,
image_size=32,
intermediate_size=16,
patch_size=1,
)
image_encoder = CLIPVisionModelWithProjection(image_encoder_config)
torch.manual_seed(0)
image_processor = CLIPImageProcessor(crop_size=32, size=32)
components = {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"image_encoder": image_encoder,
"image_processor": image_processor,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
image_height = 16
image_width = 16
image = Image.new("RGB", (image_width, image_height))
inputs = {
"image": image,
"prompt": "dance monkey",
"negative_prompt": "negative", # TODO
"max_area": 1024,
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
"num_frames": 9,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
self.assertEqual(generated_video.shape, (9, 3, 32, 32))
expected_video = torch.randn(9, 3, 32, 32)
max_diff = np.abs(generated_video - expected_video).max()
self.assertLessEqual(max_diff, 1e10)
@unittest.skip("Test not supported")
def test_attention_slicing_forward_pass(self):
pass
@unittest.skip("TODO: revisit failing as it requires a very high threshold to pass")
def test_inference_batch_single_identical(self):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment