Unverified Commit 06c79730 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Add unconditional image generation (#79)

* uP

* finish downsampling layers

* finish major refactor

* remove bugus file
parent ea8d58ea
...@@ -7,7 +7,7 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode ...@@ -7,7 +7,7 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode
__version__ = "0.0.4" __version__ = "0.0.4"
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .models import AutoencoderKL, NCSNpp, TemporalUNet, UNetLDMModel, UNetModel, VQModel from .models import AutoencoderKL, NCSNpp, TemporalUNet, UNetLDMModel, UNetModel, UNetUnconditionalModel, VQModel
from .pipeline_utils import DiffusionPipeline from .pipeline_utils import DiffusionPipeline
from .pipelines import ( from .pipelines import (
BDDMPipeline, BDDMPipeline,
......
...@@ -22,4 +22,5 @@ from .unet_grad_tts import UNetGradTTSModel ...@@ -22,4 +22,5 @@ from .unet_grad_tts import UNetGradTTSModel
from .unet_ldm import UNetLDMModel from .unet_ldm import UNetLDMModel
from .unet_rl import TemporalUNet from .unet_rl import TemporalUNet
from .unet_sde_score_estimation import NCSNpp from .unet_sde_score_estimation import NCSNpp
from .unet_unconditional import UNetUnconditionalModel
from .vae import AutoencoderKL, VQModel from .vae import AutoencoderKL, VQModel
...@@ -93,6 +93,7 @@ class AttentionBlock(nn.Module): ...@@ -93,6 +93,7 @@ class AttentionBlock(nn.Module):
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6) self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6)
else: else:
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1)) self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
self.set_weights(self)
self.is_overwritten = False self.is_overwritten = False
...@@ -123,11 +124,11 @@ class AttentionBlock(nn.Module): ...@@ -123,11 +124,11 @@ class AttentionBlock(nn.Module):
self.norm.weight.data = self.GroupNorm_0.weight.data self.norm.weight.data = self.GroupNorm_0.weight.data
self.norm.bias.data = self.GroupNorm_0.bias.data self.norm.bias.data = self.GroupNorm_0.bias.data
else: else:
self.proj.weight.data = module.proj_out.weight.data self.proj.weight.data = self.proj_out.weight.data
self.proj.bias.data = module.proj_out.bias.data self.proj.bias.data = self.proj_out.bias.data
def forward(self, x, encoder_out=None): def forward(self, x, encoder_out=None):
if not self.is_overwritten: if not self.is_overwritten and (self.overwrite_qkv or self.overwrite_linear):
self.set_weights(self) self.set_weights(self)
self.is_overwritten = True self.is_overwritten = True
...@@ -164,6 +165,77 @@ class AttentionBlock(nn.Module): ...@@ -164,6 +165,77 @@ class AttentionBlock(nn.Module):
return result return result
class AttentionBlockNew(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def __init__(
self,
channels,
num_head_channels=1,
num_groups=32,
encoder_channels=None,
rescale_output_factor=1.0,
):
super().__init__()
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-5, affine=True)
self.qkv = nn.Conv1d(channels, channels * 3, 1)
self.n_heads = channels // num_head_channels
self.rescale_output_factor = rescale_output_factor
if encoder_channels is not None:
self.encoder_kv = nn.Conv1d(encoder_channels, channels * 2, 1)
self.proj = zero_module(nn.Conv1d(channels, channels, 1))
def set_weight(self, attn_layer):
self.norm.weight.data = attn_layer.norm.weight.data
self.norm.bias.data = attn_layer.norm.bias.data
self.qkv.weight.data = attn_layer.qkv.weight.data
self.qkv.bias.data = attn_layer.qkv.bias.data
self.proj.weight.data = attn_layer.proj.weight.data
self.proj.bias.data = attn_layer.proj.bias.data
def forward(self, x, encoder_out=None):
b, c, *spatial = x.shape
hid_states = self.norm(x).view(b, c, -1)
qkv = self.qkv(hid_states)
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
if encoder_out is not None:
encoder_kv = self.encoder_kv(encoder_out)
assert encoder_kv.shape[1] == self.n_heads * ch * 2
ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
k = torch.cat([ek, k], dim=-1)
v = torch.cat([ev, v], dim=-1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = torch.einsum("bts,bcs->bct", weight, v)
h = a.reshape(bs, -1, length)
h = self.proj(h)
h = h.reshape(b, c, *spatial)
result = x + h
result = result / self.rescale_output_factor
return result
class SpatialTransformer(nn.Module): class SpatialTransformer(nn.Module):
""" """
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
......
...@@ -461,6 +461,7 @@ class ResnetBlock2D(nn.Module): ...@@ -461,6 +461,7 @@ class ResnetBlock2D(nn.Module):
self.skip_connection = nn.Identity() self.skip_connection = nn.Identity()
else: else:
self.skip_connection = nn.Conv2d(channels, self.out_channels, 1) self.skip_connection = nn.Conv2d(channels, self.out_channels, 1)
self.set_weights_ldm()
elif self.overwrite_for_score_vde: elif self.overwrite_for_score_vde:
in_ch = in_channels in_ch = in_channels
out_ch = out_channels out_ch = out_channels
...@@ -550,9 +551,6 @@ class ResnetBlock2D(nn.Module): ...@@ -550,9 +551,6 @@ class ResnetBlock2D(nn.Module):
if self.overwrite_for_grad_tts and not self.is_overwritten: if self.overwrite_for_grad_tts and not self.is_overwritten:
self.set_weights_grad_tts() self.set_weights_grad_tts()
self.is_overwritten = True self.is_overwritten = True
elif self.overwrite_for_ldm and not self.is_overwritten:
self.set_weights_ldm()
self.is_overwritten = True
elif self.overwrite_for_score_vde and not self.is_overwritten: elif self.overwrite_for_score_vde and not self.is_overwritten:
self.set_weights_score_vde() self.set_weights_score_vde()
self.is_overwritten = True self.is_overwritten = True
...@@ -610,6 +608,162 @@ class ResnetBlock2D(nn.Module): ...@@ -610,6 +608,162 @@ class ResnetBlock2D(nn.Module):
return (x + h) / self.output_scale_factor return (x + h) / self.output_scale_factor
class ResnetBlock(nn.Module):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout=0.0,
temb_channels=512,
groups=32,
groups_out=None,
pre_norm=True,
eps=1e-6,
non_linearity="swish",
time_embedding_norm="default",
kernel=None,
output_scale_factor=1.0,
use_nin_shortcut=None,
up=False,
down=False,
overwrite_for_grad_tts=False,
overwrite_for_ldm=False,
overwrite_for_glide=False,
overwrite_for_score_vde=False,
):
super().__init__()
self.pre_norm = pre_norm
self.pre_norm = True
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.time_embedding_norm = time_embedding_norm
self.up = up
self.down = down
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
if self.pre_norm:
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
else:
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if time_embedding_norm == "default" and temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
elif time_embedding_norm == "scale_shift" and temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = lambda x: F.silu(x)
elif non_linearity == "mish":
self.nonlinearity = Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
self.upsample = self.downsample = None
if self.up:
if kernel == "fir":
fir_kernel = (1, 3, 3, 1)
self.upsample = lambda x: upsample_2d(x, k=fir_kernel)
elif kernel == "sde_vp":
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
else:
self.upsample = Upsample2D(in_channels, use_conv=False)
elif self.down:
if kernel == "fir":
fir_kernel = (1, 3, 3, 1)
self.downsample = lambda x: downsample_2d(x, k=fir_kernel)
elif kernel == "sde_vp":
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
else:
self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut
self.nin_shortcut = None
if self.use_nin_shortcut:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x, temb):
h = x
if self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
if self.upsample is not None:
x = self.upsample(x)
h = self.upsample(h)
elif self.downsample is not None:
x = self.downsample(x)
h = self.downsample(h)
h = self.conv1(h)
if not self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
if temb is not None:
temb = self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
else:
temb = 0
if self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
h = self.norm2(h)
h = h + h * scale + shift
h = self.nonlinearity(h)
elif self.time_embedding_norm == "default":
h = h + temb
if self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if not self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
if self.nin_shortcut is not None:
x = self.nin_shortcut(x)
return (x + h) / self.output_scale_factor
def set_weight(self, resnet):
self.norm1.weight.data = resnet.norm1.weight.data
self.norm1.bias.data = resnet.norm1.bias.data
self.conv1.weight.data = resnet.conv1.weight.data
self.conv1.bias.data = resnet.conv1.bias.data
self.temb_proj.weight.data = resnet.temb_proj.weight.data
self.temb_proj.bias.data = resnet.temb_proj.bias.data
self.norm2.weight.data = resnet.norm2.weight.data
self.norm2.bias.data = resnet.norm2.bias.data
self.conv2.weight.data = resnet.conv2.weight.data
self.conv2.bias.data = resnet.conv2.bias.data
if self.use_nin_shortcut:
self.nin_shortcut.weight.data = resnet.nin_shortcut.weight.data
self.nin_shortcut.bias.data = resnet.nin_shortcut.bias.data
# TODO(Patrick) - just there to convert the weights; can delete afterward # TODO(Patrick) - just there to convert the weights; can delete afterward
class Block(torch.nn.Module): class Block(torch.nn.Module):
def __init__(self, dim, dim_out, groups=8): def __init__(self, dim, dim_out, groups=8):
......
...@@ -114,9 +114,7 @@ class UNetModel(ModelMixin, ConfigMixin): ...@@ -114,9 +114,7 @@ class UNetModel(ModelMixin, ConfigMixin):
self.mid.block_2 = ResnetBlock2D( self.mid.block_2 = ResnetBlock2D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
) )
self.mid_new = UNetMidBlock2D( self.mid_new = UNetMidBlock2D(in_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
in_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, overwrite_qkv=True, overwrite_unet=True
)
self.mid_new.resnets[0] = self.mid.block_1 self.mid_new.resnets[0] = self.mid.block_1
self.mid_new.attentions[0] = self.mid.attn_1 self.mid_new.attentions[0] = self.mid.attn_1
self.mid_new.resnets[1] = self.mid.block_2 self.mid_new.resnets[1] = self.mid.block_2
...@@ -154,7 +152,8 @@ class UNetModel(ModelMixin, ConfigMixin): ...@@ -154,7 +152,8 @@ class UNetModel(ModelMixin, ConfigMixin):
self.norm_out = Normalize(block_in) self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, x, timesteps): def forward(self, sample, timesteps):
x = sample
assert x.shape[2] == x.shape[3] == self.resolution assert x.shape[2] == x.shape[3] == self.resolution
if not torch.is_tensor(timesteps): if not torch.is_tensor(timesteps):
......
...@@ -438,7 +438,8 @@ class GlideTextToImageUNetModel(GlideUNetModel): ...@@ -438,7 +438,8 @@ class GlideTextToImageUNetModel(GlideUNetModel):
self.transformer_proj = nn.Linear(transformer_dim, self.model_channels * 4) self.transformer_proj = nn.Linear(transformer_dim, self.model_channels * 4)
def forward(self, x, timesteps, transformer_out=None): def forward(self, sample, timesteps, transformer_out=None):
x = sample
hs = [] hs = []
emb = self.time_embed( emb = self.time_embed(
get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0) get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
...@@ -528,7 +529,8 @@ class GlideSuperResUNetModel(GlideUNetModel): ...@@ -528,7 +529,8 @@ class GlideSuperResUNetModel(GlideUNetModel):
resblock_updown=resblock_updown, resblock_updown=resblock_updown,
) )
def forward(self, x, timesteps, low_res=None): def forward(self, sample, timesteps, low_res=None):
x = sample
_, _, new_height, new_width = x.shape _, _, new_height, new_width = x.shape
upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear") upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
x = torch.cat([x, upsampled], dim=1) x = torch.cat([x, upsampled], dim=1)
......
...@@ -180,7 +180,8 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ...@@ -180,7 +180,8 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
self.final_block = Block(dim, dim) self.final_block = Block(dim, dim)
self.final_conv = torch.nn.Conv2d(dim, 1, 1) self.final_conv = torch.nn.Conv2d(dim, 1, 1)
def forward(self, x, timesteps, mu, mask, spk=None): def forward(self, sample, timesteps, mu, mask, spk=None):
x = sample
if self.n_spks > 1: if self.n_spks > 1:
# Get speaker embedding # Get speaker embedding
spk = self.spk_emb(spk) spk = self.spk_emb(spk)
......
...@@ -301,6 +301,12 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -301,6 +301,12 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
self.input_blocks = nn.ModuleList( self.input_blocks = nn.ModuleList(
[TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))] [TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]
) )
self.down_in_conv = self.input_blocks[0][0]
self.downsample_blocks = nn.ModuleList([])
self.upsample_blocks = nn.ModuleList([])
# ========================= Down (OLD) =================== #
self._feature_size = model_channels self._feature_size = model_channels
input_block_chans = [model_channels] input_block_chans = [model_channels]
ch = model_channels ch = model_channels
...@@ -354,6 +360,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -354,6 +360,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
ds *= 2 ds *= 2
self._feature_size += ch self._feature_size += ch
input_channels = [model_channels * mult for mult in [1] + list(channel_mult[:-1])]
output_channels = [model_channels * mult for mult in channel_mult]
if num_head_channels == -1: if num_head_channels == -1:
dim_head = ch // num_heads dim_head = ch // num_heads
else: else:
...@@ -365,6 +374,8 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -365,6 +374,8 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
if dim_head < 0: if dim_head < 0:
dim_head = None dim_head = None
# ========================= MID (New) =================== #
self.mid = UNetMidBlock2D( self.mid = UNetMidBlock2D(
in_channels=ch, in_channels=ch,
dropout=dropout, dropout=dropout,
...@@ -414,6 +425,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -414,6 +425,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
self._feature_size += ch self._feature_size += ch
# ========================= Up (Old) =================== #
self.output_blocks = nn.ModuleList([]) self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]: for level, mult in list(enumerate(channel_mult))[::-1]:
for i in range(num_res_blocks + 1): for i in range(num_res_blocks + 1):
...@@ -462,28 +474,6 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -462,28 +474,6 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
nn.SiLU(), nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
) )
if self.predict_codebook_ids:
self.id_predictor = nn.Sequential(
normalization(ch),
conv_nd(dims, model_channels, n_embed, 1),
# nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
)
def convert_to_fp16(self):
"""
Convert the torso of the model to float16.
"""
self.input_blocks.apply(convert_module_to_f16)
self.middle_block.apply(convert_module_to_f16)
self.output_blocks.apply(convert_module_to_f16)
def convert_to_fp32(self):
"""
Convert the torso of the model to float32.
"""
self.input_blocks.apply(convert_module_to_f32)
self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32)
def forward(self, x, timesteps=None, context=None, y=None, **kwargs): def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
""" """
...@@ -505,18 +495,18 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -505,18 +495,18 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
emb = emb + self.label_emb(y) emb = emb + self.label_emb(y)
h = x.type(self.dtype_) h = x.type(self.dtype_)
for module in self.input_blocks: for module in self.input_blocks:
h = module(h, emb, context) h = module(h, emb, context)
hs.append(h) hs.append(h)
h = self.mid(h, emb, context) h = self.mid(h, emb, context)
for module in self.output_blocks: for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1) h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb, context) h = module(h, emb, context)
h = h.type(x.dtype)
if self.predict_codebook_ids: return self.out(h)
return self.id_predictor(h)
else:
return self.out(h)
class SpatialTransformer(nn.Module): class SpatialTransformer(nn.Module):
......
...@@ -12,10 +12,11 @@ ...@@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import torch
from torch import nn from torch import nn
from .attention import AttentionBlock, LinearAttention, SpatialTransformer from .attention import AttentionBlockNew
from .resnet import ResnetBlock2D from .resnet import Downsample2D, ResnetBlock, Upsample2D
class UNetMidBlock2D(nn.Module): class UNetMidBlock2D(nn.Module):
...@@ -24,30 +25,25 @@ class UNetMidBlock2D(nn.Module): ...@@ -24,30 +25,25 @@ class UNetMidBlock2D(nn.Module):
in_channels: int, in_channels: int,
temb_channels: int, temb_channels: int,
dropout: float = 0.0, dropout: float = 0.0,
num_blocks: int = 1, num_layers: int = 1,
resnet_eps: float = 1e-6, resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default", resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
attention_layer_type: str = "self", attn_num_head_channels=1,
attn_num_heads=1,
attn_num_head_channels=None,
attn_encoder_channels=None,
attn_dim_head=None,
attn_depth=None,
output_scale_factor=1.0, output_scale_factor=1.0,
overwrite_qkv=False, **kwargs,
overwrite_unet=False,
): ):
super().__init__() super().__init__()
# there is always at least one resnet # there is always at least one resnet
resnets = [ resnets = [
ResnetBlock2D( ResnetBlock(
in_channels=in_channels, in_channels=in_channels,
out_channels=in_channels, out_channels=in_channels,
temb_channels=temb_channels, temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups, groups=resnet_groups,
dropout=dropout, dropout=dropout,
time_embedding_norm=resnet_time_scale_shift, time_embedding_norm=resnet_time_scale_shift,
...@@ -58,36 +54,20 @@ class UNetMidBlock2D(nn.Module): ...@@ -58,36 +54,20 @@ class UNetMidBlock2D(nn.Module):
] ]
attentions = [] attentions = []
for _ in range(num_blocks): for _ in range(num_layers):
if attention_layer_type == "self": attentions.append(
attentions.append( AttentionBlockNew(
AttentionBlock( in_channels,
in_channels, num_head_channels=attn_num_head_channels,
num_heads=attn_num_heads, rescale_output_factor=output_scale_factor,
num_head_channels=attn_num_head_channels,
encoder_channels=attn_encoder_channels,
overwrite_qkv=overwrite_qkv,
rescale_output_factor=output_scale_factor,
)
)
elif attention_layer_type == "spatial":
attentions.append(
SpatialTransformer(
in_channels,
attn_num_heads,
attn_num_head_channels,
depth=attn_depth,
context_dim=attn_encoder_channels,
)
) )
elif attention_layer_type == "linear": )
attentions.append(LinearAttention(in_channels))
resnets.append( resnets.append(
ResnetBlock2D( ResnetBlock(
in_channels=in_channels, in_channels=in_channels,
out_channels=in_channels, out_channels=in_channels,
temb_channels=temb_channels, temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups, groups=resnet_groups,
dropout=dropout, dropout=dropout,
time_embedding_norm=resnet_time_scale_shift, time_embedding_norm=resnet_time_scale_shift,
...@@ -100,37 +80,283 @@ class UNetMidBlock2D(nn.Module): ...@@ -100,37 +80,283 @@ class UNetMidBlock2D(nn.Module):
self.attentions = nn.ModuleList(attentions) self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None, encoder_states=None, mask=1.0): def forward(self, hidden_states, temb=None, encoder_states=None, mask=None):
hidden_states = self.resnets[0](hidden_states, temb, mask=mask) if mask is not None:
hidden_states = self.resnets[0](hidden_states, temb, mask=mask)
else:
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]): for attn, resnet in zip(self.attentions, self.resnets[1:]):
hidden_states = attn(hidden_states, encoder_states) hidden_states = attn(hidden_states, encoder_states)
hidden_states = resnet(hidden_states, temb, mask=mask) if mask is not None:
hidden_states = resnet(hidden_states, temb, mask=mask)
else:
hidden_states = resnet(hidden_states, temb)
return hidden_states return hidden_states
# class UNetResAttnDownBlock(nn.Module): class UNetResAttnDownBlock2D(nn.Module):
# def __init__( def __init__(
# self, self,
# in_channels: int, in_channels: int,
# out_channels: int, out_channels: int,
# temb_channels: int, temb_channels: int,
# dropout: float = 0.0, dropout: float = 0.0,
# resnet_eps: float = 1e-6, num_layers: int = 1,
# resnet_time_scale_shift: str = "default", resnet_eps: float = 1e-6,
# resnet_act_fn: str = "swish", resnet_time_scale_shift: str = "default",
# resnet_groups: int = 32, resnet_act_fn: str = "swish",
# resnet_pre_norm: bool = True, resnet_groups: int = 32,
# attention_layer_type: str = "self", resnet_pre_norm: bool = True,
# attn_num_heads=1, attn_num_head_channels=1,
# attn_num_head_channels=None, output_scale_factor=1.0,
# attn_encoder_channels=None, add_downsample=True,
# attn_dim_head=None, ):
# attn_depth=None, super().__init__()
# output_scale_factor=1.0, resnets = []
# overwrite_qkv=False, attentions = []
# overwrite_unet=False,
# ): for i in range(num_layers):
# in_channels = in_channels if i == 0 else out_channels
# self.resents = resnets.append(
ResnetBlock(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
attentions.append(
AttentionBlockNew(
out_channels,
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
if add_downsample:
self.downsamplers = nn.ModuleList(
[Downsample2D(in_channels, use_conv=True, out_channels=out_channels, padding=1, name="op")]
)
else:
self.downsamplers = None
def forward(self, hidden_states, temb=None):
output_states = ()
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states)
output_states += (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states += (hidden_states,)
return hidden_states, output_states
class UNetResDownBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor=1.0,
add_downsample=True,
):
super().__init__()
resnets = []
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if add_downsample:
self.downsamplers = nn.ModuleList(
[Downsample2D(in_channels, use_conv=True, out_channels=out_channels, padding=1, name="op")]
)
else:
self.downsamplers = None
def forward(self, hidden_states, temb=None):
output_states = ()
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb)
output_states += (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states += (hidden_states,)
return hidden_states, output_states
class UNetResAttnUpBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
next_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attention_layer_type: str = "self",
attn_num_head_channels=1,
output_scale_factor=1.0,
add_upsample=True,
):
super().__init__()
resnets = []
attentions = []
for i in range(num_layers):
resnet_channels = in_channels if i < num_layers - 1 else next_channels
resnets.append(
ResnetBlock(
in_channels=in_channels + resnet_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
attentions.append(
AttentionBlockNew(
in_channels,
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(in_channels, use_conv=True, out_channels=in_channels)])
else:
self.upsamplers = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states
class UNetResUpBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
next_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attention_layer_type: str = "self",
output_scale_factor=1.0,
add_upsample=True,
):
super().__init__()
resnets = []
for i in range(num_layers):
resnet_channels = in_channels if i < num_layers - 1 else next_channels
resnets.append(
ResnetBlock(
in_channels=in_channels + resnet_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(in_channels, use_conv=True, out_channels=in_channels)])
else:
self.upsamplers = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states
...@@ -129,10 +129,11 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): ...@@ -129,10 +129,11 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
nn.Conv1d(dim, transition_dim, 1), nn.Conv1d(dim, transition_dim, 1),
) )
def forward(self, x, timesteps): def forward(self, sample, timesteps):
""" """
x : [ batch x horizon x transition ] x : [ batch x horizon x transition ]
""" """
x = sample
x = x.permute(0, 2, 1) x = x.permute(0, 2, 1)
...@@ -212,7 +213,6 @@ class TemporalValue(nn.Module): ...@@ -212,7 +213,6 @@ class TemporalValue(nn.Module):
""" """
x : [ batch x horizon x transition ] x : [ batch x horizon x transition ]
""" """
x = x.permute(0, 2, 1) x = x.permute(0, 2, 1)
t = self.time_mlp(time) t = self.time_mlp(time)
......
...@@ -323,7 +323,8 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -323,7 +323,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
self.all_modules = nn.ModuleList(modules) self.all_modules = nn.ModuleList(modules)
def forward(self, x, timesteps, sigmas=None): def forward(self, sample, timesteps, sigmas=None):
x = sample
# timestep/noise_level embedding; only for continuous training # timestep/noise_level embedding; only for continuous training
modules = self.all_modules modules = self.all_modules
m_idx = 0 m_idx = 0
......
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention import AttentionBlock
from .embeddings import get_timestep_embedding
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
from .unet_new import (
UNetMidBlock2D,
UNetResAttnDownBlock2D,
UNetResAttnUpBlock2D,
UNetResDownBlock2D,
UNetResUpBlock2D,
)
class UNetUnconditionalModel(ModelMixin, ConfigMixin):
"""
The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param
model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param
num_res_blocks: number of residual blocks per downsample. :param attention_resolutions: a collection of downsample
rates at which
attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x
downsampling, attention will be used.
:param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param
conv_resample: if True, use learned convolutions for upsampling and
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D. :param num_classes: if specified (as an int), then this
model will be
class-conditional with `num_classes` classes.
:param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention
heads in each attention layer. :param num_heads_channels: if specified, ignore num_heads and instead use
a fixed channel width per attention head.
:param num_heads_upsample: works with num_heads to set a different number
of heads for upsampling. Deprecated.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks
for up/downsampling. :param use_new_attention_order: use a different attention pattern for potentially
increased efficiency.
"""
def init_for_ldm(
self,
dims,
in_channels,
model_channels,
channel_mult,
num_res_blocks,
dropout,
time_embed_dim,
attention_resolutions,
num_head_channels,
num_heads,
legacy,
use_spatial_transformer,
transformer_depth,
context_dim,
conv_resample,
out_channels,
):
# TODO(PVP) - delete after weight conversion
class TimestepEmbedSequential(nn.Sequential):
"""
A sequential module that passes timestep embeddings to the children that support it as an extra input.
"""
pass
# TODO(PVP) - delete after weight conversion
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
self.input_blocks = nn.ModuleList(
[TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]
)
self._feature_size = model_channels
input_block_chans = [model_channels]
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
ResnetBlock2D(
in_channels=ch,
out_channels=mult * model_channels,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
overwrite_for_ldm=True,
)
]
ch = mult * model_channels
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
# num_heads = 1
dim_head = num_head_channels
layers.append(
AttentionBlock(
ch,
num_heads=num_heads,
num_head_channels=dim_head,
),
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
Downsample2D(ch, use_conv=conv_resample, out_channels=out_ch, padding=1, name="op")
)
)
ch = out_ch
input_block_chans.append(ch)
ds *= 2
self._feature_size += ch
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
# num_heads = 1
dim_head = num_head_channels
if dim_head < 0:
dim_head = None
# TODO(Patrick) - delete after weight conversion
# init to be able to overwrite `self.mid`
self.middle_block = TimestepEmbedSequential(
ResnetBlock2D(
in_channels=ch,
out_channels=None,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
overwrite_for_ldm=True,
),
AttentionBlock(
ch,
num_heads=num_heads,
num_head_channels=dim_head,
),
ResnetBlock2D(
in_channels=ch,
out_channels=None,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
overwrite_for_ldm=True,
),
)
self._feature_size += ch
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]:
for i in range(num_res_blocks + 1):
ich = input_block_chans.pop()
layers = [
ResnetBlock2D(
in_channels=ch + ich,
out_channels=model_channels * mult,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
overwrite_for_ldm=True,
),
]
ch = model_channels * mult
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
# num_heads = 1
dim_head = num_head_channels
layers.append(
AttentionBlock(
ch,
num_heads=-1,
num_head_channels=dim_head,
),
)
if level and i == num_res_blocks:
out_ch = ch
layers.append(Upsample2D(ch, use_conv=conv_resample, out_channels=out_ch))
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
# ================ SET WEIGHTS OF ALL WEIGHTS ==================
for i, input_layer in enumerate(self.input_blocks[1:]):
block_id = i // (num_res_blocks + 1)
layer_in_block_id = i % (num_res_blocks + 1)
if layer_in_block_id == 2:
self.downsample_blocks[block_id].downsamplers[0].op.weight.data = input_layer[0].op.weight.data
self.downsample_blocks[block_id].downsamplers[0].op.bias.data = input_layer[0].op.bias.data
elif len(input_layer) > 1:
self.downsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
self.downsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1])
else:
self.downsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
self.mid.resnets[0].set_weight(self.middle_block[0])
self.mid.resnets[1].set_weight(self.middle_block[2])
self.mid.attentions[0].set_weight(self.middle_block[1])
for i, input_layer in enumerate(self.output_blocks):
block_id = i // (num_res_blocks + 1)
layer_in_block_id = i % (num_res_blocks + 1)
if len(input_layer) > 2:
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
self.upsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1])
self.upsample_blocks[block_id].upsamplers[0].conv.weight.data = input_layer[2].conv.weight.data
self.upsample_blocks[block_id].upsamplers[0].conv.bias.data = input_layer[2].conv.bias.data
elif len(input_layer) > 1 and "Upsample2D" in input_layer[1].__class__.__name__:
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
self.upsample_blocks[block_id].upsamplers[0].conv.weight.data = input_layer[1].conv.weight.data
self.upsample_blocks[block_id].upsamplers[0].conv.bias.data = input_layer[1].conv.bias.data
elif len(input_layer) > 1:
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
self.upsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1])
else:
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
self.conv_in.weight.data = self.input_blocks[0][0].weight.data
self.conv_in.bias.data = self.input_blocks[0][0].bias.data
def __init__(
self,
image_size,
in_channels,
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
num_classes=None,
use_checkpoint=False,
use_fp16=False,
num_heads=-1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
transformer_depth=1, # custom transformer support
context_dim=None, # custom transformer support
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
legacy=True,
):
super().__init__()
# register all __init__ params with self.register
self.register_to_config(
image_size=image_size,
in_channels=in_channels,
model_channels=model_channels,
out_channels=out_channels,
num_res_blocks=num_res_blocks,
attention_resolutions=attention_resolutions,
dropout=dropout,
channel_mult=channel_mult,
conv_resample=conv_resample,
dims=dims,
num_classes=num_classes,
use_fp16=use_fp16,
num_heads=num_heads,
num_heads_upsample=num_heads_upsample,
num_head_channels=num_head_channels,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
transformer_depth=transformer_depth,
context_dim=context_dim,
n_embed=n_embed,
legacy=legacy,
)
self.image_size = image_size
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.num_classes = num_classes
self.dtype_ = torch.float16 if use_fp16 else torch.float32
self.num_heads = num_heads
self.num_heads_upsample = num_heads_upsample
self.predict_codebook_ids = n_embed is not None
time_embed_dim = model_channels * 4
# ======================== Input ===================
self.conv_in = nn.Conv2d(in_channels, model_channels, kernel_size=3, padding=(1, 1))
# ======================== Time ====================
self.time_embed = nn.Sequential(
nn.Linear(model_channels, time_embed_dim),
nn.SiLU(),
nn.Linear(time_embed_dim, time_embed_dim),
)
# ======================== Down ====================
input_channels = [model_channels * mult for mult in [1] + list(channel_mult[:-1])]
output_channels = [model_channels * mult for mult in channel_mult]
ds_new = 1
self.downsample_blocks = nn.ModuleList([])
for i, (input_channel, output_channel) in enumerate(zip(input_channels, output_channels)):
is_final_block = i == len(input_channels) - 1
if ds_new in attention_resolutions:
down_block = UNetResAttnDownBlock2D(
num_layers=num_res_blocks,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=1e-5,
resnet_act_fn="silu",
attn_num_head_channels=num_head_channels,
)
else:
down_block = UNetResDownBlock2D(
num_layers=num_res_blocks,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=1e-5,
resnet_act_fn="silu",
)
self.downsample_blocks.append(down_block)
ds_new *= 2
ds_new = ds_new / 2
# ======================== Mid ====================
self.mid = UNetMidBlock2D(
in_channels=output_channels[-1],
dropout=dropout,
temb_channels=time_embed_dim,
resnet_eps=1e-5,
resnet_act_fn="silu",
resnet_time_scale_shift="scale_shift" if use_scale_shift_norm else "default",
attn_num_head_channels=num_head_channels,
)
# ======================== Up =====================
# input_channels = [model_channels * mult for mult in channel_mult]
# output_channels = [model_channels * mult for mult in channel_mult]
self.upsample_blocks = nn.ModuleList([])
for i, (input_channel, output_channel) in enumerate(zip(reversed(input_channels), reversed(output_channels))):
is_final_block = i == len(input_channels) - 1
if ds_new in attention_resolutions:
up_block = UNetResAttnUpBlock2D(
num_layers=num_res_blocks + 1,
in_channels=output_channel,
next_channels=input_channel,
temb_channels=time_embed_dim,
add_upsample=not is_final_block,
resnet_eps=1e-5,
resnet_act_fn="silu",
attn_num_head_channels=num_head_channels,
)
else:
up_block = UNetResUpBlock2D(
num_layers=num_res_blocks + 1,
in_channels=output_channel,
next_channels=input_channel,
temb_channels=time_embed_dim,
add_upsample=not is_final_block,
resnet_eps=1e-5,
resnet_act_fn="silu",
)
self.upsample_blocks.append(up_block)
ds_new /= 2
# ======================== Out ====================
self.out = nn.Sequential(
nn.GroupNorm(num_channels=output_channels[0], num_groups=32, eps=1e-5),
nn.SiLU(),
nn.Conv2d(model_channels, out_channels, 3, padding=1),
)
# =========== TO DELETE AFTER CONVERSION ==========
self.init_for_ldm(
dims,
in_channels,
model_channels,
channel_mult,
num_res_blocks,
dropout,
time_embed_dim,
attention_resolutions,
num_head_channels,
num_heads,
legacy,
False,
transformer_depth,
context_dim,
conv_resample,
out_channels,
)
def forward(self, sample, timesteps=None):
# 1. time step embeddings
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
t_emb = get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
emb = self.time_embed(t_emb)
# 2. pre-process sample
sample = sample.type(self.dtype_)
sample = self.conv_in(sample)
# 3. down blocks
down_block_res_samples = (sample,)
for downsample_block in self.downsample_blocks:
sample, res_samples = downsample_block(sample, emb)
# append to tuple
down_block_res_samples += res_samples
# 4. mid block
sample = self.mid(sample, emb)
# 5. up blocks
for upsample_block in self.upsample_blocks:
# pop from tuple
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
sample = upsample_block(sample, res_samples, emb)
# 6. post-process sample
sample = self.out(sample)
return sample
...@@ -470,7 +470,8 @@ class VQModel(ModelMixin, ConfigMixin): ...@@ -470,7 +470,8 @@ class VQModel(ModelMixin, ConfigMixin):
dec = self.decoder(quant) dec = self.decoder(quant)
return dec return dec
def forward(self, x): def forward(self, sample):
x = sample
h = self.encode(x) h = self.encode(x)
dec = self.decode(h) dec = self.decode(h)
return dec return dec
...@@ -561,7 +562,8 @@ class AutoencoderKL(ModelMixin, ConfigMixin): ...@@ -561,7 +562,8 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
dec = self.decoder(z) dec = self.decoder(z)
return dec return dec
def forward(self, x, sample_posterior=False): def forward(self, sample, sample_posterior=False):
x = sample
posterior = self.encode(x) posterior = self.encode(x)
if sample_posterior: if sample_posterior:
z = posterior.sample() z = posterior.sample()
......
...@@ -46,6 +46,7 @@ from diffusers import ( ...@@ -46,6 +46,7 @@ from diffusers import (
UNetGradTTSModel, UNetGradTTSModel,
UNetLDMModel, UNetLDMModel,
UNetModel, UNetModel,
UNetUnconditionalModel,
VQModel, VQModel,
) )
from diffusers.configuration_utils import ConfigMixin from diffusers.configuration_utils import ConfigMixin
...@@ -146,7 +147,7 @@ class ModelTesterMixin: ...@@ -146,7 +147,7 @@ class ModelTesterMixin:
output = model(**inputs_dict) output = model(**inputs_dict)
self.assertIsNotNone(output) self.assertIsNotNone(output)
expected_shape = inputs_dict["x"].shape expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_forward_signature(self): def test_forward_signature(self):
...@@ -157,7 +158,7 @@ class ModelTesterMixin: ...@@ -157,7 +158,7 @@ class ModelTesterMixin:
# signature.parameters is an OrderedDict => so arg_names order is deterministic # signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()] arg_names = [*signature.parameters.keys()]
expected_arg_names = ["x", "timesteps"] expected_arg_names = ["sample", "timesteps"]
self.assertListEqual(arg_names[:2], expected_arg_names) self.assertListEqual(arg_names[:2], expected_arg_names)
def test_model_from_config(self): def test_model_from_config(self):
...@@ -194,7 +195,7 @@ class ModelTesterMixin: ...@@ -194,7 +195,7 @@ class ModelTesterMixin:
model.to(torch_device) model.to(torch_device)
model.train() model.train()
output = model(**inputs_dict) output = model(**inputs_dict)
noise = torch.randn((inputs_dict["x"].shape[0],) + self.output_shape).to(torch_device) noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise) loss = torch.nn.functional.mse_loss(output, noise)
loss.backward() loss.backward()
...@@ -207,7 +208,7 @@ class ModelTesterMixin: ...@@ -207,7 +208,7 @@ class ModelTesterMixin:
ema_model = EMAModel(model, device=torch_device) ema_model = EMAModel(model, device=torch_device)
output = model(**inputs_dict) output = model(**inputs_dict)
noise = torch.randn((inputs_dict["x"].shape[0],) + self.output_shape).to(torch_device) noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise) loss = torch.nn.functional.mse_loss(output, noise)
loss.backward() loss.backward()
ema_model.step(model) ema_model.step(model)
...@@ -225,7 +226,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -225,7 +226,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device) time_step = torch.tensor([10]).to(torch_device)
return {"x": noise, "timesteps": time_step} return {"sample": noise, "timesteps": time_step}
@property @property
def input_shape(self): def input_shape(self):
...@@ -291,7 +292,7 @@ class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase): ...@@ -291,7 +292,7 @@ class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase):
low_res = torch.randn((batch_size, 3) + low_res_size).to(torch_device) low_res = torch.randn((batch_size, 3) + low_res_size).to(torch_device)
time_step = torch.tensor([10] * noise.shape[0], device=torch_device) time_step = torch.tensor([10] * noise.shape[0], device=torch_device)
return {"x": noise, "timesteps": time_step, "low_res": low_res} return {"sample": noise, "timesteps": time_step, "low_res": low_res}
@property @property
def input_shape(self): def input_shape(self):
...@@ -330,7 +331,7 @@ class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase): ...@@ -330,7 +331,7 @@ class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase):
output, _ = torch.split(output, 3, dim=1) output, _ = torch.split(output, 3, dim=1)
self.assertIsNotNone(output) self.assertIsNotNone(output)
expected_shape = inputs_dict["x"].shape expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_from_pretrained_hub(self): def test_from_pretrained_hub(self):
...@@ -382,7 +383,7 @@ class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -382,7 +383,7 @@ class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
emb = torch.randn((batch_size, seq_len, transformer_dim)).to(torch_device) emb = torch.randn((batch_size, seq_len, transformer_dim)).to(torch_device)
time_step = torch.tensor([10] * noise.shape[0], device=torch_device) time_step = torch.tensor([10] * noise.shape[0], device=torch_device)
return {"x": noise, "timesteps": time_step, "transformer_out": emb} return {"sample": noise, "timesteps": time_step, "transformer_out": emb}
@property @property
def input_shape(self): def input_shape(self):
...@@ -422,7 +423,7 @@ class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -422,7 +423,7 @@ class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
output, _ = torch.split(output, 3, dim=1) output, _ = torch.split(output, 3, dim=1)
self.assertIsNotNone(output) self.assertIsNotNone(output)
expected_shape = inputs_dict["x"].shape expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_from_pretrained_hub(self): def test_from_pretrained_hub(self):
...@@ -463,7 +464,7 @@ class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -463,7 +464,7 @@ class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
model_class = UNetLDMModel model_class = UNetUnconditionalModel
@property @property
def dummy_input(self): def dummy_input(self):
...@@ -474,7 +475,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -474,7 +475,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device) time_step = torch.tensor([10]).to(torch_device)
return {"x": noise, "timesteps": time_step} return {"sample": noise, "timesteps": time_step}
@property @property
def input_shape(self): def input_shape(self):
...@@ -493,14 +494,14 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -493,14 +494,14 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
"num_res_blocks": 2, "num_res_blocks": 2,
"attention_resolutions": (16,), "attention_resolutions": (16,),
"channel_mult": (1, 2), "channel_mult": (1, 2),
"num_heads": 2, "num_head_channels": 32,
"conv_resample": True, "conv_resample": True,
} }
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
return init_dict, inputs_dict return init_dict, inputs_dict
def test_from_pretrained_hub(self): def test_from_pretrained_hub(self):
model, loading_info = UNetLDMModel.from_pretrained("fusing/unet-ldm-dummy", output_loading_info=True) model, loading_info = UNetUnconditionalModel.from_pretrained("fusing/unet-ldm-dummy", output_loading_info=True)
self.assertIsNotNone(model) self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0) self.assertEqual(len(loading_info["missing_keys"]), 0)
...@@ -510,7 +511,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -510,7 +511,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
assert image is not None, "Make sure output is not None" assert image is not None, "Make sure output is not None"
def test_output_pretrained(self): def test_output_pretrained(self):
model = UNetLDMModel.from_pretrained("fusing/unet-ldm-dummy") model = UNetUnconditionalModel.from_pretrained("fusing/unet-ldm-dummy")
model.eval() model.eval()
torch.manual_seed(0) torch.manual_seed(0)
...@@ -567,7 +568,7 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -567,7 +568,7 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
mask = floats_tensor((batch_size, 1, seq_len)).to(torch_device) mask = floats_tensor((batch_size, 1, seq_len)).to(torch_device)
time_step = torch.tensor([10] * batch_size).to(torch_device) time_step = torch.tensor([10] * batch_size).to(torch_device)
return {"x": noise, "timesteps": time_step, "mu": condition, "mask": mask} return {"sample": noise, "timesteps": time_step, "mu": condition, "mask": mask}
@property @property
def input_shape(self): def input_shape(self):
...@@ -637,7 +638,7 @@ class TemporalUNetModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -637,7 +638,7 @@ class TemporalUNetModelTests(ModelTesterMixin, unittest.TestCase):
noise = floats_tensor((batch_size, seq_len, num_features)).to(torch_device) noise = floats_tensor((batch_size, seq_len, num_features)).to(torch_device)
time_step = torch.tensor([10] * batch_size).to(torch_device) time_step = torch.tensor([10] * batch_size).to(torch_device)
return {"x": noise, "timesteps": time_step} return {"sample": noise, "timesteps": time_step}
@property @property
def input_shape(self): def input_shape(self):
...@@ -708,7 +709,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -708,7 +709,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor(batch_size * [10]).to(torch_device) time_step = torch.tensor(batch_size * [10]).to(torch_device)
return {"x": noise, "timesteps": time_step} return {"sample": noise, "timesteps": time_step}
@property @property
def input_shape(self): def input_shape(self):
...@@ -834,7 +835,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -834,7 +835,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
return {"x": image} return {"sample": image}
@property @property
def input_shape(self): def input_shape(self):
...@@ -909,7 +910,7 @@ class AutoEncoderKLTests(ModelTesterMixin, unittest.TestCase): ...@@ -909,7 +910,7 @@ class AutoEncoderKLTests(ModelTesterMixin, unittest.TestCase):
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
return {"x": image} return {"sample": image}
@property @property
def input_shape(self): def input_shape(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment