"vscode:/vscode.git/clone" did not exist on "7e2d18877e81b0d35eee9f97ab96b7fa6d69f608"
Commit 0926dc24 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

save intermediate grad tts

parent 814133ec
This diff is collapsed.
...@@ -46,8 +46,8 @@ def conv_transpose_nd(dims, *args, **kwargs): ...@@ -46,8 +46,8 @@ def conv_transpose_nd(dims, *args, **kwargs):
raise ValueError(f"unsupported dimensions: {dims}") raise ValueError(f"unsupported dimensions: {dims}")
def Normalize(in_channels): def Normalize(in_channels, num_groups=32, eps=1e-6):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=eps, affine=True)
def nonlinearity(x, swish=1.0): def nonlinearity(x, swish=1.0):
...@@ -341,7 +341,7 @@ class ResBlock(TimestepBlock): ...@@ -341,7 +341,7 @@ class ResBlock(TimestepBlock):
# unet.py # unet.py
class ResnetBlock(nn.Module): class OLD_ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
...@@ -383,11 +383,129 @@ class ResnetBlock(nn.Module): ...@@ -383,11 +383,129 @@ class ResnetBlock(nn.Module):
return x + h return x + h
class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, temb_channels=512, groups=32, pre_norm=True, eps=1e-6, non_linearity="swish", overwrite_for_grad_tts=False):
super().__init__()
self.pre_norm = pre_norm
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
if self.pre_norm:
self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps)
else:
self.norm1 = Normalize(out_channels, num_groups=groups, eps=eps)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = nonlinearity
elif non_linearity == "mish":
self.nonlinearity = Mish()
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.is_overwritten = False
self.overwrite_for_grad_tts = overwrite_for_grad_tts
if self.overwrite_for_grad_tts:
dim = in_channels
dim_out = out_channels
time_emb_dim = temb_channels
self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out))
self.pre_norm = pre_norm
self.block1 = Block(dim, dim_out, groups=groups)
self.block2 = Block(dim_out, dim_out, groups=groups)
if dim != dim_out:
self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
else:
self.res_conv = torch.nn.Identity()
# num_groups = 8
# self.pre_norm = False
# eps = 1e-5
# non_linearity = "mish"
def set_weights_grad_tts(self):
self.conv1.weight.data = self.block1.block[0].weight.data
self.conv1.bias.data = self.block1.block[0].bias.data
self.norm1.weight.data = self.block1.block[1].weight.data
self.norm1.bias.data = self.block1.block[1].bias.data
self.conv2.weight.data = self.block2.block[0].weight.data
self.conv2.bias.data = self.block2.block[0].bias.data
self.norm2.weight.data = self.block2.block[1].weight.data
self.norm2.bias.data = self.block2.block[1].bias.data
self.temb_proj.weight.data = self.mlp[1].weight.data
self.temb_proj.bias.data = self.mlp[1].bias.data
if self.in_channels != self.out_channels:
self.nin_shortcut.weight.data = self.res_conv.weight.data
self.nin_shortcut.bias.data = self.res_conv.bias.data
def forward(self, x, temb, mask=None):
if not self.pre_norm:
temp = mask
mask = temb
temb = temp
if self.overwrite_for_grad_tts and not self.is_overwritten:
self.set_weights_grad_tts()
self.is_overwritten = True
h = x
h = h * mask if mask is not None else h
if self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
h = self.conv1(h)
if not self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
h = h * mask if mask is not None else h
h = h + self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
h = h * mask if mask is not None else h
if self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if not self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
h = h * mask if mask is not None else h
x = x * mask if mask is not None else x
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
# unet_grad_tts.py # unet_grad_tts.py
class ResnetBlockGradTTS(torch.nn.Module): class ResnetBlockGradTTS(torch.nn.Module):
def __init__(self, dim, dim_out, time_emb_dim, groups=8): def __init__(self, dim, dim_out, time_emb_dim, groups=8, eps=1e-6, overwrite=True, conv_shortcut=False, pre_norm=True):
super(ResnetBlockGradTTS, self).__init__() super(ResnetBlockGradTTS, self).__init__()
self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out)) self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out))
self.pre_norm = pre_norm
self.block1 = Block(dim, dim_out, groups=groups) self.block1 = Block(dim, dim_out, groups=groups)
self.block2 = Block(dim_out, dim_out, groups=groups) self.block2 = Block(dim_out, dim_out, groups=groups)
...@@ -396,45 +514,126 @@ class ResnetBlockGradTTS(torch.nn.Module): ...@@ -396,45 +514,126 @@ class ResnetBlockGradTTS(torch.nn.Module):
else: else:
self.res_conv = torch.nn.Identity() self.res_conv = torch.nn.Identity()
self.overwrite = overwrite
if self.overwrite:
in_channels = dim
out_channels = dim_out
temb_channels = time_emb_dim
# To set via init
self.pre_norm = False
eps = 1e-5
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
if self.pre_norm:
self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps)
else:
self.norm1 = Normalize(out_channels, num_groups=groups, eps=eps)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps)
dropout = 0.0
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.nonlinearity = Mish()
self.is_overwritten = False
def set_weights(self):
self.conv1.weight.data = self.block1.block[0].weight.data
self.conv1.bias.data = self.block1.block[0].bias.data
self.norm1.weight.data = self.block1.block[1].weight.data
self.norm1.bias.data = self.block1.block[1].bias.data
self.conv2.weight.data = self.block2.block[0].weight.data
self.conv2.bias.data = self.block2.block[0].bias.data
self.norm2.weight.data = self.block2.block[1].weight.data
self.norm2.bias.data = self.block2.block[1].bias.data
self.temb_proj.weight.data = self.mlp[1].weight.data
self.temb_proj.bias.data = self.mlp[1].bias.data
if self.in_channels != self.out_channels:
self.nin_shortcut.weight.data = self.res_conv.weight.data
self.nin_shortcut.bias.data = self.res_conv.bias.data
def forward(self, x, mask, time_emb): def forward(self, x, mask, time_emb):
h = self.block1(x, mask) h = self.block1(x, mask)
h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1) h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1)
h = self.block2(h, mask) h = self.block2(h, mask)
output = h + self.res_conv(x * mask) output = h + self.res_conv(x * mask)
output = self.forward_2(x, time_emb, mask=mask)
return output return output
def forward_2(self, x, temb, mask=None):
if not self.is_overwritten:
self.set_weights()
self.is_overwritten = True
# unet_rl.py if mask is None:
class ResidualTemporalBlock(nn.Module): mask = torch.ones_like(x)
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
super().__init__()
self.blocks = nn.ModuleList( h = x
[
Conv1dBlock(inp_channels, out_channels, kernel_size),
Conv1dBlock(out_channels, out_channels, kernel_size),
]
)
self.time_mlp = nn.Sequential( h = h * mask
nn.Mish(), if self.pre_norm:
nn.Linear(embed_dim, out_channels), h = self.norm1(h)
RearrangeDim(), h = self.nonlinearity(h)
# Rearrange("batch t -> batch t 1"),
)
self.residual_conv = ( h = self.conv1(h)
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
if not self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
h = h * mask
h = h + self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
h = h * mask
if self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if not self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
h = h * mask
x = x * mask
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class Block(torch.nn.Module):
def __init__(self, dim, dim_out, groups=8):
super(Block, self).__init__()
self.block = torch.nn.Sequential(
torch.nn.Conv2d(dim, dim_out, 3, padding=1), torch.nn.GroupNorm(groups, dim_out), Mish()
) )
def forward(self, x, t): def forward(self, x, mask):
""" output = self.block(x * mask)
x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x return output * mask
out_channels x horizon ]
"""
out = self.blocks[0](x) + self.time_mlp(t)
out = self.blocks[1](out)
return out + self.residual_conv(x)
# unet_score_estimation.py # unet_score_estimation.py
...@@ -570,6 +769,39 @@ class ResnetBlockDDPMpp(nn.Module): ...@@ -570,6 +769,39 @@ class ResnetBlockDDPMpp(nn.Module):
return (x + h) / np.sqrt(2.0) return (x + h) / np.sqrt(2.0)
# unet_rl.py
class ResidualTemporalBlock(nn.Module):
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
super().__init__()
self.blocks = nn.ModuleList(
[
Conv1dBlock(inp_channels, out_channels, kernel_size),
Conv1dBlock(out_channels, out_channels, kernel_size),
]
)
self.time_mlp = nn.Sequential(
nn.Mish(),
nn.Linear(embed_dim, out_channels),
RearrangeDim(),
# Rearrange("batch t -> batch t 1"),
)
self.residual_conv = (
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
)
def forward(self, x, t):
"""
x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x
out_channels x horizon ]
"""
out = self.blocks[0](x) + self.time_mlp(t)
out = self.blocks[1](out)
return out + self.residual_conv(x)
# HELPER Modules # HELPER Modules
...@@ -617,18 +849,6 @@ class Mish(torch.nn.Module): ...@@ -617,18 +849,6 @@ class Mish(torch.nn.Module):
return x * torch.tanh(torch.nn.functional.softplus(x)) return x * torch.tanh(torch.nn.functional.softplus(x))
class Block(torch.nn.Module):
def __init__(self, dim, dim_out, groups=8):
super(Block, self).__init__()
self.block = torch.nn.Sequential(
torch.nn.Conv2d(dim, dim_out, 3, padding=1), torch.nn.GroupNorm(groups, dim_out), Mish()
)
def forward(self, x, mask):
output = self.block(x * mask)
return output * mask
class Conv1dBlock(nn.Module): class Conv1dBlock(nn.Module):
""" """
Conv1d --> GroupNorm --> Mish Conv1d --> GroupNorm --> Mish
......
...@@ -5,6 +5,7 @@ from ..modeling_utils import ModelMixin ...@@ -5,6 +5,7 @@ from ..modeling_utils import ModelMixin
from .attention import LinearAttention from .attention import LinearAttention
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample from .resnet import Downsample
from .resnet import ResnetBlock as ResnetBlockNew
from .resnet import ResnetBlockGradTTS as ResnetBlock from .resnet import ResnetBlockGradTTS as ResnetBlock
from .resnet import Upsample from .resnet import Upsample
...@@ -81,13 +82,20 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ...@@ -81,13 +82,20 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
self.ups = torch.nn.ModuleList([]) self.ups = torch.nn.ModuleList([])
num_resolutions = len(in_out) num_resolutions = len(in_out)
# num_groups = 8
# self.pre_norm = False
# eps = 1e-5
# non_linearity = "mish"
for ind, (dim_in, dim_out) in enumerate(in_out): for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1) is_last = ind >= (num_resolutions - 1)
self.downs.append( self.downs.append(
torch.nn.ModuleList( torch.nn.ModuleList(
[ [
ResnetBlock(dim_in, dim_out, time_emb_dim=dim), # ResnetBlock(dim_in, dim_out, time_emb_dim=dim),
ResnetBlock(dim_out, dim_out, time_emb_dim=dim), # ResnetBlock(dim_out, dim_out, time_emb_dim=dim),
ResnetBlockNew(in_channels=dim_in, out_channels=dim_out, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
ResnetBlockNew(in_channels=dim_out, out_channels=dim_out, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
Residual(Rezero(LinearAttention(dim_out))), Residual(Rezero(LinearAttention(dim_out))),
Downsample(dim_out, use_conv=True, padding=1) if not is_last else torch.nn.Identity(), Downsample(dim_out, use_conv=True, padding=1) if not is_last else torch.nn.Identity(),
] ]
...@@ -95,16 +103,20 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ...@@ -95,16 +103,20 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
) )
mid_dim = dims[-1] mid_dim = dims[-1]
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim) # self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
# self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
self.mid_block1 = ResnetBlockNew(in_channels=mid_dim, out_channels=mid_dim, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True)
self.mid_attn = Residual(Rezero(LinearAttention(mid_dim))) self.mid_attn = Residual(Rezero(LinearAttention(mid_dim)))
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim) self.mid_block2 = ResnetBlockNew(in_channels=mid_dim, out_channels=mid_dim, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
self.ups.append( self.ups.append(
torch.nn.ModuleList( torch.nn.ModuleList(
[ [
ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim), # ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim),
ResnetBlock(dim_in, dim_in, time_emb_dim=dim), # ResnetBlock(dim_in, dim_in, time_emb_dim=dim),
ResnetBlockNew(in_channels=dim_out * 2, out_channels=dim_in, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
ResnetBlockNew(in_channels=dim_in, out_channels=dim_in, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True),
Residual(Rezero(LinearAttention(dim_in))), Residual(Rezero(LinearAttention(dim_in))),
Upsample(dim_in, use_conv_transpose=True), Upsample(dim_in, use_conv_transpose=True),
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment