"...text-generation-inference.git" did not exist on "b8b950b37cc1d65f1c02f6c1615a6f0cf40b09c5"
Commit 635da723 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

one attention module only

parent 79db3eb6
import math
import torch
import torch.nn.functional as F
from torch import nn
# unet_grad_tts.py
# TODO(Patrick) - weird linear attention layer. Check with: https://github.com/huawei-noah/Speech-Backbones/issues/15
class LinearAttention(torch.nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super(LinearAttention, self).__init__()
......@@ -18,7 +18,6 @@ class LinearAttention(torch.nn.Module):
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
# q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
q, k, v = (
qkv.reshape(b, 3, self.heads, self.dim_head, h, w)
.permute(1, 0, 2, 3, 4, 5)
......@@ -27,12 +26,11 @@ class LinearAttention(torch.nn.Module):
k = k.softmax(dim=-1)
context = torch.einsum("bhdn,bhen->bhde", k, v)
out = torch.einsum("bhde,bhdn->bhen", context, q)
# out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
out = out.reshape(b, self.heads, self.dim_head, h, w).reshape(b, self.heads * self.dim_head, h, w)
return self.to_out(out)
# unet_glide.py & unet_ldm.py
# the main attention block that is used for all models
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
......@@ -46,10 +44,13 @@ class AttentionBlock(nn.Module):
channels,
num_heads=1,
num_head_channels=-1,
num_groups=32,
use_checkpoint=False,
encoder_channels=None,
use_new_attention_order=False, # TODO(Patrick) -> is never used, maybe delete?
overwrite_qkv=False,
overwrite_linear=False,
rescale_output_factor=1.0,
):
super().__init__()
self.channels = channels
......@@ -62,23 +63,34 @@ class AttentionBlock(nn.Module):
self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint
self.norm = normalization(channels, swish=0.0)
self.qkv = conv_nd(1, channels, channels * 3, 1)
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-5, affine=True)
self.qkv = nn.Conv1d(channels, channels * 3, 1)
self.n_heads = self.num_heads
self.rescale_output_factor = rescale_output_factor
if encoder_channels is not None:
self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1)
self.encoder_kv = nn.Conv1d(encoder_channels, channels * 2, 1)
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
self.overwrite_qkv = overwrite_qkv
if overwrite_qkv:
in_channels = channels
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.overwrite_linear = overwrite_linear
if self.overwrite_linear:
num_groups = min(channels // 4, 32)
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
self.NIN_0 = NIN(channels, channels)
self.NIN_1 = NIN(channels, channels)
self.NIN_2 = NIN(channels, channels)
self.NIN_3 = NIN(channels, channels)
self.is_overwritten = False
def set_weights(self, module):
......@@ -89,11 +101,17 @@ class AttentionBlock(nn.Module):
self.qkv.weight.data = qkv_weight
self.qkv.bias.data = qkv_bias
proj_out = zero_module(conv_nd(1, self.channels, self.channels, 1))
proj_out = zero_module(nn.Conv1d(self.channels, self.channels, 1))
proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0]
proj_out.bias.data = module.proj_out.bias.data
self.proj_out = proj_out
elif self.overwrite_linear:
self.qkv.weight.data = torch.concat([self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0)[:, :, None]
self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0)
self.proj_out.weight.data = self.NIN_3.W.data.T[:, :, None]
self.proj_out.bias.data = self.NIN_3.b.data
def forward(self, x, encoder_out=None):
if self.overwrite_qkv and not self.is_overwritten:
......@@ -124,69 +142,74 @@ class AttentionBlock(nn.Module):
h = a.reshape(bs, -1, length)
h = self.proj_out(h)
h = h.reshape(b, c, *spatial)
return x + h.reshape(b, c, *spatial)
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
class GroupNorm32(nn.GroupNorm):
def __init__(self, num_groups, num_channels, swish, eps=1e-5, affine=True):
super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps, affine=affine)
self.swish = swish
def forward(self, x):
y = super().forward(x.float()).to(x.dtype)
if self.swish == 1.0:
y = F.silu(y)
elif self.swish:
y = y * F.sigmoid(y * float(self.swish))
return y
def normalization(channels, swish=0.0, eps=1e-5):
"""
Make a standard normalization layer, with an optional swish activation.
:param channels: number of input channels. :return: an nn.Module for normalization.
"""
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish, eps=eps, affine=True)
result = x + h
result = result / self.rescale_output_factor
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
return result
# unet_score_estimation.py
# class AttnBlockpp(nn.Module):
#class AttnBlockpp(nn.Module):
# """Channel-wise self-attention block. Modified from DDPM."""
#
# def __init__(self, channels, skip_rescale=False, init_scale=0.0):
# def __init__(
# self,
# channels,
# skip_rescale=False,
# init_scale=0.0,
# num_heads=1,
# num_head_channels=-1,
# use_checkpoint=False,
# encoder_channels=None,
# use_new_attention_order=False, # TODO(Patrick) -> is never used, maybe delete?
# overwrite_qkv=False,
# overwrite_from_grad_tts=False,
# ):
# super().__init__()
# self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-6)
# num_groups = min(channels // 4, 32)
# self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6)
# self.NIN_0 = NIN(channels, channels)
# self.NIN_1 = NIN(channels, channels)
# self.NIN_2 = NIN(channels, channels)
# self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
# self.skip_rescale = skip_rescale
#
# self.channels = channels
# if num_head_channels == -1:
# self.num_heads = num_heads
# else:
# assert (
# channels % num_head_channels == 0
# ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
# self.num_heads = channels // num_head_channels
#
# self.use_checkpoint = use_checkpoint
# self.norm = normalization(channels, num_groups=num_groups, eps=1e-6, swish=None)
# self.qkv = conv_nd(1, channels, channels * 3, 1)
# self.n_heads = self.num_heads
#
# if encoder_channels is not None:
# self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1)
#
# self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
#
# self.is_weight_set = False
#
# def set_weights(self):
# self.qkv.weight.data = torch.concat([self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0)[:, :, None]
# self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0)
#
# self.proj_out.weight.data = self.NIN_3.W.data.T[:, :, None]
# self.proj_out.bias.data = self.NIN_3.b.data
#
# def forward(self, x):
# if not self.is_weight_set:
# self.set_weights()
# self.is_weight_set = True
#
# B, C, H, W = x.shape
# h = self.GroupNorm_0(x)
# q = self.NIN_0(h)
......@@ -199,7 +222,58 @@ def zero_module(module):
# w = torch.reshape(w, (B, H, W, H, W))
# h = torch.einsum("bhwij,bcij->bchw", w, v)
# h = self.NIN_3(h)
#
# if not self.skip_rescale:
# return x + h
# result = x + h
# else:
# return (x + h) / np.sqrt(2.0)
# result = (x + h) / np.sqrt(2.0)
#
# result = self.forward_2(x)
#
# return result
#
# def forward_2(self, x, encoder_out=None):
# b, c, *spatial = x.shape
# hid_states = self.norm(x).view(b, c, -1)
#
# qkv = self.qkv(hid_states)
# bs, width, length = qkv.shape
# assert width % (3 * self.n_heads) == 0
# ch = width // (3 * self.n_heads)
# q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
#
# if encoder_out is not None:
# encoder_kv = self.encoder_kv(encoder_out)
# assert encoder_kv.shape[1] == self.n_heads * ch * 2
# ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
# k = torch.cat([ek, k], dim=-1)
# v = torch.cat([ev, v], dim=-1)
#
# scale = 1 / math.sqrt(math.sqrt(ch))
# weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
# weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
#
# a = torch.einsum("bts,bcs->bct", weight, v)
# h = a.reshape(bs, -1, length)
#
# h = self.proj_out(h)
# h = h.reshape(b, c, *spatial)
#
# return (x + h) / np.sqrt(2.0)
# TODO(Patrick) - this can and should be removed
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
# TODO(Patrick) - remove once all weights have been converted -> not needed anymore then
class NIN(nn.Module):
def __init__(self, in_dim, num_units, init_scale=0.1):
super().__init__()
self.W = nn.Parameter(torch.zeros(in_dim, num_units), requires_grad=True)
self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)
......@@ -65,19 +65,3 @@ class GaussianFourierProjection(nn.Module):
def forward(self, x):
x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
# unet_rl.py - TODO(need test)
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
......@@ -5,6 +5,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
from .resnet import Downsample, Upsample
from .attention2d import LinearAttention
class Mish(torch.nn.Module):
......@@ -54,7 +55,7 @@ class ResnetBlock(torch.nn.Module):
return output
class LinearAttention(torch.nn.Module):
class old_LinearAttention(torch.nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super(LinearAttention, self).__init__()
self.heads = heads
......
......@@ -22,10 +22,12 @@ import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .embeddings import GaussianFourierProjection, get_timestep_embedding
from .attention2d import AttentionBlock
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
......@@ -414,37 +416,6 @@ class Combine(nn.Module):
raise ValueError(f"Method {self.method} not recognized.")
class AttnBlockpp(nn.Module):
"""Channel-wise self-attention block. Modified from DDPM."""
def __init__(self, channels, skip_rescale=False, init_scale=0.0):
super().__init__()
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-6)
self.NIN_0 = NIN(channels, channels)
self.NIN_1 = NIN(channels, channels)
self.NIN_2 = NIN(channels, channels)
self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
self.skip_rescale = skip_rescale
def forward(self, x):
B, C, H, W = x.shape
h = self.GroupNorm_0(x)
q = self.NIN_0(h)
k = self.NIN_1(h)
v = self.NIN_2(h)
w = torch.einsum("bchw,bcij->bhwij", q, k) * (int(C) ** (-0.5))
w = torch.reshape(w, (B, H, W, H * W))
w = F.softmax(w, dim=-1)
w = torch.reshape(w, (B, H, W, H, W))
h = torch.einsum("bhwij,bcij->bchw", w, v)
h = self.NIN_3(h)
if not self.skip_rescale:
return x + h
else:
return (x + h) / np.sqrt(2.0)
class Upsample(nn.Module):
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)):
super().__init__()
......@@ -756,7 +727,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
modules[-1].weight.data = default_init()(modules[-1].weight.shape)
nn.init.zeros_(modules[-1].bias)
AttnBlock = functools.partial(AttnBlockpp, init_scale=init_scale, skip_rescale=skip_rescale)
AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
Up_sample = functools.partial(Upsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment