Commit 3b804999 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2420 failed with stages
in 0 seconds
# pytorch_diffusion + derived encoder decoder
import math
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from .movq_enc_3d import CausalConv3d, DownSample3D, Upsample3D
def cast_tuple(t, length=1):
return t if isinstance(t, tuple) else ((t, ) * length)
def divisible_by(num, den):
return (num % den) == 0
def is_odd(n):
return not divisible_by(n, 2)
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
class SpatialNorm3D(nn.Module):
def __init__(
self,
f_channels,
zq_channels,
norm_layer=nn.GroupNorm,
freeze_norm_layer=False,
add_conv=False,
pad_mode='constant',
**norm_layer_params,
):
super().__init__()
self.norm_layer = norm_layer(num_channels=f_channels,
**norm_layer_params)
if freeze_norm_layer:
for p in self.norm_layer.parameters:
p.requires_grad = False
self.add_conv = add_conv
if self.add_conv:
self.conv = CausalConv3d(zq_channels,
zq_channels,
kernel_size=3,
pad_mode=pad_mode)
self.conv_y = CausalConv3d(zq_channels,
f_channels,
kernel_size=1,
pad_mode=pad_mode)
self.conv_b = CausalConv3d(zq_channels,
f_channels,
kernel_size=1,
pad_mode=pad_mode)
def forward(self, f, zq):
if zq.shape[2] > 1:
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:]
zq_first = torch.nn.functional.interpolate(zq_first,
size=f_first_size,
mode='nearest')
zq_rest = torch.nn.functional.interpolate(zq_rest,
size=f_rest_size,
mode='nearest')
zq = torch.cat([zq_first, zq_rest], dim=2)
else:
zq = torch.nn.functional.interpolate(zq,
size=f.shape[-3:],
mode='nearest')
if self.add_conv:
zq = self.conv(zq)
norm_f = self.norm_layer(f)
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
return new_f
def Normalize3D(in_channels, zq_ch, add_conv):
return SpatialNorm3D(
in_channels,
zq_ch,
norm_layer=nn.GroupNorm,
freeze_norm_layer=False,
add_conv=add_conv,
num_groups=32,
eps=1e-6,
affine=True,
)
class ResnetBlock3D(nn.Module):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout,
temb_channels=512,
zq_ch=None,
add_conv=False,
pad_mode='constant',
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize3D(in_channels, zq_ch, add_conv=add_conv)
self.conv1 = CausalConv3d(in_channels,
out_channels,
kernel_size=3,
pad_mode=pad_mode)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize3D(out_channels, zq_ch, add_conv=add_conv)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = CausalConv3d(out_channels,
out_channels,
kernel_size=3,
pad_mode=pad_mode)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = CausalConv3d(in_channels,
out_channels,
kernel_size=3,
pad_mode=pad_mode)
else:
self.nin_shortcut = torch.nn.Conv3d(in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x, temb, zq):
h = x
h = self.norm1(h, zq)
h = nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]
h = self.norm2(h, zq)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class AttnBlock2D(nn.Module):
def __init__(self, in_channels, zq_ch=None, add_conv=False):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize3D(in_channels, zq_ch, add_conv=add_conv)
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x, zq):
h_ = x
h_ = self.norm(h_, zq)
t = h_.shape[2]
h_ = rearrange(h_, 'b c t h w -> (b t) c h w')
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1) # b,hw,c
k = k.reshape(b, c, h * w) # b,c,hw
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h * w)
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(
v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
h_ = rearrange(h_, '(b t) c h w -> b c t h w', t=t)
return x + h_
class MOVQDecoder3D(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
zq_ch=None,
add_conv=False,
pad_mode='first',
temporal_compress_times=4,
**ignorekwargs,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
# log2 of temporal_compress_times
self.temporal_compress_level = int(np.log2(temporal_compress_times))
if zq_ch is None:
zq_ch = z_channels
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2**(self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
self.conv_in = CausalConv3d(z_channels,
block_in,
kernel_size=3,
pad_mode=pad_mode)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
pad_mode=pad_mode,
)
self.mid.block_2 = ResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
pad_mode=pad_mode,
)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
ResnetBlock3D(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
pad_mode=pad_mode,
))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttnBlock2D(block_in, zq_ch,
add_conv=add_conv))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
if i_level < self.num_resolutions - self.temporal_compress_level:
up.upsample = Upsample3D(block_in,
resamp_with_conv,
compress_time=False)
else:
up.upsample = Upsample3D(block_in,
resamp_with_conv,
compress_time=True)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv)
self.conv_out = CausalConv3d(block_in,
out_ch,
kernel_size=3,
pad_mode=pad_mode)
def forward(self, z, use_cp=False):
self.last_z_shape = z.shape
# timestep embedding
temb = None
t = z.shape[2]
# z to block_in
zq = z
h = self.conv_in(z)
# middle
h = self.mid.block_1(h, temb, zq)
# h = self.mid.attn_1(h, zq)
h = self.mid.block_2(h, temb, zq)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb, zq)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, zq)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h, zq)
h = nonlinearity(h)
h = self.conv_out(h)
return h
def get_last_layer(self):
return self.conv_out.conv.weight
class NewDecoder3D(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
zq_ch=None,
add_conv=False,
pad_mode='first',
temporal_compress_times=4,
post_quant_conv=False,
**ignorekwargs,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
# log2 of temporal_compress_times
self.temporal_compress_level = int(np.log2(temporal_compress_times))
if zq_ch is None:
zq_ch = z_channels
if post_quant_conv:
self.post_quant_conv = CausalConv3d(zq_ch,
z_channels,
kernel_size=3,
pad_mode=pad_mode)
else:
self.post_quant_conv = None
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1, ) + tuple(ch_mult)
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2**(self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
print('Working with z of shape {} = {} dimensions.'.format(
self.z_shape, np.prod(self.z_shape)))
# z to block_in
# self.conv_in = torch.nn.Conv3d(z_channels,
# block_in,
# kernel_size=3,
# stride=1,
# padding=1)
self.conv_in = CausalConv3d(z_channels,
block_in,
kernel_size=3,
pad_mode=pad_mode)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
pad_mode=pad_mode,
)
# remove attention block
# self.mid.attn_1 = AttnBlock2D(block_in, zq_ch, add_conv=add_conv)
self.mid.block_2 = ResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
pad_mode=pad_mode,
)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
ResnetBlock3D(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
pad_mode=pad_mode,
))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttnBlock2D(block_in, zq_ch,
add_conv=add_conv))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
if i_level < self.num_resolutions - self.temporal_compress_level:
up.upsample = Upsample3D(block_in,
resamp_with_conv,
compress_time=False)
else:
up.upsample = Upsample3D(block_in,
resamp_with_conv,
compress_time=True)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv)
# self.conv_out = torch.nn.Conv3d(block_in,
# out_ch,
# kernel_size=3,
# stride=1,
# padding=1)
self.conv_out = CausalConv3d(block_in,
out_ch,
kernel_size=3,
pad_mode=pad_mode)
def forward(self, z):
# assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape
# timestep embedding
temb = None
t = z.shape[2]
# z to block_in
zq = z
if self.post_quant_conv is not None:
z = self.post_quant_conv(z)
h = self.conv_in(z)
# middle
h = self.mid.block_1(h, temb, zq)
# h = self.mid.attn_1(h, zq)
h = self.mid.block_2(h, temb, zq)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb, zq)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, zq)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h, zq)
h = nonlinearity(h)
h = self.conv_out(h)
return h
def get_last_layer(self):
return self.conv_out.conv.weight
# pytorch_diffusion + derived encoder decoder
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from beartype import beartype
from beartype.typing import List, Optional, Tuple, Union
from einops import rearrange
from .movq_enc_3d import CausalConv3d, DownSample3D, Upsample3D
def cast_tuple(t, length=1):
return t if isinstance(t, tuple) else ((t, ) * length)
def divisible_by(num, den):
return (num % den) == 0
def is_odd(n):
return not divisible_by(n, 2)
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
class SpatialNorm3D(nn.Module):
def __init__(
self,
f_channels,
zq_channels,
norm_layer=nn.GroupNorm,
freeze_norm_layer=False,
add_conv=False,
pad_mode='constant',
**norm_layer_params,
):
super().__init__()
self.norm_layer = norm_layer(num_channels=f_channels,
**norm_layer_params)
if freeze_norm_layer:
for p in self.norm_layer.parameters:
p.requires_grad = False
self.add_conv = add_conv
if self.add_conv:
# self.conv = nn.Conv3d(zq_channels, zq_channels, kernel_size=3, stride=1, padding=1)
self.conv = CausalConv3d(zq_channels,
zq_channels,
kernel_size=3,
pad_mode=pad_mode)
# self.conv_y = nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
# self.conv_b = nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
self.conv_y = CausalConv3d(zq_channels,
f_channels,
kernel_size=1,
pad_mode=pad_mode)
self.conv_b = CausalConv3d(zq_channels,
f_channels,
kernel_size=1,
pad_mode=pad_mode)
def forward(self, f, zq):
if zq.shape[2] > 1:
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:]
zq_first = torch.nn.functional.interpolate(zq_first,
size=f_first_size,
mode='nearest')
zq_rest = torch.nn.functional.interpolate(zq_rest,
size=f_rest_size,
mode='nearest')
zq = torch.cat([zq_first, zq_rest], dim=2)
else:
zq = torch.nn.functional.interpolate(zq,
size=f.shape[-3:],
mode='nearest')
if self.add_conv:
zq = self.conv(zq)
norm_f = self.norm_layer(f)
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
return new_f
def Normalize3D(in_channels, zq_ch, add_conv):
return SpatialNorm3D(
in_channels,
zq_ch,
norm_layer=nn.GroupNorm,
freeze_norm_layer=False,
add_conv=add_conv,
num_groups=32,
eps=1e-6,
affine=True,
)
class ResnetBlock3D(nn.Module):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout,
temb_channels=512,
zq_ch=None,
add_conv=False,
pad_mode='constant',
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize3D(in_channels, zq_ch, add_conv=add_conv)
# self.conv1 = torch.nn.Conv3d(in_channels,
# out_channels,
# kernel_size=3,
# stride=1,
# padding=1)
self.conv1 = CausalConv3d(in_channels,
out_channels,
kernel_size=3,
pad_mode=pad_mode)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize3D(out_channels, zq_ch, add_conv=add_conv)
self.dropout = torch.nn.Dropout(dropout)
# self.conv2 = torch.nn.Conv3d(out_channels,
# out_channels,
# kernel_size=3,
# stride=1,
# padding=1)
self.conv2 = CausalConv3d(out_channels,
out_channels,
kernel_size=3,
pad_mode=pad_mode)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
# self.conv_shortcut = torch.nn.Conv3d(in_channels,
# out_channels,
# kernel_size=3,
# stride=1,
# padding=1)
self.conv_shortcut = CausalConv3d(in_channels,
out_channels,
kernel_size=3,
pad_mode=pad_mode)
else:
self.nin_shortcut = torch.nn.Conv3d(in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0)
# self.nin_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=1, pad_mode=pad_mode)
def forward(self, x, temb, zq):
h = x
h = self.norm1(h, zq)
h = nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]
h = self.norm2(h, zq)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class AttnBlock2D(nn.Module):
def __init__(self, in_channels, zq_ch=None, add_conv=False):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize3D(in_channels, zq_ch, add_conv=add_conv)
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x, zq):
h_ = x
h_ = self.norm(h_, zq)
t = h_.shape[2]
h_ = rearrange(h_, 'b c t h w -> (b t) c h w')
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1) # b,hw,c
k = k.reshape(b, c, h * w) # b,c,hw
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h * w)
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(
v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
h_ = rearrange(h_, '(b t) c h w -> b c t h w', t=t)
return x + h_
class MOVQDecoder3D(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
zq_ch=None,
add_conv=False,
pad_mode='first',
temporal_compress_times=4,
**ignorekwargs,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
# log2 of temporal_compress_times
self.temporal_compress_level = int(np.log2(temporal_compress_times))
if zq_ch is None:
zq_ch = z_channels
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1, ) + tuple(ch_mult)
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2**(self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
print('Working with z of shape {} = {} dimensions.'.format(
self.z_shape, np.prod(self.z_shape)))
# z to block_in
# self.conv_in = torch.nn.Conv3d(z_channels,
# block_in,
# kernel_size=3,
# stride=1,
# padding=1)
self.conv_in = CausalConv3d(z_channels,
block_in,
kernel_size=3,
pad_mode=pad_mode)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
pad_mode=pad_mode,
)
# remove attention block
# self.mid.attn_1 = AttnBlock2D(block_in, zq_ch, add_conv=add_conv)
self.mid.block_2 = ResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
pad_mode=pad_mode,
)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
ResnetBlock3D(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
pad_mode=pad_mode,
))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttnBlock2D(block_in, zq_ch,
add_conv=add_conv))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
if i_level < self.num_resolutions - self.temporal_compress_level:
up.upsample = Upsample3D(block_in,
resamp_with_conv,
compress_time=False)
else:
up.upsample = Upsample3D(block_in,
resamp_with_conv,
compress_time=True)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv)
# self.conv_out = torch.nn.Conv3d(block_in,
# out_ch,
# kernel_size=3,
# stride=1,
# padding=1)
self.conv_out = CausalConv3d(block_in,
out_ch,
kernel_size=3,
pad_mode=pad_mode)
def forward(self, z, use_cp=False):
# assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape
# timestep embedding
temb = None
t = z.shape[2]
# z to block_in
zq = z
h = self.conv_in(z)
# middle
h = self.mid.block_1(h, temb, zq)
# h = self.mid.attn_1(h, zq)
h = self.mid.block_2(h, temb, zq)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb, zq)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, zq)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h, zq)
h = nonlinearity(h)
h = self.conv_out(h)
return h
def get_last_layer(self):
return self.conv_out.conv.weight
class NewDecoder3D(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
zq_ch=None,
add_conv=False,
pad_mode='first',
temporal_compress_times=4,
post_quant_conv=False,
**ignorekwargs,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
# log2 of temporal_compress_times
self.temporal_compress_level = int(np.log2(temporal_compress_times))
if zq_ch is None:
zq_ch = z_channels
if post_quant_conv:
self.post_quant_conv = CausalConv3d(zq_ch,
z_channels,
kernel_size=3,
pad_mode=pad_mode)
else:
self.post_quant_conv = None
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1, ) + tuple(ch_mult)
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2**(self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
print('Working with z of shape {} = {} dimensions.'.format(
self.z_shape, np.prod(self.z_shape)))
# z to block_in
# self.conv_in = torch.nn.Conv3d(z_channels,
# block_in,
# kernel_size=3,
# stride=1,
# padding=1)
self.conv_in = CausalConv3d(z_channels,
block_in,
kernel_size=3,
pad_mode=pad_mode)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
pad_mode=pad_mode,
)
# remove attention block
# self.mid.attn_1 = AttnBlock2D(block_in, zq_ch, add_conv=add_conv)
self.mid.block_2 = ResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
pad_mode=pad_mode,
)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
ResnetBlock3D(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
pad_mode=pad_mode,
))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttnBlock2D(block_in, zq_ch,
add_conv=add_conv))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
if i_level < self.num_resolutions - self.temporal_compress_level:
up.upsample = Upsample3D(block_in,
resamp_with_conv,
compress_time=False)
else:
up.upsample = Upsample3D(block_in,
resamp_with_conv,
compress_time=True)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv)
# self.conv_out = torch.nn.Conv3d(block_in,
# out_ch,
# kernel_size=3,
# stride=1,
# padding=1)
self.conv_out = CausalConv3d(block_in,
out_ch,
kernel_size=3,
pad_mode=pad_mode)
def forward(self, z):
# assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape
# timestep embedding
temb = None
t = z.shape[2]
# z to block_in
zq = z
if self.post_quant_conv is not None:
z = self.post_quant_conv(z)
h = self.conv_in(z)
# middle
h = self.mid.block_1(h, temb, zq)
# h = self.mid.attn_1(h, zq)
h = self.mid.block_2(h, temb, zq)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb, zq)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, zq)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h, zq)
h = nonlinearity(h)
h = self.conv_out(h)
return h
def get_last_layer(self):
return self.conv_out.conv.weight
# pytorch_diffusion + derived encoder decoder
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from beartype import beartype
from beartype.typing import List, Optional, Tuple, Union
from einops import rearrange
def cast_tuple(t, length=1):
return t if isinstance(t, tuple) else ((t, ) * length)
def divisible_by(num, den):
return (num % den) == 0
def is_odd(n):
return not divisible_by(n, 2)
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
class CausalConv3d(nn.Module):
@beartype
def __init__(self,
chan_in,
chan_out,
kernel_size: Union[int, Tuple[int, int, int]],
pad_mode='constant',
**kwargs):
super().__init__()
kernel_size = cast_tuple(kernel_size, 3)
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
dilation = kwargs.pop('dilation', 1)
stride = kwargs.pop('stride', 1)
self.pad_mode = pad_mode
time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
height_pad = height_kernel_size // 2
width_pad = width_kernel_size // 2
self.height_pad = height_pad
self.width_pad = width_pad
self.time_pad = time_pad
self.time_causal_padding = (width_pad, width_pad, height_pad,
height_pad, time_pad, 0)
stride = (stride, 1, 1)
dilation = (dilation, 1, 1)
self.conv = nn.Conv3d(chan_in,
chan_out,
kernel_size,
stride=stride,
dilation=dilation,
**kwargs)
def forward(self, x):
if self.pad_mode == 'constant':
causal_padding_3d = (self.time_pad, 0, self.width_pad,
self.width_pad, self.height_pad,
self.height_pad)
x = F.pad(x, causal_padding_3d, mode='constant', value=0)
elif self.pad_mode == 'first':
pad_x = torch.cat([x[:, :, :1]] * self.time_pad, dim=2)
x = torch.cat([pad_x, x], dim=2)
causal_padding_2d = (self.width_pad, self.width_pad,
self.height_pad, self.height_pad)
x = F.pad(x, causal_padding_2d, mode='constant', value=0)
elif self.pad_mode == 'reflect':
# reflect padding
reflect_x = x[:, :, 1:self.time_pad + 1, :, :].flip(dims=[2])
if reflect_x.shape[2] < self.time_pad:
reflect_x = torch.cat([torch.zeros_like(x[:, :, :1, :, :])] *
(self.time_pad - reflect_x.shape[2]) +
[reflect_x],
dim=2)
x = torch.cat([reflect_x, x], dim=2)
causal_padding_2d = (self.width_pad, self.width_pad,
self.height_pad, self.height_pad)
x = F.pad(x, causal_padding_2d, mode='constant', value=0)
else:
raise ValueError('Invalid pad mode')
return self.conv(x)
def Normalize3D(in_channels): # same for 3D and 2D
return torch.nn.GroupNorm(num_groups=32,
num_channels=in_channels,
eps=1e-6,
affine=True)
class Upsample3D(nn.Module):
def __init__(self, in_channels, with_conv, compress_time=False):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1)
self.compress_time = compress_time
def forward(self, x):
if self.compress_time:
if x.shape[2] > 1:
# split first frame
x_first, x_rest = x[:, :, 0], x[:, :, 1:]
x_first = torch.nn.functional.interpolate(x_first,
scale_factor=2.0,
mode='nearest')
x_rest = torch.nn.functional.interpolate(x_rest,
scale_factor=2.0,
mode='nearest')
x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
else:
x = x.squeeze(2)
x = torch.nn.functional.interpolate(x,
scale_factor=2.0,
mode='nearest')
x = x[:, :, None, :, :]
else:
# only interpolate 2D
t = x.shape[2]
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = torch.nn.functional.interpolate(x,
scale_factor=2.0,
mode='nearest')
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
if self.with_conv:
t = x.shape[2]
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = self.conv(x)
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
return x
class DownSample3D(nn.Module):
def __init__(self,
in_channels,
with_conv,
compress_time=False,
out_channels=None):
super().__init__()
self.with_conv = with_conv
if out_channels is None:
out_channels = in_channels
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=2,
padding=0)
self.compress_time = compress_time
def forward(self, x):
if self.compress_time:
h, w = x.shape[-2:]
x = rearrange(x, 'b c t h w -> (b h w) c t')
# split first frame
x_first, x_rest = x[..., 0], x[..., 1:]
if x_rest.shape[-1] > 0:
x_rest = torch.nn.functional.avg_pool1d(x_rest,
kernel_size=2,
stride=2)
x = torch.cat([x_first[..., None], x_rest], dim=-1)
x = rearrange(x, '(b h w) c t -> b c t h w', h=h, w=w)
if self.with_conv:
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode='constant', value=0)
t = x.shape[2]
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = self.conv(x)
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
else:
t = x.shape[2]
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
return x
class ResnetBlock3D(nn.Module):
def __init__(self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout,
temb_channels=512,
pad_mode='constant'):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize3D(in_channels)
# self.conv1 = torch.nn.Conv3d(in_channels,
# out_channels,
# kernel_size=3,
# stride=1,
# padding=1)
self.conv1 = CausalConv3d(in_channels,
out_channels,
kernel_size=3,
pad_mode=pad_mode)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize3D(out_channels)
self.dropout = torch.nn.Dropout(dropout)
# self.conv2 = torch.nn.Conv3d(out_channels,
# out_channels,
# kernel_size=3,
# stride=1,
# padding=1)
self.conv2 = CausalConv3d(out_channels,
out_channels,
kernel_size=3,
pad_mode=pad_mode)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
# self.conv_shortcut = torch.nn.Conv3d(in_channels,
# out_channels,
# kernel_size=3,
# stride=1,
# padding=1)
self.conv_shortcut = CausalConv3d(in_channels,
out_channels,
kernel_size=3,
pad_mode=pad_mode)
else:
self.nin_shortcut = torch.nn.Conv3d(in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0)
# self.nin_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=1, pad_mode=pad_mode)
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class AttnBlock2D(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize3D(in_channels)
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
t = h_.shape[2]
h_ = rearrange(h_, 'b c t h w -> (b t) c h w')
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1) # b,hw,c
k = k.reshape(b, c, h * w) # b,c,hw
# # original version, nan in fp16
# w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
# w_ = w_ * (int(c)**(-0.5))
# # implement c**-0.5 on q
q = q * (int(c)**(-0.5))
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h * w)
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(
v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
h_ = rearrange(h_, '(b t) c h w -> b c t h w', t=t)
return x + h_
class Encoder3D(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
double_z=True,
pad_mode='first',
temporal_compress_times=4,
**ignore_kwargs,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# log2 of temporal_compress_times
self.temporal_compress_level = int(np.log2(temporal_compress_times))
# downsampling
# self.conv_in = torch.nn.Conv3d(in_channels,
# self.ch,
# kernel_size=3,
# stride=1,
# padding=1)
self.conv_in = CausalConv3d(in_channels,
self.ch,
kernel_size=3,
pad_mode=pad_mode)
curr_res = resolution
in_ch_mult = (1, ) + tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ResnetBlock3D(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
pad_mode=pad_mode,
))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttnBlock2D(block_in))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
if i_level < self.temporal_compress_level:
down.downsample = DownSample3D(block_in,
resamp_with_conv,
compress_time=True)
else:
down.downsample = DownSample3D(block_in,
resamp_with_conv,
compress_time=False)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock3D(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
pad_mode=pad_mode)
# remove attention block
# self.mid.attn_1 = AttnBlock2D(block_in)
self.mid.block_2 = ResnetBlock3D(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
pad_mode=pad_mode)
# end
self.norm_out = Normalize3D(block_in)
# self.conv_out = torch.nn.Conv3d(block_in,
# 2*z_channels if double_z else z_channels,
# kernel_size=3,
# stride=1,
# padding=1)
self.conv_out = CausalConv3d(block_in,
2 *
z_channels if double_z else z_channels,
kernel_size=3,
pad_mode=pad_mode)
def forward(self, x, use_cp=False):
# assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
# timestep embedding
temb = None
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
# h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
# pytorch_diffusion + derived encoder decoder
import math
import numpy as np
import torch
import torch.nn as nn
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
class SpatialNorm(nn.Module):
def __init__(
self,
f_channels,
zq_channels,
norm_layer=nn.GroupNorm,
freeze_norm_layer=False,
add_conv=False,
**norm_layer_params,
):
super().__init__()
self.norm_layer = norm_layer(num_channels=f_channels,
**norm_layer_params)
if freeze_norm_layer:
for p in self.norm_layer.parameters:
p.requires_grad = False
self.add_conv = add_conv
if self.add_conv:
self.conv = nn.Conv2d(zq_channels,
zq_channels,
kernel_size=3,
stride=1,
padding=1)
self.conv_y = nn.Conv2d(zq_channels,
f_channels,
kernel_size=1,
stride=1,
padding=0)
self.conv_b = nn.Conv2d(zq_channels,
f_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, f, zq):
f_size = f.shape[-2:]
zq = torch.nn.functional.interpolate(zq, size=f_size, mode='nearest')
if self.add_conv:
zq = self.conv(zq)
norm_f = self.norm_layer(f)
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
return new_f
def Normalize(in_channels, zq_ch, add_conv):
return SpatialNorm(
in_channels,
zq_ch,
norm_layer=nn.GroupNorm,
freeze_norm_layer=False,
add_conv=add_conv,
num_groups=32,
eps=1e-6,
affine=True,
)
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
x = torch.nn.functional.interpolate(x,
scale_factor=2.0,
mode='nearest')
if self.with_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=2,
padding=0)
def forward(self, x):
if self.with_conv:
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode='constant', value=0)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
class ResnetBlock(nn.Module):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout,
temb_channels=512,
zq_ch=None,
add_conv=False,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels, zq_ch, add_conv=add_conv)
self.conv1 = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize(out_channels, zq_ch, add_conv=add_conv)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x, temb, zq):
h = x
h = self.norm1(h, zq)
h = nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
h = self.norm2(h, zq)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class AttnBlock(nn.Module):
def __init__(self, in_channels, zq_ch=None, add_conv=False):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels, zq_ch, add_conv=add_conv)
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x, zq):
h_ = x
h_ = self.norm(h_, zq)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1) # b,hw,c
k = k.reshape(b, c, h * w) # b,c,hw
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h * w)
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(
v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
return x + h_
class MOVQDecoder(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
zq_ch=None,
add_conv=False,
**ignorekwargs,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1, ) + tuple(ch_mult)
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2**(self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
print('Working with z of shape {} = {} dimensions.'.format(
self.z_shape, np.prod(self.z_shape)))
# z to block_in
self.conv_in = torch.nn.Conv2d(z_channels,
block_in,
kernel_size=3,
stride=1,
padding=1)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
)
self.mid.attn_1 = AttnBlock(block_in, zq_ch, add_conv=add_conv)
self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttnBlock(block_in, zq_ch, add_conv=add_conv))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in, zq_ch, add_conv=add_conv)
self.conv_out = torch.nn.Conv2d(block_in,
out_ch,
kernel_size=3,
stride=1,
padding=1)
def forward(self, z, zq):
# assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape
# timestep embedding
temb = None
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h, temb, zq)
h = self.mid.attn_1(h, zq)
h = self.mid.block_2(h, temb, zq)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb, zq)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, zq)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h, zq)
h = nonlinearity(h)
h = self.conv_out(h)
return h
def forward_with_features_output(self, z, zq):
# assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape
# timestep embedding
temb = None
output_features = {}
# z to block_in
h = self.conv_in(z)
output_features['conv_in'] = h
# middle
h = self.mid.block_1(h, temb, zq)
output_features['mid_block_1'] = h
h = self.mid.attn_1(h, zq)
output_features['mid_attn_1'] = h
h = self.mid.block_2(h, temb, zq)
output_features['mid_block_2'] = h
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb, zq)
output_features[f'up_{i_level}_block_{i_block}'] = h
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, zq)
output_features[f'up_{i_level}_attn_{i_block}'] = h
if i_level != 0:
h = self.up[i_level].upsample(h)
output_features[f'up_{i_level}_upsample'] = h
# end
if self.give_pre_end:
return h
h = self.norm_out(h, zq)
output_features['norm_out'] = h
h = nonlinearity(h)
output_features['nonlinearity'] = h
h = self.conv_out(h)
output_features['conv_out'] = h
return h, output_features
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch import einsum
class VectorQuantizer2(nn.Module):
"""
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
"""
# NOTE: due to a bug the beta term was applied to the wrong term. for
# backwards compatibility we use the buggy version by default, but you can
# specify legacy=False to fix it.
def __init__(self,
n_e,
e_dim,
beta,
remap=None,
unknown_index='random',
sane_index_shape=False,
legacy=True):
super().__init__()
self.n_e = n_e
self.e_dim = e_dim
self.beta = beta
self.legacy = legacy
self.embedding = nn.Embedding(self.n_e, self.e_dim)
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
self.remap = remap
if self.remap is not None:
self.register_buffer('used', torch.tensor(np.load(self.remap)))
self.re_embed = self.used.shape[0]
self.unknown_index = unknown_index # "random" or "extra" or integer
if self.unknown_index == 'extra':
self.unknown_index = self.re_embed
self.re_embed = self.re_embed + 1
print(f'Remapping {self.n_e} indices to {self.re_embed} indices. '
f'Using {self.unknown_index} for unknown indices.')
else:
self.re_embed = n_e
self.sane_index_shape = sane_index_shape
def remap_to_used(self, inds):
ishape = inds.shape
assert len(ishape) > 1
inds = inds.reshape(ishape[0], -1)
used = self.used.to(inds)
match = (inds[:, :, None] == used[None, None, ...]).long()
new = match.argmax(-1)
unknown = match.sum(2) < 1
if self.unknown_index == 'random':
new[unknown] = torch.randint(
0, self.re_embed,
size=new[unknown].shape).to(device=new.device)
else:
new[unknown] = self.unknown_index
return new.reshape(ishape)
def unmap_to_all(self, inds):
ishape = inds.shape
assert len(ishape) > 1
inds = inds.reshape(ishape[0], -1)
used = self.used.to(inds)
if self.re_embed > self.used.shape[0]: # extra token
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
return back.reshape(ishape)
def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
assert temp is None or temp == 1.0, 'Only for interface compatible with Gumbel'
assert rescale_logits == False, 'Only for interface compatible with Gumbel'
assert return_logits == False, 'Only for interface compatible with Gumbel'
# reshape z -> (batch, height, width, channel) and flatten
z = rearrange(z, 'b c h w -> b h w c').contiguous()
z_flattened = z.view(-1, self.e_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = (torch.sum(z_flattened**2, dim=1, keepdim=True) +
torch.sum(self.embedding.weight**2, dim=1) -
2 * torch.einsum('bd,dn->bn', z_flattened,
rearrange(self.embedding.weight, 'n d -> d n')))
min_encoding_indices = torch.argmin(d, dim=1)
z_q = self.embedding(min_encoding_indices).view(z.shape)
perplexity = None
min_encodings = None
# compute loss for embedding
if not self.legacy:
loss = self.beta * torch.mean((z_q.detach() - z)**2) + torch.mean(
(z_q - z.detach())**2)
else:
loss = torch.mean((z_q.detach() - z)**2) + self.beta * torch.mean(
(z_q - z.detach())**2)
# preserve gradients
z_q = z + (z_q - z).detach()
# reshape back to match original input shape
z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
if self.remap is not None:
min_encoding_indices = min_encoding_indices.reshape(
z.shape[0], -1) # add batch axis
min_encoding_indices = self.remap_to_used(min_encoding_indices)
min_encoding_indices = min_encoding_indices.reshape(-1,
1) # flatten
if self.sane_index_shape:
min_encoding_indices = min_encoding_indices.reshape(
z_q.shape[0], z_q.shape[2], z_q.shape[3])
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
def get_codebook_entry(self, indices, shape):
# shape specifying (batch, height, width, channel)
if self.remap is not None:
indices = indices.reshape(shape[0], -1) # add batch axis
indices = self.unmap_to_all(indices)
indices = indices.reshape(-1) # flatten again
# get quantized latent vectors
z_q = self.embedding(indices)
if shape is not None:
z_q = z_q.view(shape)
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return z_q
class GumbelQuantize(nn.Module):
"""
credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
Gumbel Softmax trick quantizer
Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
https://arxiv.org/abs/1611.01144
"""
def __init__(
self,
num_hiddens,
embedding_dim,
n_embed,
straight_through=True,
kl_weight=5e-4,
temp_init=1.0,
use_vqinterface=True,
remap=None,
unknown_index='random',
):
super().__init__()
self.embedding_dim = embedding_dim
self.n_embed = n_embed
self.straight_through = straight_through
self.temperature = temp_init
self.kl_weight = kl_weight
self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
self.embed = nn.Embedding(n_embed, embedding_dim)
self.use_vqinterface = use_vqinterface
self.remap = remap
if self.remap is not None:
self.register_buffer('used', torch.tensor(np.load(self.remap)))
self.re_embed = self.used.shape[0]
self.unknown_index = unknown_index # "random" or "extra" or integer
if self.unknown_index == 'extra':
self.unknown_index = self.re_embed
self.re_embed = self.re_embed + 1
print(
f'Remapping {self.n_embed} indices to {self.re_embed} indices. '
f'Using {self.unknown_index} for unknown indices.')
else:
self.re_embed = n_embed
def remap_to_used(self, inds):
ishape = inds.shape
assert len(ishape) > 1
inds = inds.reshape(ishape[0], -1)
used = self.used.to(inds)
match = (inds[:, :, None] == used[None, None, ...]).long()
new = match.argmax(-1)
unknown = match.sum(2) < 1
if self.unknown_index == 'random':
new[unknown] = torch.randint(
0, self.re_embed,
size=new[unknown].shape).to(device=new.device)
else:
new[unknown] = self.unknown_index
return new.reshape(ishape)
def unmap_to_all(self, inds):
ishape = inds.shape
assert len(ishape) > 1
inds = inds.reshape(ishape[0], -1)
used = self.used.to(inds)
if self.re_embed > self.used.shape[0]: # extra token
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
return back.reshape(ishape)
def forward(self, z, temp=None, return_logits=False):
# force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work
hard = self.straight_through if self.training else True
temp = self.temperature if temp is None else temp
logits = self.proj(z)
if self.remap is not None:
# continue only with used logits
full_zeros = torch.zeros_like(logits)
logits = logits[:, self.used, ...]
soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
if self.remap is not None:
# go back to all entries but unused set to zero
full_zeros[:, self.used, ...] = soft_one_hot
soft_one_hot = full_zeros
z_q = einsum('b n h w, n d -> b d h w', soft_one_hot,
self.embed.weight)
# + kl divergence to the prior loss
qy = F.softmax(logits, dim=1)
diff = self.kl_weight * torch.sum(
qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
ind = soft_one_hot.argmax(dim=1)
if self.remap is not None:
ind = self.remap_to_used(ind)
if self.use_vqinterface:
if return_logits:
return z_q, diff, (None, None, ind), logits
return z_q, diff, (None, None, ind)
return z_q, diff, ind
def get_codebook_entry(self, indices, shape):
b, h, w, c = shape
assert b * h * w == indices.shape[0]
indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w)
if self.remap is not None:
indices = self.unmap_to_all(indices)
one_hot = F.one_hot(indices,
num_classes=self.n_embed).permute(0, 3, 1,
2).float()
z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight)
return z_q
# pytorch_diffusion + derived encoder decoder
import math
import numpy as np
import torch
import torch.nn as nn
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32,
num_channels=in_channels,
eps=1e-6,
affine=True)
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
x = torch.nn.functional.interpolate(x,
scale_factor=2.0,
mode='nearest')
if self.with_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=2,
padding=0)
def forward(self, x):
if self.with_conv:
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode='constant', value=0)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
class ResnetBlock(nn.Module):
def __init__(self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout,
temb_channels=512):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1) # b,hw,c
k = k.reshape(b, c, h * w) # b,c,hw
# # original version, nan in fp16
# w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
# w_ = w_ * (int(c)**(-0.5))
# # implement c**-0.5 on q
q = q * (int(c)**(-0.5))
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h * w)
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(
v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
return x + h_
class Encoder(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
double_z=True,
**ignore_kwargs,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# downsampling
self.conv_in = torch.nn.Conv2d(in_channels,
self.ch,
kernel_size=3,
stride=1,
padding=1)
curr_res = resolution
in_ch_mult = (1, ) + tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttnBlock(block_in))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
2 *
z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
# assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
# timestep embedding
temb = None
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
def forward_with_features_output(self, x):
# assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
# timestep embedding
temb = None
output_features = {}
# downsampling
hs = [self.conv_in(x)]
output_features['conv_in'] = hs[-1]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
output_features[f'down{i_level}_block{i_block}'] = h
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
output_features['down{}_attn{}'.format(i_level,
i_block)] = h
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
output_features[f'down{i_level}_downsample'] = hs[-1]
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
output_features['mid_block_1'] = h
h = self.mid.attn_1(h)
output_features['mid_attn_1'] = h
h = self.mid.block_2(h, temb)
output_features['mid_block_2'] = h
# end
h = self.norm_out(h)
output_features['norm_out'] = h
h = nonlinearity(h)
output_features['nonlinearity'] = h
h = self.conv_out(h)
output_features['conv_out'] = h
return h, output_features
class Decoder(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
**ignorekwargs,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1, ) + tuple(ch_mult)
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2**(self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
print('Working with z of shape {} = {} dimensions.'.format(
self.z_shape, np.prod(self.z_shape)))
# z to block_in
self.conv_in = torch.nn.Conv2d(z_channels,
block_in,
kernel_size=3,
stride=1,
padding=1)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttnBlock(block_in))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
out_ch,
kernel_size=3,
stride=1,
padding=1)
def forward(self, z):
# assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape
# timestep embedding
temb = None
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
import math
import torch
import torch.distributed
import torch.nn as nn
from ..util import (get_context_parallel_group, get_context_parallel_rank,
get_context_parallel_world_size)
_USE_CP = True
def cast_tuple(t, length=1):
return t if isinstance(t, tuple) else ((t, ) * length)
def divisible_by(num, den):
return (num % den) == 0
def is_odd(n):
return not divisible_by(n, 2)
def exists(v):
return v is not None
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
def leaky_relu(p=0.1):
return nn.LeakyReLU(p)
def _split(input_, dim):
cp_world_size = get_context_parallel_world_size()
if cp_world_size == 1:
return input_
cp_rank = get_context_parallel_rank()
# print('in _split, cp_rank:', cp_rank, 'input_size:', input_.shape)
inpu_first_frame_ = input_.transpose(0,
dim)[:1].transpose(0,
dim).contiguous()
input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
dim_size = input_.size()[dim] // cp_world_size
input_list = torch.split(input_, dim_size, dim=dim)
output = input_list[cp_rank]
if cp_rank == 0:
output = torch.cat([inpu_first_frame_, output], dim=dim)
output = output.contiguous()
# print('out _split, cp_rank:', cp_rank, 'output_size:', output.shape)
return output
def _gather(input_, dim):
cp_world_size = get_context_parallel_world_size()
# Bypass the function if context parallel is 1
if cp_world_size == 1:
return input_
group = get_context_parallel_group()
cp_rank = get_context_parallel_rank()
# print('in _gather, cp_rank:', cp_rank, 'input_size:', input_.shape)
input_first_frame_ = input_.transpose(0,
dim)[:1].transpose(0,
dim).contiguous()
if cp_rank == 0:
input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
tensor_list = [
torch.empty_like(torch.cat([input_first_frame_, input_], dim=dim))
] + [torch.empty_like(input_) for _ in range(cp_world_size - 1)]
if cp_rank == 0:
input_ = torch.cat([input_first_frame_, input_], dim=dim)
tensor_list[cp_rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=group)
output = torch.cat(tensor_list, dim=dim).contiguous()
# print('out _gather, cp_rank:', cp_rank, 'output_size:', output.shape)
return output
def _conv_split(input_, dim, kernel_size):
cp_world_size = get_context_parallel_world_size()
# Bypass the function if context parallel is 1
if cp_world_size == 1:
return input_
# print('in _conv_split, cp_rank:', cp_rank, 'input_size:', input_.shape)
cp_rank = get_context_parallel_rank()
dim_size = (input_.size()[dim] - kernel_size) // cp_world_size
if cp_rank == 0:
output = input_.transpose(dim, 0)[:dim_size + kernel_size].transpose(
dim, 0)
else:
output = input_.transpose(
dim, 0)[cp_rank * dim_size + 1:(cp_rank + 1) * dim_size +
kernel_size].transpose(dim, 0)
output = output.contiguous()
# print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape)
return output
def _conv_gather(input_, dim, kernel_size):
cp_world_size = get_context_parallel_world_size()
# Bypass the function if context parallel is 1
if cp_world_size == 1:
return input_
group = get_context_parallel_group()
cp_rank = get_context_parallel_rank()
# print('in _conv_gather, cp_rank:', cp_rank, 'input_size:', input_.shape)
input_first_kernel_ = input_.transpose(0, dim)[:kernel_size].transpose(
0, dim).contiguous()
if cp_rank == 0:
input_ = input_.transpose(0, dim)[kernel_size:].transpose(
0, dim).contiguous()
else:
input_ = input_.transpose(0, dim)[kernel_size - 1:].transpose(
0, dim).contiguous()
tensor_list = [
torch.empty_like(torch.cat([input_first_kernel_, input_], dim=dim))
] + [torch.empty_like(input_) for _ in range(cp_world_size - 1)]
if cp_rank == 0:
input_ = torch.cat([input_first_kernel_, input_], dim=dim)
tensor_list[cp_rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=group)
# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=dim).contiguous()
# print('out _conv_gather, cp_rank:', cp_rank, 'input_size:', output.shape)
return output
from .denoiser import Denoiser
from .discretizer import Discretization
from .model import Decoder, Encoder, Model
from .openaimodel import UNetModel
from .sampling import BaseDiffusionSampler
from .wrappers import OpenAIWrapper
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment