Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
b1a2c0d5
Unverified
Commit
b1a2c0d5
authored
Jun 13, 2024
by
Dhruv Nair
Committed by
GitHub
Jun 13, 2024
Browse files
Expand Single File support in SD3 Pipeline (#8517)
* update * update
parent
06ee907b
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
66 additions
and
17 deletions
+66
-17
docs/source/en/api/loaders/single_file.md
docs/source/en/api/loaders/single_file.md
+2
-0
docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md
...e/en/api/pipelines/stable_diffusion/stable_diffusion_3.md
+29
-8
src/diffusers/loaders/single_file.py
src/diffusers/loaders/single_file.py
+12
-0
src/diffusers/loaders/single_file_utils.py
src/diffusers/loaders/single_file_utils.py
+23
-9
No files found.
docs/source/en/api/loaders/single_file.md
View file @
b1a2c0d5
...
...
@@ -35,6 +35,7 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load:
-
[
`StableDiffusionXLInstructPix2PixPipeline`
]
-
[
`StableDiffusionXLControlNetPipeline`
]
-
[
`StableDiffusionXLKDiffusionPipeline`
]
-
[
`StableDiffusion3Pipeline`
]
-
[
`LatentConsistencyModelPipeline`
]
-
[
`LatentConsistencyModelImg2ImgPipeline`
]
-
[
`StableDiffusionControlNetXSPipeline`
]
...
...
@@ -49,6 +50,7 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load:
-
[
`StableCascadeUNet`
]
-
[
`AutoencoderKL`
]
-
[
`ControlNetModel`
]
-
[
`SD3Transformer2DModel`
]
## FromSingleFileMixin
...
...
docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md
View file @
b1a2c0d5
...
...
@@ -211,17 +211,38 @@ model = SD3Transformer2DModel.from_single_file("https://huggingface.co/stability
## Loading the single checkpoint for the `StableDiffusion3Pipeline`
### Loading the single file checkpoint without T5
```
python
import
torch
from
diffusers
import
StableDiffusion3Pipeline
from
transformers
import
T5EncoderModel
text_encoder_3
=
T5EncoderModel
.
from_pretrained
(
"stabilityai/stable-diffusion-3-medium-diffusers"
,
subfolder
=
"text_encoder_3"
,
torch_dtype
=
torch
.
float16
)
pipe
=
StableDiffusion3Pipeline
.
from_single_file
(
"https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips.safetensors"
,
torch_dtype
=
torch
.
float16
,
text_encoder_3
=
text_encoder_3
)
pipe
=
StableDiffusion3Pipeline
.
from_single_file
(
"https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips.safetensors"
,
torch_dtype
=
torch
.
float16
,
text_encoder_3
=
None
)
pipe
.
enable_model_cpu_offload
()
image
=
pipe
(
"a picture of a cat holding a sign that says hello world"
).
images
[
0
]
image
.
save
(
'sd3-single-file.png'
)
```
<Tip>
`from_single_file`
support for the
`fp8`
version of the checkpoints is coming soon. Watch this space.
</Tip>
### Loading the single file checkpoint without T5
```
python
import
torch
from
diffusers
import
StableDiffusion3Pipeline
pipe
=
StableDiffusion3Pipeline
.
from_single_file
(
"https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips_t5xxlfp8.safetensors"
,
torch_dtype
=
torch
.
float16
,
)
pipe
.
enable_model_cpu_offload
()
image
=
pipe
(
"a picture of a cat holding a sign that says hello world"
).
images
[
0
]
image
.
save
(
'sd3-single-file-t5-fp8.png'
)
```
## StableDiffusion3Pipeline
...
...
src/diffusers/loaders/single_file.py
View file @
b1a2c0d5
...
...
@@ -28,9 +28,11 @@ from .single_file_utils import (
_legacy_load_safety_checker
,
_legacy_load_scheduler
,
create_diffusers_clip_model_from_ldm
,
create_diffusers_t5_model_from_checkpoint
,
fetch_diffusers_config
,
fetch_original_config
,
is_clip_model_in_single_file
,
is_t5_in_single_file
,
load_single_file_checkpoint
,
)
...
...
@@ -118,6 +120,16 @@ def load_single_file_sub_model(
is_legacy_loading
=
is_legacy_loading
,
)
elif
is_transformers_model
and
is_t5_in_single_file
(
checkpoint
):
loaded_sub_model
=
create_diffusers_t5_model_from_checkpoint
(
class_obj
,
checkpoint
=
checkpoint
,
config
=
cached_model_config_path
,
subfolder
=
name
,
torch_dtype
=
torch_dtype
,
local_files_only
=
local_files_only
,
)
elif
is_tokenizer
and
is_legacy_loading
:
loaded_sub_model
=
_legacy_load_clip_tokenizer
(
class_obj
,
checkpoint
=
checkpoint
,
config
=
cached_model_config_path
,
local_files_only
=
local_files_only
...
...
src/diffusers/loaders/single_file_utils.py
View file @
b1a2c0d5
...
...
@@ -252,7 +252,6 @@ LDM_CONTROLNET_KEY = "control_model."
LDM_CLIP_PREFIX_TO_REMOVE
=
[
"cond_stage_model.transformer."
,
"conditioner.embedders.0.transformer."
,
"text_encoders.clip_l.transformer."
,
]
OPEN_CLIP_PREFIX
=
"conditioner.embedders.0.model."
LDM_OPEN_CLIP_TEXT_PROJECTION_DIM
=
1024
...
...
@@ -399,11 +398,14 @@ def is_open_clip_sdxl_model(checkpoint):
def
is_open_clip_sd3_model
(
checkpoint
):
is_open_clip_sdxl_refiner_model
(
checkpoint
)
if
CHECKPOINT_KEY_NAMES
[
"open_clip_sd3"
]
in
checkpoint
:
return
True
return
False
def
is_open_clip_sdxl_refiner_model
(
checkpoint
):
if
CHECKPOINT_KEY_NAMES
[
"open_clip_sd
3
"
]
in
checkpoint
:
if
CHECKPOINT_KEY_NAMES
[
"open_clip_sd
xl_refiner
"
]
in
checkpoint
:
return
True
return
False
...
...
@@ -1233,11 +1235,14 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
return
new_checkpoint
def
convert_ldm_clip_checkpoint
(
checkpoint
):
def
convert_ldm_clip_checkpoint
(
checkpoint
,
remove_prefix
=
None
):
keys
=
list
(
checkpoint
.
keys
())
text_model_dict
=
{}
remove_prefixes
=
LDM_CLIP_PREFIX_TO_REMOVE
remove_prefixes
=
[]
remove_prefixes
.
extend
(
LDM_CLIP_PREFIX_TO_REMOVE
)
if
remove_prefix
:
remove_prefixes
.
append
(
remove_prefix
)
for
key
in
keys
:
for
prefix
in
remove_prefixes
:
...
...
@@ -1376,6 +1381,13 @@ def create_diffusers_clip_model_from_ldm(
):
diffusers_format_checkpoint
=
convert_ldm_clip_checkpoint
(
checkpoint
)
elif
(
is_clip_sd3_model
(
checkpoint
)
and
checkpoint
[
CHECKPOINT_KEY_NAMES
[
"clip_sd3"
]].
shape
[
-
1
]
==
position_embedding_dim
):
diffusers_format_checkpoint
=
convert_ldm_clip_checkpoint
(
checkpoint
,
"text_encoders.clip_l.transformer."
)
diffusers_format_checkpoint
[
"text_projection.weight"
]
=
torch
.
eye
(
position_embedding_dim
)
elif
is_open_clip_model
(
checkpoint
):
prefix
=
"cond_stage_model.model."
diffusers_format_checkpoint
=
convert_open_clip_checkpoint
(
model
,
checkpoint
,
prefix
=
prefix
)
...
...
@@ -1391,9 +1403,11 @@ def create_diffusers_clip_model_from_ldm(
prefix
=
"conditioner.embedders.0.model."
diffusers_format_checkpoint
=
convert_open_clip_checkpoint
(
model
,
checkpoint
,
prefix
=
prefix
)
elif
is_open_clip_sd3_model
(
checkpoint
):
prefix
=
"text_encoders.clip_g.transformer."
diffusers_format_checkpoint
=
convert_open_clip_checkpoint
(
model
,
checkpoint
,
prefix
=
prefix
)
elif
(
is_open_clip_sd3_model
(
checkpoint
)
and
checkpoint
[
CHECKPOINT_KEY_NAMES
[
"open_clip_sd3"
]].
shape
[
-
1
]
==
position_embedding_dim
):
diffusers_format_checkpoint
=
convert_ldm_clip_checkpoint
(
checkpoint
,
"text_encoders.clip_g.transformer."
)
else
:
raise
ValueError
(
"The provided checkpoint does not seem to contain a valid CLIP model."
)
...
...
@@ -1755,7 +1769,7 @@ def convert_sd3_t5_checkpoint_to_diffusers(checkpoint):
keys
=
list
(
checkpoint
.
keys
())
text_model_dict
=
{}
remove_prefixes
=
[
"text_encoders.t5xxl.transformer.
encoder.
"
]
remove_prefixes
=
[
"text_encoders.t5xxl.transformer."
]
for
key
in
keys
:
for
prefix
in
remove_prefixes
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment