Unverified Commit a0597f33 authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

t2i pipeline (#3932)



* Quick implementation of t2i-adapter

Load adapter module with from_pretrained

Prototyping generalized adapter framework

Writeup doc string for sideload framework(WIP) + some minor update on implementation

Update adapter models

Remove old adapter optional args in UNet

Add StableDiffusionAdapterPipeline unit test

Handle cpu offload in StableDiffusionAdapterPipeline

Auto correct coding style

Update model repo name to "RzZ/sd-v1-4-adapter-pipeline"

Refactor MultiAdapter to better compatible with config system

Export MultiAdapter

Create pipeline document template from controlnet

Create dummy objects

Supproting new AdapterLight model

Fix StableDiffusionAdapterPipeline common pipeline test

[WIP] Update adapter pipeline document

Handle num_inference_steps in StableDiffusionAdapterPipeline

Update definition of Adapter "channels_in"

Update documents

Apply code style

Fix doc typo and merge error

Update doc string and example

Quality of life improvement

Remove redundant code and file from prototyping

Remove unused pageage

Remove comments

Fix title

Fix typo

Add conditioning scale arg

Bring back old implmentation

Offload sideload

Add supply info on document

Update src/diffusers/models/adapter.py
Co-authored-by: default avatarWill Berman <wlbberman@gmail.com>

Update MultiAdapter constructor

Swap out custom checkpoint and update pipeline constructor

Update docment

Apply suggestions from code review
Co-authored-by: default avatarWill Berman <wlbberman@gmail.com>

Correcting style

Following single-file policy

Update auto size in image preprocess func

Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_adapter.py
Co-authored-by: default avatarWill Berman <wlbberman@gmail.com>

fix copies

Update adapter pipeline behavior

Add adapter_conditioning_scale doc string

Add the missing doc string

Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

Fix few bugs from suggestion

Handle L-mode PIL image as control image

Rename to differentiate adapter resblock

Update src/diffusers/models/adapter.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

Fix typo

Update adapter parameter name

Update test case and code style

Fix copies

Fix typo

Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_adapter.py
Co-authored-by: default avatarWill Berman <wlbberman@gmail.com>

Update Adapter class name

Add checkpoint converting script

Fix style

Fix-copies

Remove dev script

Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

Updates for parameter rename

Fix convert_adapter

remove main

fix diff

more

refactoring

more

more

small fixes

refactor

tests

more slow tests

more tests

Update docs/source/en/api/pipelines/overview.mdx
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

add community contributor to docs

Update docs/source/en/api/pipelines/stable_diffusion/adapter.mdx
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

Update docs/source/en/api/pipelines/stable_diffusion/adapter.mdx
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

Update docs/source/en/api/pipelines/stable_diffusion/adapter.mdx
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

Update docs/source/en/api/pipelines/stable_diffusion/adapter.mdx
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

Update docs/source/en/api/pipelines/stable_diffusion/adapter.mdx
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

fix

remove from_adapters

license

paper link

docs

more url fixes

more docs

fix

fixes

fix

fix

* fix sample inplace add

* additional_kwargs -> additional_residuals

* move t2i adapter pipeline to own module

* preprocess -> _preprocess_adapter_image

* add TencentArc to license

* fix example code links

* add image converter and fix example doc string

* fix links

* clearer additional residual application

---------
Co-authored-by: default avatarHimariO <dsfhe49854@gmail.com>
parent 39299546
......@@ -196,12 +196,12 @@
title: DDIM
- local: api/pipelines/ddpm
title: DDPM
- local: api/pipelines/deepfloyd_if
title: DeepFloyd IF
- local: api/pipelines/diffedit
title: DiffEdit
- local: api/pipelines/dit
title: DiT
- local: api/pipelines/deepfloyd_if
title: DeepFloyd IF
- local: api/pipelines/pix2pix
title: InstructPix2Pix
- local: api/pipelines/kandinsky
......@@ -255,6 +255,8 @@
title: Super-Resolution
- local: api/pipelines/stable_diffusion/ldm3d_diffusion
title: LDM3D Text-to-(RGB, Depth)
- local: api/pipelines/stable_diffusion/adapter
title: Stable Diffusion T2I-adapter
title: Stable Diffusion
- local: api/pipelines/stable_unclip
title: Stable unCLIP
......
......@@ -66,6 +66,7 @@ available a colab notebook to directly try them out.
| [score_sde_ve](./score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
| [score_sde_vp](./score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
| [semantic_stable_diffusion](./semantic_stable_diffusion) | [**SEGA: Instructing Diffusion using Semantic Dimensions**](https://arxiv.org/abs/2301.12247) | Text-to-Image Generation |
| [stable_diffusion_adapter](./stable_diffusion/adapter) | [**T2I-Adapter**](https://arxiv.org/abs/2302.08453) | Image-to-Image Text-Guided Generation with Adapters | -
| [stable_diffusion_text2img](./stable_diffusion/text2img) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
| [stable_diffusion_img2img](./stable_diffusion/img2img) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
| [stable_diffusion_inpaint](./stable_diffusion/inpaint) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
......
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Text-to-Image Generation with Adapter Conditioning
## Overview
[T2I-Adapter: Learning Adapters to Dig out More Controllable Ability for Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.08453) by Chong Mou, Xintao Wang, Liangbin Xie, Jian Zhang, Zhongang Qi, Ying Shan, Xiaohu Qie.
Using the pretrained models we can provide control images (for example, a depth map) to control Stable Diffusion text-to-image generation so that it follows the structure of the depth image and fills in the details.
The abstract of the paper is the following:
*The incredible generative ability of large-scale text-to-image (T2I) models has demonstrated strong power of learning complex structures and meaningful semantics. However, relying solely on text prompts cannot fully take advantage of the knowledge learned by the model, especially when flexible and accurate structure control is needed. In this paper, we aim to ``dig out" the capabilities that T2I models have implicitly learned, and then explicitly use them to control the generation more granularly. Specifically, we propose to learn simple and small T2I-Adapters to align internal knowledge in T2I models with external control signals, while freezing the original large T2I models. In this way, we can train various adapters according to different conditions, and achieve rich control and editing effects. Further, the proposed T2I-Adapters have attractive properties of practical value, such as composability and generalization ability. Extensive experiments demonstrate that our T2I-Adapter has promising generation quality and a wide range of applications.*
This model was contributed by the community contributor [HimariO](https://github.com/HimariO) ❤️ .
## Available Pipelines:
| Pipeline | Tasks | Demo
|---|---|:---:|
| [StableDiffusionAdapterPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_adapter.py) | *Text-to-Image Generation with T2I-Adapter Conditioning* | -
## Usage example
In the following we give a simple example of how to use a *T2IAdapter* checkpoint with Diffusers for inference.
All adapters use the same pipeline.
1. Images are first converted into the appropriate *control image* format.
2. The *control image* and *prompt* are passed to the [`StableDiffusionAdapterPipeline`].
Let's have a look at a simple example using the [Color Adapter](https://huggingface.co/TencentARC/t2iadapter_color_sd14v1).
```python
from diffusers.utils import load_image
image = load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_ref.png")
```
![img](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_ref.png)
Then we can create our color palette by simply resizing it to 8 by 8 pixels and then scaling it back to original size.
```python
from PIL import Image
color_palette = image.resize((8, 8))
color_palette = color_palette.resize((512, 512), resample=Image.Resampling.NEAREST)
```
Let's take a look at the processed image.
![img](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_palette.png)
Next, create the adapter pipeline
```py
import torch
from diffusers import StableDiffusionAdapterPipeline, T2IAdapter
adapter = T2IAdapter.from_pretrained("TencentARC/t2iadapter_color_sd14v1")
pipe = StableDiffusionAdapterPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
adapter=adapter,
torch_dtype=torch.float16,
)
pipe.to("cuda")
```
Finally, pass the prompt and control image to the pipeline
```py
# fix the random seed, so you will get the same result as the example
generator = torch.manual_seed(7)
out_image = pipe(
"At night, glowing cubes in front of the beach",
image=color_palette,
generator=generator,
).images[0]
```
![img](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_output.png)
## Available checkpoints
Non-diffusers checkpoints can be found under [TencentARC/T2I-Adapter](https://huggingface.co/TencentARC/T2I-Adapter/tree/main/models).
### T2I-Adapter with Stable Diffusion 1.4
| Model Name | Control Image Overview| Control Image Example | Generated Image Example |
|---|---|---|---|
|[TencentARC/t2iadapter_color_sd14v1](https://huggingface.co/TencentARC/t2iadapter_color_sd14v1)<br/> *Trained with spatial color palette* | A image with 8x8 color palette.|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_sample_input.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_sample_input.png"/></a>|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_sample_output.png"><img width="64" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_sample_output.png"/></a>|
|[TencentARC/t2iadapter_canny_sd14v1](https://huggingface.co/TencentARC/t2iadapter_canny_sd14v1)<br/> *Trained with canny edge detection* | A monochrome image with white edges on a black background.|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/canny_sample_input.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/canny_sample_input.png"/></a>|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/canny_sample_output.png"><img width="64" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/canny_sample_output.png"/></a>|
|[TencentARC/t2iadapter_sketch_sd14v1](https://huggingface.co/TencentARC/t2iadapter_sketch_sd14v1)<br/> *Trained with [PidiNet](https://github.com/zhuoinoulu/pidinet) edge detection* | A hand-drawn monochrome image with white outlines on a black background.|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/sketch_sample_input.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/sketch_sample_input.png"/></a>|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/sketch_sample_output.png"><img width="64" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/sketch_sample_output.png"/></a>|
|[TencentARC/t2iadapter_depth_sd14v1](https://huggingface.co/TencentARC/t2iadapter_depth_sd14v1)<br/> *Trained with Midas depth estimation* | A grayscale image with black representing deep areas and white representing shallow areas.|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/depth_sample_input.png"><img width="64" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/depth_sample_input.png"/></a>|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/depth_sample_output.png"><img width="64" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/depth_sample_output.png"/></a>|
|[TencentARC/t2iadapter_openpose_sd14v1](https://huggingface.co/TencentARC/t2iadapter_openpose_sd14v1)<br/> *Trained with OpenPose bone image* | A [OpenPose bone](https://github.com/CMU-Perceptual-Computing-Lab/openpose) image.|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/openpose_sample_input.png"><img width="64" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/openpose_sample_input.png"/></a>|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/openpose_sample_output.png"><img width="64" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/openpose_sample_output.png"/></a>|
|[TencentARC/t2iadapter_keypose_sd14v1](https://huggingface.co/TencentARC/t2iadapter_keypose_sd14v1)<br/> *Trained with mmpose skeleton image* | A [mmpose skeleton](https://github.com/open-mmlab/mmpose) image.|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/keypose_sample_input.png"><img width="64" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/keypose_sample_input.png"/></a>|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/keypose_sample_output.png"><img width="64" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/keypose_sample_output.png"/></a>|
|[TencentARC/t2iadapter_seg_sd14v1](https://huggingface.co/TencentARC/t2iadapter_seg_sd14v1)<br/>*Trained with semantic segmentation* | An [custom](https://github.com/TencentARC/T2I-Adapter/discussions/25) segmentation protocol image.|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/seg_sample_input.png"><img width="64" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/seg_sample_input.png"/></a>|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/seg_sample_output.png"><img width="64" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/seg_sample_output.png"/></a> |
|[TencentARC/t2iadapter_canny_sd15v2](https://huggingface.co/TencentARC/t2iadapter_canny_sd15v2)||
|[TencentARC/t2iadapter_depth_sd15v2](https://huggingface.co/TencentARC/t2iadapter_depth_sd15v2)||
|[TencentARC/t2iadapter_sketch_sd15v2](https://huggingface.co/TencentARC/t2iadapter_sketch_sd15v2)||
|[TencentARC/t2iadapter_zoedepth_sd15v1](https://huggingface.co/TencentARC/t2iadapter_zoedepth_sd15v1)||
## Combining multiple adapters
[`MultiAdapter`] can be used for applying multiple conditionings at once.
Here we use the keypose adapter for the character posture and the depth adapter for creating the scene.
```py
import torch
from PIL import Image
from diffusers.utils import load_image
cond_keypose = load_image(
"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/keypose_sample_input.png"
)
cond_depth = load_image(
"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/depth_sample_input.png"
)
cond = [[cond_keypose, cond_depth]]
prompt = ["A man walking in an office room with a nice view"]
```
The two control images look as such:
![img](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/keypose_sample_input.png)
![img](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/depth_sample_input.png)
`MultiAdapter` combines keypose and depth adapters.
`adapter_conditioning_scale` balances the relative influence of the different adapters.
```py
from diffusers import StableDiffusionAdapterPipeline, MultiAdapter
adapters = MultiAdapter(
[
T2IAdapter.from_pretrained("TencentARC/t2iadapter_keypose_sd14v1"),
T2IAdapter.from_pretrained("TencentARC/t2iadapter_depth_sd14v1"),
]
)
adapters = adapters.to(torch.float16)
pipe = StableDiffusionAdapterPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
torch_dtype=torch.float16,
adapter=adapters,
)
images = pipe(prompt, cond, adapter_conditioning_scale=[0.8, 0.8])
```
![img](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/keypose_depth_sample_output.png)
## T2I Adapter vs ControlNet
T2I-Adapter is similar to [ControlNet](https://huggingface.co/docs/diffusers/main/en/api/pipelines/controlnet).
T2i-Adapter uses a smaller auxiliary network which is only run once for the entire diffusion process.
However, T2I-Adapter performs slightly worse than ControlNet.
## StableDiffusionAdapterPipeline
[[autodoc]] StableDiffusionAdapterPipeline
- all
- __call__
- enable_attention_slicing
- disable_attention_slicing
- enable_vae_slicing
- disable_vae_slicing
- enable_xformers_memory_efficient_attention
- disable_xformers_memory_efficient_attention
......@@ -69,6 +69,7 @@ The library has three main components:
| [score_sde_ve](./api/pipelines/score_sde_ve) | [Score-Based Generative Modeling through Stochastic Differential Equations](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
| [score_sde_vp](./api/pipelines/score_sde_vp) | [Score-Based Generative Modeling through Stochastic Differential Equations](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
| [semantic_stable_diffusion](./api/pipelines/semantic_stable_diffusion) | [Semantic Guidance](https://arxiv.org/abs/2301.12247) | Text-Guided Generation |
| [stable_diffusion_adapter](./api/pipelines/stable_diffusion/adapter) | [**T2I-Adapter**](https://arxiv.org/abs/2302.08453) | Image-to-Image Text-Guided Generation | -
| [stable_diffusion_text2img](./api/pipelines/stable_diffusion/text2img) | [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation |
| [stable_diffusion_img2img](./api/pipelines/stable_diffusion/img2img) | [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation |
| [stable_diffusion_inpaint](./api/pipelines/stable_diffusion/inpaint) | [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting |
......
......@@ -59,6 +59,7 @@ For convenience, we provide a table to denote which methods are inference-only a
| [Custom Diffusion](#custom-diffusion) | ❌ | ✅ | |
| [Model Editing](#model-editing) | ✅ | ❌ | |
| [DiffEdit](#diffedit) | ✅ | ❌ | |
| [T2I-Adapter](#t2i-adapter) | ✅ | ❌ | |
## Instruct Pix2Pix
......@@ -215,4 +216,13 @@ To know more details, check out the [official doc](../api/pipelines/stable_diffu
[DiffEdit](../api/pipelines/stable_diffusion/diffedit) allows for semantic editing of input images along with
input prompts while preserving the original input images as much as possible.
To know more details, check out the [official doc](../api/pipelines/stable_diffusion/model_editing).
\ No newline at end of file
To know more details, check out the [official doc](../api/pipelines/stable_diffusion/model_editing).
## T2I-Adapter
[Paper](https://arxiv.org/abs/2302.08453)
[T2I-Adapter](../api/pipelines/stable_diffusion/adapter) is an auxiliary network which adds an extra condition.
There are 8 canonical pre-trained adapters trained on different conditionings such as edge detection, sketch,
depth maps, and semantic segmentations.
See [here](../api/pipelines/stable_diffusion/adapter) for more information on how to use it.
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Conversion script for the T2I-Adapter checkpoints.
"""
import argparse
import torch
from diffusers import T2IAdapter
def convert_adapter(src_state, in_channels):
original_body_length = max([int(x.split(".")[1]) for x in src_state.keys() if "body." in x]) + 1
assert original_body_length == 8
# (0, 1) -> channels 1
assert src_state["body.0.block1.weight"].shape == (320, 320, 3, 3)
# (2, 3) -> channels 2
assert src_state["body.2.in_conv.weight"].shape == (640, 320, 1, 1)
# (4, 5) -> channels 3
assert src_state["body.4.in_conv.weight"].shape == (1280, 640, 1, 1)
# (6, 7) -> channels 4
assert src_state["body.6.block1.weight"].shape == (1280, 1280, 3, 3)
res_state = {
"adapter.conv_in.weight": src_state.pop("conv_in.weight"),
"adapter.conv_in.bias": src_state.pop("conv_in.bias"),
# 0.resnets.0
"adapter.body.0.resnets.0.block1.weight": src_state.pop("body.0.block1.weight"),
"adapter.body.0.resnets.0.block1.bias": src_state.pop("body.0.block1.bias"),
"adapter.body.0.resnets.0.block2.weight": src_state.pop("body.0.block2.weight"),
"adapter.body.0.resnets.0.block2.bias": src_state.pop("body.0.block2.bias"),
# 0.resnets.1
"adapter.body.0.resnets.1.block1.weight": src_state.pop("body.1.block1.weight"),
"adapter.body.0.resnets.1.block1.bias": src_state.pop("body.1.block1.bias"),
"adapter.body.0.resnets.1.block2.weight": src_state.pop("body.1.block2.weight"),
"adapter.body.0.resnets.1.block2.bias": src_state.pop("body.1.block2.bias"),
# 1
"adapter.body.1.in_conv.weight": src_state.pop("body.2.in_conv.weight"),
"adapter.body.1.in_conv.bias": src_state.pop("body.2.in_conv.bias"),
# 1.resnets.0
"adapter.body.1.resnets.0.block1.weight": src_state.pop("body.2.block1.weight"),
"adapter.body.1.resnets.0.block1.bias": src_state.pop("body.2.block1.bias"),
"adapter.body.1.resnets.0.block2.weight": src_state.pop("body.2.block2.weight"),
"adapter.body.1.resnets.0.block2.bias": src_state.pop("body.2.block2.bias"),
# 1.resnets.1
"adapter.body.1.resnets.1.block1.weight": src_state.pop("body.3.block1.weight"),
"adapter.body.1.resnets.1.block1.bias": src_state.pop("body.3.block1.bias"),
"adapter.body.1.resnets.1.block2.weight": src_state.pop("body.3.block2.weight"),
"adapter.body.1.resnets.1.block2.bias": src_state.pop("body.3.block2.bias"),
# 2
"adapter.body.2.in_conv.weight": src_state.pop("body.4.in_conv.weight"),
"adapter.body.2.in_conv.bias": src_state.pop("body.4.in_conv.bias"),
# 2.resnets.0
"adapter.body.2.resnets.0.block1.weight": src_state.pop("body.4.block1.weight"),
"adapter.body.2.resnets.0.block1.bias": src_state.pop("body.4.block1.bias"),
"adapter.body.2.resnets.0.block2.weight": src_state.pop("body.4.block2.weight"),
"adapter.body.2.resnets.0.block2.bias": src_state.pop("body.4.block2.bias"),
# 2.resnets.1
"adapter.body.2.resnets.1.block1.weight": src_state.pop("body.5.block1.weight"),
"adapter.body.2.resnets.1.block1.bias": src_state.pop("body.5.block1.bias"),
"adapter.body.2.resnets.1.block2.weight": src_state.pop("body.5.block2.weight"),
"adapter.body.2.resnets.1.block2.bias": src_state.pop("body.5.block2.bias"),
# 3.resnets.0
"adapter.body.3.resnets.0.block1.weight": src_state.pop("body.6.block1.weight"),
"adapter.body.3.resnets.0.block1.bias": src_state.pop("body.6.block1.bias"),
"adapter.body.3.resnets.0.block2.weight": src_state.pop("body.6.block2.weight"),
"adapter.body.3.resnets.0.block2.bias": src_state.pop("body.6.block2.bias"),
# 3.resnets.1
"adapter.body.3.resnets.1.block1.weight": src_state.pop("body.7.block1.weight"),
"adapter.body.3.resnets.1.block1.bias": src_state.pop("body.7.block1.bias"),
"adapter.body.3.resnets.1.block2.weight": src_state.pop("body.7.block2.weight"),
"adapter.body.3.resnets.1.block2.bias": src_state.pop("body.7.block2.bias"),
}
assert len(src_state) == 0
adapter = T2IAdapter(in_channels=in_channels, adapter_type="full_adapter")
adapter.load_state_dict(res_state)
return adapter
def convert_light_adapter(src_state):
original_body_length = max([int(x.split(".")[1]) for x in src_state.keys() if "body." in x]) + 1
assert original_body_length == 4
res_state = {
# body.0.in_conv
"adapter.body.0.in_conv.weight": src_state.pop("body.0.in_conv.weight"),
"adapter.body.0.in_conv.bias": src_state.pop("body.0.in_conv.bias"),
# body.0.resnets.0
"adapter.body.0.resnets.0.block1.weight": src_state.pop("body.0.body.0.block1.weight"),
"adapter.body.0.resnets.0.block1.bias": src_state.pop("body.0.body.0.block1.bias"),
"adapter.body.0.resnets.0.block2.weight": src_state.pop("body.0.body.0.block2.weight"),
"adapter.body.0.resnets.0.block2.bias": src_state.pop("body.0.body.0.block2.bias"),
# body.0.resnets.1
"adapter.body.0.resnets.1.block1.weight": src_state.pop("body.0.body.1.block1.weight"),
"adapter.body.0.resnets.1.block1.bias": src_state.pop("body.0.body.1.block1.bias"),
"adapter.body.0.resnets.1.block2.weight": src_state.pop("body.0.body.1.block2.weight"),
"adapter.body.0.resnets.1.block2.bias": src_state.pop("body.0.body.1.block2.bias"),
# body.0.resnets.2
"adapter.body.0.resnets.2.block1.weight": src_state.pop("body.0.body.2.block1.weight"),
"adapter.body.0.resnets.2.block1.bias": src_state.pop("body.0.body.2.block1.bias"),
"adapter.body.0.resnets.2.block2.weight": src_state.pop("body.0.body.2.block2.weight"),
"adapter.body.0.resnets.2.block2.bias": src_state.pop("body.0.body.2.block2.bias"),
# body.0.resnets.3
"adapter.body.0.resnets.3.block1.weight": src_state.pop("body.0.body.3.block1.weight"),
"adapter.body.0.resnets.3.block1.bias": src_state.pop("body.0.body.3.block1.bias"),
"adapter.body.0.resnets.3.block2.weight": src_state.pop("body.0.body.3.block2.weight"),
"adapter.body.0.resnets.3.block2.bias": src_state.pop("body.0.body.3.block2.bias"),
# body.0.out_conv
"adapter.body.0.out_conv.weight": src_state.pop("body.0.out_conv.weight"),
"adapter.body.0.out_conv.bias": src_state.pop("body.0.out_conv.bias"),
# body.1.in_conv
"adapter.body.1.in_conv.weight": src_state.pop("body.1.in_conv.weight"),
"adapter.body.1.in_conv.bias": src_state.pop("body.1.in_conv.bias"),
# body.1.resnets.0
"adapter.body.1.resnets.0.block1.weight": src_state.pop("body.1.body.0.block1.weight"),
"adapter.body.1.resnets.0.block1.bias": src_state.pop("body.1.body.0.block1.bias"),
"adapter.body.1.resnets.0.block2.weight": src_state.pop("body.1.body.0.block2.weight"),
"adapter.body.1.resnets.0.block2.bias": src_state.pop("body.1.body.0.block2.bias"),
# body.1.resnets.1
"adapter.body.1.resnets.1.block1.weight": src_state.pop("body.1.body.1.block1.weight"),
"adapter.body.1.resnets.1.block1.bias": src_state.pop("body.1.body.1.block1.bias"),
"adapter.body.1.resnets.1.block2.weight": src_state.pop("body.1.body.1.block2.weight"),
"adapter.body.1.resnets.1.block2.bias": src_state.pop("body.1.body.1.block2.bias"),
# body.1.body.2
"adapter.body.1.resnets.2.block1.weight": src_state.pop("body.1.body.2.block1.weight"),
"adapter.body.1.resnets.2.block1.bias": src_state.pop("body.1.body.2.block1.bias"),
"adapter.body.1.resnets.2.block2.weight": src_state.pop("body.1.body.2.block2.weight"),
"adapter.body.1.resnets.2.block2.bias": src_state.pop("body.1.body.2.block2.bias"),
# body.1.body.3
"adapter.body.1.resnets.3.block1.weight": src_state.pop("body.1.body.3.block1.weight"),
"adapter.body.1.resnets.3.block1.bias": src_state.pop("body.1.body.3.block1.bias"),
"adapter.body.1.resnets.3.block2.weight": src_state.pop("body.1.body.3.block2.weight"),
"adapter.body.1.resnets.3.block2.bias": src_state.pop("body.1.body.3.block2.bias"),
# body.1.out_conv
"adapter.body.1.out_conv.weight": src_state.pop("body.1.out_conv.weight"),
"adapter.body.1.out_conv.bias": src_state.pop("body.1.out_conv.bias"),
# body.2.in_conv
"adapter.body.2.in_conv.weight": src_state.pop("body.2.in_conv.weight"),
"adapter.body.2.in_conv.bias": src_state.pop("body.2.in_conv.bias"),
# body.2.body.0
"adapter.body.2.resnets.0.block1.weight": src_state.pop("body.2.body.0.block1.weight"),
"adapter.body.2.resnets.0.block1.bias": src_state.pop("body.2.body.0.block1.bias"),
"adapter.body.2.resnets.0.block2.weight": src_state.pop("body.2.body.0.block2.weight"),
"adapter.body.2.resnets.0.block2.bias": src_state.pop("body.2.body.0.block2.bias"),
# body.2.body.1
"adapter.body.2.resnets.1.block1.weight": src_state.pop("body.2.body.1.block1.weight"),
"adapter.body.2.resnets.1.block1.bias": src_state.pop("body.2.body.1.block1.bias"),
"adapter.body.2.resnets.1.block2.weight": src_state.pop("body.2.body.1.block2.weight"),
"adapter.body.2.resnets.1.block2.bias": src_state.pop("body.2.body.1.block2.bias"),
# body.2.body.2
"adapter.body.2.resnets.2.block1.weight": src_state.pop("body.2.body.2.block1.weight"),
"adapter.body.2.resnets.2.block1.bias": src_state.pop("body.2.body.2.block1.bias"),
"adapter.body.2.resnets.2.block2.weight": src_state.pop("body.2.body.2.block2.weight"),
"adapter.body.2.resnets.2.block2.bias": src_state.pop("body.2.body.2.block2.bias"),
# body.2.body.3
"adapter.body.2.resnets.3.block1.weight": src_state.pop("body.2.body.3.block1.weight"),
"adapter.body.2.resnets.3.block1.bias": src_state.pop("body.2.body.3.block1.bias"),
"adapter.body.2.resnets.3.block2.weight": src_state.pop("body.2.body.3.block2.weight"),
"adapter.body.2.resnets.3.block2.bias": src_state.pop("body.2.body.3.block2.bias"),
# body.2.out_conv
"adapter.body.2.out_conv.weight": src_state.pop("body.2.out_conv.weight"),
"adapter.body.2.out_conv.bias": src_state.pop("body.2.out_conv.bias"),
# body.3.in_conv
"adapter.body.3.in_conv.weight": src_state.pop("body.3.in_conv.weight"),
"adapter.body.3.in_conv.bias": src_state.pop("body.3.in_conv.bias"),
# body.3.body.0
"adapter.body.3.resnets.0.block1.weight": src_state.pop("body.3.body.0.block1.weight"),
"adapter.body.3.resnets.0.block1.bias": src_state.pop("body.3.body.0.block1.bias"),
"adapter.body.3.resnets.0.block2.weight": src_state.pop("body.3.body.0.block2.weight"),
"adapter.body.3.resnets.0.block2.bias": src_state.pop("body.3.body.0.block2.bias"),
# body.3.body.1
"adapter.body.3.resnets.1.block1.weight": src_state.pop("body.3.body.1.block1.weight"),
"adapter.body.3.resnets.1.block1.bias": src_state.pop("body.3.body.1.block1.bias"),
"adapter.body.3.resnets.1.block2.weight": src_state.pop("body.3.body.1.block2.weight"),
"adapter.body.3.resnets.1.block2.bias": src_state.pop("body.3.body.1.block2.bias"),
# body.3.body.2
"adapter.body.3.resnets.2.block1.weight": src_state.pop("body.3.body.2.block1.weight"),
"adapter.body.3.resnets.2.block1.bias": src_state.pop("body.3.body.2.block1.bias"),
"adapter.body.3.resnets.2.block2.weight": src_state.pop("body.3.body.2.block2.weight"),
"adapter.body.3.resnets.2.block2.bias": src_state.pop("body.3.body.2.block2.bias"),
# body.3.body.3
"adapter.body.3.resnets.3.block1.weight": src_state.pop("body.3.body.3.block1.weight"),
"adapter.body.3.resnets.3.block1.bias": src_state.pop("body.3.body.3.block1.bias"),
"adapter.body.3.resnets.3.block2.weight": src_state.pop("body.3.body.3.block2.weight"),
"adapter.body.3.resnets.3.block2.bias": src_state.pop("body.3.body.3.block2.bias"),
# body.3.out_conv
"adapter.body.3.out_conv.weight": src_state.pop("body.3.out_conv.weight"),
"adapter.body.3.out_conv.bias": src_state.pop("body.3.out_conv.bias"),
}
assert len(src_state) == 0
adapter = T2IAdapter(in_channels=3, channels=[320, 640, 1280], num_res_blocks=4, adapter_type="light_adapter")
adapter.load_state_dict(res_state)
return adapter
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
)
parser.add_argument(
"--output_path", default=None, type=str, required=True, help="Path to the store the result checkpoint."
)
parser.add_argument(
"--is_adapter_light",
action="store_true",
help="Is checkpoint come from Adapter-Light architecture. ex: color-adapter",
)
parser.add_argument("--in_channels", required=False, type=int, help="Input channels for non-light adapter")
args = parser.parse_args()
src_state = torch.load(args.checkpoint_path)
if args.is_adapter_light:
adapter = convert_light_adapter(src_state)
else:
if args.in_channels is None:
raise ValueError("set `--in_channels=<n>`")
adapter = convert_adapter(src_state, args.in_channels)
adapter.save_pretrained(args.output_path)
......@@ -39,7 +39,9 @@ else:
AutoencoderKL,
ControlNetModel,
ModelMixin,
MultiAdapter,
PriorTransformer,
T2IAdapter,
T5FilmDecoder,
Transformer2DModel,
UNet1DModel,
......@@ -151,6 +153,7 @@ else:
SemanticStableDiffusionPipeline,
ShapEImg2ImgPipeline,
ShapEPipeline,
StableDiffusionAdapterPipeline,
StableDiffusionAttendAndExcitePipeline,
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline,
......
......@@ -16,6 +16,7 @@ from ..utils import is_flax_available, is_torch_available
if is_torch_available():
from .adapter import MultiAdapter, T2IAdapter
from .autoencoder_kl import AutoencoderKL
from .controlnet import ControlNetModel
from .dual_transformer_2d import DualTransformer2DModel
......
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from .modeling_utils import ModelMixin
from .resnet import Downsample2D
class MultiAdapter(ModelMixin):
r"""
MultiAdapter is a wrapper model that contains multiple adapter models and merges their outputs according to
user-assigned weighting.
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the model (such as downloading or saving, etc.)
Parameters:
adapters (`List[T2IAdapter]`, *optional*, defaults to None):
A list of `T2IAdapter` model instances.
"""
def __init__(self, adapters: List["T2IAdapter"]):
super(MultiAdapter, self).__init__()
self.num_adapter = len(adapters)
self.adapters = nn.ModuleList(adapters)
def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = None) -> List[torch.Tensor]:
r"""
Args:
xs (`torch.Tensor`):
(batch, channel, height, width) input images for multiple adapter models concated along dimension 1,
`channel` should equal to `num_adapter` * "number of channel of image".
adapter_weights (`List[float]`, *optional*, defaults to None):
List of floats representing the weight which will be multiply to each adapter's output before adding
them together.
"""
if adapter_weights is None:
adapter_weights = torch.tensor([1 / self.num_adapter] * self.num_adapter)
else:
adapter_weights = torch.tensor(adapter_weights)
if xs.shape[1] % self.num_adapter != 0:
raise ValueError(
f"Expecting multi-adapter's input have number of channel that cab be evenly divisible "
f"by num_adapter: {xs.shape[1]} % {self.num_adapter} != 0"
)
x_list = torch.chunk(xs, self.num_adapter, dim=1)
accume_state = None
for x, w, adapter in zip(x_list, adapter_weights, self.adapters):
features = adapter(x)
if accume_state is None:
accume_state = features
else:
for i in range(len(features)):
accume_state[i] += w * features[i]
return accume_state
class T2IAdapter(ModelMixin, ConfigMixin):
r"""
A simple ResNet-like model that accepts images containing control signals such as keyposes and depth. The model
generates multiple feature maps that are used as additional conditioning in [`UNet2DConditionModel`]. The model's
architecture follows the original implementation of
[Adapter](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L97)
and
[AdapterLight](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L235).
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the model (such as downloading or saving, etc.)
Parameters:
in_channels (`int`, *optional*, defaults to 3):
Number of channels of Aapter's input(*control image*). Set this parameter to 1 if you're using gray scale
image as *control image*.
channels (`List[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The number of channel of each downsample block's output hidden state. The `len(block_out_channels)` will
also determine the number of downsample blocks in the Adapter.
num_res_blocks (`int`, *optional*, defaults to 2):
Number of ResNet blocks in each downsample block
"""
@register_to_config
def __init__(
self,
in_channels: int = 3,
channels: List[int] = [320, 640, 1280, 1280],
num_res_blocks: int = 2,
downscale_factor: int = 8,
adapter_type: str = "full_adapter",
):
super().__init__()
if adapter_type == "full_adapter":
self.adapter = FullAdapter(in_channels, channels, num_res_blocks, downscale_factor)
elif adapter_type == "light_adapter":
self.adapter = LightAdapter(in_channels, channels, num_res_blocks, downscale_factor)
else:
raise ValueError(f"unknown adapter_type: {type}. Choose either 'full_adapter' or 'simple_adapter'")
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
return self.adapter(x)
@property
def total_downscale_factor(self):
return self.adapter.total_downscale_factor
# full adapter
class FullAdapter(nn.Module):
def __init__(
self,
in_channels: int = 3,
channels: List[int] = [320, 640, 1280, 1280],
num_res_blocks: int = 2,
downscale_factor: int = 8,
):
super().__init__()
in_channels = in_channels * downscale_factor**2
self.unshuffle = nn.PixelUnshuffle(downscale_factor)
self.conv_in = nn.Conv2d(in_channels, channels[0], kernel_size=3, padding=1)
self.body = nn.ModuleList(
[
AdapterBlock(channels[0], channels[0], num_res_blocks),
*[
AdapterBlock(channels[i - 1], channels[i], num_res_blocks, down=True)
for i in range(1, len(channels))
],
]
)
self.total_downscale_factor = downscale_factor * 2 ** (len(channels) - 1)
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
x = self.unshuffle(x)
x = self.conv_in(x)
features = []
for block in self.body:
x = block(x)
features.append(x)
return features
class AdapterBlock(nn.Module):
def __init__(self, in_channels, out_channels, num_res_blocks, down=False):
super().__init__()
self.downsample = None
if down:
self.downsample = Downsample2D(in_channels)
self.in_conv = None
if in_channels != out_channels:
self.in_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
self.resnets = nn.Sequential(
*[AdapterResnetBlock(out_channels) for _ in range(num_res_blocks)],
)
def forward(self, x):
if self.downsample is not None:
x = self.downsample(x)
if self.in_conv is not None:
x = self.in_conv(x)
x = self.resnets(x)
return x
class AdapterResnetBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.act = nn.ReLU()
self.block2 = nn.Conv2d(channels, channels, kernel_size=1)
def forward(self, x):
h = x
h = self.block1(h)
h = self.act(h)
h = self.block2(h)
return h + x
# light adapter
class LightAdapter(nn.Module):
def __init__(
self,
in_channels: int = 3,
channels: List[int] = [320, 640, 1280],
num_res_blocks: int = 4,
downscale_factor: int = 8,
):
super().__init__()
in_channels = in_channels * downscale_factor**2
self.unshuffle = nn.PixelUnshuffle(downscale_factor)
self.body = nn.ModuleList(
[
LightAdapterBlock(in_channels, channels[0], num_res_blocks),
*[
LightAdapterBlock(channels[i], channels[i + 1], num_res_blocks, down=True)
for i in range(len(channels) - 1)
],
LightAdapterBlock(channels[-1], channels[-1], num_res_blocks, down=True),
]
)
self.total_downscale_factor = downscale_factor * (2 ** len(channels))
def forward(self, x):
x = self.unshuffle(x)
features = []
for block in self.body:
x = block(x)
features.append(x)
return features
class LightAdapterBlock(nn.Module):
def __init__(self, in_channels, out_channels, num_res_blocks, down=False):
super().__init__()
mid_channels = out_channels // 4
self.downsample = None
if down:
self.downsample = Downsample2D(in_channels)
self.in_conv = nn.Conv2d(in_channels, mid_channels, kernel_size=1)
self.resnets = nn.Sequential(*[LightAdapterResnetBlock(mid_channels) for _ in range(num_res_blocks)])
self.out_conv = nn.Conv2d(mid_channels, out_channels, kernel_size=1)
def forward(self, x):
if self.downsample is not None:
x = self.downsample(x)
x = self.in_conv(x)
x = self.resnets(x)
x = self.out_conv(x)
return x
class LightAdapterResnetBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.act = nn.ReLU()
self.block2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
def forward(self, x):
h = x
h = self.block1(h)
h = self.act(h)
h = self.block2(h)
return h + x
......@@ -955,10 +955,13 @@ class CrossAttnDownBlock2D(nn.Module):
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
additional_residuals=None,
):
output_states = ()
for resnet, attn in zip(self.resnets, self.attentions):
blocks = list(zip(self.resnets, self.attentions))
for i, (resnet, attn) in enumerate(blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
......@@ -999,6 +1002,10 @@ class CrossAttnDownBlock2D(nn.Module):
return_dict=False,
)[0]
# apply additional residuals to the output of the last pair of resnet and attention blocks
if i == len(blocks) - 1 and additional_residuals is not None:
hidden_states = hidden_states + additional_residuals
output_states = output_states + (hidden_states,)
if self.downsamplers is not None:
......
......@@ -899,9 +899,18 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
sample = self.conv_in(sample)
# 3. down
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
# For t2i-adapter CrossAttnDownBlock2D
additional_residuals = {}
if is_adapter and len(down_block_additional_residuals) > 0:
additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
......@@ -909,13 +918,17 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
**additional_residuals,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
if is_adapter and len(down_block_additional_residuals) > 0:
sample += down_block_additional_residuals.pop(0)
down_block_res_samples += res_samples
if down_block_additional_residuals is not None:
if is_controlnet:
new_down_block_res_samples = ()
for down_block_res_sample, down_block_additional_residual in zip(
......@@ -937,7 +950,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
encoder_attention_mask=encoder_attention_mask,
)
if mid_block_additional_residual is not None:
if is_controlnet:
sample = sample + mid_block_additional_residual
# 5. up
......
......@@ -101,6 +101,7 @@ else:
StableUnCLIPPipeline,
)
from .stable_diffusion_safe import StableDiffusionPipelineSafe
from .t2i_adapter import StableDiffusionAdapterPipeline
from .text_to_video_synthesis import TextToVideoSDPipeline, TextToVideoZeroPipeline, VideoToVideoSDPipeline
from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline
from .unidiffuser import ImageTextPipelineOutput, UniDiffuserModel, UniDiffuserPipeline, UniDiffuserTextDecoder
......
from ...utils import (
OptionalDependencyNotAvailable,
is_torch_available,
is_transformers_available,
)
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .pipeline_stable_diffusion_adapter import StableDiffusionAdapterPipeline
......@@ -1012,9 +1012,18 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
sample = self.conv_in(sample)
# 3. down
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
# For t2i-adapter CrossAttnDownBlockFlat
additional_residuals = {}
if is_adapter and len(down_block_additional_residuals) > 0:
additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
......@@ -1022,13 +1031,17 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
**additional_residuals,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
if is_adapter and len(down_block_additional_residuals) > 0:
sample += down_block_additional_residuals.pop(0)
down_block_res_samples += res_samples
if down_block_additional_residuals is not None:
if is_controlnet:
new_down_block_res_samples = ()
for down_block_res_sample, down_block_additional_residual in zip(
......@@ -1050,7 +1063,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
encoder_attention_mask=encoder_attention_mask,
)
if mid_block_additional_residual is not None:
if is_controlnet:
sample = sample + mid_block_additional_residual
# 5. up
......@@ -1390,10 +1403,13 @@ class CrossAttnDownBlockFlat(nn.Module):
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
additional_residuals=None,
):
output_states = ()
for resnet, attn in zip(self.resnets, self.attentions):
blocks = list(zip(self.resnets, self.attentions))
for i, (resnet, attn) in enumerate(blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
......@@ -1434,6 +1450,10 @@ class CrossAttnDownBlockFlat(nn.Module):
return_dict=False,
)[0]
# apply additional residuals to the output of the last pair of resnet and attention blocks
if i == len(blocks) - 1 and additional_residuals is not None:
hidden_states = hidden_states + additional_residuals
output_states = output_states + (hidden_states,)
if self.downsamplers is not None:
......
......@@ -47,6 +47,21 @@ class ModelMixin(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class MultiAdapter(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class PriorTransformer(metaclass=DummyObject):
_backends = ["torch"]
......@@ -62,6 +77,21 @@ class PriorTransformer(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class T2IAdapter(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class T5FilmDecoder(metaclass=DummyObject):
_backends = ["torch"]
......
......@@ -407,6 +407,21 @@ class ShapEPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionAdapterPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionAttendAndExcitePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
......
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import random
import unittest
import numpy as np
import torch
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import (
AutoencoderKL,
PNDMScheduler,
StableDiffusionAdapterPipeline,
T2IAdapter,
UNet2DConditionModel,
)
from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
class AdapterTests:
pipeline_class = StableDiffusionAdapterPipeline
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
def get_dummy_components(self, adapter_type):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
)
scheduler = PNDMScheduler(skip_prk_steps=True)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
torch.manual_seed(0)
adapter = T2IAdapter(
in_channels=3,
channels=[32, 64],
num_res_blocks=2,
downscale_factor=2,
adapter_type=adapter_type,
)
components = {
"adapter": adapter,
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
}
return components
def get_dummy_inputs(self, device, seed=0):
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device)
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"image": image,
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
"output_type": "numpy",
}
return inputs
def test_attention_slicing_forward_pass(self):
return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3)
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3)
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
class StableDiffusionFullAdapterPipelineFastTests(AdapterTests, PipelineTesterMixin, unittest.TestCase):
def get_dummy_components(self):
return super().get_dummy_components("full_adapter")
def test_stable_diffusion_adapter_default_case(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionAdapterPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.4858, 0.5500, 0.4278, 0.4669, 0.6184, 0.4322, 0.5010, 0.5033, 0.4746])
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3
class StableDiffusionLightAdapterPipelineFastTests(AdapterTests, PipelineTesterMixin, unittest.TestCase):
def get_dummy_components(self):
return super().get_dummy_components("light_adapter")
def test_stable_diffusion_adapter_default_case(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionAdapterPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.4965, 0.5548, 0.4330, 0.4771, 0.6226, 0.4382, 0.5037, 0.5071, 0.4782])
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3
@slow
@require_torch_gpu
class StableDiffusionAdapterPipelineSlowTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_stable_diffusion_adapter(self):
test_cases = [
(
"TencentARC/t2iadapter_color_sd14v1",
"CompVis/stable-diffusion-v1-4",
"snail",
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/color.png",
3,
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_color_sd14v1.npy",
),
(
"TencentARC/t2iadapter_depth_sd14v1",
"CompVis/stable-diffusion-v1-4",
"desk",
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/desk_depth.png",
3,
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_depth_sd14v1.npy",
),
(
"TencentARC/t2iadapter_depth_sd15v2",
"runwayml/stable-diffusion-v1-5",
"desk",
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/desk_depth.png",
3,
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_depth_sd15v2.npy",
),
(
"TencentARC/t2iadapter_keypose_sd14v1",
"CompVis/stable-diffusion-v1-4",
"person",
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/person_keypose.png",
3,
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_keypose_sd14v1.npy",
),
(
"TencentARC/t2iadapter_openpose_sd14v1",
"CompVis/stable-diffusion-v1-4",
"person",
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/iron_man_pose.png",
3,
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_openpose_sd14v1.npy",
),
(
"TencentARC/t2iadapter_seg_sd14v1",
"CompVis/stable-diffusion-v1-4",
"motorcycle",
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/motor.png",
3,
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_seg_sd14v1.npy",
),
(
"TencentARC/t2iadapter_zoedepth_sd15v1",
"runwayml/stable-diffusion-v1-5",
"motorcycle",
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/motorcycle.png",
3,
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_zoedepth_sd15v1.npy",
),
(
"TencentARC/t2iadapter_canny_sd14v1",
"CompVis/stable-diffusion-v1-4",
"toy",
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/toy_canny.png",
1,
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_canny_sd14v1.npy",
),
(
"TencentARC/t2iadapter_canny_sd15v2",
"runwayml/stable-diffusion-v1-5",
"toy",
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/toy_canny.png",
1,
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_canny_sd15v2.npy",
),
(
"TencentARC/t2iadapter_sketch_sd14v1",
"CompVis/stable-diffusion-v1-4",
"cat",
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/edge.png",
1,
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_sketch_sd14v1.npy",
),
(
"TencentARC/t2iadapter_sketch_sd15v2",
"runwayml/stable-diffusion-v1-5",
"cat",
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/edge.png",
1,
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_sketch_sd15v2.npy",
),
]
for adapter_model, sd_model, prompt, image_url, input_channels, out_url in test_cases:
image = load_image(image_url)
expected_out = load_numpy(out_url)
if input_channels == 1:
image = image.convert("L")
adapter = T2IAdapter.from_pretrained(adapter_model, torch_dtype=torch.float16)
pipe = StableDiffusionAdapterPipeline.from_pretrained(sd_model, adapter=adapter, safety_checker=None)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
generator = torch.Generator(device="cpu").manual_seed(0)
out = pipe(prompt=prompt, image=image, generator=generator, num_inference_steps=2, output_type="np").images
self.assertTrue(np.allclose(out, expected_out))
def test_stable_diffusion_adapter_pipeline_with_sequential_cpu_offloading(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
adapter = T2IAdapter.from_pretrained("TencentARC/t2iadapter_seg_sd14v1")
pipe = StableDiffusionAdapterPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", adapter=adapter, safety_checker=None
)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing(1)
pipe.enable_sequential_cpu_offload()
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/motor.png"
)
pipe(prompt="foo", image=image, num_inference_steps=2)
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes < 5 * 10**9
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment