Commit f0a2b81c authored by comfyanonymous's avatar comfyanonymous
Browse files

Cleanup: Remove a bunch of useless files.

parent 74297f5f
......@@ -14,8 +14,7 @@ from ..ldm.modules.diffusionmodules.util import (
from ..ldm.modules.attention import SpatialTransformer
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
from ..ldm.models.diffusion.ddpm import LatentDiffusion
from ..ldm.util import log_txt_as_img, exists, instantiate_from_config
from ..ldm.util import exists
class ControlledUnetModel(UNetModel):
......
......@@ -3,7 +3,6 @@ import os
import yaml
import folder_paths
from comfy.ldm.util import instantiate_from_config
from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE, load_checkpoint
import os.path as osp
import re
......
import torch
from ldm.modules.midas.api import load_midas_transform
class AddMiDaS(object):
def __init__(self, model_type):
super().__init__()
self.transform = load_midas_transform(model_type)
def pt2np(self, x):
x = ((x + 1.0) * .5).detach().cpu().numpy()
return x
def np2pt(self, x):
x = torch.from_numpy(x) * 2 - 1.
return x
def __call__(self, sample):
# sample['jpg'] is tensor hwc in [-1, 1] at this point
x = self.pt2np(sample['jpg'])
x = self.transform({"image": x})["image"]
sample['midas_in'] = x
return sample
\ No newline at end of file
"""
wild mixture of
https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
https://github.com/CompVis/taming-transformers
-- merci
"""
import torch
import torch.nn as nn
import numpy as np
# import pytorch_lightning as pl
from torch.optim.lr_scheduler import LambdaLR
from einops import rearrange, repeat
from contextlib import contextmanager, nullcontext
from functools import partial
import itertools
from tqdm import tqdm
from torchvision.utils import make_grid
# from pytorch_lightning.utilities.distributed import rank_zero_only
from comfy.ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
from comfy.ldm.modules.ema import LitEma
from comfy.ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
from ..autoencoder import IdentityFirstStage, AutoencoderKL
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
from .ddim import DDIMSampler
__conditioning_keys__ = {'concat': 'c_concat',
'crossattn': 'c_crossattn',
'adm': 'y'}
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
def uniform_on_device(r1, r2, shape, device):
return (r1 - r2) * torch.rand(*shape, device=device) + r2
# class DDPM(pl.LightningModule):
class DDPM(torch.nn.Module):
# classic DDPM with Gaussian diffusion, in image space
def __init__(self,
unet_config,
timesteps=1000,
beta_schedule="linear",
loss_type="l2",
ckpt_path=None,
ignore_keys=[],
load_only_unet=False,
monitor="val/loss",
use_ema=True,
first_stage_key="image",
image_size=256,
channels=3,
log_every_t=100,
clip_denoised=True,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3,
given_betas=None,
original_elbo_weight=0.,
v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
l_simple_weight=1.,
conditioning_key=None,
parameterization="eps", # all assuming fixed variance schedules
scheduler_config=None,
use_positional_encodings=False,
learn_logvar=False,
logvar_init=0.,
make_it_fit=False,
ucg_training=None,
reset_ema=False,
reset_num_ema_updates=False,
):
super().__init__()
assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"'
self.parameterization = parameterization
print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
self.cond_stage_model = None
self.clip_denoised = clip_denoised
self.log_every_t = log_every_t
self.first_stage_key = first_stage_key
self.image_size = image_size # try conv?
self.channels = channels
self.use_positional_encodings = use_positional_encodings
self.model = DiffusionWrapper(unet_config, conditioning_key)
count_params(self.model, verbose=True)
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self.model)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
self.use_scheduler = scheduler_config is not None
if self.use_scheduler:
self.scheduler_config = scheduler_config
self.v_posterior = v_posterior
self.original_elbo_weight = original_elbo_weight
self.l_simple_weight = l_simple_weight
if monitor is not None:
self.monitor = monitor
self.make_it_fit = make_it_fit
if reset_ema: assert exists(ckpt_path)
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
if reset_ema:
assert self.use_ema
print(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
self.model_ema = LitEma(self.model)
if reset_num_ema_updates:
print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
assert self.use_ema
self.model_ema.reset_num_updates()
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
self.loss_type = loss_type
self.learn_logvar = learn_logvar
self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
if self.learn_logvar:
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
self.ucg_training = ucg_training or dict()
if self.ucg_training:
self.ucg_prng = np.random.RandomState()
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if exists(given_betas):
betas = given_betas
else:
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
cosine_s=cosine_s)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
to_torch = partial(torch.tensor, dtype=torch.float32)
self.register_buffer('betas', to_torch(betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
1. - alphas_cumprod) + self.v_posterior * betas
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.register_buffer('posterior_variance', to_torch(posterior_variance))
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
self.register_buffer('posterior_mean_coef1', to_torch(
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
self.register_buffer('posterior_mean_coef2', to_torch(
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
if self.parameterization == "eps":
lvlb_weights = self.betas ** 2 / (
2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
elif self.parameterization == "x0":
lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
elif self.parameterization == "v":
lvlb_weights = torch.ones_like(self.betas ** 2 / (
2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)))
else:
raise NotImplementedError("mu not supported")
lvlb_weights[0] = lvlb_weights[1]
self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
assert not torch.isnan(self.lvlb_weights).all()
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.model.parameters())
self.model_ema.copy_to(self.model)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.model.parameters())
if context is not None:
print(f"{context}: Restored training weights")
@torch.no_grad()
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
sd = torch.load(path, map_location="cpu")
if "state_dict" in list(sd.keys()):
sd = sd["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
if self.make_it_fit:
n_params = len([name for name, _ in
itertools.chain(self.named_parameters(),
self.named_buffers())])
for name, param in tqdm(
itertools.chain(self.named_parameters(),
self.named_buffers()),
desc="Fitting old weights to new weights",
total=n_params
):
if not name in sd:
continue
old_shape = sd[name].shape
new_shape = param.shape
assert len(old_shape) == len(new_shape)
if len(new_shape) > 2:
# we only modify first two axes
assert new_shape[2:] == old_shape[2:]
# assumes first axis corresponds to output dim
if not new_shape == old_shape:
new_param = param.clone()
old_param = sd[name]
if len(new_shape) == 1:
for i in range(new_param.shape[0]):
new_param[i] = old_param[i % old_shape[0]]
elif len(new_shape) >= 2:
for i in range(new_param.shape[0]):
for j in range(new_param.shape[1]):
new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]]
n_used_old = torch.ones(old_shape[1])
for j in range(new_param.shape[1]):
n_used_old[j % old_shape[1]] += 1
n_used_new = torch.zeros(new_shape[1])
for j in range(new_param.shape[1]):
n_used_new[j] = n_used_old[j % old_shape[1]]
n_used_new = n_used_new[None, :]
while len(n_used_new.shape) < len(new_shape):
n_used_new = n_used_new.unsqueeze(-1)
new_param /= n_used_new
sd[name] = new_param
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0:
print(f"Missing Keys:\n {missing}")
if len(unexpected) > 0:
print(f"\nUnexpected Keys:\n {unexpected}")
def q_mean_variance(self, x_start, t):
"""
Get the distribution q(x_t | x_0).
:param x_start: the [N x C x ...] tensor of noiseless inputs.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
"""
mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
return mean, variance, log_variance
def predict_start_from_noise(self, x_t, t, noise):
return (
extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def predict_start_from_z_and_v(self, x_t, t, v):
# self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
# self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
)
def predict_eps_from_z_and_v(self, x_t, t, v):
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v +
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t
)
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, t, clip_denoised: bool):
model_out = self.model(x, t)
if self.parameterization == "eps":
x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
elif self.parameterization == "x0":
x_recon = model_out
if clip_denoised:
x_recon.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, shape, return_intermediates=False):
device = self.betas.device
b = shape[0]
img = torch.randn(shape, device=device)
intermediates = [img]
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
clip_denoised=self.clip_denoised)
if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
intermediates.append(img)
if return_intermediates:
return img, intermediates
return img
@torch.no_grad()
def sample(self, batch_size=16, return_intermediates=False):
image_size = self.image_size
channels = self.channels
return self.p_sample_loop((batch_size, channels, image_size, image_size),
return_intermediates=return_intermediates)
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
def get_v(self, x, noise, t):
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
)
def get_loss(self, pred, target, mean=True):
if self.loss_type == 'l1':
loss = (target - pred).abs()
if mean:
loss = loss.mean()
elif self.loss_type == 'l2':
if mean:
loss = torch.nn.functional.mse_loss(target, pred)
else:
loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
else:
raise NotImplementedError("unknown loss type '{loss_type}'")
return loss
def p_losses(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
model_out = self.model(x_noisy, t)
loss_dict = {}
if self.parameterization == "eps":
target = noise
elif self.parameterization == "x0":
target = x_start
elif self.parameterization == "v":
target = self.get_v(x_start, noise, t)
else:
raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported")
loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
log_prefix = 'train' if self.training else 'val'
loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
loss_simple = loss.mean() * self.l_simple_weight
loss_vlb = (self.lvlb_weights[t] * loss).mean()
loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
loss = loss_simple + self.original_elbo_weight * loss_vlb
loss_dict.update({f'{log_prefix}/loss': loss})
return loss, loss_dict
def forward(self, x, *args, **kwargs):
# b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
# assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
return self.p_losses(x, t, *args, **kwargs)
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = rearrange(x, 'b h w c -> b c h w')
x = x.to(memory_format=torch.contiguous_format).float()
return x
def shared_step(self, batch):
x = self.get_input(batch, self.first_stage_key)
loss, loss_dict = self(x)
return loss, loss_dict
def training_step(self, batch, batch_idx):
for k in self.ucg_training:
p = self.ucg_training[k]["p"]
val = self.ucg_training[k]["val"]
if val is None:
val = ""
for i in range(len(batch[k])):
if self.ucg_prng.choice(2, p=[1 - p, p]):
batch[k][i] = val
loss, loss_dict = self.shared_step(batch)
self.log_dict(loss_dict, prog_bar=True,
logger=True, on_step=True, on_epoch=True)
self.log("global_step", self.global_step,
prog_bar=True, logger=True, on_step=True, on_epoch=False)
if self.use_scheduler:
lr = self.optimizers().param_groups[0]['lr']
self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
return loss
@torch.no_grad()
def validation_step(self, batch, batch_idx):
_, loss_dict_no_ema = self.shared_step(batch)
with self.ema_scope():
_, loss_dict_ema = self.shared_step(batch)
loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self.model)
def _get_rows_from_list(self, samples):
n_imgs_per_row = len(samples)
denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
return denoise_grid
@torch.no_grad()
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
log = dict()
x = self.get_input(batch, self.first_stage_key)
N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row)
x = x.to(self.device)[:N]
log["inputs"] = x
# get diffusion row
diffusion_row = list()
x_start = x[:n_row]
for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
t = t.to(self.device).long()
noise = torch.randn_like(x_start)
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
diffusion_row.append(x_noisy)
log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
if sample:
# get denoise row
with self.ema_scope("Plotting"):
samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
log["samples"] = samples
log["denoise_row"] = self._get_rows_from_list(denoise_row)
if return_keys:
if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
return log
else:
return {key: log[key] for key in return_keys}
return log
def configure_optimizers(self):
lr = self.learning_rate
params = list(self.model.parameters())
if self.learn_logvar:
params = params + [self.logvar]
opt = torch.optim.AdamW(params, lr=lr)
return opt
class LatentDiffusion(DDPM):
"""main class"""
def __init__(self,
first_stage_config={},
cond_stage_config={},
num_timesteps_cond=None,
cond_stage_key="image",
cond_stage_trainable=False,
concat_mode=True,
cond_stage_forward=None,
conditioning_key=None,
scale_factor=1.0,
scale_by_std=False,
force_null_conditioning=False,
*args, **kwargs):
self.force_null_conditioning = force_null_conditioning
self.num_timesteps_cond = default(num_timesteps_cond, 1)
self.scale_by_std = scale_by_std
assert self.num_timesteps_cond <= kwargs['timesteps']
# for backwards compatibility after implementation of DiffusionWrapper
if conditioning_key is None:
conditioning_key = 'concat' if concat_mode else 'crossattn'
if cond_stage_config == '__is_unconditional__' and not self.force_null_conditioning:
conditioning_key = None
ckpt_path = kwargs.pop("ckpt_path", None)
reset_ema = kwargs.pop("reset_ema", False)
reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False)
ignore_keys = kwargs.pop("ignore_keys", [])
super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
self.concat_mode = concat_mode
self.cond_stage_trainable = cond_stage_trainable
self.cond_stage_key = cond_stage_key
try:
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
except:
self.num_downs = 0
if not scale_by_std:
self.scale_factor = scale_factor
else:
self.register_buffer('scale_factor', torch.tensor(scale_factor))
# self.instantiate_first_stage(first_stage_config)
# self.instantiate_cond_stage(cond_stage_config)
self.cond_stage_forward = cond_stage_forward
self.clip_denoised = False
self.bbox_tokenizer = None
self.restarted_from_ckpt = False
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys)
self.restarted_from_ckpt = True
if reset_ema:
assert self.use_ema
print(
f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
self.model_ema = LitEma(self.model)
if reset_num_ema_updates:
print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
assert self.use_ema
self.model_ema.reset_num_updates()
def make_cond_schedule(self, ):
self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
self.cond_ids[:self.num_timesteps_cond] = ids
# @rank_zero_only
@torch.no_grad()
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
# only for very first batch
if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
# set rescale weight to 1./std of encodings
print("### USING STD-RESCALING ###")
x = super().get_input(batch, self.first_stage_key)
x = x.to(self.device)
encoder_posterior = self.encode_first_stage(x)
z = self.get_first_stage_encoding(encoder_posterior).detach()
del self.scale_factor
self.register_buffer('scale_factor', 1. / z.flatten().std())
print(f"setting self.scale_factor to {self.scale_factor}")
print("### USING STD-RESCALING ###")
def register_schedule(self,
given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
self.shorten_cond_schedule = self.num_timesteps_cond > 1
if self.shorten_cond_schedule:
self.make_cond_schedule()
def instantiate_first_stage(self, config):
model = instantiate_from_config(config)
self.first_stage_model = model.eval()
self.first_stage_model.train = disabled_train
for param in self.first_stage_model.parameters():
param.requires_grad = False
def instantiate_cond_stage(self, config):
if not self.cond_stage_trainable:
if config == "__is_first_stage__":
print("Using first stage also as cond stage.")
self.cond_stage_model = self.first_stage_model
elif config == "__is_unconditional__":
print(f"Training {self.__class__.__name__} as an unconditional model.")
self.cond_stage_model = None
# self.be_unconditional = True
else:
model = instantiate_from_config(config)
self.cond_stage_model = model.eval()
self.cond_stage_model.train = disabled_train
for param in self.cond_stage_model.parameters():
param.requires_grad = False
else:
assert config != '__is_first_stage__'
assert config != '__is_unconditional__'
model = instantiate_from_config(config)
self.cond_stage_model = model
def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
denoise_row = []
for zd in tqdm(samples, desc=desc):
denoise_row.append(self.decode_first_stage(zd.to(self.device),
force_not_quantize=force_no_decoder_quantization))
n_imgs_per_row = len(denoise_row)
denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
return denoise_grid
def get_first_stage_encoding(self, encoder_posterior):
if isinstance(encoder_posterior, DiagonalGaussianDistribution):
z = encoder_posterior.sample()
elif isinstance(encoder_posterior, torch.Tensor):
z = encoder_posterior
else:
raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
return self.scale_factor * z
def get_learned_conditioning(self, c):
if self.cond_stage_forward is None:
if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
c = self.cond_stage_model.encode(c)
if isinstance(c, DiagonalGaussianDistribution):
c = c.mode()
else:
c = self.cond_stage_model(c)
else:
assert hasattr(self.cond_stage_model, self.cond_stage_forward)
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
return c
def meshgrid(self, h, w):
y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
arr = torch.cat([y, x], dim=-1)
return arr
def delta_border(self, h, w):
"""
:param h: height
:param w: width
:return: normalized distance to image border,
wtith min distance = 0 at border and max dist = 0.5 at image center
"""
lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
arr = self.meshgrid(h, w) / lower_right_corner
dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
return edge_dist
def get_weighting(self, h, w, Ly, Lx, device):
weighting = self.delta_border(h, w)
weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
self.split_input_params["clip_max_weight"], )
weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
if self.split_input_params["tie_braker"]:
L_weighting = self.delta_border(Ly, Lx)
L_weighting = torch.clip(L_weighting,
self.split_input_params["clip_min_tie_weight"],
self.split_input_params["clip_max_tie_weight"])
L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
weighting = weighting * L_weighting
return weighting
def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
"""
:param x: img of size (bs, c, h, w)
:return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
"""
bs, nc, h, w = x.shape
# number of crops in image
Ly = (h - kernel_size[0]) // stride[0] + 1
Lx = (w - kernel_size[1]) // stride[1] + 1
if uf == 1 and df == 1:
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
unfold = torch.nn.Unfold(**fold_params)
fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
elif uf > 1 and df == 1:
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
unfold = torch.nn.Unfold(**fold_params)
fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
dilation=1, padding=0,
stride=(stride[0] * uf, stride[1] * uf))
fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
elif df > 1 and uf == 1:
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
unfold = torch.nn.Unfold(**fold_params)
fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
dilation=1, padding=0,
stride=(stride[0] // df, stride[1] // df))
fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
else:
raise NotImplementedError
return fold, unfold, normalization, weighting
@torch.no_grad()
def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
cond_key=None, return_original_cond=False, bs=None, return_x=False):
x = super().get_input(batch, k)
if bs is not None:
x = x[:bs]
x = x.to(self.device)
encoder_posterior = self.encode_first_stage(x)
z = self.get_first_stage_encoding(encoder_posterior).detach()
if self.model.conditioning_key is not None and not self.force_null_conditioning:
if cond_key is None:
cond_key = self.cond_stage_key
if cond_key != self.first_stage_key:
if cond_key in ['caption', 'coordinates_bbox', "txt"]:
xc = batch[cond_key]
elif cond_key in ['class_label', 'cls']:
xc = batch
else:
xc = super().get_input(batch, cond_key).to(self.device)
else:
xc = x
if not self.cond_stage_trainable or force_c_encode:
if isinstance(xc, dict) or isinstance(xc, list):
c = self.get_learned_conditioning(xc)
else:
c = self.get_learned_conditioning(xc.to(self.device))
else:
c = xc
if bs is not None:
c = c[:bs]
if self.use_positional_encodings:
pos_x, pos_y = self.compute_latent_shifts(batch)
ckey = __conditioning_keys__[self.model.conditioning_key]
c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
else:
c = None
xc = None
if self.use_positional_encodings:
pos_x, pos_y = self.compute_latent_shifts(batch)
c = {'pos_x': pos_x, 'pos_y': pos_y}
out = [z, c]
if return_first_stage_outputs:
xrec = self.decode_first_stage(z)
out.extend([x, xrec])
if return_x:
out.extend([x])
if return_original_cond:
out.append(xc)
return out
@torch.no_grad()
def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
if predict_cids:
if z.dim() == 4:
z = torch.argmax(z.exp(), dim=1).long()
z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
z = rearrange(z, 'b h w c -> b c h w').contiguous()
z = 1. / self.scale_factor * z
return self.first_stage_model.decode(z)
@torch.no_grad()
def encode_first_stage(self, x):
return self.first_stage_model.encode(x)
def shared_step(self, batch, **kwargs):
x, c = self.get_input(batch, self.first_stage_key)
loss = self(x, c)
return loss
def forward(self, x, c, *args, **kwargs):
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
if self.model.conditioning_key is not None:
assert c is not None
if self.cond_stage_trainable:
c = self.get_learned_conditioning(c)
if self.shorten_cond_schedule: # TODO: drop this option
tc = self.cond_ids[t].to(self.device)
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
return self.p_losses(x, c, t, *args, **kwargs)
def apply_model(self, x_noisy, t, cond, return_ids=False):
if isinstance(cond, dict):
# hybrid case, cond is expected to be a dict
pass
else:
if not isinstance(cond, list):
cond = [cond]
key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
cond = {key: cond}
x_recon = self.model(x_noisy, t, **cond)
if isinstance(x_recon, tuple) and not return_ids:
return x_recon[0]
else:
return x_recon
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
def _prior_bpd(self, x_start):
"""
Get the prior KL term for the variational lower-bound, measured in
bits-per-dim.
This term can't be optimized, as it only depends on the encoder.
:param x_start: the [N x C x ...] tensor of inputs.
:return: a batch of [N] KL values (in bits), one per batch element.
"""
batch_size = x_start.shape[0]
t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
return mean_flat(kl_prior) / np.log(2.0)
def p_losses(self, x_start, cond, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
model_output = self.apply_model(x_noisy, t, cond)
loss_dict = {}
prefix = 'train' if self.training else 'val'
if self.parameterization == "x0":
target = x_start
elif self.parameterization == "eps":
target = noise
elif self.parameterization == "v":
target = self.get_v(x_start, noise, t)
else:
raise NotImplementedError()
loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
logvar_t = self.logvar[t].to(self.device)
loss = loss_simple / torch.exp(logvar_t) + logvar_t
# loss = loss_simple / torch.exp(self.logvar) + self.logvar
if self.learn_logvar:
loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
loss_dict.update({'logvar': self.logvar.data.mean()})
loss = self.l_simple_weight * loss.mean()
loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
loss += (self.original_elbo_weight * loss_vlb)
loss_dict.update({f'{prefix}/loss': loss})
return loss, loss_dict
def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
return_x0=False, score_corrector=None, corrector_kwargs=None):
t_in = t
model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
if score_corrector is not None:
assert self.parameterization == "eps"
model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
if return_codebook_ids:
model_out, logits = model_out
if self.parameterization == "eps":
x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
elif self.parameterization == "x0":
x_recon = model_out
else:
raise NotImplementedError()
if clip_denoised:
x_recon.clamp_(-1., 1.)
if quantize_denoised:
x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
if return_codebook_ids:
return model_mean, posterior_variance, posterior_log_variance, logits
elif return_x0:
return model_mean, posterior_variance, posterior_log_variance, x_recon
else:
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
return_codebook_ids=False, quantize_denoised=False, return_x0=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
b, *_, device = *x.shape, x.device
outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
return_codebook_ids=return_codebook_ids,
quantize_denoised=quantize_denoised,
return_x0=return_x0,
score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
if return_codebook_ids:
raise DeprecationWarning("Support dropped.")
model_mean, _, model_log_variance, logits = outputs
elif return_x0:
model_mean, _, model_log_variance, x0 = outputs
else:
model_mean, _, model_log_variance = outputs
noise = noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
if return_codebook_ids:
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
if return_x0:
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
else:
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
log_every_t=None):
if not log_every_t:
log_every_t = self.log_every_t
timesteps = self.num_timesteps
if batch_size is not None:
b = batch_size if batch_size is not None else shape[0]
shape = [batch_size] + list(shape)
else:
b = batch_size = shape[0]
if x_T is None:
img = torch.randn(shape, device=self.device)
else:
img = x_T
intermediates = []
if cond is not None:
if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
if start_T is not None:
timesteps = min(timesteps, start_T)
iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
total=timesteps) if verbose else reversed(
range(0, timesteps))
if type(temperature) == float:
temperature = [temperature] * timesteps
for i in iterator:
ts = torch.full((b,), i, device=self.device, dtype=torch.long)
if self.shorten_cond_schedule:
assert self.model.conditioning_key != 'hybrid'
tc = self.cond_ids[ts].to(cond.device)
cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
img, x0_partial = self.p_sample(img, cond, ts,
clip_denoised=self.clip_denoised,
quantize_denoised=quantize_denoised, return_x0=True,
temperature=temperature[i], noise_dropout=noise_dropout,
score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
if mask is not None:
assert x0 is not None
img_orig = self.q_sample(x0, ts)
img = img_orig * mask + (1. - mask) * img
if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(x0_partial)
if callback: callback(i)
if img_callback: img_callback(img, i)
return img, intermediates
@torch.no_grad()
def p_sample_loop(self, cond, shape, return_intermediates=False,
x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, start_T=None,
log_every_t=None):
if not log_every_t:
log_every_t = self.log_every_t
device = self.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
intermediates = [img]
if timesteps is None:
timesteps = self.num_timesteps
if start_T is not None:
timesteps = min(timesteps, start_T)
iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
range(0, timesteps))
if mask is not None:
assert x0 is not None
assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
for i in iterator:
ts = torch.full((b,), i, device=device, dtype=torch.long)
if self.shorten_cond_schedule:
assert self.model.conditioning_key != 'hybrid'
tc = self.cond_ids[ts].to(cond.device)
cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
img = self.p_sample(img, cond, ts,
clip_denoised=self.clip_denoised,
quantize_denoised=quantize_denoised)
if mask is not None:
img_orig = self.q_sample(x0, ts)
img = img_orig * mask + (1. - mask) * img
if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(img)
if callback: callback(i)
if img_callback: img_callback(img, i)
if return_intermediates:
return img, intermediates
return img
@torch.no_grad()
def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
verbose=True, timesteps=None, quantize_denoised=False,
mask=None, x0=None, shape=None, **kwargs):
if shape is None:
shape = (batch_size, self.channels, self.image_size, self.image_size)
if cond is not None:
if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
return self.p_sample_loop(cond,
shape,
return_intermediates=return_intermediates, x_T=x_T,
verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
mask=mask, x0=x0)
@torch.no_grad()
def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
if ddim:
ddim_sampler = DDIMSampler(self)
shape = (self.channels, self.image_size, self.image_size)
samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size,
shape, cond, verbose=False, **kwargs)
else:
samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
return_intermediates=True, **kwargs)
return samples, intermediates
@torch.no_grad()
def get_unconditional_conditioning(self, batch_size, null_label=None):
if null_label is not None:
xc = null_label
# if isinstance(xc, ListConfig):
# xc = list(xc)
if isinstance(xc, dict) or isinstance(xc, list):
c = self.get_learned_conditioning(xc)
else:
if hasattr(xc, "to"):
xc = xc.to(self.device)
c = self.get_learned_conditioning(xc)
else:
if self.cond_stage_key in ["class_label", "cls"]:
xc = self.cond_stage_model.get_unconditional_conditioning(batch_size, device=self.device)
return self.get_learned_conditioning(xc)
else:
raise NotImplementedError("todo")
if isinstance(c, list): # in case the encoder gives us a list
for i in range(len(c)):
c[i] = repeat(c[i], '1 ... -> b ...', b=batch_size).to(self.device)
else:
c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device)
return c
@torch.no_grad()
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0., return_keys=None,
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
use_ema_scope=True,
**kwargs):
ema_scope = self.ema_scope if use_ema_scope else nullcontext
use_ddim = ddim_steps is not None
log = dict()
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
return_first_stage_outputs=True,
force_c_encode=True,
return_original_cond=True,
bs=N)
N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row)
log["inputs"] = x
log["reconstruction"] = xrec
if self.model.conditioning_key is not None:
if hasattr(self.cond_stage_model, "decode"):
xc = self.cond_stage_model.decode(c)
log["conditioning"] = xc
elif self.cond_stage_key in ["caption", "txt"]:
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
log["conditioning"] = xc
elif self.cond_stage_key in ['class_label', "cls"]:
try:
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
log['conditioning'] = xc
except KeyError:
# probably no "human_label" in batch
pass
elif isimage(xc):
log["conditioning"] = xc
if ismap(xc):
log["original_conditioning"] = self.to_rgb(xc)
if plot_diffusion_rows:
# get diffusion row
diffusion_row = list()
z_start = z[:n_row]
for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
t = t.to(self.device).long()
noise = torch.randn_like(z_start)
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
diffusion_row.append(self.decode_first_stage(z_noisy))
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
log["diffusion_row"] = diffusion_grid
if sample:
# get denoise row
with ema_scope("Sampling"):
samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
x_samples = self.decode_first_stage(samples)
log["samples"] = x_samples
if plot_denoise_rows:
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
log["denoise_row"] = denoise_grid
if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
self.first_stage_model, IdentityFirstStage):
# also display when quantizing x0 while sampling
with ema_scope("Plotting Quantized Denoised"):
samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta,
quantize_denoised=True)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
# quantize_denoised=True)
x_samples = self.decode_first_stage(samples.to(self.device))
log["samples_x0_quantized"] = x_samples
if unconditional_guidance_scale > 1.0:
uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
if self.model.conditioning_key == "crossattn-adm":
uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
with ema_scope("Sampling with classifier-free guidance"):
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc,
)
x_samples_cfg = self.decode_first_stage(samples_cfg)
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
if inpaint:
# make a simple center square
b, h, w = z.shape[0], z.shape[2], z.shape[3]
mask = torch.ones(N, h, w).to(self.device)
# zeros will be filled in
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
mask = mask[:, None, ...]
with ema_scope("Plotting Inpaint"):
samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
x_samples = self.decode_first_stage(samples.to(self.device))
log["samples_inpainting"] = x_samples
log["mask"] = mask
# outpaint
mask = 1. - mask
with ema_scope("Plotting Outpaint"):
samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
x_samples = self.decode_first_stage(samples.to(self.device))
log["samples_outpainting"] = x_samples
if plot_progressive_rows:
with ema_scope("Plotting Progressives"):
img, progressives = self.progressive_denoising(c,
shape=(self.channels, self.image_size, self.image_size),
batch_size=N)
prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
log["progressive_row"] = prog_row
if return_keys:
if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
return log
else:
return {key: log[key] for key in return_keys}
return log
def configure_optimizers(self):
lr = self.learning_rate
params = list(self.model.parameters())
if self.cond_stage_trainable:
print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
params = params + list(self.cond_stage_model.parameters())
if self.learn_logvar:
print('Diffusion model optimizing logvar')
params.append(self.logvar)
opt = torch.optim.AdamW(params, lr=lr)
if self.use_scheduler:
assert 'target' in self.scheduler_config
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
scheduler = [
{
'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
'interval': 'step',
'frequency': 1
}]
return [opt], scheduler
return opt
@torch.no_grad()
def to_rgb(self, x):
x = x.float()
if not hasattr(self, "colorize"):
self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
x = nn.functional.conv2d(x, weight=self.colorize)
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
return x
# class DiffusionWrapper(pl.LightningModule):
class DiffusionWrapper(torch.nn.Module):
def __init__(self, diff_model_config, conditioning_key):
super().__init__()
self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False)
self.diffusion_model = instantiate_from_config(diff_model_config)
self.conditioning_key = conditioning_key
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None, control=None, transformer_options={}):
if self.conditioning_key is None:
out = self.diffusion_model(x, t, control=control, transformer_options=transformer_options)
elif self.conditioning_key == 'concat':
xc = torch.cat([x] + c_concat, dim=1)
out = self.diffusion_model(xc, t, control=control, transformer_options=transformer_options)
elif self.conditioning_key == 'crossattn':
if not self.sequential_cross_attn:
cc = torch.cat(c_crossattn, 1)
else:
cc = c_crossattn
if hasattr(self, "scripted_diffusion_model"):
# TorchScript changes names of the arguments
# with argument cc defined as context=cc scripted model will produce
# an error: RuntimeError: forward() is missing value for argument 'argument_3'.
out = self.scripted_diffusion_model(x, t, cc, control=control, transformer_options=transformer_options)
else:
out = self.diffusion_model(x, t, context=cc, control=control, transformer_options=transformer_options)
elif self.conditioning_key == 'hybrid':
xc = torch.cat([x] + c_concat, dim=1)
cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(xc, t, context=cc, control=control, transformer_options=transformer_options)
elif self.conditioning_key == 'hybrid-adm':
assert c_adm is not None
xc = torch.cat([x] + c_concat, dim=1)
cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(xc, t, context=cc, y=c_adm, control=control, transformer_options=transformer_options)
elif self.conditioning_key == 'crossattn-adm':
assert c_adm is not None
cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(x, t, context=cc, y=c_adm, control=control, transformer_options=transformer_options)
elif self.conditioning_key == 'adm':
cc = c_crossattn[0]
out = self.diffusion_model(x, t, y=cc, control=control, transformer_options=transformer_options)
else:
raise NotImplementedError()
return out
class LatentUpscaleDiffusion(LatentDiffusion):
def __init__(self, *args, low_scale_config, low_scale_key="LR", noise_level_key=None, **kwargs):
super().__init__(*args, **kwargs)
# assumes that neither the cond_stage nor the low_scale_model contain trainable params
assert not self.cond_stage_trainable
self.instantiate_low_stage(low_scale_config)
self.low_scale_key = low_scale_key
self.noise_level_key = noise_level_key
def instantiate_low_stage(self, config):
model = instantiate_from_config(config)
self.low_scale_model = model.eval()
self.low_scale_model.train = disabled_train
for param in self.low_scale_model.parameters():
param.requires_grad = False
@torch.no_grad()
def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
if not log_mode:
z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)
else:
z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
force_c_encode=True, return_original_cond=True, bs=bs)
x_low = batch[self.low_scale_key][:bs]
x_low = rearrange(x_low, 'b h w c -> b c h w')
x_low = x_low.to(memory_format=torch.contiguous_format).float()
zx, noise_level = self.low_scale_model(x_low)
if self.noise_level_key is not None:
# get noise level from batch instead, e.g. when extracting a custom noise level for bsr
raise NotImplementedError('TODO')
all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level}
if log_mode:
# TODO: maybe disable if too expensive
x_low_rec = self.low_scale_model.decode(zx)
return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level
return z, all_conds
@torch.no_grad()
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True,
unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True,
**kwargs):
ema_scope = self.ema_scope if use_ema_scope else nullcontext
use_ddim = ddim_steps is not None
log = dict()
z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N,
log_mode=True)
N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row)
log["inputs"] = x
log["reconstruction"] = xrec
log["x_lr"] = x_low
log[f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"] = x_low_rec
if self.model.conditioning_key is not None:
if hasattr(self.cond_stage_model, "decode"):
xc = self.cond_stage_model.decode(c)
log["conditioning"] = xc
elif self.cond_stage_key in ["caption", "txt"]:
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
log["conditioning"] = xc
elif self.cond_stage_key in ['class_label', 'cls']:
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
log['conditioning'] = xc
elif isimage(xc):
log["conditioning"] = xc
if ismap(xc):
log["original_conditioning"] = self.to_rgb(xc)
if plot_diffusion_rows:
# get diffusion row
diffusion_row = list()
z_start = z[:n_row]
for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
t = t.to(self.device).long()
noise = torch.randn_like(z_start)
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
diffusion_row.append(self.decode_first_stage(z_noisy))
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
log["diffusion_row"] = diffusion_grid
if sample:
# get denoise row
with ema_scope("Sampling"):
samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
x_samples = self.decode_first_stage(samples)
log["samples"] = x_samples
if plot_denoise_rows:
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
log["denoise_row"] = denoise_grid
if unconditional_guidance_scale > 1.0:
uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label)
# TODO explore better "unconditional" choices for the other keys
# maybe guide away from empty text label and highest noise level and maximally degraded zx?
uc = dict()
for k in c:
if k == "c_crossattn":
assert isinstance(c[k], list) and len(c[k]) == 1
uc[k] = [uc_tmp]
elif k == "c_adm": # todo: only run with text-based guidance?
assert isinstance(c[k], torch.Tensor)
#uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level
uc[k] = c[k]
elif isinstance(c[k], list):
uc[k] = [c[k][i] for i in range(len(c[k]))]
else:
uc[k] = c[k]
with ema_scope("Sampling with classifier-free guidance"):
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc,
)
x_samples_cfg = self.decode_first_stage(samples_cfg)
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
if plot_progressive_rows:
with ema_scope("Plotting Progressives"):
img, progressives = self.progressive_denoising(c,
shape=(self.channels, self.image_size, self.image_size),
batch_size=N)
prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
log["progressive_row"] = prog_row
return log
class LatentFinetuneDiffusion(LatentDiffusion):
"""
Basis for different finetunas, such as inpainting or depth2image
To disable finetuning mode, set finetune_keys to None
"""
def __init__(self,
concat_keys: tuple,
finetune_keys=("model.diffusion_model.input_blocks.0.0.weight",
"model_ema.diffusion_modelinput_blocks00weight"
),
keep_finetune_dims=4,
# if model was trained without concat mode before and we would like to keep these channels
c_concat_log_start=None, # to log reconstruction of c_concat codes
c_concat_log_end=None,
*args, **kwargs
):
ckpt_path = kwargs.pop("ckpt_path", None)
ignore_keys = kwargs.pop("ignore_keys", list())
super().__init__(*args, **kwargs)
self.finetune_keys = finetune_keys
self.concat_keys = concat_keys
self.keep_dims = keep_finetune_dims
self.c_concat_log_start = c_concat_log_start
self.c_concat_log_end = c_concat_log_end
if exists(self.finetune_keys): assert exists(ckpt_path), 'can only finetune from a given checkpoint'
if exists(ckpt_path):
self.init_from_ckpt(ckpt_path, ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
sd = torch.load(path, map_location="cpu")
if "state_dict" in list(sd.keys()):
sd = sd["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
# make it explicit, finetune by including extra input channels
if exists(self.finetune_keys) and k in self.finetune_keys:
new_entry = None
for name, param in self.named_parameters():
if name in self.finetune_keys:
print(
f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only")
new_entry = torch.zeros_like(param) # zero init
assert exists(new_entry), 'did not find matching parameter to modify'
new_entry[:, :self.keep_dims, ...] = sd[k]
sd[k] = new_entry
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0:
print(f"Missing Keys: {missing}")
if len(unexpected) > 0:
print(f"Unexpected Keys: {unexpected}")
@torch.no_grad()
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
use_ema_scope=True,
**kwargs):
ema_scope = self.ema_scope if use_ema_scope else nullcontext
use_ddim = ddim_steps is not None
log = dict()
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True)
c_cat, c = c["c_concat"][0], c["c_crossattn"][0]
N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row)
log["inputs"] = x
log["reconstruction"] = xrec
if self.model.conditioning_key is not None:
if hasattr(self.cond_stage_model, "decode"):
xc = self.cond_stage_model.decode(c)
log["conditioning"] = xc
elif self.cond_stage_key in ["caption", "txt"]:
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
log["conditioning"] = xc
elif self.cond_stage_key in ['class_label', 'cls']:
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
log['conditioning'] = xc
elif isimage(xc):
log["conditioning"] = xc
if ismap(xc):
log["original_conditioning"] = self.to_rgb(xc)
if not (self.c_concat_log_start is None and self.c_concat_log_end is None):
log["c_concat_decoded"] = self.decode_first_stage(c_cat[:, self.c_concat_log_start:self.c_concat_log_end])
if plot_diffusion_rows:
# get diffusion row
diffusion_row = list()
z_start = z[:n_row]
for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
t = t.to(self.device).long()
noise = torch.randn_like(z_start)
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
diffusion_row.append(self.decode_first_stage(z_noisy))
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
log["diffusion_row"] = diffusion_grid
if sample:
# get denoise row
with ema_scope("Sampling"):
samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
x_samples = self.decode_first_stage(samples)
log["samples"] = x_samples
if plot_denoise_rows:
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
log["denoise_row"] = denoise_grid
if unconditional_guidance_scale > 1.0:
uc_cross = self.get_unconditional_conditioning(N, unconditional_guidance_label)
uc_cat = c_cat
uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
with ema_scope("Sampling with classifier-free guidance"):
samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc_full,
)
x_samples_cfg = self.decode_first_stage(samples_cfg)
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
return log
class LatentInpaintDiffusion(LatentFinetuneDiffusion):
"""
can either run as pure inpainting model (only concat mode) or with mixed conditionings,
e.g. mask as concat and text via cross-attn.
To disable finetuning mode, set finetune_keys to None
"""
def __init__(self,
concat_keys=("mask", "masked_image"),
masked_image_key="masked_image",
*args, **kwargs
):
super().__init__(concat_keys, *args, **kwargs)
self.masked_image_key = masked_image_key
assert self.masked_image_key in concat_keys
@torch.no_grad()
def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
# note: restricted to non-trainable encoders currently
assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting'
z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
force_c_encode=True, return_original_cond=True, bs=bs)
assert exists(self.concat_keys)
c_cat = list()
for ck in self.concat_keys:
cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
if bs is not None:
cc = cc[:bs]
cc = cc.to(self.device)
bchw = z.shape
if ck != self.masked_image_key:
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
else:
cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
c_cat.append(cc)
c_cat = torch.cat(c_cat, dim=1)
all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
if return_first_stage_outputs:
return z, all_conds, x, xrec, xc
return z, all_conds
@torch.no_grad()
def log_images(self, *args, **kwargs):
log = super(LatentInpaintDiffusion, self).log_images(*args, **kwargs)
log["masked_image"] = rearrange(args[0]["masked_image"],
'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
return log
class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion):
"""
condition on monocular depth estimation
"""
def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs):
super().__init__(concat_keys=concat_keys, *args, **kwargs)
self.depth_model = instantiate_from_config(depth_stage_config)
self.depth_stage_key = concat_keys[0]
@torch.no_grad()
def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
# note: restricted to non-trainable encoders currently
assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for depth2img'
z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
force_c_encode=True, return_original_cond=True, bs=bs)
assert exists(self.concat_keys)
assert len(self.concat_keys) == 1
c_cat = list()
for ck in self.concat_keys:
cc = batch[ck]
if bs is not None:
cc = cc[:bs]
cc = cc.to(self.device)
cc = self.depth_model(cc)
cc = torch.nn.functional.interpolate(
cc,
size=z.shape[2:],
mode="bicubic",
align_corners=False,
)
depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3],
keepdim=True)
cc = 2. * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1.
c_cat.append(cc)
c_cat = torch.cat(c_cat, dim=1)
all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
if return_first_stage_outputs:
return z, all_conds, x, xrec, xc
return z, all_conds
@torch.no_grad()
def log_images(self, *args, **kwargs):
log = super().log_images(*args, **kwargs)
depth = self.depth_model(args[0][self.depth_stage_key])
depth_min, depth_max = torch.amin(depth, dim=[1, 2, 3], keepdim=True), \
torch.amax(depth, dim=[1, 2, 3], keepdim=True)
log["depth"] = 2. * (depth - depth_min) / (depth_max - depth_min) - 1.
return log
class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
"""
condition on low-res image (and optionally on some spatial noise augmentation)
"""
def __init__(self, concat_keys=("lr",), reshuffle_patch_size=None,
low_scale_config=None, low_scale_key=None, *args, **kwargs):
super().__init__(concat_keys=concat_keys, *args, **kwargs)
self.reshuffle_patch_size = reshuffle_patch_size
self.low_scale_model = None
if low_scale_config is not None:
print("Initializing a low-scale model")
assert exists(low_scale_key)
self.instantiate_low_stage(low_scale_config)
self.low_scale_key = low_scale_key
def instantiate_low_stage(self, config):
model = instantiate_from_config(config)
self.low_scale_model = model.eval()
self.low_scale_model.train = disabled_train
for param in self.low_scale_model.parameters():
param.requires_grad = False
@torch.no_grad()
def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
# note: restricted to non-trainable encoders currently
assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for upscaling-ft'
z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
force_c_encode=True, return_original_cond=True, bs=bs)
assert exists(self.concat_keys)
assert len(self.concat_keys) == 1
# optionally make spatial noise_level here
c_cat = list()
noise_level = None
for ck in self.concat_keys:
cc = batch[ck]
cc = rearrange(cc, 'b h w c -> b c h w')
if exists(self.reshuffle_patch_size):
assert isinstance(self.reshuffle_patch_size, int)
cc = rearrange(cc, 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w',
p1=self.reshuffle_patch_size, p2=self.reshuffle_patch_size)
if bs is not None:
cc = cc[:bs]
cc = cc.to(self.device)
if exists(self.low_scale_model) and ck == self.low_scale_key:
cc, noise_level = self.low_scale_model(cc)
c_cat.append(cc)
c_cat = torch.cat(c_cat, dim=1)
if exists(noise_level):
all_conds = {"c_concat": [c_cat], "c_crossattn": [c], "c_adm": noise_level}
else:
all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
if return_first_stage_outputs:
return z, all_conds, x, xrec, xc
return z, all_conds
@torch.no_grad()
def log_images(self, *args, **kwargs):
log = super().log_images(*args, **kwargs)
log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w')
return log
class ImageEmbeddingConditionedLatentDiffusion(LatentDiffusion):
def __init__(self, embedder_config=None, embedding_key="jpg", embedding_dropout=0.5,
freeze_embedder=True, noise_aug_config=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.embed_key = embedding_key
self.embedding_dropout = embedding_dropout
# self._init_embedder(embedder_config, freeze_embedder)
self._init_noise_aug(noise_aug_config)
def _init_embedder(self, config, freeze=True):
embedder = instantiate_from_config(config)
if freeze:
self.embedder = embedder.eval()
self.embedder.train = disabled_train
for param in self.embedder.parameters():
param.requires_grad = False
def _init_noise_aug(self, config):
if config is not None:
# use the KARLO schedule for noise augmentation on CLIP image embeddings
noise_augmentor = instantiate_from_config(config)
assert isinstance(noise_augmentor, nn.Module)
noise_augmentor = noise_augmentor.eval()
noise_augmentor.train = disabled_train
self.noise_augmentor = noise_augmentor
else:
self.noise_augmentor = None
def get_input(self, batch, k, cond_key=None, bs=None, **kwargs):
outputs = LatentDiffusion.get_input(self, batch, k, bs=bs, **kwargs)
z, c = outputs[0], outputs[1]
img = batch[self.embed_key][:bs]
img = rearrange(img, 'b h w c -> b c h w')
c_adm = self.embedder(img)
if self.noise_augmentor is not None:
c_adm, noise_level_emb = self.noise_augmentor(c_adm)
# assume this gives embeddings of noise levels
c_adm = torch.cat((c_adm, noise_level_emb), 1)
if self.training:
c_adm = torch.bernoulli((1. - self.embedding_dropout) * torch.ones(c_adm.shape[0],
device=c_adm.device)[:, None]) * c_adm
all_conds = {"c_crossattn": [c], "c_adm": c_adm}
noutputs = [z, all_conds]
noutputs.extend(outputs[2:])
return noutputs
@torch.no_grad()
def log_images(self, batch, N=8, n_row=4, **kwargs):
log = dict()
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True,
return_original_cond=True)
log["inputs"] = x
log["reconstruction"] = xrec
assert self.model.conditioning_key is not None
assert self.cond_stage_key in ["caption", "txt"]
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
log["conditioning"] = xc
uc = self.get_unconditional_conditioning(N, kwargs.get('unconditional_guidance_label', ''))
unconditional_guidance_scale = kwargs.get('unconditional_guidance_scale', 5.)
uc_ = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
ema_scope = self.ema_scope if kwargs.get('use_ema_scope', True) else nullcontext
with ema_scope(f"Sampling"):
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=True,
ddim_steps=kwargs.get('ddim_steps', 50), eta=kwargs.get('ddim_eta', 0.),
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc_, )
x_samples_cfg = self.decode_first_stage(samples_cfg)
log[f"samplescfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
return log
from typing import List, Tuple, Union
import torch
import torch.nn as nn
#from: https://github.com/kornia/kornia/blob/master/kornia/enhance/normalize.py
def enhance_normalize(data: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
r"""Normalize an image/video tensor with mean and standard deviation.
.. math::
\text{input[channel] = (input[channel] - mean[channel]) / std[channel]}
Where `mean` is :math:`(M_1, ..., M_n)` and `std` :math:`(S_1, ..., S_n)` for `n` channels,
Args:
data: Image tensor of size :math:`(B, C, *)`.
mean: Mean for each channel.
std: Standard deviations for each channel.
Return:
Normalised tensor with same size as input :math:`(B, C, *)`.
Examples:
>>> x = torch.rand(1, 4, 3, 3)
>>> out = normalize(x, torch.tensor([0.0]), torch.tensor([255.]))
>>> out.shape
torch.Size([1, 4, 3, 3])
>>> x = torch.rand(1, 4, 3, 3)
>>> mean = torch.zeros(4)
>>> std = 255. * torch.ones(4)
>>> out = normalize(x, mean, std)
>>> out.shape
torch.Size([1, 4, 3, 3])
"""
shape = data.shape
if len(mean.shape) == 0 or mean.shape[0] == 1:
mean = mean.expand(shape[1])
if len(std.shape) == 0 or std.shape[0] == 1:
std = std.expand(shape[1])
# Allow broadcast on channel dimension
if mean.shape and mean.shape[0] != 1:
if mean.shape[0] != data.shape[1] and mean.shape[:2] != data.shape[:2]:
raise ValueError(f"mean length and number of channels do not match. Got {mean.shape} and {data.shape}.")
# Allow broadcast on channel dimension
if std.shape and std.shape[0] != 1:
if std.shape[0] != data.shape[1] and std.shape[:2] != data.shape[:2]:
raise ValueError(f"std length and number of channels do not match. Got {std.shape} and {data.shape}.")
mean = torch.as_tensor(mean, device=data.device, dtype=data.dtype)
std = torch.as_tensor(std, device=data.device, dtype=data.dtype)
if mean.shape:
mean = mean[..., :, None]
if std.shape:
std = std[..., :, None]
out: torch.Tensor = (data.view(shape[0], shape[1], -1) - mean) / std
return out.view(shape)
import torch
import torch.nn as nn
from . import kornia_functions
from torch.utils.checkpoint import checkpoint
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
import open_clip
from ldm.util import default, count_params
class AbstractEncoder(nn.Module):
def __init__(self):
super().__init__()
def encode(self, *args, **kwargs):
raise NotImplementedError
class IdentityEncoder(AbstractEncoder):
def encode(self, x):
return x
class ClassEmbedder(nn.Module):
def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
super().__init__()
self.key = key
self.embedding = nn.Embedding(n_classes, embed_dim)
self.n_classes = n_classes
self.ucg_rate = ucg_rate
def forward(self, batch, key=None, disable_dropout=False):
if key is None:
key = self.key
# this is for use in crossattn
c = batch[key][:, None]
if self.ucg_rate > 0. and not disable_dropout:
mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1)
c = c.long()
c = self.embedding(c)
return c
def get_unconditional_conditioning(self, bs, device="cuda"):
uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
uc = torch.ones((bs,), device=device) * uc_class
uc = {self.key: uc}
return uc
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class FrozenT5Embedder(AbstractEncoder):
"""Uses the T5 transformer encoder for text"""
def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77,
freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super().__init__()
self.tokenizer = T5Tokenizer.from_pretrained(version)
self.transformer = T5EncoderModel.from_pretrained(version)
self.device = device
self.max_length = max_length # TODO: typical value?
if freeze:
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
# self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
return z
def encode(self, text):
return self(text)
class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = [
"last",
"pooled",
"hidden"
]
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
super().__init__()
assert layer in self.LAYERS
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
self.layer_idx = layer_idx
if layer == "hidden":
assert layer_idx is not None
assert 0 <= abs(layer_idx) <= 12
def freeze(self):
self.transformer = self.transformer.eval()
# self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden")
if self.layer == "last":
z = outputs.last_hidden_state
elif self.layer == "pooled":
z = outputs.pooler_output[:, None, :]
else:
z = outputs.hidden_states[self.layer_idx]
return z
def encode(self, text):
return self(text)
class ClipImageEmbedder(nn.Module):
def __init__(
self,
model,
jit=False,
device='cuda' if torch.cuda.is_available() else 'cpu',
antialias=True,
ucg_rate=0.
):
super().__init__()
from clip import load as load_clip
self.model, _ = load_clip(name=model, device=device, jit=jit)
self.antialias = antialias
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
self.ucg_rate = ucg_rate
def preprocess(self, x):
# normalize to [0,1]
# x = kornia_functions.geometry_resize(x, (224, 224),
# interpolation='bicubic', align_corners=True,
# antialias=self.antialias)
x = torch.nn.functional.interpolate(x, size=(224, 224), mode='bicubic', align_corners=True, antialias=True)
x = (x + 1.) / 2.
# re-normalize according to clip
x = kornia_functions.enhance_normalize(x, self.mean, self.std)
return x
def forward(self, x, no_dropout=False):
# x is assumed to be in range [-1,1]
out = self.model.encode_image(self.preprocess(x))
out = out.to(x.dtype)
if self.ucg_rate > 0. and not no_dropout:
out = torch.bernoulli((1. - self.ucg_rate) * torch.ones(out.shape[0], device=out.device))[:, None] * out
return out
class FrozenOpenCLIPEmbedder(AbstractEncoder):
"""
Uses the OpenCLIP transformer encoder for text
"""
LAYERS = [
# "pooled",
"last",
"penultimate"
]
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
freeze=True, layer="last"):
super().__init__()
assert layer in self.LAYERS
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
del model.visual
self.model = model
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
if self.layer == "last":
self.layer_idx = 0
elif self.layer == "penultimate":
self.layer_idx = 1
else:
raise NotImplementedError()
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
tokens = open_clip.tokenize(text)
z = self.encode_with_transformer(tokens.to(self.device))
return z
def encode_with_transformer(self, text):
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
x = x + self.model.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.model.ln_final(x)
return x
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
for i, r in enumerate(self.model.transformer.resblocks):
if i == len(self.model.transformer.resblocks) - self.layer_idx:
break
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(r, x, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
return x
def encode(self, text):
return self(text)
class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
"""
Uses the OpenCLIP vision transformer encoder for images
"""
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
freeze=True, layer="pooled", antialias=True, ucg_rate=0.):
super().__init__()
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'),
pretrained=version, )
del model.transformer
self.model = model
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
if self.layer == "penultimate":
raise NotImplementedError()
self.layer_idx = 1
self.antialias = antialias
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
self.ucg_rate = ucg_rate
def preprocess(self, x):
# normalize to [0,1]
# x = kornia.geometry.resize(x, (224, 224),
# interpolation='bicubic', align_corners=True,
# antialias=self.antialias)
x = torch.nn.functional.interpolate(x, size=(224, 224), mode='bicubic', align_corners=True, antialias=True)
x = (x + 1.) / 2.
# renormalize according to clip
x = kornia_functions.enhance_normalize(x, self.mean, self.std)
return x
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, image, no_dropout=False):
z = self.encode_with_vision_transformer(image)
if self.ucg_rate > 0. and not no_dropout:
z = torch.bernoulli((1. - self.ucg_rate) * torch.ones(z.shape[0], device=z.device))[:, None] * z
return z
def encode_with_vision_transformer(self, img):
img = self.preprocess(img)
x = self.model.visual(img)
return x
def encode(self, text):
return self(text)
class FrozenCLIPT5Encoder(AbstractEncoder):
def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
clip_max_length=77, t5_max_length=77):
super().__init__()
self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params.")
def encode(self, text):
return self(text)
def forward(self, text):
clip_z = self.clip_encoder.encode(text)
t5_z = self.t5_encoder.encode(text)
return [clip_z, t5_z]
# based on https://github.com/isl-org/MiDaS
import cv2
import torch
import torch.nn as nn
from torchvision.transforms import Compose
from ldm.modules.midas.midas.dpt_depth import DPTDepthModel
from ldm.modules.midas.midas.midas_net import MidasNet
from ldm.modules.midas.midas.midas_net_custom import MidasNet_small
from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet
ISL_PATHS = {
"dpt_large": "midas_models/dpt_large-midas-2f21e586.pt",
"dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt",
"midas_v21": "",
"midas_v21_small": "",
}
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
def load_midas_transform(model_type):
# https://github.com/isl-org/MiDaS/blob/master/run.py
# load transform only
if model_type == "dpt_large": # DPT-Large
net_w, net_h = 384, 384
resize_mode = "minimal"
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
elif model_type == "dpt_hybrid": # DPT-Hybrid
net_w, net_h = 384, 384
resize_mode = "minimal"
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
elif model_type == "midas_v21":
net_w, net_h = 384, 384
resize_mode = "upper_bound"
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
elif model_type == "midas_v21_small":
net_w, net_h = 256, 256
resize_mode = "upper_bound"
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
else:
assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
transform = Compose(
[
Resize(
net_w,
net_h,
resize_target=None,
keep_aspect_ratio=True,
ensure_multiple_of=32,
resize_method=resize_mode,
image_interpolation_method=cv2.INTER_CUBIC,
),
normalization,
PrepareForNet(),
]
)
return transform
def load_model(model_type):
# https://github.com/isl-org/MiDaS/blob/master/run.py
# load network
model_path = ISL_PATHS[model_type]
if model_type == "dpt_large": # DPT-Large
model = DPTDepthModel(
path=model_path,
backbone="vitl16_384",
non_negative=True,
)
net_w, net_h = 384, 384
resize_mode = "minimal"
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
elif model_type == "dpt_hybrid": # DPT-Hybrid
model = DPTDepthModel(
path=model_path,
backbone="vitb_rn50_384",
non_negative=True,
)
net_w, net_h = 384, 384
resize_mode = "minimal"
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
elif model_type == "midas_v21":
model = MidasNet(model_path, non_negative=True)
net_w, net_h = 384, 384
resize_mode = "upper_bound"
normalization = NormalizeImage(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
elif model_type == "midas_v21_small":
model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
non_negative=True, blocks={'expand': True})
net_w, net_h = 256, 256
resize_mode = "upper_bound"
normalization = NormalizeImage(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
else:
print(f"model_type '{model_type}' not implemented, use: --model_type large")
assert False
transform = Compose(
[
Resize(
net_w,
net_h,
resize_target=None,
keep_aspect_ratio=True,
ensure_multiple_of=32,
resize_method=resize_mode,
image_interpolation_method=cv2.INTER_CUBIC,
),
normalization,
PrepareForNet(),
]
)
return model.eval(), transform
class MiDaSInference(nn.Module):
MODEL_TYPES_TORCH_HUB = [
"DPT_Large",
"DPT_Hybrid",
"MiDaS_small"
]
MODEL_TYPES_ISL = [
"dpt_large",
"dpt_hybrid",
"midas_v21",
"midas_v21_small",
]
def __init__(self, model_type):
super().__init__()
assert (model_type in self.MODEL_TYPES_ISL)
model, _ = load_model(model_type)
self.model = model
self.model.train = disabled_train
def forward(self, x):
# x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array
# NOTE: we expect that the correct transform has been called during dataloading.
with torch.no_grad():
prediction = self.model(x)
prediction = torch.nn.functional.interpolate(
prediction.unsqueeze(1),
size=x.shape[2:],
mode="bicubic",
align_corners=False,
)
assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3])
return prediction
import torch
class BaseModel(torch.nn.Module):
def load(self, path):
"""Load model from file.
Args:
path (str): file path
"""
parameters = torch.load(path, map_location=torch.device('cpu'))
if "optimizer" in parameters:
parameters = parameters["model"]
self.load_state_dict(parameters)
import torch
import torch.nn as nn
from .vit import (
_make_pretrained_vitb_rn50_384,
_make_pretrained_vitl16_384,
_make_pretrained_vitb16_384,
forward_vit,
)
def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
if backbone == "vitl16_384":
pretrained = _make_pretrained_vitl16_384(
use_pretrained, hooks=hooks, use_readout=use_readout
)
scratch = _make_scratch(
[256, 512, 1024, 1024], features, groups=groups, expand=expand
) # ViT-L/16 - 85.0% Top1 (backbone)
elif backbone == "vitb_rn50_384":
pretrained = _make_pretrained_vitb_rn50_384(
use_pretrained,
hooks=hooks,
use_vit_only=use_vit_only,
use_readout=use_readout,
)
scratch = _make_scratch(
[256, 512, 768, 768], features, groups=groups, expand=expand
) # ViT-H/16 - 85.0% Top1 (backbone)
elif backbone == "vitb16_384":
pretrained = _make_pretrained_vitb16_384(
use_pretrained, hooks=hooks, use_readout=use_readout
)
scratch = _make_scratch(
[96, 192, 384, 768], features, groups=groups, expand=expand
) # ViT-B/16 - 84.6% Top1 (backbone)
elif backbone == "resnext101_wsl":
pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
elif backbone == "efficientnet_lite3":
pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
else:
print(f"Backbone '{backbone}' not implemented")
assert False
return pretrained, scratch
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
scratch = nn.Module()
out_shape1 = out_shape
out_shape2 = out_shape
out_shape3 = out_shape
out_shape4 = out_shape
if expand==True:
out_shape1 = out_shape
out_shape2 = out_shape*2
out_shape3 = out_shape*4
out_shape4 = out_shape*8
scratch.layer1_rn = nn.Conv2d(
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer2_rn = nn.Conv2d(
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer3_rn = nn.Conv2d(
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer4_rn = nn.Conv2d(
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
return scratch
def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
efficientnet = torch.hub.load(
"rwightman/gen-efficientnet-pytorch",
"tf_efficientnet_lite3",
pretrained=use_pretrained,
exportable=exportable
)
return _make_efficientnet_backbone(efficientnet)
def _make_efficientnet_backbone(effnet):
pretrained = nn.Module()
pretrained.layer1 = nn.Sequential(
effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
)
pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
return pretrained
def _make_resnet_backbone(resnet):
pretrained = nn.Module()
pretrained.layer1 = nn.Sequential(
resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
)
pretrained.layer2 = resnet.layer2
pretrained.layer3 = resnet.layer3
pretrained.layer4 = resnet.layer4
return pretrained
def _make_pretrained_resnext101_wsl(use_pretrained):
resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
return _make_resnet_backbone(resnet)
class Interpolate(nn.Module):
"""Interpolation module.
"""
def __init__(self, scale_factor, mode, align_corners=False):
"""Init.
Args:
scale_factor (float): scaling
mode (str): interpolation mode
"""
super(Interpolate, self).__init__()
self.interp = nn.functional.interpolate
self.scale_factor = scale_factor
self.mode = mode
self.align_corners = align_corners
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: interpolated data
"""
x = self.interp(
x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
)
return x
class ResidualConvUnit(nn.Module):
"""Residual convolution module.
"""
def __init__(self, features):
"""Init.
Args:
features (int): number of features
"""
super().__init__()
self.conv1 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True
)
self.conv2 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True
)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
out = self.relu(x)
out = self.conv1(out)
out = self.relu(out)
out = self.conv2(out)
return out + x
class FeatureFusionBlock(nn.Module):
"""Feature fusion block.
"""
def __init__(self, features):
"""Init.
Args:
features (int): number of features
"""
super(FeatureFusionBlock, self).__init__()
self.resConfUnit1 = ResidualConvUnit(features)
self.resConfUnit2 = ResidualConvUnit(features)
def forward(self, *xs):
"""Forward pass.
Returns:
tensor: output
"""
output = xs[0]
if len(xs) == 2:
output += self.resConfUnit1(xs[1])
output = self.resConfUnit2(output)
output = nn.functional.interpolate(
output, scale_factor=2, mode="bilinear", align_corners=True
)
return output
class ResidualConvUnit_custom(nn.Module):
"""Residual convolution module.
"""
def __init__(self, features, activation, bn):
"""Init.
Args:
features (int): number of features
"""
super().__init__()
self.bn = bn
self.groups=1
self.conv1 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
)
self.conv2 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
)
if self.bn==True:
self.bn1 = nn.BatchNorm2d(features)
self.bn2 = nn.BatchNorm2d(features)
self.activation = activation
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
out = self.activation(x)
out = self.conv1(out)
if self.bn==True:
out = self.bn1(out)
out = self.activation(out)
out = self.conv2(out)
if self.bn==True:
out = self.bn2(out)
if self.groups > 1:
out = self.conv_merge(out)
return self.skip_add.add(out, x)
# return out + x
class FeatureFusionBlock_custom(nn.Module):
"""Feature fusion block.
"""
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
"""Init.
Args:
features (int): number of features
"""
super(FeatureFusionBlock_custom, self).__init__()
self.deconv = deconv
self.align_corners = align_corners
self.groups=1
self.expand = expand
out_features = features
if self.expand==True:
out_features = features//2
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, *xs):
"""Forward pass.
Returns:
tensor: output
"""
output = xs[0]
if len(xs) == 2:
res = self.resConfUnit1(xs[1])
output = self.skip_add.add(output, res)
# output += res
output = self.resConfUnit2(output)
output = nn.functional.interpolate(
output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
)
output = self.out_conv(output)
return output
import torch
import torch.nn as nn
import torch.nn.functional as F
from .base_model import BaseModel
from .blocks import (
FeatureFusionBlock,
FeatureFusionBlock_custom,
Interpolate,
_make_encoder,
forward_vit,
)
def _make_fusion_block(features, use_bn):
return FeatureFusionBlock_custom(
features,
nn.ReLU(False),
deconv=False,
bn=use_bn,
expand=False,
align_corners=True,
)
class DPT(BaseModel):
def __init__(
self,
head,
features=256,
backbone="vitb_rn50_384",
readout="project",
channels_last=False,
use_bn=False,
):
super(DPT, self).__init__()
self.channels_last = channels_last
hooks = {
"vitb_rn50_384": [0, 1, 8, 11],
"vitb16_384": [2, 5, 8, 11],
"vitl16_384": [5, 11, 17, 23],
}
# Instantiate backbone and reassemble blocks
self.pretrained, self.scratch = _make_encoder(
backbone,
features,
False, # Set to true of you want to train from scratch, uses ImageNet weights
groups=1,
expand=False,
exportable=False,
hooks=hooks[backbone],
use_readout=readout,
)
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
self.scratch.output_conv = head
def forward(self, x):
if self.channels_last == True:
x.contiguous(memory_format=torch.channels_last)
layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)
path_4 = self.scratch.refinenet4(layer_4_rn)
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
out = self.scratch.output_conv(path_1)
return out
class DPTDepthModel(DPT):
def __init__(self, path=None, non_negative=True, **kwargs):
features = kwargs["features"] if "features" in kwargs else 256
head = nn.Sequential(
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
nn.ReLU(True) if non_negative else nn.Identity(),
nn.Identity(),
)
super().__init__(head, **kwargs)
if path is not None:
self.load(path)
def forward(self, x):
return super().forward(x).squeeze(dim=1)
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
This file contains code that is adapted from
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
"""
import torch
import torch.nn as nn
from .base_model import BaseModel
from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
class MidasNet(BaseModel):
"""Network for monocular depth estimation.
"""
def __init__(self, path=None, features=256, non_negative=True):
"""Init.
Args:
path (str, optional): Path to saved model. Defaults to None.
features (int, optional): Number of features. Defaults to 256.
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
"""
print("Loading weights: ", path)
super(MidasNet, self).__init__()
use_pretrained = False if path is None else True
self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
self.scratch.refinenet4 = FeatureFusionBlock(features)
self.scratch.refinenet3 = FeatureFusionBlock(features)
self.scratch.refinenet2 = FeatureFusionBlock(features)
self.scratch.refinenet1 = FeatureFusionBlock(features)
self.scratch.output_conv = nn.Sequential(
nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
Interpolate(scale_factor=2, mode="bilinear"),
nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
nn.ReLU(True) if non_negative else nn.Identity(),
)
if path:
self.load(path)
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input data (image)
Returns:
tensor: depth
"""
layer_1 = self.pretrained.layer1(x)
layer_2 = self.pretrained.layer2(layer_1)
layer_3 = self.pretrained.layer3(layer_2)
layer_4 = self.pretrained.layer4(layer_3)
layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)
path_4 = self.scratch.refinenet4(layer_4_rn)
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
out = self.scratch.output_conv(path_1)
return torch.squeeze(out, dim=1)
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
This file contains code that is adapted from
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
"""
import torch
import torch.nn as nn
from .base_model import BaseModel
from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
class MidasNet_small(BaseModel):
"""Network for monocular depth estimation.
"""
def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
blocks={'expand': True}):
"""Init.
Args:
path (str, optional): Path to saved model. Defaults to None.
features (int, optional): Number of features. Defaults to 256.
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
"""
print("Loading weights: ", path)
super(MidasNet_small, self).__init__()
use_pretrained = False if path else True
self.channels_last = channels_last
self.blocks = blocks
self.backbone = backbone
self.groups = 1
features1=features
features2=features
features3=features
features4=features
self.expand = False
if "expand" in self.blocks and self.blocks['expand'] == True:
self.expand = True
features1=features
features2=features*2
features3=features*4
features4=features*8
self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
self.scratch.activation = nn.ReLU(False)
self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
self.scratch.output_conv = nn.Sequential(
nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
Interpolate(scale_factor=2, mode="bilinear"),
nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
self.scratch.activation,
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
nn.ReLU(True) if non_negative else nn.Identity(),
nn.Identity(),
)
if path:
self.load(path)
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input data (image)
Returns:
tensor: depth
"""
if self.channels_last==True:
print("self.channels_last = ", self.channels_last)
x.contiguous(memory_format=torch.channels_last)
layer_1 = self.pretrained.layer1(x)
layer_2 = self.pretrained.layer2(layer_1)
layer_3 = self.pretrained.layer3(layer_2)
layer_4 = self.pretrained.layer4(layer_3)
layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)
path_4 = self.scratch.refinenet4(layer_4_rn)
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
out = self.scratch.output_conv(path_1)
return torch.squeeze(out, dim=1)
def fuse_model(m):
prev_previous_type = nn.Identity()
prev_previous_name = ''
previous_type = nn.Identity()
previous_name = ''
for name, module in m.named_modules():
if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
# print("FUSED ", prev_previous_name, previous_name, name)
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
# print("FUSED ", prev_previous_name, previous_name)
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
# elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
# print("FUSED ", previous_name, name)
# torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
prev_previous_type = previous_type
prev_previous_name = previous_name
previous_type = type(module)
previous_name = name
\ No newline at end of file
import numpy as np
import cv2
import math
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
Args:
sample (dict): sample
size (tuple): image size
Returns:
tuple: new size
"""
shape = list(sample["disparity"].shape)
if shape[0] >= size[0] and shape[1] >= size[1]:
return sample
scale = [0, 0]
scale[0] = size[0] / shape[0]
scale[1] = size[1] / shape[1]
scale = max(scale)
shape[0] = math.ceil(scale * shape[0])
shape[1] = math.ceil(scale * shape[1])
# resize
sample["image"] = cv2.resize(
sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
)
sample["disparity"] = cv2.resize(
sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
)
sample["mask"] = cv2.resize(
sample["mask"].astype(np.float32),
tuple(shape[::-1]),
interpolation=cv2.INTER_NEAREST,
)
sample["mask"] = sample["mask"].astype(bool)
return tuple(shape)
class Resize(object):
"""Resize sample to given size (width, height).
"""
def __init__(
self,
width,
height,
resize_target=True,
keep_aspect_ratio=False,
ensure_multiple_of=1,
resize_method="lower_bound",
image_interpolation_method=cv2.INTER_AREA,
):
"""Init.
Args:
width (int): desired output width
height (int): desired output height
resize_target (bool, optional):
True: Resize the full sample (image, mask, target).
False: Resize image only.
Defaults to True.
keep_aspect_ratio (bool, optional):
True: Keep the aspect ratio of the input sample.
Output sample might not have the given width and height, and
resize behaviour depends on the parameter 'resize_method'.
Defaults to False.
ensure_multiple_of (int, optional):
Output width and height is constrained to be multiple of this parameter.
Defaults to 1.
resize_method (str, optional):
"lower_bound": Output will be at least as large as the given size.
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
Defaults to "lower_bound".
"""
self.__width = width
self.__height = height
self.__resize_target = resize_target
self.__keep_aspect_ratio = keep_aspect_ratio
self.__multiple_of = ensure_multiple_of
self.__resize_method = resize_method
self.__image_interpolation_method = image_interpolation_method
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
if max_val is not None and y > max_val:
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
if y < min_val:
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
return y
def get_size(self, width, height):
# determine new height and width
scale_height = self.__height / height
scale_width = self.__width / width
if self.__keep_aspect_ratio:
if self.__resize_method == "lower_bound":
# scale such that output size is lower bound
if scale_width > scale_height:
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
elif self.__resize_method == "upper_bound":
# scale such that output size is upper bound
if scale_width < scale_height:
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
elif self.__resize_method == "minimal":
# scale as least as possbile
if abs(1 - scale_width) < abs(1 - scale_height):
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
else:
raise ValueError(
f"resize_method {self.__resize_method} not implemented"
)
if self.__resize_method == "lower_bound":
new_height = self.constrain_to_multiple_of(
scale_height * height, min_val=self.__height
)
new_width = self.constrain_to_multiple_of(
scale_width * width, min_val=self.__width
)
elif self.__resize_method == "upper_bound":
new_height = self.constrain_to_multiple_of(
scale_height * height, max_val=self.__height
)
new_width = self.constrain_to_multiple_of(
scale_width * width, max_val=self.__width
)
elif self.__resize_method == "minimal":
new_height = self.constrain_to_multiple_of(scale_height * height)
new_width = self.constrain_to_multiple_of(scale_width * width)
else:
raise ValueError(f"resize_method {self.__resize_method} not implemented")
return (new_width, new_height)
def __call__(self, sample):
width, height = self.get_size(
sample["image"].shape[1], sample["image"].shape[0]
)
# resize sample
sample["image"] = cv2.resize(
sample["image"],
(width, height),
interpolation=self.__image_interpolation_method,
)
if self.__resize_target:
if "disparity" in sample:
sample["disparity"] = cv2.resize(
sample["disparity"],
(width, height),
interpolation=cv2.INTER_NEAREST,
)
if "depth" in sample:
sample["depth"] = cv2.resize(
sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
)
sample["mask"] = cv2.resize(
sample["mask"].astype(np.float32),
(width, height),
interpolation=cv2.INTER_NEAREST,
)
sample["mask"] = sample["mask"].astype(bool)
return sample
class NormalizeImage(object):
"""Normlize image by given mean and std.
"""
def __init__(self, mean, std):
self.__mean = mean
self.__std = std
def __call__(self, sample):
sample["image"] = (sample["image"] - self.__mean) / self.__std
return sample
class PrepareForNet(object):
"""Prepare sample for usage as network input.
"""
def __init__(self):
pass
def __call__(self, sample):
image = np.transpose(sample["image"], (2, 0, 1))
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
if "mask" in sample:
sample["mask"] = sample["mask"].astype(np.float32)
sample["mask"] = np.ascontiguousarray(sample["mask"])
if "disparity" in sample:
disparity = sample["disparity"].astype(np.float32)
sample["disparity"] = np.ascontiguousarray(disparity)
if "depth" in sample:
depth = sample["depth"].astype(np.float32)
sample["depth"] = np.ascontiguousarray(depth)
return sample
import torch
import torch.nn as nn
import timm
import types
import math
import torch.nn.functional as F
class Slice(nn.Module):
def __init__(self, start_index=1):
super(Slice, self).__init__()
self.start_index = start_index
def forward(self, x):
return x[:, self.start_index :]
class AddReadout(nn.Module):
def __init__(self, start_index=1):
super(AddReadout, self).__init__()
self.start_index = start_index
def forward(self, x):
if self.start_index == 2:
readout = (x[:, 0] + x[:, 1]) / 2
else:
readout = x[:, 0]
return x[:, self.start_index :] + readout.unsqueeze(1)
class ProjectReadout(nn.Module):
def __init__(self, in_features, start_index=1):
super(ProjectReadout, self).__init__()
self.start_index = start_index
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
def forward(self, x):
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
features = torch.cat((x[:, self.start_index :], readout), -1)
return self.project(features)
class Transpose(nn.Module):
def __init__(self, dim0, dim1):
super(Transpose, self).__init__()
self.dim0 = dim0
self.dim1 = dim1
def forward(self, x):
x = x.transpose(self.dim0, self.dim1)
return x
def forward_vit(pretrained, x):
b, c, h, w = x.shape
glob = pretrained.model.forward_flex(x)
layer_1 = pretrained.activations["1"]
layer_2 = pretrained.activations["2"]
layer_3 = pretrained.activations["3"]
layer_4 = pretrained.activations["4"]
layer_1 = pretrained.act_postprocess1[0:2](layer_1)
layer_2 = pretrained.act_postprocess2[0:2](layer_2)
layer_3 = pretrained.act_postprocess3[0:2](layer_3)
layer_4 = pretrained.act_postprocess4[0:2](layer_4)
unflatten = nn.Sequential(
nn.Unflatten(
2,
torch.Size(
[
h // pretrained.model.patch_size[1],
w // pretrained.model.patch_size[0],
]
),
)
)
if layer_1.ndim == 3:
layer_1 = unflatten(layer_1)
if layer_2.ndim == 3:
layer_2 = unflatten(layer_2)
if layer_3.ndim == 3:
layer_3 = unflatten(layer_3)
if layer_4.ndim == 3:
layer_4 = unflatten(layer_4)
layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
return layer_1, layer_2, layer_3, layer_4
def _resize_pos_embed(self, posemb, gs_h, gs_w):
posemb_tok, posemb_grid = (
posemb[:, : self.start_index],
posemb[0, self.start_index :],
)
gs_old = int(math.sqrt(len(posemb_grid)))
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
return posemb
def forward_flex(self, x):
b, c, h, w = x.shape
pos_embed = self._resize_pos_embed(
self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
)
B = x.shape[0]
if hasattr(self.patch_embed, "backbone"):
x = self.patch_embed.backbone(x)
if isinstance(x, (list, tuple)):
x = x[-1] # last feature if backbone outputs list/tuple of features
x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
if getattr(self, "dist_token", None) is not None:
cls_tokens = self.cls_token.expand(
B, -1, -1
) # stole cls_tokens impl from Phil Wang, thanks
dist_token = self.dist_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, dist_token, x), dim=1)
else:
cls_tokens = self.cls_token.expand(
B, -1, -1
) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
x = x + pos_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x
activations = {}
def get_activation(name):
def hook(model, input, output):
activations[name] = output
return hook
def get_readout_oper(vit_features, features, use_readout, start_index=1):
if use_readout == "ignore":
readout_oper = [Slice(start_index)] * len(features)
elif use_readout == "add":
readout_oper = [AddReadout(start_index)] * len(features)
elif use_readout == "project":
readout_oper = [
ProjectReadout(vit_features, start_index) for out_feat in features
]
else:
assert (
False
), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
return readout_oper
def _make_vit_b16_backbone(
model,
features=[96, 192, 384, 768],
size=[384, 384],
hooks=[2, 5, 8, 11],
vit_features=768,
use_readout="ignore",
start_index=1,
):
pretrained = nn.Module()
pretrained.model = model
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
pretrained.activations = activations
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
# 32, 48, 136, 384
pretrained.act_postprocess1 = nn.Sequential(
readout_oper[0],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[0],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=features[0],
out_channels=features[0],
kernel_size=4,
stride=4,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)
pretrained.act_postprocess2 = nn.Sequential(
readout_oper[1],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[1],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=features[1],
out_channels=features[1],
kernel_size=2,
stride=2,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)
pretrained.act_postprocess3 = nn.Sequential(
readout_oper[2],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[2],
kernel_size=1,
stride=1,
padding=0,
),
)
pretrained.act_postprocess4 = nn.Sequential(
readout_oper[3],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[3],
kernel_size=1,
stride=1,
padding=0,
),
nn.Conv2d(
in_channels=features[3],
out_channels=features[3],
kernel_size=3,
stride=2,
padding=1,
),
)
pretrained.model.start_index = start_index
pretrained.model.patch_size = [16, 16]
# We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source.
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
pretrained.model._resize_pos_embed = types.MethodType(
_resize_pos_embed, pretrained.model
)
return pretrained
def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
hooks = [5, 11, 17, 23] if hooks == None else hooks
return _make_vit_b16_backbone(
model,
features=[256, 512, 1024, 1024],
hooks=hooks,
vit_features=1024,
use_readout=use_readout,
)
def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
hooks = [2, 5, 8, 11] if hooks == None else hooks
return _make_vit_b16_backbone(
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
)
def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
hooks = [2, 5, 8, 11] if hooks == None else hooks
return _make_vit_b16_backbone(
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
)
def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
model = timm.create_model(
"vit_deit_base_distilled_patch16_384", pretrained=pretrained
)
hooks = [2, 5, 8, 11] if hooks == None else hooks
return _make_vit_b16_backbone(
model,
features=[96, 192, 384, 768],
hooks=hooks,
use_readout=use_readout,
start_index=2,
)
def _make_vit_b_rn50_backbone(
model,
features=[256, 512, 768, 768],
size=[384, 384],
hooks=[0, 1, 8, 11],
vit_features=768,
use_vit_only=False,
use_readout="ignore",
start_index=1,
):
pretrained = nn.Module()
pretrained.model = model
if use_vit_only == True:
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
else:
pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
get_activation("1")
)
pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
get_activation("2")
)
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
pretrained.activations = activations
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
if use_vit_only == True:
pretrained.act_postprocess1 = nn.Sequential(
readout_oper[0],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[0],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=features[0],
out_channels=features[0],
kernel_size=4,
stride=4,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)
pretrained.act_postprocess2 = nn.Sequential(
readout_oper[1],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[1],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=features[1],
out_channels=features[1],
kernel_size=2,
stride=2,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)
else:
pretrained.act_postprocess1 = nn.Sequential(
nn.Identity(), nn.Identity(), nn.Identity()
)
pretrained.act_postprocess2 = nn.Sequential(
nn.Identity(), nn.Identity(), nn.Identity()
)
pretrained.act_postprocess3 = nn.Sequential(
readout_oper[2],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[2],
kernel_size=1,
stride=1,
padding=0,
),
)
pretrained.act_postprocess4 = nn.Sequential(
readout_oper[3],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[3],
kernel_size=1,
stride=1,
padding=0,
),
nn.Conv2d(
in_channels=features[3],
out_channels=features[3],
kernel_size=3,
stride=2,
padding=1,
),
)
pretrained.model.start_index = start_index
pretrained.model.patch_size = [16, 16]
# We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source.
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
# We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source.
pretrained.model._resize_pos_embed = types.MethodType(
_resize_pos_embed, pretrained.model
)
return pretrained
def _make_pretrained_vitb_rn50_384(
pretrained, use_readout="ignore", hooks=None, use_vit_only=False
):
model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
hooks = [0, 1, 8, 11] if hooks == None else hooks
return _make_vit_b_rn50_backbone(
model,
features=[256, 512, 768, 768],
size=[384, 384],
hooks=hooks,
use_vit_only=use_vit_only,
use_readout=use_readout,
)
"""Utils for monoDepth."""
import sys
import re
import numpy as np
import cv2
import torch
def read_pfm(path):
"""Read pfm file.
Args:
path (str): path to file
Returns:
tuple: (data, scale)
"""
with open(path, "rb") as file:
color = None
width = None
height = None
scale = None
endian = None
header = file.readline().rstrip()
if header.decode("ascii") == "PF":
color = True
elif header.decode("ascii") == "Pf":
color = False
else:
raise Exception("Not a PFM file: " + path)
dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
if dim_match:
width, height = list(map(int, dim_match.groups()))
else:
raise Exception("Malformed PFM header.")
scale = float(file.readline().decode("ascii").rstrip())
if scale < 0:
# little-endian
endian = "<"
scale = -scale
else:
# big-endian
endian = ">"
data = np.fromfile(file, endian + "f")
shape = (height, width, 3) if color else (height, width)
data = np.reshape(data, shape)
data = np.flipud(data)
return data, scale
def write_pfm(path, image, scale=1):
"""Write pfm file.
Args:
path (str): pathto file
image (array): data
scale (int, optional): Scale. Defaults to 1.
"""
with open(path, "wb") as file:
color = None
if image.dtype.name != "float32":
raise Exception("Image dtype must be float32.")
image = np.flipud(image)
if len(image.shape) == 3 and image.shape[2] == 3: # color image
color = True
elif (
len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
): # greyscale
color = False
else:
raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
file.write("PF\n" if color else "Pf\n".encode())
file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
endian = image.dtype.byteorder
if endian == "<" or endian == "=" and sys.byteorder == "little":
scale = -scale
file.write("%f\n".encode() % scale)
image.tofile(file)
def read_image(path):
"""Read image and output RGB image (0-1).
Args:
path (str): path to file
Returns:
array: RGB image (0-1)
"""
img = cv2.imread(path)
if img.ndim == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
return img
def resize_image(img):
"""Resize image and make it fit for network.
Args:
img (array): image
Returns:
tensor: data ready for network
"""
height_orig = img.shape[0]
width_orig = img.shape[1]
if width_orig > height_orig:
scale = width_orig / 384
else:
scale = height_orig / 384
height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
img_resized = (
torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
)
img_resized = img_resized.unsqueeze(0)
return img_resized
def resize_depth(depth, width, height):
"""Resize depth map and bring to CPU (numpy).
Args:
depth (tensor): depth
width (int): image width
height (int): image height
Returns:
array: processed depth
"""
depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
depth_resized = cv2.resize(
depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
)
return depth_resized
def write_depth(path, depth, bits=1):
"""Write depth map to pfm and png file.
Args:
path (str): filepath without extension
depth (array): depth
"""
write_pfm(path + ".pfm", depth.astype(np.float32))
depth_min = depth.min()
depth_max = depth.max()
max_val = (2**(8*bits))-1
if depth_max - depth_min > np.finfo("float").eps:
out = max_val * (depth - depth_min) / (depth_max - depth_min)
else:
out = np.zeros(depth.shape, dtype=depth.type)
if bits == 1:
cv2.imwrite(path + ".png", out.astype("uint8"))
elif bits == 2:
cv2.imwrite(path + ".png", out.astype("uint16"))
return
......@@ -1111,7 +1111,6 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
unet_config["context_dim"] = sd['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight'].shape[1]
sd_config["unet_config"] = {"target": "comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config}
model_config = {"target": "comfy.ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config}
unclip_model = False
inpaint_model = False
......@@ -1121,11 +1120,9 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
sd_config["embedding_dropout"] = 0.25
sd_config["conditioning_key"] = 'crossattn-adm'
unclip_model = True
model_config["target"] = "comfy.ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion"
elif unet_config["in_channels"] > 4: #inpainting model
sd_config["conditioning_key"] = "hybrid"
sd_config["finetune_keys"] = None
model_config["target"] = "comfy.ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
inpaint_model = True
else:
sd_config["conditioning_key"] = "crossattn"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment