Unverified Commit cb0f3b49 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

[Refactor] Better align `from_single_file` logic with `from_pretrained` (#7496)



* refactor unet single file loading a bit.

* retrieve the unet from create_diffusers_unet_model_from_ldm

* update

* update

* updae

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* tests

* update

* update

* update

* Update docs/source/en/api/single_file.md
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update docs/source/en/api/single_file.md
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* Update docs/source/en/api/loaders/single_file.md
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* Update src/diffusers/loaders/single_file.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* Update docs/source/en/api/loaders/single_file.md
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update docs/source/en/api/loaders/single_file.md
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update docs/source/en/api/loaders/single_file.md
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update docs/source/en/api/loaders/single_file.md
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

---------
Co-authored-by: default avatarsayakpaul <spsayakpaul@gmail.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent caf9e985
......@@ -124,7 +124,7 @@ jobs:
shell: bash
strategy:
matrix:
module: [models, schedulers, lora, others]
module: [models, schedulers, lora, others, single_file]
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
......
......@@ -10,13 +10,124 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
# Single files
# Loading Pipelines and Models via `from_single_file`
Diffusers supports loading pretrained pipeline (or model) weights stored in a single file, such as a `ckpt` or `safetensors` file. These single file types are typically produced from community trained models. There are three classes for loading single file weights:
The `from_single_file` method allows you to load supported pipelines using a single checkpoint file as opposed to the folder format used by Diffusers. This is useful if you are working with many of the Stable Diffusion Web UI's (such as A1111) that extensively rely on a single file to distribute all the components of a diffusion model.
- [`FromSingleFileMixin`] supports loading pretrained pipeline weights stored in a single file, which can either be a `ckpt` or `safetensors` file.
- [`FromOriginalVAEMixin`] supports loading a pretrained [`AutoencoderKL`] from pretrained ControlNet weights stored in a single file, which can either be a `ckpt` or `safetensors` file.
- [`FromOriginalControlnetMixin`] supports loading pretrained ControlNet weights stored in a single file, which can either be a `ckpt` or `safetensors` file.
The `from_single_file` method also supports loading models in their originally distributed format. This means that supported models that have been finetuned with other services can be loaded directly into supported Diffusers model objects and pipelines.
## Pipelines that currently support `from_single_file` loading
- [`StableDiffusionPipeline`]
- [`StableDiffusionImg2ImgPipeline`]
- [`StableDiffusionInpaintPipeline`]
- [`StableDiffusionControlNetPipeline`]
- [`StableDiffusionControlNetImg2ImgPipeline`]
- [`StableDiffusionControlNetInpaintPipeline`]
- [`StableDiffusionUpscalePipeline`]
- [`StableDiffusionXLPipeline`]
- [`StableDiffusionXLImg2ImgPipeline`]
- [`StableDiffusionXLInpaintPipeline`]
- [`StableDiffusionXLInstructPix2PixPipeline`]
- [`StableDiffusionXLControlNetPipeline`]
- [`StableDiffusionXLKDiffusionPipeline`]
- [`LatentConsistencyModelPipeline`]
- [`LatentConsistencyModelImg2ImgPipeline`]
- [`StableDiffusionControlNetXSPipeline`]
- [`StableDiffusionXLControlNetXSPipeline`]
- [`LEditsPPPipelineStableDiffusion`]
- [`LEditsPPPipelineStableDiffusionXL`]
- [`PIAPipeline`]
## Models that currently support `from_single_file` loading
- [`UNet2DConditionModel`]
- [`StableCascadeUNet`]
- [`AutoencoderKL`]
- [`ControlNetModel`]
## Usage Examples
## Loading a Pipeline using `from_single_file`
```python
from diffusers import StableDiffusionXLPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors"
pipe = StableDiffusionXLPipeline.from_single_file(ckpt_path)
```
## Setting components in a Pipeline using `from_single_file`
Swap components of the pipeline by passing them directly to the `from_single_file` method. e.g If you would like use a different scheduler than the pipeline default.
```python
from diffusers import StableDiffusionXLPipeline, DDIMScheduler
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors"
scheduler = DDIMScheduler()
pipe = StableDiffusionXLPipeline.from_single_file(ckpt_path, scheduler=scheduler)
```
```python
from diffusers import StableDiffusionPipeline, ControlNetModel
ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors"
controlnet = ControlNetModel.from_pretrained("https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors")
pipe = StableDiffusionPipeline.from_single_file(ckpt_path, controlnet=controlnet)
```
## Loading a Model using `from_single_file`
```python
from diffusers import StableCascadeUNet
ckpt_path = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_lite.safetensors"
model = StableCascadeUNet.from_single_file(ckpt_path)
```
## Using a Diffusers model repository to configure single file loading
Under the hood, `from_single_file` will try to determine a model repository to use to configure the components of the pipeline. You can also pass in a repository id to the `config` argument of the `from_single_file` method to explicitly set the repository to use.
```python
from diffusers import StableDiffusionXLPipeline
ckpt_path = "https://huggingface.co/segmind/SSD-1B/blob/main/SSD-1B.safetensors"
repo_id = "segmind/SSD-1B"
pipe = StableDiffusionXLPipeline.from_single_file(ckpt_path, config=repo_id)
```
## Override configuration options when using single file loading
Override the default model or pipeline configuration options when using `from_single_file` by passing in the relevant arguments directly to the `from_single_file` method. Any argument that is supported by the model or pipeline class can be configured in this way:
```python
from diffusers import StableDiffusionXLInstructPix2PixPipeline
ckpt_path = "https://huggingface.co/stabilityai/cosxl/blob/main/cosxl_edit.safetensors"
pipe = StableDiffusionXLInstructPix2PixPipeline.from_single_file(ckpt_path, config="diffusers/sdxl-instructpix2pix-768", is_cosxl_edit=True)
```
```python
from diffusers import UNet2DConditionModel
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors"
model = UNet2DConditionModel.from_single_file(ckpt_path, upcast_attention=True)
```
In the example above, since we explicitly passed `repo_id="segmind/SSD-1B"`, it will use this [configuration file](https://huggingface.co/segmind/SSD-1B/blob/main/unet/config.json) from the "unet" subfolder in `"segmind/SSD-1B"` to configure the unet component included in the checkpoint; Similarly, it will use the `config.json` file from `"vae"` subfolder to configure the vae model, `config.json` file from text_encoder folder to configure text_encoder and so on.
Note that most of the time you do not need to explicitly a `config` argument, `from_single_file` will automatically map the checkpoint to a repo id (we will discuss this in more details in next section). However, this can be useful in cases where model components might have been changed from what was originally distributed or in cases where a checkpoint file might not have the necessary metadata to correctly determine the configuration to use for the pipeline.
<Tip>
......@@ -24,14 +135,114 @@ To learn more about how to load single file weights, see the [Load different Sta
</Tip>
## FromSingleFileMixin
## Working with local files
[[autodoc]] loaders.single_file.FromSingleFileMixin
As of `diffusers>=0.28.0` the `from_single_file` method will attempt to configure a pipeline or model by first inferring the model type from the checkpoint file and then using the model type to determine the appropriate model repo configuration to use from the Hugging Face Hub. For example, any single file checkpoint based on the Stable Diffusion XL base model will use the [`stabilityai/stable-diffusion-xl-base-1.0`](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) model repo to configure the pipeline.
## FromOriginalVAEMixin
If you are working in an environment with restricted internet access, it is recommended to download the config files and checkpoints for the model to your preferred directory and pass the local paths to the `pretrained_model_link_or_path` and `config` arguments of the `from_single_file` method.
[[autodoc]] loaders.autoencoder.FromOriginalVAEMixin
```python
from huggingface_hub import hf_hub_download, snapshot_download
my_local_checkpoint_path = hf_hub_download(
repo_id="segmind/SSD-1B",
filename="SSD-1B.safetensors"
)
my_local_config_path = snapshot_download(
repo_id="segmind/SSD-1B",
allowed_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
)
pipe = StableDiffusionXLPipeline.from_single_file(my_local_checkpoint_path, config=my_local_config_path, local_files_only=True)
```
By default this will download the checkpoints and config files to the [Hugging Face Hub cache directory](https://huggingface.co/docs/huggingface_hub/en/guides/manage-cache). You can also specify a local directory to download the files to by passing the `local_dir` argument to the `hf_hub_download` and `snapshot_download` functions.
```python
from huggingface_hub import hf_hub_download, snapshot_download
my_local_checkpoint_path = hf_hub_download(
repo_id="segmind/SSD-1B",
filename="SSD-1B.safetensors"
local_dir="my_local_checkpoints"
)
my_local_config_path = snapshot_download(
repo_id="segmind/SSD-1B",
allowed_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
local_dir="my_local_config"
)
pipe = StableDiffusionXLPipeline.from_single_file(my_local_checkpoint_path, config=my_local_config_path, local_files_only=True)
```
## Working with local files on file systems that do not support symlinking
By default the `from_single_file` method relies on the `huggingface_hub` caching mechanism to fetch and store checkpoints and config files for models and pipelines. If you are working with a file system that does not support symlinking, it is recommended that you first download the checkpoint file to a local directory and disable symlinking by passing the `local_dir_use_symlink=False` argument to the `hf_hub_download` and `snapshot_download` functions.
```python
from huggingface_hub import hf_hub_download, snapshot_download
my_local_checkpoint_path = hf_hub_download(
repo_id="segmind/SSD-1B",
filename="SSD-1B.safetensors"
local_dir="my_local_checkpoints",
local_dir_use_symlinks=False
)
print("My local checkpoint: ", my_local_checkpoint_path)
my_local_config_path = snapshot_download(
repo_id="segmind/SSD-1B",
allowed_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
local_dir_use_symlinks=False,
)
print("My local config: ", my_local_config_path)
```
Then pass the local paths to the `pretrained_model_link_or_path` and `config` arguments of the `from_single_file` method.
```python
pipe = StableDiffusionXLPipeline.from_single_file(my_local_checkpoint_path, config=my_local_config_path, local_files_only=True)
```
<Tip>
Disabling symlinking means that the `huggingface_hub` caching mechanism has no way to determine whether a file has already been downloaded to the local directory. This means that the `hf_hub_download` and `snapshot_download` functions will download files to the local directory each time they are executed. If you are disabling symlinking, it is recommended that you separate the model download and loading steps to avoid downloading the same file multiple times.
</Tip>
## Using the original configuration file of a model
If you would like to configure the parameters of the model components in the pipeline using the orignal YAML configuration file, you can pass a local path or url to the original configuration file to the `original_config` argument of the `from_single_file` method.
```python
from diffusers import StableDiffusionXLPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors"
repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
original_config = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
pipe = StableDiffusionXLPipeline.from_single_file(ckpt_path, original_config=original_config)
```
In the example above, the `original_config` file is only used to configure the parameters of the individual model components of the pipeline. For example it will be used to configure parameters such as the `in_channels` of the `vae` model and `unet` model. It is not used to determine the type of component objects in the pipeline.
<Tip>
When using `original_config` with local_files_only=True`, Diffusers will attempt to infer the components based on the type signatures of pipeline class, rather than attempting to fetch the pipeline config from the Hugging Face Hub. This is to prevent backwards breaking changes in existing code that might not be able to connect to the internet to fetch the necessary pipeline config files.
This is not as reliable as providing a path to a local config repo and might lead to errors when configuring the pipeline. To avoid this, please run the pipeline with `local_files_only=False` once to download the appropriate pipeline config files to the local cache.
</Tip>
## FromSingleFileMixin
[[autodoc]] loaders.single_file.FromSingleFileMixin
## FromOriginalControlnetMixin
## FromOriginalModelMixin
[[autodoc]] loaders.controlnet.FromOriginalControlNetMixin
\ No newline at end of file
[[autodoc]] loaders.single_file_model.FromOriginalModelMixin
......@@ -27,6 +27,7 @@ from .utils import (
_import_structure = {
"configuration_utils": ["ConfigMixin"],
"loaders": ["FromOriginalModelMixin"],
"models": [],
"pipelines": [],
"schedulers": [],
......
......@@ -340,6 +340,8 @@ class ConfigMixin:
"""
cache_dir = kwargs.pop("cache_dir", None)
local_dir = kwargs.pop("local_dir", None)
local_dir_use_symlinks = kwargs.pop("local_dir_use_symlinks", "auto")
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", None)
proxies = kwargs.pop("proxies", None)
......@@ -364,13 +366,13 @@ class ConfigMixin:
if os.path.isfile(pretrained_model_name_or_path):
config_file = pretrained_model_name_or_path
elif os.path.isdir(pretrained_model_name_or_path):
if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
# Load from a PyTorch checkpoint
config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
elif subfolder is not None and os.path.isfile(
if subfolder is not None and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
):
config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
# Load from a PyTorch checkpoint
config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
else:
raise EnvironmentError(
f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
......@@ -390,6 +392,8 @@ class ConfigMixin:
user_agent=user_agent,
subfolder=subfolder,
revision=revision,
local_dir=local_dir,
local_dir_use_symlinks=local_dir_use_symlinks,
)
except RepositoryNotFoundError:
raise EnvironmentError(
......
......@@ -54,9 +54,7 @@ if is_transformers_available():
_import_structure = {}
if is_torch_available():
_import_structure["autoencoder"] = ["FromOriginalVAEMixin"]
_import_structure["controlnet"] = ["FromOriginalControlNetMixin"]
_import_structure["single_file_model"] = ["FromOriginalModelMixin"]
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
_import_structure["utils"] = ["AttnProcsLayers"]
if is_transformers_available():
......@@ -70,8 +68,7 @@ _import_structure["peft"] = ["PeftAdapterMixin"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
from .autoencoder import FromOriginalVAEMixin
from .controlnet import FromOriginalControlNetMixin
from .single_file_model import FromOriginalModelMixin
from .unet import UNet2DConditionLoadersMixin
from .utils import AttnProcsLayers
......
......@@ -11,144 +11,243 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import inspect
import os
from huggingface_hub.utils import validate_hf_hub_args
import torch
from huggingface_hub import snapshot_download
from huggingface_hub.utils import LocalEntryNotFoundError, validate_hf_hub_args
from packaging import version
from ..utils import is_transformers_available, logging
from ..utils import deprecate, is_transformers_available, logging
from .single_file_utils import (
create_diffusers_unet_model_from_ldm,
create_diffusers_vae_model_from_ldm,
create_scheduler_from_ldm,
create_text_encoders_and_tokenizers_from_ldm,
fetch_ldm_config_and_checkpoint,
infer_model_type,
SingleFileComponentError,
_is_model_weights_in_cached_folder,
_legacy_load_clip_tokenizer,
_legacy_load_safety_checker,
_legacy_load_scheduler,
create_diffusers_clip_model_from_ldm,
fetch_diffusers_config,
fetch_original_config,
is_clip_model_in_single_file,
load_single_file_checkpoint,
)
logger = logging.get_logger(__name__)
# Pipelines that support the SDXL Refiner checkpoint
REFINER_PIPELINES = [
"StableDiffusionXLImg2ImgPipeline",
"StableDiffusionXLInpaintPipeline",
"StableDiffusionXLControlNetImg2ImgPipeline",
]
# Legacy behaviour. `from_single_file` does not load the safety checker unless explicitly provided
SINGLE_FILE_OPTIONAL_COMPONENTS = ["safety_checker"]
if is_transformers_available():
from transformers import AutoFeatureExtractor
import transformers
from transformers import PreTrainedModel, PreTrainedTokenizer
def build_sub_model_components(
pipeline_components,
pipeline_class_name,
component_name,
original_config,
def load_single_file_sub_model(
library_name,
class_name,
name,
checkpoint,
pipelines,
is_pipeline_module,
cached_model_config_path,
original_config=None,
local_files_only=False,
load_safety_checker=False,
model_type=None,
image_size=None,
torch_dtype=None,
is_legacy_loading=False,
**kwargs,
):
if component_name in pipeline_components:
return {}
if component_name == "unet":
num_in_channels = kwargs.pop("num_in_channels", None)
upcast_attention = kwargs.pop("upcast_attention", None)
unet_components = create_diffusers_unet_model_from_ldm(
pipeline_class_name,
original_config,
checkpoint,
num_in_channels=num_in_channels,
image_size=image_size,
if is_pipeline_module:
pipeline_module = getattr(pipelines, library_name)
class_obj = getattr(pipeline_module, class_name)
else:
# else we just import it from the library.
library = importlib.import_module(library_name)
class_obj = getattr(library, class_name)
if is_transformers_available():
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
else:
transformers_version = "N/A"
is_transformers_model = (
is_transformers_available()
and issubclass(class_obj, PreTrainedModel)
and transformers_version >= version.parse("4.20.0")
)
is_tokenizer = (
is_transformers_available()
and issubclass(class_obj, PreTrainedTokenizer)
and transformers_version >= version.parse("4.20.0")
)
diffusers_module = importlib.import_module(__name__.split(".")[0])
is_diffusers_single_file_model = issubclass(class_obj, diffusers_module.FromOriginalModelMixin)
is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
is_diffusers_scheduler = issubclass(class_obj, diffusers_module.SchedulerMixin)
if is_diffusers_single_file_model:
load_method = getattr(class_obj, "from_single_file")
# We cannot provide two different config options to the `from_single_file` method
# Here we have to ignore loading the config from `cached_model_config_path` if `original_config` is provided
if original_config:
cached_model_config_path = None
loaded_sub_model = load_method(
pretrained_model_link_or_path_or_dict=checkpoint,
original_config=original_config,
config=cached_model_config_path,
subfolder=name,
torch_dtype=torch_dtype,
model_type=model_type,
upcast_attention=upcast_attention,
local_files_only=local_files_only,
**kwargs,
)
return unet_components
if component_name == "vae":
scaling_factor = kwargs.get("scaling_factor", None)
vae_components = create_diffusers_vae_model_from_ldm(
pipeline_class_name,
original_config,
checkpoint,
image_size,
scaling_factor,
torch_dtype,
model_type=model_type,
elif is_transformers_model and is_clip_model_in_single_file(class_obj, checkpoint):
loaded_sub_model = create_diffusers_clip_model_from_ldm(
class_obj,
checkpoint=checkpoint,
config=cached_model_config_path,
subfolder=name,
torch_dtype=torch_dtype,
local_files_only=local_files_only,
is_legacy_loading=is_legacy_loading,
)
return vae_components
if component_name == "scheduler":
scheduler_type = kwargs.get("scheduler_type", "ddim")
prediction_type = kwargs.get("prediction_type", None)
scheduler_components = create_scheduler_from_ldm(
pipeline_class_name,
original_config,
checkpoint,
scheduler_type=scheduler_type,
prediction_type=prediction_type,
model_type=model_type,
elif is_tokenizer and is_legacy_loading:
loaded_sub_model = _legacy_load_clip_tokenizer(
class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only
)
return scheduler_components
elif is_diffusers_scheduler and is_legacy_loading:
loaded_sub_model = _legacy_load_scheduler(
class_obj, checkpoint=checkpoint, component_name=name, original_config=original_config, **kwargs
)
if component_name in ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"]:
text_encoder_components = create_text_encoders_and_tokenizers_from_ldm(
original_config,
checkpoint,
model_type=model_type,
local_files_only=local_files_only,
torch_dtype=torch_dtype,
else:
if not hasattr(class_obj, "from_pretrained"):
raise ValueError(
(
f"The component {class_obj.__name__} cannot be loaded as it does not seem to have"
" a supported loading method."
)
)
loading_kwargs = {}
loading_kwargs.update(
{
"pretrained_model_name_or_path": cached_model_config_path,
"subfolder": name,
"local_files_only": local_files_only,
}
)
return text_encoder_components
if component_name == "safety_checker":
if load_safety_checker:
from ..pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
# Schedulers and Tokenizers don't make use of torch_dtype
# Skip passing it to those objects
if issubclass(class_obj, torch.nn.Module):
loading_kwargs.update({"torch_dtype": torch_dtype})
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only, torch_dtype=torch_dtype
)
else:
safety_checker = None
return {"safety_checker": safety_checker}
if is_diffusers_model or is_transformers_model:
if not _is_model_weights_in_cached_folder(cached_model_config_path, name):
raise SingleFileComponentError(
f"Failed to load {class_name}. Weights for this component appear to be missing in the checkpoint."
)
if component_name == "feature_extractor":
if load_safety_checker:
feature_extractor = AutoFeatureExtractor.from_pretrained(
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only
)
else:
feature_extractor = None
return {"feature_extractor": feature_extractor}
load_method = getattr(class_obj, "from_pretrained")
loaded_sub_model = load_method(**loading_kwargs)
return
return loaded_sub_model
def set_additional_components(
pipeline_class_name,
original_config,
checkpoint=None,
model_type=None,
):
components = {}
if pipeline_class_name in REFINER_PIPELINES:
model_type = infer_model_type(original_config, checkpoint=checkpoint, model_type=model_type)
is_refiner = model_type == "SDXL-Refiner"
components.update(
{
"requires_aesthetics_score": is_refiner,
"force_zeros_for_empty_prompt": False if is_refiner else True,
}
def _map_component_types_to_config_dict(component_types):
diffusers_module = importlib.import_module(__name__.split(".")[0])
config_dict = {}
component_types.pop("self", None)
if is_transformers_available():
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
else:
transformers_version = "N/A"
for component_name, component_value in component_types.items():
is_diffusers_model = issubclass(component_value[0], diffusers_module.ModelMixin)
is_scheduler_enum = component_value[0].__name__ == "KarrasDiffusionSchedulers"
is_scheduler = issubclass(component_value[0], diffusers_module.SchedulerMixin)
is_transformers_model = (
is_transformers_available()
and issubclass(component_value[0], PreTrainedModel)
and transformers_version >= version.parse("4.20.0")
)
is_transformers_tokenizer = (
is_transformers_available()
and issubclass(component_value[0], PreTrainedTokenizer)
and transformers_version >= version.parse("4.20.0")
)
if is_diffusers_model and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS:
config_dict[component_name] = ["diffusers", component_value[0].__name__]
elif is_scheduler_enum or is_scheduler:
if is_scheduler_enum:
# Since we cannot fetch a scheduler config from the hub, we default to DDIMScheduler
# if the type hint is a KarrassDiffusionSchedulers enum
config_dict[component_name] = ["diffusers", "DDIMScheduler"]
elif is_scheduler:
config_dict[component_name] = ["diffusers", component_value[0].__name__]
elif (
is_transformers_model or is_transformers_tokenizer
) and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS:
config_dict[component_name] = ["transformers", component_value[0].__name__]
return components
else:
config_dict[component_name] = [None, None]
return config_dict
def _infer_pipeline_config_dict(pipeline_class):
parameters = inspect.signature(pipeline_class.__init__).parameters
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
component_types = pipeline_class._get_signature_types()
# Ignore parameters that are not required for the pipeline
component_types = {k: v for k, v in component_types.items() if k in required_parameters}
config_dict = _map_component_types_to_config_dict(component_types)
return config_dict
def _download_diffusers_model_config_from_hub(
pretrained_model_name_or_path,
cache_dir,
revision,
proxies,
force_download=None,
resume_download=None,
local_files_only=None,
token=None,
):
allow_patterns = ["**/*.json", "*.json", "*.txt", "**/*.txt"]
cached_model_path = snapshot_download(
pretrained_model_name_or_path,
cache_dir=cache_dir,
revision=revision,
proxies=proxies,
force_download=force_download,
resume_download=resume_download,
local_files_only=local_files_only,
token=token,
allow_patterns=allow_patterns,
)
return cached_model_path
class FromSingleFileMixin:
......@@ -195,27 +294,12 @@ class FromSingleFileMixin:
original_config_file (`str`, *optional*):
The path to the original config file that was used to train the model. If not provided, the config file
will be inferred from the checkpoint file.
model_type (`str`, *optional*):
The type of model to load. If not provided, the model type will be inferred from the checkpoint file.
image_size (`int`, *optional*):
The size of the image output. It's used to configure the `sample_size` parameter of the UNet and VAE
model.
load_safety_checker (`bool`, *optional*, defaults to `False`):
Whether to load the safety checker model or not. By default, the safety checker is not loaded unless a
`safety_checker` component is passed to the `kwargs`.
num_in_channels (`int`, *optional*):
Specify the number of input channels for the UNet model. Read more about how to configure UNet model
with this parameter
[here](https://huggingface.co/docs/diffusers/training/adapt_a_model#configure-unet2dconditionmodel-parameters).
scaling_factor (`float`, *optional*):
The scaling factor to use for the VAE model. If not provided, it is inferred from the config file
first. If the scaling factor is not found in the config file, the default value 0.18215 is used.
scheduler_type (`str`, *optional*):
The type of scheduler to load. If not provided, the scheduler type will be inferred from the checkpoint
file.
prediction_type (`str`, *optional*):
The type of prediction to load. If not provided, the prediction type will be inferred from the
checkpoint file.
config (`str`, *optional*):
Can be either:
- A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline
hosted on the Hub.
- A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline
component configs in Diffusers format.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
......@@ -233,7 +317,7 @@ class FromSingleFileMixin:
>>> # Download pipeline from local file
>>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt
>>> pipeline = StableDiffusionPipeline.from_single_file("./v1-5-pruned-emaonly")
>>> pipeline = StableDiffusionPipeline.from_single_file("./v1-5-pruned-emaonly.ckpt")
>>> # Enable float16 and move to GPU
>>> pipeline = StableDiffusionPipeline.from_single_file(
......@@ -242,9 +326,21 @@ class FromSingleFileMixin:
... )
>>> pipeline.to("cuda")
```
"""
original_config_file = kwargs.pop("original_config_file", None)
resume_download = kwargs.pop("resume_download", None)
config = kwargs.pop("config", None)
original_config = kwargs.pop("original_config", None)
if original_config_file is not None:
deprecation_message = (
"`original_config_file` argument is deprecated and will be removed in future versions."
"please use the `original_config` argument instead."
)
deprecate("original_config_file", "1.0.0", deprecation_message)
original_config = original_config_file
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
token = kwargs.pop("token", None)
......@@ -253,68 +349,198 @@ class FromSingleFileMixin:
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
class_name = cls.__name__
is_legacy_loading = False
original_config, checkpoint = fetch_ldm_config_and_checkpoint(
pretrained_model_link_or_path=pretrained_model_link_or_path,
class_name=class_name,
original_config_file=original_config_file,
# We shouldn't allow configuring individual models components through a Pipeline creation method
# These model kwargs should be deprecated
scaling_factor = kwargs.get("scaling_factor", None)
if scaling_factor is not None:
deprecation_message = (
"Passing the `scaling_factor` argument to `from_single_file is deprecated "
"and will be ignored in future versions."
)
deprecate("scaling_factor", "1.0.0", deprecation_message)
if original_config is not None:
original_config = fetch_original_config(original_config, local_files_only=local_files_only)
from ..pipelines.pipeline_utils import _get_pipeline_class
pipeline_class = _get_pipeline_class(cls, config=None)
checkpoint = load_single_file_checkpoint(
pretrained_model_link_or_path,
resume_download=resume_download,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
local_files_only=local_files_only,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
)
from ..pipelines.pipeline_utils import _get_pipeline_class
if config is None:
config = fetch_diffusers_config(checkpoint)
default_pretrained_model_config_name = config["pretrained_model_name_or_path"]
else:
default_pretrained_model_config_name = config
if not os.path.isdir(default_pretrained_model_config_name):
# Provided config is a repo_id
if default_pretrained_model_config_name.count("/") > 1:
raise ValueError(
f'The provided config "{config}"'
" is neither a valid local path nor a valid repo id. Please check the parameter."
)
try:
# Attempt to download the config files for the pipeline
cached_model_config_path = _download_diffusers_model_config_from_hub(
default_pretrained_model_config_name,
cache_dir=cache_dir,
revision=revision,
proxies=proxies,
force_download=force_download,
resume_download=resume_download,
local_files_only=local_files_only,
token=token,
)
config_dict = pipeline_class.load_config(cached_model_config_path)
except LocalEntryNotFoundError:
# `local_files_only=True` but a local diffusers format model config is not available in the cache
# If `original_config` is not provided, we need override `local_files_only` to False
# to fetch the config files from the hub so that we have a way
# to configure the pipeline components.
if original_config is None:
logger.warning(
"`local_files_only` is True but no local configs were found for this checkpoint.\n"
"Attempting to download the necessary config files for this pipeline.\n"
)
cached_model_config_path = _download_diffusers_model_config_from_hub(
default_pretrained_model_config_name,
cache_dir=cache_dir,
revision=revision,
proxies=proxies,
force_download=force_download,
resume_download=resume_download,
local_files_only=False,
token=token,
)
config_dict = pipeline_class.load_config(cached_model_config_path)
else:
# For backwards compatibility
# If `original_config` is provided, then we need to assume we are using legacy loading for pipeline components
logger.warning(
"Detected legacy `from_single_file` loading behavior. Attempting to create the pipeline based on inferred components.\n"
"This may lead to errors if the model components are not correctly inferred. \n"
"To avoid this warning, please explicity pass the `config` argument to `from_single_file` with a path to a local diffusers model repo \n"
"e.g. `from_single_file(<my model checkpoint path>, config=<path to local diffusers model repo>) \n"
"or run `from_single_file` with `local_files_only=False` first to update the local cache directory with "
"the necessary config files.\n"
)
is_legacy_loading = True
cached_model_config_path = None
config_dict = _infer_pipeline_config_dict(pipeline_class)
config_dict["_class_name"] = pipeline_class.__name__
pipeline_class = _get_pipeline_class(
cls,
config=None,
cache_dir=cache_dir,
)
else:
# Provided config is a path to a local directory attempt to load directly.
cached_model_config_path = default_pretrained_model_config_name
config_dict = pipeline_class.load_config(cached_model_config_path)
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
# pop out "_ignore_files" as it is only needed for download
config_dict.pop("_ignore_files", None)
expected_modules, optional_kwargs = pipeline_class._get_signature_keys(cls)
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
model_type = kwargs.pop("model_type", None)
image_size = kwargs.pop("image_size", None)
load_safety_checker = (kwargs.pop("load_safety_checker", False)) or (
passed_class_obj.get("safety_checker", None) is not None
)
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
from diffusers import pipelines
# remove `null` components
def load_module(name, value):
if value[0] is None:
return False
if name in passed_class_obj and passed_class_obj[name] is None:
return False
if name in SINGLE_FILE_OPTIONAL_COMPONENTS:
return False
return True
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
for name, (library_name, class_name) in logging.tqdm(
sorted(init_dict.items()), desc="Loading pipeline components..."
):
loaded_sub_model = None
is_pipeline_module = hasattr(pipelines, library_name)
init_kwargs = {}
for name in expected_modules:
if name in passed_class_obj:
init_kwargs[name] = passed_class_obj[name]
loaded_sub_model = passed_class_obj[name]
else:
components = build_sub_model_components(
init_kwargs,
class_name,
name,
original_config,
checkpoint,
model_type=model_type,
image_size=image_size,
load_safety_checker=load_safety_checker,
local_files_only=local_files_only,
torch_dtype=torch_dtype,
**kwargs,
)
if not components:
continue
init_kwargs.update(components)
try:
loaded_sub_model = load_single_file_sub_model(
library_name=library_name,
class_name=class_name,
name=name,
checkpoint=checkpoint,
is_pipeline_module=is_pipeline_module,
cached_model_config_path=cached_model_config_path,
pipelines=pipelines,
torch_dtype=torch_dtype,
original_config=original_config,
local_files_only=local_files_only,
is_legacy_loading=is_legacy_loading,
**kwargs,
)
except SingleFileComponentError as e:
raise SingleFileComponentError(
(
f"{e.message}\n"
f"Please load the component before passing it in as an argument to `from_single_file`.\n"
f"\n"
f"{name} = {class_name}.from_pretrained('...')\n"
f"pipe = {pipeline_class.__name__}.from_single_file(<checkpoint path>, {name}={name})\n"
f"\n"
)
)
init_kwargs[name] = loaded_sub_model
missing_modules = set(expected_modules) - set(init_kwargs.keys())
passed_modules = list(passed_class_obj.keys())
optional_modules = pipeline_class._optional_components
if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules):
for module in missing_modules:
init_kwargs[module] = passed_class_obj.get(module, None)
elif len(missing_modules) > 0:
passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
raise ValueError(
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
)
additional_components = set_additional_components(
class_name, original_config, checkpoint=checkpoint, model_type=model_type
)
if additional_components:
init_kwargs.update(additional_components)
# deprecated kwargs
load_safety_checker = kwargs.pop("load_safety_checker", None)
if load_safety_checker is not None:
deprecation_message = (
"Please pass instances of `StableDiffusionSafetyChecker` and `AutoImageProcessor`"
"using the `safety_checker` and `feature_extractor` arguments in `from_single_file`"
)
deprecate("load_safety_checker", "1.0.0", deprecation_message)
safety_checker_components = _legacy_load_safety_checker(local_files_only, torch_dtype)
init_kwargs.update(safety_checker_components)
init_kwargs.update(passed_pipe_kwargs)
pipe = pipeline_class(**init_kwargs)
if torch_dtype is not None:
......
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import re
from contextlib import nullcontext
from typing import Optional
from huggingface_hub.utils import validate_hf_hub_args
from ..utils import deprecate, is_accelerate_available, logging
from .single_file_utils import (
SingleFileComponentError,
convert_controlnet_checkpoint,
convert_ldm_unet_checkpoint,
convert_ldm_vae_checkpoint,
convert_stable_cascade_unet_single_file_to_diffusers,
create_controlnet_diffusers_config_from_ldm,
create_unet_diffusers_config_from_ldm,
create_vae_diffusers_config_from_ldm,
fetch_diffusers_config,
fetch_original_config,
load_single_file_checkpoint,
)
logger = logging.get_logger(__name__)
if is_accelerate_available():
from accelerate import init_empty_weights
from ..models.modeling_utils import load_model_dict_into_meta
SINGLE_FILE_LOADABLE_CLASSES = {
"StableCascadeUNet": {
"checkpoint_mapping_fn": convert_stable_cascade_unet_single_file_to_diffusers,
},
"UNet2DConditionModel": {
"checkpoint_mapping_fn": convert_ldm_unet_checkpoint,
"config_mapping_fn": create_unet_diffusers_config_from_ldm,
"default_subfolder": "unet",
"legacy_kwargs": {
"num_in_channels": "in_channels", # Legacy kwargs supported by `from_single_file` mapped to new args
},
},
"AutoencoderKL": {
"checkpoint_mapping_fn": convert_ldm_vae_checkpoint,
"config_mapping_fn": create_vae_diffusers_config_from_ldm,
"default_subfolder": "vae",
},
"ControlNetModel": {
"checkpoint_mapping_fn": convert_controlnet_checkpoint,
"config_mapping_fn": create_controlnet_diffusers_config_from_ldm,
},
}
def _get_mapping_function_kwargs(mapping_fn, **kwargs):
parameters = inspect.signature(mapping_fn).parameters
mapping_kwargs = {}
for parameter in parameters:
if parameter in kwargs:
mapping_kwargs[parameter] = kwargs[parameter]
return mapping_kwargs
class FromOriginalModelMixin:
"""
Load pretrained weights saved in the `.ckpt` or `.safetensors` format into a model.
"""
@classmethod
@validate_hf_hub_args
def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = None, **kwargs):
r"""
Instantiate a model from pretrained weights saved in the original `.ckpt` or `.safetensors` format. The model
is set in evaluation mode (`model.eval()`) by default.
Parameters:
pretrained_model_link_or_path_or_dict (`str`, *optional*):
Can be either:
- A link to the `.safetensors` or `.ckpt` file (for example
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.safetensors"`) on the Hub.
- A path to a local *file* containing the weights of the component model.
- A state dict containing the component model weights.
config (`str`, *optional*):
- A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline hosted
on the Hub.
- A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline component
configs in Diffusers format.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
original_config (`str`, *optional*):
Dict or path to a yaml file containing the configuration for the model in its original format.
If a dict is provided, it will be used to initialize the model configuration.
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
dtype is automatically derived from the model's weights.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
incompletely downloaded files are deleted.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to True, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (for example the pipeline components of the
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
method. See example below for more information.
```py
>>> from diffusers import StableCascadeUNet
>>> ckpt_path = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_lite.safetensors"
>>> model = StableCascadeUNet.from_single_file(ckpt_path)
```
"""
class_name = cls.__name__
if class_name not in SINGLE_FILE_LOADABLE_CLASSES:
raise ValueError(
f"FromOriginalModelMixin is currently only compatible with {', '.join(SINGLE_FILE_LOADABLE_CLASSES.keys())}"
)
pretrained_model_link_or_path = kwargs.get("pretrained_model_link_or_path", None)
if pretrained_model_link_or_path is not None:
deprecation_message = (
"Please use `pretrained_model_link_or_path_or_dict` argument instead for model classes"
)
deprecate("pretrained_model_link_or_path", "1.0.0", deprecation_message)
pretrained_model_link_or_path_or_dict = pretrained_model_link_or_path
config = kwargs.pop("config", None)
original_config = kwargs.pop("original_config", None)
if config is not None and original_config is not None:
raise ValueError(
"`from_single_file` cannot accept both `config` and `original_config` arguments. Please provide only one of these arguments"
)
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
token = kwargs.pop("token", None)
cache_dir = kwargs.pop("cache_dir", None)
local_files_only = kwargs.pop("local_files_only", None)
subfolder = kwargs.pop("subfolder", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
if isinstance(pretrained_model_link_or_path_or_dict, dict):
checkpoint = pretrained_model_link_or_path_or_dict
else:
checkpoint = load_single_file_checkpoint(
pretrained_model_link_or_path_or_dict,
resume_download=resume_download,
force_download=force_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
)
mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[class_name]
checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"]
if original_config:
if "config_mapping_fn" in mapping_functions:
config_mapping_fn = mapping_functions["config_mapping_fn"]
else:
config_mapping_fn = None
if config_mapping_fn is None:
raise ValueError(
(
f"`original_config` has been provided for {class_name} but no mapping function"
"was found to convert the original config to a Diffusers config in"
"`diffusers.loaders.single_file_utils`"
)
)
if isinstance(original_config, str):
# If original_config is a URL or filepath fetch the original_config dict
original_config = fetch_original_config(original_config, local_files_only=local_files_only)
config_mapping_kwargs = _get_mapping_function_kwargs(config_mapping_fn, **kwargs)
diffusers_model_config = config_mapping_fn(
original_config=original_config, checkpoint=checkpoint, **config_mapping_kwargs
)
else:
if config:
if isinstance(config, str):
default_pretrained_model_config_name = config
else:
raise ValueError(
(
"Invalid `config` argument. Please provide a string representing a repo id"
"or path to a local Diffusers model repo."
)
)
else:
config = fetch_diffusers_config(checkpoint)
default_pretrained_model_config_name = config["pretrained_model_name_or_path"]
if "default_subfolder" in mapping_functions:
subfolder = mapping_functions["default_subfolder"]
subfolder = subfolder or config.pop(
"subfolder", None
) # some configs contain a subfolder key, e.g. StableCascadeUNet
diffusers_model_config = cls.load_config(
pretrained_model_name_or_path=default_pretrained_model_config_name,
subfolder=subfolder,
local_files_only=local_files_only,
)
expected_kwargs, optional_kwargs = cls._get_signature_keys(cls)
# Map legacy kwargs to new kwargs
if "legacy_kwargs" in mapping_functions:
legacy_kwargs = mapping_functions["legacy_kwargs"]
for legacy_key, new_key in legacy_kwargs.items():
if legacy_key in kwargs:
kwargs[new_key] = kwargs.pop(legacy_key)
model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs}
diffusers_model_config.update(model_kwargs)
checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
diffusers_format_checkpoint = checkpoint_mapping_fn(
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
)
if not diffusers_format_checkpoint:
raise SingleFileComponentError(
f"Failed to load {class_name}. Weights for this component appear to be missing in the checkpoint."
)
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
model = cls.from_config(diffusers_model_config)
if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
if model._keys_to_ignore_on_load_unexpected is not None:
for pat in model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else:
model.load_state_dict(diffusers_format_checkpoint)
if torch_dtype is not None:
model.to(torch_dtype)
model.eval()
return model
......@@ -26,7 +26,6 @@ import yaml
from ..models.modeling_utils import load_state_dict
from ..schedulers import (
DDIMScheduler,
DDPMScheduler,
DPMSolverMultistepScheduler,
EDMDPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
......@@ -35,17 +34,19 @@ from ..schedulers import (
LMSDiscreteScheduler,
PNDMScheduler,
)
from ..utils import is_accelerate_available, is_transformers_available, logging
from ..utils import (
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME,
deprecate,
is_accelerate_available,
is_transformers_available,
logging,
)
from ..utils.hub_utils import _get_model_file
if is_transformers_available():
from transformers import (
CLIPTextConfig,
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPTokenizer,
)
from transformers import AutoImageProcessor
if is_accelerate_available():
from accelerate import init_empty_weights
......@@ -54,116 +55,64 @@ if is_accelerate_available():
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
CONFIG_URLS = {
"v1": "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml",
"v2": "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml",
"xl": "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml",
"xl_refiner": "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml",
"upscale": "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml",
"controlnet": "https://raw.githubusercontent.com/lllyasviel/ControlNet/main/models/cldm_v15.yaml",
}
CHECKPOINT_KEY_NAMES = {
"v2": "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
"xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias",
"xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias",
"upscale": "model.diffusion_model.input_blocks.10.0.skip_connection.bias",
"controlnet": "control_model.time_embed.0.weight",
"playground-v2-5": "edm_mean",
"inpainting": "model.diffusion_model.input_blocks.0.0.weight",
"clip": "cond_stage_model.transformer.text_model.embeddings.position_ids",
"clip_sdxl": "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight",
"open_clip": "cond_stage_model.model.token_embedding.weight",
"open_clip_sdxl": "conditioner.embedders.1.model.positional_embedding",
"open_clip_sdxl_refiner": "conditioner.embedders.0.model.text_projection",
"stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
"stable_cascade_stage_c": "clip_txt_mapper.weight",
}
SCHEDULER_DEFAULT_CONFIG = {
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"beta_end": 0.012,
"interpolation_type": "linear",
"num_train_timesteps": 1000,
"prediction_type": "epsilon",
"sample_max_value": 1.0,
"set_alpha_to_one": False,
"skip_prk_steps": True,
"steps_offset": 1,
"timestep_spacing": "leading",
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"xl_base": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-base-1.0"},
"xl_refiner": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-refiner-1.0"},
"xl_inpaint": {"pretrained_model_name_or_path": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"},
"playground-v2-5": {"pretrained_model_name_or_path": "playgroundai/playground-v2.5-1024px-aesthetic"},
"upscale": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-x4-upscaler"},
"inpainting": {"pretrained_model_name_or_path": "runwayml/stable-diffusion-inpainting"},
"inpainting_v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-inpainting"},
"controlnet": {"pretrained_model_name_or_path": "lllyasviel/control_v11p_sd15_canny"},
"v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-1"},
"v1": {"pretrained_model_name_or_path": "runwayml/stable-diffusion-v1-5"},
"stable_cascade_stage_b": {"pretrained_model_name_or_path": "stabilityai/stable-cascade", "subfolder": "decoder"},
"stable_cascade_stage_b_lite": {
"pretrained_model_name_or_path": "stabilityai/stable-cascade",
"subfolder": "decoder_lite",
},
"stable_cascade_stage_c": {
"pretrained_model_name_or_path": "stabilityai/stable-cascade-prior",
"subfolder": "prior",
},
"stable_cascade_stage_c_lite": {
"pretrained_model_name_or_path": "stabilityai/stable-cascade-prior",
"subfolder": "prior_lite",
},
}
STABLE_CASCADE_DEFAULT_CONFIGS = {
"stage_c": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "prior"},
"stage_c_lite": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "prior_lite"},
"stage_b": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "decoder"},
"stage_b_lite": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "decoder_lite"},
# Use to configure model sample size when original config is provided
DIFFUSERS_TO_LDM_DEFAULT_IMAGE_SIZE_MAP = {
"xl_base": 1024,
"xl_refiner": 1024,
"xl_inpaint": 1024,
"playground-v2-5": 1024,
"upscale": 512,
"inpainting": 512,
"inpainting_v2": 512,
"controlnet": 512,
"v2": 768,
"v1": 512,
}
def convert_stable_cascade_unet_single_file_to_diffusers(original_state_dict):
is_stage_c = "clip_txt_mapper.weight" in original_state_dict
if is_stage_c:
state_dict = {}
for key in original_state_dict.keys():
if key.endswith("in_proj_weight"):
weights = original_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
elif key.endswith("in_proj_bias"):
weights = original_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
elif key.endswith("out_proj.weight"):
weights = original_state_dict[key]
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
elif key.endswith("out_proj.bias"):
weights = original_state_dict[key]
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
else:
state_dict[key] = original_state_dict[key]
else:
state_dict = {}
for key in original_state_dict.keys():
if key.endswith("in_proj_weight"):
weights = original_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
elif key.endswith("in_proj_bias"):
weights = original_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
elif key.endswith("out_proj.weight"):
weights = original_state_dict[key]
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
elif key.endswith("out_proj.bias"):
weights = original_state_dict[key]
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
# rename clip_mapper to clip_txt_pooled_mapper
elif key.endswith("clip_mapper.weight"):
weights = original_state_dict[key]
state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights
elif key.endswith("clip_mapper.bias"):
weights = original_state_dict[key]
state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights
else:
state_dict[key] = original_state_dict[key]
return state_dict
def infer_stable_cascade_single_file_config(checkpoint):
is_stage_c = "clip_txt_mapper.weight" in checkpoint
is_stage_b = "down_blocks.1.0.channelwise.0.weight" in checkpoint
if is_stage_c and (checkpoint["clip_txt_mapper.weight"].shape[0] == 1536):
config_type = "stage_c_lite"
elif is_stage_c and (checkpoint["clip_txt_mapper.weight"].shape[0] == 2048):
config_type = "stage_c"
elif is_stage_b and checkpoint["down_blocks.1.0.channelwise.0.weight"].shape[-1] == 576:
config_type = "stage_b_lite"
elif is_stage_b and checkpoint["down_blocks.1.0.channelwise.0.weight"].shape[-1] == 640:
config_type = "stage_b"
return STABLE_CASCADE_DEFAULT_CONFIGS[config_type]
DIFFUSERS_TO_LDM_MAPPING = {
"unet": {
"layers": {
......@@ -257,14 +206,6 @@ DIFFUSERS_TO_LDM_MAPPING = {
},
}
LDM_VAE_KEY = "first_stage_model."
LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
PLAYGROUND_VAE_SCALING_FACTOR = 0.5
LDM_UNET_KEY = "model.diffusion_model."
LDM_CONTROLNET_KEY = "control_model."
LDM_CLIP_PREFIX_TO_REMOVE = ["cond_stage_model.transformer.", "conditioner.embedders.0.transformer."]
LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
SD_2_TEXT_ENCODER_KEYS_TO_IGNORE = [
"cond_stage_model.model.transformer.resblocks.23.attn.in_proj_bias",
"cond_stage_model.model.transformer.resblocks.23.attn.in_proj_weight",
......@@ -281,11 +222,51 @@ SD_2_TEXT_ENCODER_KEYS_TO_IGNORE = [
"cond_stage_model.model.text_projection",
]
# To support legacy scheduler_type argument
SCHEDULER_DEFAULT_CONFIG = {
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"beta_end": 0.012,
"interpolation_type": "linear",
"num_train_timesteps": 1000,
"prediction_type": "epsilon",
"sample_max_value": 1.0,
"set_alpha_to_one": False,
"skip_prk_steps": True,
"steps_offset": 1,
"timestep_spacing": "leading",
}
LDM_VAE_KEY = "first_stage_model."
LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
PLAYGROUND_VAE_SCALING_FACTOR = 0.5
LDM_UNET_KEY = "model.diffusion_model."
LDM_CONTROLNET_KEY = "control_model."
LDM_CLIP_PREFIX_TO_REMOVE = ["cond_stage_model.transformer.", "conditioner.embedders.0.transformer."]
OPEN_CLIP_PREFIX = "conditioner.embedders.0.model."
LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
VALID_URL_PREFIXES = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]
class SingleFileComponentError(Exception):
def __init__(self, message=None):
self.message = message
super().__init__(self.message)
def is_valid_url(url):
result = urlparse(url)
if result.scheme and result.netloc:
return True
return False
def _extract_repo_id_and_weights_name(pretrained_model_name_or_path):
if not is_valid_url(pretrained_model_name_or_path):
raise ValueError("Invalid `pretrained_model_name_or_path` provided. Please set it to a valid URL.")
pattern = r"([^/]+)/([^/]+)/(?:blob/main/)?(.+)"
weights_name = None
repo_id = (None,)
......@@ -293,6 +274,7 @@ def _extract_repo_id_and_weights_name(pretrained_model_name_or_path):
pretrained_model_name_or_path = pretrained_model_name_or_path.replace(prefix, "")
match = re.match(pattern, pretrained_model_name_or_path)
if not match:
logger.warning("Unable to identify the repo_id and weights_name from the provided URL.")
return repo_id, weights_name
repo_id = f"{match.group(1)}/{match.group(2)}"
......@@ -301,34 +283,18 @@ def _extract_repo_id_and_weights_name(pretrained_model_name_or_path):
return repo_id, weights_name
def fetch_ldm_config_and_checkpoint(
pretrained_model_link_or_path,
class_name,
original_config_file=None,
resume_download=None,
force_download=False,
proxies=None,
token=None,
cache_dir=None,
local_files_only=None,
revision=None,
):
checkpoint = load_single_file_model_checkpoint(
pretrained_model_link_or_path,
resume_download=resume_download,
force_download=force_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
)
original_config = fetch_original_config(class_name, checkpoint, original_config_file)
def _is_model_weights_in_cached_folder(cached_folder, name):
pretrained_model_name_or_path = os.path.join(cached_folder, name)
weights_exist = False
for weights_name in [WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME]:
if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
weights_exist = True
return original_config, checkpoint
return weights_exist
def load_single_file_model_checkpoint(
def load_single_file_checkpoint(
pretrained_model_link_or_path,
resume_download=False,
force_download=False,
......@@ -339,10 +305,11 @@ def load_single_file_model_checkpoint(
revision=None,
):
if os.path.isfile(pretrained_model_link_or_path):
checkpoint = load_state_dict(pretrained_model_link_or_path)
pretrained_model_link_or_path = pretrained_model_link_or_path
else:
repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path)
checkpoint_path = _get_model_file(
pretrained_model_link_or_path = _get_model_file(
repo_id,
weights_name=weights_name,
force_download=force_download,
......@@ -353,7 +320,8 @@ def load_single_file_model_checkpoint(
token=token,
revision=revision,
)
checkpoint = load_state_dict(checkpoint_path)
checkpoint = load_state_dict(pretrained_model_link_or_path)
# some checkpoints contain the model state dict under a "state_dict" key
while "state_dict" in checkpoint:
......@@ -362,120 +330,154 @@ def load_single_file_model_checkpoint(
return checkpoint
def infer_original_config_file(class_name, checkpoint):
if CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024:
config_url = CONFIG_URLS["v2"]
def fetch_original_config(original_config_file, local_files_only=False):
if os.path.isfile(original_config_file):
with open(original_config_file, "r") as fp:
original_config_file = fp.read()
elif CHECKPOINT_KEY_NAMES["xl_base"] in checkpoint:
config_url = CONFIG_URLS["xl"]
elif is_valid_url(original_config_file):
if local_files_only:
raise ValueError(
"`local_files_only` is set to True, but a URL was provided as `original_config_file`. "
"Please provide a valid local file path."
)
elif CHECKPOINT_KEY_NAMES["xl_refiner"] in checkpoint:
config_url = CONFIG_URLS["xl_refiner"]
original_config_file = BytesIO(requests.get(original_config_file).content)
elif class_name == "StableDiffusionUpscalePipeline":
config_url = CONFIG_URLS["upscale"]
else:
raise ValueError("Invalid `original_config_file` provided. Please set it to a valid file path or URL.")
elif class_name == "ControlNetModel":
config_url = CONFIG_URLS["controlnet"]
original_config = yaml.safe_load(original_config_file)
else:
config_url = CONFIG_URLS["v1"]
return original_config
original_config_file = BytesIO(requests.get(config_url).content)
return original_config_file
def is_clip_model(checkpoint):
if CHECKPOINT_KEY_NAMES["clip"] in checkpoint:
return True
return False
def fetch_original_config(pipeline_class_name, checkpoint, original_config_file=None):
def is_valid_url(url):
result = urlparse(url)
if result.scheme and result.netloc:
return True
return False
def is_clip_sdxl_model(checkpoint):
if CHECKPOINT_KEY_NAMES["clip_sdxl"] in checkpoint:
return True
if original_config_file is None:
original_config_file = infer_original_config_file(pipeline_class_name, checkpoint)
return False
elif os.path.isfile(original_config_file):
with open(original_config_file, "r") as fp:
original_config_file = fp.read()
elif is_valid_url(original_config_file):
original_config_file = BytesIO(requests.get(original_config_file).content)
def is_open_clip_model(checkpoint):
if CHECKPOINT_KEY_NAMES["open_clip"] in checkpoint:
return True
else:
raise ValueError("Invalid `original_config_file` provided. Please set it to a valid file path or URL.")
return False
original_config = yaml.safe_load(original_config_file)
return original_config
def is_open_clip_sdxl_model(checkpoint):
if CHECKPOINT_KEY_NAMES["open_clip_sdxl"] in checkpoint:
return True
return False
def infer_model_type(original_config, checkpoint, model_type=None):
if model_type is not None:
return model_type
has_cond_stage_config = (
"cond_stage_config" in original_config["model"]["params"]
and original_config["model"]["params"]["cond_stage_config"] is not None
)
has_network_config = (
"network_config" in original_config["model"]["params"]
and original_config["model"]["params"]["network_config"] is not None
def is_open_clip_sdxl_refiner_model(checkpoint):
if CHECKPOINT_KEY_NAMES["open_clip_sdxl_refiner"] in checkpoint:
return True
return False
def is_clip_model_in_single_file(class_obj, checkpoint):
is_clip_in_checkpoint = any(
[
is_clip_model(checkpoint),
is_open_clip_model(checkpoint),
is_open_clip_sdxl_model(checkpoint),
is_open_clip_sdxl_refiner_model(checkpoint),
]
)
if (
class_obj.__name__ == "CLIPTextModel" or class_obj.__name__ == "CLIPTextModelWithProjection"
) and is_clip_in_checkpoint:
return True
return False
if has_cond_stage_config:
model_type = original_config["model"]["params"]["cond_stage_config"]["target"].split(".")[-1]
elif has_network_config:
context_dim = original_config["model"]["params"]["network_config"]["params"]["context_dim"]
if "edm_mean" in checkpoint and "edm_std" in checkpoint:
model_type = "Playground"
elif context_dim == 2048:
model_type = "SDXL"
def infer_diffusers_model_type(checkpoint):
if (
CHECKPOINT_KEY_NAMES["inpainting"] in checkpoint
and checkpoint[CHECKPOINT_KEY_NAMES["inpainting"]].shape[1] == 9
):
if CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024:
model_type = "inpainting_v2"
else:
model_type = "SDXL-Refiner"
else:
raise ValueError("Unable to infer model type from config")
model_type = "inpainting"
logger.debug(f"No `model_type` given, `model_type` inferred as: {model_type}")
elif CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024:
model_type = "v2"
return model_type
elif CHECKPOINT_KEY_NAMES["playground-v2-5"] in checkpoint:
model_type = "playground-v2-5"
elif CHECKPOINT_KEY_NAMES["xl_base"] in checkpoint:
model_type = "xl_base"
def get_default_scheduler_config():
return SCHEDULER_DEFAULT_CONFIG
elif CHECKPOINT_KEY_NAMES["xl_refiner"] in checkpoint:
model_type = "xl_refiner"
elif CHECKPOINT_KEY_NAMES["upscale"] in checkpoint:
model_type = "upscale"
def set_image_size(pipeline_class_name, original_config, checkpoint, image_size=None, model_type=None):
if image_size:
return image_size
elif CHECKPOINT_KEY_NAMES["controlnet"] in checkpoint:
model_type = "controlnet"
global_step = checkpoint["global_step"] if "global_step" in checkpoint else None
model_type = infer_model_type(original_config, checkpoint, model_type)
elif (
CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"] in checkpoint
and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"]].shape[0] == 1536
):
model_type = "stable_cascade_stage_c_lite"
if pipeline_class_name == "StableDiffusionUpscalePipeline":
image_size = original_config["model"]["params"]["unet_config"]["params"]["image_size"]
return image_size
elif (
CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"] in checkpoint
and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"]].shape[0] == 2048
):
model_type = "stable_cascade_stage_c"
elif model_type in ["SDXL", "SDXL-Refiner", "Playground"]:
image_size = 1024
return image_size
elif (
CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"] in checkpoint
and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"]].shape[-1] == 576
):
model_type = "stable_cascade_stage_b_lite"
elif (
"parameterization" in original_config["model"]["params"]
and original_config["model"]["params"]["parameterization"] == "v"
CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"] in checkpoint
and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"]].shape[-1] == 640
):
# NOTE: For stable diffusion 2 base one has to pass `image_size==512`
# as it relies on a brittle global step parameter here
image_size = 512 if global_step == 875000 else 768
return image_size
model_type = "stable_cascade_stage_b"
else:
image_size = 512
model_type = "v1"
return model_type
def fetch_diffusers_config(checkpoint):
model_type = infer_diffusers_model_type(checkpoint)
model_path = DIFFUSERS_DEFAULT_PIPELINE_PATHS[model_type]
return model_path
def set_image_size(checkpoint, image_size=None):
if image_size:
return image_size
model_type = infer_diffusers_model_type(checkpoint)
image_size = DIFFUSERS_TO_LDM_DEFAULT_IMAGE_SIZE_MAP[model_type]
return image_size
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear
def conv_attn_to_linear(checkpoint):
......@@ -490,10 +492,21 @@ def conv_attn_to_linear(checkpoint):
checkpoint[key] = checkpoint[key][:, :, 0]
def create_unet_diffusers_config(original_config, image_size: int):
def create_unet_diffusers_config_from_ldm(
original_config, checkpoint, image_size=None, upcast_attention=None, num_in_channels=None
):
"""
Creates a config for the diffusers based on the config of the LDM model.
"""
if image_size is not None:
deprecation_message = (
"Configuring UNet2DConditionModel with the `image_size` argument to `from_single_file`"
"is deprecated and will be ignored in future versions."
)
deprecate("image_size", "1.0.0", deprecation_message)
image_size = set_image_size(checkpoint, image_size=image_size)
if (
"unet_config" in original_config["model"]["params"]
and original_config["model"]["params"]["unet_config"] is not None
......@@ -502,6 +515,16 @@ def create_unet_diffusers_config(original_config, image_size: int):
else:
unet_params = original_config["model"]["params"]["network_config"]["params"]
if num_in_channels is not None:
deprecation_message = (
"Configuring UNet2DConditionModel with the `num_in_channels` argument to `from_single_file`"
"is deprecated and will be ignored in future versions."
)
deprecate("image_size", "1.0.0", deprecation_message)
in_channels = num_in_channels
else:
in_channels = unet_params["in_channels"]
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
......@@ -566,7 +589,7 @@ def create_unet_diffusers_config(original_config, image_size: int):
config = {
"sample_size": image_size // vae_scale_factor,
"in_channels": unet_params["in_channels"],
"in_channels": in_channels,
"down_block_types": down_block_types,
"block_out_channels": block_out_channels,
"layers_per_block": unet_params["num_res_blocks"],
......@@ -580,6 +603,14 @@ def create_unet_diffusers_config(original_config, image_size: int):
"transformer_layers_per_block": transformer_layers_per_block,
}
if upcast_attention is not None:
deprecation_message = (
"Configuring UNet2DConditionModel with the `upcast_attention` argument to `from_single_file`"
"is deprecated and will be ignored in future versions."
)
deprecate("image_size", "1.0.0", deprecation_message)
config["upcast_attention"] = upcast_attention
if "disable_self_attentions" in unet_params:
config["only_cross_attention"] = unet_params["disable_self_attentions"]
......@@ -592,9 +623,18 @@ def create_unet_diffusers_config(original_config, image_size: int):
return config
def create_controlnet_diffusers_config(original_config, image_size: int):
def create_controlnet_diffusers_config_from_ldm(original_config, checkpoint, image_size=None, **kwargs):
if image_size is not None:
deprecation_message = (
"Configuring ControlNetModel with the `image_size` argument"
"is deprecated and will be ignored in future versions."
)
deprecate("image_size", "1.0.0", deprecation_message)
image_size = set_image_size(checkpoint, image_size=image_size)
unet_params = original_config["model"]["params"]["control_stage_config"]["params"]
diffusers_unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
diffusers_unet_config = create_unet_diffusers_config_from_ldm(original_config, image_size=image_size)
controlnet_config = {
"conditioning_channels": unet_params["hint_channels"],
......@@ -615,15 +655,33 @@ def create_controlnet_diffusers_config(original_config, image_size: int):
return controlnet_config
def create_vae_diffusers_config(original_config, image_size, scaling_factor=None, latents_mean=None, latents_std=None):
def create_vae_diffusers_config_from_ldm(original_config, checkpoint, image_size=None, scaling_factor=None):
"""
Creates a config for the diffusers based on the config of the LDM model.
"""
if image_size is not None:
deprecation_message = (
"Configuring AutoencoderKL with the `image_size` argument"
"is deprecated and will be ignored in future versions."
)
deprecate("image_size", "1.0.0", deprecation_message)
image_size = set_image_size(checkpoint, image_size=image_size)
if "edm_mean" in checkpoint and "edm_std" in checkpoint:
latents_mean = checkpoint["edm_mean"]
latents_std = checkpoint["edm_std"]
else:
latents_mean = None
latents_std = None
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
if (scaling_factor is None) and (latents_mean is not None) and (latents_std is not None):
scaling_factor = PLAYGROUND_VAE_SCALING_FACTOR
elif (scaling_factor is None) and ("scale_factor" in original_config["model"]["params"]):
scaling_factor = original_config["model"]["params"]["scale_factor"]
elif scaling_factor is None:
scaling_factor = LDM_VAE_DEFAULT_SCALING_FACTOR
......@@ -660,16 +718,104 @@ def update_unet_resnet_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, ma
)
if mapping:
diffusers_key = diffusers_key.replace(mapping["old"], mapping["new"])
new_checkpoint[diffusers_key] = checkpoint.pop(ldm_key)
new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
def update_unet_attention_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, mapping):
for ldm_key in ldm_keys:
diffusers_key = ldm_key.replace(mapping["old"], mapping["new"])
new_checkpoint[diffusers_key] = checkpoint.pop(ldm_key)
new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
def update_vae_resnet_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
for ldm_key in keys:
diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]).replace("nin_shortcut", "conv_shortcut")
new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
def update_vae_attentions_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
for ldm_key in keys:
diffusers_key = (
ldm_key.replace(mapping["old"], mapping["new"])
.replace("norm.weight", "group_norm.weight")
.replace("norm.bias", "group_norm.bias")
.replace("q.weight", "to_q.weight")
.replace("q.bias", "to_q.bias")
.replace("k.weight", "to_k.weight")
.replace("k.bias", "to_k.bias")
.replace("v.weight", "to_v.weight")
.replace("v.bias", "to_v.bias")
.replace("proj_out.weight", "to_out.0.weight")
.replace("proj_out.bias", "to_out.0.bias")
)
new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
# proj_attn.weight has to be converted from conv 1D to linear
shape = new_checkpoint[diffusers_key].shape
if len(shape) == 3:
new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0]
elif len(shape) == 4:
new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0, 0]
def convert_stable_cascade_unet_single_file_to_diffusers(checkpoint, **kwargs):
is_stage_c = "clip_txt_mapper.weight" in checkpoint
if is_stage_c:
state_dict = {}
for key in checkpoint.keys():
if key.endswith("in_proj_weight"):
weights = checkpoint[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
elif key.endswith("in_proj_bias"):
weights = checkpoint[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
elif key.endswith("out_proj.weight"):
weights = checkpoint[key]
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
elif key.endswith("out_proj.bias"):
weights = checkpoint[key]
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
else:
state_dict[key] = checkpoint[key]
else:
state_dict = {}
for key in checkpoint.keys():
if key.endswith("in_proj_weight"):
weights = checkpoint[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
elif key.endswith("in_proj_bias"):
weights = checkpoint[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
elif key.endswith("out_proj.weight"):
weights = checkpoint[key]
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
elif key.endswith("out_proj.bias"):
weights = checkpoint[key]
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
# rename clip_mapper to clip_txt_pooled_mapper
elif key.endswith("clip_mapper.weight"):
weights = checkpoint[key]
state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights
elif key.endswith("clip_mapper.bias"):
weights = checkpoint[key]
state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights
else:
state_dict[key] = checkpoint[key]
return state_dict
def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False):
def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False, **kwargs):
"""
Takes a state dict and a config, and returns a converted checkpoint.
"""
......@@ -680,24 +826,24 @@ def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False):
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
logger.warning("Checkpoint has both EMA and non-EMA weights.")
logger.warning(
logger.warninging("Checkpoint has both EMA and non-EMA weights.")
logger.warninging(
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
)
for key in keys:
if key.startswith("model.diffusion_model"):
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
unet_state_dict[key.replace(unet_key, "")] = checkpoint.get(flat_ema_key)
else:
if sum(k.startswith("model_ema") for k in keys) > 100:
logger.warning(
logger.warninging(
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
)
for key in keys:
if key.startswith(unet_key):
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
unet_state_dict[key.replace(unet_key, "")] = checkpoint.get(key)
new_checkpoint = {}
ldm_unet_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["layers"]
......@@ -758,10 +904,10 @@ def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False):
)
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.get(
f"input_blocks.{i}.0.op.weight"
)
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.get(
f"input_blocks.{i}.0.op.bias"
)
......@@ -775,19 +921,22 @@ def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False):
)
# Mid blocks
resnet_0 = middle_blocks[0]
attentions = middle_blocks[1]
resnet_1 = middle_blocks[2]
update_unet_resnet_ldm_to_diffusers(
resnet_0, new_checkpoint, unet_state_dict, mapping={"old": "middle_block.0", "new": "mid_block.resnets.0"}
)
update_unet_resnet_ldm_to_diffusers(
resnet_1, new_checkpoint, unet_state_dict, mapping={"old": "middle_block.2", "new": "mid_block.resnets.1"}
)
update_unet_attention_ldm_to_diffusers(
attentions, new_checkpoint, unet_state_dict, mapping={"old": "middle_block.1", "new": "mid_block.attentions.0"}
)
for key in middle_blocks.keys():
diffusers_key = max(key - 1, 0)
if key % 2 == 0:
update_unet_resnet_ldm_to_diffusers(
middle_blocks[key],
new_checkpoint,
unet_state_dict,
mapping={"old": f"middle_block.{key}", "new": f"mid_block.resnets.{diffusers_key}"},
)
else:
update_unet_attention_ldm_to_diffusers(
middle_blocks[key],
new_checkpoint,
unet_state_dict,
mapping={"old": f"middle_block.{key}", "new": f"mid_block.attentions.{diffusers_key}"},
)
# Up Blocks
for i in range(num_output_blocks):
......@@ -836,6 +985,7 @@ def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False):
def convert_controlnet_checkpoint(
checkpoint,
config,
**kwargs,
):
# Some controlnet ckpt files are distributed independently from the rest of the
# model components i.e. https://huggingface.co/thibaud/controlnet-sd21/
......@@ -848,7 +998,7 @@ def convert_controlnet_checkpoint(
controlnet_key = LDM_CONTROLNET_KEY
for key in keys:
if key.startswith(controlnet_key):
controlnet_state_dict[key.replace(controlnet_key, "")] = checkpoint.pop(key)
controlnet_state_dict[key.replace(controlnet_key, "")] = checkpoint.get(key)
new_checkpoint = {}
ldm_controlnet_keys = DIFFUSERS_TO_LDM_MAPPING["controlnet"]["layers"]
......@@ -882,10 +1032,10 @@ def convert_controlnet_checkpoint(
)
if f"input_blocks.{i}.0.op.weight" in controlnet_state_dict:
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = controlnet_state_dict.pop(
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = controlnet_state_dict.get(
f"input_blocks.{i}.0.op.weight"
)
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = controlnet_state_dict.pop(
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = controlnet_state_dict.get(
f"input_blocks.{i}.0.op.bias"
)
......@@ -900,8 +1050,8 @@ def convert_controlnet_checkpoint(
# controlnet down blocks
for i in range(num_input_blocks):
new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = controlnet_state_dict.pop(f"zero_convs.{i}.0.weight")
new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = controlnet_state_dict.pop(f"zero_convs.{i}.0.bias")
new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = controlnet_state_dict.get(f"zero_convs.{i}.0.weight")
new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = controlnet_state_dict.get(f"zero_convs.{i}.0.bias")
# Retrieves the keys for the middle blocks only
num_middle_blocks = len(
......@@ -911,33 +1061,28 @@ def convert_controlnet_checkpoint(
layer_id: [key for key in controlnet_state_dict if f"middle_block.{layer_id}" in key]
for layer_id in range(num_middle_blocks)
}
if middle_blocks:
resnet_0 = middle_blocks[0]
attentions = middle_blocks[1]
resnet_1 = middle_blocks[2]
update_unet_resnet_ldm_to_diffusers(
resnet_0,
new_checkpoint,
controlnet_state_dict,
mapping={"old": "middle_block.0", "new": "mid_block.resnets.0"},
)
update_unet_resnet_ldm_to_diffusers(
resnet_1,
new_checkpoint,
controlnet_state_dict,
mapping={"old": "middle_block.2", "new": "mid_block.resnets.1"},
)
update_unet_attention_ldm_to_diffusers(
attentions,
new_checkpoint,
controlnet_state_dict,
mapping={"old": "middle_block.1", "new": "mid_block.attentions.0"},
)
# Mid blocks
for key in middle_blocks.keys():
diffusers_key = max(key - 1, 0)
if key % 2 == 0:
update_unet_resnet_ldm_to_diffusers(
middle_blocks[key],
new_checkpoint,
controlnet_state_dict,
mapping={"old": f"middle_block.{key}", "new": f"mid_block.resnets.{diffusers_key}"},
)
else:
update_unet_attention_ldm_to_diffusers(
middle_blocks[key],
new_checkpoint,
controlnet_state_dict,
mapping={"old": f"middle_block.{key}", "new": f"mid_block.attentions.{diffusers_key}"},
)
# mid block
new_checkpoint["controlnet_mid_block.weight"] = controlnet_state_dict.pop("middle_block_out.0.weight")
new_checkpoint["controlnet_mid_block.bias"] = controlnet_state_dict.pop("middle_block_out.0.bias")
new_checkpoint["controlnet_mid_block.weight"] = controlnet_state_dict.get("middle_block_out.0.weight")
new_checkpoint["controlnet_mid_block.bias"] = controlnet_state_dict.get("middle_block_out.0.bias")
# controlnet cond embedding blocks
cond_embedding_blocks = {
......@@ -951,86 +1096,16 @@ def convert_controlnet_checkpoint(
diffusers_idx = idx - 1
cond_block_id = 2 * idx
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.weight"] = controlnet_state_dict.pop(
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.weight"] = controlnet_state_dict.get(
f"input_hint_block.{cond_block_id}.weight"
)
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.bias"] = controlnet_state_dict.pop(
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.bias"] = controlnet_state_dict.get(
f"input_hint_block.{cond_block_id}.bias"
)
return new_checkpoint
def create_diffusers_controlnet_model_from_ldm(
pipeline_class_name, original_config, checkpoint, upcast_attention=False, image_size=None, torch_dtype=None
):
# import here to avoid circular imports
from ..models import ControlNetModel
image_size = set_image_size(pipeline_class_name, original_config, checkpoint, image_size=image_size)
diffusers_config = create_controlnet_diffusers_config(original_config, image_size=image_size)
diffusers_config["upcast_attention"] = upcast_attention
diffusers_format_controlnet_checkpoint = convert_controlnet_checkpoint(checkpoint, diffusers_config)
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
controlnet = ControlNetModel(**diffusers_config)
if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(
controlnet, diffusers_format_controlnet_checkpoint, dtype=torch_dtype
)
if controlnet._keys_to_ignore_on_load_unexpected is not None:
for pat in controlnet._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {controlnet.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else:
controlnet.load_state_dict(diffusers_format_controlnet_checkpoint)
if torch_dtype is not None:
controlnet = controlnet.to(torch_dtype)
return {"controlnet": controlnet}
def update_vae_resnet_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
for ldm_key in keys:
diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]).replace("nin_shortcut", "conv_shortcut")
new_checkpoint[diffusers_key] = checkpoint.pop(ldm_key)
def update_vae_attentions_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
for ldm_key in keys:
diffusers_key = (
ldm_key.replace(mapping["old"], mapping["new"])
.replace("norm.weight", "group_norm.weight")
.replace("norm.bias", "group_norm.bias")
.replace("q.weight", "to_q.weight")
.replace("q.bias", "to_q.bias")
.replace("k.weight", "to_k.weight")
.replace("k.bias", "to_k.bias")
.replace("v.weight", "to_v.weight")
.replace("v.bias", "to_v.bias")
.replace("proj_out.weight", "to_out.0.weight")
.replace("proj_out.bias", "to_out.0.bias")
)
new_checkpoint[diffusers_key] = checkpoint.pop(ldm_key)
# proj_attn.weight has to be converted from conv 1D to linear
shape = new_checkpoint[diffusers_key].shape
if len(shape) == 3:
new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0]
elif len(shape) == 4:
new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0, 0]
def convert_ldm_vae_checkpoint(checkpoint, config):
# extract state dict for VAE
# remove the LDM_VAE_KEY prefix from the ldm checkpoint keys so that it is easier to map them to diffusers keys
......@@ -1063,10 +1138,10 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
mapping={"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"},
)
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.get(
f"encoder.down.{i}.downsample.conv.weight"
)
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.get(
f"encoder.down.{i}.downsample.conv.bias"
)
......@@ -1131,18 +1206,7 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
return new_checkpoint
def create_text_encoder_from_ldm_clip_checkpoint(config_name, checkpoint, local_files_only=False, torch_dtype=None):
try:
config = CLIPTextConfig.from_pretrained(config_name, local_files_only=local_files_only)
except Exception:
raise ValueError(
f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: 'openai/clip-vit-large-patch14'."
)
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
text_model = CLIPTextModel(config)
def convert_ldm_clip_checkpoint(checkpoint):
keys = list(checkpoint.keys())
text_model_dict = {}
......@@ -1152,55 +1216,26 @@ def create_text_encoder_from_ldm_clip_checkpoint(config_name, checkpoint, local_
for prefix in remove_prefixes:
if key.startswith(prefix):
diffusers_key = key.replace(prefix, "")
text_model_dict[diffusers_key] = checkpoint[key]
text_model_dict[diffusers_key] = checkpoint.get(key)
if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(text_model, text_model_dict, dtype=torch_dtype)
if text_model._keys_to_ignore_on_load_unexpected is not None:
for pat in text_model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
return text_model_dict
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {text_model.__class__.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else:
if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)):
text_model_dict.pop("text_model.embeddings.position_ids", None)
text_model.load_state_dict(text_model_dict)
if torch_dtype is not None:
text_model = text_model.to(torch_dtype)
return text_model
def create_text_encoder_from_open_clip_checkpoint(
config_name,
def convert_open_clip_checkpoint(
text_model,
checkpoint,
prefix="cond_stage_model.model.",
has_projection=False,
local_files_only=False,
torch_dtype=None,
**config_kwargs,
):
try:
config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs, local_files_only=local_files_only)
except Exception:
raise ValueError(
f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: '{config_name}'."
)
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
text_model = CLIPTextModelWithProjection(config) if has_projection else CLIPTextModel(config)
text_model_dict = {}
text_proj_key = prefix + "text_projection"
text_proj_dim = (
int(checkpoint[text_proj_key].shape[0]) if text_proj_key in checkpoint else LDM_OPEN_CLIP_TEXT_PROJECTION_DIM
)
if text_proj_key in checkpoint:
text_proj_dim = int(checkpoint[text_proj_key].shape[0])
elif hasattr(text_model.config, "projection_dim"):
text_proj_dim = text_model.config.projection_dim
else:
text_proj_dim = LDM_OPEN_CLIP_TEXT_PROJECTION_DIM
text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
keys = list(checkpoint.keys())
......@@ -1233,303 +1268,165 @@ def create_text_encoder_from_open_clip_checkpoint(
)
if key.endswith(".in_proj_weight"):
weight_value = checkpoint[key]
weight_value = checkpoint.get(key)
text_model_dict[diffusers_key + ".q_proj.weight"] = weight_value[:text_proj_dim, :]
text_model_dict[diffusers_key + ".k_proj.weight"] = weight_value[text_proj_dim : text_proj_dim * 2, :]
text_model_dict[diffusers_key + ".v_proj.weight"] = weight_value[text_proj_dim * 2 :, :]
text_model_dict[diffusers_key + ".q_proj.weight"] = weight_value[:text_proj_dim, :].clone().detach()
text_model_dict[diffusers_key + ".k_proj.weight"] = (
weight_value[text_proj_dim : text_proj_dim * 2, :].clone().detach()
)
text_model_dict[diffusers_key + ".v_proj.weight"] = weight_value[text_proj_dim * 2 :, :].clone().detach()
elif key.endswith(".in_proj_bias"):
weight_value = checkpoint[key]
text_model_dict[diffusers_key + ".q_proj.bias"] = weight_value[:text_proj_dim]
text_model_dict[diffusers_key + ".k_proj.bias"] = weight_value[text_proj_dim : text_proj_dim * 2]
text_model_dict[diffusers_key + ".v_proj.bias"] = weight_value[text_proj_dim * 2 :]
else:
text_model_dict[diffusers_key] = checkpoint[key]
if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(text_model, text_model_dict, dtype=torch_dtype)
if text_model._keys_to_ignore_on_load_unexpected is not None:
for pat in text_model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {text_model.__class__.__name__}: \n {[', '.join(unexpected_keys)]}"
weight_value = checkpoint.get(key)
text_model_dict[diffusers_key + ".q_proj.bias"] = weight_value[:text_proj_dim].clone().detach()
text_model_dict[diffusers_key + ".k_proj.bias"] = (
weight_value[text_proj_dim : text_proj_dim * 2].clone().detach()
)
text_model_dict[diffusers_key + ".v_proj.bias"] = weight_value[text_proj_dim * 2 :].clone().detach()
else:
text_model_dict[diffusers_key] = checkpoint.get(key)
else:
if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)):
text_model_dict.pop("text_model.embeddings.position_ids", None)
text_model.load_state_dict(text_model_dict)
if torch_dtype is not None:
text_model = text_model.to(torch_dtype)
if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)):
text_model_dict.pop("text_model.embeddings.position_ids", None)
return text_model
return text_model_dict
def create_diffusers_unet_model_from_ldm(
pipeline_class_name,
original_config,
def create_diffusers_clip_model_from_ldm(
cls,
checkpoint,
num_in_channels=None,
upcast_attention=None,
extract_ema=False,
image_size=None,
subfolder="",
config=None,
torch_dtype=None,
model_type=None,
local_files_only=None,
is_legacy_loading=False,
):
from ..models import UNet2DConditionModel
if config:
config = {"pretrained_model_name_or_path": config}
else:
config = fetch_diffusers_config(checkpoint)
if num_in_channels is None:
if pipeline_class_name in [
"StableDiffusionInpaintPipeline",
"StableDiffusionControlNetInpaintPipeline",
"StableDiffusionXLInpaintPipeline",
"StableDiffusionXLControlNetInpaintPipeline",
]:
num_in_channels = 9
# For backwards compatibility
# Older versions of `from_single_file` expected CLIP configs to be placed in their original transformers model repo
# in the cache_dir, rather than in a subfolder of the Diffusers model
if is_legacy_loading:
logger.warning(
(
"Detected legacy CLIP loading behavior. Please run `from_single_file` with `local_files_only=False once to update "
"the local cache directory with the necessary CLIP model config files. "
"Attempting to load CLIP model from legacy cache directory."
)
)
elif pipeline_class_name == "StableDiffusionUpscalePipeline":
num_in_channels = 7
if is_clip_model(checkpoint) or is_clip_sdxl_model(checkpoint):
clip_config = "openai/clip-vit-large-patch14"
config["pretrained_model_name_or_path"] = clip_config
subfolder = ""
else:
num_in_channels = 4
elif is_open_clip_model(checkpoint):
clip_config = "stabilityai/stable-diffusion-2"
config["pretrained_model_name_or_path"] = clip_config
subfolder = "text_encoder"
image_size = set_image_size(
pipeline_class_name, original_config, checkpoint, image_size=image_size, model_type=model_type
)
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
unet_config["in_channels"] = num_in_channels
if upcast_attention is not None:
unet_config["upcast_attention"] = upcast_attention
else:
clip_config = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
config["pretrained_model_name_or_path"] = clip_config
subfolder = ""
diffusers_format_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config, extract_ema=extract_ema)
model_config = cls.config_class.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only)
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
unet = UNet2DConditionModel(**unet_config)
model = cls(model_config)
if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(unet, diffusers_format_unet_checkpoint, dtype=torch_dtype)
if unet._keys_to_ignore_on_load_unexpected is not None:
for pat in unet._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
position_embedding_dim = model.text_model.embeddings.position_embedding.weight.shape[-1]
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {unet.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else:
unet.load_state_dict(diffusers_format_unet_checkpoint)
if is_clip_model(checkpoint):
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint)
if torch_dtype is not None:
unet = unet.to(torch_dtype)
return {"unet": unet}
elif (
is_clip_sdxl_model(checkpoint)
and checkpoint[CHECKPOINT_KEY_NAMES["clip_sdxl"]].shape[-1] == position_embedding_dim
):
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint)
elif is_open_clip_model(checkpoint):
prefix = "cond_stage_model.model."
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
def create_diffusers_vae_model_from_ldm(
pipeline_class_name,
original_config,
checkpoint,
image_size=None,
scaling_factor=None,
torch_dtype=None,
model_type=None,
):
# import here to avoid circular imports
from ..models import AutoencoderKL
elif (
is_open_clip_sdxl_model(checkpoint)
and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sdxl"]].shape[-1] == position_embedding_dim
):
prefix = "conditioner.embedders.1.model."
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
image_size = set_image_size(
pipeline_class_name, original_config, checkpoint, image_size=image_size, model_type=model_type
)
model_type = infer_model_type(original_config, checkpoint, model_type)
elif is_open_clip_sdxl_refiner_model(checkpoint):
prefix = "conditioner.embedders.0.model."
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
if model_type == "Playground":
edm_mean = (
checkpoint["edm_mean"].to(dtype=torch_dtype).tolist() if torch_dtype else checkpoint["edm_mean"].tolist()
)
edm_std = (
checkpoint["edm_std"].to(dtype=torch_dtype).tolist() if torch_dtype else checkpoint["edm_std"].tolist()
)
else:
edm_mean = None
edm_std = None
vae_config = create_vae_diffusers_config(
original_config,
image_size=image_size,
scaling_factor=scaling_factor,
latents_mean=edm_mean,
latents_std=edm_std,
)
diffusers_format_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
vae = AutoencoderKL(**vae_config)
raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(vae, diffusers_format_vae_checkpoint, dtype=torch_dtype)
if vae._keys_to_ignore_on_load_unexpected is not None:
for pat in vae._keys_to_ignore_on_load_unexpected:
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
if model._keys_to_ignore_on_load_unexpected is not None:
for pat in model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {vae.__name__}: \n {[', '.join(unexpected_keys)]}"
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else:
vae.load_state_dict(diffusers_format_vae_checkpoint)
model.load_state_dict(diffusers_format_checkpoint)
if torch_dtype is not None:
vae = vae.to(torch_dtype)
model.to(torch_dtype)
return {"vae": vae}
model.eval()
return model
def create_text_encoders_and_tokenizers_from_ldm(
original_config,
def _legacy_load_scheduler(
cls,
checkpoint,
model_type=None,
local_files_only=False,
torch_dtype=None,
component_name,
original_config=None,
**kwargs,
):
model_type = infer_model_type(original_config, checkpoint=checkpoint, model_type=model_type)
if model_type == "FrozenOpenCLIPEmbedder":
config_name = "stabilityai/stable-diffusion-2"
config_kwargs = {"subfolder": "text_encoder"}
try:
text_encoder = create_text_encoder_from_open_clip_checkpoint(
config_name, checkpoint, local_files_only=local_files_only, torch_dtype=torch_dtype, **config_kwargs
)
tokenizer = CLIPTokenizer.from_pretrained(
config_name, subfolder="tokenizer", local_files_only=local_files_only
)
except Exception:
raise ValueError(
f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder in the following path: '{config_name}'."
)
else:
return {"text_encoder": text_encoder, "tokenizer": tokenizer}
elif model_type == "FrozenCLIPEmbedder":
try:
config_name = "openai/clip-vit-large-patch14"
text_encoder = create_text_encoder_from_ldm_clip_checkpoint(
config_name,
checkpoint,
local_files_only=local_files_only,
torch_dtype=torch_dtype,
)
tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only)
except Exception:
raise ValueError(
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: '{config_name}'."
)
else:
return {"text_encoder": text_encoder, "tokenizer": tokenizer}
elif model_type == "SDXL-Refiner":
config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
config_kwargs = {"projection_dim": 1280}
prefix = "conditioner.embedders.0.model."
try:
tokenizer_2 = CLIPTokenizer.from_pretrained(config_name, pad_token="!", local_files_only=local_files_only)
text_encoder_2 = create_text_encoder_from_open_clip_checkpoint(
config_name,
checkpoint,
prefix=prefix,
has_projection=True,
local_files_only=local_files_only,
torch_dtype=torch_dtype,
**config_kwargs,
)
except Exception:
raise ValueError(
f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder_2 and tokenizer_2 in the following path: {config_name} with `pad_token` set to '!'."
)
scheduler_type = kwargs.get("scheduler_type", None)
prediction_type = kwargs.get("prediction_type", None)
else:
return {
"text_encoder": None,
"tokenizer": None,
"tokenizer_2": tokenizer_2,
"text_encoder_2": text_encoder_2,
}
elif model_type in ["SDXL", "Playground"]:
try:
config_name = "openai/clip-vit-large-patch14"
tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only)
text_encoder = create_text_encoder_from_ldm_clip_checkpoint(
config_name, checkpoint, local_files_only=local_files_only, torch_dtype=torch_dtype
)
except Exception:
raise ValueError(
f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder and tokenizer in the following path: 'openai/clip-vit-large-patch14'."
)
try:
config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
config_kwargs = {"projection_dim": 1280}
prefix = "conditioner.embedders.1.model."
tokenizer_2 = CLIPTokenizer.from_pretrained(config_name, pad_token="!", local_files_only=local_files_only)
text_encoder_2 = create_text_encoder_from_open_clip_checkpoint(
config_name,
checkpoint,
prefix=prefix,
has_projection=True,
local_files_only=local_files_only,
torch_dtype=torch_dtype,
**config_kwargs,
)
except Exception:
raise ValueError(
f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder_2 and tokenizer_2 in the following path: {config_name} with `pad_token` set to '!'."
)
return {
"tokenizer": tokenizer,
"text_encoder": text_encoder,
"tokenizer_2": tokenizer_2,
"text_encoder_2": text_encoder_2,
}
return
if scheduler_type is not None:
deprecation_message = (
"Please pass an instance of a Scheduler object directly to the `scheduler` argument in `from_single_file`."
)
deprecate("scheduler_type", "1.0.0", deprecation_message)
if prediction_type is not None:
deprecation_message = (
"Please configure an instance of a Scheduler with the appropriate `prediction_type` "
"and pass the object directly to the `scheduler` argument in `from_single_file`."
)
deprecate("prediction_type", "1.0.0", deprecation_message)
def create_scheduler_from_ldm(
pipeline_class_name,
original_config,
checkpoint,
prediction_type=None,
scheduler_type="ddim",
model_type=None,
):
scheduler_config = get_default_scheduler_config()
model_type = infer_model_type(original_config, checkpoint=checkpoint, model_type=model_type)
scheduler_config = SCHEDULER_DEFAULT_CONFIG
model_type = infer_diffusers_model_type(checkpoint=checkpoint)
global_step = checkpoint["global_step"] if "global_step" in checkpoint else None
num_train_timesteps = getattr(original_config["model"]["params"], "timesteps", None) or 1000
if original_config:
num_train_timesteps = getattr(original_config["model"]["params"], "timesteps", 1000)
else:
num_train_timesteps = 1000
scheduler_config["num_train_timesteps"] = num_train_timesteps
if (
"parameterization" in original_config["model"]["params"]
and original_config["model"]["params"]["parameterization"] == "v"
):
if model_type == "v2":
if prediction_type is None:
# NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"`
# as it relies on a brittle global step parameter here
# NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` # as it relies on a brittle global step parameter here
prediction_type = "epsilon" if global_step == 875000 else "v_prediction"
else:
......@@ -1537,20 +1434,44 @@ def create_scheduler_from_ldm(
scheduler_config["prediction_type"] = prediction_type
if model_type in ["SDXL", "SDXL-Refiner"]:
if model_type in ["xl_base", "xl_refiner"]:
scheduler_type = "euler"
elif model_type == "Playground":
elif model_type == "playground":
scheduler_type = "edm_dpm_solver_multistep"
else:
beta_start = original_config["model"]["params"].get("linear_start", 0.02)
beta_end = original_config["model"]["params"].get("linear_end", 0.085)
if original_config:
beta_start = original_config["model"]["params"].get("linear_start")
beta_end = original_config["model"]["params"].get("linear_end")
else:
beta_start = 0.02
beta_end = 0.085
scheduler_config["beta_start"] = beta_start
scheduler_config["beta_end"] = beta_end
scheduler_config["beta_schedule"] = "scaled_linear"
scheduler_config["clip_sample"] = False
scheduler_config["set_alpha_to_one"] = False
if scheduler_type == "pndm":
# to deal with an edge case StableDiffusionUpscale pipeline has two schedulers
if component_name == "low_res_scheduler":
return cls.from_config(
{
"beta_end": 0.02,
"beta_schedule": "scaled_linear",
"beta_start": 0.0001,
"clip_sample": True,
"num_train_timesteps": 1000,
"prediction_type": "epsilon",
"trained_betas": None,
"variance_type": "fixed_small",
}
)
if scheduler_type is None:
return cls.from_config(scheduler_config)
elif scheduler_type == "pndm":
scheduler_config["skip_prk_steps"] = True
scheduler = PNDMScheduler.from_config(scheduler_config)
......@@ -1595,15 +1516,46 @@ def create_scheduler_from_ldm(
else:
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
if pipeline_class_name == "StableDiffusionUpscalePipeline":
scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-x4-upscaler", subfolder="scheduler")
low_res_scheduler = DDPMScheduler.from_pretrained(
"stabilityai/stable-diffusion-x4-upscaler", subfolder="low_res_scheduler"
)
return scheduler
return {
"scheduler": scheduler,
"low_res_scheduler": low_res_scheduler,
}
return {"scheduler": scheduler}
def _legacy_load_clip_tokenizer(cls, checkpoint, config=None, local_files_only=False):
if config:
config = {"pretrained_model_name_or_path": config}
else:
config = fetch_diffusers_config(checkpoint)
if is_clip_model(checkpoint) or is_clip_sdxl_model(checkpoint):
clip_config = "openai/clip-vit-large-patch14"
config["pretrained_model_name_or_path"] = clip_config
subfolder = ""
elif is_open_clip_model(checkpoint):
clip_config = "stabilityai/stable-diffusion-2"
config["pretrained_model_name_or_path"] = clip_config
subfolder = "tokenizer"
else:
clip_config = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
config["pretrained_model_name_or_path"] = clip_config
subfolder = ""
tokenizer = cls.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only)
return tokenizer
def _legacy_load_safety_checker(local_files_only, torch_dtype):
# Support for loading safety checker components using the deprecated
# `load_safety_checker` argument.
from ..pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
feature_extractor = AutoImageProcessor.from_pretrained(
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only, torch_dtype=torch_dtype
)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only, torch_dtype=torch_dtype
)
return {"safety_checker": safety_checker, "feature_extractor": feature_extractor}
......@@ -44,11 +44,6 @@ from ..utils import (
set_adapter_layers,
set_weights_and_activate_adapters,
)
from .single_file_utils import (
convert_stable_cascade_unet_single_file_to_diffusers,
infer_stable_cascade_single_file_config,
load_single_file_model_checkpoint,
)
from .unet_loader_utils import _maybe_expand_lora_scales
from .utils import AttnProcsLayers
......@@ -1059,103 +1054,3 @@ class UNet2DConditionLoadersMixin:
}
)
return lora_dicts
class FromOriginalUNetMixin:
"""
Load pretrained UNet model weights saved in the `.ckpt` or `.safetensors` format into a [`StableCascadeUNet`].
"""
@classmethod
@validate_hf_hub_args
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
r"""
Instantiate a [`StableCascadeUNet`] from pretrained StableCascadeUNet weights saved in the original `.ckpt` or
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
Parameters:
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
- A link to the `.ckpt` file (for example
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
- A path to a *file* containing all pipeline weights.
config: (`dict`, *optional*):
Dictionary containing the configuration of the model:
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
dtype is automatically derived from the model's weights.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
resume_download:
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
of Diffusers.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to True, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables of the model.
"""
class_name = cls.__name__
if class_name != "StableCascadeUNet":
raise ValueError("FromOriginalUNetMixin is currently only compatible with StableCascadeUNet")
config = kwargs.pop("config", None)
resume_download = kwargs.pop("resume_download", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
token = kwargs.pop("token", None)
cache_dir = kwargs.pop("cache_dir", None)
local_files_only = kwargs.pop("local_files_only", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
checkpoint = load_single_file_model_checkpoint(
pretrained_model_link_or_path,
resume_download=resume_download,
force_download=force_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
)
if config is None:
config = infer_stable_cascade_single_file_config(checkpoint)
model_config = cls.load_config(**config, **kwargs)
else:
model_config = config
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
model = cls.from_config(model_config, **kwargs)
diffusers_format_checkpoint = convert_stable_cascade_unet_single_file_to_diffusers(checkpoint)
if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else:
model.load_state_dict(diffusers_format_checkpoint)
if torch_dtype is not None:
model.to(torch_dtype)
return model
......@@ -17,7 +17,7 @@ import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalVAEMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils.accelerate_utils import apply_forward_hook
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
......@@ -32,7 +32,7 @@ from ..modeling_utils import ModelMixin
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
......
......@@ -19,7 +19,7 @@ from torch import nn
from torch.nn import functional as F
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import FromOriginalControlNetMixin
from ..loaders.single_file_model import FromOriginalModelMixin
from ..utils import BaseOutput, logging
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
......@@ -108,7 +108,7 @@ class ControlNetConditioningEmbedding(nn.Module):
return embedding
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
"""
A ControlNet model.
......
......@@ -963,6 +963,15 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
@classmethod
def _get_signature_keys(cls, obj):
parameters = inspect.signature(obj.__init__).parameters
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
expected_modules = set(required_parameters.keys()) - {"self"}
return expected_modules, optional_parameters
# Adapted from `transformers` modeling_utils.py
def _get_no_split_modules(self, device_map: str):
"""
......
......@@ -20,6 +20,7 @@ import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ..activations import get_activation
from ..attention_processor import (
......@@ -66,7 +67,9 @@ class UNet2DConditionOutput(BaseOutput):
sample: torch.FloatTensor = None
class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
class UNet2DConditionModel(
ModelMixin, ConfigMixin, FromOriginalModelMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin
):
r"""
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
shaped output.
......
......@@ -21,7 +21,7 @@ import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders.unet import FromOriginalUNetMixin
from ...loaders import FromOriginalModelMixin
from ...utils import BaseOutput
from ..attention_processor import Attention
from ..modeling_utils import ModelMixin
......@@ -134,7 +134,7 @@ class StableCascadeUNetOutput(BaseOutput):
sample: torch.FloatTensor = None
class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalUNetMixin):
class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
_supports_gradient_checkpointing = True
@register_to_config
......
......@@ -609,6 +609,7 @@ def load_sub_model(
):
"""Helper method to load the module `name` from `library_name` and `class_name`"""
# retrieve class candidates
class_obj, class_candidates = get_class_obj_and_candidates(
library_name,
class_name,
......
......@@ -791,62 +791,6 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
tolerance = 3e-3 if torch_device != "mps" else 1e-2
assert torch_all_close(output_slice, expected_output_slice, atol=tolerance)
def test_stable_diffusion_model_local(self):
model_id = "stabilityai/sd-vae-ft-mse"
model_1 = AutoencoderKL.from_pretrained(model_id).to(torch_device)
url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors"
model_2 = AutoencoderKL.from_single_file(url).to(torch_device)
image = self.get_sd_image(33)
with torch.no_grad():
sample_1 = model_1(image).sample
sample_2 = model_2(image).sample
assert sample_1.shape == sample_2.shape
output_slice_1 = sample_1[-1, -2:, -2:, :2].flatten().float().cpu()
output_slice_2 = sample_2[-1, -2:, -2:, :2].flatten().float().cpu()
assert torch_all_close(output_slice_1, output_slice_2, atol=3e-3)
def test_single_file_component_configs(self):
vae_single_file = AutoencoderKL.from_single_file(
"https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors"
)
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values"]
for param_name, param_value in vae_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert (
vae.config[param_name] == param_value
), f"{param_name} differs between single file loading and pretrained loading"
def test_single_file_arguments(self):
vae_default = AutoencoderKL.from_single_file(
"https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors",
)
assert vae_default.config.scaling_factor == 0.18215
assert vae_default.config.sample_size == 512
assert vae_default.dtype == torch.float32
scaling_factor = 2.0
image_size = 256
torch_dtype = torch.float16
vae = AutoencoderKL.from_single_file(
"https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors",
image_size=image_size,
scaling_factor=scaling_factor,
torch_dtype=torch_dtype,
)
assert vae.config.scaling_factor == scaling_factor
assert vae.config.sample_size == image_size
assert vae.dtype == torch_dtype
@slow
class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
......
......@@ -56,7 +56,7 @@ class StableCascadeUNetModelSlowTests(unittest.TestCase):
gc.collect()
torch.cuda.empty_cache()
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values"]
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in single_file_unet_config.items():
if param_name in PARAMS_TO_IGNORE:
continue
......@@ -78,7 +78,7 @@ class StableCascadeUNetModelSlowTests(unittest.TestCase):
gc.collect()
torch.cuda.empty_cache()
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values"]
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in single_file_unet_config.items():
if param_name in PARAMS_TO_IGNORE:
continue
......@@ -97,7 +97,7 @@ class StableCascadeUNetModelSlowTests(unittest.TestCase):
gc.collect()
torch.cuda.empty_cache()
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values"]
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in config.items():
if param_name in PARAMS_TO_IGNORE:
continue
......
......@@ -38,7 +38,6 @@ from diffusers.utils.testing_utils import (
get_python_version,
load_image,
load_numpy,
numpy_cosine_similarity_distance,
require_python39_or_higher,
require_torch_2,
require_torch_gpu,
......@@ -1063,97 +1062,6 @@ class ControlNetPipelineSlowTests(unittest.TestCase):
expected_slice = np.array([0.1338, 0.1597, 0.1202, 0.1687, 0.1377, 0.1017, 0.2070, 0.1574, 0.1348])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_load_local(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
)
pipe.unet.set_default_attn_processor()
pipe.enable_model_cpu_offload()
controlnet = ControlNetModel.from_single_file(
"https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
)
pipe_sf = StableDiffusionControlNetPipeline.from_single_file(
"https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors",
safety_checker=None,
controlnet=controlnet,
scheduler_type="pndm",
)
pipe_sf.unet.set_default_attn_processor()
pipe_sf.enable_model_cpu_offload()
control_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
).resize((512, 512))
prompt = "bird"
generator = torch.Generator(device="cpu").manual_seed(0)
output = pipe(
prompt,
image=control_image,
generator=generator,
output_type="np",
num_inference_steps=3,
).images[0]
generator = torch.Generator(device="cpu").manual_seed(0)
output_sf = pipe_sf(
prompt,
image=control_image,
generator=generator,
output_type="np",
num_inference_steps=3,
).images[0]
max_diff = numpy_cosine_similarity_distance(output_sf.flatten(), output.flatten())
assert max_diff < 1e-3
def test_single_file_component_configs(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny", variant="fp16")
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", variant="fp16", safety_checker=None, controlnet=controlnet
)
controlnet_single_file = ControlNetModel.from_single_file(
"https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
)
single_file_pipe = StableDiffusionControlNetPipeline.from_single_file(
"https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors",
safety_checker=None,
controlnet=controlnet_single_file,
scheduler_type="pndm",
)
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "architectures", "_use_default_values"]
for param_name, param_value in single_file_pipe.controlnet.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
# This parameter doesn't appear to be loaded from the config.
# So when it is registered to config, it remains a tuple as this is the default in the class definition
# from_pretrained, does load from config and converts to a list when registering to config
if param_name == "conditioning_embedding_out_channels" and isinstance(param_value, tuple):
param_value = list(param_value)
assert (
pipe.controlnet.config[param_name] == param_value
), f"{param_name} differs between single file loading and pretrained loading"
for param_name, param_value in single_file_pipe.unet.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert (
pipe.unet.config[param_name] == param_value
), f"{param_name} differs between single file loading and pretrained loading"
for param_name, param_value in single_file_pipe.vae.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert (
pipe.vae.config[param_name] == param_value
), f"{param_name} differs between single file loading and pretrained loading"
@slow
@require_torch_gpu
......
......@@ -39,7 +39,6 @@ from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
load_numpy,
numpy_cosine_similarity_distance,
require_torch_gpu,
slow,
torch_device,
......@@ -441,56 +440,3 @@ class ControlNetImg2ImgPipelineSlowTests(unittest.TestCase):
)
assert np.abs(expected_image - image).max() < 9e-2
def test_load_local(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
)
pipe.unet.set_default_attn_processor()
pipe.enable_model_cpu_offload()
controlnet = ControlNetModel.from_single_file(
"https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
)
pipe_sf = StableDiffusionControlNetImg2ImgPipeline.from_single_file(
"https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors",
safety_checker=None,
controlnet=controlnet,
scheduler_type="pndm",
)
pipe_sf.unet.set_default_attn_processor()
pipe_sf.enable_model_cpu_offload()
control_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
).resize((512, 512))
image = load_image(
"https://huggingface.co/lllyasviel/sd-controlnet-canny/resolve/main/images/bird.png"
).resize((512, 512))
prompt = "bird"
generator = torch.Generator(device="cpu").manual_seed(0)
output = pipe(
prompt,
image=image,
control_image=control_image,
strength=0.9,
generator=generator,
output_type="np",
num_inference_steps=3,
).images[0]
generator = torch.Generator(device="cpu").manual_seed(0)
output_sf = pipe_sf(
prompt,
image=image,
control_image=control_image,
strength=0.9,
generator=generator,
output_type="np",
num_inference_steps=3,
).images[0]
max_diff = numpy_cosine_similarity_distance(output_sf.flatten(), output.flatten())
assert max_diff < 1e-3
......@@ -556,55 +556,3 @@ class ControlNetInpaintPipelineSlowTests(unittest.TestCase):
)
assert numpy_cosine_similarity_distance(expected_image.flatten(), image.flatten()) < 1e-2
def test_load_local(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
pipe_1 = StableDiffusionControlNetInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting", safety_checker=None, controlnet=controlnet
)
controlnet = ControlNetModel.from_single_file(
"https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
)
pipe_2 = StableDiffusionControlNetInpaintPipeline.from_single_file(
"https://huggingface.co/runwayml/stable-diffusion-inpainting/blob/main/sd-v1-5-inpainting.ckpt",
safety_checker=None,
controlnet=controlnet,
)
control_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
).resize((512, 512))
image = load_image(
"https://huggingface.co/lllyasviel/sd-controlnet-canny/resolve/main/images/bird.png"
).resize((512, 512))
mask_image = load_image(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_inpaint/input_bench_mask.png"
).resize((512, 512))
pipes = [pipe_1, pipe_2]
images = []
for pipe in pipes:
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
prompt = "bird"
output = pipe(
prompt,
image=image,
control_image=control_image,
mask_image=mask_image,
strength=0.9,
generator=generator,
output_type="np",
num_inference_steps=3,
)
images.append(output.images[0])
del pipe
gc.collect()
torch.cuda.empty_cache()
max_diff = numpy_cosine_similarity_distance(images[0].flatten(), images[1].flatten())
assert max_diff < 1e-3
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment