Commit 0926dc24 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

save intermediate grad tts

parent 814133ec
This diff is collapsed.
......@@ -46,8 +46,8 @@ def conv_transpose_nd(dims, *args, **kwargs):
raise ValueError(f"unsupported dimensions: {dims}")
def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
def Normalize(in_channels, num_groups=32, eps=1e-6):
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=eps, affine=True)
def nonlinearity(x, swish=1.0):
......@@ -341,7 +341,7 @@ class ResBlock(TimestepBlock):
# unet.py
class ResnetBlock(nn.Module):
class OLD_ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
super().__init__()
self.in_channels = in_channels
......@@ -383,11 +383,129 @@ class ResnetBlock(nn.Module):
return x + h
class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, temb_channels=512, groups=32, pre_norm=True, eps=1e-6, non_linearity="swish", overwrite_for_grad_tts=False):
super().__init__()
self.pre_norm = pre_norm
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
if self.pre_norm:
self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps)
else:
self.norm1 = Normalize(out_channels, num_groups=groups, eps=eps)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = nonlinearity
elif non_linearity == "mish":
self.nonlinearity = Mish()
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.is_overwritten = False
self.overwrite_for_grad_tts = overwrite_for_grad_tts
if self.overwrite_for_grad_tts:
dim = in_channels
dim_out = out_channels
time_emb_dim = temb_channels
self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out))
self.pre_norm = pre_norm
self.block1 = Block(dim, dim_out, groups=groups)
self.block2 = Block(dim_out, dim_out, groups=groups)
if dim != dim_out:
self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
else:
self.res_conv = torch.nn.Identity()
# num_groups = 8
# self.pre_norm = False
# eps = 1e-5
# non_linearity = "mish"
def set_weights_grad_tts(self):
self.conv1.weight.data = self.block1.block[0].weight.data
self.conv1.bias.data = self.block1.block[0].bias.data
self.norm1.weight.data = self.block1.block[1].weight.data
self.norm1.bias.data = self.block1.block[1].bias.data
self.conv2.weight.data = self.block2.block[0].weight.data
self.conv2.bias.data = self.block2.block[0].bias.data
self.norm2.weight.data = self.block2.block[1].weight.data
self.norm2.bias.data = self.block2.block[1].bias.data
self.temb_proj.weight.data = self.mlp[1].weight.data
self.temb_proj.bias.data = self.mlp[1].bias.data
if self.in_channels != self.out_channels:
self.nin_shortcut.weight.data = self.res_conv.weight.data
self.nin_shortcut.bias.data = self.res_conv.bias.data
def forward(self, x, temb, mask=None):
if not self.pre_norm:
temp = mask
mask = temb
temb = temp
if self.overwrite_for_grad_tts and not self.is_overwritten:
self.set_weights_grad_tts()
self.is_overwritten = True
h = x
h = h * mask if mask is not None else h
if self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
h = self.conv1(h)
if not self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
h = h * mask if mask is not None else h
h = h + self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
h = h * mask if mask is not None else h
if self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if not self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
h = h * mask if mask is not None else h
x = x * mask if mask is not None else x
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
# unet_grad_tts.py
class ResnetBlockGradTTS(torch.nn.Module):
def __init__(self, dim, dim_out, time_emb_dim, groups=8):
def __init__(self, dim, dim_out, time_emb_dim, groups=8, eps=1e-6, overwrite=True, conv_shortcut=False, pre_norm=True):
super(ResnetBlockGradTTS, self).__init__()
self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out))
self.pre_norm = pre_norm
self.block1 = Block(dim, dim_out, groups=groups)
self.block2 = Block(dim_out, dim_out, groups=groups)
......@@ -396,45 +514,126 @@ class ResnetBlockGradTTS(torch.nn.Module):
else:
self.res_conv = torch.nn.Identity()
self.overwrite = overwrite
if self.overwrite:
in_channels = dim
out_channels = dim_out
temb_channels = time_emb_dim
# To set via init
self.pre_norm = False
eps = 1e-5
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
if self.pre_norm:
self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps)
else:
self.norm1 = Normalize(out_channels, num_groups=groups, eps=eps)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps)
dropout = 0.0
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.nonlinearity = Mish()
self.is_overwritten = False
def set_weights(self):
self.conv1.weight.data = self.block1.block[0].weight.data
self.conv1.bias.data = self.block1.block[0].bias.data
self.norm1.weight.data = self.block1.block[1].weight.data
self.norm1.bias.data = self.block1.block[1].bias.data
self.conv2.weight.data = self.block2.block[0].weight.data
self.conv2.bias.data = self.block2.block[0].bias.data
self.norm2.weight.data = self.block2.block[1].weight.data
self.norm2.bias.data = self.block2.block[1].bias.data
self.temb_proj.weight.data = self.mlp[1].weight.data
self.temb_proj.bias.data = self.mlp[1].bias.data
if self.in_channels != self.out_channels:
self.nin_shortcut.weight.data = self.res_conv.weight.data
self.nin_shortcut.bias.data = self.res_conv.bias.data
def forward(self, x, mask, time_emb):
h = self.block1(x, mask)
h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1)
h = self.block2(h, mask)
output = h + self.res_conv(x * mask)
output = self.forward_2(x, time_emb, mask=mask)
return output
def forward_2(self, x, temb, mask=None):
if not self.is_overwritten:
self.set_weights()
self.is_overwritten = True
# unet_rl.py
class ResidualTemporalBlock(nn.Module):
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
super().__init__()
if mask is None:
mask = torch.ones_like(x)
self.blocks = nn.ModuleList(
[
Conv1dBlock(inp_channels, out_channels, kernel_size),
Conv1dBlock(out_channels, out_channels, kernel_size),
]
)
h = x
self.time_mlp = nn.Sequential(
nn.Mish(),
nn.Linear(embed_dim, out_channels),
RearrangeDim(),
# Rearrange("batch t -> batch t 1"),
)
h = h * mask
if self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
self.residual_conv = (
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
h = self.conv1(h)
if not self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
h = h * mask
h = h + self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
h = h * mask
if self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if not self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
h = h * mask
x = x * mask
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class Block(torch.nn.Module):
def __init__(self, dim, dim_out, groups=8):
super(Block, self).__init__()
self.block = torch.nn.Sequential(
torch.nn.Conv2d(dim, dim_out, 3, padding=1), torch.nn.GroupNorm(groups, dim_out), Mish()
)
def forward(self, x, t):
"""
x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x
out_channels x horizon ]
"""
out = self.blocks[0](x) + self.time_mlp(t)
out = self.blocks[1](out)
return out + self.residual_conv(x)
def forward(self, x, mask):
output = self.block(x * mask)
return output * mask
# unet_score_estimation.py
......@@ -570,6 +769,39 @@ class ResnetBlockDDPMpp(nn.Module):
return (x + h) / np.sqrt(2.0)
# unet_rl.py
class ResidualTemporalBlock(nn.Module):
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
super().__init__()
self.blocks = nn.ModuleList(
[
Conv1dBlock(inp_channels, out_channels, kernel_size),
Conv1dBlock(out_channels, out_channels, kernel_size),
]
)
self.time_mlp = nn.Sequential(
nn.Mish(),
nn.Linear(embed_dim, out_channels),
RearrangeDim(),
# Rearrange("batch t -> batch t 1"),
)
self.residual_conv = (
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
)
def forward(self, x, t):
"""
x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x
out_channels x horizon ]
"""
out = self.blocks[0](x) + self.time_mlp(t)
out = self.blocks[1](out)
return out + self.residual_conv(x)
# HELPER Modules
......@@ -617,18 +849,6 @@ class Mish(torch.nn.Module):
return x * torch.tanh(torch.nn.functional.softplus(x))
class Block(torch.nn.Module):
def __init__(self, dim, dim_out, groups=8):
super(Block, self).__init__()
self.block = torch.nn.Sequential(
torch.nn.Conv2d(dim, dim_out, 3, padding=1), torch.nn.GroupNorm(groups, dim_out), Mish()
)
def forward(self, x, mask):
output = self.block(x * mask)
return output * mask
class Conv1dBlock(nn.Module):
"""
Conv1d --> GroupNorm --> Mish
......
......@@ -5,6 +5,7 @@ from ..modeling_utils import ModelMixin
from .attention import LinearAttention
from .embeddings import get_timestep_embedding
from .resnet import Downsample
from .resnet import ResnetBlock as ResnetBlockNew
from .resnet import ResnetBlockGradTTS as ResnetBlock
from .resnet import Upsample
......@@ -81,13 +82,20 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
self.ups = torch.nn.ModuleList([])
num_resolutions = len(in_out)
# num_groups = 8
# self.pre_norm = False
# eps = 1e-5
# non_linearity = "mish"
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
self.downs.append(
torch.nn.ModuleList(
[
ResnetBlock(dim_in, dim_out, time_emb_dim=dim),
ResnetBlock(dim_out, dim_out, time_emb_dim=dim),
# ResnetBlock(dim_in, dim_out, time_emb_dim=dim),
# ResnetBlock(dim_out, dim_out, time_emb_dim=dim),
ResnetBlockNew(in_channels=dim_in, out_channels=dim_out, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
ResnetBlockNew(in_channels=dim_out, out_channels=dim_out, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
Residual(Rezero(LinearAttention(dim_out))),
Downsample(dim_out, use_conv=True, padding=1) if not is_last else torch.nn.Identity(),
]
......@@ -95,16 +103,20 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
)
mid_dim = dims[-1]
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
# self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
# self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
self.mid_block1 = ResnetBlockNew(in_channels=mid_dim, out_channels=mid_dim, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True)
self.mid_attn = Residual(Rezero(LinearAttention(mid_dim)))
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
self.mid_block2 = ResnetBlockNew(in_channels=mid_dim, out_channels=mid_dim, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
self.ups.append(
torch.nn.ModuleList(
[
ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim),
ResnetBlock(dim_in, dim_in, time_emb_dim=dim),
# ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim),
# ResnetBlock(dim_in, dim_in, time_emb_dim=dim),
ResnetBlockNew(in_channels=dim_out * 2, out_channels=dim_in, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
ResnetBlockNew(in_channels=dim_in, out_channels=dim_in, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
Residual(Rezero(LinearAttention(dim_in))),
Upsample(dim_in, use_conv_transpose=True),
]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment