Commit f15ab901 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

fix comments

parent d1f2e3e4
from abc import abstractmethod
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -96,16 +94,6 @@ def zero_module(module): ...@@ -96,16 +94,6 @@ def zero_module(module):
return module return module
# class TimestepBlock(nn.Module):
# """
# Any module where forward() takes timestep embeddings as a second argument. #"""
#
# @abstractmethod
# def forward(self, x, emb):
# """
# Apply the module to `x` given `emb` timestep embeddings. #"""
class TimestepEmbedSequential(nn.Sequential, TimestepBlock): class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
""" """
A sequential module that passes timestep embeddings to the children that support it as an extra input. A sequential module that passes timestep embeddings to the children that support it as an extra input.
...@@ -122,101 +110,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): ...@@ -122,101 +110,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return x return x
# class ResBlock(TimestepBlock):
# """
# A residual block that can optionally change the number of channels. # # :param channels: the number of input
channels. :param emb_channels: the number of timestep embedding channels. # :param dropout: the rate of dropout. :param
out_channels: if specified, the number of out channels. :param # use_conv: if True and out_channels is specified, use a
spatial # convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. # :param
dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing # on this
module. :param up: if True, use this block for upsampling. :param down: if True, use this block for # downsampling. #"""
#
# def __init__(
# self,
# channels,
# emb_channels,
# dropout,
# out_channels=None,
# use_conv=False,
# use_scale_shift_norm=False,
# dims=2,
# use_checkpoint=False,
# up=False,
# down=False,
# ):
# super().__init__()
# self.channels = channels
# self.emb_channels = emb_channels
# self.dropout = dropout
# self.out_channels = out_channels or channels
# self.use_conv = use_conv
# self.use_checkpoint = use_checkpoint
# self.use_scale_shift_norm = use_scale_shift_norm
#
# self.in_layers = nn.Sequential(
# normalization(channels, swish=1.0),
# nn.Identity(),
# conv_nd(dims, channels, self.out_channels, 3, padding=1),
# )
#
# self.updown = up or down
#
# if up:
# self.h_upd = Upsample(channels, use_conv=False, dims=dims)
# self.x_upd = Upsample(channels, use_conv=False, dims=dims)
# elif down:
# self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
# self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
# else:
# self.h_upd = self.x_upd = nn.Identity()
#
# self.emb_layers = nn.Sequential(
# nn.SiLU(),
# linear(
# emb_channels,
# 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
# ),
# )
# self.out_layers = nn.Sequential(
# normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
# nn.SiLU() if use_scale_shift_norm else nn.Identity(),
# nn.Dropout(p=dropout),
# zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
# )
#
# if self.out_channels == channels:
# self.skip_connection = nn.Identity()
# elif use_conv:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
# else:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
#
# def forward(self, x, emb):
# """
# Apply the block to a Tensor, conditioned on a timestep embedding. # # :param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings. # :return: an [N x C x ...] Tensor of outputs. #"""
# if self.updown:
# in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
# h = in_rest(x)
# h = self.h_upd(h)
# x = self.x_upd(x)
# h = in_conv(h)
# else:
# h = self.in_layers(x)
# emb_out = self.emb_layers(emb).type(h.dtype)
# while len(emb_out.shape) < len(h.shape):
# emb_out = emb_out[..., None]
# if self.use_scale_shift_norm:
# out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
# scale, shift = torch.chunk(emb_out, 2, dim=1)
# h = out_norm(h) * (1 + scale) + shift
# h = out_rest(h)
# else:
# h = h + emb_out
# h = self.out_layers(h)
# return self.skip_connection(x) + h
class GlideUNetModel(ModelMixin, ConfigMixin): class GlideUNetModel(ModelMixin, ConfigMixin):
""" """
The full UNet model with attention and timestep embedding. The full UNet model with attention and timestep embedding.
......
...@@ -36,26 +36,6 @@ class Block(torch.nn.Module): ...@@ -36,26 +36,6 @@ class Block(torch.nn.Module):
return output * mask return output * mask
# class ResnetBlock(torch.nn.Module):
# def __init__(self, dim, dim_out, time_emb_dim, groups=8):
# super(ResnetBlock, self).__init__()
# self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out))
#
# self.block1 = Block(dim, dim_out, groups=groups)
# self.block2 = Block(dim_out, dim_out, groups=groups)
# if dim != dim_out:
# self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
# else:
# self.res_conv = torch.nn.Identity()
#
# def forward(self, x, mask, time_emb):
# h = self.block1(x, mask)
# h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1)
# h = self.block2(h, mask)
# output = h + self.res_conv(x * mask)
# return output
class Residual(torch.nn.Module): class Residual(torch.nn.Module):
def __init__(self, fn): def __init__(self, fn):
super(Residual, self).__init__() super(Residual, self).__init__()
......
import math import math
from abc import abstractmethod
from inspect import isfunction from inspect import isfunction
import numpy as np import numpy as np
...@@ -328,7 +327,6 @@ def normalization(channels, swish=0.0): ...@@ -328,7 +327,6 @@ def normalization(channels, swish=0.0):
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish) return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
## go
class AttentionPool2d(nn.Module): class AttentionPool2d(nn.Module):
""" """
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
...@@ -359,16 +357,6 @@ class AttentionPool2d(nn.Module): ...@@ -359,16 +357,6 @@ class AttentionPool2d(nn.Module):
return x[:, :, 0] return x[:, :, 0]
# class TimestepBlock(nn.Module):
# """
# Any module where forward() takes timestep embeddings as a second argument. #"""
#
# @abstractmethod
# def forward(self, x, emb):
# """
# Apply the module to `x` given `emb` timestep embeddings. #"""
class TimestepEmbedSequential(nn.Sequential, TimestepBlock): class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
""" """
A sequential module that passes timestep embeddings to the children that support it as an extra input. A sequential module that passes timestep embeddings to the children that support it as an extra input.
...@@ -385,99 +373,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): ...@@ -385,99 +373,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return x return x
# class A_ResBlock(TimestepBlock):
# """
# A residual block that can optionally change the number of channels. :param channels: the number of input channels. #
:param emb_channels: the number of timestep embedding channels. :param dropout: the rate of dropout. :param #
out_channels: if specified, the number of out channels. :param use_conv: if True and out_channels is specified, use # a
spatial # convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. # :param
dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing # on this
module. :param up: if True, use this block for upsampling. :param down: if True, use this block for # downsampling. #"""
#
# def __init__(
# self,
# channels,
# emb_channels,
# dropout,
# out_channels=None,
# use_conv=False,
# use_scale_shift_norm=False,
# dims=2,
# use_checkpoint=False,
# up=False,
# down=False,
# ):
# super().__init__()
# self.channels = channels
# self.emb_channels = emb_channels
# self.dropout = dropout
# self.out_channels = out_channels or channels
# self.use_conv = use_conv
# self.use_checkpoint = use_checkpoint
# self.use_scale_shift_norm = use_scale_shift_norm
#
# self.in_layers = nn.Sequential(
# normalization(channels),
# nn.SiLU(),
# conv_nd(dims, channels, self.out_channels, 3, padding=1),
# )
#
# self.updown = up or down
#
# if up:
# self.h_upd = Upsample(channels, use_conv=False, dims=dims)
# self.x_upd = Upsample(channels, use_conv=False, dims=dims)
# elif down:
# self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
# self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op")
# else:
# self.h_upd = self.x_upd = nn.Identity()
#
# self.emb_layers = nn.Sequential(
# nn.SiLU(),
# linear(
# emb_channels,
# 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
# ),
# )
# self.out_layers = nn.Sequential(
# normalization(self.out_channels),
# nn.SiLU(),
# nn.Dropout(p=dropout),
# zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
# )
#
# if self.out_channels == channels:
# self.skip_connection = nn.Identity()
# elif use_conv:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
# else:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
#
# def forward(self, x, emb):
# if self.updown:
# in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
# h = in_rest(x)
# h = self.h_upd(h)
# x = self.x_upd(x)
# h = in_conv(h)
# else:
# h = self.in_layers(x)
# emb_out = self.emb_layers(emb).type(h.dtype)
# while len(emb_out.shape) < len(h.shape):
# emb_out = emb_out[..., None]
# if self.use_scale_shift_norm:
# out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
# scale, shift = torch.chunk(emb_out, 2, dim=1)
# h = out_norm(h) * (1 + scale) + shift
# h = out_rest(h)
# else:
# h = h + emb_out
# h = self.out_layers(h)
# return self.skip_connection(x) + h
#
class QKVAttention(nn.Module): class QKVAttention(nn.Module):
""" """
A module which performs QKV attention and splits in a different order. A module which performs QKV attention and splits in a different order.
......
...@@ -73,37 +73,6 @@ class Conv1dBlock(nn.Module): ...@@ -73,37 +73,6 @@ class Conv1dBlock(nn.Module):
return self.block(x) return self.block(x)
# class ResidualTemporalBlock(nn.Module):
# def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
# super().__init__()
#
# self.blocks = nn.ModuleList(
# [
# Conv1dBlock(inp_channels, out_channels, kernel_size),
# Conv1dBlock(out_channels, out_channels, kernel_size),
# ]
# )
#
# self.time_mlp = nn.Sequential(
# nn.Mish(),
# nn.Linear(embed_dim, out_channels),
# RearrangeDim(),
# Rearrange("batch t -> batch t 1"),
# )
#
# self.residual_conv = (
# nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
# )
#
# def forward(self, x, t):
# """
# x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x #
out_channels x horizon ] #"""
# out = self.blocks[0](x) + self.time_mlp(t)
# out = self.blocks[1](out)
# return out + self.residual_conv(x)
class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
def __init__( def __init__(
self, self,
......
...@@ -491,137 +491,6 @@ class Downsample(nn.Module): ...@@ -491,137 +491,6 @@ class Downsample(nn.Module):
return x return x
# class ResnetBlockDDPMpp(nn.Module):
# """ResBlock adapted from DDPM."""
#
# def __init__(
# self,
# act,
# in_ch,
# out_ch=None,
# temb_dim=None,
# conv_shortcut=False,
# dropout=0.1,
# skip_rescale=False,
# init_scale=0.0,
# ):
# super().__init__()
# out_ch = out_ch if out_ch else in_ch
# self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
# self.Conv_0 = conv3x3(in_ch, out_ch)
# if temb_dim is not None:
# self.Dense_0 = nn.Linear(temb_dim, out_ch)
# self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
# nn.init.zeros_(self.Dense_0.bias)
# self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
# self.Dropout_0 = nn.Dropout(dropout)
# self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
# if in_ch != out_ch:
# if conv_shortcut:
# self.Conv_2 = conv3x3(in_ch, out_ch)
# else:
# self.NIN_0 = NIN(in_ch, out_ch)
#
# self.skip_rescale = skip_rescale
# self.act = act
# self.out_ch = out_ch
# self.conv_shortcut = conv_shortcut
#
# def forward(self, x, temb=None):
# h = self.act(self.GroupNorm_0(x))
# h = self.Conv_0(h)
# if temb is not None:
# h += self.Dense_0(self.act(temb))[:, :, None, None]
# h = self.act(self.GroupNorm_1(h))
# h = self.Dropout_0(h)
# h = self.Conv_1(h)
# if x.shape[1] != self.out_ch:
# if self.conv_shortcut:
# x = self.Conv_2(x)
# else:
# x = self.NIN_0(x)
# if not self.skip_rescale:
# return x + h
# else:
# return (x + h) / np.sqrt(2.0)
# class ResnetBlockBigGANpp(nn.Module):
# def __init__(
# self,
# act,
# in_ch,
# out_ch=None,
# temb_dim=None,
# up=False,
# down=False,
# dropout=0.1,
# fir=False,
# fir_kernel=(1, 3, 3, 1),
# skip_rescale=True,
# init_scale=0.0,
# ):
# super().__init__()
#
# out_ch = out_ch if out_ch else in_ch
# self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
# self.up = up
# self.down = down
# self.fir = fir
# self.fir_kernel = fir_kernel
#
# self.Conv_0 = conv3x3(in_ch, out_ch)
# if temb_dim is not None:
# self.Dense_0 = nn.Linear(temb_dim, out_ch)
# self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape)
# nn.init.zeros_(self.Dense_0.bias)
#
# self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
# self.Dropout_0 = nn.Dropout(dropout)
# self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
# if in_ch != out_ch or up or down:
# self.Conv_2 = conv1x1(in_ch, out_ch)
#
# self.skip_rescale = skip_rescale
# self.act = act
# self.in_ch = in_ch
# self.out_ch = out_ch
#
# def forward(self, x, temb=None):
# h = self.act(self.GroupNorm_0(x))
#
# if self.up:
# if self.fir:
# h = upsample_2d(h, self.fir_kernel, factor=2)
# x = upsample_2d(x, self.fir_kernel, factor=2)
# else:
# h = naive_upsample_2d(h, factor=2)
# x = naive_upsample_2d(x, factor=2)
# elif self.down:
# if self.fir:
# h = downsample_2d(h, self.fir_kernel, factor=2)
# x = downsample_2d(x, self.fir_kernel, factor=2)
# else:
# h = naive_downsample_2d(h, factor=2)
# x = naive_downsample_2d(x, factor=2)
#
# h = self.Conv_0(h)
# Add bias to each feature map conditioned on the time embedding
# if temb is not None:
# h += self.Dense_0(self.act(temb))[:, :, None, None]
# h = self.act(self.GroupNorm_1(h))
# h = self.Dropout_0(h)
# h = self.Conv_1(h)
#
# if self.in_ch != self.out_ch or self.up or self.down:
# x = self.Conv_2(x)
#
# if not self.skip_rescale:
# return x + h
# else:
# return (x + h) / np.sqrt(2.0)
class NCSNpp(ModelMixin, ConfigMixin): class NCSNpp(ModelMixin, ConfigMixin):
"""NCSN++ model""" """NCSN++ model"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment