Unverified Commit 3eb498e7 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Core] add: controlnet support for SDXL (#4038)

* add: controlnet sdxl.

* modifications to controlnet.

* run styling.

* add: __init__.pys

* incorporate https://github.com/huggingface/diffusers/pull/4019

 changes.

* run make fix-copies.

* resize the conditioning images.

* remove autocast.

* run styling.

* disable autocast.

* debugging

* device placement.

* back to autocast.

* remove comment.

* save some memory by reusing the vae and unet in the pipeline.

* apply styling.

* Allow low precision sd xl

* finish

* finish

* changes to accommodate the improved VAE.

* modifications to how we handle vae encoding in the training.

* make style

* make existing controlnet fast tests pass.

* change vae checkpoint cli arg.

* fix: vae pretrained paths.

* fix: steps in get_scheduler().

* debugging.

* debugging./

* fix: weight conversion.

* add: docs.

* add: limited tests./

* add: datasets to the requirements.

* update docstrings and incorporate the usage of watermarking.

* incorporate fix from #4083

* fix watermarking dependency handling.

* run make-fix-copies.

* Empty-Commit

* Update requirements_sdxl.txt

* remove vae upcasting part.

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* run make style

* run make fix-copies.

* disable suppot for multicontrolnet.

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* run make fix-copies.

* dtyle/.

* fix-copies.

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent c6e56e92
...@@ -274,9 +274,9 @@ pipe = StableDiffusionControlNetPipeline.from_pretrained( ...@@ -274,9 +274,9 @@ pipe = StableDiffusionControlNetPipeline.from_pretrained(
# speed up diffusion process with faster scheduler and memory optimization # speed up diffusion process with faster scheduler and memory optimization
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
# remove following line if xformers is not installed # remove following line if xformers is not installed or when using Torch 2.0.
pipe.enable_xformers_memory_efficient_attention() pipe.enable_xformers_memory_efficient_attention()
# memory optimization.
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload()
control_image = load_image("./conditioning_image_1.png") control_image = load_image("./conditioning_image_1.png")
...@@ -287,7 +287,6 @@ generator = torch.manual_seed(0) ...@@ -287,7 +287,6 @@ generator = torch.manual_seed(0)
image = pipe( image = pipe(
prompt, num_inference_steps=20, generator=generator, image=control_image prompt, num_inference_steps=20, generator=generator, image=control_image
).images[0] ).images[0]
image.save("./output.png") image.save("./output.png")
``` ```
...@@ -460,3 +459,7 @@ The profile can then be inspected at http://localhost:6006/#profile ...@@ -460,3 +459,7 @@ The profile can then be inspected at http://localhost:6006/#profile
Sometimes you'll get version conflicts (error messages like `Duplicate plugins for name projector`), which means that you have to uninstall and reinstall all versions of Tensorflow/Tensorboard (e.g. with `pip uninstall tensorflow tf-nightly tensorboard tb-nightly tensorboard-plugin-profile && pip install tf-nightly tbp-nightly tensorboard-plugin-profile`). Sometimes you'll get version conflicts (error messages like `Duplicate plugins for name projector`), which means that you have to uninstall and reinstall all versions of Tensorflow/Tensorboard (e.g. with `pip uninstall tensorflow tf-nightly tensorboard tb-nightly tensorboard-plugin-profile && pip install tf-nightly tbp-nightly tensorboard-plugin-profile`).
Note that the debugging functionality of the Tensorboard `profile` plugin is still under active development. Not all views are fully functional, and for example the `trace_viewer` cuts off events after 1M (which can result in all your device traces getting lost if you for example profile the compilation step by accident). Note that the debugging functionality of the Tensorboard `profile` plugin is still under active development. Not all views are fully functional, and for example the `trace_viewer` cuts off events after 1M (which can result in all your device traces getting lost if you for example profile the compilation step by accident).
## Support for Stable Diffusion XL
We provide a training script for training a ControlNet with [Stable Diffusion XL](https://huggingface.co/papers/2307.01952). Please refer to [README_sdxl.md](./README_sdxl.md) for more details.
# DreamBooth training example for Stable Diffusion XL (SDXL)
The `train_controlnet_sdxl.py` script shows how to implement the training procedure and adapt it for [Stable Diffusion XL](https://huggingface.co/papers/2307.01952).
## Running locally with PyTorch
### Installing the dependencies
Before running the scripts, make sure to install the library's training dependencies:
**Important**
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .
```
Then cd in the `examples/controlnet` folder and run
```bash
pip install -r requirements_sdxl.txt
```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash
accelerate config
```
Or for a default accelerate configuration without answering questions about your environment
```bash
accelerate config default
```
Or if your environment doesn't support an interactive shell (e.g., a notebook)
```python
from accelerate.utils import write_basic_config
write_basic_config()
```
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
## Circle filling dataset
The original dataset is hosted in the [ControlNet repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip). We re-uploaded it to be compatible with `datasets` [here](https://huggingface.co/datasets/fusing/fill50k). Note that `datasets` handles dataloading within the training script.
## Training
Our training examples use two test conditioning images. They can be downloaded by running
```sh
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
```
Then run `huggingface-cli login` to log into your Hugging Face account. This is needed to be able to push the trained ControlNet parameters to Hugging Face Hub.
```bash
export MODEL_DIR="stabilityai/stable-diffusion-xl-base-0.9"
export OUTPUT_DIR="path to save model"
accelerate launch train_controlnet_sdxl.py \
--pretrained_model_name_or_path=$MODEL_DIR \
--output_dir=$OUTPUT_DIR \
--dataset_name=fusing/fill50k \
--mixed_precision="fp16" \
--resolution=1024 \
--learning_rate=1e-5 \
--max_train_steps=15000 \
--validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
--validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
--validation_steps=100 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--report_to="wandb" \
--seed=42 \
--push_to_hub
```
To better track our training experiments, we're using the following flags in the command above:
* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
* `validation_image`, `validation_prompt`, and `validation_steps` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
Our experiments were conducted on a single 40GB A100 GPU.
### Inference
Once training is done, we can perform inference like so:
```python
from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
from diffusers.utils import load_image
import torch
base_model_path = "stabilityai/stable-diffusion-xl-base-0.9"
controlnet_path = "path to controlnet"
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
base_model_path, controlnet=controlnet, torch_dtype=torch.float16
)
# speed up diffusion process with faster scheduler and memory optimization
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
# remove following line if xformers is not installed or when using Torch 2.0.
pipe.enable_xformers_memory_efficient_attention()
# memory optimization.
pipe.enable_model_cpu_offload()
control_image = load_image("./conditioning_image_1.png")
prompt = "pale golden rod circle with old lace background"
# generate image
generator = torch.manual_seed(0)
image = pipe(
prompt, num_inference_steps=20, generator=generator, image=control_image
).images[0]
image.save("./output.png")
```
## Notes
### Specifying a better VAE
SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
\ No newline at end of file
accelerate>=0.16.0
torchvision
transformers>=4.25.1
ftfy
tensorboard
Jinja2
invisible-watermark>=0.2.0
datasets
wandb
This diff is collapsed.
...@@ -199,6 +199,7 @@ except OptionalDependencyNotAvailable: ...@@ -199,6 +199,7 @@ except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403 from .utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403
else: else:
from .pipelines import ( from .pipelines import (
StableDiffusionXLControlNetPipeline,
StableDiffusionXLImg2ImgPipeline, StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline, StableDiffusionXLInpaintPipeline,
StableDiffusionXLPipeline, StableDiffusionXLPipeline,
......
...@@ -21,7 +21,7 @@ from torch.nn import functional as F ...@@ -21,7 +21,7 @@ from torch.nn import functional as F
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging from ..utils import BaseOutput, logging
from .attention_processor import AttentionProcessor, AttnProcessor from .attention_processor import AttentionProcessor, AttnProcessor
from .embeddings import TimestepEmbedding, Timesteps from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .unet_2d_blocks import ( from .unet_2d_blocks import (
CrossAttnDownBlock2D, CrossAttnDownBlock2D,
...@@ -131,12 +131,25 @@ class ControlNetModel(ModelMixin, ConfigMixin): ...@@ -131,12 +131,25 @@ class ControlNetModel(ModelMixin, ConfigMixin):
The epsilon to use for the normalization. The epsilon to use for the normalization.
cross_attention_dim (`int`, defaults to 1280): cross_attention_dim (`int`, defaults to 1280):
The dimension of the cross attention features. The dimension of the cross attention features.
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
encoder_hid_dim (`int`, *optional*, defaults to None):
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
dimension to `cross_attention_dim`.
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8): attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
The dimension of the attention heads. The dimension of the attention heads.
use_linear_projection (`bool`, defaults to `False`): use_linear_projection (`bool`, defaults to `False`):
class_embed_type (`str`, *optional*, defaults to `None`): class_embed_type (`str`, *optional*, defaults to `None`):
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None, The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
addition_embed_type (`str`, *optional*, defaults to `None`):
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
"text". "text" will use the `TextTimeEmbedding` layer.
num_class_embeds (`int`, *optional*, defaults to 0): num_class_embeds (`int`, *optional*, defaults to 0):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`. class conditioning with `class_embed_type` equal to `None`.
...@@ -177,10 +190,15 @@ class ControlNetModel(ModelMixin, ConfigMixin): ...@@ -177,10 +190,15 @@ class ControlNetModel(ModelMixin, ConfigMixin):
norm_num_groups: Optional[int] = 32, norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5, norm_eps: float = 1e-5,
cross_attention_dim: int = 1280, cross_attention_dim: int = 1280,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
attention_head_dim: Union[int, Tuple[int]] = 8, attention_head_dim: Union[int, Tuple[int]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int]]] = None, num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
use_linear_projection: bool = False, use_linear_projection: bool = False,
class_embed_type: Optional[str] = None, class_embed_type: Optional[str] = None,
addition_embed_type: Optional[str] = None,
addition_time_embed_dim: Optional[int] = None,
num_class_embeds: Optional[int] = None, num_class_embeds: Optional[int] = None,
upcast_attention: bool = False, upcast_attention: bool = False,
resnet_time_scale_shift: str = "default", resnet_time_scale_shift: str = "default",
...@@ -188,6 +206,7 @@ class ControlNetModel(ModelMixin, ConfigMixin): ...@@ -188,6 +206,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
controlnet_conditioning_channel_order: str = "rgb", controlnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
global_pool_conditions: bool = False, global_pool_conditions: bool = False,
addition_embed_type_num_heads=64,
): ):
super().__init__() super().__init__()
...@@ -215,6 +234,9 @@ class ControlNetModel(ModelMixin, ConfigMixin): ...@@ -215,6 +234,9 @@ class ControlNetModel(ModelMixin, ConfigMixin):
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
) )
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
# input # input
conv_in_kernel = 3 conv_in_kernel = 3
conv_in_padding = (conv_in_kernel - 1) // 2 conv_in_padding = (conv_in_kernel - 1) // 2
...@@ -224,16 +246,43 @@ class ControlNetModel(ModelMixin, ConfigMixin): ...@@ -224,16 +246,43 @@ class ControlNetModel(ModelMixin, ConfigMixin):
# time # time
time_embed_dim = block_out_channels[0] * 4 time_embed_dim = block_out_channels[0] * 4
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0] timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding( self.time_embedding = TimestepEmbedding(
timestep_input_dim, timestep_input_dim,
time_embed_dim, time_embed_dim,
act_fn=act_fn, act_fn=act_fn,
) )
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
encoder_hid_dim_type = "text_proj"
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
raise ValueError(
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
)
if encoder_hid_dim_type == "text_proj":
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
elif encoder_hid_dim_type == "text_image_proj":
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
self.encoder_hid_proj = TextImageProjection(
text_embed_dim=encoder_hid_dim,
image_embed_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim,
)
elif encoder_hid_dim_type is not None:
raise ValueError(
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
)
else:
self.encoder_hid_proj = None
# class embedding # class embedding
if class_embed_type is None and num_class_embeds is not None: if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
...@@ -257,6 +306,29 @@ class ControlNetModel(ModelMixin, ConfigMixin): ...@@ -257,6 +306,29 @@ class ControlNetModel(ModelMixin, ConfigMixin):
else: else:
self.class_embedding = None self.class_embedding = None
if addition_embed_type == "text":
if encoder_hid_dim is not None:
text_time_embedding_from_dim = encoder_hid_dim
else:
text_time_embedding_from_dim = cross_attention_dim
self.add_embedding = TextTimeEmbedding(
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
)
elif addition_embed_type == "text_image":
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
self.add_embedding = TextImageTimeEmbedding(
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
)
elif addition_embed_type == "text_time":
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
elif addition_embed_type is not None:
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
# control net conditioning embedding # control net conditioning embedding
self.controlnet_cond_embedding = ControlNetConditioningEmbedding( self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
conditioning_embedding_channels=block_out_channels[0], conditioning_embedding_channels=block_out_channels[0],
...@@ -291,6 +363,7 @@ class ControlNetModel(ModelMixin, ConfigMixin): ...@@ -291,6 +363,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
down_block = get_down_block( down_block = get_down_block(
down_block_type, down_block_type,
num_layers=layers_per_block, num_layers=layers_per_block,
transformer_layers_per_block=transformer_layers_per_block[i],
in_channels=input_channel, in_channels=input_channel,
out_channels=output_channel, out_channels=output_channel,
temb_channels=time_embed_dim, temb_channels=time_embed_dim,
...@@ -327,6 +400,7 @@ class ControlNetModel(ModelMixin, ConfigMixin): ...@@ -327,6 +400,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
self.controlnet_mid_block = controlnet_block self.controlnet_mid_block = controlnet_block
self.mid_block = UNetMidBlock2DCrossAttn( self.mid_block = UNetMidBlock2DCrossAttn(
transformer_layers_per_block=transformer_layers_per_block[-1],
in_channels=mid_block_channel, in_channels=mid_block_channel,
temb_channels=time_embed_dim, temb_channels=time_embed_dim,
resnet_eps=norm_eps, resnet_eps=norm_eps,
...@@ -356,7 +430,22 @@ class ControlNetModel(ModelMixin, ConfigMixin): ...@@ -356,7 +430,22 @@ class ControlNetModel(ModelMixin, ConfigMixin):
The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
where applicable. where applicable.
""" """
transformer_layers_per_block = (
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
)
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
addition_time_embed_dim = (
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
)
controlnet = cls( controlnet = cls(
encoder_hid_dim=encoder_hid_dim,
encoder_hid_dim_type=encoder_hid_dim_type,
addition_embed_type=addition_embed_type,
addition_time_embed_dim=addition_time_embed_dim,
transformer_layers_per_block=transformer_layers_per_block,
in_channels=unet.config.in_channels, in_channels=unet.config.in_channels,
flip_sin_to_cos=unet.config.flip_sin_to_cos, flip_sin_to_cos=unet.config.flip_sin_to_cos,
freq_shift=unet.config.freq_shift, freq_shift=unet.config.freq_shift,
...@@ -542,6 +631,7 @@ class ControlNetModel(ModelMixin, ConfigMixin): ...@@ -542,6 +631,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
class_labels: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guess_mode: bool = False, guess_mode: bool = False,
return_dict: bool = True, return_dict: bool = True,
...@@ -564,7 +654,9 @@ class ControlNetModel(ModelMixin, ConfigMixin): ...@@ -564,7 +654,9 @@ class ControlNetModel(ModelMixin, ConfigMixin):
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
attention_mask (`torch.Tensor`, *optional*, defaults to `None`): attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
cross_attention_kwargs(`dict[str]`, *optional*, defaults to `None`): added_cond_kwargs (`dict`):
Additional conditions for the Stable Diffusion XL UNet.
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
A kwargs dictionary that if specified is passed along to the `AttnProcessor`. A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
guess_mode (`bool`, defaults to `False`): guess_mode (`bool`, defaults to `False`):
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
...@@ -618,6 +710,7 @@ class ControlNetModel(ModelMixin, ConfigMixin): ...@@ -618,6 +710,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
t_emb = t_emb.to(dtype=sample.dtype) t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb, timestep_cond) emb = self.time_embedding(t_emb, timestep_cond)
aug_emb = None
if self.class_embedding is not None: if self.class_embedding is not None:
if class_labels is None: if class_labels is None:
...@@ -629,6 +722,30 @@ class ControlNetModel(ModelMixin, ConfigMixin): ...@@ -629,6 +722,30 @@ class ControlNetModel(ModelMixin, ConfigMixin):
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb emb = emb + class_emb
if "addition_embed_type" in self.config:
if self.config.addition_embed_type == "text":
aug_emb = self.add_embedding(encoder_hidden_states)
elif self.config.addition_embed_type == "text_time":
if "text_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
)
text_embeds = added_cond_kwargs.get("text_embeds")
if "time_ids" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
)
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds)
emb = emb + aug_emb if aug_emb is not None else emb
# 2. pre-process # 2. pre-process
sample = self.conv_in(sample) sample = self.conv_in(sample)
......
...@@ -120,6 +120,7 @@ try: ...@@ -120,6 +120,7 @@ try:
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403 from ..utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403
else: else:
from .controlnet import StableDiffusionXLControlNetPipeline
from .stable_diffusion_xl import ( from .stable_diffusion_xl import (
StableDiffusionXLImg2ImgPipeline, StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline, StableDiffusionXLInpaintPipeline,
......
from ...utils import ( from ...utils import (
OptionalDependencyNotAvailable, OptionalDependencyNotAvailable,
is_flax_available, is_flax_available,
is_invisible_watermark_available,
is_torch_available, is_torch_available,
is_transformers_available, is_transformers_available,
) )
if is_transformers_available() and is_torch_available() and is_invisible_watermark_available():
from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
try: try:
if not (is_transformers_available() and is_torch_available()): if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
......
...@@ -45,17 +45,17 @@ class MultiControlNetModel(ModelMixin): ...@@ -45,17 +45,17 @@ class MultiControlNetModel(ModelMixin):
) -> Union[ControlNetOutput, Tuple]: ) -> Union[ControlNetOutput, Tuple]:
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)): for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
down_samples, mid_sample = controlnet( down_samples, mid_sample = controlnet(
sample, sample=sample,
timestep, timestep=timestep,
encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
image, controlnet_cond=image,
scale, conditioning_scale=scale,
class_labels, class_labels=class_labels,
timestep_cond, timestep_cond=timestep_cond,
attention_mask, attention_mask=attention_mask,
cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
guess_mode, guess_mode=guess_mode,
return_dict, return_dict=return_dict,
) )
# merge samples # merge samples
......
...@@ -2,6 +2,21 @@ ...@@ -2,6 +2,21 @@
from ..utils import DummyObject, requires_backends from ..utils import DummyObject, requires_backends
class StableDiffusionXLControlNetPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers", "invisible_watermark"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers", "invisible_watermark"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers", "invisible_watermark"])
class StableDiffusionXLImg2ImgPipeline(metaclass=DummyObject): class StableDiffusionXLImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers", "invisible_watermark"] _backends = ["torch", "transformers", "invisible_watermark"]
......
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import (
AutoencoderKL,
ControlNetModel,
EulerDiscreteScheduler,
StableDiffusionXLControlNetPipeline,
UNet2DConditionModel,
)
from diffusers.utils import randn_tensor, torch_device
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
enable_full_determinism,
)
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_IMAGE_PARAMS,
TEXT_TO_IMAGE_PARAMS,
)
from ..test_pipelines_common import (
PipelineKarrasSchedulerTesterMixin,
PipelineLatentTesterMixin,
PipelineTesterMixin,
)
enable_full_determinism()
class ControlNetPipelineSDXLFastTests(
PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
):
pipeline_class = StableDiffusionXLControlNetPipeline
params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
# SD2-specific config below
attention_head_dim=(2, 4),
use_linear_projection=True,
addition_embed_type="text_time",
addition_time_embed_dim=8,
transformer_layers_per_block=(1, 2),
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
cross_attention_dim=64,
)
torch.manual_seed(0)
controlnet = ControlNetModel(
block_out_channels=(32, 64),
layers_per_block=2,
in_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
conditioning_embedding_out_channels=(16, 32),
# SD2-specific config below
attention_head_dim=(2, 4),
use_linear_projection=True,
addition_embed_type="text_time",
addition_time_embed_dim=8,
transformer_layers_per_block=(1, 2),
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
cross_attention_dim=64,
)
torch.manual_seed(0)
scheduler = EulerDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
steps_offset=1,
beta_schedule="scaled_linear",
timestep_spacing="leading",
)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
# SD2-specific config below
hidden_act="gelu",
projection_dim=32,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True)
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True)
components = {
"unet": unet,
"controlnet": controlnet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
controlnet_embedder_scale_factor = 2
image = randn_tensor(
(1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
generator=generator,
device=torch.device(device),
)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
"output_type": "numpy",
"image": image,
}
return inputs
def test_attention_slicing_forward_pass(self):
return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3)
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3)
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment