Unverified Commit 352ca319 authored by dg845's avatar dg845 Committed by GitHub
Browse files

[WIP] Add UniDiffuser model and pipeline (#2963)



* Fix a bug of pano when not doing CFG (#3030)

* Fix a bug of pano when not doing CFG

* enhance code quality

* apply formatting.

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Text2video zero refinements (#3070)

* fix progress bar issue in pipeline_text_to_video_zero.py. Copy scheduler after first backward

* fix tensor loading in test_text_to_video_zero.py

* make style && make quality

* Release: v0.15.0

* [Tests] Speed up panorama tests (#3067)

* fix: norm group test for UNet3D.

* chore: speed up the panorama tests (fast).

* set default value of _test_inference_batch_single_identical.

* fix: batch_sizes default value.

* [Post release] v0.16.0dev (#3072)

* Adds profiling flags, computes train metrics average. (#3053)

* WIP controlnet training

- bugfix --streaming
- bugfix running report_to!='wandb'
- adds memory profile before validation

* Adds final logging statement.

* Sets tr...
parent 7948db81
......@@ -232,6 +232,8 @@
title: UnCLIP
- local: api/pipelines/latent_diffusion_uncond
title: Unconditional Latent Diffusion
- local: api/pipelines/unidiffuser
title: UniDiffuser
- local: api/pipelines/versatile_diffusion
title: Versatile Diffusion
- local: api/pipelines/vq_diffusion
......
......@@ -45,3 +45,8 @@ By default diffusion pipelines return an object of class
By default diffusion pipelines return an object of class
[[autodoc]] pipelines.AudioPipelineOutput
## ImageTextPipelineOutput
By default diffusion pipelines return an object of class
[[autodoc]] ImageTextPipelineOutput
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# UniDiffuser
The UniDiffuser model was proposed in [One Transformer Fits All Distributions in Multi-Modal Diffusion at Scale](https://arxiv.org/abs/2303.06555) by Fan Bao, Shen Nie, Kaiwen Xue, Chongxuan Li, Shi Pu, Yaole Wang, Gang Yue, Yue Cao, Hang Su, Jun Zhu.
The abstract of the [paper](https://arxiv.org/abs/2303.06555) is the following:
*This paper proposes a unified diffusion framework (dubbed UniDiffuser) to fit all distributions relevant to a set of multi-modal data in one model. Our key insight is -- learning diffusion models for marginal, conditional, and joint distributions can be unified as predicting the noise in the perturbed data, where the perturbation levels (i.e. timesteps) can be different for different modalities. Inspired by the unified view, UniDiffuser learns all distributions simultaneously with a minimal modification to the original diffusion model -- perturbs data in all modalities instead of a single modality, inputs individual timesteps in different modalities, and predicts the noise of all modalities instead of a single modality. UniDiffuser is parameterized by a transformer for diffusion models to handle input types of different modalities. Implemented on large-scale paired image-text data, UniDiffuser is able to perform image, text, text-to-image, image-to-text, and image-text pair generation by setting proper timesteps without additional overhead. In particular, UniDiffuser is able to produce perceptually realistic samples in all tasks and its quantitative results (e.g., the FID and CLIP score) are not only superior to existing general-purpose models but also comparable to the bespoken models (e.g., Stable Diffusion and DALL-E 2) in representative tasks (e.g., text-to-image generation).*
Resources:
* [Paper](https://arxiv.org/abs/2303.06555).
* [Original Code](https://github.com/thu-ml/unidiffuser).
Available Checkpoints are:
- *UniDiffuser-v0 (512x512 resolution)* [thu-ml/unidiffuser-v0](https://huggingface.co/thu-ml/unidiffuser-v0)
- *UniDiffuser-v1 (512x512 resolution)* [thu-ml/unidiffuser-v1](https://huggingface.co/thu-ml/unidiffuser-v1)
This pipeline was contributed by our community member [dg845](https://github.com/dg845).
## Available Pipelines:
| Pipeline | Tasks | Demo | Colab |
|:---:|:---:|:---:|:---:|
| [UniDiffuserPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_unidiffuser.py) | *Joint Image-Text Gen*, *Text-to-Image*, *Image-to-Text*,<br> *Image Gen*, *Text Gen*, *Image Variation*, *Text Variation* | [🤗 Spaces](https://huggingface.co/spaces/thu-ml/unidiffuser) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/unidiffuser.ipynb) |
## Usage Examples
Because the UniDiffuser model is trained to model the joint distribution of (image, text) pairs, it is capable of performing a diverse range of generation tasks.
### Unconditional Image and Text Generation
Unconditional generation (where we start from only latents sampled from a standard Gaussian prior) from a [`UniDiffuserPipeline`] will produce a (image, text) pair:
```python
import torch
from diffusers import UniDiffuserPipeline
device = "cuda"
model_id_or_path = "thu-ml/unidiffuser-v1"
pipe = UniDiffuserPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
pipe.to(device)
# Unconditional image and text generation. The generation task is automatically inferred.
sample = pipe(num_inference_steps=20, guidance_scale=8.0)
image = sample.images[0]
text = sample.text[0]
image.save("unidiffuser_joint_sample_image.png")
print(text)
```
This is also called "joint" generation in the UniDiffusers paper, since we are sampling from the joint image-text distribution.
Note that the generation task is inferred from the inputs used when calling the pipeline.
It is also possible to manually specify the unconditional generation task ("mode") manually with [`UniDiffuserPipeline.set_joint_mode`]:
```python
# Equivalent to the above.
pipe.set_joint_mode()
sample = pipe(num_inference_steps=20, guidance_scale=8.0)
```
When the mode is set manually, subsequent calls to the pipeline will use the set mode without attempting the infer the mode.
You can reset the mode with [`UniDiffuserPipeline.reset_mode`], after which the pipeline will once again infer the mode.
You can also generate only an image or only text (which the UniDiffuser paper calls "marginal" generation since we sample from the marginal distribution of images and text, respectively):
```python
# Unlike other generation tasks, image-only and text-only generation don't use classifier-free guidance
# Image-only generation
pipe.set_image_mode()
sample_image = pipe(num_inference_steps=20).images[0]
# Text-only generation
pipe.set_text_mode()
sample_text = pipe(num_inference_steps=20).text[0]
```
### Text-to-Image Generation
UniDiffuser is also capable of sampling from conditional distributions; that is, the distribution of images conditioned on a text prompt or the distribution of texts conditioned on an image.
Here is an example of sampling from the conditional image distribution (text-to-image generation or text-conditioned image generation):
```python
import torch
from diffusers import UniDiffuserPipeline
device = "cuda"
model_id_or_path = "thu-ml/unidiffuser-v1"
pipe = UniDiffuserPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
pipe.to(device)
# Text-to-image generation
prompt = "an elephant under the sea"
sample = pipe(prompt=prompt, num_inference_steps=20, guidance_scale=8.0)
t2i_image = sample.images[0]
t2i_image.save("unidiffuser_text2img_sample_image.png")
```
The `text2img` mode requires that either an input `prompt` or `prompt_embeds` be supplied. You can set the `text2img` mode manually with [`UniDiffuserPipeline.set_text_to_image_mode`].
### Image-to-Text Generation
Similarly, UniDiffuser can also produce text samples given an image (image-to-text or image-conditioned text generation):
```python
import torch
from diffusers import UniDiffuserPipeline
from diffusers.utils import load_image
device = "cuda"
model_id_or_path = "thu-ml/unidiffuser-v1"
pipe = UniDiffuserPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
pipe.to(device)
# Image-to-text generation
image_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unidiffuser/unidiffuser_example_image.jpg"
init_image = load_image(image_url).resize((512, 512))
sample = pipe(image=init_image, num_inference_steps=20, guidance_scale=8.0)
i2t_text = sample.text[0]
print(i2t_text)
```
The `img2text` mode requires that an input `image` be supplied. You can set the `img2text` mode manually with [`UniDiffuserPipeline.set_image_to_text_mode`].
### Image Variation
The UniDiffuser authors suggest performing image variation through a "round-trip" generation method, where given an input image, we first perform an image-to-text generation, and the perform a text-to-image generation on the outputs of the first generation.
This produces a new image which is semantically similar to the input image:
```python
import torch
from diffusers import UniDiffuserPipeline
from diffusers.utils import load_image
device = "cuda"
model_id_or_path = "thu-ml/unidiffuser-v1"
pipe = UniDiffuserPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
pipe.to(device)
# Image variation can be performed with a image-to-text generation followed by a text-to-image generation:
# 1. Image-to-text generation
image_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unidiffuser/unidiffuser_example_image.jpg"
init_image = load_image(image_url).resize((512, 512))
sample = pipe(image=init_image, num_inference_steps=20, guidance_scale=8.0)
i2t_text = sample.text[0]
print(i2t_text)
# 2. Text-to-image generation
sample = pipe(prompt=i2t_text, num_inference_steps=20, guidance_scale=8.0)
final_image = sample.images[0]
final_image.save("unidiffuser_image_variation_sample.png")
```
### Text Variation
Similarly, text variation can be performed on an input prompt with a text-to-image generation followed by a image-to-text generation:
```python
import torch
from diffusers import UniDiffuserPipeline
device = "cuda"
model_id_or_path = "thu-ml/unidiffuser-v1"
pipe = UniDiffuserPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
pipe.to(device)
# Text variation can be performed with a text-to-image generation followed by a image-to-text generation:
# 1. Text-to-image generation
prompt = "an elephant under the sea"
sample = pipe(prompt=prompt, num_inference_steps=20, guidance_scale=8.0)
t2i_image = sample.images[0]
t2i_image.save("unidiffuser_text2img_sample_image.png")
# 2. Image-to-text generation
sample = pipe(image=t2i_image, num_inference_steps=20, guidance_scale=8.0)
final_prompt = sample.text[0]
print(final_prompt)
```
## UniDiffuserPipeline
[[autodoc]] UniDiffuserPipeline
- all
- __call__
# Convert the original UniDiffuser checkpoints into diffusers equivalents.
import argparse
from argparse import Namespace
import torch
from transformers import (
CLIPImageProcessor,
CLIPTextConfig,
CLIPTextModel,
CLIPTokenizer,
CLIPVisionConfig,
CLIPVisionModelWithProjection,
GPT2Tokenizer,
)
from diffusers import (
AutoencoderKL,
DPMSolverMultistepScheduler,
UniDiffuserModel,
UniDiffuserPipeline,
UniDiffuserTextDecoder,
)
SCHEDULER_CONFIG = Namespace(
**{
"beta_start": 0.00085,
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"solver_order": 3,
}
)
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.shave_segments
def shave_segments(path, n_shave_prefix_segments=1):
"""
Removes segments. Positive values shave the first segments, negative shave the last segments.
"""
if n_shave_prefix_segments >= 0:
return ".".join(path.split(".")[n_shave_prefix_segments:])
else:
return ".".join(path.split(".")[:n_shave_prefix_segments])
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_vae_resnet_paths
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside resnets to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_vae_attention_paths
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside attentions to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item
new_item = new_item.replace("norm.weight", "group_norm.weight")
new_item = new_item.replace("norm.bias", "group_norm.bias")
new_item = new_item.replace("q.weight", "query.weight")
new_item = new_item.replace("q.bias", "query.bias")
new_item = new_item.replace("k.weight", "key.weight")
new_item = new_item.replace("k.bias", "key.bias")
new_item = new_item.replace("v.weight", "value.weight")
new_item = new_item.replace("v.bias", "value.bias")
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
# Modified from diffusers.pipelines.stable_diffusion.convert_from_ckpt.assign_to_checkpoint
# config.num_head_channels => num_head_channels
def assign_to_checkpoint(
paths,
checkpoint,
old_checkpoint,
attention_paths_to_split=None,
additional_replacements=None,
num_head_channels=1,
):
"""
This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
attention layers, and takes into account additional replacements that may arise. Assigns the weights to the new
checkpoint.
"""
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
# Splits the attention layers into three variables.
if attention_paths_to_split is not None:
for path, path_map in attention_paths_to_split.items():
old_tensor = old_checkpoint[path]
channels = old_tensor.shape[0] // 3
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
num_heads = old_tensor.shape[0] // num_head_channels // 3
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
query, key, value = old_tensor.split(channels // num_heads, dim=1)
checkpoint[path_map["query"]] = query.reshape(target_shape)
checkpoint[path_map["key"]] = key.reshape(target_shape)
checkpoint[path_map["value"]] = value.reshape(target_shape)
for path in paths:
new_path = path["new"]
# These have already been assigned
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
continue
# Global renaming happens here
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
if additional_replacements is not None:
for replacement in additional_replacements:
new_path = new_path.replace(replacement["old"], replacement["new"])
# proj_attn.weight has to be converted from conv 1D to linear
if "proj_attn.weight" in new_path:
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
else:
checkpoint[new_path] = old_checkpoint[path["old"]]
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear
def conv_attn_to_linear(checkpoint):
keys = list(checkpoint.keys())
attn_keys = ["query.weight", "key.weight", "value.weight"]
for key in keys:
if ".".join(key.split(".")[-2:]) in attn_keys:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0, 0]
elif "proj_attn.weight" in key:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0]
def create_vae_diffusers_config(config_type):
# Hardcoded for now
if args.config_type == "test":
vae_config = create_vae_diffusers_config_test()
elif args.config_type == "big":
vae_config = create_vae_diffusers_config_big()
else:
raise NotImplementedError(
f"Config type {config_type} is not implemented, currently only config types"
" 'test' and 'big' are available."
)
return vae_config
def create_unidiffuser_unet_config(config_type, version):
# Hardcoded for now
if args.config_type == "test":
unet_config = create_unidiffuser_unet_config_test()
elif args.config_type == "big":
unet_config = create_unidiffuser_unet_config_big()
else:
raise NotImplementedError(
f"Config type {config_type} is not implemented, currently only config types"
" 'test' and 'big' are available."
)
# Unidiffuser-v1 uses data type embeddings
if version == 1:
unet_config["use_data_type_embedding"] = True
return unet_config
def create_text_decoder_config(config_type):
# Hardcoded for now
if args.config_type == "test":
text_decoder_config = create_text_decoder_config_test()
elif args.config_type == "big":
text_decoder_config = create_text_decoder_config_big()
else:
raise NotImplementedError(
f"Config type {config_type} is not implemented, currently only config types"
" 'test' and 'big' are available."
)
return text_decoder_config
# Hardcoded configs for test versions of the UniDiffuser models, corresponding to those in the fast default tests.
def create_vae_diffusers_config_test():
vae_config = {
"sample_size": 32,
"in_channels": 3,
"out_channels": 3,
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
"block_out_channels": [32, 64],
"latent_channels": 4,
"layers_per_block": 1,
}
return vae_config
def create_unidiffuser_unet_config_test():
unet_config = {
"text_dim": 32,
"clip_img_dim": 32,
"num_text_tokens": 77,
"num_attention_heads": 2,
"attention_head_dim": 8,
"in_channels": 4,
"out_channels": 4,
"num_layers": 2,
"dropout": 0.0,
"norm_num_groups": 32,
"attention_bias": False,
"sample_size": 16,
"patch_size": 2,
"activation_fn": "gelu",
"num_embeds_ada_norm": 1000,
"norm_type": "layer_norm",
"block_type": "unidiffuser",
"pre_layer_norm": False,
"use_timestep_embedding": False,
"norm_elementwise_affine": True,
"use_patch_pos_embed": False,
"ff_final_dropout": True,
"use_data_type_embedding": False,
}
return unet_config
def create_text_decoder_config_test():
text_decoder_config = {
"prefix_length": 77,
"prefix_inner_dim": 32,
"prefix_hidden_dim": 32,
"vocab_size": 1025, # 1024 + 1 for new EOS token
"n_positions": 1024,
"n_embd": 32,
"n_layer": 5,
"n_head": 4,
"n_inner": 37,
"activation_function": "gelu",
"resid_pdrop": 0.1,
"embd_pdrop": 0.1,
"attn_pdrop": 0.1,
"layer_norm_epsilon": 1e-5,
"initializer_range": 0.02,
}
return text_decoder_config
# Hardcoded configs for the UniDiffuser V1 model at https://huggingface.co/thu-ml/unidiffuser-v1
# See also https://github.com/thu-ml/unidiffuser/blob/main/configs/sample_unidiffuser_v1.py
def create_vae_diffusers_config_big():
vae_config = {
"sample_size": 256,
"in_channels": 3,
"out_channels": 3,
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"],
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"],
"block_out_channels": [128, 256, 512, 512],
"latent_channels": 4,
"layers_per_block": 2,
}
return vae_config
def create_unidiffuser_unet_config_big():
unet_config = {
"text_dim": 64,
"clip_img_dim": 512,
"num_text_tokens": 77,
"num_attention_heads": 24,
"attention_head_dim": 64,
"in_channels": 4,
"out_channels": 4,
"num_layers": 30,
"dropout": 0.0,
"norm_num_groups": 32,
"attention_bias": False,
"sample_size": 64,
"patch_size": 2,
"activation_fn": "gelu",
"num_embeds_ada_norm": 1000,
"norm_type": "layer_norm",
"block_type": "unidiffuser",
"pre_layer_norm": False,
"use_timestep_embedding": False,
"norm_elementwise_affine": True,
"use_patch_pos_embed": False,
"ff_final_dropout": True,
"use_data_type_embedding": False,
}
return unet_config
# From https://huggingface.co/gpt2/blob/main/config.json, the GPT2 checkpoint used by UniDiffuser
def create_text_decoder_config_big():
text_decoder_config = {
"prefix_length": 77,
"prefix_inner_dim": 768,
"prefix_hidden_dim": 64,
"vocab_size": 50258, # 50257 + 1 for new EOS token
"n_positions": 1024,
"n_embd": 768,
"n_layer": 12,
"n_head": 12,
"n_inner": 3072,
"activation_function": "gelu",
"resid_pdrop": 0.1,
"embd_pdrop": 0.1,
"attn_pdrop": 0.1,
"layer_norm_epsilon": 1e-5,
"initializer_range": 0.02,
}
return text_decoder_config
# Based on diffusers.pipelines.stable_diffusion.convert_from_ckpt.shave_segments.convert_ldm_vae_checkpoint
def convert_vae_to_diffusers(ckpt, diffusers_model, num_head_channels=1):
"""
Converts a UniDiffuser autoencoder_kl.pth checkpoint to a diffusers AutoencoderKL.
"""
# autoencoder_kl.pth ckpt is a torch state dict
vae_state_dict = torch.load(ckpt, map_location="cpu")
new_checkpoint = {}
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
# Retrieves the keys for the encoder down blocks only
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
down_blocks = {
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
}
# Retrieves the keys for the decoder up blocks only
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
up_blocks = {
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
}
for i in range(num_down_blocks):
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
f"encoder.down.{i}.downsample.conv.weight"
)
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
f"encoder.down.{i}.downsample.conv.bias"
)
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
assign_to_checkpoint(
paths,
new_checkpoint,
vae_state_dict,
additional_replacements=[meta_path],
num_head_channels=num_head_channels, # not used in vae
)
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1):
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
assign_to_checkpoint(
paths,
new_checkpoint,
vae_state_dict,
additional_replacements=[meta_path],
num_head_channels=num_head_channels, # not used in vae
)
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
paths = renew_vae_attention_paths(mid_attentions)
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(
paths,
new_checkpoint,
vae_state_dict,
additional_replacements=[meta_path],
num_head_channels=num_head_channels, # not used in vae
)
conv_attn_to_linear(new_checkpoint)
for i in range(num_up_blocks):
block_id = num_up_blocks - 1 - i
resnets = [
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
]
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.weight"
]
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.bias"
]
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
assign_to_checkpoint(
paths,
new_checkpoint,
vae_state_dict,
additional_replacements=[meta_path],
num_head_channels=num_head_channels, # not used in vae
)
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1):
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
assign_to_checkpoint(
paths,
new_checkpoint,
vae_state_dict,
additional_replacements=[meta_path],
num_head_channels=num_head_channels, # not used in vae
)
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
paths = renew_vae_attention_paths(mid_attentions)
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(
paths,
new_checkpoint,
vae_state_dict,
additional_replacements=[meta_path],
num_head_channels=num_head_channels, # not used in vae
)
conv_attn_to_linear(new_checkpoint)
missing_keys, unexpected_keys = diffusers_model.load_state_dict(new_checkpoint)
for missing_key in missing_keys:
print(f"Missing key: {missing_key}")
for unexpected_key in unexpected_keys:
print(f"Unexpected key: {unexpected_key}")
return diffusers_model
def convert_uvit_block_to_diffusers_block(
uvit_state_dict,
new_state_dict,
block_prefix,
new_prefix="transformer.transformer_",
skip_connection=False,
):
"""
Maps the keys in a UniDiffuser transformer block (`Block`) to the keys in a diffusers transformer block
(`UTransformerBlock`/`UniDiffuserBlock`).
"""
prefix = new_prefix + block_prefix
if skip_connection:
new_state_dict[prefix + ".skip.skip_linear.weight"] = uvit_state_dict[block_prefix + ".skip_linear.weight"]
new_state_dict[prefix + ".skip.skip_linear.bias"] = uvit_state_dict[block_prefix + ".skip_linear.bias"]
new_state_dict[prefix + ".skip.norm.weight"] = uvit_state_dict[block_prefix + ".norm1.weight"]
new_state_dict[prefix + ".skip.norm.bias"] = uvit_state_dict[block_prefix + ".norm1.bias"]
# Create the prefix string for out_blocks.
prefix += ".block"
# Split up attention qkv.weight into to_q.weight, to_k.weight, to_v.weight
qkv = uvit_state_dict[block_prefix + ".attn.qkv.weight"]
new_attn_keys = [".attn1.to_q.weight", ".attn1.to_k.weight", ".attn1.to_v.weight"]
new_attn_keys = [prefix + key for key in new_attn_keys]
shape = qkv.shape[0] // len(new_attn_keys)
for i, attn_key in enumerate(new_attn_keys):
new_state_dict[attn_key] = qkv[i * shape : (i + 1) * shape]
new_state_dict[prefix + ".attn1.to_out.0.weight"] = uvit_state_dict[block_prefix + ".attn.proj.weight"]
new_state_dict[prefix + ".attn1.to_out.0.bias"] = uvit_state_dict[block_prefix + ".attn.proj.bias"]
new_state_dict[prefix + ".norm1.weight"] = uvit_state_dict[block_prefix + ".norm2.weight"]
new_state_dict[prefix + ".norm1.bias"] = uvit_state_dict[block_prefix + ".norm2.bias"]
new_state_dict[prefix + ".ff.net.0.proj.weight"] = uvit_state_dict[block_prefix + ".mlp.fc1.weight"]
new_state_dict[prefix + ".ff.net.0.proj.bias"] = uvit_state_dict[block_prefix + ".mlp.fc1.bias"]
new_state_dict[prefix + ".ff.net.2.weight"] = uvit_state_dict[block_prefix + ".mlp.fc2.weight"]
new_state_dict[prefix + ".ff.net.2.bias"] = uvit_state_dict[block_prefix + ".mlp.fc2.bias"]
new_state_dict[prefix + ".norm3.weight"] = uvit_state_dict[block_prefix + ".norm3.weight"]
new_state_dict[prefix + ".norm3.bias"] = uvit_state_dict[block_prefix + ".norm3.bias"]
return uvit_state_dict, new_state_dict
def convert_uvit_to_diffusers(ckpt, diffusers_model):
"""
Converts a UniDiffuser uvit_v*.pth checkpoint to a diffusers UniDiffusersModel.
"""
# uvit_v*.pth ckpt is a torch state dict
uvit_state_dict = torch.load(ckpt, map_location="cpu")
new_state_dict = {}
# Input layers
new_state_dict["vae_img_in.proj.weight"] = uvit_state_dict["patch_embed.proj.weight"]
new_state_dict["vae_img_in.proj.bias"] = uvit_state_dict["patch_embed.proj.bias"]
new_state_dict["clip_img_in.weight"] = uvit_state_dict["clip_img_embed.weight"]
new_state_dict["clip_img_in.bias"] = uvit_state_dict["clip_img_embed.bias"]
new_state_dict["text_in.weight"] = uvit_state_dict["text_embed.weight"]
new_state_dict["text_in.bias"] = uvit_state_dict["text_embed.bias"]
new_state_dict["pos_embed"] = uvit_state_dict["pos_embed"]
# Handle data type token embeddings for UniDiffuser-v1
if "token_embedding.weight" in uvit_state_dict and diffusers_model.use_data_type_embedding:
new_state_dict["data_type_pos_embed_token"] = uvit_state_dict["pos_embed_token"]
new_state_dict["data_type_token_embedding.weight"] = uvit_state_dict["token_embedding.weight"]
# Also initialize the PatchEmbedding in UTransformer2DModel with the PatchEmbedding from the checkpoint.
# This isn't used in the current implementation, so might want to remove.
new_state_dict["transformer.pos_embed.proj.weight"] = uvit_state_dict["patch_embed.proj.weight"]
new_state_dict["transformer.pos_embed.proj.bias"] = uvit_state_dict["patch_embed.proj.bias"]
# Output layers
new_state_dict["transformer.norm_out.weight"] = uvit_state_dict["norm.weight"]
new_state_dict["transformer.norm_out.bias"] = uvit_state_dict["norm.bias"]
new_state_dict["vae_img_out.weight"] = uvit_state_dict["decoder_pred.weight"]
new_state_dict["vae_img_out.bias"] = uvit_state_dict["decoder_pred.bias"]
new_state_dict["clip_img_out.weight"] = uvit_state_dict["clip_img_out.weight"]
new_state_dict["clip_img_out.bias"] = uvit_state_dict["clip_img_out.bias"]
new_state_dict["text_out.weight"] = uvit_state_dict["text_out.weight"]
new_state_dict["text_out.bias"] = uvit_state_dict["text_out.bias"]
# in_blocks
in_blocks_prefixes = {".".join(layer.split(".")[:2]) for layer in uvit_state_dict if "in_blocks" in layer}
for in_block_prefix in list(in_blocks_prefixes):
convert_uvit_block_to_diffusers_block(uvit_state_dict, new_state_dict, in_block_prefix)
# mid_block
# Assume there's only one mid block
convert_uvit_block_to_diffusers_block(uvit_state_dict, new_state_dict, "mid_block")
# out_blocks
out_blocks_prefixes = {".".join(layer.split(".")[:2]) for layer in uvit_state_dict if "out_blocks" in layer}
for out_block_prefix in list(out_blocks_prefixes):
convert_uvit_block_to_diffusers_block(uvit_state_dict, new_state_dict, out_block_prefix, skip_connection=True)
missing_keys, unexpected_keys = diffusers_model.load_state_dict(new_state_dict)
for missing_key in missing_keys:
print(f"Missing key: {missing_key}")
for unexpected_key in unexpected_keys:
print(f"Unexpected key: {unexpected_key}")
return diffusers_model
def convert_caption_decoder_to_diffusers(ckpt, diffusers_model):
"""
Converts a UniDiffuser caption_decoder.pth checkpoint to a diffusers UniDiffuserTextDecoder.
"""
# caption_decoder.pth ckpt is a torch state dict
checkpoint_state_dict = torch.load(ckpt, map_location="cpu")
decoder_state_dict = {}
# Remove the "module." prefix, if necessary
caption_decoder_key = "module."
for key in checkpoint_state_dict:
if key.startswith(caption_decoder_key):
decoder_state_dict[key.replace(caption_decoder_key, "")] = checkpoint_state_dict.get(key)
else:
decoder_state_dict[key] = checkpoint_state_dict.get(key)
new_state_dict = {}
# Encoder and Decoder
new_state_dict["encode_prefix.weight"] = decoder_state_dict["encode_prefix.weight"]
new_state_dict["encode_prefix.bias"] = decoder_state_dict["encode_prefix.bias"]
new_state_dict["decode_prefix.weight"] = decoder_state_dict["decode_prefix.weight"]
new_state_dict["decode_prefix.bias"] = decoder_state_dict["decode_prefix.bias"]
# Internal GPT2LMHeadModel transformer model
for key, val in decoder_state_dict.items():
if key.startswith("gpt"):
suffix = key[len("gpt") :]
new_state_dict["transformer" + suffix] = val
missing_keys, unexpected_keys = diffusers_model.load_state_dict(new_state_dict)
for missing_key in missing_keys:
print(f"Missing key: {missing_key}")
for unexpected_key in unexpected_keys:
print(f"Unexpected key: {unexpected_key}")
return diffusers_model
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--caption_decoder_checkpoint_path",
default=None,
type=str,
required=False,
help="Path to caption decoder checkpoint to convert.",
)
parser.add_argument(
"--uvit_checkpoint_path", default=None, type=str, required=False, help="Path to U-ViT checkpoint to convert."
)
parser.add_argument(
"--vae_checkpoint_path",
default=None,
type=str,
required=False,
help="Path to VAE checkpoint to convert.",
)
parser.add_argument(
"--pipeline_output_path",
default=None,
type=str,
required=True,
help="Path to save the output pipeline to.",
)
parser.add_argument(
"--config_type",
default="test",
type=str,
help=(
"Config type to use. Should be 'test' to create small models for testing or 'big' to convert a full"
" checkpoint."
),
)
parser.add_argument(
"--version",
default=0,
type=int,
help="The UniDiffuser model type to convert to. Should be 0 for UniDiffuser-v0 and 1 for UniDiffuser-v1.",
)
args = parser.parse_args()
# Convert the VAE model.
if args.vae_checkpoint_path is not None:
vae_config = create_vae_diffusers_config(args.config_type)
vae = AutoencoderKL(**vae_config)
vae = convert_vae_to_diffusers(args.vae_checkpoint_path, vae)
# Convert the U-ViT ("unet") model.
if args.uvit_checkpoint_path is not None:
unet_config = create_unidiffuser_unet_config(args.config_type, args.version)
unet = UniDiffuserModel(**unet_config)
unet = convert_uvit_to_diffusers(args.uvit_checkpoint_path, unet)
# Convert the caption decoder ("text_decoder") model.
if args.caption_decoder_checkpoint_path is not None:
text_decoder_config = create_text_decoder_config(args.config_type)
text_decoder = UniDiffuserTextDecoder(**text_decoder_config)
text_decoder = convert_caption_decoder_to_diffusers(args.caption_decoder_checkpoint_path, text_decoder)
# Scheduler is the same for both the test and big models.
scheduler_config = SCHEDULER_CONFIG
scheduler = DPMSolverMultistepScheduler(
beta_start=scheduler_config.beta_start,
beta_end=scheduler_config.beta_end,
beta_schedule=scheduler_config.beta_schedule,
solver_order=scheduler_config.solver_order,
)
if args.config_type == "test":
# Make a small random CLIPTextModel
torch.manual_seed(0)
clip_text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
)
text_encoder = CLIPTextModel(clip_text_encoder_config)
clip_tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
# Make a small random CLIPVisionModel and accompanying CLIPImageProcessor
torch.manual_seed(0)
clip_image_encoder_config = CLIPVisionConfig(
image_size=32,
patch_size=2,
num_channels=3,
hidden_size=32,
projection_dim=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
dropout=0.1,
attention_dropout=0.1,
initializer_range=0.02,
)
image_encoder = CLIPVisionModelWithProjection(clip_image_encoder_config)
image_processor = CLIPImageProcessor(crop_size=32, size=32)
# Note that the text_decoder should already have its token embeddings resized.
text_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-GPT2Model")
eos = "<|EOS|>"
special_tokens_dict = {"eos_token": eos}
text_tokenizer.add_special_tokens(special_tokens_dict)
elif args.config_type == "big":
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
# Note that the text_decoder should already have its token embeddings resized.
text_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
eos = "<|EOS|>"
special_tokens_dict = {"eos_token": eos}
text_tokenizer.add_special_tokens(special_tokens_dict)
else:
raise NotImplementedError(
f"Config type {args.config_type} is not implemented, currently only config types"
" 'test' and 'big' are available."
)
pipeline = UniDiffuserPipeline(
vae=vae,
text_encoder=text_encoder,
image_encoder=image_encoder,
image_processor=image_processor,
clip_tokenizer=clip_tokenizer,
text_decoder=text_decoder,
text_tokenizer=text_tokenizer,
unet=unet,
scheduler=scheduler,
)
pipeline.save_pretrained(args.pipeline_output_path)
......@@ -129,6 +129,7 @@ else:
IFInpaintingSuperResolutionPipeline,
IFPipeline,
IFSuperResolutionPipeline,
ImageTextPipelineOutput,
KandinskyImg2ImgPipeline,
KandinskyInpaintPipeline,
KandinskyPipeline,
......@@ -161,6 +162,9 @@ else:
TextToVideoZeroPipeline,
UnCLIPImageVariationPipeline,
UnCLIPPipeline,
UniDiffuserModel,
UniDiffuserPipeline,
UniDiffuserTextDecoder,
VersatileDiffusionDualGuidedPipeline,
VersatileDiffusionImageVariationPipeline,
VersatileDiffusionPipeline,
......
......@@ -89,6 +89,7 @@ else:
from .stable_diffusion_safe import StableDiffusionPipelineSafe
from .text_to_video_synthesis import TextToVideoSDPipeline, TextToVideoZeroPipeline
from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline
from .unidiffuser import ImageTextPipelineOutput, UniDiffuserModel, UniDiffuserPipeline, UniDiffuserTextDecoder
from .versatile_diffusion import (
VersatileDiffusionDualGuidedPipeline,
VersatileDiffusionImageVariationPipeline,
......
from ...utils import (
OptionalDependencyNotAvailable,
is_torch_available,
is_transformers_available,
is_transformers_version,
)
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import (
ImageTextPipelineOutput,
UniDiffuserPipeline,
)
else:
from .modeling_text_decoder import UniDiffuserTextDecoder
from .modeling_uvit import UniDiffuserModel, UTransformer2DModel
from .pipeline_unidiffuser import ImageTextPipelineOutput, UniDiffuserPipeline
from typing import Optional
import numpy as np
import torch
from torch import nn
from transformers import GPT2Config, GPT2LMHeadModel
from transformers.modeling_utils import ModuleUtilsMixin
from ...configuration_utils import ConfigMixin, register_to_config
from ...models import ModelMixin
# Modified from ClipCaptionModel in https://github.com/thu-ml/unidiffuser/blob/main/libs/caption_decoder.py
class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
"""
Text decoder model for a image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model. This is used to
generate text from the UniDiffuser image-text embedding.
Parameters:
prefix_length (`int`):
Max number of prefix tokens that will be supplied to the model.
prefix_inner_dim (`int`):
The hidden size of the the incoming prefix embeddings. For UniDiffuser, this would be the hidden dim of the
CLIP text encoder.
prefix_hidden_dim (`int`, *optional*):
Hidden dim of the MLP if we encode the prefix.
vocab_size (`int`, *optional*, defaults to 50257):
Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`].
n_positions (`int`, *optional*, defaults to 1024):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
n_embd (`int`, *optional*, defaults to 768):
Dimensionality of the embeddings and hidden states.
n_layer (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
n_head (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
n_inner (`int`, *optional*, defaults to None):
Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd
activation_function (`str`, *optional*, defaults to `"gelu"`):
Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`.
resid_pdrop (`float`, *optional*, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
embd_pdrop (`float`, *optional*, defaults to 0.1):
The dropout ratio for the embeddings.
attn_pdrop (`float`, *optional*, defaults to 0.1):
The dropout ratio for the attention.
layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
The epsilon to use in the layer normalization layers.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
scale_attn_weights (`bool`, *optional*, defaults to `True`):
Scale attention weights by dividing by sqrt(hidden_size)..
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models).
scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`):
Whether to additionally scale attention weights by `1 / layer_idx + 1`.
reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`):
Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention
dot-product/softmax to float() when training with mixed precision.
"""
@register_to_config
def __init__(
self,
prefix_length: int,
prefix_inner_dim: int,
prefix_hidden_dim: Optional[int] = None,
vocab_size: int = 50257, # Start of GPT2 config args
n_positions: int = 1024,
n_embd: int = 768,
n_layer: int = 12,
n_head: int = 12,
n_inner: Optional[int] = None,
activation_function: str = "gelu_new",
resid_pdrop: float = 0.1,
embd_pdrop: float = 0.1,
attn_pdrop: float = 0.1,
layer_norm_epsilon: float = 1e-5,
initializer_range: float = 0.02,
scale_attn_weights: bool = True,
use_cache: bool = True,
scale_attn_by_inverse_layer_idx: bool = False,
reorder_and_upcast_attn: bool = False,
):
super().__init__()
self.prefix_length = prefix_length
if prefix_inner_dim != n_embd and prefix_hidden_dim is None:
raise ValueError(
f"`prefix_hidden_dim` cannot be `None` when `prefix_inner_dim`: {prefix_hidden_dim} and"
f" `n_embd`: {n_embd} are not equal."
)
self.prefix_inner_dim = prefix_inner_dim
self.prefix_hidden_dim = prefix_hidden_dim
self.encode_prefix = (
nn.Linear(self.prefix_inner_dim, self.prefix_hidden_dim)
if self.prefix_hidden_dim is not None
else nn.Identity()
)
self.decode_prefix = (
nn.Linear(self.prefix_hidden_dim, n_embd) if self.prefix_hidden_dim is not None else nn.Identity()
)
gpt_config = GPT2Config(
vocab_size=vocab_size,
n_positions=n_positions,
n_embd=n_embd,
n_layer=n_layer,
n_head=n_head,
n_inner=n_inner,
activation_function=activation_function,
resid_pdrop=resid_pdrop,
embd_pdrop=embd_pdrop,
attn_pdrop=attn_pdrop,
layer_norm_epsilon=layer_norm_epsilon,
initializer_range=initializer_range,
scale_attn_weights=scale_attn_weights,
use_cache=use_cache,
scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx,
reorder_and_upcast_attn=reorder_and_upcast_attn,
)
self.transformer = GPT2LMHeadModel(gpt_config)
def forward(
self,
input_ids: torch.Tensor,
prefix_embeds: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
):
"""
Args:
input_ids (`torch.Tensor` of shape `(N, max_seq_len)`):
Text tokens to use for inference.
prefix_embeds (`torch.Tensor` of shape `(N, prefix_length, 768)`):
Prefix embedding to preprend to the embedded tokens.
attention_mask (`torch.Tensor` of shape `(N, prefix_length + max_seq_len, 768)`, *optional*):
Attention mask for the prefix embedding.
labels (`torch.Tensor`, *optional*):
Labels to use for language modeling.
"""
embedding_text = self.transformer.transformer.wte(input_ids)
hidden = self.encode_prefix(prefix_embeds)
prefix_embeds = self.decode_prefix(hidden)
embedding_cat = torch.cat((prefix_embeds, embedding_text), dim=1)
if labels is not None:
dummy_token = self.get_dummy_token(input_ids.shape[0], input_ids.device)
labels = torch.cat((dummy_token, input_ids), dim=1)
out = self.transformer(inputs_embeds=embedding_cat, labels=labels, attention_mask=attention_mask)
if self.prefix_hidden_dim is not None:
return out, hidden
else:
return out
def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor:
return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
def encode(self, prefix):
return self.encode_prefix(prefix)
@torch.no_grad()
def generate_captions(self, features, eos_token_id, device):
"""
Generate captions given text embedding features. Returns list[L].
Args:
features (`torch.Tensor` of shape `(B, L, D)`):
Text embedding features to generate captions from.
eos_token_id (`int`):
The token ID of the EOS token for the text decoder model.
device:
Device to perform text generation on.
Returns:
`List[str]`: A list of strings generated from the decoder model.
"""
features = torch.split(features, 1, dim=0)
generated_tokens = []
generated_seq_lengths = []
for feature in features:
feature = self.decode_prefix(feature.to(device)) # back to the clip feature
# Only support beam search for now
output_tokens, seq_lengths = self.generate_beam(
input_embeds=feature, device=device, eos_token_id=eos_token_id
)
generated_tokens.append(output_tokens[0])
generated_seq_lengths.append(seq_lengths[0])
generated_tokens = torch.stack(generated_tokens)
generated_seq_lengths = torch.stack(generated_seq_lengths)
return generated_tokens, generated_seq_lengths
@torch.no_grad()
def generate_beam(
self,
input_ids=None,
input_embeds=None,
device=None,
beam_size: int = 5,
entry_length: int = 67,
temperature: float = 1.0,
eos_token_id: Optional[int] = None,
):
"""
Generates text using the given tokenizer and text prompt or token embedding via beam search. This
implementation is based on the beam search implementation from the [original UniDiffuser
code](https://github.com/thu-ml/unidiffuser/blob/main/libs/caption_decoder.py#L89).
Args:
eos_token_id (`int`, *optional*):
The token ID of the EOS token for the text decoder model.
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
Tokenizer indices of input sequence tokens in the vocabulary. One of `input_ids` and `input_embeds`
must be supplied.
input_embeds (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*):
An embedded representation to directly pass to the transformer as a prefix for beam search. One of
`input_ids` and `input_embeds` must be supplied.
device:
The device to perform beam search on.
beam_size (`int`, *optional*, defaults to `5`):
The number of best states to store during beam search.
entry_length (`int`, *optional*, defaults to `67`):
The number of iterations to run beam search.
temperature (`float`, *optional*, defaults to 1.0):
The temperature to use when performing the softmax over logits from the decoding model.
Returns:
`Tuple(torch.Tensor, torch.Tensor)`: A tuple of tensors where the first element is a tensor of generated
token sequences sorted by score in descending order, and the second element is the sequence lengths
corresponding to those sequences.
"""
# Generates text until stop_token is reached using beam search with the desired beam size.
stop_token_index = eos_token_id
tokens = None
scores = None
seq_lengths = torch.ones(beam_size, device=device, dtype=torch.int)
is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
if input_embeds is not None:
generated = input_embeds
else:
generated = self.transformer.transformer.wte(input_ids)
for i in range(entry_length):
outputs = self.transformer(inputs_embeds=generated)
logits = outputs.logits
logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
logits = logits.softmax(-1).log()
if scores is None:
scores, next_tokens = logits.topk(beam_size, -1)
generated = generated.expand(beam_size, *generated.shape[1:])
next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
if tokens is None:
tokens = next_tokens
else:
tokens = tokens.expand(beam_size, *tokens.shape[1:])
tokens = torch.cat((tokens, next_tokens), dim=1)
else:
logits[is_stopped] = -float(np.inf)
logits[is_stopped, 0] = 0
scores_sum = scores[:, None] + logits
seq_lengths[~is_stopped] += 1
scores_sum_average = scores_sum / seq_lengths[:, None]
scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)
next_tokens_source = next_tokens // scores_sum.shape[1]
seq_lengths = seq_lengths[next_tokens_source]
next_tokens = next_tokens % scores_sum.shape[1]
next_tokens = next_tokens.unsqueeze(1)
tokens = tokens[next_tokens_source]
tokens = torch.cat((tokens, next_tokens), dim=1)
generated = generated[next_tokens_source]
scores = scores_sum_average * seq_lengths
is_stopped = is_stopped[next_tokens_source]
next_token_embed = self.transformer.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1)
generated = torch.cat((generated, next_token_embed), dim=1)
is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
if is_stopped.all():
break
scores = scores / seq_lengths
order = scores.argsort(descending=True)
# tokens tensors are already padded to max_seq_length
output_texts = [tokens[i] for i in order]
output_texts = torch.stack(output_texts, dim=0)
seq_lengths = torch.tensor([seq_lengths[i] for i in order], dtype=seq_lengths.dtype)
return output_texts, seq_lengths
import math
from typing import Optional, Union
import torch
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...models import ModelMixin
from ...models.attention import AdaLayerNorm, FeedForward
from ...models.attention_processor import Attention
from ...models.embeddings import TimestepEmbedding, Timesteps, get_2d_sincos_pos_embed
from ...models.transformer_2d import Transformer2DModelOutput
from ...utils import logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
if (mean < a - 2 * std) or (mean > b + 2 * std):
logger.warning(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect."
)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
# type: (torch.Tensor, float, float, float, float) -> torch.Tensor
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the normal distribution :math:`\mathcal{N}(\text{mean},
\text{std}^2)` with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for
generating the random values works best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5) >>> nn.init.trunc_normal_(w)
"""
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
class PatchEmbed(nn.Module):
"""2D Image to Patch Embedding"""
def __init__(
self,
height=224,
width=224,
patch_size=16,
in_channels=3,
embed_dim=768,
layer_norm=False,
flatten=True,
bias=True,
use_pos_embed=True,
):
super().__init__()
num_patches = (height // patch_size) * (width // patch_size)
self.flatten = flatten
self.layer_norm = layer_norm
self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
if layer_norm:
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
else:
self.norm = None
self.use_pos_embed = use_pos_embed
if self.use_pos_embed:
pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5))
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
def forward(self, latent):
latent = self.proj(latent)
if self.flatten:
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
if self.layer_norm:
latent = self.norm(latent)
if self.use_pos_embed:
return latent + self.pos_embed
else:
return latent
class SkipBlock(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.skip_linear = nn.Linear(2 * dim, dim)
# Use torch.nn.LayerNorm for now, following the original code
self.norm = nn.LayerNorm(dim)
def forward(self, x, skip):
x = self.skip_linear(torch.cat([x, skip], dim=-1))
x = self.norm(x)
return x
# Modified to support both pre-LayerNorm and post-LayerNorm configurations
# Don't support AdaLayerNormZero for now
# Modified from diffusers.models.attention.BasicTransformerBlock
class UTransformerBlock(nn.Module):
r"""
A modification of BasicTransformerBlock which supports pre-LayerNorm and post-LayerNorm configurations.
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`):
Activation function to be used in feed-forward.
num_embeds_ada_norm (:obj: `int`, *optional*):
The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:obj: `bool`, *optional*, defaults to `False`):
Configure if the attentions should contain a bias parameter.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
upcast_attention (`bool`, *optional*):
Whether to upcast the query and key to float32 when performing the attention calculation.
norm_elementwise_affine (`bool`, *optional*):
Whether to use learnable per-element affine parameters during layer normalization.
norm_type (`str`, defaults to `"layer_norm"`):
The layer norm implementation to use.
pre_layer_norm (`bool`, *optional*):
Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"),
as opposed to after ("post-LayerNorm"). Note that `BasicTransformerBlock` uses pre-LayerNorm, e.g.
`pre_layer_norm = True`.
final_dropout (`bool`, *optional*):
Whether to use a final Dropout layer after the feedforward network.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm",
pre_layer_norm: bool = True,
final_dropout: bool = False,
):
super().__init__()
self.only_cross_attention = only_cross_attention
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
self.pre_layer_norm = pre_layer_norm
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError(
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
)
# 1. Self-Attn
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
)
# 2. Cross-Attn
if cross_attention_dim is not None or double_self_attention:
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
) # is self-attn if encoder_hidden_states is none
else:
self.attn2 = None
if self.use_ada_layer_norm:
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
else:
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
if cross_attention_dim is not None or double_self_attention:
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
self.norm2 = (
AdaLayerNorm(dim, num_embeds_ada_norm)
if self.use_ada_layer_norm
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
)
else:
self.norm2 = None
# 3. Feed-forward
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
def forward(
self,
hidden_states,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
timestep=None,
cross_attention_kwargs=None,
class_labels=None,
):
# Pre-LayerNorm
if self.pre_layer_norm:
if self.use_ada_layer_norm:
norm_hidden_states = self.norm1(hidden_states, timestep)
else:
norm_hidden_states = self.norm1(hidden_states)
else:
norm_hidden_states = hidden_states
# 1. Self-Attention
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
# Post-LayerNorm
if not self.pre_layer_norm:
if self.use_ada_layer_norm:
attn_output = self.norm1(attn_output, timestep)
else:
attn_output = self.norm1(attn_output)
hidden_states = attn_output + hidden_states
if self.attn2 is not None:
# Pre-LayerNorm
if self.pre_layer_norm:
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
else:
norm_hidden_states = hidden_states
# TODO (Birch-San): Here we should prepare the encoder_attention mask correctly
# prepare attention mask here
# 2. Cross-Attention
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
# Post-LayerNorm
if not self.pre_layer_norm:
attn_output = self.norm2(attn_output, timestep) if self.use_ada_layer_norm else self.norm2(attn_output)
hidden_states = attn_output + hidden_states
# 3. Feed-forward
# Pre-LayerNorm
if self.pre_layer_norm:
norm_hidden_states = self.norm3(hidden_states)
else:
norm_hidden_states = hidden_states
ff_output = self.ff(norm_hidden_states)
# Post-LayerNorm
if not self.pre_layer_norm:
ff_output = self.norm3(ff_output)
hidden_states = ff_output + hidden_states
return hidden_states
# Like UTransformerBlock except with LayerNorms on the residual backbone of the block
# Modified from diffusers.models.attention.BasicTransformerBlock
class UniDiffuserBlock(nn.Module):
r"""
A modification of BasicTransformerBlock which supports pre-LayerNorm and post-LayerNorm configurations and puts the
LayerNorms on the residual backbone of the block. This matches the transformer block in the [original UniDiffuser
implementation](https://github.com/thu-ml/unidiffuser/blob/main/libs/uvit_multi_post_ln_v1.py#L104).
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`):
Activation function to be used in feed-forward.
num_embeds_ada_norm (:obj: `int`, *optional*):
The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:obj: `bool`, *optional*, defaults to `False`):
Configure if the attentions should contain a bias parameter.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
upcast_attention (`bool`, *optional*):
Whether to upcast the query and key to float() when performing the attention calculation.
norm_elementwise_affine (`bool`, *optional*):
Whether to use learnable per-element affine parameters during layer normalization.
norm_type (`str`, defaults to `"layer_norm"`):
The layer norm implementation to use.
pre_layer_norm (`bool`, *optional*):
Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"),
as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm
(`pre_layer_norm = False`).
final_dropout (`bool`, *optional*):
Whether to use a final Dropout layer after the feedforward network.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm",
pre_layer_norm: bool = False,
final_dropout: bool = True,
):
super().__init__()
self.only_cross_attention = only_cross_attention
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
self.pre_layer_norm = pre_layer_norm
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError(
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
)
# 1. Self-Attn
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
)
# 2. Cross-Attn
if cross_attention_dim is not None or double_self_attention:
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
) # is self-attn if encoder_hidden_states is none
else:
self.attn2 = None
if self.use_ada_layer_norm:
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
else:
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
if cross_attention_dim is not None or double_self_attention:
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
self.norm2 = (
AdaLayerNorm(dim, num_embeds_ada_norm)
if self.use_ada_layer_norm
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
)
else:
self.norm2 = None
# 3. Feed-forward
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
def forward(
self,
hidden_states,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
timestep=None,
cross_attention_kwargs=None,
class_labels=None,
):
# Following the diffusers transformer block implementation, put the LayerNorm on the
# residual backbone
# Pre-LayerNorm
if self.pre_layer_norm:
if self.use_ada_layer_norm:
hidden_states = self.norm1(hidden_states, timestep)
else:
hidden_states = self.norm1(hidden_states)
# 1. Self-Attention
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
attn_output = self.attn1(
hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# Following the diffusers transformer block implementation, put the LayerNorm on the
# residual backbone
# Post-LayerNorm
if not self.pre_layer_norm:
if self.use_ada_layer_norm:
hidden_states = self.norm1(hidden_states, timestep)
else:
hidden_states = self.norm1(hidden_states)
if self.attn2 is not None:
# Pre-LayerNorm
if self.pre_layer_norm:
hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
# TODO (Birch-San): Here we should prepare the encoder_attention mask correctly
# prepare attention mask here
# 2. Cross-Attention
attn_output = self.attn2(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# Post-LayerNorm
if not self.pre_layer_norm:
hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
# 3. Feed-forward
# Pre-LayerNorm
if self.pre_layer_norm:
hidden_states = self.norm3(hidden_states)
ff_output = self.ff(hidden_states)
hidden_states = ff_output + hidden_states
# Post-LayerNorm
if not self.pre_layer_norm:
hidden_states = self.norm3(hidden_states)
return hidden_states
# Modified from diffusers.models.transformer_2d.Transformer2DModel
# Modify the transformer block structure to be U-Net like following U-ViT
# Only supports patch-style input and torch.nn.LayerNorm currently
# https://github.com/baofff/U-ViT
class UTransformer2DModel(ModelMixin, ConfigMixin):
"""
Transformer model based on the [U-ViT](https://github.com/baofff/U-ViT) architecture for image-like data. Compared
to [`Transformer2DModel`], this model has skip connections between transformer blocks in a "U"-shaped fashion,
similar to a U-Net. Supports only continuous (actual embeddings) inputs, which are embedded via a [`PatchEmbed`]
layer and then reshaped to (b, t, d).
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
Pass if the input is continuous. The number of channels in the input.
out_channels (`int`, *optional*):
The number of output channels; if `None`, defaults to `in_channels`.
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
norm_num_groups (`int`, *optional*, defaults to `32`):
The number of groups to use when performing Group Normalization.
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
attention_bias (`bool`, *optional*):
Configure if the TransformerBlocks' attention should contain a bias parameter.
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
`ImagePositionalEmbeddings`.
num_vector_embeds (`int`, *optional*):
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
Includes the class for the masked latent pixel.
patch_size (`int`, *optional*, defaults to 2):
The patch size to use in the patch embedding.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
up to but not more than steps than `num_embeds_ada_norm`.
use_linear_projection (int, *optional*): TODO: Not used
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used in each
transformer block.
upcast_attention (`bool`, *optional*):
Whether to upcast the query and key to float() when performing the attention calculation.
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
The Layer Normalization implementation to use. Defaults to `torch.nn.LayerNorm`.
block_type (`str`, *optional*, defaults to `"unidiffuser"`):
The transformer block implementation to use. If `"unidiffuser"`, has the LayerNorms on the residual
backbone of each transformer block; otherwise has them in the attention/feedforward branches (the standard
behavior in `diffusers`.)
pre_layer_norm (`bool`, *optional*):
Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"),
as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm
(`pre_layer_norm = False`).
norm_elementwise_affine (`bool`, *optional*):
Whether to use learnable per-element affine parameters during layer normalization.
use_patch_pos_embed (`bool`, *optional*):
Whether to use position embeddings inside the patch embedding layer (`PatchEmbed`).
final_dropout (`bool`, *optional*):
Whether to use a final Dropout layer after the feedforward network.
"""
@register_to_config
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
sample_size: Optional[int] = None,
num_vector_embeds: Optional[int] = None,
patch_size: Optional[int] = 2,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
norm_type: str = "layer_norm",
block_type: str = "unidiffuser",
pre_layer_norm: bool = False,
norm_elementwise_affine: bool = True,
use_patch_pos_embed=False,
ff_final_dropout: bool = False,
):
super().__init__()
self.use_linear_projection = use_linear_projection
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
# 1. Input
# Only support patch input of shape (batch_size, num_channels, height, width) for now
assert in_channels is not None and patch_size is not None, "Patch input requires in_channels and patch_size."
assert sample_size is not None, "UTransformer2DModel over patched input must provide sample_size"
# 2. Define input layers
self.height = sample_size
self.width = sample_size
self.patch_size = patch_size
self.pos_embed = PatchEmbed(
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
use_pos_embed=use_patch_pos_embed,
)
# 3. Define transformers blocks
# Modify this to have in_blocks ("downsample" blocks, even though we don't actually downsample), a mid_block,
# and out_blocks ("upsample" blocks). Like a U-Net, there are skip connections from in_blocks to out_blocks in
# a "U"-shaped fashion (e.g. first in_block to last out_block, etc.).
# Quick hack to make the transformer block type configurable
if block_type == "unidiffuser":
block_cls = UniDiffuserBlock
else:
block_cls = UTransformerBlock
self.transformer_in_blocks = nn.ModuleList(
[
block_cls(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
norm_type=norm_type,
pre_layer_norm=pre_layer_norm,
norm_elementwise_affine=norm_elementwise_affine,
final_dropout=ff_final_dropout,
)
for d in range(num_layers // 2)
]
)
self.transformer_mid_block = block_cls(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
norm_type=norm_type,
pre_layer_norm=pre_layer_norm,
norm_elementwise_affine=norm_elementwise_affine,
final_dropout=ff_final_dropout,
)
# For each skip connection, we use a SkipBlock (concatenation + Linear + LayerNorm) to process the inputs
# before each transformer out_block.
self.transformer_out_blocks = nn.ModuleList(
[
nn.ModuleDict(
{
"skip": SkipBlock(
inner_dim,
),
"block": block_cls(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
norm_type=norm_type,
pre_layer_norm=pre_layer_norm,
norm_elementwise_affine=norm_elementwise_affine,
final_dropout=ff_final_dropout,
),
}
)
for d in range(num_layers // 2)
]
)
# 4. Define output layers
self.out_channels = in_channels if out_channels is None else out_channels
# Following the UniDiffuser U-ViT implementation, we process the transformer output with
# a LayerNorm layer with per-element affine params
self.norm_out = nn.LayerNorm(inner_dim)
def forward(
self,
hidden_states,
encoder_hidden_states=None,
timestep=None,
class_labels=None,
cross_attention_kwargs=None,
return_dict: bool = True,
hidden_states_is_embedding: bool = False,
unpatchify: bool = True,
):
"""
Args:
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
hidden_states
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.long`, *optional*):
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels
conditioning.
cross_attention_kwargs (*optional*):
Keyword arguments to supply to the cross attention layers, if used.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
hidden_states_is_embedding (`bool`, *optional*, defaults to `False`):
Whether or not hidden_states is an embedding directly usable by the transformer. In this case we will
ignore input handling (e.g. continuous, vectorized, etc.) and directly feed hidden_states into the
transformer blocks.
unpatchify (`bool`, *optional*, defaults to `True`):
Whether to unpatchify the transformer output.
Returns:
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
[`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
# 0. Check inputs
if not unpatchify and return_dict:
raise ValueError(
f"Cannot both define `unpatchify`: {unpatchify} and `return_dict`: {return_dict} since when"
f" `unpatchify` is {unpatchify} the returned output is of shape (batch_size, seq_len, hidden_dim)"
" rather than (batch_size, num_channels, height, width)."
)
# 1. Input
if not hidden_states_is_embedding:
hidden_states = self.pos_embed(hidden_states)
# 2. Blocks
# In ("downsample") blocks
skips = []
for in_block in self.transformer_in_blocks:
hidden_states = in_block(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
)
skips.append(hidden_states)
# Mid block
hidden_states = self.transformer_mid_block(hidden_states)
# Out ("upsample") blocks
for out_block in self.transformer_out_blocks:
hidden_states = out_block["skip"](hidden_states, skips.pop())
hidden_states = out_block["block"](
hidden_states,
encoder_hidden_states=encoder_hidden_states,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
)
# 3. Output
# Don't support AdaLayerNorm for now, so no conditioning/scale/shift logic
hidden_states = self.norm_out(hidden_states)
# hidden_states = self.proj_out(hidden_states)
if unpatchify:
# unpatchify
height = width = int(hidden_states.shape[1] ** 0.5)
hidden_states = hidden_states.reshape(
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
)
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
)
else:
output = hidden_states
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
class UniDiffuserModel(ModelMixin, ConfigMixin):
"""
Transformer model for a image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model. This is a
modification of [`UTransformer2DModel`] with input and output heads for the VAE-embedded latent image, the
CLIP-embedded image, and the CLIP-embedded prompt (see paper for more details).
Parameters:
text_dim (`int`): The hidden dimension of the CLIP text model used to embed images.
clip_img_dim (`int`): The hidden dimension of the CLIP vision model used to embed prompts.
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
Pass if the input is continuous. The number of channels in the input.
out_channels (`int`, *optional*):
The number of output channels; if `None`, defaults to `in_channels`.
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
norm_num_groups (`int`, *optional*, defaults to `32`):
The number of groups to use when performing Group Normalization.
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
attention_bias (`bool`, *optional*):
Configure if the TransformerBlocks' attention should contain a bias parameter.
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
`ImagePositionalEmbeddings`.
num_vector_embeds (`int`, *optional*):
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
Includes the class for the masked latent pixel.
patch_size (`int`, *optional*, defaults to 2):
The patch size to use in the patch embedding.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
up to but not more than steps than `num_embeds_ada_norm`.
use_linear_projection (int, *optional*): TODO: Not used
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used in each
transformer block.
upcast_attention (`bool`, *optional*):
Whether to upcast the query and key to float32 when performing the attention calculation.
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
The Layer Normalization implementation to use. Defaults to `torch.nn.LayerNorm`.
block_type (`str`, *optional*, defaults to `"unidiffuser"`):
The transformer block implementation to use. If `"unidiffuser"`, has the LayerNorms on the residual
backbone of each transformer block; otherwise has them in the attention/feedforward branches (the standard
behavior in `diffusers`.)
pre_layer_norm (`bool`, *optional*):
Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"),
as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm
(`pre_layer_norm = False`).
norm_elementwise_affine (`bool`, *optional*):
Whether to use learnable per-element affine parameters during layer normalization.
use_patch_pos_embed (`bool`, *optional*):
Whether to use position embeddings inside the patch embedding layer (`PatchEmbed`).
ff_final_dropout (`bool`, *optional*):
Whether to use a final Dropout layer after the feedforward network.
use_data_type_embedding (`bool`, *optional*):
Whether to use a data type embedding. This is only relevant for UniDiffuser-v1 style models; UniDiffuser-v1
is continue-trained from UniDiffuser-v0 on non-publically-available data and accepts a `data_type`
argument, which can either be `1` to use the weights trained on non-publically-available data or `0`
otherwise. This argument is subsequently embedded by the data type embedding, if used.
"""
@register_to_config
def __init__(
self,
text_dim: int = 768,
clip_img_dim: int = 512,
num_text_tokens: int = 77,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
sample_size: Optional[int] = None,
num_vector_embeds: Optional[int] = None,
patch_size: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
norm_type: str = "layer_norm",
block_type: str = "unidiffuser",
pre_layer_norm: bool = False,
use_timestep_embedding=False,
norm_elementwise_affine: bool = True,
use_patch_pos_embed=False,
ff_final_dropout: bool = True,
use_data_type_embedding: bool = False,
):
super().__init__()
# 0. Handle dimensions
self.inner_dim = num_attention_heads * attention_head_dim
assert sample_size is not None, "UniDiffuserModel over patched input must provide sample_size"
self.sample_size = sample_size
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.patch_size = patch_size
# Assume image is square...
self.num_patches = (self.sample_size // patch_size) * (self.sample_size // patch_size)
# 1. Define input layers
# 1.1 Input layers for text and image input
# For now, only support patch input for VAE latent image input
self.vae_img_in = PatchEmbed(
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=self.inner_dim,
use_pos_embed=use_patch_pos_embed,
)
self.clip_img_in = nn.Linear(clip_img_dim, self.inner_dim)
self.text_in = nn.Linear(text_dim, self.inner_dim)
# 1.2. Timestep embeddings for t_img, t_text
self.timestep_img_proj = Timesteps(
self.inner_dim,
flip_sin_to_cos=True,
downscale_freq_shift=0,
)
self.timestep_img_embed = (
TimestepEmbedding(
self.inner_dim,
4 * self.inner_dim,
out_dim=self.inner_dim,
)
if use_timestep_embedding
else nn.Identity()
)
self.timestep_text_proj = Timesteps(
self.inner_dim,
flip_sin_to_cos=True,
downscale_freq_shift=0,
)
self.timestep_text_embed = (
TimestepEmbedding(
self.inner_dim,
4 * self.inner_dim,
out_dim=self.inner_dim,
)
if use_timestep_embedding
else nn.Identity()
)
# 1.3. Positional embedding
self.num_text_tokens = num_text_tokens
self.num_tokens = 1 + 1 + num_text_tokens + 1 + self.num_patches
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, self.inner_dim))
self.pos_embed_drop = nn.Dropout(p=dropout)
trunc_normal_(self.pos_embed, std=0.02)
# 1.4. Handle data type token embeddings for UniDiffuser-V1, if necessary
self.use_data_type_embedding = use_data_type_embedding
if self.use_data_type_embedding:
self.data_type_token_embedding = nn.Embedding(2, self.inner_dim)
self.data_type_pos_embed_token = nn.Parameter(torch.zeros(1, 1, self.inner_dim))
# 2. Define transformer blocks
self.transformer = UTransformer2DModel(
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
in_channels=in_channels,
out_channels=out_channels,
num_layers=num_layers,
dropout=dropout,
norm_num_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attention_bias=attention_bias,
sample_size=sample_size,
num_vector_embeds=num_vector_embeds,
patch_size=patch_size,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
norm_type=norm_type,
block_type=block_type,
pre_layer_norm=pre_layer_norm,
norm_elementwise_affine=norm_elementwise_affine,
use_patch_pos_embed=use_patch_pos_embed,
ff_final_dropout=ff_final_dropout,
)
# 3. Define output layers
patch_dim = (patch_size**2) * out_channels
self.vae_img_out = nn.Linear(self.inner_dim, patch_dim)
self.clip_img_out = nn.Linear(self.inner_dim, clip_img_dim)
self.text_out = nn.Linear(self.inner_dim, text_dim)
@torch.jit.ignore
def no_weight_decay(self):
return {"pos_embed"}
def forward(
self,
latent_image_embeds: torch.FloatTensor,
image_embeds: torch.FloatTensor,
prompt_embeds: torch.FloatTensor,
timestep_img: Union[torch.Tensor, float, int],
timestep_text: Union[torch.Tensor, float, int],
data_type: Optional[Union[torch.Tensor, float, int]] = 1,
encoder_hidden_states=None,
cross_attention_kwargs=None,
):
"""
Args:
latent_image_embeds (`torch.FloatTensor` of shape `(batch size, latent channels, height, width)`):
Latent image representation from the VAE encoder.
image_embeds (`torch.FloatTensor` of shape `(batch size, 1, clip_img_dim)`):
CLIP-embedded image representation (unsqueezed in the first dimension).
prompt_embeds (`torch.FloatTensor` of shape `(batch size, seq_len, text_dim)`):
CLIP-embedded text representation.
timestep_img (`torch.long` or `float` or `int`):
Current denoising step for the image.
timestep_text (`torch.long` or `float` or `int`):
Current denoising step for the text.
data_type: (`torch.int` or `float` or `int`, *optional*, defaults to `1`):
Only used in UniDiffuser-v1-style models. Can be either `1`, to use weights trained on nonpublic data,
or `0` otherwise.
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
cross_attention_kwargs (*optional*):
Keyword arguments to supply to the cross attention layers, if used.
Returns:
`tuple`: Returns relevant parts of the model's noise prediction: the first element of the tuple is tbe VAE
image embedding, the second element is the CLIP image embedding, and the third element is the CLIP text
embedding.
"""
batch_size = latent_image_embeds.shape[0]
# 1. Input
# 1.1. Map inputs to shape (B, N, inner_dim)
vae_hidden_states = self.vae_img_in(latent_image_embeds)
clip_hidden_states = self.clip_img_in(image_embeds)
text_hidden_states = self.text_in(prompt_embeds)
num_text_tokens, num_img_tokens = text_hidden_states.size(1), vae_hidden_states.size(1)
# 1.2. Encode image timesteps to single token (B, 1, inner_dim)
if not torch.is_tensor(timestep_img):
timestep_img = torch.tensor([timestep_img], dtype=torch.long, device=vae_hidden_states.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep_img = timestep_img * torch.ones(batch_size, dtype=timestep_img.dtype, device=timestep_img.device)
timestep_img_token = self.timestep_img_proj(timestep_img)
# t_img_token does not contain any weights and will always return f32 tensors
# but time_embedding might be fp16, so we need to cast here.
timestep_img_token = timestep_img_token.to(dtype=self.dtype)
timestep_img_token = self.timestep_img_embed(timestep_img_token)
timestep_img_token = timestep_img_token.unsqueeze(dim=1)
# 1.3. Encode text timesteps to single token (B, 1, inner_dim)
if not torch.is_tensor(timestep_text):
timestep_text = torch.tensor([timestep_text], dtype=torch.long, device=vae_hidden_states.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep_text = timestep_text * torch.ones(batch_size, dtype=timestep_text.dtype, device=timestep_text.device)
timestep_text_token = self.timestep_text_proj(timestep_text)
# t_text_token does not contain any weights and will always return f32 tensors
# but time_embedding might be fp16, so we need to cast here.
timestep_text_token = timestep_text_token.to(dtype=self.dtype)
timestep_text_token = self.timestep_text_embed(timestep_text_token)
timestep_text_token = timestep_text_token.unsqueeze(dim=1)
# 1.4. Concatenate all of the embeddings together.
if self.use_data_type_embedding:
assert data_type is not None, "data_type must be supplied if the model uses a data type embedding"
if not torch.is_tensor(data_type):
data_type = torch.tensor([data_type], dtype=torch.int, device=vae_hidden_states.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
data_type = data_type * torch.ones(batch_size, dtype=data_type.dtype, device=data_type.device)
data_type_token = self.data_type_token_embedding(data_type).unsqueeze(dim=1)
hidden_states = torch.cat(
[
timestep_img_token,
timestep_text_token,
data_type_token,
text_hidden_states,
clip_hidden_states,
vae_hidden_states,
],
dim=1,
)
else:
hidden_states = torch.cat(
[timestep_img_token, timestep_text_token, text_hidden_states, clip_hidden_states, vae_hidden_states],
dim=1,
)
# 1.5. Prepare the positional embeddings and add to hidden states
# Note: I think img_vae should always have the proper shape, so there's no need to interpolate
# the position embeddings.
if self.use_data_type_embedding:
pos_embed = torch.cat(
[self.pos_embed[:, : 1 + 1, :], self.data_type_pos_embed_token, self.pos_embed[:, 1 + 1 :, :]], dim=1
)
else:
pos_embed = self.pos_embed
hidden_states = hidden_states + pos_embed
hidden_states = self.pos_embed_drop(hidden_states)
# 2. Blocks
hidden_states = self.transformer(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
timestep=None,
class_labels=None,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
hidden_states_is_embedding=True,
unpatchify=False,
)[0]
# 3. Output
# Split out the predicted noise representation.
if self.use_data_type_embedding:
(
t_img_token_out,
t_text_token_out,
data_type_token_out,
text_out,
img_clip_out,
img_vae_out,
) = hidden_states.split((1, 1, 1, num_text_tokens, 1, num_img_tokens), dim=1)
else:
t_img_token_out, t_text_token_out, text_out, img_clip_out, img_vae_out = hidden_states.split(
(1, 1, num_text_tokens, 1, num_img_tokens), dim=1
)
img_vae_out = self.vae_img_out(img_vae_out)
# unpatchify
height = width = int(img_vae_out.shape[1] ** 0.5)
img_vae_out = img_vae_out.reshape(
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
)
img_vae_out = torch.einsum("nhwpqc->nchpwq", img_vae_out)
img_vae_out = img_vae_out.reshape(
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
)
img_clip_out = self.clip_img_out(img_clip_out)
text_out = self.text_out(text_out)
return img_vae_out, img_clip_out, text_out
import inspect
from dataclasses import dataclass
from typing import Callable, List, Optional, Union
import numpy as np
import PIL
import torch
from transformers import (
CLIPImageProcessor,
CLIPTextModel,
CLIPTokenizer,
CLIPVisionModelWithProjection,
GPT2Tokenizer,
)
from ...models import AutoencoderKL
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
PIL_INTERPOLATION,
deprecate,
is_accelerate_available,
is_accelerate_version,
logging,
randn_tensor,
)
from ...utils.outputs import BaseOutput
from ..pipeline_utils import DiffusionPipeline
from .modeling_text_decoder import UniDiffuserTextDecoder
from .modeling_uvit import UniDiffuserModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def preprocess(image):
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
w, h = image[0].size
w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
return image
# New BaseOutput child class for joint image-text output
@dataclass
class ImageTextPipelineOutput(BaseOutput):
"""
Output class for joint image-text pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
text (`List[str]` or `List[List[str]]`)
List of generated text strings of length `batch_size` or a list of list of strings whose outer list has
length `batch_size`. Text generated by the diffusion pipeline.
"""
images: Optional[Union[List[PIL.Image.Image], np.ndarray]]
text: Optional[Union[List[str], List[List[str]]]]
class UniDiffuserPipeline(DiffusionPipeline):
r"""
Pipeline for a bimodal image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model, which supports
unconditional text and image generation, text-conditioned image generation, image-conditioned text generation, and
joint image-text generation.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. This
is part of the UniDiffuser image representation, along with the CLIP vision encoding.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Similar to Stable Diffusion, UniDiffuser uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) to encode text
prompts.
image_encoder ([`CLIPVisionModel`]):
UniDiffuser uses the vision portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModel) to encode
images as part of its image representation, along with the VAE latent representation.
image_processor ([`CLIPImageProcessor`]):
CLIP image processor of class
[CLIPImageProcessor](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPImageProcessor),
used to preprocess the image before CLIP encoding it with `image_encoder`.
clip_tokenizer ([`CLIPTokenizer`]):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTokenizer) which
is used to tokenizer a prompt before encoding it with `text_encoder`.
text_decoder ([`UniDiffuserTextDecoder`]):
Frozen text decoder. This is a GPT-style model which is used to generate text from the UniDiffuser
embedding.
text_tokenizer ([`GPT2Tokenizer`]):
Tokenizer of class
[GPT2Tokenizer](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Tokenizer) which
is used along with the `text_decoder` to decode text for text generation.
unet ([`UniDiffuserModel`]):
UniDiffuser uses a [U-ViT](https://github.com/baofff/U-ViT) model architecture, which is similar to a
[`Transformer2DModel`] with U-Net-style skip connections between transformer layers.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image and/or text latents. The
original UniDiffuser paper uses the [`DPMSolverMultistepScheduler`] scheduler.
"""
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
image_encoder: CLIPVisionModelWithProjection,
image_processor: CLIPImageProcessor,
clip_tokenizer: CLIPTokenizer,
text_decoder: UniDiffuserTextDecoder,
text_tokenizer: GPT2Tokenizer,
unet: UniDiffuserModel,
scheduler: KarrasDiffusionSchedulers,
):
super().__init__()
if text_encoder.config.hidden_size != text_decoder.prefix_inner_dim:
raise ValueError(
f"The text encoder hidden size and text decoder prefix inner dim must be the same, but"
f" `text_encoder.config.hidden_size`: {text_encoder.config.hidden_size} and `text_decoder.prefix_inner_dim`: {text_decoder.prefix_inner_dim}"
)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
image_encoder=image_encoder,
image_processor=image_processor,
clip_tokenizer=clip_tokenizer,
text_decoder=text_decoder,
text_tokenizer=text_tokenizer,
unet=unet,
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.num_channels_latents = vae.config.latent_channels
self.text_encoder_seq_len = text_encoder.config.max_position_embeddings
self.text_encoder_hidden_size = text_encoder.config.hidden_size
self.image_encoder_projection_dim = image_encoder.config.projection_dim
self.unet_resolution = unet.config.sample_size
self.text_intermediate_dim = self.text_encoder_hidden_size
if self.text_decoder.prefix_hidden_dim is not None:
self.text_intermediate_dim = self.text_decoder.prefix_hidden_dim
self.mode = None
# TODO: handle safety checking?
self.safety_checker = None
# Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
# Add self.image_encoder, self.text_decoder to cpu_offloaded_models list
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
`torch.device('meta')` and loaded to GPU only when their specific submodule has its `forward` method called.
Note that offloading happens on a submodule basis. Memory savings are higher than with
`enable_model_cpu_offload`, but performance is lower.
"""
if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
from accelerate import cpu_offload
else:
raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
device = torch.device(f"cuda:{gpu_id}")
if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.image_encoder, self.text_decoder]:
cpu_offload(cpu_offloaded_model, device)
if self.safety_checker is not None:
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
# Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload
# Add self.image_encoder, self.text_decoder to cpu_offloaded_models list
def enable_model_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
"""
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate import cpu_offload_with_hook
else:
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
device = torch.device(f"cuda:{gpu_id}")
if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
hook = None
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae, self.image_encoder, self.text_decoder]:
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
if self.safety_checker is not None:
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
# We'll offload the last model manually.
self.final_offload_hook = hook
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def _execution_device(self):
r"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if not hasattr(self.unet, "_hf_hook"):
return self.device
for module in self.unet.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return self.device
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def _infer_mode(self, prompt, prompt_embeds, image, latents, prompt_latents, vae_latents, clip_latents):
r"""
Infer the generation task ('mode') from the inputs to `__call__`. If the mode has been manually set, the set
mode will be used.
"""
prompt_available = (prompt is not None) or (prompt_embeds is not None)
image_available = image is not None
input_available = prompt_available or image_available
prompt_latents_available = prompt_latents is not None
vae_latents_available = vae_latents is not None
clip_latents_available = clip_latents is not None
full_latents_available = latents is not None
image_latents_available = vae_latents_available and clip_latents_available
all_indv_latents_available = prompt_latents_available and image_latents_available
if self.mode is not None:
# Preferentially use the mode set by the user
mode = self.mode
elif prompt_available:
mode = "text2img"
elif image_available:
mode = "img2text"
else:
# Neither prompt nor image supplied, infer based on availability of latents
if full_latents_available or all_indv_latents_available:
mode = "joint"
elif prompt_latents_available:
mode = "text"
elif image_latents_available:
mode = "img"
else:
# No inputs or latents available
mode = "joint"
# Give warnings for ambiguous cases
if self.mode is None and prompt_available and image_available:
logger.warning(
f"You have supplied both a text prompt and image to the pipeline and mode has not been set manually,"
f" defaulting to mode '{mode}'."
)
if self.mode is None and not input_available:
if vae_latents_available != clip_latents_available:
# Exactly one of vae_latents and clip_latents is supplied
logger.warning(
f"You have supplied exactly one of `vae_latents` and `clip_latents`, whereas either both or none"
f" are expected to be supplied. Defaulting to mode '{mode}'."
)
elif not prompt_latents_available and not vae_latents_available and not clip_latents_available:
# No inputs or latents supplied
logger.warning(
f"No inputs or latents have been supplied, and mode has not been manually set,"
f" defaulting to mode '{mode}'."
)
return mode
# Functions to manually set the mode
def set_text_mode(self):
r"""Manually set the generation mode to unconditional ("marginal") text generation."""
self.mode = "text"
def set_image_mode(self):
r"""Manually set the generation mode to unconditional ("marginal") image generation."""
self.mode = "img"
def set_text_to_image_mode(self):
r"""Manually set the generation mode to text-conditioned image generation."""
self.mode = "text2img"
def set_image_to_text_mode(self):
r"""Manually set the generation mode to image-conditioned text generation."""
self.mode = "img2text"
def set_joint_mode(self):
r"""Manually set the generation mode to unconditional joint image-text generation."""
self.mode = "joint"
def reset_mode(self):
r"""Removes a manually set mode; after calling this, the pipeline will infer the mode from inputs."""
self.mode = None
def _infer_batch_size(
self,
mode,
prompt,
prompt_embeds,
image,
num_images_per_prompt,
num_prompts_per_image,
latents,
prompt_latents,
vae_latents,
clip_latents,
):
r"""Infers the batch size and multiplier depending on mode and supplied arguments to `__call__`."""
if num_images_per_prompt is None:
num_images_per_prompt = 1
if num_prompts_per_image is None:
num_prompts_per_image = 1
assert num_images_per_prompt > 0, "num_images_per_prompt must be a positive integer"
assert num_prompts_per_image > 0, "num_prompts_per_image must be a positive integer"
if mode in ["text2img"]:
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
# Either prompt or prompt_embeds must be present for text2img.
batch_size = prompt_embeds.shape[0]
multiplier = num_images_per_prompt
elif mode in ["img2text"]:
if isinstance(image, PIL.Image.Image):
batch_size = 1
else:
# Image must be available and type either PIL.Image.Image or torch.FloatTensor.
# Not currently supporting something like image_embeds.
batch_size = image.shape[0]
multiplier = num_prompts_per_image
elif mode in ["img"]:
if vae_latents is not None:
batch_size = vae_latents.shape[0]
elif clip_latents is not None:
batch_size = clip_latents.shape[0]
else:
batch_size = 1
multiplier = num_images_per_prompt
elif mode in ["text"]:
if prompt_latents is not None:
batch_size = prompt_latents.shape[0]
else:
batch_size = 1
multiplier = num_prompts_per_image
elif mode in ["joint"]:
if latents is not None:
batch_size = latents.shape[0]
elif prompt_latents is not None:
batch_size = prompt_latents.shape[0]
elif vae_latents is not None:
batch_size = vae_latents.shape[0]
elif clip_latents is not None:
batch_size = clip_latents.shape[0]
else:
batch_size = 1
if num_images_per_prompt == num_prompts_per_image:
multiplier = num_images_per_prompt
else:
multiplier = min(num_images_per_prompt, num_prompts_per_image)
logger.warning(
f"You are using mode `{mode}` and `num_images_per_prompt`: {num_images_per_prompt} and"
f" num_prompts_per_image: {num_prompts_per_image} are not equal. Using batch size equal to"
f" `min(num_images_per_prompt, num_prompts_per_image) = {batch_size}."
)
return batch_size, multiplier
# Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
# self.tokenizer => self.clip_tokenizer
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
"""
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
text_inputs = self.clip_tokenizer(
prompt,
padding="max_length",
max_length=self.clip_tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.clip_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.clip_tokenizer.batch_decode(
untruncated_ids[:, self.clip_tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.clip_tokenizer.model_max_length} tokens: {removed_text}"
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
prompt_embeds = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
max_length = prompt_embeds.shape[1]
uncond_input = self.clip_tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = uncond_input.attention_mask.to(device)
else:
attention_mask = None
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return prompt_embeds
# Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix.StableDiffusionInstructPix2PixPipeline.prepare_image_latents
# Add num_prompts_per_image argument, sample from autoencoder moment distribution
def encode_image_vae_latents(
self,
image,
batch_size,
num_prompts_per_image,
dtype,
device,
do_classifier_free_guidance,
generator=None,
):
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
raise ValueError(
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
)
image = image.to(device=device, dtype=dtype)
batch_size = batch_size * num_prompts_per_image
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if isinstance(generator, list):
image_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
* self.vae.config.scaling_factor
for i in range(batch_size)
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
# Scale image_latents by the VAE's scaling factor
image_latents = image_latents * self.vae.config.scaling_factor
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
# expand image_latents for batch_size
deprecation_message = (
f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial"
" images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
" your script to pass as many initial images as text prompts to suppress this warning."
)
deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
additional_image_per_prompt = batch_size // image_latents.shape[0]
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
)
else:
image_latents = torch.cat([image_latents], dim=0)
if do_classifier_free_guidance:
uncond_image_latents = torch.zeros_like(image_latents)
image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0)
return image_latents
def encode_image_clip_latents(
self,
image,
batch_size,
num_prompts_per_image,
dtype,
device,
generator=None,
):
# Map image to CLIP embedding.
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
raise ValueError(
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
)
preprocessed_image = self.image_processor.preprocess(
image,
return_tensors="pt",
)
preprocessed_image = preprocessed_image.to(device=device, dtype=dtype)
batch_size = batch_size * num_prompts_per_image
if isinstance(generator, list):
image_latents = [
self.image_encoder(**preprocessed_image[i : i + 1]).image_embeds for i in range(batch_size)
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = self.image_encoder(**preprocessed_image).image_embeds
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
# expand image_latents for batch_size
deprecation_message = (
f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial"
" images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
" your script to pass as many initial images as text prompts to suppress this warning."
)
deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
additional_image_per_prompt = batch_size // image_latents.shape[0]
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
)
else:
image_latents = torch.cat([image_latents], dim=0)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
return image_latents
# Note that the CLIP latents are not decoded for image generation.
# Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
# Rename: decode_latents -> decode_image_latents
def decode_image_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
def prepare_text_latents(
self, batch_size, num_images_per_prompt, seq_len, hidden_size, dtype, device, generator, latents=None
):
# Prepare latents for the CLIP embedded prompt.
shape = (batch_size * num_images_per_prompt, seq_len, hidden_size)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
# latents is assumed to have shace (B, L, D)
latents = latents.repeat(num_images_per_prompt, 1, 1)
latents = latents.to(device=device, dtype=dtype)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
# Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
# Rename prepare_latents -> prepare_image_vae_latents and add num_prompts_per_image argument.
def prepare_image_vae_latents(
self,
batch_size,
num_prompts_per_image,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
shape = (
batch_size * num_prompts_per_image,
num_channels_latents,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
# latents is assumed to have shape (B, C, H, W)
latents = latents.repeat(num_prompts_per_image, 1, 1, 1)
latents = latents.to(device=device, dtype=dtype)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
def prepare_image_clip_latents(
self, batch_size, num_prompts_per_image, clip_img_dim, dtype, device, generator, latents=None
):
# Prepare latents for the CLIP embedded image.
shape = (batch_size * num_prompts_per_image, 1, clip_img_dim)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
# latents is assumed to have shape (B, L, D)
latents = latents.repeat(num_prompts_per_image, 1, 1)
latents = latents.to(device=device, dtype=dtype)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
def _split(self, x, height, width):
r"""
Splits a flattened embedding x of shape (B, C * H * W + clip_img_dim) into two tensors of shape (B, C, H, W)
and (B, 1, clip_img_dim)
"""
batch_size = x.shape[0]
latent_height = height // self.vae_scale_factor
latent_width = width // self.vae_scale_factor
img_vae_dim = self.num_channels_latents * latent_height * latent_width
img_vae, img_clip = x.split([img_vae_dim, self.image_encoder_projection_dim], dim=1)
img_vae = torch.reshape(img_vae, (batch_size, self.num_channels_latents, latent_height, latent_width))
img_clip = torch.reshape(img_clip, (batch_size, 1, self.image_encoder_projection_dim))
return img_vae, img_clip
def _combine(self, img_vae, img_clip):
r"""
Combines a latent iamge img_vae of shape (B, C, H, W) and a CLIP-embedded image img_clip of shape (B, 1,
clip_img_dim) into a single tensor of shape (B, C * H * W + clip_img_dim).
"""
img_vae = torch.reshape(img_vae, (img_vae.shape[0], -1))
img_clip = torch.reshape(img_clip, (img_clip.shape[0], -1))
return torch.concat([img_vae, img_clip], dim=-1)
def _split_joint(self, x, height, width):
r"""
Splits a flattened embedding x of shape (B, C * H * W + clip_img_dim + text_seq_len * text_dim] into (img_vae,
img_clip, text) where img_vae is of shape (B, C, H, W), img_clip is of shape (B, 1, clip_img_dim), and text is
of shape (B, text_seq_len, text_dim).
"""
batch_size = x.shape[0]
latent_height = height // self.vae_scale_factor
latent_width = width // self.vae_scale_factor
img_vae_dim = self.num_channels_latents * latent_height * latent_width
text_dim = self.text_encoder_seq_len * self.text_intermediate_dim
img_vae, img_clip, text = x.split([img_vae_dim, self.image_encoder_projection_dim, text_dim], dim=1)
img_vae = torch.reshape(img_vae, (batch_size, self.num_channels_latents, latent_height, latent_width))
img_clip = torch.reshape(img_clip, (batch_size, 1, self.image_encoder_projection_dim))
text = torch.reshape(text, (batch_size, self.text_encoder_seq_len, self.text_intermediate_dim))
return img_vae, img_clip, text
def _combine_joint(self, img_vae, img_clip, text):
r"""
Combines a latent image img_vae of shape (B, C, H, W), a CLIP-embedded image img_clip of shape (B, L_img,
clip_img_dim), and a text embedding text of shape (B, L_text, text_dim) into a single embedding x of shape (B,
C * H * W + L_img * clip_img_dim + L_text * text_dim).
"""
img_vae = torch.reshape(img_vae, (img_vae.shape[0], -1))
img_clip = torch.reshape(img_clip, (img_clip.shape[0], -1))
text = torch.reshape(text, (text.shape[0], -1))
return torch.concat([img_vae, img_clip, text], dim=-1)
def _get_noise_pred(
self,
mode,
latents,
t,
prompt_embeds,
img_vae,
img_clip,
max_timestep,
data_type,
guidance_scale,
generator,
device,
height,
width,
):
r"""
Gets the noise prediction using the `unet` and performs classifier-free guidance, if necessary.
"""
if mode == "joint":
# Joint text-image generation
img_vae_latents, img_clip_latents, text_latents = self._split_joint(latents, height, width)
img_vae_out, img_clip_out, text_out = self.unet(
img_vae_latents, img_clip_latents, text_latents, timestep_img=t, timestep_text=t, data_type=data_type
)
x_out = self._combine_joint(img_vae_out, img_clip_out, text_out)
if guidance_scale <= 1.0:
return x_out
# Classifier-free guidance
img_vae_T = randn_tensor(img_vae.shape, generator=generator, device=device, dtype=img_vae.dtype)
img_clip_T = randn_tensor(img_clip.shape, generator=generator, device=device, dtype=img_clip.dtype)
text_T = randn_tensor(prompt_embeds.shape, generator=generator, device=device, dtype=prompt_embeds.dtype)
_, _, text_out_uncond = self.unet(
img_vae_T, img_clip_T, text_latents, timestep_img=max_timestep, timestep_text=t, data_type=data_type
)
img_vae_out_uncond, img_clip_out_uncond, _ = self.unet(
img_vae_latents,
img_clip_latents,
text_T,
timestep_img=t,
timestep_text=max_timestep,
data_type=data_type,
)
x_out_uncond = self._combine_joint(img_vae_out_uncond, img_clip_out_uncond, text_out_uncond)
return guidance_scale * x_out + (1.0 - guidance_scale) * x_out_uncond
elif mode == "text2img":
# Text-conditioned image generation
img_vae_latents, img_clip_latents = self._split(latents, height, width)
img_vae_out, img_clip_out, text_out = self.unet(
img_vae_latents, img_clip_latents, prompt_embeds, timestep_img=t, timestep_text=0, data_type=data_type
)
img_out = self._combine(img_vae_out, img_clip_out)
if guidance_scale <= 1.0:
return img_out
# Classifier-free guidance
text_T = randn_tensor(prompt_embeds.shape, generator=generator, device=device, dtype=prompt_embeds.dtype)
img_vae_out_uncond, img_clip_out_uncond, text_out_uncond = self.unet(
img_vae_latents,
img_clip_latents,
text_T,
timestep_img=t,
timestep_text=max_timestep,
data_type=data_type,
)
img_out_uncond = self._combine(img_vae_out_uncond, img_clip_out_uncond)
return guidance_scale * img_out + (1.0 - guidance_scale) * img_out_uncond
elif mode == "img2text":
# Image-conditioned text generation
img_vae_out, img_clip_out, text_out = self.unet(
img_vae, img_clip, latents, timestep_img=0, timestep_text=t, data_type=data_type
)
if guidance_scale <= 1.0:
return text_out
# Classifier-free guidance
img_vae_T = randn_tensor(img_vae.shape, generator=generator, device=device, dtype=img_vae.dtype)
img_clip_T = randn_tensor(img_clip.shape, generator=generator, device=device, dtype=img_clip.dtype)
img_vae_out_uncond, img_clip_out_uncond, text_out_uncond = self.unet(
img_vae_T, img_clip_T, latents, timestep_img=max_timestep, timestep_text=t, data_type=data_type
)
return guidance_scale * text_out + (1.0 - guidance_scale) * text_out_uncond
elif mode == "text":
# Unconditional ("marginal") text generation (no CFG)
img_vae_out, img_clip_out, text_out = self.unet(
img_vae, img_clip, latents, timestep_img=max_timestep, timestep_text=t, data_type=data_type
)
return text_out
elif mode == "img":
# Unconditional ("marginal") image generation (no CFG)
img_vae_latents, img_clip_latents = self._split(latents, height, width)
img_vae_out, img_clip_out, text_out = self.unet(
img_vae_latents,
img_clip_latents,
prompt_embeds,
timestep_img=t,
timestep_text=max_timestep,
data_type=data_type,
)
img_out = self._combine(img_vae_out, img_clip_out)
return img_out
def check_latents_shape(self, latents_name, latents, expected_shape):
latents_shape = latents.shape
expected_num_dims = len(expected_shape) + 1 # expected dimensions plus the batch dimension
expected_shape_str = ", ".join(str(dim) for dim in expected_shape)
if len(latents_shape) != expected_num_dims:
raise ValueError(
f"`{latents_name}` should have shape (batch_size, {expected_shape_str}), but the current shape"
f" {latents_shape} has {len(latents_shape)} dimensions."
)
for i in range(1, expected_num_dims):
if latents_shape[i] != expected_shape[i - 1]:
raise ValueError(
f"`{latents_name}` should have shape (batch_size, {expected_shape_str}), but the current shape"
f" {latents_shape} has {latents_shape[i]} != {expected_shape[i - 1]} at dimension {i}."
)
def check_inputs(
self,
mode,
prompt,
image,
height,
width,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
latents=None,
prompt_latents=None,
vae_latents=None,
clip_latents=None,
):
# Check inputs before running the generative process.
if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
raise ValueError(
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}."
)
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if mode == "text2img":
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if mode == "img2text":
if image is None:
raise ValueError("`img2text` mode requires an image to be provided.")
# Check provided latents
latent_height = height // self.vae_scale_factor
latent_width = width // self.vae_scale_factor
full_latents_available = latents is not None
prompt_latents_available = prompt_latents is not None
vae_latents_available = vae_latents is not None
clip_latents_available = clip_latents is not None
if full_latents_available:
individual_latents_available = (
prompt_latents is not None or vae_latents is not None or clip_latents is not None
)
if individual_latents_available:
logger.warning(
"You have supplied both `latents` and at least one of `prompt_latents`, `vae_latents`, and"
" `clip_latents`. The value of `latents` will override the value of any individually supplied latents."
)
# Check shape of full latents
img_vae_dim = self.num_channels_latents * latent_height * latent_width
text_dim = self.text_encoder_seq_len * self.text_encoder_hidden_size
latents_dim = img_vae_dim + self.image_encoder_projection_dim + text_dim
latents_expected_shape = (latents_dim,)
self.check_latents_shape("latents", latents, latents_expected_shape)
# Check individual latent shapes, if present
if prompt_latents_available:
prompt_latents_expected_shape = (self.text_encoder_seq_len, self.text_encoder_hidden_size)
self.check_latents_shape("prompt_latents", prompt_latents, prompt_latents_expected_shape)
if vae_latents_available:
vae_latents_expected_shape = (self.num_channels_latents, latent_height, latent_width)
self.check_latents_shape("vae_latents", vae_latents, vae_latents_expected_shape)
if clip_latents_available:
clip_latents_expected_shape = (1, self.image_encoder_projection_dim)
self.check_latents_shape("clip_latents", clip_latents, clip_latents_expected_shape)
if mode in ["text2img", "img"] and vae_latents_available and clip_latents_available:
if vae_latents.shape[0] != clip_latents.shape[0]:
raise ValueError(
f"Both `vae_latents` and `clip_latents` are supplied, but their batch dimensions are not equal:"
f" {vae_latents.shape[0]} != {clip_latents.shape[0]}."
)
if mode == "joint" and prompt_latents_available and vae_latents_available and clip_latents_available:
if prompt_latents.shape[0] != vae_latents.shape[0] or prompt_latents.shape[0] != clip_latents.shape[0]:
raise ValueError(
f"All of `prompt_latents`, `vae_latents`, and `clip_latents` are supplied, but their batch"
f" dimensions are not equal: {prompt_latents.shape[0]} != {vae_latents.shape[0]}"
f" != {clip_latents.shape[0]}."
)
@torch.no_grad()
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
data_type: Optional[int] = 1,
num_inference_steps: int = 50,
guidance_scale: float = 8.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
num_prompts_per_image: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_latents: Optional[torch.FloatTensor] = None,
vae_latents: Optional[torch.FloatTensor] = None,
clip_latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`
instead. Required for text-conditioned image generation (`text2img`) mode.
image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*):
`Image`, or tensor representing an image batch. Required for image-conditioned text generation
(`img2text`) mode.
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The width in pixels of the generated image.
data_type (`int`, *optional*, defaults to 1):
The data type (either 0 or 1). Only used if you are loading a checkpoint which supports a data type
embedding; this is added for compatibility with the UniDiffuser-v1 checkpoint.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 8.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality. Note that the original [UniDiffuser
paper](https://arxiv.org/pdf/2303.06555.pdf) uses a different definition of the guidance scale `w'`,
which satisfies `w = w' + 1`.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`). Used in text-conditioned image generation (`text2img`) mode.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt. Used in `text2img` (text-conditioned image generation) and
`img` mode. If the mode is joint and both `num_images_per_prompt` and `num_prompts_per_image` are
supplied, `min(num_images_per_prompt, num_prompts_per_image)` samples will be generated.
num_prompts_per_image (`int`, *optional*, defaults to 1):
The number of prompts to generate per image. Used in `img2text` (image-conditioned text generation) and
`text` mode. If the mode is joint and both `num_images_per_prompt` and `num_prompts_per_image` are
supplied, `min(num_images_per_prompt, num_prompts_per_image)` samples will be generated.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for joint
image-text generation. Can be used to tweak the same generation with different prompts. If not
provided, a latents tensor will be generated by sampling using the supplied random `generator`. Note
that this is assumed to be a full set of VAE, CLIP, and text latents, if supplied, this will override
the value of `prompt_latents`, `vae_latents`, and `clip_latents`.
prompt_latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for text
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will be generated by sampling using the supplied random `generator`.
vae_latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will be generated by sampling using the supplied random `generator`.
clip_latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument. Used in text-conditioned
image generation (`text2img`) mode.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument. Used in text-conditioned image generation (`text2img`) mode.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.unidiffuser.ImageTextPipelineOutput`] instead of a plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Returns:
[`~pipelines.unidiffuser.ImageTextPipelineOutput`] or `tuple`:
[`pipelines.unidiffuser.ImageTextPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is a list with the generated images, and the second element is a list
of generated texts.
"""
# 0. Default height and width to unet
height = height or self.unet_resolution * self.vae_scale_factor
width = width or self.unet_resolution * self.vae_scale_factor
# 1. Check inputs
# Recalculate mode for each call to the pipeline.
mode = self._infer_mode(prompt, prompt_embeds, image, latents, prompt_latents, vae_latents, clip_latents)
self.check_inputs(
mode,
prompt,
image,
height,
width,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
latents,
prompt_latents,
vae_latents,
clip_latents,
)
# 2. Define call parameters
batch_size, multiplier = self._infer_batch_size(
mode,
prompt,
prompt_embeds,
image,
num_images_per_prompt,
num_prompts_per_image,
latents,
prompt_latents,
vae_latents,
clip_latents,
)
device = self._execution_device
reduce_text_emb_dim = self.text_intermediate_dim < self.text_encoder_hidden_size or self.mode != "text2img"
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
# Note that this differs from the formulation in the unidiffusers paper!
# do_classifier_free_guidance = guidance_scale > 1.0
# check if scheduler is in sigmas space
# scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas")
# 3. Encode input prompt, if available; otherwise prepare text latents
if latents is not None:
# Overwrite individual latents
vae_latents, clip_latents, prompt_latents = self._split_joint(latents, height, width)
if mode in ["text2img"]:
# 3.1. Encode input prompt, if available
assert prompt is not None or prompt_embeds is not None
prompt_embeds = self._encode_prompt(
prompt=prompt,
device=device,
num_images_per_prompt=multiplier,
do_classifier_free_guidance=False, # don't support standard classifier-free guidance for now
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)
else:
# 3.2. Prepare text latent variables, if input not available
prompt_embeds = self.prepare_text_latents(
batch_size=batch_size,
num_images_per_prompt=multiplier,
seq_len=self.text_encoder_seq_len,
hidden_size=self.text_encoder_hidden_size,
dtype=self.text_encoder.dtype, # Should work with both full precision and mixed precision
device=device,
generator=generator,
latents=prompt_latents,
)
if reduce_text_emb_dim:
prompt_embeds = self.text_decoder.encode(prompt_embeds)
# 4. Encode image, if available; otherwise prepare image latents
if mode in ["img2text"]:
# 4.1. Encode images, if available
assert image is not None, "`img2text` requires a conditioning image"
# Encode image using VAE
image_vae = preprocess(image)
height, width = image_vae.shape[-2:]
image_vae_latents = self.encode_image_vae_latents(
image=image_vae,
batch_size=batch_size,
num_prompts_per_image=multiplier,
dtype=prompt_embeds.dtype,
device=device,
do_classifier_free_guidance=False, # Copied from InstructPix2Pix, don't use their version of CFG
generator=generator,
)
# Encode image using CLIP
image_clip_latents = self.encode_image_clip_latents(
image=image,
batch_size=batch_size,
num_prompts_per_image=multiplier,
dtype=prompt_embeds.dtype,
device=device,
generator=generator,
)
# (batch_size, clip_hidden_size) => (batch_size, 1, clip_hidden_size)
image_clip_latents = image_clip_latents.unsqueeze(1)
else:
# 4.2. Prepare image latent variables, if input not available
# Prepare image VAE latents in latent space
image_vae_latents = self.prepare_image_vae_latents(
batch_size=batch_size,
num_prompts_per_image=multiplier,
num_channels_latents=self.num_channels_latents,
height=height,
width=width,
dtype=prompt_embeds.dtype,
device=device,
generator=generator,
latents=vae_latents,
)
# Prepare image CLIP latents
image_clip_latents = self.prepare_image_clip_latents(
batch_size=batch_size,
num_prompts_per_image=multiplier,
clip_img_dim=self.image_encoder_projection_dim,
dtype=prompt_embeds.dtype,
device=device,
generator=generator,
latents=clip_latents,
)
# 5. Set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# max_timestep = timesteps[0]
max_timestep = self.scheduler.config.num_train_timesteps
# 6. Prepare latent variables
if mode == "joint":
latents = self._combine_joint(image_vae_latents, image_clip_latents, prompt_embeds)
elif mode in ["text2img", "img"]:
latents = self._combine(image_vae_latents, image_clip_latents)
elif mode in ["img2text", "text"]:
latents = prompt_embeds
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
logger.debug(f"Scheduler extra step kwargs: {extra_step_kwargs}")
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# predict the noise residual
# Also applies classifier-free guidance as described in the UniDiffuser paper
noise_pred = self._get_noise_pred(
mode,
latents,
t,
prompt_embeds,
image_vae_latents,
image_clip_latents,
max_timestep,
data_type,
guidance_scale,
generator,
device,
height,
width,
)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 9. Post-processing
gen_image = None
gen_text = None
if mode == "joint":
image_vae_latents, image_clip_latents, text_latents = self._split_joint(latents, height, width)
# Map latent VAE image back to pixel space
gen_image = self.decode_image_latents(image_vae_latents)
# Generate text using the text decoder
output_token_list, seq_lengths = self.text_decoder.generate_captions(
text_latents, self.text_tokenizer.eos_token_id, device=device
)
output_list = output_token_list.cpu().numpy()
gen_text = [
self.text_tokenizer.decode(output[: int(length)], skip_special_tokens=True)
for output, length in zip(output_list, seq_lengths)
]
elif mode in ["text2img", "img"]:
image_vae_latents, image_clip_latents = self._split(latents, height, width)
gen_image = self.decode_image_latents(image_vae_latents)
elif mode in ["img2text", "text"]:
text_latents = latents
output_token_list, seq_lengths = self.text_decoder.generate_captions(
text_latents, self.text_tokenizer.eos_token_id, device=device
)
output_list = output_token_list.cpu().numpy()
gen_text = [
self.text_tokenizer.decode(output[: int(length)], skip_special_tokens=True)
for output, length in zip(output_list, seq_lengths)
]
# 10. Convert to PIL
if output_type == "pil" and gen_image is not None:
gen_image = self.numpy_to_pil(gen_image)
# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
if not return_dict:
return (gen_image, gen_text)
return ImageTextPipelineOutput(images=gen_image, text=gen_text)
......@@ -152,6 +152,21 @@ class IFSuperResolutionPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class ImageTextPipelineOutput(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class KandinskyImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
......@@ -632,6 +647,51 @@ class UnCLIPPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class UniDiffuserModel(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class UniDiffuserPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class UniDiffuserTextDecoder(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class VersatileDiffusionDualGuidedPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
......
import gc
import random
import unittest
import numpy as np
import torch
from PIL import Image
from transformers import (
CLIPImageProcessor,
CLIPTextModel,
CLIPTokenizer,
CLIPVisionModelWithProjection,
GPT2Tokenizer,
)
from diffusers import (
AutoencoderKL,
DPMSolverMultistepScheduler,
UniDiffuserModel,
UniDiffuserPipeline,
UniDiffuserTextDecoder,
)
from diffusers.utils import floats_tensor, load_image, randn_tensor, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = UniDiffuserPipeline
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
def get_dummy_components(self):
unet = UniDiffuserModel.from_pretrained(
"hf-internal-testing/unidiffuser-diffusers-test",
subfolder="unet",
)
scheduler = DPMSolverMultistepScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
solver_order=3,
)
vae = AutoencoderKL.from_pretrained(
"hf-internal-testing/unidiffuser-diffusers-test",
subfolder="vae",
)
text_encoder = CLIPTextModel.from_pretrained(
"hf-internal-testing/unidiffuser-diffusers-test",
subfolder="text_encoder",
)
clip_tokenizer = CLIPTokenizer.from_pretrained(
"hf-internal-testing/unidiffuser-diffusers-test",
subfolder="clip_tokenizer",
)
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"hf-internal-testing/unidiffuser-diffusers-test",
subfolder="image_encoder",
)
# From the Stable Diffusion Image Variation pipeline tests
image_processor = CLIPImageProcessor(crop_size=32, size=32)
# image_processor = CLIPImageProcessor.from_pretrained("hf-internal-testing/tiny-random-clip")
text_tokenizer = GPT2Tokenizer.from_pretrained(
"hf-internal-testing/unidiffuser-diffusers-test",
subfolder="text_tokenizer",
)
text_decoder = UniDiffuserTextDecoder.from_pretrained(
"hf-internal-testing/unidiffuser-diffusers-test",
subfolder="text_decoder",
)
components = {
"vae": vae,
"text_encoder": text_encoder,
"image_encoder": image_encoder,
"image_processor": image_processor,
"clip_tokenizer": clip_tokenizer,
"text_decoder": text_decoder,
"text_tokenizer": text_tokenizer,
"unet": unet,
"scheduler": scheduler,
}
return components
def get_dummy_inputs(self, device, seed=0):
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
image = image.cpu().permute(0, 2, 3, 1)[0]
image = Image.fromarray(np.uint8(image)).convert("RGB")
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "an elephant under the sea",
"image": image,
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
"output_type": "numpy",
}
return inputs
def get_fixed_latents(self, device, seed=0):
if type(device) == str:
device = torch.device(device)
generator = torch.Generator(device=device).manual_seed(seed)
# Hardcode the shapes for now.
prompt_latents = randn_tensor((1, 77, 32), generator=generator, device=device, dtype=torch.float32)
vae_latents = randn_tensor((1, 4, 16, 16), generator=generator, device=device, dtype=torch.float32)
clip_latents = randn_tensor((1, 1, 32), generator=generator, device=device, dtype=torch.float32)
latents = {
"prompt_latents": prompt_latents,
"vae_latents": vae_latents,
"clip_latents": clip_latents,
}
return latents
def get_dummy_inputs_with_latents(self, device, seed=0):
# image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
# image = image.cpu().permute(0, 2, 3, 1)[0]
# image = Image.fromarray(np.uint8(image)).convert("RGB")
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unidiffuser/unidiffuser_example_image.jpg",
)
image = image.resize((32, 32))
latents = self.get_fixed_latents(device, seed=seed)
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "an elephant under the sea",
"image": image,
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
"output_type": "numpy",
"prompt_latents": latents.get("prompt_latents"),
"vae_latents": latents.get("vae_latents"),
"clip_latents": latents.get("clip_latents"),
}
return inputs
def test_unidiffuser_default_joint_v0(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
unidiffuser_pipe = UniDiffuserPipeline(**components)
unidiffuser_pipe = unidiffuser_pipe.to(device)
unidiffuser_pipe.set_progress_bar_config(disable=None)
# Set mode to 'joint'
unidiffuser_pipe.set_joint_mode()
assert unidiffuser_pipe.mode == "joint"
# inputs = self.get_dummy_inputs(device)
inputs = self.get_dummy_inputs_with_latents(device)
# Delete prompt and image for joint inference.
del inputs["prompt"]
del inputs["image"]
sample = unidiffuser_pipe(**inputs)
image = sample.images
text = sample.text
assert image.shape == (1, 32, 32, 3)
image_slice = image[0, -3:, -3:, -1]
expected_img_slice = np.array([0.5760, 0.6270, 0.6571, 0.4965, 0.4638, 0.5663, 0.5254, 0.5068, 0.5716])
assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-3
expected_text_prefix = " no no no "
assert text[0][:10] == expected_text_prefix
def test_unidiffuser_default_joint_no_cfg_v0(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
unidiffuser_pipe = UniDiffuserPipeline(**components)
unidiffuser_pipe = unidiffuser_pipe.to(device)
unidiffuser_pipe.set_progress_bar_config(disable=None)
# Set mode to 'joint'
unidiffuser_pipe.set_joint_mode()
assert unidiffuser_pipe.mode == "joint"
# inputs = self.get_dummy_inputs(device)
inputs = self.get_dummy_inputs_with_latents(device)
# Delete prompt and image for joint inference.
del inputs["prompt"]
del inputs["image"]
# Set guidance scale to 1.0 to turn off CFG
inputs["guidance_scale"] = 1.0
sample = unidiffuser_pipe(**inputs)
image = sample.images
text = sample.text
assert image.shape == (1, 32, 32, 3)
image_slice = image[0, -3:, -3:, -1]
expected_img_slice = np.array([0.5760, 0.6270, 0.6571, 0.4965, 0.4638, 0.5663, 0.5254, 0.5068, 0.5716])
assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-3
expected_text_prefix = " no no no "
assert text[0][:10] == expected_text_prefix
def test_unidiffuser_default_text2img_v0(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
unidiffuser_pipe = UniDiffuserPipeline(**components)
unidiffuser_pipe = unidiffuser_pipe.to(device)
unidiffuser_pipe.set_progress_bar_config(disable=None)
# Set mode to 'text2img'
unidiffuser_pipe.set_text_to_image_mode()
assert unidiffuser_pipe.mode == "text2img"
inputs = self.get_dummy_inputs_with_latents(device)
# Delete image for text-conditioned image generation
del inputs["image"]
image = unidiffuser_pipe(**inputs).images
assert image.shape == (1, 32, 32, 3)
image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.5758, 0.6269, 0.6570, 0.4967, 0.4639, 0.5664, 0.5257, 0.5067, 0.5715])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_unidiffuser_default_image_0(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
unidiffuser_pipe = UniDiffuserPipeline(**components)
unidiffuser_pipe = unidiffuser_pipe.to(device)
unidiffuser_pipe.set_progress_bar_config(disable=None)
# Set mode to 'img'
unidiffuser_pipe.set_image_mode()
assert unidiffuser_pipe.mode == "img"
inputs = self.get_dummy_inputs(device)
# Delete prompt and image for unconditional ("marginal") text generation.
del inputs["prompt"]
del inputs["image"]
image = unidiffuser_pipe(**inputs).images
assert image.shape == (1, 32, 32, 3)
image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.5760, 0.6270, 0.6571, 0.4966, 0.4638, 0.5663, 0.5254, 0.5068, 0.5715])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_unidiffuser_default_text_v0(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
unidiffuser_pipe = UniDiffuserPipeline(**components)
unidiffuser_pipe = unidiffuser_pipe.to(device)
unidiffuser_pipe.set_progress_bar_config(disable=None)
# Set mode to 'img'
unidiffuser_pipe.set_text_mode()
assert unidiffuser_pipe.mode == "text"
inputs = self.get_dummy_inputs(device)
# Delete prompt and image for unconditional ("marginal") text generation.
del inputs["prompt"]
del inputs["image"]
text = unidiffuser_pipe(**inputs).text
expected_text_prefix = " no no no "
assert text[0][:10] == expected_text_prefix
def test_unidiffuser_default_img2text_v0(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
unidiffuser_pipe = UniDiffuserPipeline(**components)
unidiffuser_pipe = unidiffuser_pipe.to(device)
unidiffuser_pipe.set_progress_bar_config(disable=None)
# Set mode to 'img2text'
unidiffuser_pipe.set_image_to_text_mode()
assert unidiffuser_pipe.mode == "img2text"
inputs = self.get_dummy_inputs_with_latents(device)
# Delete text for image-conditioned text generation
del inputs["prompt"]
text = unidiffuser_pipe(**inputs).text
expected_text_prefix = " no no no "
assert text[0][:10] == expected_text_prefix
def test_unidiffuser_default_joint_v1(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unidiffuser_pipe = UniDiffuserPipeline.from_pretrained("hf-internal-testing/unidiffuser-test-v1")
unidiffuser_pipe = unidiffuser_pipe.to(device)
unidiffuser_pipe.set_progress_bar_config(disable=None)
# Set mode to 'joint'
unidiffuser_pipe.set_joint_mode()
assert unidiffuser_pipe.mode == "joint"
# inputs = self.get_dummy_inputs(device)
inputs = self.get_dummy_inputs_with_latents(device)
# Delete prompt and image for joint inference.
del inputs["prompt"]
del inputs["image"]
inputs["data_type"] = 1
sample = unidiffuser_pipe(**inputs)
image = sample.images
text = sample.text
assert image.shape == (1, 32, 32, 3)
image_slice = image[0, -3:, -3:, -1]
expected_img_slice = np.array([0.5760, 0.6270, 0.6571, 0.4965, 0.4638, 0.5663, 0.5254, 0.5068, 0.5716])
assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-3
expected_text_prefix = " no no no "
assert text[0][:10] == expected_text_prefix
def test_unidiffuser_default_text2img_v1(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unidiffuser_pipe = UniDiffuserPipeline.from_pretrained("hf-internal-testing/unidiffuser-test-v1")
unidiffuser_pipe = unidiffuser_pipe.to(device)
unidiffuser_pipe.set_progress_bar_config(disable=None)
# Set mode to 'text2img'
unidiffuser_pipe.set_text_to_image_mode()
assert unidiffuser_pipe.mode == "text2img"
inputs = self.get_dummy_inputs_with_latents(device)
# Delete image for text-conditioned image generation
del inputs["image"]
image = unidiffuser_pipe(**inputs).images
assert image.shape == (1, 32, 32, 3)
image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.5758, 0.6269, 0.6570, 0.4967, 0.4639, 0.5664, 0.5257, 0.5067, 0.5715])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_unidiffuser_default_img2text_v1(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unidiffuser_pipe = UniDiffuserPipeline.from_pretrained("hf-internal-testing/unidiffuser-test-v1")
unidiffuser_pipe = unidiffuser_pipe.to(device)
unidiffuser_pipe.set_progress_bar_config(disable=None)
# Set mode to 'img2text'
unidiffuser_pipe.set_image_to_text_mode()
assert unidiffuser_pipe.mode == "img2text"
inputs = self.get_dummy_inputs_with_latents(device)
# Delete text for image-conditioned text generation
del inputs["prompt"]
text = unidiffuser_pipe(**inputs).text
expected_text_prefix = " no no no "
assert text[0][:10] == expected_text_prefix
def test_unidiffuser_text2img_multiple_images(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
unidiffuser_pipe = UniDiffuserPipeline(**components)
unidiffuser_pipe = unidiffuser_pipe.to(device)
unidiffuser_pipe.set_progress_bar_config(disable=None)
# Set mode to 'text2img'
unidiffuser_pipe.set_text_to_image_mode()
assert unidiffuser_pipe.mode == "text2img"
inputs = self.get_dummy_inputs(device)
# Delete image for text-conditioned image generation
del inputs["image"]
inputs["num_images_per_prompt"] = 2
inputs["num_prompts_per_image"] = 3
image = unidiffuser_pipe(**inputs).images
assert image.shape == (2, 32, 32, 3)
def test_unidiffuser_img2text_multiple_prompts(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
unidiffuser_pipe = UniDiffuserPipeline(**components)
unidiffuser_pipe = unidiffuser_pipe.to(device)
unidiffuser_pipe.set_progress_bar_config(disable=None)
# Set mode to 'img2text'
unidiffuser_pipe.set_image_to_text_mode()
assert unidiffuser_pipe.mode == "img2text"
inputs = self.get_dummy_inputs(device)
# Delete text for image-conditioned text generation
del inputs["prompt"]
inputs["num_images_per_prompt"] = 2
inputs["num_prompts_per_image"] = 3
text = unidiffuser_pipe(**inputs).text
assert len(text) == 3
def test_unidiffuser_text2img_multiple_images_with_latents(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
unidiffuser_pipe = UniDiffuserPipeline(**components)
unidiffuser_pipe = unidiffuser_pipe.to(device)
unidiffuser_pipe.set_progress_bar_config(disable=None)
# Set mode to 'text2img'
unidiffuser_pipe.set_text_to_image_mode()
assert unidiffuser_pipe.mode == "text2img"
inputs = self.get_dummy_inputs_with_latents(device)
# Delete image for text-conditioned image generation
del inputs["image"]
inputs["num_images_per_prompt"] = 2
inputs["num_prompts_per_image"] = 3
image = unidiffuser_pipe(**inputs).images
assert image.shape == (2, 32, 32, 3)
def test_unidiffuser_img2text_multiple_prompts_with_latents(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
unidiffuser_pipe = UniDiffuserPipeline(**components)
unidiffuser_pipe = unidiffuser_pipe.to(device)
unidiffuser_pipe.set_progress_bar_config(disable=None)
# Set mode to 'img2text'
unidiffuser_pipe.set_image_to_text_mode()
assert unidiffuser_pipe.mode == "img2text"
inputs = self.get_dummy_inputs_with_latents(device)
# Delete text for image-conditioned text generation
del inputs["prompt"]
inputs["num_images_per_prompt"] = 2
inputs["num_prompts_per_image"] = 3
text = unidiffuser_pipe(**inputs).text
assert len(text) == 3
@require_torch_gpu
def test_unidiffuser_default_joint_v1_cuda_fp16(self):
device = "cuda"
unidiffuser_pipe = UniDiffuserPipeline.from_pretrained(
"hf-internal-testing/unidiffuser-test-v1", torch_dtype=torch.float16
)
unidiffuser_pipe = unidiffuser_pipe.to(device)
unidiffuser_pipe.set_progress_bar_config(disable=None)
# Set mode to 'joint'
unidiffuser_pipe.set_joint_mode()
assert unidiffuser_pipe.mode == "joint"
inputs = self.get_dummy_inputs_with_latents(device)
# Delete prompt and image for joint inference.
del inputs["prompt"]
del inputs["image"]
inputs["data_type"] = 1
sample = unidiffuser_pipe(**inputs)
image = sample.images
text = sample.text
assert image.shape == (1, 32, 32, 3)
image_slice = image[0, -3:, -3:, -1]
expected_img_slice = np.array([0.5049, 0.5498, 0.5854, 0.3052, 0.4460, 0.6489, 0.5122, 0.4810, 0.6138])
assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-3
expected_text_prefix = '" This This'
assert text[0][: len(expected_text_prefix)] == expected_text_prefix
@require_torch_gpu
def test_unidiffuser_default_text2img_v1_cuda_fp16(self):
device = "cuda"
unidiffuser_pipe = UniDiffuserPipeline.from_pretrained(
"hf-internal-testing/unidiffuser-test-v1", torch_dtype=torch.float16
)
unidiffuser_pipe = unidiffuser_pipe.to(device)
unidiffuser_pipe.set_progress_bar_config(disable=None)
# Set mode to 'text2img'
unidiffuser_pipe.set_text_to_image_mode()
assert unidiffuser_pipe.mode == "text2img"
inputs = self.get_dummy_inputs_with_latents(device)
# Delete prompt and image for joint inference.
del inputs["image"]
inputs["data_type"] = 1
sample = unidiffuser_pipe(**inputs)
image = sample.images
assert image.shape == (1, 32, 32, 3)
image_slice = image[0, -3:, -3:, -1]
expected_img_slice = np.array([0.5054, 0.5498, 0.5854, 0.3052, 0.4458, 0.6489, 0.5122, 0.4810, 0.6138])
assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-3
@require_torch_gpu
def test_unidiffuser_default_img2text_v1_cuda_fp16(self):
device = "cuda"
unidiffuser_pipe = UniDiffuserPipeline.from_pretrained(
"hf-internal-testing/unidiffuser-test-v1", torch_dtype=torch.float16
)
unidiffuser_pipe = unidiffuser_pipe.to(device)
unidiffuser_pipe.set_progress_bar_config(disable=None)
# Set mode to 'img2text'
unidiffuser_pipe.set_image_to_text_mode()
assert unidiffuser_pipe.mode == "img2text"
inputs = self.get_dummy_inputs_with_latents(device)
# Delete prompt and image for joint inference.
del inputs["prompt"]
inputs["data_type"] = 1
text = unidiffuser_pipe(**inputs).text
expected_text_prefix = '" This This'
assert text[0][: len(expected_text_prefix)] == expected_text_prefix
@slow
@require_torch_gpu
class UniDiffuserPipelineSlowTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, device, seed=0, generate_latents=False):
generator = torch.manual_seed(seed)
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unidiffuser/unidiffuser_example_image.jpg"
)
inputs = {
"prompt": "an elephant under the sea",
"image": image,
"generator": generator,
"num_inference_steps": 3,
"guidance_scale": 8.0,
"output_type": "numpy",
}
if generate_latents:
latents = self.get_fixed_latents(device, seed=seed)
for latent_name, latent_tensor in latents.items():
inputs[latent_name] = latent_tensor
return inputs
def get_fixed_latents(self, device, seed=0):
if type(device) == str:
device = torch.device(device)
latent_device = torch.device("cpu")
generator = torch.Generator(device=latent_device).manual_seed(seed)
# Hardcode the shapes for now.
prompt_latents = randn_tensor((1, 77, 768), generator=generator, device=device, dtype=torch.float32)
vae_latents = randn_tensor((1, 4, 64, 64), generator=generator, device=device, dtype=torch.float32)
clip_latents = randn_tensor((1, 1, 512), generator=generator, device=device, dtype=torch.float32)
# Move latents onto desired device.
prompt_latents = prompt_latents.to(device)
vae_latents = vae_latents.to(device)
clip_latents = clip_latents.to(device)
latents = {
"prompt_latents": prompt_latents,
"vae_latents": vae_latents,
"clip_latents": clip_latents,
}
return latents
def test_unidiffuser_default_joint_v1(self):
pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1")
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
# inputs = self.get_dummy_inputs(device)
inputs = self.get_inputs(device=torch_device, generate_latents=True)
# Delete prompt and image for joint inference.
del inputs["prompt"]
del inputs["image"]
sample = pipe(**inputs)
image = sample.images
text = sample.text
assert image.shape == (1, 512, 512, 3)
image_slice = image[0, -3:, -3:, -1]
expected_img_slice = np.array([0.2402, 0.2375, 0.2285, 0.2378, 0.2407, 0.2263, 0.2354, 0.2307, 0.2520])
assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-1
expected_text_prefix = "A living room"
assert text[0][: len(expected_text_prefix)] == expected_text_prefix
def test_unidiffuser_default_text2img_v1(self):
pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1")
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
inputs = self.get_inputs(device=torch_device, generate_latents=True)
del inputs["image"]
sample = pipe(**inputs)
image = sample.images
assert image.shape == (1, 512, 512, 3)
image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.0242, 0.0103, 0.0022, 0.0129, 0.0000, 0.0090, 0.0376, 0.0508, 0.0005])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
def test_unidiffuser_default_img2text_v1(self):
pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1")
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
inputs = self.get_inputs(device=torch_device, generate_latents=True)
del inputs["prompt"]
sample = pipe(**inputs)
text = sample.text
expected_text_prefix = "An astronaut"
assert text[0][: len(expected_text_prefix)] == expected_text_prefix
def test_unidiffuser_default_joint_v1_fp16(self):
pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1", torch_dtype=torch.float16)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
# inputs = self.get_dummy_inputs(device)
inputs = self.get_inputs(device=torch_device, generate_latents=True)
# Delete prompt and image for joint inference.
del inputs["prompt"]
del inputs["image"]
sample = pipe(**inputs)
image = sample.images
text = sample.text
assert image.shape == (1, 512, 512, 3)
image_slice = image[0, -3:, -3:, -1]
expected_img_slice = np.array([0.2402, 0.2375, 0.2285, 0.2378, 0.2407, 0.2263, 0.2354, 0.2307, 0.2520])
assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-1
expected_text_prefix = "A living room"
assert text[0][: len(expected_text_prefix)] == expected_text_prefix
def test_unidiffuser_default_text2img_v1_fp16(self):
pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1", torch_dtype=torch.float16)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
inputs = self.get_inputs(device=torch_device, generate_latents=True)
del inputs["image"]
sample = pipe(**inputs)
image = sample.images
assert image.shape == (1, 512, 512, 3)
image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.0242, 0.0103, 0.0022, 0.0129, 0.0000, 0.0090, 0.0376, 0.0508, 0.0005])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
def test_unidiffuser_default_img2text_v1_fp16(self):
pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1", torch_dtype=torch.float16)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
inputs = self.get_inputs(device=torch_device, generate_latents=True)
del inputs["prompt"]
sample = pipe(**inputs)
text = sample.text
expected_text_prefix = "An astronaut"
assert text[0][: len(expected_text_prefix)] == expected_text_prefix
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment