Commit 1a4bd9e9 authored by comfyanonymous's avatar comfyanonymous
Browse files

Refactor the attention functions.

There's no reason for the whole CrossAttention object to be repeated when
only the operation in the middle changes.
parent 8cc75c64
...@@ -94,95 +94,41 @@ def zero_module(module): ...@@ -94,95 +94,41 @@ def zero_module(module):
def Normalize(in_channels, dtype=None, device=None): def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
def attention_basic(q, k, v, heads, mask=None):
h = heads
scale = (q.shape[-1] // heads) ** -0.5
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
class SpatialSelfAttention(nn.Module): # force cast to fp32 to avoid overflowing
def __init__(self, in_channels): if _ATTN_PRECISION =="fp32":
super().__init__() with torch.autocast(enabled=False, device_type = 'cuda'):
self.in_channels = in_channels q, k = q.float(), k.float()
sim = einsum('b i d, b j d -> b i j', q, k) * scale
self.norm = Normalize(in_channels) else:
self.q = torch.nn.Conv2d(in_channels, sim = einsum('b i d, b j d -> b i j', q, k) * scale
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b,c,h,w = q.shape
q = rearrange(q, 'b c h w -> b (h w) c')
k = rearrange(k, 'b c h w -> b c (h w)')
w_ = torch.einsum('bij,bjk->bik', q, k)
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = rearrange(v, 'b c h w -> b c (h w)')
w_ = rearrange(w_, 'b i j -> b j i')
h_ = torch.einsum('bij,bjk->bik', v, w_)
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
h_ = self.proj_out(h_)
return x+h_
class CrossAttentionBirchSan(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head ** -0.5
self.heads = heads
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) del q, k
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential( if exists(mask):
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), mask = rearrange(mask, 'b ... -> b (...)')
nn.Dropout(dropout) max_neg_value = -torch.finfo(sim.dtype).max
) mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
def forward(self, x, context=None, value=None, mask=None): # attention, what we cannot get enough of
h = self.heads sim = sim.softmax(dim=-1)
query = self.to_q(x) out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
context = default(context, x) out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
key = self.to_k(context) return out
if value is not None:
value = self.to_v(value)
else:
value = self.to_v(context)
del context, x
query = query.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1) def attention_sub_quad(query, key, value, heads, mask=None):
key_t = key.transpose(1,2).unflatten(1, (self.heads, -1)).flatten(end_dim=1) scale = (query.shape[-1] // heads) ** -0.5
query = query.unflatten(-1, (heads, -1)).transpose(1,2).flatten(end_dim=1)
key_t = key.transpose(1,2).unflatten(1, (heads, -1)).flatten(end_dim=1)
del key del key
value = value.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1) value = value.unflatten(-1, (heads, -1)).transpose(1,2).flatten(end_dim=1)
dtype = query.dtype dtype = query.dtype
upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32 upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32
...@@ -230,54 +176,19 @@ class CrossAttentionBirchSan(nn.Module): ...@@ -230,54 +176,19 @@ class CrossAttentionBirchSan(nn.Module):
query_chunk_size=query_chunk_size, query_chunk_size=query_chunk_size,
kv_chunk_size=kv_chunk_size, kv_chunk_size=kv_chunk_size,
kv_chunk_size_min=kv_chunk_size_min, kv_chunk_size_min=kv_chunk_size_min,
use_checkpoint=self.training, use_checkpoint=False,
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
) )
hidden_states = hidden_states.to(dtype) hidden_states = hidden_states.to(dtype)
hidden_states = hidden_states.unflatten(0, (-1, self.heads)).transpose(1,2).flatten(start_dim=2) hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
out_proj, dropout = self.to_out
hidden_states = out_proj(hidden_states)
hidden_states = dropout(hidden_states)
return hidden_states return hidden_states
def attention_split(q, k, v, heads, mask=None):
class CrossAttentionDoggettx(nn.Module): scale = (q.shape[-1] // heads) ** -0.5
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops): h = heads
super().__init__() q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head ** -0.5
self.heads = heads
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential(
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
nn.Dropout(dropout)
)
def forward(self, x, context=None, value=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
k_in = self.to_k(context)
if value is not None:
v_in = self.to_v(value)
del value
else:
v_in = self.to_v(context)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
...@@ -310,9 +221,9 @@ class CrossAttentionDoggettx(nn.Module): ...@@ -310,9 +221,9 @@ class CrossAttentionDoggettx(nn.Module):
end = i + slice_size end = i + slice_size
if _ATTN_PRECISION =="fp32": if _ATTN_PRECISION =="fp32":
with torch.autocast(enabled=False, device_type = 'cuda'): with torch.autocast(enabled=False, device_type = 'cuda'):
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * self.scale s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
else: else:
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
first_op_done = True first_op_done = True
s2 = s1.softmax(dim=-1).to(v.dtype) s2 = s1.softmax(dim=-1).to(v.dtype)
...@@ -339,115 +250,66 @@ class CrossAttentionDoggettx(nn.Module): ...@@ -339,115 +250,66 @@ class CrossAttentionDoggettx(nn.Module):
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1 del r1
return r2
return self.to_out(r2) def attention_xformers(q, k, v, heads, mask=None):
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head ** -0.5
self.heads = heads
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential(
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
nn.Dropout(dropout)
)
def forward(self, x, context=None, value=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
if value is not None:
v = self.to_v(value)
del value
else:
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
# force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32":
with torch.autocast(enabled=False, device_type = 'cuda'):
q, k = q.float(), k.float()
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
else:
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
del q, k
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', sim, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
class MemoryEfficientCrossAttention(nn.Module):
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None, device=None, operations=comfy.ops):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.heads = heads
self.dim_head = dim_head
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None
def forward(self, x, context=None, value=None, mask=None):
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
if value is not None:
v = self.to_v(value)
del value
else:
v = self.to_v(context)
b, _, _ = q.shape b, _, _ = q.shape
q, k, v = map( q, k, v = map(
lambda t: t.unsqueeze(3) lambda t: t.unsqueeze(3)
.reshape(b, t.shape[1], self.heads, self.dim_head) .reshape(b, t.shape[1], heads, -1)
.permute(0, 2, 1, 3) .permute(0, 2, 1, 3)
.reshape(b * self.heads, t.shape[1], self.dim_head) .reshape(b * heads, t.shape[1], -1)
.contiguous(), .contiguous(),
(q, k, v), (q, k, v),
) )
# actually compute the attention, what we cannot get enough of # actually compute the attention, what we cannot get enough of
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
if exists(mask): if exists(mask):
raise NotImplementedError raise NotImplementedError
out = ( out = (
out.unsqueeze(0) out.unsqueeze(0)
.reshape(b, self.heads, out.shape[1], self.dim_head) .reshape(b, heads, out.shape[1], -1)
.permute(0, 2, 1, 3) .permute(0, 2, 1, 3)
.reshape(b, out.shape[1], self.heads * self.dim_head) .reshape(b, out.shape[1], -1)
) )
return self.to_out(out) return out
def attention_pytorch(q, k, v, heads, mask=None):
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
(q, k, v),
)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
if exists(mask):
raise NotImplementedError
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
return out
optimized_attention = attention_basic
if model_management.xformers_enabled():
print("Using xformers cross attention")
optimized_attention = attention_xformers
elif model_management.pytorch_attention_enabled():
print("Using pytorch cross attention")
optimized_attention = attention_pytorch
else:
if args.use_split_cross_attention:
print("Using split optimization for cross attention")
optimized_attention = attention_split
else:
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
optimized_attention = attention_sub_quad
class CrossAttentionPytorch(nn.Module): class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
super().__init__() super().__init__()
inner_dim = dim_head * heads inner_dim = dim_head * heads
...@@ -461,7 +323,6 @@ class CrossAttentionPytorch(nn.Module): ...@@ -461,7 +323,6 @@ class CrossAttentionPytorch(nn.Module):
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None
def forward(self, x, context=None, value=None, mask=None): def forward(self, x, context=None, value=None, mask=None):
q = self.to_q(x) q = self.to_q(x)
...@@ -473,36 +334,9 @@ class CrossAttentionPytorch(nn.Module): ...@@ -473,36 +334,9 @@ class CrossAttentionPytorch(nn.Module):
else: else:
v = self.to_v(context) v = self.to_v(context)
b, _, _ = q.shape out = optimized_attention(q, k, v, self.heads, mask)
q, k, v = map(
lambda t: t.view(b, -1, self.heads, self.dim_head).transpose(1, 2),
(q, k, v),
)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
if exists(mask):
raise NotImplementedError
out = (
out.transpose(1, 2).reshape(b, -1, self.heads * self.dim_head)
)
return self.to_out(out) return self.to_out(out)
if model_management.xformers_enabled():
print("Using xformers cross attention")
CrossAttention = MemoryEfficientCrossAttention
elif model_management.pytorch_attention_enabled():
print("Using pytorch cross attention")
CrossAttention = CrossAttentionPytorch
else:
if args.use_split_cross_attention:
print("Using split optimization for cross attention")
CrossAttention = CrossAttentionDoggettx
else:
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
CrossAttention = CrossAttentionBirchSan
class BasicTransformerBlock(nn.Module): class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
......
...@@ -6,7 +6,6 @@ import numpy as np ...@@ -6,7 +6,6 @@ import numpy as np
from einops import rearrange from einops import rearrange
from typing import Optional, Any from typing import Optional, Any
from ..attention import MemoryEfficientCrossAttention
from comfy import model_management from comfy import model_management
import comfy.ops import comfy.ops
...@@ -352,15 +351,6 @@ class MemoryEfficientAttnBlockPytorch(nn.Module): ...@@ -352,15 +351,6 @@ class MemoryEfficientAttnBlockPytorch(nn.Module):
out = self.proj_out(out) out = self.proj_out(out)
return x+out return x+out
class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
def forward(self, x, context=None, mask=None):
b, c, h, w = x.shape
x = rearrange(x, 'b c h w -> b (h w) c')
out = super().forward(x, context=context, mask=mask)
out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c)
return x + out
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown' assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
if model_management.xformers_enabled_vae() and attn_type == "vanilla": if model_management.xformers_enabled_vae() and attn_type == "vanilla":
...@@ -376,9 +366,6 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): ...@@ -376,9 +366,6 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
return MemoryEfficientAttnBlock(in_channels) return MemoryEfficientAttnBlock(in_channels)
elif attn_type == "vanilla-pytorch": elif attn_type == "vanilla-pytorch":
return MemoryEfficientAttnBlockPytorch(in_channels) return MemoryEfficientAttnBlockPytorch(in_channels)
elif type == "memory-efficient-cross-attn":
attn_kwargs["query_dim"] = in_channels
return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
elif attn_type == "none": elif attn_type == "none":
return nn.Identity(in_channels) return nn.Identity(in_channels)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment