Commit 1e7e23a9 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

Merge branch 'fuse_final_resnets' of https://github.com/huggingface/diffusers...

Merge branch 'fuse_final_resnets' of https://github.com/huggingface/diffusers into fuse_final_resnets
parents b8415bb4 3a15afac
import argparse
import torch
from diffusers.pipelines.bddm import DiffWave, BDDMPipeline
from diffusers import DDPMScheduler
def convert_bddm_orginal(checkpoint_path, noise_scheduler_checkpoint_path, output_path):
sd = torch.load(checkpoint_path, map_location="cpu")["model_state_dict"]
noise_scheduler_sd = torch.load(noise_scheduler_checkpoint_path, map_location="cpu")
model = DiffWave()
model.load_state_dict(sd, strict=False)
ts, _, betas, _ = noise_scheduler_sd
ts, betas = list(ts.numpy().tolist()), list(betas.numpy().tolist())
noise_scheduler = DDPMScheduler(
timesteps=12,
trained_betas=betas,
timestep_values=ts,
clip_sample=False,
tensor_format="np",
)
pipeline = BDDMPipeline(model, noise_scheduler)
pipeline.save_pretrained(output_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_path", type=str, required=True)
parser.add_argument("--noise_scheduler_checkpoint_path", type=str, required=True)
parser.add_argument("--output_path", type=str, required=True)
args = parser.parse_args()
convert_bddm_orginal(args.checkpoint_path, args.noise_scheduler_checkpoint_path, args.output_path)
import argparse
import OmegaConf
import torch
from diffusers import UNetLDMModel, VQModel, LatentDiffusionUncondPipeline, DDIMScheduler
def convert_ldm_original(checkpoint_path, config_path, output_path):
config = OmegaConf.load(config_path)
state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
keys = list(state_dict.keys())
# extract state_dict for VQVAE
first_stage_dict = {}
first_stage_key = "first_stage_model."
for key in keys:
if key.startswith(first_stage_key):
first_stage_dict[key.replace(first_stage_key, "")] = state_dict[key]
# extract state_dict for UNetLDM
unet_state_dict = {}
unet_key = "model.diffusion_model."
for key in keys:
if key.startswith(unet_key):
unet_state_dict[key.replace(unet_key, "")] = state_dict[key]
vqvae_init_args = config.model.params.first_stage_config.params
unet_init_args = config.model.params.unet_config.params
vqvae = VQModel(**vqvae_init_args).eval()
vqvae.load_state_dict(first_stage_dict)
unet = UNetLDMModel(**unet_init_args).eval()
unet.load_state_dict(unet_state_dict)
noise_scheduler = DDIMScheduler(
timesteps=config.model.params.timesteps,
beta_schedule="scaled_linear",
beta_start=config.model.params.linear_start,
beta_end=config.model.params.linear_end,
clip_sample=False,
)
pipeline = LatentDiffusionUncondPipeline(vqvae, unet, noise_scheduler)
pipeline.save_pretrained(output_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_path", type=str, required=True)
parser.add_argument("--config_path", type=str, required=True)
parser.add_argument("--output_path", type=str, required=True)
args = parser.parse_args()
convert_ldm_original(args.checkpoint_path, args.config_path, args.output_path)
from abc import abstractmethod from abc import abstractmethod
from functools import partial
import functools
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -79,18 +79,25 @@ class Upsample(nn.Module): ...@@ -79,18 +79,25 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions. upsampling occurs in the inner-two dimensions.
""" """
def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None): def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None, name="conv"):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
self.use_conv = use_conv self.use_conv = use_conv
self.dims = dims self.dims = dims
self.use_conv_transpose = use_conv_transpose self.use_conv_transpose = use_conv_transpose
self.name = name
conv = None
if use_conv_transpose: if use_conv_transpose:
self.conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1) conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1)
elif use_conv: elif use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
if name == "conv":
self.conv = conv
else:
self.Conv2d_0 = conv
def forward(self, x): def forward(self, x):
assert x.shape[1] == self.channels assert x.shape[1] == self.channels
...@@ -103,7 +110,10 @@ class Upsample(nn.Module): ...@@ -103,7 +110,10 @@ class Upsample(nn.Module):
x = F.interpolate(x, scale_factor=2.0, mode="nearest") x = F.interpolate(x, scale_factor=2.0, mode="nearest")
if self.use_conv: if self.use_conv:
if self.name == "conv":
x = self.conv(x) x = self.conv(x)
else:
x = self.Conv2d_0(x)
return x return x
...@@ -135,6 +145,8 @@ class Downsample(nn.Module): ...@@ -135,6 +145,8 @@ class Downsample(nn.Module):
if name == "conv": if name == "conv":
self.conv = conv self.conv = conv
elif name == "Conv2d_0":
self.Conv2d_0 = conv
else: else:
self.op = conv self.op = conv
...@@ -146,6 +158,8 @@ class Downsample(nn.Module): ...@@ -146,6 +158,8 @@ class Downsample(nn.Module):
if self.name == "conv": if self.name == "conv":
return self.conv(x) return self.conv(x)
elif self.name == "Conv2d_0":
return self.Conv2d_0(x)
else: else:
return self.op(x) return self.op(x)
...@@ -160,211 +174,7 @@ class Downsample(nn.Module): ...@@ -160,211 +174,7 @@ class Downsample(nn.Module):
# return self.conv(x) # return self.conv(x)
# RESNETS # unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py, unet_score_vde.py
# unet_score_estimation.py
class ResnetBlockBigGANppNew(nn.Module):
def __init__(
self,
act,
in_ch,
out_ch=None,
temb_dim=None,
up=False,
down=False,
dropout=0.1,
fir_kernel=(1, 3, 3, 1),
skip_rescale=True,
init_scale=0.0,
overwrite=True,
):
super().__init__()
out_ch = out_ch if out_ch else in_ch
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
self.up = up
self.down = down
self.fir_kernel = fir_kernel
self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1)
if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch)
self.Dense_0.weight.data = variance_scaling()(self.Dense_0.weight.shape)
nn.init.zeros_(self.Dense_0.bias)
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
self.Dropout_0 = nn.Dropout(dropout)
self.Conv_1 = conv2d(out_ch, out_ch, init_scale=init_scale, kernel_size=3, padding=1)
if in_ch != out_ch or up or down:
# 1x1 convolution with DDPM initialization.
self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0)
self.skip_rescale = skip_rescale
self.act = act
self.in_ch = in_ch
self.out_ch = out_ch
self.is_overwritten = False
self.overwrite = overwrite
if overwrite:
self.output_scale_factor = np.sqrt(2.0)
self.in_channels = in_channels = in_ch
self.out_channels = out_channels = out_ch
groups = min(in_ch // 4, 32)
out_groups = min(out_ch // 4, 32)
eps = 1e-6
self.pre_norm = True
temb_channels = temb_dim
non_linearity = "silu"
self.time_embedding_norm = time_embedding_norm = "default"
if self.pre_norm:
self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps)
else:
self.norm1 = Normalize(out_channels, num_groups=groups, eps=eps)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if time_embedding_norm == "default":
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
elif time_embedding_norm == "scale_shift":
self.temb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
self.norm2 = Normalize(out_channels, num_groups=out_groups, eps=eps)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = nonlinearity
elif non_linearity == "mish":
self.nonlinearity = Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
if up:
self.h_upd = Upsample(in_channels, use_conv=False, dims=2)
self.x_upd = Upsample(in_channels, use_conv=False, dims=2)
elif down:
self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
if self.in_channels != self.out_channels or self.up or self.down:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def set_weights(self):
self.conv1.weight.data = self.Conv_0.weight.data
self.conv1.bias.data = self.Conv_0.bias.data
self.norm1.weight.data = self.GroupNorm_0.weight.data
self.norm1.bias.data = self.GroupNorm_0.bias.data
self.conv2.weight.data = self.Conv_1.weight.data
self.conv2.bias.data = self.Conv_1.bias.data
self.norm2.weight.data = self.GroupNorm_1.weight.data
self.norm2.bias.data = self.GroupNorm_1.bias.data
self.temb_proj.weight.data = self.Dense_0.weight.data
self.temb_proj.bias.data = self.Dense_0.bias.data
if self.in_channels != self.out_channels or self.up or self.down:
self.nin_shortcut.weight.data = self.Conv_2.weight.data
self.nin_shortcut.bias.data = self.Conv_2.bias.data
def forward(self, x, temb=None):
if self.overwrite and not self.is_overwritten:
self.set_weights()
self.is_overwritten = True
orig_x = x
h = self.act(self.GroupNorm_0(x))
if self.up:
h = upsample_2d(h, self.fir_kernel, factor=2)
x = upsample_2d(x, self.fir_kernel, factor=2)
elif self.down:
h = downsample_2d(h, self.fir_kernel, factor=2)
x = downsample_2d(x, self.fir_kernel, factor=2)
h = self.Conv_0(h)
# Add bias to each feature map conditioned on the time embedding
if temb is not None:
h += self.Dense_0(self.act(temb))[:, :, None, None]
h = self.act(self.GroupNorm_1(h))
h = self.Dropout_0(h)
h = self.Conv_1(h)
if self.in_ch != self.out_ch or self.up or self.down:
x = self.Conv_2(x)
if not self.skip_rescale:
raise ValueError("Is this branch run?!")
# import ipdb; ipdb.set_trace()
result = x + h
else:
result = (x + h) / np.sqrt(2.0)
result_2 = self.forward_2(orig_x, temb)
return result_2
def forward_2(self, x, temb, mask=1.0):
h = x
h = h * mask
if self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
# if self.up or self.down:
# x = self.x_upd(x)
# h = self.h_upd(h)
if self.up:
h = upsample_2d(h, self.fir_kernel, factor=2)
x = upsample_2d(x, self.fir_kernel, factor=2)
elif self.down:
h = downsample_2d(h, self.fir_kernel, factor=2)
x = downsample_2d(x, self.fir_kernel, factor=2)
h = self.conv1(h)
if not self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
h = h * mask
temb = self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
if self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
h = self.norm2(h)
h = h + h * scale + shift
h = self.nonlinearity(h)
elif self.time_embedding_norm == "default":
h = h + temb
h = h * mask
if self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
else:
raise ValueError("Nananan nanana - don't go here!")
h = self.dropout(h)
h = self.conv2(h)
if not self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
h = h * mask
x = x * mask
# if self.in_channels != self.out_channels:
if self.in_channels != self.out_channels or self.up or self.down:
x = self.nin_shortcut(x)
result = x + h
return result / self.output_scale_factor
# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py
class ResnetBlock(nn.Module): class ResnetBlock(nn.Module):
def __init__( def __init__(
self, self,
...@@ -674,6 +484,7 @@ class ResnetBlockBigGANpp(nn.Module): ...@@ -674,6 +484,7 @@ class ResnetBlockBigGANpp(nn.Module):
up=False, up=False,
down=False, down=False,
dropout=0.1, dropout=0.1,
fir=False,
fir_kernel=(1, 3, 3, 1), fir_kernel=(1, 3, 3, 1),
skip_rescale=True, skip_rescale=True,
init_scale=0.0, init_scale=0.0,
...@@ -684,8 +495,20 @@ class ResnetBlockBigGANpp(nn.Module): ...@@ -684,8 +495,20 @@ class ResnetBlockBigGANpp(nn.Module):
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
self.up = up self.up = up
self.down = down self.down = down
self.fir = fir
self.fir_kernel = fir_kernel self.fir_kernel = fir_kernel
if self.up:
if self.fir:
self.upsample = partial(upsample_2d, k=self.fir_kernel, factor=2)
else:
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
elif self.down:
if self.fir:
self.downsample = partial(downsample_2d, k=self.fir_kernel, factor=2)
else:
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1) self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1)
if temb_dim is not None: if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch) self.Dense_0 = nn.Linear(temb_dim, out_ch)
...@@ -708,11 +531,11 @@ class ResnetBlockBigGANpp(nn.Module): ...@@ -708,11 +531,11 @@ class ResnetBlockBigGANpp(nn.Module):
h = self.act(self.GroupNorm_0(x)) h = self.act(self.GroupNorm_0(x))
if self.up: if self.up:
h = upsample_2d(h, self.fir_kernel, factor=2) h = self.upsample(h)
x = upsample_2d(x, self.fir_kernel, factor=2) x = self.upsample(x)
elif self.down: elif self.down:
h = downsample_2d(h, self.fir_kernel, factor=2) h = self.downsample(h)
x = downsample_2d(x, self.fir_kernel, factor=2) x = self.downsample(x)
h = self.Conv_0(h) h = self.Conv_0(h)
# Add bias to each feature map conditioned on the time embedding # Add bias to each feature map conditioned on the time embedding
......
...@@ -27,7 +27,7 @@ from ..configuration_utils import ConfigMixin ...@@ -27,7 +27,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import AttentionBlock from .attention import AttentionBlock
from .embeddings import GaussianFourierProjection, get_timestep_embedding from .embeddings import GaussianFourierProjection, get_timestep_embedding
from .resnet import downsample_2d, upfirdn2d, upsample_2d from .resnet import downsample_2d, upfirdn2d, upsample_2d, Downsample, Upsample
from .resnet import ResnetBlock from .resnet import ResnetBlock
...@@ -185,18 +185,19 @@ class Combine(nn.Module): ...@@ -185,18 +185,19 @@ class Combine(nn.Module):
class FirUpsample(nn.Module): class FirUpsample(nn.Module):
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir_kernel=(1, 3, 3, 1)): def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
super().__init__() super().__init__()
out_ch = out_ch if out_ch else in_ch out_channels = out_channels if out_channels else channels
if with_conv: if use_conv:
self.Conv2d_0 = Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1) self.Conv2d_0 = Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
self.with_conv = with_conv self.use_conv = use_conv
self.fir_kernel = fir_kernel self.fir_kernel = fir_kernel
self.out_ch = out_ch self.out_channels = out_channels
def forward(self, x): def forward(self, x):
if self.with_conv: if self.use_conv:
h = _upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel) h = _upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
h = h + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else: else:
h = upsample_2d(x, self.fir_kernel, factor=2) h = upsample_2d(x, self.fir_kernel, factor=2)
...@@ -204,18 +205,19 @@ class FirUpsample(nn.Module): ...@@ -204,18 +205,19 @@ class FirUpsample(nn.Module):
class FirDownsample(nn.Module): class FirDownsample(nn.Module):
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir_kernel=(1, 3, 3, 1)): def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
super().__init__() super().__init__()
out_ch = out_ch if out_ch else in_ch out_channels = out_channels if out_channels else channels
if with_conv: if use_conv:
self.Conv2d_0 = self.Conv2d_0 = Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1) self.Conv2d_0 = self.Conv2d_0 = Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
self.fir_kernel = fir_kernel self.fir_kernel = fir_kernel
self.with_conv = with_conv self.use_conv = use_conv
self.out_ch = out_ch self.out_channels = out_channels
def forward(self, x): def forward(self, x):
if self.with_conv: if self.use_conv:
x = _conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel) x = _conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else: else:
x = downsample_2d(x, self.fir_kernel, factor=2) x = downsample_2d(x, self.fir_kernel, factor=2)
...@@ -229,13 +231,14 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -229,13 +231,14 @@ class NCSNpp(ModelMixin, ConfigMixin):
self, self,
image_size=1024, image_size=1024,
num_channels=3, num_channels=3,
centered=False,
attn_resolutions=(16,), attn_resolutions=(16,),
ch_mult=(1, 2, 4, 8, 16, 32, 32, 32), ch_mult=(1, 2, 4, 8, 16, 32, 32, 32),
conditional=True, conditional=True,
conv_size=3, conv_size=3,
dropout=0.0, dropout=0.0,
embedding_type="fourier", embedding_type="fourier",
fir=True, # TODO (patil-suraj) remove this option from here and pre-trained model configs fir=True,
fir_kernel=(1, 3, 3, 1), fir_kernel=(1, 3, 3, 1),
fourier_scale=16, fourier_scale=16,
init_scale=0.0, init_scale=0.0,
...@@ -253,12 +256,14 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -253,12 +256,14 @@ class NCSNpp(ModelMixin, ConfigMixin):
self.register_to_config( self.register_to_config(
image_size=image_size, image_size=image_size,
num_channels=num_channels, num_channels=num_channels,
centered=centered,
attn_resolutions=attn_resolutions, attn_resolutions=attn_resolutions,
ch_mult=ch_mult, ch_mult=ch_mult,
conditional=conditional, conditional=conditional,
conv_size=conv_size, conv_size=conv_size,
dropout=dropout, dropout=dropout,
embedding_type=embedding_type, embedding_type=embedding_type,
fir=fir,
fir_kernel=fir_kernel, fir_kernel=fir_kernel,
fourier_scale=fourier_scale, fourier_scale=fourier_scale,
init_scale=init_scale, init_scale=init_scale,
...@@ -308,21 +313,26 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -308,21 +313,26 @@ class NCSNpp(ModelMixin, ConfigMixin):
modules.append(Linear(nf * 4, nf * 4)) modules.append(Linear(nf * 4, nf * 4))
AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0)) AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
Up_sample = functools.partial(FirUpsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel)
if self.fir:
Up_sample = functools.partial(FirUpsample, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
else:
Up_sample = functools.partial(Upsample, name="Conv2d_0")
if progressive == "output_skip": if progressive == "output_skip":
self.pyramid_upsample = Up_sample(fir_kernel=fir_kernel, with_conv=False) self.pyramid_upsample = Up_sample(channels=None, use_conv=False)
elif progressive == "residual": elif progressive == "residual":
pyramid_upsample = functools.partial(Up_sample, fir_kernel=fir_kernel, with_conv=True) pyramid_upsample = functools.partial(Up_sample, use_conv=True)
Down_sample = functools.partial(FirDownsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel) if self.fir:
Down_sample = functools.partial(FirDownsample, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
else:
Down_sample = functools.partial(Downsample, padding=0, name="Conv2d_0")
if progressive_input == "input_skip": if progressive_input == "input_skip":
self.pyramid_downsample = Down_sample(fir_kernel=fir_kernel, with_conv=False) self.pyramid_downsample = Down_sample(channels=None, use_conv=False)
elif progressive_input == "residual": elif progressive_input == "residual":
pyramid_downsample = functools.partial(Down_sample, fir_kernel=fir_kernel, with_conv=True) pyramid_downsample = functools.partial(Down_sample, use_conv=True)
# Downsampling block
channels = num_channels channels = num_channels
if progressive_input != "none": if progressive_input != "none":
...@@ -376,7 +386,7 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -376,7 +386,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
in_ch *= 2 in_ch *= 2
elif progressive_input == "residual": elif progressive_input == "residual":
modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch)) modules.append(pyramid_downsample(channels=input_pyramid_ch, out_channels=in_ch))
input_pyramid_ch = in_ch input_pyramid_ch = in_ch
hs_c.append(in_ch) hs_c.append(in_ch)
...@@ -448,7 +458,7 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -448,7 +458,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
) )
pyramid_ch = channels pyramid_ch = channels
elif progressive == "residual": elif progressive == "residual":
modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch)) modules.append(pyramid_upsample(channels=pyramid_ch, out_channels=in_ch))
pyramid_ch = in_ch pyramid_ch = in_ch
else: else:
raise ValueError(f"{progressive} is not a valid name") raise ValueError(f"{progressive} is not a valid name")
...@@ -464,7 +474,7 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -464,7 +474,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
groups_out=min(out_ch // 4, 32), groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True, overwrite_for_score_vde=True,
up=True, up=True,
kernel="fir", kernel="fir", # TODO(Patrick) - it seems like both fir and non-fir kernels are fine
use_nin_shortcut=True, use_nin_shortcut=True,
) )
) )
...@@ -505,6 +515,7 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -505,6 +515,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
temb = None temb = None
# If input data is in [0, 1] # If input data is in [0, 1]
if not self.config.centered:
x = 2 * x - 1.0 x = 2 * x - 1.0
# Downsampling block # Downsampling block
......
...@@ -63,9 +63,6 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline): ...@@ -63,9 +63,6 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
# 4. set current image to prev_image: x_t -> x_t-1 # 4. set current image to prev_image: x_t -> x_t-1
image = pred_prev_image + variance image = pred_prev_image + variance
# scale and decode image with vae # decode image with vae
image = 1 / 0.18215 * image
image = self.vqvae.decode(image) image = self.vqvae.decode(image)
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
return image return image
...@@ -1159,7 +1159,9 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -1159,7 +1159,9 @@ class PipelineTesterMixin(unittest.TestCase):
image_slice = image[0, -1, -3:, -3:].cpu() image_slice = image[0, -1, -3:, -3:].cpu()
assert image.shape == (1, 3, 256, 256) assert image.shape == (1, 3, 256, 256)
expected_slice = torch.tensor([0.5025, 0.4121, 0.3851, 0.4806, 0.3996, 0.3745, 0.4839, 0.4559, 0.4293]) expected_slice = torch.tensor(
[-0.1202, -0.1005, -0.0635, -0.0520, -0.1282, -0.0838, -0.0981, -0.1318, -0.1106]
)
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
def test_module_from_pipeline(self): def test_module_from_pipeline(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment