".github/vscode:/vscode.git/clone" did not exist on "677ff400ae581e4058136bdc8fe9fc644c38fc51"
Unverified Commit 03b7a84c authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

Add Kandinsky 2.1 (#3308)



add kandinsky2.1

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
Co-authored-by: default avatarAyush Mangal <43698245+ayushtues@users.noreply.github.com>
Co-authored-by: default avatarayushmangal <ayushmangal@microsoft.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent f19f1287
...@@ -166,6 +166,8 @@ ...@@ -166,6 +166,8 @@
title: DiT title: DiT
- local: api/pipelines/if - local: api/pipelines/if
title: IF title: IF
- local: api/pipelines/kandinsky
title: Kandinsky
- local: api/pipelines/latent_diffusion - local: api/pipelines/latent_diffusion
title: Latent Diffusion title: Latent Diffusion
- local: api/pipelines/paint_by_example - local: api/pipelines/paint_by_example
......
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Kandinsky
## Overview
Kandinsky 2.1 inherits best practices from [DALL-E 2](https://arxiv.org/abs/2204.06125) and [Latent Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/latent_diffusion), while introducing some new ideas.
It uses [CLIP](https://huggingface.co/docs/transformers/model_doc/clip) for encoding images and text, and a diffusion image prior (mapping) between latent spaces of CLIP modalities. This approach enhances the visual performance of the model and unveils new horizons in blending images and text-guided image manipulation.
The Kandinsky model is created by [Arseniy Shakhmatov](https://github.com/cene555), [Anton Razzhigaev](https://github.com/razzant), [Aleksandr Nikolich](https://github.com/AlexWortega), [Igor Pavlov](https://github.com/boomb0om), [Andrey Kuznetsov](https://github.com/kuznetsoffandrey) and [Denis Dimitrov](https://github.com/denndimitrov) and the original codebase can be found [here](https://github.com/ai-forever/Kandinsky-2)
## Available Pipelines:
| Pipeline | Tasks | Colab
|---|---|:---:|
| [pipeline_kandinsky.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py) | *Text-to-Image Generation* | - |
| [pipeline_kandinsky_inpaint.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py) | *Image-Guided Image Generation* | - |
| [pipeline_kandinsky_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py) | *Image-Guided Image Generation* | - |
## Usage example
In the following, we will walk you through some cool examples of using the Kandinsky pipelines to create some visually aesthetic artwork.
### Text-to-Image Generation
For text-to-image generation, we need to use both [`KandinskyPriorPipeline`] and [`KandinskyPipeline`]. The first step is to encode text prompts with CLIP and then diffuse the CLIP text embeddings to CLIP image embeddings, as first proposed in [DALL-E 2](https://cdn.openai.com/papers/dall-e-2.pdf). Let's throw a fun prompt at Kandinsky to see what it comes up with :)
```python
prompt = "A alien cheeseburger creature eating itself, claymation, cinematic, moody lighting"
negative_prompt = "low quality, bad quality"
```
We will pass both the `prompt` and `negative_prompt` to our prior diffusion pipeline. In contrast to other diffusion pipelines, such as Stable Diffusion, the `prompt` and `negative_prompt` shall be passed separately so that we can retrieve a CLIP image embedding for each prompt input. You can use `guidance_scale`, and `num_inference_steps` arguments to guide this process, just like how you would normally do with all other pipelines in diffusers.
```python
from diffusers import KandinskyPriorPipeline
import torch
# create prior
pipe_prior = KandinskyPriorPipeline.from_pretrained(
"kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16
)
pipe_prior.to("cuda")
generator = torch.Generator(device="cuda").manual_seed(12)
image_emb = pipe_prior(
prompt, guidance_scale=1.0, num_inference_steps=25, generator=generator, negative_prompt=negative_prompt
).images
zero_image_emb = pipe_prior(
negative_prompt, guidance_scale=1.0, num_inference_steps=25, generator=generator, negative_prompt=negative_prompt
).images
```
Once we create the image embedding, we can use [`KandinskyPipeline`] to generate images.
```python
from PIL import Image
from diffusers import KandinskyPipeline
def image_grid(imgs, rows, cols):
assert len(imgs) == rows * cols
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
# create diffuser pipeline
pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16)
pipe.to("cuda")
images = pipe(
prompt,
image_embeds=image_emb,
negative_image_embeds=zero_image_emb,
num_images_per_prompt=2,
height=768,
width=768,
num_inference_steps=100,
guidance_scale=4.0,
generator=generator,
).images
```
One cheeseburger monster coming up! Enjoy!
![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/cheeseburger.png)
The Kandinsky model works extremely well with creative prompts. Here is some of the amazing art that can be created using the exact same process but with different prompts.
```python
prompt = "bird eye view shot of a full body woman with cyan light orange magenta makeup, digital art, long braided hair her face separated by makeup in the style of yin Yang surrealism, symmetrical face, real image, contrasting tone, pastel gradient background"
```
![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/hair.png)
```python
prompt = "A car exploding into colorful dust"
```
![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/dusts.png)
```python
prompt = "editorial photography of an organic, almost liquid smoke style armchair"
```
![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/smokechair.png)
```python
prompt = "birds eye view of a quilted paper style alien planet landscape, vibrant colours, Cinematic lighting"
```
![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/alienplanet.png)
### Text Guided Image-to-Image Generation
The same Kandinsky model weights can be used for text-guided image-to-image translation. In this case, just make sure to load the weights using the [`KandinskyImg2ImgPipeline`] pipeline.
**Note**: You can also directly move the weights of the text-to-image pipelines to the image-to-image pipelines
without loading them twice by making use of the [`~DiffusionPipeline.components`] function as explained [here](#converting-between-different-pipelines).
Let's download an image.
```python
from PIL import Image
import requests
from io import BytesIO
# download image
url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
response = requests.get(url)
original_image = Image.open(BytesIO(response.content)).convert("RGB")
original_image = original_image.resize((768, 512))
```
![img](https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg)
```python
import torch
from diffusers import KandinskyImg2ImgPipeline, KandinskyPriorPipeline
# create prior
pipe_prior = KandinskyPriorPipeline.from_pretrained(
"kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16
)
pipe_prior.to("cuda")
# create img2img pipeline
pipe = KandinskyImg2ImgPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16)
pipe.to("cuda")
prompt = "A fantasy landscape, Cinematic lighting"
negative_prompt = "low quality, bad quality"
generator = torch.Generator(device="cuda").manual_seed(30)
image_emb = pipe_prior(
prompt, guidance_scale=4.0, num_inference_steps=25, generator=generator, negative_prompt=negative_prompt
).images
zero_image_emb = pipe_prior(
negative_prompt, guidance_scale=4.0, num_inference_steps=25, generator=generator, negative_prompt=negative_prompt
).images
out = pipe(
prompt,
image=original_image,
image_embeds=image_emb,
negative_image_embeds=zero_image_emb,
height=768,
width=768,
num_inference_steps=500,
strength=0.3,
)
out.images[0].save("fantasy_land.png")
```
![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/img2img_fantasyland.png)
### Text Guided Inpainting Generation
You can use [`KandinskyInpaintPipeline`] to edit images. In this example, we will add a hat to the portrait of a cat.
```python
from diffusers import KandinskyInpaintPipeline, KandinskyPriorPipeline
from diffusers.utils import load_image
import torch
import numpy as np
pipe_prior = KandinskyPriorPipeline.from_pretrained(
"kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16
)
pipe_prior.to("cuda")
prompt = "a hat"
image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False)
pipe = KandinskyInpaintPipeline.from_pretrained("kandinsky-community/kandinsky-2-1-inpaint", torch_dtype=torch.float16)
pipe.to("cuda")
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png"
)
mask = np.ones((768, 768), dtype=np.float32)
# Let's mask out an area above the cat's head
mask[:250, 250:-250] = 0
out = pipe(
prompt,
image=init_image,
mask_image=mask,
image_embeds=image_emb,
negative_image_embeds=zero_image_emb,
height=768,
width=768,
num_inference_steps=150,
)
image = out.images[0]
image.save("cat_with_hat.png")
```
![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/inpaint_cat_hat.png)
### Interpolate
The [`KandinskyPriorPipeline`] also comes with a cool utility function that will allow you to interpolate the latent space of different images and texts super easily. Here is an example of how you can create an Impressionist-style portrait for your pet based on "The Starry Night".
Note that you can interpolate between texts and images - in the below example, we passed a text prompt "a cat" and two images to the `interplate` function, along with a `weights` variable containing the corresponding weights for each condition we interplate.
```python
from diffusers import KandinskyPriorPipeline, KandinskyPipeline
from diffusers.utils import load_image
import PIL
import torch
from torchvision import transforms
pipe_prior = KandinskyPriorPipeline.from_pretrained(
"kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16
)
pipe_prior.to("cuda")
img1 = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png"
)
img2 = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/starry_night.jpeg"
)
# add all the conditions we want to interpolate, can be either text or image
images_texts = ["a cat", img1, img2]
# specify the weights for each condition in images_texts
weights = [0.3, 0.3, 0.4]
image_emb, zero_image_emb = pipe_prior.interpolate(images_texts, weights)
pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16)
pipe.to("cuda")
image = pipe(
"", image_embeds=image_emb, negative_image_embeds=zero_image_emb, height=768, width=768, num_inference_steps=150
).images[0]
image.save("starry_cat.png")
```
![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/starry_cat.png)
## KandinskyPriorPipeline
[[autodoc]] KandinskyPriorPipeline
- all
- __call__
- interpolate
## KandinskyPipeline
[[autodoc]] KandinskyPipeline
- all
- __call__
## KandinskyInpaintPipeline
[[autodoc]] KandinskyInpaintPipeline
- all
- __call__
## KandinskyImg2ImgPipeline
[[autodoc]] KandinskyImg2ImgPipeline
- all
- __call__
This diff is collapsed.
...@@ -129,6 +129,10 @@ else: ...@@ -129,6 +129,10 @@ else:
IFInpaintingSuperResolutionPipeline, IFInpaintingSuperResolutionPipeline,
IFPipeline, IFPipeline,
IFSuperResolutionPipeline, IFSuperResolutionPipeline,
KandinskyImg2ImgPipeline,
KandinskyInpaintPipeline,
KandinskyPipeline,
KandinskyPriorPipeline,
LDMTextToImagePipeline, LDMTextToImagePipeline,
PaintByExamplePipeline, PaintByExamplePipeline,
SemanticStableDiffusionPipeline, SemanticStableDiffusionPipeline,
......
...@@ -62,6 +62,7 @@ class Attention(nn.Module): ...@@ -62,6 +62,7 @@ class Attention(nn.Module):
cross_attention_norm_num_groups: int = 32, cross_attention_norm_num_groups: int = 32,
added_kv_proj_dim: Optional[int] = None, added_kv_proj_dim: Optional[int] = None,
norm_num_groups: Optional[int] = None, norm_num_groups: Optional[int] = None,
spatial_norm_dim: Optional[int] = None,
out_bias: bool = True, out_bias: bool = True,
scale_qk: bool = True, scale_qk: bool = True,
only_cross_attention: bool = False, only_cross_attention: bool = False,
...@@ -105,6 +106,11 @@ class Attention(nn.Module): ...@@ -105,6 +106,11 @@ class Attention(nn.Module):
else: else:
self.group_norm = None self.group_norm = None
if spatial_norm_dim is not None:
self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
else:
self.spatial_norm = None
if cross_attention_norm is None: if cross_attention_norm is None:
self.norm_cross = None self.norm_cross = None
elif cross_attention_norm == "layer_norm": elif cross_attention_norm == "layer_norm":
...@@ -431,9 +437,13 @@ class AttnProcessor: ...@@ -431,9 +437,13 @@ class AttnProcessor:
hidden_states, hidden_states,
encoder_hidden_states=None, encoder_hidden_states=None,
attention_mask=None, attention_mask=None,
temb=None,
): ):
residual = hidden_states residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim input_ndim = hidden_states.ndim
if input_ndim == 4: if input_ndim == 4:
...@@ -899,9 +909,19 @@ class AttnProcessor2_0: ...@@ -899,9 +909,19 @@ class AttnProcessor2_0:
if not hasattr(F, "scaled_dot_product_attention"): if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): def __call__(
self,
attn: Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
residual = hidden_states residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim input_ndim = hidden_states.ndim
if input_ndim == 4: if input_ndim == 4:
...@@ -1271,3 +1291,26 @@ AttentionProcessor = Union[ ...@@ -1271,3 +1291,26 @@ AttentionProcessor = Union[
CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor,
CustomDiffusionXFormersAttnProcessor, CustomDiffusionXFormersAttnProcessor,
] ]
class SpatialNorm(nn.Module):
"""
Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002
"""
def __init__(
self,
f_channels,
zq_channels,
):
super().__init__()
self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
def forward(self, f, zq):
f_size = f.shape[-2:]
zq = F.interpolate(zq, size=f_size, mode="nearest")
norm_f = self.norm_layer(f)
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
return new_f
...@@ -360,6 +360,33 @@ class LabelEmbedding(nn.Module): ...@@ -360,6 +360,33 @@ class LabelEmbedding(nn.Module):
return embeddings return embeddings
class TextImageProjection(nn.Module):
def __init__(
self,
text_embed_dim: int = 1024,
image_embed_dim: int = 768,
cross_attention_dim: int = 768,
num_image_text_embeds: int = 10,
):
super().__init__()
self.num_image_text_embeds = num_image_text_embeds
self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim)
def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
batch_size = text_embeds.shape[0]
# image
image_text_embeds = self.image_embeds(image_embeds)
image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
# text
text_embeds = self.text_proj(text_embeds)
return torch.cat([image_text_embeds, text_embeds], dim=1)
class CombinedTimestepLabelEmbeddings(nn.Module): class CombinedTimestepLabelEmbeddings(nn.Module):
def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
super().__init__() super().__init__()
...@@ -395,6 +422,24 @@ class TextTimeEmbedding(nn.Module): ...@@ -395,6 +422,24 @@ class TextTimeEmbedding(nn.Module):
return hidden_states return hidden_states
class TextImageTimeEmbedding(nn.Module):
def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536):
super().__init__()
self.text_proj = nn.Linear(text_embed_dim, time_embed_dim)
self.text_norm = nn.LayerNorm(time_embed_dim)
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
# text
time_text_embeds = self.text_proj(text_embeds)
time_text_embeds = self.text_norm(time_text_embeds)
# image
time_image_embeds = self.image_proj(image_embeds)
return time_image_embeds + time_text_embeds
class AttentionPooling(nn.Module): class AttentionPooling(nn.Module):
# Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54 # Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54
......
...@@ -21,6 +21,7 @@ import torch.nn as nn ...@@ -21,6 +21,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .attention import AdaGroupNorm from .attention import AdaGroupNorm
from .attention_processor import SpatialNorm
class Upsample1D(nn.Module): class Upsample1D(nn.Module):
...@@ -500,7 +501,7 @@ class ResnetBlock2D(nn.Module): ...@@ -500,7 +501,7 @@ class ResnetBlock2D(nn.Module):
eps=1e-6, eps=1e-6,
non_linearity="swish", non_linearity="swish",
skip_time_act=False, skip_time_act=False,
time_embedding_norm="default", # default, scale_shift, ada_group time_embedding_norm="default", # default, scale_shift, ada_group, spatial
kernel=None, kernel=None,
output_scale_factor=1.0, output_scale_factor=1.0,
use_in_shortcut=None, use_in_shortcut=None,
...@@ -527,6 +528,8 @@ class ResnetBlock2D(nn.Module): ...@@ -527,6 +528,8 @@ class ResnetBlock2D(nn.Module):
if self.time_embedding_norm == "ada_group": if self.time_embedding_norm == "ada_group":
self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps) self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
elif self.time_embedding_norm == "spatial":
self.norm1 = SpatialNorm(in_channels, temb_channels)
else: else:
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
...@@ -537,7 +540,7 @@ class ResnetBlock2D(nn.Module): ...@@ -537,7 +540,7 @@ class ResnetBlock2D(nn.Module):
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels) self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
elif self.time_embedding_norm == "scale_shift": elif self.time_embedding_norm == "scale_shift":
self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels) self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
elif self.time_embedding_norm == "ada_group": elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
self.time_emb_proj = None self.time_emb_proj = None
else: else:
raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
...@@ -546,6 +549,8 @@ class ResnetBlock2D(nn.Module): ...@@ -546,6 +549,8 @@ class ResnetBlock2D(nn.Module):
if self.time_embedding_norm == "ada_group": if self.time_embedding_norm == "ada_group":
self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps) self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
elif self.time_embedding_norm == "spatial":
self.norm2 = SpatialNorm(out_channels, temb_channels)
else: else:
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
...@@ -591,7 +596,7 @@ class ResnetBlock2D(nn.Module): ...@@ -591,7 +596,7 @@ class ResnetBlock2D(nn.Module):
def forward(self, input_tensor, temb): def forward(self, input_tensor, temb):
hidden_states = input_tensor hidden_states = input_tensor
if self.time_embedding_norm == "ada_group": if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
hidden_states = self.norm1(hidden_states, temb) hidden_states = self.norm1(hidden_states, temb)
else: else:
hidden_states = self.norm1(hidden_states) hidden_states = self.norm1(hidden_states)
...@@ -619,7 +624,7 @@ class ResnetBlock2D(nn.Module): ...@@ -619,7 +624,7 @@ class ResnetBlock2D(nn.Module):
if temb is not None and self.time_embedding_norm == "default": if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb hidden_states = hidden_states + temb
if self.time_embedding_norm == "ada_group": if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
hidden_states = self.norm2(hidden_states, temb) hidden_states = self.norm2(hidden_states, temb)
else: else:
hidden_states = self.norm2(hidden_states) hidden_states = self.norm2(hidden_states)
......
...@@ -349,6 +349,7 @@ def get_up_block( ...@@ -349,6 +349,7 @@ def get_up_block(
resnet_act_fn=resnet_act_fn, resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups, resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
temb_channels=temb_channels,
) )
elif up_block_type == "AttnUpDecoderBlock2D": elif up_block_type == "AttnUpDecoderBlock2D":
return AttnUpDecoderBlock2D( return AttnUpDecoderBlock2D(
...@@ -361,6 +362,7 @@ def get_up_block( ...@@ -361,6 +362,7 @@ def get_up_block(
resnet_groups=resnet_groups, resnet_groups=resnet_groups,
attn_num_head_channels=attn_num_head_channels, attn_num_head_channels=attn_num_head_channels,
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
temb_channels=temb_channels,
) )
elif up_block_type == "KUpBlock2D": elif up_block_type == "KUpBlock2D":
return KUpBlock2D( return KUpBlock2D(
...@@ -396,7 +398,7 @@ class UNetMidBlock2D(nn.Module): ...@@ -396,7 +398,7 @@ class UNetMidBlock2D(nn.Module):
dropout: float = 0.0, dropout: float = 0.0,
num_layers: int = 1, num_layers: int = 1,
resnet_eps: float = 1e-6, resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default", resnet_time_scale_shift: str = "default", # default, spatial
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
...@@ -434,7 +436,8 @@ class UNetMidBlock2D(nn.Module): ...@@ -434,7 +436,8 @@ class UNetMidBlock2D(nn.Module):
dim_head=attn_num_head_channels if attn_num_head_channels is not None else in_channels, dim_head=attn_num_head_channels if attn_num_head_channels is not None else in_channels,
rescale_output_factor=output_scale_factor, rescale_output_factor=output_scale_factor,
eps=resnet_eps, eps=resnet_eps,
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups if resnet_time_scale_shift == "default" else None,
spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
residual_connection=True, residual_connection=True,
bias=True, bias=True,
upcast_softmax=True, upcast_softmax=True,
...@@ -466,7 +469,7 @@ class UNetMidBlock2D(nn.Module): ...@@ -466,7 +469,7 @@ class UNetMidBlock2D(nn.Module):
hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]): for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None: if attn is not None:
hidden_states = attn(hidden_states) hidden_states = attn(hidden_states, temb=temb)
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
return hidden_states return hidden_states
...@@ -2116,12 +2119,13 @@ class UpDecoderBlock2D(nn.Module): ...@@ -2116,12 +2119,13 @@ class UpDecoderBlock2D(nn.Module):
dropout: float = 0.0, dropout: float = 0.0,
num_layers: int = 1, num_layers: int = 1,
resnet_eps: float = 1e-6, resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default", resnet_time_scale_shift: str = "default", # default, spatial
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
output_scale_factor=1.0, output_scale_factor=1.0,
add_upsample=True, add_upsample=True,
temb_channels=None,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -2133,7 +2137,7 @@ class UpDecoderBlock2D(nn.Module): ...@@ -2133,7 +2137,7 @@ class UpDecoderBlock2D(nn.Module):
ResnetBlock2D( ResnetBlock2D(
in_channels=input_channels, in_channels=input_channels,
out_channels=out_channels, out_channels=out_channels,
temb_channels=None, temb_channels=temb_channels,
eps=resnet_eps, eps=resnet_eps,
groups=resnet_groups, groups=resnet_groups,
dropout=dropout, dropout=dropout,
...@@ -2151,9 +2155,9 @@ class UpDecoderBlock2D(nn.Module): ...@@ -2151,9 +2155,9 @@ class UpDecoderBlock2D(nn.Module):
else: else:
self.upsamplers = None self.upsamplers = None
def forward(self, hidden_states): def forward(self, hidden_states, temb=None):
for resnet in self.resnets: for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=None) hidden_states = resnet(hidden_states, temb=temb)
if self.upsamplers is not None: if self.upsamplers is not None:
for upsampler in self.upsamplers: for upsampler in self.upsamplers:
...@@ -2177,6 +2181,7 @@ class AttnUpDecoderBlock2D(nn.Module): ...@@ -2177,6 +2181,7 @@ class AttnUpDecoderBlock2D(nn.Module):
attn_num_head_channels=1, attn_num_head_channels=1,
output_scale_factor=1.0, output_scale_factor=1.0,
add_upsample=True, add_upsample=True,
temb_channels=None,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -2189,7 +2194,7 @@ class AttnUpDecoderBlock2D(nn.Module): ...@@ -2189,7 +2194,7 @@ class AttnUpDecoderBlock2D(nn.Module):
ResnetBlock2D( ResnetBlock2D(
in_channels=input_channels, in_channels=input_channels,
out_channels=out_channels, out_channels=out_channels,
temb_channels=None, temb_channels=temb_channels,
eps=resnet_eps, eps=resnet_eps,
groups=resnet_groups, groups=resnet_groups,
dropout=dropout, dropout=dropout,
...@@ -2206,7 +2211,8 @@ class AttnUpDecoderBlock2D(nn.Module): ...@@ -2206,7 +2211,8 @@ class AttnUpDecoderBlock2D(nn.Module):
dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels, dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels,
rescale_output_factor=output_scale_factor, rescale_output_factor=output_scale_factor,
eps=resnet_eps, eps=resnet_eps,
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups if resnet_time_scale_shift == "default" else None,
spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
residual_connection=True, residual_connection=True,
bias=True, bias=True,
upcast_softmax=True, upcast_softmax=True,
...@@ -2222,10 +2228,10 @@ class AttnUpDecoderBlock2D(nn.Module): ...@@ -2222,10 +2228,10 @@ class AttnUpDecoderBlock2D(nn.Module):
else: else:
self.upsamplers = None self.upsamplers = None
def forward(self, hidden_states): def forward(self, hidden_states, temb=None):
for resnet, attn in zip(self.resnets, self.attentions): for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb=None) hidden_states = resnet(hidden_states, temb=temb)
hidden_states = attn(hidden_states) hidden_states = attn(hidden_states, temb=temb)
if self.upsamplers is not None: if self.upsamplers is not None:
for upsampler in self.upsamplers: for upsampler in self.upsamplers:
......
...@@ -23,7 +23,14 @@ from ..configuration_utils import ConfigMixin, register_to_config ...@@ -23,7 +23,14 @@ from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin from ..loaders import UNet2DConditionLoadersMixin
from ..utils import BaseOutput, logging from ..utils import BaseOutput, logging
from .attention_processor import AttentionProcessor, AttnProcessor from .attention_processor import AttentionProcessor, AttnProcessor
from .embeddings import GaussianFourierProjection, TextTimeEmbedding, TimestepEmbedding, Timesteps from .embeddings import (
GaussianFourierProjection,
TextImageProjection,
TextImageTimeEmbedding,
TextTimeEmbedding,
TimestepEmbedding,
Timesteps,
)
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .unet_2d_blocks import ( from .unet_2d_blocks import (
CrossAttnDownBlock2D, CrossAttnDownBlock2D,
...@@ -90,7 +97,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -90,7 +97,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
The dimension of the cross attention features. The dimension of the cross attention features.
encoder_hid_dim (`int`, *optional*, defaults to None): encoder_hid_dim (`int`, *optional*, defaults to None):
If given, `encoder_hidden_states` will be projected from this dimension to `cross_attention_dim`. If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
dimension to `cross_attention_dim`.
encoder_hid_dim_type (`str`, *optional*, defaults to None):
If given, the `encoder_hidden_states` and potentially other embeddings will be down-projected to text
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`. for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
...@@ -156,6 +167,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -156,6 +167,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
norm_eps: float = 1e-5, norm_eps: float = 1e-5,
cross_attention_dim: Union[int, Tuple[int]] = 1280, cross_attention_dim: Union[int, Tuple[int]] = 1280,
encoder_hid_dim: Optional[int] = None, encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
attention_head_dim: Union[int, Tuple[int]] = 8, attention_head_dim: Union[int, Tuple[int]] = 8,
dual_cross_attention: bool = False, dual_cross_attention: bool = False,
use_linear_projection: bool = False, use_linear_projection: bool = False,
...@@ -247,8 +259,31 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -247,8 +259,31 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
cond_proj_dim=time_cond_proj_dim, cond_proj_dim=time_cond_proj_dim,
) )
if encoder_hid_dim is not None: if encoder_hid_dim_type is None and encoder_hid_dim is not None:
encoder_hid_dim_type = "text_proj"
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
raise ValueError(
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
)
if encoder_hid_dim_type == "text_proj":
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
elif encoder_hid_dim_type == "text_image_proj":
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
self.encoder_hid_proj = TextImageProjection(
text_embed_dim=encoder_hid_dim,
image_embed_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim,
)
elif encoder_hid_dim_type is not None:
raise ValueError(
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
)
else: else:
self.encoder_hid_proj = None self.encoder_hid_proj = None
...@@ -290,8 +325,15 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -290,8 +325,15 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
self.add_embedding = TextTimeEmbedding( self.add_embedding = TextTimeEmbedding(
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
) )
elif addition_embed_type == "text_image":
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
self.add_embedding = TextImageTimeEmbedding(
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
)
elif addition_embed_type is not None: elif addition_embed_type is not None:
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None or 'text'.") raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
if time_embedding_act_fn is None: if time_embedding_act_fn is None:
self.time_embed_act = None self.time_embed_act = None
...@@ -616,6 +658,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -616,6 +658,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
timestep_cond: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
mid_block_additional_residual: Optional[torch.Tensor] = None, mid_block_additional_residual: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
...@@ -636,6 +679,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -636,6 +679,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in `self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
added_cond_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified includes additonal conditions that can be used for additonal time
embeddings or encoder hidden states projections. See the configurations `encoder_hid_dim_type` and
`addition_embed_type` for more information.
Returns: Returns:
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
...@@ -728,12 +775,33 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -728,12 +775,33 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
if self.config.addition_embed_type == "text": if self.config.addition_embed_type == "text":
aug_emb = self.add_embedding(encoder_hidden_states) aug_emb = self.add_embedding(encoder_hidden_states)
emb = emb + aug_emb emb = emb + aug_emb
elif self.config.addition_embed_type == "text_image":
# Kadinsky 2.1 - style
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
)
image_embs = added_cond_kwargs.get("image_embeds")
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
aug_emb = self.add_embedding(text_embs, image_embs)
emb = emb + aug_emb
if self.time_embed_act is not None: if self.time_embed_act is not None:
emb = self.time_embed_act(emb) emb = self.time_embed_act(emb)
if self.encoder_hid_proj is not None: if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
# Kadinsky 2.1 - style
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
)
image_embeds = added_cond_kwargs.get("image_embeds")
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
# 2. pre-process # 2. pre-process
sample = self.conv_in(sample) sample = self.conv_in(sample)
......
...@@ -19,6 +19,7 @@ import torch ...@@ -19,6 +19,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from ..utils import BaseOutput, is_torch_version, randn_tensor from ..utils import BaseOutput, is_torch_version, randn_tensor
from .attention_processor import SpatialNorm
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
...@@ -158,6 +159,7 @@ class Decoder(nn.Module): ...@@ -158,6 +159,7 @@ class Decoder(nn.Module):
layers_per_block=2, layers_per_block=2,
norm_num_groups=32, norm_num_groups=32,
act_fn="silu", act_fn="silu",
norm_type="group", # group, spatial
): ):
super().__init__() super().__init__()
self.layers_per_block = layers_per_block self.layers_per_block = layers_per_block
...@@ -173,16 +175,18 @@ class Decoder(nn.Module): ...@@ -173,16 +175,18 @@ class Decoder(nn.Module):
self.mid_block = None self.mid_block = None
self.up_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([])
temb_channels = in_channels if norm_type == "spatial" else None
# mid # mid
self.mid_block = UNetMidBlock2D( self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1], in_channels=block_out_channels[-1],
resnet_eps=1e-6, resnet_eps=1e-6,
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
output_scale_factor=1, output_scale_factor=1,
resnet_time_scale_shift="default", resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
attn_num_head_channels=None, attn_num_head_channels=None,
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
temb_channels=None, temb_channels=temb_channels,
) )
# up # up
...@@ -205,19 +209,23 @@ class Decoder(nn.Module): ...@@ -205,19 +209,23 @@ class Decoder(nn.Module):
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
attn_num_head_channels=None, attn_num_head_channels=None,
temb_channels=None, temb_channels=temb_channels,
resnet_time_scale_shift=norm_type,
) )
self.up_blocks.append(up_block) self.up_blocks.append(up_block)
prev_output_channel = output_channel prev_output_channel = output_channel
# out # out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) if norm_type == "spatial":
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
else:
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU() self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward(self, z): def forward(self, z, latent_embeds=None):
sample = z sample = z
sample = self.conv_in(sample) sample = self.conv_in(sample)
...@@ -233,34 +241,39 @@ class Decoder(nn.Module): ...@@ -233,34 +241,39 @@ class Decoder(nn.Module):
if is_torch_version(">=", "1.11.0"): if is_torch_version(">=", "1.11.0"):
# middle # middle
sample = torch.utils.checkpoint.checkpoint( sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, use_reentrant=False create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False
) )
sample = sample.to(upscale_dtype) sample = sample.to(upscale_dtype)
# up # up
for up_block in self.up_blocks: for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint( sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block), sample, use_reentrant=False create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False
) )
else: else:
# middle # middle
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, latent_embeds
)
sample = sample.to(upscale_dtype) sample = sample.to(upscale_dtype)
# up # up
for up_block in self.up_blocks: for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample) sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
else: else:
# middle # middle
sample = self.mid_block(sample) sample = self.mid_block(sample, latent_embeds)
sample = sample.to(upscale_dtype) sample = sample.to(upscale_dtype)
# up # up
for up_block in self.up_blocks: for up_block in self.up_blocks:
sample = up_block(sample) sample = up_block(sample, latent_embeds)
# post-process # post-process
sample = self.conv_norm_out(sample) if latent_embeds is None:
sample = self.conv_norm_out(sample)
else:
sample = self.conv_norm_out(sample, latent_embeds)
sample = self.conv_act(sample) sample = self.conv_act(sample)
sample = self.conv_out(sample) sample = self.conv_out(sample)
......
...@@ -82,6 +82,7 @@ class VQModel(ModelMixin, ConfigMixin): ...@@ -82,6 +82,7 @@ class VQModel(ModelMixin, ConfigMixin):
norm_num_groups: int = 32, norm_num_groups: int = 32,
vq_embed_dim: Optional[int] = None, vq_embed_dim: Optional[int] = None,
scaling_factor: float = 0.18215, scaling_factor: float = 0.18215,
norm_type: str = "group", # group, spatial
): ):
super().__init__() super().__init__()
...@@ -112,6 +113,7 @@ class VQModel(ModelMixin, ConfigMixin): ...@@ -112,6 +113,7 @@ class VQModel(ModelMixin, ConfigMixin):
layers_per_block=layers_per_block, layers_per_block=layers_per_block,
act_fn=act_fn, act_fn=act_fn,
norm_num_groups=norm_num_groups, norm_num_groups=norm_num_groups,
norm_type=norm_type,
) )
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput: def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput:
...@@ -131,8 +133,8 @@ class VQModel(ModelMixin, ConfigMixin): ...@@ -131,8 +133,8 @@ class VQModel(ModelMixin, ConfigMixin):
quant, emb_loss, info = self.quantize(h) quant, emb_loss, info = self.quantize(h)
else: else:
quant = h quant = h
quant = self.post_quant_conv(quant) quant2 = self.post_quant_conv(quant)
dec = self.decoder(quant) dec = self.decoder(quant2, quant if self.config.norm_type == "spatial" else None)
if not return_dict: if not return_dict:
return (dec,) return (dec,)
......
...@@ -57,6 +57,12 @@ else: ...@@ -57,6 +57,12 @@ else:
IFPipeline, IFPipeline,
IFSuperResolutionPipeline, IFSuperResolutionPipeline,
) )
from .kandinsky import (
KandinskyImg2ImgPipeline,
KandinskyInpaintPipeline,
KandinskyPipeline,
KandinskyPriorPipeline,
)
from .latent_diffusion import LDMTextToImagePipeline from .latent_diffusion import LDMTextToImagePipeline
from .paint_by_example import PaintByExamplePipeline from .paint_by_example import PaintByExamplePipeline
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
......
from ...utils import (
OptionalDependencyNotAvailable,
is_torch_available,
is_transformers_available,
is_transformers_version,
)
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import KandinskyPipeline, KandinskyPriorPipeline
else:
from .pipeline_kandinsky import KandinskyPipeline
from .pipeline_kandinsky_img2img import KandinskyImg2ImgPipeline
from .pipeline_kandinsky_inpaint import KandinskyInpaintPipeline
from .pipeline_kandinsky_prior import KandinskyPriorPipeline
from .text_encoder import MultilingualCLIP
This diff is collapsed.
This diff is collapsed.
import torch
from transformers import PreTrainedModel, XLMRobertaConfig, XLMRobertaModel
class MCLIPConfig(XLMRobertaConfig):
model_type = "M-CLIP"
def __init__(self, transformerDimSize=1024, imageDimSize=768, **kwargs):
self.transformerDimensions = transformerDimSize
self.numDims = imageDimSize
super().__init__(**kwargs)
class MultilingualCLIP(PreTrainedModel):
config_class = MCLIPConfig
def __init__(self, config, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self.transformer = XLMRobertaModel(config)
self.LinearTransformation = torch.nn.Linear(
in_features=config.transformerDimensions, out_features=config.numDims
)
def forward(self, input_ids, attention_mask):
embs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)[0]
embs2 = (embs * attention_mask.unsqueeze(2)).sum(dim=1) / attention_mask.sum(dim=1)[:, None]
return self.LinearTransformation(embs2), embs
...@@ -15,7 +15,14 @@ from ...models.attention_processor import ( ...@@ -15,7 +15,14 @@ from ...models.attention_processor import (
AttnProcessor, AttnProcessor,
) )
from ...models.dual_transformer_2d import DualTransformer2DModel from ...models.dual_transformer_2d import DualTransformer2DModel
from ...models.embeddings import GaussianFourierProjection, TextTimeEmbedding, TimestepEmbedding, Timesteps from ...models.embeddings import (
GaussianFourierProjection,
TextImageProjection,
TextImageTimeEmbedding,
TextTimeEmbedding,
TimestepEmbedding,
Timesteps,
)
from ...models.transformer_2d import Transformer2DModel from ...models.transformer_2d import Transformer2DModel
from ...models.unet_2d_condition import UNet2DConditionOutput from ...models.unet_2d_condition import UNet2DConditionOutput
from ...utils import is_torch_version, logging from ...utils import is_torch_version, logging
...@@ -182,7 +189,11 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -182,7 +189,11 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
The dimension of the cross attention features. The dimension of the cross attention features.
encoder_hid_dim (`int`, *optional*, defaults to None): encoder_hid_dim (`int`, *optional*, defaults to None):
If given, `encoder_hidden_states` will be projected from this dimension to `cross_attention_dim`. If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
dimension to `cross_attention_dim`.
encoder_hid_dim_type (`str`, *optional*, defaults to None):
If given, the `encoder_hidden_states` and potentially other embeddings will be down-projected to text
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`. for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`.
...@@ -253,6 +264,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -253,6 +264,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
norm_eps: float = 1e-5, norm_eps: float = 1e-5,
cross_attention_dim: Union[int, Tuple[int]] = 1280, cross_attention_dim: Union[int, Tuple[int]] = 1280,
encoder_hid_dim: Optional[int] = None, encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
attention_head_dim: Union[int, Tuple[int]] = 8, attention_head_dim: Union[int, Tuple[int]] = 8,
dual_cross_attention: bool = False, dual_cross_attention: bool = False,
use_linear_projection: bool = False, use_linear_projection: bool = False,
...@@ -350,8 +362,31 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -350,8 +362,31 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
cond_proj_dim=time_cond_proj_dim, cond_proj_dim=time_cond_proj_dim,
) )
if encoder_hid_dim is not None: if encoder_hid_dim_type is None and encoder_hid_dim is not None:
encoder_hid_dim_type = "text_proj"
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
raise ValueError(
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
)
if encoder_hid_dim_type == "text_proj":
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
elif encoder_hid_dim_type == "text_image_proj":
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
self.encoder_hid_proj = TextImageProjection(
text_embed_dim=encoder_hid_dim,
image_embed_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim,
)
elif encoder_hid_dim_type is not None:
raise ValueError(
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
)
else: else:
self.encoder_hid_proj = None self.encoder_hid_proj = None
...@@ -393,8 +428,15 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -393,8 +428,15 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
self.add_embedding = TextTimeEmbedding( self.add_embedding = TextTimeEmbedding(
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
) )
elif addition_embed_type == "text_image":
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
self.add_embedding = TextImageTimeEmbedding(
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
)
elif addition_embed_type is not None: elif addition_embed_type is not None:
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None or 'text'.") raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
if time_embedding_act_fn is None: if time_embedding_act_fn is None:
self.time_embed_act = None self.time_embed_act = None
...@@ -719,6 +761,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -719,6 +761,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
timestep_cond: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
mid_block_additional_residual: Optional[torch.Tensor] = None, mid_block_additional_residual: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
...@@ -739,6 +782,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -739,6 +782,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in `self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
added_cond_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified includes additonal conditions that can be used for additonal time
embeddings or encoder hidden states projections. See the configurations `encoder_hid_dim_type` and
`addition_embed_type` for more information.
Returns: Returns:
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
...@@ -831,12 +878,35 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -831,12 +878,35 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
if self.config.addition_embed_type == "text": if self.config.addition_embed_type == "text":
aug_emb = self.add_embedding(encoder_hidden_states) aug_emb = self.add_embedding(encoder_hidden_states)
emb = emb + aug_emb emb = emb + aug_emb
elif self.config.addition_embed_type == "text_image":
# Kadinsky 2.1 - style
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires"
" the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
)
image_embs = added_cond_kwargs.get("image_embeds")
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
aug_emb = self.add_embedding(text_embs, image_embs)
emb = emb + aug_emb
if self.time_embed_act is not None: if self.time_embed_act is not None:
emb = self.time_embed_act(emb) emb = self.time_embed_act(emb)
if self.encoder_hid_proj is not None: if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
# Kadinsky 2.1 - style
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which"
" requires the keyword argument `image_embeds` to be passed in `added_conditions`"
)
image_embeds = added_cond_kwargs.get("image_embeds")
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
# 2. pre-process # 2. pre-process
sample = self.conv_in(sample) sample = self.conv_in(sample)
......
...@@ -152,6 +152,66 @@ class IFSuperResolutionPipeline(metaclass=DummyObject): ...@@ -152,6 +152,66 @@ class IFSuperResolutionPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class KandinskyImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class KandinskyInpaintPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class KandinskyPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class KandinskyPriorPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class LDMTextToImagePipeline(metaclass=DummyObject): class LDMTextToImagePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment