Commit 7e0fd19f authored by anton-l's avatar anton-l
Browse files

Merge remote-tracking branch 'origin/main'

parents 21aac1ac b65eb377
...@@ -162,7 +162,7 @@ class Downsample(nn.Module): ...@@ -162,7 +162,7 @@ class Downsample(nn.Module):
# RESNETS # RESNETS
# unet_glide.py & unet_ldm.py # unet_glide.py
class ResBlock(TimestepBlock): class ResBlock(TimestepBlock):
""" """
A residual block that can optionally change the number of channels. A residual block that can optionally change the number of channels.
...@@ -188,6 +188,7 @@ class ResBlock(TimestepBlock): ...@@ -188,6 +188,7 @@ class ResBlock(TimestepBlock):
use_checkpoint=False, use_checkpoint=False,
up=False, up=False,
down=False, down=False,
overwrite=False, # TODO(Patrick) - use for glide at later stage
): ):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
...@@ -236,6 +237,65 @@ class ResBlock(TimestepBlock): ...@@ -236,6 +237,65 @@ class ResBlock(TimestepBlock):
else: else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
self.overwrite = overwrite
self.is_overwritten = False
if self.overwrite:
in_channels = channels
out_channels = self.out_channels
conv_shortcut = False
dropout = 0.0
temb_channels = emb_channels
groups = 32
pre_norm = True
eps = 1e-5
non_linearity = "silu"
self.pre_norm = pre_norm
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
if self.pre_norm:
self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps)
else:
self.norm1 = Normalize(out_channels, num_groups=groups, eps=eps)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = nonlinearity
elif non_linearity == "mish":
self.nonlinearity = Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
if self.in_channels != self.out_channels:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def set_weights(self):
# TODO(Patrick): use for glide at later stage
self.norm1.weight.data = self.in_layers[0].weight.data
self.norm1.bias.data = self.in_layers[0].bias.data
self.conv1.weight.data = self.in_layers[-1].weight.data
self.conv1.bias.data = self.in_layers[-1].bias.data
self.temb_proj.weight.data = self.emb_layers[-1].weight.data
self.temb_proj.bias.data = self.emb_layers[-1].bias.data
self.norm2.weight.data = self.out_layers[0].weight.data
self.norm2.bias.data = self.out_layers[0].bias.data
self.conv2.weight.data = self.out_layers[-1].weight.data
self.conv2.bias.data = self.out_layers[-1].bias.data
if self.in_channels != self.out_channels:
self.nin_shortcut.weight.data = self.skip_connection.weight.data
self.nin_shortcut.bias.data = self.skip_connection.bias.data
def forward(self, x, emb): def forward(self, x, emb):
""" """
Apply the block to a Tensor, conditioned on a timestep embedding. Apply the block to a Tensor, conditioned on a timestep embedding.
...@@ -243,6 +303,10 @@ class ResBlock(TimestepBlock): ...@@ -243,6 +303,10 @@ class ResBlock(TimestepBlock):
:param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings. :param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs. :return: an [N x C x ...] Tensor of outputs.
""" """
if self.overwrite:
# TODO(Patrick): use for glide at later stage
self.set_weights()
if self.updown: if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x) h = in_rest(x)
...@@ -251,6 +315,7 @@ class ResBlock(TimestepBlock): ...@@ -251,6 +315,7 @@ class ResBlock(TimestepBlock):
h = in_conv(h) h = in_conv(h)
else: else:
h = self.in_layers(x) h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype) emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape): while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None] emb_out = emb_out[..., None]
...@@ -262,7 +327,50 @@ class ResBlock(TimestepBlock): ...@@ -262,7 +327,50 @@ class ResBlock(TimestepBlock):
else: else:
h = h + emb_out h = h + emb_out
h = self.out_layers(h) h = self.out_layers(h)
return self.skip_connection(x) + h
result = self.skip_connection(x) + h
# TODO(Patrick) Use for glide at later stage
# result = self.forward_2(x, emb)
return result
def forward_2(self, x, temb, mask=1.0):
if self.overwrite and not self.is_overwritten:
self.set_weights()
self.is_overwritten = True
h = x
if self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
h = self.conv1(h)
if not self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
h = h + self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
if self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if not self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
# unet.py and unet_grad_tts.py # unet.py and unet_grad_tts.py
...@@ -280,6 +388,7 @@ class ResnetBlock(nn.Module): ...@@ -280,6 +388,7 @@ class ResnetBlock(nn.Module):
eps=1e-6, eps=1e-6,
non_linearity="swish", non_linearity="swish",
overwrite_for_grad_tts=False, overwrite_for_grad_tts=False,
overwrite_for_ldm=False,
): ):
super().__init__() super().__init__()
self.pre_norm = pre_norm self.pre_norm = pre_norm
...@@ -302,15 +411,19 @@ class ResnetBlock(nn.Module): ...@@ -302,15 +411,19 @@ class ResnetBlock(nn.Module):
self.nonlinearity = nonlinearity self.nonlinearity = nonlinearity
elif non_linearity == "mish": elif non_linearity == "mish":
self.nonlinearity = Mish() self.nonlinearity = Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
if self.in_channels != self.out_channels: if self.in_channels != self.out_channels:
if self.use_conv_shortcut: if self.use_conv_shortcut:
# TODO(Patrick) - this branch is never used I think => can be deleted!
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
else: else:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.is_overwritten = False self.is_overwritten = False
self.overwrite_for_grad_tts = overwrite_for_grad_tts self.overwrite_for_grad_tts = overwrite_for_grad_tts
self.overwrite_for_ldm = overwrite_for_ldm
if self.overwrite_for_grad_tts: if self.overwrite_for_grad_tts:
dim = in_channels dim = in_channels
dim_out = out_channels dim_out = out_channels
...@@ -324,6 +437,39 @@ class ResnetBlock(nn.Module): ...@@ -324,6 +437,39 @@ class ResnetBlock(nn.Module):
self.res_conv = torch.nn.Conv2d(dim, dim_out, 1) self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
else: else:
self.res_conv = torch.nn.Identity() self.res_conv = torch.nn.Identity()
elif self.overwrite_for_ldm:
dims = 2
# eps = 1e-5
# non_linearity = "silu"
# overwrite_for_ldm
channels = in_channels
emb_channels = temb_channels
use_scale_shift_norm = False
self.in_layers = nn.Sequential(
normalization(channels, swish=1.0),
nn.Identity(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
)
self.emb_layers = nn.Sequential(
nn.SiLU(),
linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
nn.SiLU() if use_scale_shift_norm else nn.Identity(),
nn.Dropout(p=dropout),
zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
)
if self.out_channels == in_channels:
self.skip_connection = nn.Identity()
# elif use_conv:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
def set_weights_grad_tts(self): def set_weights_grad_tts(self):
self.conv1.weight.data = self.block1.block[0].weight.data self.conv1.weight.data = self.block1.block[0].weight.data
...@@ -343,13 +489,36 @@ class ResnetBlock(nn.Module): ...@@ -343,13 +489,36 @@ class ResnetBlock(nn.Module):
self.nin_shortcut.weight.data = self.res_conv.weight.data self.nin_shortcut.weight.data = self.res_conv.weight.data
self.nin_shortcut.bias.data = self.res_conv.bias.data self.nin_shortcut.bias.data = self.res_conv.bias.data
def forward(self, x, temb, mask=None): def set_weights_ldm(self):
self.norm1.weight.data = self.in_layers[0].weight.data
self.norm1.bias.data = self.in_layers[0].bias.data
self.conv1.weight.data = self.in_layers[-1].weight.data
self.conv1.bias.data = self.in_layers[-1].bias.data
self.temb_proj.weight.data = self.emb_layers[-1].weight.data
self.temb_proj.bias.data = self.emb_layers[-1].bias.data
self.norm2.weight.data = self.out_layers[0].weight.data
self.norm2.bias.data = self.out_layers[0].bias.data
self.conv2.weight.data = self.out_layers[-1].weight.data
self.conv2.bias.data = self.out_layers[-1].bias.data
if self.in_channels != self.out_channels:
self.nin_shortcut.weight.data = self.skip_connection.weight.data
self.nin_shortcut.bias.data = self.skip_connection.bias.data
def forward(self, x, temb, mask=1.0):
if self.overwrite_for_grad_tts and not self.is_overwritten: if self.overwrite_for_grad_tts and not self.is_overwritten:
self.set_weights_grad_tts() self.set_weights_grad_tts()
self.is_overwritten = True self.is_overwritten = True
elif self.overwrite_for_ldm and not self.is_overwritten:
self.set_weights_ldm()
self.is_overwritten = True
h = x h = x
h = h * mask if mask is not None else h h = h * mask
if self.pre_norm: if self.pre_norm:
h = self.norm1(h) h = self.norm1(h)
h = self.nonlinearity(h) h = self.nonlinearity(h)
...@@ -359,11 +528,11 @@ class ResnetBlock(nn.Module): ...@@ -359,11 +528,11 @@ class ResnetBlock(nn.Module):
if not self.pre_norm: if not self.pre_norm:
h = self.norm1(h) h = self.norm1(h)
h = self.nonlinearity(h) h = self.nonlinearity(h)
h = h * mask if mask is not None else h h = h * mask
h = h + self.temb_proj(self.nonlinearity(temb))[:, :, None, None] h = h + self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
h = h * mask if mask is not None else h h = h * mask
if self.pre_norm: if self.pre_norm:
h = self.norm2(h) h = self.norm2(h)
h = self.nonlinearity(h) h = self.nonlinearity(h)
...@@ -374,9 +543,9 @@ class ResnetBlock(nn.Module): ...@@ -374,9 +543,9 @@ class ResnetBlock(nn.Module):
if not self.pre_norm: if not self.pre_norm:
h = self.norm2(h) h = self.norm2(h)
h = self.nonlinearity(h) h = self.nonlinearity(h)
h = h * mask if mask is not None else h h = h * mask
x = x * mask if mask is not None else x x = x * mask
if self.in_channels != self.out_channels: if self.in_channels != self.out_channels:
if self.use_conv_shortcut: if self.use_conv_shortcut:
x = self.conv_shortcut(x) x = self.conv_shortcut(x)
......
...@@ -10,7 +10,10 @@ from ..configuration_utils import ConfigMixin ...@@ -10,7 +10,10 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import AttentionBlock from .attention import AttentionBlock
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample, ResBlock, TimestepBlock, Upsample from .resnet import Downsample, ResnetBlock, TimestepBlock, Upsample
# from .resnet import ResBlock
def exists(val): def exists(val):
...@@ -364,7 +367,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): ...@@ -364,7 +367,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
def forward(self, x, emb, context=None): def forward(self, x, emb, context=None):
for layer in self: for layer in self:
if isinstance(layer, TimestepBlock): if isinstance(layer, TimestepBlock) or isinstance(layer, ResnetBlock):
x = layer(x, emb) x = layer(x, emb)
elif isinstance(layer, SpatialTransformer): elif isinstance(layer, SpatialTransformer):
x = layer(x, context) x = layer(x, context)
...@@ -559,14 +562,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -559,14 +562,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
for level, mult in enumerate(channel_mult): for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks): for _ in range(num_res_blocks):
layers = [ layers = [
ResBlock( ResnetBlock(
ch, in_channels=ch,
time_embed_dim,
dropout,
out_channels=mult * model_channels, out_channels=mult * model_channels,
dims=dims, dropout=dropout,
use_checkpoint=use_checkpoint, temb_channels=time_embed_dim,
use_scale_shift_norm=use_scale_shift_norm, eps=1e-5,
non_linearity="silu",
overwrite_for_ldm=True,
) )
] ]
ch = mult * model_channels ch = mult * model_channels
...@@ -599,16 +602,17 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -599,16 +602,17 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
out_ch = ch out_ch = ch
self.input_blocks.append( self.input_blocks.append(
TimestepEmbedSequential( TimestepEmbedSequential(
ResBlock( # ResBlock(
ch, # ch,
time_embed_dim, # time_embed_dim,
dropout, # dropout,
out_channels=out_ch, # out_channels=out_ch,
dims=dims, # dims=dims,
use_checkpoint=use_checkpoint, # use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, # use_scale_shift_norm=use_scale_shift_norm,
down=True, # down=True,
) # )
None
if resblock_updown if resblock_updown
else Downsample( else Downsample(
ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op" ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op"
...@@ -629,13 +633,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -629,13 +633,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
# num_heads = 1 # num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
self.middle_block = TimestepEmbedSequential( self.middle_block = TimestepEmbedSequential(
ResBlock( ResnetBlock(
ch, in_channels=ch,
time_embed_dim, out_channels=None,
dropout, dropout=dropout,
dims=dims, temb_channels=time_embed_dim,
use_checkpoint=use_checkpoint, eps=1e-5,
use_scale_shift_norm=use_scale_shift_norm, non_linearity="silu",
overwrite_for_ldm=True,
), ),
AttentionBlock( AttentionBlock(
ch, ch,
...@@ -646,13 +651,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -646,13 +651,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
) )
if not use_spatial_transformer if not use_spatial_transformer
else SpatialTransformer(ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim), else SpatialTransformer(ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim),
ResBlock( ResnetBlock(
ch, in_channels=ch,
time_embed_dim, out_channels=None,
dropout, dropout=dropout,
dims=dims, temb_channels=time_embed_dim,
use_checkpoint=use_checkpoint, eps=1e-5,
use_scale_shift_norm=use_scale_shift_norm, non_linearity="silu",
overwrite_for_ldm=True,
), ),
) )
self._feature_size += ch self._feature_size += ch
...@@ -662,15 +668,15 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -662,15 +668,15 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
for i in range(num_res_blocks + 1): for i in range(num_res_blocks + 1):
ich = input_block_chans.pop() ich = input_block_chans.pop()
layers = [ layers = [
ResBlock( ResnetBlock(
ch + ich, in_channels=ch + ich,
time_embed_dim,
dropout,
out_channels=model_channels * mult, out_channels=model_channels * mult,
dims=dims, dropout=dropout,
use_checkpoint=use_checkpoint, temb_channels=time_embed_dim,
use_scale_shift_norm=use_scale_shift_norm, eps=1e-5,
) non_linearity="silu",
overwrite_for_ldm=True,
),
] ]
ch = model_channels * mult ch = model_channels * mult
if ds in attention_resolutions: if ds in attention_resolutions:
...@@ -698,16 +704,17 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -698,16 +704,17 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
if level and i == num_res_blocks: if level and i == num_res_blocks:
out_ch = ch out_ch = ch
layers.append( layers.append(
ResBlock( # ResBlock(
ch, # ch,
time_embed_dim, # time_embed_dim,
dropout, # dropout,
out_channels=out_ch, # out_channels=out_ch,
dims=dims, # dims=dims,
use_checkpoint=use_checkpoint, # use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, # use_scale_shift_norm=use_scale_shift_norm,
up=True, # up=True,
) # )
None
if resblock_updown if resblock_updown
else Upsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch) else Upsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch)
) )
...@@ -842,15 +849,15 @@ class EncoderUNetModel(nn.Module): ...@@ -842,15 +849,15 @@ class EncoderUNetModel(nn.Module):
for level, mult in enumerate(channel_mult): for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks): for _ in range(num_res_blocks):
layers = [ layers = [
ResBlock( ResnetBlock(
ch, in_channels=ch,
time_embed_dim, out_channels=model_channels * mult,
dropout, dropout=dropout,
out_channels=mult * model_channels, temb_channels=time_embed_dim,
dims=dims, eps=1e-5,
use_checkpoint=use_checkpoint, non_linearity="silu",
use_scale_shift_norm=use_scale_shift_norm, overwrite_for_ldm=True,
) ),
] ]
ch = mult * model_channels ch = mult * model_channels
if ds in attention_resolutions: if ds in attention_resolutions:
...@@ -870,16 +877,17 @@ class EncoderUNetModel(nn.Module): ...@@ -870,16 +877,17 @@ class EncoderUNetModel(nn.Module):
out_ch = ch out_ch = ch
self.input_blocks.append( self.input_blocks.append(
TimestepEmbedSequential( TimestepEmbedSequential(
ResBlock( # ResBlock(
ch, # ch,
time_embed_dim, # time_embed_dim,
dropout, # dropout,
out_channels=out_ch, # out_channels=out_ch,
dims=dims, # dims=dims,
use_checkpoint=use_checkpoint, # use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, # use_scale_shift_norm=use_scale_shift_norm,
down=True, # down=True,
) # )
None
if resblock_updown if resblock_updown
else Downsample( else Downsample(
ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op" ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op"
...@@ -892,13 +900,14 @@ class EncoderUNetModel(nn.Module): ...@@ -892,13 +900,14 @@ class EncoderUNetModel(nn.Module):
self._feature_size += ch self._feature_size += ch
self.middle_block = TimestepEmbedSequential( self.middle_block = TimestepEmbedSequential(
ResBlock( ResnetBlock(
ch, in_channels=ch,
time_embed_dim, out_channels=None,
dropout, dropout=dropout,
dims=dims, temb_channels=time_embed_dim,
use_checkpoint=use_checkpoint, eps=1e-5,
use_scale_shift_norm=use_scale_shift_norm, non_linearity="silu",
overwrite_for_ldm=True,
), ),
AttentionBlock( AttentionBlock(
ch, ch,
...@@ -907,13 +916,14 @@ class EncoderUNetModel(nn.Module): ...@@ -907,13 +916,14 @@ class EncoderUNetModel(nn.Module):
num_head_channels=num_head_channels, num_head_channels=num_head_channels,
use_new_attention_order=use_new_attention_order, use_new_attention_order=use_new_attention_order,
), ),
ResBlock( ResnetBlock(
ch, in_channels=ch,
time_embed_dim, out_channels=None,
dropout, dropout=dropout,
dims=dims, temb_channels=time_embed_dim,
use_checkpoint=use_checkpoint, eps=1e-5,
use_scale_shift_norm=use_scale_shift_norm, non_linearity="silu",
overwrite_for_ldm=True,
), ),
) )
self._feature_size += ch self._feature_size += ch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment