Commit 571e4062 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

merge from master

parents 14bd3567 c2bc59d2
import argparse
import torch
from diffusers.pipelines.bddm import DiffWave, BDDMPipeline
from diffusers import DDPMScheduler
def convert_bddm_orginal(checkpoint_path, noise_scheduler_checkpoint_path, output_path):
sd = torch.load(checkpoint_path, map_location="cpu")["model_state_dict"]
noise_scheduler_sd = torch.load(noise_scheduler_checkpoint_path, map_location="cpu")
model = DiffWave()
model.load_state_dict(sd, strict=False)
ts, _, betas, _ = noise_scheduler_sd
ts, betas = list(ts.numpy().tolist()), list(betas.numpy().tolist())
noise_scheduler = DDPMScheduler(
timesteps=12,
trained_betas=betas,
timestep_values=ts,
clip_sample=False,
tensor_format="np",
)
pipeline = BDDMPipeline(model, noise_scheduler)
pipeline.save_pretrained(output_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_path", type=str, required=True)
parser.add_argument("--noise_scheduler_checkpoint_path", type=str, required=True)
parser.add_argument("--output_path", type=str, required=True)
args = parser.parse_args()
convert_bddm_orginal(args.checkpoint_path, args.noise_scheduler_checkpoint_path, args.output_path)
import argparse
import OmegaConf
import torch
from diffusers import UNetLDMModel, VQModel, LatentDiffusionUncondPipeline, DDIMScheduler
def convert_ldm_original(checkpoint_path, config_path, output_path):
config = OmegaConf.load(config_path)
state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
keys = list(state_dict.keys())
# extract state_dict for VQVAE
first_stage_dict = {}
first_stage_key = "first_stage_model."
for key in keys:
if key.startswith(first_stage_key):
first_stage_dict[key.replace(first_stage_key, "")] = state_dict[key]
# extract state_dict for UNetLDM
unet_state_dict = {}
unet_key = "model.diffusion_model."
for key in keys:
if key.startswith(unet_key):
unet_state_dict[key.replace(unet_key, "")] = state_dict[key]
vqvae_init_args = config.model.params.first_stage_config.params
unet_init_args = config.model.params.unet_config.params
vqvae = VQModel(**vqvae_init_args).eval()
vqvae.load_state_dict(first_stage_dict)
unet = UNetLDMModel(**unet_init_args).eval()
unet.load_state_dict(unet_state_dict)
noise_scheduler = DDIMScheduler(
timesteps=config.model.params.timesteps,
beta_schedule="scaled_linear",
beta_start=config.model.params.linear_start,
beta_end=config.model.params.linear_end,
clip_sample=False,
)
pipeline = LatentDiffusionUncondPipeline(vqvae, unet, noise_scheduler)
pipeline.save_pretrained(output_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_path", type=str, required=True)
parser.add_argument("--config_path", type=str, required=True)
parser.add_argument("--output_path", type=str, required=True)
args = parser.parse_args()
convert_ldm_original(args.checkpoint_path, args.config_path, args.output_path)
from abc import abstractmethod
from functools import partial
import numpy as np
import torch
......@@ -78,18 +79,25 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None):
def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
self.use_conv_transpose = use_conv_transpose
self.name = name
conv = None
if use_conv_transpose:
self.conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1)
conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1)
elif use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
if name == "conv":
self.conv = conv
else:
self.Conv2d_0 = conv
def forward(self, x):
assert x.shape[1] == self.channels
......@@ -102,7 +110,10 @@ class Upsample(nn.Module):
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
if self.use_conv:
if self.name == "conv":
x = self.conv(x)
else:
x = self.Conv2d_0(x)
return x
......@@ -134,6 +145,8 @@ class Downsample(nn.Module):
if name == "conv":
self.conv = conv
elif name == "Conv2d_0":
self.Conv2d_0 = conv
else:
self.op = conv
......@@ -145,6 +158,8 @@ class Downsample(nn.Module):
if self.name == "conv":
return self.conv(x)
elif self.name == "Conv2d_0":
return self.Conv2d_0(x)
else:
return self.op(x)
......@@ -469,6 +484,7 @@ class ResnetBlockBigGANpp(nn.Module):
up=False,
down=False,
dropout=0.1,
fir=False,
fir_kernel=(1, 3, 3, 1),
skip_rescale=True,
init_scale=0.0,
......@@ -479,8 +495,20 @@ class ResnetBlockBigGANpp(nn.Module):
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
self.up = up
self.down = down
self.fir = fir
self.fir_kernel = fir_kernel
if self.up:
if self.fir:
self.upsample = partial(upsample_2d, k=self.fir_kernel, factor=2)
else:
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
elif self.down:
if self.fir:
self.downsample = partial(downsample_2d, k=self.fir_kernel, factor=2)
else:
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1)
if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch)
......@@ -503,11 +531,11 @@ class ResnetBlockBigGANpp(nn.Module):
h = self.act(self.GroupNorm_0(x))
if self.up:
h = upsample_2d(h, self.fir_kernel, factor=2)
x = upsample_2d(x, self.fir_kernel, factor=2)
h = self.upsample(h)
x = self.upsample(x)
elif self.down:
h = downsample_2d(h, self.fir_kernel, factor=2)
x = downsample_2d(x, self.fir_kernel, factor=2)
h = self.downsample(h)
x = self.downsample(x)
h = self.Conv_0(h)
# Add bias to each feature map conditioned on the time embedding
......
......@@ -27,7 +27,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention import AttentionBlock
from .embeddings import GaussianFourierProjection, get_timestep_embedding
from .resnet import downsample_2d, upfirdn2d, upsample_2d
from .resnet import downsample_2d, upfirdn2d, upsample_2d, Downsample, Upsample
from .resnet import ResnetBlock
......@@ -185,18 +185,19 @@ class Combine(nn.Module):
class FirUpsample(nn.Module):
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir_kernel=(1, 3, 3, 1)):
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
super().__init__()
out_ch = out_ch if out_ch else in_ch
if with_conv:
self.Conv2d_0 = Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1)
self.with_conv = with_conv
out_channels = out_channels if out_channels else channels
if use_conv:
self.Conv2d_0 = Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
self.use_conv = use_conv
self.fir_kernel = fir_kernel
self.out_ch = out_ch
self.out_channels = out_channels
def forward(self, x):
if self.with_conv:
if self.use_conv:
h = _upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
h = h + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else:
h = upsample_2d(x, self.fir_kernel, factor=2)
......@@ -204,18 +205,19 @@ class FirUpsample(nn.Module):
class FirDownsample(nn.Module):
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir_kernel=(1, 3, 3, 1)):
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
super().__init__()
out_ch = out_ch if out_ch else in_ch
if with_conv:
self.Conv2d_0 = self.Conv2d_0 = Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1)
out_channels = out_channels if out_channels else channels
if use_conv:
self.Conv2d_0 = self.Conv2d_0 = Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
self.fir_kernel = fir_kernel
self.with_conv = with_conv
self.out_ch = out_ch
self.use_conv = use_conv
self.out_channels = out_channels
def forward(self, x):
if self.with_conv:
if self.use_conv:
x = _conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else:
x = downsample_2d(x, self.fir_kernel, factor=2)
......@@ -229,13 +231,14 @@ class NCSNpp(ModelMixin, ConfigMixin):
self,
image_size=1024,
num_channels=3,
centered=False,
attn_resolutions=(16,),
ch_mult=(1, 2, 4, 8, 16, 32, 32, 32),
conditional=True,
conv_size=3,
dropout=0.0,
embedding_type="fourier",
fir=True, # TODO (patil-suraj) remove this option from here and pre-trained model configs
fir=True,
fir_kernel=(1, 3, 3, 1),
fourier_scale=16,
init_scale=0.0,
......@@ -253,12 +256,14 @@ class NCSNpp(ModelMixin, ConfigMixin):
self.register_to_config(
image_size=image_size,
num_channels=num_channels,
centered=centered,
attn_resolutions=attn_resolutions,
ch_mult=ch_mult,
conditional=conditional,
conv_size=conv_size,
dropout=dropout,
embedding_type=embedding_type,
fir=fir,
fir_kernel=fir_kernel,
fourier_scale=fourier_scale,
init_scale=init_scale,
......@@ -308,21 +313,26 @@ class NCSNpp(ModelMixin, ConfigMixin):
modules.append(Linear(nf * 4, nf * 4))
AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
Up_sample = functools.partial(FirUpsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel)
if self.fir:
Up_sample = functools.partial(FirUpsample, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
else:
Up_sample = functools.partial(Upsample, name="Conv2d_0")
if progressive == "output_skip":
self.pyramid_upsample = Up_sample(fir_kernel=fir_kernel, with_conv=False)
self.pyramid_upsample = Up_sample(channels=None, use_conv=False)
elif progressive == "residual":
pyramid_upsample = functools.partial(Up_sample, fir_kernel=fir_kernel, with_conv=True)
pyramid_upsample = functools.partial(Up_sample, use_conv=True)
Down_sample = functools.partial(FirDownsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel)
if self.fir:
Down_sample = functools.partial(FirDownsample, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
else:
Down_sample = functools.partial(Downsample, padding=0, name="Conv2d_0")
if progressive_input == "input_skip":
self.pyramid_downsample = Down_sample(fir_kernel=fir_kernel, with_conv=False)
self.pyramid_downsample = Down_sample(channels=None, use_conv=False)
elif progressive_input == "residual":
pyramid_downsample = functools.partial(Down_sample, fir_kernel=fir_kernel, with_conv=True)
# Downsampling block
pyramid_downsample = functools.partial(Down_sample, use_conv=True)
channels = num_channels
if progressive_input != "none":
......@@ -376,7 +386,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
in_ch *= 2
elif progressive_input == "residual":
modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch))
modules.append(pyramid_downsample(channels=input_pyramid_ch, out_channels=in_ch))
input_pyramid_ch = in_ch
hs_c.append(in_ch)
......@@ -448,7 +458,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
)
pyramid_ch = channels
elif progressive == "residual":
modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch))
modules.append(pyramid_upsample(channels=pyramid_ch, out_channels=in_ch))
pyramid_ch = in_ch
else:
raise ValueError(f"{progressive} is not a valid name")
......@@ -505,6 +515,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
temb = None
# If input data is in [0, 1]
if not self.config.centered:
x = 2 * x - 1.0
# Downsampling block
......
......@@ -63,9 +63,6 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
# 4. set current image to prev_image: x_t -> x_t-1
image = pred_prev_image + variance
# scale and decode image with vae
image = 1 / 0.18215 * image
# decode image with vae
image = self.vqvae.decode(image)
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
return image
......@@ -1159,7 +1159,9 @@ class PipelineTesterMixin(unittest.TestCase):
image_slice = image[0, -1, -3:, -3:].cpu()
assert image.shape == (1, 3, 256, 256)
expected_slice = torch.tensor([0.5025, 0.4121, 0.3851, 0.4806, 0.3996, 0.3745, 0.4839, 0.4559, 0.4293])
expected_slice = torch.tensor(
[-0.1202, -0.1005, -0.0635, -0.0520, -0.1282, -0.0838, -0.0981, -0.1318, -0.1106]
)
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
def test_module_from_pipeline(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment