# Copyright 2022 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import Optional import torch import torch.nn.functional as F from torch import nn from ..utils.import_utils import is_xformers_available from .cross_attention import CrossAttention if is_xformers_available(): import xformers import xformers.ops else: xformers = None class AttentionBlock(nn.Module): """ An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted to the N-d case. https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. Uses three q, k, v linear layers to compute attention. Parameters: channels (`int`): The number of channels in the input and output. num_head_channels (`int`, *optional*): The number of channels in each head. If None, then `num_heads` = 1. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm. rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by. eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. """ # IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore def __init__( self, channels: int, num_head_channels: Optional[int] = None, norm_num_groups: int = 32, rescale_output_factor: float = 1.0, eps: float = 1e-5, ): super().__init__() self.channels = channels self.num_heads = channels // num_head_channels if num_head_channels is not None else 1 self.num_head_size = num_head_channels self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True) # define q,k,v as linear layers self.query = nn.Linear(channels, channels) self.key = nn.Linear(channels, channels) self.value = nn.Linear(channels, channels) self.rescale_output_factor = rescale_output_factor self.proj_attn = nn.Linear(channels, channels, 1) self._use_memory_efficient_attention_xformers = False def reshape_heads_to_batch_dim(self, tensor): batch_size, seq_len, dim = tensor.shape head_size = self.num_heads tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) return tensor def reshape_batch_dim_to_heads(self, tensor): batch_size, seq_len, dim = tensor.shape head_size = self.num_heads tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) return tensor def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): if use_memory_efficient_attention_xformers: if not is_xformers_available(): raise ModuleNotFoundError( ( "Refer to https://github.com/facebookresearch/xformers for more information on how to install" " xformers" ), name="xformers", ) elif not torch.cuda.is_available(): raise ValueError( "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" " only available for GPU " ) else: try: # Make sure we can run the memory efficient attention _ = xformers.ops.memory_efficient_attention( torch.randn((1, 2, 40), device="cuda"), torch.randn((1, 2, 40), device="cuda"), torch.randn((1, 2, 40), device="cuda"), ) except Exception as e: raise e self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers def forward(self, hidden_states): residual = hidden_states batch, channel, height, width = hidden_states.shape # norm hidden_states = self.group_norm(hidden_states) hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) # proj to q, k, v query_proj = self.query(hidden_states) key_proj = self.key(hidden_states) value_proj = self.value(hidden_states) scale = 1 / math.sqrt(self.channels / self.num_heads) query_proj = self.reshape_heads_to_batch_dim(query_proj) key_proj = self.reshape_heads_to_batch_dim(key_proj) value_proj = self.reshape_heads_to_batch_dim(value_proj) if self._use_memory_efficient_attention_xformers: # Memory efficient attention hidden_states = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None) hidden_states = hidden_states.to(query_proj.dtype) else: attention_scores = torch.baddbmm( torch.empty( query_proj.shape[0], query_proj.shape[1], key_proj.shape[1], dtype=query_proj.dtype, device=query_proj.device, ), query_proj, key_proj.transpose(-1, -2), beta=0, alpha=scale, ) attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype) hidden_states = torch.bmm(attention_probs, value_proj) # reshape hidden_states hidden_states = self.reshape_batch_dim_to_heads(hidden_states) # compute next hidden_states hidden_states = self.proj_attn(hidden_states) hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) # res connect and rescale hidden_states = (hidden_states + residual) / self.rescale_output_factor return hidden_states class BasicTransformerBlock(nn.Module): r""" A basic Transformer block. Parameters: dim (`int`): The number of channels in the input and output. num_attention_heads (`int`): The number of heads to use for multi-head attention. attention_head_dim (`int`): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm (: obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. attention_bias (: obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. """ def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, dropout=0.0, cross_attention_dim: Optional[int] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, attention_bias: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, ): super().__init__() self.only_cross_attention = only_cross_attention self.use_ada_layer_norm = num_embeds_ada_norm is not None # 1. Self-Attn self.attn1 = CrossAttention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, cross_attention_dim=cross_attention_dim if only_cross_attention else None, upcast_attention=upcast_attention, ) self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) # 2. Cross-Attn if cross_attention_dim is not None: self.attn2 = CrossAttention( query_dim=dim, cross_attention_dim=cross_attention_dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, upcast_attention=upcast_attention, ) # is self-attn if encoder_hidden_states is none else: self.attn2 = None self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) if cross_attention_dim is not None: self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) else: self.norm2 = None # 3. Feed-forward self.norm3 = nn.LayerNorm(dim) def forward( self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, cross_attention_kwargs=None, ): # 1. Self-Attention norm_hidden_states = ( self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) ) cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, attention_mask=attention_mask, **cross_attention_kwargs, ) hidden_states = attn_output + hidden_states if self.attn2 is not None: # 2. Cross-Attention norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, **cross_attention_kwargs, ) hidden_states = attn_output + hidden_states # 3. Feed-forward hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states return hidden_states class FeedForward(nn.Module): r""" A feed-forward layer. Parameters: dim (`int`): The number of channels in the input. dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. """ def __init__( self, dim: int, dim_out: Optional[int] = None, mult: int = 4, dropout: float = 0.0, activation_fn: str = "geglu", ): super().__init__() inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim if activation_fn == "gelu": act_fn = GELU(dim, inner_dim) elif activation_fn == "geglu": act_fn = GEGLU(dim, inner_dim) elif activation_fn == "geglu-approximate": act_fn = ApproximateGELU(dim, inner_dim) self.net = nn.ModuleList([]) # project in self.net.append(act_fn) # project dropout self.net.append(nn.Dropout(dropout)) # project out self.net.append(nn.Linear(inner_dim, dim_out)) def forward(self, hidden_states): for module in self.net: hidden_states = module(hidden_states) return hidden_states class GELU(nn.Module): r""" GELU activation function """ def __init__(self, dim_in: int, dim_out: int): super().__init__() self.proj = nn.Linear(dim_in, dim_out) def gelu(self, gate): if gate.device.type != "mps": return F.gelu(gate) # mps: gelu is not implemented for float16 return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) def forward(self, hidden_states): hidden_states = self.proj(hidden_states) hidden_states = self.gelu(hidden_states) return hidden_states # feedforward class GEGLU(nn.Module): r""" A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. Parameters: dim_in (`int`): The number of channels in the input. dim_out (`int`): The number of channels in the output. """ def __init__(self, dim_in: int, dim_out: int): super().__init__() self.proj = nn.Linear(dim_in, dim_out * 2) def gelu(self, gate): if gate.device.type != "mps": return F.gelu(gate) # mps: gelu is not implemented for float16 return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) def forward(self, hidden_states): hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) return hidden_states * self.gelu(gate) class ApproximateGELU(nn.Module): """ The approximate form of Gaussian Error Linear Unit (GELU) For more details, see section 2: https://arxiv.org/abs/1606.08415 """ def __init__(self, dim_in: int, dim_out: int): super().__init__() self.proj = nn.Linear(dim_in, dim_out) def forward(self, x): x = self.proj(x) return x * torch.sigmoid(1.702 * x) class AdaLayerNorm(nn.Module): """ Norm layer modified to incorporate timestep embeddings. """ def __init__(self, embedding_dim, num_embeddings): super().__init__() self.emb = nn.Embedding(num_embeddings, embedding_dim) self.silu = nn.SiLU() self.linear = nn.Linear(embedding_dim, embedding_dim * 2) self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False) def forward(self, x, timestep): emb = self.linear(self.silu(self.emb(timestep))) scale, shift = torch.chunk(emb, 2) x = self.norm(x) * (1 + scale) + shift return x