Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
1357931d
Unverified
Commit
1357931d
authored
Mar 07, 2025
by
Dhruv Nair
Committed by
GitHub
Mar 07, 2025
Browse files
[Single File] Add single file support for Wan T2V/I2V (#10991)
* update * update * update * update * update * update * update
parent
a2d3d6af
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
518 additions
and
50 deletions
+518
-50
docs/source/en/api/pipelines/wan.md
docs/source/en/api/pipelines/wan.md
+16
-0
src/diffusers/loaders/single_file_model.py
src/diffusers/loaders/single_file_model.py
+10
-0
src/diffusers/loaders/single_file_utils.py
src/diffusers/loaders/single_file_utils.py
+330
-45
src/diffusers/models/attention_processor.py
src/diffusers/models/attention_processor.py
+3
-2
src/diffusers/models/autoencoders/autoencoder_kl_wan.py
src/diffusers/models/autoencoders/autoencoder_kl_wan.py
+2
-1
src/diffusers/models/transformers/transformer_wan.py
src/diffusers/models/transformers/transformer_wan.py
+3
-2
tests/single_file/test_model_wan_autoencoder_single_file.py
tests/single_file/test_model_wan_autoencoder_single_file.py
+61
-0
tests/single_file/test_model_wan_transformer3d_single_file.py
...s/single_file/test_model_wan_transformer3d_single_file.py
+93
-0
No files found.
docs/source/en/api/pipelines/wan.md
View file @
1357931d
...
...
@@ -45,6 +45,22 @@ pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", scheduler
pipe
.
scheduler
=
<
CUSTOM_SCHEDULER_HERE
>
```
### Using single file loading with Wan
The
`WanTransformer3DModel`
and
`AutoencoderKLWan`
models support loading checkpoints in their original format via the
`from_single_file`
loading
method.
```
python
import
torch
from
diffusers
import
WanPipeline
,
WanTransformer3DModel
ckpt_path
=
"https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors"
transformer
=
WanTransformer3DModel
.
from_single_file
(
ckpt_path
,
torch_dtype
=
torch
.
bfloat16
)
pipe
=
WanPipeline
.
from_pretrained
(
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
,
transformer
=
transformer
)
```
## WanPipeline
[[autodoc]] WanPipeline
...
...
src/diffusers/loaders/single_file_model.py
View file @
1357931d
...
...
@@ -39,6 +39,8 @@ from .single_file_utils import (
convert_mochi_transformer_checkpoint_to_diffusers
,
convert_sd3_transformer_checkpoint_to_diffusers
,
convert_stable_cascade_unet_single_file_to_diffusers
,
convert_wan_transformer_to_diffusers
,
convert_wan_vae_to_diffusers
,
create_controlnet_diffusers_config_from_ldm
,
create_unet_diffusers_config_from_ldm
,
create_vae_diffusers_config_from_ldm
,
...
...
@@ -117,6 +119,14 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"checkpoint_mapping_fn"
:
convert_lumina2_to_diffusers
,
"default_subfolder"
:
"transformer"
,
},
"WanTransformer3DModel"
:
{
"checkpoint_mapping_fn"
:
convert_wan_transformer_to_diffusers
,
"default_subfolder"
:
"transformer"
,
},
"AutoencoderKLWan"
:
{
"checkpoint_mapping_fn"
:
convert_wan_vae_to_diffusers
,
"default_subfolder"
:
"vae"
,
},
}
...
...
src/diffusers/loaders/single_file_utils.py
View file @
1357931d
...
...
@@ -117,6 +117,8 @@ CHECKPOINT_KEY_NAMES = {
"hunyuan-video"
:
"txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias"
,
"instruct-pix2pix"
:
"model.diffusion_model.input_blocks.0.0.weight"
,
"lumina2"
:
[
"model.diffusion_model.cap_embedder.0.weight"
,
"cap_embedder.0.weight"
],
"wan"
:
[
"model.diffusion_model.head.modulation"
,
"head.modulation"
],
"wan_vae"
:
"decoder.middle.0.residual.0.gamma"
,
}
DIFFUSERS_DEFAULT_PIPELINE_PATHS
=
{
...
...
@@ -176,6 +178,9 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"hunyuan-video"
:
{
"pretrained_model_name_or_path"
:
"hunyuanvideo-community/HunyuanVideo"
},
"instruct-pix2pix"
:
{
"pretrained_model_name_or_path"
:
"timbrooks/instruct-pix2pix"
},
"lumina2"
:
{
"pretrained_model_name_or_path"
:
"Alpha-VLLM/Lumina-Image-2.0"
},
"wan-t2v-1.3B"
:
{
"pretrained_model_name_or_path"
:
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
},
"wan-t2v-14B"
:
{
"pretrained_model_name_or_path"
:
"Wan-AI/Wan2.1-T2V-14B-Diffusers"
},
"wan-i2v-14B"
:
{
"pretrained_model_name_or_path"
:
"Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
},
}
# Use to configure model sample size when original config is provided
...
...
@@ -664,6 +669,21 @@ def infer_diffusers_model_type(checkpoint):
elif
any
(
key
in
checkpoint
for
key
in
CHECKPOINT_KEY_NAMES
[
"lumina2"
]):
model_type
=
"lumina2"
elif
any
(
key
in
checkpoint
for
key
in
CHECKPOINT_KEY_NAMES
[
"wan"
]):
if
"model.diffusion_model.patch_embedding.weight"
in
checkpoint
:
target_key
=
"model.diffusion_model.patch_embedding.weight"
else
:
target_key
=
"patch_embedding.weight"
if
checkpoint
[
target_key
].
shape
[
0
]
==
1536
:
model_type
=
"wan-t2v-1.3B"
elif
checkpoint
[
target_key
].
shape
[
0
]
==
5120
and
checkpoint
[
target_key
].
shape
[
1
]
==
16
:
model_type
=
"wan-t2v-14B"
else
:
model_type
=
"wan-i2v-14B"
elif
CHECKPOINT_KEY_NAMES
[
"wan_vae"
]
in
checkpoint
:
# All Wan models use the same VAE so we can use the same default model repo to fetch the config
model_type
=
"wan-t2v-14B"
else
:
model_type
=
"v1"
...
...
@@ -2470,7 +2490,7 @@ def convert_autoencoder_dc_checkpoint_to_diffusers(checkpoint, **kwargs):
def
convert_mochi_transformer_checkpoint_to_diffusers
(
checkpoint
,
**
kwargs
):
new
_state_dict
=
{}
converted
_state_dict
=
{}
# Comfy checkpoints add this prefix
keys
=
list
(
checkpoint
.
keys
())
...
...
@@ -2479,22 +2499,22 @@ def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
checkpoint
[
k
.
replace
(
"model.diffusion_model."
,
""
)]
=
checkpoint
.
pop
(
k
)
# Convert patch_embed
new
_state_dict
[
"patch_embed.proj.weight"
]
=
checkpoint
.
pop
(
"x_embedder.proj.weight"
)
new
_state_dict
[
"patch_embed.proj.bias"
]
=
checkpoint
.
pop
(
"x_embedder.proj.bias"
)
converted
_state_dict
[
"patch_embed.proj.weight"
]
=
checkpoint
.
pop
(
"x_embedder.proj.weight"
)
converted
_state_dict
[
"patch_embed.proj.bias"
]
=
checkpoint
.
pop
(
"x_embedder.proj.bias"
)
# Convert time_embed
new
_state_dict
[
"time_embed.timestep_embedder.linear_1.weight"
]
=
checkpoint
.
pop
(
"t_embedder.mlp.0.weight"
)
new
_state_dict
[
"time_embed.timestep_embedder.linear_1.bias"
]
=
checkpoint
.
pop
(
"t_embedder.mlp.0.bias"
)
new
_state_dict
[
"time_embed.timestep_embedder.linear_2.weight"
]
=
checkpoint
.
pop
(
"t_embedder.mlp.2.weight"
)
new
_state_dict
[
"time_embed.timestep_embedder.linear_2.bias"
]
=
checkpoint
.
pop
(
"t_embedder.mlp.2.bias"
)
new
_state_dict
[
"time_embed.pooler.to_kv.weight"
]
=
checkpoint
.
pop
(
"t5_y_embedder.to_kv.weight"
)
new
_state_dict
[
"time_embed.pooler.to_kv.bias"
]
=
checkpoint
.
pop
(
"t5_y_embedder.to_kv.bias"
)
new
_state_dict
[
"time_embed.pooler.to_q.weight"
]
=
checkpoint
.
pop
(
"t5_y_embedder.to_q.weight"
)
new
_state_dict
[
"time_embed.pooler.to_q.bias"
]
=
checkpoint
.
pop
(
"t5_y_embedder.to_q.bias"
)
new
_state_dict
[
"time_embed.pooler.to_out.weight"
]
=
checkpoint
.
pop
(
"t5_y_embedder.to_out.weight"
)
new
_state_dict
[
"time_embed.pooler.to_out.bias"
]
=
checkpoint
.
pop
(
"t5_y_embedder.to_out.bias"
)
new
_state_dict
[
"time_embed.caption_proj.weight"
]
=
checkpoint
.
pop
(
"t5_yproj.weight"
)
new
_state_dict
[
"time_embed.caption_proj.bias"
]
=
checkpoint
.
pop
(
"t5_yproj.bias"
)
converted
_state_dict
[
"time_embed.timestep_embedder.linear_1.weight"
]
=
checkpoint
.
pop
(
"t_embedder.mlp.0.weight"
)
converted
_state_dict
[
"time_embed.timestep_embedder.linear_1.bias"
]
=
checkpoint
.
pop
(
"t_embedder.mlp.0.bias"
)
converted
_state_dict
[
"time_embed.timestep_embedder.linear_2.weight"
]
=
checkpoint
.
pop
(
"t_embedder.mlp.2.weight"
)
converted
_state_dict
[
"time_embed.timestep_embedder.linear_2.bias"
]
=
checkpoint
.
pop
(
"t_embedder.mlp.2.bias"
)
converted
_state_dict
[
"time_embed.pooler.to_kv.weight"
]
=
checkpoint
.
pop
(
"t5_y_embedder.to_kv.weight"
)
converted
_state_dict
[
"time_embed.pooler.to_kv.bias"
]
=
checkpoint
.
pop
(
"t5_y_embedder.to_kv.bias"
)
converted
_state_dict
[
"time_embed.pooler.to_q.weight"
]
=
checkpoint
.
pop
(
"t5_y_embedder.to_q.weight"
)
converted
_state_dict
[
"time_embed.pooler.to_q.bias"
]
=
checkpoint
.
pop
(
"t5_y_embedder.to_q.bias"
)
converted
_state_dict
[
"time_embed.pooler.to_out.weight"
]
=
checkpoint
.
pop
(
"t5_y_embedder.to_out.weight"
)
converted
_state_dict
[
"time_embed.pooler.to_out.bias"
]
=
checkpoint
.
pop
(
"t5_y_embedder.to_out.bias"
)
converted
_state_dict
[
"time_embed.caption_proj.weight"
]
=
checkpoint
.
pop
(
"t5_yproj.weight"
)
converted
_state_dict
[
"time_embed.caption_proj.bias"
]
=
checkpoint
.
pop
(
"t5_yproj.bias"
)
# Convert transformer blocks
num_layers
=
48
...
...
@@ -2503,68 +2523,84 @@ def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
old_prefix
=
f
"blocks.
{
i
}
."
# norm1
new
_state_dict
[
block_prefix
+
"norm1.linear.weight"
]
=
checkpoint
.
pop
(
old_prefix
+
"mod_x.weight"
)
new
_state_dict
[
block_prefix
+
"norm1.linear.bias"
]
=
checkpoint
.
pop
(
old_prefix
+
"mod_x.bias"
)
converted
_state_dict
[
block_prefix
+
"norm1.linear.weight"
]
=
checkpoint
.
pop
(
old_prefix
+
"mod_x.weight"
)
converted
_state_dict
[
block_prefix
+
"norm1.linear.bias"
]
=
checkpoint
.
pop
(
old_prefix
+
"mod_x.bias"
)
if
i
<
num_layers
-
1
:
new_state_dict
[
block_prefix
+
"norm1_context.linear.weight"
]
=
checkpoint
.
pop
(
old_prefix
+
"mod_y.weight"
)
new_state_dict
[
block_prefix
+
"norm1_context.linear.bias"
]
=
checkpoint
.
pop
(
old_prefix
+
"mod_y.bias"
)
converted_state_dict
[
block_prefix
+
"norm1_context.linear.weight"
]
=
checkpoint
.
pop
(
old_prefix
+
"mod_y.weight"
)
converted_state_dict
[
block_prefix
+
"norm1_context.linear.bias"
]
=
checkpoint
.
pop
(
old_prefix
+
"mod_y.bias"
)
else
:
new
_state_dict
[
block_prefix
+
"norm1_context.linear_1.weight"
]
=
checkpoint
.
pop
(
converted
_state_dict
[
block_prefix
+
"norm1_context.linear_1.weight"
]
=
checkpoint
.
pop
(
old_prefix
+
"mod_y.weight"
)
new_state_dict
[
block_prefix
+
"norm1_context.linear_1.bias"
]
=
checkpoint
.
pop
(
old_prefix
+
"mod_y.bias"
)
converted_state_dict
[
block_prefix
+
"norm1_context.linear_1.bias"
]
=
checkpoint
.
pop
(
old_prefix
+
"mod_y.bias"
)
# Visual attention
qkv_weight
=
checkpoint
.
pop
(
old_prefix
+
"attn.qkv_x.weight"
)
q
,
k
,
v
=
qkv_weight
.
chunk
(
3
,
dim
=
0
)
new_state_dict
[
block_prefix
+
"attn1.to_q.weight"
]
=
q
new_state_dict
[
block_prefix
+
"attn1.to_k.weight"
]
=
k
new_state_dict
[
block_prefix
+
"attn1.to_v.weight"
]
=
v
new_state_dict
[
block_prefix
+
"attn1.norm_q.weight"
]
=
checkpoint
.
pop
(
old_prefix
+
"attn.q_norm_x.weight"
)
new_state_dict
[
block_prefix
+
"attn1.norm_k.weight"
]
=
checkpoint
.
pop
(
old_prefix
+
"attn.k_norm_x.weight"
)
new_state_dict
[
block_prefix
+
"attn1.to_out.0.weight"
]
=
checkpoint
.
pop
(
old_prefix
+
"attn.proj_x.weight"
)
new_state_dict
[
block_prefix
+
"attn1.to_out.0.bias"
]
=
checkpoint
.
pop
(
old_prefix
+
"attn.proj_x.bias"
)
converted_state_dict
[
block_prefix
+
"attn1.to_q.weight"
]
=
q
converted_state_dict
[
block_prefix
+
"attn1.to_k.weight"
]
=
k
converted_state_dict
[
block_prefix
+
"attn1.to_v.weight"
]
=
v
converted_state_dict
[
block_prefix
+
"attn1.norm_q.weight"
]
=
checkpoint
.
pop
(
old_prefix
+
"attn.q_norm_x.weight"
)
converted_state_dict
[
block_prefix
+
"attn1.norm_k.weight"
]
=
checkpoint
.
pop
(
old_prefix
+
"attn.k_norm_x.weight"
)
converted_state_dict
[
block_prefix
+
"attn1.to_out.0.weight"
]
=
checkpoint
.
pop
(
old_prefix
+
"attn.proj_x.weight"
)
converted_state_dict
[
block_prefix
+
"attn1.to_out.0.bias"
]
=
checkpoint
.
pop
(
old_prefix
+
"attn.proj_x.bias"
)
# Context attention
qkv_weight
=
checkpoint
.
pop
(
old_prefix
+
"attn.qkv_y.weight"
)
q
,
k
,
v
=
qkv_weight
.
chunk
(
3
,
dim
=
0
)
new
_state_dict
[
block_prefix
+
"attn1.add_q_proj.weight"
]
=
q
new
_state_dict
[
block_prefix
+
"attn1.add_k_proj.weight"
]
=
k
new
_state_dict
[
block_prefix
+
"attn1.add_v_proj.weight"
]
=
v
new
_state_dict
[
block_prefix
+
"attn1.norm_added_q.weight"
]
=
checkpoint
.
pop
(
converted
_state_dict
[
block_prefix
+
"attn1.add_q_proj.weight"
]
=
q
converted
_state_dict
[
block_prefix
+
"attn1.add_k_proj.weight"
]
=
k
converted
_state_dict
[
block_prefix
+
"attn1.add_v_proj.weight"
]
=
v
converted
_state_dict
[
block_prefix
+
"attn1.norm_added_q.weight"
]
=
checkpoint
.
pop
(
old_prefix
+
"attn.q_norm_y.weight"
)
new
_state_dict
[
block_prefix
+
"attn1.norm_added_k.weight"
]
=
checkpoint
.
pop
(
converted
_state_dict
[
block_prefix
+
"attn1.norm_added_k.weight"
]
=
checkpoint
.
pop
(
old_prefix
+
"attn.k_norm_y.weight"
)
if
i
<
num_layers
-
1
:
new
_state_dict
[
block_prefix
+
"attn1.to_add_out.weight"
]
=
checkpoint
.
pop
(
converted
_state_dict
[
block_prefix
+
"attn1.to_add_out.weight"
]
=
checkpoint
.
pop
(
old_prefix
+
"attn.proj_y.weight"
)
new_state_dict
[
block_prefix
+
"attn1.to_add_out.bias"
]
=
checkpoint
.
pop
(
old_prefix
+
"attn.proj_y.bias"
)
converted_state_dict
[
block_prefix
+
"attn1.to_add_out.bias"
]
=
checkpoint
.
pop
(
old_prefix
+
"attn.proj_y.bias"
)
# MLP
new
_state_dict
[
block_prefix
+
"ff.net.0.proj.weight"
]
=
swap_proj_gate
(
converted
_state_dict
[
block_prefix
+
"ff.net.0.proj.weight"
]
=
swap_proj_gate
(
checkpoint
.
pop
(
old_prefix
+
"mlp_x.w1.weight"
)
)
new
_state_dict
[
block_prefix
+
"ff.net.2.weight"
]
=
checkpoint
.
pop
(
old_prefix
+
"mlp_x.w2.weight"
)
converted
_state_dict
[
block_prefix
+
"ff.net.2.weight"
]
=
checkpoint
.
pop
(
old_prefix
+
"mlp_x.w2.weight"
)
if
i
<
num_layers
-
1
:
new
_state_dict
[
block_prefix
+
"ff_context.net.0.proj.weight"
]
=
swap_proj_gate
(
converted
_state_dict
[
block_prefix
+
"ff_context.net.0.proj.weight"
]
=
swap_proj_gate
(
checkpoint
.
pop
(
old_prefix
+
"mlp_y.w1.weight"
)
)
new_state_dict
[
block_prefix
+
"ff_context.net.2.weight"
]
=
checkpoint
.
pop
(
old_prefix
+
"mlp_y.w2.weight"
)
converted_state_dict
[
block_prefix
+
"ff_context.net.2.weight"
]
=
checkpoint
.
pop
(
old_prefix
+
"mlp_y.w2.weight"
)
# Output layers
new
_state_dict
[
"norm_out.linear.weight"
]
=
swap_scale_shift
(
checkpoint
.
pop
(
"final_layer.mod.weight"
),
dim
=
0
)
new
_state_dict
[
"norm_out.linear.bias"
]
=
swap_scale_shift
(
checkpoint
.
pop
(
"final_layer.mod.bias"
),
dim
=
0
)
new
_state_dict
[
"proj_out.weight"
]
=
checkpoint
.
pop
(
"final_layer.linear.weight"
)
new
_state_dict
[
"proj_out.bias"
]
=
checkpoint
.
pop
(
"final_layer.linear.bias"
)
converted
_state_dict
[
"norm_out.linear.weight"
]
=
swap_scale_shift
(
checkpoint
.
pop
(
"final_layer.mod.weight"
),
dim
=
0
)
converted
_state_dict
[
"norm_out.linear.bias"
]
=
swap_scale_shift
(
checkpoint
.
pop
(
"final_layer.mod.bias"
),
dim
=
0
)
converted
_state_dict
[
"proj_out.weight"
]
=
checkpoint
.
pop
(
"final_layer.linear.weight"
)
converted
_state_dict
[
"proj_out.bias"
]
=
checkpoint
.
pop
(
"final_layer.linear.bias"
)
new
_state_dict
[
"pos_frequencies"
]
=
checkpoint
.
pop
(
"pos_frequencies"
)
converted
_state_dict
[
"pos_frequencies"
]
=
checkpoint
.
pop
(
"pos_frequencies"
)
return
new
_state_dict
return
converted
_state_dict
def
convert_hunyuan_video_transformer_to_diffusers
(
checkpoint
,
**
kwargs
):
...
...
@@ -2859,3 +2895,252 @@ def convert_lumina2_to_diffusers(checkpoint, **kwargs):
converted_state_dict
[
diffusers_key
]
=
checkpoint
.
pop
(
key
)
return
converted_state_dict
def
convert_wan_transformer_to_diffusers
(
checkpoint
,
**
kwargs
):
converted_state_dict
=
{}
keys
=
list
(
checkpoint
.
keys
())
for
k
in
keys
:
if
"model.diffusion_model."
in
k
:
checkpoint
[
k
.
replace
(
"model.diffusion_model."
,
""
)]
=
checkpoint
.
pop
(
k
)
TRANSFORMER_KEYS_RENAME_DICT
=
{
"time_embedding.0"
:
"condition_embedder.time_embedder.linear_1"
,
"time_embedding.2"
:
"condition_embedder.time_embedder.linear_2"
,
"text_embedding.0"
:
"condition_embedder.text_embedder.linear_1"
,
"text_embedding.2"
:
"condition_embedder.text_embedder.linear_2"
,
"time_projection.1"
:
"condition_embedder.time_proj"
,
"cross_attn"
:
"attn2"
,
"self_attn"
:
"attn1"
,
".o."
:
".to_out.0."
,
".q."
:
".to_q."
,
".k."
:
".to_k."
,
".v."
:
".to_v."
,
".k_img."
:
".add_k_proj."
,
".v_img."
:
".add_v_proj."
,
".norm_k_img."
:
".norm_added_k."
,
"head.modulation"
:
"scale_shift_table"
,
"head.head"
:
"proj_out"
,
"modulation"
:
"scale_shift_table"
,
"ffn.0"
:
"ffn.net.0.proj"
,
"ffn.2"
:
"ffn.net.2"
,
# Hack to swap the layer names
# The original model calls the norms in following order: norm1, norm3, norm2
# We convert it to: norm1, norm2, norm3
"norm2"
:
"norm__placeholder"
,
"norm3"
:
"norm2"
,
"norm__placeholder"
:
"norm3"
,
# For the I2V model
"img_emb.proj.0"
:
"condition_embedder.image_embedder.norm1"
,
"img_emb.proj.1"
:
"condition_embedder.image_embedder.ff.net.0.proj"
,
"img_emb.proj.3"
:
"condition_embedder.image_embedder.ff.net.2"
,
"img_emb.proj.4"
:
"condition_embedder.image_embedder.norm2"
,
}
for
key
in
list
(
checkpoint
.
keys
()):
new_key
=
key
[:]
for
replace_key
,
rename_key
in
TRANSFORMER_KEYS_RENAME_DICT
.
items
():
new_key
=
new_key
.
replace
(
replace_key
,
rename_key
)
converted_state_dict
[
new_key
]
=
checkpoint
.
pop
(
key
)
return
converted_state_dict
def
convert_wan_vae_to_diffusers
(
checkpoint
,
**
kwargs
):
converted_state_dict
=
{}
# Create mappings for specific components
middle_key_mapping
=
{
# Encoder middle block
"encoder.middle.0.residual.0.gamma"
:
"encoder.mid_block.resnets.0.norm1.gamma"
,
"encoder.middle.0.residual.2.bias"
:
"encoder.mid_block.resnets.0.conv1.bias"
,
"encoder.middle.0.residual.2.weight"
:
"encoder.mid_block.resnets.0.conv1.weight"
,
"encoder.middle.0.residual.3.gamma"
:
"encoder.mid_block.resnets.0.norm2.gamma"
,
"encoder.middle.0.residual.6.bias"
:
"encoder.mid_block.resnets.0.conv2.bias"
,
"encoder.middle.0.residual.6.weight"
:
"encoder.mid_block.resnets.0.conv2.weight"
,
"encoder.middle.2.residual.0.gamma"
:
"encoder.mid_block.resnets.1.norm1.gamma"
,
"encoder.middle.2.residual.2.bias"
:
"encoder.mid_block.resnets.1.conv1.bias"
,
"encoder.middle.2.residual.2.weight"
:
"encoder.mid_block.resnets.1.conv1.weight"
,
"encoder.middle.2.residual.3.gamma"
:
"encoder.mid_block.resnets.1.norm2.gamma"
,
"encoder.middle.2.residual.6.bias"
:
"encoder.mid_block.resnets.1.conv2.bias"
,
"encoder.middle.2.residual.6.weight"
:
"encoder.mid_block.resnets.1.conv2.weight"
,
# Decoder middle block
"decoder.middle.0.residual.0.gamma"
:
"decoder.mid_block.resnets.0.norm1.gamma"
,
"decoder.middle.0.residual.2.bias"
:
"decoder.mid_block.resnets.0.conv1.bias"
,
"decoder.middle.0.residual.2.weight"
:
"decoder.mid_block.resnets.0.conv1.weight"
,
"decoder.middle.0.residual.3.gamma"
:
"decoder.mid_block.resnets.0.norm2.gamma"
,
"decoder.middle.0.residual.6.bias"
:
"decoder.mid_block.resnets.0.conv2.bias"
,
"decoder.middle.0.residual.6.weight"
:
"decoder.mid_block.resnets.0.conv2.weight"
,
"decoder.middle.2.residual.0.gamma"
:
"decoder.mid_block.resnets.1.norm1.gamma"
,
"decoder.middle.2.residual.2.bias"
:
"decoder.mid_block.resnets.1.conv1.bias"
,
"decoder.middle.2.residual.2.weight"
:
"decoder.mid_block.resnets.1.conv1.weight"
,
"decoder.middle.2.residual.3.gamma"
:
"decoder.mid_block.resnets.1.norm2.gamma"
,
"decoder.middle.2.residual.6.bias"
:
"decoder.mid_block.resnets.1.conv2.bias"
,
"decoder.middle.2.residual.6.weight"
:
"decoder.mid_block.resnets.1.conv2.weight"
,
}
# Create a mapping for attention blocks
attention_mapping
=
{
# Encoder middle attention
"encoder.middle.1.norm.gamma"
:
"encoder.mid_block.attentions.0.norm.gamma"
,
"encoder.middle.1.to_qkv.weight"
:
"encoder.mid_block.attentions.0.to_qkv.weight"
,
"encoder.middle.1.to_qkv.bias"
:
"encoder.mid_block.attentions.0.to_qkv.bias"
,
"encoder.middle.1.proj.weight"
:
"encoder.mid_block.attentions.0.proj.weight"
,
"encoder.middle.1.proj.bias"
:
"encoder.mid_block.attentions.0.proj.bias"
,
# Decoder middle attention
"decoder.middle.1.norm.gamma"
:
"decoder.mid_block.attentions.0.norm.gamma"
,
"decoder.middle.1.to_qkv.weight"
:
"decoder.mid_block.attentions.0.to_qkv.weight"
,
"decoder.middle.1.to_qkv.bias"
:
"decoder.mid_block.attentions.0.to_qkv.bias"
,
"decoder.middle.1.proj.weight"
:
"decoder.mid_block.attentions.0.proj.weight"
,
"decoder.middle.1.proj.bias"
:
"decoder.mid_block.attentions.0.proj.bias"
,
}
# Create a mapping for the head components
head_mapping
=
{
# Encoder head
"encoder.head.0.gamma"
:
"encoder.norm_out.gamma"
,
"encoder.head.2.bias"
:
"encoder.conv_out.bias"
,
"encoder.head.2.weight"
:
"encoder.conv_out.weight"
,
# Decoder head
"decoder.head.0.gamma"
:
"decoder.norm_out.gamma"
,
"decoder.head.2.bias"
:
"decoder.conv_out.bias"
,
"decoder.head.2.weight"
:
"decoder.conv_out.weight"
,
}
# Create a mapping for the quant components
quant_mapping
=
{
"conv1.weight"
:
"quant_conv.weight"
,
"conv1.bias"
:
"quant_conv.bias"
,
"conv2.weight"
:
"post_quant_conv.weight"
,
"conv2.bias"
:
"post_quant_conv.bias"
,
}
# Process each key in the state dict
for
key
,
value
in
checkpoint
.
items
():
# Handle middle block keys using the mapping
if
key
in
middle_key_mapping
:
new_key
=
middle_key_mapping
[
key
]
converted_state_dict
[
new_key
]
=
value
# Handle attention blocks using the mapping
elif
key
in
attention_mapping
:
new_key
=
attention_mapping
[
key
]
converted_state_dict
[
new_key
]
=
value
# Handle head keys using the mapping
elif
key
in
head_mapping
:
new_key
=
head_mapping
[
key
]
converted_state_dict
[
new_key
]
=
value
# Handle quant keys using the mapping
elif
key
in
quant_mapping
:
new_key
=
quant_mapping
[
key
]
converted_state_dict
[
new_key
]
=
value
# Handle encoder conv1
elif
key
==
"encoder.conv1.weight"
:
converted_state_dict
[
"encoder.conv_in.weight"
]
=
value
elif
key
==
"encoder.conv1.bias"
:
converted_state_dict
[
"encoder.conv_in.bias"
]
=
value
# Handle decoder conv1
elif
key
==
"decoder.conv1.weight"
:
converted_state_dict
[
"decoder.conv_in.weight"
]
=
value
elif
key
==
"decoder.conv1.bias"
:
converted_state_dict
[
"decoder.conv_in.bias"
]
=
value
# Handle encoder downsamples
elif
key
.
startswith
(
"encoder.downsamples."
):
# Convert to down_blocks
new_key
=
key
.
replace
(
"encoder.downsamples."
,
"encoder.down_blocks."
)
# Convert residual block naming but keep the original structure
if
".residual.0.gamma"
in
new_key
:
new_key
=
new_key
.
replace
(
".residual.0.gamma"
,
".norm1.gamma"
)
elif
".residual.2.bias"
in
new_key
:
new_key
=
new_key
.
replace
(
".residual.2.bias"
,
".conv1.bias"
)
elif
".residual.2.weight"
in
new_key
:
new_key
=
new_key
.
replace
(
".residual.2.weight"
,
".conv1.weight"
)
elif
".residual.3.gamma"
in
new_key
:
new_key
=
new_key
.
replace
(
".residual.3.gamma"
,
".norm2.gamma"
)
elif
".residual.6.bias"
in
new_key
:
new_key
=
new_key
.
replace
(
".residual.6.bias"
,
".conv2.bias"
)
elif
".residual.6.weight"
in
new_key
:
new_key
=
new_key
.
replace
(
".residual.6.weight"
,
".conv2.weight"
)
elif
".shortcut.bias"
in
new_key
:
new_key
=
new_key
.
replace
(
".shortcut.bias"
,
".conv_shortcut.bias"
)
elif
".shortcut.weight"
in
new_key
:
new_key
=
new_key
.
replace
(
".shortcut.weight"
,
".conv_shortcut.weight"
)
converted_state_dict
[
new_key
]
=
value
# Handle decoder upsamples
elif
key
.
startswith
(
"decoder.upsamples."
):
# Convert to up_blocks
parts
=
key
.
split
(
"."
)
block_idx
=
int
(
parts
[
2
])
# Group residual blocks
if
"residual"
in
key
:
if
block_idx
in
[
0
,
1
,
2
]:
new_block_idx
=
0
resnet_idx
=
block_idx
elif
block_idx
in
[
4
,
5
,
6
]:
new_block_idx
=
1
resnet_idx
=
block_idx
-
4
elif
block_idx
in
[
8
,
9
,
10
]:
new_block_idx
=
2
resnet_idx
=
block_idx
-
8
elif
block_idx
in
[
12
,
13
,
14
]:
new_block_idx
=
3
resnet_idx
=
block_idx
-
12
else
:
# Keep as is for other blocks
converted_state_dict
[
key
]
=
value
continue
# Convert residual block naming
if
".residual.0.gamma"
in
key
:
new_key
=
f
"decoder.up_blocks.
{
new_block_idx
}
.resnets.
{
resnet_idx
}
.norm1.gamma"
elif
".residual.2.bias"
in
key
:
new_key
=
f
"decoder.up_blocks.
{
new_block_idx
}
.resnets.
{
resnet_idx
}
.conv1.bias"
elif
".residual.2.weight"
in
key
:
new_key
=
f
"decoder.up_blocks.
{
new_block_idx
}
.resnets.
{
resnet_idx
}
.conv1.weight"
elif
".residual.3.gamma"
in
key
:
new_key
=
f
"decoder.up_blocks.
{
new_block_idx
}
.resnets.
{
resnet_idx
}
.norm2.gamma"
elif
".residual.6.bias"
in
key
:
new_key
=
f
"decoder.up_blocks.
{
new_block_idx
}
.resnets.
{
resnet_idx
}
.conv2.bias"
elif
".residual.6.weight"
in
key
:
new_key
=
f
"decoder.up_blocks.
{
new_block_idx
}
.resnets.
{
resnet_idx
}
.conv2.weight"
else
:
new_key
=
key
converted_state_dict
[
new_key
]
=
value
# Handle shortcut connections
elif
".shortcut."
in
key
:
if
block_idx
==
4
:
new_key
=
key
.
replace
(
".shortcut."
,
".resnets.0.conv_shortcut."
)
new_key
=
new_key
.
replace
(
"decoder.upsamples.4"
,
"decoder.up_blocks.1"
)
else
:
new_key
=
key
.
replace
(
"decoder.upsamples."
,
"decoder.up_blocks."
)
new_key
=
new_key
.
replace
(
".shortcut."
,
".conv_shortcut."
)
converted_state_dict
[
new_key
]
=
value
# Handle upsamplers
elif
".resample."
in
key
or
".time_conv."
in
key
:
if
block_idx
==
3
:
new_key
=
key
.
replace
(
f
"decoder.upsamples.
{
block_idx
}
"
,
"decoder.up_blocks.0.upsamplers.0"
)
elif
block_idx
==
7
:
new_key
=
key
.
replace
(
f
"decoder.upsamples.
{
block_idx
}
"
,
"decoder.up_blocks.1.upsamplers.0"
)
elif
block_idx
==
11
:
new_key
=
key
.
replace
(
f
"decoder.upsamples.
{
block_idx
}
"
,
"decoder.up_blocks.2.upsamplers.0"
)
else
:
new_key
=
key
.
replace
(
"decoder.upsamples."
,
"decoder.up_blocks."
)
converted_state_dict
[
new_key
]
=
value
else
:
new_key
=
key
.
replace
(
"decoder.upsamples."
,
"decoder.up_blocks."
)
converted_state_dict
[
new_key
]
=
value
else
:
# Keep other keys unchanged
converted_state_dict
[
key
]
=
value
return
converted_state_dict
src/diffusers/models/attention_processor.py
View file @
1357931d
...
...
@@ -284,8 +284,9 @@ class Attention(nn.Module):
self
.
norm_added_q
=
RMSNorm
(
dim_head
,
eps
=
eps
)
self
.
norm_added_k
=
RMSNorm
(
dim_head
,
eps
=
eps
)
elif
qk_norm
==
"rms_norm_across_heads"
:
# Wanx applies qk norm across all heads
self
.
norm_added_q
=
RMSNorm
(
dim_head
*
heads
,
eps
=
eps
)
# Wan applies qk norm across all heads
# Wan also doesn't apply a q norm
self
.
norm_added_q
=
None
self
.
norm_added_k
=
RMSNorm
(
dim_head
*
kv_heads
,
eps
=
eps
)
else
:
raise
ValueError
(
...
...
src/diffusers/models/autoencoders/autoencoder_kl_wan.py
View file @
1357931d
...
...
@@ -20,6 +20,7 @@ import torch.nn.functional as F
import
torch.utils.checkpoint
from
...configuration_utils
import
ConfigMixin
,
register_to_config
from
...loaders
import
FromOriginalModelMixin
from
...utils
import
logging
from
...utils.accelerate_utils
import
apply_forward_hook
from
..activations
import
get_activation
...
...
@@ -655,7 +656,7 @@ class WanDecoder3d(nn.Module):
return
x
class
AutoencoderKLWan
(
ModelMixin
,
ConfigMixin
):
class
AutoencoderKLWan
(
ModelMixin
,
ConfigMixin
,
FromOriginalModelMixin
):
r
"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
Introduced in [Wan 2.1].
...
...
src/diffusers/models/transformers/transformer_wan.py
View file @
1357931d
...
...
@@ -20,7 +20,7 @@ import torch.nn as nn
import
torch.nn.functional
as
F
from
...configuration_utils
import
ConfigMixin
,
register_to_config
from
...loaders
import
PeftAdapterMixin
from
...loaders
import
FromOriginalModelMixin
,
PeftAdapterMixin
from
...utils
import
USE_PEFT_BACKEND
,
logging
,
scale_lora_layers
,
unscale_lora_layers
from
..attention
import
FeedForward
from
..attention_processor
import
Attention
...
...
@@ -288,7 +288,7 @@ class WanTransformerBlock(nn.Module):
return
hidden_states
class
WanTransformer3DModel
(
ModelMixin
,
ConfigMixin
,
PeftAdapterMixin
):
class
WanTransformer3DModel
(
ModelMixin
,
ConfigMixin
,
PeftAdapterMixin
,
FromOriginalModelMixin
):
r
"""
A Transformer model for video-like data used in the Wan model.
...
...
@@ -329,6 +329,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
_skip_layerwise_casting_patterns
=
[
"patch_embedding"
,
"condition_embedder"
,
"norm"
]
_no_split_modules
=
[
"WanTransformerBlock"
]
_keep_in_fp32_modules
=
[
"time_embedder"
,
"scale_shift_table"
,
"norm1"
,
"norm2"
,
"norm3"
]
_keys_to_ignore_on_load_unexpected
=
[
"norm_added_q"
]
@
register_to_config
def
__init__
(
...
...
tests/single_file/test_model_wan_autoencoder_single_file.py
0 → 100644
View file @
1357931d
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
gc
import
unittest
from
diffusers
import
(
AutoencoderKLWan
,
)
from
diffusers.utils.testing_utils
import
(
backend_empty_cache
,
enable_full_determinism
,
require_torch_accelerator
,
torch_device
,
)
enable_full_determinism
()
@
require_torch_accelerator
class
AutoencoderKLWanSingleFileTests
(
unittest
.
TestCase
):
model_class
=
AutoencoderKLWan
ckpt_path
=
(
"https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/vae/wan_2.1_vae.safetensors"
)
repo_id
=
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
def
setUp
(
self
):
super
().
setUp
()
gc
.
collect
()
backend_empty_cache
(
torch_device
)
def
tearDown
(
self
):
super
().
tearDown
()
gc
.
collect
()
backend_empty_cache
(
torch_device
)
def
test_single_file_components
(
self
):
model
=
self
.
model_class
.
from_pretrained
(
self
.
repo_id
,
subfolder
=
"vae"
)
model_single_file
=
self
.
model_class
.
from_single_file
(
self
.
ckpt_path
)
PARAMS_TO_IGNORE
=
[
"torch_dtype"
,
"_name_or_path"
,
"_use_default_values"
,
"_diffusers_version"
]
for
param_name
,
param_value
in
model_single_file
.
config
.
items
():
if
param_name
in
PARAMS_TO_IGNORE
:
continue
assert
(
model
.
config
[
param_name
]
==
param_value
),
f
"
{
param_name
}
differs between single file loading and pretrained loading"
tests/single_file/test_model_wan_transformer3d_single_file.py
0 → 100644
View file @
1357931d
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
gc
import
unittest
import
torch
from
diffusers
import
(
WanTransformer3DModel
,
)
from
diffusers.utils.testing_utils
import
(
backend_empty_cache
,
enable_full_determinism
,
require_big_gpu_with_torch_cuda
,
require_torch_accelerator
,
torch_device
,
)
enable_full_determinism
()
@
require_torch_accelerator
class
WanTransformer3DModelText2VideoSingleFileTest
(
unittest
.
TestCase
):
model_class
=
WanTransformer3DModel
ckpt_path
=
"https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors"
repo_id
=
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
def
setUp
(
self
):
super
().
setUp
()
gc
.
collect
()
backend_empty_cache
(
torch_device
)
def
tearDown
(
self
):
super
().
tearDown
()
gc
.
collect
()
backend_empty_cache
(
torch_device
)
def
test_single_file_components
(
self
):
model
=
self
.
model_class
.
from_pretrained
(
self
.
repo_id
,
subfolder
=
"transformer"
)
model_single_file
=
self
.
model_class
.
from_single_file
(
self
.
ckpt_path
)
PARAMS_TO_IGNORE
=
[
"torch_dtype"
,
"_name_or_path"
,
"_use_default_values"
,
"_diffusers_version"
]
for
param_name
,
param_value
in
model_single_file
.
config
.
items
():
if
param_name
in
PARAMS_TO_IGNORE
:
continue
assert
(
model
.
config
[
param_name
]
==
param_value
),
f
"
{
param_name
}
differs between single file loading and pretrained loading"
@
require_big_gpu_with_torch_cuda
@
require_torch_accelerator
class
WanTransformer3DModelImage2VideoSingleFileTest
(
unittest
.
TestCase
):
model_class
=
WanTransformer3DModel
ckpt_path
=
"https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_i2v_480p_14B_fp8_e4m3fn.safetensors"
repo_id
=
"Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
torch_dtype
=
torch
.
float8_e4m3fn
def
setUp
(
self
):
super
().
setUp
()
gc
.
collect
()
backend_empty_cache
(
torch_device
)
def
tearDown
(
self
):
super
().
tearDown
()
gc
.
collect
()
backend_empty_cache
(
torch_device
)
def
test_single_file_components
(
self
):
model
=
self
.
model_class
.
from_pretrained
(
self
.
repo_id
,
subfolder
=
"transformer"
,
torch_dtype
=
self
.
torch_dtype
)
model_single_file
=
self
.
model_class
.
from_single_file
(
self
.
ckpt_path
,
torch_dtype
=
self
.
torch_dtype
)
PARAMS_TO_IGNORE
=
[
"torch_dtype"
,
"_name_or_path"
,
"_use_default_values"
,
"_diffusers_version"
]
for
param_name
,
param_value
in
model_single_file
.
config
.
items
():
if
param_name
in
PARAMS_TO_IGNORE
:
continue
assert
(
model
.
config
[
param_name
]
==
param_value
),
f
"
{
param_name
}
differs between single file loading and pretrained loading"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment