Unverified Commit c691bb2f authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

Merge pull request #60 from huggingface/add-fir-back

fix unde sde for vp model.
parents abedfb08 4c293e0e
from abc import abstractmethod from abc import abstractmethod
from functools import partial
import numpy as np import numpy as np
import torch import torch
...@@ -78,18 +79,25 @@ class Upsample(nn.Module): ...@@ -78,18 +79,25 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions. upsampling occurs in the inner-two dimensions.
""" """
def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None): def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None, name="conv"):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
self.use_conv = use_conv self.use_conv = use_conv
self.dims = dims self.dims = dims
self.use_conv_transpose = use_conv_transpose self.use_conv_transpose = use_conv_transpose
self.name = name
conv = None
if use_conv_transpose: if use_conv_transpose:
self.conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1) conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1)
elif use_conv: elif use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
if name == "conv":
self.conv = conv
else:
self.Conv2d_0 = conv
def forward(self, x): def forward(self, x):
assert x.shape[1] == self.channels assert x.shape[1] == self.channels
...@@ -102,7 +110,10 @@ class Upsample(nn.Module): ...@@ -102,7 +110,10 @@ class Upsample(nn.Module):
x = F.interpolate(x, scale_factor=2.0, mode="nearest") x = F.interpolate(x, scale_factor=2.0, mode="nearest")
if self.use_conv: if self.use_conv:
x = self.conv(x) if self.name == "conv":
x = self.conv(x)
else:
x = self.Conv2d_0(x)
return x return x
...@@ -134,6 +145,8 @@ class Downsample(nn.Module): ...@@ -134,6 +145,8 @@ class Downsample(nn.Module):
if name == "conv": if name == "conv":
self.conv = conv self.conv = conv
elif name == "Conv2d_0":
self.Conv2d_0 = conv
else: else:
self.op = conv self.op = conv
...@@ -145,6 +158,8 @@ class Downsample(nn.Module): ...@@ -145,6 +158,8 @@ class Downsample(nn.Module):
if self.name == "conv": if self.name == "conv":
return self.conv(x) return self.conv(x)
elif self.name == "Conv2d_0":
return self.Conv2d_0(x)
else: else:
return self.op(x) return self.op(x)
...@@ -390,6 +405,7 @@ class ResnetBlockBigGANpp(nn.Module): ...@@ -390,6 +405,7 @@ class ResnetBlockBigGANpp(nn.Module):
up=False, up=False,
down=False, down=False,
dropout=0.1, dropout=0.1,
fir=False,
fir_kernel=(1, 3, 3, 1), fir_kernel=(1, 3, 3, 1),
skip_rescale=True, skip_rescale=True,
init_scale=0.0, init_scale=0.0,
...@@ -400,8 +416,20 @@ class ResnetBlockBigGANpp(nn.Module): ...@@ -400,8 +416,20 @@ class ResnetBlockBigGANpp(nn.Module):
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
self.up = up self.up = up
self.down = down self.down = down
self.fir = fir
self.fir_kernel = fir_kernel self.fir_kernel = fir_kernel
if self.up:
if self.fir:
self.upsample = partial(upsample_2d, k=self.fir_kernel, factor=2)
else:
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
elif self.down:
if self.fir:
self.downsample = partial(downsample_2d, k=self.fir_kernel, factor=2)
else:
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1) self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1)
if temb_dim is not None: if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch) self.Dense_0 = nn.Linear(temb_dim, out_ch)
...@@ -424,11 +452,11 @@ class ResnetBlockBigGANpp(nn.Module): ...@@ -424,11 +452,11 @@ class ResnetBlockBigGANpp(nn.Module):
h = self.act(self.GroupNorm_0(x)) h = self.act(self.GroupNorm_0(x))
if self.up: if self.up:
h = upsample_2d(h, self.fir_kernel, factor=2) h = self.upsample(h)
x = upsample_2d(x, self.fir_kernel, factor=2) x = self.upsample(x)
elif self.down: elif self.down:
h = downsample_2d(h, self.fir_kernel, factor=2) h = self.downsample(h)
x = downsample_2d(x, self.fir_kernel, factor=2) x = self.downsample(x)
h = self.Conv_0(h) h = self.Conv_0(h)
# Add bias to each feature map conditioned on the time embedding # Add bias to each feature map conditioned on the time embedding
......
...@@ -27,7 +27,7 @@ from ..configuration_utils import ConfigMixin ...@@ -27,7 +27,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import AttentionBlock from .attention import AttentionBlock
from .embeddings import GaussianFourierProjection, get_timestep_embedding from .embeddings import GaussianFourierProjection, get_timestep_embedding
from .resnet import ResnetBlockBigGANpp, downsample_2d, upfirdn2d, upsample_2d from .resnet import Downsample, ResnetBlockBigGANpp, Upsample, downsample_2d, upfirdn2d, upsample_2d
def _setup_kernel(k): def _setup_kernel(k):
...@@ -184,18 +184,19 @@ class Combine(nn.Module): ...@@ -184,18 +184,19 @@ class Combine(nn.Module):
class FirUpsample(nn.Module): class FirUpsample(nn.Module):
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir_kernel=(1, 3, 3, 1)): def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
super().__init__() super().__init__()
out_ch = out_ch if out_ch else in_ch out_channels = out_channels if out_channels else channels
if with_conv: if use_conv:
self.Conv2d_0 = Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1) self.Conv2d_0 = Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
self.with_conv = with_conv self.use_conv = use_conv
self.fir_kernel = fir_kernel self.fir_kernel = fir_kernel
self.out_ch = out_ch self.out_channels = out_channels
def forward(self, x): def forward(self, x):
if self.with_conv: if self.use_conv:
h = _upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel) h = _upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
h = h + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else: else:
h = upsample_2d(x, self.fir_kernel, factor=2) h = upsample_2d(x, self.fir_kernel, factor=2)
...@@ -203,18 +204,19 @@ class FirUpsample(nn.Module): ...@@ -203,18 +204,19 @@ class FirUpsample(nn.Module):
class FirDownsample(nn.Module): class FirDownsample(nn.Module):
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir_kernel=(1, 3, 3, 1)): def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
super().__init__() super().__init__()
out_ch = out_ch if out_ch else in_ch out_channels = out_channels if out_channels else channels
if with_conv: if use_conv:
self.Conv2d_0 = self.Conv2d_0 = Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1) self.Conv2d_0 = self.Conv2d_0 = Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
self.fir_kernel = fir_kernel self.fir_kernel = fir_kernel
self.with_conv = with_conv self.use_conv = use_conv
self.out_ch = out_ch self.out_channels = out_channels
def forward(self, x): def forward(self, x):
if self.with_conv: if self.use_conv:
x = _conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel) x = _conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else: else:
x = downsample_2d(x, self.fir_kernel, factor=2) x = downsample_2d(x, self.fir_kernel, factor=2)
...@@ -228,13 +230,14 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -228,13 +230,14 @@ class NCSNpp(ModelMixin, ConfigMixin):
self, self,
image_size=1024, image_size=1024,
num_channels=3, num_channels=3,
centered=False,
attn_resolutions=(16,), attn_resolutions=(16,),
ch_mult=(1, 2, 4, 8, 16, 32, 32, 32), ch_mult=(1, 2, 4, 8, 16, 32, 32, 32),
conditional=True, conditional=True,
conv_size=3, conv_size=3,
dropout=0.0, dropout=0.0,
embedding_type="fourier", embedding_type="fourier",
fir=True, # TODO (patil-suraj) remove this option from here and pre-trained model configs fir=True,
fir_kernel=(1, 3, 3, 1), fir_kernel=(1, 3, 3, 1),
fourier_scale=16, fourier_scale=16,
init_scale=0.0, init_scale=0.0,
...@@ -252,12 +255,14 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -252,12 +255,14 @@ class NCSNpp(ModelMixin, ConfigMixin):
self.register_to_config( self.register_to_config(
image_size=image_size, image_size=image_size,
num_channels=num_channels, num_channels=num_channels,
centered=centered,
attn_resolutions=attn_resolutions, attn_resolutions=attn_resolutions,
ch_mult=ch_mult, ch_mult=ch_mult,
conditional=conditional, conditional=conditional,
conv_size=conv_size, conv_size=conv_size,
dropout=dropout, dropout=dropout,
embedding_type=embedding_type, embedding_type=embedding_type,
fir=fir,
fir_kernel=fir_kernel, fir_kernel=fir_kernel,
fourier_scale=fourier_scale, fourier_scale=fourier_scale,
init_scale=init_scale, init_scale=init_scale,
...@@ -307,24 +312,32 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -307,24 +312,32 @@ class NCSNpp(ModelMixin, ConfigMixin):
modules.append(Linear(nf * 4, nf * 4)) modules.append(Linear(nf * 4, nf * 4))
AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0)) AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
Up_sample = functools.partial(FirUpsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel)
if self.fir:
Up_sample = functools.partial(FirUpsample, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
else:
Up_sample = functools.partial(Upsample, name="Conv2d_0")
if progressive == "output_skip": if progressive == "output_skip":
self.pyramid_upsample = Up_sample(fir_kernel=fir_kernel, with_conv=False) self.pyramid_upsample = Up_sample(channels=None, use_conv=False)
elif progressive == "residual": elif progressive == "residual":
pyramid_upsample = functools.partial(Up_sample, fir_kernel=fir_kernel, with_conv=True) pyramid_upsample = functools.partial(Up_sample, use_conv=True)
Down_sample = functools.partial(FirDownsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel) if self.fir:
Down_sample = functools.partial(FirDownsample, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
else:
Down_sample = functools.partial(Downsample, padding=0, name="Conv2d_0")
if progressive_input == "input_skip": if progressive_input == "input_skip":
self.pyramid_downsample = Down_sample(fir_kernel=fir_kernel, with_conv=False) self.pyramid_downsample = Down_sample(channels=None, use_conv=False)
elif progressive_input == "residual": elif progressive_input == "residual":
pyramid_downsample = functools.partial(Down_sample, fir_kernel=fir_kernel, with_conv=True) pyramid_downsample = functools.partial(Down_sample, use_conv=True)
ResnetBlock = functools.partial( ResnetBlock = functools.partial(
ResnetBlockBigGANpp, ResnetBlockBigGANpp,
act=act, act=act,
dropout=dropout, dropout=dropout,
fir=fir,
fir_kernel=fir_kernel, fir_kernel=fir_kernel,
init_scale=init_scale, init_scale=init_scale,
skip_rescale=skip_rescale, skip_rescale=skip_rescale,
...@@ -361,7 +374,7 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -361,7 +374,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
in_ch *= 2 in_ch *= 2
elif progressive_input == "residual": elif progressive_input == "residual":
modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch)) modules.append(pyramid_downsample(channels=input_pyramid_ch, out_channels=in_ch))
input_pyramid_ch = in_ch input_pyramid_ch = in_ch
hs_c.append(in_ch) hs_c.append(in_ch)
...@@ -402,7 +415,7 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -402,7 +415,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
) )
pyramid_ch = channels pyramid_ch = channels
elif progressive == "residual": elif progressive == "residual":
modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch)) modules.append(pyramid_upsample(channels=pyramid_ch, out_channels=in_ch))
pyramid_ch = in_ch pyramid_ch = in_ch
else: else:
raise ValueError(f"{progressive} is not a valid name") raise ValueError(f"{progressive} is not a valid name")
...@@ -446,7 +459,8 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -446,7 +459,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
temb = None temb = None
# If input data is in [0, 1] # If input data is in [0, 1]
x = 2 * x - 1.0 if not self.config.centered:
x = 2 * x - 1.0
# Downsampling block # Downsampling block
input_pyramid = None input_pyramid = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment