Unverified Commit 8c31925b authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Get diffusers ready 🚀🚀🚀 (#101)

* big purge

* more fixes

* finish for now
parent 33344ed9
#!/usr/bin/env python3
import json
import os
from regex import P
from diffusers import UNetUnconditionalModel
from scripts.convert_ncsnpp_original_checkpoint_to_diffusers import convert_ncsnpp_checkpoint
from huggingface_hub import hf_hub_download
import torch
def convert_checkpoint(model_id, subfolder=None, checkpoint = "diffusion_model.pt", config = "config.json"):
if subfolder is not None:
checkpoint = os.path.join(subfolder, checkpoint)
config = os.path.join(subfolder, config)
original_checkpoint = torch.load(hf_hub_download(model_id, checkpoint),map_location='cpu')
config_path = hf_hub_download(model_id, config)
with open(config_path) as f:
config = json.load(f)
checkpoint = convert_ncsnpp_checkpoint(original_checkpoint, config)
def current_codebase_conversion(path):
model = UNetUnconditionalModel.from_pretrained(model_id, subfolder=subfolder, sde=True)
model.eval()
model.config.sde=False
model.save_config(path)
model.config.sde=True
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
time_step = torch.tensor([10] * noise.shape[0])
with torch.no_grad():
output = model(noise, time_step)
return model.state_dict()
path = f"{model_id}_converted"
currently_converted_checkpoint = current_codebase_conversion(path)
def diff_between_checkpoints(ch_0, ch_1):
all_layers_included = False
if not set(ch_0.keys()) == set(ch_1.keys()):
print(f"Contained in ch_0 and not in ch_1 (Total: {len((set(ch_0.keys()) - set(ch_1.keys())))})")
for key in sorted(list((set(ch_0.keys()) - set(ch_1.keys())))):
print(f"\t{key}")
print(f"Contained in ch_1 and not in ch_0 (Total: {len((set(ch_1.keys()) - set(ch_0.keys())))})")
for key in sorted(list((set(ch_1.keys()) - set(ch_0.keys())))):
print(f"\t{key}")
else:
print("Keys are the same between the two checkpoints")
all_layers_included = True
keys = ch_0.keys()
non_equal_keys = []
if all_layers_included:
for key in keys:
try:
if not torch.allclose(ch_0[key].cpu(), ch_1[key].cpu()):
non_equal_keys.append(f'{key}. Diff: {torch.max(torch.abs(ch_0[key].cpu() - ch_1[key].cpu()))}')
except RuntimeError as e:
print(e)
non_equal_keys.append(f'{key}. Diff in shape: {ch_0[key].size()} vs {ch_1[key].size()}')
if len(non_equal_keys):
non_equal_keys = '\n\t'.join(non_equal_keys)
print(f"These keys do not satisfy equivalence requirement:\n\t{non_equal_keys}")
else:
print("All keys are equal across checkpoints.")
diff_between_checkpoints(currently_converted_checkpoint, checkpoint)
os.makedirs( f"{model_id}_converted",exist_ok =True)
torch.save(checkpoint, f"{model_id}_converted/diffusion_model.pt")
model_ids = ["fusing/ffhq_ncsnpp","fusing/church_256-ncsnpp-ve", "fusing/celebahq_256-ncsnpp-ve",
"fusing/bedroom_256-ncsnpp-ve","fusing/ffhq_256-ncsnpp-ve","fusing/ncsnpp-ffhq-ve-dummy"
]
for model in model_ids:
print(f"converting {model}")
try:
convert_checkpoint(model)
except Exception as e:
print(e)
from tests.test_modeling_utils import PipelineTesterMixin, NCSNppModelTests
tester1 = NCSNppModelTests()
tester2 = PipelineTesterMixin()
os.environ["RUN_SLOW"] = '1'
cmd = "export RUN_SLOW=1; echo $RUN_SLOW" # or whatever command
os.system(cmd)
tester2.test_score_sde_ve_pipeline(f"{model_ids[0]}_converted")
tester1.test_output_pretrained_ve_mid(f"{model_ids[2]}_converted")
tester1.test_output_pretrained_ve_large(f"{model_ids[-1]}_converted")
#!/usr/bin/env python3
import numpy as np
import PIL
import torch
#from configs.ve import ffhq_ncsnpp_continuous as configs
# from configs.ve import cifar10_ncsnpp_continuous as configs
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cuda.matmul.allow_tf32 = False
torch.manual_seed(0)
class NewReverseDiffusionPredictor:
def __init__(self, score_fn, probability_flow=False, sigma_min=0.0, sigma_max=0.0, N=0):
super().__init__()
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.N = N
self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N))
self.probability_flow = probability_flow
self.score_fn = score_fn
def discretize(self, x, t):
timestep = (t * (self.N - 1)).long()
sigma = self.discrete_sigmas.to(t.device)[timestep]
adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t),
self.discrete_sigmas[timestep - 1].to(t.device))
f = torch.zeros_like(x)
G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2)
labels = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
result = self.score_fn(x, labels)
rev_f = f - G[:, None, None, None] ** 2 * result * (0.5 if self.probability_flow else 1.)
rev_G = torch.zeros_like(G) if self.probability_flow else G
return rev_f, rev_G
def update_fn(self, x, t):
f, G = self.discretize(x, t)
z = torch.randn_like(x)
x_mean = x - f
x = x_mean + G[:, None, None, None] * z
return x, x_mean
class NewLangevinCorrector:
def __init__(self, score_fn, snr, n_steps, sigma_min=0.0, sigma_max=0.0):
super().__init__()
self.score_fn = score_fn
self.snr = snr
self.n_steps = n_steps
self.sigma_min = sigma_min
self.sigma_max = sigma_max
def update_fn(self, x, t):
score_fn = self.score_fn
n_steps = self.n_steps
target_snr = self.snr
# if isinstance(sde, VPSDE) or isinstance(sde, subVPSDE):
# timestep = (t * (sde.N - 1) / sde.T).long()
# alpha = sde.alphas.to(t.device)[timestep]
# else:
alpha = torch.ones_like(t)
for i in range(n_steps):
labels = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
grad = score_fn(x, labels)
noise = torch.randn_like(x)
grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha
x_mean = x + step_size[:, None, None, None] * grad
x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None] * noise
return x, x_mean
def save_image(x):
image_processed = np.clip(x.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8)
image_pil = PIL.Image.fromarray(image_processed[0])
image_pil.save("../images/hey.png")
# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
#ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth"
# Note usually we need to restore ema etc...
# ema restored checkpoint used from below
N = 2
sigma_min = 0.01
sigma_max = 1348
sampling_eps = 1e-5
batch_size = 1
centered = False
from diffusers import NCSNpp
model = NCSNpp.from_pretrained("/home/patrick/ffhq_ncsnpp").to(device)
model = torch.nn.DataParallel(model)
img_size = model.module.config.image_size
channels = model.module.config.num_channels
shape = (batch_size, channels, img_size, img_size)
probability_flow = False
snr = 0.15
n_steps = 1
new_corrector = NewLangevinCorrector(score_fn=model, snr=snr, n_steps=n_steps, sigma_min=sigma_min, sigma_max=sigma_max)
new_predictor = NewReverseDiffusionPredictor(score_fn=model, sigma_min=sigma_min, sigma_max=sigma_max, N=N)
with torch.no_grad():
# Initial sample
x = torch.randn(*shape) * sigma_max
x = x.to(device)
timesteps = torch.linspace(1, sampling_eps, N, device=device)
for i in range(N):
t = timesteps[i]
vec_t = torch.ones(shape[0], device=t.device) * t
x, x_mean = new_corrector.update_fn(x, vec_t)
x, x_mean = new_predictor.update_fn(x, vec_t)
x = x_mean
if centered:
x = (x + 1.) / 2.
# save_image(x)
# for 5 cifar10
x_sum = 106071.9922
x_mean = 34.52864456176758
# for 1000 cifar10
x_sum = 461.9700
x_mean = 0.1504
# for 2 for 1024
x_sum = 3382810112.0
x_mean = 1075.366455078125
def check_x_sum_x_mean(x, x_sum, x_mean):
assert (x.abs().sum() - x_sum).abs().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}"
assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}"
check_x_sum_x_mean(x, x_sum, x_mean)
import torch
from torch import nn
from diffusers import ClassifierFreeGuidanceScheduler, DDIMScheduler, GlideSuperResUNetModel, GlideTextToImageUNetModel
from diffusers.pipelines.pipeline_glide import Glide, CLIPTextModel
from transformers import CLIPTextConfig, GPT2Tokenizer
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt
state_dict = torch.load("base.pt", map_location="cpu")
state_dict = {k: nn.Parameter(v) for k, v in state_dict.items()}
### Convert the text encoder
config = CLIPTextConfig(
vocab_size=50257,
max_position_embeddings=128,
hidden_size=512,
intermediate_size=2048,
num_hidden_layers=16,
num_attention_heads=8,
use_padding_embeddings=True,
)
model = CLIPTextModel(config).eval()
tokenizer = GPT2Tokenizer(
"./glide-base/tokenizer/vocab.json", "./glide-base/tokenizer/merges.txt", pad_token="<|endoftext|>"
)
hf_encoder = model.text_model
hf_encoder.embeddings.token_embedding.weight = state_dict["token_embedding.weight"]
hf_encoder.embeddings.position_embedding.weight.data = state_dict["positional_embedding"]
hf_encoder.embeddings.padding_embedding.weight.data = state_dict["padding_embedding"]
hf_encoder.final_layer_norm.weight = state_dict["final_ln.weight"]
hf_encoder.final_layer_norm.bias = state_dict["final_ln.bias"]
for layer_idx in range(config.num_hidden_layers):
hf_layer = hf_encoder.encoder.layers[layer_idx]
hf_layer.self_attn.qkv_proj.weight = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.weight"]
hf_layer.self_attn.qkv_proj.bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.bias"]
hf_layer.self_attn.out_proj.weight = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.weight"]
hf_layer.self_attn.out_proj.bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.bias"]
hf_layer.layer_norm1.weight = state_dict[f"transformer.resblocks.{layer_idx}.ln_1.weight"]
hf_layer.layer_norm1.bias = state_dict[f"transformer.resblocks.{layer_idx}.ln_1.bias"]
hf_layer.layer_norm2.weight = state_dict[f"transformer.resblocks.{layer_idx}.ln_2.weight"]
hf_layer.layer_norm2.bias = state_dict[f"transformer.resblocks.{layer_idx}.ln_2.bias"]
hf_layer.mlp.fc1.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_fc.weight"]
hf_layer.mlp.fc1.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_fc.bias"]
hf_layer.mlp.fc2.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.weight"]
hf_layer.mlp.fc2.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.bias"]
### Convert the Text-to-Image UNet
text2im_model = GlideTextToImageUNetModel(
in_channels=3,
model_channels=192,
out_channels=6,
num_res_blocks=3,
attention_resolutions=(2, 4, 8),
dropout=0.1,
channel_mult=(1, 2, 3, 4),
num_heads=1,
num_head_channels=64,
num_heads_upsample=1,
use_scale_shift_norm=True,
resblock_updown=True,
transformer_dim=512,
)
text2im_model.load_state_dict(state_dict, strict=False)
text_scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="squaredcos_cap_v2")
### Convert the Super-Resolution UNet
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample.pt
ups_state_dict = torch.load("upsample.pt", map_location="cpu")
superres_model = GlideSuperResUNetModel(
in_channels=6,
model_channels=192,
out_channels=6,
num_res_blocks=2,
attention_resolutions=(8, 16, 32),
dropout=0.1,
channel_mult=(1, 1, 2, 2, 4, 4),
num_heads=1,
num_head_channels=64,
num_heads_upsample=1,
use_scale_shift_norm=True,
resblock_updown=True,
)
superres_model.load_state_dict(ups_state_dict, strict=False)
upscale_scheduler = DDIMScheduler(
timesteps=1000, beta_schedule="linear", beta_start=0.0001, beta_end=0.02, tensor_format="pt"
)
glide = Glide(
text_unet=text2im_model,
text_noise_scheduler=text_scheduler,
text_encoder=model,
tokenizer=tokenizer,
upscale_unet=superres_model,
upscale_noise_scheduler=upscale_scheduler,
)
glide.save_pretrained("./glide-base")
...@@ -7,36 +7,13 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode ...@@ -7,36 +7,13 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode
__version__ = "0.0.4" __version__ = "0.0.4"
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .models import ( from .models import AutoencoderKL, UNetConditionalModel, UNetUnconditionalModel, VQModel
AutoencoderKL,
NCSNpp,
UNetConditionalModel,
UNetLDMModel,
UNetModel,
UNetUnconditionalModel,
VQModel,
)
from .pipeline_utils import DiffusionPipeline from .pipeline_utils import DiffusionPipeline
from .pipelines import ( from .pipelines import DDIMPipeline, DDPMPipeline, LatentDiffusionUncondPipeline, PNDMPipeline, ScoreSdeVePipeline
DDIMPipeline, from .schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, SchedulerMixin, ScoreSdeVeScheduler
DDPMPipeline,
LatentDiffusionUncondPipeline,
PNDMPipeline,
ScoreSdeVePipeline,
ScoreSdeVpPipeline,
)
from .schedulers import (
DDIMScheduler,
DDPMScheduler,
PNDMScheduler,
SchedulerMixin,
ScoreSdeVeScheduler,
ScoreSdeVpScheduler,
)
if is_transformers_available(): if is_transformers_available():
from .models.unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, GlideUNetModel from .pipelines import LatentDiffusionPipeline
from .pipelines import GlidePipeline, LatentDiffusionPipeline
else: else:
from .utils.dummy_transformers_objects import * from .utils.dummy_transformers_objects import *
...@@ -16,10 +16,6 @@ ...@@ -16,10 +16,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .unet import UNetModel
from .unet_conditional import UNetConditionalModel from .unet_conditional import UNetConditionalModel
from .unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, GlideUNetModel
from .unet_ldm import UNetLDMModel
from .unet_sde_score_estimation import NCSNpp
from .unet_unconditional import UNetUnconditionalModel from .unet_unconditional import UNetUnconditionalModel
from .vae import AutoencoderKL, VQModel from .vae import AutoencoderKL, VQModel
...@@ -54,6 +54,43 @@ def get_timestep_embedding( ...@@ -54,6 +54,43 @@ def get_timestep_embedding(
return emb return emb
class TimestepEmbedding(nn.Module):
def __init__(self, channel, time_embed_dim, act_fn="silu"):
super().__init__()
self.linear_1 = nn.Linear(channel, time_embed_dim)
self.act = None
if act_fn == "silu":
self.act = nn.SiLU()
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
def forward(self, sample):
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
class Timesteps(nn.Module):
def __init__(self, num_channels, flip_sin_to_cos, downscale_freq_shift):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
)
return t_emb
class GaussianFourierProjection(nn.Module): class GaussianFourierProjection(nn.Module):
"""Gaussian Fourier embeddings for noise levels.""" """Gaussian Fourier embeddings for noise levels."""
......
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# helpers functions
import torch
from torch import nn
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention import AttentionBlock
from .embeddings import get_timestep_embedding
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
from .unet_new import UNetMidBlock2D
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class UNetModel(ModelMixin, ConfigMixin):
def __init__(
self,
ch=128,
out_ch=3,
ch_mult=(1, 1, 2, 2, 4, 4),
num_res_blocks=2,
attn_resolutions=(16,),
dropout=0.0,
resamp_with_conv=True,
in_channels=3,
resolution=256,
):
super().__init__()
self.register_to_config(
ch=ch,
out_ch=out_ch,
ch_mult=ch_mult,
num_res_blocks=num_res_blocks,
attn_resolutions=attn_resolutions,
dropout=dropout,
resamp_with_conv=resamp_with_conv,
in_channels=in_channels,
resolution=resolution,
)
ch_mult = tuple(ch_mult)
self.ch = ch
self.temb_ch = self.ch * 4
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# timestep embedding
self.temb = nn.Module()
self.temb.dense = nn.ModuleList(
[
torch.nn.Linear(self.ch, self.temb_ch),
torch.nn.Linear(self.temb_ch, self.temb_ch),
]
)
# downsampling
self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
curr_res = resolution
in_ch_mult = (1,) + ch_mult
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ResnetBlock2D(
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttentionBlock(block_in, overwrite_qkv=True))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample2D(block_in, use_conv=resamp_with_conv, padding=0)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock2D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
)
self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True)
self.mid.block_2 = ResnetBlock2D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
)
self.mid_new = UNetMidBlock2D(in_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
self.mid_new.resnets[0] = self.mid.block_1
self.mid_new.attentions[0] = self.mid.attn_1
self.mid_new.resnets[1] = self.mid.block_2
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
skip_in = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
if i_block == self.num_res_blocks:
skip_in = ch * in_ch_mult[i_level]
block.append(
ResnetBlock2D(
in_channels=block_in + skip_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(AttentionBlock(block_in, overwrite_qkv=True))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample2D(block_in, use_conv=resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, sample, timesteps):
x = sample
assert x.shape[2] == x.shape[3] == self.resolution
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=x.device)
# timestep embedding
temb = get_timestep_embedding(timesteps, self.ch)
temb = self.temb.dense[0](temb)
temb = nonlinearity(temb)
temb = self.temb.dense[1](temb)
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = self.mid_new(hs[-1], temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
import functools
import math
from typing import Dict, Union from typing import Dict, Union
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import AttentionBlock, SpatialTransformer from .embeddings import TimestepEmbedding, Timesteps
from .embeddings import GaussianFourierProjection, get_timestep_embedding from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
from .unet_new import UNetMidBlock2DCrossAttn, get_down_block, get_up_block
class Combine(nn.Module):
"""Combine information from skip connections."""
def __init__(self, dim1, dim2, method="cat"):
super().__init__()
# 1x1 convolution with DDPM initialization.
self.Conv_0 = nn.Conv2d(dim1, dim2, kernel_size=1, padding=0)
self.method = method
# def forward(self, x, y):
# h = self.Conv_0(x)
# if self.method == "cat":
# return torch.cat([h, y], dim=1)
# elif self.method == "sum":
# return h + y
# else:
# raise ValueError(f"Method {self.method} not recognized.")
class TimestepEmbedding(nn.Module):
def __init__(self, channel, time_embed_dim, act_fn="silu"):
super().__init__()
self.linear_1 = nn.Linear(channel, time_embed_dim)
self.act = None
if act_fn == "silu":
self.act = nn.SiLU()
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
def forward(self, sample):
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
class Timesteps(nn.Module):
def __init__(self, num_channels, flip_sin_to_cos, downscale_freq_shift):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
)
return t_emb
class UNetConditionalModel(ModelMixin, ConfigMixin): class UNetConditionalModel(ModelMixin, ConfigMixin):
...@@ -124,38 +62,7 @@ class UNetConditionalModel(ModelMixin, ConfigMixin): ...@@ -124,38 +62,7 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
downscale_freq_shift=0, downscale_freq_shift=0,
mid_block_scale_factor=1, mid_block_scale_factor=1,
center_input_sample=False, center_input_sample=False,
# TODO(PVP) - to delete later at release resnet_num_groups=30,
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# ======================================
# LDM
attention_resolutions=(4, 2, 1),
# DDPM
out_ch=None,
resolution=None,
attn_resolutions=None,
resamp_with_conv=None,
ch_mult=None,
ch=None,
ddpm=False,
# SDE
sde=False,
nf=None,
fir=None,
progressive=None,
progressive_combine=None,
scale_by_sigma=None,
skip_rescale=None,
num_channels=None,
centered=False,
conditional=True,
conv_size=3,
fir_kernel=(1, 3, 3, 1),
fourier_scale=16,
init_scale=0.0,
progressive_input="input_skip",
resnet_num_groups=32,
continuous=True,
ldm=False,
): ):
super().__init__() super().__init__()
# register all __init__ params to be accessible via `self.config.<...>` # register all __init__ params to be accessible via `self.config.<...>`
...@@ -175,21 +82,13 @@ class UNetConditionalModel(ModelMixin, ConfigMixin): ...@@ -175,21 +82,13 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
num_head_channels=num_head_channels, num_head_channels=num_head_channels,
flip_sin_to_cos=flip_sin_to_cos, flip_sin_to_cos=flip_sin_to_cos,
downscale_freq_shift=downscale_freq_shift, downscale_freq_shift=downscale_freq_shift,
attention_resolutions=attention_resolutions,
attn_resolutions=attn_resolutions,
mid_block_scale_factor=mid_block_scale_factor, mid_block_scale_factor=mid_block_scale_factor,
resnet_num_groups=resnet_num_groups, resnet_num_groups=resnet_num_groups,
center_input_sample=center_input_sample, center_input_sample=center_input_sample,
) )
self.ldm = ldm
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# ======================================
self.image_size = image_size self.image_size = image_size
time_embed_dim = block_channels[0] * 4 time_embed_dim = block_channels[0] * 4
# ======================================
# input # input
self.conv_in = nn.Conv2d(in_channels, block_channels[0], kernel_size=3, padding=(1, 1)) self.conv_in = nn.Conv2d(in_channels, block_channels[0], kernel_size=3, padding=(1, 1))
...@@ -264,57 +163,18 @@ class UNetConditionalModel(ModelMixin, ConfigMixin): ...@@ -264,57 +163,18 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
prev_output_channel = output_channel prev_output_channel = output_channel
# out # out
num_groups_out = resnet_num_groups if resnet_num_groups is not None else min(block_channels[0] // 4, 32) self.conv_norm_out = nn.GroupNorm(num_channels=block_channels[0], num_groups=resnet_num_groups, eps=resnet_eps)
self.conv_norm_out = nn.GroupNorm(num_channels=block_channels[0], num_groups=num_groups_out, eps=resnet_eps)
self.conv_act = nn.SiLU() self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_channels[0], out_channels, 3, padding=1) self.conv_out = nn.Conv2d(block_channels[0], out_channels, 3, padding=1)
# ======================== Out ====================
# =========== TO DELETE AFTER CONVERSION ==========
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# ======================================
self.is_overwritten = False
if ldm:
num_heads = 8
num_head_channels = -1
transformer_depth = 1
use_spatial_transformer = True
context_dim = 1280
legacy = False
model_channels = block_channels[0]
channel_mult = tuple([x // model_channels for x in block_channels])
self.init_for_ldm(
in_channels,
model_channels,
channel_mult,
num_res_blocks,
dropout,
time_embed_dim,
attention_resolutions,
num_head_channels,
num_heads,
legacy,
False,
transformer_depth,
context_dim,
conv_resample,
out_channels,
)
def forward( def forward(
self, self,
sample: torch.FloatTensor, sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int], timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
) -> Dict[str, torch.FloatTensor]: ) -> Dict[str, torch.FloatTensor]:
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# ======================================
if not self.is_overwritten:
self.set_weights()
# 0. center input if necessary
if self.config.center_input_sample: if self.config.center_input_sample:
sample = 2 * sample - 1.0 sample = 2 * sample - 1.0
...@@ -329,7 +189,6 @@ class UNetConditionalModel(ModelMixin, ConfigMixin): ...@@ -329,7 +189,6 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
emb = self.time_embedding(t_emb) emb = self.time_embedding(t_emb)
# 2. pre-process # 2. pre-process
skip_sample = sample
sample = self.conv_in(sample) sample = self.conv_in(sample)
# 3. down # 3. down
...@@ -349,7 +208,6 @@ class UNetConditionalModel(ModelMixin, ConfigMixin): ...@@ -349,7 +208,6 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
sample = self.mid(sample, emb, encoder_hidden_states=encoder_hidden_states) sample = self.mid(sample, emb, encoder_hidden_states=encoder_hidden_states)
# 5. up # 5. up
skip_sample = None
for upsample_block in self.upsample_blocks: for upsample_block in self.upsample_blocks:
res_samples = down_block_res_samples[-len(upsample_block.resnets) :] res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
...@@ -374,259 +232,3 @@ class UNetConditionalModel(ModelMixin, ConfigMixin): ...@@ -374,259 +232,3 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
output = {"sample": sample} output = {"sample": sample}
return output return output
# !!!IMPORTANT - ALL OF THE FOLLOWING CODE WILL BE DELETED AT RELEASE TIME AND SHOULD NOT BE TAKEN INTO CONSIDERATION WHEN EVALUATING THE API ###
# =================================================================================================================================================
def set_weights(self):
self.is_overwritten = True
if self.ldm:
self.time_embedding.linear_1.weight.data = self.time_embed[0].weight.data
self.time_embedding.linear_1.bias.data = self.time_embed[0].bias.data
self.time_embedding.linear_2.weight.data = self.time_embed[2].weight.data
self.time_embedding.linear_2.bias.data = self.time_embed[2].bias.data
self.conv_in.weight.data = self.input_blocks[0][0].weight.data
self.conv_in.bias.data = self.input_blocks[0][0].bias.data
# ================ SET WEIGHTS OF ALL WEIGHTS ==================
for i, input_layer in enumerate(self.input_blocks[1:]):
block_id = i // (self.config.num_res_blocks + 1)
layer_in_block_id = i % (self.config.num_res_blocks + 1)
if layer_in_block_id == 2:
self.downsample_blocks[block_id].downsamplers[0].conv.weight.data = input_layer[0].op.weight.data
self.downsample_blocks[block_id].downsamplers[0].conv.bias.data = input_layer[0].op.bias.data
elif len(input_layer) > 1:
self.downsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
self.downsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1])
else:
self.downsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
self.mid.resnets[0].set_weight(self.middle_block[0])
self.mid.resnets[1].set_weight(self.middle_block[2])
self.mid.attentions[0].set_weight(self.middle_block[1])
for i, input_layer in enumerate(self.output_blocks):
block_id = i // (self.config.num_res_blocks + 1)
layer_in_block_id = i % (self.config.num_res_blocks + 1)
if len(input_layer) > 2:
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
self.upsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1])
self.upsample_blocks[block_id].upsamplers[0].conv.weight.data = input_layer[2].conv.weight.data
self.upsample_blocks[block_id].upsamplers[0].conv.bias.data = input_layer[2].conv.bias.data
elif len(input_layer) > 1 and "Upsample2D" in input_layer[1].__class__.__name__:
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
self.upsample_blocks[block_id].upsamplers[0].conv.weight.data = input_layer[1].conv.weight.data
self.upsample_blocks[block_id].upsamplers[0].conv.bias.data = input_layer[1].conv.bias.data
elif len(input_layer) > 1:
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
self.upsample_blocks[block_id].attentions[layer_in_block_id].set_weight(input_layer[1])
else:
self.upsample_blocks[block_id].resnets[layer_in_block_id].set_weight(input_layer[0])
self.conv_norm_out.weight.data = self.out[0].weight.data
self.conv_norm_out.bias.data = self.out[0].bias.data
self.conv_out.weight.data = self.out[2].weight.data
self.conv_out.bias.data = self.out[2].bias.data
self.remove_ldm()
def init_for_ldm(
self,
in_channels,
model_channels,
channel_mult,
num_res_blocks,
dropout,
time_embed_dim,
attention_resolutions,
num_head_channels,
num_heads,
legacy,
use_spatial_transformer,
transformer_depth,
context_dim,
conv_resample,
out_channels,
):
# TODO(PVP) - delete after weight conversion
class TimestepEmbedSequential(nn.Sequential):
"""
A sequential module that passes timestep embeddings to the children that support it as an extra input.
"""
pass
# TODO(PVP) - delete after weight conversion
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
self.time_embed = nn.Sequential(
nn.Linear(model_channels, time_embed_dim),
nn.SiLU(),
nn.Linear(time_embed_dim, time_embed_dim),
)
dims = 2
self.input_blocks = nn.ModuleList(
[TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]
)
self._feature_size = model_channels
input_block_chans = [model_channels]
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
ResnetBlock2D(
in_channels=ch,
out_channels=mult * model_channels,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
overwrite_for_ldm=True,
)
]
ch = mult * model_channels
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
# num_heads = 1
dim_head = num_head_channels
layers.append(
SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
),
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
Downsample2D(ch, use_conv=conv_resample, out_channels=out_ch, padding=1, name="op")
)
)
ch = out_ch
input_block_chans.append(ch)
ds *= 2
self._feature_size += ch
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
# num_heads = 1
dim_head = num_head_channels
if dim_head < 0:
dim_head = None
# TODO(Patrick) - delete after weight conversion
# init to be able to overwrite `self.mid`
self.middle_block = TimestepEmbedSequential(
ResnetBlock2D(
in_channels=ch,
out_channels=None,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
overwrite_for_ldm=True,
),
SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
),
ResnetBlock2D(
in_channels=ch,
out_channels=None,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
overwrite_for_ldm=True,
),
)
self._feature_size += ch
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]:
for i in range(num_res_blocks + 1):
ich = input_block_chans.pop()
layers = [
ResnetBlock2D(
in_channels=ch + ich,
out_channels=model_channels * mult,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
overwrite_for_ldm=True,
),
]
ch = model_channels * mult
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
# num_heads = 1
dim_head = num_head_channels
layers.append(
SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
)
)
if level and i == num_res_blocks:
out_ch = ch
layers.append(Upsample2D(ch, use_conv=conv_resample, out_channels=out_ch))
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
self.out = nn.Sequential(
nn.GroupNorm(num_channels=model_channels, num_groups=32, eps=1e-5),
nn.SiLU(),
nn.Conv2d(model_channels, out_channels, 3, padding=1),
)
def remove_ldm(self):
del self.time_embed
del self.input_blocks
del self.middle_block
del self.output_blocks
del self.out
This diff is collapsed.
This diff is collapsed.
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# helpers functions
import functools
import math
import numpy as np
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention import AttentionBlock
from .embeddings import GaussianFourierProjection, get_timestep_embedding
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
from .unet_new import UNetMidBlock2D
class Combine(nn.Module):
"""Combine information from skip connections."""
def __init__(self, dim1, dim2, method="cat"):
super().__init__()
# 1x1 convolution with DDPM initialization.
self.Conv_0 = nn.Conv2d(dim1, dim2, kernel_size=1, padding=0)
self.method = method
def forward(self, x, y):
h = self.Conv_0(x)
if self.method == "cat":
return torch.cat([h, y], dim=1)
elif self.method == "sum":
return h + y
else:
raise ValueError(f"Method {self.method} not recognized.")
class NCSNpp(ModelMixin, ConfigMixin):
"""NCSN++ model"""
def __init__(
self,
image_size=1024,
num_channels=3,
centered=False,
attn_resolutions=(16,),
ch_mult=(1, 2, 4, 8, 16, 32, 32, 32),
conditional=True,
conv_size=3,
dropout=0.0,
embedding_type="fourier",
fir=True,
fir_kernel=(1, 3, 3, 1),
fourier_scale=16,
init_scale=0.0,
nf=16,
num_res_blocks=1,
progressive="output_skip",
progressive_combine="sum",
progressive_input="input_skip",
resamp_with_conv=True,
scale_by_sigma=True,
skip_rescale=True,
continuous=True,
):
super().__init__()
self.register_to_config(
image_size=image_size,
num_channels=num_channels,
centered=centered,
attn_resolutions=attn_resolutions,
ch_mult=ch_mult,
conditional=conditional,
conv_size=conv_size,
dropout=dropout,
embedding_type=embedding_type,
fir=fir,
fir_kernel=fir_kernel,
fourier_scale=fourier_scale,
init_scale=init_scale,
nf=nf,
num_res_blocks=num_res_blocks,
progressive=progressive,
progressive_combine=progressive_combine,
progressive_input=progressive_input,
resamp_with_conv=resamp_with_conv,
scale_by_sigma=scale_by_sigma,
skip_rescale=skip_rescale,
continuous=continuous,
)
self.act = nn.SiLU()
self.nf = nf
self.num_res_blocks = num_res_blocks
self.attn_resolutions = attn_resolutions
self.num_resolutions = len(ch_mult)
self.all_resolutions = all_resolutions = [image_size // (2**i) for i in range(self.num_resolutions)]
self.conditional = conditional
self.skip_rescale = skip_rescale
self.progressive = progressive
self.progressive_input = progressive_input
self.embedding_type = embedding_type
assert progressive in ["none", "output_skip", "residual"]
assert progressive_input in ["none", "input_skip", "residual"]
assert embedding_type in ["fourier", "positional"]
combine_method = progressive_combine.lower()
combiner = functools.partial(Combine, method=combine_method)
modules = []
# timestep/noise_level embedding; only for continuous training
if embedding_type == "fourier":
# Gaussian Fourier features embeddings.
modules.append(GaussianFourierProjection(embedding_size=nf, scale=fourier_scale))
embed_dim = 2 * nf
elif embedding_type == "positional":
embed_dim = nf
else:
raise ValueError(f"embedding type {embedding_type} unknown.")
modules.append(nn.Linear(embed_dim, nf * 4))
modules.append(nn.Linear(nf * 4, nf * 4))
AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
if self.fir:
Up_sample = functools.partial(FirUpsample2D, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
else:
Up_sample = functools.partial(Upsample2D, name="Conv2d_0")
if progressive == "output_skip":
self.pyramid_upsample = Up_sample(channels=None, use_conv=False)
elif progressive == "residual":
pyramid_upsample = functools.partial(Up_sample, use_conv=True)
if self.fir:
Down_sample = functools.partial(FirDownsample2D, fir_kernel=fir_kernel, use_conv=resamp_with_conv)
else:
Down_sample = functools.partial(Downsample2D, padding=0, name="Conv2d_0")
if progressive_input == "input_skip":
self.pyramid_downsample = Down_sample(channels=None, use_conv=False)
elif progressive_input == "residual":
pyramid_downsample = functools.partial(Down_sample, use_conv=True)
channels = num_channels
if progressive_input != "none":
input_pyramid_ch = channels
modules.append(nn.Conv2d(channels, nf, kernel_size=3, padding=1))
hs_c = [nf]
in_ch = nf
for i_level in range(self.num_resolutions):
# Residual blocks for this resolution
for i_block in range(num_res_blocks):
out_ch = nf * ch_mult[i_level]
modules.append(
ResnetBlock2D(
in_channels=in_ch,
out_channels=out_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
)
)
in_ch = out_ch
if all_resolutions[i_level] in attn_resolutions:
modules.append(AttnBlock(channels=in_ch))
hs_c.append(in_ch)
if i_level != self.num_resolutions - 1:
modules.append(
ResnetBlock2D(
in_channels=in_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
down=True,
kernel="fir" if self.fir else "sde_vp",
use_nin_shortcut=True,
)
)
if progressive_input == "input_skip":
modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
if combine_method == "cat":
in_ch *= 2
elif progressive_input == "residual":
modules.append(pyramid_downsample(channels=input_pyramid_ch, out_channels=in_ch))
input_pyramid_ch = in_ch
hs_c.append(in_ch)
# mid
self.mid = UNetMidBlock2D(
in_channels=in_ch,
temb_channels=4 * nf,
output_scale_factor=math.sqrt(2.0),
resnet_act_fn="silu",
resnet_groups=min(in_ch // 4, 32),
dropout=dropout,
)
in_ch = hs_c[-1]
modules.append(
ResnetBlock2D(
in_channels=in_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
)
)
modules.append(AttnBlock(channels=in_ch))
modules.append(
ResnetBlock2D(
in_channels=in_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
)
)
# self.mid.resnets[0] = modules[len(modules) - 3]
# self.mid.attentions[0] = modules[len(modules) - 2]
# self.mid.resnets[1] = modules[len(modules) - 1]
pyramid_ch = 0
# Upsampling block
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(num_res_blocks + 1):
out_ch = nf * ch_mult[i_level]
in_ch = in_ch + hs_c.pop()
modules.append(
ResnetBlock2D(
in_channels=in_ch,
out_channels=out_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
)
)
in_ch = out_ch
if all_resolutions[i_level] in attn_resolutions:
modules.append(AttnBlock(channels=in_ch))
if progressive != "none":
if i_level == self.num_resolutions - 1:
if progressive == "output_skip":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(nn.Conv2d(in_ch, channels, kernel_size=3, padding=1))
pyramid_ch = channels
# elif progressive == "residual":
# modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
# modules.append(nn.Conv2d(in_ch, in_ch, bias=True, kernel_size=3, padding=1))
# pyramid_ch = in_ch
# else:
# raise ValueError(f"{progressive} is not a valid name.")
else:
if progressive == "output_skip":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(nn.Conv2d(in_ch, channels, bias=True, kernel_size=3, padding=1))
pyramid_ch = channels
# elif progressive == "residual":
# modules.append(pyramid_upsample(channels=pyramid_ch, out_channels=in_ch))
# pyramid_ch = in_ch
# else:
# raise ValueError(f"{progressive} is not a valid name")
if i_level != 0:
modules.append(
ResnetBlock2D(
in_channels=in_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
up=True,
kernel="fir" if self.fir else "sde_vp",
use_nin_shortcut=True,
)
)
assert not hs_c
if progressive != "output_skip":
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
modules.append(nn.Conv2d(in_ch, channels, kernel_size=3, padding=1))
self.all_modules = nn.ModuleList(modules)
def forward(self, sample, timestep, sigmas=None):
timesteps = timestep
x = sample
# timestep/noise_level embedding; only for continuous training
modules = self.all_modules
m_idx = 0
if self.embedding_type == "fourier":
# Gaussian Fourier features embeddings.
used_sigmas = timesteps
temb = modules[m_idx](used_sigmas)
m_idx += 1
elif self.embedding_type == "positional":
# Sinusoidal positional embeddings.
timesteps = timesteps
used_sigmas = sigmas
temb = get_timestep_embedding(timesteps, self.nf)
else:
raise ValueError(f"embedding type {self.embedding_type} unknown.")
if self.conditional:
temb = modules[m_idx](temb)
m_idx += 1
temb = modules[m_idx](self.act(temb))
m_idx += 1
else:
temb = None
# If input data is in [0, 1]
if not self.config.centered:
x = 2 * x - 1.0
# Downsampling block
input_pyramid = None
if self.progressive_input != "none":
input_pyramid = x
hs = [modules[m_idx](x)]
m_idx += 1
for i_level in range(self.num_resolutions):
# Residual blocks for this resolution
for i_block in range(self.num_res_blocks):
h = modules[m_idx](hs[-1], temb)
m_idx += 1
if h.shape[-1] in self.attn_resolutions:
h = modules[m_idx](h)
m_idx += 1
hs.append(h)
if i_level != self.num_resolutions - 1:
h = modules[m_idx](hs[-1], temb)
m_idx += 1
if self.progressive_input == "input_skip":
input_pyramid = self.pyramid_downsample(input_pyramid)
h = modules[m_idx](input_pyramid, h)
m_idx += 1
elif self.progressive_input == "residual":
input_pyramid = modules[m_idx](input_pyramid)
m_idx += 1
if self.skip_rescale:
input_pyramid = (input_pyramid + h) / np.sqrt(2.0)
else:
input_pyramid = input_pyramid + h
h = input_pyramid
hs.append(h)
h = hs[-1]
h = modules[m_idx](h, temb)
m_idx += 1
h = modules[m_idx](h)
m_idx += 1
h = modules[m_idx](h, temb)
m_idx += 1
pyramid = None
# Upsampling block
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb)
m_idx += 1
if h.shape[-1] in self.attn_resolutions:
h = modules[m_idx](h)
m_idx += 1
if self.progressive != "none":
if i_level == self.num_resolutions - 1:
if self.progressive == "output_skip":
pyramid = self.act(modules[m_idx](h))
m_idx += 1
pyramid = modules[m_idx](pyramid)
m_idx += 1
# elif self.progressive == "residual":
# pyramid = self.act(modules[m_idx](h))
# m_idx += 1
# pyramid = modules[m_idx](pyramid)
# m_idx += 1
# else:
# raise ValueError(f"{self.progressive} is not a valid name.")
else:
if self.progressive == "output_skip":
pyramid_h = self.act(modules[m_idx](h))
m_idx += 1
pyramid_h = modules[m_idx](pyramid_h)
m_idx += 1
skip_sample = self.pyramid_upsample(pyramid)
pyramid = skip_sample + pyramid_h
# elif self.progressive == "residual":
# pyramid = modules[m_idx](pyramid)
# m_idx += 1
# if self.skip_rescale:
# pyramid = (pyramid + h) / np.sqrt(2.0)
# else:
# pyramid = pyramid + h
# h = pyramid
# else:
# raise ValueError(f"{self.progressive} is not a valid name")
if i_level != 0:
h = modules[m_idx](h, temb)
m_idx += 1
assert not hs
if self.progressive == "output_skip":
h = pyramid
else:
h = self.act(modules[m_idx](h))
m_idx += 1
h = modules[m_idx](h)
m_idx += 1
assert m_idx == len(modules)
if self.config.scale_by_sigma:
used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:]))))
h = h / used_sigmas
return h
This diff is collapsed.
...@@ -4,9 +4,7 @@ from .ddpm import DDPMPipeline ...@@ -4,9 +4,7 @@ from .ddpm import DDPMPipeline
from .latent_diffusion_uncond import LatentDiffusionUncondPipeline from .latent_diffusion_uncond import LatentDiffusionUncondPipeline
from .pndm import PNDMPipeline from .pndm import PNDMPipeline
from .score_sde_ve import ScoreSdeVePipeline from .score_sde_ve import ScoreSdeVePipeline
from .score_sde_vp import ScoreSdeVpPipeline
if is_transformers_available(): if is_transformers_available():
from .glide import GlidePipeline
from .latent_diffusion import LatentDiffusionPipeline from .latent_diffusion import LatentDiffusionPipeline
from ...utils import is_transformers_available
if is_transformers_available():
from .pipeline_glide import CLIPTextModel, GlidePipeline
This diff is collapsed.
from .pipeline_score_sde_vp import ScoreSdeVpPipeline
#!/usr/bin/env python3
import torch
from diffusers import DiffusionPipeline
# TODO(Patrick, Anton, Suraj) - rename `x` to better variable names
class ScoreSdeVpPipeline(DiffusionPipeline):
def __init__(self, model, scheduler):
super().__init__()
self.register_modules(model=model, scheduler=scheduler)
def __call__(self, num_inference_steps=1000, generator=None):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
img_size = self.model.config.image_size
channels = self.model.config.num_channels
shape = (1, channels, img_size, img_size)
model = self.model.to(device)
x = torch.randn(*shape).to(device)
self.scheduler.set_timesteps(num_inference_steps)
for t in self.scheduler.timesteps:
t = t * torch.ones(shape[0], device=device)
scaled_t = t * (num_inference_steps - 1)
# TODO add corrector
with torch.no_grad():
result = model(x, scaled_t)
x, x_mean = self.scheduler.step_pred(result, x, t)
x_mean = (x_mean + 1.0) / 2.0
return x_mean
...@@ -255,8 +255,6 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -255,8 +255,6 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
def prepare_init_args_and_inputs_for_common(self): def prepare_init_args_and_inputs_for_common(self):
init_dict = { init_dict = {
"ch": 32,
"ch_mult": (1, 2),
"block_channels": (32, 64), "block_channels": (32, 64),
"down_blocks": ("UNetResDownBlock2D", "UNetResAttnDownBlock2D"), "down_blocks": ("UNetResDownBlock2D", "UNetResAttnDownBlock2D"),
"up_blocks": ("UNetResAttnUpBlock2D", "UNetResUpBlock2D"), "up_blocks": ("UNetResAttnUpBlock2D", "UNetResUpBlock2D"),
...@@ -264,8 +262,6 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -264,8 +262,6 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
"out_channels": 3, "out_channels": 3,
"in_channels": 3, "in_channels": 3,
"num_res_blocks": 2, "num_res_blocks": 2,
"attn_resolutions": (16,),
"resolution": 32,
"image_size": 32, "image_size": 32,
} }
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
...@@ -322,13 +318,11 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -322,13 +318,11 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
"in_channels": 4, "in_channels": 4,
"out_channels": 4, "out_channels": 4,
"num_res_blocks": 2, "num_res_blocks": 2,
"attention_resolutions": (16,),
"block_channels": (32, 64), "block_channels": (32, 64),
"num_head_channels": 32, "num_head_channels": 32,
"conv_resample": True, "conv_resample": True,
"down_blocks": ("UNetResDownBlock2D", "UNetResDownBlock2D"), "down_blocks": ("UNetResDownBlock2D", "UNetResDownBlock2D"),
"up_blocks": ("UNetResUpBlock2D", "UNetResUpBlock2D"), "up_blocks": ("UNetResUpBlock2D", "UNetResUpBlock2D"),
"ldm": True,
} }
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
return init_dict, inputs_dict return init_dict, inputs_dict
...@@ -529,8 +523,8 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -529,8 +523,8 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
"ch": 64, "ch": 64,
"out_ch": 3, "out_ch": 3,
"num_res_blocks": 1, "num_res_blocks": 1,
"attn_resolutions": [],
"in_channels": 3, "in_channels": 3,
"attn_resolutions": [],
"resolution": 32, "resolution": 32,
"z_channels": 3, "z_channels": 3,
"n_embed": 256, "n_embed": 256,
...@@ -605,11 +599,11 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase): ...@@ -605,11 +599,11 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
"ch_mult": (1,), "ch_mult": (1,),
"embed_dim": 4, "embed_dim": 4,
"in_channels": 3, "in_channels": 3,
"attn_resolutions": [],
"num_res_blocks": 1, "num_res_blocks": 1,
"out_ch": 3, "out_ch": 3,
"resolution": 32, "resolution": 32,
"z_channels": 4, "z_channels": 4,
"attn_resolutions": [],
} }
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
return init_dict, inputs_dict return init_dict, inputs_dict
...@@ -655,7 +649,6 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -655,7 +649,6 @@ class PipelineTesterMixin(unittest.TestCase):
model = UNetUnconditionalModel( model = UNetUnconditionalModel(
block_channels=(32, 64), block_channels=(32, 64),
num_res_blocks=2, num_res_blocks=2,
attn_resolutions=(16,),
image_size=32, image_size=32,
in_channels=3, in_channels=3,
out_channels=3, out_channels=3,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment