Commit 4c293e0e authored by patil-suraj's avatar patil-suraj
Browse files

fix bias when using fir up/down sample

parent 516cb9e7
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
import functools import functools
import math import math
from unicodedata import name
import numpy as np import numpy as np
import torch import torch
...@@ -197,6 +196,7 @@ class FirUpsample(nn.Module): ...@@ -197,6 +196,7 @@ class FirUpsample(nn.Module):
def forward(self, x): def forward(self, x):
if self.use_conv: if self.use_conv:
h = _upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel) h = _upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
h = h + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else: else:
h = upsample_2d(x, self.fir_kernel, factor=2) h = upsample_2d(x, self.fir_kernel, factor=2)
...@@ -216,6 +216,7 @@ class FirDownsample(nn.Module): ...@@ -216,6 +216,7 @@ class FirDownsample(nn.Module):
def forward(self, x): def forward(self, x):
if self.use_conv: if self.use_conv:
x = _conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel) x = _conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else: else:
x = downsample_2d(x, self.fir_kernel, factor=2) x = downsample_2d(x, self.fir_kernel, factor=2)
...@@ -313,7 +314,7 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -313,7 +314,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0)) AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
if self.fir: if self.fir:
Up_sample = functools.partial(FirUpsample, fir_kernel=fir_kernel) Up_sample = functools.partial(FirUpsample, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
else: else:
Up_sample = functools.partial(Upsample, name="Conv2d_0") Up_sample = functools.partial(Upsample, name="Conv2d_0")
...@@ -323,9 +324,8 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -323,9 +324,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
pyramid_upsample = functools.partial(Up_sample, use_conv=True) pyramid_upsample = functools.partial(Up_sample, use_conv=True)
if self.fir: if self.fir:
Down_sample = functools.partial(FirDownsample, fir_kernel=fir_kernel) Down_sample = functools.partial(FirDownsample, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
else: else:
print("fir false")
Down_sample = functools.partial(Downsample, padding=0, name="Conv2d_0") Down_sample = functools.partial(Downsample, padding=0, name="Conv2d_0")
if progressive_input == "input_skip": if progressive_input == "input_skip":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment