Commit 3d1b9667 authored by wangwei990215's avatar wangwei990215
Browse files

更新diffusers文件夹

parent b6a53272
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional, Union
import safetensors
import torch
from huggingface_hub.utils import validate_hf_hub_args
from torch import nn
from ..utils import _get_model_file, is_accelerate_available, is_transformers_available, logging
if is_transformers_available():
from transformers import PreTrainedModel, PreTrainedTokenizer
if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
logger = logging.get_logger(__name__)
TEXT_INVERSION_NAME = "learned_embeds.bin"
TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
@validate_hf_hub_args
def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs):
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {
"file_type": "text_inversion",
"framework": "pytorch",
}
state_dicts = []
for pretrained_model_name_or_path in pretrained_model_name_or_paths:
if not isinstance(pretrained_model_name_or_path, (dict, torch.Tensor)):
# 3.1. Load textual inversion file
model_file = None
# Let's first try to load .safetensors weights
if (use_safetensors and weight_name is None) or (
weight_name is not None and weight_name.endswith(".safetensors")
):
try:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
except Exception as e:
if not allow_pickle:
raise e
model_file = None
if model_file is None:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=weight_name or TEXT_INVERSION_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = torch.load(model_file, map_location="cpu")
else:
state_dict = pretrained_model_name_or_path
state_dicts.append(state_dict)
return state_dicts
class TextualInversionLoaderMixin:
r"""
Load Textual Inversion tokens and embeddings to the tokenizer and text encoder.
"""
def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): # noqa: F821
r"""
Processes prompts that include a special token corresponding to a multi-vector textual inversion embedding to
be replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
inversion token or if the textual inversion token is a single vector, the input prompt is returned.
Parameters:
prompt (`str` or list of `str`):
The prompt or prompts to guide the image generation.
tokenizer (`PreTrainedTokenizer`):
The tokenizer responsible for encoding the prompt into input tokens.
Returns:
`str` or list of `str`: The converted prompt
"""
if not isinstance(prompt, List):
prompts = [prompt]
else:
prompts = prompt
prompts = [self._maybe_convert_prompt(p, tokenizer) for p in prompts]
if not isinstance(prompt, List):
return prompts[0]
return prompts
def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"): # noqa: F821
r"""
Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds
to a multi-vector textual inversion embedding, this function will process the prompt so that the special token
is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
inversion token or a textual inversion token that is a single vector, the input prompt is simply returned.
Parameters:
prompt (`str`):
The prompt to guide the image generation.
tokenizer (`PreTrainedTokenizer`):
The tokenizer responsible for encoding the prompt into input tokens.
Returns:
`str`: The converted prompt
"""
tokens = tokenizer.tokenize(prompt)
unique_tokens = set(tokens)
for token in unique_tokens:
if token in tokenizer.added_tokens_encoder:
replacement = token
i = 1
while f"{token}_{i}" in tokenizer.added_tokens_encoder:
replacement += f" {token}_{i}"
i += 1
prompt = prompt.replace(token, replacement)
return prompt
def _check_text_inv_inputs(self, tokenizer, text_encoder, pretrained_model_name_or_paths, tokens):
if tokenizer is None:
raise ValueError(
f"{self.__class__.__name__} requires `self.tokenizer` or passing a `tokenizer` of type `PreTrainedTokenizer` for calling"
f" `{self.load_textual_inversion.__name__}`"
)
if text_encoder is None:
raise ValueError(
f"{self.__class__.__name__} requires `self.text_encoder` or passing a `text_encoder` of type `PreTrainedModel` for calling"
f" `{self.load_textual_inversion.__name__}`"
)
if len(pretrained_model_name_or_paths) > 1 and len(pretrained_model_name_or_paths) != len(tokens):
raise ValueError(
f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)} "
f"Make sure both lists have the same length."
)
valid_tokens = [t for t in tokens if t is not None]
if len(set(valid_tokens)) < len(valid_tokens):
raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}")
@staticmethod
def _retrieve_tokens_and_embeddings(tokens, state_dicts, tokenizer):
all_tokens = []
all_embeddings = []
for state_dict, token in zip(state_dicts, tokens):
if isinstance(state_dict, torch.Tensor):
if token is None:
raise ValueError(
"You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`."
)
loaded_token = token
embedding = state_dict
elif len(state_dict) == 1:
# diffusers
loaded_token, embedding = next(iter(state_dict.items()))
elif "string_to_param" in state_dict:
# A1111
loaded_token = state_dict["name"]
embedding = state_dict["string_to_param"]["*"]
else:
raise ValueError(
f"Loaded state dictionary is incorrect: {state_dict}. \n\n"
"Please verify that the loaded state dictionary of the textual embedding either only has a single key or includes the `string_to_param`"
" input key."
)
if token is not None and loaded_token != token:
logger.info(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.")
else:
token = loaded_token
if token in tokenizer.get_vocab():
raise ValueError(
f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
)
all_tokens.append(token)
all_embeddings.append(embedding)
return all_tokens, all_embeddings
@staticmethod
def _extend_tokens_and_embeddings(tokens, embeddings, tokenizer):
all_tokens = []
all_embeddings = []
for embedding, token in zip(embeddings, tokens):
if f"{token}_1" in tokenizer.get_vocab():
multi_vector_tokens = [token]
i = 1
while f"{token}_{i}" in tokenizer.added_tokens_encoder:
multi_vector_tokens.append(f"{token}_{i}")
i += 1
raise ValueError(
f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder."
)
is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1
if is_multi_vector:
all_tokens += [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])]
all_embeddings += [e for e in embedding] # noqa: C416
else:
all_tokens += [token]
all_embeddings += [embedding[0]] if len(embedding.shape) > 1 else [embedding]
return all_tokens, all_embeddings
@validate_hf_hub_args
def load_textual_inversion(
self,
pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
token: Optional[Union[str, List[str]]] = None,
tokenizer: Optional["PreTrainedTokenizer"] = None, # noqa: F821
text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
**kwargs,
):
r"""
Load Textual Inversion embeddings into the text encoder of [`StableDiffusionPipeline`] (both 🤗 Diffusers and
Automatic1111 formats are supported).
Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` or `Dict` or `List[Dict]`):
Can be either one of the following or a list of them:
- A string, the *model id* (for example `sd-concepts-library/low-poly-hd-logos-icons`) of a
pretrained model hosted on the Hub.
- A path to a *directory* (for example `./my_text_inversion_directory/`) containing the textual
inversion weights.
- A path to a *file* (for example `./my_text_inversions.pt`) containing textual inversion weights.
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
token (`str` or `List[str]`, *optional*):
Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a
list, then `token` must also be a list of equal length.
text_encoder ([`~transformers.CLIPTextModel`], *optional*):
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
If not specified, function will take self.tokenizer.
tokenizer ([`~transformers.CLIPTokenizer`], *optional*):
A `CLIPTokenizer` to tokenize text. If not specified, function will take self.tokenizer.
weight_name (`str`, *optional*):
Name of a custom weight file. This should be used when:
- The saved textual inversion file is in 🤗 Diffusers format, but was saved under a specific weight
name such as `text_inv.bin`.
- The saved textual inversion file is in the Automatic1111 format.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
incompletely downloaded files are deleted.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
mirror (`str`, *optional*):
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
information.
Example:
To load a Textual Inversion embedding vector in 🤗 Diffusers format:
```py
from diffusers import StableDiffusionPipeline
import torch
model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
pipe.load_textual_inversion("sd-concepts-library/cat-toy")
prompt = "A <cat-toy> backpack"
image = pipe(prompt, num_inference_steps=50).images[0]
image.save("cat-backpack.png")
```
To load a Textual Inversion embedding vector in Automatic1111 format, make sure to download the vector first
(for example from [civitAI](https://civitai.com/models/3036?modelVersionId=9857)) and then load the vector
locally:
```py
from diffusers import StableDiffusionPipeline
import torch
model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
pipe.load_textual_inversion("./charturnerv2.pt", token="charturnerv2")
prompt = "charturnerv2, multiple views of the same character in the same outfit, a character turnaround of a woman wearing a black jacket and red shirt, best quality, intricate details."
image = pipe(prompt, num_inference_steps=50).images[0]
image.save("character.png")
```
"""
# 1. Set correct tokenizer and text encoder
tokenizer = tokenizer or getattr(self, "tokenizer", None)
text_encoder = text_encoder or getattr(self, "text_encoder", None)
# 2. Normalize inputs
pretrained_model_name_or_paths = (
[pretrained_model_name_or_path]
if not isinstance(pretrained_model_name_or_path, list)
else pretrained_model_name_or_path
)
tokens = [token] if not isinstance(token, list) else token
if tokens[0] is None:
tokens = tokens * len(pretrained_model_name_or_paths)
# 3. Check inputs
self._check_text_inv_inputs(tokenizer, text_encoder, pretrained_model_name_or_paths, tokens)
# 4. Load state dicts of textual embeddings
state_dicts = load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
# 4.1 Handle the special case when state_dict is a tensor that contains n embeddings for n tokens
if len(tokens) > 1 and len(state_dicts) == 1:
if isinstance(state_dicts[0], torch.Tensor):
state_dicts = list(state_dicts[0])
if len(tokens) != len(state_dicts):
raise ValueError(
f"You have passed a state_dict contains {len(state_dicts)} embeddings, and list of tokens of length {len(tokens)} "
f"Make sure both have the same length."
)
# 4. Retrieve tokens and embeddings
tokens, embeddings = self._retrieve_tokens_and_embeddings(tokens, state_dicts, tokenizer)
# 5. Extend tokens and embeddings for multi vector
tokens, embeddings = self._extend_tokens_and_embeddings(tokens, embeddings, tokenizer)
# 6. Make sure all embeddings have the correct size
expected_emb_dim = text_encoder.get_input_embeddings().weight.shape[-1]
if any(expected_emb_dim != emb.shape[-1] for emb in embeddings):
raise ValueError(
"Loaded embeddings are of incorrect shape. Expected each textual inversion embedding "
"to be of shape {input_embeddings.shape[-1]}, but are {embeddings.shape[-1]} "
)
# 7. Now we can be sure that loading the embedding matrix works
# < Unsafe code:
# 7.1 Offload all hooks in case the pipeline was cpu offloaded before make sure, we offload and onload again
is_model_cpu_offload = False
is_sequential_cpu_offload = False
for _, component in self.components.items():
if isinstance(component, nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
logger.info(
"Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
# 7.2 save expected device and dtype
device = text_encoder.device
dtype = text_encoder.dtype
# 7.3 Increase token embedding matrix
text_encoder.resize_token_embeddings(len(tokenizer) + len(tokens))
input_embeddings = text_encoder.get_input_embeddings().weight
# 7.4 Load token and embedding
for token, embedding in zip(tokens, embeddings):
# add tokens and get ids
tokenizer.add_tokens(token)
token_id = tokenizer.convert_tokens_to_ids(token)
input_embeddings.data[token_id] = embedding
logger.info(f"Loaded textual inversion embedding for {token}.")
input_embeddings.to(dtype=dtype, device=device)
# 7.5 Offload the model again
if is_model_cpu_offload:
self.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
self.enable_sequential_cpu_offload()
# / Unsafe Code >
def unload_textual_inversion(
self,
tokens: Optional[Union[str, List[str]]] = None,
tokenizer: Optional["PreTrainedTokenizer"] = None,
text_encoder: Optional["PreTrainedModel"] = None,
):
r"""
Unload Textual Inversion embeddings from the text encoder of [`StableDiffusionPipeline`]
Example:
```py
from diffusers import AutoPipelineForText2Image
import torch
pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5")
# Example 1
pipeline.load_textual_inversion("sd-concepts-library/gta5-artwork")
pipeline.load_textual_inversion("sd-concepts-library/moeb-style")
# Remove all token embeddings
pipeline.unload_textual_inversion()
# Example 2
pipeline.load_textual_inversion("sd-concepts-library/moeb-style")
pipeline.load_textual_inversion("sd-concepts-library/gta5-artwork")
# Remove just one token
pipeline.unload_textual_inversion("<moe-bius>")
# Example 3: unload from SDXL
pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
embedding_path = hf_hub_download(repo_id="linoyts/web_y2k", filename="web_y2k_emb.safetensors", repo_type="model")
# load embeddings to the text encoders
state_dict = load_file(embedding_path)
# load embeddings of text_encoder 1 (CLIP ViT-L/14)
pipeline.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
# load embeddings of text_encoder 2 (CLIP ViT-G/14)
pipeline.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
# Unload explicitly from both text encoders abd tokenizers
pipeline.unload_textual_inversion(tokens=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
pipeline.unload_textual_inversion(tokens=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
```
"""
tokenizer = tokenizer or getattr(self, "tokenizer", None)
text_encoder = text_encoder or getattr(self, "text_encoder", None)
# Get textual inversion tokens and ids
token_ids = []
last_special_token_id = None
if tokens:
if isinstance(tokens, str):
tokens = [tokens]
for added_token_id, added_token in tokenizer.added_tokens_decoder.items():
if not added_token.special:
if added_token.content in tokens:
token_ids.append(added_token_id)
else:
last_special_token_id = added_token_id
if len(token_ids) == 0:
raise ValueError("No tokens to remove found")
else:
tokens = []
for added_token_id, added_token in tokenizer.added_tokens_decoder.items():
if not added_token.special:
token_ids.append(added_token_id)
tokens.append(added_token.content)
else:
last_special_token_id = added_token_id
# Delete from tokenizer
for token_id, token_to_remove in zip(token_ids, tokens):
del tokenizer._added_tokens_decoder[token_id]
del tokenizer._added_tokens_encoder[token_to_remove]
# Make all token ids sequential in tokenizer
key_id = 1
for token_id in tokenizer.added_tokens_decoder:
if token_id > last_special_token_id and token_id > last_special_token_id + key_id:
token = tokenizer._added_tokens_decoder[token_id]
tokenizer._added_tokens_decoder[last_special_token_id + key_id] = token
del tokenizer._added_tokens_decoder[token_id]
tokenizer._added_tokens_encoder[token.content] = last_special_token_id + key_id
key_id += 1
tokenizer._update_trie()
# Delete from text encoder
text_embedding_dim = text_encoder.get_input_embeddings().embedding_dim
temp_text_embedding_weights = text_encoder.get_input_embeddings().weight
text_embedding_weights = temp_text_embedding_weights[: last_special_token_id + 1]
to_append = []
for i in range(last_special_token_id + 1, temp_text_embedding_weights.shape[0]):
if i not in token_ids:
to_append.append(temp_text_embedding_weights[i].unsqueeze(0))
if len(to_append) > 0:
to_append = torch.cat(to_append, dim=0)
text_embedding_weights = torch.cat([text_embedding_weights, to_append], dim=0)
text_embeddings_filtered = nn.Embedding(text_embedding_weights.shape[0], text_embedding_dim)
text_embeddings_filtered.weight.data = text_embedding_weights
text_encoder.set_input_embeddings(text_embeddings_filtered)
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
from collections import defaultdict
from contextlib import nullcontext
from functools import partial
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
import safetensors
import torch
import torch.nn.functional as F
from huggingface_hub.utils import validate_hf_hub_args
from torch import nn
from ..models.embeddings import (
ImageProjection,
IPAdapterFullImageProjection,
IPAdapterPlusImageProjection,
MultiIPAdapterImageProjection,
)
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import (
USE_PEFT_BACKEND,
_get_model_file,
delete_adapter_layers,
is_accelerate_available,
is_torch_version,
logging,
set_adapter_layers,
set_weights_and_activate_adapters,
)
from .single_file_utils import (
convert_stable_cascade_unet_single_file_to_diffusers,
infer_stable_cascade_single_file_config,
load_single_file_model_checkpoint,
)
from .utils import AttnProcsLayers
if is_accelerate_available():
from accelerate import init_empty_weights
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
logger = logging.get_logger(__name__)
TEXT_ENCODER_NAME = "text_encoder"
UNET_NAME = "unet"
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
class UNet2DConditionLoadersMixin:
"""
Load LoRA layers into a [`UNet2DCondtionModel`].
"""
text_encoder_name = TEXT_ENCODER_NAME
unet_name = UNET_NAME
@validate_hf_hub_args
def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
r"""
Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be
defined in
[`attention_processor.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py)
and be a `torch.nn.Module` class.
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
Can be either:
- A string, the model id (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a directory (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
incompletely downloaded files are deleted.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
mirror (`str`, *optional*):
Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
information.
Example:
```py
from diffusers import AutoPipelineForText2Image
import torch
pipeline = AutoPipelineForText2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.unet.load_attn_procs(
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
)
```
"""
from ..models.attention_processor import CustomDiffusionAttnProcessor
from ..models.lora import LoRACompatibleConv, LoRACompatibleLinear, LoRAConv2dLayer, LoRALinearLayer
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
network_alphas = kwargs.pop("network_alphas", None)
_pipeline = kwargs.pop("_pipeline", None)
is_network_alphas_none = network_alphas is None
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
model_file = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
# Let's first try to load .safetensors weights
if (use_safetensors and weight_name is None) or (
weight_name is not None and weight_name.endswith(".safetensors")
):
try:
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
except IOError as e:
if not allow_pickle:
raise e
# try loading non-safetensors weights
pass
if model_file is None:
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name or LORA_WEIGHT_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = torch.load(model_file, map_location="cpu")
else:
state_dict = pretrained_model_name_or_path_or_dict
# fill attn processors
lora_layers_list = []
is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys()) and not USE_PEFT_BACKEND
is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
if is_lora:
# correct keys
state_dict, network_alphas = self.convert_state_dict_legacy_attn_format(state_dict, network_alphas)
if network_alphas is not None:
network_alphas_keys = list(network_alphas.keys())
used_network_alphas_keys = set()
lora_grouped_dict = defaultdict(dict)
mapped_network_alphas = {}
all_keys = list(state_dict.keys())
for key in all_keys:
value = state_dict.pop(key)
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
lora_grouped_dict[attn_processor_key][sub_key] = value
# Create another `mapped_network_alphas` dictionary so that we can properly map them.
if network_alphas is not None:
for k in network_alphas_keys:
if k.replace(".alpha", "") in key:
mapped_network_alphas.update({attn_processor_key: network_alphas.get(k)})
used_network_alphas_keys.add(k)
if not is_network_alphas_none:
if len(set(network_alphas_keys) - used_network_alphas_keys) > 0:
raise ValueError(
f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}"
)
if len(state_dict) > 0:
raise ValueError(
f"The `state_dict` has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}"
)
for key, value_dict in lora_grouped_dict.items():
attn_processor = self
for sub_key in key.split("."):
attn_processor = getattr(attn_processor, sub_key)
# Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers
# or add_{k,v,q,out_proj}_proj_lora layers.
rank = value_dict["lora.down.weight"].shape[0]
if isinstance(attn_processor, LoRACompatibleConv):
in_features = attn_processor.in_channels
out_features = attn_processor.out_channels
kernel_size = attn_processor.kernel_size
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
with ctx():
lora = LoRAConv2dLayer(
in_features=in_features,
out_features=out_features,
rank=rank,
kernel_size=kernel_size,
stride=attn_processor.stride,
padding=attn_processor.padding,
network_alpha=mapped_network_alphas.get(key),
)
elif isinstance(attn_processor, LoRACompatibleLinear):
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
with ctx():
lora = LoRALinearLayer(
attn_processor.in_features,
attn_processor.out_features,
rank,
mapped_network_alphas.get(key),
)
else:
raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.")
value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()}
lora_layers_list.append((attn_processor, lora))
if low_cpu_mem_usage:
device = next(iter(value_dict.values())).device
dtype = next(iter(value_dict.values())).dtype
load_model_dict_into_meta(lora, value_dict, device=device, dtype=dtype)
else:
lora.load_state_dict(value_dict)
elif is_custom_diffusion:
attn_processors = {}
custom_diffusion_grouped_dict = defaultdict(dict)
for key, value in state_dict.items():
if len(value) == 0:
custom_diffusion_grouped_dict[key] = {}
else:
if "to_out" in key:
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
else:
attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:])
custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value
for key, value_dict in custom_diffusion_grouped_dict.items():
if len(value_dict) == 0:
attn_processors[key] = CustomDiffusionAttnProcessor(
train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None
)
else:
cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1]
hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0]
train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False
attn_processors[key] = CustomDiffusionAttnProcessor(
train_kv=True,
train_q_out=train_q_out,
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
)
attn_processors[key].load_state_dict(value_dict)
elif USE_PEFT_BACKEND:
# In that case we have nothing to do as loading the adapter weights is already handled above by `set_peft_model_state_dict`
# on the Unet
pass
else:
raise ValueError(
f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
)
# <Unsafe code
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
# Now we remove any existing hooks to
is_model_cpu_offload = False
is_sequential_cpu_offload = False
# For PEFT backend the Unet is already offloaded at this stage as it is handled inside `load_lora_weights_into_unet`
if not USE_PEFT_BACKEND:
if _pipeline is not None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
# only custom diffusion needs to set attn processors
if is_custom_diffusion:
self.set_attn_processor(attn_processors)
# set lora layers
for target_module, lora_layer in lora_layers_list:
target_module.set_lora_layer(lora_layer)
self.to(dtype=self.dtype, device=self.device)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas):
is_new_lora_format = all(
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
)
if is_new_lora_format:
# Strip the `"unet"` prefix.
is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys())
if is_text_encoder_present:
warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)."
logger.warning(warn_message)
unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)]
state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
# change processor format to 'pure' LoRACompatibleLinear format
if any("processor" in k.split(".") for k in state_dict.keys()):
def format_to_lora_compatible(key):
if "processor" not in key.split("."):
return key
return key.replace(".processor", "").replace("to_out_lora", "to_out.0.lora").replace("_lora", ".lora")
state_dict = {format_to_lora_compatible(k): v for k, v in state_dict.items()}
if network_alphas is not None:
network_alphas = {format_to_lora_compatible(k): v for k, v in network_alphas.items()}
return state_dict, network_alphas
def save_attn_procs(
self,
save_directory: Union[str, os.PathLike],
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
**kwargs,
):
r"""
Save attention processor layers to a directory so that it can be reloaded with the
[`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`] method.
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to save an attention processor to (will be created if it doesn't exist).
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful during distributed training and you
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
process to avoid race conditions.
save_function (`Callable`):
The function to use to save the state dictionary. Useful during distributed training when you need to
replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or with `pickle`.
Example:
```py
import torch
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
torch_dtype=torch.float16,
).to("cuda")
pipeline.unet.load_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin")
pipeline.unet.save_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin")
```
"""
from ..models.attention_processor import (
CustomDiffusionAttnProcessor,
CustomDiffusionAttnProcessor2_0,
CustomDiffusionXFormersAttnProcessor,
)
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
if save_function is None:
if safe_serialization:
def save_function(weights, filename):
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
else:
save_function = torch.save
os.makedirs(save_directory, exist_ok=True)
is_custom_diffusion = any(
isinstance(
x,
(CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor),
)
for (_, x) in self.attn_processors.items()
)
if is_custom_diffusion:
model_to_save = AttnProcsLayers(
{
y: x
for (y, x) in self.attn_processors.items()
if isinstance(
x,
(
CustomDiffusionAttnProcessor,
CustomDiffusionAttnProcessor2_0,
CustomDiffusionXFormersAttnProcessor,
),
)
}
)
state_dict = model_to_save.state_dict()
for name, attn in self.attn_processors.items():
if len(attn.state_dict()) == 0:
state_dict[name] = {}
else:
model_to_save = AttnProcsLayers(self.attn_processors)
state_dict = model_to_save.state_dict()
if weight_name is None:
if safe_serialization:
weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE
else:
weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME if is_custom_diffusion else LORA_WEIGHT_NAME
# Save the model
save_path = Path(save_directory, weight_name).as_posix()
save_function(state_dict, save_path)
logger.info(f"Model weights saved in {save_path}")
def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None):
self.lora_scale = lora_scale
self._safe_fusing = safe_fusing
self.apply(partial(self._fuse_lora_apply, adapter_names=adapter_names))
def _fuse_lora_apply(self, module, adapter_names=None):
if not USE_PEFT_BACKEND:
if hasattr(module, "_fuse_lora"):
module._fuse_lora(self.lora_scale, self._safe_fusing)
if adapter_names is not None:
raise ValueError(
"The `adapter_names` argument is not supported in your environment. Please switch"
" to PEFT backend to use this argument by installing latest PEFT and transformers."
" `pip install -U peft transformers`"
)
else:
from peft.tuners.tuners_utils import BaseTunerLayer
merge_kwargs = {"safe_merge": self._safe_fusing}
if isinstance(module, BaseTunerLayer):
if self.lora_scale != 1.0:
module.scale_layer(self.lora_scale)
# For BC with prevous PEFT versions, we need to check the signature
# of the `merge` method to see if it supports the `adapter_names` argument.
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
if "adapter_names" in supported_merge_kwargs:
merge_kwargs["adapter_names"] = adapter_names
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
raise ValueError(
"The `adapter_names` argument is not supported with your PEFT version. Please upgrade"
" to the latest version of PEFT. `pip install -U peft`"
)
module.merge(**merge_kwargs)
def unfuse_lora(self):
self.apply(self._unfuse_lora_apply)
def _unfuse_lora_apply(self, module):
if not USE_PEFT_BACKEND:
if hasattr(module, "_unfuse_lora"):
module._unfuse_lora()
else:
from peft.tuners.tuners_utils import BaseTunerLayer
if isinstance(module, BaseTunerLayer):
module.unmerge()
def set_adapters(
self,
adapter_names: Union[List[str], str],
weights: Optional[Union[List[float], float]] = None,
):
"""
Set the currently active adapters for use in the UNet.
Args:
adapter_names (`List[str]` or `str`):
The names of the adapters to use.
adapter_weights (`Union[List[float], float]`, *optional*):
The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the
adapters.
Example:
```py
from diffusers import AutoPipelineForText2Image
import torch
pipeline = AutoPipelineForText2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights(
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
)
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
```
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for `set_adapters()`.")
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
if weights is None:
weights = [1.0] * len(adapter_names)
elif isinstance(weights, float):
weights = [weights] * len(adapter_names)
if len(adapter_names) != len(weights):
raise ValueError(
f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}."
)
set_weights_and_activate_adapters(self, adapter_names, weights)
def disable_lora(self):
"""
Disable the UNet's active LoRA layers.
Example:
```py
from diffusers import AutoPipelineForText2Image
import torch
pipeline = AutoPipelineForText2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights(
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
)
pipeline.disable_lora()
```
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
set_adapter_layers(self, enabled=False)
def enable_lora(self):
"""
Enable the UNet's active LoRA layers.
Example:
```py
from diffusers import AutoPipelineForText2Image
import torch
pipeline = AutoPipelineForText2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights(
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
)
pipeline.enable_lora()
```
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
set_adapter_layers(self, enabled=True)
def delete_adapters(self, adapter_names: Union[List[str], str]):
"""
Delete an adapter's LoRA layers from the UNet.
Args:
adapter_names (`Union[List[str], str]`):
The names (single string or list of strings) of the adapter to delete.
Example:
```py
from diffusers import AutoPipelineForText2Image
import torch
pipeline = AutoPipelineForText2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights(
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic"
)
pipeline.delete_adapters("cinematic")
```
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
if isinstance(adapter_names, str):
adapter_names = [adapter_names]
for adapter_name in adapter_names:
delete_adapter_layers(self, adapter_name)
# Pop also the corresponding adapter from the config
if hasattr(self, "peft_config"):
self.peft_config.pop(adapter_name, None)
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=False):
if low_cpu_mem_usage:
if is_accelerate_available():
from accelerate import init_empty_weights
else:
low_cpu_mem_usage = False
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
updated_state_dict = {}
image_projection = None
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
if "proj.weight" in state_dict:
# IP-Adapter
num_image_text_embeds = 4
clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
cross_attention_dim = state_dict["proj.weight"].shape[0] // 4
with init_context():
image_projection = ImageProjection(
cross_attention_dim=cross_attention_dim,
image_embed_dim=clip_embeddings_dim,
num_image_text_embeds=num_image_text_embeds,
)
for key, value in state_dict.items():
diffusers_name = key.replace("proj", "image_embeds")
updated_state_dict[diffusers_name] = value
elif "proj.3.weight" in state_dict:
# IP-Adapter Full
clip_embeddings_dim = state_dict["proj.0.weight"].shape[0]
cross_attention_dim = state_dict["proj.3.weight"].shape[0]
with init_context():
image_projection = IPAdapterFullImageProjection(
cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim
)
for key, value in state_dict.items():
diffusers_name = key.replace("proj.0", "ff.net.0.proj")
diffusers_name = diffusers_name.replace("proj.2", "ff.net.2")
diffusers_name = diffusers_name.replace("proj.3", "norm")
updated_state_dict[diffusers_name] = value
else:
# IP-Adapter Plus
num_image_text_embeds = state_dict["latents"].shape[1]
embed_dims = state_dict["proj_in.weight"].shape[1]
output_dims = state_dict["proj_out.weight"].shape[0]
hidden_dims = state_dict["latents"].shape[2]
heads = state_dict["layers.0.0.to_q.weight"].shape[0] // 64
with init_context():
image_projection = IPAdapterPlusImageProjection(
embed_dims=embed_dims,
output_dims=output_dims,
hidden_dims=hidden_dims,
heads=heads,
num_queries=num_image_text_embeds,
)
for key, value in state_dict.items():
diffusers_name = key.replace("0.to", "2.to")
diffusers_name = diffusers_name.replace("1.0.weight", "3.0.weight")
diffusers_name = diffusers_name.replace("1.0.bias", "3.0.bias")
diffusers_name = diffusers_name.replace("1.1.weight", "3.1.net.0.proj.weight")
diffusers_name = diffusers_name.replace("1.3.weight", "3.1.net.2.weight")
if "norm1" in diffusers_name:
updated_state_dict[diffusers_name.replace("0.norm1", "0")] = value
elif "norm2" in diffusers_name:
updated_state_dict[diffusers_name.replace("0.norm2", "1")] = value
elif "to_kv" in diffusers_name:
v_chunk = value.chunk(2, dim=0)
updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0]
updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1]
elif "to_out" in diffusers_name:
updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value
else:
updated_state_dict[diffusers_name] = value
if not low_cpu_mem_usage:
image_projection.load_state_dict(updated_state_dict)
else:
load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
return image_projection
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False):
from ..models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
)
if low_cpu_mem_usage:
if is_accelerate_available():
from accelerate import init_empty_weights
else:
low_cpu_mem_usage = False
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
# set ip-adapter cross-attention processors & load state_dict
attn_procs = {}
key_id = 1
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
for name in self.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else self.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = self.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(self.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = self.config.block_out_channels[block_id]
if cross_attention_dim is None or "motion_modules" in name:
attn_processor_class = (
AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
)
attn_procs[name] = attn_processor_class()
else:
attn_processor_class = (
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
)
num_image_text_embeds = []
for state_dict in state_dicts:
if "proj.weight" in state_dict["image_proj"]:
# IP-Adapter
num_image_text_embeds += [4]
elif "proj.3.weight" in state_dict["image_proj"]:
# IP-Adapter Full Face
num_image_text_embeds += [257] # 256 CLIP tokens + 1 CLS token
else:
# IP-Adapter Plus
num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]]
with init_context():
attn_procs[name] = attn_processor_class(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
num_tokens=num_image_text_embeds,
)
value_dict = {}
for i, state_dict in enumerate(state_dicts):
value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
if not low_cpu_mem_usage:
attn_procs[name].load_state_dict(value_dict)
else:
device = next(iter(value_dict.values())).device
dtype = next(iter(value_dict.values())).dtype
load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
key_id += 2
return attn_procs
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
if not isinstance(state_dicts, list):
state_dicts = [state_dicts]
# Set encoder_hid_proj after loading ip_adapter weights,
# because `IPAdapterPlusImageProjection` also has `attn_processors`.
self.encoder_hid_proj = None
attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
self.set_attn_processor(attn_procs)
# convert IP-Adapter Image Projection layers to diffusers
image_projection_layers = []
for state_dict in state_dicts:
image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(
state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage
)
image_projection_layers.append(image_projection_layer)
self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
self.config.encoder_hid_dim_type = "ip_image_proj"
self.to(dtype=self.dtype, device=self.device)
class FromOriginalUNetMixin:
"""
Load pretrained UNet model weights saved in the `.ckpt` or `.safetensors` format into a [`StableCascadeUNet`].
"""
@classmethod
@validate_hf_hub_args
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
r"""
Instantiate a [`StableCascadeUNet`] from pretrained StableCascadeUNet weights saved in the original `.ckpt` or
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
Parameters:
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
- A link to the `.ckpt` file (for example
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
- A path to a *file* containing all pipeline weights.
config: (`dict`, *optional*):
Dictionary containing the configuration of the model:
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
dtype is automatically derived from the model's weights.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
incompletely downloaded files are deleted.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to True, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables of the model.
"""
class_name = cls.__name__
if class_name != "StableCascadeUNet":
raise ValueError("FromOriginalUNetMixin is currently only compatible with StableCascadeUNet")
config = kwargs.pop("config", None)
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
token = kwargs.pop("token", None)
cache_dir = kwargs.pop("cache_dir", None)
local_files_only = kwargs.pop("local_files_only", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
checkpoint = load_single_file_model_checkpoint(
pretrained_model_link_or_path,
resume_download=resume_download,
force_download=force_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
)
if config is None:
config = infer_stable_cascade_single_file_config(checkpoint)
model_config = cls.load_config(**config, **kwargs)
else:
model_config = config
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
model = cls.from_config(model_config, **kwargs)
diffusers_format_checkpoint = convert_stable_cascade_unet_single_file_to_diffusers(checkpoint)
if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
if len(unexpected_keys) > 0:
logger.warn(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else:
model.load_state_dict(diffusers_format_checkpoint)
if torch_dtype is not None:
model.to(torch_dtype)
return model
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
import torch
class AttnProcsLayers(torch.nn.Module):
def __init__(self, state_dict: Dict[str, torch.Tensor]):
super().__init__()
self.layers = torch.nn.ModuleList(state_dict.values())
self.mapping = dict(enumerate(state_dict.keys()))
self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}
# .processor for unet, .self_attn for text encoder
self.split_keys = [".processor", ".self_attn"]
# we add a hook to state_dict() and load_state_dict() so that the
# naming fits with `unet.attn_processors`
def map_to(module, state_dict, *args, **kwargs):
new_state_dict = {}
for key, value in state_dict.items():
num = int(key.split(".")[1]) # 0 is always "layers"
new_key = key.replace(f"layers.{num}", module.mapping[num])
new_state_dict[new_key] = value
return new_state_dict
def remap_key(key, state_dict):
for k in self.split_keys:
if k in key:
return key.split(k)[0] + k
raise ValueError(
f"There seems to be a problem with the state_dict: {set(state_dict.keys())}. {key} has to have one of {self.split_keys}."
)
def map_from(module, state_dict, *args, **kwargs):
all_keys = list(state_dict.keys())
for key in all_keys:
replace_key = remap_key(key, state_dict)
new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}")
state_dict[new_key] = state_dict[key]
del state_dict[key]
self._register_state_dict_hook(map_to)
self._register_load_state_dict_pre_hook(map_from, with_module=True)
# Models
For more detail on the models, please refer to the [docs](https://huggingface.co/docs/diffusers/api/models/overview).
\ No newline at end of file
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ..utils import (
DIFFUSERS_SLOW_IMPORT,
_LazyModule,
is_flax_available,
is_torch_available,
)
_import_structure = {}
if is_torch_available():
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["controlnet"] = ["ControlNetModel"]
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
_import_structure["embeddings"] = ["ImageProjection"]
_import_structure["modeling_utils"] = ["ModelMixin"]
_import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
_import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"]
_import_structure["unets.unet_3d_condition"] = ["UNet3DConditionModel"]
_import_structure["unets.unet_i2vgen_xl"] = ["I2VGenXLUNet"]
_import_structure["unets.unet_kandinsky3"] = ["Kandinsky3UNet"]
_import_structure["unets.unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
_import_structure["unets.unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
_import_structure["unets.unet_stable_cascade"] = ["StableCascadeUNet"]
_import_structure["unets.uvit_2d"] = ["UVit2DModel"]
_import_structure["vq_model"] = ["VQModel"]
if is_flax_available():
_import_structure["controlnet_flax"] = ["FlaxControlNetModel"]
_import_structure["unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
_import_structure["vae_flax"] = ["FlaxAutoencoderKL"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
from .adapter import MultiAdapter, T2IAdapter
from .autoencoders import (
AsymmetricAutoencoderKL,
AutoencoderKL,
AutoencoderKLTemporalDecoder,
AutoencoderTiny,
ConsistencyDecoderVAE,
)
from .controlnet import ControlNetModel
from .embeddings import ImageProjection
from .modeling_utils import ModelMixin
from .transformers import (
DualTransformer2DModel,
PriorTransformer,
T5FilmDecoder,
Transformer2DModel,
TransformerTemporalModel,
)
from .unets import (
I2VGenXLUNet,
Kandinsky3UNet,
MotionAdapter,
StableCascadeUNet,
UNet1DModel,
UNet2DConditionModel,
UNet2DModel,
UNet3DConditionModel,
UNetMotionModel,
UNetSpatioTemporalConditionModel,
UVit2DModel,
)
from .vq_model import VQModel
if is_flax_available():
from .controlnet_flax import FlaxControlNetModel
from .unets import FlaxUNet2DConditionModel
from .vae_flax import FlaxAutoencoderKL
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
from torch import nn
from ..utils import deprecate
ACTIVATION_FUNCTIONS = {
"swish": nn.SiLU(),
"silu": nn.SiLU(),
"mish": nn.Mish(),
"gelu": nn.GELU(),
"relu": nn.ReLU(),
}
def get_activation(act_fn: str) -> nn.Module:
"""Helper function to get activation function from string.
Args:
act_fn (str): Name of activation function.
Returns:
nn.Module: Activation function.
"""
act_fn = act_fn.lower()
if act_fn in ACTIVATION_FUNCTIONS:
return ACTIVATION_FUNCTIONS[act_fn]
else:
raise ValueError(f"Unsupported activation function: {act_fn}")
class GELU(nn.Module):
r"""
GELU activation function with tanh approximation support with `approximate="tanh"`.
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
self.approximate = approximate
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type != "mps":
return F.gelu(gate, approximate=self.approximate)
# mps: gelu is not implemented for float16
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
hidden_states = self.gelu(hidden_states)
return hidden_states
class GEGLU(nn.Module):
r"""
A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function.
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type != "mps":
return F.gelu(gate)
# mps: gelu is not implemented for float16
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
def forward(self, hidden_states, *args, **kwargs):
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
return hidden_states * self.gelu(gate)
class ApproximateGELU(nn.Module):
r"""
The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this
[paper](https://arxiv.org/abs/1606.08415).
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
return x * torch.sigmoid(1.702 * x)
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Callable, List, Optional, Union
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import logging
from .modeling_utils import ModelMixin
logger = logging.get_logger(__name__)
class MultiAdapter(ModelMixin):
r"""
MultiAdapter is a wrapper model that contains multiple adapter models and merges their outputs according to
user-assigned weighting.
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the model (such as downloading or saving, etc.)
Parameters:
adapters (`List[T2IAdapter]`, *optional*, defaults to None):
A list of `T2IAdapter` model instances.
"""
def __init__(self, adapters: List["T2IAdapter"]):
super(MultiAdapter, self).__init__()
self.num_adapter = len(adapters)
self.adapters = nn.ModuleList(adapters)
if len(adapters) == 0:
raise ValueError("Expecting at least one adapter")
if len(adapters) == 1:
raise ValueError("For a single adapter, please use the `T2IAdapter` class instead of `MultiAdapter`")
# The outputs from each adapter are added together with a weight.
# This means that the change in dimensions from downsampling must
# be the same for all adapters. Inductively, it also means the
# downscale_factor and total_downscale_factor must be the same for all
# adapters.
first_adapter_total_downscale_factor = adapters[0].total_downscale_factor
first_adapter_downscale_factor = adapters[0].downscale_factor
for idx in range(1, len(adapters)):
if (
adapters[idx].total_downscale_factor != first_adapter_total_downscale_factor
or adapters[idx].downscale_factor != first_adapter_downscale_factor
):
raise ValueError(
f"Expecting all adapters to have the same downscaling behavior, but got:\n"
f"adapters[0].total_downscale_factor={first_adapter_total_downscale_factor}\n"
f"adapters[0].downscale_factor={first_adapter_downscale_factor}\n"
f"adapter[`{idx}`].total_downscale_factor={adapters[idx].total_downscale_factor}\n"
f"adapter[`{idx}`].downscale_factor={adapters[idx].downscale_factor}"
)
self.total_downscale_factor = first_adapter_total_downscale_factor
self.downscale_factor = first_adapter_downscale_factor
def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = None) -> List[torch.Tensor]:
r"""
Args:
xs (`torch.Tensor`):
(batch, channel, height, width) input images for multiple adapter models concated along dimension 1,
`channel` should equal to `num_adapter` * "number of channel of image".
adapter_weights (`List[float]`, *optional*, defaults to None):
List of floats representing the weight which will be multiply to each adapter's output before adding
them together.
"""
if adapter_weights is None:
adapter_weights = torch.tensor([1 / self.num_adapter] * self.num_adapter)
else:
adapter_weights = torch.tensor(adapter_weights)
accume_state = None
for x, w, adapter in zip(xs, adapter_weights, self.adapters):
features = adapter(x)
if accume_state is None:
accume_state = features
for i in range(len(accume_state)):
accume_state[i] = w * accume_state[i]
else:
for i in range(len(features)):
accume_state[i] += w * features[i]
return accume_state
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
is_main_process: bool = True,
save_function: Callable = None,
safe_serialization: bool = True,
variant: Optional[str] = None,
):
"""
Save a model and its configuration file to a directory, so that it can be re-loaded using the
`[`~models.adapter.MultiAdapter.from_pretrained`]` class method.
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to which to save. Will be created if it doesn't exist.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful when in distributed training like
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
the main process to avoid race conditions.
save_function (`Callable`):
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
need to replace `torch.save` by another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
variant (`str`, *optional*):
If specified, weights are saved in the format pytorch_model.<variant>.bin.
"""
idx = 0
model_path_to_save = save_directory
for adapter in self.adapters:
adapter.save_pretrained(
model_path_to_save,
is_main_process=is_main_process,
save_function=save_function,
safe_serialization=safe_serialization,
variant=variant,
)
idx += 1
model_path_to_save = model_path_to_save + f"_{idx}"
@classmethod
def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
r"""
Instantiate a pretrained MultiAdapter model from multiple pre-trained adapter models.
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
the model, you should first set it back in training mode with `model.train()`.
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
task.
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
weights are discarded.
Parameters:
pretrained_model_path (`os.PathLike`):
A path to a *directory* containing model weights saved using
[`~diffusers.models.adapter.MultiAdapter.save_pretrained`], e.g., `./my_model_directory/adapter`.
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
will be automatically derived from the model's weights.
output_loading_info(`bool`, *optional*, defaults to `False`):
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
A map that specifies where each submodule should go. It doesn't need to be refined to each
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
same device.
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
more information about each option see [designing a device
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
max_memory (`Dict`, *optional*):
A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
GPU and the available CPU RAM if unset.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
setting this argument to `True` will raise an error.
variant (`str`, *optional*):
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
ignored when using `from_flax`.
use_safetensors (`bool`, *optional*, defaults to `None`):
If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the
`safetensors` library is installed. If set to `True`, the model will be forcibly loaded from
`safetensors` weights. If set to `False`, loading will *not* use `safetensors`.
"""
idx = 0
adapters = []
# load adapter and append to list until no adapter directory exists anymore
# first adapter has to be saved under `./mydirectory/adapter` to be compliant with `DiffusionPipeline.from_pretrained`
# second, third, ... adapters have to be saved under `./mydirectory/adapter_1`, `./mydirectory/adapter_2`, ...
model_path_to_load = pretrained_model_path
while os.path.isdir(model_path_to_load):
adapter = T2IAdapter.from_pretrained(model_path_to_load, **kwargs)
adapters.append(adapter)
idx += 1
model_path_to_load = pretrained_model_path + f"_{idx}"
logger.info(f"{len(adapters)} adapters loaded from {pretrained_model_path}.")
if len(adapters) == 0:
raise ValueError(
f"No T2IAdapters found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}."
)
return cls(adapters)
class T2IAdapter(ModelMixin, ConfigMixin):
r"""
A simple ResNet-like model that accepts images containing control signals such as keyposes and depth. The model
generates multiple feature maps that are used as additional conditioning in [`UNet2DConditionModel`]. The model's
architecture follows the original implementation of
[Adapter](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L97)
and
[AdapterLight](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L235).
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the model (such as downloading or saving, etc.)
Parameters:
in_channels (`int`, *optional*, defaults to 3):
Number of channels of Aapter's input(*control image*). Set this parameter to 1 if you're using gray scale
image as *control image*.
channels (`List[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The number of channel of each downsample block's output hidden state. The `len(block_out_channels)` will
also determine the number of downsample blocks in the Adapter.
num_res_blocks (`int`, *optional*, defaults to 2):
Number of ResNet blocks in each downsample block.
downscale_factor (`int`, *optional*, defaults to 8):
A factor that determines the total downscale factor of the Adapter.
adapter_type (`str`, *optional*, defaults to `full_adapter`):
The type of Adapter to use. Choose either `full_adapter` or `full_adapter_xl` or `light_adapter`.
"""
@register_to_config
def __init__(
self,
in_channels: int = 3,
channels: List[int] = [320, 640, 1280, 1280],
num_res_blocks: int = 2,
downscale_factor: int = 8,
adapter_type: str = "full_adapter",
):
super().__init__()
if adapter_type == "full_adapter":
self.adapter = FullAdapter(in_channels, channels, num_res_blocks, downscale_factor)
elif adapter_type == "full_adapter_xl":
self.adapter = FullAdapterXL(in_channels, channels, num_res_blocks, downscale_factor)
elif adapter_type == "light_adapter":
self.adapter = LightAdapter(in_channels, channels, num_res_blocks, downscale_factor)
else:
raise ValueError(
f"Unsupported adapter_type: '{adapter_type}'. Choose either 'full_adapter' or "
"'full_adapter_xl' or 'light_adapter'."
)
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
r"""
This function processes the input tensor `x` through the adapter model and returns a list of feature tensors,
each representing information extracted at a different scale from the input. The length of the list is
determined by the number of downsample blocks in the Adapter, as specified by the `channels` and
`num_res_blocks` parameters during initialization.
"""
return self.adapter(x)
@property
def total_downscale_factor(self):
return self.adapter.total_downscale_factor
@property
def downscale_factor(self):
"""The downscale factor applied in the T2I-Adapter's initial pixel unshuffle operation. If an input image's dimensions are
not evenly divisible by the downscale_factor then an exception will be raised.
"""
return self.adapter.unshuffle.downscale_factor
# full adapter
class FullAdapter(nn.Module):
r"""
See [`T2IAdapter`] for more information.
"""
def __init__(
self,
in_channels: int = 3,
channels: List[int] = [320, 640, 1280, 1280],
num_res_blocks: int = 2,
downscale_factor: int = 8,
):
super().__init__()
in_channels = in_channels * downscale_factor**2
self.unshuffle = nn.PixelUnshuffle(downscale_factor)
self.conv_in = nn.Conv2d(in_channels, channels[0], kernel_size=3, padding=1)
self.body = nn.ModuleList(
[
AdapterBlock(channels[0], channels[0], num_res_blocks),
*[
AdapterBlock(channels[i - 1], channels[i], num_res_blocks, down=True)
for i in range(1, len(channels))
],
]
)
self.total_downscale_factor = downscale_factor * 2 ** (len(channels) - 1)
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
r"""
This method processes the input tensor `x` through the FullAdapter model and performs operations including
pixel unshuffling, convolution, and a stack of AdapterBlocks. It returns a list of feature tensors, each
capturing information at a different stage of processing within the FullAdapter model. The number of feature
tensors in the list is determined by the number of downsample blocks specified during initialization.
"""
x = self.unshuffle(x)
x = self.conv_in(x)
features = []
for block in self.body:
x = block(x)
features.append(x)
return features
class FullAdapterXL(nn.Module):
r"""
See [`T2IAdapter`] for more information.
"""
def __init__(
self,
in_channels: int = 3,
channels: List[int] = [320, 640, 1280, 1280],
num_res_blocks: int = 2,
downscale_factor: int = 16,
):
super().__init__()
in_channels = in_channels * downscale_factor**2
self.unshuffle = nn.PixelUnshuffle(downscale_factor)
self.conv_in = nn.Conv2d(in_channels, channels[0], kernel_size=3, padding=1)
self.body = []
# blocks to extract XL features with dimensions of [320, 64, 64], [640, 64, 64], [1280, 32, 32], [1280, 32, 32]
for i in range(len(channels)):
if i == 1:
self.body.append(AdapterBlock(channels[i - 1], channels[i], num_res_blocks))
elif i == 2:
self.body.append(AdapterBlock(channels[i - 1], channels[i], num_res_blocks, down=True))
else:
self.body.append(AdapterBlock(channels[i], channels[i], num_res_blocks))
self.body = nn.ModuleList(self.body)
# XL has only one downsampling AdapterBlock.
self.total_downscale_factor = downscale_factor * 2
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
r"""
This method takes the tensor x as input and processes it through FullAdapterXL model. It consists of operations
including unshuffling pixels, applying convolution layer and appending each block into list of feature tensors.
"""
x = self.unshuffle(x)
x = self.conv_in(x)
features = []
for block in self.body:
x = block(x)
features.append(x)
return features
class AdapterBlock(nn.Module):
r"""
An AdapterBlock is a helper model that contains multiple ResNet-like blocks. It is used in the `FullAdapter` and
`FullAdapterXL` models.
Parameters:
in_channels (`int`):
Number of channels of AdapterBlock's input.
out_channels (`int`):
Number of channels of AdapterBlock's output.
num_res_blocks (`int`):
Number of ResNet blocks in the AdapterBlock.
down (`bool`, *optional*, defaults to `False`):
Whether to perform downsampling on AdapterBlock's input.
"""
def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False):
super().__init__()
self.downsample = None
if down:
self.downsample = nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True)
self.in_conv = None
if in_channels != out_channels:
self.in_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
self.resnets = nn.Sequential(
*[AdapterResnetBlock(out_channels) for _ in range(num_res_blocks)],
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""
This method takes tensor x as input and performs operations downsampling and convolutional layers if the
self.downsample and self.in_conv properties of AdapterBlock model are specified. Then it applies a series of
residual blocks to the input tensor.
"""
if self.downsample is not None:
x = self.downsample(x)
if self.in_conv is not None:
x = self.in_conv(x)
x = self.resnets(x)
return x
class AdapterResnetBlock(nn.Module):
r"""
An `AdapterResnetBlock` is a helper model that implements a ResNet-like block.
Parameters:
channels (`int`):
Number of channels of AdapterResnetBlock's input and output.
"""
def __init__(self, channels: int):
super().__init__()
self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.act = nn.ReLU()
self.block2 = nn.Conv2d(channels, channels, kernel_size=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""
This method takes input tensor x and applies a convolutional layer, ReLU activation, and another convolutional
layer on the input tensor. It returns addition with the input tensor.
"""
h = self.act(self.block1(x))
h = self.block2(h)
return h + x
# light adapter
class LightAdapter(nn.Module):
r"""
See [`T2IAdapter`] for more information.
"""
def __init__(
self,
in_channels: int = 3,
channels: List[int] = [320, 640, 1280],
num_res_blocks: int = 4,
downscale_factor: int = 8,
):
super().__init__()
in_channels = in_channels * downscale_factor**2
self.unshuffle = nn.PixelUnshuffle(downscale_factor)
self.body = nn.ModuleList(
[
LightAdapterBlock(in_channels, channels[0], num_res_blocks),
*[
LightAdapterBlock(channels[i], channels[i + 1], num_res_blocks, down=True)
for i in range(len(channels) - 1)
],
LightAdapterBlock(channels[-1], channels[-1], num_res_blocks, down=True),
]
)
self.total_downscale_factor = downscale_factor * (2 ** len(channels))
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
r"""
This method takes the input tensor x and performs downscaling and appends it in list of feature tensors. Each
feature tensor corresponds to a different level of processing within the LightAdapter.
"""
x = self.unshuffle(x)
features = []
for block in self.body:
x = block(x)
features.append(x)
return features
class LightAdapterBlock(nn.Module):
r"""
A `LightAdapterBlock` is a helper model that contains multiple `LightAdapterResnetBlocks`. It is used in the
`LightAdapter` model.
Parameters:
in_channels (`int`):
Number of channels of LightAdapterBlock's input.
out_channels (`int`):
Number of channels of LightAdapterBlock's output.
num_res_blocks (`int`):
Number of LightAdapterResnetBlocks in the LightAdapterBlock.
down (`bool`, *optional*, defaults to `False`):
Whether to perform downsampling on LightAdapterBlock's input.
"""
def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False):
super().__init__()
mid_channels = out_channels // 4
self.downsample = None
if down:
self.downsample = nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True)
self.in_conv = nn.Conv2d(in_channels, mid_channels, kernel_size=1)
self.resnets = nn.Sequential(*[LightAdapterResnetBlock(mid_channels) for _ in range(num_res_blocks)])
self.out_conv = nn.Conv2d(mid_channels, out_channels, kernel_size=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""
This method takes tensor x as input and performs downsampling if required. Then it applies in convolution
layer, a sequence of residual blocks, and out convolutional layer.
"""
if self.downsample is not None:
x = self.downsample(x)
x = self.in_conv(x)
x = self.resnets(x)
x = self.out_conv(x)
return x
class LightAdapterResnetBlock(nn.Module):
"""
A `LightAdapterResnetBlock` is a helper model that implements a ResNet-like block with a slightly different
architecture than `AdapterResnetBlock`.
Parameters:
channels (`int`):
Number of channels of LightAdapterResnetBlock's input and output.
"""
def __init__(self, channels: int):
super().__init__()
self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.act = nn.ReLU()
self.block2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""
This function takes input tensor x and processes it through one convolutional layer, ReLU activation, and
another convolutional layer and adds it to input tensor.
"""
h = self.act(self.block1(x))
h = self.block2(h)
return h + x
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
import torch
import torch.nn.functional as F
from torch import nn
from ..utils import deprecate, logging
from ..utils.torch_utils import maybe_allow_in_graph
from .activations import GEGLU, GELU, ApproximateGELU
from .attention_processor import Attention
from .embeddings import SinusoidalPositionalEmbedding
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
logger = logging.get_logger(__name__)
def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
# "feed_forward_chunk_size" can be used to save memory
if hidden_states.shape[chunk_dim] % chunk_size != 0:
raise ValueError(
f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
)
num_chunks = hidden_states.shape[chunk_dim] // chunk_size
ff_output = torch.cat(
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
dim=chunk_dim,
)
return ff_output
@maybe_allow_in_graph
class GatedSelfAttentionDense(nn.Module):
r"""
A gated self-attention dense layer that combines visual features and object features.
Parameters:
query_dim (`int`): The number of channels in the query.
context_dim (`int`): The number of channels in the context.
n_heads (`int`): The number of heads to use for attention.
d_head (`int`): The number of channels in each head.
"""
def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
super().__init__()
# we need a linear projection since we need cat visual feature and obj feature
self.linear = nn.Linear(context_dim, query_dim)
self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
self.ff = FeedForward(query_dim, activation_fn="geglu")
self.norm1 = nn.LayerNorm(query_dim)
self.norm2 = nn.LayerNorm(query_dim)
self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
self.enabled = True
def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
if not self.enabled:
return x
n_visual = x.shape[1]
objs = self.linear(objs)
x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
return x
@maybe_allow_in_graph
class BasicTransformerBlock(nn.Module):
r"""
A basic Transformer block.
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
upcast_attention (`bool`, *optional*):
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
final_dropout (`bool` *optional*, defaults to False):
Whether to apply a final dropout after the last feed-forward layer.
attention_type (`str`, *optional*, defaults to `"default"`):
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
positional_embeddings (`str`, *optional*, defaults to `None`):
The type of positional embeddings to apply to.
num_positional_embeddings (`int`, *optional*, defaults to `None`):
The maximum number of positional embeddings to apply.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
norm_eps: float = 1e-5,
final_dropout: bool = False,
attention_type: str = "default",
positional_embeddings: Optional[str] = None,
num_positional_embeddings: Optional[int] = None,
ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
ada_norm_bias: Optional[int] = None,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
):
super().__init__()
self.only_cross_attention = only_cross_attention
# We keep these boolean flags for backward-compatibility.
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
self.use_layer_norm = norm_type == "layer_norm"
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError(
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
)
self.norm_type = norm_type
self.num_embeds_ada_norm = num_embeds_ada_norm
if positional_embeddings and (num_positional_embeddings is None):
raise ValueError(
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
)
if positional_embeddings == "sinusoidal":
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
else:
self.pos_embed = None
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
if norm_type == "ada_norm":
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
elif norm_type == "ada_norm_zero":
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
elif norm_type == "ada_norm_continuous":
self.norm1 = AdaLayerNormContinuous(
dim,
ada_norm_continous_conditioning_embedding_dim,
norm_elementwise_affine,
norm_eps,
ada_norm_bias,
"rms_norm",
)
else:
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
)
# 2. Cross-Attn
if cross_attention_dim is not None or double_self_attention:
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
if norm_type == "ada_norm":
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
elif norm_type == "ada_norm_continuous":
self.norm2 = AdaLayerNormContinuous(
dim,
ada_norm_continous_conditioning_embedding_dim,
norm_elementwise_affine,
norm_eps,
ada_norm_bias,
"rms_norm",
)
else:
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
) # is self-attn if encoder_hidden_states is none
else:
self.norm2 = None
self.attn2 = None
# 3. Feed-forward
if norm_type == "ada_norm_continuous":
self.norm3 = AdaLayerNormContinuous(
dim,
ada_norm_continous_conditioning_embedding_dim,
norm_elementwise_affine,
norm_eps,
ada_norm_bias,
"layer_norm",
)
elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm", "ada_norm_continuous"]:
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
elif norm_type == "layer_norm_i2vgen":
self.norm3 = None
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)
# 4. Fuser
if attention_type == "gated" or attention_type == "gated-text-image":
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
# 5. Scale-shift for PixArt-Alpha.
if norm_type == "ada_norm_single":
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.FloatTensor:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
batch_size = hidden_states.shape[0]
if self.norm_type == "ada_norm":
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.norm_type == "ada_norm_zero":
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
)
elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
norm_hidden_states = self.norm1(hidden_states)
elif self.norm_type == "ada_norm_continuous":
norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
elif self.norm_type == "ada_norm_single":
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
).chunk(6, dim=1)
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
norm_hidden_states = norm_hidden_states.squeeze(1)
else:
raise ValueError("Incorrect norm used")
if self.pos_embed is not None:
norm_hidden_states = self.pos_embed(norm_hidden_states)
# 1. Prepare GLIGEN inputs
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
if self.norm_type == "ada_norm_zero":
attn_output = gate_msa.unsqueeze(1) * attn_output
elif self.norm_type == "ada_norm_single":
attn_output = gate_msa * attn_output
hidden_states = attn_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
# 1.2 GLIGEN Control
if gligen_kwargs is not None:
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
# 3. Cross-Attention
if self.attn2 is not None:
if self.norm_type == "ada_norm":
norm_hidden_states = self.norm2(hidden_states, timestep)
elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
norm_hidden_states = self.norm2(hidden_states)
elif self.norm_type == "ada_norm_single":
# For PixArt norm2 isn't applied here:
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
norm_hidden_states = hidden_states
elif self.norm_type == "ada_norm_continuous":
norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
else:
raise ValueError("Incorrect norm")
if self.pos_embed is not None and self.norm_type != "ada_norm_single":
norm_hidden_states = self.pos_embed(norm_hidden_states)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# 4. Feed-forward
# i2vgen doesn't have this norm 🤷‍♂️
if self.norm_type == "ada_norm_continuous":
norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
elif not self.norm_type == "ada_norm_single":
norm_hidden_states = self.norm3(hidden_states)
if self.norm_type == "ada_norm_zero":
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
if self.norm_type == "ada_norm_single":
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
else:
ff_output = self.ff(norm_hidden_states)
if self.norm_type == "ada_norm_zero":
ff_output = gate_mlp.unsqueeze(1) * ff_output
elif self.norm_type == "ada_norm_single":
ff_output = gate_mlp * ff_output
hidden_states = ff_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
return hidden_states
@maybe_allow_in_graph
class TemporalBasicTransformerBlock(nn.Module):
r"""
A basic Transformer block for video like data.
Parameters:
dim (`int`): The number of channels in the input and output.
time_mix_inner_dim (`int`): The number of channels for temporal attention.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
"""
def __init__(
self,
dim: int,
time_mix_inner_dim: int,
num_attention_heads: int,
attention_head_dim: int,
cross_attention_dim: Optional[int] = None,
):
super().__init__()
self.is_res = dim == time_mix_inner_dim
self.norm_in = nn.LayerNorm(dim)
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
self.ff_in = FeedForward(
dim,
dim_out=time_mix_inner_dim,
activation_fn="geglu",
)
self.norm1 = nn.LayerNorm(time_mix_inner_dim)
self.attn1 = Attention(
query_dim=time_mix_inner_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
cross_attention_dim=None,
)
# 2. Cross-Attn
if cross_attention_dim is not None:
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
self.norm2 = nn.LayerNorm(time_mix_inner_dim)
self.attn2 = Attention(
query_dim=time_mix_inner_dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
) # is self-attn if encoder_hidden_states is none
else:
self.norm2 = None
self.attn2 = None
# 3. Feed-forward
self.norm3 = nn.LayerNorm(time_mix_inner_dim)
self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = None
def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
# Sets chunk feed-forward
self._chunk_size = chunk_size
# chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
self._chunk_dim = 1
def forward(
self,
hidden_states: torch.FloatTensor,
num_frames: int,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
batch_size = hidden_states.shape[0]
batch_frames, seq_length, channels = hidden_states.shape
batch_size = batch_frames // num_frames
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
hidden_states = hidden_states.permute(0, 2, 1, 3)
hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
residual = hidden_states
hidden_states = self.norm_in(hidden_states)
if self._chunk_size is not None:
hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
else:
hidden_states = self.ff_in(hidden_states)
if self.is_res:
hidden_states = hidden_states + residual
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
hidden_states = attn_output + hidden_states
# 3. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = self.norm2(hidden_states)
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
hidden_states = attn_output + hidden_states
# 4. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
if self._chunk_size is not None:
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
else:
ff_output = self.ff(norm_hidden_states)
if self.is_res:
hidden_states = ff_output + hidden_states
else:
hidden_states = ff_output
hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
hidden_states = hidden_states.permute(0, 2, 1, 3)
hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
return hidden_states
class SkipFFTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
kv_input_dim: int,
kv_input_dim_proj_use_bias: bool,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
attention_out_bias: bool = True,
):
super().__init__()
if kv_input_dim != dim:
self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias)
else:
self.kv_mapper = None
self.norm1 = RMSNorm(dim, 1e-06)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim,
out_bias=attention_out_bias,
)
self.norm2 = RMSNorm(dim, 1e-06)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
out_bias=attention_out_bias,
)
def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs):
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
if self.kv_mapper is not None:
encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
norm_hidden_states = self.norm2(hidden_states)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
return hidden_states
class FeedForward(nn.Module):
r"""
A feed-forward layer.
Parameters:
dim (`int`): The number of channels in the input.
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
mult: int = 4,
dropout: float = 0.0,
activation_fn: str = "geglu",
final_dropout: bool = False,
inner_dim=None,
bias: bool = True,
):
super().__init__()
if inner_dim is None:
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
linear_cls = nn.Linear
if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim, bias=bias)
if activation_fn == "gelu-approximate":
act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
elif activation_fn == "geglu":
act_fn = GEGLU(dim, inner_dim, bias=bias)
elif activation_fn == "geglu-approximate":
act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
self.net = nn.ModuleList([])
# project in
self.net.append(act_fn)
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
self.net.append(linear_cls(inner_dim, dim_out, bias=bias))
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
if final_dropout:
self.net.append(nn.Dropout(dropout))
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import math
import flax.linen as nn
import jax
import jax.numpy as jnp
def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
"""Multi-head dot product attention with a limited number of queries."""
num_kv, num_heads, k_features = key.shape[-3:]
v_features = value.shape[-1]
key_chunk_size = min(key_chunk_size, num_kv)
query = query / jnp.sqrt(k_features)
@functools.partial(jax.checkpoint, prevent_cse=False)
def summarize_chunk(query, key, value):
attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision)
max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
max_score = jax.lax.stop_gradient(max_score)
exp_weights = jnp.exp(attn_weights - max_score)
exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision)
max_score = jnp.einsum("...qhk->...qh", max_score)
return (exp_values, exp_weights.sum(axis=-1), max_score)
def chunk_scanner(chunk_idx):
# julienne key array
key_chunk = jax.lax.dynamic_slice(
operand=key,
start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0], # [...,k,h,d]
slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features], # [...,k,h,d]
)
# julienne value array
value_chunk = jax.lax.dynamic_slice(
operand=value,
start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0], # [...,v,h,d]
slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features], # [...,v,h,d]
)
return summarize_chunk(query, key_chunk, value_chunk)
chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size))
global_max = jnp.max(chunk_max, axis=0, keepdims=True)
max_diffs = jnp.exp(chunk_max - global_max)
chunk_values *= jnp.expand_dims(max_diffs, axis=-1)
chunk_weights *= max_diffs
all_values = chunk_values.sum(axis=0)
all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0)
return all_values / all_weights
def jax_memory_efficient_attention(
query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096
):
r"""
Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2
https://github.com/AminRezaei0x443/memory-efficient-attention
Args:
query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head)
key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head)
value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head)
precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`):
numerical precision for computation
query_chunk_size (`int`, *optional*, defaults to 1024):
chunk size to divide query array value must divide query_length equally without remainder
key_chunk_size (`int`, *optional*, defaults to 4096):
chunk size to divide key and value array value must divide key_value_length equally without remainder
Returns:
(`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head)
"""
num_q, num_heads, q_features = query.shape[-3:]
def chunk_scanner(chunk_idx, _):
# julienne query array
query_chunk = jax.lax.dynamic_slice(
operand=query,
start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0], # [...,q,h,d]
slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features], # [...,q,h,d]
)
return (
chunk_idx + query_chunk_size, # unused ignore it
_query_chunk_attention(
query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size
),
)
_, res = jax.lax.scan(
f=chunk_scanner,
init=0,
xs=None,
length=math.ceil(num_q / query_chunk_size), # start counter # stop counter
)
return jnp.concatenate(res, axis=-3) # fuse the chunked result back
class FlaxAttention(nn.Module):
r"""
A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762
Parameters:
query_dim (:obj:`int`):
Input hidden states dimension
heads (:obj:`int`, *optional*, defaults to 8):
Number of heads
dim_head (:obj:`int`, *optional*, defaults to 64):
Hidden states dimension inside each head
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
query_dim: int
heads: int = 8
dim_head: int = 64
dropout: float = 0.0
use_memory_efficient_attention: bool = False
split_head_dim: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self):
inner_dim = self.dim_head * self.heads
self.scale = self.dim_head**-0.5
# Weights were exported with old names {to_q, to_k, to_v, to_out}
self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q")
self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k")
self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")
self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0")
self.dropout_layer = nn.Dropout(rate=self.dropout)
def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor
def reshape_batch_dim_to_heads(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
def __call__(self, hidden_states, context=None, deterministic=True):
context = hidden_states if context is None else context
query_proj = self.query(hidden_states)
key_proj = self.key(context)
value_proj = self.value(context)
if self.split_head_dim:
b = hidden_states.shape[0]
query_states = jnp.reshape(query_proj, (b, -1, self.heads, self.dim_head))
key_states = jnp.reshape(key_proj, (b, -1, self.heads, self.dim_head))
value_states = jnp.reshape(value_proj, (b, -1, self.heads, self.dim_head))
else:
query_states = self.reshape_heads_to_batch_dim(query_proj)
key_states = self.reshape_heads_to_batch_dim(key_proj)
value_states = self.reshape_heads_to_batch_dim(value_proj)
if self.use_memory_efficient_attention:
query_states = query_states.transpose(1, 0, 2)
key_states = key_states.transpose(1, 0, 2)
value_states = value_states.transpose(1, 0, 2)
# this if statement create a chunk size for each layer of the unet
# the chunk size is equal to the query_length dimension of the deepest layer of the unet
flatten_latent_dim = query_states.shape[-3]
if flatten_latent_dim % 64 == 0:
query_chunk_size = int(flatten_latent_dim / 64)
elif flatten_latent_dim % 16 == 0:
query_chunk_size = int(flatten_latent_dim / 16)
elif flatten_latent_dim % 4 == 0:
query_chunk_size = int(flatten_latent_dim / 4)
else:
query_chunk_size = int(flatten_latent_dim)
hidden_states = jax_memory_efficient_attention(
query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4
)
hidden_states = hidden_states.transpose(1, 0, 2)
else:
# compute attentions
if self.split_head_dim:
attention_scores = jnp.einsum("b t n h, b f n h -> b n f t", key_states, query_states)
else:
attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)
attention_scores = attention_scores * self.scale
attention_probs = nn.softmax(attention_scores, axis=-1 if self.split_head_dim else 2)
# attend to values
if self.split_head_dim:
hidden_states = jnp.einsum("b n f t, b t n h -> b f n h", attention_probs, value_states)
b = hidden_states.shape[0]
hidden_states = jnp.reshape(hidden_states, (b, -1, self.heads * self.dim_head))
else:
hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
hidden_states = self.proj_attn(hidden_states)
return self.dropout_layer(hidden_states, deterministic=deterministic)
class FlaxBasicTransformerBlock(nn.Module):
r"""
A Flax transformer block layer with `GLU` (Gated Linear Unit) activation function as described in:
https://arxiv.org/abs/1706.03762
Parameters:
dim (:obj:`int`):
Inner hidden states dimension
n_heads (:obj:`int`):
Number of heads
d_head (:obj:`int`):
Hidden states dimension inside each head
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
only_cross_attention (`bool`, defaults to `False`):
Whether to only apply cross attention.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
"""
dim: int
n_heads: int
d_head: int
dropout: float = 0.0
only_cross_attention: bool = False
dtype: jnp.dtype = jnp.float32
use_memory_efficient_attention: bool = False
split_head_dim: bool = False
def setup(self):
# self attention (or cross_attention if only_cross_attention is True)
self.attn1 = FlaxAttention(
self.dim,
self.n_heads,
self.d_head,
self.dropout,
self.use_memory_efficient_attention,
self.split_head_dim,
dtype=self.dtype,
)
# cross attention
self.attn2 = FlaxAttention(
self.dim,
self.n_heads,
self.d_head,
self.dropout,
self.use_memory_efficient_attention,
self.split_head_dim,
dtype=self.dtype,
)
self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
self.dropout_layer = nn.Dropout(rate=self.dropout)
def __call__(self, hidden_states, context, deterministic=True):
# self attention
residual = hidden_states
if self.only_cross_attention:
hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic)
else:
hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
hidden_states = hidden_states + residual
# cross attention
residual = hidden_states
hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic)
hidden_states = hidden_states + residual
# feed forward
residual = hidden_states
hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic)
hidden_states = hidden_states + residual
return self.dropout_layer(hidden_states, deterministic=deterministic)
class FlaxTransformer2DModel(nn.Module):
r"""
A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in:
https://arxiv.org/pdf/1506.02025.pdf
Parameters:
in_channels (:obj:`int`):
Input number of channels
n_heads (:obj:`int`):
Number of heads
d_head (:obj:`int`):
Hidden states dimension inside each head
depth (:obj:`int`, *optional*, defaults to 1):
Number of transformers block
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
use_linear_projection (`bool`, defaults to `False`): tbd
only_cross_attention (`bool`, defaults to `False`): tbd
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
"""
in_channels: int
n_heads: int
d_head: int
depth: int = 1
dropout: float = 0.0
use_linear_projection: bool = False
only_cross_attention: bool = False
dtype: jnp.dtype = jnp.float32
use_memory_efficient_attention: bool = False
split_head_dim: bool = False
def setup(self):
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
inner_dim = self.n_heads * self.d_head
if self.use_linear_projection:
self.proj_in = nn.Dense(inner_dim, dtype=self.dtype)
else:
self.proj_in = nn.Conv(
inner_dim,
kernel_size=(1, 1),
strides=(1, 1),
padding="VALID",
dtype=self.dtype,
)
self.transformer_blocks = [
FlaxBasicTransformerBlock(
inner_dim,
self.n_heads,
self.d_head,
dropout=self.dropout,
only_cross_attention=self.only_cross_attention,
dtype=self.dtype,
use_memory_efficient_attention=self.use_memory_efficient_attention,
split_head_dim=self.split_head_dim,
)
for _ in range(self.depth)
]
if self.use_linear_projection:
self.proj_out = nn.Dense(inner_dim, dtype=self.dtype)
else:
self.proj_out = nn.Conv(
inner_dim,
kernel_size=(1, 1),
strides=(1, 1),
padding="VALID",
dtype=self.dtype,
)
self.dropout_layer = nn.Dropout(rate=self.dropout)
def __call__(self, hidden_states, context, deterministic=True):
batch, height, width, channels = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
if self.use_linear_projection:
hidden_states = hidden_states.reshape(batch, height * width, channels)
hidden_states = self.proj_in(hidden_states)
else:
hidden_states = self.proj_in(hidden_states)
hidden_states = hidden_states.reshape(batch, height * width, channels)
for transformer_block in self.transformer_blocks:
hidden_states = transformer_block(hidden_states, context, deterministic=deterministic)
if self.use_linear_projection:
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, width, channels)
else:
hidden_states = hidden_states.reshape(batch, height, width, channels)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states + residual
return self.dropout_layer(hidden_states, deterministic=deterministic)
class FlaxFeedForward(nn.Module):
r"""
Flax module that encapsulates two Linear layers separated by a non-linearity. It is the counterpart of PyTorch's
[`FeedForward`] class, with the following simplifications:
- The activation function is currently hardcoded to a gated linear unit from:
https://arxiv.org/abs/2002.05202
- `dim_out` is equal to `dim`.
- The number of hidden dimensions is hardcoded to `dim * 4` in [`FlaxGELU`].
Parameters:
dim (:obj:`int`):
Inner hidden states dimension
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
dim: int
dropout: float = 0.0
dtype: jnp.dtype = jnp.float32
def setup(self):
# The second linear layer needs to be called
# net_2 for now to match the index of the Sequential layer
self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype)
self.net_2 = nn.Dense(self.dim, dtype=self.dtype)
def __call__(self, hidden_states, deterministic=True):
hidden_states = self.net_0(hidden_states, deterministic=deterministic)
hidden_states = self.net_2(hidden_states)
return hidden_states
class FlaxGEGLU(nn.Module):
r"""
Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from
https://arxiv.org/abs/2002.05202.
Parameters:
dim (:obj:`int`):
Input hidden states dimension
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
dim: int
dropout: float = 0.0
dtype: jnp.dtype = jnp.float32
def setup(self):
inner_dim = self.dim * 4
self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
self.dropout_layer = nn.Dropout(rate=self.dropout)
def __call__(self, hidden_states, deterministic=True):
hidden_states = self.proj(hidden_states)
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from importlib import import_module
from typing import Callable, Optional, Union
import torch
import torch.nn.functional as F
from torch import nn
from ..image_processor import IPAdapterMaskProcessor
from ..utils import deprecate, logging
from ..utils.import_utils import is_xformers_available
from ..utils.torch_utils import maybe_allow_in_graph
from .lora import LoRALinearLayer
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None
@maybe_allow_in_graph
class Attention(nn.Module):
r"""
A cross attention layer.
Parameters:
query_dim (`int`):
The number of channels in the query.
cross_attention_dim (`int`, *optional*):
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
heads (`int`, *optional*, defaults to 8):
The number of heads to use for multi-head attention.
dim_head (`int`, *optional*, defaults to 64):
The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0):
The dropout probability to use.
bias (`bool`, *optional*, defaults to False):
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
upcast_attention (`bool`, *optional*, defaults to False):
Set to `True` to upcast the attention computation to `float32`.
upcast_softmax (`bool`, *optional*, defaults to False):
Set to `True` to upcast the softmax computation to `float32`.
cross_attention_norm (`str`, *optional*, defaults to `None`):
The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups to use for the group norm in the cross attention.
added_kv_proj_dim (`int`, *optional*, defaults to `None`):
The number of channels to use for the added key and value projections. If `None`, no projection is used.
norm_num_groups (`int`, *optional*, defaults to `None`):
The number of groups to use for the group norm in the attention.
spatial_norm_dim (`int`, *optional*, defaults to `None`):
The number of channels to use for the spatial normalization.
out_bias (`bool`, *optional*, defaults to `True`):
Set to `True` to use a bias in the output linear layer.
scale_qk (`bool`, *optional*, defaults to `True`):
Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
only_cross_attention (`bool`, *optional*, defaults to `False`):
Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
`added_kv_proj_dim` is not `None`.
eps (`float`, *optional*, defaults to 1e-5):
An additional value added to the denominator in group normalization that is used for numerical stability.
rescale_output_factor (`float`, *optional*, defaults to 1.0):
A factor to rescale the output by dividing it with this value.
residual_connection (`bool`, *optional*, defaults to `False`):
Set to `True` to add the residual connection to the output.
_from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
Set to `True` if the attention block is loaded from a deprecated state dict.
processor (`AttnProcessor`, *optional*, defaults to `None`):
The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
`AttnProcessor` otherwise.
"""
def __init__(
self,
query_dim: int,
cross_attention_dim: Optional[int] = None,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
upcast_attention: bool = False,
upcast_softmax: bool = False,
cross_attention_norm: Optional[str] = None,
cross_attention_norm_num_groups: int = 32,
added_kv_proj_dim: Optional[int] = None,
norm_num_groups: Optional[int] = None,
spatial_norm_dim: Optional[int] = None,
out_bias: bool = True,
scale_qk: bool = True,
only_cross_attention: bool = False,
eps: float = 1e-5,
rescale_output_factor: float = 1.0,
residual_connection: bool = False,
_from_deprecated_attn_block: bool = False,
processor: Optional["AttnProcessor"] = None,
out_dim: int = None,
):
super().__init__()
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.query_dim = query_dim
self.use_bias = bias
self.is_cross_attention = cross_attention_dim is not None
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.rescale_output_factor = rescale_output_factor
self.residual_connection = residual_connection
self.dropout = dropout
self.fused_projections = False
self.out_dim = out_dim if out_dim is not None else query_dim
# we make use of this private variable to know whether this class is loaded
# with an deprecated state dict so that we can convert it on the fly
self._from_deprecated_attn_block = _from_deprecated_attn_block
self.scale_qk = scale_qk
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
self.heads = out_dim // dim_head if out_dim is not None else heads
# for slice_size > 0 the attention score computation
# is split across the batch axis to save memory
# You can set slice_size with `set_attention_slice`
self.sliceable_head_dim = heads
self.added_kv_proj_dim = added_kv_proj_dim
self.only_cross_attention = only_cross_attention
if self.added_kv_proj_dim is None and self.only_cross_attention:
raise ValueError(
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
)
if norm_num_groups is not None:
self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
else:
self.group_norm = None
if spatial_norm_dim is not None:
self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
else:
self.spatial_norm = None
if cross_attention_norm is None:
self.norm_cross = None
elif cross_attention_norm == "layer_norm":
self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
elif cross_attention_norm == "group_norm":
if self.added_kv_proj_dim is not None:
# The given `encoder_hidden_states` are initially of shape
# (batch_size, seq_len, added_kv_proj_dim) before being projected
# to (batch_size, seq_len, cross_attention_dim). The norm is applied
# before the projection, so we need to use `added_kv_proj_dim` as
# the number of channels for the group norm.
norm_cross_num_channels = added_kv_proj_dim
else:
norm_cross_num_channels = self.cross_attention_dim
self.norm_cross = nn.GroupNorm(
num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
)
else:
raise ValueError(
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
)
linear_cls = nn.Linear
self.linear_cls = linear_cls
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
if not self.only_cross_attention:
# only relevant for the `AddedKVProcessor` classes
self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
else:
self.to_k = None
self.to_v = None
if self.added_kv_proj_dim is not None:
self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
self.to_out = nn.ModuleList([])
self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(nn.Dropout(dropout))
# set attention processor
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
if processor is None:
processor = (
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
)
self.set_processor(processor)
def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
) -> None:
r"""
Set whether to use memory efficient attention from `xformers` or not.
Args:
use_memory_efficient_attention_xformers (`bool`):
Whether to use memory efficient attention from `xformers` or not.
attention_op (`Callable`, *optional*):
The attention operation to use. Defaults to `None` which uses the default attention operation from
`xformers`.
"""
is_lora = hasattr(self, "processor") and isinstance(
self.processor,
LORA_ATTENTION_PROCESSORS,
)
is_custom_diffusion = hasattr(self, "processor") and isinstance(
self.processor,
(CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0),
)
is_added_kv_processor = hasattr(self, "processor") and isinstance(
self.processor,
(
AttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
SlicedAttnAddedKVProcessor,
XFormersAttnAddedKVProcessor,
LoRAAttnAddedKVProcessor,
),
)
if use_memory_efficient_attention_xformers:
if is_added_kv_processor and (is_lora or is_custom_diffusion):
raise NotImplementedError(
f"Memory efficient attention is currently not supported for LoRA or custom diffusion for attention processor type {self.processor}"
)
if not is_xformers_available():
raise ModuleNotFoundError(
(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
" xformers"
),
name="xformers",
)
elif not torch.cuda.is_available():
raise ValueError(
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
" only available for GPU "
)
else:
try:
# Make sure we can run the memory efficient attention
_ = xformers.ops.memory_efficient_attention(
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
)
except Exception as e:
raise e
if is_lora:
# TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
# variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
processor = LoRAXFormersAttnProcessor(
hidden_size=self.processor.hidden_size,
cross_attention_dim=self.processor.cross_attention_dim,
rank=self.processor.rank,
attention_op=attention_op,
)
processor.load_state_dict(self.processor.state_dict())
processor.to(self.processor.to_q_lora.up.weight.device)
elif is_custom_diffusion:
processor = CustomDiffusionXFormersAttnProcessor(
train_kv=self.processor.train_kv,
train_q_out=self.processor.train_q_out,
hidden_size=self.processor.hidden_size,
cross_attention_dim=self.processor.cross_attention_dim,
attention_op=attention_op,
)
processor.load_state_dict(self.processor.state_dict())
if hasattr(self.processor, "to_k_custom_diffusion"):
processor.to(self.processor.to_k_custom_diffusion.weight.device)
elif is_added_kv_processor:
# TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
# which uses this type of cross attention ONLY because the attention mask of format
# [0, ..., -10.000, ..., 0, ...,] is not supported
# throw warning
logger.info(
"Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
)
processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
else:
processor = XFormersAttnProcessor(attention_op=attention_op)
else:
if is_lora:
attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
)
processor = attn_processor_class(
hidden_size=self.processor.hidden_size,
cross_attention_dim=self.processor.cross_attention_dim,
rank=self.processor.rank,
)
processor.load_state_dict(self.processor.state_dict())
processor.to(self.processor.to_q_lora.up.weight.device)
elif is_custom_diffusion:
attn_processor_class = (
CustomDiffusionAttnProcessor2_0
if hasattr(F, "scaled_dot_product_attention")
else CustomDiffusionAttnProcessor
)
processor = attn_processor_class(
train_kv=self.processor.train_kv,
train_q_out=self.processor.train_q_out,
hidden_size=self.processor.hidden_size,
cross_attention_dim=self.processor.cross_attention_dim,
)
processor.load_state_dict(self.processor.state_dict())
if hasattr(self.processor, "to_k_custom_diffusion"):
processor.to(self.processor.to_k_custom_diffusion.weight.device)
else:
# set attention processor
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
processor = (
AttnProcessor2_0()
if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
else AttnProcessor()
)
self.set_processor(processor)
def set_attention_slice(self, slice_size: int) -> None:
r"""
Set the slice size for attention computation.
Args:
slice_size (`int`):
The slice size for attention computation.
"""
if slice_size is not None and slice_size > self.sliceable_head_dim:
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
if slice_size is not None and self.added_kv_proj_dim is not None:
processor = SlicedAttnAddedKVProcessor(slice_size)
elif slice_size is not None:
processor = SlicedAttnProcessor(slice_size)
elif self.added_kv_proj_dim is not None:
processor = AttnAddedKVProcessor()
else:
# set attention processor
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
processor = (
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
)
self.set_processor(processor)
def set_processor(self, processor: "AttnProcessor") -> None:
r"""
Set the attention processor to use.
Args:
processor (`AttnProcessor`):
The attention processor to use.
"""
# if current processor is in `self._modules` and if passed `processor` is not, we need to
# pop `processor` from `self._modules`
if (
hasattr(self, "processor")
and isinstance(self.processor, torch.nn.Module)
and not isinstance(processor, torch.nn.Module)
):
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
self._modules.pop("processor")
self.processor = processor
def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
r"""
Get the attention processor in use.
Args:
return_deprecated_lora (`bool`, *optional*, defaults to `False`):
Set to `True` to return the deprecated LoRA attention processor.
Returns:
"AttentionProcessor": The attention processor in use.
"""
if not return_deprecated_lora:
return self.processor
# TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible
# serialization format for LoRA Attention Processors. It should be deleted once the integration
# with PEFT is completed.
is_lora_activated = {
name: module.lora_layer is not None
for name, module in self.named_modules()
if hasattr(module, "lora_layer")
}
# 1. if no layer has a LoRA activated we can return the processor as usual
if not any(is_lora_activated.values()):
return self.processor
# If doesn't apply LoRA do `add_k_proj` or `add_v_proj`
is_lora_activated.pop("add_k_proj", None)
is_lora_activated.pop("add_v_proj", None)
# 2. else it is not posssible that only some layers have LoRA activated
if not all(is_lora_activated.values()):
raise ValueError(
f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}"
)
# 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
non_lora_processor_cls_name = self.processor.__class__.__name__
lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name)
hidden_size = self.inner_dim
# now create a LoRA attention processor from the LoRA layers
if lora_processor_cls in [LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor]:
kwargs = {
"cross_attention_dim": self.cross_attention_dim,
"rank": self.to_q.lora_layer.rank,
"network_alpha": self.to_q.lora_layer.network_alpha,
"q_rank": self.to_q.lora_layer.rank,
"q_hidden_size": self.to_q.lora_layer.out_features,
"k_rank": self.to_k.lora_layer.rank,
"k_hidden_size": self.to_k.lora_layer.out_features,
"v_rank": self.to_v.lora_layer.rank,
"v_hidden_size": self.to_v.lora_layer.out_features,
"out_rank": self.to_out[0].lora_layer.rank,
"out_hidden_size": self.to_out[0].lora_layer.out_features,
}
if hasattr(self.processor, "attention_op"):
kwargs["attention_op"] = self.processor.attention_op
lora_processor = lora_processor_cls(hidden_size, **kwargs)
lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
elif lora_processor_cls == LoRAAttnAddedKVProcessor:
lora_processor = lora_processor_cls(
hidden_size,
cross_attention_dim=self.add_k_proj.weight.shape[0],
rank=self.to_q.lora_layer.rank,
network_alpha=self.to_q.lora_layer.network_alpha,
)
lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
# only save if used
if self.add_k_proj.lora_layer is not None:
lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict())
lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict())
else:
lora_processor.add_k_proj_lora = None
lora_processor.add_v_proj_lora = None
else:
raise ValueError(f"{lora_processor_cls} does not exist.")
return lora_processor
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
**cross_attention_kwargs,
) -> torch.Tensor:
r"""
The forward method of the `Attention` class.
Args:
hidden_states (`torch.Tensor`):
The hidden states of the query.
encoder_hidden_states (`torch.Tensor`, *optional*):
The hidden states of the encoder.
attention_mask (`torch.Tensor`, *optional*):
The attention mask to use. If `None`, no mask is applied.
**cross_attention_kwargs:
Additional keyword arguments to pass along to the cross attention.
Returns:
`torch.Tensor`: The output of the attention layer.
"""
# The `Attention` class can call different attention processors / attention functions
# here we simply pass along all tensors to the selected processor class
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
unused_kwargs = [k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters]
if len(unused_kwargs) > 0:
logger.warning(
f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
)
cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
return self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
r"""
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
is the number of heads initialized while constructing the `Attention` class.
Args:
tensor (`torch.Tensor`): The tensor to reshape.
Returns:
`torch.Tensor`: The reshaped tensor.
"""
head_size = self.heads
batch_size, seq_len, dim = tensor.shape
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
r"""
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
the number of heads initialized while constructing the `Attention` class.
Args:
tensor (`torch.Tensor`): The tensor to reshape.
out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
reshaped to `[batch_size * heads, seq_len, dim // heads]`.
Returns:
`torch.Tensor`: The reshaped tensor.
"""
head_size = self.heads
if tensor.ndim == 3:
batch_size, seq_len, dim = tensor.shape
extra_dim = 1
else:
batch_size, extra_dim, seq_len, dim = tensor.shape
tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3)
if out_dim == 3:
tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)
return tensor
def get_attention_scores(
self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None
) -> torch.Tensor:
r"""
Compute the attention scores.
Args:
query (`torch.Tensor`): The query tensor.
key (`torch.Tensor`): The key tensor.
attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
Returns:
`torch.Tensor`: The attention probabilities/scores.
"""
dtype = query.dtype
if self.upcast_attention:
query = query.float()
key = key.float()
if attention_mask is None:
baddbmm_input = torch.empty(
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
)
beta = 0
else:
baddbmm_input = attention_mask
beta = 1
attention_scores = torch.baddbmm(
baddbmm_input,
query,
key.transpose(-1, -2),
beta=beta,
alpha=self.scale,
)
del baddbmm_input
if self.upcast_softmax:
attention_scores = attention_scores.float()
attention_probs = attention_scores.softmax(dim=-1)
del attention_scores
attention_probs = attention_probs.to(dtype)
return attention_probs
def prepare_attention_mask(
self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
) -> torch.Tensor:
r"""
Prepare the attention mask for the attention computation.
Args:
attention_mask (`torch.Tensor`):
The attention mask to prepare.
target_length (`int`):
The target length of the attention mask. This is the length of the attention mask after padding.
batch_size (`int`):
The batch size, which is used to repeat the attention mask.
out_dim (`int`, *optional*, defaults to `3`):
The output dimension of the attention mask. Can be either `3` or `4`.
Returns:
`torch.Tensor`: The prepared attention mask.
"""
head_size = self.heads
if attention_mask is None:
return attention_mask
current_length: int = attention_mask.shape[-1]
if current_length != target_length:
if attention_mask.device.type == "mps":
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
# Instead, we can manually construct the padding tensor.
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
attention_mask = torch.cat([attention_mask, padding], dim=2)
else:
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
# we want to instead pad by (0, remaining_length), where remaining_length is:
# remaining_length: int = target_length - current_length
# TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
if out_dim == 3:
if attention_mask.shape[0] < batch_size * head_size:
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
elif out_dim == 4:
attention_mask = attention_mask.unsqueeze(1)
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
return attention_mask
def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
r"""
Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
`Attention` class.
Args:
encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
Returns:
`torch.Tensor`: The normalized encoder hidden states.
"""
assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
if isinstance(self.norm_cross, nn.LayerNorm):
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
elif isinstance(self.norm_cross, nn.GroupNorm):
# Group norm norms along the channels dimension and expects
# input to be in the shape of (N, C, *). In this case, we want
# to norm along the hidden dimension, so we need to move
# (batch_size, sequence_length, hidden_size) ->
# (batch_size, hidden_size, sequence_length)
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
else:
assert False
return encoder_hidden_states
@torch.no_grad()
def fuse_projections(self, fuse=True):
device = self.to_q.weight.data.device
dtype = self.to_q.weight.data.dtype
if not self.is_cross_attention:
# fetch weight matrices.
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
# create a new single projection layer and copy over the weights.
self.to_qkv = self.linear_cls(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
self.to_qkv.weight.copy_(concatenated_weights)
if self.use_bias:
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
self.to_qkv.bias.copy_(concatenated_bias)
else:
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
self.to_kv = self.linear_cls(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
self.to_kv.weight.copy_(concatenated_weights)
if self.use_bias:
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
self.to_kv.bias.copy_(concatenated_bias)
self.fused_projections = fuse
class AttnProcessor:
r"""
Default processor for performing attention-related computations.
"""
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class CustomDiffusionAttnProcessor(nn.Module):
r"""
Processor for implementing attention for the Custom Diffusion method.
Args:
train_kv (`bool`, defaults to `True`):
Whether to newly train the key and value matrices corresponding to the text features.
train_q_out (`bool`, defaults to `True`):
Whether to newly train query matrices corresponding to the latent image features.
hidden_size (`int`, *optional*, defaults to `None`):
The hidden size of the attention layer.
cross_attention_dim (`int`, *optional*, defaults to `None`):
The number of channels in the `encoder_hidden_states`.
out_bias (`bool`, defaults to `True`):
Whether to include the bias parameter in `train_q_out`.
dropout (`float`, *optional*, defaults to 0.0):
The dropout probability to use.
"""
def __init__(
self,
train_kv: bool = True,
train_q_out: bool = True,
hidden_size: Optional[int] = None,
cross_attention_dim: Optional[int] = None,
out_bias: bool = True,
dropout: float = 0.0,
):
super().__init__()
self.train_kv = train_kv
self.train_q_out = train_q_out
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
# `_custom_diffusion` id for easy serialization and loading.
if self.train_kv:
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
if self.train_q_out:
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
self.to_out_custom_diffusion = nn.ModuleList([])
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if self.train_q_out:
query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
else:
query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))
if encoder_hidden_states is None:
crossattn = False
encoder_hidden_states = hidden_states
else:
crossattn = True
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
if self.train_kv:
key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
key = key.to(attn.to_q.weight.dtype)
value = value.to(attn.to_q.weight.dtype)
else:
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
if crossattn:
detach = torch.ones_like(key)
detach[:, :1, :] = detach[:, :1, :] * 0.0
key = detach * key + (1 - detach) * key.detach()
value = detach * value + (1 - detach) * value.detach()
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
if self.train_q_out:
# linear proj
hidden_states = self.to_out_custom_diffusion[0](hidden_states)
# dropout
hidden_states = self.to_out_custom_diffusion[1](hidden_states)
else:
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class AttnAddedKVProcessor:
r"""
Processor for performing attention-related computations with extra learnable key and value matrices for the text
encoder.
"""
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
if not attn.only_cross_attention:
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
else:
key = encoder_hidden_states_key_proj
value = encoder_hidden_states_value_proj
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
hidden_states = hidden_states + residual
return hidden_states
class AttnAddedKVProcessor2_0:
r"""
Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra
learnable key and value matrices for the text encoder.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query, out_dim=4)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
if not attn.only_cross_attention:
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
key = attn.head_to_batch_dim(key, out_dim=4)
value = attn.head_to_batch_dim(value, out_dim=4)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
else:
key = encoder_hidden_states_key_proj
value = encoder_hidden_states_value_proj
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
hidden_states = hidden_states + residual
return hidden_states
class XFormersAttnAddedKVProcessor:
r"""
Processor for implementing memory efficient attention using xFormers.
Args:
attention_op (`Callable`, *optional*, defaults to `None`):
The base
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
operator.
"""
def __init__(self, attention_op: Optional[Callable] = None):
self.attention_op = attention_op
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
if not attn.only_cross_attention:
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
else:
key = encoder_hidden_states_key_proj
value = encoder_hidden_states_value_proj
hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
hidden_states = hidden_states + residual
return hidden_states
class XFormersAttnProcessor:
r"""
Processor for implementing memory efficient attention using xFormers.
Args:
attention_op (`Callable`, *optional*, defaults to `None`):
The base
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
operator.
"""
def __init__(self, attention_op: Optional[Callable] = None):
self.attention_op = attention_op
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, key_tokens, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
if attention_mask is not None:
# expand our mask's singleton query_tokens dimension:
# [batch*heads, 1, key_tokens] ->
# [batch*heads, query_tokens, key_tokens]
# so that it can be added as a bias onto the attention scores that xformers computes:
# [batch*heads, query_tokens, key_tokens]
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
_, query_tokens, _ = hidden_states.shape
attention_mask = attention_mask.expand(-1, query_tokens, -1)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class AttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class FusedAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
It uses fused projection layers. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is currently 🧪 experimental in nature and can change in future.
</Tip>
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"FusedAttnProcessor2_0 requires at least PyTorch 2.0, to use it. Please upgrade PyTorch to > 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
if encoder_hidden_states is None:
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
else:
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
query = attn.to_q(hidden_states)
kv = attn.to_kv(encoder_hidden_states)
split_size = kv.shape[-1] // 2
key, value = torch.split(kv, split_size, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class CustomDiffusionXFormersAttnProcessor(nn.Module):
r"""
Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
Args:
train_kv (`bool`, defaults to `True`):
Whether to newly train the key and value matrices corresponding to the text features.
train_q_out (`bool`, defaults to `True`):
Whether to newly train query matrices corresponding to the latent image features.
hidden_size (`int`, *optional*, defaults to `None`):
The hidden size of the attention layer.
cross_attention_dim (`int`, *optional*, defaults to `None`):
The number of channels in the `encoder_hidden_states`.
out_bias (`bool`, defaults to `True`):
Whether to include the bias parameter in `train_q_out`.
dropout (`float`, *optional*, defaults to 0.0):
The dropout probability to use.
attention_op (`Callable`, *optional*, defaults to `None`):
The base
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use
as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator.
"""
def __init__(
self,
train_kv: bool = True,
train_q_out: bool = False,
hidden_size: Optional[int] = None,
cross_attention_dim: Optional[int] = None,
out_bias: bool = True,
dropout: float = 0.0,
attention_op: Optional[Callable] = None,
):
super().__init__()
self.train_kv = train_kv
self.train_q_out = train_q_out
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.attention_op = attention_op
# `_custom_diffusion` id for easy serialization and loading.
if self.train_kv:
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
if self.train_q_out:
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
self.to_out_custom_diffusion = nn.ModuleList([])
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if self.train_q_out:
query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
else:
query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))
if encoder_hidden_states is None:
crossattn = False
encoder_hidden_states = hidden_states
else:
crossattn = True
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
if self.train_kv:
key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
key = key.to(attn.to_q.weight.dtype)
value = value.to(attn.to_q.weight.dtype)
else:
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
if crossattn:
detach = torch.ones_like(key)
detach[:, :1, :] = detach[:, :1, :] * 0.0
key = detach * key + (1 - detach) * key.detach()
value = detach * value + (1 - detach) * value.detach()
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
if self.train_q_out:
# linear proj
hidden_states = self.to_out_custom_diffusion[0](hidden_states)
# dropout
hidden_states = self.to_out_custom_diffusion[1](hidden_states)
else:
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class CustomDiffusionAttnProcessor2_0(nn.Module):
r"""
Processor for implementing attention for the Custom Diffusion method using PyTorch 2.0’s memory-efficient scaled
dot-product attention.
Args:
train_kv (`bool`, defaults to `True`):
Whether to newly train the key and value matrices corresponding to the text features.
train_q_out (`bool`, defaults to `True`):
Whether to newly train query matrices corresponding to the latent image features.
hidden_size (`int`, *optional*, defaults to `None`):
The hidden size of the attention layer.
cross_attention_dim (`int`, *optional*, defaults to `None`):
The number of channels in the `encoder_hidden_states`.
out_bias (`bool`, defaults to `True`):
Whether to include the bias parameter in `train_q_out`.
dropout (`float`, *optional*, defaults to 0.0):
The dropout probability to use.
"""
def __init__(
self,
train_kv: bool = True,
train_q_out: bool = True,
hidden_size: Optional[int] = None,
cross_attention_dim: Optional[int] = None,
out_bias: bool = True,
dropout: float = 0.0,
):
super().__init__()
self.train_kv = train_kv
self.train_q_out = train_q_out
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
# `_custom_diffusion` id for easy serialization and loading.
if self.train_kv:
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
if self.train_q_out:
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
self.to_out_custom_diffusion = nn.ModuleList([])
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if self.train_q_out:
query = self.to_q_custom_diffusion(hidden_states)
else:
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
crossattn = False
encoder_hidden_states = hidden_states
else:
crossattn = True
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
if self.train_kv:
key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
key = key.to(attn.to_q.weight.dtype)
value = value.to(attn.to_q.weight.dtype)
else:
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
if crossattn:
detach = torch.ones_like(key)
detach[:, :1, :] = detach[:, :1, :] * 0.0
key = detach * key + (1 - detach) * key.detach()
value = detach * value + (1 - detach) * value.detach()
inner_dim = hidden_states.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if self.train_q_out:
# linear proj
hidden_states = self.to_out_custom_diffusion[0](hidden_states)
# dropout
hidden_states = self.to_out_custom_diffusion[1](hidden_states)
else:
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class SlicedAttnProcessor:
r"""
Processor for implementing sliced attention.
Args:
slice_size (`int`, *optional*):
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
`attention_head_dim` must be a multiple of the `slice_size`.
"""
def __init__(self, slice_size: int):
self.slice_size = slice_size
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
dim = query.shape[-1]
query = attn.head_to_batch_dim(query)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
batch_size_attention, query_tokens, _ = query.shape
hidden_states = torch.zeros(
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
)
for i in range(batch_size_attention // self.slice_size):
start_idx = i * self.slice_size
end_idx = (i + 1) * self.slice_size
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class SlicedAttnAddedKVProcessor:
r"""
Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder.
Args:
slice_size (`int`, *optional*):
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
`attention_head_dim` must be a multiple of the `slice_size`.
"""
def __init__(self, slice_size):
self.slice_size = slice_size
def __call__(
self,
attn: "Attention",
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
dim = query.shape[-1]
query = attn.head_to_batch_dim(query)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
if not attn.only_cross_attention:
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
else:
key = encoder_hidden_states_key_proj
value = encoder_hidden_states_value_proj
batch_size_attention, query_tokens, _ = query.shape
hidden_states = torch.zeros(
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
)
for i in range(batch_size_attention // self.slice_size):
start_idx = i * self.slice_size
end_idx = (i + 1) * self.slice_size
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
hidden_states = hidden_states + residual
return hidden_states
class SpatialNorm(nn.Module):
"""
Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002.
Args:
f_channels (`int`):
The number of channels for input to group normalization layer, and output of the spatial norm layer.
zq_channels (`int`):
The number of channels for the quantized vector as described in the paper.
"""
def __init__(
self,
f_channels: int,
zq_channels: int,
):
super().__init__()
self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
def forward(self, f: torch.FloatTensor, zq: torch.FloatTensor) -> torch.FloatTensor:
f_size = f.shape[-2:]
zq = F.interpolate(zq, size=f_size, mode="nearest")
norm_f = self.norm_layer(f)
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
return new_f
class LoRAAttnProcessor(nn.Module):
def __init__(
self,
hidden_size: int,
cross_attention_dim: Optional[int] = None,
rank: int = 4,
network_alpha: Optional[int] = None,
**kwargs,
):
deprecation_message = "Using LoRAAttnProcessor is deprecated. Please use the PEFT backend for all things LoRA. You can install PEFT by running `pip install peft`."
deprecate("LoRAAttnProcessor", "0.30.0", deprecation_message, standard_warn=False)
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.rank = rank
q_rank = kwargs.pop("q_rank", None)
q_hidden_size = kwargs.pop("q_hidden_size", None)
q_rank = q_rank if q_rank is not None else rank
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
v_rank = kwargs.pop("v_rank", None)
v_hidden_size = kwargs.pop("v_hidden_size", None)
v_rank = v_rank if v_rank is not None else rank
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
out_rank = kwargs.pop("out_rank", None)
out_hidden_size = kwargs.pop("out_hidden_size", None)
out_rank = out_rank if out_rank is not None else rank
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
self_cls_name = self.__class__.__name__
deprecate(
self_cls_name,
"0.26.0",
(
f"Make sure use {self_cls_name[4:]} instead by setting"
"LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
" `LoraLoaderMixin.load_lora_weights`"
),
)
attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
attn._modules.pop("processor")
attn.processor = AttnProcessor()
return attn.processor(attn, hidden_states, **kwargs)
class LoRAAttnProcessor2_0(nn.Module):
def __init__(
self,
hidden_size: int,
cross_attention_dim: Optional[int] = None,
rank: int = 4,
network_alpha: Optional[int] = None,
**kwargs,
):
deprecation_message = "Using LoRAAttnProcessor is deprecated. Please use the PEFT backend for all things LoRA. You can install PEFT by running `pip install peft`."
deprecate("LoRAAttnProcessor2_0", "0.30.0", deprecation_message, standard_warn=False)
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.rank = rank
q_rank = kwargs.pop("q_rank", None)
q_hidden_size = kwargs.pop("q_hidden_size", None)
q_rank = q_rank if q_rank is not None else rank
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
v_rank = kwargs.pop("v_rank", None)
v_hidden_size = kwargs.pop("v_hidden_size", None)
v_rank = v_rank if v_rank is not None else rank
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
out_rank = kwargs.pop("out_rank", None)
out_hidden_size = kwargs.pop("out_hidden_size", None)
out_rank = out_rank if out_rank is not None else rank
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
self_cls_name = self.__class__.__name__
deprecate(
self_cls_name,
"0.26.0",
(
f"Make sure use {self_cls_name[4:]} instead by setting"
"LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
" `LoraLoaderMixin.load_lora_weights`"
),
)
attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
attn._modules.pop("processor")
attn.processor = AttnProcessor2_0()
return attn.processor(attn, hidden_states, **kwargs)
class LoRAXFormersAttnProcessor(nn.Module):
r"""
Processor for implementing the LoRA attention mechanism with memory efficient attention using xFormers.
Args:
hidden_size (`int`, *optional*):
The hidden size of the attention layer.
cross_attention_dim (`int`, *optional*):
The number of channels in the `encoder_hidden_states`.
rank (`int`, defaults to 4):
The dimension of the LoRA update matrices.
attention_op (`Callable`, *optional*, defaults to `None`):
The base
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
operator.
network_alpha (`int`, *optional*):
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
kwargs (`dict`):
Additional keyword arguments to pass to the `LoRALinearLayer` layers.
"""
def __init__(
self,
hidden_size: int,
cross_attention_dim: int,
rank: int = 4,
attention_op: Optional[Callable] = None,
network_alpha: Optional[int] = None,
**kwargs,
):
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.rank = rank
self.attention_op = attention_op
q_rank = kwargs.pop("q_rank", None)
q_hidden_size = kwargs.pop("q_hidden_size", None)
q_rank = q_rank if q_rank is not None else rank
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
v_rank = kwargs.pop("v_rank", None)
v_hidden_size = kwargs.pop("v_hidden_size", None)
v_rank = v_rank if v_rank is not None else rank
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
out_rank = kwargs.pop("out_rank", None)
out_hidden_size = kwargs.pop("out_hidden_size", None)
out_rank = out_rank if out_rank is not None else rank
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
self_cls_name = self.__class__.__name__
deprecate(
self_cls_name,
"0.26.0",
(
f"Make sure use {self_cls_name[4:]} instead by setting"
"LoRA layers to `self.{to_q,to_k,to_v,add_k_proj,add_v_proj,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
" `LoraLoaderMixin.load_lora_weights`"
),
)
attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
attn._modules.pop("processor")
attn.processor = XFormersAttnProcessor()
return attn.processor(attn, hidden_states, **kwargs)
class LoRAAttnAddedKVProcessor(nn.Module):
r"""
Processor for implementing the LoRA attention mechanism with extra learnable key and value matrices for the text
encoder.
Args:
hidden_size (`int`, *optional*):
The hidden size of the attention layer.
cross_attention_dim (`int`, *optional*, defaults to `None`):
The number of channels in the `encoder_hidden_states`.
rank (`int`, defaults to 4):
The dimension of the LoRA update matrices.
network_alpha (`int`, *optional*):
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
kwargs (`dict`):
Additional keyword arguments to pass to the `LoRALinearLayer` layers.
"""
def __init__(
self,
hidden_size: int,
cross_attention_dim: Optional[int] = None,
rank: int = 4,
network_alpha: Optional[int] = None,
):
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.rank = rank
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
self_cls_name = self.__class__.__name__
deprecate(
self_cls_name,
"0.26.0",
(
f"Make sure use {self_cls_name[4:]} instead by setting"
"LoRA layers to `self.{to_q,to_k,to_v,add_k_proj,add_v_proj,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
" `LoraLoaderMixin.load_lora_weights`"
),
)
attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
attn._modules.pop("processor")
attn.processor = AttnAddedKVProcessor()
return attn.processor(attn, hidden_states, **kwargs)
class IPAdapterAttnProcessor(nn.Module):
r"""
Attention processor for Multiple IP-Adapater.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
The context length of the image features.
scale (`float` or List[`float`], defaults to 1.0):
the weight scale of image prompt.
"""
def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
if not isinstance(num_tokens, (tuple, list)):
num_tokens = [num_tokens]
self.num_tokens = num_tokens
if not isinstance(scale, list):
scale = [scale] * len(num_tokens)
if len(scale) != len(num_tokens):
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
self.scale = scale
self.to_k_ip = nn.ModuleList(
[nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
)
self.to_v_ip = nn.ModuleList(
[nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
ip_adapter_masks: Optional[torch.FloatTensor] = None,
):
residual = hidden_states
# separate ip_hidden_states from encoder_hidden_states
if encoder_hidden_states is not None:
if isinstance(encoder_hidden_states, tuple):
encoder_hidden_states, ip_hidden_states = encoder_hidden_states
else:
deprecation_message = (
"You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release."
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning."
)
deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
[encoder_hidden_states[:, end_pos:, :]],
)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
if ip_adapter_masks is not None:
if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4:
raise ValueError(
" ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]."
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
)
if len(ip_adapter_masks) != len(self.scale):
raise ValueError(
f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})"
)
else:
ip_adapter_masks = [None] * len(self.scale)
# for ip-adapter
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
):
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)
ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
if mask is not None:
mask_downsample = IPAdapterMaskProcessor.downsample(
mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2]
)
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
current_ip_hidden_states = current_ip_hidden_states * mask_downsample
hidden_states = hidden_states + scale * current_ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class IPAdapterAttnProcessor2_0(torch.nn.Module):
r"""
Attention processor for IP-Adapater for PyTorch 2.0.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
The context length of the image features.
scale (`float` or `List[float]`, defaults to 1.0):
the weight scale of image prompt.
"""
def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
if not isinstance(num_tokens, (tuple, list)):
num_tokens = [num_tokens]
self.num_tokens = num_tokens
if not isinstance(scale, list):
scale = [scale] * len(num_tokens)
if len(scale) != len(num_tokens):
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
self.scale = scale
self.to_k_ip = nn.ModuleList(
[nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
)
self.to_v_ip = nn.ModuleList(
[nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
ip_adapter_masks: Optional[torch.FloatTensor] = None,
):
residual = hidden_states
# separate ip_hidden_states from encoder_hidden_states
if encoder_hidden_states is not None:
if isinstance(encoder_hidden_states, tuple):
encoder_hidden_states, ip_hidden_states = encoder_hidden_states
else:
deprecation_message = (
"You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release."
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning."
)
deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
[encoder_hidden_states[:, end_pos:, :]],
)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if ip_adapter_masks is not None:
if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4:
raise ValueError(
" ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]."
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
)
if len(ip_adapter_masks) != len(self.scale):
raise ValueError(
f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})"
)
else:
ip_adapter_masks = [None] * len(self.scale)
# for ip-adapter
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
):
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
current_ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
if mask is not None:
mask_downsample = IPAdapterMaskProcessor.downsample(
mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2]
)
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
current_ip_hidden_states = current_ip_hidden_states * mask_downsample
hidden_states = hidden_states + scale * current_ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
LORA_ATTENTION_PROCESSORS = (
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
LoRAAttnAddedKVProcessor,
)
ADDED_KV_ATTENTION_PROCESSORS = (
AttnAddedKVProcessor,
SlicedAttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
XFormersAttnAddedKVProcessor,
LoRAAttnAddedKVProcessor,
)
CROSS_ATTENTION_PROCESSORS = (
AttnProcessor,
AttnProcessor2_0,
XFormersAttnProcessor,
SlicedAttnProcessor,
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
)
AttentionProcessor = Union[
AttnProcessor,
AttnProcessor2_0,
FusedAttnProcessor2_0,
XFormersAttnProcessor,
SlicedAttnProcessor,
AttnAddedKVProcessor,
SlicedAttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
XFormersAttnAddedKVProcessor,
CustomDiffusionAttnProcessor,
CustomDiffusionXFormersAttnProcessor,
CustomDiffusionAttnProcessor2_0,
# deprecated
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
LoRAAttnAddedKVProcessor,
]
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
from .autoencoder_kl import AutoencoderKL
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
from .autoencoder_tiny import AutoencoderTiny
from .consistency_decoder_vae import ConsistencyDecoderVAE
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils.accelerate_utils import apply_forward_hook
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
r"""
Designing a Better Asymmetric VQGAN for StableDiffusion https://arxiv.org/abs/2306.04632 . A VAE model with KL loss
for encoding images into latents and decoding latent representations into images.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
Parameters:
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
Tuple of downsample block types.
down_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
Tuple of down block output channels.
layers_per_down_block (`int`, *optional*, defaults to `1`):
Number layers for down block.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
Tuple of upsample block types.
up_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
Tuple of up block output channels.
layers_per_up_block (`int`, *optional*, defaults to `1`):
Number layers for up block.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
norm_num_groups (`int`, *optional*, defaults to `32`):
Number of groups to use for the first normalization layer in ResNet blocks.
scaling_factor (`float`, *optional*, defaults to 0.18215):
The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
"""
@register_to_config
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
down_block_out_channels: Tuple[int, ...] = (64,),
layers_per_down_block: int = 1,
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
up_block_out_channels: Tuple[int, ...] = (64,),
layers_per_up_block: int = 1,
act_fn: str = "silu",
latent_channels: int = 4,
norm_num_groups: int = 32,
sample_size: int = 32,
scaling_factor: float = 0.18215,
) -> None:
super().__init__()
# pass init params to Encoder
self.encoder = Encoder(
in_channels=in_channels,
out_channels=latent_channels,
down_block_types=down_block_types,
block_out_channels=down_block_out_channels,
layers_per_block=layers_per_down_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
double_z=True,
)
# pass init params to Decoder
self.decoder = MaskConditionDecoder(
in_channels=latent_channels,
out_channels=out_channels,
up_block_types=up_block_types,
block_out_channels=up_block_out_channels,
layers_per_block=layers_per_up_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
)
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
self.use_slicing = False
self.use_tiling = False
self.register_to_config(block_out_channels=up_block_out_channels)
self.register_to_config(force_upcast=False)
@apply_forward_hook
def encode(
self, x: torch.FloatTensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[torch.FloatTensor]]:
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(
self,
z: torch.FloatTensor,
image: Optional[torch.FloatTensor] = None,
mask: Optional[torch.FloatTensor] = None,
return_dict: bool = True,
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
z = self.post_quant_conv(z)
dec = self.decoder(z, image, mask)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
@apply_forward_hook
def decode(
self,
z: torch.FloatTensor,
generator: Optional[torch.Generator] = None,
image: Optional[torch.FloatTensor] = None,
mask: Optional[torch.FloatTensor] = None,
return_dict: bool = True,
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
decoded = self._decode(z, image, mask).sample
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def forward(
self,
sample: torch.FloatTensor,
mask: Optional[torch.FloatTensor] = None,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
r"""
Args:
sample (`torch.FloatTensor`): Input sample.
mask (`torch.FloatTensor`, *optional*, defaults to `None`): Optional inpainting mask.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z, sample, mask).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalVAEMixin
from ...utils.accelerate_utils import apply_forward_hook
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
Attention,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
Parameters:
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
Tuple of downsample block types.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
Tuple of block output channels.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
scaling_factor (`float`, *optional*, defaults to 0.18215):
The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
force_upcast (`bool`, *optional*, default to `True`):
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
can be fine-tuned / trained to a lower range without loosing too much precision in which case
`force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
block_out_channels: Tuple[int] = (64,),
layers_per_block: int = 1,
act_fn: str = "silu",
latent_channels: int = 4,
norm_num_groups: int = 32,
sample_size: int = 32,
scaling_factor: float = 0.18215,
latents_mean: Optional[Tuple[float]] = None,
latents_std: Optional[Tuple[float]] = None,
force_upcast: float = True,
):
super().__init__()
# pass init params to Encoder
self.encoder = Encoder(
in_channels=in_channels,
out_channels=latent_channels,
down_block_types=down_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
double_z=True,
)
# pass init params to Decoder
self.decoder = Decoder(
in_channels=latent_channels,
out_channels=out_channels,
up_block_types=up_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
act_fn=act_fn,
)
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
self.use_slicing = False
self.use_tiling = False
# only relevant if vae tiling is enabled
self.tile_sample_min_size = self.config.sample_size
sample_size = (
self.config.sample_size[0]
if isinstance(self.config.sample_size, (list, tuple))
else self.config.sample_size
)
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (Encoder, Decoder)):
module.gradient_checkpointing = value
def enable_tiling(self, use_tiling: bool = True):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.use_tiling = use_tiling
def disable_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.enable_tiling(False)
def enable_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnAddedKVProcessor()
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnProcessor()
else:
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor)
@apply_forward_hook
def encode(
self, x: torch.FloatTensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of images into latents.
Args:
x (`torch.FloatTensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded images. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
return self.tiled_encode(x, return_dict=return_dict)
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
return self.tiled_decode(z, return_dict=return_dict)
z = self.post_quant_conv(z)
dec = self.decoder(z)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
@apply_forward_hook
def decode(
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
) -> Union[DecoderOutput, torch.FloatTensor]:
"""
Decode a batch of images.
Args:
z (`torch.FloatTensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
for y in range(blend_extent):
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for x in range(blend_extent):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
return b
def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
r"""Encode a batch of images using a tiled encoder.
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
output, but they should be much less noticeable.
Args:
x (`torch.FloatTensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
[`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
`tuple` is returned.
"""
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent
# Split the image into 512x512 tiles and encode them separately.
rows = []
for i in range(0, x.shape[2], overlap_size):
row = []
for j in range(0, x.shape[3], overlap_size):
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
tile = self.encoder(tile)
tile = self.quant_conv(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=3))
moments = torch.cat(result_rows, dim=2)
posterior = DiagonalGaussianDistribution(moments)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
r"""
Decode a batch of images using a tiled decoder.
Args:
z (`torch.FloatTensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
row_limit = self.tile_sample_min_size - blend_extent
# Split z into overlapping 64x64 tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, z.shape[2], overlap_size):
row = []
for j in range(0, z.shape[3], overlap_size):
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
tile = self.post_quant_conv(tile)
decoded = self.decoder(tile)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=3))
dec = torch.cat(result_rows, dim=2)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(
self,
sample: torch.FloatTensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.FloatTensor]:
r"""
Args:
sample (`torch.FloatTensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torch_version
from ...utils.accelerate_utils import apply_forward_hook
from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from ..unets.unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
class TemporalDecoder(nn.Module):
def __init__(
self,
in_channels: int = 4,
out_channels: int = 3,
block_out_channels: Tuple[int] = (128, 256, 512, 512),
layers_per_block: int = 2,
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
self.mid_block = MidBlockTemporalDecoder(
num_layers=self.layers_per_block,
in_channels=block_out_channels[-1],
out_channels=block_out_channels[-1],
attention_head_dim=block_out_channels[-1],
)
# up
self.up_blocks = nn.ModuleList([])
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i in range(len(block_out_channels)):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
up_block = UpBlockTemporalDecoder(
num_layers=self.layers_per_block + 1,
in_channels=prev_output_channel,
out_channels=output_channel,
add_upsample=not is_final_block,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = torch.nn.Conv2d(
in_channels=block_out_channels[0],
out_channels=out_channels,
kernel_size=3,
padding=1,
)
conv_out_kernel_size = (3, 1, 1)
padding = [int(k // 2) for k in conv_out_kernel_size]
self.time_conv_out = torch.nn.Conv3d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=conv_out_kernel_size,
padding=padding,
)
self.gradient_checkpointing = False
def forward(
self,
sample: torch.FloatTensor,
image_only_indicator: torch.FloatTensor,
num_frames: int = 1,
) -> torch.FloatTensor:
r"""The forward method of the `Decoder` class."""
sample = self.conv_in(sample)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
sample,
image_only_indicator,
use_reentrant=False,
)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block),
sample,
image_only_indicator,
use_reentrant=False,
)
else:
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
sample,
image_only_indicator,
)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block),
sample,
image_only_indicator,
)
else:
# middle
sample = self.mid_block(sample, image_only_indicator=image_only_indicator)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = up_block(sample, image_only_indicator=image_only_indicator)
# post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
batch_frames, channels, height, width = sample.shape
batch_size = batch_frames // num_frames
sample = sample[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
sample = self.time_conv_out(sample)
sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width)
return sample
class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
Parameters:
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
Tuple of downsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
Tuple of block output channels.
layers_per_block: (`int`, *optional*, defaults to 1): Number of layers per block.
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
scaling_factor (`float`, *optional*, defaults to 0.18215):
The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
force_upcast (`bool`, *optional*, default to `True`):
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
can be fine-tuned / trained to a lower range without loosing too much precision in which case
`force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
block_out_channels: Tuple[int] = (64,),
layers_per_block: int = 1,
latent_channels: int = 4,
sample_size: int = 32,
scaling_factor: float = 0.18215,
force_upcast: float = True,
):
super().__init__()
# pass init params to Encoder
self.encoder = Encoder(
in_channels=in_channels,
out_channels=latent_channels,
down_block_types=down_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
double_z=True,
)
# pass init params to Decoder
self.decoder = TemporalDecoder(
in_channels=latent_channels,
out_channels=out_channels,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
)
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
sample_size = (
self.config.sample_size[0]
if isinstance(self.config.sample_size, (list, tuple))
else self.config.sample_size
)
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (Encoder, TemporalDecoder)):
module.gradient_checkpointing = value
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnProcessor()
else:
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor)
@apply_forward_hook
def encode(
self, x: torch.FloatTensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of images into latents.
Args:
x (`torch.FloatTensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded images. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
@apply_forward_hook
def decode(
self,
z: torch.FloatTensor,
num_frames: int,
return_dict: bool = True,
) -> Union[DecoderOutput, torch.FloatTensor]:
"""
Decode a batch of images.
Args:
z (`torch.FloatTensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
batch_size = z.shape[0] // num_frames
image_only_indicator = torch.zeros(batch_size, num_frames, dtype=z.dtype, device=z.device)
decoded = self.decoder(z, num_frames=num_frames, image_only_indicator=image_only_indicator)
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def forward(
self,
sample: torch.FloatTensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
num_frames: int = 1,
) -> Union[DecoderOutput, torch.FloatTensor]:
r"""
Args:
sample (`torch.FloatTensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z, num_frames=num_frames).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
# Copyright 2024 Ollin Boer Bohan and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import BaseOutput
from ...utils.accelerate_utils import apply_forward_hook
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DecoderTiny, EncoderTiny
@dataclass
class AutoencoderTinyOutput(BaseOutput):
"""
Output of AutoencoderTiny encoding method.
Args:
latents (`torch.Tensor`): Encoded outputs of the `Encoder`.
"""
latents: torch.Tensor
class AutoencoderTiny(ModelMixin, ConfigMixin):
r"""
A tiny distilled VAE model for encoding images into latents and decoding latent representations into images.
[`AutoencoderTiny`] is a wrapper around the original implementation of `TAESD`.
This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for
all models (such as downloading or saving).
Parameters:
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
encoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`):
Tuple of integers representing the number of output channels for each encoder block. The length of the
tuple should be equal to the number of encoder blocks.
decoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`):
Tuple of integers representing the number of output channels for each decoder block. The length of the
tuple should be equal to the number of decoder blocks.
act_fn (`str`, *optional*, defaults to `"relu"`):
Activation function to be used throughout the model.
latent_channels (`int`, *optional*, defaults to 4):
Number of channels in the latent representation. The latent space acts as a compressed representation of
the input image.
upsampling_scaling_factor (`int`, *optional*, defaults to 2):
Scaling factor for upsampling in the decoder. It determines the size of the output image during the
upsampling process.
num_encoder_blocks (`Tuple[int]`, *optional*, defaults to `(1, 3, 3, 3)`):
Tuple of integers representing the number of encoder blocks at each stage of the encoding process. The
length of the tuple should be equal to the number of stages in the encoder. Each stage has a different
number of encoder blocks.
num_decoder_blocks (`Tuple[int]`, *optional*, defaults to `(3, 3, 3, 1)`):
Tuple of integers representing the number of decoder blocks at each stage of the decoding process. The
length of the tuple should be equal to the number of stages in the decoder. Each stage has a different
number of decoder blocks.
latent_magnitude (`float`, *optional*, defaults to 3.0):
Magnitude of the latent representation. This parameter scales the latent representation values to control
the extent of information preservation.
latent_shift (float, *optional*, defaults to 0.5):
Shift applied to the latent representation. This parameter controls the center of the latent space.
scaling_factor (`float`, *optional*, defaults to 1.0):
The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. For this Autoencoder,
however, no such scaling factor was used, hence the value of 1.0 as the default.
force_upcast (`bool`, *optional*, default to `False`):
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
can be fine-tuned / trained to a lower range without losing too much precision, in which case
`force_upcast` can be set to `False` (see this fp16-friendly
[AutoEncoder](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
encoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
decoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
act_fn: str = "relu",
latent_channels: int = 4,
upsampling_scaling_factor: int = 2,
num_encoder_blocks: Tuple[int, ...] = (1, 3, 3, 3),
num_decoder_blocks: Tuple[int, ...] = (3, 3, 3, 1),
latent_magnitude: int = 3,
latent_shift: float = 0.5,
force_upcast: bool = False,
scaling_factor: float = 1.0,
):
super().__init__()
if len(encoder_block_out_channels) != len(num_encoder_blocks):
raise ValueError("`encoder_block_out_channels` should have the same length as `num_encoder_blocks`.")
if len(decoder_block_out_channels) != len(num_decoder_blocks):
raise ValueError("`decoder_block_out_channels` should have the same length as `num_decoder_blocks`.")
self.encoder = EncoderTiny(
in_channels=in_channels,
out_channels=latent_channels,
num_blocks=num_encoder_blocks,
block_out_channels=encoder_block_out_channels,
act_fn=act_fn,
)
self.decoder = DecoderTiny(
in_channels=latent_channels,
out_channels=out_channels,
num_blocks=num_decoder_blocks,
block_out_channels=decoder_block_out_channels,
upsampling_scaling_factor=upsampling_scaling_factor,
act_fn=act_fn,
)
self.latent_magnitude = latent_magnitude
self.latent_shift = latent_shift
self.scaling_factor = scaling_factor
self.use_slicing = False
self.use_tiling = False
# only relevant if vae tiling is enabled
self.spatial_scale_factor = 2**out_channels
self.tile_overlap_factor = 0.125
self.tile_sample_min_size = 512
self.tile_latent_min_size = self.tile_sample_min_size // self.spatial_scale_factor
self.register_to_config(block_out_channels=decoder_block_out_channels)
self.register_to_config(force_upcast=False)
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
if isinstance(module, (EncoderTiny, DecoderTiny)):
module.gradient_checkpointing = value
def scale_latents(self, x: torch.FloatTensor) -> torch.FloatTensor:
"""raw latents -> [0, 1]"""
return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1)
def unscale_latents(self, x: torch.FloatTensor) -> torch.FloatTensor:
"""[0, 1] -> raw latents"""
return x.sub(self.latent_shift).mul(2 * self.latent_magnitude)
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def enable_tiling(self, use_tiling: bool = True) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.use_tiling = use_tiling
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.enable_tiling(False)
def _tiled_encode(self, x: torch.FloatTensor) -> torch.FloatTensor:
r"""Encode a batch of images using a tiled encoder.
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output.
Args:
x (`torch.FloatTensor`): Input batch of images.
Returns:
`torch.FloatTensor`: Encoded batch of images.
"""
# scale of encoder output relative to input
sf = self.spatial_scale_factor
tile_size = self.tile_sample_min_size
# number of pixels to blend and to traverse between tile
blend_size = int(tile_size * self.tile_overlap_factor)
traverse_size = tile_size - blend_size
# tiles index (up/left)
ti = range(0, x.shape[-2], traverse_size)
tj = range(0, x.shape[-1], traverse_size)
# mask for blending
blend_masks = torch.stack(
torch.meshgrid([torch.arange(tile_size / sf) / (blend_size / sf - 1)] * 2, indexing="ij")
)
blend_masks = blend_masks.clamp(0, 1).to(x.device)
# output array
out = torch.zeros(x.shape[0], 4, x.shape[-2] // sf, x.shape[-1] // sf, device=x.device)
for i in ti:
for j in tj:
tile_in = x[..., i : i + tile_size, j : j + tile_size]
# tile result
tile_out = out[..., i // sf : (i + tile_size) // sf, j // sf : (j + tile_size) // sf]
tile = self.encoder(tile_in)
h, w = tile.shape[-2], tile.shape[-1]
# blend tile result into output
blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0]
blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1]
blend_mask = blend_mask_i * blend_mask_j
tile, blend_mask = tile[..., :h, :w], blend_mask[..., :h, :w]
tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out)
return out
def _tiled_decode(self, x: torch.FloatTensor) -> torch.FloatTensor:
r"""Encode a batch of images using a tiled encoder.
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output.
Args:
x (`torch.FloatTensor`): Input batch of images.
Returns:
`torch.FloatTensor`: Encoded batch of images.
"""
# scale of decoder output relative to input
sf = self.spatial_scale_factor
tile_size = self.tile_latent_min_size
# number of pixels to blend and to traverse between tiles
blend_size = int(tile_size * self.tile_overlap_factor)
traverse_size = tile_size - blend_size
# tiles index (up/left)
ti = range(0, x.shape[-2], traverse_size)
tj = range(0, x.shape[-1], traverse_size)
# mask for blending
blend_masks = torch.stack(
torch.meshgrid([torch.arange(tile_size * sf) / (blend_size * sf - 1)] * 2, indexing="ij")
)
blend_masks = blend_masks.clamp(0, 1).to(x.device)
# output array
out = torch.zeros(x.shape[0], 3, x.shape[-2] * sf, x.shape[-1] * sf, device=x.device)
for i in ti:
for j in tj:
tile_in = x[..., i : i + tile_size, j : j + tile_size]
# tile result
tile_out = out[..., i * sf : (i + tile_size) * sf, j * sf : (j + tile_size) * sf]
tile = self.decoder(tile_in)
h, w = tile.shape[-2], tile.shape[-1]
# blend tile result into output
blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0]
blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1]
blend_mask = (blend_mask_i * blend_mask_j)[..., :h, :w]
tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out)
return out
@apply_forward_hook
def encode(
self, x: torch.FloatTensor, return_dict: bool = True
) -> Union[AutoencoderTinyOutput, Tuple[torch.FloatTensor]]:
if self.use_slicing and x.shape[0] > 1:
output = [
self._tiled_encode(x_slice) if self.use_tiling else self.encoder(x_slice) for x_slice in x.split(1)
]
output = torch.cat(output)
else:
output = self._tiled_encode(x) if self.use_tiling else self.encoder(x)
if not return_dict:
return (output,)
return AutoencoderTinyOutput(latents=output)
@apply_forward_hook
def decode(
self, x: torch.FloatTensor, generator: Optional[torch.Generator] = None, return_dict: bool = True
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
if self.use_slicing and x.shape[0] > 1:
output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)]
output = torch.cat(output)
else:
output = self._tiled_decode(x) if self.use_tiling else self.decoder(x)
if not return_dict:
return (output,)
return DecoderOutput(sample=output)
def forward(
self,
sample: torch.FloatTensor,
return_dict: bool = True,
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
r"""
Args:
sample (`torch.FloatTensor`): Input sample.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
enc = self.encode(sample).latents
# scale latents to be in [0, 1], then quantize latents to a byte tensor,
# as if we were storing the latents in an RGBA uint8 image.
scaled_enc = self.scale_latents(enc).mul_(255).round_().byte()
# unquantize latents back into [0, 1], then unscale latents back to their original range,
# as if we were loading the latents from an RGBA uint8 image.
unscaled_enc = self.unscale_latents(scaled_enc / 255.0)
dec = self.decode(unscaled_enc)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...schedulers import ConsistencyDecoderScheduler
from ...utils import BaseOutput
from ...utils.accelerate_utils import apply_forward_hook
from ...utils.torch_utils import randn_tensor
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
from ..modeling_utils import ModelMixin
from ..unets.unet_2d import UNet2DModel
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
@dataclass
class ConsistencyDecoderVAEOutput(BaseOutput):
"""
Output of encoding method.
Args:
latent_dist (`DiagonalGaussianDistribution`):
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
`DiagonalGaussianDistribution` allows for sampling latents from the distribution.
"""
latent_dist: "DiagonalGaussianDistribution"
class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
r"""
The consistency decoder used with DALL-E 3.
Examples:
```py
>>> import torch
>>> from diffusers import StableDiffusionPipeline, ConsistencyDecoderVAE
>>> vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16)
>>> pipe = StableDiffusionPipeline.from_pretrained(
... "runwayml/stable-diffusion-v1-5", vae=vae, torch_dtype=torch.float16
... ).to("cuda")
>>> pipe("horse", generator=torch.manual_seed(0)).images
```
"""
@register_to_config
def __init__(
self,
scaling_factor: float = 0.18215,
latent_channels: int = 4,
encoder_act_fn: str = "silu",
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
encoder_double_z: bool = True,
encoder_down_block_types: Tuple[str, ...] = (
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D",
),
encoder_in_channels: int = 3,
encoder_layers_per_block: int = 2,
encoder_norm_num_groups: int = 32,
encoder_out_channels: int = 4,
decoder_add_attention: bool = False,
decoder_block_out_channels: Tuple[int, ...] = (320, 640, 1024, 1024),
decoder_down_block_types: Tuple[str, ...] = (
"ResnetDownsampleBlock2D",
"ResnetDownsampleBlock2D",
"ResnetDownsampleBlock2D",
"ResnetDownsampleBlock2D",
),
decoder_downsample_padding: int = 1,
decoder_in_channels: int = 7,
decoder_layers_per_block: int = 3,
decoder_norm_eps: float = 1e-05,
decoder_norm_num_groups: int = 32,
decoder_num_train_timesteps: int = 1024,
decoder_out_channels: int = 6,
decoder_resnet_time_scale_shift: str = "scale_shift",
decoder_time_embedding_type: str = "learned",
decoder_up_block_types: Tuple[str, ...] = (
"ResnetUpsampleBlock2D",
"ResnetUpsampleBlock2D",
"ResnetUpsampleBlock2D",
"ResnetUpsampleBlock2D",
),
):
super().__init__()
self.encoder = Encoder(
act_fn=encoder_act_fn,
block_out_channels=encoder_block_out_channels,
double_z=encoder_double_z,
down_block_types=encoder_down_block_types,
in_channels=encoder_in_channels,
layers_per_block=encoder_layers_per_block,
norm_num_groups=encoder_norm_num_groups,
out_channels=encoder_out_channels,
)
self.decoder_unet = UNet2DModel(
add_attention=decoder_add_attention,
block_out_channels=decoder_block_out_channels,
down_block_types=decoder_down_block_types,
downsample_padding=decoder_downsample_padding,
in_channels=decoder_in_channels,
layers_per_block=decoder_layers_per_block,
norm_eps=decoder_norm_eps,
norm_num_groups=decoder_norm_num_groups,
num_train_timesteps=decoder_num_train_timesteps,
out_channels=decoder_out_channels,
resnet_time_scale_shift=decoder_resnet_time_scale_shift,
time_embedding_type=decoder_time_embedding_type,
up_block_types=decoder_up_block_types,
)
self.decoder_scheduler = ConsistencyDecoderScheduler()
self.register_to_config(block_out_channels=encoder_block_out_channels)
self.register_to_config(force_upcast=False)
self.register_buffer(
"means",
torch.tensor([0.38862467, 0.02253063, 0.07381133, -0.0171294])[None, :, None, None],
persistent=False,
)
self.register_buffer(
"stds", torch.tensor([0.9654121, 1.0440036, 0.76147926, 0.77022034])[None, :, None, None], persistent=False
)
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
self.use_slicing = False
self.use_tiling = False
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_tiling
def enable_tiling(self, use_tiling: bool = True):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.use_tiling = use_tiling
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.disable_tiling
def disable_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.enable_tiling(False)
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_slicing
def enable_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.disable_slicing
def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnAddedKVProcessor()
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnProcessor()
else:
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor)
@apply_forward_hook
def encode(
self, x: torch.FloatTensor, return_dict: bool = True
) -> Union[ConsistencyDecoderVAEOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of images into latents.
Args:
x (`torch.FloatTensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.consistecy_decoder_vae.ConsistencyDecoderOoutput`] instead of a plain
tuple.
Returns:
The latent representations of the encoded images. If `return_dict` is True, a
[`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] is returned, otherwise a plain `tuple`
is returned.
"""
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
return self.tiled_encode(x, return_dict=return_dict)
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
if not return_dict:
return (posterior,)
return ConsistencyDecoderVAEOutput(latent_dist=posterior)
@apply_forward_hook
def decode(
self,
z: torch.FloatTensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
num_inference_steps: int = 2,
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
z = (z * self.config.scaling_factor - self.means) / self.stds
scale_factor = 2 ** (len(self.config.block_out_channels) - 1)
z = F.interpolate(z, mode="nearest", scale_factor=scale_factor)
batch_size, _, height, width = z.shape
self.decoder_scheduler.set_timesteps(num_inference_steps, device=self.device)
x_t = self.decoder_scheduler.init_noise_sigma * randn_tensor(
(batch_size, 3, height, width), generator=generator, dtype=z.dtype, device=z.device
)
for t in self.decoder_scheduler.timesteps:
model_input = torch.concat([self.decoder_scheduler.scale_model_input(x_t, t), z], dim=1)
model_output = self.decoder_unet(model_input, t).sample[:, :3, :, :]
prev_sample = self.decoder_scheduler.step(model_output, t, x_t, generator).prev_sample
x_t = prev_sample
x_0 = x_t
if not return_dict:
return (x_0,)
return DecoderOutput(sample=x_0)
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.blend_v
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
for y in range(blend_extent):
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
return b
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.blend_h
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for x in range(blend_extent):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
return b
def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> ConsistencyDecoderVAEOutput:
r"""Encode a batch of images using a tiled encoder.
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
output, but they should be much less noticeable.
Args:
x (`torch.FloatTensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] instead of a
plain tuple.
Returns:
[`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] or `tuple`:
If return_dict is True, a [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] is returned,
otherwise a plain `tuple` is returned.
"""
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent
# Split the image into 512x512 tiles and encode them separately.
rows = []
for i in range(0, x.shape[2], overlap_size):
row = []
for j in range(0, x.shape[3], overlap_size):
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
tile = self.encoder(tile)
tile = self.quant_conv(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=3))
moments = torch.cat(result_rows, dim=2)
posterior = DiagonalGaussianDistribution(moments)
if not return_dict:
return (posterior,)
return ConsistencyDecoderVAEOutput(latent_dist=posterior)
def forward(
self,
sample: torch.FloatTensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
r"""
Args:
sample (`torch.FloatTensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
generator (`torch.Generator`, *optional*, defaults to `None`):
Generator to use for sampling.
Returns:
[`DecoderOutput`] or `tuple`:
If return_dict is True, a [`DecoderOutput`] is returned, otherwise a plain `tuple` is returned.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z, generator=generator).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from ...utils import BaseOutput, is_torch_version
from ...utils.torch_utils import randn_tensor
from ..activations import get_activation
from ..attention_processor import SpatialNorm
from ..unets.unet_2d_blocks import (
AutoencoderTinyBlock,
UNetMidBlock2D,
get_down_block,
get_up_block,
)
@dataclass
class DecoderOutput(BaseOutput):
r"""
Output of decoding method.
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
The decoded output sample from the last layer of the model.
"""
sample: torch.FloatTensor
class Encoder(nn.Module):
r"""
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
Args:
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
options.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
double_z (`bool`, *optional*, defaults to `True`):
Whether to double the number of output channels for the last block.
"""
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
double_z: bool = True,
mid_block_add_attention=True,
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = nn.Conv2d(
in_channels,
block_out_channels[0],
kernel_size=3,
stride=1,
padding=1,
)
self.mid_block = None
self.down_blocks = nn.ModuleList([])
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=self.layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
add_downsample=not is_final_block,
resnet_eps=1e-6,
downsample_padding=0,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attention_head_dim=output_channel,
temb_channels=None,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default",
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=None,
add_attention=mid_block_add_attention,
)
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
conv_out_channels = 2 * out_channels if double_z else out_channels
self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
self.gradient_checkpointing = False
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class."""
sample = self.conv_in(sample)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# down
if is_torch_version(">=", "1.11.0"):
for down_block in self.down_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(down_block), sample, use_reentrant=False
)
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, use_reentrant=False
)
else:
for down_block in self.down_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
# middle
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
else:
# down
for down_block in self.down_blocks:
sample = down_block(sample)
# middle
sample = self.mid_block(sample)
# post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
class Decoder(nn.Module):
r"""
The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
Args:
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
norm_type (`str`, *optional*, defaults to `"group"`):
The normalization type to use. Can be either `"group"` or `"spatial"`.
"""
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
norm_type: str = "group", # group, spatial
mid_block_add_attention=True,
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = nn.Conv2d(
in_channels,
block_out_channels[-1],
kernel_size=3,
stride=1,
padding=1,
)
self.mid_block = None
self.up_blocks = nn.ModuleList([])
temb_channels = in_channels if norm_type == "spatial" else None
# mid
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=temb_channels,
add_attention=mid_block_add_attention,
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
up_block = get_up_block(
up_block_type,
num_layers=self.layers_per_block + 1,
in_channels=prev_output_channel,
out_channels=output_channel,
prev_output_channel=None,
add_upsample=not is_final_block,
resnet_eps=1e-6,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attention_head_dim=output_channel,
temb_channels=temb_channels,
resnet_time_scale_shift=norm_type,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
if norm_type == "spatial":
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
else:
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
self.gradient_checkpointing = False
def forward(
self,
sample: torch.FloatTensor,
latent_embeds: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
r"""The forward method of the `Decoder` class."""
sample = self.conv_in(sample)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
sample,
latent_embeds,
use_reentrant=False,
)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block),
sample,
latent_embeds,
use_reentrant=False,
)
else:
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, latent_embeds
)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
else:
# middle
sample = self.mid_block(sample, latent_embeds)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = up_block(sample, latent_embeds)
# post-process
if latent_embeds is None:
sample = self.conv_norm_out(sample)
else:
sample = self.conv_norm_out(sample, latent_embeds)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
class UpSample(nn.Module):
r"""
The `UpSample` layer of a variational autoencoder that upsamples its input.
Args:
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `UpSample` class."""
x = torch.relu(x)
x = self.deconv(x)
return x
class MaskConditionEncoder(nn.Module):
"""
used in AsymmetricAutoencoderKL
"""
def __init__(
self,
in_ch: int,
out_ch: int = 192,
res_ch: int = 768,
stride: int = 16,
) -> None:
super().__init__()
channels = []
while stride > 1:
stride = stride // 2
in_ch_ = out_ch * 2
if out_ch > res_ch:
out_ch = res_ch
if stride == 1:
in_ch_ = res_ch
channels.append((in_ch_, out_ch))
out_ch *= 2
out_channels = []
for _in_ch, _out_ch in channels:
out_channels.append(_out_ch)
out_channels.append(channels[-1][0])
layers = []
in_ch_ = in_ch
for l in range(len(out_channels)):
out_ch_ = out_channels[l]
if l == 0 or l == 1:
layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=3, stride=1, padding=1))
else:
layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=4, stride=2, padding=1))
in_ch_ = out_ch_
self.layers = nn.Sequential(*layers)
def forward(self, x: torch.FloatTensor, mask=None) -> torch.FloatTensor:
r"""The forward method of the `MaskConditionEncoder` class."""
out = {}
for l in range(len(self.layers)):
layer = self.layers[l]
x = layer(x)
out[str(tuple(x.shape))] = x
x = torch.relu(x)
return out
class MaskConditionDecoder(nn.Module):
r"""The `MaskConditionDecoder` should be used in combination with [`AsymmetricAutoencoderKL`] to enhance the model's
decoder with a conditioner on the mask and masked image.
Args:
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
norm_type (`str`, *optional*, defaults to `"group"`):
The normalization type to use. Can be either `"group"` or `"spatial"`.
"""
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
norm_type: str = "group", # group, spatial
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = nn.Conv2d(
in_channels,
block_out_channels[-1],
kernel_size=3,
stride=1,
padding=1,
)
self.mid_block = None
self.up_blocks = nn.ModuleList([])
temb_channels = in_channels if norm_type == "spatial" else None
# mid
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=temb_channels,
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
up_block = get_up_block(
up_block_type,
num_layers=self.layers_per_block + 1,
in_channels=prev_output_channel,
out_channels=output_channel,
prev_output_channel=None,
add_upsample=not is_final_block,
resnet_eps=1e-6,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attention_head_dim=output_channel,
temb_channels=temb_channels,
resnet_time_scale_shift=norm_type,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# condition encoder
self.condition_encoder = MaskConditionEncoder(
in_ch=out_channels,
out_ch=block_out_channels[0],
res_ch=block_out_channels[-1],
)
# out
if norm_type == "spatial":
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
else:
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
self.gradient_checkpointing = False
def forward(
self,
z: torch.FloatTensor,
image: Optional[torch.FloatTensor] = None,
mask: Optional[torch.FloatTensor] = None,
latent_embeds: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
r"""The forward method of the `MaskConditionDecoder` class."""
sample = z
sample = self.conv_in(sample)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
sample,
latent_embeds,
use_reentrant=False,
)
sample = sample.to(upscale_dtype)
# condition encoder
if image is not None and mask is not None:
masked_image = (1 - mask) * image
im_x = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.condition_encoder),
masked_image,
mask,
use_reentrant=False,
)
# up
for up_block in self.up_blocks:
if image is not None and mask is not None:
sample_ = im_x[str(tuple(sample.shape))]
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
sample = sample * mask_ + sample_ * (1 - mask_)
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block),
sample,
latent_embeds,
use_reentrant=False,
)
if image is not None and mask is not None:
sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
else:
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, latent_embeds
)
sample = sample.to(upscale_dtype)
# condition encoder
if image is not None and mask is not None:
masked_image = (1 - mask) * image
im_x = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.condition_encoder),
masked_image,
mask,
)
# up
for up_block in self.up_blocks:
if image is not None and mask is not None:
sample_ = im_x[str(tuple(sample.shape))]
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
sample = sample * mask_ + sample_ * (1 - mask_)
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
if image is not None and mask is not None:
sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
else:
# middle
sample = self.mid_block(sample, latent_embeds)
sample = sample.to(upscale_dtype)
# condition encoder
if image is not None and mask is not None:
masked_image = (1 - mask) * image
im_x = self.condition_encoder(masked_image, mask)
# up
for up_block in self.up_blocks:
if image is not None and mask is not None:
sample_ = im_x[str(tuple(sample.shape))]
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
sample = sample * mask_ + sample_ * (1 - mask_)
sample = up_block(sample, latent_embeds)
if image is not None and mask is not None:
sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
# post-process
if latent_embeds is None:
sample = self.conv_norm_out(sample)
else:
sample = self.conv_norm_out(sample, latent_embeds)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
class VectorQuantizer(nn.Module):
"""
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix
multiplications and allows for post-hoc remapping of indices.
"""
# NOTE: due to a bug the beta term was applied to the wrong term. for
# backwards compatibility we use the buggy version by default, but you can
# specify legacy=False to fix it.
def __init__(
self,
n_e: int,
vq_embed_dim: int,
beta: float,
remap=None,
unknown_index: str = "random",
sane_index_shape: bool = False,
legacy: bool = True,
):
super().__init__()
self.n_e = n_e
self.vq_embed_dim = vq_embed_dim
self.beta = beta
self.legacy = legacy
self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim)
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
self.remap = remap
if self.remap is not None:
self.register_buffer("used", torch.tensor(np.load(self.remap)))
self.used: torch.Tensor
self.re_embed = self.used.shape[0]
self.unknown_index = unknown_index # "random" or "extra" or integer
if self.unknown_index == "extra":
self.unknown_index = self.re_embed
self.re_embed = self.re_embed + 1
print(
f"Remapping {self.n_e} indices to {self.re_embed} indices. "
f"Using {self.unknown_index} for unknown indices."
)
else:
self.re_embed = n_e
self.sane_index_shape = sane_index_shape
def remap_to_used(self, inds: torch.LongTensor) -> torch.LongTensor:
ishape = inds.shape
assert len(ishape) > 1
inds = inds.reshape(ishape[0], -1)
used = self.used.to(inds)
match = (inds[:, :, None] == used[None, None, ...]).long()
new = match.argmax(-1)
unknown = match.sum(2) < 1
if self.unknown_index == "random":
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
else:
new[unknown] = self.unknown_index
return new.reshape(ishape)
def unmap_to_all(self, inds: torch.LongTensor) -> torch.LongTensor:
ishape = inds.shape
assert len(ishape) > 1
inds = inds.reshape(ishape[0], -1)
used = self.used.to(inds)
if self.re_embed > self.used.shape[0]: # extra token
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
return back.reshape(ishape)
def forward(self, z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, Tuple]:
# reshape z -> (batch, height, width, channel) and flatten
z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.vq_embed_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
min_encoding_indices = torch.argmin(torch.cdist(z_flattened, self.embedding.weight), dim=1)
z_q = self.embedding(min_encoding_indices).view(z.shape)
perplexity = None
min_encodings = None
# compute loss for embedding
if not self.legacy:
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
else:
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
# preserve gradients
z_q: torch.FloatTensor = z + (z_q - z).detach()
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
if self.remap is not None:
min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
min_encoding_indices = self.remap_to_used(min_encoding_indices)
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
if self.sane_index_shape:
min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...]) -> torch.FloatTensor:
# shape specifying (batch, height, width, channel)
if self.remap is not None:
indices = indices.reshape(shape[0], -1) # add batch axis
indices = self.unmap_to_all(indices)
indices = indices.reshape(-1) # flatten again
# get quantized latent vectors
z_q: torch.FloatTensor = self.embedding(indices)
if shape is not None:
z_q = z_q.view(shape)
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return z_q
class DiagonalGaussianDistribution(object):
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
)
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
# make sure sample is on the same device as the parameters and has same dtype
sample = randn_tensor(
self.mean.shape,
generator=generator,
device=self.parameters.device,
dtype=self.parameters.dtype,
)
x = self.mean + self.std * sample
return x
def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
if self.deterministic:
return torch.Tensor([0.0])
else:
if other is None:
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3],
)
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var
- 1.0
- self.logvar
+ other.logvar,
dim=[1, 2, 3],
)
def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
if self.deterministic:
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims,
)
def mode(self) -> torch.Tensor:
return self.mean
class EncoderTiny(nn.Module):
r"""
The `EncoderTiny` layer is a simpler version of the `Encoder` layer.
Args:
in_channels (`int`):
The number of input channels.
out_channels (`int`):
The number of output channels.
num_blocks (`Tuple[int, ...]`):
Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to
use.
block_out_channels (`Tuple[int, ...]`):
The number of output channels for each block.
act_fn (`str`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
num_blocks: Tuple[int, ...],
block_out_channels: Tuple[int, ...],
act_fn: str,
):
super().__init__()
layers = []
for i, num_block in enumerate(num_blocks):
num_channels = block_out_channels[i]
if i == 0:
layers.append(nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1))
else:
layers.append(
nn.Conv2d(
num_channels,
num_channels,
kernel_size=3,
padding=1,
stride=2,
bias=False,
)
)
for _ in range(num_block):
layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))
layers.append(nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, padding=1))
self.layers = nn.Sequential(*layers)
self.gradient_checkpointing = False
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `EncoderTiny` class."""
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
else:
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
else:
# scale image from [-1, 1] to [0, 1] to match TAESD convention
x = self.layers(x.add(1).div(2))
return x
class DecoderTiny(nn.Module):
r"""
The `DecoderTiny` layer is a simpler version of the `Decoder` layer.
Args:
in_channels (`int`):
The number of input channels.
out_channels (`int`):
The number of output channels.
num_blocks (`Tuple[int, ...]`):
Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to
use.
block_out_channels (`Tuple[int, ...]`):
The number of output channels for each block.
upsampling_scaling_factor (`int`):
The scaling factor to use for upsampling.
act_fn (`str`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
num_blocks: Tuple[int, ...],
block_out_channels: Tuple[int, ...],
upsampling_scaling_factor: int,
act_fn: str,
):
super().__init__()
layers = [
nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1),
get_activation(act_fn),
]
for i, num_block in enumerate(num_blocks):
is_final_block = i == (len(num_blocks) - 1)
num_channels = block_out_channels[i]
for _ in range(num_block):
layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))
if not is_final_block:
layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor))
conv_out_channel = num_channels if not is_final_block else out_channels
layers.append(
nn.Conv2d(
num_channels,
conv_out_channel,
kernel_size=3,
padding=1,
bias=is_final_block,
)
)
self.layers = nn.Sequential(*layers)
self.gradient_checkpointing = False
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `DecoderTiny` class."""
# Clamp.
x = torch.tanh(x / 3) * 3
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
else:
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
else:
x = self.layers(x)
# scale image from [0, 1] to [-1, 1] to match diffusers convention
return x.mul(2).sub(1)
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import functional as F
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import FromOriginalControlNetMixin
from ..utils import BaseOutput, logging
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .unets.unet_2d_blocks import (
CrossAttnDownBlock2D,
DownBlock2D,
UNetMidBlock2D,
UNetMidBlock2DCrossAttn,
get_down_block,
)
from .unets.unet_2d_condition import UNet2DConditionModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class ControlNetOutput(BaseOutput):
"""
The output of [`ControlNetModel`].
Args:
down_block_res_samples (`tuple[torch.Tensor]`):
A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
used to condition the original UNet's downsampling activations.
mid_down_block_re_sample (`torch.Tensor`):
The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
`(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
Output can be used to condition the original UNet's middle block activation.
"""
down_block_res_samples: Tuple[torch.Tensor]
mid_block_res_sample: torch.Tensor
class ControlNetConditioningEmbedding(nn.Module):
"""
Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
model) to encode image-space conditions ... into feature maps ..."
"""
def __init__(
self,
conditioning_embedding_channels: int,
conditioning_channels: int = 3,
block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
):
super().__init__()
self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
self.blocks = nn.ModuleList([])
for i in range(len(block_out_channels) - 1):
channel_in = block_out_channels[i]
channel_out = block_out_channels[i + 1]
self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
self.conv_out = zero_module(
nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
)
def forward(self, conditioning):
embedding = self.conv_in(conditioning)
embedding = F.silu(embedding)
for block in self.blocks:
embedding = block(embedding)
embedding = F.silu(embedding)
embedding = self.conv_out(embedding)
return embedding
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
"""
A ControlNet model.
Args:
in_channels (`int`, defaults to 4):
The number of channels in the input sample.
flip_sin_to_cos (`bool`, defaults to `True`):
Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, defaults to 0):
The frequency shift to apply to the time embedding.
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
The tuple of downsample blocks to use.
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
layers_per_block (`int`, defaults to 2):
The number of layers per block.
downsample_padding (`int`, defaults to 1):
The padding to use for the downsampling convolution.
mid_block_scale_factor (`float`, defaults to 1):
The scale factor to use for the mid block.
act_fn (`str`, defaults to "silu"):
The activation function to use.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
in post-processing.
norm_eps (`float`, defaults to 1e-5):
The epsilon to use for the normalization.
cross_attention_dim (`int`, defaults to 1280):
The dimension of the cross attention features.
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
encoder_hid_dim (`int`, *optional*, defaults to None):
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
dimension to `cross_attention_dim`.
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
The dimension of the attention heads.
use_linear_projection (`bool`, defaults to `False`):
class_embed_type (`str`, *optional*, defaults to `None`):
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
addition_embed_type (`str`, *optional*, defaults to `None`):
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
"text". "text" will use the `TextTimeEmbedding` layer.
num_class_embeds (`int`, *optional*, defaults to 0):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`.
upcast_attention (`bool`, defaults to `False`):
resnet_time_scale_shift (`str`, defaults to `"default"`):
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
`class_embed_type="projection"`.
controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
The tuple of output channel for each block in the `conditioning_embedding` layer.
global_pool_conditions (`bool`, defaults to `False`):
TODO(Patrick) - unused parameter.
addition_embed_type_num_heads (`int`, defaults to 64):
The number of heads to use for the `TextTimeEmbedding` layer.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 4,
conditioning_channels: int = 3,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
act_fn: str = "silu",
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
use_linear_projection: bool = False,
class_embed_type: Optional[str] = None,
addition_embed_type: Optional[str] = None,
addition_time_embed_dim: Optional[int] = None,
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
projection_class_embeddings_input_dim: Optional[int] = None,
controlnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
global_pool_conditions: bool = False,
addition_embed_type_num_heads: int = 64,
):
super().__init__()
# If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here.
num_attention_heads = num_attention_heads or attention_head_dim
# Check inputs
if len(block_out_channels) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
)
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
)
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
# input
conv_in_kernel = 3
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = nn.Conv2d(
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
)
# time
time_embed_dim = block_out_channels[0] * 4
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding(
timestep_input_dim,
time_embed_dim,
act_fn=act_fn,
)
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
encoder_hid_dim_type = "text_proj"
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
raise ValueError(
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
)
if encoder_hid_dim_type == "text_proj":
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
elif encoder_hid_dim_type == "text_image_proj":
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
self.encoder_hid_proj = TextImageProjection(
text_embed_dim=encoder_hid_dim,
image_embed_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim,
)
elif encoder_hid_dim_type is not None:
raise ValueError(
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
)
else:
self.encoder_hid_proj = None
# class embedding
if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
elif class_embed_type == "timestep":
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
elif class_embed_type == "identity":
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
elif class_embed_type == "projection":
if projection_class_embeddings_input_dim is None:
raise ValueError(
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
)
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
# 2. it projects from an arbitrary input dimension.
#
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
else:
self.class_embedding = None
if addition_embed_type == "text":
if encoder_hid_dim is not None:
text_time_embedding_from_dim = encoder_hid_dim
else:
text_time_embedding_from_dim = cross_attention_dim
self.add_embedding = TextTimeEmbedding(
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
)
elif addition_embed_type == "text_image":
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
self.add_embedding = TextImageTimeEmbedding(
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
)
elif addition_embed_type == "text_time":
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
elif addition_embed_type is not None:
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
# control net conditioning embedding
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
conditioning_embedding_channels=block_out_channels[0],
block_out_channels=conditioning_embedding_out_channels,
conditioning_channels=conditioning_channels,
)
self.down_blocks = nn.ModuleList([])
self.controlnet_down_blocks = nn.ModuleList([])
if isinstance(only_cross_attention, bool):
only_cross_attention = [only_cross_attention] * len(down_block_types)
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(down_block_types)
# down
output_channel = block_out_channels[0]
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
controlnet_block = zero_module(controlnet_block)
self.controlnet_down_blocks.append(controlnet_block)
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
transformer_layers_per_block=transformer_layers_per_block[i],
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads[i],
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
downsample_padding=downsample_padding,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
)
self.down_blocks.append(down_block)
for _ in range(layers_per_block):
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
controlnet_block = zero_module(controlnet_block)
self.controlnet_down_blocks.append(controlnet_block)
if not is_final_block:
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
controlnet_block = zero_module(controlnet_block)
self.controlnet_down_blocks.append(controlnet_block)
# mid
mid_block_channel = block_out_channels[-1]
controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
controlnet_block = zero_module(controlnet_block)
self.controlnet_mid_block = controlnet_block
if mid_block_type == "UNetMidBlock2DCrossAttn":
self.mid_block = UNetMidBlock2DCrossAttn(
transformer_layers_per_block=transformer_layers_per_block[-1],
in_channels=mid_block_channel,
temb_channels=time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads[-1],
resnet_groups=norm_num_groups,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
)
elif mid_block_type == "UNetMidBlock2D":
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
num_layers=0,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
add_attention=False,
)
else:
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
@classmethod
def from_unet(
cls,
unet: UNet2DConditionModel,
controlnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
load_weights_from_unet: bool = True,
conditioning_channels: int = 3,
):
r"""
Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
Parameters:
unet (`UNet2DConditionModel`):
The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
where applicable.
"""
transformer_layers_per_block = (
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
)
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
addition_time_embed_dim = (
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
)
controlnet = cls(
encoder_hid_dim=encoder_hid_dim,
encoder_hid_dim_type=encoder_hid_dim_type,
addition_embed_type=addition_embed_type,
addition_time_embed_dim=addition_time_embed_dim,
transformer_layers_per_block=transformer_layers_per_block,
in_channels=unet.config.in_channels,
flip_sin_to_cos=unet.config.flip_sin_to_cos,
freq_shift=unet.config.freq_shift,
down_block_types=unet.config.down_block_types,
only_cross_attention=unet.config.only_cross_attention,
block_out_channels=unet.config.block_out_channels,
layers_per_block=unet.config.layers_per_block,
downsample_padding=unet.config.downsample_padding,
mid_block_scale_factor=unet.config.mid_block_scale_factor,
act_fn=unet.config.act_fn,
norm_num_groups=unet.config.norm_num_groups,
norm_eps=unet.config.norm_eps,
cross_attention_dim=unet.config.cross_attention_dim,
attention_head_dim=unet.config.attention_head_dim,
num_attention_heads=unet.config.num_attention_heads,
use_linear_projection=unet.config.use_linear_projection,
class_embed_type=unet.config.class_embed_type,
num_class_embeds=unet.config.num_class_embeds,
upcast_attention=unet.config.upcast_attention,
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
mid_block_type=unet.config.mid_block_type,
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
conditioning_channels=conditioning_channels,
)
if load_weights_from_unet:
controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
if controlnet.class_embedding:
controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
return controlnet
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnAddedKVProcessor()
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnProcessor()
else:
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
r"""
Enable sliced attention computation.
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
Args:
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
"""
sliceable_head_dims = []
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
if hasattr(module, "set_attention_slice"):
sliceable_head_dims.append(module.sliceable_head_dim)
for child in module.children():
fn_recursive_retrieve_sliceable_dims(child)
# retrieve number of attention layers
for module in self.children():
fn_recursive_retrieve_sliceable_dims(module)
num_sliceable_layers = len(sliceable_head_dims)
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = [dim // 2 for dim in sliceable_head_dims]
elif slice_size == "max":
# make smallest slice possible
slice_size = num_sliceable_layers * [1]
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
if len(slice_size) != len(sliceable_head_dims):
raise ValueError(
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
)
for i in range(len(slice_size)):
size = slice_size[i]
dim = sliceable_head_dims[i]
if size is not None and size > dim:
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
# Recursively walk through all the children.
# Any children which exposes the set_attention_slice method
# gets the message
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
if hasattr(module, "set_attention_slice"):
module.set_attention_slice(slice_size.pop())
for child in module.children():
fn_recursive_set_attention_slice(child, slice_size)
reversed_slice_size = list(reversed(slice_size))
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
module.gradient_checkpointing = value
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
controlnet_cond: torch.FloatTensor,
conditioning_scale: float = 1.0,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guess_mode: bool = False,
return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
"""
The [`ControlNetModel`] forward method.
Args:
sample (`torch.FloatTensor`):
The noisy input tensor.
timestep (`Union[torch.Tensor, float, int]`):
The number of timesteps to denoise an input.
encoder_hidden_states (`torch.Tensor`):
The encoder hidden states.
controlnet_cond (`torch.FloatTensor`):
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
conditioning_scale (`float`, defaults to `1.0`):
The scale factor for ControlNet outputs.
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
embeddings.
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
negative values to the attention scores corresponding to "discard" tokens.
added_cond_kwargs (`dict`):
Additional conditions for the Stable Diffusion XL UNet.
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
guess_mode (`bool`, defaults to `False`):
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
Returns:
[`~models.controlnet.ControlNetOutput`] **or** `tuple`:
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
returned where the first element is the sample tensor.
"""
# check channel order
channel_order = self.config.controlnet_conditioning_channel_order
if channel_order == "rgb":
# in rgb order by default
...
elif channel_order == "bgr":
controlnet_cond = torch.flip(controlnet_cond, dims=[1])
else:
raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
# prepare attention_mask
if attention_mask is not None:
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
aug_emb = None
if self.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
if self.config.addition_embed_type is not None:
if self.config.addition_embed_type == "text":
aug_emb = self.add_embedding(encoder_hidden_states)
elif self.config.addition_embed_type == "text_time":
if "text_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
)
text_embeds = added_cond_kwargs.get("text_embeds")
if "time_ids" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
)
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds)
emb = emb + aug_emb if aug_emb is not None else emb
# 2. pre-process
sample = self.conv_in(sample)
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
sample = sample + controlnet_cond
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# 4. mid
if self.mid_block is not None:
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
sample = self.mid_block(
sample,
emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
)
else:
sample = self.mid_block(sample, emb)
# 5. Control net blocks
controlnet_down_block_res_samples = ()
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
down_block_res_sample = controlnet_block(down_block_res_sample)
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
down_block_res_samples = controlnet_down_block_res_samples
mid_block_res_sample = self.controlnet_mid_block(sample)
# 6. scaling
if guess_mode and not self.config.global_pool_conditions:
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
scales = scales * conditioning_scale
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
else:
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
mid_block_res_sample = mid_block_res_sample * conditioning_scale
if self.config.global_pool_conditions:
down_block_res_samples = [
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
]
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
if not return_dict:
return (down_block_res_samples, mid_block_res_sample)
return ControlNetOutput(
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
)
def zero_module(module):
for p in module.parameters():
nn.init.zeros_(p)
return module
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from ..configuration_utils import ConfigMixin, flax_register_to_config
from ..utils import BaseOutput
from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
from .modeling_flax_utils import FlaxModelMixin
from .unets.unet_2d_blocks_flax import (
FlaxCrossAttnDownBlock2D,
FlaxDownBlock2D,
FlaxUNetMidBlock2DCrossAttn,
)
@flax.struct.dataclass
class FlaxControlNetOutput(BaseOutput):
"""
The output of [`FlaxControlNetModel`].
Args:
down_block_res_samples (`jnp.ndarray`):
mid_block_res_sample (`jnp.ndarray`):
"""
down_block_res_samples: jnp.ndarray
mid_block_res_sample: jnp.ndarray
class FlaxControlNetConditioningEmbedding(nn.Module):
conditioning_embedding_channels: int
block_out_channels: Tuple[int, ...] = (16, 32, 96, 256)
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
self.conv_in = nn.Conv(
self.block_out_channels[0],
kernel_size=(3, 3),
padding=((1, 1), (1, 1)),
dtype=self.dtype,
)
blocks = []
for i in range(len(self.block_out_channels) - 1):
channel_in = self.block_out_channels[i]
channel_out = self.block_out_channels[i + 1]
conv1 = nn.Conv(
channel_in,
kernel_size=(3, 3),
padding=((1, 1), (1, 1)),
dtype=self.dtype,
)
blocks.append(conv1)
conv2 = nn.Conv(
channel_out,
kernel_size=(3, 3),
strides=(2, 2),
padding=((1, 1), (1, 1)),
dtype=self.dtype,
)
blocks.append(conv2)
self.blocks = blocks
self.conv_out = nn.Conv(
self.conditioning_embedding_channels,
kernel_size=(3, 3),
padding=((1, 1), (1, 1)),
kernel_init=nn.initializers.zeros_init(),
bias_init=nn.initializers.zeros_init(),
dtype=self.dtype,
)
def __call__(self, conditioning: jnp.ndarray) -> jnp.ndarray:
embedding = self.conv_in(conditioning)
embedding = nn.silu(embedding)
for block in self.blocks:
embedding = block(embedding)
embedding = nn.silu(embedding)
embedding = self.conv_out(embedding)
return embedding
@flax_register_to_config
class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
r"""
A ControlNet model.
This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it’s generic methods
implemented for all models (such as downloading or saving).
This model is also a Flax Linen [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its
general usage and behavior.
Inherent JAX features such as the following are supported:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
Parameters:
sample_size (`int`, *optional*):
The size of the input sample.
in_channels (`int`, *optional*, defaults to 4):
The number of channels in the input sample.
down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`):
The tuple of downsample blocks to use.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
The dimension of the attention heads.
num_attention_heads (`int` or `Tuple[int]`, *optional*):
The number of attention heads.
cross_attention_dim (`int`, *optional*, defaults to 768):
The dimension of the cross attention features.
dropout (`float`, *optional*, defaults to 0):
Dropout probability for down, up and bottleneck blocks.
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
controlnet_conditioning_channel_order (`str`, *optional*, defaults to `rgb`):
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
conditioning_embedding_out_channels (`tuple`, *optional*, defaults to `(16, 32, 96, 256)`):
The tuple of output channel for each block in the `conditioning_embedding` layer.
"""
sample_size: int = 32
in_channels: int = 4
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
)
only_cross_attention: Union[bool, Tuple[bool, ...]] = False
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280)
layers_per_block: int = 2
attention_head_dim: Union[int, Tuple[int, ...]] = 8
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None
cross_attention_dim: int = 1280
dropout: float = 0.0
use_linear_projection: bool = False
dtype: jnp.dtype = jnp.float32
flip_sin_to_cos: bool = True
freq_shift: int = 0
controlnet_conditioning_channel_order: str = "rgb"
conditioning_embedding_out_channels: Tuple[int, ...] = (16, 32, 96, 256)
def init_weights(self, rng: jax.Array) -> FrozenDict:
# init input tensors
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
timesteps = jnp.ones((1,), dtype=jnp.int32)
encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32)
controlnet_cond_shape = (1, 3, self.sample_size * 8, self.sample_size * 8)
controlnet_cond = jnp.zeros(controlnet_cond_shape, dtype=jnp.float32)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.init(rngs, sample, timesteps, encoder_hidden_states, controlnet_cond)["params"]
def setup(self) -> None:
block_out_channels = self.block_out_channels
time_embed_dim = block_out_channels[0] * 4
# If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here.
num_attention_heads = self.num_attention_heads or self.attention_head_dim
# input
self.conv_in = nn.Conv(
block_out_channels[0],
kernel_size=(3, 3),
strides=(1, 1),
padding=((1, 1), (1, 1)),
dtype=self.dtype,
)
# time
self.time_proj = FlaxTimesteps(
block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift
)
self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
self.controlnet_cond_embedding = FlaxControlNetConditioningEmbedding(
conditioning_embedding_channels=block_out_channels[0],
block_out_channels=self.conditioning_embedding_out_channels,
)
only_cross_attention = self.only_cross_attention
if isinstance(only_cross_attention, bool):
only_cross_attention = (only_cross_attention,) * len(self.down_block_types)
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(self.down_block_types)
# down
down_blocks = []
controlnet_down_blocks = []
output_channel = block_out_channels[0]
controlnet_block = nn.Conv(
output_channel,
kernel_size=(1, 1),
padding="VALID",
kernel_init=nn.initializers.zeros_init(),
bias_init=nn.initializers.zeros_init(),
dtype=self.dtype,
)
controlnet_down_blocks.append(controlnet_block)
for i, down_block_type in enumerate(self.down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
if down_block_type == "CrossAttnDownBlock2D":
down_block = FlaxCrossAttnDownBlock2D(
in_channels=input_channel,
out_channels=output_channel,
dropout=self.dropout,
num_layers=self.layers_per_block,
num_attention_heads=num_attention_heads[i],
add_downsample=not is_final_block,
use_linear_projection=self.use_linear_projection,
only_cross_attention=only_cross_attention[i],
dtype=self.dtype,
)
else:
down_block = FlaxDownBlock2D(
in_channels=input_channel,
out_channels=output_channel,
dropout=self.dropout,
num_layers=self.layers_per_block,
add_downsample=not is_final_block,
dtype=self.dtype,
)
down_blocks.append(down_block)
for _ in range(self.layers_per_block):
controlnet_block = nn.Conv(
output_channel,
kernel_size=(1, 1),
padding="VALID",
kernel_init=nn.initializers.zeros_init(),
bias_init=nn.initializers.zeros_init(),
dtype=self.dtype,
)
controlnet_down_blocks.append(controlnet_block)
if not is_final_block:
controlnet_block = nn.Conv(
output_channel,
kernel_size=(1, 1),
padding="VALID",
kernel_init=nn.initializers.zeros_init(),
bias_init=nn.initializers.zeros_init(),
dtype=self.dtype,
)
controlnet_down_blocks.append(controlnet_block)
self.down_blocks = down_blocks
self.controlnet_down_blocks = controlnet_down_blocks
# mid
mid_block_channel = block_out_channels[-1]
self.mid_block = FlaxUNetMidBlock2DCrossAttn(
in_channels=mid_block_channel,
dropout=self.dropout,
num_attention_heads=num_attention_heads[-1],
use_linear_projection=self.use_linear_projection,
dtype=self.dtype,
)
self.controlnet_mid_block = nn.Conv(
mid_block_channel,
kernel_size=(1, 1),
padding="VALID",
kernel_init=nn.initializers.zeros_init(),
bias_init=nn.initializers.zeros_init(),
dtype=self.dtype,
)
def __call__(
self,
sample: jnp.ndarray,
timesteps: Union[jnp.ndarray, float, int],
encoder_hidden_states: jnp.ndarray,
controlnet_cond: jnp.ndarray,
conditioning_scale: float = 1.0,
return_dict: bool = True,
train: bool = False,
) -> Union[FlaxControlNetOutput, Tuple[Tuple[jnp.ndarray, ...], jnp.ndarray]]:
r"""
Args:
sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor
timestep (`jnp.ndarray` or `float` or `int`): timesteps
encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states
controlnet_cond (`jnp.ndarray`): (batch, channel, height, width) the conditional input tensor
conditioning_scale (`float`, *optional*, defaults to `1.0`): the scale factor for controlnet outputs
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
plain tuple.
train (`bool`, *optional*, defaults to `False`):
Use deterministic functions and disable dropout when not training.
Returns:
[`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
[`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is the sample tensor.
"""
channel_order = self.controlnet_conditioning_channel_order
if channel_order == "bgr":
controlnet_cond = jnp.flip(controlnet_cond, axis=1)
# 1. time
if not isinstance(timesteps, jnp.ndarray):
timesteps = jnp.array([timesteps], dtype=jnp.int32)
elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0:
timesteps = timesteps.astype(dtype=jnp.float32)
timesteps = jnp.expand_dims(timesteps, 0)
t_emb = self.time_proj(timesteps)
t_emb = self.time_embedding(t_emb)
# 2. pre-process
sample = jnp.transpose(sample, (0, 2, 3, 1))
sample = self.conv_in(sample)
controlnet_cond = jnp.transpose(controlnet_cond, (0, 2, 3, 1))
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
sample += controlnet_cond
# 3. down
down_block_res_samples = (sample,)
for down_block in self.down_blocks:
if isinstance(down_block, FlaxCrossAttnDownBlock2D):
sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
else:
sample, res_samples = down_block(sample, t_emb, deterministic=not train)
down_block_res_samples += res_samples
# 4. mid
sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
# 5. contronet blocks
controlnet_down_block_res_samples = ()
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
down_block_res_sample = controlnet_block(down_block_res_sample)
controlnet_down_block_res_samples += (down_block_res_sample,)
down_block_res_samples = controlnet_down_block_res_samples
mid_block_res_sample = self.controlnet_mid_block(sample)
# 6. scaling
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
mid_block_res_sample *= conditioning_scale
if not return_dict:
return (down_block_res_samples, mid_block_res_sample)
return FlaxControlNetOutput(
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
)
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..utils import deprecate
from .normalization import RMSNorm
from .upsampling import upfirdn2d_native
class Downsample1D(nn.Module):
"""A 1D downsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
padding (`int`, default `1`):
padding for the convolution.
name (`str`, default `conv`):
name of the downsampling 1D layer.
"""
def __init__(
self,
channels: int,
use_conv: bool = False,
out_channels: Optional[int] = None,
padding: int = 1,
name: str = "conv",
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.padding = padding
stride = 2
self.name = name
if use_conv:
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
assert self.channels == self.out_channels
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
assert inputs.shape[1] == self.channels
return self.conv(inputs)
class Downsample2D(nn.Module):
"""A 2D downsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
padding (`int`, default `1`):
padding for the convolution.
name (`str`, default `conv`):
name of the downsampling 2D layer.
"""
def __init__(
self,
channels: int,
use_conv: bool = False,
out_channels: Optional[int] = None,
padding: int = 1,
name: str = "conv",
kernel_size=3,
norm_type=None,
eps=None,
elementwise_affine=None,
bias=True,
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.padding = padding
stride = 2
self.name = name
conv_cls = nn.Conv2d
if norm_type == "ln_norm":
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
elif norm_type == "rms_norm":
self.norm = RMSNorm(channels, eps, elementwise_affine)
elif norm_type is None:
self.norm = None
else:
raise ValueError(f"unknown norm_type: {norm_type}")
if use_conv:
conv = conv_cls(
self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
)
else:
assert self.channels == self.out_channels
conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if name == "conv":
self.Conv2d_0 = conv
self.conv = conv
elif name == "Conv2d_0":
self.conv = conv
else:
self.conv = conv
def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
assert hidden_states.shape[1] == self.channels
if self.norm is not None:
hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
if self.use_conv and self.padding == 0:
pad = (0, 1, 0, 1)
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
assert hidden_states.shape[1] == self.channels
hidden_states = self.conv(hidden_states)
return hidden_states
class FirDownsample2D(nn.Module):
"""A 2D FIR downsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
kernel for the FIR filter.
"""
def __init__(
self,
channels: Optional[int] = None,
out_channels: Optional[int] = None,
use_conv: bool = False,
fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
):
super().__init__()
out_channels = out_channels if out_channels else channels
if use_conv:
self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
self.fir_kernel = fir_kernel
self.use_conv = use_conv
self.out_channels = out_channels
def _downsample_2d(
self,
hidden_states: torch.FloatTensor,
weight: Optional[torch.FloatTensor] = None,
kernel: Optional[torch.FloatTensor] = None,
factor: int = 2,
gain: float = 1,
) -> torch.FloatTensor:
"""Fused `Conv2d()` followed by `downsample_2d()`.
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
arbitrary order.
Args:
hidden_states (`torch.FloatTensor`):
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
weight (`torch.FloatTensor`, *optional*):
Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
performed by `inChannels = x.shape[0] // numGroups`.
kernel (`torch.FloatTensor`, *optional*):
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
corresponds to average pooling.
factor (`int`, *optional*, default to `2`):
Integer downsampling factor.
gain (`float`, *optional*, default to `1.0`):
Scaling factor for signal magnitude.
Returns:
output (`torch.FloatTensor`):
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
datatype as `x`.
"""
assert isinstance(factor, int) and factor >= 1
if kernel is None:
kernel = [1] * factor
# setup kernel
kernel = torch.tensor(kernel, dtype=torch.float32)
if kernel.ndim == 1:
kernel = torch.outer(kernel, kernel)
kernel /= torch.sum(kernel)
kernel = kernel * gain
if self.use_conv:
_, _, convH, convW = weight.shape
pad_value = (kernel.shape[0] - factor) + (convW - 1)
stride_value = [factor, factor]
upfirdn_input = upfirdn2d_native(
hidden_states,
torch.tensor(kernel, device=hidden_states.device),
pad=((pad_value + 1) // 2, pad_value // 2),
)
output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
else:
pad_value = kernel.shape[0] - factor
output = upfirdn2d_native(
hidden_states,
torch.tensor(kernel, device=hidden_states.device),
down=factor,
pad=((pad_value + 1) // 2, pad_value // 2),
)
return output
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
if self.use_conv:
downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else:
hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
return hidden_states
# downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead
class KDownsample2D(nn.Module):
r"""A 2D K-downsampling layer.
Parameters:
pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use.
"""
def __init__(self, pad_mode: str = "reflect"):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
self.pad = kernel_1d.shape[1] // 2 - 1
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
weight = inputs.new_zeros(
[
inputs.shape[1],
inputs.shape[1],
self.kernel.shape[0],
self.kernel.shape[1],
]
)
indices = torch.arange(inputs.shape[1], device=inputs.device)
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
weight[indices, indices] = kernel
return F.conv2d(inputs, weight, stride=2)
def downsample_2d(
hidden_states: torch.FloatTensor,
kernel: Optional[torch.FloatTensor] = None,
factor: int = 2,
gain: float = 1,
) -> torch.FloatTensor:
r"""Downsample2D a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
shape is a multiple of the downsampling factor.
Args:
hidden_states (`torch.FloatTensor`)
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
kernel (`torch.FloatTensor`, *optional*):
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
corresponds to average pooling.
factor (`int`, *optional*, default to `2`):
Integer downsampling factor.
gain (`float`, *optional*, default to `1.0`):
Scaling factor for signal magnitude.
Returns:
output (`torch.FloatTensor`):
Tensor of the shape `[N, C, H // factor, W // factor]`
"""
assert isinstance(factor, int) and factor >= 1
if kernel is None:
kernel = [1] * factor
kernel = torch.tensor(kernel, dtype=torch.float32)
if kernel.ndim == 1:
kernel = torch.outer(kernel, kernel)
kernel /= torch.sum(kernel)
kernel = kernel * gain
pad_value = kernel.shape[0] - factor
output = upfirdn2d_native(
hidden_states,
kernel.to(device=hidden_states.device),
down=factor,
pad=((pad_value + 1) // 2, pad_value // 2),
)
return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment