Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
9764f229
Unverified
Commit
9764f229
authored
Dec 19, 2024
by
Dhruv Nair
Committed by
GitHub
Dec 19, 2024
Browse files
[Single File] Add single file support for Mochi Transformer (#10268)
update
parent
1826a1e7
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
116 additions
and
1 deletion
+116
-1
src/diffusers/loaders/single_file_model.py
src/diffusers/loaders/single_file_model.py
+5
-0
src/diffusers/loaders/single_file_utils.py
src/diffusers/loaders/single_file_utils.py
+109
-0
src/diffusers/models/transformers/transformer_mochi.py
src/diffusers/models/transformers/transformer_mochi.py
+2
-1
No files found.
src/diffusers/loaders/single_file_model.py
View file @
9764f229
...
@@ -32,6 +32,7 @@ from .single_file_utils import (
...
@@ -32,6 +32,7 @@ from .single_file_utils import (
convert_ldm_vae_checkpoint
,
convert_ldm_vae_checkpoint
,
convert_ltx_transformer_checkpoint_to_diffusers
,
convert_ltx_transformer_checkpoint_to_diffusers
,
convert_ltx_vae_checkpoint_to_diffusers
,
convert_ltx_vae_checkpoint_to_diffusers
,
convert_mochi_transformer_checkpoint_to_diffusers
,
convert_sd3_transformer_checkpoint_to_diffusers
,
convert_sd3_transformer_checkpoint_to_diffusers
,
convert_stable_cascade_unet_single_file_to_diffusers
,
convert_stable_cascade_unet_single_file_to_diffusers
,
create_controlnet_diffusers_config_from_ldm
,
create_controlnet_diffusers_config_from_ldm
,
...
@@ -96,6 +97,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
...
@@ -96,6 +97,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"default_subfolder"
:
"vae"
,
"default_subfolder"
:
"vae"
,
},
},
"AutoencoderDC"
:
{
"checkpoint_mapping_fn"
:
convert_autoencoder_dc_checkpoint_to_diffusers
},
"AutoencoderDC"
:
{
"checkpoint_mapping_fn"
:
convert_autoencoder_dc_checkpoint_to_diffusers
},
"MochiTransformer3DModel"
:
{
"checkpoint_mapping_fn"
:
convert_mochi_transformer_checkpoint_to_diffusers
,
"default_subfolder"
:
"transformer"
,
},
}
}
...
...
src/diffusers/loaders/single_file_utils.py
View file @
9764f229
...
@@ -106,6 +106,7 @@ CHECKPOINT_KEY_NAMES = {
...
@@ -106,6 +106,7 @@ CHECKPOINT_KEY_NAMES = {
],
],
"autoencoder-dc"
:
"decoder.stages.1.op_list.0.main.conv.conv.bias"
,
"autoencoder-dc"
:
"decoder.stages.1.op_list.0.main.conv.conv.bias"
,
"autoencoder-dc-sana"
:
"encoder.project_in.conv.bias"
,
"autoencoder-dc-sana"
:
"encoder.project_in.conv.bias"
,
"mochi-1-preview"
:
[
"model.diffusion_model.blocks.0.attn.qkv_x.weight"
,
"blocks.0.attn.qkv_x.weight"
],
}
}
DIFFUSERS_DEFAULT_PIPELINE_PATHS
=
{
DIFFUSERS_DEFAULT_PIPELINE_PATHS
=
{
...
@@ -159,6 +160,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
...
@@ -159,6 +160,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"autoencoder-dc-f64c128"
:
{
"pretrained_model_name_or_path"
:
"mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"
},
"autoencoder-dc-f64c128"
:
{
"pretrained_model_name_or_path"
:
"mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"
},
"autoencoder-dc-f32c32"
:
{
"pretrained_model_name_or_path"
:
"mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"
},
"autoencoder-dc-f32c32"
:
{
"pretrained_model_name_or_path"
:
"mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"
},
"autoencoder-dc-f32c32-sana"
:
{
"pretrained_model_name_or_path"
:
"mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"
},
"autoencoder-dc-f32c32-sana"
:
{
"pretrained_model_name_or_path"
:
"mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"
},
"mochi-1-preview"
:
{
"pretrained_model_name_or_path"
:
"genmo/mochi-1-preview"
},
}
}
# Use to configure model sample size when original config is provided
# Use to configure model sample size when original config is provided
...
@@ -618,6 +620,9 @@ def infer_diffusers_model_type(checkpoint):
...
@@ -618,6 +620,9 @@ def infer_diffusers_model_type(checkpoint):
else
:
else
:
model_type
=
"autoencoder-dc-f128c512"
model_type
=
"autoencoder-dc-f128c512"
elif
any
(
key
in
checkpoint
for
key
in
CHECKPOINT_KEY_NAMES
[
"mochi-1-preview"
]):
model_type
=
"mochi-1-preview"
else
:
else
:
model_type
=
"v1"
model_type
=
"v1"
...
@@ -1758,6 +1763,12 @@ def swap_scale_shift(weight, dim):
...
@@ -1758,6 +1763,12 @@ def swap_scale_shift(weight, dim):
return
new_weight
return
new_weight
def
swap_proj_gate
(
weight
):
proj
,
gate
=
weight
.
chunk
(
2
,
dim
=
0
)
new_weight
=
torch
.
cat
([
gate
,
proj
],
dim
=
0
)
return
new_weight
def
get_attn2_layers
(
state_dict
):
def
get_attn2_layers
(
state_dict
):
attn2_layers
=
[]
attn2_layers
=
[]
for
key
in
state_dict
.
keys
():
for
key
in
state_dict
.
keys
():
...
@@ -2414,3 +2425,101 @@ def convert_autoencoder_dc_checkpoint_to_diffusers(checkpoint, **kwargs):
...
@@ -2414,3 +2425,101 @@ def convert_autoencoder_dc_checkpoint_to_diffusers(checkpoint, **kwargs):
handler_fn_inplace
(
key
,
converted_state_dict
)
handler_fn_inplace
(
key
,
converted_state_dict
)
return
converted_state_dict
return
converted_state_dict
def
convert_mochi_transformer_checkpoint_to_diffusers
(
checkpoint
,
**
kwargs
):
new_state_dict
=
{}
# Comfy checkpoints add this prefix
keys
=
list
(
checkpoint
.
keys
())
for
k
in
keys
:
if
"model.diffusion_model."
in
k
:
checkpoint
[
k
.
replace
(
"model.diffusion_model."
,
""
)]
=
checkpoint
.
pop
(
k
)
# Convert patch_embed
new_state_dict
[
"patch_embed.proj.weight"
]
=
checkpoint
.
pop
(
"x_embedder.proj.weight"
)
new_state_dict
[
"patch_embed.proj.bias"
]
=
checkpoint
.
pop
(
"x_embedder.proj.bias"
)
# Convert time_embed
new_state_dict
[
"time_embed.timestep_embedder.linear_1.weight"
]
=
checkpoint
.
pop
(
"t_embedder.mlp.0.weight"
)
new_state_dict
[
"time_embed.timestep_embedder.linear_1.bias"
]
=
checkpoint
.
pop
(
"t_embedder.mlp.0.bias"
)
new_state_dict
[
"time_embed.timestep_embedder.linear_2.weight"
]
=
checkpoint
.
pop
(
"t_embedder.mlp.2.weight"
)
new_state_dict
[
"time_embed.timestep_embedder.linear_2.bias"
]
=
checkpoint
.
pop
(
"t_embedder.mlp.2.bias"
)
new_state_dict
[
"time_embed.pooler.to_kv.weight"
]
=
checkpoint
.
pop
(
"t5_y_embedder.to_kv.weight"
)
new_state_dict
[
"time_embed.pooler.to_kv.bias"
]
=
checkpoint
.
pop
(
"t5_y_embedder.to_kv.bias"
)
new_state_dict
[
"time_embed.pooler.to_q.weight"
]
=
checkpoint
.
pop
(
"t5_y_embedder.to_q.weight"
)
new_state_dict
[
"time_embed.pooler.to_q.bias"
]
=
checkpoint
.
pop
(
"t5_y_embedder.to_q.bias"
)
new_state_dict
[
"time_embed.pooler.to_out.weight"
]
=
checkpoint
.
pop
(
"t5_y_embedder.to_out.weight"
)
new_state_dict
[
"time_embed.pooler.to_out.bias"
]
=
checkpoint
.
pop
(
"t5_y_embedder.to_out.bias"
)
new_state_dict
[
"time_embed.caption_proj.weight"
]
=
checkpoint
.
pop
(
"t5_yproj.weight"
)
new_state_dict
[
"time_embed.caption_proj.bias"
]
=
checkpoint
.
pop
(
"t5_yproj.bias"
)
# Convert transformer blocks
num_layers
=
48
for
i
in
range
(
num_layers
):
block_prefix
=
f
"transformer_blocks.
{
i
}
."
old_prefix
=
f
"blocks.
{
i
}
."
# norm1
new_state_dict
[
block_prefix
+
"norm1.linear.weight"
]
=
checkpoint
.
pop
(
old_prefix
+
"mod_x.weight"
)
new_state_dict
[
block_prefix
+
"norm1.linear.bias"
]
=
checkpoint
.
pop
(
old_prefix
+
"mod_x.bias"
)
if
i
<
num_layers
-
1
:
new_state_dict
[
block_prefix
+
"norm1_context.linear.weight"
]
=
checkpoint
.
pop
(
old_prefix
+
"mod_y.weight"
)
new_state_dict
[
block_prefix
+
"norm1_context.linear.bias"
]
=
checkpoint
.
pop
(
old_prefix
+
"mod_y.bias"
)
else
:
new_state_dict
[
block_prefix
+
"norm1_context.linear_1.weight"
]
=
checkpoint
.
pop
(
old_prefix
+
"mod_y.weight"
)
new_state_dict
[
block_prefix
+
"norm1_context.linear_1.bias"
]
=
checkpoint
.
pop
(
old_prefix
+
"mod_y.bias"
)
# Visual attention
qkv_weight
=
checkpoint
.
pop
(
old_prefix
+
"attn.qkv_x.weight"
)
q
,
k
,
v
=
qkv_weight
.
chunk
(
3
,
dim
=
0
)
new_state_dict
[
block_prefix
+
"attn1.to_q.weight"
]
=
q
new_state_dict
[
block_prefix
+
"attn1.to_k.weight"
]
=
k
new_state_dict
[
block_prefix
+
"attn1.to_v.weight"
]
=
v
new_state_dict
[
block_prefix
+
"attn1.norm_q.weight"
]
=
checkpoint
.
pop
(
old_prefix
+
"attn.q_norm_x.weight"
)
new_state_dict
[
block_prefix
+
"attn1.norm_k.weight"
]
=
checkpoint
.
pop
(
old_prefix
+
"attn.k_norm_x.weight"
)
new_state_dict
[
block_prefix
+
"attn1.to_out.0.weight"
]
=
checkpoint
.
pop
(
old_prefix
+
"attn.proj_x.weight"
)
new_state_dict
[
block_prefix
+
"attn1.to_out.0.bias"
]
=
checkpoint
.
pop
(
old_prefix
+
"attn.proj_x.bias"
)
# Context attention
qkv_weight
=
checkpoint
.
pop
(
old_prefix
+
"attn.qkv_y.weight"
)
q
,
k
,
v
=
qkv_weight
.
chunk
(
3
,
dim
=
0
)
new_state_dict
[
block_prefix
+
"attn1.add_q_proj.weight"
]
=
q
new_state_dict
[
block_prefix
+
"attn1.add_k_proj.weight"
]
=
k
new_state_dict
[
block_prefix
+
"attn1.add_v_proj.weight"
]
=
v
new_state_dict
[
block_prefix
+
"attn1.norm_added_q.weight"
]
=
checkpoint
.
pop
(
old_prefix
+
"attn.q_norm_y.weight"
)
new_state_dict
[
block_prefix
+
"attn1.norm_added_k.weight"
]
=
checkpoint
.
pop
(
old_prefix
+
"attn.k_norm_y.weight"
)
if
i
<
num_layers
-
1
:
new_state_dict
[
block_prefix
+
"attn1.to_add_out.weight"
]
=
checkpoint
.
pop
(
old_prefix
+
"attn.proj_y.weight"
)
new_state_dict
[
block_prefix
+
"attn1.to_add_out.bias"
]
=
checkpoint
.
pop
(
old_prefix
+
"attn.proj_y.bias"
)
# MLP
new_state_dict
[
block_prefix
+
"ff.net.0.proj.weight"
]
=
swap_proj_gate
(
checkpoint
.
pop
(
old_prefix
+
"mlp_x.w1.weight"
)
)
new_state_dict
[
block_prefix
+
"ff.net.2.weight"
]
=
checkpoint
.
pop
(
old_prefix
+
"mlp_x.w2.weight"
)
if
i
<
num_layers
-
1
:
new_state_dict
[
block_prefix
+
"ff_context.net.0.proj.weight"
]
=
swap_proj_gate
(
checkpoint
.
pop
(
old_prefix
+
"mlp_y.w1.weight"
)
)
new_state_dict
[
block_prefix
+
"ff_context.net.2.weight"
]
=
checkpoint
.
pop
(
old_prefix
+
"mlp_y.w2.weight"
)
# Output layers
new_state_dict
[
"norm_out.linear.weight"
]
=
swap_scale_shift
(
checkpoint
.
pop
(
"final_layer.mod.weight"
),
dim
=
0
)
new_state_dict
[
"norm_out.linear.bias"
]
=
swap_scale_shift
(
checkpoint
.
pop
(
"final_layer.mod.bias"
),
dim
=
0
)
new_state_dict
[
"proj_out.weight"
]
=
checkpoint
.
pop
(
"final_layer.linear.weight"
)
new_state_dict
[
"proj_out.bias"
]
=
checkpoint
.
pop
(
"final_layer.linear.bias"
)
new_state_dict
[
"pos_frequencies"
]
=
checkpoint
.
pop
(
"pos_frequencies"
)
return
new_state_dict
src/diffusers/models/transformers/transformer_mochi.py
View file @
9764f229
...
@@ -20,6 +20,7 @@ import torch.nn as nn
...
@@ -20,6 +20,7 @@ import torch.nn as nn
from
...configuration_utils
import
ConfigMixin
,
register_to_config
from
...configuration_utils
import
ConfigMixin
,
register_to_config
from
...loaders
import
PeftAdapterMixin
from
...loaders
import
PeftAdapterMixin
from
...loaders.single_file_model
import
FromOriginalModelMixin
from
...utils
import
USE_PEFT_BACKEND
,
is_torch_version
,
logging
,
scale_lora_layers
,
unscale_lora_layers
from
...utils
import
USE_PEFT_BACKEND
,
is_torch_version
,
logging
,
scale_lora_layers
,
unscale_lora_layers
from
...utils.torch_utils
import
maybe_allow_in_graph
from
...utils.torch_utils
import
maybe_allow_in_graph
from
..attention
import
FeedForward
from
..attention
import
FeedForward
...
@@ -304,7 +305,7 @@ class MochiRoPE(nn.Module):
...
@@ -304,7 +305,7 @@ class MochiRoPE(nn.Module):
@
maybe_allow_in_graph
@
maybe_allow_in_graph
class
MochiTransformer3DModel
(
ModelMixin
,
ConfigMixin
,
PeftAdapterMixin
):
class
MochiTransformer3DModel
(
ModelMixin
,
ConfigMixin
,
PeftAdapterMixin
,
FromOriginalModelMixin
):
r
"""
r
"""
A Transformer model for video-like data introduced in [Mochi](https://huggingface.co/genmo/mochi-1-preview).
A Transformer model for video-like data introduced in [Mochi](https://huggingface.co/genmo/mochi-1-preview).
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment