Unverified Commit 886575ee authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Refactor controlnet and add img2img and inpaint (#3386)

* refactor controlnet and add img2img and inpaint

* First draft to get pipelines to work

* make style

* Fix more

* Fix more

* More tests

* Fix more

* Make inpainting work

* make style and more tests

* Apply suggestions from code review

* up

* make style

* Fix imports

* Fix more

* Fix more

* Improve examples

* add test

* Make sure import is correctly deprecated

* Make sure everything works in compile mode

* make sure authorship is correctly attributed
parent 9d44e2fb
...@@ -148,6 +148,8 @@ ...@@ -148,6 +148,8 @@
title: Audio Diffusion title: Audio Diffusion
- local: api/pipelines/audioldm - local: api/pipelines/audioldm
title: AudioLDM title: AudioLDM
- local: api/pipelines/controlnet
title: ControlNet
- local: api/pipelines/cycle_diffusion - local: api/pipelines/cycle_diffusion
title: Cycle Diffusion title: Cycle Diffusion
- local: api/pipelines/dance_diffusion - local: api/pipelines/dance_diffusion
...@@ -203,8 +205,6 @@ ...@@ -203,8 +205,6 @@
title: Self-Attention Guidance title: Self-Attention Guidance
- local: api/pipelines/stable_diffusion/panorama - local: api/pipelines/stable_diffusion/panorama
title: MultiDiffusion Panorama title: MultiDiffusion Panorama
- local: api/pipelines/stable_diffusion/controlnet
title: Text-to-Image Generation with ControlNet Conditioning
- local: api/pipelines/stable_diffusion/model_editing - local: api/pipelines/stable_diffusion/model_editing
title: Text-to-Image Model Editing title: Text-to-Image Model Editing
- local: api/pipelines/stable_diffusion/diffedit - local: api/pipelines/stable_diffusion/diffedit
......
...@@ -46,7 +46,7 @@ available a colab notebook to directly try them out. ...@@ -46,7 +46,7 @@ available a colab notebook to directly try them out.
|---|---|:---:|:---:| |---|---|:---:|:---:|
| [alt_diffusion](./alt_diffusion) | [**AltDiffusion**](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation | - | [alt_diffusion](./alt_diffusion) | [**AltDiffusion**](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation | -
| [audio_diffusion](./audio_diffusion) | [**Audio Diffusion**](https://github.com/teticio/audio_diffusion.git) | Unconditional Audio Generation | | [audio_diffusion](./audio_diffusion) | [**Audio Diffusion**](https://github.com/teticio/audio_diffusion.git) | Unconditional Audio Generation |
| [controlnet](./api/pipelines/stable_diffusion/controlnet) | [**ControlNet with Stable Diffusion**](https://arxiv.org/abs/2302.05543) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/controlnet.ipynb) | [controlnet](./api/pipelines/controlnet) | [**ControlNet with Stable Diffusion**](https://arxiv.org/abs/2302.05543) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/controlnet.ipynb)
| [cycle_diffusion](./cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation | | [cycle_diffusion](./cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation |
| [dance_diffusion](./dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation | | [dance_diffusion](./dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
| [ddpm](./ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation | | [ddpm](./ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
......
...@@ -53,7 +53,7 @@ The library has three main components: ...@@ -53,7 +53,7 @@ The library has three main components:
|---|---|:---:| |---|---|:---:|
| [alt_diffusion](./api/pipelines/alt_diffusion) | [AltCLIP: Altering the Language Encoder in CLIP for Extended Language Capabilities](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation | | [alt_diffusion](./api/pipelines/alt_diffusion) | [AltCLIP: Altering the Language Encoder in CLIP for Extended Language Capabilities](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation |
| [audio_diffusion](./api/pipelines/audio_diffusion) | [Audio Diffusion](https://github.com/teticio/audio-diffusion.git) | Unconditional Audio Generation | | [audio_diffusion](./api/pipelines/audio_diffusion) | [Audio Diffusion](https://github.com/teticio/audio-diffusion.git) | Unconditional Audio Generation |
| [controlnet](./api/pipelines/stable_diffusion/controlnet) | [Adding Conditional Control to Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.05543) | Image-to-Image Text-Guided Generation | | [controlnet](./api/pipelines/controlnet) | [Adding Conditional Control to Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.05543) | Image-to-Image Text-Guided Generation |
| [cycle_diffusion](./api/pipelines/cycle_diffusion) | [Unifying Diffusion Models' Latent Space, with Applications to CycleDiffusion and Guidance](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation | | [cycle_diffusion](./api/pipelines/cycle_diffusion) | [Unifying Diffusion Models' Latent Space, with Applications to CycleDiffusion and Guidance](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation |
| [dance_diffusion](./api/pipelines/dance_diffusion) | [Dance Diffusion](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation | | [dance_diffusion](./api/pipelines/dance_diffusion) | [Dance Diffusion](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
| [ddpm](./api/pipelines/ddpm) | [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation | | [ddpm](./api/pipelines/ddpm) | [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
......
...@@ -132,6 +132,8 @@ else: ...@@ -132,6 +132,8 @@ else:
PaintByExamplePipeline, PaintByExamplePipeline,
SemanticStableDiffusionPipeline, SemanticStableDiffusionPipeline,
StableDiffusionAttendAndExcitePipeline, StableDiffusionAttendAndExcitePipeline,
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline,
StableDiffusionControlNetPipeline, StableDiffusionControlNetPipeline,
StableDiffusionDepth2ImgPipeline, StableDiffusionDepth2ImgPipeline,
StableDiffusionDiffEditPipeline, StableDiffusionDiffEditPipeline,
......
...@@ -17,3 +17,13 @@ ...@@ -17,3 +17,13 @@
# It only exists so that temporarely `from diffusers.pipelines import DiffusionPipeline` works # It only exists so that temporarely `from diffusers.pipelines import DiffusionPipeline` works
from .pipelines import DiffusionPipeline, ImagePipelineOutput # noqa: F401 from .pipelines import DiffusionPipeline, ImagePipelineOutput # noqa: F401
from .utils import deprecate
deprecate(
"pipelines_utils",
"0.22.0",
"Importing `DiffusionPipeline` or `ImagePipelineOutput` from diffusers.pipeline_utils is deprecated. Please import from diffusers.pipelines.pipeline_utils instead.",
standard_warn=False,
stacklevel=3,
)
...@@ -44,6 +44,11 @@ except OptionalDependencyNotAvailable: ...@@ -44,6 +44,11 @@ except OptionalDependencyNotAvailable:
else: else:
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
from .audioldm import AudioLDMPipeline from .audioldm import AudioLDMPipeline
from .controlnet import (
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline,
StableDiffusionControlNetPipeline,
)
from .deepfloyd_if import ( from .deepfloyd_if import (
IFImg2ImgPipeline, IFImg2ImgPipeline,
IFImg2ImgSuperResolutionPipeline, IFImg2ImgSuperResolutionPipeline,
...@@ -58,7 +63,6 @@ else: ...@@ -58,7 +63,6 @@ else:
from .stable_diffusion import ( from .stable_diffusion import (
CycleDiffusionPipeline, CycleDiffusionPipeline,
StableDiffusionAttendAndExcitePipeline, StableDiffusionAttendAndExcitePipeline,
StableDiffusionControlNetPipeline,
StableDiffusionDepth2ImgPipeline, StableDiffusionDepth2ImgPipeline,
StableDiffusionDiffEditPipeline, StableDiffusionDiffEditPipeline,
StableDiffusionImageVariationPipeline, StableDiffusionImageVariationPipeline,
...@@ -133,8 +137,8 @@ try: ...@@ -133,8 +137,8 @@ try:
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ..utils.dummy_flax_and_transformers_objects import * # noqa F403 from ..utils.dummy_flax_and_transformers_objects import * # noqa F403
else: else:
from .controlnet import FlaxStableDiffusionControlNetPipeline
from .stable_diffusion import ( from .stable_diffusion import (
FlaxStableDiffusionControlNetPipeline,
FlaxStableDiffusionImg2ImgPipeline, FlaxStableDiffusionImg2ImgPipeline,
FlaxStableDiffusionInpaintPipeline, FlaxStableDiffusionInpaintPipeline,
FlaxStableDiffusionPipeline, FlaxStableDiffusionPipeline,
......
from ...utils import (
OptionalDependencyNotAvailable,
is_flax_available,
is_torch_available,
is_transformers_available,
)
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .multicontrolnet import MultiControlNetModel
from .pipeline_controlnet import StableDiffusionControlNetPipeline
from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
if is_transformers_available() and is_flax_available():
from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from ...models.controlnet import ControlNetModel, ControlNetOutput
from ...models.modeling_utils import ModelMixin
class MultiControlNetModel(ModelMixin):
r"""
Multiple `ControlNetModel` wrapper class for Multi-ControlNet
This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be
compatible with `ControlNetModel`.
Args:
controlnets (`List[ControlNetModel]`):
Provides additional conditioning to the unet during the denoising process. You must set multiple
`ControlNetModel` as a list.
"""
def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]):
super().__init__()
self.nets = nn.ModuleList(controlnets)
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
controlnet_cond: List[torch.tensor],
conditioning_scale: List[float],
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guess_mode: bool = False,
return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple]:
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
down_samples, mid_sample = controlnet(
sample,
timestep,
encoder_hidden_states,
image,
scale,
class_labels,
timestep_cond,
attention_mask,
cross_attention_kwargs,
guess_mode,
return_dict,
)
# merge samples
if i == 0:
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
else:
down_block_res_samples = [
samples_prev + samples_curr
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
]
mid_block_res_sample += mid_sample
return down_block_res_samples, mid_block_res_sample
This diff is collapsed.
This diff is collapsed.
...@@ -8,10 +8,10 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer ...@@ -8,10 +8,10 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import logging, randn_tensor from ...utils import logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from . import SemanticStableDiffusionPipelineOutput from . import SemanticStableDiffusionPipelineOutput
......
...@@ -45,7 +45,6 @@ else: ...@@ -45,7 +45,6 @@ else:
from .pipeline_cycle_diffusion import CycleDiffusionPipeline from .pipeline_cycle_diffusion import CycleDiffusionPipeline
from .pipeline_stable_diffusion import StableDiffusionPipeline from .pipeline_stable_diffusion import StableDiffusionPipeline
from .pipeline_stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline from .pipeline_stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline
from .pipeline_stable_diffusion_controlnet import StableDiffusionControlNetPipeline
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy
...@@ -130,7 +129,6 @@ if is_transformers_available() and is_flax_available(): ...@@ -130,7 +129,6 @@ if is_transformers_available() and is_flax_available():
from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState
from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline
from .pipeline_flax_stable_diffusion_controlnet import FlaxStableDiffusionControlNetPipeline
from .pipeline_flax_stable_diffusion_img2img import FlaxStableDiffusionImg2ImgPipeline from .pipeline_flax_stable_diffusion_img2img import FlaxStableDiffusionImg2ImgPipeline
from .pipeline_flax_stable_diffusion_inpaint import FlaxStableDiffusionInpaintPipeline from .pipeline_flax_stable_diffusion_inpaint import FlaxStableDiffusionInpaintPipeline
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
...@@ -212,6 +212,36 @@ class StableDiffusionAttendAndExcitePipeline(metaclass=DummyObject): ...@@ -212,6 +212,36 @@ class StableDiffusionAttendAndExcitePipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class StableDiffusionControlNetImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionControlNetInpaintPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionControlNetPipeline(metaclass=DummyObject): class StableDiffusionControlNetPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
......
...@@ -34,7 +34,10 @@ from diffusers.utils import load_image, load_numpy, randn_tensor, slow, torch_de ...@@ -34,7 +34,10 @@ from diffusers.utils import load_image, load_numpy, randn_tensor, slow, torch_de
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import require_torch_gpu from diffusers.utils.testing_utils import require_torch_gpu
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS from ..pipeline_params import (
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_PARAMS,
)
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
...@@ -42,7 +45,7 @@ torch.backends.cuda.matmul.allow_tf32 = False ...@@ -42,7 +45,7 @@ torch.backends.cuda.matmul.allow_tf32 = False
torch.use_deterministic_algorithms(True) torch.use_deterministic_algorithms(True)
class StableDiffusionControlNetPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): class ControlNetPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableDiffusionControlNetPipeline pipeline_class = StableDiffusionControlNetPipeline
params = TEXT_TO_IMAGE_PARAMS params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
...@@ -155,6 +158,7 @@ class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unitt ...@@ -155,6 +158,7 @@ class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unitt
pipeline_class = StableDiffusionControlNetPipeline pipeline_class = StableDiffusionControlNetPipeline
params = TEXT_TO_IMAGE_PARAMS params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
...@@ -307,7 +311,7 @@ class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unitt ...@@ -307,7 +311,7 @@ class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unitt
@slow @slow
@require_torch_gpu @require_torch_gpu
class StableDiffusionControlNetPipelineSlowTests(unittest.TestCase): class ControlNetPipelineSlowTests(unittest.TestCase):
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
gc.collect() gc.collect()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment