Commit 9dccc7dc authored by Patrick von Platen's avatar Patrick von Platen
Browse files

refactor unet's attention

parent 52b3ff5e
...@@ -5,10 +5,6 @@ import torch.nn.functional as F ...@@ -5,10 +5,6 @@ import torch.nn.functional as F
from torch import nn from torch import nn
def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
# unet_grad_tts.py # unet_grad_tts.py
class LinearAttention(torch.nn.Module): class LinearAttention(torch.nn.Module):
def __init__(self, dim, heads=4, dim_head=32): def __init__(self, dim, heads=4, dim_head=32):
...@@ -42,31 +38,48 @@ class AttnBlock(nn.Module): ...@@ -42,31 +38,48 @@ class AttnBlock(nn.Module):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.norm = Normalize(in_channels) self.norm = normalization(in_channels, swish=None, eps=1e-6)
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x): def forward(self, x):
print("x", x.abs().sum())
h_ = x h_ = x
h_ = self.norm(h_) h_ = self.norm(h_)
print("hid_states shape", h_.shape)
print("hid_states", h_.abs().sum())
print("hid_states - 3 - 3", h_.view(h_.shape[0], h_.shape[1], -1)[:, :3, -3:])
q = self.q(h_) q = self.q(h_)
k = self.k(h_) k = self.k(h_)
v = self.v(h_) v = self.v(h_)
print(self.q)
print("q_shape", q.shape)
print("q", q.abs().sum())
# print("k_shape", k.shape)
# print("k", k.abs().sum())
# print("v_shape", v.shape)
# print("v", v.abs().sum())
# compute attention # compute attention
b, c, h, w = q.shape b, c, h, w = q.shape
q = q.reshape(b, c, h * w) q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1) # b,hw,c q = q.permute(0, 2, 1) # b,hw,c
k = k.reshape(b, c, h * w) # b,c,hw k = k.reshape(b, c, h * w) # b,c,hw
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c) ** (-0.5)) w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2) w_ = torch.nn.functional.softmax(w_, dim=2)
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
print("weight", w_.abs().sum())
# attend to values # attend to values
v = v.reshape(b, c, h * w) v = v.reshape(b, c, h * w)
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b, c, h, w) h_ = h_.reshape(b, c, h, w)
...@@ -92,6 +105,7 @@ class AttentionBlock(nn.Module): ...@@ -92,6 +105,7 @@ class AttentionBlock(nn.Module):
use_checkpoint=False, use_checkpoint=False,
encoder_channels=None, encoder_channels=None,
use_new_attention_order=False, # TODO(Patrick) -> is never used, maybe delete? use_new_attention_order=False, # TODO(Patrick) -> is never used, maybe delete?
overwrite_qkv=False,
): ):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
...@@ -102,57 +116,72 @@ class AttentionBlock(nn.Module): ...@@ -102,57 +116,72 @@ class AttentionBlock(nn.Module):
channels % num_head_channels == 0 channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint self.use_checkpoint = use_checkpoint
self.norm = normalization(channels, swish=0.0) self.norm = normalization(channels, swish=0.0)
self.qkv = conv_nd(1, channels, channels * 3, 1) self.qkv = conv_nd(1, channels, channels * 3, 1)
self.attention = QKVAttention(self.num_heads) self.n_heads = self.num_heads
if encoder_channels is not None: if encoder_channels is not None:
self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1) self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1)
self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, x, encoder_out=None): self.overwrite_qkv = overwrite_qkv
b, c, *spatial = x.shape if overwrite_qkv:
qkv = self.qkv(self.norm(x).view(b, c, -1)) in_channels = channels
if encoder_out is not None: self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
encoder_out = self.encoder_kv(encoder_out) self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
h = self.attention(qkv, encoder_out) self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
else: self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
h = self.attention(qkv)
h = self.proj_out(h)
return x + h.reshape(b, c, *spatial)
self.is_overwritten = False
class QKVAttention(nn.Module): def set_weights(self, module):
""" if self.overwrite_qkv:
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping qkv_weight = torch.cat([module.q.weight.data, module.k.weight.data, module.v.weight.data], dim=0)[:, :, :, 0]
""" qkv_bias = torch.cat([module.q.bias.data, module.k.bias.data, module.v.bias.data], dim=0)
def __init__(self, n_heads): self.qkv.weight.data = qkv_weight
super().__init__() self.qkv.bias.data = qkv_bias
self.n_heads = n_heads
proj_out = zero_module(conv_nd(1, self.channels, self.channels, 1))
proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0]
proj_out.bias.data = module.proj_out.bias.data
def forward(self, qkv, encoder_kv=None): self.proj_out = proj_out
"""
Apply QKV attention. def forward(self, x, encoder_out=None):
if self.overwrite_qkv and not self.is_overwritten:
self.set_weights(self)
self.is_overwritten = True
b, c, *spatial = x.shape
hid_states = self.norm(x).view(b, c, -1)
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x T] tensor after qkv = self.qkv(hid_states)
attention.
"""
bs, width, length = qkv.shape bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0 assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads) ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
if encoder_kv is not None:
if encoder_out is not None:
encoder_kv = self.encoder_kv(encoder_out)
assert encoder_kv.shape[1] == self.n_heads * ch * 2 assert encoder_kv.shape[1] == self.n_heads * ch * 2
ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1) ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
k = torch.cat([ek, k], dim=-1) k = torch.cat([ek, k], dim=-1)
v = torch.cat([ev, v], dim=-1) v = torch.cat([ev, v], dim=-1)
scale = 1 / math.sqrt(math.sqrt(ch)) scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = torch.einsum("bts,bcs->bct", weight, v) a = torch.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length) h = a.reshape(bs, -1, length)
h = self.proj_out(h)
return x + h.reshape(b, c, *spatial)
def conv_nd(dims, *args, **kwargs): def conv_nd(dims, *args, **kwargs):
...@@ -169,8 +198,8 @@ def conv_nd(dims, *args, **kwargs): ...@@ -169,8 +198,8 @@ def conv_nd(dims, *args, **kwargs):
class GroupNorm32(nn.GroupNorm): class GroupNorm32(nn.GroupNorm):
def __init__(self, num_groups, num_channels, swish, eps=1e-5): def __init__(self, num_groups, num_channels, swish, eps=1e-5, affine=True):
super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps) super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps, affine=affine)
self.swish = swish self.swish = swish
def forward(self, x): def forward(self, x):
...@@ -182,13 +211,13 @@ class GroupNorm32(nn.GroupNorm): ...@@ -182,13 +211,13 @@ class GroupNorm32(nn.GroupNorm):
return y return y
def normalization(channels, swish=0.0): def normalization(channels, swish=0.0, eps=1e-5):
""" """
Make a standard normalization layer, with an optional swish activation. Make a standard normalization layer, with an optional swish activation.
:param channels: number of input channels. :return: an nn.Module for normalization. :param channels: number of input channels. :return: an nn.Module for normalization.
""" """
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish) return GroupNorm32(num_channels=channels, num_groups=32, swish=swish, eps=eps, affine=True)
def zero_module(module): def zero_module(module):
......
...@@ -32,6 +32,7 @@ from ..configuration_utils import ConfigMixin ...@@ -32,6 +32,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample, Upsample from .resnet import Downsample, Upsample
from .attention2d import AttnBlock, AttentionBlock
def nonlinearity(x): def nonlinearity(x):
...@@ -85,42 +86,42 @@ class ResnetBlock(nn.Module): ...@@ -85,42 +86,42 @@ class ResnetBlock(nn.Module):
return x + h return x + h
class AttnBlock(nn.Module): #class AttnBlock(nn.Module):
def __init__(self, in_channels): # def __init__(self, in_channels):
super().__init__() # super().__init__()
self.in_channels = in_channels # self.in_channels = in_channels
#
self.norm = Normalize(in_channels) # self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) # self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) # self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) # self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) # self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
#
def forward(self, x): # def forward(self, x):
h_ = x # h_ = x
h_ = self.norm(h_) # h_ = self.norm(h_)
q = self.q(h_) # q = self.q(h_)
k = self.k(h_) # k = self.k(h_)
v = self.v(h_) # v = self.v(h_)
#
# compute attention # compute attention
b, c, h, w = q.shape # b, c, h, w = q.shape
q = q.reshape(b, c, h * w) # q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1) # b,hw,c # q = q.permute(0, 2, 1) # b,hw,c
k = k.reshape(b, c, h * w) # b,c,hw # k = k.reshape(b, c, h * w) # b,c,hw
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] # w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c) ** (-0.5)) # w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2) # w_ = torch.nn.functional.softmax(w_, dim=2)
#
# attend to values # attend to values
v = v.reshape(b, c, h * w) # v = v.reshape(b, c, h * w)
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) # w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] # h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b, c, h, w) # h_ = h_.reshape(b, c, h, w)
#
h_ = self.proj_out(h_) # h_ = self.proj_out(h_)
#
return x + h_ # return x + h_
class UNetModel(ModelMixin, ConfigMixin): class UNetModel(ModelMixin, ConfigMixin):
...@@ -174,6 +175,7 @@ class UNetModel(ModelMixin, ConfigMixin): ...@@ -174,6 +175,7 @@ class UNetModel(ModelMixin, ConfigMixin):
for i_level in range(self.num_resolutions): for i_level in range(self.num_resolutions):
block = nn.ModuleList() block = nn.ModuleList()
attn = nn.ModuleList() attn = nn.ModuleList()
attn_2 = nn.ModuleList()
block_in = ch * in_ch_mult[i_level] block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level] block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks): for i_block in range(self.num_res_blocks):
...@@ -184,10 +186,12 @@ class UNetModel(ModelMixin, ConfigMixin): ...@@ -184,10 +186,12 @@ class UNetModel(ModelMixin, ConfigMixin):
) )
block_in = block_out block_in = block_out
if curr_res in attn_resolutions: if curr_res in attn_resolutions:
attn.append(AttnBlock(block_in)) # attn.append(AttnBlock(block_in))
attn.append(AttentionBlock(block_in, overwrite_qkv=True))
down = nn.Module() down = nn.Module()
down.block = block down.block = block
down.attn = attn down.attn = attn
down.attn_2 = attn_2
if i_level != self.num_resolutions - 1: if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, use_conv=resamp_with_conv, padding=0) down.downsample = Downsample(block_in, use_conv=resamp_with_conv, padding=0)
curr_res = curr_res // 2 curr_res = curr_res // 2
...@@ -198,7 +202,8 @@ class UNetModel(ModelMixin, ConfigMixin): ...@@ -198,7 +202,8 @@ class UNetModel(ModelMixin, ConfigMixin):
self.mid.block_1 = ResnetBlock( self.mid.block_1 = ResnetBlock(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
) )
self.mid.attn_1 = AttnBlock(block_in) # self.mid.attn_1 = AttnBlock(block_in)
self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True)
self.mid.block_2 = ResnetBlock( self.mid.block_2 = ResnetBlock(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
) )
...@@ -223,7 +228,8 @@ class UNetModel(ModelMixin, ConfigMixin): ...@@ -223,7 +228,8 @@ class UNetModel(ModelMixin, ConfigMixin):
) )
block_in = block_out block_in = block_out
if curr_res in attn_resolutions: if curr_res in attn_resolutions:
attn.append(AttnBlock(block_in)) # attn.append(AttnBlock(block_in))
attn.append(AttentionBlock(block_in, overwrite_qkv=True))
up = nn.Module() up = nn.Module()
up.block = block up.block = block
up.attn = attn up.attn = attn
...@@ -254,7 +260,11 @@ class UNetModel(ModelMixin, ConfigMixin): ...@@ -254,7 +260,11 @@ class UNetModel(ModelMixin, ConfigMixin):
for i_block in range(self.num_res_blocks): for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb) h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0: if len(self.down[i_level].attn) > 0:
# self.down[i_level].attn_2[i_block].set_weights(self.down[i_level].attn[i_block])
# h = self.down[i_level].attn_2[i_block](h)
h = self.down[i_level].attn[i_block](h) h = self.down[i_level].attn[i_block](h)
# print("Result", (h - h_2).abs().sum())
hs.append(h) hs.append(h)
if i_level != self.num_resolutions - 1: if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1])) hs.append(self.down[i_level].downsample(hs[-1]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment