Unverified Commit 909742db authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

attention refactor: the trilogy (#3387)

* Replace `AttentionBlock` with `Attention`

* use _from_deprecated_attn_block check re: @patrickvonplaten
parent 28f40434
...@@ -11,189 +11,17 @@ ...@@ -11,189 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math from typing import Optional
from typing import Callable, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from ..utils import maybe_allow_in_graph from ..utils import maybe_allow_in_graph
from ..utils.import_utils import is_xformers_available
from .attention_processor import Attention from .attention_processor import Attention
from .embeddings import CombinedTimestepLabelEmbeddings from .embeddings import CombinedTimestepLabelEmbeddings
if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
Uses three q, k, v linear layers to compute attention.
Parameters:
channels (`int`): The number of channels in the input and output.
num_head_channels (`int`, *optional*):
The number of channels in each head. If None, then `num_heads` = 1.
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm.
rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
"""
# IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore
def __init__(
self,
channels: int,
num_head_channels: Optional[int] = None,
norm_num_groups: int = 32,
rescale_output_factor: float = 1.0,
eps: float = 1e-5,
):
super().__init__()
self.channels = channels
self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
# define q,k,v as linear layers
self.query = nn.Linear(channels, channels)
self.key = nn.Linear(channels, channels)
self.value = nn.Linear(channels, channels)
self.rescale_output_factor = rescale_output_factor
self.proj_attn = nn.Linear(channels, channels, bias=True)
self._use_memory_efficient_attention_xformers = False
self._use_2_0_attn = True
self._attention_op = None
def reshape_heads_to_batch_dim(self, tensor, merge_head_and_batch=True):
batch_size, seq_len, dim = tensor.shape
head_size = self.num_heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3)
if merge_head_and_batch:
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor
def reshape_batch_dim_to_heads(self, tensor, unmerge_head_and_batch=True):
head_size = self.num_heads
if unmerge_head_and_batch:
batch_head_size, seq_len, dim = tensor.shape
batch_size = batch_head_size // head_size
tensor = tensor.reshape(batch_size, head_size, seq_len, dim)
else:
batch_size, _, seq_len, dim = tensor.shape
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size, seq_len, dim * head_size)
return tensor
def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
):
if use_memory_efficient_attention_xformers:
if not is_xformers_available():
raise ModuleNotFoundError(
(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
" xformers"
),
name="xformers",
)
elif not torch.cuda.is_available():
raise ValueError(
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
" only available for GPU "
)
else:
try:
# Make sure we can run the memory efficient attention
_ = xformers.ops.memory_efficient_attention(
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
)
except Exception as e:
raise e
self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
self._attention_op = attention_op
def forward(self, hidden_states):
residual = hidden_states
batch, channel, height, width = hidden_states.shape
# norm
hidden_states = self.group_norm(hidden_states)
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
# proj to q, k, v
query_proj = self.query(hidden_states)
key_proj = self.key(hidden_states)
value_proj = self.value(hidden_states)
scale = 1 / math.sqrt(self.channels / self.num_heads)
_use_2_0_attn = self._use_2_0_attn and not self._use_memory_efficient_attention_xformers
use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention") and _use_2_0_attn
query_proj = self.reshape_heads_to_batch_dim(query_proj, merge_head_and_batch=not use_torch_2_0_attn)
key_proj = self.reshape_heads_to_batch_dim(key_proj, merge_head_and_batch=not use_torch_2_0_attn)
value_proj = self.reshape_heads_to_batch_dim(value_proj, merge_head_and_batch=not use_torch_2_0_attn)
if self._use_memory_efficient_attention_xformers:
# Memory efficient attention
hidden_states = xformers.ops.memory_efficient_attention(
query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op, scale=scale
)
hidden_states = hidden_states.to(query_proj.dtype)
elif use_torch_2_0_attn:
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query_proj, key_proj, value_proj, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.to(query_proj.dtype)
else:
attention_scores = torch.baddbmm(
torch.empty(
query_proj.shape[0],
query_proj.shape[1],
key_proj.shape[1],
dtype=query_proj.dtype,
device=query_proj.device,
),
query_proj,
key_proj.transpose(-1, -2),
beta=0,
alpha=scale,
)
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
hidden_states = torch.bmm(attention_probs, value_proj)
# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states, unmerge_head_and_batch=not use_torch_2_0_attn)
# compute next hidden_states
hidden_states = self.proj_attn(hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
# res connect and rescale
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states
@maybe_allow_in_graph @maybe_allow_in_graph
class BasicTransformerBlock(nn.Module): class BasicTransformerBlock(nn.Module):
r""" r"""
......
...@@ -65,6 +65,10 @@ class Attention(nn.Module): ...@@ -65,6 +65,10 @@ class Attention(nn.Module):
out_bias: bool = True, out_bias: bool = True,
scale_qk: bool = True, scale_qk: bool = True,
only_cross_attention: bool = False, only_cross_attention: bool = False,
eps: float = 1e-5,
rescale_output_factor: float = 1.0,
residual_connection: bool = False,
_from_deprecated_attn_block=False,
processor: Optional["AttnProcessor"] = None, processor: Optional["AttnProcessor"] = None,
): ):
super().__init__() super().__init__()
...@@ -72,6 +76,12 @@ class Attention(nn.Module): ...@@ -72,6 +76,12 @@ class Attention(nn.Module):
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax self.upcast_softmax = upcast_softmax
self.rescale_output_factor = rescale_output_factor
self.residual_connection = residual_connection
# we make use of this private variable to know whether this class is loaded
# with an deprecated state dict so that we can convert it on the fly
self._from_deprecated_attn_block = _from_deprecated_attn_block
self.scale_qk = scale_qk self.scale_qk = scale_qk
self.scale = dim_head**-0.5 if self.scale_qk else 1.0 self.scale = dim_head**-0.5 if self.scale_qk else 1.0
...@@ -91,7 +101,7 @@ class Attention(nn.Module): ...@@ -91,7 +101,7 @@ class Attention(nn.Module):
) )
if norm_num_groups is not None: if norm_num_groups is not None:
self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
else: else:
self.group_norm = None self.group_norm = None
...@@ -407,10 +417,22 @@ class AttnProcessor: ...@@ -407,10 +417,22 @@ class AttnProcessor:
encoder_hidden_states=None, encoder_hidden_states=None,
attention_mask=None, attention_mask=None,
): ):
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = ( batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
) )
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states) query = attn.to_q(hidden_states)
if encoder_hidden_states is None: if encoder_hidden_states is None:
...@@ -434,6 +456,14 @@ class AttnProcessor: ...@@ -434,6 +456,14 @@ class AttnProcessor:
# dropout # dropout
hidden_states = attn.to_out[1](hidden_states) hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states return hidden_states
...@@ -474,11 +504,22 @@ class LoRAAttnProcessor(nn.Module): ...@@ -474,11 +504,22 @@ class LoRAAttnProcessor(nn.Module):
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = ( batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
) )
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
query = attn.head_to_batch_dim(query) query = attn.head_to_batch_dim(query)
...@@ -502,6 +543,14 @@ class LoRAAttnProcessor(nn.Module): ...@@ -502,6 +543,14 @@ class LoRAAttnProcessor(nn.Module):
# dropout # dropout
hidden_states = attn.to_out[1](hidden_states) hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states return hidden_states
...@@ -762,12 +811,23 @@ class XFormersAttnProcessor: ...@@ -762,12 +811,23 @@ class XFormersAttnProcessor:
self.attention_op = attention_op self.attention_op = attention_op
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = ( batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
) )
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states) query = attn.to_q(hidden_states)
if encoder_hidden_states is None: if encoder_hidden_states is None:
...@@ -792,6 +852,15 @@ class XFormersAttnProcessor: ...@@ -792,6 +852,15 @@ class XFormersAttnProcessor:
hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[0](hidden_states)
# dropout # dropout
hidden_states = attn.to_out[1](hidden_states) hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states return hidden_states
...@@ -801,6 +870,14 @@ class AttnProcessor2_0: ...@@ -801,6 +870,14 @@ class AttnProcessor2_0:
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = ( batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
) )
...@@ -812,6 +889,9 @@ class AttnProcessor2_0: ...@@ -812,6 +889,9 @@ class AttnProcessor2_0:
# (batch, heads, source_length, target_length) # (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states) query = attn.to_q(hidden_states)
if encoder_hidden_states is None: if encoder_hidden_states is None:
...@@ -840,6 +920,15 @@ class AttnProcessor2_0: ...@@ -840,6 +920,15 @@ class AttnProcessor2_0:
hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[0](hidden_states)
# dropout # dropout
hidden_states = attn.to_out[1](hidden_states) hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states return hidden_states
...@@ -858,11 +947,22 @@ class LoRAXFormersAttnProcessor(nn.Module): ...@@ -858,11 +947,22 @@ class LoRAXFormersAttnProcessor(nn.Module):
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = ( batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
) )
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
query = attn.head_to_batch_dim(query).contiguous() query = attn.head_to_batch_dim(query).contiguous()
...@@ -887,6 +987,14 @@ class LoRAXFormersAttnProcessor(nn.Module): ...@@ -887,6 +987,14 @@ class LoRAXFormersAttnProcessor(nn.Module):
# dropout # dropout
hidden_states = attn.to_out[1](hidden_states) hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states return hidden_states
...@@ -980,11 +1088,22 @@ class SlicedAttnProcessor: ...@@ -980,11 +1088,22 @@ class SlicedAttnProcessor:
self.slice_size = slice_size self.slice_size = slice_size
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = ( batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
) )
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states) query = attn.to_q(hidden_states)
dim = query.shape[-1] dim = query.shape[-1]
query = attn.head_to_batch_dim(query) query = attn.head_to_batch_dim(query)
...@@ -1025,6 +1144,14 @@ class SlicedAttnProcessor: ...@@ -1025,6 +1144,14 @@ class SlicedAttnProcessor:
# dropout # dropout
hidden_states = attn.to_out[1](hidden_states) hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states return hidden_states
......
...@@ -583,6 +583,7 @@ class ModelMixin(torch.nn.Module): ...@@ -583,6 +583,7 @@ class ModelMixin(torch.nn.Module):
if device_map is None: if device_map is None:
param_device = "cpu" param_device = "cpu"
state_dict = load_state_dict(model_file, variant=variant) state_dict = load_state_dict(model_file, variant=variant)
model._convert_deprecated_attention_blocks(state_dict)
# move the params from meta device to cpu # move the params from meta device to cpu
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
if len(missing_keys) > 0: if len(missing_keys) > 0:
...@@ -625,6 +626,7 @@ class ModelMixin(torch.nn.Module): ...@@ -625,6 +626,7 @@ class ModelMixin(torch.nn.Module):
model = cls.from_config(config, **unused_kwargs) model = cls.from_config(config, **unused_kwargs)
state_dict = load_state_dict(model_file, variant=variant) state_dict = load_state_dict(model_file, variant=variant)
model._convert_deprecated_attention_blocks(state_dict)
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
model, model,
...@@ -803,3 +805,47 @@ class ModelMixin(torch.nn.Module): ...@@ -803,3 +805,47 @@ class ModelMixin(torch.nn.Module):
return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable) return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
else: else:
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
def _convert_deprecated_attention_blocks(self, state_dict):
deprecated_attention_block_paths = []
def recursive_find_attn_block(name, module):
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
deprecated_attention_block_paths.append(name)
for sub_name, sub_module in module.named_children():
sub_name = sub_name if name == "" else f"{name}.{sub_name}"
recursive_find_attn_block(sub_name, sub_module)
recursive_find_attn_block("", self)
# NOTE: we have to check if the deprecated parameters are in the state dict
# because it is possible we are loading from a state dict that was already
# converted
for path in deprecated_attention_block_paths:
# group_norm path stays the same
# query -> to_q
if f"{path}.query.weight" in state_dict:
state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight")
if f"{path}.query.bias" in state_dict:
state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias")
# key -> to_k
if f"{path}.key.weight" in state_dict:
state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight")
if f"{path}.key.bias" in state_dict:
state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias")
# value -> to_v
if f"{path}.value.weight" in state_dict:
state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight")
if f"{path}.value.bias" in state_dict:
state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias")
# proj_attn -> to_out.0
if f"{path}.proj_attn.weight" in state_dict:
state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
if f"{path}.proj_attn.bias" in state_dict:
state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
...@@ -18,7 +18,7 @@ import torch ...@@ -18,7 +18,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from .attention import AdaGroupNorm, AttentionBlock from .attention import AdaGroupNorm
from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
from .dual_transformer_2d import DualTransformer2DModel from .dual_transformer_2d import DualTransformer2DModel
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D
...@@ -427,12 +427,17 @@ class UNetMidBlock2D(nn.Module): ...@@ -427,12 +427,17 @@ class UNetMidBlock2D(nn.Module):
for _ in range(num_layers): for _ in range(num_layers):
if self.add_attention: if self.add_attention:
attentions.append( attentions.append(
AttentionBlock( Attention(
in_channels, in_channels,
num_head_channels=attn_num_head_channels, heads=in_channels // attn_num_head_channels if attn_num_head_channels is not None else 1,
dim_head=attn_num_head_channels if attn_num_head_channels is not None else in_channels,
rescale_output_factor=output_scale_factor, rescale_output_factor=output_scale_factor,
eps=resnet_eps, eps=resnet_eps,
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups,
residual_connection=True,
bias=True,
upcast_softmax=True,
_from_deprecated_attn_block=True,
) )
) )
else: else:
...@@ -711,12 +716,17 @@ class AttnDownBlock2D(nn.Module): ...@@ -711,12 +716,17 @@ class AttnDownBlock2D(nn.Module):
) )
) )
attentions.append( attentions.append(
AttentionBlock( Attention(
out_channels, out_channels,
num_head_channels=attn_num_head_channels, heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1,
dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels,
rescale_output_factor=output_scale_factor, rescale_output_factor=output_scale_factor,
eps=resnet_eps, eps=resnet_eps,
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups,
residual_connection=True,
bias=True,
upcast_softmax=True,
_from_deprecated_attn_block=True,
) )
) )
...@@ -1060,12 +1070,17 @@ class AttnDownEncoderBlock2D(nn.Module): ...@@ -1060,12 +1070,17 @@ class AttnDownEncoderBlock2D(nn.Module):
) )
) )
attentions.append( attentions.append(
AttentionBlock( Attention(
out_channels, out_channels,
num_head_channels=attn_num_head_channels, heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1,
dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels,
rescale_output_factor=output_scale_factor, rescale_output_factor=output_scale_factor,
eps=resnet_eps, eps=resnet_eps,
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups,
residual_connection=True,
bias=True,
upcast_softmax=True,
_from_deprecated_attn_block=True,
) )
) )
...@@ -1134,11 +1149,17 @@ class AttnSkipDownBlock2D(nn.Module): ...@@ -1134,11 +1149,17 @@ class AttnSkipDownBlock2D(nn.Module):
) )
) )
self.attentions.append( self.attentions.append(
AttentionBlock( Attention(
out_channels, out_channels,
num_head_channels=attn_num_head_channels, heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1,
dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels,
rescale_output_factor=output_scale_factor, rescale_output_factor=output_scale_factor,
eps=resnet_eps, eps=resnet_eps,
norm_num_groups=32,
residual_connection=True,
bias=True,
upcast_softmax=True,
_from_deprecated_attn_block=True,
) )
) )
...@@ -1703,12 +1724,17 @@ class AttnUpBlock2D(nn.Module): ...@@ -1703,12 +1724,17 @@ class AttnUpBlock2D(nn.Module):
) )
) )
attentions.append( attentions.append(
AttentionBlock( Attention(
out_channels, out_channels,
num_head_channels=attn_num_head_channels, heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1,
dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels,
rescale_output_factor=output_scale_factor, rescale_output_factor=output_scale_factor,
eps=resnet_eps, eps=resnet_eps,
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups,
residual_connection=True,
bias=True,
upcast_softmax=True,
_from_deprecated_attn_block=True,
) )
) )
...@@ -2037,12 +2063,17 @@ class AttnUpDecoderBlock2D(nn.Module): ...@@ -2037,12 +2063,17 @@ class AttnUpDecoderBlock2D(nn.Module):
) )
) )
attentions.append( attentions.append(
AttentionBlock( Attention(
out_channels, out_channels,
num_head_channels=attn_num_head_channels, heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1,
dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels,
rescale_output_factor=output_scale_factor, rescale_output_factor=output_scale_factor,
eps=resnet_eps, eps=resnet_eps,
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups,
residual_connection=True,
bias=True,
upcast_softmax=True,
_from_deprecated_attn_block=True,
) )
) )
...@@ -2109,11 +2140,17 @@ class AttnSkipUpBlock2D(nn.Module): ...@@ -2109,11 +2140,17 @@ class AttnSkipUpBlock2D(nn.Module):
) )
self.attentions.append( self.attentions.append(
AttentionBlock( Attention(
out_channels, out_channels,
num_head_channels=attn_num_head_channels, heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1,
dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels,
rescale_output_factor=output_scale_factor, rescale_output_factor=output_scale_factor,
eps=resnet_eps, eps=resnet_eps,
norm_num_groups=32,
residual_connection=True,
bias=True,
upcast_softmax=True,
_from_deprecated_attn_block=True,
) )
) )
......
...@@ -19,11 +19,11 @@ from typing import Any, Callable, List, Optional, Union ...@@ -19,11 +19,11 @@ from typing import Any, Callable, List, Optional, Union
import numpy as np import numpy as np
import PIL import PIL
import torch import torch
import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...loaders import TextualInversionLoaderMixin from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import AttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor
from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers
from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
...@@ -709,12 +709,14 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -709,12 +709,14 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
# make sure the VAE is in float32 mode, as it overflows in float16 # make sure the VAE is in float32 mode, as it overflows in float16
self.vae.to(dtype=torch.float32) self.vae.to(dtype=torch.float32)
# TODO(Patrick, William) - clean up when attention is refactored use_torch_2_0_or_xformers = self.vae.decoder.mid_block.attentions[0].processor in [
use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention") AttnProcessor2_0,
use_xformers = self.vae.decoder.mid_block.attentions[0]._use_memory_efficient_attention_xformers XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
]
# if xformers or torch_2_0 is used attention block does not need # if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory # to be in float32 which can save lots of memory
if not use_torch_2_0_attn and not use_xformers: if not use_torch_2_0_or_xformers:
self.vae.post_quant_conv.to(latents.dtype) self.vae.post_quant_conv.to(latents.dtype)
self.vae.decoder.conv_in.to(latents.dtype) self.vae.decoder.conv_in.to(latents.dtype)
self.vae.decoder.mid_block.to(latents.dtype) self.vae.decoder.mid_block.to(latents.dtype)
......
...@@ -20,7 +20,7 @@ import numpy as np ...@@ -20,7 +20,7 @@ import numpy as np
import torch import torch
from torch import nn from torch import nn
from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU, AttentionBlock from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU
from diffusers.models.embeddings import get_timestep_embedding from diffusers.models.embeddings import get_timestep_embedding
from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
from diffusers.models.transformer_2d import Transformer2DModel from diffusers.models.transformer_2d import Transformer2DModel
...@@ -314,59 +314,6 @@ class ResnetBlock2DTests(unittest.TestCase): ...@@ -314,59 +314,6 @@ class ResnetBlock2DTests(unittest.TestCase):
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
class AttentionBlockTests(unittest.TestCase):
@unittest.skipIf(
torch_device == "mps", "Matmul crashes on MPS, see https://github.com/pytorch/pytorch/issues/84039"
)
def test_attention_block_default(self):
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
sample = torch.randn(1, 32, 64, 64).to(torch_device)
attentionBlock = AttentionBlock(
channels=32,
num_head_channels=1,
rescale_output_factor=1.0,
eps=1e-6,
norm_num_groups=32,
).to(torch_device)
with torch.no_grad():
attention_scores = attentionBlock(sample)
assert attention_scores.shape == (1, 32, 64, 64)
output_slice = attention_scores[0, -1, -3:, -3:]
expected_slice = torch.tensor(
[-1.4975, -0.0038, -0.7847, -1.4567, 1.1220, -0.8962, -1.7394, 1.1319, -0.5427], device=torch_device
)
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
def test_attention_block_sd(self):
# This version uses SD params and is compatible with mps
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
sample = torch.randn(1, 512, 64, 64).to(torch_device)
attentionBlock = AttentionBlock(
channels=512,
rescale_output_factor=1.0,
eps=1e-6,
norm_num_groups=32,
).to(torch_device)
with torch.no_grad():
attention_scores = attentionBlock(sample)
assert attention_scores.shape == (1, 512, 64, 64)
output_slice = attention_scores[0, -1, -3:, -3:]
expected_slice = torch.tensor(
[-0.6621, -0.0156, -3.2766, 0.8025, -0.8609, 0.2820, 0.0905, -1.1179, -3.2126], device=torch_device
)
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
class Transformer2DModelTests(unittest.TestCase): class Transformer2DModelTests(unittest.TestCase):
def test_spatial_transformer_default(self): def test_spatial_transformer_default(self):
torch.manual_seed(0) torch.manual_seed(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment