"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "f912f39b50f87e50a9d99346f5c1b6e644653262"
Commit 635da723 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

one attention module only

parent 79db3eb6
import math import math
import torch import torch
import torch.nn.functional as F
from torch import nn from torch import nn
# unet_grad_tts.py # unet_grad_tts.py
# TODO(Patrick) - weird linear attention layer. Check with: https://github.com/huawei-noah/Speech-Backbones/issues/15
class LinearAttention(torch.nn.Module): class LinearAttention(torch.nn.Module):
def __init__(self, dim, heads=4, dim_head=32): def __init__(self, dim, heads=4, dim_head=32):
super(LinearAttention, self).__init__() super(LinearAttention, self).__init__()
...@@ -18,7 +18,6 @@ class LinearAttention(torch.nn.Module): ...@@ -18,7 +18,6 @@ class LinearAttention(torch.nn.Module):
def forward(self, x): def forward(self, x):
b, c, h, w = x.shape b, c, h, w = x.shape
qkv = self.to_qkv(x) qkv = self.to_qkv(x)
# q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
q, k, v = ( q, k, v = (
qkv.reshape(b, 3, self.heads, self.dim_head, h, w) qkv.reshape(b, 3, self.heads, self.dim_head, h, w)
.permute(1, 0, 2, 3, 4, 5) .permute(1, 0, 2, 3, 4, 5)
...@@ -27,12 +26,11 @@ class LinearAttention(torch.nn.Module): ...@@ -27,12 +26,11 @@ class LinearAttention(torch.nn.Module):
k = k.softmax(dim=-1) k = k.softmax(dim=-1)
context = torch.einsum("bhdn,bhen->bhde", k, v) context = torch.einsum("bhdn,bhen->bhde", k, v)
out = torch.einsum("bhde,bhdn->bhen", context, q) out = torch.einsum("bhde,bhdn->bhen", context, q)
# out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
out = out.reshape(b, self.heads, self.dim_head, h, w).reshape(b, self.heads * self.dim_head, h, w) out = out.reshape(b, self.heads, self.dim_head, h, w).reshape(b, self.heads * self.dim_head, h, w)
return self.to_out(out) return self.to_out(out)
# unet_glide.py & unet_ldm.py # the main attention block that is used for all models
class AttentionBlock(nn.Module): class AttentionBlock(nn.Module):
""" """
An attention block that allows spatial positions to attend to each other. An attention block that allows spatial positions to attend to each other.
...@@ -46,10 +44,13 @@ class AttentionBlock(nn.Module): ...@@ -46,10 +44,13 @@ class AttentionBlock(nn.Module):
channels, channels,
num_heads=1, num_heads=1,
num_head_channels=-1, num_head_channels=-1,
num_groups=32,
use_checkpoint=False, use_checkpoint=False,
encoder_channels=None, encoder_channels=None,
use_new_attention_order=False, # TODO(Patrick) -> is never used, maybe delete? use_new_attention_order=False, # TODO(Patrick) -> is never used, maybe delete?
overwrite_qkv=False, overwrite_qkv=False,
overwrite_linear=False,
rescale_output_factor=1.0,
): ):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
...@@ -62,23 +63,34 @@ class AttentionBlock(nn.Module): ...@@ -62,23 +63,34 @@ class AttentionBlock(nn.Module):
self.num_heads = channels // num_head_channels self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint self.use_checkpoint = use_checkpoint
self.norm = normalization(channels, swish=0.0) self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-5, affine=True)
self.qkv = conv_nd(1, channels, channels * 3, 1) self.qkv = nn.Conv1d(channels, channels * 3, 1)
self.n_heads = self.num_heads self.n_heads = self.num_heads
self.rescale_output_factor = rescale_output_factor
if encoder_channels is not None: if encoder_channels is not None:
self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1) self.encoder_kv = nn.Conv1d(encoder_channels, channels * 2, 1)
self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
self.overwrite_qkv = overwrite_qkv self.overwrite_qkv = overwrite_qkv
if overwrite_qkv: if overwrite_qkv:
in_channels = channels in_channels = channels
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.overwrite_linear = overwrite_linear
if self.overwrite_linear:
num_groups = min(channels // 4, 32)
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
self.NIN_0 = NIN(channels, channels)
self.NIN_1 = NIN(channels, channels)
self.NIN_2 = NIN(channels, channels)
self.NIN_3 = NIN(channels, channels)
self.is_overwritten = False self.is_overwritten = False
def set_weights(self, module): def set_weights(self, module):
...@@ -89,11 +101,17 @@ class AttentionBlock(nn.Module): ...@@ -89,11 +101,17 @@ class AttentionBlock(nn.Module):
self.qkv.weight.data = qkv_weight self.qkv.weight.data = qkv_weight
self.qkv.bias.data = qkv_bias self.qkv.bias.data = qkv_bias
proj_out = zero_module(conv_nd(1, self.channels, self.channels, 1)) proj_out = zero_module(nn.Conv1d(self.channels, self.channels, 1))
proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0] proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0]
proj_out.bias.data = module.proj_out.bias.data proj_out.bias.data = module.proj_out.bias.data
self.proj_out = proj_out self.proj_out = proj_out
elif self.overwrite_linear:
self.qkv.weight.data = torch.concat([self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0)[:, :, None]
self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0)
self.proj_out.weight.data = self.NIN_3.W.data.T[:, :, None]
self.proj_out.bias.data = self.NIN_3.b.data
def forward(self, x, encoder_out=None): def forward(self, x, encoder_out=None):
if self.overwrite_qkv and not self.is_overwritten: if self.overwrite_qkv and not self.is_overwritten:
...@@ -124,69 +142,74 @@ class AttentionBlock(nn.Module): ...@@ -124,69 +142,74 @@ class AttentionBlock(nn.Module):
h = a.reshape(bs, -1, length) h = a.reshape(bs, -1, length)
h = self.proj_out(h) h = self.proj_out(h)
h = h.reshape(b, c, *spatial)
return x + h.reshape(b, c, *spatial) result = x + h
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
class GroupNorm32(nn.GroupNorm):
def __init__(self, num_groups, num_channels, swish, eps=1e-5, affine=True):
super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps, affine=affine)
self.swish = swish
def forward(self, x):
y = super().forward(x.float()).to(x.dtype)
if self.swish == 1.0:
y = F.silu(y)
elif self.swish:
y = y * F.sigmoid(y * float(self.swish))
return y
def normalization(channels, swish=0.0, eps=1e-5):
"""
Make a standard normalization layer, with an optional swish activation.
:param channels: number of input channels. :return: an nn.Module for normalization.
"""
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish, eps=eps, affine=True)
result = result / self.rescale_output_factor
def zero_module(module): return result
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
# unet_score_estimation.py # unet_score_estimation.py
# class AttnBlockpp(nn.Module): #class AttnBlockpp(nn.Module):
# """Channel-wise self-attention block. Modified from DDPM.""" # """Channel-wise self-attention block. Modified from DDPM."""
# #
# def __init__(self, channels, skip_rescale=False, init_scale=0.0): # def __init__(
# self,
# channels,
# skip_rescale=False,
# init_scale=0.0,
# num_heads=1,
# num_head_channels=-1,
# use_checkpoint=False,
# encoder_channels=None,
# use_new_attention_order=False, # TODO(Patrick) -> is never used, maybe delete?
# overwrite_qkv=False,
# overwrite_from_grad_tts=False,
# ):
# super().__init__() # super().__init__()
# self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-6) # num_groups = min(channels // 4, 32)
# self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6)
# self.NIN_0 = NIN(channels, channels) # self.NIN_0 = NIN(channels, channels)
# self.NIN_1 = NIN(channels, channels) # self.NIN_1 = NIN(channels, channels)
# self.NIN_2 = NIN(channels, channels) # self.NIN_2 = NIN(channels, channels)
# self.NIN_3 = NIN(channels, channels, init_scale=init_scale) # self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
# self.skip_rescale = skip_rescale # self.skip_rescale = skip_rescale
# #
# self.channels = channels
# if num_head_channels == -1:
# self.num_heads = num_heads
# else:
# assert (
# channels % num_head_channels == 0
# ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
# self.num_heads = channels // num_head_channels
#
# self.use_checkpoint = use_checkpoint
# self.norm = normalization(channels, num_groups=num_groups, eps=1e-6, swish=None)
# self.qkv = conv_nd(1, channels, channels * 3, 1)
# self.n_heads = self.num_heads
#
# if encoder_channels is not None:
# self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1)
#
# self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
#
# self.is_weight_set = False
#
# def set_weights(self):
# self.qkv.weight.data = torch.concat([self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0)[:, :, None]
# self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0)
#
# self.proj_out.weight.data = self.NIN_3.W.data.T[:, :, None]
# self.proj_out.bias.data = self.NIN_3.b.data
#
# def forward(self, x): # def forward(self, x):
# if not self.is_weight_set:
# self.set_weights()
# self.is_weight_set = True
#
# B, C, H, W = x.shape # B, C, H, W = x.shape
# h = self.GroupNorm_0(x) # h = self.GroupNorm_0(x)
# q = self.NIN_0(h) # q = self.NIN_0(h)
...@@ -199,7 +222,58 @@ def zero_module(module): ...@@ -199,7 +222,58 @@ def zero_module(module):
# w = torch.reshape(w, (B, H, W, H, W)) # w = torch.reshape(w, (B, H, W, H, W))
# h = torch.einsum("bhwij,bcij->bchw", w, v) # h = torch.einsum("bhwij,bcij->bchw", w, v)
# h = self.NIN_3(h) # h = self.NIN_3(h)
#
# if not self.skip_rescale: # if not self.skip_rescale:
# return x + h # result = x + h
# else: # else:
# return (x + h) / np.sqrt(2.0) # result = (x + h) / np.sqrt(2.0)
#
# result = self.forward_2(x)
#
# return result
#
# def forward_2(self, x, encoder_out=None):
# b, c, *spatial = x.shape
# hid_states = self.norm(x).view(b, c, -1)
#
# qkv = self.qkv(hid_states)
# bs, width, length = qkv.shape
# assert width % (3 * self.n_heads) == 0
# ch = width // (3 * self.n_heads)
# q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
#
# if encoder_out is not None:
# encoder_kv = self.encoder_kv(encoder_out)
# assert encoder_kv.shape[1] == self.n_heads * ch * 2
# ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
# k = torch.cat([ek, k], dim=-1)
# v = torch.cat([ev, v], dim=-1)
#
# scale = 1 / math.sqrt(math.sqrt(ch))
# weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
# weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
#
# a = torch.einsum("bts,bcs->bct", weight, v)
# h = a.reshape(bs, -1, length)
#
# h = self.proj_out(h)
# h = h.reshape(b, c, *spatial)
#
# return (x + h) / np.sqrt(2.0)
# TODO(Patrick) - this can and should be removed
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
# TODO(Patrick) - remove once all weights have been converted -> not needed anymore then
class NIN(nn.Module):
def __init__(self, in_dim, num_units, init_scale=0.1):
super().__init__()
self.W = nn.Parameter(torch.zeros(in_dim, num_units), requires_grad=True)
self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)
...@@ -65,19 +65,3 @@ class GaussianFourierProjection(nn.Module): ...@@ -65,19 +65,3 @@ class GaussianFourierProjection(nn.Module):
def forward(self, x): def forward(self, x):
x_proj = x[:, None] * self.W[None, :] * 2 * np.pi x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
# unet_rl.py - TODO(need test)
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
...@@ -5,6 +5,7 @@ from ..configuration_utils import ConfigMixin ...@@ -5,6 +5,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample, Upsample from .resnet import Downsample, Upsample
from .attention2d import LinearAttention
class Mish(torch.nn.Module): class Mish(torch.nn.Module):
...@@ -54,7 +55,7 @@ class ResnetBlock(torch.nn.Module): ...@@ -54,7 +55,7 @@ class ResnetBlock(torch.nn.Module):
return output return output
class LinearAttention(torch.nn.Module): class old_LinearAttention(torch.nn.Module):
def __init__(self, dim, heads=4, dim_head=32): def __init__(self, dim, heads=4, dim_head=32):
super(LinearAttention, self).__init__() super(LinearAttention, self).__init__()
self.heads = heads self.heads = heads
......
...@@ -22,10 +22,12 @@ import numpy as np ...@@ -22,10 +22,12 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import math
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .embeddings import GaussianFourierProjection, get_timestep_embedding from .embeddings import GaussianFourierProjection, get_timestep_embedding
from .attention2d import AttentionBlock
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
...@@ -414,37 +416,6 @@ class Combine(nn.Module): ...@@ -414,37 +416,6 @@ class Combine(nn.Module):
raise ValueError(f"Method {self.method} not recognized.") raise ValueError(f"Method {self.method} not recognized.")
class AttnBlockpp(nn.Module):
"""Channel-wise self-attention block. Modified from DDPM."""
def __init__(self, channels, skip_rescale=False, init_scale=0.0):
super().__init__()
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-6)
self.NIN_0 = NIN(channels, channels)
self.NIN_1 = NIN(channels, channels)
self.NIN_2 = NIN(channels, channels)
self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
self.skip_rescale = skip_rescale
def forward(self, x):
B, C, H, W = x.shape
h = self.GroupNorm_0(x)
q = self.NIN_0(h)
k = self.NIN_1(h)
v = self.NIN_2(h)
w = torch.einsum("bchw,bcij->bhwij", q, k) * (int(C) ** (-0.5))
w = torch.reshape(w, (B, H, W, H * W))
w = F.softmax(w, dim=-1)
w = torch.reshape(w, (B, H, W, H, W))
h = torch.einsum("bhwij,bcij->bchw", w, v)
h = self.NIN_3(h)
if not self.skip_rescale:
return x + h
else:
return (x + h) / np.sqrt(2.0)
class Upsample(nn.Module): class Upsample(nn.Module):
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)): def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)):
super().__init__() super().__init__()
...@@ -756,7 +727,7 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -756,7 +727,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
modules[-1].weight.data = default_init()(modules[-1].weight.shape) modules[-1].weight.data = default_init()(modules[-1].weight.shape)
nn.init.zeros_(modules[-1].bias) nn.init.zeros_(modules[-1].bias)
AttnBlock = functools.partial(AttnBlockpp, init_scale=init_scale, skip_rescale=skip_rescale) AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
Up_sample = functools.partial(Upsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel) Up_sample = functools.partial(Upsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment