Unverified Commit fda1531d authored by UmerHA's avatar UmerHA Committed by GitHub
Browse files

Fixing implementation of ControlNet-XS (#6772)



* CheckIn - created DownSubBlocks

* Added extra channels, implemented subblock fwd

* Fixed connection sizes

* checkin

* Removed iter, next in forward

* Models for SD21 & SDXL run through

* Added back pipelines, cleared up connections

* Cleaned up connection creation

* added debug logs

* updated logs

* logs: added input loading

* Update umer_debug_logger.py

* log: Loading hint

* Update umer_debug_logger.py

* added logs

* Changed debug logging

* debug: added more logs

* Fixed num_norm_groups

* Debug: Logging all of SDXL input

* Update umer_debug_logger.py

* debug: updated logs

* checkim

* Readded tests

* Removed debug logs

* Fixed Slow Tests

* Added value ckecks | Updated model_cpu_offload_seq

* accelerate-offloading works ; fast tests work

* Made unet & addon explicit in controlnet

* Updated slow tests

* Added dtype/device to ControlNetXS

* Filled in test model paths

* Added image_encoder/feature_extractor to XL pipe

* Fixed fast tests

* Added comments and docstrings

* Fixed copies

* Added docs ; Updates slow tests

* Moved changes to UNetMidBlock2DCrossAttn

* tiny cleanups

* Removed stray prints

* Removed ip adapters + freeU

- Removed ip adapters + freeU as they don't make sense for ControlNet-XS
- Fixed imports of UNet components

* Fixed test_save_load_float16

* Make style, quality, fix-copies

* Changed loading/saving API for ControlNetXS

- Changed loading/saving API for ControlNetXS
- other small fixes

* Removed ControlNet-XS from research examples

* Make style, quality, fix-copies

* Small fixes

- deleted ControlNetXSModel.init_original
- added time_embedding_mix to StableDiffusionControlNetXSPipeline .from_pretrained / StableDiffusionXLControlNetXSPipeline.from_pretrained
- fixed copy hints

* checkin May 11 '23

* CheckIn Mar 12 '24

* Fixed tests for SD

* Added tests for UNetControlNetXSModel

* Fixed SDXL tests

* cleanup

* Delete Pipfile

* CheckIn Mar 20

Started replacing sub blocks  by `ControlNetXSCrossAttnDownBlock2D` and `ControlNetXSCrossAttnUplock2D`

* check-in Mar 23

* checkin 24 Mar

* Created init for UNetCnxs and CnxsAddon

* CheckIn

* Made from_modules, from_unet and no_control work

* make style,quality,fix-copies & small changes

* Fixed freezing

* Added gradient ckpt'ing; fixed tests

* Fix slow tests(+compile) ; clear naming confusion

* Don't create UNet in init ; removed class_emb

* Incorporated review feedback

- Deleted get_base_pipeline /  get_controlnet_addon for pipes
- Pipes inherit from StableDiffusionXLPipeline
- Made module dicts for cnxs-addon's down/mid/up classes
- Added support for qkv fusion and freeU

* Make style, quality, fix-copies

* Implemented review feedback

* Removed compatibility check for vae/ctrl embedding

* make style, quality, fix-copies

* Delete Pipfile

* Integrated review feedback

- Importing ControlNetConditioningEmbedding now
- get_down/mid/up_block_addon now outside class
- renamed `do_control` to `apply_control`

* Reduced size of test tensors

For this, added `norm_num_groups` as parameter everywhere

* Renamed cnxs-`Addon` to cnxs-`Adapter`

- `ControlNetXSAddon` -> `ControlNetXSAdapter`
- `ControlNetXSAddonDownBlockComponents` -> `DownBlockControlNetXSAdapter`, and similarly for mid/up
- `get_mid_block_addon` -> `get_mid_block_adapter`, and similarly for mid/up

* Fixed save_pretrained/from_pretrained bug

* Removed redundant code

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent cf6e0407
......@@ -282,6 +282,10 @@
title: ControlNet
- local: api/pipelines/controlnet_sdxl
title: ControlNet with Stable Diffusion XL
- local: api/pipelines/controlnetxs
title: ControlNet-XS
- local: api/pipelines/controlnetxs_sdxl
title: ControlNet-XS with Stable Diffusion XL
- local: api/pipelines/dance_diffusion
title: Dance Diffusion
- local: api/pipelines/ddim
......
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# ControlNet-XS
ControlNet-XS was introduced in [ControlNet-XS](https://vislearn.github.io/ControlNet-XS/) by Denis Zavadski and Carsten Rother. It is based on the observation that the control model in the [original ControlNet](https://huggingface.co/papers/2302.05543) can be made much smaller and still produce good results.
......@@ -12,5 +24,16 @@ Here's the overview from the [project page](https://vislearn.github.io/ControlNe
This model was contributed by [UmerHA](https://twitter.com/UmerHAdil). ❤️
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## StableDiffusionControlNetXSPipeline
[[autodoc]] StableDiffusionControlNetXSPipeline
- all
- __call__
> 🧠 Make sure to check out the Schedulers [guide](https://huggingface.co/docs/diffusers/main/en/using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
\ No newline at end of file
## StableDiffusionPipelineOutput
[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# ControlNet-XS with Stable Diffusion XL
ControlNet-XS was introduced in [ControlNet-XS](https://vislearn.github.io/ControlNet-XS/) by Denis Zavadski and Carsten Rother. It is based on the observation that the control model in the [original ControlNet](https://huggingface.co/papers/2302.05543) can be made much smaller and still produce good results.
......@@ -12,4 +24,22 @@ Here's the overview from the [project page](https://vislearn.github.io/ControlNe
This model was contributed by [UmerHA](https://twitter.com/UmerHAdil). ❤️
> 🧠 Make sure to check out the Schedulers [guide](https://huggingface.co/docs/diffusers/main/en/using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
\ No newline at end of file
<Tip warning={true}>
🧪 Many of the SDXL ControlNet checkpoints are experimental, and there is a lot of room for improvement. Feel free to open an [Issue](https://github.com/huggingface/diffusers/issues/new/choose) and leave us feedback on how we can improve!
</Tip>
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## StableDiffusionXLControlNetXSPipeline
[[autodoc]] StableDiffusionXLControlNetXSPipeline
- all
- __call__
## StableDiffusionPipelineOutput
[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
# !pip install opencv-python transformers accelerate
import argparse
import cv2
import numpy as np
import torch
from controlnetxs import ControlNetXSModel
from PIL import Image
from pipeline_controlnet_xs import StableDiffusionControlNetXSPipeline
from diffusers.utils import load_image
parser = argparse.ArgumentParser()
parser.add_argument(
"--prompt", type=str, default="aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
)
parser.add_argument("--negative_prompt", type=str, default="low quality, bad quality, sketches")
parser.add_argument("--controlnet_conditioning_scale", type=float, default=0.7)
parser.add_argument(
"--image_path",
type=str,
default="https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png",
)
parser.add_argument("--num_inference_steps", type=int, default=50)
args = parser.parse_args()
prompt = args.prompt
negative_prompt = args.negative_prompt
# download an image
image = load_image(args.image_path)
# initialize the models and pipeline
controlnet_conditioning_scale = args.controlnet_conditioning_scale
controlnet = ControlNetXSModel.from_pretrained("UmerHA/ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1", controlnet=controlnet, torch_dtype=torch.float16
)
pipe.enable_model_cpu_offload()
# get canny image
image = np.array(image)
image = cv2.Canny(image, 100, 200)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
canny_image = Image.fromarray(image)
num_inference_steps = args.num_inference_steps
# generate image
image = pipe(
prompt,
controlnet_conditioning_scale=controlnet_conditioning_scale,
image=canny_image,
num_inference_steps=num_inference_steps,
).images[0]
image.save("cnxs_sd.canny.png")
# !pip install opencv-python transformers accelerate
import argparse
import cv2
import numpy as np
import torch
from controlnetxs import ControlNetXSModel
from PIL import Image
from pipeline_controlnet_xs import StableDiffusionControlNetXSPipeline
from diffusers.utils import load_image
parser = argparse.ArgumentParser()
parser.add_argument(
"--prompt", type=str, default="aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
)
parser.add_argument("--negative_prompt", type=str, default="low quality, bad quality, sketches")
parser.add_argument("--controlnet_conditioning_scale", type=float, default=0.7)
parser.add_argument(
"--image_path",
type=str,
default="https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png",
)
parser.add_argument("--num_inference_steps", type=int, default=50)
args = parser.parse_args()
prompt = args.prompt
negative_prompt = args.negative_prompt
# download an image
image = load_image(args.image_path)
# initialize the models and pipeline
controlnet_conditioning_scale = args.controlnet_conditioning_scale
controlnet = ControlNetXSModel.from_pretrained("UmerHA/ConrolNetXS-SDXL-canny", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
)
pipe.enable_model_cpu_offload()
# get canny image
image = np.array(image)
image = cv2.Canny(image, 100, 200)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
canny_image = Image.fromarray(image)
num_inference_steps = args.num_inference_steps
# generate image
image = pipe(
prompt,
controlnet_conditioning_scale=controlnet_conditioning_scale,
image=canny_image,
num_inference_steps=num_inference_steps,
).images[0]
image.save("cnxs_sdxl.canny.png")
......@@ -80,6 +80,7 @@ else:
"AutoencoderTiny",
"ConsistencyDecoderVAE",
"ControlNetModel",
"ControlNetXSAdapter",
"I2VGenXLUNet",
"Kandinsky3UNet",
"ModelMixin",
......@@ -94,6 +95,7 @@ else:
"UNet2DConditionModel",
"UNet2DModel",
"UNet3DConditionModel",
"UNetControlNetXSModel",
"UNetMotionModel",
"UNetSpatioTemporalConditionModel",
"UVit2DModel",
......@@ -270,6 +272,7 @@ else:
"StableDiffusionControlNetImg2ImgPipeline",
"StableDiffusionControlNetInpaintPipeline",
"StableDiffusionControlNetPipeline",
"StableDiffusionControlNetXSPipeline",
"StableDiffusionDepth2ImgPipeline",
"StableDiffusionDiffEditPipeline",
"StableDiffusionGLIGENPipeline",
......@@ -293,6 +296,7 @@ else:
"StableDiffusionXLControlNetImg2ImgPipeline",
"StableDiffusionXLControlNetInpaintPipeline",
"StableDiffusionXLControlNetPipeline",
"StableDiffusionXLControlNetXSPipeline",
"StableDiffusionXLImg2ImgPipeline",
"StableDiffusionXLInpaintPipeline",
"StableDiffusionXLInstructPix2PixPipeline",
......@@ -474,6 +478,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderTiny,
ConsistencyDecoderVAE,
ControlNetModel,
ControlNetXSAdapter,
I2VGenXLUNet,
Kandinsky3UNet,
ModelMixin,
......@@ -487,6 +492,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
UNet2DConditionModel,
UNet2DModel,
UNet3DConditionModel,
UNetControlNetXSModel,
UNetMotionModel,
UNetSpatioTemporalConditionModel,
UVit2DModel,
......@@ -642,6 +648,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline,
StableDiffusionControlNetPipeline,
StableDiffusionControlNetXSPipeline,
StableDiffusionDepth2ImgPipeline,
StableDiffusionDiffEditPipeline,
StableDiffusionGLIGENPipeline,
......@@ -665,6 +672,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionXLControlNetImg2ImgPipeline,
StableDiffusionXLControlNetInpaintPipeline,
StableDiffusionXLControlNetPipeline,
StableDiffusionXLControlNetXSPipeline,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLInstructPix2PixPipeline,
......
......@@ -32,6 +32,7 @@ if is_torch_available():
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["controlnet"] = ["ControlNetModel"]
_import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
_import_structure["embeddings"] = ["ImageProjection"]
_import_structure["modeling_utils"] = ["ModelMixin"]
......@@ -68,6 +69,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ConsistencyDecoderVAE,
)
from .controlnet import ControlNetModel
from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel
from .embeddings import ImageProjection
from .modeling_utils import ModelMixin
from .transformers import (
......
This diff is collapsed.
......@@ -746,6 +746,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
self,
in_channels: int,
temb_channels: int,
out_channels: Optional[int] = None,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
......@@ -753,6 +754,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_groups_out: Optional[int] = None,
resnet_pre_norm: bool = True,
num_attention_heads: int = 1,
output_scale_factor: float = 1.0,
......@@ -764,6 +766,10 @@ class UNetMidBlock2DCrossAttn(nn.Module):
):
super().__init__()
out_channels = out_channels or in_channels
self.in_channels = in_channels
self.out_channels = out_channels
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
......@@ -772,14 +778,17 @@ class UNetMidBlock2DCrossAttn(nn.Module):
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
resnet_groups_out = resnet_groups_out or resnet_groups
# there is always at least one resnet
resnets = [
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
groups_out=resnet_groups_out,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
......@@ -794,11 +803,11 @@ class UNetMidBlock2DCrossAttn(nn.Module):
attentions.append(
Transformer2DModel(
num_attention_heads,
in_channels // num_attention_heads,
in_channels=in_channels,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
norm_num_groups=resnet_groups_out,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
attention_type=attention_type,
......@@ -808,8 +817,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
attentions.append(
DualTransformer2DModel(
num_attention_heads,
in_channels // num_attention_heads,
in_channels=in_channels,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
......@@ -817,11 +826,11 @@ class UNetMidBlock2DCrossAttn(nn.Module):
)
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
in_channels=out_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
groups=resnet_groups_out,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
......
......@@ -134,6 +134,12 @@ else:
"StableDiffusionXLControlNetPipeline",
]
)
_import_structure["controlnet_xs"].extend(
[
"StableDiffusionControlNetXSPipeline",
"StableDiffusionXLControlNetXSPipeline",
]
)
_import_structure["deepfloyd_if"] = [
"IFImg2ImgPipeline",
"IFImg2ImgSuperResolutionPipeline",
......@@ -378,6 +384,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionXLControlNetInpaintPipeline,
StableDiffusionXLControlNetPipeline,
)
from .controlnet_xs import (
StableDiffusionControlNetXSPipeline,
StableDiffusionXLControlNetXSPipeline,
)
from .deepfloyd_if import (
IFImg2ImgPipeline,
IFImg2ImgSuperResolutionPipeline,
......
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_flax_available,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_controlnet_xs"] = ["StableDiffusionControlNetXSPipeline"]
_import_structure["pipeline_controlnet_xs_sd_xl"] = ["StableDiffusionXLControlNetXSPipeline"]
try:
if not (is_transformers_available() and is_flax_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_flax_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
else:
pass # _import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_controlnet_xs import StableDiffusionControlNetXSPipeline
from .pipeline_controlnet_xs_sd_xl import StableDiffusionXLControlNetXSPipeline
try:
if not (is_transformers_available() and is_flax_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_flax_and_transformers_objects import * # noqa F403
else:
pass # from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
......@@ -2238,6 +2238,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
self,
in_channels: int,
temb_channels: int,
out_channels: Optional[int] = None,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
......@@ -2245,6 +2246,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_groups_out: Optional[int] = None,
resnet_pre_norm: bool = True,
num_attention_heads: int = 1,
output_scale_factor: float = 1.0,
......@@ -2256,6 +2258,10 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
):
super().__init__()
out_channels = out_channels or in_channels
self.in_channels = in_channels
self.out_channels = out_channels
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
......@@ -2264,14 +2270,17 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
resnet_groups_out = resnet_groups_out or resnet_groups
# there is always at least one resnet
resnets = [
ResnetBlockFlat(
in_channels=in_channels,
out_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
groups_out=resnet_groups_out,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
......@@ -2286,11 +2295,11 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
attentions.append(
Transformer2DModel(
num_attention_heads,
in_channels // num_attention_heads,
in_channels=in_channels,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
norm_num_groups=resnet_groups_out,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
attention_type=attention_type,
......@@ -2300,8 +2309,8 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
attentions.append(
DualTransformer2DModel(
num_attention_heads,
in_channels // num_attention_heads,
in_channels=in_channels,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
......@@ -2309,11 +2318,11 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
)
resnets.append(
ResnetBlockFlat(
in_channels=in_channels,
out_channels=in_channels,
in_channels=out_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
groups=resnet_groups_out,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
......
......@@ -92,6 +92,21 @@ class ControlNetModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class ControlNetXSAdapter(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class I2VGenXLUNet(metaclass=DummyObject):
_backends = ["torch"]
......@@ -287,6 +302,21 @@ class UNet3DConditionModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class UNetControlNetXSModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class UNetMotionModel(metaclass=DummyObject):
_backends = ["torch"]
......
......@@ -902,6 +902,21 @@ class StableDiffusionControlNetPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionControlNetXSPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionDepth2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
......@@ -1247,6 +1262,21 @@ class StableDiffusionXLControlNetPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionXLControlNetXSPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionXLImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
......
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import unittest
import numpy as np
import torch
from torch import nn
from diffusers import ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel
from diffusers.utils import logging
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
logger = logging.get_logger(__name__)
enable_full_determinism()
class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNetControlNetXSModel
main_input_name = "sample"
@property
def dummy_input(self):
batch_size = 4
num_channels = 4
sizes = (16, 16)
conditioning_image_size = (3, 32, 32) # size of additional, unprocessed image for control-conditioning
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
controlnet_cond = floats_tensor((batch_size, *conditioning_image_size)).to(torch_device)
conditioning_scale = 1
return {
"sample": noise,
"timestep": time_step,
"encoder_hidden_states": encoder_hidden_states,
"controlnet_cond": controlnet_cond,
"conditioning_scale": conditioning_scale,
}
@property
def input_shape(self):
return (4, 16, 16)
@property
def output_shape(self):
return (4, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"sample_size": 16,
"down_block_types": ("DownBlock2D", "CrossAttnDownBlock2D"),
"up_block_types": ("CrossAttnUpBlock2D", "UpBlock2D"),
"block_out_channels": (4, 8),
"cross_attention_dim": 8,
"transformer_layers_per_block": 1,
"num_attention_heads": 2,
"norm_num_groups": 4,
"upcast_attention": False,
"ctrl_block_out_channels": [2, 4],
"ctrl_num_attention_heads": 4,
"ctrl_max_norm_num_groups": 2,
"ctrl_conditioning_embedding_out_channels": (2, 2),
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_unet(self):
"""For some tests we also need the underlying UNet. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Adapter"""
return UNet2DConditionModel(
block_out_channels=(4, 8),
layers_per_block=2,
sample_size=16,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=8,
norm_num_groups=4,
use_linear_projection=True,
)
def get_dummy_controlnet_from_unet(self, unet, **kwargs):
"""For some tests we also need the underlying ControlNetXS-Adapter. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Adapter"""
# size_ratio and conditioning_embedding_out_channels chosen to keep model small
return ControlNetXSAdapter.from_unet(unet, size_ratio=1, conditioning_embedding_out_channels=(2, 2), **kwargs)
def test_from_unet(self):
unet = self.get_dummy_unet()
controlnet = self.get_dummy_controlnet_from_unet(unet)
model = UNetControlNetXSModel.from_unet(unet, controlnet)
model_state_dict = model.state_dict()
def assert_equal_weights(module, weight_dict_prefix):
for param_name, param_value in module.named_parameters():
assert torch.equal(model_state_dict[weight_dict_prefix + "." + param_name], param_value)
# # check unet
# everything expect down,mid,up blocks
modules_from_unet = [
"time_embedding",
"conv_in",
"conv_norm_out",
"conv_out",
]
for p in modules_from_unet:
assert_equal_weights(getattr(unet, p), "base_" + p)
optional_modules_from_unet = [
"class_embedding",
"add_time_proj",
"add_embedding",
]
for p in optional_modules_from_unet:
if hasattr(unet, p) and getattr(unet, p) is not None:
assert_equal_weights(getattr(unet, p), "base_" + p)
# down blocks
assert len(unet.down_blocks) == len(model.down_blocks)
for i, d in enumerate(unet.down_blocks):
assert_equal_weights(d.resnets, f"down_blocks.{i}.base_resnets")
if hasattr(d, "attentions"):
assert_equal_weights(d.attentions, f"down_blocks.{i}.base_attentions")
if hasattr(d, "downsamplers") and getattr(d, "downsamplers") is not None:
assert_equal_weights(d.downsamplers[0], f"down_blocks.{i}.base_downsamplers")
# mid block
assert_equal_weights(unet.mid_block, "mid_block.base_midblock")
# up blocks
assert len(unet.up_blocks) == len(model.up_blocks)
for i, u in enumerate(unet.up_blocks):
assert_equal_weights(u.resnets, f"up_blocks.{i}.resnets")
if hasattr(u, "attentions"):
assert_equal_weights(u.attentions, f"up_blocks.{i}.attentions")
if hasattr(u, "upsamplers") and getattr(u, "upsamplers") is not None:
assert_equal_weights(u.upsamplers[0], f"up_blocks.{i}.upsamplers")
# # check controlnet
# everything expect down,mid,up blocks
modules_from_controlnet = {
"controlnet_cond_embedding": "controlnet_cond_embedding",
"conv_in": "ctrl_conv_in",
"control_to_base_for_conv_in": "control_to_base_for_conv_in",
}
optional_modules_from_controlnet = {"time_embedding": "ctrl_time_embedding"}
for name_in_controlnet, name_in_unetcnxs in modules_from_controlnet.items():
assert_equal_weights(getattr(controlnet, name_in_controlnet), name_in_unetcnxs)
for name_in_controlnet, name_in_unetcnxs in optional_modules_from_controlnet.items():
if hasattr(controlnet, name_in_controlnet) and getattr(controlnet, name_in_controlnet) is not None:
assert_equal_weights(getattr(controlnet, name_in_controlnet), name_in_unetcnxs)
# down blocks
assert len(controlnet.down_blocks) == len(model.down_blocks)
for i, d in enumerate(controlnet.down_blocks):
assert_equal_weights(d.resnets, f"down_blocks.{i}.ctrl_resnets")
assert_equal_weights(d.base_to_ctrl, f"down_blocks.{i}.base_to_ctrl")
assert_equal_weights(d.ctrl_to_base, f"down_blocks.{i}.ctrl_to_base")
if d.attentions is not None:
assert_equal_weights(d.attentions, f"down_blocks.{i}.ctrl_attentions")
if d.downsamplers is not None:
assert_equal_weights(d.downsamplers, f"down_blocks.{i}.ctrl_downsamplers")
# mid block
assert_equal_weights(controlnet.mid_block.base_to_ctrl, "mid_block.base_to_ctrl")
assert_equal_weights(controlnet.mid_block.midblock, "mid_block.ctrl_midblock")
assert_equal_weights(controlnet.mid_block.ctrl_to_base, "mid_block.ctrl_to_base")
# up blocks
assert len(controlnet.up_connections) == len(model.up_blocks)
for i, u in enumerate(controlnet.up_connections):
assert_equal_weights(u.ctrl_to_base, f"up_blocks.{i}.ctrl_to_base")
def test_freeze_unet(self):
def assert_frozen(module):
for p in module.parameters():
assert not p.requires_grad
def assert_unfrozen(module):
for p in module.parameters():
assert p.requires_grad
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = UNetControlNetXSModel(**init_dict)
model.freeze_unet_params()
# # check unet
# everything expect down,mid,up blocks
modules_from_unet = [
model.base_time_embedding,
model.base_conv_in,
model.base_conv_norm_out,
model.base_conv_out,
]
for m in modules_from_unet:
assert_frozen(m)
optional_modules_from_unet = [
model.base_add_time_proj,
model.base_add_embedding,
]
for m in optional_modules_from_unet:
if m is not None:
assert_frozen(m)
# down blocks
for i, d in enumerate(model.down_blocks):
assert_frozen(d.base_resnets)
if isinstance(d.base_attentions, nn.ModuleList): # attentions can be list of Nones
assert_frozen(d.base_attentions)
if d.base_downsamplers is not None:
assert_frozen(d.base_downsamplers)
# mid block
assert_frozen(model.mid_block.base_midblock)
# up blocks
for i, u in enumerate(model.up_blocks):
assert_frozen(u.resnets)
if isinstance(u.attentions, nn.ModuleList): # attentions can be list of Nones
assert_frozen(u.attentions)
if u.upsamplers is not None:
assert_frozen(u.upsamplers)
# # check controlnet
# everything expect down,mid,up blocks
modules_from_controlnet = [
model.controlnet_cond_embedding,
model.ctrl_conv_in,
model.control_to_base_for_conv_in,
]
optional_modules_from_controlnet = [model.ctrl_time_embedding]
for m in modules_from_controlnet:
assert_unfrozen(m)
for m in optional_modules_from_controlnet:
if m is not None:
assert_unfrozen(m)
# down blocks
for d in model.down_blocks:
assert_unfrozen(d.ctrl_resnets)
assert_unfrozen(d.base_to_ctrl)
assert_unfrozen(d.ctrl_to_base)
if isinstance(d.ctrl_attentions, nn.ModuleList): # attentions can be list of Nones
assert_unfrozen(d.ctrl_attentions)
if d.ctrl_downsamplers is not None:
assert_unfrozen(d.ctrl_downsamplers)
# mid block
assert_unfrozen(model.mid_block.base_to_ctrl)
assert_unfrozen(model.mid_block.ctrl_midblock)
assert_unfrozen(model.mid_block.ctrl_to_base)
# up blocks
for u in model.up_blocks:
assert_unfrozen(u.ctrl_to_base)
def test_gradient_checkpointing_is_applied(self):
model_class_copy = copy.copy(UNetControlNetXSModel)
modules_with_gc_enabled = {}
# now monkey patch the following function:
# def _set_gradient_checkpointing(self, module, value=False):
# if hasattr(module, "gradient_checkpointing"):
# module.gradient_checkpointing = value
def _set_gradient_checkpointing_new(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
modules_with_gc_enabled[module.__class__.__name__] = True
model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = model_class_copy(**init_dict)
model.enable_gradient_checkpointing()
EXPECTED_SET = {
"Transformer2DModel",
"UNetMidBlock2DCrossAttn",
"ControlNetXSCrossAttnDownBlock2D",
"ControlNetXSCrossAttnMidBlock2D",
"ControlNetXSCrossAttnUpBlock2D",
}
assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET
assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
def test_forward_no_control(self):
unet = self.get_dummy_unet()
controlnet = self.get_dummy_controlnet_from_unet(unet)
model = UNetControlNetXSModel.from_unet(unet, controlnet)
unet = unet.to(torch_device)
model = model.to(torch_device)
input_ = self.dummy_input
control_specific_input = ["controlnet_cond", "conditioning_scale"]
input_for_unet = {k: v for k, v in input_.items() if k not in control_specific_input}
with torch.no_grad():
unet_output = unet(**input_for_unet).sample.cpu()
unet_controlnet_output = model(**input_, apply_control=False).sample.cpu()
assert np.abs(unet_output.flatten() - unet_controlnet_output.flatten()).max() < 3e-4
def test_time_embedding_mixing(self):
unet = self.get_dummy_unet()
controlnet = self.get_dummy_controlnet_from_unet(unet)
controlnet_mix_time = self.get_dummy_controlnet_from_unet(
unet, time_embedding_mix=0.5, learn_time_embedding=True
)
model = UNetControlNetXSModel.from_unet(unet, controlnet)
model_mix_time = UNetControlNetXSModel.from_unet(unet, controlnet_mix_time)
unet = unet.to(torch_device)
model = model.to(torch_device)
model_mix_time = model_mix_time.to(torch_device)
input_ = self.dummy_input
with torch.no_grad():
output = model(**input_).sample
output_mix_time = model_mix_time(**input_).sample
assert output.shape == output_mix_time.shape
def test_forward_with_norm_groups(self):
# UNetControlNetXSModel currently only supports StableDiffusion and StableDiffusion-XL, both of which have norm_num_groups fixed at 32. So we don't need to test different values for norm_num_groups.
pass
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment