You need to sign in or sign up before continuing.
Unverified Commit 886575ee authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Refactor controlnet and add img2img and inpaint (#3386)

* refactor controlnet and add img2img and inpaint

* First draft to get pipelines to work

* make style

* Fix more

* Fix more

* More tests

* Fix more

* Make inpainting work

* make style and more tests

* Apply suggestions from code review

* up

* make style

* Fix imports

* Fix more

* Fix more

* Improve examples

* add test

* Make sure import is correctly deprecated

* Make sure everything works in compile mode

* make sure authorship is correctly attributed
parent 9d44e2fb
...@@ -148,6 +148,8 @@ ...@@ -148,6 +148,8 @@
title: Audio Diffusion title: Audio Diffusion
- local: api/pipelines/audioldm - local: api/pipelines/audioldm
title: AudioLDM title: AudioLDM
- local: api/pipelines/controlnet
title: ControlNet
- local: api/pipelines/cycle_diffusion - local: api/pipelines/cycle_diffusion
title: Cycle Diffusion title: Cycle Diffusion
- local: api/pipelines/dance_diffusion - local: api/pipelines/dance_diffusion
...@@ -203,8 +205,6 @@ ...@@ -203,8 +205,6 @@
title: Self-Attention Guidance title: Self-Attention Guidance
- local: api/pipelines/stable_diffusion/panorama - local: api/pipelines/stable_diffusion/panorama
title: MultiDiffusion Panorama title: MultiDiffusion Panorama
- local: api/pipelines/stable_diffusion/controlnet
title: Text-to-Image Generation with ControlNet Conditioning
- local: api/pipelines/stable_diffusion/model_editing - local: api/pipelines/stable_diffusion/model_editing
title: Text-to-Image Model Editing title: Text-to-Image Model Editing
- local: api/pipelines/stable_diffusion/diffedit - local: api/pipelines/stable_diffusion/diffedit
......
...@@ -46,7 +46,7 @@ available a colab notebook to directly try them out. ...@@ -46,7 +46,7 @@ available a colab notebook to directly try them out.
|---|---|:---:|:---:| |---|---|:---:|:---:|
| [alt_diffusion](./alt_diffusion) | [**AltDiffusion**](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation | - | [alt_diffusion](./alt_diffusion) | [**AltDiffusion**](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation | -
| [audio_diffusion](./audio_diffusion) | [**Audio Diffusion**](https://github.com/teticio/audio_diffusion.git) | Unconditional Audio Generation | | [audio_diffusion](./audio_diffusion) | [**Audio Diffusion**](https://github.com/teticio/audio_diffusion.git) | Unconditional Audio Generation |
| [controlnet](./api/pipelines/stable_diffusion/controlnet) | [**ControlNet with Stable Diffusion**](https://arxiv.org/abs/2302.05543) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/controlnet.ipynb) | [controlnet](./api/pipelines/controlnet) | [**ControlNet with Stable Diffusion**](https://arxiv.org/abs/2302.05543) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/controlnet.ipynb)
| [cycle_diffusion](./cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation | | [cycle_diffusion](./cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation |
| [dance_diffusion](./dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation | | [dance_diffusion](./dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
| [ddpm](./ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation | | [ddpm](./ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
......
...@@ -53,7 +53,7 @@ The library has three main components: ...@@ -53,7 +53,7 @@ The library has three main components:
|---|---|:---:| |---|---|:---:|
| [alt_diffusion](./api/pipelines/alt_diffusion) | [AltCLIP: Altering the Language Encoder in CLIP for Extended Language Capabilities](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation | | [alt_diffusion](./api/pipelines/alt_diffusion) | [AltCLIP: Altering the Language Encoder in CLIP for Extended Language Capabilities](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation |
| [audio_diffusion](./api/pipelines/audio_diffusion) | [Audio Diffusion](https://github.com/teticio/audio-diffusion.git) | Unconditional Audio Generation | | [audio_diffusion](./api/pipelines/audio_diffusion) | [Audio Diffusion](https://github.com/teticio/audio-diffusion.git) | Unconditional Audio Generation |
| [controlnet](./api/pipelines/stable_diffusion/controlnet) | [Adding Conditional Control to Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.05543) | Image-to-Image Text-Guided Generation | | [controlnet](./api/pipelines/controlnet) | [Adding Conditional Control to Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.05543) | Image-to-Image Text-Guided Generation |
| [cycle_diffusion](./api/pipelines/cycle_diffusion) | [Unifying Diffusion Models' Latent Space, with Applications to CycleDiffusion and Guidance](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation | | [cycle_diffusion](./api/pipelines/cycle_diffusion) | [Unifying Diffusion Models' Latent Space, with Applications to CycleDiffusion and Guidance](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation |
| [dance_diffusion](./api/pipelines/dance_diffusion) | [Dance Diffusion](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation | | [dance_diffusion](./api/pipelines/dance_diffusion) | [Dance Diffusion](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
| [ddpm](./api/pipelines/ddpm) | [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation | | [ddpm](./api/pipelines/ddpm) | [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
......
...@@ -132,6 +132,8 @@ else: ...@@ -132,6 +132,8 @@ else:
PaintByExamplePipeline, PaintByExamplePipeline,
SemanticStableDiffusionPipeline, SemanticStableDiffusionPipeline,
StableDiffusionAttendAndExcitePipeline, StableDiffusionAttendAndExcitePipeline,
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline,
StableDiffusionControlNetPipeline, StableDiffusionControlNetPipeline,
StableDiffusionDepth2ImgPipeline, StableDiffusionDepth2ImgPipeline,
StableDiffusionDiffEditPipeline, StableDiffusionDiffEditPipeline,
......
...@@ -17,3 +17,13 @@ ...@@ -17,3 +17,13 @@
# It only exists so that temporarely `from diffusers.pipelines import DiffusionPipeline` works # It only exists so that temporarely `from diffusers.pipelines import DiffusionPipeline` works
from .pipelines import DiffusionPipeline, ImagePipelineOutput # noqa: F401 from .pipelines import DiffusionPipeline, ImagePipelineOutput # noqa: F401
from .utils import deprecate
deprecate(
"pipelines_utils",
"0.22.0",
"Importing `DiffusionPipeline` or `ImagePipelineOutput` from diffusers.pipeline_utils is deprecated. Please import from diffusers.pipelines.pipeline_utils instead.",
standard_warn=False,
stacklevel=3,
)
...@@ -44,6 +44,11 @@ except OptionalDependencyNotAvailable: ...@@ -44,6 +44,11 @@ except OptionalDependencyNotAvailable:
else: else:
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
from .audioldm import AudioLDMPipeline from .audioldm import AudioLDMPipeline
from .controlnet import (
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline,
StableDiffusionControlNetPipeline,
)
from .deepfloyd_if import ( from .deepfloyd_if import (
IFImg2ImgPipeline, IFImg2ImgPipeline,
IFImg2ImgSuperResolutionPipeline, IFImg2ImgSuperResolutionPipeline,
...@@ -58,7 +63,6 @@ else: ...@@ -58,7 +63,6 @@ else:
from .stable_diffusion import ( from .stable_diffusion import (
CycleDiffusionPipeline, CycleDiffusionPipeline,
StableDiffusionAttendAndExcitePipeline, StableDiffusionAttendAndExcitePipeline,
StableDiffusionControlNetPipeline,
StableDiffusionDepth2ImgPipeline, StableDiffusionDepth2ImgPipeline,
StableDiffusionDiffEditPipeline, StableDiffusionDiffEditPipeline,
StableDiffusionImageVariationPipeline, StableDiffusionImageVariationPipeline,
...@@ -133,8 +137,8 @@ try: ...@@ -133,8 +137,8 @@ try:
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ..utils.dummy_flax_and_transformers_objects import * # noqa F403 from ..utils.dummy_flax_and_transformers_objects import * # noqa F403
else: else:
from .controlnet import FlaxStableDiffusionControlNetPipeline
from .stable_diffusion import ( from .stable_diffusion import (
FlaxStableDiffusionControlNetPipeline,
FlaxStableDiffusionImg2ImgPipeline, FlaxStableDiffusionImg2ImgPipeline,
FlaxStableDiffusionInpaintPipeline, FlaxStableDiffusionInpaintPipeline,
FlaxStableDiffusionPipeline, FlaxStableDiffusionPipeline,
......
from ...utils import (
OptionalDependencyNotAvailable,
is_flax_available,
is_torch_available,
is_transformers_available,
)
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .multicontrolnet import MultiControlNetModel
from .pipeline_controlnet import StableDiffusionControlNetPipeline
from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
if is_transformers_available() and is_flax_available():
from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from ...models.controlnet import ControlNetModel, ControlNetOutput
from ...models.modeling_utils import ModelMixin
class MultiControlNetModel(ModelMixin):
r"""
Multiple `ControlNetModel` wrapper class for Multi-ControlNet
This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be
compatible with `ControlNetModel`.
Args:
controlnets (`List[ControlNetModel]`):
Provides additional conditioning to the unet during the denoising process. You must set multiple
`ControlNetModel` as a list.
"""
def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]):
super().__init__()
self.nets = nn.ModuleList(controlnets)
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
controlnet_cond: List[torch.tensor],
conditioning_scale: List[float],
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guess_mode: bool = False,
return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple]:
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
down_samples, mid_sample = controlnet(
sample,
timestep,
encoder_hidden_states,
image,
scale,
class_labels,
timestep_cond,
attention_mask,
cross_attention_kwargs,
guess_mode,
return_dict,
)
# merge samples
if i == 0:
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
else:
down_block_res_samples = [
samples_prev + samples_curr
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
]
mid_block_res_sample += mid_sample
return down_block_res_samples, mid_block_res_sample
This diff is collapsed.
This diff is collapsed.
...@@ -8,10 +8,10 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer ...@@ -8,10 +8,10 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import logging, randn_tensor from ...utils import logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from . import SemanticStableDiffusionPipelineOutput from . import SemanticStableDiffusionPipelineOutput
......
...@@ -45,7 +45,6 @@ else: ...@@ -45,7 +45,6 @@ else:
from .pipeline_cycle_diffusion import CycleDiffusionPipeline from .pipeline_cycle_diffusion import CycleDiffusionPipeline
from .pipeline_stable_diffusion import StableDiffusionPipeline from .pipeline_stable_diffusion import StableDiffusionPipeline
from .pipeline_stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline from .pipeline_stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline
from .pipeline_stable_diffusion_controlnet import StableDiffusionControlNetPipeline
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy
...@@ -130,7 +129,6 @@ if is_transformers_available() and is_flax_available(): ...@@ -130,7 +129,6 @@ if is_transformers_available() and is_flax_available():
from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState
from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline
from .pipeline_flax_stable_diffusion_controlnet import FlaxStableDiffusionControlNetPipeline
from .pipeline_flax_stable_diffusion_img2img import FlaxStableDiffusionImg2ImgPipeline from .pipeline_flax_stable_diffusion_img2img import FlaxStableDiffusionImg2ImgPipeline
from .pipeline_flax_stable_diffusion_inpaint import FlaxStableDiffusionInpaintPipeline from .pipeline_flax_stable_diffusion_inpaint import FlaxStableDiffusionInpaintPipeline
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
...@@ -212,6 +212,36 @@ class StableDiffusionAttendAndExcitePipeline(metaclass=DummyObject): ...@@ -212,6 +212,36 @@ class StableDiffusionAttendAndExcitePipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class StableDiffusionControlNetImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionControlNetInpaintPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionControlNetPipeline(metaclass=DummyObject): class StableDiffusionControlNetPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
......
...@@ -34,7 +34,10 @@ from diffusers.utils import load_image, load_numpy, randn_tensor, slow, torch_de ...@@ -34,7 +34,10 @@ from diffusers.utils import load_image, load_numpy, randn_tensor, slow, torch_de
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import require_torch_gpu from diffusers.utils.testing_utils import require_torch_gpu
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS from ..pipeline_params import (
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_PARAMS,
)
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
...@@ -42,7 +45,7 @@ torch.backends.cuda.matmul.allow_tf32 = False ...@@ -42,7 +45,7 @@ torch.backends.cuda.matmul.allow_tf32 = False
torch.use_deterministic_algorithms(True) torch.use_deterministic_algorithms(True)
class StableDiffusionControlNetPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): class ControlNetPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableDiffusionControlNetPipeline pipeline_class = StableDiffusionControlNetPipeline
params = TEXT_TO_IMAGE_PARAMS params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
...@@ -155,6 +158,7 @@ class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unitt ...@@ -155,6 +158,7 @@ class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unitt
pipeline_class = StableDiffusionControlNetPipeline pipeline_class = StableDiffusionControlNetPipeline
params = TEXT_TO_IMAGE_PARAMS params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
...@@ -307,7 +311,7 @@ class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unitt ...@@ -307,7 +311,7 @@ class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unitt
@slow @slow
@require_torch_gpu @require_torch_gpu
class StableDiffusionControlNetPipelineSlowTests(unittest.TestCase): class ControlNetPipelineSlowTests(unittest.TestCase):
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
gc.collect() gc.collect()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment