Commit 185347e4 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

up

parent c1c4dea9
...@@ -207,9 +207,6 @@ class ResBlock(TimestepBlock): ...@@ -207,9 +207,6 @@ class ResBlock(TimestepBlock):
self.updown = up or down self.updown = up or down
# if self.updown:
# import ipdb; ipdb.set_trace()
if up: if up:
self.h_upd = Upsample(channels, use_conv=False, dims=dims) self.h_upd = Upsample(channels, use_conv=False, dims=dims)
self.x_upd = Upsample(channels, use_conv=False, dims=dims) self.x_upd = Upsample(channels, use_conv=False, dims=dims)
...@@ -227,8 +224,10 @@ class ResBlock(TimestepBlock): ...@@ -227,8 +224,10 @@ class ResBlock(TimestepBlock):
), ),
) )
self.out_layers = nn.Sequential( self.out_layers = nn.Sequential(
normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0), # normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
nn.SiLU() if use_scale_shift_norm else nn.Identity(), # nn.SiLU() if use_scale_shift_norm else nn.Identity(),
normalization(self.out_channels, swish=0.0),
nn.SiLU(),
nn.Dropout(p=dropout), nn.Dropout(p=dropout),
zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)), zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
) )
...@@ -322,6 +321,7 @@ class ResBlock(TimestepBlock): ...@@ -322,6 +321,7 @@ class ResBlock(TimestepBlock):
emb_out = self.emb_layers(emb).type(h.dtype) emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape): while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None] emb_out = emb_out[..., None]
if self.use_scale_shift_norm: if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:] out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = torch.chunk(emb_out, 2, dim=1) scale, shift = torch.chunk(emb_out, 2, dim=1)
...@@ -338,35 +338,31 @@ class ResBlock(TimestepBlock): ...@@ -338,35 +338,31 @@ class ResBlock(TimestepBlock):
return result return result
def forward_2(self, x, temb, mask=1.0): def forward_2(self, x, temb):
if self.overwrite and not self.is_overwritten: if self.overwrite and not self.is_overwritten:
self.set_weights() self.set_weights()
self.is_overwritten = True self.is_overwritten = True
h = x h = x
if self.pre_norm:
h = self.norm1(h) h = self.norm1(h)
h = self.nonlinearity(h) h = self.nonlinearity(h)
h = self.conv1(h) h = self.conv1(h)
if not self.pre_norm: temb = self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
h = self.norm1(h)
h = self.nonlinearity(h)
h = h + self.temb_proj(self.nonlinearity(temb))[:, :, None, None] scale, shift = torch.chunk(temb, 2, dim=1)
h = self.norm2(h)
h = h * scale + shift
if self.pre_norm:
h = self.norm2(h) h = self.norm2(h)
h = self.nonlinearity(h) h = self.nonlinearity(h)
h = self.dropout(h) h = self.dropout(h)
h = self.conv2(h) h = self.conv2(h)
if not self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
if self.in_channels != self.out_channels: if self.in_channels != self.out_channels:
if self.use_conv_shortcut: if self.use_conv_shortcut:
x = self.conv_shortcut(x) x = self.conv_shortcut(x)
...@@ -376,7 +372,7 @@ class ResBlock(TimestepBlock): ...@@ -376,7 +372,7 @@ class ResBlock(TimestepBlock):
return x + h return x + h
# unet.py and unet_grad_tts.py # unet.py, unet_grad_tts.py, unet_ldm.py
class ResnetBlock(nn.Module): class ResnetBlock(nn.Module):
def __init__( def __init__(
self, self,
...@@ -410,6 +406,7 @@ class ResnetBlock(nn.Module): ...@@ -410,6 +406,7 @@ class ResnetBlock(nn.Module):
self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps) self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps)
self.dropout = torch.nn.Dropout(dropout) self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish": if non_linearity == "swish":
self.nonlinearity = nonlinearity self.nonlinearity = nonlinearity
elif non_linearity == "mish": elif non_linearity == "mish":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment