Unverified Commit e8282327 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Rename attention (#2691)

* rename file

* rename attention

* fix more

* rename more

* up

* more deprecation imports

* fixes
parent 588e50bc
......@@ -50,7 +50,7 @@ pip install --pre torch torchvision --index-url https://download.pytorch.org/whl
```Python
import torch
from diffusers import StableDiffusionPipeline
from diffusers.models.cross_attention import AttnProcessor2_0
from diffusers.models.attention_processor import AttnProcessor2_0
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
pipe.unet.set_attn_processor(AttnProcessor2_0())
......
......@@ -713,7 +713,7 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
......
......@@ -868,7 +868,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
......
......@@ -911,7 +911,7 @@ class StableDiffusionControlNetInpaintImg2ImgPipeline(DiffusionPipeline):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
......
......@@ -47,7 +47,7 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.loaders import AttnProcsLayers
from diffusers.models.cross_attention import LoRACrossAttnProcessor
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
......@@ -723,9 +723,7 @@ def main(args):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRACrossAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
unet.set_attn_processor(lora_attn_procs)
lora_layers = AttnProcsLayers(unet.attn_processors)
......
......@@ -22,7 +22,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel
from diffusers.loaders import AttnProcsLayers
from diffusers.models.cross_attention import LoRACrossAttnProcessor
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version
from diffusers.utils.import_utils import is_xformers_available
......@@ -561,9 +561,7 @@ def main():
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRACrossAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
unet.set_attn_processor(lora_attn_procs)
lora_layers = AttnProcsLayers(unet.attn_processors)
......
......@@ -43,7 +43,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.loaders import AttnProcsLayers
from diffusers.models.cross_attention import LoRACrossAttnProcessor
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
......@@ -536,9 +536,7 @@ def main():
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRACrossAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
unet.set_attn_processor(lora_attn_procs)
lora_layers = AttnProcsLayers(unet.attn_processors)
......
......@@ -41,7 +41,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.loaders import AttnProcsLayers
from diffusers.models.cross_attention import LoRACrossAttnProcessor
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
......@@ -474,9 +474,7 @@ def main():
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRACrossAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
unet.set_attn_processor(lora_attn_procs)
......
......@@ -17,7 +17,7 @@ from typing import Callable, Dict, Union
import torch
from .models.cross_attention import LoRACrossAttnProcessor
from .models.attention_processor import LoRAAttnProcessor
from .models.modeling_utils import _get_model_file
from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, deprecate, is_safetensors_available, logging
......@@ -207,7 +207,7 @@ class UNet2DConditionLoadersMixin:
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
attn_processors[key] = LoRACrossAttnProcessor(
attn_processors[key] = LoRAAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
)
attn_processors[key].load_state_dict(value_dict)
......
......@@ -19,7 +19,7 @@ import torch.nn.functional as F
from torch import nn
from ..utils.import_utils import is_xformers_available
from .cross_attention import CrossAttention
from .attention_processor import Attention
from .embeddings import CombinedTimestepLabelEmbeddings
......@@ -220,7 +220,7 @@ class BasicTransformerBlock(nn.Module):
)
# 1. Self-Attn
self.attn1 = CrossAttention(
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
......@@ -234,7 +234,7 @@ class BasicTransformerBlock(nn.Module):
# 2. Cross-Attn
if cross_attention_dim is not None:
self.attn2 = CrossAttention(
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
......
......@@ -16,7 +16,7 @@ import flax.linen as nn
import jax.numpy as jnp
class FlaxCrossAttention(nn.Module):
class FlaxAttention(nn.Module):
r"""
A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762
......@@ -118,9 +118,9 @@ class FlaxBasicTransformerBlock(nn.Module):
def setup(self):
# self attention (or cross_attention if only_cross_attention is True)
self.attn1 = FlaxCrossAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
self.attn1 = FlaxAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
# cross attention
self.attn2 = FlaxCrossAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
self.attn2 = FlaxAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
......
This diff is collapsed.
......@@ -20,7 +20,7 @@ from torch.nn import functional as F
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging
from .cross_attention import AttnProcessor
from .attention_processor import AttentionProcessor
from .embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .unet_2d_blocks import (
......@@ -314,7 +314,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttnProcessor]:
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
......@@ -323,7 +323,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessor]):
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "set_processor"):
processors[f"{name}.processor"] = module.processor
......@@ -338,12 +338,12 @@ class ControlNetModel(ModelMixin, ConfigMixin):
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]]):
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Parameters:
`processor (`dict` of `AttnProcessor` or `AttnProcessor`):
`processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
of **all** `CrossAttention` layers.
of **all** `Attention` layers.
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.:
"""
......
This diff is collapsed.
......@@ -114,7 +114,7 @@ class DualTransformer2DModel(nn.Module):
timestep ( `torch.long`, *optional*):
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
attention_mask (`torch.FloatTensor`, *optional*):
Optional attention mask to be applied in CrossAttention
Optional attention mask to be applied in Attention
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
......
......@@ -18,7 +18,7 @@ import torch
from torch import nn
from .attention import AdaGroupNorm, AttentionBlock
from .cross_attention import CrossAttention, CrossAttnAddedKVProcessor
from .attention_processor import Attention, AttnAddedKVProcessor
from .dual_transformer_2d import DualTransformer2DModel
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D
from .transformer_2d import Transformer2DModel
......@@ -591,7 +591,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
for _ in range(num_layers):
attentions.append(
CrossAttention(
Attention(
query_dim=in_channels,
cross_attention_dim=in_channels,
heads=self.num_heads,
......@@ -600,7 +600,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
norm_num_groups=resnet_groups,
bias=True,
upcast_softmax=True,
processor=CrossAttnAddedKVProcessor(),
processor=AttnAddedKVProcessor(),
)
)
resnets.append(
......@@ -1365,7 +1365,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
)
)
attentions.append(
CrossAttention(
Attention(
query_dim=out_channels,
cross_attention_dim=out_channels,
heads=self.num_heads,
......@@ -1374,7 +1374,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
norm_num_groups=resnet_groups,
bias=True,
upcast_softmax=True,
processor=CrossAttnAddedKVProcessor(),
processor=AttnAddedKVProcessor(),
)
)
self.attentions = nn.ModuleList(attentions)
......@@ -2358,7 +2358,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
)
)
attentions.append(
CrossAttention(
Attention(
query_dim=out_channels,
cross_attention_dim=out_channels,
heads=self.num_heads,
......@@ -2367,7 +2367,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
norm_num_groups=resnet_groups,
bias=True,
upcast_softmax=True,
processor=CrossAttnAddedKVProcessor(),
processor=AttnAddedKVProcessor(),
)
)
self.attentions = nn.ModuleList(attentions)
......@@ -2677,7 +2677,7 @@ class KAttentionBlock(nn.Module):
# 1. Self-Attn
if add_self_attention:
self.norm1 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size))
self.attn1 = CrossAttention(
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
......@@ -2689,7 +2689,7 @@ class KAttentionBlock(nn.Module):
# 2. Cross-Attn
self.norm2 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size))
self.attn2 = CrossAttention(
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
......
......@@ -21,7 +21,7 @@ import torch.utils.checkpoint
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin
from ..utils import BaseOutput, logging
from .cross_attention import AttnProcessor
from .attention_processor import AttentionProcessor
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .unet_2d_blocks import (
......@@ -362,7 +362,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
)
@property
def attn_processors(self) -> Dict[str, AttnProcessor]:
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
......@@ -371,7 +371,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessor]):
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "set_processor"):
processors[f"{name}.processor"] = module.processor
......@@ -385,12 +385,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
return processors
def set_attn_processor(self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]]):
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Parameters:
`processor (`dict` of `AttnProcessor` or `AttnProcessor`):
`processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
of **all** `CrossAttention` layers.
of **all** `Attention` layers.
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.:
"""
......@@ -505,7 +505,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
......
......@@ -585,7 +585,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
......
......@@ -588,7 +588,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
......
......@@ -22,7 +22,7 @@ from torch.nn import functional as F
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.cross_attention import CrossAttention
from ...models.attention_processor import Attention
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline
......@@ -121,13 +121,13 @@ class AttentionStore:
self.attn_res = attn_res
class AttendExciteCrossAttnProcessor:
class AttendExciteAttnProcessor:
def __init__(self, attnstore, place_in_unet):
super().__init__()
self.attnstore = attnstore
self.place_in_unet = place_in_unet
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
......@@ -679,9 +679,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline):
continue
cross_att_count += 1
attn_procs[name] = AttendExciteCrossAttnProcessor(
attnstore=self.attention_store, place_in_unet=place_in_unet
)
attn_procs[name] = AttendExciteAttnProcessor(attnstore=self.attention_store, place_in_unet=place_in_unet)
self.unet.set_attn_processor(attn_procs)
self.attention_store.num_att_layers = cross_att_count
......@@ -777,7 +775,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
max_iter_to_alter (`int`, *optional*, defaults to `25`):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment