Unverified Commit e8282327 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Rename attention (#2691)

* rename file

* rename attention

* fix more

* rename more

* up

* more deprecation imports

* fixes
parent 588e50bc
...@@ -789,7 +789,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline): ...@@ -789,7 +789,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
The frequency at which the `callback` function will be called. If not specified, the callback will be The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step. called at every step.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in `self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
......
...@@ -525,7 +525,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline): ...@@ -525,7 +525,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline):
The frequency at which the `callback` function will be called. If not specified, the callback will be The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step. called at every step.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in `self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
......
...@@ -29,7 +29,7 @@ from transformers import ( ...@@ -29,7 +29,7 @@ from transformers import (
) )
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.cross_attention import CrossAttention from ...models.attention_processor import Attention
from ...schedulers import DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler from ...schedulers import DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler
from ...schedulers.scheduling_ddim_inverse import DDIMInverseScheduler from ...schedulers.scheduling_ddim_inverse import DDIMInverseScheduler
from ...utils import ( from ...utils import (
...@@ -200,10 +200,10 @@ def prepare_unet(unet: UNet2DConditionModel): ...@@ -200,10 +200,10 @@ def prepare_unet(unet: UNet2DConditionModel):
module_name = name.replace(".processor", "") module_name = name.replace(".processor", "")
module = unet.get_submodule(module_name) module = unet.get_submodule(module_name)
if "attn2" in name: if "attn2" in name:
pix2pix_zero_attn_procs[name] = Pix2PixZeroCrossAttnProcessor(is_pix2pix_zero=True) pix2pix_zero_attn_procs[name] = Pix2PixZeroAttnProcessor(is_pix2pix_zero=True)
module.requires_grad_(True) module.requires_grad_(True)
else: else:
pix2pix_zero_attn_procs[name] = Pix2PixZeroCrossAttnProcessor(is_pix2pix_zero=False) pix2pix_zero_attn_procs[name] = Pix2PixZeroAttnProcessor(is_pix2pix_zero=False)
module.requires_grad_(False) module.requires_grad_(False)
unet.set_attn_processor(pix2pix_zero_attn_procs) unet.set_attn_processor(pix2pix_zero_attn_procs)
...@@ -218,7 +218,7 @@ class Pix2PixZeroL2Loss: ...@@ -218,7 +218,7 @@ class Pix2PixZeroL2Loss:
self.loss += ((predictions - targets) ** 2).sum((1, 2)).mean(0) self.loss += ((predictions - targets) ** 2).sum((1, 2)).mean(0)
class Pix2PixZeroCrossAttnProcessor: class Pix2PixZeroAttnProcessor:
"""An attention processor class to store the attention weights. """An attention processor class to store the attention weights.
In Pix2Pix Zero, it happens during computations in the cross-attention blocks.""" In Pix2Pix Zero, it happens during computations in the cross-attention blocks."""
...@@ -229,7 +229,7 @@ class Pix2PixZeroCrossAttnProcessor: ...@@ -229,7 +229,7 @@ class Pix2PixZeroCrossAttnProcessor:
def __call__( def __call__(
self, self,
attn: CrossAttention, attn: Attention,
hidden_states, hidden_states,
encoder_hidden_states=None, encoder_hidden_states=None,
attention_mask=None, attention_mask=None,
......
...@@ -530,7 +530,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline): ...@@ -530,7 +530,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline):
The frequency at which the `callback` function will be called. If not specified, the callback will be The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step. called at every step.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in `self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
......
...@@ -684,7 +684,7 @@ class StableUnCLIPPipeline(DiffusionPipeline): ...@@ -684,7 +684,7 @@ class StableUnCLIPPipeline(DiffusionPipeline):
The frequency at which the `callback` function will be called. If not specified, the callback will be The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step. called at every step.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in `self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
noise_level (`int`, *optional*, defaults to `0`): noise_level (`int`, *optional*, defaults to `0`):
......
...@@ -653,7 +653,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline): ...@@ -653,7 +653,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline):
The frequency at which the `callback` function will be called. If not specified, the callback will be The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step. called at every step.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in `self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
noise_level (`int`, *optional*, defaults to `0`): noise_level (`int`, *optional*, defaults to `0`):
......
...@@ -6,8 +6,8 @@ import torch.nn as nn ...@@ -6,8 +6,8 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...models import ModelMixin from ...models import ModelMixin
from ...models.attention import CrossAttention from ...models.attention import Attention
from ...models.cross_attention import AttnProcessor, CrossAttnAddedKVProcessor from ...models.attention_processor import AttentionProcessor, AttnAddedKVProcessor
from ...models.dual_transformer_2d import DualTransformer2DModel from ...models.dual_transformer_2d import DualTransformer2DModel
from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from ...models.transformer_2d import Transformer2DModel from ...models.transformer_2d import Transformer2DModel
...@@ -452,7 +452,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -452,7 +452,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
) )
@property @property
def attn_processors(self) -> Dict[str, AttnProcessor]: def attn_processors(self) -> Dict[str, AttentionProcessor]:
r""" r"""
Returns: Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with `dict` of attention processors: A dictionary containing all attention processors used in the model with
...@@ -461,7 +461,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -461,7 +461,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
# set recursively # set recursively
processors = {} processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessor]): def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "set_processor"): if hasattr(module, "set_processor"):
processors[f"{name}.processor"] = module.processor processors[f"{name}.processor"] = module.processor
...@@ -475,12 +475,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -475,12 +475,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
return processors return processors
def set_attn_processor(self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]]): def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r""" r"""
Parameters: Parameters:
`processor (`dict` of `AttnProcessor` or `AttnProcessor`): `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor The instantiated processor class or a dictionary of processor classes that will be set as the processor
of **all** `CrossAttention` layers. of **all** `Attention` layers.
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.: In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.:
""" """
...@@ -595,7 +595,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -595,7 +595,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
cross_attention_kwargs (`dict`, *optional*): cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in `self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
...@@ -1425,7 +1425,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module): ...@@ -1425,7 +1425,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
for _ in range(num_layers): for _ in range(num_layers):
attentions.append( attentions.append(
CrossAttention( Attention(
query_dim=in_channels, query_dim=in_channels,
cross_attention_dim=in_channels, cross_attention_dim=in_channels,
heads=self.num_heads, heads=self.num_heads,
...@@ -1434,7 +1434,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module): ...@@ -1434,7 +1434,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups,
bias=True, bias=True,
upcast_softmax=True, upcast_softmax=True,
processor=CrossAttnAddedKVProcessor(), processor=AttnAddedKVProcessor(),
) )
) )
resnets.append( resnets.append(
......
...@@ -22,7 +22,7 @@ import torch ...@@ -22,7 +22,7 @@ import torch
from parameterized import parameterized from parameterized import parameterized
from diffusers import UNet2DConditionModel from diffusers import UNet2DConditionModel
from diffusers.models.cross_attention import CrossAttnProcessor, LoRACrossAttnProcessor from diffusers.models.attention_processor import AttnProcessor, LoRAAttnProcessor
from diffusers.utils import ( from diffusers.utils import (
floats_tensor, floats_tensor,
load_hf_numpy, load_hf_numpy,
...@@ -54,9 +54,7 @@ def create_lora_layers(model): ...@@ -54,9 +54,7 @@ def create_lora_layers(model):
block_id = int(name[len("down_blocks.")]) block_id = int(name[len("down_blocks.")])
hidden_size = model.config.block_out_channels[block_id] hidden_size = model.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRACrossAttnProcessor( lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
lora_attn_procs[name] = lora_attn_procs[name].to(model.device) lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
# add 1 to weights to mock trained weights # add 1 to weights to mock trained weights
...@@ -119,7 +117,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -119,7 +117,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
assert ( assert (
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__ model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
== "XFormersCrossAttnProcessor" == "XFormersAttnProcessor"
), "xformers is not enabled" ), "xformers is not enabled"
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS") @unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
...@@ -324,9 +322,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -324,9 +322,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
block_id = int(name[len("down_blocks.")]) block_id = int(name[len("down_blocks.")])
hidden_size = model.config.block_out_channels[block_id] hidden_size = model.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRACrossAttnProcessor( lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
# add 1 to weights to mock trained weights # add 1 to weights to mock trained weights
with torch.no_grad(): with torch.no_grad():
...@@ -413,9 +409,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -413,9 +409,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
block_id = int(name[len("down_blocks.")]) block_id = int(name[len("down_blocks.")])
hidden_size = model.config.block_out_channels[block_id] hidden_size = model.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRACrossAttnProcessor( lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
lora_attn_procs[name] = lora_attn_procs[name].to(model.device) lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
# add 1 to weights to mock trained weights # add 1 to weights to mock trained weights
...@@ -468,9 +462,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -468,9 +462,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
block_id = int(name[len("down_blocks.")]) block_id = int(name[len("down_blocks.")])
hidden_size = model.config.block_out_channels[block_id] hidden_size = model.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRACrossAttnProcessor( lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
lora_attn_procs[name] = lora_attn_procs[name].to(model.device) lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
model.set_attn_processor(lora_attn_procs) model.set_attn_processor(lora_attn_procs)
...@@ -502,7 +494,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -502,7 +494,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample
model.set_attn_processor(CrossAttnProcessor()) model.set_attn_processor(AttnProcessor())
with torch.no_grad(): with torch.no_grad():
new_sample = model(**inputs_dict).sample new_sample = model(**inputs_dict).sample
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment