"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "cffa4e001a08054eb67c5550367fb03e771fed33"
Commit efe1e60e authored by Patrick von Platen's avatar Patrick von Platen
Browse files

merge glide into resnets

parent fd6f93b2
...@@ -161,229 +161,7 @@ class Downsample(nn.Module): ...@@ -161,229 +161,7 @@ class Downsample(nn.Module):
# RESNETS # RESNETS
# unet_glide.py # unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py
class ResBlock(TimestepBlock):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels. :param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout. :param out_channels: if specified, the number of out channels. :param
use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing
on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for
downsampling.
"""
def __init__(
self,
channels,
emb_channels,
dropout,
out_channels=None,
use_conv=False,
use_scale_shift_norm=False,
dims=2,
use_checkpoint=False,
up=False,
down=False,
overwrite=True, # TODO(Patrick) - use for glide at later stage
):
super().__init__()
self.channels = channels
self.emb_channels = emb_channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_checkpoint = use_checkpoint
self.use_scale_shift_norm = use_scale_shift_norm
self.in_layers = nn.Sequential(
normalization(channels, swish=1.0),
nn.Identity(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
)
self.updown = up or down
if up:
self.h_upd = Upsample(channels, use_conv=False, dims=dims)
self.x_upd = Upsample(channels, use_conv=False, dims=dims)
elif down:
self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
else:
self.h_upd = self.x_upd = nn.Identity()
self.emb_layers = nn.Sequential(
nn.SiLU(),
linear(
emb_channels,
2 * self.out_channels,
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels, swish=0.0),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
self.overwrite = overwrite
self.is_overwritten = False
if self.overwrite:
in_channels = channels
out_channels = self.out_channels
conv_shortcut = False
dropout = 0.0
temb_channels = emb_channels
groups = 32
pre_norm = True
eps = 1e-5
non_linearity = "silu"
self.pre_norm = pre_norm
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
# Add to init
self.time_embedding_norm = "scale_shift"
if self.pre_norm:
self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps)
else:
self.norm1 = Normalize(out_channels, num_groups=groups, eps=eps)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.temb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = nonlinearity
elif non_linearity == "mish":
self.nonlinearity = Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
if self.in_channels != self.out_channels:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.up, self.down = up, down
# if self.up:
# self.h_upd = Upsample(in_channels, use_conv=False, dims=dims)
# self.x_upd = Upsample(in_channels, use_conv=False, dims=dims)
# elif self.down:
# self.h_upd = Downsample(in_channels, use_conv=False, dims=dims, padding=1, name="op")
# self.x_upd = Downsample(in_channels, use_conv=False, dims=dims, padding=1, name="op")
def set_weights(self):
# TODO(Patrick): use for glide at later stage
self.norm1.weight.data = self.in_layers[0].weight.data
self.norm1.bias.data = self.in_layers[0].bias.data
self.conv1.weight.data = self.in_layers[-1].weight.data
self.conv1.bias.data = self.in_layers[-1].bias.data
self.temb_proj.weight.data = self.emb_layers[-1].weight.data
self.temb_proj.bias.data = self.emb_layers[-1].bias.data
self.norm2.weight.data = self.out_layers[0].weight.data
self.norm2.bias.data = self.out_layers[0].bias.data
self.conv2.weight.data = self.out_layers[-1].weight.data
self.conv2.bias.data = self.out_layers[-1].bias.data
if self.in_channels != self.out_channels:
self.nin_shortcut.weight.data = self.skip_connection.weight.data
self.nin_shortcut.bias.data = self.skip_connection.bias.data
def forward(self, x, emb):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
if self.overwrite:
# TODO(Patrick): use for glide at later stage
self.set_weights()
orig_x = x
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = in_conv(h)
else:
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = torch.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_rest(h)
else:
h = h + emb_out
h = self.out_layers(h)
result = self.skip_connection(x) + h
# TODO(Patrick) Use for glide at later stage
result = self.forward_2(orig_x, emb)
return result
def forward_2(self, x, temb):
if self.overwrite and not self.is_overwritten:
self.set_weights()
self.is_overwritten = True
h = x
h = self.norm1(h)
h = self.nonlinearity(h)
if self.up or self.down:
x = self.x_upd(x)
h = self.h_upd(h)
h = self.conv1(h)
temb = self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
if self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
h = self.norm2(h)
h = h + h * scale + shift
h = self.nonlinearity(h)
else:
h = h + temb
h = self.norm2(h)
h = self.nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x + h
# unet.py, unet_grad_tts.py, unet_ldm.py
class ResnetBlock(nn.Module): class ResnetBlock(nn.Module):
def __init__( def __init__(
self, self,
...@@ -445,12 +223,9 @@ class ResnetBlock(nn.Module): ...@@ -445,12 +223,9 @@ class ResnetBlock(nn.Module):
self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
if self.in_channels != self.out_channels: if self.in_channels != self.out_channels:
if self.use_conv_shortcut: self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
# TODO(Patrick) - this branch is never used I think => can be deleted!
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
# TODO(SURAJ, PATRICK): ALL OF THE FOLLOWING OF THE INIT METHOD CAN BE DELETED ONCE WEIGHTS ARE CONVERTED
self.is_overwritten = False self.is_overwritten = False
self.overwrite_for_glide = overwrite_for_glide self.overwrite_for_glide = overwrite_for_glide
self.overwrite_for_grad_tts = overwrite_for_grad_tts self.overwrite_for_grad_tts = overwrite_for_grad_tts
...@@ -497,8 +272,6 @@ class ResnetBlock(nn.Module): ...@@ -497,8 +272,6 @@ class ResnetBlock(nn.Module):
) )
if self.out_channels == in_channels: if self.out_channels == in_channels:
self.skip_connection = nn.Identity() self.skip_connection = nn.Identity()
# elif use_conv:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
else: else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
...@@ -541,6 +314,8 @@ class ResnetBlock(nn.Module): ...@@ -541,6 +314,8 @@ class ResnetBlock(nn.Module):
self.nin_shortcut.bias.data = self.skip_connection.bias.data self.nin_shortcut.bias.data = self.skip_connection.bias.data
def forward(self, x, temb, mask=1.0): def forward(self, x, temb, mask=1.0):
# TODO(Patrick) eventually this class should be split into multiple classes
# too many if else statements
if self.overwrite_for_grad_tts and not self.is_overwritten: if self.overwrite_for_grad_tts and not self.is_overwritten:
self.set_weights_grad_tts() self.set_weights_grad_tts()
self.is_overwritten = True self.is_overwritten = True
...@@ -566,6 +341,7 @@ class ResnetBlock(nn.Module): ...@@ -566,6 +341,7 @@ class ResnetBlock(nn.Module):
h = h * mask h = h * mask
temb = self.temb_proj(self.nonlinearity(temb))[:, :, None, None] temb = self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
if self.time_embedding_norm == "scale_shift": if self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1) scale, shift = torch.chunk(temb, 2, dim=1)
...@@ -589,9 +365,6 @@ class ResnetBlock(nn.Module): ...@@ -589,9 +365,6 @@ class ResnetBlock(nn.Module):
x = x * mask x = x * mask
if self.in_channels != self.out_channels: if self.in_channels != self.out_channels:
# if self.use_conv_shortcut:
# x = self.conv_shortcut(x)
# else:
x = self.nin_shortcut(x) x = self.nin_shortcut(x)
return x + h return x + h
...@@ -605,10 +378,6 @@ class Block(torch.nn.Module): ...@@ -605,10 +378,6 @@ class Block(torch.nn.Module):
torch.nn.Conv2d(dim, dim_out, 3, padding=1), torch.nn.GroupNorm(groups, dim_out), Mish() torch.nn.Conv2d(dim, dim_out, 3, padding=1), torch.nn.GroupNorm(groups, dim_out), Mish()
) )
def forward(self, x, mask):
output = self.block(x * mask)
return output * mask
# unet_score_estimation.py # unet_score_estimation.py
class ResnetBlockBigGANpp(nn.Module): class ResnetBlockBigGANpp(nn.Module):
......
...@@ -6,8 +6,7 @@ from ..configuration_utils import ConfigMixin ...@@ -6,8 +6,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import AttentionBlock from .attention import AttentionBlock
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample, ResBlock, TimestepBlock, Upsample from .resnet import Downsample, ResnetBlock, TimestepBlock, Upsample
from .resnet import ResnetBlock
def convert_module_to_f16(l): def convert_module_to_f16(l):
...@@ -191,15 +190,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin): ...@@ -191,15 +190,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
for level, mult in enumerate(channel_mult): for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks): for _ in range(num_res_blocks):
layers = [ layers = [
# ResBlock(
# ch,
# time_embed_dim,
# dropout,
# out_channels=int(mult * model_channels),
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# )
ResnetBlock( ResnetBlock(
in_channels=ch, in_channels=ch,
out_channels=mult * model_channels, out_channels=mult * model_channels,
...@@ -207,7 +197,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin): ...@@ -207,7 +197,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
temb_channels=time_embed_dim, temb_channels=time_embed_dim,
eps=1e-5, eps=1e-5,
non_linearity="silu", non_linearity="silu",
time_embedding_norm="scale_shift", time_embedding_norm="scale_shift" if use_scale_shift_norm else "default",
overwrite_for_glide=True, overwrite_for_glide=True,
) )
] ]
...@@ -229,16 +219,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin): ...@@ -229,16 +219,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
out_ch = ch out_ch = ch
self.input_blocks.append( self.input_blocks.append(
TimestepEmbedSequential( TimestepEmbedSequential(
# ResBlock(
# ch,
# time_embed_dim,
# dropout,
# out_channels=out_ch,
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# down=True,
# )
ResnetBlock( ResnetBlock(
in_channels=ch, in_channels=ch,
out_channels=out_ch, out_channels=out_ch,
...@@ -246,9 +226,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin): ...@@ -246,9 +226,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
temb_channels=time_embed_dim, temb_channels=time_embed_dim,
eps=1e-5, eps=1e-5,
non_linearity="silu", non_linearity="silu",
time_embedding_norm="scale_shift", time_embedding_norm="scale_shift" if use_scale_shift_norm else "default",
overwrite_for_glide=True, overwrite_for_glide=True,
down=True down=True,
) )
if resblock_updown if resblock_updown
else Downsample( else Downsample(
...@@ -262,21 +242,13 @@ class GlideUNetModel(ModelMixin, ConfigMixin): ...@@ -262,21 +242,13 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
self._feature_size += ch self._feature_size += ch
self.middle_block = TimestepEmbedSequential( self.middle_block = TimestepEmbedSequential(
# ResBlock(
# ch,
# time_embed_dim,
# dropout,
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# ),
ResnetBlock( ResnetBlock(
in_channels=ch, in_channels=ch,
dropout=dropout, dropout=dropout,
temb_channels=time_embed_dim, temb_channels=time_embed_dim,
eps=1e-5, eps=1e-5,
non_linearity="silu", non_linearity="silu",
time_embedding_norm="scale_shift", time_embedding_norm="scale_shift" if use_scale_shift_norm else "default",
overwrite_for_glide=True, overwrite_for_glide=True,
), ),
AttentionBlock( AttentionBlock(
...@@ -286,23 +258,15 @@ class GlideUNetModel(ModelMixin, ConfigMixin): ...@@ -286,23 +258,15 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
num_head_channels=num_head_channels, num_head_channels=num_head_channels,
encoder_channels=transformer_dim, encoder_channels=transformer_dim,
), ),
# ResBlock(
# ch,
# time_embed_dim,
# dropout,
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# ),
ResnetBlock( ResnetBlock(
in_channels=ch, in_channels=ch,
dropout=dropout, dropout=dropout,
temb_channels=time_embed_dim, temb_channels=time_embed_dim,
eps=1e-5, eps=1e-5,
non_linearity="silu", non_linearity="silu",
time_embedding_norm="scale_shift", time_embedding_norm="scale_shift" if use_scale_shift_norm else "default",
overwrite_for_glide=True, overwrite_for_glide=True,
) ),
) )
self._feature_size += ch self._feature_size += ch
...@@ -311,15 +275,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin): ...@@ -311,15 +275,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
for i in range(num_res_blocks + 1): for i in range(num_res_blocks + 1):
ich = input_block_chans.pop() ich = input_block_chans.pop()
layers = [ layers = [
# ResBlock(
# ch + ich,
# time_embed_dim,
# dropout,
# out_channels=int(model_channels * mult),
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# )
ResnetBlock( ResnetBlock(
in_channels=ch + ich, in_channels=ch + ich,
out_channels=model_channels * mult, out_channels=model_channels * mult,
...@@ -327,7 +282,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin): ...@@ -327,7 +282,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
temb_channels=time_embed_dim, temb_channels=time_embed_dim,
eps=1e-5, eps=1e-5,
non_linearity="silu", non_linearity="silu",
time_embedding_norm="scale_shift", time_embedding_norm="scale_shift" if use_scale_shift_norm else "default",
overwrite_for_glide=True, overwrite_for_glide=True,
), ),
] ]
...@@ -345,16 +300,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin): ...@@ -345,16 +300,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
if level and i == num_res_blocks: if level and i == num_res_blocks:
out_ch = ch out_ch = ch
layers.append( layers.append(
# ResBlock(
# ch,
# time_embed_dim,
# dropout,
# out_channels=out_ch,
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# up=True,
# )
ResnetBlock( ResnetBlock(
in_channels=ch, in_channels=ch,
out_channels=out_ch, out_channels=out_ch,
...@@ -362,7 +307,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin): ...@@ -362,7 +307,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
temb_channels=time_embed_dim, temb_channels=time_embed_dim,
eps=1e-5, eps=1e-5,
non_linearity="silu", non_linearity="silu",
time_embedding_norm="scale_shift", time_embedding_norm="scale_shift" if use_scale_shift_norm else "default",
overwrite_for_glide=True, overwrite_for_glide=True,
up=True, up=True,
) )
......
...@@ -795,7 +795,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -795,7 +795,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
sizes = (32, 32) sizes = (32, 32)
noise = torch.randn((batch_size, num_channels) + sizes).to(torch_device) noise = torch.randn((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor(batch_size * [9.]).to(torch_device) time_step = torch.tensor(batch_size * [9.0]).to(torch_device)
with torch.no_grad(): with torch.no_grad():
output = model(noise, time_step) output = model(noise, time_step)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment