Unverified Commit fda1531d authored by UmerHA's avatar UmerHA Committed by GitHub
Browse files

Fixing implementation of ControlNet-XS (#6772)



* CheckIn - created DownSubBlocks

* Added extra channels, implemented subblock fwd

* Fixed connection sizes

* checkin

* Removed iter, next in forward

* Models for SD21 & SDXL run through

* Added back pipelines, cleared up connections

* Cleaned up connection creation

* added debug logs

* updated logs

* logs: added input loading

* Update umer_debug_logger.py

* log: Loading hint

* Update umer_debug_logger.py

* added logs

* Changed debug logging

* debug: added more logs

* Fixed num_norm_groups

* Debug: Logging all of SDXL input

* Update umer_debug_logger.py

* debug: updated logs

* checkim

* Readded tests

* Removed debug logs

* Fixed Slow Tests

* Added value ckecks | Updated model_cpu_offload_seq

* accelerate-offloading works ; fast tests work

* Made unet & addon explicit in controlnet

* Updated slow tests

* Added dtype/device to ControlNetXS

* Filled in test model paths

* Added image_encoder/feature_extractor to XL pipe

* Fixed fast tests

* Added comments and docstrings

* Fixed copies

* Added docs ; Updates slow tests

* Moved changes to UNetMidBlock2DCrossAttn

* tiny cleanups

* Removed stray prints

* Removed ip adapters + freeU

- Removed ip adapters + freeU as they don't make sense for ControlNet-XS
- Fixed imports of UNet components

* Fixed test_save_load_float16

* Make style, quality, fix-copies

* Changed loading/saving API for ControlNetXS

- Changed loading/saving API for ControlNetXS
- other small fixes

* Removed ControlNet-XS from research examples

* Make style, quality, fix-copies

* Small fixes

- deleted ControlNetXSModel.init_original
- added time_embedding_mix to StableDiffusionControlNetXSPipeline .from_pretrained / StableDiffusionXLControlNetXSPipeline.from_pretrained
- fixed copy hints

* checkin May 11 '23

* CheckIn Mar 12 '24

* Fixed tests for SD

* Added tests for UNetControlNetXSModel

* Fixed SDXL tests

* cleanup

* Delete Pipfile

* CheckIn Mar 20

Started replacing sub blocks  by `ControlNetXSCrossAttnDownBlock2D` and `ControlNetXSCrossAttnUplock2D`

* check-in Mar 23

* checkin 24 Mar

* Created init for UNetCnxs and CnxsAddon

* CheckIn

* Made from_modules, from_unet and no_control work

* make style,quality,fix-copies & small changes

* Fixed freezing

* Added gradient ckpt'ing; fixed tests

* Fix slow tests(+compile) ; clear naming confusion

* Don't create UNet in init ; removed class_emb

* Incorporated review feedback

- Deleted get_base_pipeline /  get_controlnet_addon for pipes
- Pipes inherit from StableDiffusionXLPipeline
- Made module dicts for cnxs-addon's down/mid/up classes
- Added support for qkv fusion and freeU

* Make style, quality, fix-copies

* Implemented review feedback

* Removed compatibility check for vae/ctrl embedding

* make style, quality, fix-copies

* Delete Pipfile

* Integrated review feedback

- Importing ControlNetConditioningEmbedding now
- get_down/mid/up_block_addon now outside class
- renamed `do_control` to `apply_control`

* Reduced size of test tensors

For this, added `norm_num_groups` as parameter everywhere

* Renamed cnxs-`Addon` to cnxs-`Adapter`

- `ControlNetXSAddon` -> `ControlNetXSAdapter`
- `ControlNetXSAddonDownBlockComponents` -> `DownBlockControlNetXSAdapter`, and similarly for mid/up
- `get_mid_block_addon` -> `get_mid_block_adapter`, and similarly for mid/up

* Fixed save_pretrained/from_pretrained bug

* Removed redundant code

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent cf6e0407
......@@ -282,6 +282,10 @@
title: ControlNet
- local: api/pipelines/controlnet_sdxl
title: ControlNet with Stable Diffusion XL
- local: api/pipelines/controlnetxs
title: ControlNet-XS
- local: api/pipelines/controlnetxs_sdxl
title: ControlNet-XS with Stable Diffusion XL
- local: api/pipelines/dance_diffusion
title: Dance Diffusion
- local: api/pipelines/ddim
......
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# ControlNet-XS
ControlNet-XS was introduced in [ControlNet-XS](https://vislearn.github.io/ControlNet-XS/) by Denis Zavadski and Carsten Rother. It is based on the observation that the control model in the [original ControlNet](https://huggingface.co/papers/2302.05543) can be made much smaller and still produce good results.
......@@ -12,5 +24,16 @@ Here's the overview from the [project page](https://vislearn.github.io/ControlNe
This model was contributed by [UmerHA](https://twitter.com/UmerHAdil). ❤️
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## StableDiffusionControlNetXSPipeline
[[autodoc]] StableDiffusionControlNetXSPipeline
- all
- __call__
> 🧠 Make sure to check out the Schedulers [guide](https://huggingface.co/docs/diffusers/main/en/using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
\ No newline at end of file
## StableDiffusionPipelineOutput
[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# ControlNet-XS with Stable Diffusion XL
ControlNet-XS was introduced in [ControlNet-XS](https://vislearn.github.io/ControlNet-XS/) by Denis Zavadski and Carsten Rother. It is based on the observation that the control model in the [original ControlNet](https://huggingface.co/papers/2302.05543) can be made much smaller and still produce good results.
......@@ -12,4 +24,22 @@ Here's the overview from the [project page](https://vislearn.github.io/ControlNe
This model was contributed by [UmerHA](https://twitter.com/UmerHAdil). ❤️
> 🧠 Make sure to check out the Schedulers [guide](https://huggingface.co/docs/diffusers/main/en/using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
\ No newline at end of file
<Tip warning={true}>
🧪 Many of the SDXL ControlNet checkpoints are experimental, and there is a lot of room for improvement. Feel free to open an [Issue](https://github.com/huggingface/diffusers/issues/new/choose) and leave us feedback on how we can improve!
</Tip>
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## StableDiffusionXLControlNetXSPipeline
[[autodoc]] StableDiffusionXLControlNetXSPipeline
- all
- __call__
## StableDiffusionPipelineOutput
[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import functional as F
from torch.nn.modules.normalization import GroupNorm
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.attention_processor import USE_PEFT_BACKEND, AttentionProcessor
from diffusers.models.autoencoders import AutoencoderKL
from diffusers.models.lora import LoRACompatibleConv
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.unets.unet_2d_blocks import (
CrossAttnDownBlock2D,
CrossAttnUpBlock2D,
DownBlock2D,
Downsample2D,
ResnetBlock2D,
Transformer2DModel,
UpBlock2D,
Upsample2D,
)
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.utils import BaseOutput, logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class ControlNetXSOutput(BaseOutput):
"""
The output of [`ControlNetXSModel`].
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
The output of the `ControlNetXSModel`. Unlike `ControlNetOutput` this is NOT to be added to the base model
output, but is already the final output.
"""
sample: torch.FloatTensor = None
# copied from diffusers.models.controlnet.ControlNetConditioningEmbedding
class ControlNetConditioningEmbedding(nn.Module):
"""
Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
model) to encode image-space conditions ... into feature maps ..."
"""
def __init__(
self,
conditioning_embedding_channels: int,
conditioning_channels: int = 3,
block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
):
super().__init__()
self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
self.blocks = nn.ModuleList([])
for i in range(len(block_out_channels) - 1):
channel_in = block_out_channels[i]
channel_out = block_out_channels[i + 1]
self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
self.conv_out = zero_module(
nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
)
def forward(self, conditioning):
embedding = self.conv_in(conditioning)
embedding = F.silu(embedding)
for block in self.blocks:
embedding = block(embedding)
embedding = F.silu(embedding)
embedding = self.conv_out(embedding)
return embedding
class ControlNetXSModel(ModelMixin, ConfigMixin):
r"""
A ControlNet-XS model
This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic
methods implemented for all models (such as downloading or saving).
Most of parameters for this model are passed into the [`UNet2DConditionModel`] it creates. Check the documentation
of [`UNet2DConditionModel`] for them.
Parameters:
conditioning_channels (`int`, defaults to 3):
Number of channels of conditioning input (e.g. an image)
controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`):
The tuple of output channel for each block in the `controlnet_cond_embedding` layer.
time_embedding_input_dim (`int`, defaults to 320):
Dimension of input into time embedding. Needs to be same as in the base model.
time_embedding_dim (`int`, defaults to 1280):
Dimension of output from time embedding. Needs to be same as in the base model.
learn_embedding (`bool`, defaults to `False`):
Whether to use time embedding of the control model. If yes, the time embedding is a linear interpolation of
the time embeddings of the control and base model with interpolation parameter `time_embedding_mix**3`.
time_embedding_mix (`float`, defaults to 1.0):
Linear interpolation parameter used if `learn_embedding` is `True`. A value of 1.0 means only the
control model's time embedding will be used. A value of 0.0 means only the base model's time embedding will be used.
base_model_channel_sizes (`Dict[str, List[Tuple[int]]]`):
Channel sizes of each subblock of base model. Use `gather_subblock_sizes` on your base model to compute it.
"""
@classmethod
def init_original(cls, base_model: UNet2DConditionModel, is_sdxl=True):
"""
Create a ControlNetXS model with the same parameters as in the original paper (https://github.com/vislearn/ControlNet-XS).
Parameters:
base_model (`UNet2DConditionModel`):
Base UNet model. Needs to be either StableDiffusion or StableDiffusion-XL.
is_sdxl (`bool`, defaults to `True`):
Whether passed `base_model` is a StableDiffusion-XL model.
"""
def get_dim_attn_heads(base_model: UNet2DConditionModel, size_ratio: float, num_attn_heads: int):
"""
Currently, diffusers can only set the dimension of attention heads (see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why).
The original ControlNet-XS model, however, define the number of attention heads.
That's why compute the dimensions needed to get the correct number of attention heads.
"""
block_out_channels = [int(size_ratio * c) for c in base_model.config.block_out_channels]
dim_attn_heads = [math.ceil(c / num_attn_heads) for c in block_out_channels]
return dim_attn_heads
if is_sdxl:
return ControlNetXSModel.from_unet(
base_model,
time_embedding_mix=0.95,
learn_embedding=True,
size_ratio=0.1,
conditioning_embedding_out_channels=(16, 32, 96, 256),
num_attention_heads=get_dim_attn_heads(base_model, 0.1, 64),
)
else:
return ControlNetXSModel.from_unet(
base_model,
time_embedding_mix=1.0,
learn_embedding=True,
size_ratio=0.0125,
conditioning_embedding_out_channels=(16, 32, 96, 256),
num_attention_heads=get_dim_attn_heads(base_model, 0.0125, 8),
)
@classmethod
def _gather_subblock_sizes(cls, unet: UNet2DConditionModel, base_or_control: str):
"""To create correctly sized connections between base and control model, we need to know
the input and output channels of each subblock.
Parameters:
unet (`UNet2DConditionModel`):
Unet of which the subblock channels sizes are to be gathered.
base_or_control (`str`):
Needs to be either "base" or "control". If "base", decoder is also considered.
"""
if base_or_control not in ["base", "control"]:
raise ValueError("`base_or_control` needs to be either `base` or `control`")
channel_sizes = {"down": [], "mid": [], "up": []}
# input convolution
channel_sizes["down"].append((unet.conv_in.in_channels, unet.conv_in.out_channels))
# encoder blocks
for module in unet.down_blocks:
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
for r in module.resnets:
channel_sizes["down"].append((r.in_channels, r.out_channels))
if module.downsamplers:
channel_sizes["down"].append(
(module.downsamplers[0].channels, module.downsamplers[0].out_channels)
)
else:
raise ValueError(f"Encountered unknown module of type {type(module)} while creating ControlNet-XS.")
# middle block
channel_sizes["mid"].append((unet.mid_block.resnets[0].in_channels, unet.mid_block.resnets[0].out_channels))
# decoder blocks
if base_or_control == "base":
for module in unet.up_blocks:
if isinstance(module, (CrossAttnUpBlock2D, UpBlock2D)):
for r in module.resnets:
channel_sizes["up"].append((r.in_channels, r.out_channels))
else:
raise ValueError(
f"Encountered unknown module of type {type(module)} while creating ControlNet-XS."
)
return channel_sizes
@register_to_config
def __init__(
self,
conditioning_channels: int = 3,
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
controlnet_conditioning_channel_order: str = "rgb",
time_embedding_input_dim: int = 320,
time_embedding_dim: int = 1280,
time_embedding_mix: float = 1.0,
learn_embedding: bool = False,
base_model_channel_sizes: Dict[str, List[Tuple[int]]] = {
"down": [
(4, 320),
(320, 320),
(320, 320),
(320, 320),
(320, 640),
(640, 640),
(640, 640),
(640, 1280),
(1280, 1280),
],
"mid": [(1280, 1280)],
"up": [
(2560, 1280),
(2560, 1280),
(1920, 1280),
(1920, 640),
(1280, 640),
(960, 640),
(960, 320),
(640, 320),
(640, 320),
],
},
sample_size: Optional[int] = None,
down_block_types: Tuple[str] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
norm_num_groups: Optional[int] = 32,
cross_attention_dim: Union[int, Tuple[int]] = 1280,
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
num_attention_heads: Optional[Union[int, Tuple[int]]] = 8,
upcast_attention: bool = False,
):
super().__init__()
# 1 - Create control unet
self.control_model = UNet2DConditionModel(
sample_size=sample_size,
down_block_types=down_block_types,
up_block_types=up_block_types,
block_out_channels=block_out_channels,
norm_num_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
transformer_layers_per_block=transformer_layers_per_block,
attention_head_dim=num_attention_heads,
use_linear_projection=True,
upcast_attention=upcast_attention,
time_embedding_dim=time_embedding_dim,
)
# 2 - Do model surgery on control model
# 2.1 - Allow to use the same time information as the base model
adjust_time_dims(self.control_model, time_embedding_input_dim, time_embedding_dim)
# 2.2 - Allow for information infusion from base model
# We concat the output of each base encoder subblocks to the input of the next control encoder subblock
# (We ignore the 1st element, as it represents the `conv_in`.)
extra_input_channels = [input_channels for input_channels, _ in base_model_channel_sizes["down"][1:]]
it_extra_input_channels = iter(extra_input_channels)
for b, block in enumerate(self.control_model.down_blocks):
for r in range(len(block.resnets)):
increase_block_input_in_encoder_resnet(
self.control_model, block_no=b, resnet_idx=r, by=next(it_extra_input_channels)
)
if block.downsamplers:
increase_block_input_in_encoder_downsampler(
self.control_model, block_no=b, by=next(it_extra_input_channels)
)
increase_block_input_in_mid_resnet(self.control_model, by=extra_input_channels[-1])
# 2.3 - Make group norms work with modified channel sizes
adjust_group_norms(self.control_model)
# 3 - Gather Channel Sizes
self.ch_inout_ctrl = ControlNetXSModel._gather_subblock_sizes(self.control_model, base_or_control="control")
self.ch_inout_base = base_model_channel_sizes
# 4 - Build connections between base and control model
self.down_zero_convs_out = nn.ModuleList([])
self.down_zero_convs_in = nn.ModuleList([])
self.middle_block_out = nn.ModuleList([])
self.middle_block_in = nn.ModuleList([])
self.up_zero_convs_out = nn.ModuleList([])
self.up_zero_convs_in = nn.ModuleList([])
for ch_io_base in self.ch_inout_base["down"]:
self.down_zero_convs_in.append(self._make_zero_conv(in_channels=ch_io_base[1], out_channels=ch_io_base[1]))
for i in range(len(self.ch_inout_ctrl["down"])):
self.down_zero_convs_out.append(
self._make_zero_conv(self.ch_inout_ctrl["down"][i][1], self.ch_inout_base["down"][i][1])
)
self.middle_block_out = self._make_zero_conv(
self.ch_inout_ctrl["mid"][-1][1], self.ch_inout_base["mid"][-1][1]
)
self.up_zero_convs_out.append(
self._make_zero_conv(self.ch_inout_ctrl["down"][-1][1], self.ch_inout_base["mid"][-1][1])
)
for i in range(1, len(self.ch_inout_ctrl["down"])):
self.up_zero_convs_out.append(
self._make_zero_conv(self.ch_inout_ctrl["down"][-(i + 1)][1], self.ch_inout_base["up"][i - 1][1])
)
# 5 - Create conditioning hint embedding
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
conditioning_embedding_channels=block_out_channels[0],
block_out_channels=conditioning_embedding_out_channels,
conditioning_channels=conditioning_channels,
)
# In the mininal implementation setting, we only need the control model up to the mid block
del self.control_model.up_blocks
del self.control_model.conv_norm_out
del self.control_model.conv_out
@classmethod
def from_unet(
cls,
unet: UNet2DConditionModel,
conditioning_channels: int = 3,
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
controlnet_conditioning_channel_order: str = "rgb",
learn_embedding: bool = False,
time_embedding_mix: float = 1.0,
block_out_channels: Optional[Tuple[int]] = None,
size_ratio: Optional[float] = None,
num_attention_heads: Optional[Union[int, Tuple[int]]] = 8,
norm_num_groups: Optional[int] = None,
):
r"""
Instantiate a [`ControlNetXSModel`] from [`UNet2DConditionModel`].
Parameters:
unet (`UNet2DConditionModel`):
The UNet model we want to control. The dimensions of the ControlNetXSModel will be adapted to it.
conditioning_channels (`int`, defaults to 3):
Number of channels of conditioning input (e.g. an image)
conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`):
The tuple of output channel for each block in the `controlnet_cond_embedding` layer.
controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
learn_embedding (`bool`, defaults to `False`):
Wether to use time embedding of the control model. If yes, the time embedding is a linear interpolation
of the time embeddings of the control and base model with interpolation parameter
`time_embedding_mix**3`.
time_embedding_mix (`float`, defaults to 1.0):
Linear interpolation parameter used if `learn_embedding` is `True`.
block_out_channels (`Tuple[int]`, *optional*):
Down blocks output channels in control model. Either this or `size_ratio` must be given.
size_ratio (float, *optional*):
When given, block_out_channels is set to a relative fraction of the base model's block_out_channels.
Either this or `block_out_channels` must be given.
num_attention_heads (`Union[int, Tuple[int]]`, *optional*):
The dimension of the attention heads. The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why.
norm_num_groups (int, *optional*, defaults to `None`):
The number of groups to use for the normalization of the control unet. If `None`,
`int(unet.config.norm_num_groups * size_ratio)` is taken.
"""
# Check input
fixed_size = block_out_channels is not None
relative_size = size_ratio is not None
if not (fixed_size ^ relative_size):
raise ValueError(
"Pass exactly one of `block_out_channels` (for absolute sizing) or `control_model_ratio` (for relative sizing)."
)
# Create model
if block_out_channels is None:
block_out_channels = [int(size_ratio * c) for c in unet.config.block_out_channels]
# Check that attention heads and group norms match channel sizes
# - attention heads
def attn_heads_match_channel_sizes(attn_heads, channel_sizes):
if isinstance(attn_heads, (tuple, list)):
return all(c % a == 0 for a, c in zip(attn_heads, channel_sizes))
else:
return all(c % attn_heads == 0 for c in channel_sizes)
num_attention_heads = num_attention_heads or unet.config.attention_head_dim
if not attn_heads_match_channel_sizes(num_attention_heads, block_out_channels):
raise ValueError(
f"The dimension of attention heads ({num_attention_heads}) must divide `block_out_channels` ({block_out_channels}). If you didn't set `num_attention_heads` the default settings don't match your model. Set `num_attention_heads` manually."
)
# - group norms
def group_norms_match_channel_sizes(num_groups, channel_sizes):
return all(c % num_groups == 0 for c in channel_sizes)
if norm_num_groups is None:
if group_norms_match_channel_sizes(unet.config.norm_num_groups, block_out_channels):
norm_num_groups = unet.config.norm_num_groups
else:
norm_num_groups = min(block_out_channels)
if group_norms_match_channel_sizes(norm_num_groups, block_out_channels):
print(
f"`norm_num_groups` was set to `min(block_out_channels)` (={norm_num_groups}) so it divides all block_out_channels` ({block_out_channels}). Set it explicitly to remove this information."
)
else:
raise ValueError(
f"`block_out_channels` ({block_out_channels}) don't match the base models `norm_num_groups` ({unet.config.norm_num_groups}). Setting `norm_num_groups` to `min(block_out_channels)` ({norm_num_groups}) didn't fix this. Pass `norm_num_groups` explicitly so it divides all block_out_channels."
)
def get_time_emb_input_dim(unet: UNet2DConditionModel):
return unet.time_embedding.linear_1.in_features
def get_time_emb_dim(unet: UNet2DConditionModel):
return unet.time_embedding.linear_2.out_features
# Clone params from base unet if
# (i) it's required to build SD or SDXL, and
# (ii) it's not used for the time embedding (as time embedding of control model is never used), and
# (iii) it's not set further below anyway
to_keep = [
"cross_attention_dim",
"down_block_types",
"sample_size",
"transformer_layers_per_block",
"up_block_types",
"upcast_attention",
]
kwargs = {k: v for k, v in dict(unet.config).items() if k in to_keep}
kwargs.update(block_out_channels=block_out_channels)
kwargs.update(num_attention_heads=num_attention_heads)
kwargs.update(norm_num_groups=norm_num_groups)
# Add controlnetxs-specific params
kwargs.update(
conditioning_channels=conditioning_channels,
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
time_embedding_input_dim=get_time_emb_input_dim(unet),
time_embedding_dim=get_time_emb_dim(unet),
time_embedding_mix=time_embedding_mix,
learn_embedding=learn_embedding,
base_model_channel_sizes=ControlNetXSModel._gather_subblock_sizes(unet, base_or_control="base"),
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
)
return cls(**kwargs)
@property
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
return self.control_model.attn_processors
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
self.control_model.set_attn_processor(processor)
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
self.control_model.set_default_attn_processor()
def set_attention_slice(self, slice_size):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
Args:
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
"""
self.control_model.set_attention_slice(slice_size)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (UNet2DConditionModel)):
if value:
module.enable_gradient_checkpointing()
else:
module.disable_gradient_checkpointing()
def forward(
self,
base_model: UNet2DConditionModel,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
controlnet_cond: torch.Tensor,
conditioning_scale: float = 1.0,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
return_dict: bool = True,
) -> Union[ControlNetXSOutput, Tuple]:
"""
The [`ControlNetModel`] forward method.
Args:
base_model (`UNet2DConditionModel`):
The base unet model we want to control.
sample (`torch.FloatTensor`):
The noisy input tensor.
timestep (`Union[torch.Tensor, float, int]`):
The number of timesteps to denoise an input.
encoder_hidden_states (`torch.Tensor`):
The encoder hidden states.
controlnet_cond (`torch.FloatTensor`):
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
conditioning_scale (`float`, defaults to `1.0`):
How much the control model affects the base model outputs.
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
embeddings.
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
negative values to the attention scores corresponding to "discard" tokens.
added_cond_kwargs (`dict`):
Additional conditions for the Stable Diffusion XL UNet.
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
Returns:
[`~models.controlnetxs.ControlNetXSOutput`] **or** `tuple`:
If `return_dict` is `True`, a [`~models.controlnetxs.ControlNetXSOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
# check channel order
channel_order = self.config.controlnet_conditioning_channel_order
if channel_order == "rgb":
# in rgb order by default
...
elif channel_order == "bgr":
controlnet_cond = torch.flip(controlnet_cond, dims=[1])
else:
raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
# scale control strength
n_connections = len(self.down_zero_convs_out) + 1 + len(self.up_zero_convs_out)
scale_list = torch.full((n_connections,), conditioning_scale)
# prepare attention_mask
if attention_mask is not None:
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = base_model.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=sample.dtype)
if self.config.learn_embedding:
ctrl_temb = self.control_model.time_embedding(t_emb, timestep_cond)
base_temb = base_model.time_embedding(t_emb, timestep_cond)
interpolation_param = self.config.time_embedding_mix**0.3
temb = ctrl_temb * interpolation_param + base_temb * (1 - interpolation_param)
else:
temb = base_model.time_embedding(t_emb)
# added time & text embeddings
aug_emb = None
if base_model.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
if base_model.config.class_embed_type == "timestep":
class_labels = base_model.time_proj(class_labels)
class_emb = base_model.class_embedding(class_labels).to(dtype=self.dtype)
temb = temb + class_emb
if base_model.config.addition_embed_type is not None:
if base_model.config.addition_embed_type == "text":
aug_emb = base_model.add_embedding(encoder_hidden_states)
elif base_model.config.addition_embed_type == "text_image":
raise NotImplementedError()
elif base_model.config.addition_embed_type == "text_time":
# SDXL - style
if "text_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
)
text_embeds = added_cond_kwargs.get("text_embeds")
if "time_ids" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
)
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = base_model.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(temb.dtype)
aug_emb = base_model.add_embedding(add_embeds)
elif base_model.config.addition_embed_type == "image":
raise NotImplementedError()
elif base_model.config.addition_embed_type == "image_hint":
raise NotImplementedError()
temb = temb + aug_emb if aug_emb is not None else temb
# text embeddings
cemb = encoder_hidden_states
# Preparation
guided_hint = self.controlnet_cond_embedding(controlnet_cond)
h_ctrl = h_base = sample
hs_base, hs_ctrl = [], []
it_down_convs_in, it_down_convs_out, it_dec_convs_in, it_up_convs_out = map(
iter, (self.down_zero_convs_in, self.down_zero_convs_out, self.up_zero_convs_in, self.up_zero_convs_out)
)
scales = iter(scale_list)
base_down_subblocks = to_sub_blocks(base_model.down_blocks)
ctrl_down_subblocks = to_sub_blocks(self.control_model.down_blocks)
base_mid_subblocks = to_sub_blocks([base_model.mid_block])
ctrl_mid_subblocks = to_sub_blocks([self.control_model.mid_block])
base_up_subblocks = to_sub_blocks(base_model.up_blocks)
# Cross Control
# 0 - conv in
h_base = base_model.conv_in(h_base)
h_ctrl = self.control_model.conv_in(h_ctrl)
if guided_hint is not None:
h_ctrl += guided_hint
h_base = h_base + next(it_down_convs_out)(h_ctrl) * next(scales) # D - add ctrl -> base
hs_base.append(h_base)
hs_ctrl.append(h_ctrl)
# 1 - down
for m_base, m_ctrl in zip(base_down_subblocks, ctrl_down_subblocks):
h_ctrl = torch.cat([h_ctrl, next(it_down_convs_in)(h_base)], dim=1) # A - concat base -> ctrl
h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # B - apply base subblock
h_ctrl = m_ctrl(h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs) # C - apply ctrl subblock
h_base = h_base + next(it_down_convs_out)(h_ctrl) * next(scales) # D - add ctrl -> base
hs_base.append(h_base)
hs_ctrl.append(h_ctrl)
# 2 - mid
h_ctrl = torch.cat([h_ctrl, next(it_down_convs_in)(h_base)], dim=1) # A - concat base -> ctrl
for m_base, m_ctrl in zip(base_mid_subblocks, ctrl_mid_subblocks):
h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # B - apply base subblock
h_ctrl = m_ctrl(h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs) # C - apply ctrl subblock
h_base = h_base + self.middle_block_out(h_ctrl) * next(scales) # D - add ctrl -> base
# 3 - up
for i, m_base in enumerate(base_up_subblocks):
h_base = h_base + next(it_up_convs_out)(hs_ctrl.pop()) * next(scales) # add info from ctrl encoder
h_base = torch.cat([h_base, hs_base.pop()], dim=1) # concat info from base encoder+ctrl encoder
h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs)
h_base = base_model.conv_norm_out(h_base)
h_base = base_model.conv_act(h_base)
h_base = base_model.conv_out(h_base)
if not return_dict:
return h_base
return ControlNetXSOutput(sample=h_base)
def _make_zero_conv(self, in_channels, out_channels=None):
# keep running track of channels sizes
self.in_channels = in_channels
self.out_channels = out_channels or in_channels
return zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0))
@torch.no_grad()
def _check_if_vae_compatible(self, vae: AutoencoderKL):
condition_downscale_factor = 2 ** (len(self.config.conditioning_embedding_out_channels) - 1)
vae_downscale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
compatible = condition_downscale_factor == vae_downscale_factor
return compatible, condition_downscale_factor, vae_downscale_factor
class SubBlock(nn.ModuleList):
"""A SubBlock is the largest piece of either base or control model, that is executed independently of the other model respectively.
Before each subblock, information is concatted from base to control. And after each subblock, information is added from control to base.
"""
def __init__(self, ms, *args, **kwargs):
if not is_iterable(ms):
ms = [ms]
super().__init__(ms, *args, **kwargs)
def forward(
self,
x: torch.Tensor,
temb: torch.Tensor,
cemb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
):
"""Iterate through children and pass correct information to each."""
for m in self:
if isinstance(m, ResnetBlock2D):
x = m(x, temb)
elif isinstance(m, Transformer2DModel):
x = m(x, cemb, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs).sample
elif isinstance(m, Downsample2D):
x = m(x)
elif isinstance(m, Upsample2D):
x = m(x)
else:
raise ValueError(
f"Type of m is {type(m)} but should be `ResnetBlock2D`, `Transformer2DModel`, `Downsample2D` or `Upsample2D`"
)
return x
def adjust_time_dims(unet: UNet2DConditionModel, in_dim: int, out_dim: int):
unet.time_embedding.linear_1 = nn.Linear(in_dim, out_dim)
def increase_block_input_in_encoder_resnet(unet: UNet2DConditionModel, block_no, resnet_idx, by):
"""Increase channels sizes to allow for additional concatted information from base model"""
r = unet.down_blocks[block_no].resnets[resnet_idx]
old_norm1, old_conv1 = r.norm1, r.conv1
# norm
norm_args = "num_groups num_channels eps affine".split(" ")
for a in norm_args:
assert hasattr(old_norm1, a)
norm_kwargs = {a: getattr(old_norm1, a) for a in norm_args}
norm_kwargs["num_channels"] += by # surgery done here
# conv1
conv1_args = [
"in_channels",
"out_channels",
"kernel_size",
"stride",
"padding",
"dilation",
"groups",
"bias",
"padding_mode",
]
if not USE_PEFT_BACKEND:
conv1_args.append("lora_layer")
for a in conv1_args:
assert hasattr(old_conv1, a)
conv1_kwargs = {a: getattr(old_conv1, a) for a in conv1_args}
conv1_kwargs["bias"] = "bias" in conv1_kwargs # as param, bias is a boolean, but as attr, it's a tensor.
conv1_kwargs["in_channels"] += by # surgery done here
# conv_shortcut
# as we changed the input size of the block, the input and output sizes are likely different,
# therefore we need a conv_shortcut (simply adding won't work)
conv_shortcut_args_kwargs = {
"in_channels": conv1_kwargs["in_channels"],
"out_channels": conv1_kwargs["out_channels"],
# default arguments from resnet.__init__
"kernel_size": 1,
"stride": 1,
"padding": 0,
"bias": True,
}
# swap old with new modules
unet.down_blocks[block_no].resnets[resnet_idx].norm1 = GroupNorm(**norm_kwargs)
unet.down_blocks[block_no].resnets[resnet_idx].conv1 = (
nn.Conv2d(**conv1_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv1_kwargs)
)
unet.down_blocks[block_no].resnets[resnet_idx].conv_shortcut = (
nn.Conv2d(**conv_shortcut_args_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv_shortcut_args_kwargs)
)
unet.down_blocks[block_no].resnets[resnet_idx].in_channels += by # surgery done here
def increase_block_input_in_encoder_downsampler(unet: UNet2DConditionModel, block_no, by):
"""Increase channels sizes to allow for additional concatted information from base model"""
old_down = unet.down_blocks[block_no].downsamplers[0].conv
args = [
"in_channels",
"out_channels",
"kernel_size",
"stride",
"padding",
"dilation",
"groups",
"bias",
"padding_mode",
]
if not USE_PEFT_BACKEND:
args.append("lora_layer")
for a in args:
assert hasattr(old_down, a)
kwargs = {a: getattr(old_down, a) for a in args}
kwargs["bias"] = "bias" in kwargs # as param, bias is a boolean, but as attr, it's a tensor.
kwargs["in_channels"] += by # surgery done here
# swap old with new modules
unet.down_blocks[block_no].downsamplers[0].conv = (
nn.Conv2d(**kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**kwargs)
)
unet.down_blocks[block_no].downsamplers[0].channels += by # surgery done here
def increase_block_input_in_mid_resnet(unet: UNet2DConditionModel, by):
"""Increase channels sizes to allow for additional concatted information from base model"""
m = unet.mid_block.resnets[0]
old_norm1, old_conv1 = m.norm1, m.conv1
# norm
norm_args = "num_groups num_channels eps affine".split(" ")
for a in norm_args:
assert hasattr(old_norm1, a)
norm_kwargs = {a: getattr(old_norm1, a) for a in norm_args}
norm_kwargs["num_channels"] += by # surgery done here
conv1_args = [
"in_channels",
"out_channels",
"kernel_size",
"stride",
"padding",
"dilation",
"groups",
"bias",
"padding_mode",
]
if not USE_PEFT_BACKEND:
conv1_args.append("lora_layer")
conv1_kwargs = {a: getattr(old_conv1, a) for a in conv1_args}
conv1_kwargs["bias"] = "bias" in conv1_kwargs # as param, bias is a boolean, but as attr, it's a tensor.
conv1_kwargs["in_channels"] += by # surgery done here
# conv_shortcut
# as we changed the input size of the block, the input and output sizes are likely different,
# therefore we need a conv_shortcut (simply adding won't work)
conv_shortcut_args_kwargs = {
"in_channels": conv1_kwargs["in_channels"],
"out_channels": conv1_kwargs["out_channels"],
# default arguments from resnet.__init__
"kernel_size": 1,
"stride": 1,
"padding": 0,
"bias": True,
}
# swap old with new modules
unet.mid_block.resnets[0].norm1 = GroupNorm(**norm_kwargs)
unet.mid_block.resnets[0].conv1 = (
nn.Conv2d(**conv1_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv1_kwargs)
)
unet.mid_block.resnets[0].conv_shortcut = (
nn.Conv2d(**conv_shortcut_args_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv_shortcut_args_kwargs)
)
unet.mid_block.resnets[0].in_channels += by # surgery done here
def adjust_group_norms(unet: UNet2DConditionModel, max_num_group: int = 32):
def find_denominator(number, start):
if start >= number:
return number
while start != 0:
residual = number % start
if residual == 0:
return start
start -= 1
for block in [*unet.down_blocks, unet.mid_block]:
# resnets
for r in block.resnets:
if r.norm1.num_groups < max_num_group:
r.norm1.num_groups = find_denominator(r.norm1.num_channels, start=max_num_group)
if r.norm2.num_groups < max_num_group:
r.norm2.num_groups = find_denominator(r.norm2.num_channels, start=max_num_group)
# transformers
if hasattr(block, "attentions"):
for a in block.attentions:
if a.norm.num_groups < max_num_group:
a.norm.num_groups = find_denominator(a.norm.num_channels, start=max_num_group)
def is_iterable(o):
if isinstance(o, str):
return False
try:
iter(o)
return True
except TypeError:
return False
def to_sub_blocks(blocks):
if not is_iterable(blocks):
blocks = [blocks]
sub_blocks = []
for b in blocks:
if hasattr(b, "resnets"):
if hasattr(b, "attentions") and b.attentions is not None:
for r, a in zip(b.resnets, b.attentions):
sub_blocks.append([r, a])
num_resnets = len(b.resnets)
num_attns = len(b.attentions)
if num_resnets > num_attns:
# we can have more resnets than attentions, so add each resnet as separate subblock
for i in range(num_attns, num_resnets):
sub_blocks.append([b.resnets[i]])
else:
for r in b.resnets:
sub_blocks.append([r])
# upsamplers are part of the same subblock
if hasattr(b, "upsamplers") and b.upsamplers is not None:
for u in b.upsamplers:
sub_blocks[-1].extend([u])
# downsamplers are own subblock
if hasattr(b, "downsamplers") and b.downsamplers is not None:
for d in b.downsamplers:
sub_blocks.append([d])
return list(map(SubBlock, sub_blocks))
def zero_module(module):
for p in module.parameters():
nn.init.zeros_(p)
return module
# !pip install opencv-python transformers accelerate
import argparse
import cv2
import numpy as np
import torch
from controlnetxs import ControlNetXSModel
from PIL import Image
from pipeline_controlnet_xs import StableDiffusionControlNetXSPipeline
from diffusers.utils import load_image
parser = argparse.ArgumentParser()
parser.add_argument(
"--prompt", type=str, default="aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
)
parser.add_argument("--negative_prompt", type=str, default="low quality, bad quality, sketches")
parser.add_argument("--controlnet_conditioning_scale", type=float, default=0.7)
parser.add_argument(
"--image_path",
type=str,
default="https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png",
)
parser.add_argument("--num_inference_steps", type=int, default=50)
args = parser.parse_args()
prompt = args.prompt
negative_prompt = args.negative_prompt
# download an image
image = load_image(args.image_path)
# initialize the models and pipeline
controlnet_conditioning_scale = args.controlnet_conditioning_scale
controlnet = ControlNetXSModel.from_pretrained("UmerHA/ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1", controlnet=controlnet, torch_dtype=torch.float16
)
pipe.enable_model_cpu_offload()
# get canny image
image = np.array(image)
image = cv2.Canny(image, 100, 200)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
canny_image = Image.fromarray(image)
num_inference_steps = args.num_inference_steps
# generate image
image = pipe(
prompt,
controlnet_conditioning_scale=controlnet_conditioning_scale,
image=canny_image,
num_inference_steps=num_inference_steps,
).images[0]
image.save("cnxs_sd.canny.png")
# !pip install opencv-python transformers accelerate
import argparse
import cv2
import numpy as np
import torch
from controlnetxs import ControlNetXSModel
from PIL import Image
from pipeline_controlnet_xs import StableDiffusionControlNetXSPipeline
from diffusers.utils import load_image
parser = argparse.ArgumentParser()
parser.add_argument(
"--prompt", type=str, default="aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
)
parser.add_argument("--negative_prompt", type=str, default="low quality, bad quality, sketches")
parser.add_argument("--controlnet_conditioning_scale", type=float, default=0.7)
parser.add_argument(
"--image_path",
type=str,
default="https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png",
)
parser.add_argument("--num_inference_steps", type=int, default=50)
args = parser.parse_args()
prompt = args.prompt
negative_prompt = args.negative_prompt
# download an image
image = load_image(args.image_path)
# initialize the models and pipeline
controlnet_conditioning_scale = args.controlnet_conditioning_scale
controlnet = ControlNetXSModel.from_pretrained("UmerHA/ConrolNetXS-SDXL-canny", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
)
pipe.enable_model_cpu_offload()
# get canny image
image = np.array(image)
image = cv2.Canny(image, 100, 200)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
canny_image = Image.fromarray(image)
num_inference_steps = args.num_inference_steps
# generate image
image = pipe(
prompt,
controlnet_conditioning_scale=controlnet_conditioning_scale,
image=canny_image,
num_inference_steps=num_inference_steps,
).images[0]
image.save("cnxs_sdxl.canny.png")
......@@ -80,6 +80,7 @@ else:
"AutoencoderTiny",
"ConsistencyDecoderVAE",
"ControlNetModel",
"ControlNetXSAdapter",
"I2VGenXLUNet",
"Kandinsky3UNet",
"ModelMixin",
......@@ -94,6 +95,7 @@ else:
"UNet2DConditionModel",
"UNet2DModel",
"UNet3DConditionModel",
"UNetControlNetXSModel",
"UNetMotionModel",
"UNetSpatioTemporalConditionModel",
"UVit2DModel",
......@@ -270,6 +272,7 @@ else:
"StableDiffusionControlNetImg2ImgPipeline",
"StableDiffusionControlNetInpaintPipeline",
"StableDiffusionControlNetPipeline",
"StableDiffusionControlNetXSPipeline",
"StableDiffusionDepth2ImgPipeline",
"StableDiffusionDiffEditPipeline",
"StableDiffusionGLIGENPipeline",
......@@ -293,6 +296,7 @@ else:
"StableDiffusionXLControlNetImg2ImgPipeline",
"StableDiffusionXLControlNetInpaintPipeline",
"StableDiffusionXLControlNetPipeline",
"StableDiffusionXLControlNetXSPipeline",
"StableDiffusionXLImg2ImgPipeline",
"StableDiffusionXLInpaintPipeline",
"StableDiffusionXLInstructPix2PixPipeline",
......@@ -474,6 +478,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderTiny,
ConsistencyDecoderVAE,
ControlNetModel,
ControlNetXSAdapter,
I2VGenXLUNet,
Kandinsky3UNet,
ModelMixin,
......@@ -487,6 +492,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
UNet2DConditionModel,
UNet2DModel,
UNet3DConditionModel,
UNetControlNetXSModel,
UNetMotionModel,
UNetSpatioTemporalConditionModel,
UVit2DModel,
......@@ -642,6 +648,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline,
StableDiffusionControlNetPipeline,
StableDiffusionControlNetXSPipeline,
StableDiffusionDepth2ImgPipeline,
StableDiffusionDiffEditPipeline,
StableDiffusionGLIGENPipeline,
......@@ -665,6 +672,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionXLControlNetImg2ImgPipeline,
StableDiffusionXLControlNetInpaintPipeline,
StableDiffusionXLControlNetPipeline,
StableDiffusionXLControlNetXSPipeline,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLInstructPix2PixPipeline,
......
......@@ -32,6 +32,7 @@ if is_torch_available():
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["controlnet"] = ["ControlNetModel"]
_import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
_import_structure["embeddings"] = ["ImageProjection"]
_import_structure["modeling_utils"] = ["ModelMixin"]
......@@ -68,6 +69,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ConsistencyDecoderVAE,
)
from .controlnet import ControlNetModel
from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel
from .embeddings import ImageProjection
from .modeling_utils import ModelMixin
from .transformers import (
......
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from math import gcd
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import FloatTensor, nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, is_torch_version, logging
from ..utils.torch_utils import apply_freeu
from .attention_processor import Attention, AttentionProcessor
from .controlnet import ControlNetConditioningEmbedding
from .embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .unets.unet_2d_blocks import (
CrossAttnDownBlock2D,
CrossAttnUpBlock2D,
Downsample2D,
ResnetBlock2D,
Transformer2DModel,
UNetMidBlock2DCrossAttn,
Upsample2D,
)
from .unets.unet_2d_condition import UNet2DConditionModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class ControlNetXSOutput(BaseOutput):
"""
The output of [`UNetControlNetXSModel`].
Args:
sample (`FloatTensor` of shape `(batch_size, num_channels, height, width)`):
The output of the `UNetControlNetXSModel`. Unlike `ControlNetOutput` this is NOT to be added to the base
model output, but is already the final output.
"""
sample: FloatTensor = None
class DownBlockControlNetXSAdapter(nn.Module):
"""Components that together with corresponding components from the base model will form a
`ControlNetXSCrossAttnDownBlock2D`"""
def __init__(
self,
resnets: nn.ModuleList,
base_to_ctrl: nn.ModuleList,
ctrl_to_base: nn.ModuleList,
attentions: Optional[nn.ModuleList] = None,
downsampler: Optional[nn.Conv2d] = None,
):
super().__init__()
self.resnets = resnets
self.base_to_ctrl = base_to_ctrl
self.ctrl_to_base = ctrl_to_base
self.attentions = attentions
self.downsamplers = downsampler
class MidBlockControlNetXSAdapter(nn.Module):
"""Components that together with corresponding components from the base model will form a
`ControlNetXSCrossAttnMidBlock2D`"""
def __init__(self, midblock: UNetMidBlock2DCrossAttn, base_to_ctrl: nn.ModuleList, ctrl_to_base: nn.ModuleList):
super().__init__()
self.midblock = midblock
self.base_to_ctrl = base_to_ctrl
self.ctrl_to_base = ctrl_to_base
class UpBlockControlNetXSAdapter(nn.Module):
"""Components that together with corresponding components from the base model will form a `ControlNetXSCrossAttnUpBlock2D`"""
def __init__(self, ctrl_to_base: nn.ModuleList):
super().__init__()
self.ctrl_to_base = ctrl_to_base
def get_down_block_adapter(
base_in_channels: int,
base_out_channels: int,
ctrl_in_channels: int,
ctrl_out_channels: int,
temb_channels: int,
max_norm_num_groups: Optional[int] = 32,
has_crossattn=True,
transformer_layers_per_block: Optional[Union[int, Tuple[int]]] = 1,
num_attention_heads: Optional[int] = 1,
cross_attention_dim: Optional[int] = 1024,
add_downsample: bool = True,
upcast_attention: Optional[bool] = False,
):
num_layers = 2 # only support sd + sdxl
resnets = []
attentions = []
ctrl_to_base = []
base_to_ctrl = []
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
for i in range(num_layers):
base_in_channels = base_in_channels if i == 0 else base_out_channels
ctrl_in_channels = ctrl_in_channels if i == 0 else ctrl_out_channels
# Before the resnet/attention application, information is concatted from base to control.
# Concat doesn't require change in number of channels
base_to_ctrl.append(make_zero_conv(base_in_channels, base_in_channels))
resnets.append(
ResnetBlock2D(
in_channels=ctrl_in_channels + base_in_channels, # information from base is concatted to ctrl
out_channels=ctrl_out_channels,
temb_channels=temb_channels,
groups=find_largest_factor(ctrl_in_channels + base_in_channels, max_factor=max_norm_num_groups),
groups_out=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups),
eps=1e-5,
)
)
if has_crossattn:
attentions.append(
Transformer2DModel(
num_attention_heads,
ctrl_out_channels // num_attention_heads,
in_channels=ctrl_out_channels,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
use_linear_projection=True,
upcast_attention=upcast_attention,
norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups),
)
)
# After the resnet/attention application, information is added from control to base
# Addition requires change in number of channels
ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels))
if add_downsample:
# Before the downsampler application, information is concatted from base to control
# Concat doesn't require change in number of channels
base_to_ctrl.append(make_zero_conv(base_out_channels, base_out_channels))
downsamplers = Downsample2D(
ctrl_out_channels + base_out_channels, use_conv=True, out_channels=ctrl_out_channels, name="op"
)
# After the downsampler application, information is added from control to base
# Addition requires change in number of channels
ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels))
else:
downsamplers = None
down_block_components = DownBlockControlNetXSAdapter(
resnets=nn.ModuleList(resnets),
base_to_ctrl=nn.ModuleList(base_to_ctrl),
ctrl_to_base=nn.ModuleList(ctrl_to_base),
)
if has_crossattn:
down_block_components.attentions = nn.ModuleList(attentions)
if downsamplers is not None:
down_block_components.downsamplers = downsamplers
return down_block_components
def get_mid_block_adapter(
base_channels: int,
ctrl_channels: int,
temb_channels: Optional[int] = None,
max_norm_num_groups: Optional[int] = 32,
transformer_layers_per_block: int = 1,
num_attention_heads: Optional[int] = 1,
cross_attention_dim: Optional[int] = 1024,
upcast_attention: bool = False,
):
# Before the midblock application, information is concatted from base to control.
# Concat doesn't require change in number of channels
base_to_ctrl = make_zero_conv(base_channels, base_channels)
midblock = UNetMidBlock2DCrossAttn(
transformer_layers_per_block=transformer_layers_per_block,
in_channels=ctrl_channels + base_channels,
out_channels=ctrl_channels,
temb_channels=temb_channels,
# number or norm groups must divide both in_channels and out_channels
resnet_groups=find_largest_factor(gcd(ctrl_channels, ctrl_channels + base_channels), max_norm_num_groups),
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads,
use_linear_projection=True,
upcast_attention=upcast_attention,
)
# After the midblock application, information is added from control to base
# Addition requires change in number of channels
ctrl_to_base = make_zero_conv(ctrl_channels, base_channels)
return MidBlockControlNetXSAdapter(base_to_ctrl=base_to_ctrl, midblock=midblock, ctrl_to_base=ctrl_to_base)
def get_up_block_adapter(
out_channels: int,
prev_output_channel: int,
ctrl_skip_channels: List[int],
):
ctrl_to_base = []
num_layers = 3 # only support sd + sdxl
for i in range(num_layers):
resnet_in_channels = prev_output_channel if i == 0 else out_channels
ctrl_to_base.append(make_zero_conv(ctrl_skip_channels[i], resnet_in_channels))
return UpBlockControlNetXSAdapter(ctrl_to_base=nn.ModuleList(ctrl_to_base))
class ControlNetXSAdapter(ModelMixin, ConfigMixin):
r"""
A `ControlNetXSAdapter` model. To use it, pass it into a `UNetControlNetXSModel` (together with a
`UNet2DConditionModel` base model).
This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic
methods implemented for all models (such as downloading or saving).
Like `UNetControlNetXSModel`, `ControlNetXSAdapter` is compatible with StableDiffusion and StableDiffusion-XL. It's
default parameters are compatible with StableDiffusion.
Parameters:
conditioning_channels (`int`, defaults to 3):
Number of channels of conditioning input (e.g. an image)
conditioning_channel_order (`str`, defaults to `"rgb"`):
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`):
The tuple of output channels for each block in the `controlnet_cond_embedding` layer.
time_embedding_mix (`float`, defaults to 1.0):
If 0, then only the control adapters's time embedding is used. If 1, then only the base unet's time
embedding is used. Otherwise, both are combined.
learn_time_embedding (`bool`, defaults to `False`):
Whether a time embedding should be learned. If yes, `UNetControlNetXSModel` will combine the time
embeddings of the base model and the control adapter. If no, `UNetControlNetXSModel` will use the base
model's time embedding.
num_attention_heads (`list[int]`, defaults to `[4]`):
The number of attention heads.
block_out_channels (`list[int]`, defaults to `[4, 8, 16, 16]`):
The tuple of output channels for each block.
base_block_out_channels (`list[int]`, defaults to `[320, 640, 1280, 1280]`):
The tuple of output channels for each block in the base unet.
cross_attention_dim (`int`, defaults to 1024):
The dimension of the cross attention features.
down_block_types (`list[str]`, defaults to `["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"]`):
The tuple of downsample blocks to use.
sample_size (`int`, defaults to 96):
Height and width of input/output sample.
transformer_layers_per_block (`Union[int, Tuple[int]]`, defaults to 1):
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
upcast_attention (`bool`, defaults to `True`):
Whether the attention computation should always be upcasted.
max_norm_num_groups (`int`, defaults to 32):
Maximum number of groups in group normal. The actual number will the the largest divisor of the respective
channels, that is <= max_norm_num_groups.
"""
@register_to_config
def __init__(
self,
conditioning_channels: int = 3,
conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
time_embedding_mix: float = 1.0,
learn_time_embedding: bool = False,
num_attention_heads: Union[int, Tuple[int]] = 4,
block_out_channels: Tuple[int] = (4, 8, 16, 16),
base_block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
cross_attention_dim: int = 1024,
down_block_types: Tuple[str] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
sample_size: Optional[int] = 96,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
upcast_attention: bool = True,
max_norm_num_groups: int = 32,
):
super().__init__()
time_embedding_input_dim = base_block_out_channels[0]
time_embedding_dim = base_block_out_channels[0] * 4
# Check inputs
if conditioning_channel_order not in ["rgb", "bgr"]:
raise ValueError(f"unknown `conditioning_channel_order`: {conditioning_channel_order}")
if len(block_out_channels) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
)
if not isinstance(transformer_layers_per_block, (list, tuple)):
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
if not isinstance(cross_attention_dim, (list, tuple)):
cross_attention_dim = [cross_attention_dim] * len(down_block_types)
# see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why `ControlNetXSAdapter` takes `num_attention_heads` instead of `attention_head_dim`
if not isinstance(num_attention_heads, (list, tuple)):
num_attention_heads = [num_attention_heads] * len(down_block_types)
if len(num_attention_heads) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)
# 5 - Create conditioning hint embedding
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
conditioning_embedding_channels=block_out_channels[0],
block_out_channels=conditioning_embedding_out_channels,
conditioning_channels=conditioning_channels,
)
# time
if learn_time_embedding:
self.time_embedding = TimestepEmbedding(time_embedding_input_dim, time_embedding_dim)
else:
self.time_embedding = None
self.down_blocks = nn.ModuleList([])
self.up_connections = nn.ModuleList([])
# input
self.conv_in = nn.Conv2d(4, block_out_channels[0], kernel_size=3, padding=1)
self.control_to_base_for_conv_in = make_zero_conv(block_out_channels[0], base_block_out_channels[0])
# down
base_out_channels = base_block_out_channels[0]
ctrl_out_channels = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
base_in_channels = base_out_channels
base_out_channels = base_block_out_channels[i]
ctrl_in_channels = ctrl_out_channels
ctrl_out_channels = block_out_channels[i]
has_crossattn = "CrossAttn" in down_block_type
is_final_block = i == len(down_block_types) - 1
self.down_blocks.append(
get_down_block_adapter(
base_in_channels=base_in_channels,
base_out_channels=base_out_channels,
ctrl_in_channels=ctrl_in_channels,
ctrl_out_channels=ctrl_out_channels,
temb_channels=time_embedding_dim,
max_norm_num_groups=max_norm_num_groups,
has_crossattn=has_crossattn,
transformer_layers_per_block=transformer_layers_per_block[i],
num_attention_heads=num_attention_heads[i],
cross_attention_dim=cross_attention_dim[i],
add_downsample=not is_final_block,
upcast_attention=upcast_attention,
)
)
# mid
self.mid_block = get_mid_block_adapter(
base_channels=base_block_out_channels[-1],
ctrl_channels=block_out_channels[-1],
temb_channels=time_embedding_dim,
transformer_layers_per_block=transformer_layers_per_block[-1],
num_attention_heads=num_attention_heads[-1],
cross_attention_dim=cross_attention_dim[-1],
upcast_attention=upcast_attention,
)
# up
# The skip connection channels are the output of the conv_in and of all the down subblocks
ctrl_skip_channels = [block_out_channels[0]]
for i, out_channels in enumerate(block_out_channels):
number_of_subblocks = (
3 if i < len(block_out_channels) - 1 else 2
) # every block has 3 subblocks, except last one, which has 2 as it has no downsampler
ctrl_skip_channels.extend([out_channels] * number_of_subblocks)
reversed_base_block_out_channels = list(reversed(base_block_out_channels))
base_out_channels = reversed_base_block_out_channels[0]
for i in range(len(down_block_types)):
prev_base_output_channel = base_out_channels
base_out_channels = reversed_base_block_out_channels[i]
ctrl_skip_channels_ = [ctrl_skip_channels.pop() for _ in range(3)]
self.up_connections.append(
get_up_block_adapter(
out_channels=base_out_channels,
prev_output_channel=prev_base_output_channel,
ctrl_skip_channels=ctrl_skip_channels_,
)
)
@classmethod
def from_unet(
cls,
unet: UNet2DConditionModel,
size_ratio: Optional[float] = None,
block_out_channels: Optional[List[int]] = None,
num_attention_heads: Optional[List[int]] = None,
learn_time_embedding: bool = False,
time_embedding_mix: int = 1.0,
conditioning_channels: int = 3,
conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
):
r"""
Instantiate a [`ControlNetXSAdapter`] from a [`UNet2DConditionModel`].
Parameters:
unet (`UNet2DConditionModel`):
The UNet model we want to control. The dimensions of the ControlNetXSAdapter will be adapted to it.
size_ratio (float, *optional*, defaults to `None`):
When given, block_out_channels is set to a fraction of the base model's block_out_channels. Either this
or `block_out_channels` must be given.
block_out_channels (`List[int]`, *optional*, defaults to `None`):
Down blocks output channels in control model. Either this or `size_ratio` must be given.
num_attention_heads (`List[int]`, *optional*, defaults to `None`):
The dimension of the attention heads. The naming seems a bit confusing and it is, see
https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why.
learn_time_embedding (`bool`, defaults to `False`):
Whether the `ControlNetXSAdapter` should learn a time embedding.
time_embedding_mix (`float`, defaults to 1.0):
If 0, then only the control adapter's time embedding is used. If 1, then only the base unet's time
embedding is used. Otherwise, both are combined.
conditioning_channels (`int`, defaults to 3):
Number of channels of conditioning input (e.g. an image)
conditioning_channel_order (`str`, defaults to `"rgb"`):
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
conditioning_embedding_out_channels (`Tuple[int]`, defaults to `(16, 32, 96, 256)`):
The tuple of output channel for each block in the `controlnet_cond_embedding` layer.
"""
# Check input
fixed_size = block_out_channels is not None
relative_size = size_ratio is not None
if not (fixed_size ^ relative_size):
raise ValueError(
"Pass exactly one of `block_out_channels` (for absolute sizing) or `size_ratio` (for relative sizing)."
)
# Create model
block_out_channels = block_out_channels or [int(b * size_ratio) for b in unet.config.block_out_channels]
if num_attention_heads is None:
# The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why.
num_attention_heads = unet.config.attention_head_dim
model = cls(
conditioning_channels=conditioning_channels,
conditioning_channel_order=conditioning_channel_order,
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
time_embedding_mix=time_embedding_mix,
learn_time_embedding=learn_time_embedding,
num_attention_heads=num_attention_heads,
block_out_channels=block_out_channels,
base_block_out_channels=unet.config.block_out_channels,
cross_attention_dim=unet.config.cross_attention_dim,
down_block_types=unet.config.down_block_types,
sample_size=unet.config.sample_size,
transformer_layers_per_block=unet.config.transformer_layers_per_block,
upcast_attention=unet.config.upcast_attention,
max_norm_num_groups=unet.config.norm_num_groups,
)
# ensure that the ControlNetXSAdapter is the same dtype as the UNet2DConditionModel
model.to(unet.dtype)
return model
def forward(self, *args, **kwargs):
raise ValueError(
"A ControlNetXSAdapter cannot be run by itself. Use it together with a UNet2DConditionModel to instantiate a UNetControlNetXSModel."
)
class UNetControlNetXSModel(ModelMixin, ConfigMixin):
r"""
A UNet fused with a ControlNet-XS adapter model
This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic
methods implemented for all models (such as downloading or saving).
`UNetControlNetXSModel` is compatible with StableDiffusion and StableDiffusion-XL. It's default parameters are
compatible with StableDiffusion.
It's parameters are either passed to the underlying `UNet2DConditionModel` or used exactly like in
`ControlNetXSAdapter` . See their documentation for details.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
# unet configs
sample_size: Optional[int] = 96,
down_block_types: Tuple[str] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
norm_num_groups: Optional[int] = 32,
cross_attention_dim: Union[int, Tuple[int]] = 1024,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
num_attention_heads: Union[int, Tuple[int]] = 8,
addition_embed_type: Optional[str] = None,
addition_time_embed_dim: Optional[int] = None,
upcast_attention: bool = True,
time_cond_proj_dim: Optional[int] = None,
projection_class_embeddings_input_dim: Optional[int] = None,
# additional controlnet configs
time_embedding_mix: float = 1.0,
ctrl_conditioning_channels: int = 3,
ctrl_conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
ctrl_conditioning_channel_order: str = "rgb",
ctrl_learn_time_embedding: bool = False,
ctrl_block_out_channels: Tuple[int] = (4, 8, 16, 16),
ctrl_num_attention_heads: Union[int, Tuple[int]] = 4,
ctrl_max_norm_num_groups: int = 32,
):
super().__init__()
if time_embedding_mix < 0 or time_embedding_mix > 1:
raise ValueError("`time_embedding_mix` needs to be between 0 and 1.")
if time_embedding_mix < 1 and not ctrl_learn_time_embedding:
raise ValueError("To use `time_embedding_mix` < 1, `ctrl_learn_time_embedding` must be `True`")
if addition_embed_type is not None and addition_embed_type != "text_time":
raise ValueError(
"As `UNetControlNetXSModel` currently only supports StableDiffusion and StableDiffusion-XL, `addition_embed_type` must be `None` or `'text_time'`."
)
if not isinstance(transformer_layers_per_block, (list, tuple)):
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
if not isinstance(cross_attention_dim, (list, tuple)):
cross_attention_dim = [cross_attention_dim] * len(down_block_types)
if not isinstance(num_attention_heads, (list, tuple)):
num_attention_heads = [num_attention_heads] * len(down_block_types)
if not isinstance(ctrl_num_attention_heads, (list, tuple)):
ctrl_num_attention_heads = [ctrl_num_attention_heads] * len(down_block_types)
base_num_attention_heads = num_attention_heads
self.in_channels = 4
# # Input
self.base_conv_in = nn.Conv2d(4, block_out_channels[0], kernel_size=3, padding=1)
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
conditioning_embedding_channels=ctrl_block_out_channels[0],
block_out_channels=ctrl_conditioning_embedding_out_channels,
conditioning_channels=ctrl_conditioning_channels,
)
self.ctrl_conv_in = nn.Conv2d(4, ctrl_block_out_channels[0], kernel_size=3, padding=1)
self.control_to_base_for_conv_in = make_zero_conv(ctrl_block_out_channels[0], block_out_channels[0])
# # Time
time_embed_input_dim = block_out_channels[0]
time_embed_dim = block_out_channels[0] * 4
self.base_time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos=True, downscale_freq_shift=0)
self.base_time_embedding = TimestepEmbedding(
time_embed_input_dim,
time_embed_dim,
cond_proj_dim=time_cond_proj_dim,
)
self.ctrl_time_embedding = TimestepEmbedding(in_channels=time_embed_input_dim, time_embed_dim=time_embed_dim)
if addition_embed_type is None:
self.base_add_time_proj = None
self.base_add_embedding = None
else:
self.base_add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.base_add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
# # Create down blocks
down_blocks = []
base_out_channels = block_out_channels[0]
ctrl_out_channels = ctrl_block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
base_in_channels = base_out_channels
base_out_channels = block_out_channels[i]
ctrl_in_channels = ctrl_out_channels
ctrl_out_channels = ctrl_block_out_channels[i]
has_crossattn = "CrossAttn" in down_block_type
is_final_block = i == len(down_block_types) - 1
down_blocks.append(
ControlNetXSCrossAttnDownBlock2D(
base_in_channels=base_in_channels,
base_out_channels=base_out_channels,
ctrl_in_channels=ctrl_in_channels,
ctrl_out_channels=ctrl_out_channels,
temb_channels=time_embed_dim,
norm_num_groups=norm_num_groups,
ctrl_max_norm_num_groups=ctrl_max_norm_num_groups,
has_crossattn=has_crossattn,
transformer_layers_per_block=transformer_layers_per_block[i],
base_num_attention_heads=base_num_attention_heads[i],
ctrl_num_attention_heads=ctrl_num_attention_heads[i],
cross_attention_dim=cross_attention_dim[i],
add_downsample=not is_final_block,
upcast_attention=upcast_attention,
)
)
# # Create mid block
self.mid_block = ControlNetXSCrossAttnMidBlock2D(
base_channels=block_out_channels[-1],
ctrl_channels=ctrl_block_out_channels[-1],
temb_channels=time_embed_dim,
norm_num_groups=norm_num_groups,
ctrl_max_norm_num_groups=ctrl_max_norm_num_groups,
transformer_layers_per_block=transformer_layers_per_block[-1],
base_num_attention_heads=base_num_attention_heads[-1],
ctrl_num_attention_heads=ctrl_num_attention_heads[-1],
cross_attention_dim=cross_attention_dim[-1],
upcast_attention=upcast_attention,
)
# # Create up blocks
up_blocks = []
rev_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
rev_num_attention_heads = list(reversed(base_num_attention_heads))
rev_cross_attention_dim = list(reversed(cross_attention_dim))
# The skip connection channels are the output of the conv_in and of all the down subblocks
ctrl_skip_channels = [ctrl_block_out_channels[0]]
for i, out_channels in enumerate(ctrl_block_out_channels):
number_of_subblocks = (
3 if i < len(ctrl_block_out_channels) - 1 else 2
) # every block has 3 subblocks, except last one, which has 2 as it has no downsampler
ctrl_skip_channels.extend([out_channels] * number_of_subblocks)
reversed_block_out_channels = list(reversed(block_out_channels))
out_channels = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = out_channels
out_channels = reversed_block_out_channels[i]
in_channels = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
ctrl_skip_channels_ = [ctrl_skip_channels.pop() for _ in range(3)]
has_crossattn = "CrossAttn" in up_block_type
is_final_block = i == len(block_out_channels) - 1
up_blocks.append(
ControlNetXSCrossAttnUpBlock2D(
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
ctrl_skip_channels=ctrl_skip_channels_,
temb_channels=time_embed_dim,
resolution_idx=i,
has_crossattn=has_crossattn,
transformer_layers_per_block=rev_transformer_layers_per_block[i],
num_attention_heads=rev_num_attention_heads[i],
cross_attention_dim=rev_cross_attention_dim[i],
add_upsample=not is_final_block,
upcast_attention=upcast_attention,
norm_num_groups=norm_num_groups,
)
)
self.down_blocks = nn.ModuleList(down_blocks)
self.up_blocks = nn.ModuleList(up_blocks)
self.base_conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups)
self.base_conv_act = nn.SiLU()
self.base_conv_out = nn.Conv2d(block_out_channels[0], 4, kernel_size=3, padding=1)
@classmethod
def from_unet(
cls,
unet: UNet2DConditionModel,
controlnet: Optional[ControlNetXSAdapter] = None,
size_ratio: Optional[float] = None,
ctrl_block_out_channels: Optional[List[float]] = None,
time_embedding_mix: Optional[float] = None,
ctrl_optional_kwargs: Optional[Dict] = None,
):
r"""
Instantiate a [`UNetControlNetXSModel`] from a [`UNet2DConditionModel`] and an optional [`ControlNetXSAdapter`]
.
Parameters:
unet (`UNet2DConditionModel`):
The UNet model we want to control.
controlnet (`ControlNetXSAdapter`):
The ConntrolNet-XS adapter with which the UNet will be fused. If none is given, a new ConntrolNet-XS
adapter will be created.
size_ratio (float, *optional*, defaults to `None`):
Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details.
ctrl_block_out_channels (`List[int]`, *optional*, defaults to `None`):
Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details,
where this parameter is called `block_out_channels`.
time_embedding_mix (`float`, *optional*, defaults to None):
Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details.
ctrl_optional_kwargs (`Dict`, *optional*, defaults to `None`):
Passed to the `init` of the new controlent if no controlent was given.
"""
if controlnet is None:
controlnet = ControlNetXSAdapter.from_unet(
unet, size_ratio, ctrl_block_out_channels, **ctrl_optional_kwargs
)
else:
if any(
o is not None for o in (size_ratio, ctrl_block_out_channels, time_embedding_mix, ctrl_optional_kwargs)
):
raise ValueError(
"When a controlnet is passed, none of these parameters should be passed: size_ratio, ctrl_block_out_channels, time_embedding_mix, ctrl_optional_kwargs."
)
# # get params
params_for_unet = [
"sample_size",
"down_block_types",
"up_block_types",
"block_out_channels",
"norm_num_groups",
"cross_attention_dim",
"transformer_layers_per_block",
"addition_embed_type",
"addition_time_embed_dim",
"upcast_attention",
"time_cond_proj_dim",
"projection_class_embeddings_input_dim",
]
params_for_unet = {k: v for k, v in unet.config.items() if k in params_for_unet}
# The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why.
params_for_unet["num_attention_heads"] = unet.config.attention_head_dim
params_for_controlnet = [
"conditioning_channels",
"conditioning_embedding_out_channels",
"conditioning_channel_order",
"learn_time_embedding",
"block_out_channels",
"num_attention_heads",
"max_norm_num_groups",
]
params_for_controlnet = {"ctrl_" + k: v for k, v in controlnet.config.items() if k in params_for_controlnet}
params_for_controlnet["time_embedding_mix"] = controlnet.config.time_embedding_mix
# # create model
model = cls.from_config({**params_for_unet, **params_for_controlnet})
# # load weights
# from unet
modules_from_unet = [
"time_embedding",
"conv_in",
"conv_norm_out",
"conv_out",
]
for m in modules_from_unet:
getattr(model, "base_" + m).load_state_dict(getattr(unet, m).state_dict())
optional_modules_from_unet = [
"add_time_proj",
"add_embedding",
]
for m in optional_modules_from_unet:
if hasattr(unet, m) and getattr(unet, m) is not None:
getattr(model, "base_" + m).load_state_dict(getattr(unet, m).state_dict())
# from controlnet
model.controlnet_cond_embedding.load_state_dict(controlnet.controlnet_cond_embedding.state_dict())
model.ctrl_conv_in.load_state_dict(controlnet.conv_in.state_dict())
if controlnet.time_embedding is not None:
model.ctrl_time_embedding.load_state_dict(controlnet.time_embedding.state_dict())
model.control_to_base_for_conv_in.load_state_dict(controlnet.control_to_base_for_conv_in.state_dict())
# from both
model.down_blocks = nn.ModuleList(
ControlNetXSCrossAttnDownBlock2D.from_modules(b, c)
for b, c in zip(unet.down_blocks, controlnet.down_blocks)
)
model.mid_block = ControlNetXSCrossAttnMidBlock2D.from_modules(unet.mid_block, controlnet.mid_block)
model.up_blocks = nn.ModuleList(
ControlNetXSCrossAttnUpBlock2D.from_modules(b, c)
for b, c in zip(unet.up_blocks, controlnet.up_connections)
)
# ensure that the UNetControlNetXSModel is the same dtype as the UNet2DConditionModel
model.to(unet.dtype)
return model
def freeze_unet_params(self) -> None:
"""Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine
tuning."""
# Freeze everything
for param in self.parameters():
param.requires_grad = True
# Unfreeze ControlNetXSAdapter
base_parts = [
"base_time_proj",
"base_time_embedding",
"base_add_time_proj",
"base_add_embedding",
"base_conv_in",
"base_conv_norm_out",
"base_conv_act",
"base_conv_out",
]
base_parts = [getattr(self, part) for part in base_parts if getattr(self, part) is not None]
for part in base_parts:
for param in part.parameters():
param.requires_grad = False
for d in self.down_blocks:
d.freeze_base_params()
self.mid_block.freeze_base_params()
for u in self.up_blocks:
u.freeze_base_params()
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
@property
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
The suffixes after the scaling factors represent the stage blocks where they are being applied.
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
Args:
s1 (`float`):
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
mitigate the "oversmoothing effect" in the enhanced denoising process.
s2 (`float`):
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
mitigate the "oversmoothing effect" in the enhanced denoising process.
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
"""
for i, upsample_block in enumerate(self.up_blocks):
setattr(upsample_block, "s1", s1)
setattr(upsample_block, "s2", s2)
setattr(upsample_block, "b1", b1)
setattr(upsample_block, "b2", b2)
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
def disable_freeu(self):
"""Disables the FreeU mechanism."""
freeu_keys = {"s1", "s2", "b1", "b2"}
for i, upsample_block in enumerate(self.up_blocks):
for k in freeu_keys:
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
setattr(upsample_block, k, None)
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
# copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
def forward(
self,
sample: FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
controlnet_cond: Optional[torch.Tensor] = None,
conditioning_scale: Optional[float] = 1.0,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
return_dict: bool = True,
apply_control: bool = True,
) -> Union[ControlNetXSOutput, Tuple]:
"""
The [`ControlNetXSModel`] forward method.
Args:
sample (`FloatTensor`):
The noisy input tensor.
timestep (`Union[torch.Tensor, float, int]`):
The number of timesteps to denoise an input.
encoder_hidden_states (`torch.Tensor`):
The encoder hidden states.
controlnet_cond (`FloatTensor`):
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
conditioning_scale (`float`, defaults to `1.0`):
How much the control model affects the base model outputs.
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
embeddings.
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
negative values to the attention scores corresponding to "discard" tokens.
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
added_cond_kwargs (`dict`):
Additional conditions for the Stable Diffusion XL UNet.
return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
apply_control (`bool`, defaults to `True`):
If `False`, the input is run only through the base model.
Returns:
[`~models.controlnetxs.ControlNetXSOutput`] **or** `tuple`:
If `return_dict` is `True`, a [`~models.controlnetxs.ControlNetXSOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
# check channel order
if self.config.ctrl_conditioning_channel_order == "bgr":
controlnet_cond = torch.flip(controlnet_cond, dims=[1])
# prepare attention_mask
if attention_mask is not None:
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.base_time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=sample.dtype)
if self.config.ctrl_learn_time_embedding and apply_control:
ctrl_temb = self.ctrl_time_embedding(t_emb, timestep_cond)
base_temb = self.base_time_embedding(t_emb, timestep_cond)
interpolation_param = self.config.time_embedding_mix**0.3
temb = ctrl_temb * interpolation_param + base_temb * (1 - interpolation_param)
else:
temb = self.base_time_embedding(t_emb)
# added time & text embeddings
aug_emb = None
if self.config.addition_embed_type is None:
pass
elif self.config.addition_embed_type == "text_time":
# SDXL - style
if "text_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
)
text_embeds = added_cond_kwargs.get("text_embeds")
if "time_ids" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
)
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.base_add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(temb.dtype)
aug_emb = self.base_add_embedding(add_embeds)
else:
raise ValueError(
f"ControlNet-XS currently only supports StableDiffusion and StableDiffusion-XL, so addition_embed_type = {self.config.addition_embed_type} is currently not supported."
)
temb = temb + aug_emb if aug_emb is not None else temb
# text embeddings
cemb = encoder_hidden_states
# Preparation
h_ctrl = h_base = sample
hs_base, hs_ctrl = [], []
# Cross Control
guided_hint = self.controlnet_cond_embedding(controlnet_cond)
# 1 - conv in & down
h_base = self.base_conv_in(h_base)
h_ctrl = self.ctrl_conv_in(h_ctrl)
if guided_hint is not None:
h_ctrl += guided_hint
if apply_control:
h_base = h_base + self.control_to_base_for_conv_in(h_ctrl) * conditioning_scale # add ctrl -> base
hs_base.append(h_base)
hs_ctrl.append(h_ctrl)
for down in self.down_blocks:
h_base, h_ctrl, residual_hb, residual_hc = down(
hidden_states_base=h_base,
hidden_states_ctrl=h_ctrl,
temb=temb,
encoder_hidden_states=cemb,
conditioning_scale=conditioning_scale,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
apply_control=apply_control,
)
hs_base.extend(residual_hb)
hs_ctrl.extend(residual_hc)
# 2 - mid
h_base, h_ctrl = self.mid_block(
hidden_states_base=h_base,
hidden_states_ctrl=h_ctrl,
temb=temb,
encoder_hidden_states=cemb,
conditioning_scale=conditioning_scale,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
apply_control=apply_control,
)
# 3 - up
for up in self.up_blocks:
n_resnets = len(up.resnets)
skips_hb = hs_base[-n_resnets:]
skips_hc = hs_ctrl[-n_resnets:]
hs_base = hs_base[:-n_resnets]
hs_ctrl = hs_ctrl[:-n_resnets]
h_base = up(
hidden_states=h_base,
res_hidden_states_tuple_base=skips_hb,
res_hidden_states_tuple_ctrl=skips_hc,
temb=temb,
encoder_hidden_states=cemb,
conditioning_scale=conditioning_scale,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
apply_control=apply_control,
)
# 4 - conv out
h_base = self.base_conv_norm_out(h_base)
h_base = self.base_conv_act(h_base)
h_base = self.base_conv_out(h_base)
if not return_dict:
return (h_base,)
return ControlNetXSOutput(sample=h_base)
class ControlNetXSCrossAttnDownBlock2D(nn.Module):
def __init__(
self,
base_in_channels: int,
base_out_channels: int,
ctrl_in_channels: int,
ctrl_out_channels: int,
temb_channels: int,
norm_num_groups: int = 32,
ctrl_max_norm_num_groups: int = 32,
has_crossattn=True,
transformer_layers_per_block: Optional[Union[int, Tuple[int]]] = 1,
base_num_attention_heads: Optional[int] = 1,
ctrl_num_attention_heads: Optional[int] = 1,
cross_attention_dim: Optional[int] = 1024,
add_downsample: bool = True,
upcast_attention: Optional[bool] = False,
):
super().__init__()
base_resnets = []
base_attentions = []
ctrl_resnets = []
ctrl_attentions = []
ctrl_to_base = []
base_to_ctrl = []
num_layers = 2 # only support sd + sdxl
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
for i in range(num_layers):
base_in_channels = base_in_channels if i == 0 else base_out_channels
ctrl_in_channels = ctrl_in_channels if i == 0 else ctrl_out_channels
# Before the resnet/attention application, information is concatted from base to control.
# Concat doesn't require change in number of channels
base_to_ctrl.append(make_zero_conv(base_in_channels, base_in_channels))
base_resnets.append(
ResnetBlock2D(
in_channels=base_in_channels,
out_channels=base_out_channels,
temb_channels=temb_channels,
groups=norm_num_groups,
)
)
ctrl_resnets.append(
ResnetBlock2D(
in_channels=ctrl_in_channels + base_in_channels, # information from base is concatted to ctrl
out_channels=ctrl_out_channels,
temb_channels=temb_channels,
groups=find_largest_factor(
ctrl_in_channels + base_in_channels, max_factor=ctrl_max_norm_num_groups
),
groups_out=find_largest_factor(ctrl_out_channels, max_factor=ctrl_max_norm_num_groups),
eps=1e-5,
)
)
if has_crossattn:
base_attentions.append(
Transformer2DModel(
base_num_attention_heads,
base_out_channels // base_num_attention_heads,
in_channels=base_out_channels,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
use_linear_projection=True,
upcast_attention=upcast_attention,
norm_num_groups=norm_num_groups,
)
)
ctrl_attentions.append(
Transformer2DModel(
ctrl_num_attention_heads,
ctrl_out_channels // ctrl_num_attention_heads,
in_channels=ctrl_out_channels,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
use_linear_projection=True,
upcast_attention=upcast_attention,
norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=ctrl_max_norm_num_groups),
)
)
# After the resnet/attention application, information is added from control to base
# Addition requires change in number of channels
ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels))
if add_downsample:
# Before the downsampler application, information is concatted from base to control
# Concat doesn't require change in number of channels
base_to_ctrl.append(make_zero_conv(base_out_channels, base_out_channels))
self.base_downsamplers = Downsample2D(
base_out_channels, use_conv=True, out_channels=base_out_channels, name="op"
)
self.ctrl_downsamplers = Downsample2D(
ctrl_out_channels + base_out_channels, use_conv=True, out_channels=ctrl_out_channels, name="op"
)
# After the downsampler application, information is added from control to base
# Addition requires change in number of channels
ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels))
else:
self.base_downsamplers = None
self.ctrl_downsamplers = None
self.base_resnets = nn.ModuleList(base_resnets)
self.ctrl_resnets = nn.ModuleList(ctrl_resnets)
self.base_attentions = nn.ModuleList(base_attentions) if has_crossattn else [None] * num_layers
self.ctrl_attentions = nn.ModuleList(ctrl_attentions) if has_crossattn else [None] * num_layers
self.base_to_ctrl = nn.ModuleList(base_to_ctrl)
self.ctrl_to_base = nn.ModuleList(ctrl_to_base)
self.gradient_checkpointing = False
@classmethod
def from_modules(cls, base_downblock: CrossAttnDownBlock2D, ctrl_downblock: DownBlockControlNetXSAdapter):
# get params
def get_first_cross_attention(block):
return block.attentions[0].transformer_blocks[0].attn2
base_in_channels = base_downblock.resnets[0].in_channels
base_out_channels = base_downblock.resnets[0].out_channels
ctrl_in_channels = (
ctrl_downblock.resnets[0].in_channels - base_in_channels
) # base channels are concatted to ctrl channels in init
ctrl_out_channels = ctrl_downblock.resnets[0].out_channels
temb_channels = base_downblock.resnets[0].time_emb_proj.in_features
num_groups = base_downblock.resnets[0].norm1.num_groups
ctrl_num_groups = ctrl_downblock.resnets[0].norm1.num_groups
if hasattr(base_downblock, "attentions"):
has_crossattn = True
transformer_layers_per_block = len(base_downblock.attentions[0].transformer_blocks)
base_num_attention_heads = get_first_cross_attention(base_downblock).heads
ctrl_num_attention_heads = get_first_cross_attention(ctrl_downblock).heads
cross_attention_dim = get_first_cross_attention(base_downblock).cross_attention_dim
upcast_attention = get_first_cross_attention(base_downblock).upcast_attention
else:
has_crossattn = False
transformer_layers_per_block = None
base_num_attention_heads = None
ctrl_num_attention_heads = None
cross_attention_dim = None
upcast_attention = None
add_downsample = base_downblock.downsamplers is not None
# create model
model = cls(
base_in_channels=base_in_channels,
base_out_channels=base_out_channels,
ctrl_in_channels=ctrl_in_channels,
ctrl_out_channels=ctrl_out_channels,
temb_channels=temb_channels,
norm_num_groups=num_groups,
ctrl_max_norm_num_groups=ctrl_num_groups,
has_crossattn=has_crossattn,
transformer_layers_per_block=transformer_layers_per_block,
base_num_attention_heads=base_num_attention_heads,
ctrl_num_attention_heads=ctrl_num_attention_heads,
cross_attention_dim=cross_attention_dim,
add_downsample=add_downsample,
upcast_attention=upcast_attention,
)
# # load weights
model.base_resnets.load_state_dict(base_downblock.resnets.state_dict())
model.ctrl_resnets.load_state_dict(ctrl_downblock.resnets.state_dict())
if has_crossattn:
model.base_attentions.load_state_dict(base_downblock.attentions.state_dict())
model.ctrl_attentions.load_state_dict(ctrl_downblock.attentions.state_dict())
if add_downsample:
model.base_downsamplers.load_state_dict(base_downblock.downsamplers[0].state_dict())
model.ctrl_downsamplers.load_state_dict(ctrl_downblock.downsamplers.state_dict())
model.base_to_ctrl.load_state_dict(ctrl_downblock.base_to_ctrl.state_dict())
model.ctrl_to_base.load_state_dict(ctrl_downblock.ctrl_to_base.state_dict())
return model
def freeze_base_params(self) -> None:
"""Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine
tuning."""
# Unfreeze everything
for param in self.parameters():
param.requires_grad = True
# Freeze base part
base_parts = [self.base_resnets]
if isinstance(self.base_attentions, nn.ModuleList): # attentions can be a list of Nones
base_parts.append(self.base_attentions)
if self.base_downsamplers is not None:
base_parts.append(self.base_downsamplers)
for part in base_parts:
for param in part.parameters():
param.requires_grad = False
def forward(
self,
hidden_states_base: FloatTensor,
temb: FloatTensor,
encoder_hidden_states: Optional[FloatTensor] = None,
hidden_states_ctrl: Optional[FloatTensor] = None,
conditioning_scale: Optional[float] = 1.0,
attention_mask: Optional[FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[FloatTensor] = None,
apply_control: bool = True,
) -> Tuple[FloatTensor, FloatTensor, Tuple[FloatTensor, ...], Tuple[FloatTensor, ...]]:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
h_base = hidden_states_base
h_ctrl = hidden_states_ctrl
base_output_states = ()
ctrl_output_states = ()
base_blocks = list(zip(self.base_resnets, self.base_attentions))
ctrl_blocks = list(zip(self.ctrl_resnets, self.ctrl_attentions))
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
for (b_res, b_attn), (c_res, c_attn), b2c, c2b in zip(
base_blocks, ctrl_blocks, self.base_to_ctrl, self.ctrl_to_base
):
# concat base -> ctrl
if apply_control:
h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1)
# apply base subblock
if self.training and self.gradient_checkpointing:
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
h_base = torch.utils.checkpoint.checkpoint(
create_custom_forward(b_res),
h_base,
temb,
**ckpt_kwargs,
)
else:
h_base = b_res(h_base, temb)
if b_attn is not None:
h_base = b_attn(
h_base,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
# apply ctrl subblock
if apply_control:
if self.training and self.gradient_checkpointing:
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
h_ctrl = torch.utils.checkpoint.checkpoint(
create_custom_forward(c_res),
h_ctrl,
temb,
**ckpt_kwargs,
)
else:
h_ctrl = c_res(h_ctrl, temb)
if c_attn is not None:
h_ctrl = c_attn(
h_ctrl,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
# add ctrl -> base
if apply_control:
h_base = h_base + c2b(h_ctrl) * conditioning_scale
base_output_states = base_output_states + (h_base,)
ctrl_output_states = ctrl_output_states + (h_ctrl,)
if self.base_downsamplers is not None: # if we have a base_downsampler, then also a ctrl_downsampler
b2c = self.base_to_ctrl[-1]
c2b = self.ctrl_to_base[-1]
# concat base -> ctrl
if apply_control:
h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1)
# apply base subblock
h_base = self.base_downsamplers(h_base)
# apply ctrl subblock
if apply_control:
h_ctrl = self.ctrl_downsamplers(h_ctrl)
# add ctrl -> base
if apply_control:
h_base = h_base + c2b(h_ctrl) * conditioning_scale
base_output_states = base_output_states + (h_base,)
ctrl_output_states = ctrl_output_states + (h_ctrl,)
return h_base, h_ctrl, base_output_states, ctrl_output_states
class ControlNetXSCrossAttnMidBlock2D(nn.Module):
def __init__(
self,
base_channels: int,
ctrl_channels: int,
temb_channels: Optional[int] = None,
norm_num_groups: int = 32,
ctrl_max_norm_num_groups: int = 32,
transformer_layers_per_block: int = 1,
base_num_attention_heads: Optional[int] = 1,
ctrl_num_attention_heads: Optional[int] = 1,
cross_attention_dim: Optional[int] = 1024,
upcast_attention: bool = False,
):
super().__init__()
# Before the midblock application, information is concatted from base to control.
# Concat doesn't require change in number of channels
self.base_to_ctrl = make_zero_conv(base_channels, base_channels)
self.base_midblock = UNetMidBlock2DCrossAttn(
transformer_layers_per_block=transformer_layers_per_block,
in_channels=base_channels,
temb_channels=temb_channels,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
num_attention_heads=base_num_attention_heads,
use_linear_projection=True,
upcast_attention=upcast_attention,
)
self.ctrl_midblock = UNetMidBlock2DCrossAttn(
transformer_layers_per_block=transformer_layers_per_block,
in_channels=ctrl_channels + base_channels,
out_channels=ctrl_channels,
temb_channels=temb_channels,
# number or norm groups must divide both in_channels and out_channels
resnet_groups=find_largest_factor(
gcd(ctrl_channels, ctrl_channels + base_channels), ctrl_max_norm_num_groups
),
cross_attention_dim=cross_attention_dim,
num_attention_heads=ctrl_num_attention_heads,
use_linear_projection=True,
upcast_attention=upcast_attention,
)
# After the midblock application, information is added from control to base
# Addition requires change in number of channels
self.ctrl_to_base = make_zero_conv(ctrl_channels, base_channels)
self.gradient_checkpointing = False
@classmethod
def from_modules(
cls,
base_midblock: UNetMidBlock2DCrossAttn,
ctrl_midblock: MidBlockControlNetXSAdapter,
):
base_to_ctrl = ctrl_midblock.base_to_ctrl
ctrl_to_base = ctrl_midblock.ctrl_to_base
ctrl_midblock = ctrl_midblock.midblock
# get params
def get_first_cross_attention(midblock):
return midblock.attentions[0].transformer_blocks[0].attn2
base_channels = ctrl_to_base.out_channels
ctrl_channels = ctrl_to_base.in_channels
transformer_layers_per_block = len(base_midblock.attentions[0].transformer_blocks)
temb_channels = base_midblock.resnets[0].time_emb_proj.in_features
num_groups = base_midblock.resnets[0].norm1.num_groups
ctrl_num_groups = ctrl_midblock.resnets[0].norm1.num_groups
base_num_attention_heads = get_first_cross_attention(base_midblock).heads
ctrl_num_attention_heads = get_first_cross_attention(ctrl_midblock).heads
cross_attention_dim = get_first_cross_attention(base_midblock).cross_attention_dim
upcast_attention = get_first_cross_attention(base_midblock).upcast_attention
# create model
model = cls(
base_channels=base_channels,
ctrl_channels=ctrl_channels,
temb_channels=temb_channels,
norm_num_groups=num_groups,
ctrl_max_norm_num_groups=ctrl_num_groups,
transformer_layers_per_block=transformer_layers_per_block,
base_num_attention_heads=base_num_attention_heads,
ctrl_num_attention_heads=ctrl_num_attention_heads,
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
)
# load weights
model.base_to_ctrl.load_state_dict(base_to_ctrl.state_dict())
model.base_midblock.load_state_dict(base_midblock.state_dict())
model.ctrl_midblock.load_state_dict(ctrl_midblock.state_dict())
model.ctrl_to_base.load_state_dict(ctrl_to_base.state_dict())
return model
def freeze_base_params(self) -> None:
"""Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine
tuning."""
# Unfreeze everything
for param in self.parameters():
param.requires_grad = True
# Freeze base part
for param in self.base_midblock.parameters():
param.requires_grad = False
def forward(
self,
hidden_states_base: FloatTensor,
temb: FloatTensor,
encoder_hidden_states: FloatTensor,
hidden_states_ctrl: Optional[FloatTensor] = None,
conditioning_scale: Optional[float] = 1.0,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
attention_mask: Optional[FloatTensor] = None,
encoder_attention_mask: Optional[FloatTensor] = None,
apply_control: bool = True,
) -> Tuple[FloatTensor, FloatTensor]:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
h_base = hidden_states_base
h_ctrl = hidden_states_ctrl
joint_args = {
"temb": temb,
"encoder_hidden_states": encoder_hidden_states,
"attention_mask": attention_mask,
"cross_attention_kwargs": cross_attention_kwargs,
"encoder_attention_mask": encoder_attention_mask,
}
if apply_control:
h_ctrl = torch.cat([h_ctrl, self.base_to_ctrl(h_base)], dim=1) # concat base -> ctrl
h_base = self.base_midblock(h_base, **joint_args) # apply base mid block
if apply_control:
h_ctrl = self.ctrl_midblock(h_ctrl, **joint_args) # apply ctrl mid block
h_base = h_base + self.ctrl_to_base(h_ctrl) * conditioning_scale # add ctrl -> base
return h_base, h_ctrl
class ControlNetXSCrossAttnUpBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
prev_output_channel: int,
ctrl_skip_channels: List[int],
temb_channels: int,
norm_num_groups: int = 32,
resolution_idx: Optional[int] = None,
has_crossattn=True,
transformer_layers_per_block: int = 1,
num_attention_heads: int = 1,
cross_attention_dim: int = 1024,
add_upsample: bool = True,
upcast_attention: bool = False,
):
super().__init__()
resnets = []
attentions = []
ctrl_to_base = []
num_layers = 3 # only support sd + sdxl
self.has_cross_attention = has_crossattn
self.num_attention_heads = num_attention_heads
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
ctrl_to_base.append(make_zero_conv(ctrl_skip_channels[i], resnet_in_channels))
resnets.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
groups=norm_num_groups,
)
)
if has_crossattn:
attentions.append(
Transformer2DModel(
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
use_linear_projection=True,
upcast_attention=upcast_attention,
norm_num_groups=norm_num_groups,
)
)
self.resnets = nn.ModuleList(resnets)
self.attentions = nn.ModuleList(attentions) if has_crossattn else [None] * num_layers
self.ctrl_to_base = nn.ModuleList(ctrl_to_base)
if add_upsample:
self.upsamplers = Upsample2D(out_channels, use_conv=True, out_channels=out_channels)
else:
self.upsamplers = None
self.gradient_checkpointing = False
self.resolution_idx = resolution_idx
@classmethod
def from_modules(cls, base_upblock: CrossAttnUpBlock2D, ctrl_upblock: UpBlockControlNetXSAdapter):
ctrl_to_base_skip_connections = ctrl_upblock.ctrl_to_base
# get params
def get_first_cross_attention(block):
return block.attentions[0].transformer_blocks[0].attn2
out_channels = base_upblock.resnets[0].out_channels
in_channels = base_upblock.resnets[-1].in_channels - out_channels
prev_output_channels = base_upblock.resnets[0].in_channels - out_channels
ctrl_skip_channelss = [c.in_channels for c in ctrl_to_base_skip_connections]
temb_channels = base_upblock.resnets[0].time_emb_proj.in_features
num_groups = base_upblock.resnets[0].norm1.num_groups
resolution_idx = base_upblock.resolution_idx
if hasattr(base_upblock, "attentions"):
has_crossattn = True
transformer_layers_per_block = len(base_upblock.attentions[0].transformer_blocks)
num_attention_heads = get_first_cross_attention(base_upblock).heads
cross_attention_dim = get_first_cross_attention(base_upblock).cross_attention_dim
upcast_attention = get_first_cross_attention(base_upblock).upcast_attention
else:
has_crossattn = False
transformer_layers_per_block = None
num_attention_heads = None
cross_attention_dim = None
upcast_attention = None
add_upsample = base_upblock.upsamplers is not None
# create model
model = cls(
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channels,
ctrl_skip_channels=ctrl_skip_channelss,
temb_channels=temb_channels,
norm_num_groups=num_groups,
resolution_idx=resolution_idx,
has_crossattn=has_crossattn,
transformer_layers_per_block=transformer_layers_per_block,
num_attention_heads=num_attention_heads,
cross_attention_dim=cross_attention_dim,
add_upsample=add_upsample,
upcast_attention=upcast_attention,
)
# load weights
model.resnets.load_state_dict(base_upblock.resnets.state_dict())
if has_crossattn:
model.attentions.load_state_dict(base_upblock.attentions.state_dict())
if add_upsample:
model.upsamplers.load_state_dict(base_upblock.upsamplers[0].state_dict())
model.ctrl_to_base.load_state_dict(ctrl_to_base_skip_connections.state_dict())
return model
def freeze_base_params(self) -> None:
"""Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine
tuning."""
# Unfreeze everything
for param in self.parameters():
param.requires_grad = True
# Freeze base part
base_parts = [self.resnets]
if isinstance(self.attentions, nn.ModuleList): # attentions can be a list of Nones
base_parts.append(self.attentions)
if self.upsamplers is not None:
base_parts.append(self.upsamplers)
for part in base_parts:
for param in part.parameters():
param.requires_grad = False
def forward(
self,
hidden_states: FloatTensor,
res_hidden_states_tuple_base: Tuple[FloatTensor, ...],
res_hidden_states_tuple_ctrl: Tuple[FloatTensor, ...],
temb: FloatTensor,
encoder_hidden_states: Optional[FloatTensor] = None,
conditioning_scale: Optional[float] = 1.0,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
attention_mask: Optional[FloatTensor] = None,
upsample_size: Optional[int] = None,
encoder_attention_mask: Optional[FloatTensor] = None,
apply_control: bool = True,
) -> FloatTensor:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
and getattr(self, "b1", None)
and getattr(self, "b2", None)
)
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
def maybe_apply_freeu_to_subblock(hidden_states, res_h_base):
# FreeU: Only operate on the first two stages
if is_freeu_enabled:
return apply_freeu(
self.resolution_idx,
hidden_states,
res_h_base,
s1=self.s1,
s2=self.s2,
b1=self.b1,
b2=self.b2,
)
else:
return hidden_states, res_h_base
for resnet, attn, c2b, res_h_base, res_h_ctrl in zip(
self.resnets,
self.attentions,
self.ctrl_to_base,
reversed(res_hidden_states_tuple_base),
reversed(res_hidden_states_tuple_ctrl),
):
if apply_control:
hidden_states += c2b(res_h_ctrl) * conditioning_scale
hidden_states, res_h_base = maybe_apply_freeu_to_subblock(hidden_states, res_h_base)
hidden_states = torch.cat([hidden_states, res_h_base], dim=1)
if self.training and self.gradient_checkpointing:
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
else:
hidden_states = resnet(hidden_states, temb)
if attn is not None:
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
if self.upsamplers is not None:
hidden_states = self.upsamplers(hidden_states, upsample_size)
return hidden_states
def make_zero_conv(in_channels, out_channels=None):
return zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0))
def zero_module(module):
for p in module.parameters():
nn.init.zeros_(p)
return module
def find_largest_factor(number, max_factor):
factor = max_factor
if factor >= number:
return number
while factor != 0:
residual = number % factor
if residual == 0:
return factor
factor -= 1
......@@ -746,6 +746,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
self,
in_channels: int,
temb_channels: int,
out_channels: Optional[int] = None,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
......@@ -753,6 +754,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_groups_out: Optional[int] = None,
resnet_pre_norm: bool = True,
num_attention_heads: int = 1,
output_scale_factor: float = 1.0,
......@@ -764,6 +766,10 @@ class UNetMidBlock2DCrossAttn(nn.Module):
):
super().__init__()
out_channels = out_channels or in_channels
self.in_channels = in_channels
self.out_channels = out_channels
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
......@@ -772,14 +778,17 @@ class UNetMidBlock2DCrossAttn(nn.Module):
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
resnet_groups_out = resnet_groups_out or resnet_groups
# there is always at least one resnet
resnets = [
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
groups_out=resnet_groups_out,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
......@@ -794,11 +803,11 @@ class UNetMidBlock2DCrossAttn(nn.Module):
attentions.append(
Transformer2DModel(
num_attention_heads,
in_channels // num_attention_heads,
in_channels=in_channels,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
norm_num_groups=resnet_groups_out,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
attention_type=attention_type,
......@@ -808,8 +817,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
attentions.append(
DualTransformer2DModel(
num_attention_heads,
in_channels // num_attention_heads,
in_channels=in_channels,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
......@@ -817,11 +826,11 @@ class UNetMidBlock2DCrossAttn(nn.Module):
)
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
in_channels=out_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
groups=resnet_groups_out,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
......
......@@ -134,6 +134,12 @@ else:
"StableDiffusionXLControlNetPipeline",
]
)
_import_structure["controlnet_xs"].extend(
[
"StableDiffusionControlNetXSPipeline",
"StableDiffusionXLControlNetXSPipeline",
]
)
_import_structure["deepfloyd_if"] = [
"IFImg2ImgPipeline",
"IFImg2ImgSuperResolutionPipeline",
......@@ -378,6 +384,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionXLControlNetInpaintPipeline,
StableDiffusionXLControlNetPipeline,
)
from .controlnet_xs import (
StableDiffusionControlNetXSPipeline,
StableDiffusionXLControlNetXSPipeline,
)
from .deepfloyd_if import (
IFImg2ImgPipeline,
IFImg2ImgSuperResolutionPipeline,
......
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_flax_available,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_controlnet_xs"] = ["StableDiffusionControlNetXSPipeline"]
_import_structure["pipeline_controlnet_xs_sd_xl"] = ["StableDiffusionXLControlNetXSPipeline"]
try:
if not (is_transformers_available() and is_flax_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_flax_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
else:
pass # _import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_controlnet_xs import StableDiffusionControlNetXSPipeline
from .pipeline_controlnet_xs_sd_xl import StableDiffusionXLControlNetXSPipeline
try:
if not (is_transformers_available() and is_flax_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_flax_and_transformers_objects import * # noqa F403
else:
pass # from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
......@@ -19,30 +19,75 @@ import numpy as np
import PIL.Image
import torch
import torch.nn.functional as F
from controlnetxs import ControlNetXSModel
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import (
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> # !pip install opencv-python transformers accelerate
>>> from diffusers import StableDiffusionControlNetXSPipeline, ControlNetXSAdapter
>>> from diffusers.utils import load_image
>>> import numpy as np
>>> import torch
>>> import cv2
>>> from PIL import Image
>>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
>>> negative_prompt = "low quality, bad quality, sketches"
>>> # download an image
>>> image = load_image(
... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
... )
>>> # initialize the models and pipeline
>>> controlnet_conditioning_scale = 0.5
>>> controlnet = ControlNetXSAdapter.from_pretrained(
... "UmerHA/Testing-ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16
... )
>>> pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
... "stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, torch_dtype=torch.float16
... )
>>> pipe.enable_model_cpu_offload()
>>> # get canny image
>>> image = np.array(image)
>>> image = cv2.Canny(image, 100, 200)
>>> image = image[:, :, None]
>>> image = np.concatenate([image, image, image], axis=2)
>>> canny_image = Image.fromarray(image)
>>> # generate image
>>> image = pipe(
... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image
... ).images[0]
```
"""
class StableDiffusionControlNetXSPipeline(
DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
):
......@@ -56,7 +101,7 @@ class StableDiffusionControlNetXSPipeline(
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
Args:
vae ([`AutoencoderKL`]):
......@@ -66,9 +111,9 @@ class StableDiffusionControlNetXSPipeline(
tokenizer ([`~transformers.CLIPTokenizer`]):
A `CLIPTokenizer` to tokenize text.
unet ([`UNet2DConditionModel`]):
A `UNet2DConditionModel` to denoise the encoded image latents.
controlnet ([`ControlNetXSModel`]):
Provides additional conditioning to the `unet` during the denoising process.
A [`UNet2DConditionModel`] used to create a UNetControlNetXSModel to denoise the encoded image latents.
controlnet ([`ControlNetXSAdapter`]):
A [`ControlNetXSAdapter`] to be used in combination with `unet` to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
......@@ -80,17 +125,18 @@ class StableDiffusionControlNetXSPipeline(
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
model_cpu_offload_seq = "text_encoder->unet->vae>controlnet"
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
controlnet: ControlNetXSModel,
unet: Union[UNet2DConditionModel, UNetControlNetXSModel],
controlnet: ControlNetXSAdapter,
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
......@@ -98,6 +144,9 @@ class StableDiffusionControlNetXSPipeline(
):
super().__init__()
if isinstance(unet, UNet2DConditionModel):
unet = UNetControlNetXSModel.from_unet(unet, controlnet)
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
......@@ -114,14 +163,6 @@ class StableDiffusionControlNetXSPipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
vae_compatible, cnxs_condition_downsample_factor, vae_downsample_factor = controlnet._check_if_vae_compatible(
vae
)
if not vae_compatible:
raise ValueError(
f"The downsampling factors of the VAE ({vae_downsample_factor}) and the conditioning part of ControlNetXS model {cnxs_condition_downsample_factor} need to be equal. Consider building the ControlNetXS model with different `conditioning_block_sizes`."
)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
......@@ -403,20 +444,19 @@ class StableDiffusionControlNetXSPipeline(
self,
prompt,
image,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
controlnet_conditioning_scale=1.0,
control_guidance_start=0.0,
control_guidance_end=1.0,
callback_on_step_end_tensor_inputs=None,
):
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
......@@ -445,25 +485,16 @@ class StableDiffusionControlNetXSPipeline(
f" {negative_prompt_embeds.shape}."
)
# Check `image`
# Check `image` and `controlnet_conditioning_scale`
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
self.unet, torch._dynamo.eval_frame.OptimizedModule
)
if (
isinstance(self.controlnet, ControlNetXSModel)
isinstance(self.unet, UNetControlNetXSModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, ControlNetXSModel)
and isinstance(self.unet._orig_mod, UNetControlNetXSModel)
):
self.check_image(image, prompt, prompt_embeds)
else:
assert False
# Check `controlnet_conditioning_scale`
if (
isinstance(self.controlnet, ControlNetXSModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, ControlNetXSModel)
):
if not isinstance(controlnet_conditioning_scale, float):
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
else:
......@@ -563,7 +594,33 @@ class StableDiffusionControlNetXSPipeline(
latents = latents * self.scheduler.init_noise_sigma
return latents
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale
def guidance_scale(self):
return self._guidance_scale
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip
def clip_skip(self):
return self._clip_skip
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance
def do_classifier_free_guidance(self):
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
......@@ -581,13 +638,13 @@ class StableDiffusionControlNetXSPipeline(
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
control_guidance_start: float = 0.0,
control_guidance_end: float = 1.0,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
):
r"""
The call function to the pipeline for generation.
......@@ -595,7 +652,7 @@ class StableDiffusionControlNetXSPipeline(
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
......@@ -639,12 +696,6 @@ class StableDiffusionControlNetXSPipeline(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
......@@ -659,7 +710,15 @@ class StableDiffusionControlNetXSPipeline(
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeine class.
Examples:
Returns:
......@@ -669,21 +728,27 @@ class StableDiffusionControlNetXSPipeline(
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
unet = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
image,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
......@@ -713,6 +778,7 @@ class StableDiffusionControlNetXSPipeline(
lora_scale=text_encoder_lora_scale,
clip_skip=clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
......@@ -720,27 +786,24 @@ class StableDiffusionControlNetXSPipeline(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# 4. Prepare image
if isinstance(controlnet, ControlNetXSModel):
image = self.prepare_image(
image=image,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
)
height, width = image.shape[-2:]
else:
assert False
image = self.prepare_image(
image=image,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=unet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
)
height, width = image.shape[-2:]
# 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 6. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
num_channels_latents = self.unet.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
......@@ -757,42 +820,33 @@ class StableDiffusionControlNetXSPipeline(
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
is_unet_compiled = is_compiled_module(self.unet)
is_controlnet_compiled = is_compiled_module(self.controlnet)
self._num_timesteps = len(timesteps)
is_controlnet_compiled = is_compiled_module(self.unet)
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
if is_controlnet_compiled and is_torch_higher_equal_2_1:
torch._inductor.cudagraph_mark_step_begin()
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
dont_control = (
i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end
apply_control = (
i / len(timesteps) >= control_guidance_start and (i + 1) / len(timesteps) <= control_guidance_end
)
if dont_control:
noise_pred = self.unet(
sample=latent_model_input,
timestep=t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=True,
).sample
else:
noise_pred = self.controlnet(
base_model=self.unet,
sample=latent_model_input,
timestep=t,
encoder_hidden_states=prompt_embeds,
controlnet_cond=image,
conditioning_scale=controlnet_conditioning_scale,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=True,
).sample
noise_pred = self.unet(
sample=latent_model_input,
timestep=t,
encoder_hidden_states=prompt_embeds,
controlnet_cond=image,
conditioning_scale=controlnet_conditioning_scale,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=True,
apply_control=apply_control,
).sample
# perform guidance
if do_classifier_free_guidance:
......@@ -801,12 +855,18 @@ class StableDiffusionControlNetXSPipeline(
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
......
......@@ -19,41 +19,94 @@ import numpy as np
import PIL.Image
import torch
import torch.nn.functional as F
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from transformers import (
CLIPImageProcessor,
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPTokenizer,
)
from diffusers.utils.import_utils import is_invisible_watermark_available
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, ControlNetXSModel, UNet2DConditionModel
from diffusers.models.attention_processor import (
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel
from ...models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import (
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from diffusers.utils.import_utils import is_invisible_watermark_available
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
if is_invisible_watermark_available():
from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> # !pip install opencv-python transformers accelerate
>>> from diffusers import StableDiffusionXLControlNetXSPipeline, ControlNetXSAdapter, AutoencoderKL
>>> from diffusers.utils import load_image
>>> import numpy as np
>>> import torch
>>> import cv2
>>> from PIL import Image
>>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
>>> negative_prompt = "low quality, bad quality, sketches"
>>> # download an image
>>> image = load_image(
... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
... )
>>> # initialize the models and pipeline
>>> controlnet_conditioning_scale = 0.5
>>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
>>> controlnet = ControlNetXSAdapter.from_pretrained(
... "UmerHA/Testing-ConrolNetXS-SDXL-canny", torch_dtype=torch.float16
... )
>>> pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained(
... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
... )
>>> pipe.enable_model_cpu_offload()
>>> # get canny image
>>> image = np.array(image)
>>> image = cv2.Canny(image, 100, 200)
>>> image = image[:, :, None]
>>> image = np.concatenate([image, image, image], axis=2)
>>> canny_image = Image.fromarray(image)
>>> # generate image
>>> image = pipe(
... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image
... ).images[0]
```
"""
class StableDiffusionXLControlNetXSPipeline(
DiffusionPipeline,
StableDiffusionMixin,
TextualInversionLoaderMixin,
StableDiffusionXLLoraLoaderMixin,
FromSingleFileMixin,
......@@ -66,9 +119,8 @@ class StableDiffusionXLControlNetXSPipeline(
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
Args:
vae ([`AutoencoderKL`]):
......@@ -83,9 +135,9 @@ class StableDiffusionXLControlNetXSPipeline(
tokenizer_2 ([`~transformers.CLIPTokenizer`]):
A `CLIPTokenizer` to tokenize text.
unet ([`UNet2DConditionModel`]):
A `UNet2DConditionModel` to denoise the encoded image latents.
controlnet ([`ControlNetXSModel`]:
Provides additional conditioning to the `unet` during the denoising process.
A [`UNet2DConditionModel`] used to create a UNetControlNetXSModel to denoise the encoded image latents.
controlnet ([`ControlNetXSAdapter`]):
A [`ControlNetXSAdapter`] to be used in combination with `unet` to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
......@@ -98,9 +150,15 @@ class StableDiffusionXLControlNetXSPipeline(
watermarker is used.
"""
# leave controlnet out on purpose because it iterates with unet
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae->controlnet"
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
_optional_components = [
"tokenizer",
"tokenizer_2",
"text_encoder",
"text_encoder_2",
"feature_extractor",
]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
......@@ -109,21 +167,17 @@ class StableDiffusionXLControlNetXSPipeline(
text_encoder_2: CLIPTextModelWithProjection,
tokenizer: CLIPTokenizer,
tokenizer_2: CLIPTokenizer,
unet: UNet2DConditionModel,
controlnet: ControlNetXSModel,
unet: Union[UNet2DConditionModel, UNetControlNetXSModel],
controlnet: ControlNetXSAdapter,
scheduler: KarrasDiffusionSchedulers,
force_zeros_for_empty_prompt: bool = True,
add_watermarker: Optional[bool] = None,
feature_extractor: CLIPImageProcessor = None,
):
super().__init__()
vae_compatible, cnxs_condition_downsample_factor, vae_downsample_factor = controlnet._check_if_vae_compatible(
vae
)
if not vae_compatible:
raise ValueError(
f"The downsampling factors of the VAE ({vae_downsample_factor}) and the conditioning part of ControlNetXS model {cnxs_condition_downsample_factor} need to be equal. Consider building the ControlNetXS model with different `conditioning_block_sizes`."
)
if isinstance(unet, UNet2DConditionModel):
unet = UNetControlNetXSModel.from_unet(unet, controlnet)
self.register_modules(
vae=vae,
......@@ -134,6 +188,7 @@ class StableDiffusionXLControlNetXSPipeline(
unet=unet,
controlnet=controlnet,
scheduler=scheduler,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
......@@ -417,15 +472,21 @@ class StableDiffusionXLControlNetXSPipeline(
controlnet_conditioning_scale=1.0,
control_guidance_start=0.0,
control_guidance_end=1.0,
callback_on_step_end_tensor_inputs=None,
):
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
......@@ -474,25 +535,16 @@ class StableDiffusionXLControlNetXSPipeline(
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
)
# Check `image`
# Check `image` and ``controlnet_conditioning_scale``
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
self.unet, torch._dynamo.eval_frame.OptimizedModule
)
if (
isinstance(self.controlnet, ControlNetXSModel)
isinstance(self.unet, UNetControlNetXSModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, ControlNetXSModel)
and isinstance(self.unet._orig_mod, UNetControlNetXSModel)
):
self.check_image(image, prompt, prompt_embeds)
else:
assert False
# Check `controlnet_conditioning_scale`
if (
isinstance(self.controlnet, ControlNetXSModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, ControlNetXSModel)
):
if not isinstance(controlnet_conditioning_scale, float):
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
else:
......@@ -593,7 +645,6 @@ class StableDiffusionXLControlNetXSPipeline(
latents = latents * self.scheduler.init_noise_sigma
return latents
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids
def _get_add_time_ids(
self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
):
......@@ -602,7 +653,7 @@ class StableDiffusionXLControlNetXSPipeline(
passed_add_embed_dim = (
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
)
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
expected_add_embed_dim = self.unet.base_add_embedding.linear_1.in_features
if expected_add_embed_dim != passed_add_embed_dim:
raise ValueError(
......@@ -632,7 +683,33 @@ class StableDiffusionXLControlNetXSPipeline(
self.vae.decoder.conv_in.to(dtype)
self.vae.decoder.mid_block.to(dtype)
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale
def guidance_scale(self):
return self._guidance_scale
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip
def clip_skip(self):
return self._clip_skip
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance
def do_classifier_free_guidance(self):
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
......@@ -654,8 +731,6 @@ class StableDiffusionXLControlNetXSPipeline(
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
control_guidance_start: float = 0.0,
......@@ -667,6 +742,9 @@ class StableDiffusionXLControlNetXSPipeline(
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
negative_target_size: Optional[Tuple[int, int]] = None,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
):
r"""
The call function to the pipeline for generation.
......@@ -677,7 +755,7 @@ class StableDiffusionXLControlNetXSPipeline(
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in both text-encoders.
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
......@@ -735,12 +813,6 @@ class StableDiffusionXLControlNetXSPipeline(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
......@@ -783,6 +855,15 @@ class StableDiffusionXLControlNetXSPipeline(
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeine class.
Examples:
......@@ -791,7 +872,24 @@ class StableDiffusionXLControlNetXSPipeline(
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] is
returned, otherwise a `tuple` is returned containing the output images.
"""
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
unet = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet
# 1. Check inputs. Raise error if not correct
self.check_inputs(
......@@ -808,8 +906,14 @@ class StableDiffusionXLControlNetXSPipeline(
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
......@@ -850,7 +954,7 @@ class StableDiffusionXLControlNetXSPipeline(
)
# 4. Prepare image
if isinstance(controlnet, ControlNetXSModel):
if isinstance(unet, UNetControlNetXSModel):
image = self.prepare_image(
image=image,
width=width,
......@@ -858,7 +962,7 @@ class StableDiffusionXLControlNetXSPipeline(
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
dtype=unet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
)
height, width = image.shape[-2:]
......@@ -870,7 +974,7 @@ class StableDiffusionXLControlNetXSPipeline(
timesteps = self.scheduler.timesteps
# 6. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
num_channels_latents = self.unet.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
......@@ -928,14 +1032,14 @@ class StableDiffusionXLControlNetXSPipeline(
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
is_unet_compiled = is_compiled_module(self.unet)
is_controlnet_compiled = is_compiled_module(self.controlnet)
self._num_timesteps = len(timesteps)
is_controlnet_compiled = is_compiled_module(self.unet)
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
if is_controlnet_compiled and is_torch_higher_equal_2_1:
torch._inductor.cudagraph_mark_step_begin()
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
......@@ -944,30 +1048,20 @@ class StableDiffusionXLControlNetXSPipeline(
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
# predict the noise residual
dont_control = (
i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end
apply_control = (
i / len(timesteps) >= control_guidance_start and (i + 1) / len(timesteps) <= control_guidance_end
)
if dont_control:
noise_pred = self.unet(
sample=latent_model_input,
timestep=t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=True,
).sample
else:
noise_pred = self.controlnet(
base_model=self.unet,
sample=latent_model_input,
timestep=t,
encoder_hidden_states=prompt_embeds,
controlnet_cond=image,
conditioning_scale=controlnet_conditioning_scale,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=True,
).sample
noise_pred = self.unet(
sample=latent_model_input,
timestep=t,
encoder_hidden_states=prompt_embeds,
controlnet_cond=image,
conditioning_scale=controlnet_conditioning_scale,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=True,
apply_control=apply_control,
).sample
# perform guidance
if do_classifier_free_guidance:
......@@ -977,6 +1071,16 @@ class StableDiffusionXLControlNetXSPipeline(
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
......@@ -984,6 +1088,11 @@ class StableDiffusionXLControlNetXSPipeline(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# manually for max memory savings
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
self.upcast_vae()
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
......
......@@ -2238,6 +2238,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
self,
in_channels: int,
temb_channels: int,
out_channels: Optional[int] = None,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
......@@ -2245,6 +2246,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_groups_out: Optional[int] = None,
resnet_pre_norm: bool = True,
num_attention_heads: int = 1,
output_scale_factor: float = 1.0,
......@@ -2256,6 +2258,10 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
):
super().__init__()
out_channels = out_channels or in_channels
self.in_channels = in_channels
self.out_channels = out_channels
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
......@@ -2264,14 +2270,17 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
resnet_groups_out = resnet_groups_out or resnet_groups
# there is always at least one resnet
resnets = [
ResnetBlockFlat(
in_channels=in_channels,
out_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
groups_out=resnet_groups_out,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
......@@ -2286,11 +2295,11 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
attentions.append(
Transformer2DModel(
num_attention_heads,
in_channels // num_attention_heads,
in_channels=in_channels,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
norm_num_groups=resnet_groups_out,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
attention_type=attention_type,
......@@ -2300,8 +2309,8 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
attentions.append(
DualTransformer2DModel(
num_attention_heads,
in_channels // num_attention_heads,
in_channels=in_channels,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
......@@ -2309,11 +2318,11 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
)
resnets.append(
ResnetBlockFlat(
in_channels=in_channels,
out_channels=in_channels,
in_channels=out_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
groups=resnet_groups_out,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
......
......@@ -92,6 +92,21 @@ class ControlNetModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class ControlNetXSAdapter(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class I2VGenXLUNet(metaclass=DummyObject):
_backends = ["torch"]
......@@ -287,6 +302,21 @@ class UNet3DConditionModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class UNetControlNetXSModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class UNetMotionModel(metaclass=DummyObject):
_backends = ["torch"]
......
......@@ -902,6 +902,21 @@ class StableDiffusionControlNetPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionControlNetXSPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionDepth2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
......@@ -1247,6 +1262,21 @@ class StableDiffusionXLControlNetPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionXLControlNetXSPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionXLImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
......
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import unittest
import numpy as np
import torch
from torch import nn
from diffusers import ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel
from diffusers.utils import logging
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
logger = logging.get_logger(__name__)
enable_full_determinism()
class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNetControlNetXSModel
main_input_name = "sample"
@property
def dummy_input(self):
batch_size = 4
num_channels = 4
sizes = (16, 16)
conditioning_image_size = (3, 32, 32) # size of additional, unprocessed image for control-conditioning
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
controlnet_cond = floats_tensor((batch_size, *conditioning_image_size)).to(torch_device)
conditioning_scale = 1
return {
"sample": noise,
"timestep": time_step,
"encoder_hidden_states": encoder_hidden_states,
"controlnet_cond": controlnet_cond,
"conditioning_scale": conditioning_scale,
}
@property
def input_shape(self):
return (4, 16, 16)
@property
def output_shape(self):
return (4, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"sample_size": 16,
"down_block_types": ("DownBlock2D", "CrossAttnDownBlock2D"),
"up_block_types": ("CrossAttnUpBlock2D", "UpBlock2D"),
"block_out_channels": (4, 8),
"cross_attention_dim": 8,
"transformer_layers_per_block": 1,
"num_attention_heads": 2,
"norm_num_groups": 4,
"upcast_attention": False,
"ctrl_block_out_channels": [2, 4],
"ctrl_num_attention_heads": 4,
"ctrl_max_norm_num_groups": 2,
"ctrl_conditioning_embedding_out_channels": (2, 2),
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_unet(self):
"""For some tests we also need the underlying UNet. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Adapter"""
return UNet2DConditionModel(
block_out_channels=(4, 8),
layers_per_block=2,
sample_size=16,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=8,
norm_num_groups=4,
use_linear_projection=True,
)
def get_dummy_controlnet_from_unet(self, unet, **kwargs):
"""For some tests we also need the underlying ControlNetXS-Adapter. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Adapter"""
# size_ratio and conditioning_embedding_out_channels chosen to keep model small
return ControlNetXSAdapter.from_unet(unet, size_ratio=1, conditioning_embedding_out_channels=(2, 2), **kwargs)
def test_from_unet(self):
unet = self.get_dummy_unet()
controlnet = self.get_dummy_controlnet_from_unet(unet)
model = UNetControlNetXSModel.from_unet(unet, controlnet)
model_state_dict = model.state_dict()
def assert_equal_weights(module, weight_dict_prefix):
for param_name, param_value in module.named_parameters():
assert torch.equal(model_state_dict[weight_dict_prefix + "." + param_name], param_value)
# # check unet
# everything expect down,mid,up blocks
modules_from_unet = [
"time_embedding",
"conv_in",
"conv_norm_out",
"conv_out",
]
for p in modules_from_unet:
assert_equal_weights(getattr(unet, p), "base_" + p)
optional_modules_from_unet = [
"class_embedding",
"add_time_proj",
"add_embedding",
]
for p in optional_modules_from_unet:
if hasattr(unet, p) and getattr(unet, p) is not None:
assert_equal_weights(getattr(unet, p), "base_" + p)
# down blocks
assert len(unet.down_blocks) == len(model.down_blocks)
for i, d in enumerate(unet.down_blocks):
assert_equal_weights(d.resnets, f"down_blocks.{i}.base_resnets")
if hasattr(d, "attentions"):
assert_equal_weights(d.attentions, f"down_blocks.{i}.base_attentions")
if hasattr(d, "downsamplers") and getattr(d, "downsamplers") is not None:
assert_equal_weights(d.downsamplers[0], f"down_blocks.{i}.base_downsamplers")
# mid block
assert_equal_weights(unet.mid_block, "mid_block.base_midblock")
# up blocks
assert len(unet.up_blocks) == len(model.up_blocks)
for i, u in enumerate(unet.up_blocks):
assert_equal_weights(u.resnets, f"up_blocks.{i}.resnets")
if hasattr(u, "attentions"):
assert_equal_weights(u.attentions, f"up_blocks.{i}.attentions")
if hasattr(u, "upsamplers") and getattr(u, "upsamplers") is not None:
assert_equal_weights(u.upsamplers[0], f"up_blocks.{i}.upsamplers")
# # check controlnet
# everything expect down,mid,up blocks
modules_from_controlnet = {
"controlnet_cond_embedding": "controlnet_cond_embedding",
"conv_in": "ctrl_conv_in",
"control_to_base_for_conv_in": "control_to_base_for_conv_in",
}
optional_modules_from_controlnet = {"time_embedding": "ctrl_time_embedding"}
for name_in_controlnet, name_in_unetcnxs in modules_from_controlnet.items():
assert_equal_weights(getattr(controlnet, name_in_controlnet), name_in_unetcnxs)
for name_in_controlnet, name_in_unetcnxs in optional_modules_from_controlnet.items():
if hasattr(controlnet, name_in_controlnet) and getattr(controlnet, name_in_controlnet) is not None:
assert_equal_weights(getattr(controlnet, name_in_controlnet), name_in_unetcnxs)
# down blocks
assert len(controlnet.down_blocks) == len(model.down_blocks)
for i, d in enumerate(controlnet.down_blocks):
assert_equal_weights(d.resnets, f"down_blocks.{i}.ctrl_resnets")
assert_equal_weights(d.base_to_ctrl, f"down_blocks.{i}.base_to_ctrl")
assert_equal_weights(d.ctrl_to_base, f"down_blocks.{i}.ctrl_to_base")
if d.attentions is not None:
assert_equal_weights(d.attentions, f"down_blocks.{i}.ctrl_attentions")
if d.downsamplers is not None:
assert_equal_weights(d.downsamplers, f"down_blocks.{i}.ctrl_downsamplers")
# mid block
assert_equal_weights(controlnet.mid_block.base_to_ctrl, "mid_block.base_to_ctrl")
assert_equal_weights(controlnet.mid_block.midblock, "mid_block.ctrl_midblock")
assert_equal_weights(controlnet.mid_block.ctrl_to_base, "mid_block.ctrl_to_base")
# up blocks
assert len(controlnet.up_connections) == len(model.up_blocks)
for i, u in enumerate(controlnet.up_connections):
assert_equal_weights(u.ctrl_to_base, f"up_blocks.{i}.ctrl_to_base")
def test_freeze_unet(self):
def assert_frozen(module):
for p in module.parameters():
assert not p.requires_grad
def assert_unfrozen(module):
for p in module.parameters():
assert p.requires_grad
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = UNetControlNetXSModel(**init_dict)
model.freeze_unet_params()
# # check unet
# everything expect down,mid,up blocks
modules_from_unet = [
model.base_time_embedding,
model.base_conv_in,
model.base_conv_norm_out,
model.base_conv_out,
]
for m in modules_from_unet:
assert_frozen(m)
optional_modules_from_unet = [
model.base_add_time_proj,
model.base_add_embedding,
]
for m in optional_modules_from_unet:
if m is not None:
assert_frozen(m)
# down blocks
for i, d in enumerate(model.down_blocks):
assert_frozen(d.base_resnets)
if isinstance(d.base_attentions, nn.ModuleList): # attentions can be list of Nones
assert_frozen(d.base_attentions)
if d.base_downsamplers is not None:
assert_frozen(d.base_downsamplers)
# mid block
assert_frozen(model.mid_block.base_midblock)
# up blocks
for i, u in enumerate(model.up_blocks):
assert_frozen(u.resnets)
if isinstance(u.attentions, nn.ModuleList): # attentions can be list of Nones
assert_frozen(u.attentions)
if u.upsamplers is not None:
assert_frozen(u.upsamplers)
# # check controlnet
# everything expect down,mid,up blocks
modules_from_controlnet = [
model.controlnet_cond_embedding,
model.ctrl_conv_in,
model.control_to_base_for_conv_in,
]
optional_modules_from_controlnet = [model.ctrl_time_embedding]
for m in modules_from_controlnet:
assert_unfrozen(m)
for m in optional_modules_from_controlnet:
if m is not None:
assert_unfrozen(m)
# down blocks
for d in model.down_blocks:
assert_unfrozen(d.ctrl_resnets)
assert_unfrozen(d.base_to_ctrl)
assert_unfrozen(d.ctrl_to_base)
if isinstance(d.ctrl_attentions, nn.ModuleList): # attentions can be list of Nones
assert_unfrozen(d.ctrl_attentions)
if d.ctrl_downsamplers is not None:
assert_unfrozen(d.ctrl_downsamplers)
# mid block
assert_unfrozen(model.mid_block.base_to_ctrl)
assert_unfrozen(model.mid_block.ctrl_midblock)
assert_unfrozen(model.mid_block.ctrl_to_base)
# up blocks
for u in model.up_blocks:
assert_unfrozen(u.ctrl_to_base)
def test_gradient_checkpointing_is_applied(self):
model_class_copy = copy.copy(UNetControlNetXSModel)
modules_with_gc_enabled = {}
# now monkey patch the following function:
# def _set_gradient_checkpointing(self, module, value=False):
# if hasattr(module, "gradient_checkpointing"):
# module.gradient_checkpointing = value
def _set_gradient_checkpointing_new(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
modules_with_gc_enabled[module.__class__.__name__] = True
model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = model_class_copy(**init_dict)
model.enable_gradient_checkpointing()
EXPECTED_SET = {
"Transformer2DModel",
"UNetMidBlock2DCrossAttn",
"ControlNetXSCrossAttnDownBlock2D",
"ControlNetXSCrossAttnMidBlock2D",
"ControlNetXSCrossAttnUpBlock2D",
}
assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET
assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
def test_forward_no_control(self):
unet = self.get_dummy_unet()
controlnet = self.get_dummy_controlnet_from_unet(unet)
model = UNetControlNetXSModel.from_unet(unet, controlnet)
unet = unet.to(torch_device)
model = model.to(torch_device)
input_ = self.dummy_input
control_specific_input = ["controlnet_cond", "conditioning_scale"]
input_for_unet = {k: v for k, v in input_.items() if k not in control_specific_input}
with torch.no_grad():
unet_output = unet(**input_for_unet).sample.cpu()
unet_controlnet_output = model(**input_, apply_control=False).sample.cpu()
assert np.abs(unet_output.flatten() - unet_controlnet_output.flatten()).max() < 3e-4
def test_time_embedding_mixing(self):
unet = self.get_dummy_unet()
controlnet = self.get_dummy_controlnet_from_unet(unet)
controlnet_mix_time = self.get_dummy_controlnet_from_unet(
unet, time_embedding_mix=0.5, learn_time_embedding=True
)
model = UNetControlNetXSModel.from_unet(unet, controlnet)
model_mix_time = UNetControlNetXSModel.from_unet(unet, controlnet_mix_time)
unet = unet.to(torch_device)
model = model.to(torch_device)
model_mix_time = model_mix_time.to(torch_device)
input_ = self.dummy_input
with torch.no_grad():
output = model(**input_).sample
output_mix_time = model_mix_time(**input_).sample
assert output.shape == output_mix_time.shape
def test_forward_with_norm_groups(self):
# UNetControlNetXSModel currently only supports StableDiffusion and StableDiffusion-XL, both of which have norm_num_groups fixed at 32. So we don't need to test different values for norm_num_groups.
pass
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import traceback
import unittest
import numpy as np
import torch
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import (
AsymmetricAutoencoderKL,
AutoencoderKL,
AutoencoderTiny,
ConsistencyDecoderVAE,
ControlNetXSAdapter,
DDIMScheduler,
LCMScheduler,
StableDiffusionControlNetXSPipeline,
UNet2DConditionModel,
)
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
enable_full_determinism,
load_image,
load_numpy,
require_python39_or_higher,
require_torch_2,
require_torch_gpu,
run_test_in_subprocess,
slow,
torch_device,
)
from diffusers.utils.torch_utils import randn_tensor
from ...models.autoencoders.test_models_vae import (
get_asym_autoencoder_kl_config,
get_autoencoder_kl_config,
get_autoencoder_tiny_config,
get_consistency_vae_config,
)
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_IMAGE_PARAMS,
TEXT_TO_IMAGE_PARAMS,
)
from ..test_pipelines_common import (
PipelineKarrasSchedulerTesterMixin,
PipelineLatentTesterMixin,
PipelineTesterMixin,
SDFunctionTesterMixin,
)
enable_full_determinism()
# Will be run via run_test_in_subprocess
def _test_stable_diffusion_compile(in_queue, out_queue, timeout):
error = None
try:
_ = in_queue.get(timeout=timeout)
controlnet = ControlNetXSAdapter.from_pretrained(
"UmerHA/Testing-ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16
)
pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1-base",
controlnet=controlnet,
safety_checker=None,
torch_dtype=torch.float16,
)
pipe.to("cuda")
pipe.set_progress_bar_config(disable=None)
pipe.unet.to(memory_format=torch.channels_last)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
generator = torch.Generator(device="cpu").manual_seed(0)
prompt = "bird"
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
).resize((512, 512))
output = pipe(prompt, image, num_inference_steps=10, generator=generator, output_type="np")
image = output.images[0]
assert image.shape == (512, 512, 3)
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny_out_full.npy"
)
expected_image = np.resize(expected_image, (512, 512, 3))
assert np.abs(expected_image - image).max() < 1.0
except Exception:
error = f"{traceback.format_exc()}"
results = {"error": error}
out_queue.put(results, timeout=timeout)
out_queue.join()
class ControlNetXSPipelineFastTests(
PipelineLatentTesterMixin,
PipelineKarrasSchedulerTesterMixin,
PipelineTesterMixin,
SDFunctionTesterMixin,
unittest.TestCase,
):
pipeline_class = StableDiffusionControlNetXSPipeline
params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
test_attention_slicing = False
def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(4, 8),
layers_per_block=2,
sample_size=16,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=8,
norm_num_groups=4,
time_cond_proj_dim=time_cond_proj_dim,
use_linear_projection=True,
)
torch.manual_seed(0)
controlnet = ControlNetXSAdapter.from_unet(
unet=unet,
size_ratio=1,
learn_time_embedding=True,
conditioning_embedding_out_channels=(2, 2),
)
torch.manual_seed(0)
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[4, 8],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
norm_num_groups=2,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=8,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
components = {
"unet": unet,
"controlnet": controlnet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
controlnet_embedder_scale_factor = 2
image = randn_tensor(
(1, 3, 8 * controlnet_embedder_scale_factor, 8 * controlnet_embedder_scale_factor),
generator=generator,
device=torch.device(device),
)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
"output_type": "numpy",
"image": image,
}
return inputs
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3)
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
def test_controlnet_lcm(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components(time_cond_proj_dim=8)
sd_pipe = StableDiffusionControlNetXSPipeline(**components)
sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
output = sd_pipe(**inputs)
image = output.images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 16, 16, 3)
expected_slice = np.array([0.745, 0.753, 0.767, 0.543, 0.523, 0.502, 0.314, 0.521, 0.478])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_to_dtype(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
# pipeline creates a new UNetControlNetXSModel under the hood. So we need to check the dtype from pipe.components
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
pipe.to(dtype=torch.float16)
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
def test_multi_vae(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
block_out_channels = pipe.vae.config.block_out_channels
norm_num_groups = pipe.vae.config.norm_num_groups
vae_classes = [AutoencoderKL, AsymmetricAutoencoderKL, ConsistencyDecoderVAE, AutoencoderTiny]
configs = [
get_autoencoder_kl_config(block_out_channels, norm_num_groups),
get_asym_autoencoder_kl_config(block_out_channels, norm_num_groups),
get_consistency_vae_config(block_out_channels, norm_num_groups),
get_autoencoder_tiny_config(block_out_channels),
]
out_np = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
for vae_cls, config in zip(vae_classes, configs):
vae = vae_cls(**config)
vae = vae.to(torch_device)
components["vae"] = vae
vae_pipe = self.pipeline_class(**components)
# pipeline creates a new UNetControlNetXSModel under the hood, which aren't on device.
# So we need to move the new pipe to device.
vae_pipe.to(torch_device)
vae_pipe.set_progress_bar_config(disable=None)
out_vae_np = vae_pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
assert out_vae_np.shape == out_np.shape
@slow
@require_torch_gpu
class ControlNetXSPipelineSlowTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_canny(self):
controlnet = ControlNetXSAdapter.from_pretrained(
"UmerHA/Testing-ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16
)
pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, torch_dtype=torch.float16
)
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
prompt = "bird"
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
)
output = pipe(prompt, image, generator=generator, output_type="np", num_inference_steps=3)
image = output.images[0]
assert image.shape == (768, 512, 3)
original_image = image[-3:, -3:, -1].flatten()
expected_image = np.array([0.1963, 0.229, 0.2659, 0.2109, 0.2332, 0.2827, 0.2534, 0.2422, 0.2808])
assert np.allclose(original_image, expected_image, atol=1e-04)
def test_depth(self):
controlnet = ControlNetXSAdapter.from_pretrained(
"UmerHA/Testing-ConrolNetXS-SD2.1-depth", torch_dtype=torch.float16
)
pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, torch_dtype=torch.float16
)
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
prompt = "Stormtrooper's lecture"
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/stormtrooper_depth.png"
)
output = pipe(prompt, image, generator=generator, output_type="np", num_inference_steps=3)
image = output.images[0]
assert image.shape == (512, 512, 3)
original_image = image[-3:, -3:, -1].flatten()
expected_image = np.array([0.4844, 0.4937, 0.4956, 0.4663, 0.5039, 0.5044, 0.4565, 0.4883, 0.4941])
assert np.allclose(original_image, expected_image, atol=1e-04)
@require_python39_or_higher
@require_torch_2
def test_stable_diffusion_compile(self):
run_test_in_subprocess(test_case=self, target_func=_test_stable_diffusion_compile, inputs=None)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment