Commit 571e4062 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

merge from master

parents 14bd3567 c2bc59d2
import argparse
import torch
from diffusers.pipelines.bddm import DiffWave, BDDMPipeline
from diffusers import DDPMScheduler
def convert_bddm_orginal(checkpoint_path, noise_scheduler_checkpoint_path, output_path):
sd = torch.load(checkpoint_path, map_location="cpu")["model_state_dict"]
noise_scheduler_sd = torch.load(noise_scheduler_checkpoint_path, map_location="cpu")
model = DiffWave()
model.load_state_dict(sd, strict=False)
ts, _, betas, _ = noise_scheduler_sd
ts, betas = list(ts.numpy().tolist()), list(betas.numpy().tolist())
noise_scheduler = DDPMScheduler(
timesteps=12,
trained_betas=betas,
timestep_values=ts,
clip_sample=False,
tensor_format="np",
)
pipeline = BDDMPipeline(model, noise_scheduler)
pipeline.save_pretrained(output_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_path", type=str, required=True)
parser.add_argument("--noise_scheduler_checkpoint_path", type=str, required=True)
parser.add_argument("--output_path", type=str, required=True)
args = parser.parse_args()
convert_bddm_orginal(args.checkpoint_path, args.noise_scheduler_checkpoint_path, args.output_path)
import argparse
import OmegaConf
import torch
from diffusers import UNetLDMModel, VQModel, LatentDiffusionUncondPipeline, DDIMScheduler
def convert_ldm_original(checkpoint_path, config_path, output_path):
config = OmegaConf.load(config_path)
state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
keys = list(state_dict.keys())
# extract state_dict for VQVAE
first_stage_dict = {}
first_stage_key = "first_stage_model."
for key in keys:
if key.startswith(first_stage_key):
first_stage_dict[key.replace(first_stage_key, "")] = state_dict[key]
# extract state_dict for UNetLDM
unet_state_dict = {}
unet_key = "model.diffusion_model."
for key in keys:
if key.startswith(unet_key):
unet_state_dict[key.replace(unet_key, "")] = state_dict[key]
vqvae_init_args = config.model.params.first_stage_config.params
unet_init_args = config.model.params.unet_config.params
vqvae = VQModel(**vqvae_init_args).eval()
vqvae.load_state_dict(first_stage_dict)
unet = UNetLDMModel(**unet_init_args).eval()
unet.load_state_dict(unet_state_dict)
noise_scheduler = DDIMScheduler(
timesteps=config.model.params.timesteps,
beta_schedule="scaled_linear",
beta_start=config.model.params.linear_start,
beta_end=config.model.params.linear_end,
clip_sample=False,
)
pipeline = LatentDiffusionUncondPipeline(vqvae, unet, noise_scheduler)
pipeline.save_pretrained(output_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_path", type=str, required=True)
parser.add_argument("--config_path", type=str, required=True)
parser.add_argument("--output_path", type=str, required=True)
args = parser.parse_args()
convert_ldm_original(args.checkpoint_path, args.config_path, args.output_path)
from abc import abstractmethod from abc import abstractmethod
from functools import partial
import numpy as np import numpy as np
import torch import torch
...@@ -78,18 +79,25 @@ class Upsample(nn.Module): ...@@ -78,18 +79,25 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions. upsampling occurs in the inner-two dimensions.
""" """
def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None): def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None, name="conv"):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
self.use_conv = use_conv self.use_conv = use_conv
self.dims = dims self.dims = dims
self.use_conv_transpose = use_conv_transpose self.use_conv_transpose = use_conv_transpose
self.name = name
conv = None
if use_conv_transpose: if use_conv_transpose:
self.conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1) conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1)
elif use_conv: elif use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
if name == "conv":
self.conv = conv
else:
self.Conv2d_0 = conv
def forward(self, x): def forward(self, x):
assert x.shape[1] == self.channels assert x.shape[1] == self.channels
...@@ -102,7 +110,10 @@ class Upsample(nn.Module): ...@@ -102,7 +110,10 @@ class Upsample(nn.Module):
x = F.interpolate(x, scale_factor=2.0, mode="nearest") x = F.interpolate(x, scale_factor=2.0, mode="nearest")
if self.use_conv: if self.use_conv:
x = self.conv(x) if self.name == "conv":
x = self.conv(x)
else:
x = self.Conv2d_0(x)
return x return x
...@@ -134,6 +145,8 @@ class Downsample(nn.Module): ...@@ -134,6 +145,8 @@ class Downsample(nn.Module):
if name == "conv": if name == "conv":
self.conv = conv self.conv = conv
elif name == "Conv2d_0":
self.Conv2d_0 = conv
else: else:
self.op = conv self.op = conv
...@@ -145,6 +158,8 @@ class Downsample(nn.Module): ...@@ -145,6 +158,8 @@ class Downsample(nn.Module):
if self.name == "conv": if self.name == "conv":
return self.conv(x) return self.conv(x)
elif self.name == "Conv2d_0":
return self.Conv2d_0(x)
else: else:
return self.op(x) return self.op(x)
...@@ -469,6 +484,7 @@ class ResnetBlockBigGANpp(nn.Module): ...@@ -469,6 +484,7 @@ class ResnetBlockBigGANpp(nn.Module):
up=False, up=False,
down=False, down=False,
dropout=0.1, dropout=0.1,
fir=False,
fir_kernel=(1, 3, 3, 1), fir_kernel=(1, 3, 3, 1),
skip_rescale=True, skip_rescale=True,
init_scale=0.0, init_scale=0.0,
...@@ -479,8 +495,20 @@ class ResnetBlockBigGANpp(nn.Module): ...@@ -479,8 +495,20 @@ class ResnetBlockBigGANpp(nn.Module):
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
self.up = up self.up = up
self.down = down self.down = down
self.fir = fir
self.fir_kernel = fir_kernel self.fir_kernel = fir_kernel
if self.up:
if self.fir:
self.upsample = partial(upsample_2d, k=self.fir_kernel, factor=2)
else:
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
elif self.down:
if self.fir:
self.downsample = partial(downsample_2d, k=self.fir_kernel, factor=2)
else:
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1) self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1)
if temb_dim is not None: if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch) self.Dense_0 = nn.Linear(temb_dim, out_ch)
...@@ -503,11 +531,11 @@ class ResnetBlockBigGANpp(nn.Module): ...@@ -503,11 +531,11 @@ class ResnetBlockBigGANpp(nn.Module):
h = self.act(self.GroupNorm_0(x)) h = self.act(self.GroupNorm_0(x))
if self.up: if self.up:
h = upsample_2d(h, self.fir_kernel, factor=2) h = self.upsample(h)
x = upsample_2d(x, self.fir_kernel, factor=2) x = self.upsample(x)
elif self.down: elif self.down:
h = downsample_2d(h, self.fir_kernel, factor=2) h = self.downsample(h)
x = downsample_2d(x, self.fir_kernel, factor=2) x = self.downsample(x)
h = self.Conv_0(h) h = self.Conv_0(h)
# Add bias to each feature map conditioned on the time embedding # Add bias to each feature map conditioned on the time embedding
......
...@@ -27,7 +27,7 @@ from ..configuration_utils import ConfigMixin ...@@ -27,7 +27,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import AttentionBlock from .attention import AttentionBlock
from .embeddings import GaussianFourierProjection, get_timestep_embedding from .embeddings import GaussianFourierProjection, get_timestep_embedding
from .resnet import downsample_2d, upfirdn2d, upsample_2d from .resnet import downsample_2d, upfirdn2d, upsample_2d, Downsample, Upsample
from .resnet import ResnetBlock from .resnet import ResnetBlock
...@@ -185,18 +185,19 @@ class Combine(nn.Module): ...@@ -185,18 +185,19 @@ class Combine(nn.Module):
class FirUpsample(nn.Module): class FirUpsample(nn.Module):
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir_kernel=(1, 3, 3, 1)): def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
super().__init__() super().__init__()
out_ch = out_ch if out_ch else in_ch out_channels = out_channels if out_channels else channels
if with_conv: if use_conv:
self.Conv2d_0 = Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1) self.Conv2d_0 = Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
self.with_conv = with_conv self.use_conv = use_conv
self.fir_kernel = fir_kernel self.fir_kernel = fir_kernel
self.out_ch = out_ch self.out_channels = out_channels
def forward(self, x): def forward(self, x):
if self.with_conv: if self.use_conv:
h = _upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel) h = _upsample_conv_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
h = h + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else: else:
h = upsample_2d(x, self.fir_kernel, factor=2) h = upsample_2d(x, self.fir_kernel, factor=2)
...@@ -204,18 +205,19 @@ class FirUpsample(nn.Module): ...@@ -204,18 +205,19 @@ class FirUpsample(nn.Module):
class FirDownsample(nn.Module): class FirDownsample(nn.Module):
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir_kernel=(1, 3, 3, 1)): def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
super().__init__() super().__init__()
out_ch = out_ch if out_ch else in_ch out_channels = out_channels if out_channels else channels
if with_conv: if use_conv:
self.Conv2d_0 = self.Conv2d_0 = Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1) self.Conv2d_0 = self.Conv2d_0 = Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
self.fir_kernel = fir_kernel self.fir_kernel = fir_kernel
self.with_conv = with_conv self.use_conv = use_conv
self.out_ch = out_ch self.out_channels = out_channels
def forward(self, x): def forward(self, x):
if self.with_conv: if self.use_conv:
x = _conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel) x = _conv_downsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
else: else:
x = downsample_2d(x, self.fir_kernel, factor=2) x = downsample_2d(x, self.fir_kernel, factor=2)
...@@ -229,13 +231,14 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -229,13 +231,14 @@ class NCSNpp(ModelMixin, ConfigMixin):
self, self,
image_size=1024, image_size=1024,
num_channels=3, num_channels=3,
centered=False,
attn_resolutions=(16,), attn_resolutions=(16,),
ch_mult=(1, 2, 4, 8, 16, 32, 32, 32), ch_mult=(1, 2, 4, 8, 16, 32, 32, 32),
conditional=True, conditional=True,
conv_size=3, conv_size=3,
dropout=0.0, dropout=0.0,
embedding_type="fourier", embedding_type="fourier",
fir=True, # TODO (patil-suraj) remove this option from here and pre-trained model configs fir=True,
fir_kernel=(1, 3, 3, 1), fir_kernel=(1, 3, 3, 1),
fourier_scale=16, fourier_scale=16,
init_scale=0.0, init_scale=0.0,
...@@ -253,12 +256,14 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -253,12 +256,14 @@ class NCSNpp(ModelMixin, ConfigMixin):
self.register_to_config( self.register_to_config(
image_size=image_size, image_size=image_size,
num_channels=num_channels, num_channels=num_channels,
centered=centered,
attn_resolutions=attn_resolutions, attn_resolutions=attn_resolutions,
ch_mult=ch_mult, ch_mult=ch_mult,
conditional=conditional, conditional=conditional,
conv_size=conv_size, conv_size=conv_size,
dropout=dropout, dropout=dropout,
embedding_type=embedding_type, embedding_type=embedding_type,
fir=fir,
fir_kernel=fir_kernel, fir_kernel=fir_kernel,
fourier_scale=fourier_scale, fourier_scale=fourier_scale,
init_scale=init_scale, init_scale=init_scale,
...@@ -308,21 +313,26 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -308,21 +313,26 @@ class NCSNpp(ModelMixin, ConfigMixin):
modules.append(Linear(nf * 4, nf * 4)) modules.append(Linear(nf * 4, nf * 4))
AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0)) AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
Up_sample = functools.partial(FirUpsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel)
if self.fir:
Up_sample = functools.partial(FirUpsample, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
else:
Up_sample = functools.partial(Upsample, name="Conv2d_0")
if progressive == "output_skip": if progressive == "output_skip":
self.pyramid_upsample = Up_sample(fir_kernel=fir_kernel, with_conv=False) self.pyramid_upsample = Up_sample(channels=None, use_conv=False)
elif progressive == "residual": elif progressive == "residual":
pyramid_upsample = functools.partial(Up_sample, fir_kernel=fir_kernel, with_conv=True) pyramid_upsample = functools.partial(Up_sample, use_conv=True)
Down_sample = functools.partial(FirDownsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel) if self.fir:
Down_sample = functools.partial(FirDownsample, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
else:
Down_sample = functools.partial(Downsample, padding=0, name="Conv2d_0")
if progressive_input == "input_skip": if progressive_input == "input_skip":
self.pyramid_downsample = Down_sample(fir_kernel=fir_kernel, with_conv=False) self.pyramid_downsample = Down_sample(channels=None, use_conv=False)
elif progressive_input == "residual": elif progressive_input == "residual":
pyramid_downsample = functools.partial(Down_sample, fir_kernel=fir_kernel, with_conv=True) pyramid_downsample = functools.partial(Down_sample, use_conv=True)
# Downsampling block
channels = num_channels channels = num_channels
if progressive_input != "none": if progressive_input != "none":
...@@ -376,7 +386,7 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -376,7 +386,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
in_ch *= 2 in_ch *= 2
elif progressive_input == "residual": elif progressive_input == "residual":
modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch)) modules.append(pyramid_downsample(channels=input_pyramid_ch, out_channels=in_ch))
input_pyramid_ch = in_ch input_pyramid_ch = in_ch
hs_c.append(in_ch) hs_c.append(in_ch)
...@@ -448,7 +458,7 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -448,7 +458,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
) )
pyramid_ch = channels pyramid_ch = channels
elif progressive == "residual": elif progressive == "residual":
modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch)) modules.append(pyramid_upsample(channels=pyramid_ch, out_channels=in_ch))
pyramid_ch = in_ch pyramid_ch = in_ch
else: else:
raise ValueError(f"{progressive} is not a valid name") raise ValueError(f"{progressive} is not a valid name")
...@@ -505,7 +515,8 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -505,7 +515,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
temb = None temb = None
# If input data is in [0, 1] # If input data is in [0, 1]
x = 2 * x - 1.0 if not self.config.centered:
x = 2 * x - 1.0
# Downsampling block # Downsampling block
input_pyramid = None input_pyramid = None
......
...@@ -63,9 +63,6 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline): ...@@ -63,9 +63,6 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
# 4. set current image to prev_image: x_t -> x_t-1 # 4. set current image to prev_image: x_t -> x_t-1
image = pred_prev_image + variance image = pred_prev_image + variance
# scale and decode image with vae # decode image with vae
image = 1 / 0.18215 * image
image = self.vqvae.decode(image) image = self.vqvae.decode(image)
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
return image return image
...@@ -1159,7 +1159,9 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -1159,7 +1159,9 @@ class PipelineTesterMixin(unittest.TestCase):
image_slice = image[0, -1, -3:, -3:].cpu() image_slice = image[0, -1, -3:, -3:].cpu()
assert image.shape == (1, 3, 256, 256) assert image.shape == (1, 3, 256, 256)
expected_slice = torch.tensor([0.5025, 0.4121, 0.3851, 0.4806, 0.3996, 0.3745, 0.4839, 0.4559, 0.4293]) expected_slice = torch.tensor(
[-0.1202, -0.1005, -0.0635, -0.0520, -0.1282, -0.0838, -0.0981, -0.1318, -0.1106]
)
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
def test_module_from_pipeline(self): def test_module_from_pipeline(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment