"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "e72fe7945fce351f00c75027e898a7f8598c8fa9"
Unverified Commit e8282327 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Rename attention (#2691)

* rename file

* rename attention

* fix more

* rename more

* up

* more deprecation imports

* fixes
parent 588e50bc
...@@ -50,7 +50,7 @@ pip install --pre torch torchvision --index-url https://download.pytorch.org/whl ...@@ -50,7 +50,7 @@ pip install --pre torch torchvision --index-url https://download.pytorch.org/whl
```Python ```Python
import torch import torch
from diffusers import StableDiffusionPipeline from diffusers import StableDiffusionPipeline
from diffusers.models.cross_attention import AttnProcessor2_0 from diffusers.models.attention_processor import AttnProcessor2_0
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda") pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
pipe.unet.set_attn_processor(AttnProcessor2_0()) pipe.unet.set_attn_processor(AttnProcessor2_0())
......
...@@ -713,7 +713,7 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline): ...@@ -713,7 +713,7 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
The frequency at which the `callback` function will be called. If not specified, the callback will be The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step. called at every step.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in `self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
......
...@@ -868,7 +868,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline): ...@@ -868,7 +868,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
The frequency at which the `callback` function will be called. If not specified, the callback will be The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step. called at every step.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in `self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
......
...@@ -911,7 +911,7 @@ class StableDiffusionControlNetInpaintImg2ImgPipeline(DiffusionPipeline): ...@@ -911,7 +911,7 @@ class StableDiffusionControlNetInpaintImg2ImgPipeline(DiffusionPipeline):
The frequency at which the `callback` function will be called. If not specified, the callback will be The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step. called at every step.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in `self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
......
...@@ -47,7 +47,7 @@ from diffusers import ( ...@@ -47,7 +47,7 @@ from diffusers import (
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.loaders import AttnProcsLayers from diffusers.loaders import AttnProcsLayers
from diffusers.models.cross_attention import LoRACrossAttnProcessor from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
...@@ -723,9 +723,7 @@ def main(args): ...@@ -723,9 +723,7 @@ def main(args):
block_id = int(name[len("down_blocks.")]) block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id] hidden_size = unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRACrossAttnProcessor( lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
unet.set_attn_processor(lora_attn_procs) unet.set_attn_processor(lora_attn_procs)
lora_layers = AttnProcsLayers(unet.attn_processors) lora_layers = AttnProcsLayers(unet.attn_processors)
......
...@@ -22,7 +22,7 @@ from transformers import CLIPTextModel, CLIPTokenizer ...@@ -22,7 +22,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel
from diffusers.loaders import AttnProcsLayers from diffusers.loaders import AttnProcsLayers
from diffusers.models.cross_attention import LoRACrossAttnProcessor from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version from diffusers.utils import check_min_version
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
...@@ -561,9 +561,7 @@ def main(): ...@@ -561,9 +561,7 @@ def main():
block_id = int(name[len("down_blocks.")]) block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id] hidden_size = unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRACrossAttnProcessor( lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
unet.set_attn_processor(lora_attn_procs) unet.set_attn_processor(lora_attn_procs)
lora_layers = AttnProcsLayers(unet.attn_processors) lora_layers = AttnProcsLayers(unet.attn_processors)
......
...@@ -43,7 +43,7 @@ from transformers import CLIPTextModel, CLIPTokenizer ...@@ -43,7 +43,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
import diffusers import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.loaders import AttnProcsLayers from diffusers.loaders import AttnProcsLayers
from diffusers.models.cross_attention import LoRACrossAttnProcessor from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
...@@ -536,9 +536,7 @@ def main(): ...@@ -536,9 +536,7 @@ def main():
block_id = int(name[len("down_blocks.")]) block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id] hidden_size = unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRACrossAttnProcessor( lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
unet.set_attn_processor(lora_attn_procs) unet.set_attn_processor(lora_attn_procs)
lora_layers = AttnProcsLayers(unet.attn_processors) lora_layers = AttnProcsLayers(unet.attn_processors)
......
...@@ -41,7 +41,7 @@ from transformers import CLIPTextModel, CLIPTokenizer ...@@ -41,7 +41,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
import diffusers import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.loaders import AttnProcsLayers from diffusers.loaders import AttnProcsLayers
from diffusers.models.cross_attention import LoRACrossAttnProcessor from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
...@@ -474,9 +474,7 @@ def main(): ...@@ -474,9 +474,7 @@ def main():
block_id = int(name[len("down_blocks.")]) block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id] hidden_size = unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRACrossAttnProcessor( lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
unet.set_attn_processor(lora_attn_procs) unet.set_attn_processor(lora_attn_procs)
......
...@@ -17,7 +17,7 @@ from typing import Callable, Dict, Union ...@@ -17,7 +17,7 @@ from typing import Callable, Dict, Union
import torch import torch
from .models.cross_attention import LoRACrossAttnProcessor from .models.attention_processor import LoRAAttnProcessor
from .models.modeling_utils import _get_model_file from .models.modeling_utils import _get_model_file
from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, deprecate, is_safetensors_available, logging from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, deprecate, is_safetensors_available, logging
...@@ -207,7 +207,7 @@ class UNet2DConditionLoadersMixin: ...@@ -207,7 +207,7 @@ class UNet2DConditionLoadersMixin:
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1] cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
hidden_size = value_dict["to_k_lora.up.weight"].shape[0] hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
attn_processors[key] = LoRACrossAttnProcessor( attn_processors[key] = LoRAAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
) )
attn_processors[key].load_state_dict(value_dict) attn_processors[key].load_state_dict(value_dict)
......
...@@ -19,7 +19,7 @@ import torch.nn.functional as F ...@@ -19,7 +19,7 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from ..utils.import_utils import is_xformers_available from ..utils.import_utils import is_xformers_available
from .cross_attention import CrossAttention from .attention_processor import Attention
from .embeddings import CombinedTimestepLabelEmbeddings from .embeddings import CombinedTimestepLabelEmbeddings
...@@ -220,7 +220,7 @@ class BasicTransformerBlock(nn.Module): ...@@ -220,7 +220,7 @@ class BasicTransformerBlock(nn.Module):
) )
# 1. Self-Attn # 1. Self-Attn
self.attn1 = CrossAttention( self.attn1 = Attention(
query_dim=dim, query_dim=dim,
heads=num_attention_heads, heads=num_attention_heads,
dim_head=attention_head_dim, dim_head=attention_head_dim,
...@@ -234,7 +234,7 @@ class BasicTransformerBlock(nn.Module): ...@@ -234,7 +234,7 @@ class BasicTransformerBlock(nn.Module):
# 2. Cross-Attn # 2. Cross-Attn
if cross_attention_dim is not None: if cross_attention_dim is not None:
self.attn2 = CrossAttention( self.attn2 = Attention(
query_dim=dim, query_dim=dim,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
heads=num_attention_heads, heads=num_attention_heads,
......
...@@ -16,7 +16,7 @@ import flax.linen as nn ...@@ -16,7 +16,7 @@ import flax.linen as nn
import jax.numpy as jnp import jax.numpy as jnp
class FlaxCrossAttention(nn.Module): class FlaxAttention(nn.Module):
r""" r"""
A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762 A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762
...@@ -118,9 +118,9 @@ class FlaxBasicTransformerBlock(nn.Module): ...@@ -118,9 +118,9 @@ class FlaxBasicTransformerBlock(nn.Module):
def setup(self): def setup(self):
# self attention (or cross_attention if only_cross_attention is True) # self attention (or cross_attention if only_cross_attention is True)
self.attn1 = FlaxCrossAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) self.attn1 = FlaxAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
# cross attention # cross attention
self.attn2 = FlaxCrossAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) self.attn2 = FlaxAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype) self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
......
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional, Union
import torch
import torch.nn.functional as F
from torch import nn
from ..utils import deprecate, logging
from ..utils.import_utils import is_xformers_available
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None
class Attention(nn.Module):
r"""
A cross attention layer.
Parameters:
query_dim (`int`): The number of channels in the query.
cross_attention_dim (`int`, *optional*):
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
bias (`bool`, *optional*, defaults to False):
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
"""
def __init__(
self,
query_dim: int,
cross_attention_dim: Optional[int] = None,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias=False,
upcast_attention: bool = False,
upcast_softmax: bool = False,
cross_attention_norm: bool = False,
added_kv_proj_dim: Optional[int] = None,
norm_num_groups: Optional[int] = None,
out_bias: bool = True,
scale_qk: bool = True,
processor: Optional["AttnProcessor"] = None,
):
super().__init__()
inner_dim = dim_head * heads
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.cross_attention_norm = cross_attention_norm
self.scale = dim_head**-0.5 if scale_qk else 1.0
self.heads = heads
# for slice_size > 0 the attention score computation
# is split across the batch axis to save memory
# You can set slice_size with `set_attention_slice`
self.sliceable_head_dim = heads
self.added_kv_proj_dim = added_kv_proj_dim
if norm_num_groups is not None:
self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
else:
self.group_norm = None
if cross_attention_norm:
self.norm_cross = nn.LayerNorm(cross_attention_dim)
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
if self.added_kv_proj_dim is not None:
self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias))
self.to_out.append(nn.Dropout(dropout))
# set attention processor
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
if processor is None:
processor = (
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and scale_qk else AttnProcessor()
)
self.set_processor(processor)
def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
):
is_lora = hasattr(self, "processor") and isinstance(
self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor)
)
if use_memory_efficient_attention_xformers:
if self.added_kv_proj_dim is not None:
# TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
# which uses this type of cross attention ONLY because the attention mask of format
# [0, ..., -10.000, ..., 0, ...,] is not supported
raise NotImplementedError(
"Memory efficient attention with `xformers` is currently not supported when"
" `self.added_kv_proj_dim` is defined."
)
elif not is_xformers_available():
raise ModuleNotFoundError(
(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
" xformers"
),
name="xformers",
)
elif not torch.cuda.is_available():
raise ValueError(
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
" only available for GPU "
)
else:
try:
# Make sure we can run the memory efficient attention
_ = xformers.ops.memory_efficient_attention(
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
)
except Exception as e:
raise e
if is_lora:
processor = LoRAXFormersAttnProcessor(
hidden_size=self.processor.hidden_size,
cross_attention_dim=self.processor.cross_attention_dim,
rank=self.processor.rank,
attention_op=attention_op,
)
processor.load_state_dict(self.processor.state_dict())
processor.to(self.processor.to_q_lora.up.weight.device)
else:
processor = XFormersAttnProcessor(attention_op=attention_op)
else:
if is_lora:
processor = LoRAAttnProcessor(
hidden_size=self.processor.hidden_size,
cross_attention_dim=self.processor.cross_attention_dim,
rank=self.processor.rank,
)
processor.load_state_dict(self.processor.state_dict())
processor.to(self.processor.to_q_lora.up.weight.device)
else:
processor = AttnProcessor()
self.set_processor(processor)
def set_attention_slice(self, slice_size):
if slice_size is not None and slice_size > self.sliceable_head_dim:
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
if slice_size is not None and self.added_kv_proj_dim is not None:
processor = SlicedAttnAddedKVProcessor(slice_size)
elif slice_size is not None:
processor = SlicedAttnProcessor(slice_size)
elif self.added_kv_proj_dim is not None:
processor = AttnAddedKVProcessor()
else:
processor = AttnProcessor()
self.set_processor(processor)
def set_processor(self, processor: "AttnProcessor"):
# if current processor is in `self._modules` and if passed `processor` is not, we need to
# pop `processor` from `self._modules`
if (
hasattr(self, "processor")
and isinstance(self.processor, torch.nn.Module)
and not isinstance(processor, torch.nn.Module)
):
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
self._modules.pop("processor")
self.processor = processor
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
# The `Attention` class can call different attention processors / attention functions
# here we simply pass along all tensors to the selected processor class
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
return self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
def batch_to_head_dim(self, tensor):
head_size = self.heads
batch_size, seq_len, dim = tensor.shape
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
def head_to_batch_dim(self, tensor):
head_size = self.heads
batch_size, seq_len, dim = tensor.shape
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor
def get_attention_scores(self, query, key, attention_mask=None):
dtype = query.dtype
if self.upcast_attention:
query = query.float()
key = key.float()
if attention_mask is None:
baddbmm_input = torch.empty(
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
)
beta = 0
else:
baddbmm_input = attention_mask
beta = 1
attention_scores = torch.baddbmm(
baddbmm_input,
query,
key.transpose(-1, -2),
beta=beta,
alpha=self.scale,
)
if self.upcast_softmax:
attention_scores = attention_scores.float()
attention_probs = attention_scores.softmax(dim=-1)
attention_probs = attention_probs.to(dtype)
return attention_probs
def prepare_attention_mask(self, attention_mask, target_length, batch_size=None):
if batch_size is None:
deprecate(
"batch_size=None",
"0.0.15",
(
"Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect"
" attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to"
" `prepare_attention_mask` when preparing the attention_mask."
),
)
batch_size = 1
head_size = self.heads
if attention_mask is None:
return attention_mask
if attention_mask.shape[-1] != target_length:
if attention_mask.device.type == "mps":
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
# Instead, we can manually construct the padding tensor.
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
attention_mask = torch.cat([attention_mask, padding], dim=2)
else:
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
if attention_mask.shape[0] < batch_size * head_size:
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
return attention_mask
class AttnProcessor:
def __call__(
self,
attn: Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
):
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.cross_attention_norm:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class LoRALinearLayer(nn.Module):
def __init__(self, in_features, out_features, rank=4):
super().__init__()
if rank > min(in_features, out_features):
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
self.down = nn.Linear(in_features, rank, bias=False)
self.up = nn.Linear(rank, out_features, bias=False)
nn.init.normal_(self.down.weight, std=1 / rank)
nn.init.zeros_(self.up.weight)
def forward(self, hidden_states):
orig_dtype = hidden_states.dtype
dtype = self.down.weight.dtype
down_hidden_states = self.down(hidden_states.to(dtype))
up_hidden_states = self.up(down_hidden_states)
return up_hidden_states.to(orig_dtype)
class LoRAAttnProcessor(nn.Module):
def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.rank = rank
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
query = attn.head_to_batch_dim(query)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class AttnAddedKVProcessor:
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
hidden_states = hidden_states + residual
return hidden_states
class XFormersAttnProcessor:
def __init__(self, attention_op: Optional[Callable] = None):
self.attention_op = attention_op
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.cross_attention_norm:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class AttnProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
inner_dim = hidden_states.shape[-1]
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.cross_attention_norm:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class LoRAXFormersAttnProcessor(nn.Module):
def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None):
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.rank = rank
self.attention_op = attention_op
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
query = attn.head_to_batch_dim(query).contiguous()
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class SlicedAttnProcessor:
def __init__(self, slice_size):
self.slice_size = slice_size
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = attn.to_q(hidden_states)
dim = query.shape[-1]
query = attn.head_to_batch_dim(query)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.cross_attention_norm:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
batch_size_attention, query_tokens, _ = query.shape
hidden_states = torch.zeros(
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
)
for i in range(batch_size_attention // self.slice_size):
start_idx = i * self.slice_size
end_idx = (i + 1) * self.slice_size
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class SlicedAttnAddedKVProcessor:
def __init__(self, slice_size):
self.slice_size = slice_size
def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None):
residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
dim = query.shape[-1]
query = attn.head_to_batch_dim(query)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
batch_size_attention, query_tokens, _ = query.shape
hidden_states = torch.zeros(
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
)
for i in range(batch_size_attention // self.slice_size):
start_idx = i * self.slice_size
end_idx = (i + 1) * self.slice_size
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
hidden_states = hidden_states + residual
return hidden_states
AttentionProcessor = Union[
AttnProcessor,
XFormersAttnProcessor,
SlicedAttnProcessor,
AttnAddedKVProcessor,
SlicedAttnAddedKVProcessor,
LoRAAttnProcessor,
LoRAXFormersAttnProcessor,
]
...@@ -20,7 +20,7 @@ from torch.nn import functional as F ...@@ -20,7 +20,7 @@ from torch.nn import functional as F
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging from ..utils import BaseOutput, logging
from .cross_attention import AttnProcessor from .attention_processor import AttentionProcessor
from .embeddings import TimestepEmbedding, Timesteps from .embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .unet_2d_blocks import ( from .unet_2d_blocks import (
...@@ -314,7 +314,7 @@ class ControlNetModel(ModelMixin, ConfigMixin): ...@@ -314,7 +314,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
@property @property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttnProcessor]: def attn_processors(self) -> Dict[str, AttentionProcessor]:
r""" r"""
Returns: Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with `dict` of attention processors: A dictionary containing all attention processors used in the model with
...@@ -323,7 +323,7 @@ class ControlNetModel(ModelMixin, ConfigMixin): ...@@ -323,7 +323,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
# set recursively # set recursively
processors = {} processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessor]): def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "set_processor"): if hasattr(module, "set_processor"):
processors[f"{name}.processor"] = module.processor processors[f"{name}.processor"] = module.processor
...@@ -338,12 +338,12 @@ class ControlNetModel(ModelMixin, ConfigMixin): ...@@ -338,12 +338,12 @@ class ControlNetModel(ModelMixin, ConfigMixin):
return processors return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]]): def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r""" r"""
Parameters: Parameters:
`processor (`dict` of `AttnProcessor` or `AttnProcessor`): `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor The instantiated processor class or a dictionary of processor classes that will be set as the processor
of **all** `CrossAttention` layers. of **all** `Attention` layers.
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.: In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.:
""" """
......
...@@ -11,689 +11,86 @@ ...@@ -11,689 +11,86 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Callable, Optional, Union from ..utils import deprecate
from .attention_processor import ( # noqa: F401
import torch Attention,
import torch.nn.functional as F AttentionProcessor,
from torch import nn AttnAddedKVProcessor,
AttnProcessor2_0,
from ..utils import deprecate, logging LoRAAttnProcessor,
from ..utils.import_utils import is_xformers_available LoRALinearLayer,
LoRAXFormersAttnProcessor,
SlicedAttnAddedKVProcessor,
logger = logging.get_logger(__name__) # pylint: disable=invalid-name SlicedAttnProcessor,
XFormersAttnProcessor,
)
if is_xformers_available(): from .attention_processor import ( # noqa: F401
import xformers AttnProcessor as AttnProcessorRename,
import xformers.ops )
else:
xformers = None
class CrossAttention(nn.Module):
r"""
A cross attention layer.
Parameters:
query_dim (`int`): The number of channels in the query.
cross_attention_dim (`int`, *optional*):
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
bias (`bool`, *optional*, defaults to False):
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
"""
def __init__(
self,
query_dim: int,
cross_attention_dim: Optional[int] = None,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias=False,
upcast_attention: bool = False,
upcast_softmax: bool = False,
cross_attention_norm: bool = False,
added_kv_proj_dim: Optional[int] = None,
norm_num_groups: Optional[int] = None,
out_bias: bool = True,
scale_qk: bool = True,
processor: Optional["AttnProcessor"] = None,
):
super().__init__()
inner_dim = dim_head * heads
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.cross_attention_norm = cross_attention_norm
self.scale = dim_head**-0.5 if scale_qk else 1.0
self.heads = heads
# for slice_size > 0 the attention score computation
# is split across the batch axis to save memory
# You can set slice_size with `set_attention_slice`
self.sliceable_head_dim = heads
self.added_kv_proj_dim = added_kv_proj_dim
if norm_num_groups is not None:
self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
else:
self.group_norm = None
if cross_attention_norm:
self.norm_cross = nn.LayerNorm(cross_attention_dim)
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
if self.added_kv_proj_dim is not None:
self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias))
self.to_out.append(nn.Dropout(dropout))
# set attention processor
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
if processor is None:
processor = (
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and scale_qk else CrossAttnProcessor()
)
self.set_processor(processor)
def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
):
is_lora = hasattr(self, "processor") and isinstance(
self.processor, (LoRACrossAttnProcessor, LoRAXFormersCrossAttnProcessor)
)
if use_memory_efficient_attention_xformers:
if self.added_kv_proj_dim is not None:
# TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
# which uses this type of cross attention ONLY because the attention mask of format
# [0, ..., -10.000, ..., 0, ...,] is not supported
raise NotImplementedError(
"Memory efficient attention with `xformers` is currently not supported when"
" `self.added_kv_proj_dim` is defined."
)
elif not is_xformers_available():
raise ModuleNotFoundError(
(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
" xformers"
),
name="xformers",
)
elif not torch.cuda.is_available():
raise ValueError(
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
" only available for GPU "
)
else:
try:
# Make sure we can run the memory efficient attention
_ = xformers.ops.memory_efficient_attention(
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
)
except Exception as e:
raise e
if is_lora:
processor = LoRAXFormersCrossAttnProcessor(
hidden_size=self.processor.hidden_size,
cross_attention_dim=self.processor.cross_attention_dim,
rank=self.processor.rank,
attention_op=attention_op,
)
processor.load_state_dict(self.processor.state_dict())
processor.to(self.processor.to_q_lora.up.weight.device)
else:
processor = XFormersCrossAttnProcessor(attention_op=attention_op)
else:
if is_lora:
processor = LoRACrossAttnProcessor(
hidden_size=self.processor.hidden_size,
cross_attention_dim=self.processor.cross_attention_dim,
rank=self.processor.rank,
)
processor.load_state_dict(self.processor.state_dict())
processor.to(self.processor.to_q_lora.up.weight.device)
else:
processor = CrossAttnProcessor()
self.set_processor(processor)
def set_attention_slice(self, slice_size):
if slice_size is not None and slice_size > self.sliceable_head_dim:
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
if slice_size is not None and self.added_kv_proj_dim is not None:
processor = SlicedAttnAddedKVProcessor(slice_size)
elif slice_size is not None:
processor = SlicedAttnProcessor(slice_size)
elif self.added_kv_proj_dim is not None:
processor = CrossAttnAddedKVProcessor()
else:
processor = CrossAttnProcessor()
self.set_processor(processor)
def set_processor(self, processor: "AttnProcessor"):
# if current processor is in `self._modules` and if passed `processor` is not, we need to
# pop `processor` from `self._modules`
if (
hasattr(self, "processor")
and isinstance(self.processor, torch.nn.Module)
and not isinstance(processor, torch.nn.Module)
):
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
self._modules.pop("processor")
self.processor = processor
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
# The `CrossAttention` class can call different attention processors / attention functions
# here we simply pass along all tensors to the selected processor class
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
return self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
def batch_to_head_dim(self, tensor):
head_size = self.heads
batch_size, seq_len, dim = tensor.shape
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
def head_to_batch_dim(self, tensor):
head_size = self.heads
batch_size, seq_len, dim = tensor.shape
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor
def get_attention_scores(self, query, key, attention_mask=None):
dtype = query.dtype
if self.upcast_attention:
query = query.float()
key = key.float()
if attention_mask is None:
baddbmm_input = torch.empty(
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
)
beta = 0
else:
baddbmm_input = attention_mask
beta = 1
attention_scores = torch.baddbmm(
baddbmm_input,
query,
key.transpose(-1, -2),
beta=beta,
alpha=self.scale,
)
if self.upcast_softmax:
attention_scores = attention_scores.float()
attention_probs = attention_scores.softmax(dim=-1)
attention_probs = attention_probs.to(dtype)
return attention_probs
def prepare_attention_mask(self, attention_mask, target_length, batch_size=None):
if batch_size is None:
deprecate(
"batch_size=None",
"0.0.15",
(
"Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect"
" attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to"
" `prepare_attention_mask` when preparing the attention_mask."
),
)
batch_size = 1
head_size = self.heads
if attention_mask is None:
return attention_mask
if attention_mask.shape[-1] != target_length:
if attention_mask.device.type == "mps":
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
# Instead, we can manually construct the padding tensor.
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
attention_mask = torch.cat([attention_mask, padding], dim=2)
else:
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
if attention_mask.shape[0] < batch_size * head_size:
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
return attention_mask
class CrossAttnProcessor:
def __call__(
self,
attn: CrossAttention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
):
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.cross_attention_norm:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class LoRALinearLayer(nn.Module):
def __init__(self, in_features, out_features, rank=4):
super().__init__()
if rank > min(in_features, out_features):
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
self.down = nn.Linear(in_features, rank, bias=False)
self.up = nn.Linear(rank, out_features, bias=False)
nn.init.normal_(self.down.weight, std=1 / rank)
nn.init.zeros_(self.up.weight)
def forward(self, hidden_states):
orig_dtype = hidden_states.dtype
dtype = self.down.weight.dtype
down_hidden_states = self.down(hidden_states.to(dtype))
up_hidden_states = self.up(down_hidden_states)
return up_hidden_states.to(orig_dtype)
class LoRACrossAttnProcessor(nn.Module):
def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.rank = rank
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
def __call__(
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
):
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
query = attn.head_to_batch_dim(query)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class CrossAttnAddedKVProcessor:
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
hidden_states = hidden_states + residual
return hidden_states
class XFormersCrossAttnProcessor:
def __init__(self, attention_op: Optional[Callable] = None):
self.attention_op = attention_op
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.cross_attention_norm:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class AttnProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
inner_dim = hidden_states.shape[-1]
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.cross_attention_norm:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class LoRAXFormersCrossAttnProcessor(nn.Module):
def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None):
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.rank = rank
self.attention_op = attention_op
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
def __call__(
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
):
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
query = attn.head_to_batch_dim(query).contiguous()
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class SlicedAttnProcessor:
def __init__(self, slice_size):
self.slice_size = slice_size
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = attn.to_q(hidden_states)
dim = query.shape[-1]
query = attn.head_to_batch_dim(query)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.cross_attention_norm:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
batch_size_attention, query_tokens, _ = query.shape
hidden_states = torch.zeros(
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
)
for i in range(batch_size_attention // self.slice_size):
start_idx = i * self.slice_size
end_idx = (i + 1) * self.slice_size
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class SlicedAttnAddedKVProcessor:
def __init__(self, slice_size):
self.slice_size = slice_size
def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=None, attention_mask=None):
residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape deprecate(
"cross_attention",
"0.18.0",
"Importing from cross_attention is deprecated. Please import from attention_processor instead.",
standard_warn=False,
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) AttnProcessor = AttentionProcessor
query = attn.to_q(hidden_states)
dim = query.shape[-1]
query = attn.head_to_batch_dim(query)
key = attn.to_k(hidden_states) class CrossAttention(Attention):
value = attn.to_v(hidden_states) def __init__(self, *args, **kwargs):
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False)
super().__init__(*args, **kwargs)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) class CrossAttnProcessor(AttnProcessorRename):
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) def __init__(self, *args, **kwargs):
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False)
super().__init__(*args, **kwargs)
batch_size_attention, query_tokens, _ = query.shape
hidden_states = torch.zeros(
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
)
for i in range(batch_size_attention // self.slice_size): class LoRACrossAttnProcessor(LoRAAttnProcessor):
start_idx = i * self.slice_size def __init__(self, *args, **kwargs):
end_idx = (i + 1) * self.slice_size deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False)
super().__init__(*args, **kwargs)
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) class CrossAttnAddedKVProcessor(AttnAddedKVProcessor):
def __init__(self, *args, **kwargs):
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False)
super().__init__(*args, **kwargs)
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice class XFormersCrossAttnProcessor(XFormersAttnProcessor):
def __init__(self, *args, **kwargs):
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False)
super().__init__(*args, **kwargs)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj class LoRAXFormersCrossAttnProcessor(LoRAXFormersAttnProcessor):
hidden_states = attn.to_out[0](hidden_states) def __init__(self, *args, **kwargs):
# dropout deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
hidden_states = attn.to_out[1](hidden_states) deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False)
super().__init__(*args, **kwargs)
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
hidden_states = hidden_states + residual
return hidden_states class SlicedCrossAttnProcessor(SlicedAttnProcessor):
def __init__(self, *args, **kwargs):
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False)
super().__init__(*args, **kwargs)
AttnProcessor = Union[ class SlicedCrossAttnAddedKVProcessor(SlicedAttnAddedKVProcessor):
CrossAttnProcessor, def __init__(self, *args, **kwargs):
XFormersCrossAttnProcessor, deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
SlicedAttnProcessor, deprecate("cross_attention", "0.18.0", deprecation_message, standard_warn=False)
CrossAttnAddedKVProcessor, super().__init__(*args, **kwargs)
SlicedAttnAddedKVProcessor,
LoRACrossAttnProcessor,
LoRAXFormersCrossAttnProcessor,
]
...@@ -114,7 +114,7 @@ class DualTransformer2DModel(nn.Module): ...@@ -114,7 +114,7 @@ class DualTransformer2DModel(nn.Module):
timestep ( `torch.long`, *optional*): timestep ( `torch.long`, *optional*):
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
attention_mask (`torch.FloatTensor`, *optional*): attention_mask (`torch.FloatTensor`, *optional*):
Optional attention mask to be applied in CrossAttention Optional attention mask to be applied in Attention
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
......
...@@ -18,7 +18,7 @@ import torch ...@@ -18,7 +18,7 @@ import torch
from torch import nn from torch import nn
from .attention import AdaGroupNorm, AttentionBlock from .attention import AdaGroupNorm, AttentionBlock
from .cross_attention import CrossAttention, CrossAttnAddedKVProcessor from .attention_processor import Attention, AttnAddedKVProcessor
from .dual_transformer_2d import DualTransformer2DModel from .dual_transformer_2d import DualTransformer2DModel
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D
from .transformer_2d import Transformer2DModel from .transformer_2d import Transformer2DModel
...@@ -591,7 +591,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module): ...@@ -591,7 +591,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
for _ in range(num_layers): for _ in range(num_layers):
attentions.append( attentions.append(
CrossAttention( Attention(
query_dim=in_channels, query_dim=in_channels,
cross_attention_dim=in_channels, cross_attention_dim=in_channels,
heads=self.num_heads, heads=self.num_heads,
...@@ -600,7 +600,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module): ...@@ -600,7 +600,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups,
bias=True, bias=True,
upcast_softmax=True, upcast_softmax=True,
processor=CrossAttnAddedKVProcessor(), processor=AttnAddedKVProcessor(),
) )
) )
resnets.append( resnets.append(
...@@ -1365,7 +1365,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module): ...@@ -1365,7 +1365,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
) )
) )
attentions.append( attentions.append(
CrossAttention( Attention(
query_dim=out_channels, query_dim=out_channels,
cross_attention_dim=out_channels, cross_attention_dim=out_channels,
heads=self.num_heads, heads=self.num_heads,
...@@ -1374,7 +1374,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module): ...@@ -1374,7 +1374,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups,
bias=True, bias=True,
upcast_softmax=True, upcast_softmax=True,
processor=CrossAttnAddedKVProcessor(), processor=AttnAddedKVProcessor(),
) )
) )
self.attentions = nn.ModuleList(attentions) self.attentions = nn.ModuleList(attentions)
...@@ -2358,7 +2358,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module): ...@@ -2358,7 +2358,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
) )
) )
attentions.append( attentions.append(
CrossAttention( Attention(
query_dim=out_channels, query_dim=out_channels,
cross_attention_dim=out_channels, cross_attention_dim=out_channels,
heads=self.num_heads, heads=self.num_heads,
...@@ -2367,7 +2367,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module): ...@@ -2367,7 +2367,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups,
bias=True, bias=True,
upcast_softmax=True, upcast_softmax=True,
processor=CrossAttnAddedKVProcessor(), processor=AttnAddedKVProcessor(),
) )
) )
self.attentions = nn.ModuleList(attentions) self.attentions = nn.ModuleList(attentions)
...@@ -2677,7 +2677,7 @@ class KAttentionBlock(nn.Module): ...@@ -2677,7 +2677,7 @@ class KAttentionBlock(nn.Module):
# 1. Self-Attn # 1. Self-Attn
if add_self_attention: if add_self_attention:
self.norm1 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) self.norm1 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size))
self.attn1 = CrossAttention( self.attn1 = Attention(
query_dim=dim, query_dim=dim,
heads=num_attention_heads, heads=num_attention_heads,
dim_head=attention_head_dim, dim_head=attention_head_dim,
...@@ -2689,7 +2689,7 @@ class KAttentionBlock(nn.Module): ...@@ -2689,7 +2689,7 @@ class KAttentionBlock(nn.Module):
# 2. Cross-Attn # 2. Cross-Attn
self.norm2 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) self.norm2 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size))
self.attn2 = CrossAttention( self.attn2 = Attention(
query_dim=dim, query_dim=dim,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
heads=num_attention_heads, heads=num_attention_heads,
......
...@@ -21,7 +21,7 @@ import torch.utils.checkpoint ...@@ -21,7 +21,7 @@ import torch.utils.checkpoint
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin from ..loaders import UNet2DConditionLoadersMixin
from ..utils import BaseOutput, logging from ..utils import BaseOutput, logging
from .cross_attention import AttnProcessor from .attention_processor import AttentionProcessor
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .unet_2d_blocks import ( from .unet_2d_blocks import (
...@@ -362,7 +362,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -362,7 +362,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
) )
@property @property
def attn_processors(self) -> Dict[str, AttnProcessor]: def attn_processors(self) -> Dict[str, AttentionProcessor]:
r""" r"""
Returns: Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with `dict` of attention processors: A dictionary containing all attention processors used in the model with
...@@ -371,7 +371,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -371,7 +371,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
# set recursively # set recursively
processors = {} processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessor]): def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "set_processor"): if hasattr(module, "set_processor"):
processors[f"{name}.processor"] = module.processor processors[f"{name}.processor"] = module.processor
...@@ -385,12 +385,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -385,12 +385,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
return processors return processors
def set_attn_processor(self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]]): def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r""" r"""
Parameters: Parameters:
`processor (`dict` of `AttnProcessor` or `AttnProcessor`): `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor The instantiated processor class or a dictionary of processor classes that will be set as the processor
of **all** `CrossAttention` layers. of **all** `Attention` layers.
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.: In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.:
""" """
...@@ -505,7 +505,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -505,7 +505,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in `self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
......
...@@ -585,7 +585,7 @@ class AltDiffusionPipeline(DiffusionPipeline): ...@@ -585,7 +585,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
The frequency at which the `callback` function will be called. If not specified, the callback will be The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step. called at every step.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in `self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
......
...@@ -588,7 +588,7 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -588,7 +588,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
The frequency at which the `callback` function will be called. If not specified, the callback will be The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step. called at every step.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in `self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
......
...@@ -22,7 +22,7 @@ from torch.nn import functional as F ...@@ -22,7 +22,7 @@ from torch.nn import functional as F
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.cross_attention import CrossAttention from ...models.attention_processor import Attention
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
...@@ -121,13 +121,13 @@ class AttentionStore: ...@@ -121,13 +121,13 @@ class AttentionStore:
self.attn_res = attn_res self.attn_res = attn_res
class AttendExciteCrossAttnProcessor: class AttendExciteAttnProcessor:
def __init__(self, attnstore, place_in_unet): def __init__(self, attnstore, place_in_unet):
super().__init__() super().__init__()
self.attnstore = attnstore self.attnstore = attnstore
self.place_in_unet = place_in_unet self.place_in_unet = place_in_unet
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
...@@ -679,9 +679,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline): ...@@ -679,9 +679,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline):
continue continue
cross_att_count += 1 cross_att_count += 1
attn_procs[name] = AttendExciteCrossAttnProcessor( attn_procs[name] = AttendExciteAttnProcessor(attnstore=self.attention_store, place_in_unet=place_in_unet)
attnstore=self.attention_store, place_in_unet=place_in_unet
)
self.unet.set_attn_processor(attn_procs) self.unet.set_attn_processor(attn_procs)
self.attention_store.num_att_layers = cross_att_count self.attention_store.num_att_layers = cross_att_count
...@@ -777,7 +775,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline): ...@@ -777,7 +775,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline):
The frequency at which the `callback` function will be called. If not specified, the callback will be The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step. called at every step.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in `self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
max_iter_to_alter (`int`, *optional*, defaults to `25`): max_iter_to_alter (`int`, *optional*, defaults to `25`):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment