Unverified Commit 74621567 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

Kandinsky_v22_yiyi (#3936)



* Kandinsky2_2

* fix init kandinsky2_2

* kandinsky2_2 fix inpainting

* rename pipelines: remove decoder + 2_2 -> V22

* Update scheduling_unclip.py

* remove text_encoder and tokenizer arguments from doc string

* add test for text2img

* add tests for text2img & img2img

* fix

* add test for inpaint

* add prior tests

* style

* copies

* add controlnet test

* style

* add a test for controlnet_img2img

* update prior_emb2emb api to accept image_embedding or image

* add a test for prior_emb2emb

* style

* remove try except

* example

* fix

* add doc string examples to all kandinsky pipelines

* style

* update doc

* style

* add a top about 2.2

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* vae -> movq

* vae -> movq

* style

* fix the #copied from

* remove decoder from file name

* update doc: add a section for kandinsky 2.2

* fix

* fix-copies

* add coped from

* add copies from for prior

* add copies from for prior emb2emb

* copy from for img2img

* copied from for inpaint

* more copied from

* more copies from

* more copies

* remove the yiyi comments

* Apply suggestions from code review

* Self-contained example, pipeline order

* Import prior output instead of redefining.

* Style

* Make VQModel compatible with model offload.

* Fix copies

---------
Co-authored-by: default avatarShahmatov Arseniy <62886550+cene555@users.noreply.github.com>
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent bc9a8cef
...@@ -11,19 +11,12 @@ specific language governing permissions and limitations under the License. ...@@ -11,19 +11,12 @@ specific language governing permissions and limitations under the License.
## Overview ## Overview
Kandinsky 2.1 inherits best practices from [DALL-E 2](https://arxiv.org/abs/2204.06125) and [Latent Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/latent_diffusion), while introducing some new ideas. Kandinsky inherits best practices from [DALL-E 2](https://huggingface.co/papers/2204.06125) and [Latent Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/latent_diffusion), while introducing some new ideas.
It uses [CLIP](https://huggingface.co/docs/transformers/model_doc/clip) for encoding images and text, and a diffusion image prior (mapping) between latent spaces of CLIP modalities. This approach enhances the visual performance of the model and unveils new horizons in blending images and text-guided image manipulation. It uses [CLIP](https://huggingface.co/docs/transformers/model_doc/clip) for encoding images and text, and a diffusion image prior (mapping) between latent spaces of CLIP modalities. This approach enhances the visual performance of the model and unveils new horizons in blending images and text-guided image manipulation.
The Kandinsky model is created by [Arseniy Shakhmatov](https://github.com/cene555), [Anton Razzhigaev](https://github.com/razzant), [Aleksandr Nikolich](https://github.com/AlexWortega), [Igor Pavlov](https://github.com/boomb0om), [Andrey Kuznetsov](https://github.com/kuznetsoffandrey) and [Denis Dimitrov](https://github.com/denndimitrov) and the original codebase can be found [here](https://github.com/ai-forever/Kandinsky-2) The Kandinsky model is created by [Arseniy Shakhmatov](https://github.com/cene555), [Anton Razzhigaev](https://github.com/razzant), [Aleksandr Nikolich](https://github.com/AlexWortega), [Igor Pavlov](https://github.com/boomb0om), [Andrey Kuznetsov](https://github.com/kuznetsoffandrey) and [Denis Dimitrov](https://github.com/denndimitrov). The original codebase can be found [here](https://github.com/ai-forever/Kandinsky-2)
## Available Pipelines:
| Pipeline | Tasks |
|---|---|
| [pipeline_kandinsky.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py) | *Text-to-Image Generation* |
| [pipeline_kandinsky_inpaint.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py) | *Image-Guided Image Generation* |
| [pipeline_kandinsky_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py) | *Image-Guided Image Generation* |
## Usage example ## Usage example
...@@ -135,6 +128,7 @@ prompt = "birds eye view of a quilted paper style alien planet landscape, vibran ...@@ -135,6 +128,7 @@ prompt = "birds eye view of a quilted paper style alien planet landscape, vibran
![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/alienplanet.png) ![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/alienplanet.png)
### Text Guided Image-to-Image Generation ### Text Guided Image-to-Image Generation
The same Kandinsky model weights can be used for text-guided image-to-image translation. In this case, just make sure to load the weights using the [`KandinskyImg2ImgPipeline`] pipeline. The same Kandinsky model weights can be used for text-guided image-to-image translation. In this case, just make sure to load the weights using the [`KandinskyImg2ImgPipeline`] pipeline.
...@@ -283,6 +277,207 @@ image.save("starry_cat.png") ...@@ -283,6 +277,207 @@ image.save("starry_cat.png")
![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/starry_cat.png) ![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/starry_cat.png)
### Text-to-Image Generation with ControlNet Conditioning
In the following, we give a simple example of how to use [`KandinskyV22ControlnetPipeline`] to add control to the text-to-image generation with a depth image.
First, let's take an image and extract its depth map.
```python
from diffusers.utils import load_image
img = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinskyv22/cat.png"
).resize((768, 768))
```
![img](https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinskyv22/cat.png)
We can use the `depth-estimation` pipeline from transformers to process the image and retrieve its depth map.
```python
import torch
import numpy as np
from transformers import pipeline
from diffusers.utils import load_image
def make_hint(image, depth_estimator):
image = depth_estimator(image)["depth"]
image = np.array(image)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
detected_map = torch.from_numpy(image).float() / 255.0
hint = detected_map.permute(2, 0, 1)
return hint
depth_estimator = pipeline("depth-estimation")
hint = make_hint(img, depth_estimator).unsqueeze(0).half().to("cuda")
```
Now, we load the prior pipeline and the text-to-image controlnet pipeline
```python
from diffusers import KandinskyV22PriorPipeline, KandinskyV22ControlnetPipeline
pipe_prior = KandinskyV22PriorPipeline.from_pretrained(
"kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16
)
pipe_prior = pipe_prior.to("cuda")
pipe = KandinskyV22ControlnetPipeline.from_pretrained(
"kandinsky-community/kandinsky-2-2-controlnet-depth", torch_dtype=torch.float16
)
pipe = pipe.to("cuda")
```
We pass the prompt and negative prompt through the prior to generate image embeddings
```python
prompt = "A robot, 4k photo"
negative_prior_prompt = "lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature"
generator = torch.Generator(device="cuda").manual_seed(43)
image_emb, zero_image_emb = pipe_prior(
prompt=prompt, negative_prompt=negative_prior_prompt, generator=generator
).to_tuple()
```
Now we can pass the image embeddings and the depth image we extracted to the controlnet pipeline. With Kandinsky 2.2, only prior pipelines accept `prompt` input. You do not need to pass the prompt to the controlnet pipeline.
```python
images = pipe(
image_embeds=image_emb,
negative_image_embeds=zero_image_emb,
hint=hint,
num_inference_steps=50,
generator=generator,
height=768,
width=768,
).images
images[0].save("robot_cat.png")
```
The output image looks as follow:
![img](https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinskyv22/robot_cat_text2img.png)
### Image-to-Image Generation with ControlNet Conditioning
Kandinsky 2.2 also includes a [`KandinskyV22ControlnetImg2ImgPipeline`] that will allow you to add control to the image generation process with both the image and its depth map. This pipeline works really well with [`KandinskyV22PriorEmb2EmbPipeline`], which generates image embeddings based on both a text prompt and an image.
For our robot cat example, we will pass the prompt and cat image together to the prior pipeline to generate an image embedding. We will then use that image embedding and the depth map of the cat to further control the image generation process.
We can use the same cat image and its depth map from the last example.
```python
import torch
import numpy as np
from diffusers import KandinskyV22PriorEmb2EmbPipeline, KandinskyV22ControlnetImg2ImgPipeline
from diffusers.utils import load_image
from transformers import pipeline
img = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinskyv22/cat.png"
).resize((768, 768))
def make_hint(image, depth_estimator):
image = depth_estimator(image)["depth"]
image = np.array(image)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
detected_map = torch.from_numpy(image).float() / 255.0
hint = detected_map.permute(2, 0, 1)
return hint
depth_estimator = pipeline("depth-estimation")
hint = make_hint(img, depth_estimator).unsqueeze(0).half().to("cuda")
pipe_prior = KandinskyV22PriorEmb2EmbPipeline.from_pretrained(
"kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16
)
pipe_prior = pipe_prior.to("cuda")
pipe = KandinskyV22ControlnetImg2ImgPipeline.from_pretrained(
"kandinsky-community/kandinsky-2-2-controlnet-depth", torch_dtype=torch.float16
)
pipe = pipe.to("cuda")
prompt = "A robot, 4k photo"
negative_prior_prompt = "lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature"
generator = torch.Generator(device="cuda").manual_seed(43)
# run prior pipeline
img_emb = pipe_prior(prompt=prompt, image=img, strength=0.85, generator=generator)
negative_emb = pipe_prior(prompt=negative_prior_prompt, image=img, strength=1, generator=generator)
# run controlnet img2img pipeline
images = pipe(
image=img,
strength=0.5,
image_embeds=img_emb.image_embeds,
negative_image_embeds=negative_emb.image_embeds,
hint=hint,
num_inference_steps=50,
generator=generator,
height=768,
width=768,
).images
images[0].save("robot_cat.png")
```
Here is the output. Compared with the output from our text-to-image controlnet example, it kept a lot more cat facial details from the original image and worked into the robot style we asked for.
![img](https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinskyv22/robot_cat.png)
## Kandinsky 2.2
The Kandinsky 2.2 release includes robust new text-to-image models that support text-to-image generation, image-to-image generation, image interpolation, and text-guided image inpainting. The general workflow to perform these tasks using Kandinsky 2.2 is the same as in Kandinsky 2.1. First, you will need to use a prior pipeline to generate image embeddings based on your text prompt, and then use one of the image decoding pipelines to generate the output image. The only difference is that in Kandinsky 2.2, all of the decoding pipelines no longer accept the `prompt` input, and the image generation process is conditioned with only `image_embeds` and `negative_image_embeds`.
Let's look at an example of how to perform text-to-image generation using Kandinsky 2.2.
First, let's create the prior pipeline and text-to-image pipeline with Kandinsky 2.2 checkpoints.
```python
from diffusers import DiffusionPipeline
import torch
pipe_prior = DiffusionPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16)
pipe_prior.to("cuda")
t2i_pipe = DiffusionPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16)
t2i_pipe.to("cuda")
```
You can then use `pipe_prior` to generate image embeddings.
```python
prompt = "portrait of a women, blue eyes, cinematic"
negative_prompt = "low quality, bad quality"
image_embeds, negative_image_embeds = pipe_prior(prompt, guidance_scale=1.0).to_tuple()
```
Now you can pass these embeddings to the text-to-image pipeline. When using Kandinsky 2.2 you don't need to pass the `prompt` (but you do with the previous version, Kandinsky 2.1).
```
image = t2i_pipe(image_embeds=image_embeds, negative_image_embeds=negative_image_embeds, height=768, width=768).images[
0
]
image.save("portrait.png")
```
![img](https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinskyv22/%20blue%20eyes.png)
We used the text-to-image pipeline as an example, but the same process applies to all decoding pipelines in Kandinsky 2.2. For more information, please refer to our API section for each pipeline.
## Optimization ## Optimization
Running Kandinsky in inference requires running both a first prior pipeline: [`KandinskyPriorPipeline`] Running Kandinsky in inference requires running both a first prior pipeline: [`KandinskyPriorPipeline`]
...@@ -335,30 +530,84 @@ t2i_pipe.unet = torch.compile(t2i_pipe.unet, mode="reduce-overhead", fullgraph=T ...@@ -335,30 +530,84 @@ t2i_pipe.unet = torch.compile(t2i_pipe.unet, mode="reduce-overhead", fullgraph=T
After compilation you should see a very fast inference time. For more information, After compilation you should see a very fast inference time. For more information,
feel free to have a look at [Our PyTorch 2.0 benchmark](https://huggingface.co/docs/diffusers/main/en/optimization/torch2.0). feel free to have a look at [Our PyTorch 2.0 benchmark](https://huggingface.co/docs/diffusers/main/en/optimization/torch2.0).
## Available Pipelines:
| Pipeline | Tasks |
|---|---|
| [pipeline_kandinsky2_2.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py) | *Text-to-Image Generation* |
| [pipeline_kandinsky.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py) | *Text-to-Image Generation* |
| [pipeline_kandinsky2_2_inpaint.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py) | *Image-Guided Image Generation* |
| [pipeline_kandinsky_inpaint.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py) | *Image-Guided Image Generation* |
| [pipeline_kandinsky2_2_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py) | *Image-Guided Image Generation* |
| [pipeline_kandinsky_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py) | *Image-Guided Image Generation* |
| [pipeline_kandinsky2_2_controlnet.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py) | *Image-Guided Image Generation* |
| [pipeline_kandinsky2_2_controlnet_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py) | *Image-Guided Image Generation* |
### KandinskyV22Pipeline
[[autodoc]] KandinskyV22Pipeline
- all
- __call__
### KandinskyV22ControlnetPipeline
[[autodoc]] KandinskyV22ControlnetPipeline
- all
- __call__
### KandinskyV22ControlnetImg2ImgPipeline
[[autodoc]] KandinskyV22ControlnetImg2ImgPipeline
- all
- __call__
### KandinskyV22Img2ImgPipeline
[[autodoc]] KandinskyV22Img2ImgPipeline
- all
- __call__
### KandinskyV22InpaintPipeline
[[autodoc]] KandinskyV22InpaintPipeline
- all
- __call__
### KandinskyV22PriorPipeline
[[autodoc]] ## KandinskyV22PriorPipeline
- all
- __call__
- interpolate
### KandinskyV22PriorEmb2EmbPipeline
[[autodoc]] KandinskyV22PriorEmb2EmbPipeline
- all
- __call__
- interpolate
## KandinskyPriorPipeline ### KandinskyPriorPipeline
[[autodoc]] KandinskyPriorPipeline [[autodoc]] KandinskyPriorPipeline
- all - all
- __call__ - __call__
- interpolate - interpolate
## KandinskyPipeline ### KandinskyPipeline
[[autodoc]] KandinskyPipeline [[autodoc]] KandinskyPipeline
- all - all
- __call__ - __call__
## KandinskyImg2ImgPipeline ### KandinskyImg2ImgPipeline
[[autodoc]] KandinskyImg2ImgPipeline [[autodoc]] KandinskyImg2ImgPipeline
- all - all
- __call__ - __call__
## KandinskyInpaintPipeline ### KandinskyInpaintPipeline
[[autodoc]] KandinskyInpaintPipeline [[autodoc]] KandinskyInpaintPipeline
- all - all
......
...@@ -139,6 +139,13 @@ else: ...@@ -139,6 +139,13 @@ else:
KandinskyInpaintPipeline, KandinskyInpaintPipeline,
KandinskyPipeline, KandinskyPipeline,
KandinskyPriorPipeline, KandinskyPriorPipeline,
KandinskyV22ControlnetImg2ImgPipeline,
KandinskyV22ControlnetPipeline,
KandinskyV22Img2ImgPipeline,
KandinskyV22InpaintPipeline,
KandinskyV22Pipeline,
KandinskyV22PriorEmb2EmbPipeline,
KandinskyV22PriorPipeline,
LDMTextToImagePipeline, LDMTextToImagePipeline,
PaintByExamplePipeline, PaintByExamplePipeline,
SemanticStableDiffusionPipeline, SemanticStableDiffusionPipeline,
......
...@@ -376,6 +376,29 @@ class TextImageProjection(nn.Module): ...@@ -376,6 +376,29 @@ class TextImageProjection(nn.Module):
return torch.cat([image_text_embeds, text_embeds], dim=1) return torch.cat([image_text_embeds, text_embeds], dim=1)
class ImageProjection(nn.Module):
def __init__(
self,
image_embed_dim: int = 768,
cross_attention_dim: int = 768,
num_image_text_embeds: int = 32,
):
super().__init__()
self.num_image_text_embeds = num_image_text_embeds
self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
self.norm = nn.LayerNorm(cross_attention_dim)
def forward(self, image_embeds: torch.FloatTensor):
batch_size = image_embeds.shape[0]
# image
image_embeds = self.image_embeds(image_embeds)
image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
image_embeds = self.norm(image_embeds)
return image_embeds
class CombinedTimestepLabelEmbeddings(nn.Module): class CombinedTimestepLabelEmbeddings(nn.Module):
def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
super().__init__() super().__init__()
...@@ -429,6 +452,50 @@ class TextImageTimeEmbedding(nn.Module): ...@@ -429,6 +452,50 @@ class TextImageTimeEmbedding(nn.Module):
return time_image_embeds + time_text_embeds return time_image_embeds + time_text_embeds
class ImageTimeEmbedding(nn.Module):
def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
super().__init__()
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
self.image_norm = nn.LayerNorm(time_embed_dim)
def forward(self, image_embeds: torch.FloatTensor):
# image
time_image_embeds = self.image_proj(image_embeds)
time_image_embeds = self.image_norm(time_image_embeds)
return time_image_embeds
class ImageHintTimeEmbedding(nn.Module):
def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
super().__init__()
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
self.image_norm = nn.LayerNorm(time_embed_dim)
self.input_hint_block = nn.Sequential(
nn.Conv2d(3, 16, 3, padding=1),
nn.SiLU(),
nn.Conv2d(16, 16, 3, padding=1),
nn.SiLU(),
nn.Conv2d(16, 32, 3, padding=1, stride=2),
nn.SiLU(),
nn.Conv2d(32, 32, 3, padding=1),
nn.SiLU(),
nn.Conv2d(32, 96, 3, padding=1, stride=2),
nn.SiLU(),
nn.Conv2d(96, 96, 3, padding=1),
nn.SiLU(),
nn.Conv2d(96, 256, 3, padding=1, stride=2),
nn.SiLU(),
nn.Conv2d(256, 4, 3, padding=1),
)
def forward(self, image_embeds: torch.FloatTensor, hint: torch.FloatTensor):
# image
time_image_embeds = self.image_proj(image_embeds)
time_image_embeds = self.image_norm(time_image_embeds)
hint = self.input_hint_block(hint)
return time_image_embeds, hint
class AttentionPooling(nn.Module): class AttentionPooling(nn.Module):
# Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54 # Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54
......
...@@ -25,6 +25,9 @@ from .activations import get_activation ...@@ -25,6 +25,9 @@ from .activations import get_activation
from .attention_processor import AttentionProcessor, AttnProcessor from .attention_processor import AttentionProcessor, AttnProcessor
from .embeddings import ( from .embeddings import (
GaussianFourierProjection, GaussianFourierProjection,
ImageHintTimeEmbedding,
ImageProjection,
ImageTimeEmbedding,
TextImageProjection, TextImageProjection,
TextImageTimeEmbedding, TextImageTimeEmbedding,
TextTimeEmbedding, TextTimeEmbedding,
...@@ -306,7 +309,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -306,7 +309,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
image_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
) )
elif encoder_hid_dim_type == "image_proj":
# Kandinsky 2.2
self.encoder_hid_proj = ImageProjection(
image_embed_dim=encoder_hid_dim,
cross_attention_dim=cross_attention_dim,
)
elif encoder_hid_dim_type is not None: elif encoder_hid_dim_type is not None:
raise ValueError( raise ValueError(
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
...@@ -362,7 +370,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -362,7 +370,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
elif addition_embed_type == "text_time": elif addition_embed_type == "text_time":
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
elif addition_embed_type == "image":
# Kandinsky 2.2
self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
elif addition_embed_type == "image_hint":
# Kandinsky 2.2 ControlNet
self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
elif addition_embed_type is not None: elif addition_embed_type is not None:
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
...@@ -808,7 +821,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -808,7 +821,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
if self.config.addition_embed_type == "text": if self.config.addition_embed_type == "text":
aug_emb = self.add_embedding(encoder_hidden_states) aug_emb = self.add_embedding(encoder_hidden_states)
elif self.config.addition_embed_type == "text_image": elif self.config.addition_embed_type == "text_image":
# Kadinsky 2.1 - style # Kandinsky 2.1 - style
if "image_embeds" not in added_cond_kwargs: if "image_embeds" not in added_cond_kwargs:
raise ValueError( raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
...@@ -816,7 +829,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -816,7 +829,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
image_embs = added_cond_kwargs.get("image_embeds") image_embs = added_cond_kwargs.get("image_embeds")
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
aug_emb = self.add_embedding(text_embs, image_embs) aug_emb = self.add_embedding(text_embs, image_embs)
elif self.config.addition_embed_type == "text_time": elif self.config.addition_embed_type == "text_time":
if "text_embeds" not in added_cond_kwargs: if "text_embeds" not in added_cond_kwargs:
...@@ -835,6 +847,24 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -835,6 +847,24 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype) add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds) aug_emb = self.add_embedding(add_embeds)
elif self.config.addition_embed_type == "image":
# Kandinsky 2.2 - style
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
)
image_embs = added_cond_kwargs.get("image_embeds")
aug_emb = self.add_embedding(image_embs)
elif self.config.addition_embed_type == "image_hint":
# Kandinsky 2.2 - style
if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
)
image_embs = added_cond_kwargs.get("image_embeds")
hint = added_cond_kwargs.get("hint")
aug_emb, hint = self.add_embedding(image_embs, hint)
sample = torch.cat([sample, hint], dim=1)
emb = emb + aug_emb if aug_emb is not None else emb emb = emb + aug_emb if aug_emb is not None else emb
...@@ -852,7 +882,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -852,7 +882,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
image_embeds = added_cond_kwargs.get("image_embeds") image_embeds = added_cond_kwargs.get("image_embeds")
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
# Kandinsky 2.2 - style
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
)
image_embeds = added_cond_kwargs.get("image_embeds")
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
# 2. pre-process # 2. pre-process
sample = self.conv_in(sample) sample = self.conv_in(sample)
......
...@@ -18,7 +18,7 @@ import torch ...@@ -18,7 +18,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput from ..utils import BaseOutput, apply_forward_hook
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .vae import Decoder, DecoderOutput, Encoder, VectorQuantizer from .vae import Decoder, DecoderOutput, Encoder, VectorQuantizer
...@@ -116,6 +116,7 @@ class VQModel(ModelMixin, ConfigMixin): ...@@ -116,6 +116,7 @@ class VQModel(ModelMixin, ConfigMixin):
norm_type=norm_type, norm_type=norm_type,
) )
@apply_forward_hook
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput: def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput:
h = self.encoder(x) h = self.encoder(x)
h = self.quant_conv(h) h = self.quant_conv(h)
...@@ -125,6 +126,7 @@ class VQModel(ModelMixin, ConfigMixin): ...@@ -125,6 +126,7 @@ class VQModel(ModelMixin, ConfigMixin):
return VQEncoderOutput(latents=h) return VQEncoderOutput(latents=h)
@apply_forward_hook
def decode( def decode(
self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True
) -> Union[DecoderOutput, torch.FloatTensor]: ) -> Union[DecoderOutput, torch.FloatTensor]:
......
...@@ -65,6 +65,15 @@ else: ...@@ -65,6 +65,15 @@ else:
KandinskyPipeline, KandinskyPipeline,
KandinskyPriorPipeline, KandinskyPriorPipeline,
) )
from .kandinsky2_2 import (
KandinskyV22ControlnetImg2ImgPipeline,
KandinskyV22ControlnetPipeline,
KandinskyV22Img2ImgPipeline,
KandinskyV22InpaintPipeline,
KandinskyV22Pipeline,
KandinskyV22PriorEmb2EmbPipeline,
KandinskyV22PriorPipeline,
)
from .latent_diffusion import LDMTextToImagePipeline from .latent_diffusion import LDMTextToImagePipeline
from .paint_by_example import PaintByExamplePipeline from .paint_by_example import PaintByExamplePipeline
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
......
...@@ -15,5 +15,5 @@ else: ...@@ -15,5 +15,5 @@ else:
from .pipeline_kandinsky import KandinskyPipeline from .pipeline_kandinsky import KandinskyPipeline
from .pipeline_kandinsky_img2img import KandinskyImg2ImgPipeline from .pipeline_kandinsky_img2img import KandinskyImg2ImgPipeline
from .pipeline_kandinsky_inpaint import KandinskyInpaintPipeline from .pipeline_kandinsky_inpaint import KandinskyInpaintPipeline
from .pipeline_kandinsky_prior import KandinskyPriorPipeline from .pipeline_kandinsky_prior import KandinskyPriorPipeline, KandinskyPriorPipelineOutput
from .text_encoder import MultilingualCLIP from .text_encoder import MultilingualCLIP
...@@ -115,6 +115,7 @@ class KandinskyPipeline(DiffusionPipeline): ...@@ -115,6 +115,7 @@ class KandinskyPipeline(DiffusionPipeline):
) )
self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
if latents is None: if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
......
...@@ -275,6 +275,7 @@ class KandinskyInpaintPipeline(DiffusionPipeline): ...@@ -275,6 +275,7 @@ class KandinskyInpaintPipeline(DiffusionPipeline):
) )
self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
if latents is None: if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
......
...@@ -274,6 +274,7 @@ class KandinskyPriorPipeline(DiffusionPipeline): ...@@ -274,6 +274,7 @@ class KandinskyPriorPipeline(DiffusionPipeline):
return KandinskyPriorPipelineOutput(image_embeds=image_emb, negative_image_embeds=zero_image_emb) return KandinskyPriorPipelineOutput(image_embeds=image_emb, negative_image_embeds=zero_image_emb)
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
if latents is None: if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
......
from .pipeline_kandinsky2_2 import KandinskyV22Pipeline
from .pipeline_kandinsky2_2_controlnet import KandinskyV22ControlnetPipeline
from .pipeline_kandinsky2_2_controlnet_img2img import KandinskyV22ControlnetImg2ImgPipeline
from .pipeline_kandinsky2_2_img2img import KandinskyV22Img2ImgPipeline
from .pipeline_kandinsky2_2_inpainting import KandinskyV22InpaintPipeline
from .pipeline_kandinsky2_2_prior import KandinskyV22PriorPipeline
from .pipeline_kandinsky2_2_prior_emb2emb import KandinskyV22PriorEmb2EmbPipeline
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
import torch
from ...models import UNet2DConditionModel, VQModel
from ...pipelines import DiffusionPipeline
from ...pipelines.pipeline_utils import ImagePipelineOutput
from ...schedulers import DDPMScheduler
from ...utils import (
is_accelerate_available,
is_accelerate_version,
logging,
randn_tensor,
replace_example_docstring,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> from diffusers import KandinskyV22Pipeline, KandinskyV22PriorPipeline
>>> import torch
>>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-prior")
>>> pipe_prior.to("cuda")
>>> prompt = "red cat, 4k photo"
>>> out = pipe_prior(prompt)
>>> image_emb = out.image_embeds
>>> zero_image_emb = out.negative_image_embeds
>>> pipe = KandinskyV22Pipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder")
>>> pipe.to("cuda")
>>> image = pipe(
... image_embeds=image_emb,
... negative_image_embeds=zero_image_emb,
... height=768,
... width=768,
... num_inference_steps=50,
... ).images
>>> image[0].save("cat.png")
```
"""
def downscale_height_and_width(height, width, scale_factor=8):
new_height = height // scale_factor**2
if height % scale_factor**2 != 0:
new_height += 1
new_width = width // scale_factor**2
if width % scale_factor**2 != 0:
new_width += 1
return new_height * scale_factor, new_width * scale_factor
class KandinskyV22Pipeline(DiffusionPipeline):
"""
Pipeline for text-to-image generation using Kandinsky
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
scheduler (Union[`DDIMScheduler`,`DDPMScheduler`]):
A scheduler to be used in combination with `unet` to generate image latents.
unet ([`UNet2DConditionModel`]):
Conditional U-Net architecture to denoise the image embedding.
movq ([`VQModel`]):
MoVQ Decoder to generate the image from the latents.
"""
def __init__(
self,
unet: UNet2DConditionModel,
scheduler: DDPMScheduler,
movq: VQModel,
):
super().__init__()
self.register_modules(
unet=unet,
scheduler=scheduler,
movq=movq,
)
self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
latents = latents * scheduler.init_noise_sigma
return latents
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
when their specific submodule has its `forward` method called.
"""
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device(f"cuda:{gpu_id}")
models = [
self.unet,
self.movq,
]
for cpu_offloaded_model in models:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
def enable_model_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
"""
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate import cpu_offload_with_hook
else:
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
device = torch.device(f"cuda:{gpu_id}")
if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
hook = None
for cpu_offloaded_model in [self.unet, self.movq]:
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
# We'll offload the last model manually.
self.final_offload_hook = hook
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def _execution_device(self):
r"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if not hasattr(self.unet, "_hf_hook"):
return self.device
for module in self.unet.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return self.device
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]],
negative_image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]],
height: int = 512,
width: int = 512,
num_inference_steps: int = 100,
guidance_scale: float = 4.0,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
):
"""
Args:
Function invoked when calling the pipeline for generation.
image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`):
The clip image embeddings for text prompt, that will be used to condition the image generation.
negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`):
The clip image embeddings for negative text prompt, will be used to condition the image generation.
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 100):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 4.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
Examples:
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`
"""
device = self._execution_device
do_classifier_free_guidance = guidance_scale > 1.0
if isinstance(image_embeds, list):
image_embeds = torch.cat(image_embeds, dim=0)
batch_size = image_embeds.shape[0] * num_images_per_prompt
if isinstance(negative_image_embeds, list):
negative_image_embeds = torch.cat(negative_image_embeds, dim=0)
if do_classifier_free_guidance:
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(dtype=self.unet.dtype, device=device)
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps_tensor = self.scheduler.timesteps
num_channels_latents = self.unet.config.in_channels
height, width = downscale_height_and_width(height, width, self.movq_scale_factor)
# create initial latent
latents = self.prepare_latents(
(batch_size, num_channels_latents, height, width),
image_embeds.dtype,
device,
generator,
latents,
self.scheduler,
)
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
added_cond_kwargs = {"image_embeds": image_embeds}
noise_pred = self.unet(
sample=latent_model_input,
timestep=t,
encoder_hidden_states=None,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
if do_classifier_free_guidance:
noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1)
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
_, variance_pred_text = variance_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1)
if not (
hasattr(self.scheduler.config, "variance_type")
and self.scheduler.config.variance_type in ["learned", "learned_range"]
):
noise_pred, _ = noise_pred.split(latents.shape[1], dim=1)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
noise_pred,
t,
latents,
generator=generator,
)[0]
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
if output_type not in ["pt", "np", "pil"]:
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
if output_type in ["np", "pil"]:
image = image * 0.5 + 0.5
image = image.clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
import torch
from ...models import UNet2DConditionModel, VQModel
from ...pipelines import DiffusionPipeline
from ...pipelines.pipeline_utils import ImagePipelineOutput
from ...schedulers import DDPMScheduler
from ...utils import (
is_accelerate_available,
is_accelerate_version,
logging,
randn_tensor,
replace_example_docstring,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> import numpy as np
>>> from diffusers import KandinskyV22PriorPipeline, KandinskyV22ControlnetPipeline
>>> from transformers import pipeline
>>> from diffusers.utils import load_image
>>> def make_hint(image, depth_estimator):
... image = depth_estimator(image)["depth"]
... image = np.array(image)
... image = image[:, :, None]
... image = np.concatenate([image, image, image], axis=2)
... detected_map = torch.from_numpy(image).float() / 255.0
... hint = detected_map.permute(2, 0, 1)
... return hint
>>> depth_estimator = pipeline("depth-estimation")
>>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained(
... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16
... )
>>> pipe_prior = pipe_prior.to("cuda")
>>> pipe = KandinskyV22ControlnetPipeline.from_pretrained(
... "kandinsky-community/kandinsky-2-2-controlnet-depth", torch_dtype=torch.float16
... )
>>> pipe = pipe.to("cuda")
>>> img = load_image(
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
... "/kandinsky/cat.png"
... ).resize((768, 768))
>>> hint = make_hint(img, depth_estimator).unsqueeze(0).half().to("cuda")
>>> prompt = "A robot, 4k photo"
>>> negative_prior_prompt = "lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature"
>>> generator = torch.Generator(device="cuda").manual_seed(43)
>>> image_emb, zero_image_emb = pipe_prior(
... prompt=prompt, negative_prompt=negative_prior_prompt, generator=generator
... ).to_tuple()
>>> images = pipe(
... image_embeds=image_emb,
... negative_image_embeds=zero_image_emb,
... hint=hint,
... num_inference_steps=50,
... generator=generator,
... height=768,
... width=768,
... ).images
>>> images[0].save("robot_cat.png")
```
"""
# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width
def downscale_height_and_width(height, width, scale_factor=8):
new_height = height // scale_factor**2
if height % scale_factor**2 != 0:
new_height += 1
new_width = width // scale_factor**2
if width % scale_factor**2 != 0:
new_width += 1
return new_height * scale_factor, new_width * scale_factor
class KandinskyV22ControlnetPipeline(DiffusionPipeline):
"""
Pipeline for text-to-image generation using Kandinsky
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
scheduler ([`DDIMScheduler`]):
A scheduler to be used in combination with `unet` to generate image latents.
unet ([`UNet2DConditionModel`]):
Conditional U-Net architecture to denoise the image embedding.
movq ([`VQModel`]):
MoVQ Decoder to generate the image from the latents.
"""
def __init__(
self,
unet: UNet2DConditionModel,
scheduler: DDPMScheduler,
movq: VQModel,
):
super().__init__()
self.register_modules(
unet=unet,
scheduler=scheduler,
movq=movq,
)
self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
latents = latents * scheduler.init_noise_sigma
return latents
# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.KandinskyV22Pipeline.enable_sequential_cpu_offload
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
when their specific submodule has its `forward` method called.
"""
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device(f"cuda:{gpu_id}")
models = [
self.unet,
self.movq,
]
for cpu_offloaded_model in models:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.KandinskyV22Pipeline.enable_model_cpu_offload
def enable_model_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
"""
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate import cpu_offload_with_hook
else:
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
device = torch.device(f"cuda:{gpu_id}")
if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
hook = None
for cpu_offloaded_model in [self.unet, self.movq]:
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
# We'll offload the last model manually.
self.final_offload_hook = hook
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def _execution_device(self):
r"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if not hasattr(self.unet, "_hf_hook"):
return self.device
for module in self.unet.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return self.device
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]],
negative_image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]],
hint: torch.FloatTensor,
height: int = 512,
width: int = 512,
num_inference_steps: int = 100,
guidance_scale: float = 4.0,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
):
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
hint (`torch.FloatTensor`):
The controlnet condition.
image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`):
The clip image embeddings for text prompt, that will be used to condition the image generation.
negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`):
The clip image embeddings for negative text prompt, will be used to condition the image generation.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 100):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 4.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
Examples:
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`
"""
device = self._execution_device
do_classifier_free_guidance = guidance_scale > 1.0
if isinstance(image_embeds, list):
image_embeds = torch.cat(image_embeds, dim=0)
if isinstance(negative_image_embeds, list):
negative_image_embeds = torch.cat(negative_image_embeds, dim=0)
if isinstance(hint, list):
hint = torch.cat(hint, dim=0)
batch_size = image_embeds.shape[0] * num_images_per_prompt
if do_classifier_free_guidance:
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
hint = hint.repeat_interleave(num_images_per_prompt, dim=0)
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(dtype=self.unet.dtype, device=device)
hint = torch.cat([hint, hint], dim=0).to(dtype=self.unet.dtype, device=device)
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps_tensor = self.scheduler.timesteps
num_channels_latents = self.movq.config.latent_channels
height, width = downscale_height_and_width(height, width, self.movq_scale_factor)
# create initial latent
latents = self.prepare_latents(
(batch_size, num_channels_latents, height, width),
image_embeds.dtype,
device,
generator,
latents,
self.scheduler,
)
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
added_cond_kwargs = {"image_embeds": image_embeds, "hint": hint}
noise_pred = self.unet(
sample=latent_model_input,
timestep=t,
encoder_hidden_states=None,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
if do_classifier_free_guidance:
noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1)
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
_, variance_pred_text = variance_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1)
if not (
hasattr(self.scheduler.config, "variance_type")
and self.scheduler.config.variance_type in ["learned", "learned_range"]
):
noise_pred, _ = noise_pred.split(latents.shape[1], dim=1)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
noise_pred,
t,
latents,
generator=generator,
)[0]
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
if output_type not in ["pt", "np", "pil"]:
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
if output_type in ["np", "pil"]:
image = image * 0.5 + 0.5
image = image.clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
import numpy as np
import PIL
import torch
from PIL import Image
from ...models import UNet2DConditionModel, VQModel
from ...pipelines import DiffusionPipeline
from ...pipelines.pipeline_utils import ImagePipelineOutput
from ...schedulers import DDPMScheduler
from ...utils import (
is_accelerate_available,
is_accelerate_version,
logging,
randn_tensor,
replace_example_docstring,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> import numpy as np
>>> from diffusers import KandinskyV22PriorEmb2EmbPipeline, KandinskyV22ControlnetImg2ImgPipeline
>>> from transformers import pipeline
>>> from diffusers.utils import load_image
>>> def make_hint(image, depth_estimator):
... image = depth_estimator(image)["depth"]
... image = np.array(image)
... image = image[:, :, None]
... image = np.concatenate([image, image, image], axis=2)
... detected_map = torch.from_numpy(image).float() / 255.0
... hint = detected_map.permute(2, 0, 1)
... return hint
>>> depth_estimator = pipeline("depth-estimation")
>>> pipe_prior = KandinskyV22PriorEmb2EmbPipeline.from_pretrained(
... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16
... )
>>> pipe_prior = pipe_prior.to("cuda")
>>> pipe = KandinskyV22ControlnetImg2ImgPipeline.from_pretrained(
... "kandinsky-community/kandinsky-2-2-controlnet-depth", torch_dtype=torch.float16
... )
>>> pipe = pipe.to("cuda")
>>> img = load_image(
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
... "/kandinsky/cat.png"
... ).resize((768, 768))
>>> hint = make_hint(img, depth_estimator).unsqueeze(0).half().to("cuda")
>>> prompt = "A robot, 4k photo"
>>> negative_prior_prompt = "lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature"
>>> generator = torch.Generator(device="cuda").manual_seed(43)
>>> img_emb = pipe_prior(prompt=prompt, image=img, strength=0.85, generator=generator)
>>> negative_emb = pipe_prior(prompt=negative_prior_prompt, image=img, strength=1, generator=generator)
>>> images = pipe(
... image=img,
... strength=0.5,
... image_embeds=img_emb.image_embeds,
... negative_image_embeds=negative_emb.image_embeds,
... hint=hint,
... num_inference_steps=50,
... generator=generator,
... height=768,
... width=768,
... ).images
>>> images[0].save("robot_cat.png")
```
"""
# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width
def downscale_height_and_width(height, width, scale_factor=8):
new_height = height // scale_factor**2
if height % scale_factor**2 != 0:
new_height += 1
new_width = width // scale_factor**2
if width % scale_factor**2 != 0:
new_width += 1
return new_height * scale_factor, new_width * scale_factor
# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.prepare_image
def prepare_image(pil_image, w=512, h=512):
pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
arr = np.array(pil_image.convert("RGB"))
arr = arr.astype(np.float32) / 127.5 - 1
arr = np.transpose(arr, [2, 0, 1])
image = torch.from_numpy(arr).unsqueeze(0)
return image
class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline):
"""
Pipeline for image-to-image generation using Kandinsky
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
scheduler ([`DDIMScheduler`]):
A scheduler to be used in combination with `unet` to generate image latents.
unet ([`UNet2DConditionModel`]):
Conditional U-Net architecture to denoise the image embedding.
movq ([`VQModel`]):
MoVQ Decoder to generate the image from the latents.
"""
def __init__(
self,
unet: UNet2DConditionModel,
scheduler: DDPMScheduler,
movq: VQModel,
):
super().__init__()
self.register_modules(
unet=unet,
scheduler=scheduler,
movq=movq,
)
self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.KandinskyImg2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start:]
return timesteps, num_inference_steps - t_start
# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2_img2img.KandinskyV22Img2ImgPipeline.prepare_latents
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
raise ValueError(
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
)
image = image.to(device=device, dtype=dtype)
batch_size = batch_size * num_images_per_prompt
if image.shape[1] == 4:
init_latents = image
else:
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
elif isinstance(generator, list):
init_latents = [
self.movq.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
else:
init_latents = self.movq.encode(image).latent_dist.sample(generator)
init_latents = self.movq.config.scaling_factor * init_latents
init_latents = torch.cat([init_latents], dim=0)
shape = init_latents.shape
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
# get latents
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
latents = init_latents
return latents
# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.KandinskyV22Pipeline.enable_sequential_cpu_offload
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
when their specific submodule has its `forward` method called.
"""
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device(f"cuda:{gpu_id}")
models = [
self.unet,
self.movq,
]
for cpu_offloaded_model in models:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.KandinskyV22Pipeline.enable_model_cpu_offload
def enable_model_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
"""
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate import cpu_offload_with_hook
else:
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
device = torch.device(f"cuda:{gpu_id}")
if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
hook = None
for cpu_offloaded_model in [self.unet, self.movq]:
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
# We'll offload the last model manually.
self.final_offload_hook = hook
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def _execution_device(self):
r"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if not hasattr(self.unet, "_hf_hook"):
return self.device
for module in self.unet.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return self.device
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]],
image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]],
negative_image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]],
hint: torch.FloatTensor,
height: int = 512,
width: int = 512,
num_inference_steps: int = 100,
guidance_scale: float = 4.0,
strength: float = 0.3,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
):
"""
Function invoked when calling the pipeline for generation.
Args:
image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`):
The clip image embeddings for text prompt, that will be used to condition the image generation.
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the
process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded
again.
strength (`float`, *optional*, defaults to 0.8):
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
be maximum and the denoising process will run for the full number of iterations specified in
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
hint (`torch.FloatTensor`):
The controlnet condition.
negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`):
The clip image embeddings for negative text prompt, will be used to condition the image generation.
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 100):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 4.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
Examples:
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`
"""
device = self._execution_device
do_classifier_free_guidance = guidance_scale > 1.0
if isinstance(image_embeds, list):
image_embeds = torch.cat(image_embeds, dim=0)
if isinstance(negative_image_embeds, list):
negative_image_embeds = torch.cat(negative_image_embeds, dim=0)
if isinstance(hint, list):
hint = torch.cat(hint, dim=0)
batch_size = image_embeds.shape[0]
if do_classifier_free_guidance:
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
hint = hint.repeat_interleave(num_images_per_prompt, dim=0)
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(dtype=self.unet.dtype, device=device)
hint = torch.cat([hint, hint], dim=0).to(dtype=self.unet.dtype, device=device)
if not isinstance(image, list):
image = [image]
if not all(isinstance(i, (PIL.Image.Image, torch.Tensor)) for i in image):
raise ValueError(
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
)
image = torch.cat([prepare_image(i, width, height) for i in image], dim=0)
image = image.to(dtype=image_embeds.dtype, device=device)
latents = self.movq.encode(image)["latents"]
latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
height, width = downscale_height_and_width(height, width, self.movq_scale_factor)
latents = self.prepare_latents(
latents, latent_timestep, batch_size, num_images_per_prompt, image_embeds.dtype, device, generator
)
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
added_cond_kwargs = {"image_embeds": image_embeds, "hint": hint}
noise_pred = self.unet(
sample=latent_model_input,
timestep=t,
encoder_hidden_states=None,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
if do_classifier_free_guidance:
noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1)
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
_, variance_pred_text = variance_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1)
if not (
hasattr(self.scheduler.config, "variance_type")
and self.scheduler.config.variance_type in ["learned", "learned_range"]
):
noise_pred, _ = noise_pred.split(latents.shape[1], dim=1)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
noise_pred,
t,
latents,
generator=generator,
)[0]
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
if output_type not in ["pt", "np", "pil"]:
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
if output_type in ["np", "pil"]:
image = image * 0.5 + 0.5
image = image.clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
import numpy as np
import PIL
import torch
from PIL import Image
from ...models import UNet2DConditionModel, VQModel
from ...pipelines import DiffusionPipeline
from ...pipelines.pipeline_utils import ImagePipelineOutput
from ...schedulers import DDPMScheduler
from ...utils import (
is_accelerate_available,
is_accelerate_version,
logging,
randn_tensor,
replace_example_docstring,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> from diffusers import KandinskyV22Img2ImgPipeline, KandinskyV22PriorPipeline
>>> from diffusers.utils import load_image
>>> import torch
>>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained(
... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16
... )
>>> pipe_prior.to("cuda")
>>> prompt = "A red cartoon frog, 4k"
>>> image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False)
>>> pipe = KandinskyV22Img2ImgPipeline.from_pretrained(
... "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16
... )
>>> pipe.to("cuda")
>>> init_image = load_image(
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
... "/kandinsky/frog.png"
... )
>>> image = pipe(
... image=init_image,
... image_embeds=image_emb,
... negative_image_embeds=zero_image_emb,
... height=768,
... width=768,
... num_inference_steps=100,
... strength=0.2,
... ).images
>>> image[0].save("red_frog.png")
```
"""
# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width
def downscale_height_and_width(height, width, scale_factor=8):
new_height = height // scale_factor**2
if height % scale_factor**2 != 0:
new_height += 1
new_width = width // scale_factor**2
if width % scale_factor**2 != 0:
new_width += 1
return new_height * scale_factor, new_width * scale_factor
# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.prepare_image
def prepare_image(pil_image, w=512, h=512):
pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
arr = np.array(pil_image.convert("RGB"))
arr = arr.astype(np.float32) / 127.5 - 1
arr = np.transpose(arr, [2, 0, 1])
image = torch.from_numpy(arr).unsqueeze(0)
return image
class KandinskyV22Img2ImgPipeline(DiffusionPipeline):
"""
Pipeline for image-to-image generation using Kandinsky
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
scheduler ([`DDIMScheduler`]):
A scheduler to be used in combination with `unet` to generate image latents.
unet ([`UNet2DConditionModel`]):
Conditional U-Net architecture to denoise the image embedding.
movq ([`VQModel`]):
MoVQ Decoder to generate the image from the latents.
"""
def __init__(
self,
unet: UNet2DConditionModel,
scheduler: DDPMScheduler,
movq: VQModel,
):
super().__init__()
self.register_modules(
unet=unet,
scheduler=scheduler,
movq=movq,
)
self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.KandinskyImg2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start:]
return timesteps, num_inference_steps - t_start
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
raise ValueError(
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
)
image = image.to(device=device, dtype=dtype)
batch_size = batch_size * num_images_per_prompt
if image.shape[1] == 4:
init_latents = image
else:
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
elif isinstance(generator, list):
init_latents = [
self.movq.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
else:
init_latents = self.movq.encode(image).latent_dist.sample(generator)
init_latents = self.movq.config.scaling_factor * init_latents
init_latents = torch.cat([init_latents], dim=0)
shape = init_latents.shape
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
# get latents
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
latents = init_latents
return latents
# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.KandinskyV22Pipeline.enable_sequential_cpu_offload
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
when their specific submodule has its `forward` method called.
"""
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device(f"cuda:{gpu_id}")
models = [
self.unet,
self.movq,
]
for cpu_offloaded_model in models:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.KandinskyV22Pipeline.enable_model_cpu_offload
def enable_model_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
"""
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate import cpu_offload_with_hook
else:
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
device = torch.device(f"cuda:{gpu_id}")
if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
hook = None
for cpu_offloaded_model in [self.unet, self.movq]:
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
# We'll offload the last model manually.
self.final_offload_hook = hook
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def _execution_device(self):
r"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if not hasattr(self.unet, "_hf_hook"):
return self.device
for module in self.unet.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return self.device
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]],
image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]],
negative_image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]],
height: int = 512,
width: int = 512,
num_inference_steps: int = 100,
guidance_scale: float = 4.0,
strength: float = 0.3,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
):
"""
Function invoked when calling the pipeline for generation.
Args:
image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`):
The clip image embeddings for text prompt, that will be used to condition the image generation.
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the
process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded
again.
strength (`float`, *optional*, defaults to 0.8):
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
be maximum and the denoising process will run for the full number of iterations specified in
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`):
The clip image embeddings for negative text prompt, will be used to condition the image generation.
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 100):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 4.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
Examples:
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`
"""
device = self._execution_device
do_classifier_free_guidance = guidance_scale > 1.0
if isinstance(image_embeds, list):
image_embeds = torch.cat(image_embeds, dim=0)
batch_size = image_embeds.shape[0]
if isinstance(negative_image_embeds, list):
negative_image_embeds = torch.cat(negative_image_embeds, dim=0)
if do_classifier_free_guidance:
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(dtype=self.unet.dtype, device=device)
if not isinstance(image, list):
image = [image]
if not all(isinstance(i, (PIL.Image.Image, torch.Tensor)) for i in image):
raise ValueError(
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
)
image = torch.cat([prepare_image(i, width, height) for i in image], dim=0)
image = image.to(dtype=image_embeds.dtype, device=device)
latents = self.movq.encode(image)["latents"]
latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
height, width = downscale_height_and_width(height, width, self.movq_scale_factor)
latents = self.prepare_latents(
latents, latent_timestep, batch_size, num_images_per_prompt, image_embeds.dtype, device, generator
)
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
added_cond_kwargs = {"image_embeds": image_embeds}
noise_pred = self.unet(
sample=latent_model_input,
timestep=t,
encoder_hidden_states=None,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
if do_classifier_free_guidance:
noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1)
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
_, variance_pred_text = variance_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1)
if not (
hasattr(self.scheduler.config, "variance_type")
and self.scheduler.config.variance_type in ["learned", "learned_range"]
):
noise_pred, _ = noise_pred.split(latents.shape[1], dim=1)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
noise_pred,
t,
latents,
generator=generator,
)[0]
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
if output_type not in ["pt", "np", "pil"]:
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
if output_type in ["np", "pil"]:
image = image * 0.5 + 0.5
image = image.clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
...@@ -18,6 +18,9 @@ from ...models.attention_processor import ( ...@@ -18,6 +18,9 @@ from ...models.attention_processor import (
from ...models.dual_transformer_2d import DualTransformer2DModel from ...models.dual_transformer_2d import DualTransformer2DModel
from ...models.embeddings import ( from ...models.embeddings import (
GaussianFourierProjection, GaussianFourierProjection,
ImageHintTimeEmbedding,
ImageProjection,
ImageTimeEmbedding,
TextImageProjection, TextImageProjection,
TextImageTimeEmbedding, TextImageTimeEmbedding,
TextTimeEmbedding, TextTimeEmbedding,
...@@ -409,7 +412,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -409,7 +412,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
image_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
) )
elif encoder_hid_dim_type == "image_proj":
# Kandinsky 2.2
self.encoder_hid_proj = ImageProjection(
image_embed_dim=encoder_hid_dim,
cross_attention_dim=cross_attention_dim,
)
elif encoder_hid_dim_type is not None: elif encoder_hid_dim_type is not None:
raise ValueError( raise ValueError(
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
...@@ -465,7 +473,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -465,7 +473,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
elif addition_embed_type == "text_time": elif addition_embed_type == "text_time":
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
elif addition_embed_type == "image":
# Kandinsky 2.2
self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
elif addition_embed_type == "image_hint":
# Kandinsky 2.2 ControlNet
self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
elif addition_embed_type is not None: elif addition_embed_type is not None:
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
...@@ -911,7 +924,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -911,7 +924,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
if self.config.addition_embed_type == "text": if self.config.addition_embed_type == "text":
aug_emb = self.add_embedding(encoder_hidden_states) aug_emb = self.add_embedding(encoder_hidden_states)
elif self.config.addition_embed_type == "text_image": elif self.config.addition_embed_type == "text_image":
# Kadinsky 2.1 - style # Kandinsky 2.1 - style
if "image_embeds" not in added_cond_kwargs: if "image_embeds" not in added_cond_kwargs:
raise ValueError( raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires" f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires"
...@@ -920,7 +933,6 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -920,7 +933,6 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
image_embs = added_cond_kwargs.get("image_embeds") image_embs = added_cond_kwargs.get("image_embeds")
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
aug_emb = self.add_embedding(text_embs, image_embs) aug_emb = self.add_embedding(text_embs, image_embs)
elif self.config.addition_embed_type == "text_time": elif self.config.addition_embed_type == "text_time":
if "text_embeds" not in added_cond_kwargs: if "text_embeds" not in added_cond_kwargs:
...@@ -941,6 +953,26 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -941,6 +953,26 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype) add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds) aug_emb = self.add_embedding(add_embeds)
elif self.config.addition_embed_type == "image":
# Kandinsky 2.2 - style
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the"
" keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
)
image_embs = added_cond_kwargs.get("image_embeds")
aug_emb = self.add_embedding(image_embs)
elif self.config.addition_embed_type == "image_hint":
# Kandinsky 2.2 - style
if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires"
" the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
)
image_embs = added_cond_kwargs.get("image_embeds")
hint = added_cond_kwargs.get("hint")
aug_emb, hint = self.add_embedding(image_embs, hint)
sample = torch.cat([sample, hint], dim=1)
emb = emb + aug_emb if aug_emb is not None else emb emb = emb + aug_emb if aug_emb is not None else emb
...@@ -959,7 +991,15 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -959,7 +991,15 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
image_embeds = added_cond_kwargs.get("image_embeds") image_embeds = added_cond_kwargs.get("image_embeds")
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
# Kandinsky 2.2 - style
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires"
" the keyword argument `image_embeds` to be passed in `added_conditions`"
)
image_embeds = added_cond_kwargs.get("image_embeds")
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
# 2. pre-process # 2. pre-process
sample = self.conv_in(sample) sample = self.conv_in(sample)
......
...@@ -307,3 +307,27 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin): ...@@ -307,3 +307,27 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin):
return (pred_prev_sample,) return (pred_prev_sample,)
return UnCLIPSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) return UnCLIPSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment