Commit 635da723 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

one attention module only

parent 79db3eb6
import math import math
import torch import torch
import torch.nn.functional as F
from torch import nn from torch import nn
# unet_grad_tts.py # unet_grad_tts.py
# TODO(Patrick) - weird linear attention layer. Check with: https://github.com/huawei-noah/Speech-Backbones/issues/15
class LinearAttention(torch.nn.Module): class LinearAttention(torch.nn.Module):
def __init__(self, dim, heads=4, dim_head=32): def __init__(self, dim, heads=4, dim_head=32):
super(LinearAttention, self).__init__() super(LinearAttention, self).__init__()
...@@ -18,7 +18,6 @@ class LinearAttention(torch.nn.Module): ...@@ -18,7 +18,6 @@ class LinearAttention(torch.nn.Module):
def forward(self, x): def forward(self, x):
b, c, h, w = x.shape b, c, h, w = x.shape
qkv = self.to_qkv(x) qkv = self.to_qkv(x)
# q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
q, k, v = ( q, k, v = (
qkv.reshape(b, 3, self.heads, self.dim_head, h, w) qkv.reshape(b, 3, self.heads, self.dim_head, h, w)
.permute(1, 0, 2, 3, 4, 5) .permute(1, 0, 2, 3, 4, 5)
...@@ -27,12 +26,11 @@ class LinearAttention(torch.nn.Module): ...@@ -27,12 +26,11 @@ class LinearAttention(torch.nn.Module):
k = k.softmax(dim=-1) k = k.softmax(dim=-1)
context = torch.einsum("bhdn,bhen->bhde", k, v) context = torch.einsum("bhdn,bhen->bhde", k, v)
out = torch.einsum("bhde,bhdn->bhen", context, q) out = torch.einsum("bhde,bhdn->bhen", context, q)
# out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
out = out.reshape(b, self.heads, self.dim_head, h, w).reshape(b, self.heads * self.dim_head, h, w) out = out.reshape(b, self.heads, self.dim_head, h, w).reshape(b, self.heads * self.dim_head, h, w)
return self.to_out(out) return self.to_out(out)
# unet_glide.py & unet_ldm.py # the main attention block that is used for all models
class AttentionBlock(nn.Module): class AttentionBlock(nn.Module):
""" """
An attention block that allows spatial positions to attend to each other. An attention block that allows spatial positions to attend to each other.
...@@ -46,10 +44,13 @@ class AttentionBlock(nn.Module): ...@@ -46,10 +44,13 @@ class AttentionBlock(nn.Module):
channels, channels,
num_heads=1, num_heads=1,
num_head_channels=-1, num_head_channels=-1,
num_groups=32,
use_checkpoint=False, use_checkpoint=False,
encoder_channels=None, encoder_channels=None,
use_new_attention_order=False, # TODO(Patrick) -> is never used, maybe delete? use_new_attention_order=False, # TODO(Patrick) -> is never used, maybe delete?
overwrite_qkv=False, overwrite_qkv=False,
overwrite_linear=False,
rescale_output_factor=1.0,
): ):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
...@@ -62,23 +63,34 @@ class AttentionBlock(nn.Module): ...@@ -62,23 +63,34 @@ class AttentionBlock(nn.Module):
self.num_heads = channels // num_head_channels self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint self.use_checkpoint = use_checkpoint
self.norm = normalization(channels, swish=0.0) self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-5, affine=True)
self.qkv = conv_nd(1, channels, channels * 3, 1) self.qkv = nn.Conv1d(channels, channels * 3, 1)
self.n_heads = self.num_heads self.n_heads = self.num_heads
self.rescale_output_factor = rescale_output_factor
if encoder_channels is not None: if encoder_channels is not None:
self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1) self.encoder_kv = nn.Conv1d(encoder_channels, channels * 2, 1)
self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
self.overwrite_qkv = overwrite_qkv self.overwrite_qkv = overwrite_qkv
if overwrite_qkv: if overwrite_qkv:
in_channels = channels in_channels = channels
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.overwrite_linear = overwrite_linear
if self.overwrite_linear:
num_groups = min(channels // 4, 32)
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
self.NIN_0 = NIN(channels, channels)
self.NIN_1 = NIN(channels, channels)
self.NIN_2 = NIN(channels, channels)
self.NIN_3 = NIN(channels, channels)
self.is_overwritten = False self.is_overwritten = False
def set_weights(self, module): def set_weights(self, module):
...@@ -89,11 +101,17 @@ class AttentionBlock(nn.Module): ...@@ -89,11 +101,17 @@ class AttentionBlock(nn.Module):
self.qkv.weight.data = qkv_weight self.qkv.weight.data = qkv_weight
self.qkv.bias.data = qkv_bias self.qkv.bias.data = qkv_bias
proj_out = zero_module(conv_nd(1, self.channels, self.channels, 1)) proj_out = zero_module(nn.Conv1d(self.channels, self.channels, 1))
proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0] proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0]
proj_out.bias.data = module.proj_out.bias.data proj_out.bias.data = module.proj_out.bias.data
self.proj_out = proj_out self.proj_out = proj_out
elif self.overwrite_linear:
self.qkv.weight.data = torch.concat([self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0)[:, :, None]
self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0)
self.proj_out.weight.data = self.NIN_3.W.data.T[:, :, None]
self.proj_out.bias.data = self.NIN_3.b.data
def forward(self, x, encoder_out=None): def forward(self, x, encoder_out=None):
if self.overwrite_qkv and not self.is_overwritten: if self.overwrite_qkv and not self.is_overwritten:
...@@ -124,69 +142,74 @@ class AttentionBlock(nn.Module): ...@@ -124,69 +142,74 @@ class AttentionBlock(nn.Module):
h = a.reshape(bs, -1, length) h = a.reshape(bs, -1, length)
h = self.proj_out(h) h = self.proj_out(h)
h = h.reshape(b, c, *spatial)
return x + h.reshape(b, c, *spatial) result = x + h
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
class GroupNorm32(nn.GroupNorm):
def __init__(self, num_groups, num_channels, swish, eps=1e-5, affine=True):
super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps, affine=affine)
self.swish = swish
def forward(self, x):
y = super().forward(x.float()).to(x.dtype)
if self.swish == 1.0:
y = F.silu(y)
elif self.swish:
y = y * F.sigmoid(y * float(self.swish))
return y
def normalization(channels, swish=0.0, eps=1e-5):
"""
Make a standard normalization layer, with an optional swish activation.
:param channels: number of input channels. :return: an nn.Module for normalization.
"""
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish, eps=eps, affine=True)
result = result / self.rescale_output_factor
def zero_module(module): return result
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
# unet_score_estimation.py # unet_score_estimation.py
# class AttnBlockpp(nn.Module): #class AttnBlockpp(nn.Module):
# """Channel-wise self-attention block. Modified from DDPM.""" # """Channel-wise self-attention block. Modified from DDPM."""
# #
# def __init__(self, channels, skip_rescale=False, init_scale=0.0): # def __init__(
# self,
# channels,
# skip_rescale=False,
# init_scale=0.0,
# num_heads=1,
# num_head_channels=-1,
# use_checkpoint=False,
# encoder_channels=None,
# use_new_attention_order=False, # TODO(Patrick) -> is never used, maybe delete?
# overwrite_qkv=False,
# overwrite_from_grad_tts=False,
# ):
# super().__init__() # super().__init__()
# self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-6) # num_groups = min(channels // 4, 32)
# self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6)
# self.NIN_0 = NIN(channels, channels) # self.NIN_0 = NIN(channels, channels)
# self.NIN_1 = NIN(channels, channels) # self.NIN_1 = NIN(channels, channels)
# self.NIN_2 = NIN(channels, channels) # self.NIN_2 = NIN(channels, channels)
# self.NIN_3 = NIN(channels, channels, init_scale=init_scale) # self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
# self.skip_rescale = skip_rescale # self.skip_rescale = skip_rescale
# #
# self.channels = channels
# if num_head_channels == -1:
# self.num_heads = num_heads
# else:
# assert (
# channels % num_head_channels == 0
# ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
# self.num_heads = channels // num_head_channels
#
# self.use_checkpoint = use_checkpoint
# self.norm = normalization(channels, num_groups=num_groups, eps=1e-6, swish=None)
# self.qkv = conv_nd(1, channels, channels * 3, 1)
# self.n_heads = self.num_heads
#
# if encoder_channels is not None:
# self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1)
#
# self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
#
# self.is_weight_set = False
#
# def set_weights(self):
# self.qkv.weight.data = torch.concat([self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0)[:, :, None]
# self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0)
#
# self.proj_out.weight.data = self.NIN_3.W.data.T[:, :, None]
# self.proj_out.bias.data = self.NIN_3.b.data
#
# def forward(self, x): # def forward(self, x):
# if not self.is_weight_set:
# self.set_weights()
# self.is_weight_set = True
#
# B, C, H, W = x.shape # B, C, H, W = x.shape
# h = self.GroupNorm_0(x) # h = self.GroupNorm_0(x)
# q = self.NIN_0(h) # q = self.NIN_0(h)
...@@ -199,7 +222,58 @@ def zero_module(module): ...@@ -199,7 +222,58 @@ def zero_module(module):
# w = torch.reshape(w, (B, H, W, H, W)) # w = torch.reshape(w, (B, H, W, H, W))
# h = torch.einsum("bhwij,bcij->bchw", w, v) # h = torch.einsum("bhwij,bcij->bchw", w, v)
# h = self.NIN_3(h) # h = self.NIN_3(h)
#
# if not self.skip_rescale: # if not self.skip_rescale:
# return x + h # result = x + h
# else: # else:
# result = (x + h) / np.sqrt(2.0)
#
# result = self.forward_2(x)
#
# return result
#
# def forward_2(self, x, encoder_out=None):
# b, c, *spatial = x.shape
# hid_states = self.norm(x).view(b, c, -1)
#
# qkv = self.qkv(hid_states)
# bs, width, length = qkv.shape
# assert width % (3 * self.n_heads) == 0
# ch = width // (3 * self.n_heads)
# q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
#
# if encoder_out is not None:
# encoder_kv = self.encoder_kv(encoder_out)
# assert encoder_kv.shape[1] == self.n_heads * ch * 2
# ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
# k = torch.cat([ek, k], dim=-1)
# v = torch.cat([ev, v], dim=-1)
#
# scale = 1 / math.sqrt(math.sqrt(ch))
# weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
# weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
#
# a = torch.einsum("bts,bcs->bct", weight, v)
# h = a.reshape(bs, -1, length)
#
# h = self.proj_out(h)
# h = h.reshape(b, c, *spatial)
#
# return (x + h) / np.sqrt(2.0) # return (x + h) / np.sqrt(2.0)
# TODO(Patrick) - this can and should be removed
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
# TODO(Patrick) - remove once all weights have been converted -> not needed anymore then
class NIN(nn.Module):
def __init__(self, in_dim, num_units, init_scale=0.1):
super().__init__()
self.W = nn.Parameter(torch.zeros(in_dim, num_units), requires_grad=True)
self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)
...@@ -65,19 +65,3 @@ class GaussianFourierProjection(nn.Module): ...@@ -65,19 +65,3 @@ class GaussianFourierProjection(nn.Module):
def forward(self, x): def forward(self, x):
x_proj = x[:, None] * self.W[None, :] * 2 * np.pi x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
# unet_rl.py - TODO(need test)
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
...@@ -5,6 +5,7 @@ from ..configuration_utils import ConfigMixin ...@@ -5,6 +5,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample, Upsample from .resnet import Downsample, Upsample
from .attention2d import LinearAttention
class Mish(torch.nn.Module): class Mish(torch.nn.Module):
...@@ -54,7 +55,7 @@ class ResnetBlock(torch.nn.Module): ...@@ -54,7 +55,7 @@ class ResnetBlock(torch.nn.Module):
return output return output
class LinearAttention(torch.nn.Module): class old_LinearAttention(torch.nn.Module):
def __init__(self, dim, heads=4, dim_head=32): def __init__(self, dim, heads=4, dim_head=32):
super(LinearAttention, self).__init__() super(LinearAttention, self).__init__()
self.heads = heads self.heads = heads
......
...@@ -22,10 +22,12 @@ import numpy as np ...@@ -22,10 +22,12 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import math
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .embeddings import GaussianFourierProjection, get_timestep_embedding from .embeddings import GaussianFourierProjection, get_timestep_embedding
from .attention2d import AttentionBlock
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
...@@ -414,37 +416,6 @@ class Combine(nn.Module): ...@@ -414,37 +416,6 @@ class Combine(nn.Module):
raise ValueError(f"Method {self.method} not recognized.") raise ValueError(f"Method {self.method} not recognized.")
class AttnBlockpp(nn.Module):
"""Channel-wise self-attention block. Modified from DDPM."""
def __init__(self, channels, skip_rescale=False, init_scale=0.0):
super().__init__()
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-6)
self.NIN_0 = NIN(channels, channels)
self.NIN_1 = NIN(channels, channels)
self.NIN_2 = NIN(channels, channels)
self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
self.skip_rescale = skip_rescale
def forward(self, x):
B, C, H, W = x.shape
h = self.GroupNorm_0(x)
q = self.NIN_0(h)
k = self.NIN_1(h)
v = self.NIN_2(h)
w = torch.einsum("bchw,bcij->bhwij", q, k) * (int(C) ** (-0.5))
w = torch.reshape(w, (B, H, W, H * W))
w = F.softmax(w, dim=-1)
w = torch.reshape(w, (B, H, W, H, W))
h = torch.einsum("bhwij,bcij->bchw", w, v)
h = self.NIN_3(h)
if not self.skip_rescale:
return x + h
else:
return (x + h) / np.sqrt(2.0)
class Upsample(nn.Module): class Upsample(nn.Module):
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)): def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)):
super().__init__() super().__init__()
...@@ -756,7 +727,7 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -756,7 +727,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
modules[-1].weight.data = default_init()(modules[-1].weight.shape) modules[-1].weight.data = default_init()(modules[-1].weight.shape)
nn.init.zeros_(modules[-1].bias) nn.init.zeros_(modules[-1].bias)
AttnBlock = functools.partial(AttnBlockpp, init_scale=init_scale, skip_rescale=skip_rescale) AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
Up_sample = functools.partial(Upsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel) Up_sample = functools.partial(Upsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment