Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
3914d5a2
Commit
3914d5a2
authored
Jun 19, 2024
by
comfyanonymous
Browse files
Support full SD3 loras.
parent
55f0dc12
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
79 additions
and
10 deletions
+79
-10
comfy/lora.py
comfy/lora.py
+9
-10
comfy/utils.py
comfy/utils.py
+70
-0
No files found.
comfy/lora.py
View file @
3914d5a2
...
...
@@ -252,15 +252,14 @@ def model_lora_keys_unet(model, key_map={}):
key_map
[
diffusers_lora_key
]
=
unet_key
if
isinstance
(
model
,
comfy
.
model_base
.
SD3
):
#Diffusers lora SD3
for
i
in
range
(
model
.
model_config
.
unet_config
.
get
(
"depth"
,
0
)):
k
=
"transformer.transformer_blocks.{}.attn."
.
format
(
i
)
qkv
=
"diffusion_model.joint_blocks.{}.x_block.attn.qkv.weight"
.
format
(
i
)
proj
=
"diffusion_model.joint_blocks.{}.x_block.attn.proj.weight"
.
format
(
i
)
if
qkv
in
sd
:
offset
=
sd
[
qkv
].
shape
[
0
]
//
3
key_map
[
"{}to_q"
.
format
(
k
)]
=
(
qkv
,
(
0
,
0
,
offset
))
key_map
[
"{}to_k"
.
format
(
k
)]
=
(
qkv
,
(
0
,
offset
,
offset
))
key_map
[
"{}to_v"
.
format
(
k
)]
=
(
qkv
,
(
0
,
offset
*
2
,
offset
))
key_map
[
"{}to_out.0"
.
format
(
k
)]
=
proj
diffusers_keys
=
comfy
.
utils
.
mmdit_to_diffusers
(
model
.
model_config
.
unet_config
,
output_prefix
=
"diffusion_model."
)
for
k
in
diffusers_keys
:
if
k
.
endswith
(
".weight"
):
to
=
diffusers_keys
[
k
]
key_lora
=
"transformer.{}"
.
format
(
k
[:
-
len
(
".weight"
)])
#regular diffusers sd3 lora format
key_map
[
key_lora
]
=
to
key_lora
=
"base_model.model.{}"
.
format
(
k
[:
-
len
(
".weight"
)])
#format for flash-sd3 lora and others?
key_map
[
key_lora
]
=
to
return
key_map
comfy/utils.py
View file @
3914d5a2
...
...
@@ -249,6 +249,76 @@ def unet_to_diffusers(unet_config):
return
diffusers_unet_map
MMDIT_MAP_BASIC
=
{
(
"context_embedder.bias"
,
"context_embedder.bias"
),
(
"context_embedder.weight"
,
"context_embedder.weight"
),
(
"t_embedder.mlp.0.bias"
,
"time_text_embed.timestep_embedder.linear_1.bias"
),
(
"t_embedder.mlp.0.weight"
,
"time_text_embed.timestep_embedder.linear_1.weight"
),
(
"t_embedder.mlp.2.bias"
,
"time_text_embed.timestep_embedder.linear_2.bias"
),
(
"t_embedder.mlp.2.weight"
,
"time_text_embed.timestep_embedder.linear_2.weight"
),
(
"x_embedder.proj.bias"
,
"pos_embed.proj.bias"
),
(
"x_embedder.proj.weight"
,
"pos_embed.proj.weight"
),
(
"y_embedder.mlp.0.bias"
,
"time_text_embed.text_embedder.linear_1.bias"
),
(
"y_embedder.mlp.0.weight"
,
"time_text_embed.text_embedder.linear_1.weight"
),
(
"y_embedder.mlp.2.bias"
,
"time_text_embed.text_embedder.linear_2.bias"
),
(
"y_embedder.mlp.2.weight"
,
"time_text_embed.text_embedder.linear_2.weight"
),
(
"pos_embed"
,
"pos_embed.pos_embed"
),
(
"final_layer.adaLN_modulation.1.bias"
,
"norm_out.linear.bias"
),
(
"final_layer.adaLN_modulation.1.weight"
,
"norm_out.linear.weight"
),
(
"final_layer.linear.bias"
,
"proj_out.bias"
),
(
"final_layer.linear.weight"
,
"proj_out.weight"
),
}
MMDIT_MAP_BLOCK
=
{
(
"context_block.adaLN_modulation.1.bias"
,
"norm1_context.linear.bias"
),
(
"context_block.adaLN_modulation.1.weight"
,
"norm1_context.linear.weight"
),
(
"context_block.attn.proj.bias"
,
"attn.to_add_out.bias"
),
(
"context_block.attn.proj.weight"
,
"attn.to_add_out.weight"
),
(
"context_block.mlp.fc1.bias"
,
"ff_context.net.0.proj.bias"
),
(
"context_block.mlp.fc1.weight"
,
"ff_context.net.0.proj.weight"
),
(
"context_block.mlp.fc2.bias"
,
"ff_context.net.2.bias"
),
(
"context_block.mlp.fc2.weight"
,
"ff_context.net.2.weight"
),
(
"x_block.adaLN_modulation.1.bias"
,
"norm1.linear.bias"
),
(
"x_block.adaLN_modulation.1.weight"
,
"norm1.linear.weight"
),
(
"x_block.attn.proj.bias"
,
"attn.to_out.0.bias"
),
(
"x_block.attn.proj.weight"
,
"attn.to_out.0.weight"
),
(
"x_block.mlp.fc1.bias"
,
"ff.net.0.proj.bias"
),
(
"x_block.mlp.fc1.weight"
,
"ff.net.0.proj.weight"
),
(
"x_block.mlp.fc2.bias"
,
"ff.net.2.bias"
),
(
"x_block.mlp.fc2.weight"
,
"ff.net.2.weight"
),
(
""
,
""
),
}
def
mmdit_to_diffusers
(
mmdit_config
,
output_prefix
=
""
):
key_map
=
{}
depth
=
mmdit_config
.
get
(
"depth"
,
0
)
for
i
in
range
(
depth
):
block_from
=
"transformer_blocks.{}"
.
format
(
i
)
block_to
=
"{}joint_blocks.{}"
.
format
(
output_prefix
,
i
)
offset
=
depth
*
64
for
end
in
(
"weight"
,
"bias"
):
k
=
"{}.attn."
.
format
(
block_from
)
qkv
=
"{}.x_block.attn.qkv.{}"
.
format
(
block_to
,
end
)
key_map
[
"{}to_q.{}"
.
format
(
k
,
end
)]
=
(
qkv
,
(
0
,
0
,
offset
))
key_map
[
"{}to_k.{}"
.
format
(
k
,
end
)]
=
(
qkv
,
(
0
,
offset
,
offset
))
key_map
[
"{}to_v.{}"
.
format
(
k
,
end
)]
=
(
qkv
,
(
0
,
offset
*
2
,
offset
))
qkv
=
"{}.context_block.attn.qkv.{}"
.
format
(
block_to
,
end
)
key_map
[
"{}add_q_proj.{}"
.
format
(
k
,
end
)]
=
(
qkv
,
(
0
,
0
,
offset
))
key_map
[
"{}add_k_proj.{}"
.
format
(
k
,
end
)]
=
(
qkv
,
(
0
,
offset
,
offset
))
key_map
[
"{}add_v_proj.{}"
.
format
(
k
,
end
)]
=
(
qkv
,
(
0
,
offset
*
2
,
offset
))
for
k
in
MMDIT_MAP_BLOCK
:
key_map
[
"{}.{}"
.
format
(
block_from
,
k
[
1
])]
=
"{}.{}"
.
format
(
block_to
,
k
[
0
])
for
k
in
MMDIT_MAP_BASIC
:
key_map
[
k
[
1
]]
=
"{}{}"
.
format
(
output_prefix
,
k
[
0
])
return
key_map
def
repeat_to_batch_size
(
tensor
,
batch_size
,
dim
=
0
):
if
tensor
.
shape
[
dim
]
>
batch_size
:
return
tensor
.
narrow
(
dim
,
0
,
batch_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment