Unverified Commit c1efda70 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Clean up] Clean unused code (#245)

* CleanResNet

* refactor more

* correct
parent 47893164
...@@ -390,7 +390,7 @@ class ModelMixin(torch.nn.Module): ...@@ -390,7 +390,7 @@ class ModelMixin(torch.nn.Module):
) )
except EntryNotFoundError: except EntryNotFoundError:
raise EnvironmentError( raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {model_file}." f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}."
) )
except HTTPError as err: except HTTPError as err:
raise EnvironmentError( raise EnvironmentError(
......
import math import math
from inspect import isfunction
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
class AttentionBlockNew(nn.Module): class AttentionBlock(nn.Module):
""" """
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
to the N-d case. to the N-d case.
...@@ -82,55 +81,6 @@ class AttentionBlockNew(nn.Module): ...@@ -82,55 +81,6 @@ class AttentionBlockNew(nn.Module):
hidden_states = (hidden_states + residual) / self.rescale_output_factor hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states return hidden_states
def set_weight(self, attn_layer):
self.group_norm.weight.data = attn_layer.norm.weight.data
self.group_norm.bias.data = attn_layer.norm.bias.data
if hasattr(attn_layer, "q"):
self.query.weight.data = attn_layer.q.weight.data[:, :, 0, 0]
self.key.weight.data = attn_layer.k.weight.data[:, :, 0, 0]
self.value.weight.data = attn_layer.v.weight.data[:, :, 0, 0]
self.query.bias.data = attn_layer.q.bias.data
self.key.bias.data = attn_layer.k.bias.data
self.value.bias.data = attn_layer.v.bias.data
self.proj_attn.weight.data = attn_layer.proj_out.weight.data[:, :, 0, 0]
self.proj_attn.bias.data = attn_layer.proj_out.bias.data
elif hasattr(attn_layer, "NIN_0"):
self.query.weight.data = attn_layer.NIN_0.W.data.T
self.key.weight.data = attn_layer.NIN_1.W.data.T
self.value.weight.data = attn_layer.NIN_2.W.data.T
self.query.bias.data = attn_layer.NIN_0.b.data
self.key.bias.data = attn_layer.NIN_1.b.data
self.value.bias.data = attn_layer.NIN_2.b.data
self.proj_attn.weight.data = attn_layer.NIN_3.W.data.T
self.proj_attn.bias.data = attn_layer.NIN_3.b.data
self.group_norm.weight.data = attn_layer.GroupNorm_0.weight.data
self.group_norm.bias.data = attn_layer.GroupNorm_0.bias.data
else:
qkv_weight = attn_layer.qkv.weight.data.reshape(
self.num_heads, 3 * self.channels // self.num_heads, self.channels
)
qkv_bias = attn_layer.qkv.bias.data.reshape(self.num_heads, 3 * self.channels // self.num_heads)
q_w, k_w, v_w = qkv_weight.split(self.channels // self.num_heads, dim=1)
q_b, k_b, v_b = qkv_bias.split(self.channels // self.num_heads, dim=1)
self.query.weight.data = q_w.reshape(-1, self.channels)
self.key.weight.data = k_w.reshape(-1, self.channels)
self.value.weight.data = v_w.reshape(-1, self.channels)
self.query.bias.data = q_b.reshape(-1)
self.key.bias.data = k_b.reshape(-1)
self.value.bias.data = v_b.reshape(-1)
self.proj_attn.weight.data = attn_layer.proj.weight.data[:, :, 0]
self.proj_attn.bias.data = attn_layer.proj.bias.data
class SpatialTransformer(nn.Module): class SpatialTransformer(nn.Module):
""" """
...@@ -170,12 +120,6 @@ class SpatialTransformer(nn.Module): ...@@ -170,12 +120,6 @@ class SpatialTransformer(nn.Module):
x = self.proj_out(x) x = self.proj_out(x)
return x + x_in return x + x_in
def set_weight(self, layer):
self.norm = layer.norm
self.proj_in = layer.proj_in
self.transformer_blocks = layer.transformer_blocks
self.proj_out = layer.proj_out
class BasicTransformerBlock(nn.Module): class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True): def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True):
...@@ -203,7 +147,7 @@ class CrossAttention(nn.Module): ...@@ -203,7 +147,7 @@ class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
super().__init__() super().__init__()
inner_dim = dim_head * heads inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim) context_dim = context_dim if context_dim is not None else query_dim
self.scale = dim_head**-0.5 self.scale = dim_head**-0.5
self.heads = heads self.heads = heads
...@@ -234,7 +178,7 @@ class CrossAttention(nn.Module): ...@@ -234,7 +178,7 @@ class CrossAttention(nn.Module):
h = self.heads h = self.heads
q = self.to_q(x) q = self.to_q(x)
context = default(context, x) context = context if context is not None else x
k = self.to_k(context) k = self.to_k(context)
v = self.to_v(context) v = self.to_v(context)
...@@ -244,7 +188,7 @@ class CrossAttention(nn.Module): ...@@ -244,7 +188,7 @@ class CrossAttention(nn.Module):
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
if exists(mask): if mask is not None:
mask = mask.reshape(batch_size, -1) mask = mask.reshape(batch_size, -1)
max_neg_value = -torch.finfo(sim.dtype).max max_neg_value = -torch.finfo(sim.dtype).max
mask = mask[:, None, :].repeat(h, 1, 1) mask = mask[:, None, :].repeat(h, 1, 1)
...@@ -262,8 +206,8 @@ class FeedForward(nn.Module): ...@@ -262,8 +206,8 @@ class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__() super().__init__()
inner_dim = int(dim * mult) inner_dim = int(dim * mult)
dim_out = default(dim_out, dim) dim_out = dim_out if dim_out is not None else dim
project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) project_in = GEGLU(dim, inner_dim)
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
...@@ -280,155 +224,3 @@ class GEGLU(nn.Module): ...@@ -280,155 +224,3 @@ class GEGLU(nn.Module):
def forward(self, x): def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1) x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate) return x * F.gelu(gate)
# TODO(Patrick) - remove once all weights have been converted -> not needed anymore then
class NIN(nn.Module):
def __init__(self, in_dim, num_units, init_scale=0.1):
super().__init__()
self.W = nn.Parameter(torch.zeros(in_dim, num_units), requires_grad=True)
self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)
def exists(val):
return val is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
# the main attention block that is used for all models
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=None,
num_groups=32,
encoder_channels=None,
overwrite_qkv=False,
overwrite_linear=False,
rescale_output_factor=1.0,
eps=1e-5,
):
super().__init__()
self.channels = channels
if num_head_channels is None:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
self.qkv = nn.Conv1d(channels, channels * 3, 1)
self.n_heads = self.num_heads
self.rescale_output_factor = rescale_output_factor
if encoder_channels is not None:
self.encoder_kv = nn.Conv1d(encoder_channels, channels * 2, 1)
self.proj = nn.Conv1d(channels, channels, 1)
self.overwrite_qkv = overwrite_qkv
self.overwrite_linear = overwrite_linear
if overwrite_qkv:
in_channels = channels
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
elif self.overwrite_linear:
num_groups = min(channels // 4, 32)
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
self.NIN_0 = NIN(channels, channels)
self.NIN_1 = NIN(channels, channels)
self.NIN_2 = NIN(channels, channels)
self.NIN_3 = NIN(channels, channels)
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6)
else:
self.proj_out = nn.Conv1d(channels, channels, 1)
self.set_weights(self)
self.is_overwritten = False
def set_weights(self, module):
if self.overwrite_qkv:
qkv_weight = torch.cat([module.q.weight.data, module.k.weight.data, module.v.weight.data], dim=0)[
:, :, :, 0
]
qkv_bias = torch.cat([module.q.bias.data, module.k.bias.data, module.v.bias.data], dim=0)
self.qkv.weight.data = qkv_weight
self.qkv.bias.data = qkv_bias
proj_out = nn.Conv1d(self.channels, self.channels, 1)
proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0]
proj_out.bias.data = module.proj_out.bias.data
self.proj = proj_out
elif self.overwrite_linear:
self.qkv.weight.data = torch.concat(
[self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0
)[:, :, None]
self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0)
self.proj.weight.data = self.NIN_3.W.data.T[:, :, None]
self.proj.bias.data = self.NIN_3.b.data
self.norm.weight.data = self.GroupNorm_0.weight.data
self.norm.bias.data = self.GroupNorm_0.bias.data
else:
self.proj.weight.data = self.proj_out.weight.data
self.proj.bias.data = self.proj_out.bias.data
def forward(self, x, encoder_out=None):
if not self.is_overwritten and (self.overwrite_qkv or self.overwrite_linear):
self.set_weights(self)
self.is_overwritten = True
b, c, *spatial = x.shape
hid_states = self.norm(x).view(b, c, -1)
qkv = self.qkv(hid_states)
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
if encoder_out is not None:
encoder_kv = self.encoder_kv(encoder_out)
assert encoder_kv.shape[1] == self.n_heads * ch * 2
ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
k = torch.cat([ek, k], dim=-1)
v = torch.cat([ev, v], dim=-1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = torch.einsum("bts,bcs->bct", weight, v)
h = a.reshape(bs, -1, length)
h = self.proj(h)
h = h.reshape(b, c, *spatial)
result = x + h
result = result / self.rescale_output_factor
return result
...@@ -248,7 +248,7 @@ class FirDownsample2D(nn.Module): ...@@ -248,7 +248,7 @@ class FirDownsample2D(nn.Module):
return x return x
class ResnetBlock(nn.Module): class ResnetBlock2D(nn.Module):
def __init__( def __init__(
self, self,
*, *,
......
...@@ -17,8 +17,8 @@ import numpy as np ...@@ -17,8 +17,8 @@ import numpy as np
import torch import torch
from torch import nn from torch import nn
from .attention import AttentionBlockNew, SpatialTransformer from .attention import AttentionBlock, SpatialTransformer
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock, Upsample2D from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
def get_down_block( def get_down_block(
...@@ -219,7 +219,7 @@ class UNetMidBlock2D(nn.Module): ...@@ -219,7 +219,7 @@ class UNetMidBlock2D(nn.Module):
# there is always at least one resnet # there is always at least one resnet
resnets = [ resnets = [
ResnetBlock( ResnetBlock2D(
in_channels=in_channels, in_channels=in_channels,
out_channels=in_channels, out_channels=in_channels,
temb_channels=temb_channels, temb_channels=temb_channels,
...@@ -236,7 +236,7 @@ class UNetMidBlock2D(nn.Module): ...@@ -236,7 +236,7 @@ class UNetMidBlock2D(nn.Module):
for _ in range(num_layers): for _ in range(num_layers):
attentions.append( attentions.append(
AttentionBlockNew( AttentionBlock(
in_channels, in_channels,
num_head_channels=attn_num_head_channels, num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor, rescale_output_factor=output_scale_factor,
...@@ -245,7 +245,7 @@ class UNetMidBlock2D(nn.Module): ...@@ -245,7 +245,7 @@ class UNetMidBlock2D(nn.Module):
) )
) )
resnets.append( resnets.append(
ResnetBlock( ResnetBlock2D(
in_channels=in_channels, in_channels=in_channels,
out_channels=in_channels, out_channels=in_channels,
temb_channels=temb_channels, temb_channels=temb_channels,
...@@ -299,7 +299,7 @@ class UNetMidBlock2DCrossAttn(nn.Module): ...@@ -299,7 +299,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
# there is always at least one resnet # there is always at least one resnet
resnets = [ resnets = [
ResnetBlock( ResnetBlock2D(
in_channels=in_channels, in_channels=in_channels,
out_channels=in_channels, out_channels=in_channels,
temb_channels=temb_channels, temb_channels=temb_channels,
...@@ -325,7 +325,7 @@ class UNetMidBlock2DCrossAttn(nn.Module): ...@@ -325,7 +325,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
) )
) )
resnets.append( resnets.append(
ResnetBlock( ResnetBlock2D(
in_channels=in_channels, in_channels=in_channels,
out_channels=in_channels, out_channels=in_channels,
temb_channels=temb_channels, temb_channels=temb_channels,
...@@ -379,7 +379,7 @@ class AttnDownBlock2D(nn.Module): ...@@ -379,7 +379,7 @@ class AttnDownBlock2D(nn.Module):
for i in range(num_layers): for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels in_channels = in_channels if i == 0 else out_channels
resnets.append( resnets.append(
ResnetBlock( ResnetBlock2D(
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
temb_channels=temb_channels, temb_channels=temb_channels,
...@@ -393,7 +393,7 @@ class AttnDownBlock2D(nn.Module): ...@@ -393,7 +393,7 @@ class AttnDownBlock2D(nn.Module):
) )
) )
attentions.append( attentions.append(
AttentionBlockNew( AttentionBlock(
out_channels, out_channels,
num_head_channels=attn_num_head_channels, num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor, rescale_output_factor=output_scale_factor,
...@@ -461,7 +461,7 @@ class CrossAttnDownBlock2D(nn.Module): ...@@ -461,7 +461,7 @@ class CrossAttnDownBlock2D(nn.Module):
for i in range(num_layers): for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels in_channels = in_channels if i == 0 else out_channels
resnets.append( resnets.append(
ResnetBlock( ResnetBlock2D(
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
temb_channels=temb_channels, temb_channels=temb_channels,
...@@ -537,7 +537,7 @@ class DownBlock2D(nn.Module): ...@@ -537,7 +537,7 @@ class DownBlock2D(nn.Module):
for i in range(num_layers): for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels in_channels = in_channels if i == 0 else out_channels
resnets.append( resnets.append(
ResnetBlock( ResnetBlock2D(
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
temb_channels=temb_channels, temb_channels=temb_channels,
...@@ -602,7 +602,7 @@ class DownEncoderBlock2D(nn.Module): ...@@ -602,7 +602,7 @@ class DownEncoderBlock2D(nn.Module):
for i in range(num_layers): for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels in_channels = in_channels if i == 0 else out_channels
resnets.append( resnets.append(
ResnetBlock( ResnetBlock2D(
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
temb_channels=None, temb_channels=None,
...@@ -664,7 +664,7 @@ class AttnDownEncoderBlock2D(nn.Module): ...@@ -664,7 +664,7 @@ class AttnDownEncoderBlock2D(nn.Module):
for i in range(num_layers): for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels in_channels = in_channels if i == 0 else out_channels
resnets.append( resnets.append(
ResnetBlock( ResnetBlock2D(
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
temb_channels=None, temb_channels=None,
...@@ -678,7 +678,7 @@ class AttnDownEncoderBlock2D(nn.Module): ...@@ -678,7 +678,7 @@ class AttnDownEncoderBlock2D(nn.Module):
) )
) )
attentions.append( attentions.append(
AttentionBlockNew( AttentionBlock(
out_channels, out_channels,
num_head_channels=attn_num_head_channels, num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor, rescale_output_factor=output_scale_factor,
...@@ -740,7 +740,7 @@ class AttnSkipDownBlock2D(nn.Module): ...@@ -740,7 +740,7 @@ class AttnSkipDownBlock2D(nn.Module):
for i in range(num_layers): for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels in_channels = in_channels if i == 0 else out_channels
self.resnets.append( self.resnets.append(
ResnetBlock( ResnetBlock2D(
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
temb_channels=temb_channels, temb_channels=temb_channels,
...@@ -755,7 +755,7 @@ class AttnSkipDownBlock2D(nn.Module): ...@@ -755,7 +755,7 @@ class AttnSkipDownBlock2D(nn.Module):
) )
) )
self.attentions.append( self.attentions.append(
AttentionBlockNew( AttentionBlock(
out_channels, out_channels,
num_head_channels=attn_num_head_channels, num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor, rescale_output_factor=output_scale_factor,
...@@ -764,7 +764,7 @@ class AttnSkipDownBlock2D(nn.Module): ...@@ -764,7 +764,7 @@ class AttnSkipDownBlock2D(nn.Module):
) )
if add_downsample: if add_downsample:
self.resnet_down = ResnetBlock( self.resnet_down = ResnetBlock2D(
in_channels=out_channels, in_channels=out_channels,
out_channels=out_channels, out_channels=out_channels,
temb_channels=temb_channels, temb_channels=temb_channels,
...@@ -828,7 +828,7 @@ class SkipDownBlock2D(nn.Module): ...@@ -828,7 +828,7 @@ class SkipDownBlock2D(nn.Module):
for i in range(num_layers): for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels in_channels = in_channels if i == 0 else out_channels
self.resnets.append( self.resnets.append(
ResnetBlock( ResnetBlock2D(
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
temb_channels=temb_channels, temb_channels=temb_channels,
...@@ -844,7 +844,7 @@ class SkipDownBlock2D(nn.Module): ...@@ -844,7 +844,7 @@ class SkipDownBlock2D(nn.Module):
) )
if add_downsample: if add_downsample:
self.resnet_down = ResnetBlock( self.resnet_down = ResnetBlock2D(
in_channels=out_channels, in_channels=out_channels,
out_channels=out_channels, out_channels=out_channels,
temb_channels=temb_channels, temb_channels=temb_channels,
...@@ -915,7 +915,7 @@ class AttnUpBlock2D(nn.Module): ...@@ -915,7 +915,7 @@ class AttnUpBlock2D(nn.Module):
resnet_in_channels = prev_output_channel if i == 0 else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append( resnets.append(
ResnetBlock( ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels, in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels, out_channels=out_channels,
temb_channels=temb_channels, temb_channels=temb_channels,
...@@ -929,7 +929,7 @@ class AttnUpBlock2D(nn.Module): ...@@ -929,7 +929,7 @@ class AttnUpBlock2D(nn.Module):
) )
) )
attentions.append( attentions.append(
AttentionBlockNew( AttentionBlock(
out_channels, out_channels,
num_head_channels=attn_num_head_channels, num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor, rescale_output_factor=output_scale_factor,
...@@ -995,7 +995,7 @@ class CrossAttnUpBlock2D(nn.Module): ...@@ -995,7 +995,7 @@ class CrossAttnUpBlock2D(nn.Module):
resnet_in_channels = prev_output_channel if i == 0 else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append( resnets.append(
ResnetBlock( ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels, in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels, out_channels=out_channels,
temb_channels=temb_channels, temb_channels=temb_channels,
...@@ -1068,7 +1068,7 @@ class UpBlock2D(nn.Module): ...@@ -1068,7 +1068,7 @@ class UpBlock2D(nn.Module):
resnet_in_channels = prev_output_channel if i == 0 else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append( resnets.append(
ResnetBlock( ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels, in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels, out_channels=out_channels,
temb_channels=temb_channels, temb_channels=temb_channels,
...@@ -1128,7 +1128,7 @@ class UpDecoderBlock2D(nn.Module): ...@@ -1128,7 +1128,7 @@ class UpDecoderBlock2D(nn.Module):
input_channels = in_channels if i == 0 else out_channels input_channels = in_channels if i == 0 else out_channels
resnets.append( resnets.append(
ResnetBlock( ResnetBlock2D(
in_channels=input_channels, in_channels=input_channels,
out_channels=out_channels, out_channels=out_channels,
temb_channels=None, temb_channels=None,
...@@ -1184,7 +1184,7 @@ class AttnUpDecoderBlock2D(nn.Module): ...@@ -1184,7 +1184,7 @@ class AttnUpDecoderBlock2D(nn.Module):
input_channels = in_channels if i == 0 else out_channels input_channels = in_channels if i == 0 else out_channels
resnets.append( resnets.append(
ResnetBlock( ResnetBlock2D(
in_channels=input_channels, in_channels=input_channels,
out_channels=out_channels, out_channels=out_channels,
temb_channels=None, temb_channels=None,
...@@ -1198,7 +1198,7 @@ class AttnUpDecoderBlock2D(nn.Module): ...@@ -1198,7 +1198,7 @@ class AttnUpDecoderBlock2D(nn.Module):
) )
) )
attentions.append( attentions.append(
AttentionBlockNew( AttentionBlock(
out_channels, out_channels,
num_head_channels=attn_num_head_channels, num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor, rescale_output_factor=output_scale_factor,
...@@ -1257,7 +1257,7 @@ class AttnSkipUpBlock2D(nn.Module): ...@@ -1257,7 +1257,7 @@ class AttnSkipUpBlock2D(nn.Module):
resnet_in_channels = prev_output_channel if i == 0 else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels
self.resnets.append( self.resnets.append(
ResnetBlock( ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels, in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels, out_channels=out_channels,
temb_channels=temb_channels, temb_channels=temb_channels,
...@@ -1273,7 +1273,7 @@ class AttnSkipUpBlock2D(nn.Module): ...@@ -1273,7 +1273,7 @@ class AttnSkipUpBlock2D(nn.Module):
) )
self.attentions.append( self.attentions.append(
AttentionBlockNew( AttentionBlock(
out_channels, out_channels,
num_head_channels=attn_num_head_channels, num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor, rescale_output_factor=output_scale_factor,
...@@ -1283,7 +1283,7 @@ class AttnSkipUpBlock2D(nn.Module): ...@@ -1283,7 +1283,7 @@ class AttnSkipUpBlock2D(nn.Module):
self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
if add_upsample: if add_upsample:
self.resnet_up = ResnetBlock( self.resnet_up = ResnetBlock2D(
in_channels=out_channels, in_channels=out_channels,
out_channels=out_channels, out_channels=out_channels,
temb_channels=temb_channels, temb_channels=temb_channels,
...@@ -1363,7 +1363,7 @@ class SkipUpBlock2D(nn.Module): ...@@ -1363,7 +1363,7 @@ class SkipUpBlock2D(nn.Module):
resnet_in_channels = prev_output_channel if i == 0 else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels
self.resnets.append( self.resnets.append(
ResnetBlock( ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels, in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels, out_channels=out_channels,
temb_channels=temb_channels, temb_channels=temb_channels,
...@@ -1380,7 +1380,7 @@ class SkipUpBlock2D(nn.Module): ...@@ -1380,7 +1380,7 @@ class SkipUpBlock2D(nn.Module):
self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
if add_upsample: if add_upsample:
self.resnet_up = ResnetBlock( self.resnet_up = ResnetBlock2D(
in_channels=out_channels, in_channels=out_channels,
out_channels=out_channels, out_channels=out_channels,
temb_channels=temb_channels, temb_channels=temb_channels,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment