Commit c9bd4d43 authored by patil-suraj's avatar patil-suraj
Browse files

remove if fir from resent block and upsample, downsample for sde unet

parent 7e0fd19f
...@@ -614,19 +614,11 @@ class ResnetBlockBigGANpp(nn.Module): ...@@ -614,19 +614,11 @@ class ResnetBlockBigGANpp(nn.Module):
h = self.act(self.GroupNorm_0(x)) h = self.act(self.GroupNorm_0(x))
if self.up: if self.up:
if self.fir:
h = upsample_2d(h, self.fir_kernel, factor=2) h = upsample_2d(h, self.fir_kernel, factor=2)
x = upsample_2d(x, self.fir_kernel, factor=2) x = upsample_2d(x, self.fir_kernel, factor=2)
else:
h = naive_upsample_2d(h, factor=2)
x = naive_upsample_2d(x, factor=2)
elif self.down: elif self.down:
if self.fir:
h = downsample_2d(h, self.fir_kernel, factor=2) h = downsample_2d(h, self.fir_kernel, factor=2)
x = downsample_2d(x, self.fir_kernel, factor=2) x = downsample_2d(x, self.fir_kernel, factor=2)
else:
h = naive_downsample_2d(h, factor=2)
x = naive_downsample_2d(x, factor=2)
h = self.Conv_0(h) h = self.Conv_0(h)
# Add bias to each feature map conditioned on the time embedding # Add bias to each feature map conditioned on the time embedding
......
...@@ -417,10 +417,6 @@ class Upsample(nn.Module): ...@@ -417,10 +417,6 @@ class Upsample(nn.Module):
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)): def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)):
super().__init__() super().__init__()
out_ch = out_ch if out_ch else in_ch out_ch = out_ch if out_ch else in_ch
if not fir:
if with_conv:
self.Conv_0 = conv3x3(in_ch, out_ch)
else:
if with_conv: if with_conv:
self.Conv2d_0 = Conv2d( self.Conv2d_0 = Conv2d(
in_ch, in_ch,
...@@ -438,11 +434,6 @@ class Upsample(nn.Module): ...@@ -438,11 +434,6 @@ class Upsample(nn.Module):
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape B, C, H, W = x.shape
if not self.fir:
h = F.interpolate(x, (H * 2, W * 2), "nearest")
if self.with_conv:
h = self.Conv_0(h)
else:
if not self.with_conv: if not self.with_conv:
h = upsample_2d(x, self.fir_kernel, factor=2) h = upsample_2d(x, self.fir_kernel, factor=2)
else: else:
...@@ -455,10 +446,6 @@ class Downsample(nn.Module): ...@@ -455,10 +446,6 @@ class Downsample(nn.Module):
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)): def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)):
super().__init__() super().__init__()
out_ch = out_ch if out_ch else in_ch out_ch = out_ch if out_ch else in_ch
if not fir:
if with_conv:
self.Conv_0 = conv3x3(in_ch, out_ch, stride=2, padding=0)
else:
if with_conv: if with_conv:
self.Conv2d_0 = Conv2d( self.Conv2d_0 = Conv2d(
in_ch, in_ch,
...@@ -476,13 +463,6 @@ class Downsample(nn.Module): ...@@ -476,13 +463,6 @@ class Downsample(nn.Module):
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape B, C, H, W = x.shape
if not self.fir:
if self.with_conv:
x = F.pad(x, (0, 1, 0, 1))
x = self.Conv_0(x)
else:
x = F.avg_pool2d(x, 2, stride=2)
else:
if not self.with_conv: if not self.with_conv:
x = downsample_2d(x, self.fir_kernel, factor=2) x = downsample_2d(x, self.fir_kernel, factor=2)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment