Unverified Commit 6b1abba1 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Add controlnet and vae from single file (#4084)



* Add controlnet from single file

* Updates

* make style

* finish

* Apply suggestions from code review
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 470f51cd
...@@ -35,3 +35,11 @@ Adapters (textual inversion, LoRA, hypernetworks) allow you to modify a diffusio ...@@ -35,3 +35,11 @@ Adapters (textual inversion, LoRA, hypernetworks) allow you to modify a diffusio
## FromSingleFileMixin ## FromSingleFileMixin
[[autodoc]] loaders.FromSingleFileMixin [[autodoc]] loaders.FromSingleFileMixin
## FromOriginalControlnetMixin
[[autodoc]] loaders.FromOriginalControlnetMixin
## FromOriginalVAEMixin
[[autodoc]] loaders.FromOriginalVAEMixin
...@@ -6,6 +6,18 @@ The abstract from the paper is: ...@@ -6,6 +6,18 @@ The abstract from the paper is:
*How can we perform efficient inference and learning in directed probabilistic models, in the presence of continuous latent variables with intractable posterior distributions, and large datasets? We introduce a stochastic variational inference and learning algorithm that scales to large datasets and, under some mild differentiability conditions, even works in the intractable case. Our contributions are two-fold. First, we show that a reparameterization of the variational lower bound yields a lower bound estimator that can be straightforwardly optimized using standard stochastic gradient methods. Second, we show that for i.i.d. datasets with continuous latent variables per datapoint, posterior inference can be made especially efficient by fitting an approximate inference model (also called a recognition model) to the intractable posterior using the proposed lower bound estimator. Theoretical advantages are reflected in experimental results.* *How can we perform efficient inference and learning in directed probabilistic models, in the presence of continuous latent variables with intractable posterior distributions, and large datasets? We introduce a stochastic variational inference and learning algorithm that scales to large datasets and, under some mild differentiability conditions, even works in the intractable case. Our contributions are two-fold. First, we show that a reparameterization of the variational lower bound yields a lower bound estimator that can be straightforwardly optimized using standard stochastic gradient methods. Second, we show that for i.i.d. datasets with continuous latent variables per datapoint, posterior inference can be made especially efficient by fitting an approximate inference model (also called a recognition model) to the intractable posterior using the proposed lower bound estimator. Theoretical advantages are reflected in experimental results.*
## Loading from the original format
By default the [`AutoencoderKL`] should be loaded with [`~ModelMixin.from_pretrained`], but it can also be loaded
from the original format using [`FromOriginalVAEMixin.from_single_file`] as follows:
```py
from diffusers import AutoencoderKL
url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" # can also be local file
model = AutoencoderKL.from_single_file(url)
```
## AutoencoderKL ## AutoencoderKL
[[autodoc]] AutoencoderKL [[autodoc]] AutoencoderKL
......
...@@ -6,6 +6,21 @@ The abstract from the paper is: ...@@ -6,6 +6,21 @@ The abstract from the paper is:
*We present a neural network structure, ControlNet, to control pretrained large diffusion models to support additional input conditions. The ControlNet learns task-specific conditions in an end-to-end way, and the learning is robust even when the training dataset is small (< 50k). Moreover, training a ControlNet is as fast as fine-tuning a diffusion model, and the model can be trained on a personal devices. Alternatively, if powerful computation clusters are available, the model can scale to large amounts (millions to billions) of data. We report that large diffusion models like Stable Diffusion can be augmented with ControlNets to enable conditional inputs like edge maps, segmentation maps, keypoints, etc. This may enrich the methods to control large diffusion models and further facilitate related applications.* *We present a neural network structure, ControlNet, to control pretrained large diffusion models to support additional input conditions. The ControlNet learns task-specific conditions in an end-to-end way, and the learning is robust even when the training dataset is small (< 50k). Moreover, training a ControlNet is as fast as fine-tuning a diffusion model, and the model can be trained on a personal devices. Alternatively, if powerful computation clusters are available, the model can scale to large amounts (millions to billions) of data. We report that large diffusion models like Stable Diffusion can be augmented with ControlNets to enable conditional inputs like edge maps, segmentation maps, keypoints, etc. This may enrich the methods to control large diffusion models and further facilitate related applications.*
## Loading from the original format
By default the [`ControlNetModel`] should be loaded with [`~ModelMixin.from_pretrained`], but it can also be loaded
from the original format using [`FromOriginalControlnetMixin.from_single_file`] as follows:
```py
from diffusers import StableDiffusionControlnetPipeline, ControlNetModel
url = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" # can also be a local path
controlnet = ControlNetModel.from_single_file(url)
url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors" # can also be a local path
pipe = StableDiffusionControlnetPipeline.from_single_file(url, controlnet=controlnet)
```
## ControlNetModel ## ControlNetModel
[[autodoc]] ControlNetModel [[autodoc]] ControlNetModel
......
...@@ -14,9 +14,12 @@ ...@@ -14,9 +14,12 @@
import os import os
import warnings import warnings
from collections import defaultdict from collections import defaultdict
from contextlib import nullcontext
from io import BytesIO
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, List, Optional, Union from typing import Callable, Dict, List, Optional, Union
import requests
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
...@@ -42,10 +45,13 @@ from .utils import ( ...@@ -42,10 +45,13 @@ from .utils import (
HF_HUB_OFFLINE, HF_HUB_OFFLINE,
_get_model_file, _get_model_file,
deprecate, deprecate,
is_accelerate_available,
is_omegaconf_available,
is_safetensors_available, is_safetensors_available,
is_transformers_available, is_transformers_available,
logging, logging,
) )
from .utils.import_utils import BACKENDS_MAPPING
if is_safetensors_available(): if is_safetensors_available():
...@@ -54,6 +60,9 @@ if is_safetensors_available(): ...@@ -54,6 +60,9 @@ if is_safetensors_available():
if is_transformers_available(): if is_transformers_available():
from transformers import CLIPTextModel, PreTrainedModel, PreTrainedTokenizer from transformers import CLIPTextModel, PreTrainedModel, PreTrainedTokenizer
if is_accelerate_available():
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -1319,8 +1328,8 @@ class FromSingleFileMixin: ...@@ -1319,8 +1328,8 @@ class FromSingleFileMixin:
@classmethod @classmethod
def from_single_file(cls, pretrained_model_link_or_path, **kwargs): def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
r""" r"""
Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` format. The pipeline Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` or `.safetensors`
is set in evaluation mode (`model.eval()`) by default. format. The pipeline is set in evaluation mode (`model.eval()`) by default.
Parameters: Parameters:
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*): pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
...@@ -1430,6 +1439,7 @@ class FromSingleFileMixin: ...@@ -1430,6 +1439,7 @@ class FromSingleFileMixin:
load_safety_checker = kwargs.pop("load_safety_checker", True) load_safety_checker = kwargs.pop("load_safety_checker", True)
prediction_type = kwargs.pop("prediction_type", None) prediction_type = kwargs.pop("prediction_type", None)
text_encoder = kwargs.pop("text_encoder", None) text_encoder = kwargs.pop("text_encoder", None)
controlnet = kwargs.pop("controlnet", None)
tokenizer = kwargs.pop("tokenizer", None) tokenizer = kwargs.pop("tokenizer", None)
torch_dtype = kwargs.pop("torch_dtype", None) torch_dtype = kwargs.pop("torch_dtype", None)
...@@ -1446,11 +1456,18 @@ class FromSingleFileMixin: ...@@ -1446,11 +1456,18 @@ class FromSingleFileMixin:
# TODO: For now we only support stable diffusion # TODO: For now we only support stable diffusion
stable_unclip = None stable_unclip = None
model_type = None model_type = None
controlnet = False
if pipeline_name == "StableDiffusionControlNetPipeline": if pipeline_name in [
"StableDiffusionControlNetPipeline",
"StableDiffusionControlNetImg2ImgPipeline",
"StableDiffusionControlNetInpaintPipeline",
]:
from .models.controlnet import ControlNetModel
from .pipelines.controlnet.multicontrolnet import MultiControlNetModel
# Model type will be inferred from the checkpoint. # Model type will be inferred from the checkpoint.
controlnet = True if not isinstance(controlnet, (ControlNetModel, MultiControlNetModel)):
raise ValueError("ControlNet needs to be passed if loading from ControlNet pipeline.")
elif "StableDiffusion" in pipeline_name: elif "StableDiffusion" in pipeline_name:
# Model type will be inferred from the checkpoint. # Model type will be inferred from the checkpoint.
pass pass
...@@ -1519,3 +1536,339 @@ class FromSingleFileMixin: ...@@ -1519,3 +1536,339 @@ class FromSingleFileMixin:
pipe.to(torch_dtype=torch_dtype) pipe.to(torch_dtype=torch_dtype)
return pipe return pipe
class FromOriginalVAEMixin:
@classmethod
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
r"""
Instantiate a [`AutoencoderKL`] from pretrained controlnet weights saved in the original `.ckpt` or
`.safetensors` format. The pipeline is format. The pipeline is set in evaluation mode (`model.eval()`) by
default.
Parameters:
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
- A link to the `.ckpt` file (for example
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
- A path to a *file* containing all pipeline weights.
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
dtype is automatically derived from the model's weights.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
incompletely downloaded files are deleted.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to True, the model
won't be downloaded from the Hub.
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
image_size (`int`, *optional*, defaults to 512):
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
use_safetensors (`bool`, *optional*, defaults to `None`):
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
weights. If set to `False`, safetensors weights are not loaded.
upcast_attention (`bool`, *optional*, defaults to `None`):
Whether the attention computation should always be upcasted.
scaling_factor (`float`, *optional*, defaults to 0.18215):
The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z
= 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution
Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (for example the pipeline components of the
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
method. See example below for more information.
<Tip warning={true}>
Make sure to pass both `image_size` and `scaling_factor` to `from_single_file()` if you want to load
a VAE that does accompany a stable diffusion model of v2 or higher or SDXL.
</Tip>
Examples:
```py
from diffusers import AutoencoderKL
url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" # can also be local file
model = AutoencoderKL.from_single_file(url)
```
"""
if not is_omegaconf_available():
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
from omegaconf import OmegaConf
from .models import AutoencoderKL
# import here to avoid circular dependency
from .pipelines.stable_diffusion.convert_from_ckpt import (
convert_ldm_vae_checkpoint,
create_vae_diffusers_config,
)
config_file = kwargs.pop("config_file", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
image_size = kwargs.pop("image_size", None)
scaling_factor = kwargs.pop("scaling_factor", None)
kwargs.pop("upcast_attention", None)
torch_dtype = kwargs.pop("torch_dtype", None)
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
from_safetensors = file_extension == "safetensors"
if from_safetensors and use_safetensors is False:
raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
# remove huggingface url
for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
if pretrained_model_link_or_path.startswith(prefix):
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
ckpt_path = Path(pretrained_model_link_or_path)
if not ckpt_path.is_file():
# get repo_id and (potentially nested) file path of ckpt in repo
repo_id = "/".join(ckpt_path.parts[:2])
file_path = "/".join(ckpt_path.parts[2:])
if file_path.startswith("blob/"):
file_path = file_path[len("blob/") :]
if file_path.startswith("main/"):
file_path = file_path[len("main/") :]
pretrained_model_link_or_path = hf_hub_download(
repo_id,
filename=file_path,
cache_dir=cache_dir,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
force_download=force_download,
)
if from_safetensors:
from safetensors import safe_open
checkpoint = {}
with safe_open(pretrained_model_link_or_path, framework="pt", device="cpu") as f:
for key in f.keys():
checkpoint[key] = f.get_tensor(key)
else:
checkpoint = torch.load(pretrained_model_link_or_path, map_location="cpu")
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
if config_file is None:
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
config_file = BytesIO(requests.get(config_url).content)
original_config = OmegaConf.load(config_file)
# default to sd-v1-5
image_size = image_size or 512
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
if scaling_factor is None:
if (
"model" in original_config
and "params" in original_config.model
and "scale_factor" in original_config.model.params
):
vae_scaling_factor = original_config.model.params.scale_factor
else:
vae_scaling_factor = 0.18215 # default SD scaling factor
vae_config["scaling_factor"] = vae_scaling_factor
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
vae = AutoencoderKL(**vae_config)
if is_accelerate_available():
for param_name, param in converted_vae_checkpoint.items():
set_module_tensor_to_device(vae, param_name, "cpu", value=param)
else:
vae.load_state_dict(converted_vae_checkpoint)
if torch_dtype is not None:
vae.to(torch_dtype=torch_dtype)
return vae
class FromOriginalControlnetMixin:
@classmethod
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
r"""
Instantiate a [`ControlNetModel`] from pretrained controlnet weights saved in the original `.ckpt` or
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
Parameters:
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
- A link to the `.ckpt` file (for example
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
- A path to a *file* containing all pipeline weights.
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
dtype is automatically derived from the model's weights.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
incompletely downloaded files are deleted.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to True, the model
won't be downloaded from the Hub.
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
use_safetensors (`bool`, *optional*, defaults to `None`):
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
weights. If set to `False`, safetensors weights are not loaded.
image_size (`int`, *optional*, defaults to 512):
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
upcast_attention (`bool`, *optional*, defaults to `None`):
Whether the attention computation should always be upcasted.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (for example the pipeline components of the
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
method. See example below for more information.
Examples:
```py
from diffusers import StableDiffusionControlnetPipeline, ControlNetModel
url = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" # can also be a local path
model = ControlNetModel.from_single_file(url)
url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors" # can also be a local path
pipe = StableDiffusionControlnetPipeline.from_single_file(url, controlnet=controlnet)
```
"""
# import here to avoid circular dependency
from .pipelines.stable_diffusion.convert_from_ckpt import download_controlnet_from_original_ckpt
config_file = kwargs.pop("config_file", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
use_auth_token = kwargs.pop("use_auth_token", None)
num_in_channels = kwargs.pop("num_in_channels", None)
use_linear_projection = kwargs.pop("use_linear_projection", None)
revision = kwargs.pop("revision", None)
extract_ema = kwargs.pop("extract_ema", False)
image_size = kwargs.pop("image_size", None)
upcast_attention = kwargs.pop("upcast_attention", None)
torch_dtype = kwargs.pop("torch_dtype", None)
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
from_safetensors = file_extension == "safetensors"
if from_safetensors and use_safetensors is False:
raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
# remove huggingface url
for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
if pretrained_model_link_or_path.startswith(prefix):
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
ckpt_path = Path(pretrained_model_link_or_path)
if not ckpt_path.is_file():
# get repo_id and (potentially nested) file path of ckpt in repo
repo_id = "/".join(ckpt_path.parts[:2])
file_path = "/".join(ckpt_path.parts[2:])
if file_path.startswith("blob/"):
file_path = file_path[len("blob/") :]
if file_path.startswith("main/"):
file_path = file_path[len("main/") :]
pretrained_model_link_or_path = hf_hub_download(
repo_id,
filename=file_path,
cache_dir=cache_dir,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
force_download=force_download,
)
if config_file is None:
config_url = "https://raw.githubusercontent.com/lllyasviel/ControlNet/main/models/cldm_v15.yaml"
config_file = BytesIO(requests.get(config_url).content)
image_size = image_size or 512
controlnet = download_controlnet_from_original_ckpt(
pretrained_model_link_or_path,
original_config_file=config_file,
image_size=image_size,
extract_ema=extract_ema,
num_in_channels=num_in_channels,
upcast_attention=upcast_attention,
from_safetensors=from_safetensors,
use_linear_projection=use_linear_projection,
)
if torch_dtype is not None:
controlnet.to(torch_dtype=torch_dtype)
return controlnet
...@@ -18,6 +18,7 @@ import torch ...@@ -18,6 +18,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import FromOriginalVAEMixin
from ..utils import BaseOutput, apply_forward_hook from ..utils import BaseOutput, apply_forward_hook
from .attention_processor import AttentionProcessor, AttnProcessor from .attention_processor import AttentionProcessor, AttnProcessor
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
...@@ -38,7 +39,7 @@ class AutoencoderKLOutput(BaseOutput): ...@@ -38,7 +39,7 @@ class AutoencoderKLOutput(BaseOutput):
latent_dist: "DiagonalGaussianDistribution" latent_dist: "DiagonalGaussianDistribution"
class AutoencoderKL(ModelMixin, ConfigMixin): class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
r""" r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
......
...@@ -19,6 +19,7 @@ from torch import nn ...@@ -19,6 +19,7 @@ from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import FromOriginalControlnetMixin
from ..utils import BaseOutput, logging from ..utils import BaseOutput, logging
from .attention_processor import AttentionProcessor, AttnProcessor from .attention_processor import AttentionProcessor, AttnProcessor
from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
...@@ -100,7 +101,7 @@ class ControlNetConditioningEmbedding(nn.Module): ...@@ -100,7 +101,7 @@ class ControlNetConditioningEmbedding(nn.Module):
return embedding return embedding
class ControlNetModel(ModelMixin, ConfigMixin): class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
""" """
A ControlNet model. A ControlNet model.
......
...@@ -24,7 +24,7 @@ import torch.nn.functional as F ...@@ -24,7 +24,7 @@ import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
...@@ -90,7 +90,9 @@ EXAMPLE_DOC_STRING = """ ...@@ -90,7 +90,9 @@ EXAMPLE_DOC_STRING = """
""" """
class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): class StableDiffusionControlNetPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
):
r""" r"""
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
......
...@@ -24,7 +24,7 @@ import torch.nn.functional as F ...@@ -24,7 +24,7 @@ import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
...@@ -116,7 +116,9 @@ def prepare_image(image): ...@@ -116,7 +116,9 @@ def prepare_image(image):
return image return image
class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): class StableDiffusionControlNetImg2ImgPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
):
r""" r"""
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
......
...@@ -25,7 +25,7 @@ import torch.nn.functional as F ...@@ -25,7 +25,7 @@ import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
...@@ -222,7 +222,9 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image=False ...@@ -222,7 +222,9 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image=False
return mask, masked_image return mask, masked_image
class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): class StableDiffusionControlNetInpaintPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
):
r""" r"""
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
......
...@@ -621,8 +621,8 @@ def convert_ldm_unet_checkpoint( ...@@ -621,8 +621,8 @@ def convert_ldm_unet_checkpoint(
def convert_ldm_vae_checkpoint(checkpoint, config): def convert_ldm_vae_checkpoint(checkpoint, config):
# extract state dict for VAE # extract state dict for VAE
vae_state_dict = {} vae_state_dict = {}
vae_key = "first_stage_model."
keys = list(checkpoint.keys()) keys = list(checkpoint.keys())
vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else ""
for key in keys: for key in keys:
if key.startswith(vae_key): if key.startswith(vae_key):
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
...@@ -1064,7 +1064,7 @@ def convert_controlnet_checkpoint( ...@@ -1064,7 +1064,7 @@ def convert_controlnet_checkpoint(
if cross_attention_dim is not None: if cross_attention_dim is not None:
ctrlnet_config["cross_attention_dim"] = cross_attention_dim ctrlnet_config["cross_attention_dim"] = cross_attention_dim
controlnet_model = ControlNetModel(**ctrlnet_config) controlnet = ControlNetModel(**ctrlnet_config)
# Some controlnet ckpt files are distributed independently from the rest of the # Some controlnet ckpt files are distributed independently from the rest of the
# model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/
...@@ -1082,9 +1082,9 @@ def convert_controlnet_checkpoint( ...@@ -1082,9 +1082,9 @@ def convert_controlnet_checkpoint(
skip_extract_state_dict=skip_extract_state_dict, skip_extract_state_dict=skip_extract_state_dict,
) )
controlnet_model.load_state_dict(converted_ctrl_checkpoint) controlnet.load_state_dict(converted_ctrl_checkpoint)
return controlnet_model return controlnet
def download_from_original_stable_diffusion_ckpt( def download_from_original_stable_diffusion_ckpt(
...@@ -1181,7 +1181,7 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1181,7 +1181,7 @@ def download_from_original_stable_diffusion_ckpt(
) )
if pipeline_class is None: if pipeline_class is None:
pipeline_class = StableDiffusionPipeline pipeline_class = StableDiffusionPipeline if not controlnet else StableDiffusionControlNetPipeline
if prediction_type == "v-prediction": if prediction_type == "v-prediction":
prediction_type = "v_prediction" prediction_type = "v_prediction"
...@@ -1288,8 +1288,7 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1288,8 +1288,7 @@ def download_from_original_stable_diffusion_ckpt(
if controlnet is None: if controlnet is None:
controlnet = "control_stage_config" in original_config.model.params controlnet = "control_stage_config" in original_config.model.params
if controlnet: controlnet = convert_controlnet_checkpoint(
controlnet_model = convert_controlnet_checkpoint(
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
) )
...@@ -1400,13 +1399,13 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1400,13 +1399,13 @@ def download_from_original_stable_diffusion_ckpt(
if stable_unclip is None: if stable_unclip is None:
if controlnet: if controlnet:
pipe = StableDiffusionControlNetPipeline( pipe = pipeline_class(
vae=vae, vae=vae,
text_encoder=text_model, text_encoder=text_model,
tokenizer=tokenizer, tokenizer=tokenizer,
unet=unet, unet=unet,
scheduler=scheduler, scheduler=scheduler,
controlnet=controlnet_model, controlnet=controlnet,
safety_checker=None, safety_checker=None,
feature_extractor=None, feature_extractor=None,
requires_safety_checker=False, requires_safety_checker=False,
...@@ -1503,12 +1502,12 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1503,12 +1502,12 @@ def download_from_original_stable_diffusion_ckpt(
feature_extractor = None feature_extractor = None
if controlnet: if controlnet:
pipe = StableDiffusionControlNetPipeline( pipe = pipeline_class(
vae=vae, vae=vae,
text_encoder=text_model, text_encoder=text_model,
tokenizer=tokenizer, tokenizer=tokenizer,
unet=unet, unet=unet,
controlnet=controlnet_model, controlnet=controlnet,
scheduler=scheduler, scheduler=scheduler,
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
...@@ -1623,7 +1622,7 @@ def download_controlnet_from_original_ckpt( ...@@ -1623,7 +1622,7 @@ def download_controlnet_from_original_ckpt(
if "control_stage_config" not in original_config.model.params: if "control_stage_config" not in original_config.model.params:
raise ValueError("`control_stage_config` not present in original config") raise ValueError("`control_stage_config` not present in original config")
controlnet_model = convert_controlnet_checkpoint( controlnet = convert_controlnet_checkpoint(
checkpoint, checkpoint,
original_config, original_config,
checkpoint_path, checkpoint_path,
...@@ -1634,4 +1633,4 @@ def download_controlnet_from_original_ckpt( ...@@ -1634,4 +1633,4 @@ def download_controlnet_from_original_ckpt(
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
) )
return controlnet_model return controlnet
...@@ -199,7 +199,7 @@ class AutoencoderKLIntegrationTests(unittest.TestCase): ...@@ -199,7 +199,7 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
revision=revision, revision=revision,
) )
model.to(torch_device).eval() model.to(torch_device)
return model return model
...@@ -383,3 +383,22 @@ class AutoencoderKLIntegrationTests(unittest.TestCase): ...@@ -383,3 +383,22 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
tolerance = 3e-3 if torch_device != "mps" else 1e-2 tolerance = 3e-3 if torch_device != "mps" else 1e-2
assert torch_all_close(output_slice, expected_output_slice, atol=tolerance) assert torch_all_close(output_slice, expected_output_slice, atol=tolerance)
def test_stable_diffusion_model_local(self):
model_id = "stabilityai/sd-vae-ft-mse"
model_1 = AutoencoderKL.from_pretrained(model_id).to(torch_device)
url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors"
model_2 = AutoencoderKL.from_single_file(url).to(torch_device)
image = self.get_sd_image(33)
with torch.no_grad():
sample_1 = model_1(image).sample
sample_2 = model_2(image).sample
assert sample_1.shape == sample_2.shape
output_slice_1 = sample_1[-1, -2:, -2:, :2].flatten().float().cpu()
output_slice_2 = sample_2[-1, -2:, -2:, :2].flatten().float().cpu()
assert torch_all_close(output_slice_1, output_slice_2, atol=3e-3)
...@@ -752,6 +752,42 @@ class ControlNetPipelineSlowTests(unittest.TestCase): ...@@ -752,6 +752,42 @@ class ControlNetPipelineSlowTests(unittest.TestCase):
expected_slice = np.array([0.1338, 0.1597, 0.1202, 0.1687, 0.1377, 0.1017, 0.2070, 0.1574, 0.1348]) expected_slice = np.array([0.1338, 0.1597, 0.1202, 0.1687, 0.1377, 0.1017, 0.2070, 0.1574, 0.1348])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_load_local(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
pipe_1 = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
)
controlnet = ControlNetModel.from_single_file(
"https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
)
pipe_2 = StableDiffusionControlNetPipeline.from_single_file(
"https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors",
safety_checker=None,
controlnet=controlnet,
)
pipes = [pipe_1, pipe_2]
images = []
for pipe in pipes:
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
prompt = "bird"
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
)
output = pipe(prompt, image, generator=generator, output_type="np", num_inference_steps=3)
images.append(output.images[0])
del pipe
gc.collect()
torch.cuda.empty_cache()
assert np.abs(images[0] - images[1]).sum() < 1e-3
@slow @slow
@require_torch_gpu @require_torch_gpu
......
...@@ -401,3 +401,49 @@ class ControlNetImg2ImgPipelineSlowTests(unittest.TestCase): ...@@ -401,3 +401,49 @@ class ControlNetImg2ImgPipelineSlowTests(unittest.TestCase):
) )
assert np.abs(expected_image - image).max() < 9e-2 assert np.abs(expected_image - image).max() < 9e-2
def test_load_local(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
pipe_1 = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
)
controlnet = ControlNetModel.from_single_file(
"https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
)
pipe_2 = StableDiffusionControlNetImg2ImgPipeline.from_single_file(
"https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors",
safety_checker=None,
controlnet=controlnet,
)
control_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
).resize((512, 512))
image = load_image(
"https://huggingface.co/lllyasviel/sd-controlnet-canny/resolve/main/images/bird.png"
).resize((512, 512))
pipes = [pipe_1, pipe_2]
images = []
for pipe in pipes:
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
prompt = "bird"
output = pipe(
prompt,
image=image,
control_image=control_image,
strength=0.9,
generator=generator,
output_type="np",
num_inference_steps=3,
)
images.append(output.images[0])
del pipe
gc.collect()
torch.cuda.empty_cache()
assert np.abs(images[0] - images[1]).sum() < 1e-3
...@@ -543,3 +543,54 @@ class ControlNetInpaintPipelineSlowTests(unittest.TestCase): ...@@ -543,3 +543,54 @@ class ControlNetInpaintPipelineSlowTests(unittest.TestCase):
) )
assert np.abs(expected_image - image).max() < 9e-2 assert np.abs(expected_image - image).max() < 9e-2
def test_load_local(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
pipe_1 = StableDiffusionControlNetInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
)
controlnet = ControlNetModel.from_single_file(
"https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
)
pipe_2 = StableDiffusionControlNetInpaintPipeline.from_single_file(
"https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors",
safety_checker=None,
controlnet=controlnet,
)
control_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
).resize((512, 512))
image = load_image(
"https://huggingface.co/lllyasviel/sd-controlnet-canny/resolve/main/images/bird.png"
).resize((512, 512))
mask_image = load_image(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_inpaint/input_bench_mask.png"
).resize((512, 512))
pipes = [pipe_1, pipe_2]
images = []
for pipe in pipes:
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
prompt = "bird"
output = pipe(
prompt,
image=image,
control_image=control_image,
mask_image=mask_image,
strength=0.9,
generator=generator,
output_type="np",
num_inference_steps=3,
)
images.append(output.images[0])
del pipe
gc.collect()
torch.cuda.empty_cache()
assert np.abs(images[0] - images[1]).sum() < 1e-3
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment