"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "282ed27ac7ff341a749d94a61013f62f5ab41975"
Unverified Commit 0e95aa85 authored by Ju Hoon Park's avatar Ju Hoon Park Committed by GitHub
Browse files

[From Single File] support `from_single_file` method for `WanVACE3DTransformer` (#11807)

* add `WandVACETransformer3DModel` in`SINGLE_FILE_LOADABLE_CLASSES`

* add rename keys for `VACE`

add rename keys for `VACE`

* fix typo

Sincere thanks to @nitinmukesh 🙇‍♂️

* support for `1.3B VACE` model

Sincere thanks to @nitinmukesh again🙇

‍♂️

* update

* update

* Apply style fixes

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent 5ef74fd5
...@@ -136,6 +136,10 @@ SINGLE_FILE_LOADABLE_CLASSES = { ...@@ -136,6 +136,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers, "checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
"default_subfolder": "transformer", "default_subfolder": "transformer",
}, },
"WanVACETransformer3DModel": {
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
"default_subfolder": "transformer",
},
"AutoencoderKLWan": { "AutoencoderKLWan": {
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers, "checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
"default_subfolder": "vae", "default_subfolder": "vae",
......
...@@ -126,6 +126,7 @@ CHECKPOINT_KEY_NAMES = { ...@@ -126,6 +126,7 @@ CHECKPOINT_KEY_NAMES = {
], ],
"wan": ["model.diffusion_model.head.modulation", "head.modulation"], "wan": ["model.diffusion_model.head.modulation", "head.modulation"],
"wan_vae": "decoder.middle.0.residual.0.gamma", "wan_vae": "decoder.middle.0.residual.0.gamma",
"wan_vace": "vace_blocks.0.after_proj.bias",
"hidream": "double_stream_blocks.0.block.adaLN_modulation.1.bias", "hidream": "double_stream_blocks.0.block.adaLN_modulation.1.bias",
"cosmos-1.0": [ "cosmos-1.0": [
"net.x_embedder.proj.1.weight", "net.x_embedder.proj.1.weight",
...@@ -202,6 +203,8 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = { ...@@ -202,6 +203,8 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"}, "wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
"wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"}, "wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
"wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"}, "wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
"wan-vace-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-1.3B-diffusers"},
"wan-vace-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-14B-diffusers"},
"hidream": {"pretrained_model_name_or_path": "HiDream-ai/HiDream-I1-Dev"}, "hidream": {"pretrained_model_name_or_path": "HiDream-ai/HiDream-I1-Dev"},
"cosmos-1.0-t2w-7B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-7B-Text2World"}, "cosmos-1.0-t2w-7B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-7B-Text2World"},
"cosmos-1.0-t2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-14B-Text2World"}, "cosmos-1.0-t2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-14B-Text2World"},
...@@ -716,7 +719,13 @@ def infer_diffusers_model_type(checkpoint): ...@@ -716,7 +719,13 @@ def infer_diffusers_model_type(checkpoint):
else: else:
target_key = "patch_embedding.weight" target_key = "patch_embedding.weight"
if checkpoint[target_key].shape[0] == 1536: if CHECKPOINT_KEY_NAMES["wan_vace"] in checkpoint:
if checkpoint[target_key].shape[0] == 1536:
model_type = "wan-vace-1.3B"
elif checkpoint[target_key].shape[0] == 5120:
model_type = "wan-vace-14B"
elif checkpoint[target_key].shape[0] == 1536:
model_type = "wan-t2v-1.3B" model_type = "wan-t2v-1.3B"
elif checkpoint[target_key].shape[0] == 5120 and checkpoint[target_key].shape[1] == 16: elif checkpoint[target_key].shape[0] == 5120 and checkpoint[target_key].shape[1] == 16:
model_type = "wan-t2v-14B" model_type = "wan-t2v-14B"
...@@ -3132,6 +3141,9 @@ def convert_wan_transformer_to_diffusers(checkpoint, **kwargs): ...@@ -3132,6 +3141,9 @@ def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj", "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2", "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
"img_emb.proj.4": "condition_embedder.image_embedder.norm2", "img_emb.proj.4": "condition_embedder.image_embedder.norm2",
# For the VACE model
"before_proj": "proj_in",
"after_proj": "proj_out",
} }
for key in list(checkpoint.keys()): for key in list(checkpoint.keys()):
......
...@@ -15,6 +15,8 @@ from diffusers import ( ...@@ -15,6 +15,8 @@ from diffusers import (
HiDreamImageTransformer2DModel, HiDreamImageTransformer2DModel,
SD3Transformer2DModel, SD3Transformer2DModel,
StableDiffusion3Pipeline, StableDiffusion3Pipeline,
WanTransformer3DModel,
WanVACETransformer3DModel,
) )
from diffusers.utils import load_image from diffusers.utils import load_image
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
...@@ -577,3 +579,71 @@ class HiDreamGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase): ...@@ -577,3 +579,71 @@ class HiDreamGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
).to(torch_device, self.torch_dtype), ).to(torch_device, self.torch_dtype),
"timesteps": torch.tensor([1]).to(torch_device, self.torch_dtype), "timesteps": torch.tensor([1]).to(torch_device, self.torch_dtype),
} }
class WanGGUFTexttoVideoSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
ckpt_path = "https://huggingface.co/city96/Wan2.1-T2V-14B-gguf/blob/main/wan2.1-t2v-14b-Q3_K_S.gguf"
torch_dtype = torch.bfloat16
model_cls = WanTransformer3DModel
expected_memory_use_in_gb = 9
def get_dummy_inputs(self):
return {
"hidden_states": torch.randn((1, 36, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
torch_device, self.torch_dtype
),
"encoder_hidden_states": torch.randn(
(1, 512, 4096),
generator=torch.Generator("cpu").manual_seed(0),
).to(torch_device, self.torch_dtype),
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
}
class WanGGUFImagetoVideoSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
ckpt_path = "https://huggingface.co/city96/Wan2.1-I2V-14B-480P-gguf/blob/main/wan2.1-i2v-14b-480p-Q3_K_S.gguf"
torch_dtype = torch.bfloat16
model_cls = WanTransformer3DModel
expected_memory_use_in_gb = 9
def get_dummy_inputs(self):
return {
"hidden_states": torch.randn((1, 36, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
torch_device, self.torch_dtype
),
"encoder_hidden_states": torch.randn(
(1, 512, 4096),
generator=torch.Generator("cpu").manual_seed(0),
).to(torch_device, self.torch_dtype),
"encoder_hidden_states_image": torch.randn(
(1, 257, 1280), generator=torch.Generator("cpu").manual_seed(0)
).to(torch_device, self.torch_dtype),
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
}
class WanVACEGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
ckpt_path = "https://huggingface.co/QuantStack/Wan2.1_14B_VACE-GGUF/blob/main/Wan2.1_14B_VACE-Q3_K_S.gguf"
torch_dtype = torch.bfloat16
model_cls = WanVACETransformer3DModel
expected_memory_use_in_gb = 9
def get_dummy_inputs(self):
return {
"hidden_states": torch.randn((1, 16, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
torch_device, self.torch_dtype
),
"encoder_hidden_states": torch.randn(
(1, 512, 4096),
generator=torch.Generator("cpu").manual_seed(0),
).to(torch_device, self.torch_dtype),
"control_hidden_states": torch.randn(
(1, 96, 2, 64, 64),
generator=torch.Generator("cpu").manual_seed(0),
).to(torch_device, self.torch_dtype),
"control_hidden_states_scale": torch.randn(
(8,),
generator=torch.Generator("cpu").manual_seed(0),
).to(torch_device, self.torch_dtype),
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment