Unverified Commit e8282327 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Rename attention (#2691)

* rename file

* rename attention

* fix more

* rename more

* up

* more deprecation imports

* fixes
parent 588e50bc
......@@ -789,7 +789,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
......
......@@ -525,7 +525,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
......
......@@ -29,7 +29,7 @@ from transformers import (
)
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.cross_attention import CrossAttention
from ...models.attention_processor import Attention
from ...schedulers import DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler
from ...schedulers.scheduling_ddim_inverse import DDIMInverseScheduler
from ...utils import (
......@@ -200,10 +200,10 @@ def prepare_unet(unet: UNet2DConditionModel):
module_name = name.replace(".processor", "")
module = unet.get_submodule(module_name)
if "attn2" in name:
pix2pix_zero_attn_procs[name] = Pix2PixZeroCrossAttnProcessor(is_pix2pix_zero=True)
pix2pix_zero_attn_procs[name] = Pix2PixZeroAttnProcessor(is_pix2pix_zero=True)
module.requires_grad_(True)
else:
pix2pix_zero_attn_procs[name] = Pix2PixZeroCrossAttnProcessor(is_pix2pix_zero=False)
pix2pix_zero_attn_procs[name] = Pix2PixZeroAttnProcessor(is_pix2pix_zero=False)
module.requires_grad_(False)
unet.set_attn_processor(pix2pix_zero_attn_procs)
......@@ -218,7 +218,7 @@ class Pix2PixZeroL2Loss:
self.loss += ((predictions - targets) ** 2).sum((1, 2)).mean(0)
class Pix2PixZeroCrossAttnProcessor:
class Pix2PixZeroAttnProcessor:
"""An attention processor class to store the attention weights.
In Pix2Pix Zero, it happens during computations in the cross-attention blocks."""
......@@ -229,7 +229,7 @@ class Pix2PixZeroCrossAttnProcessor:
def __call__(
self,
attn: CrossAttention,
attn: Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
......
......@@ -530,7 +530,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
......
......@@ -684,7 +684,7 @@ class StableUnCLIPPipeline(DiffusionPipeline):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
noise_level (`int`, *optional*, defaults to `0`):
......
......@@ -653,7 +653,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
noise_level (`int`, *optional*, defaults to `0`):
......
......@@ -6,8 +6,8 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...models import ModelMixin
from ...models.attention import CrossAttention
from ...models.cross_attention import AttnProcessor, CrossAttnAddedKVProcessor
from ...models.attention import Attention
from ...models.attention_processor import AttentionProcessor, AttnAddedKVProcessor
from ...models.dual_transformer_2d import DualTransformer2DModel
from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from ...models.transformer_2d import Transformer2DModel
......@@ -452,7 +452,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
)
@property
def attn_processors(self) -> Dict[str, AttnProcessor]:
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
......@@ -461,7 +461,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessor]):
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "set_processor"):
processors[f"{name}.processor"] = module.processor
......@@ -475,12 +475,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
return processors
def set_attn_processor(self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]]):
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Parameters:
`processor (`dict` of `AttnProcessor` or `AttnProcessor`):
`processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
of **all** `CrossAttention` layers.
of **all** `Attention` layers.
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.:
"""
......@@ -595,7 +595,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
......@@ -1425,7 +1425,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
for _ in range(num_layers):
attentions.append(
CrossAttention(
Attention(
query_dim=in_channels,
cross_attention_dim=in_channels,
heads=self.num_heads,
......@@ -1434,7 +1434,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
norm_num_groups=resnet_groups,
bias=True,
upcast_softmax=True,
processor=CrossAttnAddedKVProcessor(),
processor=AttnAddedKVProcessor(),
)
)
resnets.append(
......
......@@ -22,7 +22,7 @@ import torch
from parameterized import parameterized
from diffusers import UNet2DConditionModel
from diffusers.models.cross_attention import CrossAttnProcessor, LoRACrossAttnProcessor
from diffusers.models.attention_processor import AttnProcessor, LoRAAttnProcessor
from diffusers.utils import (
floats_tensor,
load_hf_numpy,
......@@ -54,9 +54,7 @@ def create_lora_layers(model):
block_id = int(name[len("down_blocks.")])
hidden_size = model.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRACrossAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
# add 1 to weights to mock trained weights
......@@ -119,7 +117,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
assert (
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
== "XFormersCrossAttnProcessor"
== "XFormersAttnProcessor"
), "xformers is not enabled"
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
......@@ -324,9 +322,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
block_id = int(name[len("down_blocks.")])
hidden_size = model.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRACrossAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
# add 1 to weights to mock trained weights
with torch.no_grad():
......@@ -413,9 +409,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
block_id = int(name[len("down_blocks.")])
hidden_size = model.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRACrossAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
# add 1 to weights to mock trained weights
......@@ -468,9 +462,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
block_id = int(name[len("down_blocks.")])
hidden_size = model.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRACrossAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
model.set_attn_processor(lora_attn_procs)
......@@ -502,7 +494,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
with torch.no_grad():
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample
model.set_attn_processor(CrossAttnProcessor())
model.set_attn_processor(AttnProcessor())
with torch.no_grad():
new_sample = model(**inputs_dict).sample
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment