Unverified Commit 86ecd4b7 authored by 1lint's avatar 1lint Committed by GitHub
Browse files

add from_ckpt method as Mixin (#2318)



* add mixin class for pipeline from original sd ckpt

* Improve

* make style

* merge main into

* Improve more

* fix more

* up

* Apply suggestions from code review

* finish docs

* rename

* make style

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent bdeff4d6
...@@ -36,3 +36,7 @@ API to load such adapter neural networks via the [`loaders.py` module](https://g ...@@ -36,3 +36,7 @@ API to load such adapter neural networks via the [`loaders.py` module](https://g
### LoraLoaderMixin ### LoraLoaderMixin
[[autodoc]] loaders.LoraLoaderMixin [[autodoc]] loaders.LoraLoaderMixin
### FromCkptMixin
[[autodoc]] loaders.FromCkptMixin
...@@ -308,6 +308,7 @@ All checkpoints can be found under the authors' namespace [lllyasviel](https://h ...@@ -308,6 +308,7 @@ All checkpoints can be found under the authors' namespace [lllyasviel](https://h
- disable_vae_slicing - disable_vae_slicing
- enable_xformers_memory_efficient_attention - enable_xformers_memory_efficient_attention
- disable_xformers_memory_efficient_attention - disable_xformers_memory_efficient_attention
- load_textual_inversion
## FlaxStableDiffusionControlNetPipeline ## FlaxStableDiffusionControlNetPipeline
[[autodoc]] FlaxStableDiffusionControlNetPipeline [[autodoc]] FlaxStableDiffusionControlNetPipeline
......
...@@ -31,3 +31,6 @@ Available Checkpoints are: ...@@ -31,3 +31,6 @@ Available Checkpoints are:
- disable_attention_slicing - disable_attention_slicing
- enable_xformers_memory_efficient_attention - enable_xformers_memory_efficient_attention
- disable_xformers_memory_efficient_attention - disable_xformers_memory_efficient_attention
- load_textual_inversion
- load_lora_weights
- save_lora_weights
...@@ -30,6 +30,10 @@ proposed by Chenlin Meng, Yutong He, Yang Song, Jiaming Song, Jiajun Wu, Jun-Yan ...@@ -30,6 +30,10 @@ proposed by Chenlin Meng, Yutong He, Yang Song, Jiaming Song, Jiajun Wu, Jun-Yan
- disable_attention_slicing - disable_attention_slicing
- enable_xformers_memory_efficient_attention - enable_xformers_memory_efficient_attention
- disable_xformers_memory_efficient_attention - disable_xformers_memory_efficient_attention
- load_textual_inversion
- from_ckpt
- load_lora_weights
- save_lora_weights
[[autodoc]] FlaxStableDiffusionImg2ImgPipeline [[autodoc]] FlaxStableDiffusionImg2ImgPipeline
- all - all
......
...@@ -31,6 +31,9 @@ Available checkpoints are: ...@@ -31,6 +31,9 @@ Available checkpoints are:
- disable_attention_slicing - disable_attention_slicing
- enable_xformers_memory_efficient_attention - enable_xformers_memory_efficient_attention
- disable_xformers_memory_efficient_attention - disable_xformers_memory_efficient_attention
- load_textual_inversion
- load_lora_weights
- save_lora_weights
[[autodoc]] FlaxStableDiffusionInpaintPipeline [[autodoc]] FlaxStableDiffusionInpaintPipeline
- all - all
......
...@@ -68,3 +68,6 @@ images[0].save("snowy_mountains.png") ...@@ -68,3 +68,6 @@ images[0].save("snowy_mountains.png")
[[autodoc]] StableDiffusionInstructPix2PixPipeline [[autodoc]] StableDiffusionInstructPix2PixPipeline
- __call__ - __call__
- all - all
- load_textual_inversion
- load_lora_weights
- save_lora_weights
...@@ -39,6 +39,10 @@ Available Checkpoints are: ...@@ -39,6 +39,10 @@ Available Checkpoints are:
- disable_xformers_memory_efficient_attention - disable_xformers_memory_efficient_attention
- enable_vae_tiling - enable_vae_tiling
- disable_vae_tiling - disable_vae_tiling
- load_textual_inversion
- from_ckpt
- load_lora_weights
- save_lora_weights
[[autodoc]] FlaxStableDiffusionPipeline [[autodoc]] FlaxStableDiffusionPipeline
- all - all
......
...@@ -109,7 +109,6 @@ try: ...@@ -109,7 +109,6 @@ try:
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_transformers_objects import * # noqa F403 from .utils.dummy_torch_and_transformers_objects import * # noqa F403
else: else:
from .loaders import TextualInversionLoaderMixin
from .pipelines import ( from .pipelines import (
AltDiffusionImg2ImgPipeline, AltDiffusionImg2ImgPipeline,
AltDiffusionPipeline, AltDiffusionPipeline,
......
...@@ -13,9 +13,11 @@ ...@@ -13,9 +13,11 @@
# limitations under the License. # limitations under the License.
import os import os
from collections import defaultdict from collections import defaultdict
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union from typing import Callable, Dict, List, Optional, Union
import torch import torch
from huggingface_hub import hf_hub_download
from .models.attention_processor import LoRAAttnProcessor from .models.attention_processor import LoRAAttnProcessor
from .utils import ( from .utils import (
...@@ -431,6 +433,7 @@ class TextualInversionLoaderMixin: ...@@ -431,6 +433,7 @@ class TextualInversionLoaderMixin:
Example: Example:
To load a textual inversion embedding vector in `diffusers` format: To load a textual inversion embedding vector in `diffusers` format:
```py ```py
from diffusers import StableDiffusionPipeline from diffusers import StableDiffusionPipeline
import torch import torch
...@@ -463,6 +466,7 @@ class TextualInversionLoaderMixin: ...@@ -463,6 +466,7 @@ class TextualInversionLoaderMixin:
image = pipe(prompt, num_inference_steps=50).images[0] image = pipe(prompt, num_inference_steps=50).images[0]
image.save("character.png") image.save("character.png")
``` ```
""" """
if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer): if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer):
raise ValueError( raise ValueError(
...@@ -1051,3 +1055,197 @@ class LoraLoaderMixin: ...@@ -1051,3 +1055,197 @@ class LoraLoaderMixin:
save_function(state_dict, os.path.join(save_directory, weight_name)) save_function(state_dict, os.path.join(save_directory, weight_name))
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
class FromCkptMixin:
"""This helper class allows to directly load .ckpt stable diffusion file_extension
into the respective classes."""
@classmethod
def from_ckpt(cls, pretrained_model_link_or_path, **kwargs):
r"""
Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights saved in the original .ckpt format.
The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
Parameters:
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
- A link to the .ckpt file on the Hub. Should be in the format
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>"`
- A path to a *file* containing all pipeline weights.
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
will be automatically derived from the model's weights.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model).
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `huggingface-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
use_safetensors (`bool`, *optional* ):
If set to `True`, the pipeline will be loaded from `safetensors` weights. If set to `None` (the
default). The pipeline will load using `safetensors` if the safetensors weights are available *and* if
`safetensors` is installed. If the to `False` the pipeline will *not* use `safetensors`.
extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for
checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults
to `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for
inference. Non-EMA weights are usually better to continue fine-tuning.
upcast_attention (`bool`, *optional*, defaults to `None`):
Whether the attention computation should always be upcasted. This is necessary when running stable
image_size (`int`, *optional*, defaults to 512):
The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2
Base. Use 768 for Stable Diffusion v2.
prediction_type (`str`, *optional*):
The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion v1.X and Stable
Diffusion v2 Base. Use `'v_prediction'` for Stable Diffusion v2.
num_in_channels (`int`, *optional*, defaults to None):
The number of input channels. If `None`, it will be automatically inferred.
scheduler_type (`str`, *optional*, defaults to 'pndm'):
Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
"ddim"]`.
load_safety_checker (`bool`, *optional*, defaults to `True`):
Whether to load the safety checker or not. Defaults to `True`.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
specific pipeline class. The overwritten components are then directly passed to the pipelines
`__init__` method. See example below for more information.
Examples:
```py
>>> from diffusers import StableDiffusionPipeline
>>> # Download pipeline from huggingface.co and cache.
>>> pipeline = StableDiffusionPipeline.from_ckpt(
... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
... )
>>> # Download pipeline from local file
>>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt
>>> pipeline = StableDiffusionPipeline.from_ckpt("./v1-5-pruned-emaonly")
>>> # Enable float16 and move to GPU
>>> pipeline = StableDiffusionPipeline.from_ckpt(
... "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
... torch_dtype=torch.float16,
... )
>>> pipeline.to("cuda")
```
"""
# import here to avoid circular dependency
from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
extract_ema = kwargs.pop("extract_ema", False)
image_size = kwargs.pop("image_size", 512)
scheduler_type = kwargs.pop("scheduler_type", "pndm")
num_in_channels = kwargs.pop("num_in_channels", None)
upcast_attention = kwargs.pop("upcast_attention", None)
load_safety_checker = kwargs.pop("load_safety_checker", True)
prediction_type = kwargs.pop("prediction_type", None)
torch_dtype = kwargs.pop("torch_dtype", None)
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
pipeline_name = cls.__name__
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
from_safetensors = file_extension == "safetensors"
if from_safetensors and use_safetensors is True:
raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
# TODO: For now we only support stable diffusion
stable_unclip = None
controlnet = False
if pipeline_name == "StableDiffusionControlNetPipeline":
model_type = "FrozenCLIPEmbedder"
controlnet = True
elif "StableDiffusion" in pipeline_name:
model_type = "FrozenCLIPEmbedder"
elif pipeline_name == "StableUnCLIPPipeline":
model_type == "FrozenOpenCLIPEmbedder"
stable_unclip = "txt2img"
elif pipeline_name == "StableUnCLIPImg2ImgPipeline":
model_type == "FrozenOpenCLIPEmbedder"
stable_unclip = "img2img"
elif pipeline_name == "PaintByExamplePipeline":
model_type == "PaintByExample"
elif pipeline_name == "LDMTextToImagePipeline":
model_type == "LDMTextToImage"
else:
raise ValueError(f"Unhandled pipeline class: {pipeline_name}")
# remove huggingface url
for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
if pretrained_model_link_or_path.startswith(prefix):
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
ckpt_path = Path(pretrained_model_link_or_path)
if not ckpt_path.is_file():
# get repo_id and (potentially nested) file path of ckpt in repo
repo_id = str(Path().joinpath(*ckpt_path.parts[:2]))
file_path = str(Path().joinpath(*ckpt_path.parts[2:]))
if file_path.startswith("blob/"):
file_path = file_path[len("blob/") :]
if file_path.startswith("main/"):
file_path = file_path[len("main/") :]
pretrained_model_link_or_path = hf_hub_download(
repo_id,
filename=file_path,
cache_dir=cache_dir,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
force_download=force_download,
)
pipe = download_from_original_stable_diffusion_ckpt(
pretrained_model_link_or_path,
pipeline_class=cls,
model_type=model_type,
stable_unclip=stable_unclip,
controlnet=controlnet,
from_safetensors=from_safetensors,
extract_ema=extract_ema,
image_size=image_size,
scheduler_type=scheduler_type,
num_in_channels=num_in_channels,
upcast_attention=upcast_attention,
load_safety_checker=load_safety_checker,
prediction_type=prediction_type,
)
if torch_dtype is not None:
pipe.to(torch_dtype=torch_dtype)
return pipe
...@@ -57,6 +57,14 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): ...@@ -57,6 +57,14 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
In addition the pipeline inherits the following loading methods:
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
as well as the following saving methods:
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
Args: Args:
vae ([`AutoencoderKL`]): vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
......
...@@ -96,6 +96,14 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin ...@@ -96,6 +96,14 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
In addition the pipeline inherits the following loading methods:
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
as well as the following saving methods:
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
Args: Args:
vae ([`AutoencoderKL`]): vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
......
...@@ -31,35 +31,30 @@ from transformers import ( ...@@ -31,35 +31,30 @@ from transformers import (
CLIPVisionModelWithProjection, CLIPVisionModelWithProjection,
) )
from diffusers import ( from ...models import (
AutoencoderKL, AutoencoderKL,
ControlNetModel, ControlNetModel,
PriorTransformer,
UNet2DConditionModel,
)
from ...schedulers import (
DDIMScheduler, DDIMScheduler,
DDPMScheduler, DDPMScheduler,
DPMSolverMultistepScheduler, DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler, EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler, EulerDiscreteScheduler,
HeunDiscreteScheduler, HeunDiscreteScheduler,
LDMTextToImagePipeline,
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
PriorTransformer,
StableDiffusionControlNetPipeline,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionPipeline,
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,
UnCLIPScheduler, UnCLIPScheduler,
UNet2DConditionModel,
) )
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
from ...utils import is_omegaconf_available, is_safetensors_available, logging from ...utils import is_omegaconf_available, is_safetensors_available, logging
from ...utils.import_utils import BACKENDS_MAPPING from ...utils.import_utils import BACKENDS_MAPPING
from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
from ..paint_by_example import PaintByExampleImageEncoder
from ..pipeline_utils import DiffusionPipeline
from .safety_checker import StableDiffusionSafetyChecker
from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -981,7 +976,6 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -981,7 +976,6 @@ def download_from_original_stable_diffusion_ckpt(
image_size: int = 512, image_size: int = 512,
prediction_type: str = None, prediction_type: str = None,
model_type: str = None, model_type: str = None,
is_img2img: bool = False,
extract_ema: bool = False, extract_ema: bool = False,
scheduler_type: str = "pndm", scheduler_type: str = "pndm",
num_in_channels: Optional[int] = None, num_in_channels: Optional[int] = None,
...@@ -993,7 +987,8 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -993,7 +987,8 @@ def download_from_original_stable_diffusion_ckpt(
clip_stats_path: Optional[str] = None, clip_stats_path: Optional[str] = None,
controlnet: Optional[bool] = None, controlnet: Optional[bool] = None,
load_safety_checker: bool = True, load_safety_checker: bool = True,
) -> StableDiffusionPipeline: pipeline_class: DiffusionPipeline = None,
) -> DiffusionPipeline:
""" """
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml` Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
config file. config file.
...@@ -1031,12 +1026,29 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1031,12 +1026,29 @@ def download_from_original_stable_diffusion_ckpt(
Whether the attention computation should always be upcasted. This is necessary when running stable Whether the attention computation should always be upcasted. This is necessary when running stable
diffusion 2.1. diffusion 2.1.
device (`str`, *optional*, defaults to `None`): device (`str`, *optional*, defaults to `None`):
The device to use. Pass `None` to determine automatically. :param from_safetensors: If `checkpoint_path` is The device to use. Pass `None` to determine automatically.
in `safetensors` format, load checkpoint with safetensors instead of PyTorch. :return: A from_safetensors (`str`, *optional*, defaults to `False`):
StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file. If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.
load_safety_checker (`bool`, *optional*, defaults to `True`): load_safety_checker (`bool`, *optional*, defaults to `True`):
Whether to load the safety checker or not. Defaults to `True`. Whether to load the safety checker or not. Defaults to `True`.
pipeline_class (`str`, *optional*, defaults to `None`):
The pipeline class to use. Pass `None` to determine automatically.
return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
""" """
# import pipelines here to avoid circular import error when using from_ckpt method
from diffusers import (
LDMTextToImagePipeline,
PaintByExamplePipeline,
StableDiffusionControlNetPipeline,
StableDiffusionPipeline,
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,
)
if pipeline_class is None:
pipeline_class = StableDiffusionPipeline
if prediction_type == "v-prediction": if prediction_type == "v-prediction":
prediction_type = "v_prediction" prediction_type = "v_prediction"
...@@ -1198,35 +1210,7 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1198,35 +1210,7 @@ def download_from_original_stable_diffusion_ckpt(
requires_safety_checker=False, requires_safety_checker=False,
) )
else: else:
if ( pipe = pipeline_class(
hasattr(original_config, "model")
and hasattr(original_config.model, "target")
and "LatentInpaintDiffusion" in original_config.model.target
):
pipe = StableDiffusionInpaintPipeline(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
else:
if is_img2img:
pipe = StableDiffusionImg2ImgPipeline(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
else:
pipe = StableDiffusionPipeline(
vae=vae, vae=vae,
text_encoder=text_model, text_encoder=text_model,
tokenizer=tokenizer, tokenizer=tokenizer,
...@@ -1326,33 +1310,7 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1326,33 +1310,7 @@ def download_from_original_stable_diffusion_ckpt(
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
else: else:
if ( pipe = pipeline_class(
hasattr(original_config, "model")
and hasattr(original_config.model, "target")
and "LatentInpaintDiffusion" in original_config.model.target
):
pipe = StableDiffusionInpaintPipeline(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
else:
if is_img2img:
pipe = StableDiffusionImg2ImgPipeline(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
else:
pipe = StableDiffusionPipeline(
vae=vae, vae=vae,
text_encoder=text_model, text_encoder=text_model,
tokenizer=tokenizer, tokenizer=tokenizer,
...@@ -1379,7 +1337,7 @@ def download_controlnet_from_original_ckpt( ...@@ -1379,7 +1337,7 @@ def download_controlnet_from_original_ckpt(
upcast_attention: Optional[bool] = None, upcast_attention: Optional[bool] = None,
device: str = None, device: str = None,
from_safetensors: bool = False, from_safetensors: bool = False,
) -> StableDiffusionPipeline: ) -> DiffusionPipeline:
if not is_omegaconf_available(): if not is_omegaconf_available():
raise ValueError(BACKENDS_MAPPING["omegaconf"][1]) raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
......
...@@ -20,7 +20,7 @@ from packaging import version ...@@ -20,7 +20,7 @@ from packaging import version
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
...@@ -53,13 +53,21 @@ EXAMPLE_DOC_STRING = """ ...@@ -53,13 +53,21 @@ EXAMPLE_DOC_STRING = """
""" """
class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin):
r""" r"""
Pipeline for text-to-image generation using Stable Diffusion. Pipeline for text-to-image generation using Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
In addition the pipeline inherits the following loading methods:
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
as well as the following saving methods:
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
Args: Args:
vae ([`AutoencoderKL`]): vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
......
...@@ -156,6 +156,9 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -156,6 +156,9 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
In addition the pipeline inherits the following loading methods:
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
Args: Args:
vae ([`AutoencoderKL`]): vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
......
...@@ -23,7 +23,7 @@ from packaging import version ...@@ -23,7 +23,7 @@ from packaging import version
from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...loaders import TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, randn_tensor from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, randn_tensor
...@@ -55,13 +55,20 @@ def preprocess(image): ...@@ -55,13 +55,20 @@ def preprocess(image):
return image return image
class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin): class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
r""" r"""
Pipeline for text-guided image to image generation using Stable Diffusion. Pipeline for text-guided image to image generation using Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
In addition the pipeline inherits the following loading methods:
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
as well as the following saving methods:
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
Args: Args:
vae ([`AutoencoderKL`]): vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
......
...@@ -23,7 +23,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer ...@@ -23,7 +23,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
...@@ -92,13 +92,21 @@ def preprocess(image): ...@@ -92,13 +92,21 @@ def preprocess(image):
return image return image
class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin): class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin):
r""" r"""
Pipeline for text-guided image to image generation using Stable Diffusion. Pipeline for text-guided image to image generation using Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
In addition the pipeline inherits the following loading methods:
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
as well as the following saving methods:
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
Args: Args:
vae ([`AutoencoderKL`]): vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
......
...@@ -22,7 +22,7 @@ from packaging import version ...@@ -22,7 +22,7 @@ from packaging import version
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...loaders import TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor
...@@ -138,13 +138,20 @@ def prepare_mask_and_masked_image(image, mask): ...@@ -138,13 +138,20 @@ def prepare_mask_and_masked_image(image, mask):
return mask, masked_image return mask, masked_image
class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin): class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
r""" r"""
Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
In addition the pipeline inherits the following loading methods:
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
as well as the following saving methods:
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
Args: Args:
vae ([`AutoencoderKL`]): vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
......
...@@ -22,7 +22,7 @@ from packaging import version ...@@ -22,7 +22,7 @@ from packaging import version
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...loaders import TextualInversionLoaderMixin from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
...@@ -82,13 +82,23 @@ def preprocess_mask(mask, scale_factor=8): ...@@ -82,13 +82,23 @@ def preprocess_mask(mask, scale_factor=8):
return mask return mask
class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline, TextualInversionLoaderMixin): class StableDiffusionInpaintPipelineLegacy(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin
):
r""" r"""
Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
In addition the pipeline inherits the following loading methods:
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
as well as the following saving methods:
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
Args: Args:
vae ([`AutoencoderKL`]): vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
......
...@@ -20,7 +20,7 @@ import PIL ...@@ -20,7 +20,7 @@ import PIL
import torch import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...loaders import TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
...@@ -61,13 +61,20 @@ def preprocess(image): ...@@ -61,13 +61,20 @@ def preprocess(image):
return image return image
class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin): class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
r""" r"""
Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion. Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
In addition the pipeline inherits the following loading methods:
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
as well as the following saving methods:
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
Args: Args:
vae ([`AutoencoderKL`]): vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
......
...@@ -2,21 +2,6 @@ ...@@ -2,21 +2,6 @@
from ..utils import DummyObject, requires_backends from ..utils import DummyObject, requires_backends
class TextualInversionLoaderMixin(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class AltDiffusionImg2ImgPipeline(metaclass=DummyObject): class AltDiffusionImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment