Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
0d9d98fe
Unverified
Commit
0d9d98fe
authored
Oct 22, 2024
by
Dhruv Nair
Committed by
GitHub
Oct 22, 2024
Browse files
Fix typos (#9739)
* update * update * update * update * update * update
parent
60ffa842
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
95 additions
and
2 deletions
+95
-2
docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md
...e/en/api/pipelines/stable_diffusion/stable_diffusion_3.md
+20
-0
src/diffusers/loaders/single_file_utils.py
src/diffusers/loaders/single_file_utils.py
+75
-2
No files found.
docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md
View file @
0d9d98fe
...
...
@@ -313,6 +313,26 @@ image = pipe("a picture of a cat holding a sign that says hello world").images[0
image
.
save
(
'sd3-single-file-t5-fp8.png'
)
```
### Loading the single file checkpoint for the Stable Diffusion 3.5 Transformer Model
```
python
import
torch
from
diffusers
import
SD3Transformer2DModel
,
StableDiffusion3Pipeline
transformer
=
SD3Transformer2DModel
.
from_single_file
(
"https://huggingface.co/stabilityai/stable-diffusion-3.5-large-turbo/blob/main/sd3.5_large.safetensors"
,
torch_dtype
=
torch
.
bfloat16
,
)
pipe
=
StableDiffusion3Pipeline
.
from_pretrained
(
"stabilityai/stable-diffusion-3.5-large"
,
transformer
=
transformer
,
torch_dtype
=
torch
.
bfloat16
,
)
pipe
.
enable_model_cpu_offload
()
image
=
pipe
(
"a cat holding a sign that says hello world"
).
images
[
0
]
image
.
save
(
"sd35.png"
)
```
## StableDiffusion3Pipeline
[[autodoc]] StableDiffusion3Pipeline
...
...
src/diffusers/loaders/single_file_utils.py
View file @
0d9d98fe
...
...
@@ -75,6 +75,7 @@ CHECKPOINT_KEY_NAMES = {
"stable_cascade_stage_b"
:
"down_blocks.1.0.channelwise.0.weight"
,
"stable_cascade_stage_c"
:
"clip_txt_mapper.weight"
,
"sd3"
:
"model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias"
,
"sd35_large"
:
"model.diffusion_model.joint_blocks.37.x_block.mlp.fc1.weight"
,
"animatediff"
:
"down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe"
,
"animatediff_v2"
:
"mid_block.motion_modules.0.temporal_transformer.norm.bias"
,
"animatediff_sdxl_beta"
:
"up_blocks.2.motion_modules.0.temporal_transformer.norm.weight"
,
...
...
@@ -113,6 +114,9 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"sd3"
:
{
"pretrained_model_name_or_path"
:
"stabilityai/stable-diffusion-3-medium-diffusers"
,
},
"sd35_large"
:
{
"pretrained_model_name_or_path"
:
"stabilityai/stable-diffusion-3.5-large"
,
},
"animatediff_v1"
:
{
"pretrained_model_name_or_path"
:
"guoyww/animatediff-motion-adapter-v1-5"
},
"animatediff_v2"
:
{
"pretrained_model_name_or_path"
:
"guoyww/animatediff-motion-adapter-v1-5-2"
},
"animatediff_v3"
:
{
"pretrained_model_name_or_path"
:
"guoyww/animatediff-motion-adapter-v1-5-3"
},
...
...
@@ -504,9 +508,12 @@ def infer_diffusers_model_type(checkpoint):
):
model_type
=
"stable_cascade_stage_b"
elif
CHECKPOINT_KEY_NAMES
[
"sd3"
]
in
checkpoint
:
elif
CHECKPOINT_KEY_NAMES
[
"sd3"
]
in
checkpoint
and
checkpoint
[
CHECKPOINT_KEY_NAMES
[
"sd3"
]].
shape
[
-
1
]
==
9216
:
model_type
=
"sd3"
elif
CHECKPOINT_KEY_NAMES
[
"sd35_large"
]
in
checkpoint
:
model_type
=
"sd35_large"
elif
CHECKPOINT_KEY_NAMES
[
"animatediff"
]
in
checkpoint
:
if
CHECKPOINT_KEY_NAMES
[
"animatediff_scribble"
]
in
checkpoint
:
model_type
=
"animatediff_scribble"
...
...
@@ -1670,6 +1677,22 @@ def swap_scale_shift(weight, dim):
return
new_weight
def
get_attn2_layers
(
state_dict
):
attn2_layers
=
[]
for
key
in
state_dict
.
keys
():
if
"attn2."
in
key
:
# Extract the layer number from the key
layer_num
=
int
(
key
.
split
(
"."
)[
1
])
attn2_layers
.
append
(
layer_num
)
return
tuple
(
sorted
(
set
(
attn2_layers
)))
def
get_caption_projection_dim
(
state_dict
):
caption_projection_dim
=
state_dict
[
"context_embedder.weight"
].
shape
[
0
]
return
caption_projection_dim
def
convert_sd3_transformer_checkpoint_to_diffusers
(
checkpoint
,
**
kwargs
):
converted_state_dict
=
{}
keys
=
list
(
checkpoint
.
keys
())
...
...
@@ -1678,7 +1701,10 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
checkpoint
[
k
.
replace
(
"model.diffusion_model."
,
""
)]
=
checkpoint
.
pop
(
k
)
num_layers
=
list
(
set
(
int
(
k
.
split
(
"."
,
2
)[
1
])
for
k
in
checkpoint
if
"joint_blocks"
in
k
))[
-
1
]
+
1
# noqa: C401
caption_projection_dim
=
1536
dual_attention_layers
=
get_attn2_layers
(
checkpoint
)
caption_projection_dim
=
get_caption_projection_dim
(
checkpoint
)
has_qk_norm
=
any
(
"ln_q"
in
key
for
key
in
checkpoint
.
keys
())
# Positional and patch embeddings.
converted_state_dict
[
"pos_embed.pos_embed"
]
=
checkpoint
.
pop
(
"pos_embed"
)
...
...
@@ -1735,6 +1761,21 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict
[
f
"transformer_blocks.
{
i
}
.attn.add_v_proj.weight"
]
=
torch
.
cat
([
context_v
])
converted_state_dict
[
f
"transformer_blocks.
{
i
}
.attn.add_v_proj.bias"
]
=
torch
.
cat
([
context_v_bias
])
# qk norm
if
has_qk_norm
:
converted_state_dict
[
f
"transformer_blocks.
{
i
}
.attn.norm_q.weight"
]
=
checkpoint
.
pop
(
f
"joint_blocks.
{
i
}
.x_block.attn.ln_q.weight"
)
converted_state_dict
[
f
"transformer_blocks.
{
i
}
.attn.norm_k.weight"
]
=
checkpoint
.
pop
(
f
"joint_blocks.
{
i
}
.x_block.attn.ln_k.weight"
)
converted_state_dict
[
f
"transformer_blocks.
{
i
}
.attn.norm_added_q.weight"
]
=
checkpoint
.
pop
(
f
"joint_blocks.
{
i
}
.context_block.attn.ln_q.weight"
)
converted_state_dict
[
f
"transformer_blocks.
{
i
}
.attn.norm_added_k.weight"
]
=
checkpoint
.
pop
(
f
"joint_blocks.
{
i
}
.context_block.attn.ln_k.weight"
)
# output projections.
converted_state_dict
[
f
"transformer_blocks.
{
i
}
.attn.to_out.0.weight"
]
=
checkpoint
.
pop
(
f
"joint_blocks.
{
i
}
.x_block.attn.proj.weight"
...
...
@@ -1750,6 +1791,38 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
f
"joint_blocks.
{
i
}
.context_block.attn.proj.bias"
)
if
i
in
dual_attention_layers
:
# Q, K, V
sample_q2
,
sample_k2
,
sample_v2
=
torch
.
chunk
(
checkpoint
.
pop
(
f
"joint_blocks.
{
i
}
.x_block.attn2.qkv.weight"
),
3
,
dim
=
0
)
sample_q2_bias
,
sample_k2_bias
,
sample_v2_bias
=
torch
.
chunk
(
checkpoint
.
pop
(
f
"joint_blocks.
{
i
}
.x_block.attn2.qkv.bias"
),
3
,
dim
=
0
)
converted_state_dict
[
f
"transformer_blocks.
{
i
}
.attn2.to_q.weight"
]
=
torch
.
cat
([
sample_q2
])
converted_state_dict
[
f
"transformer_blocks.
{
i
}
.attn2.to_q.bias"
]
=
torch
.
cat
([
sample_q2_bias
])
converted_state_dict
[
f
"transformer_blocks.
{
i
}
.attn2.to_k.weight"
]
=
torch
.
cat
([
sample_k2
])
converted_state_dict
[
f
"transformer_blocks.
{
i
}
.attn2.to_k.bias"
]
=
torch
.
cat
([
sample_k2_bias
])
converted_state_dict
[
f
"transformer_blocks.
{
i
}
.attn2.to_v.weight"
]
=
torch
.
cat
([
sample_v2
])
converted_state_dict
[
f
"transformer_blocks.
{
i
}
.attn2.to_v.bias"
]
=
torch
.
cat
([
sample_v2_bias
])
# qk norm
if
has_qk_norm
:
converted_state_dict
[
f
"transformer_blocks.
{
i
}
.attn2.norm_q.weight"
]
=
checkpoint
.
pop
(
f
"joint_blocks.
{
i
}
.x_block.attn2.ln_q.weight"
)
converted_state_dict
[
f
"transformer_blocks.
{
i
}
.attn2.norm_k.weight"
]
=
checkpoint
.
pop
(
f
"joint_blocks.
{
i
}
.x_block.attn2.ln_k.weight"
)
# output projections.
converted_state_dict
[
f
"transformer_blocks.
{
i
}
.attn2.to_out.0.weight"
]
=
checkpoint
.
pop
(
f
"joint_blocks.
{
i
}
.x_block.attn2.proj.weight"
)
converted_state_dict
[
f
"transformer_blocks.
{
i
}
.attn2.to_out.0.bias"
]
=
checkpoint
.
pop
(
f
"joint_blocks.
{
i
}
.x_block.attn2.proj.bias"
)
# norms.
converted_state_dict
[
f
"transformer_blocks.
{
i
}
.norm1.linear.weight"
]
=
checkpoint
.
pop
(
f
"joint_blocks.
{
i
}
.x_block.adaLN_modulation.1.weight"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment