Commit 5a92374a authored by suily's avatar suily
Browse files

test commit

parents
vgg.pth
\ No newline at end of file
Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
\ No newline at end of file
"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
from collections import namedtuple
import torch
import torch.nn as nn
from torchvision import models
from ..util import get_ckpt_path
class LPIPS(nn.Module):
# Learned perceptual metric
def __init__(self, use_dropout=True):
super().__init__()
self.scaling_layer = ScalingLayer()
self.chns = [64, 128, 256, 512, 512] # vg16 features
self.net = vgg16(pretrained=True, requires_grad=False)
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
self.load_from_pretrained()
for param in self.parameters():
param.requires_grad = False
def load_from_pretrained(self, name="vgg_lpips"):
ckpt = get_ckpt_path(name, "sgm/modules/autoencoding/lpips/loss")
self.load_state_dict(
torch.load(ckpt, map_location=torch.device("cpu")), strict=False
)
print("loaded pretrained LPIPS loss from {}".format(ckpt))
@classmethod
def from_pretrained(cls, name="vgg_lpips"):
if name != "vgg_lpips":
raise NotImplementedError
model = cls()
ckpt = get_ckpt_path(name)
model.load_state_dict(
torch.load(ckpt, map_location=torch.device("cpu")), strict=False
)
return model
def forward(self, input, target):
in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
outs0, outs1 = self.net(in0_input), self.net(in1_input)
feats0, feats1, diffs = {}, {}, {}
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
for kk in range(len(self.chns)):
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(
outs1[kk]
)
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
res = [
spatial_average(lins[kk].model(diffs[kk]), keepdim=True)
for kk in range(len(self.chns))
]
val = res[0]
for l in range(1, len(self.chns)):
val += res[l]
return val
class ScalingLayer(nn.Module):
def __init__(self):
super(ScalingLayer, self).__init__()
self.register_buffer(
"shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
)
self.register_buffer(
"scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
)
def forward(self, inp):
return (inp - self.shift) / self.scale
class NetLinLayer(nn.Module):
"""A single linear layer which does a 1x1 conv"""
def __init__(self, chn_in, chn_out=1, use_dropout=False):
super(NetLinLayer, self).__init__()
layers = (
[
nn.Dropout(),
]
if (use_dropout)
else []
)
layers += [
nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
]
self.model = nn.Sequential(*layers)
class vgg16(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(vgg16, self).__init__()
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
for x in range(4):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(4, 9):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(9, 16):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(16, 23):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(23, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1_2 = h
h = self.slice2(h)
h_relu2_2 = h
h = self.slice3(h)
h_relu3_3 = h
h = self.slice4(h)
h_relu4_3 = h
h = self.slice5(h)
h_relu5_3 = h
vgg_outputs = namedtuple(
"VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
)
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
return out
def normalize_tensor(x, eps=1e-10):
norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
return x / (norm_factor + eps)
def spatial_average(x, keepdim=True):
return x.mean([2, 3], keepdim=keepdim)
Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------- LICENSE FOR pix2pix --------------------------------
BSD License
For pix2pix software
Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
----------------------------- LICENSE FOR DCGAN --------------------------------
BSD License
For dcgan.torch software
Copyright (c) 2015, Facebook, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
\ No newline at end of file
import functools
import torch.nn as nn
from ..util import ActNorm
def weights_init(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm") != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
class NLayerDiscriminator(nn.Module):
"""Defines a PatchGAN discriminator as in Pix2Pix
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
"""
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
"""Construct a PatchGAN discriminator
Parameters:
input_nc (int) -- the number of channels in input images
ndf (int) -- the number of filters in the last conv layer
n_layers (int) -- the number of conv layers in the discriminator
norm_layer -- normalization layer
"""
super(NLayerDiscriminator, self).__init__()
if not use_actnorm:
norm_layer = nn.BatchNorm2d
else:
norm_layer = ActNorm
if (
type(norm_layer) == functools.partial
): # no need to use bias as BatchNorm2d has affine parameters
use_bias = norm_layer.func != nn.BatchNorm2d
else:
use_bias = norm_layer != nn.BatchNorm2d
kw = 4
padw = 1
sequence = [
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, True),
]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
nf_mult_prev = nf_mult
nf_mult = min(2**n, 8)
sequence += [
nn.Conv2d(
ndf * nf_mult_prev,
ndf * nf_mult,
kernel_size=kw,
stride=2,
padding=padw,
bias=use_bias,
),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True),
]
nf_mult_prev = nf_mult
nf_mult = min(2**n_layers, 8)
sequence += [
nn.Conv2d(
ndf * nf_mult_prev,
ndf * nf_mult,
kernel_size=kw,
stride=1,
padding=padw,
bias=use_bias,
),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True),
]
sequence += [
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
] # output 1 channel prediction map
self.main = nn.Sequential(*sequence)
def forward(self, input):
"""Standard forward."""
return self.main(input)
import hashlib
import os
import requests
import torch
import torch.nn as nn
from tqdm import tqdm
URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"}
CKPT_MAP = {"vgg_lpips": "vgg.pth"}
MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"}
def download(url, local_path, chunk_size=1024):
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
with requests.get(url, stream=True) as r:
total_size = int(r.headers.get("content-length", 0))
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
with open(local_path, "wb") as f:
for data in r.iter_content(chunk_size=chunk_size):
if data:
f.write(data)
pbar.update(chunk_size)
def md5_hash(path):
with open(path, "rb") as f:
content = f.read()
return hashlib.md5(content).hexdigest()
def get_ckpt_path(name, root, check=False):
assert name in URL_MAP
path = os.path.join(root, CKPT_MAP[name])
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
download(URL_MAP[name], path)
md5 = md5_hash(path)
assert md5 == MD5_MAP[name], md5
return path
class ActNorm(nn.Module):
def __init__(
self, num_features, logdet=False, affine=True, allow_reverse_init=False
):
assert affine
super().__init__()
self.logdet = logdet
self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
self.allow_reverse_init = allow_reverse_init
self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
def initialize(self, input):
with torch.no_grad():
flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
mean = (
flatten.mean(1)
.unsqueeze(1)
.unsqueeze(2)
.unsqueeze(3)
.permute(1, 0, 2, 3)
)
std = (
flatten.std(1)
.unsqueeze(1)
.unsqueeze(2)
.unsqueeze(3)
.permute(1, 0, 2, 3)
)
self.loc.data.copy_(-mean)
self.scale.data.copy_(1 / (std + 1e-6))
def forward(self, input, reverse=False):
if reverse:
return self.reverse(input)
if len(input.shape) == 2:
input = input[:, :, None, None]
squeeze = True
else:
squeeze = False
_, _, height, width = input.shape
if self.training and self.initialized.item() == 0:
self.initialize(input)
self.initialized.fill_(1)
h = self.scale * (input + self.loc)
if squeeze:
h = h.squeeze(-1).squeeze(-1)
if self.logdet:
log_abs = torch.log(torch.abs(self.scale))
logdet = height * width * torch.sum(log_abs)
logdet = logdet * torch.ones(input.shape[0]).to(input)
return h, logdet
return h
def reverse(self, output):
if self.training and self.initialized.item() == 0:
if not self.allow_reverse_init:
raise RuntimeError(
"Initializing ActNorm in reverse direction is "
"disabled by default. Use allow_reverse_init=True to enable."
)
else:
self.initialize(output)
self.initialized.fill_(1)
if len(output.shape) == 2:
output = output[:, :, None, None]
squeeze = True
else:
squeeze = False
h = output / self.scale - self.loc
if squeeze:
h = h.squeeze(-1).squeeze(-1)
return h
import torch
import torch.nn.functional as F
def hinge_d_loss(logits_real, logits_fake):
loss_real = torch.mean(F.relu(1.0 - logits_real))
loss_fake = torch.mean(F.relu(1.0 + logits_fake))
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def vanilla_d_loss(logits_real, logits_fake):
d_loss = 0.5 * (
torch.mean(torch.nn.functional.softplus(-logits_real))
+ torch.mean(torch.nn.functional.softplus(logits_fake))
)
return d_loss
from abc import abstractmethod
from typing import Any, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from ....modules.distributions.distributions import DiagonalGaussianDistribution
class AbstractRegularizer(nn.Module):
def __init__(self):
super().__init__()
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
raise NotImplementedError()
@abstractmethod
def get_trainable_parameters(self) -> Any:
raise NotImplementedError()
class DiagonalGaussianRegularizer(AbstractRegularizer):
def __init__(self, sample: bool = True):
super().__init__()
self.sample = sample
def get_trainable_parameters(self) -> Any:
yield from ()
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
log = dict()
posterior = DiagonalGaussianDistribution(z)
if self.sample:
z = posterior.sample()
else:
z = posterior.mode()
kl_loss = posterior.kl()
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
log["kl_loss"] = kl_loss
return z, log
def measure_perplexity(predicted_indices, num_centroids):
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
encodings = (
F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
)
avg_probs = encodings.mean(0)
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
cluster_use = torch.sum(avg_probs > 0)
return perplexity, cluster_use
from .denoiser import Denoiser
from .discretizer import Discretization
from .loss import StandardDiffusionLoss
from .model import Decoder, Encoder, Model
from .openaimodel import UNetModel
from .sampling import BaseDiffusionSampler
from .wrappers import OpenAIWrapper
from typing import Dict, Union
import torch
import torch.nn as nn
from ...util import append_dims, instantiate_from_config
class Denoiser(nn.Module):
def __init__(self, weighting_config, scaling_config):
super().__init__()
self.weighting = instantiate_from_config(weighting_config)
self.scaling = instantiate_from_config(scaling_config)
def possibly_quantize_sigma(self, sigma):
return sigma
def possibly_quantize_c_noise(self, c_noise):
return c_noise
def w(self, sigma):
return self.weighting(sigma)
def forward(
self,
network: nn.Module,
input: torch.Tensor,
sigma: torch.Tensor,
cond: Dict,
**additional_model_inputs,
) -> torch.Tensor:
sigma = self.possibly_quantize_sigma(sigma)
sigma_shape = sigma.shape
sigma = append_dims(sigma, input.ndim)
c_skip, c_out, c_in, c_noise = self.scaling(sigma, **additional_model_inputs)
c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
return (
network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out
+ input * c_skip
)
class DiscreteDenoiser(Denoiser):
def __init__(
self,
weighting_config,
scaling_config,
num_idx,
discretization_config,
do_append_zero=False,
quantize_c_noise=True,
flip=True,
):
super().__init__(weighting_config, scaling_config)
sigmas = instantiate_from_config(discretization_config)(
num_idx, do_append_zero=do_append_zero, flip=flip
)
self.sigmas = sigmas
# self.register_buffer("sigmas", sigmas)
self.quantize_c_noise = quantize_c_noise
def sigma_to_idx(self, sigma):
dists = sigma - self.sigmas.to(sigma.device)[:, None]
return dists.abs().argmin(dim=0).view(sigma.shape)
def idx_to_sigma(self, idx):
return self.sigmas.to(idx.device)[idx]
def possibly_quantize_sigma(self, sigma):
return self.idx_to_sigma(self.sigma_to_idx(sigma))
def possibly_quantize_c_noise(self, c_noise):
if self.quantize_c_noise:
return self.sigma_to_idx(c_noise)
else:
return c_noise
from abc import ABC, abstractmethod
from typing import Any, Tuple
import torch
class DenoiserScaling(ABC):
@abstractmethod
def __call__(
self, sigma: torch.Tensor, **additional_model_inputs
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
pass
class EDMScaling:
def __init__(self, sigma_data: float = 0.5):
self.sigma_data = sigma_data
def __call__(
self, sigma: torch.Tensor, **additional_model_inputs
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
c_noise = 0.25 * sigma.log()
return c_skip, c_out, c_in, c_noise
class EpsScaling:
def __call__(
self, sigma: torch.Tensor, **additional_model_inputs
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
c_skip = torch.ones_like(sigma, device=sigma.device)
c_out = -sigma
c_in = 1 / (sigma**2 + 1.0) ** 0.5
c_noise = sigma.clone()
return c_skip, c_out, c_in, c_noise
class VScaling:
def __call__(
self, sigma: torch.Tensor, **additional_model_inputs
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
c_skip = 1.0 / (sigma**2 + 1.0)
c_out = -sigma / (sigma**2 + 1.0) ** 0.5
c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
c_noise = sigma.clone()
return c_skip, c_out, c_in, c_noise
class VScalingWithEDMcNoise(DenoiserScaling):
def __call__(
self, sigma: torch.Tensor, **additional_model_inputs
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
c_skip = 1.0 / (sigma**2 + 1.0)
c_out = -sigma / (sigma**2 + 1.0) ** 0.5
c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
c_noise = 0.25 * sigma.log()
return c_skip, c_out, c_in, c_noise
class ZeroSNRScaling: # similar to VScaling
def __call__(
self, alphas_cumprod_sqrt: torch.Tensor, **additional_model_inputs
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
c_skip = alphas_cumprod_sqrt
c_out = - (1 - alphas_cumprod_sqrt**2) ** 0.5
c_in = torch.ones_like(alphas_cumprod_sqrt, device=alphas_cumprod_sqrt.device)
c_noise = additional_model_inputs['idx'].clone()
return c_skip, c_out, c_in, c_noise
\ No newline at end of file
import torch
class UnitWeighting:
def __call__(self, sigma):
return torch.ones_like(sigma, device=sigma.device)
class EDMWeighting:
def __init__(self, sigma_data=0.5):
self.sigma_data = sigma_data
def __call__(self, sigma):
return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
class VWeighting(EDMWeighting):
def __init__(self):
super().__init__(sigma_data=1.0)
class EpsWeighting:
def __call__(self, sigma):
return sigma**-2.0
from abc import abstractmethod
from functools import partial
import numpy as np
import torch
from ...modules.diffusionmodules.util import make_beta_schedule
from ...util import append_zero
def generate_roughly_equally_spaced_steps(
num_substeps: int, max_step: int
) -> np.ndarray:
return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1]
def sub_generate_roughly_equally_spaced_steps(
num_substeps_1: int, num_substeps_2: int, max_step: int
) -> np.ndarray:
substeps_2 = np.linspace(max_step - 1, 0, num_substeps_2, endpoint=False).astype(int)[::-1]
substeps_1 = np.linspace(num_substeps_2 - 1, 0, num_substeps_1, endpoint=False).astype(int)[::-1]
return substeps_2[substeps_1]
class Discretization:
def __call__(self, n, do_append_zero=True, device="cpu", flip=False):
sigmas = self.get_sigmas(n, device=device)
sigmas = append_zero(sigmas) if do_append_zero else sigmas
return sigmas if not flip else torch.flip(sigmas, (0,))
@abstractmethod
def get_sigmas(self, n, device):
pass
class EDMDiscretization(Discretization):
def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0):
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.rho = rho
def get_sigmas(self, n, device="cpu"):
ramp = torch.linspace(0, 1, n, device=device)
min_inv_rho = self.sigma_min ** (1 / self.rho)
max_inv_rho = self.sigma_max ** (1 / self.rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho
return sigmas
class LegacyDDPMDiscretization(Discretization):
def __init__(
self,
linear_start=0.00085,
linear_end=0.0120,
num_timesteps=1000,
):
super().__init__()
self.num_timesteps = num_timesteps
betas = make_beta_schedule(
"linear", num_timesteps, linear_start=linear_start, linear_end=linear_end
)
alphas = 1.0 - betas
self.alphas_cumprod = np.cumprod(alphas, axis=0)
self.to_torch = partial(torch.tensor, dtype=torch.float32)
def get_sigmas(self, n, device="cpu"):
if n < self.num_timesteps:
timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
alphas_cumprod = self.alphas_cumprod[timesteps]
elif n == self.num_timesteps:
alphas_cumprod = self.alphas_cumprod
else:
raise ValueError
to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
return torch.flip(sigmas, (0,)) # sigma_t: 14.4 -> 0.029
class ZeroSNRDDPMDiscretization(Discretization):
def __init__(
self,
linear_start=0.00085,
linear_end=0.0120,
num_timesteps=1000,
shift_scale=1.,
):
super().__init__()
self.num_timesteps = num_timesteps
betas = make_beta_schedule(
"linear", num_timesteps, linear_start=linear_start, linear_end=linear_end
)
alphas = 1.0 - betas
self.alphas_cumprod = np.cumprod(alphas, axis=0)
self.to_torch = partial(torch.tensor, dtype=torch.float32)
self.alphas_cumprod = self.alphas_cumprod / (shift_scale + (1-shift_scale) * self.alphas_cumprod)
def get_sigmas(self, n, device="cpu", return_idx=False):
if n < self.num_timesteps:
timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
alphas_cumprod = self.alphas_cumprod[timesteps]
elif n == self.num_timesteps:
alphas_cumprod = self.alphas_cumprod
else:
raise ValueError
to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
alphas_cumprod = to_torch(alphas_cumprod)
alphas_cumprod_sqrt = alphas_cumprod.sqrt()
alphas_cumprod_sqrt_0 = alphas_cumprod_sqrt[0].clone()
alphas_cumprod_sqrt_T = alphas_cumprod_sqrt[-1].clone()
alphas_cumprod_sqrt -= alphas_cumprod_sqrt_T
alphas_cumprod_sqrt *= alphas_cumprod_sqrt_0 / (alphas_cumprod_sqrt_0 - alphas_cumprod_sqrt_T)
if return_idx:
return torch.flip(alphas_cumprod_sqrt, (0,)), timesteps
else:
return torch.flip(alphas_cumprod_sqrt, (0,)) # sqrt(alpha_t): 0 -> 0.99
def __call__(self, n, do_append_zero=True, device="cpu", flip=False, return_idx=False):
if return_idx:
sigmas, idx = self.get_sigmas(n, device=device, return_idx=True)
sigmas = append_zero(sigmas) if do_append_zero else sigmas
return (sigmas, idx) if not flip else (torch.flip(sigmas, (0,)), torch.flip(idx, (0,)))
else:
sigmas = self.get_sigmas(n, device=device)
sigmas = append_zero(sigmas) if do_append_zero else sigmas
return sigmas if not flip else torch.flip(sigmas, (0,))
\ No newline at end of file
from omegaconf import DictConfig
from functools import partial
from einops import rearrange
import numpy as np
import torch
from torch import nn
import torch.distributed
from sat.model.base_model import BaseModel
from sat.model.mixins import BaseMixin
from sat.ops.layernorm import LayerNorm
from sat.transformer_defaults import HOOKS_DEFAULT, attention_fn_default
from sat.mpu.utils import split_tensor_along_last_dim
from sgm.util import (
disabled_train,
instantiate_from_config,
)
from sgm.modules.diffusionmodules.openaimodel import Timestep
from sgm.modules.diffusionmodules.util import (
conv_nd,
linear,
timestep_embedding,
)
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def unpatchify(x, channels, patch_size, height, width):
x = rearrange(
x,
"b (h w) (c p1 p2) -> b c (h p1) (w p2)",
h=height // patch_size,
w=width // patch_size,
p1=patch_size,
p2=patch_size,
)
return x
class ImagePatchEmbeddingMixin(BaseMixin):
def __init__(
self,
in_channels,
hidden_size,
patch_size,
text_hidden_size=None,
do_rearrange=True,
):
super().__init__()
self.in_channels = in_channels
self.hidden_size = hidden_size
self.patch_size = patch_size
self.text_hidden_size = text_hidden_size
self.do_rearrange = do_rearrange
self.proj = nn.Linear(in_channels * patch_size ** 2, hidden_size)
if text_hidden_size is not None:
self.text_proj = nn.Linear(text_hidden_size, hidden_size)
def word_embedding_forward(self, input_ids, images, encoder_outputs, **kwargs):
# images: B x C x H x W
if self.do_rearrange:
patches_images = rearrange(
images, "b c (h p1) (w p2) -> b (h w) (c p1 p2)", p1=self.patch_size, p2=self.patch_size
)
else:
patches_images = images
emb = self.proj(patches_images)
if self.text_hidden_size is not None:
text_emb = self.text_proj(encoder_outputs)
emb = torch.cat([text_emb, emb], dim=1)
return emb
def reinit(self, parent_model=None):
w = self.proj.weight.data
nn.init.xavier_uniform_(self.proj.weight)
nn.init.constant_(self.proj.bias, 0)
del self.transformer.word_embeddings
def get_2d_sincos_pos_embed(embed_dim, grid_height, grid_width, cls_token=False, extra_tokens=0):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_height, dtype=np.float32)
grid_w = np.arange(grid_width, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_height, grid_width])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.0
omega = 1.0 / 10000 ** omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
class PositionEmbeddingMixin(BaseMixin):
def __init__(
self,
max_height,
max_width,
hidden_size,
text_length=0,
block_size=16,
**kwargs,
):
super().__init__()
self.max_height = max_height
self.max_width = max_width
self.hidden_size = hidden_size
self.text_length = text_length
self.block_size = block_size
self.image_pos_embedding = nn.Parameter(
torch.zeros(self.max_height, self.max_width, hidden_size), requires_grad=False
)
def position_embedding_forward(self, position_ids, target_size, **kwargs):
ret = []
for h, w in target_size:
h, w = h // self.block_size, w // self.block_size
image_pos_embed = self.image_pos_embedding[:h, :w].reshape(h * w, -1)
pos_embed = torch.cat(
[
torch.zeros(
(self.text_length, self.hidden_size),
dtype=image_pos_embed.dtype,
device=image_pos_embed.device,
),
image_pos_embed,
],
dim=0,
)
ret.append(pos_embed[None, ...])
return torch.cat(ret, dim=0)
def reinit(self, parent_model=None):
del self.transformer.position_embeddings
pos_embed = get_2d_sincos_pos_embed(self.image_pos_embedding.shape[-1], self.max_height, self.max_width)
pos_embed = pos_embed.reshape(self.max_height, self.max_width, -1)
self.image_pos_embedding.data.copy_(torch.from_numpy(pos_embed).float())
class FinalLayerMixin(BaseMixin):
def __init__(
self,
hidden_size,
time_embed_dim,
patch_size,
block_size,
out_channels,
elementwise_affine=False,
eps=1e-6,
do_unpatchify=True,
):
super().__init__()
self.hidden_size = hidden_size
self.patch_size = patch_size
self.block_size = block_size
self.out_channels = out_channels
self.do_unpatchify = do_unpatchify
self.norm_final = nn.LayerNorm(
hidden_size,
elementwise_affine=elementwise_affine,
eps=eps,
)
self.adaln = nn.Sequential(
nn.SiLU(),
nn.Linear(time_embed_dim, 2 * hidden_size),
)
self.linear = nn.Linear(hidden_size, out_channels * patch_size ** 2)
def final_forward(self, logits, emb, text_length, target_size=None, **kwargs):
x = logits[:, text_length:]
shift, scale = self.adaln(emb).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
if self.do_unpatchify:
target_height, target_width = target_size[0]
assert (
target_height % self.block_size == 0 and target_width % self.block_size == 0
), "target size must be divisible by block size"
out_height, out_width = (
target_height // self.block_size * self.patch_size,
target_width // self.block_size * self.patch_size,
)
x = unpatchify(
x, channels=self.out_channels, patch_size=self.patch_size, height=out_height, width=out_width
)
return x
def reinit(self, parent_model=None):
nn.init.xavier_uniform_(self.linear.weight)
nn.init.constant_(self.linear.bias, 0)
class AdalnAttentionMixin(BaseMixin):
def __init__(
self,
hidden_size,
num_layers,
time_embed_dim,
qk_ln=True,
hidden_size_head=None,
elementwise_affine=False,
eps=1e-6,
):
super().__init__()
self.adaln_modules = nn.ModuleList(
[nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 12 * hidden_size)) for _ in range(num_layers)]
)
self.qk_ln = qk_ln
if qk_ln:
self.query_layernorms = nn.ModuleList(
[
LayerNorm(hidden_size_head, elementwise_affine=elementwise_affine, eps=eps)
for _ in range(num_layers)
]
)
self.key_layernorms = nn.ModuleList(
[
LayerNorm(hidden_size_head, elementwise_affine=elementwise_affine, eps=eps)
for _ in range(num_layers)
]
)
def layer_forward(
self,
hidden_states,
mask,
text_length,
layer_id,
emb,
*args,
**kwargs,
):
layer = self.transformer.layers[layer_id]
adaln_module = self.adaln_modules[layer_id]
(
shift_msa_img,
scale_msa_img,
gate_msa_img,
shift_mlp_img,
scale_mlp_img,
gate_mlp_img,
shift_msa_txt,
scale_msa_txt,
gate_msa_txt,
shift_mlp_txt,
scale_mlp_txt,
gate_mlp_txt,
) = adaln_module(emb).chunk(12, dim=1)
gate_msa_img, gate_mlp_img, gate_msa_txt, gate_mlp_txt = (
gate_msa_img.unsqueeze(1),
gate_mlp_img.unsqueeze(1),
gate_msa_txt.unsqueeze(1),
gate_mlp_txt.unsqueeze(1),
)
attention_input = layer.input_layernorm(hidden_states)
text_attention_input = modulate(attention_input[:, :text_length], shift_msa_txt, scale_msa_txt)
image_attention_input = modulate(attention_input[:, text_length:], shift_msa_img, scale_msa_img)
attention_input = torch.cat((text_attention_input, image_attention_input), dim=1)
attention_output = layer.attention(attention_input, mask, layer_id=layer_id, **kwargs)
if self.transformer.layernorm_order == "sandwich":
attention_output = layer.third_layernorm(attention_output)
text_hidden_states, image_hidden_states = hidden_states[:, :text_length], hidden_states[:, text_length:]
text_attention_output, image_attention_output = (
attention_output[:, :text_length],
attention_output[:, text_length:],
)
text_hidden_states = text_hidden_states + gate_msa_txt * text_attention_output
image_hidden_states = image_hidden_states + gate_msa_img * image_attention_output
hidden_states = torch.cat((text_hidden_states, image_hidden_states), dim=1)
mlp_input = layer.post_attention_layernorm(hidden_states)
text_mlp_input = modulate(mlp_input[:, :text_length], shift_mlp_txt, scale_mlp_txt)
image_mlp_input = modulate(mlp_input[:, text_length:], shift_mlp_img, scale_mlp_img)
mlp_input = torch.cat((text_mlp_input, image_mlp_input), dim=1)
mlp_output = layer.mlp(mlp_input, layer_id=layer_id, **kwargs)
if self.transformer.layernorm_order == "sandwich":
mlp_output = layer.fourth_layernorm(mlp_output)
text_hidden_states, image_hidden_states = hidden_states[:, :text_length], hidden_states[:, text_length:]
text_mlp_output, image_mlp_output = mlp_output[:, :text_length], mlp_output[:, text_length:]
text_hidden_states = text_hidden_states + gate_mlp_txt * text_mlp_output
image_hidden_states = image_hidden_states + gate_mlp_img * image_mlp_output
hidden_states = torch.cat((text_hidden_states, image_hidden_states), dim=1)
return hidden_states
def attention_forward(self, hidden_states, mask, layer_id, **kwargs):
attention = self.transformer.layers[layer_id].attention
attention_fn = attention_fn_default
if "attention_fn" in attention.hooks:
attention_fn = attention.hooks["attention_fn"]
qkv = attention.query_key_value(hidden_states)
mixed_query_layer, mixed_key_layer, mixed_value_layer = split_tensor_along_last_dim(qkv, 3)
dropout_fn = attention.attention_dropout if self.training else None
query_layer = attention._transpose_for_scores(mixed_query_layer)
key_layer = attention._transpose_for_scores(mixed_key_layer)
value_layer = attention._transpose_for_scores(mixed_value_layer)
if self.qk_ln:
query_layernorm = self.query_layernorms[layer_id]
key_layernorm = self.key_layernorms[layer_id]
query_layer = query_layernorm(query_layer)
key_layer = key_layernorm(key_layer)
context_layer = attention_fn(query_layer, key_layer, value_layer, mask, dropout_fn, **kwargs)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (attention.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
output = attention.dense(context_layer)
if self.training:
output = attention.output_dropout(output)
return output
def mlp_forward(self, hidden_states, layer_id, **kwargs):
mlp = self.transformer.layers[layer_id].mlp
intermediate_parallel = mlp.dense_h_to_4h(hidden_states)
intermediate_parallel = mlp.activation_func(intermediate_parallel)
output = mlp.dense_4h_to_h(intermediate_parallel)
if self.training:
output = mlp.dropout(output)
return output
def reinit(self, parent_model=None):
for layer in self.adaln_modules:
nn.init.constant_(layer[-1].weight, 0)
nn.init.constant_(layer[-1].bias, 0)
str_to_dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
class DiffusionTransformer(BaseModel):
def __init__(
self,
in_channels,
out_channels,
hidden_size,
patch_size,
num_layers,
num_attention_heads,
text_length,
time_embed_dim=None,
num_classes=None,
adm_in_channels=None,
modules={},
dtype="fp32",
layernorm_order="pre",
elementwise_affine=False,
parallel_output=True,
block_size=16,
**kwargs,
):
self.model_channels = hidden_size
self.time_embed_dim = time_embed_dim if time_embed_dim is not None else hidden_size
self.num_classes = num_classes
self.adm_in_channels = adm_in_channels
self.text_length = text_length
self.block_size = block_size
self.dtype = str_to_dtype[dtype]
hidden_size_head = hidden_size // num_attention_heads # 40
approx_gelu = nn.GELU(approximate="tanh")
activation_func = approx_gelu
transformer_args = {
"vocab_size": 1,
"max_sequence_length": 64,
"skip_init": False,
"model_parallel_size": 1,
"is_decoder": False,
"layernorm_order": layernorm_order,
"num_layers": num_layers,
"hidden_size": hidden_size,
"num_attention_heads": num_attention_heads,
"parallel_output": parallel_output,
}
transformer_args = DictConfig(transformer_args)
super().__init__(
args=transformer_args,
transformer=None,
layernorm=partial(LayerNorm, elementwise_affine=elementwise_affine, eps=1e-6),
activation_func=activation_func,
**kwargs,
)
model_channels = hidden_size
time_embed_dim = self.time_embed_dim
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
assert self.adm_in_channels is not None
self.label_emb = nn.Sequential(
nn.Sequential(
linear(self.adm_in_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
)
pos_embed_config = modules.get(
"pos_embed_config",
{
"target": "sgm.modules.diffusionmodules.dit.PositionEmbeddingMixin",
},
)
pos_embed_extra_kwargs = {
"hidden_size": hidden_size,
"text_length": text_length,
}
pos_embed_mixin = instantiate_from_config(pos_embed_config, **pos_embed_extra_kwargs)
self.add_mixin("pos_embed", pos_embed_mixin, reinit=True)
patch_embed_config = modules.get(
"patch_embed_config",
{
"target": "sgm.modules.diffusionmodules.dit.ImagePatchEmbeddingMixin",
},
)
patch_embed_extra_kwargs = {
"in_channels": in_channels,
"hidden_size": hidden_size,
"patch_size": patch_size,
}
patch_embed_mixin = instantiate_from_config(patch_embed_config, **patch_embed_extra_kwargs)
self.add_mixin("patch_embed", patch_embed_mixin, reinit=True)
attention_config = modules.get(
"attention_config",
{
"target": "sgm.modules.diffusionmodules.dit.AdalnAttentionMixin",
},
)
attention_extra_kwargs = {
"hidden_size": hidden_size,
"hidden_size_head": hidden_size_head,
"num_layers": num_layers,
"time_embed_dim": time_embed_dim,
"elementwise_affine": elementwise_affine,
}
attention_mixin = instantiate_from_config(attention_config, **attention_extra_kwargs)
self.add_mixin("adaln", attention_mixin, reinit=True)
final_layer_config = modules.get(
"final_layer_config",
{
"target": "sgm.modules.diffusionmodules.dit.FinalLayerMixin",
},
)
final_layer_extra_kwargs = {
"hidden_size": hidden_size,
"time_embed_dim": time_embed_dim,
"patch_size": patch_size,
"block_size": block_size,
"out_channels": out_channels,
"elementwise_affine": elementwise_affine,
}
final_layer_mixin = instantiate_from_config(final_layer_config, **final_layer_extra_kwargs)
self.add_mixin("final_layer", final_layer_mixin, reinit=True)
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
x = x.to(self.dtype)
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False, dtype=self.dtype)
emb = self.time_embed(t_emb)
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)
input_ids = position_ids = attention_mask = torch.ones((1, 1)).to(x.dtype)
output = super().forward(
images=x,
emb=emb,
encoder_outputs=context,
text_length=self.text_length,
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
**kwargs,
)[0]
return output
import logging
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, Union
from functools import partial
import math
import torch
from einops import rearrange, repeat
from ...util import append_dims, default, instantiate_from_config
class Guider(ABC):
@abstractmethod
def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
pass
def prepare_inputs(
self, x: torch.Tensor, s: float, c: Dict, uc: Dict
) -> Tuple[torch.Tensor, float, Dict]:
pass
class VanillaCFG:
"""
implements parallelized CFG
"""
def __init__(self, scale, dyn_thresh_config=None):
scale_schedule = lambda scale, sigma: scale # independent of step
self.scale_schedule = partial(scale_schedule, scale)
self.dyn_thresh = instantiate_from_config(
default(
dyn_thresh_config,
{
"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
},
)
)
def __call__(self, x, sigma, step = None, num_steps = None, **kwargs):
x_u, x_c = x.chunk(2)
scale_value = self.scale_schedule(sigma)
x_pred = self.dyn_thresh(x_u, x_c, scale_value, step=step, num_steps=num_steps)
return x_pred
def prepare_inputs(self, x, s, c, uc):
c_out = dict()
for k in c:
if k in ["vector", "crossattn", "concat"]:
c_out[k] = torch.cat((uc[k], c[k]), 0)
else:
assert c[k] == uc[k]
c_out[k] = c[k]
return torch.cat([x] * 2), torch.cat([s] * 2), c_out
class IdentityGuider:
def __call__(self, x, sigma, **kwargs):
return x
def prepare_inputs(self, x, s, c, uc):
c_out = dict()
for k in c:
c_out[k] = c[k]
return x, s, c_out
class LinearPredictionGuider(Guider):
def __init__(
self,
max_scale: float,
num_frames: int,
min_scale: float = 1.0,
additional_cond_keys: Optional[Union[List[str], str]] = None,
):
self.min_scale = min_scale
self.max_scale = max_scale
self.num_frames = num_frames
self.scale = torch.linspace(min_scale, max_scale, num_frames).unsqueeze(0)
additional_cond_keys = default(additional_cond_keys, [])
if isinstance(additional_cond_keys, str):
additional_cond_keys = [additional_cond_keys]
self.additional_cond_keys = additional_cond_keys
def __call__(self, x: torch.Tensor, sigma: torch.Tensor, **kwargs) -> torch.Tensor:
x_u, x_c = x.chunk(2)
x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames)
x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames)
scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0])
scale = append_dims(scale, x_u.ndim).to(x_u.device)
scale = scale.to(x_u.dtype)
return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ...")
def prepare_inputs(
self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
c_out = dict()
for k in c:
if k in ["vector", "crossattn", "concat"] + self.additional_cond_keys:
c_out[k] = torch.cat((uc[k], c[k]), 0)
else:
assert c[k] == uc[k]
c_out[k] = c[k]
return torch.cat([x] * 2), torch.cat([s] * 2), c_out
import os
import copy
from typing import List, Optional, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import ListConfig
from ...util import append_dims, instantiate_from_config
from ...modules.autoencoding.lpips.loss.lpips import LPIPS
from ...modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
from ...util import get_obj_from_str, default
from ...modules.diffusionmodules.discretizer import generate_roughly_equally_spaced_steps, sub_generate_roughly_equally_spaced_steps
class StandardDiffusionLoss(nn.Module):
def __init__(
self,
sigma_sampler_config,
type="l2",
offset_noise_level=0.0,
batch2model_keys: Optional[Union[str, List[str], ListConfig]] = None,
):
super().__init__()
assert type in ["l2", "l1", "lpips"]
self.sigma_sampler = instantiate_from_config(sigma_sampler_config)
self.type = type
self.offset_noise_level = offset_noise_level
if type == "lpips":
self.lpips = LPIPS().eval()
if not batch2model_keys:
batch2model_keys = []
if isinstance(batch2model_keys, str):
batch2model_keys = [batch2model_keys]
self.batch2model_keys = set(batch2model_keys)
def __call__(self, network, denoiser, conditioner, input, batch):
cond = conditioner(batch)
additional_model_inputs = {
key: batch[key] for key in self.batch2model_keys.intersection(batch)
}
sigmas = self.sigma_sampler(input.shape[0]).to(input.device)
noise = torch.randn_like(input)
if self.offset_noise_level > 0.0:
noise = noise + append_dims(
torch.randn(input.shape[0]).to(input.device), input.ndim
) * self.offset_noise_level
noise = noise.to(input.dtype)
noised_input = input.float() + noise * append_dims(sigmas, input.ndim)
model_output = denoiser(
network, noised_input, sigmas, cond, **additional_model_inputs
)
w = append_dims(denoiser.w(sigmas), input.ndim)
return self.get_loss(model_output, input, w)
def get_loss(self, model_output, target, w):
if self.type == "l2":
return torch.mean(
(w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1
)
elif self.type == "l1":
return torch.mean(
(w * (model_output - target).abs()).reshape(target.shape[0], -1), 1
)
elif self.type == "lpips":
loss = self.lpips(model_output, target).reshape(-1)
return loss
class LinearRelayDiffusionLoss(StandardDiffusionLoss):
def __init__(
self,
sigma_sampler_config,
type="l2",
offset_noise_level=0.0,
partial_num_steps=500,
blurring_schedule='linear',
batch2model_keys: Optional[Union[str, List[str], ListConfig]] = None,
):
super().__init__(
sigma_sampler_config,
type=type,
offset_noise_level=offset_noise_level,
batch2model_keys=batch2model_keys,
)
self.blurring_schedule = blurring_schedule
self.partial_num_steps = partial_num_steps
def __call__(self, network, denoiser, conditioner, input, batch):
cond = conditioner(batch)
additional_model_inputs = {
key: batch[key] for key in self.batch2model_keys.intersection(batch)
}
lr_input = batch["lr_input"]
rand = torch.randint(0, self.partial_num_steps, (input.shape[0],))
sigmas = self.sigma_sampler(input.shape[0], rand).to(input.dtype).to(input.device)
noise = torch.randn_like(input)
if self.offset_noise_level > 0.0:
noise = noise + append_dims(
torch.randn(input.shape[0]).to(input.device), input.ndim
) * self.offset_noise_level
noise = noise.to(input.dtype)
rand = append_dims(rand, input.ndim).to(input.dtype).to(input.device)
if self.blurring_schedule == 'linear':
blurred_input = input * (1 - rand / self.partial_num_steps) + lr_input * (rand / self.partial_num_steps)
elif self.blurring_schedule == 'sigma':
max_sigmas = self.sigma_sampler(input.shape[0], torch.ones(input.shape[0])*self.partial_num_steps).to(input.dtype).to(input.device)
blurred_input = input * (1 - sigmas / max_sigmas) + lr_input * (sigmas / max_sigmas)
elif self.blurring_schedule == 'exp':
rand_blurring = (1 - torch.exp(-(torch.sin((rand+1) / self.partial_num_steps * torch.pi / 2)**4))) / (1 - torch.exp(-torch.ones_like(rand)))
blurred_input = input * (1 - rand_blurring) + lr_input * rand_blurring
else:
raise NotImplementedError
noised_input = blurred_input + noise * append_dims(sigmas, input.ndim)
model_output = denoiser(
network, noised_input, sigmas, cond, **additional_model_inputs
)
w = append_dims(denoiser.w(sigmas), input.ndim)
return self.get_loss(model_output, input, w)
class ZeroSNRDiffusionLoss(StandardDiffusionLoss):
def __call__(self, network, denoiser, conditioner, input, batch):
cond = conditioner(batch)
additional_model_inputs = {
key: batch[key] for key in self.batch2model_keys.intersection(batch)
}
alphas_cumprod_sqrt, idx = self.sigma_sampler(input.shape[0], return_idx=True)
alphas_cumprod_sqrt = alphas_cumprod_sqrt.to(input.device)
idx = idx.to(input.dtype).to(input.device)
additional_model_inputs['idx'] = idx
noise = torch.randn_like(input)
if self.offset_noise_level > 0.0:
noise = noise + append_dims(
torch.randn(input.shape[0]).to(input.device), input.ndim
) * self.offset_noise_level
noised_input = input.float() * append_dims(alphas_cumprod_sqrt, input.ndim) + noise * append_dims((1-alphas_cumprod_sqrt**2)**0.5, input.ndim)
model_output = denoiser(
network, noised_input, alphas_cumprod_sqrt, cond, **additional_model_inputs
)
w = append_dims(1/(1-alphas_cumprod_sqrt**2), input.ndim) # v-pred
return self.get_loss(model_output, input, w)
def get_loss(self, model_output, target, w):
if self.type == "l2":
return torch.mean(
(w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1
)
elif self.type == "l1":
return torch.mean(
(w * (model_output - target).abs()).reshape(target.shape[0], -1), 1
)
elif self.type == "lpips":
loss = self.lpips(model_output, target).reshape(-1)
return loss
# pytorch_diffusion + derived encoder decoder
import math
from typing import Any, Callable, Optional
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from packaging import version
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILABLE = True
except:
XFORMERS_IS_AVAILABLE = False
print("no module 'xformers'. Processing without...")
from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
def Normalize(in_channels, num_groups=32):
return torch.nn.GroupNorm(
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
)
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1
)
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=2, padding=0
)
def forward(self, x):
if self.with_conv:
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
class ResnetBlock(nn.Module):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout,
temb_channels=512,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
else:
self.nin_shortcut = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class LinAttnBlock(LinearAttention):
"""to match AttnBlock usage"""
def __init__(self, in_channels):
super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def attention(self, h_: torch.Tensor) -> torch.Tensor:
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
q, k, v = map(
lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v)
)
h_ = torch.nn.functional.scaled_dot_product_attention(
q, k, v
) # scale is dim ** -0.5 per default
# compute attention
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
def forward(self, x, **kwargs):
h_ = x
h_ = self.attention(h_)
h_ = self.proj_out(h_)
return x + h_
class MemoryEfficientAttnBlock(nn.Module):
"""
Uses xformers efficient implementation,
see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
Note: this is a single-head self-attention operation
"""
#
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.attention_op: Optional[Any] = None
def attention(self, h_: torch.Tensor) -> torch.Tensor:
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
B, C, H, W = q.shape
q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v))
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(B, t.shape[1], 1, C)
.permute(0, 2, 1, 3)
.reshape(B * 1, t.shape[1], C)
.contiguous(),
(q, k, v),
)
out = xformers.ops.memory_efficient_attention(
q, k, v, attn_bias=None, op=self.attention_op
)
out = (
out.unsqueeze(0)
.reshape(B, 1, out.shape[1], C)
.permute(0, 2, 1, 3)
.reshape(B, out.shape[1], C)
)
return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
def forward(self, x, **kwargs):
h_ = x
h_ = self.attention(h_)
h_ = self.proj_out(h_)
return x + h_
class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
def forward(self, x, context=None, mask=None, **unused_kwargs):
b, c, h, w = x.shape
x = rearrange(x, "b c h w -> b (h w) c")
out = super().forward(x, context=context, mask=mask)
out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c)
return x + out
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
assert attn_type in [
"vanilla",
"vanilla-xformers",
"memory-efficient-cross-attn",
"linear",
"none",
], f"attn_type {attn_type} unknown"
if (
version.parse(torch.__version__) < version.parse("2.0.0")
and attn_type != "none"
):
assert XFORMERS_IS_AVAILABLE, (
f"We do not support vanilla attention in {torch.__version__} anymore, "
f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'"
)
attn_type = "vanilla-xformers"
# print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
if attn_type == "vanilla":
assert attn_kwargs is None
return AttnBlock(in_channels)
elif attn_type == "vanilla-xformers":
# print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
return MemoryEfficientAttnBlock(in_channels)
elif type == "memory-efficient-cross-attn":
attn_kwargs["query_dim"] = in_channels
return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
elif attn_type == "none":
return nn.Identity(in_channels)
else:
return LinAttnBlock(in_channels)
class Model(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
use_timestep=True,
use_linear_attn=False,
attn_type="vanilla",
):
super().__init__()
if use_linear_attn:
attn_type = "linear"
self.ch = ch
self.temb_ch = self.ch * 4
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.use_timestep = use_timestep
if self.use_timestep:
# timestep embedding
self.temb = nn.Module()
self.temb.dense = nn.ModuleList(
[
torch.nn.Linear(self.ch, self.temb_ch),
torch.nn.Linear(self.temb_ch, self.temb_ch),
]
)
# downsampling
self.conv_in = torch.nn.Conv2d(
in_channels, self.ch, kernel_size=3, stride=1, padding=1
)
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
skip_in = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
if i_block == self.num_res_blocks:
skip_in = ch * in_ch_mult[i_level]
block.append(
ResnetBlock(
in_channels=block_in + skip_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(
block_in, out_ch, kernel_size=3, stride=1, padding=1
)
def forward(self, x, t=None, context=None):
# assert x.shape[2] == x.shape[3] == self.resolution
if context is not None:
# assume aligned context, cat along channel axis
x = torch.cat((x, context), dim=1)
if self.use_timestep:
# timestep embedding
assert t is not None
temb = get_timestep_embedding(t, self.ch)
temb = self.temb.dense[0](temb)
temb = nonlinearity(temb)
temb = self.temb.dense[1](temb)
else:
temb = None
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](
torch.cat([h, hs.pop()], dim=1), temb
)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
def get_last_layer(self):
return self.conv_out.weight
class Encoder(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
double_z=True,
use_linear_attn=False,
attn_type="vanilla",
mid_attn=True,
**ignore_kwargs,
):
super().__init__()
if use_linear_attn:
attn_type = "linear"
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.attn_resolutions = attn_resolutions
self.mid_attn = mid_attn
# downsampling
self.conv_in = torch.nn.Conv2d(
in_channels, self.ch, kernel_size=3, stride=1, padding=1
)
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
if mid_attn:
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(
block_in,
2 * z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
padding=1,
)
def forward(self, x):
# timestep embedding
temb = None
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
if self.mid_attn:
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class Decoder(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
tanh_out=False,
use_linear_attn=False,
attn_type="vanilla",
mid_attn=True,
**ignorekwargs,
):
super().__init__()
if use_linear_attn:
attn_type = "linear"
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
self.tanh_out = tanh_out
self.attn_resolutions = attn_resolutions
self.mid_attn = mid_attn
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1,) + tuple(ch_mult)
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
# print(
# "Working with z of shape {} = {} dimensions.".format(
# self.z_shape, np.prod(self.z_shape)
# )
# )
make_attn_cls = self._make_attn()
make_resblock_cls = self._make_resblock()
make_conv_cls = self._make_conv()
# z to block_in
self.conv_in = torch.nn.Conv2d(
z_channels, block_in, kernel_size=3, stride=1, padding=1
)
# middle
self.mid = nn.Module()
self.mid.block_1 = make_resblock_cls(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
if mid_attn:
self.mid.attn_1 = make_attn_cls(block_in, attn_type=attn_type)
self.mid.block_2 = make_resblock_cls(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
make_resblock_cls(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn_cls(block_in, attn_type=attn_type))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = make_conv_cls(
block_in, out_ch, kernel_size=3, stride=1, padding=1
)
def _make_attn(self) -> Callable:
return make_attn
def _make_resblock(self) -> Callable:
return ResnetBlock
def _make_conv(self) -> Callable:
return torch.nn.Conv2d
def get_last_layer(self, **kwargs):
return self.conv_out.weight
def forward(self, z, **kwargs):
# assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape
# timestep embedding
temb = None
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h, temb, **kwargs)
if self.mid_attn:
h = self.mid.attn_1(h, **kwargs)
h = self.mid.block_2(h, temb, **kwargs)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb, **kwargs)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, **kwargs)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h, **kwargs)
if self.tanh_out:
h = torch.tanh(h)
return h
import os
import math
from abc import abstractmethod
from functools import partial
from typing import Iterable, List, Optional, Tuple, Union
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from ...modules.attention import SpatialTransformer
from ...modules.diffusionmodules.util import (
avg_pool_nd,
checkpoint,
conv_nd,
linear,
normalization,
timestep_embedding,
zero_module,
)
from ...util import default, exists
# dummy replace
def convert_module_to_f16(x):
pass
def convert_module_to_f32(x):
pass
## go
class AttentionPool2d(nn.Module):
"""
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
"""
def __init__(
self,
spacial_dim: int,
embed_dim: int,
num_heads_channels: int,
output_dim: int = None,
):
super().__init__()
self.positional_embedding = nn.Parameter(
th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
)
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
self.num_heads = embed_dim // num_heads_channels
self.attention = QKVAttention(self.num_heads)
def forward(self, x):
b, c, *_spatial = x.shape
x = x.reshape(b, c, -1) # NC(HW)
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
x = self.qkv_proj(x)
x = self.attention(x)
x = self.c_proj(x)
return x[:, :, 0]
class TimestepBlock(nn.Module):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@abstractmethod
def forward(self, x, emb):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
"""
def forward(
self,
x: th.Tensor,
emb: th.Tensor,
context: Optional[th.Tensor] = None,
):
for layer in self:
module = layer
if isinstance(module, TimestepBlock):
x = layer(x, emb)
elif isinstance(module, SpatialTransformer):
x = layer(x, context)
else:
x = layer(x)
return x
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(
self, channels, use_conv, dims=2, out_channels=None, padding=1, third_up=False
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
self.third_up = third_up
if use_conv:
self.conv = conv_nd(
dims, self.channels, self.out_channels, 3, padding=padding
)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
t_factor = 1 if not self.third_up else 2
x = F.interpolate(
x,
(t_factor * x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
mode="nearest",
)
else:
x = F.interpolate(x, scale_factor=2, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class TransposedUpsample(nn.Module):
"Learned 2x upsampling without padding"
def __init__(self, channels, out_channels=None, ks=5):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.up = nn.ConvTranspose2d(
self.channels, self.out_channels, kernel_size=ks, stride=2
)
def forward(self, x):
return self.up(x)
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(
self, channels, use_conv, dims=2, out_channels=None, padding=1, third_down=False
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
stride = 2 if dims != 3 else ((1, 2, 2) if not third_down else (2, 2, 2))
if use_conv:
# print(f"Building a Downsample layer with {dims} dims.")
# print(
# f" --> settings are: \n in-chn: {self.channels}, out-chn: {self.out_channels}, "
# f"kernel-size: 3, stride: {stride}, padding: {padding}"
# )
# if dims == 3:
# print(f" --> Downsampling third axis (time): {third_down}")
self.op = conv_nd(
dims,
self.channels,
self.out_channels,
3,
stride=stride,
padding=padding,
)
else:
assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)
class ResBlock(TimestepBlock):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout.
:param out_channels: if specified, the number of out channels.
:param use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param use_checkpoint: if True, use gradient checkpointing on this module.
:param up: if True, use this block for upsampling.
:param down: if True, use this block for downsampling.
"""
def __init__(
self,
channels,
emb_channels,
dropout,
out_channels=None,
use_conv=False,
use_scale_shift_norm=False,
dims=2,
use_checkpoint=False,
up=False,
down=False,
kernel_size=3,
exchange_temb_dims=False,
skip_t_emb=False,
):
super().__init__()
self.channels = channels
self.emb_channels = emb_channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_checkpoint = use_checkpoint
self.use_scale_shift_norm = use_scale_shift_norm
self.exchange_temb_dims = exchange_temb_dims
if isinstance(kernel_size, Iterable):
padding = [k // 2 for k in kernel_size]
else:
padding = kernel_size // 2
self.in_layers = nn.Sequential(
normalization(channels),
nn.SiLU(),
conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding),
)
self.updown = up or down
if up:
self.h_upd = Upsample(channels, False, dims)
self.x_upd = Upsample(channels, False, dims)
elif down:
self.h_upd = Downsample(channels, False, dims)
self.x_upd = Downsample(channels, False, dims)
else:
self.h_upd = self.x_upd = nn.Identity()
self.skip_t_emb = skip_t_emb
self.emb_out_channels = (
2 * self.out_channels if use_scale_shift_norm else self.out_channels
)
if self.skip_t_emb:
print(f"Skipping timestep embedding in {self.__class__.__name__}")
assert not self.use_scale_shift_norm
self.emb_layers = None
self.exchange_temb_dims = False
else:
self.emb_layers = nn.Sequential(
nn.SiLU(),
linear(
emb_channels,
self.emb_out_channels,
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
conv_nd(
dims,
self.out_channels,
self.out_channels,
kernel_size,
padding=padding,
)
),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(
dims, channels, self.out_channels, kernel_size, padding=padding
)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
def forward(self, x, emb):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
return checkpoint(
self._forward, (x, emb), self.parameters(), self.use_checkpoint
)
def _forward(self, x, emb):
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = in_conv(h)
else:
h = self.in_layers(x)
if self.skip_t_emb:
emb_out = th.zeros_like(h)
else:
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = th.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_rest(h)
else:
if self.exchange_temb_dims:
emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=-1,
use_checkpoint=False,
use_new_attention_order=False,
):
super().__init__()
self.channels = channels
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint
self.norm = normalization(channels)
self.qkv = conv_nd(1, channels, channels * 3, 1)
if use_new_attention_order:
# split qkv before split heads
self.attention = QKVAttention(self.num_heads)
else:
# split heads before split qkv
self.attention = QKVAttentionLegacy(self.num_heads)
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, x, **kwargs):
# TODO add crossframe attention and use mixed checkpoint
return checkpoint(
self._forward, (x,), self.parameters(), True
) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
# return pt_checkpoint(self._forward, x) # pytorch
def _forward(self, x):
b, c, *spatial = x.shape
x = x.reshape(b, c, -1)
qkv = self.qkv(self.norm(x))
h = self.attention(qkv)
h = self.proj_out(h)
return (x + h).reshape(b, c, *spatial)
def count_flops_attn(model, _x, y):
"""
A counter for the `thop` package to count the operations in an
attention operation.
Meant to be used like:
macs, params = thop.profile(
model,
inputs=(inputs, timestamps),
custom_ops={QKVAttention: QKVAttention.count_flops},
)
"""
b, c, *spatial = y[0].shape
num_spatial = int(np.prod(spatial))
# We perform two matmuls with the same number of ops.
# The first computes the weight matrix, the second computes
# the combination of the value vectors.
matmul_ops = 2 * b * (num_spatial**2) * c
model.total_ops += th.DoubleTensor([matmul_ops])
class QKVAttentionLegacy(nn.Module):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
"bct,bcs->bts", q * scale, k * scale
) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length)
@staticmethod
def count_flops(model, _x, y):
return count_flops_attn(model, _x, y)
class QKVAttention(nn.Module):
"""
A module which performs QKV attention and splits in a different order.
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv):
"""
Apply QKV attention.
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.chunk(3, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
"bct,bcs->bts",
(q * scale).view(bs * self.n_heads, ch, length),
(k * scale).view(bs * self.n_heads, ch, length),
) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
return a.reshape(bs, -1, length)
@staticmethod
def count_flops(model, _x, y):
return count_flops_attn(model, _x, y)
class Timestep(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, t):
return timestep_embedding(t, self.dim)
str_to_dtype = {
"fp32": th.float32,
"fp16": th.float16,
"bf16": th.bfloat16
}
class UNetModel(nn.Module):
"""
The full UNet model with attention and timestep embedding.
:param in_channels: channels in the input Tensor.
:param model_channels: base channel count for the model.
:param out_channels: channels in the output Tensor.
:param num_res_blocks: number of residual blocks per downsample.
:param attention_resolutions: a collection of downsample rates at which
attention will take place. May be a set, list, or tuple.
For example, if this contains 4, then at 4x downsampling, attention
will be used.
:param dropout: the dropout probability.
:param channel_mult: channel multiplier for each level of the UNet.
:param conv_resample: if True, use learned convolutions for upsampling and
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param num_classes: if specified (as an int), then this model will be
class-conditional with `num_classes` classes.
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
:param num_heads: the number of attention heads in each attention layer.
:param num_heads_channels: if specified, ignore num_heads and instead use
a fixed channel width per attention head.
:param num_heads_upsample: works with num_heads to set a different number
of heads for upsampling. Deprecated.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
:param resblock_updown: use residual blocks for up/downsampling.
:param use_new_attention_order: use a different attention pattern for potentially
increased efficiency.
"""
def __init__(
self,
in_channels,
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
num_classes=None,
use_checkpoint=False,
use_fp16=False,
num_heads=-1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
use_spatial_transformer=False, # custom transformer support
transformer_depth=1, # custom transformer support
context_dim=None, # custom transformer support
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
legacy=True,
disable_self_attentions=None,
num_attention_blocks=None,
disable_middle_self_attn=False,
use_linear_in_transformer=False,
spatial_transformer_attn_type="softmax",
adm_in_channels=None,
use_fairscale_checkpoint=False,
offload_to_cpu=False,
transformer_depth_middle=None,
cfg_cond_embed_dim=None,
dtype="fp32",
):
super().__init__()
from omegaconf.listconfig import ListConfig
self.dtype = str_to_dtype[dtype]
if use_spatial_transformer:
assert (
context_dim is not None
), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
if context_dim is not None:
assert (
use_spatial_transformer
), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
if type(context_dim) == ListConfig:
context_dim = list(context_dim)
if num_heads_upsample == -1:
num_heads_upsample = num_heads
if num_heads == -1:
assert (
num_head_channels != -1
), "Either num_heads or num_head_channels has to be set"
if num_head_channels == -1:
assert (
num_heads != -1
), "Either num_heads or num_head_channels has to be set"
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
if isinstance(transformer_depth, int):
transformer_depth = len(channel_mult) * [transformer_depth]
elif isinstance(transformer_depth, ListConfig):
transformer_depth = list(transformer_depth)
transformer_depth_middle = default(
transformer_depth_middle, transformer_depth[-1]
)
if isinstance(num_res_blocks, int):
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
else:
if len(num_res_blocks) != len(channel_mult):
raise ValueError(
"provide num_res_blocks either as an int (globally constant) or "
"as a list/tuple (per-level) with the same length as channel_mult"
)
self.num_res_blocks = num_res_blocks
# self.num_res_blocks = num_res_blocks
if disable_self_attentions is not None:
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
assert len(disable_self_attentions) == len(channel_mult)
if num_attention_blocks is not None:
assert len(num_attention_blocks) == len(self.num_res_blocks)
assert all(
map(
lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
range(len(num_attention_blocks)),
)
)
print(
f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
f"attention will still not be set."
) # todo: convert to warning
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.num_classes = num_classes
self.use_checkpoint = use_checkpoint
# if use_fp16:
# print("WARNING: use_fp16 was dropped and has no effect anymore.")
# self.dtype = th.float16 if use_fp16 else th.float32
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
self.predict_codebook_ids = n_embed is not None
assert use_fairscale_checkpoint != use_checkpoint or not (
use_checkpoint or use_fairscale_checkpoint
)
self.use_fairscale_checkpoint = False
checkpoint_wrapper_fn = (
partial(checkpoint_wrapper, offload_to_cpu=offload_to_cpu)
if self.use_fairscale_checkpoint
else lambda x: x
)
time_embed_dim = model_channels * 4
self.time_embed = checkpoint_wrapper_fn(
nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
)
if self.num_classes is not None:
if isinstance(self.num_classes, int):
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
elif self.num_classes == "continuous":
print("setting up linear c_adm embedding layer")
self.label_emb = nn.Linear(1, time_embed_dim)
elif self.num_classes == "timestep":
self.label_emb = checkpoint_wrapper_fn(
nn.Sequential(
Timestep(model_channels),
nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
),
)
)
elif self.num_classes == "sequential":
assert adm_in_channels is not None
self.label_emb = nn.Sequential(
nn.Sequential(
linear(adm_in_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
)
else:
raise ValueError()
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
)
self._feature_size = model_channels
input_block_chans = [model_channels]
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for nr in range(self.num_res_blocks[level]):
layers = [
checkpoint_wrapper_fn(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
)
]
ch = mult * model_channels
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
# num_heads = 1
dim_head = (
ch // num_heads
if use_spatial_transformer
else num_head_channels
)
if exists(disable_self_attentions):
disabled_sa = disable_self_attentions[level]
else:
disabled_sa = False
if (
not exists(num_attention_blocks)
or nr < num_attention_blocks[level]
):
layers.append(
checkpoint_wrapper_fn(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
)
)
if not use_spatial_transformer
else checkpoint_wrapper_fn(
SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth[level],
context_dim=context_dim,
disable_self_attn=disabled_sa,
use_linear=use_linear_in_transformer,
attn_type=spatial_transformer_attn_type,
use_checkpoint=use_checkpoint,
)
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
checkpoint_wrapper_fn(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
)
)
if resblock_updown
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch
)
)
)
ch = out_ch
input_block_chans.append(ch)
ds *= 2
self._feature_size += ch
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
# num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
if len(attention_resolutions) == 0:
self.middle_block = TimestepEmbedSequential(
checkpoint_wrapper_fn(
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
),
checkpoint_wrapper_fn(
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
)
)
else:
self.middle_block = TimestepEmbedSequential(
checkpoint_wrapper_fn(
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
),
checkpoint_wrapper_fn(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
)
)
if not use_spatial_transformer
else checkpoint_wrapper_fn(
SpatialTransformer( # always uses a self-attn
ch,
num_heads,
dim_head,
depth=transformer_depth_middle,
context_dim=context_dim,
disable_self_attn=disable_middle_self_attn,
use_linear=use_linear_in_transformer,
attn_type=spatial_transformer_attn_type,
use_checkpoint=use_checkpoint,
)
),
checkpoint_wrapper_fn(
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
),
)
self._feature_size += ch
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]:
for i in range(self.num_res_blocks[level] + 1):
ich = input_block_chans.pop()
layers = [
checkpoint_wrapper_fn(
ResBlock(
ch + ich,
time_embed_dim,
dropout,
out_channels=model_channels * mult,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
)
]
ch = model_channels * mult
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
# num_heads = 1
dim_head = (
ch // num_heads
if use_spatial_transformer
else num_head_channels
)
if exists(disable_self_attentions):
disabled_sa = disable_self_attentions[level]
else:
disabled_sa = False
if (
not exists(num_attention_blocks)
or i < num_attention_blocks[level]
):
layers.append(
checkpoint_wrapper_fn(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads_upsample,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
)
)
if not use_spatial_transformer
else checkpoint_wrapper_fn(
SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth[level],
context_dim=context_dim,
disable_self_attn=disabled_sa,
use_linear=use_linear_in_transformer,
attn_type=spatial_transformer_attn_type,
use_checkpoint=use_checkpoint,
)
)
)
if level and i == self.num_res_blocks[level]:
out_ch = ch
layers.append(
checkpoint_wrapper_fn(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
up=True,
)
)
if resblock_updown
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
)
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
self.out = checkpoint_wrapper_fn(
nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
)
)
if self.predict_codebook_ids:
self.id_predictor = checkpoint_wrapper_fn(
nn.Sequential(
normalization(ch),
conv_nd(dims, model_channels, n_embed, 1),
# nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
)
)
if cfg_cond_embed_dim is not None:
self.w_proj = zero_module(nn.Linear(cfg_cond_embed_dim, self.model_channels, bias=False))
def convert_to_fp16(self):
"""
Convert the torso of the model to float16.
"""
self.input_blocks.apply(convert_module_to_f16)
self.middle_block.apply(convert_module_to_f16)
self.output_blocks.apply(convert_module_to_f16)
def convert_to_fp32(self):
"""
Convert the torso of the model to float32.
"""
self.input_blocks.apply(convert_module_to_f32)
self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32)
def forward(self, x, timesteps=None, context=None, y=None, scale_emb=None, **kwargs):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param context: conditioning plugged in via crossattn
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""
if x.dtype != self.dtype:
x = x.to(self.dtype)
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False, dtype=self.dtype)
if scale_emb is not None:
assert hasattr(self, "w_proj"), "w_proj not found in the model"
t_emb = t_emb + self.w_proj(scale_emb.to(self.dtype))
emb = self.time_embed(t_emb)
if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)
# h = x.type(self.dtype)
h = x
for module in self.input_blocks:
h = module(h, emb, context)
hs.append(h)
h = self.middle_block(h, emb, context)
for module in self.output_blocks:
h = th.cat([h, hs.pop()], dim=1)
h = module(h, emb, context)
h = h.type(x.dtype)
if self.predict_codebook_ids:
assert False, "not supported anymore. what the f*** are you doing?"
else:
return self.out(h)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment