"...git@developer.sourcefind.cn:OpenDAS/torch-harmonics.git" did not exist on "b36f96707db15b571c88a61ff7429a5a88eed652"
Unverified Commit e8282327 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Rename attention (#2691)

* rename file

* rename attention

* fix more

* rename more

* up

* more deprecation imports

* fixes
parent 588e50bc
...@@ -50,7 +50,7 @@ pip install --pre torch torchvision --index-url https://download.pytorch.org/whl ...@@ -50,7 +50,7 @@ pip install --pre torch torchvision --index-url https://download.pytorch.org/whl
```Python ```Python
import torch import torch
from diffusers import StableDiffusionPipeline from diffusers import StableDiffusionPipeline
from diffusers.models.cross_attention import AttnProcessor2_0 from diffusers.models.attention_processor import AttnProcessor2_0
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda") pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
pipe.unet.set_attn_processor(AttnProcessor2_0()) pipe.unet.set_attn_processor(AttnProcessor2_0())
......
...@@ -713,7 +713,7 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline): ...@@ -713,7 +713,7 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
The frequency at which the `callback` function will be called. If not specified, the callback will be The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step. called at every step.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in `self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
......
...@@ -868,7 +868,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline): ...@@ -868,7 +868,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
The frequency at which the `callback` function will be called. If not specified, the callback will be The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step. called at every step.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in `self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
......
...@@ -911,7 +911,7 @@ class StableDiffusionControlNetInpaintImg2ImgPipeline(DiffusionPipeline): ...@@ -911,7 +911,7 @@ class StableDiffusionControlNetInpaintImg2ImgPipeline(DiffusionPipeline):
The frequency at which the `callback` function will be called. If not specified, the callback will be The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step. called at every step.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in `self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
......
...@@ -47,7 +47,7 @@ from diffusers import ( ...@@ -47,7 +47,7 @@ from diffusers import (
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.loaders import AttnProcsLayers from diffusers.loaders import AttnProcsLayers
from diffusers.models.cross_attention import LoRACrossAttnProcessor from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
...@@ -723,9 +723,7 @@ def main(args): ...@@ -723,9 +723,7 @@ def main(args):
block_id = int(name[len("down_blocks.")]) block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id] hidden_size = unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRACrossAttnProcessor( lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
unet.set_attn_processor(lora_attn_procs) unet.set_attn_processor(lora_attn_procs)
lora_layers = AttnProcsLayers(unet.attn_processors) lora_layers = AttnProcsLayers(unet.attn_processors)
......
...@@ -22,7 +22,7 @@ from transformers import CLIPTextModel, CLIPTokenizer ...@@ -22,7 +22,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel
from diffusers.loaders import AttnProcsLayers from diffusers.loaders import AttnProcsLayers
from diffusers.models.cross_attention import LoRACrossAttnProcessor from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version from diffusers.utils import check_min_version
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
...@@ -561,9 +561,7 @@ def main(): ...@@ -561,9 +561,7 @@ def main():
block_id = int(name[len("down_blocks.")]) block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id] hidden_size = unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRACrossAttnProcessor( lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
unet.set_attn_processor(lora_attn_procs) unet.set_attn_processor(lora_attn_procs)
lora_layers = AttnProcsLayers(unet.attn_processors) lora_layers = AttnProcsLayers(unet.attn_processors)
......
...@@ -43,7 +43,7 @@ from transformers import CLIPTextModel, CLIPTokenizer ...@@ -43,7 +43,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
import diffusers import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.loaders import AttnProcsLayers from diffusers.loaders import AttnProcsLayers
from diffusers.models.cross_attention import LoRACrossAttnProcessor from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
...@@ -536,9 +536,7 @@ def main(): ...@@ -536,9 +536,7 @@ def main():
block_id = int(name[len("down_blocks.")]) block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id] hidden_size = unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRACrossAttnProcessor( lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
unet.set_attn_processor(lora_attn_procs) unet.set_attn_processor(lora_attn_procs)
lora_layers = AttnProcsLayers(unet.attn_processors) lora_layers = AttnProcsLayers(unet.attn_processors)
......
...@@ -41,7 +41,7 @@ from transformers import CLIPTextModel, CLIPTokenizer ...@@ -41,7 +41,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
import diffusers import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.loaders import AttnProcsLayers from diffusers.loaders import AttnProcsLayers
from diffusers.models.cross_attention import LoRACrossAttnProcessor from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
...@@ -474,9 +474,7 @@ def main(): ...@@ -474,9 +474,7 @@ def main():
block_id = int(name[len("down_blocks.")]) block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id] hidden_size = unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRACrossAttnProcessor( lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
unet.set_attn_processor(lora_attn_procs) unet.set_attn_processor(lora_attn_procs)
......
...@@ -17,7 +17,7 @@ from typing import Callable, Dict, Union ...@@ -17,7 +17,7 @@ from typing import Callable, Dict, Union
import torch import torch
from .models.cross_attention import LoRACrossAttnProcessor from .models.attention_processor import LoRAAttnProcessor
from .models.modeling_utils import _get_model_file from .models.modeling_utils import _get_model_file
from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, deprecate, is_safetensors_available, logging from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, deprecate, is_safetensors_available, logging
...@@ -207,7 +207,7 @@ class UNet2DConditionLoadersMixin: ...@@ -207,7 +207,7 @@ class UNet2DConditionLoadersMixin:
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1] cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
hidden_size = value_dict["to_k_lora.up.weight"].shape[0] hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
attn_processors[key] = LoRACrossAttnProcessor( attn_processors[key] = LoRAAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
) )
attn_processors[key].load_state_dict(value_dict) attn_processors[key].load_state_dict(value_dict)
......
...@@ -19,7 +19,7 @@ import torch.nn.functional as F ...@@ -19,7 +19,7 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from ..utils.import_utils import is_xformers_available from ..utils.import_utils import is_xformers_available
from .cross_attention import CrossAttention from .attention_processor import Attention
from .embeddings import CombinedTimestepLabelEmbeddings from .embeddings import CombinedTimestepLabelEmbeddings
...@@ -220,7 +220,7 @@ class BasicTransformerBlock(nn.Module): ...@@ -220,7 +220,7 @@ class BasicTransformerBlock(nn.Module):
) )
# 1. Self-Attn # 1. Self-Attn
self.attn1 = CrossAttention( self.attn1 = Attention(
query_dim=dim, query_dim=dim,
heads=num_attention_heads, heads=num_attention_heads,
dim_head=attention_head_dim, dim_head=attention_head_dim,
...@@ -234,7 +234,7 @@ class BasicTransformerBlock(nn.Module): ...@@ -234,7 +234,7 @@ class BasicTransformerBlock(nn.Module):
# 2. Cross-Attn # 2. Cross-Attn
if cross_attention_dim is not None: if cross_attention_dim is not None:
self.attn2 = CrossAttention( self.attn2 = Attention(
query_dim=dim, query_dim=dim,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
heads=num_attention_heads, heads=num_attention_heads,
......
...@@ -16,7 +16,7 @@ import flax.linen as nn ...@@ -16,7 +16,7 @@ import flax.linen as nn
import jax.numpy as jnp import jax.numpy as jnp
class FlaxCrossAttention(nn.Module): class FlaxAttention(nn.Module):
r""" r"""
A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762 A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762
...@@ -118,9 +118,9 @@ class FlaxBasicTransformerBlock(nn.Module): ...@@ -118,9 +118,9 @@ class FlaxBasicTransformerBlock(nn.Module):
def setup(self): def setup(self):
# self attention (or cross_attention if only_cross_attention is True) # self attention (or cross_attention if only_cross_attention is True)
self.attn1 = FlaxCrossAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) self.attn1 = FlaxAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
# cross attention # cross attention
self.attn2 = FlaxCrossAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) self.attn2 = FlaxAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype) self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
......
This diff is collapsed.
...@@ -20,7 +20,7 @@ from torch.nn import functional as F ...@@ -20,7 +20,7 @@ from torch.nn import functional as F
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging from ..utils import BaseOutput, logging
from .cross_attention import AttnProcessor from .attention_processor import AttentionProcessor
from .embeddings import TimestepEmbedding, Timesteps from .embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .unet_2d_blocks import ( from .unet_2d_blocks import (
...@@ -314,7 +314,7 @@ class ControlNetModel(ModelMixin, ConfigMixin): ...@@ -314,7 +314,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
@property @property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttnProcessor]: def attn_processors(self) -> Dict[str, AttentionProcessor]:
r""" r"""
Returns: Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with `dict` of attention processors: A dictionary containing all attention processors used in the model with
...@@ -323,7 +323,7 @@ class ControlNetModel(ModelMixin, ConfigMixin): ...@@ -323,7 +323,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
# set recursively # set recursively
processors = {} processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessor]): def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "set_processor"): if hasattr(module, "set_processor"):
processors[f"{name}.processor"] = module.processor processors[f"{name}.processor"] = module.processor
...@@ -338,12 +338,12 @@ class ControlNetModel(ModelMixin, ConfigMixin): ...@@ -338,12 +338,12 @@ class ControlNetModel(ModelMixin, ConfigMixin):
return processors return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]]): def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r""" r"""
Parameters: Parameters:
`processor (`dict` of `AttnProcessor` or `AttnProcessor`): `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor The instantiated processor class or a dictionary of processor classes that will be set as the processor
of **all** `CrossAttention` layers. of **all** `Attention` layers.
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.: In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.:
""" """
......
This diff is collapsed.
...@@ -114,7 +114,7 @@ class DualTransformer2DModel(nn.Module): ...@@ -114,7 +114,7 @@ class DualTransformer2DModel(nn.Module):
timestep ( `torch.long`, *optional*): timestep ( `torch.long`, *optional*):
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
attention_mask (`torch.FloatTensor`, *optional*): attention_mask (`torch.FloatTensor`, *optional*):
Optional attention mask to be applied in CrossAttention Optional attention mask to be applied in Attention
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
......
...@@ -18,7 +18,7 @@ import torch ...@@ -18,7 +18,7 @@ import torch
from torch import nn from torch import nn
from .attention import AdaGroupNorm, AttentionBlock from .attention import AdaGroupNorm, AttentionBlock
from .cross_attention import CrossAttention, CrossAttnAddedKVProcessor from .attention_processor import Attention, AttnAddedKVProcessor
from .dual_transformer_2d import DualTransformer2DModel from .dual_transformer_2d import DualTransformer2DModel
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D
from .transformer_2d import Transformer2DModel from .transformer_2d import Transformer2DModel
...@@ -591,7 +591,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module): ...@@ -591,7 +591,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
for _ in range(num_layers): for _ in range(num_layers):
attentions.append( attentions.append(
CrossAttention( Attention(
query_dim=in_channels, query_dim=in_channels,
cross_attention_dim=in_channels, cross_attention_dim=in_channels,
heads=self.num_heads, heads=self.num_heads,
...@@ -600,7 +600,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module): ...@@ -600,7 +600,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups,
bias=True, bias=True,
upcast_softmax=True, upcast_softmax=True,
processor=CrossAttnAddedKVProcessor(), processor=AttnAddedKVProcessor(),
) )
) )
resnets.append( resnets.append(
...@@ -1365,7 +1365,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module): ...@@ -1365,7 +1365,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
) )
) )
attentions.append( attentions.append(
CrossAttention( Attention(
query_dim=out_channels, query_dim=out_channels,
cross_attention_dim=out_channels, cross_attention_dim=out_channels,
heads=self.num_heads, heads=self.num_heads,
...@@ -1374,7 +1374,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module): ...@@ -1374,7 +1374,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups,
bias=True, bias=True,
upcast_softmax=True, upcast_softmax=True,
processor=CrossAttnAddedKVProcessor(), processor=AttnAddedKVProcessor(),
) )
) )
self.attentions = nn.ModuleList(attentions) self.attentions = nn.ModuleList(attentions)
...@@ -2358,7 +2358,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module): ...@@ -2358,7 +2358,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
) )
) )
attentions.append( attentions.append(
CrossAttention( Attention(
query_dim=out_channels, query_dim=out_channels,
cross_attention_dim=out_channels, cross_attention_dim=out_channels,
heads=self.num_heads, heads=self.num_heads,
...@@ -2367,7 +2367,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module): ...@@ -2367,7 +2367,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups,
bias=True, bias=True,
upcast_softmax=True, upcast_softmax=True,
processor=CrossAttnAddedKVProcessor(), processor=AttnAddedKVProcessor(),
) )
) )
self.attentions = nn.ModuleList(attentions) self.attentions = nn.ModuleList(attentions)
...@@ -2677,7 +2677,7 @@ class KAttentionBlock(nn.Module): ...@@ -2677,7 +2677,7 @@ class KAttentionBlock(nn.Module):
# 1. Self-Attn # 1. Self-Attn
if add_self_attention: if add_self_attention:
self.norm1 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) self.norm1 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size))
self.attn1 = CrossAttention( self.attn1 = Attention(
query_dim=dim, query_dim=dim,
heads=num_attention_heads, heads=num_attention_heads,
dim_head=attention_head_dim, dim_head=attention_head_dim,
...@@ -2689,7 +2689,7 @@ class KAttentionBlock(nn.Module): ...@@ -2689,7 +2689,7 @@ class KAttentionBlock(nn.Module):
# 2. Cross-Attn # 2. Cross-Attn
self.norm2 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) self.norm2 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size))
self.attn2 = CrossAttention( self.attn2 = Attention(
query_dim=dim, query_dim=dim,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
heads=num_attention_heads, heads=num_attention_heads,
......
...@@ -21,7 +21,7 @@ import torch.utils.checkpoint ...@@ -21,7 +21,7 @@ import torch.utils.checkpoint
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin from ..loaders import UNet2DConditionLoadersMixin
from ..utils import BaseOutput, logging from ..utils import BaseOutput, logging
from .cross_attention import AttnProcessor from .attention_processor import AttentionProcessor
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .unet_2d_blocks import ( from .unet_2d_blocks import (
...@@ -362,7 +362,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -362,7 +362,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
) )
@property @property
def attn_processors(self) -> Dict[str, AttnProcessor]: def attn_processors(self) -> Dict[str, AttentionProcessor]:
r""" r"""
Returns: Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with `dict` of attention processors: A dictionary containing all attention processors used in the model with
...@@ -371,7 +371,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -371,7 +371,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
# set recursively # set recursively
processors = {} processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessor]): def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "set_processor"): if hasattr(module, "set_processor"):
processors[f"{name}.processor"] = module.processor processors[f"{name}.processor"] = module.processor
...@@ -385,12 +385,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -385,12 +385,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
return processors return processors
def set_attn_processor(self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]]): def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r""" r"""
Parameters: Parameters:
`processor (`dict` of `AttnProcessor` or `AttnProcessor`): `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor The instantiated processor class or a dictionary of processor classes that will be set as the processor
of **all** `CrossAttention` layers. of **all** `Attention` layers.
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.: In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.:
""" """
...@@ -505,7 +505,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -505,7 +505,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in `self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
......
...@@ -585,7 +585,7 @@ class AltDiffusionPipeline(DiffusionPipeline): ...@@ -585,7 +585,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
The frequency at which the `callback` function will be called. If not specified, the callback will be The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step. called at every step.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in `self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
......
...@@ -588,7 +588,7 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -588,7 +588,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
The frequency at which the `callback` function will be called. If not specified, the callback will be The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step. called at every step.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in `self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
......
...@@ -22,7 +22,7 @@ from torch.nn import functional as F ...@@ -22,7 +22,7 @@ from torch.nn import functional as F
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.cross_attention import CrossAttention from ...models.attention_processor import Attention
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
...@@ -121,13 +121,13 @@ class AttentionStore: ...@@ -121,13 +121,13 @@ class AttentionStore:
self.attn_res = attn_res self.attn_res = attn_res
class AttendExciteCrossAttnProcessor: class AttendExciteAttnProcessor:
def __init__(self, attnstore, place_in_unet): def __init__(self, attnstore, place_in_unet):
super().__init__() super().__init__()
self.attnstore = attnstore self.attnstore = attnstore
self.place_in_unet = place_in_unet self.place_in_unet = place_in_unet
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
...@@ -679,9 +679,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline): ...@@ -679,9 +679,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline):
continue continue
cross_att_count += 1 cross_att_count += 1
attn_procs[name] = AttendExciteCrossAttnProcessor( attn_procs[name] = AttendExciteAttnProcessor(attnstore=self.attention_store, place_in_unet=place_in_unet)
attnstore=self.attention_store, place_in_unet=place_in_unet
)
self.unet.set_attn_processor(attn_procs) self.unet.set_attn_processor(attn_procs)
self.attention_store.num_att_layers = cross_att_count self.attention_store.num_att_layers = cross_att_count
...@@ -777,7 +775,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline): ...@@ -777,7 +775,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline):
The frequency at which the `callback` function will be called. If not specified, the callback will be The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step. called at every step.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in `self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
max_iter_to_alter (`int`, *optional*, defaults to `25`): max_iter_to_alter (`int`, *optional*, defaults to `25`):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment