Unverified Commit b1a2c0d5 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Expand Single File support in SD3 Pipeline (#8517)

* update

* update
parent 06ee907b
...@@ -35,6 +35,7 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load: ...@@ -35,6 +35,7 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load:
- [`StableDiffusionXLInstructPix2PixPipeline`] - [`StableDiffusionXLInstructPix2PixPipeline`]
- [`StableDiffusionXLControlNetPipeline`] - [`StableDiffusionXLControlNetPipeline`]
- [`StableDiffusionXLKDiffusionPipeline`] - [`StableDiffusionXLKDiffusionPipeline`]
- [`StableDiffusion3Pipeline`]
- [`LatentConsistencyModelPipeline`] - [`LatentConsistencyModelPipeline`]
- [`LatentConsistencyModelImg2ImgPipeline`] - [`LatentConsistencyModelImg2ImgPipeline`]
- [`StableDiffusionControlNetXSPipeline`] - [`StableDiffusionControlNetXSPipeline`]
...@@ -49,6 +50,7 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load: ...@@ -49,6 +50,7 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load:
- [`StableCascadeUNet`] - [`StableCascadeUNet`]
- [`AutoencoderKL`] - [`AutoencoderKL`]
- [`ControlNetModel`] - [`ControlNetModel`]
- [`SD3Transformer2DModel`]
## FromSingleFileMixin ## FromSingleFileMixin
......
...@@ -21,9 +21,9 @@ The abstract from the paper is: ...@@ -21,9 +21,9 @@ The abstract from the paper is:
## Usage Example ## Usage Example
_As the model is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers), fill in the form and accept the gate. Once you are in, you need to login so that your system knows you’ve accepted the gate._ _As the model is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers), fill in the form and accept the gate. Once you are in, you need to login so that your system knows you’ve accepted the gate._
Use the command below to log in: Use the command below to log in:
```bash ```bash
huggingface-cli login huggingface-cli login
...@@ -211,17 +211,38 @@ model = SD3Transformer2DModel.from_single_file("https://huggingface.co/stability ...@@ -211,17 +211,38 @@ model = SD3Transformer2DModel.from_single_file("https://huggingface.co/stability
## Loading the single checkpoint for the `StableDiffusion3Pipeline` ## Loading the single checkpoint for the `StableDiffusion3Pipeline`
### Loading the single file checkpoint without T5
```python ```python
import torch
from diffusers import StableDiffusion3Pipeline from diffusers import StableDiffusion3Pipeline
from transformers import T5EncoderModel
text_encoder_3 = T5EncoderModel.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="text_encoder_3", torch_dtype=torch.float16) pipe = StableDiffusion3Pipeline.from_single_file(
pipe = StableDiffusion3Pipeline.from_single_file("https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips.safetensors", torch_dtype=torch.float16, text_encoder_3=text_encoder_3) "https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips.safetensors",
torch_dtype=torch.float16,
text_encoder_3=None
)
pipe.enable_model_cpu_offload()
image = pipe("a picture of a cat holding a sign that says hello world").images[0]
image.save('sd3-single-file.png')
``` ```
<Tip> ### Loading the single file checkpoint without T5
`from_single_file` support for the `fp8` version of the checkpoints is coming soon. Watch this space.
</Tip> ```python
import torch
from diffusers import StableDiffusion3Pipeline
pipe = StableDiffusion3Pipeline.from_single_file(
"https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips_t5xxlfp8.safetensors",
torch_dtype=torch.float16,
)
pipe.enable_model_cpu_offload()
image = pipe("a picture of a cat holding a sign that says hello world").images[0]
image.save('sd3-single-file-t5-fp8.png')
```
## StableDiffusion3Pipeline ## StableDiffusion3Pipeline
......
...@@ -28,9 +28,11 @@ from .single_file_utils import ( ...@@ -28,9 +28,11 @@ from .single_file_utils import (
_legacy_load_safety_checker, _legacy_load_safety_checker,
_legacy_load_scheduler, _legacy_load_scheduler,
create_diffusers_clip_model_from_ldm, create_diffusers_clip_model_from_ldm,
create_diffusers_t5_model_from_checkpoint,
fetch_diffusers_config, fetch_diffusers_config,
fetch_original_config, fetch_original_config,
is_clip_model_in_single_file, is_clip_model_in_single_file,
is_t5_in_single_file,
load_single_file_checkpoint, load_single_file_checkpoint,
) )
...@@ -118,6 +120,16 @@ def load_single_file_sub_model( ...@@ -118,6 +120,16 @@ def load_single_file_sub_model(
is_legacy_loading=is_legacy_loading, is_legacy_loading=is_legacy_loading,
) )
elif is_transformers_model and is_t5_in_single_file(checkpoint):
loaded_sub_model = create_diffusers_t5_model_from_checkpoint(
class_obj,
checkpoint=checkpoint,
config=cached_model_config_path,
subfolder=name,
torch_dtype=torch_dtype,
local_files_only=local_files_only,
)
elif is_tokenizer and is_legacy_loading: elif is_tokenizer and is_legacy_loading:
loaded_sub_model = _legacy_load_clip_tokenizer( loaded_sub_model = _legacy_load_clip_tokenizer(
class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only
......
...@@ -252,7 +252,6 @@ LDM_CONTROLNET_KEY = "control_model." ...@@ -252,7 +252,6 @@ LDM_CONTROLNET_KEY = "control_model."
LDM_CLIP_PREFIX_TO_REMOVE = [ LDM_CLIP_PREFIX_TO_REMOVE = [
"cond_stage_model.transformer.", "cond_stage_model.transformer.",
"conditioner.embedders.0.transformer.", "conditioner.embedders.0.transformer.",
"text_encoders.clip_l.transformer.",
] ]
OPEN_CLIP_PREFIX = "conditioner.embedders.0.model." OPEN_CLIP_PREFIX = "conditioner.embedders.0.model."
LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024 LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
...@@ -399,11 +398,14 @@ def is_open_clip_sdxl_model(checkpoint): ...@@ -399,11 +398,14 @@ def is_open_clip_sdxl_model(checkpoint):
def is_open_clip_sd3_model(checkpoint): def is_open_clip_sd3_model(checkpoint):
is_open_clip_sdxl_refiner_model(checkpoint) if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint:
return True
return False
def is_open_clip_sdxl_refiner_model(checkpoint): def is_open_clip_sdxl_refiner_model(checkpoint):
if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint: if CHECKPOINT_KEY_NAMES["open_clip_sdxl_refiner"] in checkpoint:
return True return True
return False return False
...@@ -1233,11 +1235,14 @@ def convert_ldm_vae_checkpoint(checkpoint, config): ...@@ -1233,11 +1235,14 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
return new_checkpoint return new_checkpoint
def convert_ldm_clip_checkpoint(checkpoint): def convert_ldm_clip_checkpoint(checkpoint, remove_prefix=None):
keys = list(checkpoint.keys()) keys = list(checkpoint.keys())
text_model_dict = {} text_model_dict = {}
remove_prefixes = LDM_CLIP_PREFIX_TO_REMOVE remove_prefixes = []
remove_prefixes.extend(LDM_CLIP_PREFIX_TO_REMOVE)
if remove_prefix:
remove_prefixes.append(remove_prefix)
for key in keys: for key in keys:
for prefix in remove_prefixes: for prefix in remove_prefixes:
...@@ -1376,6 +1381,13 @@ def create_diffusers_clip_model_from_ldm( ...@@ -1376,6 +1381,13 @@ def create_diffusers_clip_model_from_ldm(
): ):
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint) diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint)
elif (
is_clip_sd3_model(checkpoint)
and checkpoint[CHECKPOINT_KEY_NAMES["clip_sd3"]].shape[-1] == position_embedding_dim
):
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_l.transformer.")
diffusers_format_checkpoint["text_projection.weight"] = torch.eye(position_embedding_dim)
elif is_open_clip_model(checkpoint): elif is_open_clip_model(checkpoint):
prefix = "cond_stage_model.model." prefix = "cond_stage_model.model."
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix) diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
...@@ -1391,9 +1403,11 @@ def create_diffusers_clip_model_from_ldm( ...@@ -1391,9 +1403,11 @@ def create_diffusers_clip_model_from_ldm(
prefix = "conditioner.embedders.0.model." prefix = "conditioner.embedders.0.model."
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix) diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
elif is_open_clip_sd3_model(checkpoint): elif (
prefix = "text_encoders.clip_g.transformer." is_open_clip_sd3_model(checkpoint)
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix) and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sd3"]].shape[-1] == position_embedding_dim
):
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_g.transformer.")
else: else:
raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.") raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
...@@ -1755,7 +1769,7 @@ def convert_sd3_t5_checkpoint_to_diffusers(checkpoint): ...@@ -1755,7 +1769,7 @@ def convert_sd3_t5_checkpoint_to_diffusers(checkpoint):
keys = list(checkpoint.keys()) keys = list(checkpoint.keys())
text_model_dict = {} text_model_dict = {}
remove_prefixes = ["text_encoders.t5xxl.transformer.encoder."] remove_prefixes = ["text_encoders.t5xxl.transformer."]
for key in keys: for key in keys:
for prefix in remove_prefixes: for prefix in remove_prefixes:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment