"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "a4bc8a86f5fc67ffff77bf1239a829c0af8d2bf2"
Commit 52b3ff5e authored by Patrick von Platen's avatar Patrick von Platen
Browse files

unify ldm and glide attention

parent fff981df
import math
import torch
import torch.nn.functional as F
from torch import nn
def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
# unet_grad_tts.py # unet_grad_tts.py
class LinearAttention(torch.nn.Module): class LinearAttention(torch.nn.Module):
def __init__(self, dim, heads=4, dim_head=32): def __init__(self, dim, heads=4, dim_head=32):
...@@ -24,6 +35,7 @@ class LinearAttention(torch.nn.Module): ...@@ -24,6 +35,7 @@ class LinearAttention(torch.nn.Module):
out = out.reshape(b, self.heads, self.dim_head, h, w).reshape(b, self.heads * self.dim_head, h, w) out = out.reshape(b, self.heads, self.dim_head, h, w).reshape(b, self.heads * self.dim_head, h, w)
return self.to_out(out) return self.to_out(out)
# unet.py # unet.py
class AttnBlock(nn.Module): class AttnBlock(nn.Module):
def __init__(self, in_channels): def __init__(self, in_channels):
...@@ -62,7 +74,8 @@ class AttnBlock(nn.Module): ...@@ -62,7 +74,8 @@ class AttnBlock(nn.Module):
return x + h_ return x + h_
# unet_glide.py
# unet_glide.py & unet_ldm.py
class AttentionBlock(nn.Module): class AttentionBlock(nn.Module):
""" """
An attention block that allows spatial positions to attend to each other. An attention block that allows spatial positions to attend to each other.
...@@ -78,6 +91,7 @@ class AttentionBlock(nn.Module): ...@@ -78,6 +91,7 @@ class AttentionBlock(nn.Module):
num_head_channels=-1, num_head_channels=-1,
use_checkpoint=False, use_checkpoint=False,
encoder_channels=None, encoder_channels=None,
use_new_attention_order=False, # TODO(Patrick) -> is never used, maybe delete?
): ):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
...@@ -108,6 +122,7 @@ class AttentionBlock(nn.Module): ...@@ -108,6 +122,7 @@ class AttentionBlock(nn.Module):
h = self.proj_out(h) h = self.proj_out(h)
return x + h.reshape(b, c, *spatial) return x + h.reshape(b, c, *spatial)
class QKVAttention(nn.Module): class QKVAttention(nn.Module):
""" """
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
...@@ -140,106 +155,78 @@ class QKVAttention(nn.Module): ...@@ -140,106 +155,78 @@ class QKVAttention(nn.Module):
return a.reshape(bs, -1, length) return a.reshape(bs, -1, length)
# unet_ldm.py def conv_nd(dims, *args, **kwargs):
class AttentionBlock(nn.Module):
""" """
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted Create a 1D, 2D, or 3D convolution module.
to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
""" """
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def __init__(
self,
channels,
num_heads=1,
num_head_channels=-1,
use_checkpoint=False,
use_new_attention_order=False,
):
super().__init__()
self.channels = channels
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint
self.norm = normalization(channels)
self.qkv = conv_nd(1, channels, channels * 3, 1)
# split heads before split qkv
self.attention = QKVAttentionLegacy(self.num_heads)
self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) class GroupNorm32(nn.GroupNorm):
def __init__(self, num_groups, num_channels, swish, eps=1e-5):
super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps)
self.swish = swish
def forward(self, x): def forward(self, x):
b, c, *spatial = x.shape y = super().forward(x.float()).to(x.dtype)
x = x.reshape(b, c, -1) if self.swish == 1.0:
qkv = self.qkv(self.norm(x)) y = F.silu(y)
h = self.attention(qkv) elif self.swish:
h = self.proj_out(h) y = y * F.sigmoid(y * float(self.swish))
return (x + h).reshape(b, c, *spatial) return y
class QKVAttention(nn.Module):
def normalization(channels, swish=0.0):
""" """
A module which performs QKV attention and splits in a different order. Make a standard normalization layer, with an optional swish activation.
:param channels: number of input channels. :return: an nn.Module for normalization.
""" """
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv): def zero_module(module):
""" """
Apply QKV attention. :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x Zero out the parameters of a module and return it.
T] tensor after attention. """
""" for p in module.parameters():
bs, width, length = qkv.shape p.detach().zero_()
assert width % (3 * self.n_heads) == 0 return module
ch = width // (3 * self.n_heads)
q, k, v = qkv.chunk(3, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum(
"bct,bcs->bts",
(q * scale).view(bs * self.n_heads, ch, length),
(k * scale).view(bs * self.n_heads, ch, length),
) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
return a.reshape(bs, -1, length)
@staticmethod
def count_flops(model, _x, y):
return count_flops_attn(model, _x, y)
# unet_score_estimation.py # unet_score_estimation.py
class AttnBlockpp(nn.Module): # class AttnBlockpp(nn.Module):
"""Channel-wise self-attention block. Modified from DDPM.""" # """Channel-wise self-attention block. Modified from DDPM."""
#
def __init__(self, channels, skip_rescale=False, init_scale=0.0): # def __init__(self, channels, skip_rescale=False, init_scale=0.0):
super().__init__() # super().__init__()
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-6) # self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-6)
self.NIN_0 = NIN(channels, channels) # self.NIN_0 = NIN(channels, channels)
self.NIN_1 = NIN(channels, channels) # self.NIN_1 = NIN(channels, channels)
self.NIN_2 = NIN(channels, channels) # self.NIN_2 = NIN(channels, channels)
self.NIN_3 = NIN(channels, channels, init_scale=init_scale) # self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
self.skip_rescale = skip_rescale # self.skip_rescale = skip_rescale
#
def forward(self, x): # def forward(self, x):
B, C, H, W = x.shape # B, C, H, W = x.shape
h = self.GroupNorm_0(x) # h = self.GroupNorm_0(x)
q = self.NIN_0(h) # q = self.NIN_0(h)
k = self.NIN_1(h) # k = self.NIN_1(h)
v = self.NIN_2(h) # v = self.NIN_2(h)
#
w = torch.einsum("bchw,bcij->bhwij", q, k) * (int(C) ** (-0.5)) # w = torch.einsum("bchw,bcij->bhwij", q, k) * (int(C) ** (-0.5))
w = torch.reshape(w, (B, H, W, H * W)) # w = torch.reshape(w, (B, H, W, H * W))
w = F.softmax(w, dim=-1) # w = F.softmax(w, dim=-1)
w = torch.reshape(w, (B, H, W, H, W)) # w = torch.reshape(w, (B, H, W, H, W))
h = torch.einsum("bhwij,bcij->bchw", w, v) # h = torch.einsum("bhwij,bcij->bchw", w, v)
h = self.NIN_3(h) # h = self.NIN_3(h)
if not self.skip_rescale: # if not self.skip_rescale:
return x + h # return x + h
else: # else:
return (x + h) / np.sqrt(2.0) # return (x + h) / np.sqrt(2.0)
import math
from abc import abstractmethod from abc import abstractmethod
import torch import torch
...@@ -7,6 +6,7 @@ import torch.nn.functional as F ...@@ -7,6 +6,7 @@ import torch.nn.functional as F
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention2d import AttentionBlock
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample, Upsample from .resnet import Downsample, Upsample
...@@ -226,84 +226,6 @@ class ResBlock(TimestepBlock): ...@@ -226,84 +226,6 @@ class ResBlock(TimestepBlock):
return self.skip_connection(x) + h return self.skip_connection(x) + h
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=-1,
use_checkpoint=False,
encoder_channels=None,
):
super().__init__()
self.channels = channels
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint
self.norm = normalization(channels, swish=0.0)
self.qkv = conv_nd(1, channels, channels * 3, 1)
self.attention = QKVAttention(self.num_heads)
if encoder_channels is not None:
self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1)
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, x, encoder_out=None):
b, c, *spatial = x.shape
qkv = self.qkv(self.norm(x).view(b, c, -1))
if encoder_out is not None:
encoder_out = self.encoder_kv(encoder_out)
h = self.attention(qkv, encoder_out)
else:
h = self.attention(qkv)
h = self.proj_out(h)
return x + h.reshape(b, c, *spatial)
class QKVAttention(nn.Module):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv, encoder_kv=None):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x T] tensor after
attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
if encoder_kv is not None:
assert encoder_kv.shape[1] == self.n_heads * ch * 2
ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
k = torch.cat([ek, k], dim=-1)
v = torch.cat([ev, v], dim=-1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = torch.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length)
class GlideUNetModel(ModelMixin, ConfigMixin): class GlideUNetModel(ModelMixin, ConfigMixin):
""" """
The full UNet model with attention and timestep embedding. The full UNet model with attention and timestep embedding.
......
...@@ -9,6 +9,7 @@ import torch.nn.functional as F ...@@ -9,6 +9,7 @@ import torch.nn.functional as F
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention2d import AttentionBlock
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample, Upsample from .resnet import Downsample, Upsample
...@@ -172,8 +173,6 @@ class CrossAttention(nn.Module): ...@@ -172,8 +173,6 @@ class CrossAttention(nn.Module):
k = self.to_k(context) k = self.to_k(context)
v = self.to_v(context) v = self.to_v(context)
# q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
q = self.reshape_heads_to_batch_dim(q) q = self.reshape_heads_to_batch_dim(q)
k = self.reshape_heads_to_batch_dim(k) k = self.reshape_heads_to_batch_dim(k)
v = self.reshape_heads_to_batch_dim(v) v = self.reshape_heads_to_batch_dim(v)
...@@ -181,12 +180,9 @@ class CrossAttention(nn.Module): ...@@ -181,12 +180,9 @@ class CrossAttention(nn.Module):
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
if exists(mask): if exists(mask):
# mask = rearrange(mask, "b ... -> b (...)") mask = mask.reshape(batch_size, -1)
maks = mask.reshape(batch_size, -1)
max_neg_value = -torch.finfo(sim.dtype).max max_neg_value = -torch.finfo(sim.dtype).max
# mask = repeat(mask, "b j -> (b h) () j", h=h)
mask = mask[:, None, :].repeat(h, 1, 1) mask = mask[:, None, :].repeat(h, 1, 1)
# x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
sim.masked_fill_(~mask, max_neg_value) sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of # attention, what we cannot get enough of
...@@ -194,7 +190,6 @@ class CrossAttention(nn.Module): ...@@ -194,7 +190,6 @@ class CrossAttention(nn.Module):
out = torch.einsum("b i j, b j d -> b i d", attn, v) out = torch.einsum("b i j, b j d -> b i d", attn, v)
out = self.reshape_batch_dim_to_heads(out) out = self.reshape_batch_dim_to_heads(out)
# out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
return self.to_out(out) return self.to_out(out)
...@@ -487,47 +482,6 @@ class ResBlock(TimestepBlock): ...@@ -487,47 +482,6 @@ class ResBlock(TimestepBlock):
return self.skip_connection(x) + h return self.skip_connection(x) + h
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=-1,
use_checkpoint=False,
use_new_attention_order=False,
):
super().__init__()
self.channels = channels
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint
self.norm = normalization(channels)
self.qkv = conv_nd(1, channels, channels * 3, 1)
# split heads before split qkv
self.attention = QKVAttentionLegacy(self.num_heads)
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, x):
b, c, *spatial = x.shape
x = x.reshape(b, c, -1)
qkv = self.qkv(self.norm(x))
h = self.attention(qkv)
h = self.proj_out(h)
return (x + h).reshape(b, c, *spatial)
class QKVAttention(nn.Module): class QKVAttention(nn.Module):
""" """
A module which performs QKV attention and splits in a different order. A module which performs QKV attention and splits in a different order.
...@@ -577,35 +531,6 @@ def count_flops_attn(model, _x, y): ...@@ -577,35 +531,6 @@ def count_flops_attn(model, _x, y):
model.total_ops += torch.DoubleTensor([matmul_ops]) model.total_ops += torch.DoubleTensor([matmul_ops])
class QKVAttentionLegacy(nn.Module):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv):
"""
Apply QKV attention. :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x
T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = torch.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length)
@staticmethod
def count_flops(model, _x, y):
return count_flops_attn(model, _x, y)
class UNetLDMModel(ModelMixin, ConfigMixin): class UNetLDMModel(ModelMixin, ConfigMixin):
""" """
The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment