Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
b1a2c0d5
Unverified
Commit
b1a2c0d5
authored
Jun 13, 2024
by
Dhruv Nair
Committed by
GitHub
Jun 13, 2024
Browse files
Expand Single File support in SD3 Pipeline (#8517)
* update * update
parent
06ee907b
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
66 additions
and
17 deletions
+66
-17
docs/source/en/api/loaders/single_file.md
docs/source/en/api/loaders/single_file.md
+2
-0
docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md
...e/en/api/pipelines/stable_diffusion/stable_diffusion_3.md
+29
-8
src/diffusers/loaders/single_file.py
src/diffusers/loaders/single_file.py
+12
-0
src/diffusers/loaders/single_file_utils.py
src/diffusers/loaders/single_file_utils.py
+23
-9
No files found.
docs/source/en/api/loaders/single_file.md
View file @
b1a2c0d5
...
...
@@ -35,6 +35,7 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load:
-
[
`StableDiffusionXLInstructPix2PixPipeline`
]
-
[
`StableDiffusionXLControlNetPipeline`
]
-
[
`StableDiffusionXLKDiffusionPipeline`
]
-
[
`StableDiffusion3Pipeline`
]
-
[
`LatentConsistencyModelPipeline`
]
-
[
`LatentConsistencyModelImg2ImgPipeline`
]
-
[
`StableDiffusionControlNetXSPipeline`
]
...
...
@@ -49,6 +50,7 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load:
-
[
`StableCascadeUNet`
]
-
[
`AutoencoderKL`
]
-
[
`ControlNetModel`
]
-
[
`SD3Transformer2DModel`
]
## FromSingleFileMixin
...
...
docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md
View file @
b1a2c0d5
...
...
@@ -211,17 +211,38 @@ model = SD3Transformer2DModel.from_single_file("https://huggingface.co/stability
## Loading the single checkpoint for the `StableDiffusion3Pipeline`
### Loading the single file checkpoint without T5
```
python
import
torch
from
diffusers
import
StableDiffusion3Pipeline
from
transformers
import
T5EncoderModel
text_encoder_3
=
T5EncoderModel
.
from_pretrained
(
"stabilityai/stable-diffusion-3-medium-diffusers"
,
subfolder
=
"text_encoder_3"
,
torch_dtype
=
torch
.
float16
)
pipe
=
StableDiffusion3Pipeline
.
from_single_file
(
"https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips.safetensors"
,
torch_dtype
=
torch
.
float16
,
text_encoder_3
=
text_encoder_3
)
pipe
=
StableDiffusion3Pipeline
.
from_single_file
(
"https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips.safetensors"
,
torch_dtype
=
torch
.
float16
,
text_encoder_3
=
None
)
pipe
.
enable_model_cpu_offload
()
image
=
pipe
(
"a picture of a cat holding a sign that says hello world"
).
images
[
0
]
image
.
save
(
'sd3-single-file.png'
)
```
<Tip>
`from_single_file`
support for the
`fp8`
version of the checkpoints is coming soon. Watch this space.
</Tip>
### Loading the single file checkpoint without T5
```
python
import
torch
from
diffusers
import
StableDiffusion3Pipeline
pipe
=
StableDiffusion3Pipeline
.
from_single_file
(
"https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips_t5xxlfp8.safetensors"
,
torch_dtype
=
torch
.
float16
,
)
pipe
.
enable_model_cpu_offload
()
image
=
pipe
(
"a picture of a cat holding a sign that says hello world"
).
images
[
0
]
image
.
save
(
'sd3-single-file-t5-fp8.png'
)
```
## StableDiffusion3Pipeline
...
...
src/diffusers/loaders/single_file.py
View file @
b1a2c0d5
...
...
@@ -28,9 +28,11 @@ from .single_file_utils import (
_legacy_load_safety_checker
,
_legacy_load_scheduler
,
create_diffusers_clip_model_from_ldm
,
create_diffusers_t5_model_from_checkpoint
,
fetch_diffusers_config
,
fetch_original_config
,
is_clip_model_in_single_file
,
is_t5_in_single_file
,
load_single_file_checkpoint
,
)
...
...
@@ -118,6 +120,16 @@ def load_single_file_sub_model(
is_legacy_loading
=
is_legacy_loading
,
)
elif
is_transformers_model
and
is_t5_in_single_file
(
checkpoint
):
loaded_sub_model
=
create_diffusers_t5_model_from_checkpoint
(
class_obj
,
checkpoint
=
checkpoint
,
config
=
cached_model_config_path
,
subfolder
=
name
,
torch_dtype
=
torch_dtype
,
local_files_only
=
local_files_only
,
)
elif
is_tokenizer
and
is_legacy_loading
:
loaded_sub_model
=
_legacy_load_clip_tokenizer
(
class_obj
,
checkpoint
=
checkpoint
,
config
=
cached_model_config_path
,
local_files_only
=
local_files_only
...
...
src/diffusers/loaders/single_file_utils.py
View file @
b1a2c0d5
...
...
@@ -252,7 +252,6 @@ LDM_CONTROLNET_KEY = "control_model."
LDM_CLIP_PREFIX_TO_REMOVE
=
[
"cond_stage_model.transformer."
,
"conditioner.embedders.0.transformer."
,
"text_encoders.clip_l.transformer."
,
]
OPEN_CLIP_PREFIX
=
"conditioner.embedders.0.model."
LDM_OPEN_CLIP_TEXT_PROJECTION_DIM
=
1024
...
...
@@ -399,11 +398,14 @@ def is_open_clip_sdxl_model(checkpoint):
def
is_open_clip_sd3_model
(
checkpoint
):
is_open_clip_sdxl_refiner_model
(
checkpoint
)
if
CHECKPOINT_KEY_NAMES
[
"open_clip_sd3"
]
in
checkpoint
:
return
True
return
False
def
is_open_clip_sdxl_refiner_model
(
checkpoint
):
if
CHECKPOINT_KEY_NAMES
[
"open_clip_sd
3
"
]
in
checkpoint
:
if
CHECKPOINT_KEY_NAMES
[
"open_clip_sd
xl_refiner
"
]
in
checkpoint
:
return
True
return
False
...
...
@@ -1233,11 +1235,14 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
return
new_checkpoint
def
convert_ldm_clip_checkpoint
(
checkpoint
):
def
convert_ldm_clip_checkpoint
(
checkpoint
,
remove_prefix
=
None
):
keys
=
list
(
checkpoint
.
keys
())
text_model_dict
=
{}
remove_prefixes
=
LDM_CLIP_PREFIX_TO_REMOVE
remove_prefixes
=
[]
remove_prefixes
.
extend
(
LDM_CLIP_PREFIX_TO_REMOVE
)
if
remove_prefix
:
remove_prefixes
.
append
(
remove_prefix
)
for
key
in
keys
:
for
prefix
in
remove_prefixes
:
...
...
@@ -1376,6 +1381,13 @@ def create_diffusers_clip_model_from_ldm(
):
diffusers_format_checkpoint
=
convert_ldm_clip_checkpoint
(
checkpoint
)
elif
(
is_clip_sd3_model
(
checkpoint
)
and
checkpoint
[
CHECKPOINT_KEY_NAMES
[
"clip_sd3"
]].
shape
[
-
1
]
==
position_embedding_dim
):
diffusers_format_checkpoint
=
convert_ldm_clip_checkpoint
(
checkpoint
,
"text_encoders.clip_l.transformer."
)
diffusers_format_checkpoint
[
"text_projection.weight"
]
=
torch
.
eye
(
position_embedding_dim
)
elif
is_open_clip_model
(
checkpoint
):
prefix
=
"cond_stage_model.model."
diffusers_format_checkpoint
=
convert_open_clip_checkpoint
(
model
,
checkpoint
,
prefix
=
prefix
)
...
...
@@ -1391,9 +1403,11 @@ def create_diffusers_clip_model_from_ldm(
prefix
=
"conditioner.embedders.0.model."
diffusers_format_checkpoint
=
convert_open_clip_checkpoint
(
model
,
checkpoint
,
prefix
=
prefix
)
elif
is_open_clip_sd3_model
(
checkpoint
):
prefix
=
"text_encoders.clip_g.transformer."
diffusers_format_checkpoint
=
convert_open_clip_checkpoint
(
model
,
checkpoint
,
prefix
=
prefix
)
elif
(
is_open_clip_sd3_model
(
checkpoint
)
and
checkpoint
[
CHECKPOINT_KEY_NAMES
[
"open_clip_sd3"
]].
shape
[
-
1
]
==
position_embedding_dim
):
diffusers_format_checkpoint
=
convert_ldm_clip_checkpoint
(
checkpoint
,
"text_encoders.clip_g.transformer."
)
else
:
raise
ValueError
(
"The provided checkpoint does not seem to contain a valid CLIP model."
)
...
...
@@ -1755,7 +1769,7 @@ def convert_sd3_t5_checkpoint_to_diffusers(checkpoint):
keys
=
list
(
checkpoint
.
keys
())
text_model_dict
=
{}
remove_prefixes
=
[
"text_encoders.t5xxl.transformer.
encoder.
"
]
remove_prefixes
=
[
"text_encoders.t5xxl.transformer."
]
for
key
in
keys
:
for
prefix
in
remove_prefixes
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment