Unverified Commit 6ea83608 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Stable Diffusion Inpainting] Deprecate inpainting pipeline in favor of official one (#903)



* finish

* up

* Apply suggestions from code review
Co-authored-by: default avatarAnton Lozhkov <anton@huggingface.co>

* Update src/diffusers/pipeline_utils.py

* Finish
Co-authored-by: default avatarAnton Lozhkov <anton@huggingface.co>
parent bd216073
...@@ -52,6 +52,7 @@ if is_torch_available() and is_transformers_available(): ...@@ -52,6 +52,7 @@ if is_torch_available() and is_transformers_available():
LDMTextToImagePipeline, LDMTextToImagePipeline,
StableDiffusionImg2ImgPipeline, StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline, StableDiffusionInpaintPipeline,
StableDiffusionInpaintPipelineLegacy,
StableDiffusionPipeline, StableDiffusionPipeline,
) )
else: else:
......
...@@ -40,6 +40,7 @@ from .utils import ( ...@@ -40,6 +40,7 @@ from .utils import (
ONNX_WEIGHTS_NAME, ONNX_WEIGHTS_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
BaseOutput, BaseOutput,
deprecate,
is_transformers_available, is_transformers_available,
logging, logging,
) )
...@@ -413,6 +414,25 @@ class DiffusionPipeline(ConfigMixin): ...@@ -413,6 +414,25 @@ class DiffusionPipeline(ConfigMixin):
diffusers_module = importlib.import_module(cls.__module__.split(".")[0]) diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
pipeline_class = getattr(diffusers_module, config_dict["_class_name"]) pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
# To be removed in 1.0.0
if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
version.parse(config_dict["_diffusers_version"]).base_version
) <= version.parse("0.5.1"):
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy
pipeline_class = StableDiffusionInpaintPipelineLegacy
deprecation_message = (
"You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the"
f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For"
" better inpainting results, we strongly suggest using Stable Diffusion's official inpainting"
" checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your"
f" checkpoint {pretrained_model_name_or_path} to the format of"
" https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain"
" the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0."
)
deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False)
# some modules can be passed directly to the init # some modules can be passed directly to the init
# in this case they are already instantiated in `kwargs` # in this case they are already instantiated in `kwargs`
# extract them here # extract them here
......
...@@ -16,6 +16,7 @@ if is_torch_available() and is_transformers_available(): ...@@ -16,6 +16,7 @@ if is_torch_available() and is_transformers_available():
from .stable_diffusion import ( from .stable_diffusion import (
StableDiffusionImg2ImgPipeline, StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline, StableDiffusionInpaintPipeline,
StableDiffusionInpaintPipelineLegacy,
StableDiffusionPipeline, StableDiffusionPipeline,
) )
......
...@@ -31,6 +31,7 @@ if is_transformers_available() and is_torch_available(): ...@@ -31,6 +31,7 @@ if is_transformers_available() and is_torch_available():
from .pipeline_stable_diffusion import StableDiffusionPipeline from .pipeline_stable_diffusion import StableDiffusionPipeline
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
if is_transformers_available() and is_onnx_available(): if is_transformers_available() and is_onnx_available():
......
...@@ -82,8 +82,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -82,8 +82,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPFeatureExtractor,
): ):
super().__init__() super().__init__()
logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.")
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = ( deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
...@@ -223,6 +221,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -223,6 +221,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`. (nsfw) content, according to the `safety_checker`.
""" """
# TODO(Suraj) - adapt to your use case
if isinstance(prompt, str): if isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif isinstance(prompt, list): elif isinstance(prompt, list):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment