Commit 52b3ff5e authored by Patrick von Platen's avatar Patrick von Platen
Browse files

unify ldm and glide attention

parent fff981df
import math
import torch
import torch.nn.functional as F
from torch import nn
def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
# unet_grad_tts.py # unet_grad_tts.py
class LinearAttention(torch.nn.Module): class LinearAttention(torch.nn.Module):
def __init__(self, dim, heads=4, dim_head=32): def __init__(self, dim, heads=4, dim_head=32):
...@@ -24,6 +35,7 @@ class LinearAttention(torch.nn.Module): ...@@ -24,6 +35,7 @@ class LinearAttention(torch.nn.Module):
out = out.reshape(b, self.heads, self.dim_head, h, w).reshape(b, self.heads * self.dim_head, h, w) out = out.reshape(b, self.heads, self.dim_head, h, w).reshape(b, self.heads * self.dim_head, h, w)
return self.to_out(out) return self.to_out(out)
# unet.py # unet.py
class AttnBlock(nn.Module): class AttnBlock(nn.Module):
def __init__(self, in_channels): def __init__(self, in_channels):
...@@ -62,7 +74,8 @@ class AttnBlock(nn.Module): ...@@ -62,7 +74,8 @@ class AttnBlock(nn.Module):
return x + h_ return x + h_
# unet_glide.py
# unet_glide.py & unet_ldm.py
class AttentionBlock(nn.Module): class AttentionBlock(nn.Module):
""" """
An attention block that allows spatial positions to attend to each other. An attention block that allows spatial positions to attend to each other.
...@@ -78,6 +91,7 @@ class AttentionBlock(nn.Module): ...@@ -78,6 +91,7 @@ class AttentionBlock(nn.Module):
num_head_channels=-1, num_head_channels=-1,
use_checkpoint=False, use_checkpoint=False,
encoder_channels=None, encoder_channels=None,
use_new_attention_order=False, # TODO(Patrick) -> is never used, maybe delete?
): ):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
...@@ -108,6 +122,7 @@ class AttentionBlock(nn.Module): ...@@ -108,6 +122,7 @@ class AttentionBlock(nn.Module):
h = self.proj_out(h) h = self.proj_out(h)
return x + h.reshape(b, c, *spatial) return x + h.reshape(b, c, *spatial)
class QKVAttention(nn.Module): class QKVAttention(nn.Module):
""" """
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
...@@ -140,106 +155,78 @@ class QKVAttention(nn.Module): ...@@ -140,106 +155,78 @@ class QKVAttention(nn.Module):
return a.reshape(bs, -1, length) return a.reshape(bs, -1, length)
# unet_ldm.py def conv_nd(dims, *args, **kwargs):
class AttentionBlock(nn.Module):
""" """
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted Create a 1D, 2D, or 3D convolution module.
to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
""" """
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def __init__(
self,
channels,
num_heads=1,
num_head_channels=-1,
use_checkpoint=False,
use_new_attention_order=False,
):
super().__init__()
self.channels = channels
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint
self.norm = normalization(channels)
self.qkv = conv_nd(1, channels, channels * 3, 1)
# split heads before split qkv
self.attention = QKVAttentionLegacy(self.num_heads)
self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) class GroupNorm32(nn.GroupNorm):
def __init__(self, num_groups, num_channels, swish, eps=1e-5):
super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps)
self.swish = swish
def forward(self, x): def forward(self, x):
b, c, *spatial = x.shape y = super().forward(x.float()).to(x.dtype)
x = x.reshape(b, c, -1) if self.swish == 1.0:
qkv = self.qkv(self.norm(x)) y = F.silu(y)
h = self.attention(qkv) elif self.swish:
h = self.proj_out(h) y = y * F.sigmoid(y * float(self.swish))
return (x + h).reshape(b, c, *spatial) return y
class QKVAttention(nn.Module):
def normalization(channels, swish=0.0):
""" """
A module which performs QKV attention and splits in a different order. Make a standard normalization layer, with an optional swish activation.
:param channels: number of input channels. :return: an nn.Module for normalization.
""" """
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv): def zero_module(module):
""" """
Apply QKV attention. :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x Zero out the parameters of a module and return it.
T] tensor after attention.
""" """
bs, width, length = qkv.shape for p in module.parameters():
assert width % (3 * self.n_heads) == 0 p.detach().zero_()
ch = width // (3 * self.n_heads) return module
q, k, v = qkv.chunk(3, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum(
"bct,bcs->bts",
(q * scale).view(bs * self.n_heads, ch, length),
(k * scale).view(bs * self.n_heads, ch, length),
) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
return a.reshape(bs, -1, length)
@staticmethod
def count_flops(model, _x, y):
return count_flops_attn(model, _x, y)
# unet_score_estimation.py # unet_score_estimation.py
class AttnBlockpp(nn.Module): # class AttnBlockpp(nn.Module):
"""Channel-wise self-attention block. Modified from DDPM.""" # """Channel-wise self-attention block. Modified from DDPM."""
#
def __init__(self, channels, skip_rescale=False, init_scale=0.0): # def __init__(self, channels, skip_rescale=False, init_scale=0.0):
super().__init__() # super().__init__()
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-6) # self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-6)
self.NIN_0 = NIN(channels, channels) # self.NIN_0 = NIN(channels, channels)
self.NIN_1 = NIN(channels, channels) # self.NIN_1 = NIN(channels, channels)
self.NIN_2 = NIN(channels, channels) # self.NIN_2 = NIN(channels, channels)
self.NIN_3 = NIN(channels, channels, init_scale=init_scale) # self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
self.skip_rescale = skip_rescale # self.skip_rescale = skip_rescale
#
def forward(self, x): # def forward(self, x):
B, C, H, W = x.shape # B, C, H, W = x.shape
h = self.GroupNorm_0(x) # h = self.GroupNorm_0(x)
q = self.NIN_0(h) # q = self.NIN_0(h)
k = self.NIN_1(h) # k = self.NIN_1(h)
v = self.NIN_2(h) # v = self.NIN_2(h)
#
w = torch.einsum("bchw,bcij->bhwij", q, k) * (int(C) ** (-0.5)) # w = torch.einsum("bchw,bcij->bhwij", q, k) * (int(C) ** (-0.5))
w = torch.reshape(w, (B, H, W, H * W)) # w = torch.reshape(w, (B, H, W, H * W))
w = F.softmax(w, dim=-1) # w = F.softmax(w, dim=-1)
w = torch.reshape(w, (B, H, W, H, W)) # w = torch.reshape(w, (B, H, W, H, W))
h = torch.einsum("bhwij,bcij->bchw", w, v) # h = torch.einsum("bhwij,bcij->bchw", w, v)
h = self.NIN_3(h) # h = self.NIN_3(h)
if not self.skip_rescale: # if not self.skip_rescale:
return x + h # return x + h
else: # else:
return (x + h) / np.sqrt(2.0) # return (x + h) / np.sqrt(2.0)
import math
from abc import abstractmethod from abc import abstractmethod
import torch import torch
...@@ -7,6 +6,7 @@ import torch.nn.functional as F ...@@ -7,6 +6,7 @@ import torch.nn.functional as F
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention2d import AttentionBlock
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample, Upsample from .resnet import Downsample, Upsample
...@@ -226,84 +226,6 @@ class ResBlock(TimestepBlock): ...@@ -226,84 +226,6 @@ class ResBlock(TimestepBlock):
return self.skip_connection(x) + h return self.skip_connection(x) + h
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=-1,
use_checkpoint=False,
encoder_channels=None,
):
super().__init__()
self.channels = channels
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint
self.norm = normalization(channels, swish=0.0)
self.qkv = conv_nd(1, channels, channels * 3, 1)
self.attention = QKVAttention(self.num_heads)
if encoder_channels is not None:
self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1)
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, x, encoder_out=None):
b, c, *spatial = x.shape
qkv = self.qkv(self.norm(x).view(b, c, -1))
if encoder_out is not None:
encoder_out = self.encoder_kv(encoder_out)
h = self.attention(qkv, encoder_out)
else:
h = self.attention(qkv)
h = self.proj_out(h)
return x + h.reshape(b, c, *spatial)
class QKVAttention(nn.Module):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv, encoder_kv=None):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x T] tensor after
attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
if encoder_kv is not None:
assert encoder_kv.shape[1] == self.n_heads * ch * 2
ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
k = torch.cat([ek, k], dim=-1)
v = torch.cat([ev, v], dim=-1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = torch.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length)
class GlideUNetModel(ModelMixin, ConfigMixin): class GlideUNetModel(ModelMixin, ConfigMixin):
""" """
The full UNet model with attention and timestep embedding. The full UNet model with attention and timestep embedding.
......
...@@ -9,6 +9,7 @@ import torch.nn.functional as F ...@@ -9,6 +9,7 @@ import torch.nn.functional as F
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention2d import AttentionBlock
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample, Upsample from .resnet import Downsample, Upsample
...@@ -172,8 +173,6 @@ class CrossAttention(nn.Module): ...@@ -172,8 +173,6 @@ class CrossAttention(nn.Module):
k = self.to_k(context) k = self.to_k(context)
v = self.to_v(context) v = self.to_v(context)
# q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
q = self.reshape_heads_to_batch_dim(q) q = self.reshape_heads_to_batch_dim(q)
k = self.reshape_heads_to_batch_dim(k) k = self.reshape_heads_to_batch_dim(k)
v = self.reshape_heads_to_batch_dim(v) v = self.reshape_heads_to_batch_dim(v)
...@@ -181,12 +180,9 @@ class CrossAttention(nn.Module): ...@@ -181,12 +180,9 @@ class CrossAttention(nn.Module):
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
if exists(mask): if exists(mask):
# mask = rearrange(mask, "b ... -> b (...)") mask = mask.reshape(batch_size, -1)
maks = mask.reshape(batch_size, -1)
max_neg_value = -torch.finfo(sim.dtype).max max_neg_value = -torch.finfo(sim.dtype).max
# mask = repeat(mask, "b j -> (b h) () j", h=h)
mask = mask[:, None, :].repeat(h, 1, 1) mask = mask[:, None, :].repeat(h, 1, 1)
# x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
sim.masked_fill_(~mask, max_neg_value) sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of # attention, what we cannot get enough of
...@@ -194,7 +190,6 @@ class CrossAttention(nn.Module): ...@@ -194,7 +190,6 @@ class CrossAttention(nn.Module):
out = torch.einsum("b i j, b j d -> b i d", attn, v) out = torch.einsum("b i j, b j d -> b i d", attn, v)
out = self.reshape_batch_dim_to_heads(out) out = self.reshape_batch_dim_to_heads(out)
# out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
return self.to_out(out) return self.to_out(out)
...@@ -487,47 +482,6 @@ class ResBlock(TimestepBlock): ...@@ -487,47 +482,6 @@ class ResBlock(TimestepBlock):
return self.skip_connection(x) + h return self.skip_connection(x) + h
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=-1,
use_checkpoint=False,
use_new_attention_order=False,
):
super().__init__()
self.channels = channels
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint
self.norm = normalization(channels)
self.qkv = conv_nd(1, channels, channels * 3, 1)
# split heads before split qkv
self.attention = QKVAttentionLegacy(self.num_heads)
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, x):
b, c, *spatial = x.shape
x = x.reshape(b, c, -1)
qkv = self.qkv(self.norm(x))
h = self.attention(qkv)
h = self.proj_out(h)
return (x + h).reshape(b, c, *spatial)
class QKVAttention(nn.Module): class QKVAttention(nn.Module):
""" """
A module which performs QKV attention and splits in a different order. A module which performs QKV attention and splits in a different order.
...@@ -577,35 +531,6 @@ def count_flops_attn(model, _x, y): ...@@ -577,35 +531,6 @@ def count_flops_attn(model, _x, y):
model.total_ops += torch.DoubleTensor([matmul_ops]) model.total_ops += torch.DoubleTensor([matmul_ops])
class QKVAttentionLegacy(nn.Module):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv):
"""
Apply QKV attention. :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x
T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = torch.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length)
@staticmethod
def count_flops(model, _x, y):
return count_flops_attn(model, _x, y)
class UNetLDMModel(ModelMixin, ConfigMixin): class UNetLDMModel(ModelMixin, ConfigMixin):
""" """
The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment