Commit 4c293e0e authored by patil-suraj's avatar patil-suraj
Browse files

fix bias when using fir up/down sample

parent 516cb9e7
......@@ -17,7 +17,6 @@
import functools
import math
from unicodedata import name
import numpy as np
import torch
......@@ -197,6 +196,7 @@ class FirUpsample(nn.Module):
def forward(self, x):
if self.use_conv:
h = _upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
h = h + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else:
h = upsample_2d(x, self.fir_kernel, factor=2)
......@@ -216,6 +216,7 @@ class FirDownsample(nn.Module):
def forward(self, x):
if self.use_conv:
x = _conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else:
x = downsample_2d(x, self.fir_kernel, factor=2)
......@@ -313,7 +314,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
if self.fir:
Up_sample = functools.partial(FirUpsample, fir_kernel=fir_kernel)
Up_sample = functools.partial(FirUpsample, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
else:
Up_sample = functools.partial(Upsample, name="Conv2d_0")
......@@ -323,9 +324,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
pyramid_upsample = functools.partial(Up_sample, use_conv=True)
if self.fir:
Down_sample = functools.partial(FirDownsample, fir_kernel=fir_kernel)
Down_sample = functools.partial(FirDownsample, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
else:
print("fir false")
Down_sample = functools.partial(Downsample, padding=0, name="Conv2d_0")
if progressive_input == "input_skip":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment