"vscode:/vscode.git/clone" did not exist on "b25ae2e6abe8eb2101a515e6e2ec7e2efd139856"
Commit 663393e2 authored by patil-suraj's avatar patil-suraj
Browse files

remove fir option

parent c50d9975
......@@ -579,7 +579,6 @@ class ResnetBlockBigGANpp(nn.Module):
up=False,
down=False,
dropout=0.1,
fir=False,
fir_kernel=(1, 3, 3, 1),
skip_rescale=True,
init_scale=0.0,
......@@ -590,7 +589,6 @@ class ResnetBlockBigGANpp(nn.Module):
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
self.up = up
self.down = down
self.fir = fir
self.fir_kernel = fir_kernel
self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1)
......
......@@ -334,7 +334,7 @@ class Combine(nn.Module):
class Upsample(nn.Module):
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)):
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir_kernel=(1, 3, 3, 1)):
super().__init__()
out_ch = out_ch if out_ch else in_ch
if with_conv:
......@@ -347,13 +347,11 @@ class Upsample(nn.Module):
use_bias=True,
kernel_init=variance_scaling(),
)
self.fir = fir
self.with_conv = with_conv
self.fir_kernel = fir_kernel
self.out_ch = out_ch
def forward(self, x):
B, C, H, W = x.shape
if not self.with_conv:
h = upsample_2d(x, self.fir_kernel, factor=2)
else:
......@@ -363,7 +361,7 @@ class Upsample(nn.Module):
class Downsample(nn.Module):
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)):
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir_kernel=(1, 3, 3, 1)):
super().__init__()
out_ch = out_ch if out_ch else in_ch
if with_conv:
......@@ -376,13 +374,11 @@ class Downsample(nn.Module):
use_bias=True,
kernel_init=variance_scaling(),
)
self.fir = fir
self.fir_kernel = fir_kernel
self.with_conv = with_conv
self.out_ch = out_ch
def forward(self, x):
B, C, H, W = x.shape
if not self.with_conv:
x = downsample_2d(x, self.fir_kernel, factor=2)
else:
......@@ -404,7 +400,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
conv_size=3,
dropout=0.0,
embedding_type="fourier",
fir=True,
fir=True, # TODO (patil-suraj) remove this option from here and pre-trained model configs
fir_kernel=(1, 3, 3, 1),
fourier_scale=16,
init_scale=0.0,
......@@ -428,7 +424,6 @@ class NCSNpp(ModelMixin, ConfigMixin):
conv_size=conv_size,
dropout=dropout,
embedding_type=embedding_type,
fir=fir,
fir_kernel=fir_kernel,
fourier_scale=fourier_scale,
init_scale=init_scale,
......@@ -483,25 +478,24 @@ class NCSNpp(ModelMixin, ConfigMixin):
nn.init.zeros_(modules[-1].bias)
AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
Up_sample = functools.partial(Upsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
Up_sample = functools.partial(Upsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel)
if progressive == "output_skip":
self.pyramid_upsample = Up_sample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
self.pyramid_upsample = Up_sample(fir_kernel=fir_kernel, with_conv=False)
elif progressive == "residual":
pyramid_upsample = functools.partial(Up_sample, fir=fir, fir_kernel=fir_kernel, with_conv=True)
pyramid_upsample = functools.partial(Up_sample, fir_kernel=fir_kernel, with_conv=True)
Down_sample = functools.partial(Downsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
Down_sample = functools.partial(Downsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel)
if progressive_input == "input_skip":
self.pyramid_downsample = Down_sample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
self.pyramid_downsample = Down_sample(fir_kernel=fir_kernel, with_conv=False)
elif progressive_input == "residual":
pyramid_downsample = functools.partial(Down_sample, fir=fir, fir_kernel=fir_kernel, with_conv=True)
pyramid_downsample = functools.partial(Down_sample, fir_kernel=fir_kernel, with_conv=True)
ResnetBlock = functools.partial(
ResnetBlockBigGANpp,
act=act,
dropout=dropout,
fir=fir,
fir_kernel=fir_kernel,
init_scale=init_scale,
skip_rescale=skip_rescale,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment