Commit 4007efdd authored by lijian6's avatar lijian6
Browse files

Initial commit

parents
Pipeline #994 canceled with stages
# ------------------------------------------------------------------------------------
# Karlo-v1.0.alpha
# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
# ------------------------------------------------------------------------------------
from ldm.modules.karlo.kakao.models.sr_64_256 import SupRes64to256Progressive
class SupRes256to1kProgressive(SupRes64to256Progressive):
pass # no difference currently
# ------------------------------------------------------------------------------------
# Karlo-v1.0.alpha
# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
# ------------------------------------------------------------------------------------
import copy
import torch
from ldm.modules.karlo.kakao.modules.unet import SuperResUNetModel
from ldm.modules.karlo.kakao.modules import create_gaussian_diffusion
class ImprovedSupRes64to256ProgressiveModel(torch.nn.Module):
"""
ImprovedSR model fine-tunes the pretrained DDPM-based SR model by using adversarial and perceptual losses.
In specific, the low-resolution sample is iteratively recovered by 6 steps with the frozen pretrained SR model.
In the following additional one step, a seperate fine-tuned model recovers high-frequency details.
This approach greatly improves the fidelity of images of 256x256px, even with small number of reverse steps.
"""
def __init__(self, config):
super().__init__()
self._config = config
self._diffusion_kwargs = dict(
steps=config.diffusion.steps,
learn_sigma=config.diffusion.learn_sigma,
sigma_small=config.diffusion.sigma_small,
noise_schedule=config.diffusion.noise_schedule,
use_kl=config.diffusion.use_kl,
predict_xstart=config.diffusion.predict_xstart,
rescale_learned_sigmas=config.diffusion.rescale_learned_sigmas,
)
self.model_first_steps = SuperResUNetModel(
in_channels=3, # auto-changed to 6 inside the model
model_channels=config.model.hparams.channels,
out_channels=3,
num_res_blocks=config.model.hparams.depth,
attention_resolutions=(), # no attention
dropout=config.model.hparams.dropout,
channel_mult=config.model.hparams.channels_multiple,
resblock_updown=True,
use_middle_attention=False,
)
self.model_last_step = SuperResUNetModel(
in_channels=3, # auto-changed to 6 inside the model
model_channels=config.model.hparams.channels,
out_channels=3,
num_res_blocks=config.model.hparams.depth,
attention_resolutions=(), # no attention
dropout=config.model.hparams.dropout,
channel_mult=config.model.hparams.channels_multiple,
resblock_updown=True,
use_middle_attention=False,
)
@classmethod
def load_from_checkpoint(cls, config, ckpt_path, strict: bool = True):
ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"]
model = cls(config)
model.load_state_dict(ckpt, strict=strict)
return model
def get_sample_fn(self, timestep_respacing):
diffusion_kwargs = copy.deepcopy(self._diffusion_kwargs)
diffusion_kwargs.update(timestep_respacing=timestep_respacing)
diffusion = create_gaussian_diffusion(**diffusion_kwargs)
return diffusion.p_sample_loop_progressive_for_improved_sr
def forward(self, low_res, timestep_respacing="7", **kwargs):
assert (
timestep_respacing == "7"
), "different respacing method may work, but no guaranteed"
sample_fn = self.get_sample_fn(timestep_respacing)
sample_outputs = sample_fn(
self.model_first_steps,
self.model_last_step,
shape=low_res.shape,
clip_denoised=True,
model_kwargs=dict(low_res=low_res),
**kwargs,
)
for x in sample_outputs:
sample = x["sample"]
yield sample
# ------------------------------------------------------------------------------------
# Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion)
# ------------------------------------------------------------------------------------
from .diffusion import gaussian_diffusion as gd
from .diffusion.respace import (
SpacedDiffusion,
space_timesteps,
)
def create_gaussian_diffusion(
steps,
learn_sigma,
sigma_small,
noise_schedule,
use_kl,
predict_xstart,
rescale_learned_sigmas,
timestep_respacing,
):
betas = gd.get_named_beta_schedule(noise_schedule, steps)
if use_kl:
loss_type = gd.LossType.RESCALED_KL
elif rescale_learned_sigmas:
loss_type = gd.LossType.RESCALED_MSE
else:
loss_type = gd.LossType.MSE
if not timestep_respacing:
timestep_respacing = [steps]
return SpacedDiffusion(
use_timesteps=space_timesteps(steps, timestep_respacing),
betas=betas,
model_mean_type=(
gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
),
model_var_type=(
(
gd.ModelVarType.FIXED_LARGE
if not sigma_small
else gd.ModelVarType.FIXED_SMALL
)
if not learn_sigma
else gd.ModelVarType.LEARNED_RANGE
),
loss_type=loss_type,
)
# ------------------------------------------------------------------------------------
# Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion)
# ------------------------------------------------------------------------------------
import enum
import math
import numpy as np
import torch as th
def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
warmup_time = int(num_diffusion_timesteps * warmup_frac)
betas[:warmup_time] = np.linspace(
beta_start, beta_end, warmup_time, dtype=np.float64
)
return betas
def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
"""
This is the deprecated API for creating beta schedules.
See get_named_beta_schedule() for the new library of schedules.
"""
if beta_schedule == "quad":
betas = (
np.linspace(
beta_start**0.5,
beta_end**0.5,
num_diffusion_timesteps,
dtype=np.float64,
)
** 2
)
elif beta_schedule == "linear":
betas = np.linspace(
beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
)
elif beta_schedule == "warmup10":
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
elif beta_schedule == "warmup50":
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
elif beta_schedule == "const":
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
betas = 1.0 / np.linspace(
num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
)
else:
raise NotImplementedError(beta_schedule)
assert betas.shape == (num_diffusion_timesteps,)
return betas
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
"""
Get a pre-defined beta schedule for the given name.
The beta schedule library consists of beta schedules which remain similar
in the limit of num_diffusion_timesteps.
Beta schedules may be added, but should not be removed or changed once
they are committed to maintain backwards compatibility.
"""
if schedule_name == "linear":
# Linear schedule from Ho et al, extended to work for any number of
# diffusion steps.
scale = 1000 / num_diffusion_timesteps
return get_beta_schedule(
"linear",
beta_start=scale * 0.0001,
beta_end=scale * 0.02,
num_diffusion_timesteps=num_diffusion_timesteps,
)
elif schedule_name == "squaredcos_cap_v2":
return betas_for_alpha_bar(
num_diffusion_timesteps,
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
)
else:
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas)
class ModelMeanType(enum.Enum):
"""
Which type of output the model predicts.
"""
PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
START_X = enum.auto() # the model predicts x_0
EPSILON = enum.auto() # the model predicts epsilon
class ModelVarType(enum.Enum):
"""
What is used as the model's output variance.
The LEARNED_RANGE option has been added to allow the model to predict
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
"""
LEARNED = enum.auto()
FIXED_SMALL = enum.auto()
FIXED_LARGE = enum.auto()
LEARNED_RANGE = enum.auto()
class LossType(enum.Enum):
MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
RESCALED_MSE = (
enum.auto()
) # use raw MSE loss (with RESCALED_KL when learning variances)
KL = enum.auto() # use the variational lower-bound
RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
def is_vb(self):
return self == LossType.KL or self == LossType.RESCALED_KL
class GaussianDiffusion(th.nn.Module):
"""
Utilities for training and sampling diffusion models.
Original ported from this codebase:
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
:param betas: a 1-D numpy array of betas for each diffusion timestep,
starting at T and going to 1.
"""
def __init__(
self,
*,
betas,
model_mean_type,
model_var_type,
loss_type,
):
super(GaussianDiffusion, self).__init__()
self.model_mean_type = model_mean_type
self.model_var_type = model_var_type
self.loss_type = loss_type
# Use float64 for accuracy.
betas = np.array(betas, dtype=np.float64)
assert len(betas.shape) == 1, "betas must be 1-D"
assert (betas > 0).all() and (betas <= 1).all()
self.num_timesteps = int(betas.shape[0])
alphas = 1.0 - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
alphas_cumprod_next = np.append(alphas_cumprod[1:], 0.0)
assert alphas_cumprod_prev.shape == (self.num_timesteps,)
# calculations for diffusion q(x_t | x_{t-1}) and others
sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - alphas_cumprod)
log_one_minus_alphas_cumprod = np.log(1.0 - alphas_cumprod)
sqrt_recip_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod)
sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod - 1)
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = (
betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
posterior_log_variance_clipped = np.log(
np.append(posterior_variance[1], posterior_variance[1:])
)
posterior_mean_coef1 = (
betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
)
posterior_mean_coef2 = (
(1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
)
self.register_buffer("betas", th.from_numpy(betas), persistent=False)
self.register_buffer(
"alphas_cumprod", th.from_numpy(alphas_cumprod), persistent=False
)
self.register_buffer(
"alphas_cumprod_prev", th.from_numpy(alphas_cumprod_prev), persistent=False
)
self.register_buffer(
"alphas_cumprod_next", th.from_numpy(alphas_cumprod_next), persistent=False
)
self.register_buffer(
"sqrt_alphas_cumprod", th.from_numpy(sqrt_alphas_cumprod), persistent=False
)
self.register_buffer(
"sqrt_one_minus_alphas_cumprod",
th.from_numpy(sqrt_one_minus_alphas_cumprod),
persistent=False,
)
self.register_buffer(
"log_one_minus_alphas_cumprod",
th.from_numpy(log_one_minus_alphas_cumprod),
persistent=False,
)
self.register_buffer(
"sqrt_recip_alphas_cumprod",
th.from_numpy(sqrt_recip_alphas_cumprod),
persistent=False,
)
self.register_buffer(
"sqrt_recipm1_alphas_cumprod",
th.from_numpy(sqrt_recipm1_alphas_cumprod),
persistent=False,
)
self.register_buffer(
"posterior_variance", th.from_numpy(posterior_variance), persistent=False
)
self.register_buffer(
"posterior_log_variance_clipped",
th.from_numpy(posterior_log_variance_clipped),
persistent=False,
)
self.register_buffer(
"posterior_mean_coef1",
th.from_numpy(posterior_mean_coef1),
persistent=False,
)
self.register_buffer(
"posterior_mean_coef2",
th.from_numpy(posterior_mean_coef2),
persistent=False,
)
def q_mean_variance(self, x_start, t):
"""
Get the distribution q(x_t | x_0).
:param x_start: the [N x C x ...] tensor of noiseless inputs.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
"""
mean = (
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
)
variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
log_variance = _extract_into_tensor(
self.log_one_minus_alphas_cumprod, t, x_start.shape
)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data for a given number of diffusion steps.
In other words, sample from q(x_t | x_0).
:param x_start: the initial data batch.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:param noise: if specified, the split-out normal noise.
:return: A noisy version of x_start.
"""
if noise is None:
noise = th.randn_like(x_start)
assert noise.shape == x_start.shape
return (
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
* noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior:
q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
_extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = _extract_into_tensor(
self.posterior_log_variance_clipped, t, x_t.shape
)
assert (
posterior_mean.shape[0]
== posterior_variance.shape[0]
== posterior_log_variance_clipped.shape[0]
== x_start.shape[0]
)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(
self,
model,
x,
t,
clip_denoised=True,
denoised_fn=None,
model_kwargs=None,
**ignore_kwargs,
):
"""
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
the initial x, x_0.
:param model: the model, which takes a signal and a batch of timesteps
as input.
:param x: the [N x C x ...] tensor at time t.
:param t: a 1-D Tensor of timesteps.
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
:param denoised_fn: if not None, a function which applies to the
x_start prediction before it is used to sample. Applies before
clip_denoised.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:return: a dict with the following keys:
- 'mean': the model mean output.
- 'variance': the model variance output.
- 'log_variance': the log of 'variance'.
- 'pred_xstart': the prediction for x_0.
"""
if model_kwargs is None:
model_kwargs = {}
B, C = x.shape[:2]
assert t.shape == (B,)
model_output = model(x, t, **model_kwargs)
if isinstance(model_output, tuple):
model_output, extra = model_output
else:
extra = None
if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
assert model_output.shape == (B, C * 2, *x.shape[2:])
model_output, model_var_values = th.split(model_output, C, dim=1)
if self.model_var_type == ModelVarType.LEARNED:
model_log_variance = model_var_values
model_variance = th.exp(model_log_variance)
else:
min_log = _extract_into_tensor(
self.posterior_log_variance_clipped, t, x.shape
)
max_log = _extract_into_tensor(th.log(self.betas), t, x.shape)
# The model_var_values is [-1, 1] for [min_var, max_var].
frac = (model_var_values + 1) / 2
model_log_variance = frac * max_log + (1 - frac) * min_log
model_variance = th.exp(model_log_variance)
else:
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so
# to get a better decoder log likelihood.
ModelVarType.FIXED_LARGE: (
th.cat([self.posterior_variance[1][None], self.betas[1:]]),
th.log(th.cat([self.posterior_variance[1][None], self.betas[1:]])),
),
ModelVarType.FIXED_SMALL: (
self.posterior_variance,
self.posterior_log_variance_clipped,
),
}[self.model_var_type]
model_variance = _extract_into_tensor(model_variance, t, x.shape)
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
def process_xstart(x):
if denoised_fn is not None:
x = denoised_fn(x)
if clip_denoised:
return x.clamp(-1, 1)
return x
if self.model_mean_type == ModelMeanType.PREVIOUS_X:
pred_xstart = process_xstart(
self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
)
model_mean = model_output
elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
if self.model_mean_type == ModelMeanType.START_X:
pred_xstart = process_xstart(model_output)
else:
pred_xstart = process_xstart(
self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
)
model_mean, _, _ = self.q_posterior_mean_variance(
x_start=pred_xstart, x_t=x, t=t
)
else:
raise NotImplementedError(self.model_mean_type)
assert (
model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
)
return {
"mean": model_mean,
"variance": model_variance,
"log_variance": model_log_variance,
"pred_xstart": pred_xstart,
}
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
)
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
return (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- pred_xstart
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
"""
Compute the mean for the previous step, given a function cond_fn that
computes the gradient of a conditional log probability with respect to
x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
condition on y.
This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
"""
gradient = cond_fn(x, t, **model_kwargs)
new_mean = (
p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
)
return new_mean
def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
"""
Compute what the p_mean_variance output would have been, should the
model's score function be conditioned by cond_fn.
See condition_mean() for details on cond_fn.
Unlike condition_mean(), this instead uses the conditioning strategy
from Song et al (2020).
"""
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
out = p_mean_var.copy()
out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
out["mean"], _, _ = self.q_posterior_mean_variance(
x_start=out["pred_xstart"], x_t=x, t=t
)
return out
def p_sample(
self,
model,
x,
t,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
):
"""
Sample x_{t-1} from the model at the given timestep.
:param model: the model to sample from.
:param x: the current tensor at x_{t-1}.
:param t: the value of t, starting at 0 for the first diffusion step.
:param clip_denoised: if True, clip the x_start prediction to [-1, 1].
:param denoised_fn: if not None, a function which applies to the
x_start prediction before it is used to sample.
:param cond_fn: if not None, this is a gradient function that acts
similarly to the model.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:return: a dict containing the following keys:
- 'sample': a random sample from the model.
- 'pred_xstart': a prediction of x_0.
"""
out = self.p_mean_variance(
model,
x,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
)
noise = th.randn_like(x)
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
) # no noise when t == 0
if cond_fn is not None:
out["mean"] = self.condition_mean(
cond_fn, out, x, t, model_kwargs=model_kwargs
)
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
def p_sample_loop(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
progress=False,
):
"""
Generate samples from the model.
:param model: the model module.
:param shape: the shape of the samples, (N, C, H, W).
:param noise: if specified, the noise from the encoder to sample.
Should be of the same shape as `shape`.
:param clip_denoised: if True, clip x_start predictions to [-1, 1].
:param denoised_fn: if not None, a function which applies to the
x_start prediction before it is used to sample.
:param cond_fn: if not None, this is a gradient function that acts
similarly to the model.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:param device: if specified, the device to create the samples on.
If not specified, use a model parameter's device.
:param progress: if True, show a tqdm progress bar.
:return: a non-differentiable batch of samples.
"""
final = None
for sample in self.p_sample_loop_progressive(
model,
shape,
noise=noise,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
device=device,
progress=progress,
):
final = sample
return final["sample"]
def p_sample_loop_progressive(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
progress=False,
):
"""
Generate samples from the model and yield intermediate samples from
each timestep of diffusion.
Arguments are the same as p_sample_loop().
Returns a generator over dicts, where each dict is the return value of
p_sample().
"""
if device is None:
device = next(model.parameters()).device
assert isinstance(shape, (tuple, list))
if noise is not None:
img = noise
else:
img = th.randn(*shape, device=device)
indices = list(range(self.num_timesteps))[::-1]
if progress:
# Lazy import so that we don't depend on tqdm.
from tqdm.auto import tqdm
indices = tqdm(indices)
for idx, i in enumerate(indices):
t = th.tensor([i] * shape[0], device=device)
with th.no_grad():
out = self.p_sample(
model,
img,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
)
yield out
img = out["sample"]
def p_sample_loop_progressive_for_improved_sr(
self,
model,
model_aux,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
progress=False,
):
"""
Modified version of p_sample_loop_progressive for sampling from the improved sr model
"""
if device is None:
device = next(model.parameters()).device
assert isinstance(shape, (tuple, list))
if noise is not None:
img = noise
else:
img = th.randn(*shape, device=device)
indices = list(range(self.num_timesteps))[::-1]
if progress:
# Lazy import so that we don't depend on tqdm.
from tqdm.auto import tqdm
indices = tqdm(indices)
for idx, i in enumerate(indices):
t = th.tensor([i] * shape[0], device=device)
with th.no_grad():
out = self.p_sample(
model_aux if len(indices) - 1 == idx else model,
img,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
)
yield out
img = out["sample"]
def ddim_sample(
self,
model,
x,
t,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
eta=0.0,
):
"""
Sample x_{t-1} from the model using DDIM.
Same usage as p_sample().
"""
out = self.p_mean_variance(
model,
x,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
)
if cond_fn is not None:
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
# Usually our model outputs epsilon, but we re-derive it
# in case we used x_start or x_prev prediction.
eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
sigma = (
eta
* th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
* th.sqrt(1 - alpha_bar / alpha_bar_prev)
)
# Equation 12.
noise = th.randn_like(x)
mean_pred = (
out["pred_xstart"] * th.sqrt(alpha_bar_prev)
+ th.sqrt(1 - alpha_bar_prev - sigma**2) * eps
)
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
) # no noise when t == 0
sample = mean_pred + nonzero_mask * sigma * noise
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
def ddim_reverse_sample(
self,
model,
x,
t,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
eta=0.0,
):
"""
Sample x_{t+1} from the model using DDIM reverse ODE.
"""
assert eta == 0.0, "Reverse ODE only for deterministic path"
out = self.p_mean_variance(
model,
x,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
)
if cond_fn is not None:
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
# Usually our model outputs epsilon, but we re-derive it
# in case we used x_start or x_prev prediction.
eps = (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
- out["pred_xstart"]
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
# Equation 12. reversed
mean_pred = (
out["pred_xstart"] * th.sqrt(alpha_bar_next)
+ th.sqrt(1 - alpha_bar_next) * eps
)
return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
def ddim_sample_loop(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
progress=False,
eta=0.0,
):
"""
Generate samples from the model using DDIM.
Same usage as p_sample_loop().
"""
final = None
for sample in self.ddim_sample_loop_progressive(
model,
shape,
noise=noise,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
device=device,
progress=progress,
eta=eta,
):
final = sample
return final["sample"]
def ddim_sample_loop_progressive(
self,
model,
shape,
noise=None,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
device=None,
progress=False,
eta=0.0,
):
"""
Use DDIM to sample from the model and yield intermediate samples from
each timestep of DDIM.
Same usage as p_sample_loop_progressive().
"""
if device is None:
device = next(model.parameters()).device
assert isinstance(shape, (tuple, list))
if noise is not None:
img = noise
else:
img = th.randn(*shape, device=device)
indices = list(range(self.num_timesteps))[::-1]
if progress:
# Lazy import so that we don't depend on tqdm.
from tqdm.auto import tqdm
indices = tqdm(indices)
for i in indices:
t = th.tensor([i] * shape[0], device=device)
with th.no_grad():
out = self.ddim_sample(
model,
img,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
model_kwargs=model_kwargs,
eta=eta,
)
yield out
img = out["sample"]
def _extract_into_tensor(arr, timesteps, broadcast_shape):
"""
Extract values from a 1-D numpy array for a batch of indices.
:param arr: the 1-D numpy array.
:param timesteps: a tensor of indices into the array to extract.
:param broadcast_shape: a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
res = arr.to(device=timesteps.device)[timesteps].float()
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res + th.zeros(broadcast_shape, device=timesteps.device)
# ------------------------------------------------------------------------------------
# Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion)
# ------------------------------------------------------------------------------------
import torch as th
from .gaussian_diffusion import GaussianDiffusion
def space_timesteps(num_timesteps, section_counts):
"""
Create a list of timesteps to use from an original diffusion process,
given the number of timesteps we want to take from equally-sized portions
of the original process.
For example, if there's 300 timesteps and the section counts are [10,15,20]
then the first 100 timesteps are strided to be 10 timesteps, the second 100
are strided to be 15 timesteps, and the final 100 are strided to be 20.
:param num_timesteps: the number of diffusion steps in the original
process to divide up.
:param section_counts: either a list of numbers, or a string containing
comma-separated numbers, indicating the step count
per section. As a special case, use "ddimN" where N
is a number of steps to use the striding from the
DDIM paper.
:return: a set of diffusion steps from the original process to use.
"""
if isinstance(section_counts, str):
if section_counts.startswith("ddim"):
desired_count = int(section_counts[len("ddim") :])
for i in range(1, num_timesteps):
if len(range(0, num_timesteps, i)) == desired_count:
return set(range(0, num_timesteps, i))
raise ValueError(
f"cannot create exactly {num_timesteps} steps with an integer stride"
)
elif section_counts == "fast27":
steps = space_timesteps(num_timesteps, "10,10,3,2,2")
# Help reduce DDIM artifacts from noisiest timesteps.
steps.remove(num_timesteps - 1)
steps.add(num_timesteps - 3)
return steps
section_counts = [int(x) for x in section_counts.split(",")]
size_per = num_timesteps // len(section_counts)
extra = num_timesteps % len(section_counts)
start_idx = 0
all_steps = []
for i, section_count in enumerate(section_counts):
size = size_per + (1 if i < extra else 0)
if size < section_count:
raise ValueError(
f"cannot divide section of {size} steps into {section_count}"
)
if section_count <= 1:
frac_stride = 1
else:
frac_stride = (size - 1) / (section_count - 1)
cur_idx = 0.0
taken_steps = []
for _ in range(section_count):
taken_steps.append(start_idx + round(cur_idx))
cur_idx += frac_stride
all_steps += taken_steps
start_idx += size
return set(all_steps)
class SpacedDiffusion(GaussianDiffusion):
"""
A diffusion process which can skip steps in a base diffusion process.
:param use_timesteps: a collection (sequence or set) of timesteps from the
original diffusion process to retain.
:param kwargs: the kwargs to create the base diffusion process.
"""
def __init__(self, use_timesteps, **kwargs):
self.use_timesteps = set(use_timesteps)
self.original_num_steps = len(kwargs["betas"])
base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
last_alpha_cumprod = 1.0
new_betas = []
timestep_map = []
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
if i in self.use_timesteps:
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
last_alpha_cumprod = alpha_cumprod
timestep_map.append(i)
kwargs["betas"] = th.tensor(new_betas).numpy()
super().__init__(**kwargs)
self.register_buffer("timestep_map", th.tensor(timestep_map), persistent=False)
def p_mean_variance(self, model, *args, **kwargs):
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
def condition_mean(self, cond_fn, *args, **kwargs):
return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
def condition_score(self, cond_fn, *args, **kwargs):
return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
def _wrap_model(self, model):
def wrapped(x, ts, **kwargs):
ts_cpu = ts.detach().to("cpu")
return model(
x, self.timestep_map[ts_cpu].to(device=ts.device, dtype=ts.dtype), **kwargs
)
return wrapped
# ------------------------------------------------------------------------------------
# Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion)
# ------------------------------------------------------------------------------------
import math
import torch as th
import torch.nn as nn
import torch.nn.functional as F
class GroupNorm32(nn.GroupNorm):
def __init__(self, num_groups, num_channels, swish, eps=1e-5):
super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps)
self.swish = swish
def forward(self, x):
y = super().forward(x.float()).to(x.dtype)
if self.swish == 1.0:
y = F.silu(y)
elif self.swish:
y = y * F.sigmoid(y * float(self.swish))
return y
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def linear(*args, **kwargs):
"""
Create a linear module.
"""
return nn.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def scale_module(module, scale):
"""
Scale the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().mul_(scale)
return module
def normalization(channels, swish=0.0):
"""
Make a standard normalization layer, with an optional swish activation.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
def timestep_embedding(timesteps, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half = dim // 2
freqs = th.exp(
-math.log(max_period)
* th.arange(start=0, end=half, dtype=th.float32, device=timesteps.device)
/ half
)
args = timesteps[:, None].float() * freqs[None]
embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
if dim % 2:
embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def mean_flat(tensor):
"""
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
# ------------------------------------------------------------------------------------
# Modified from Guided-Diffusion (https://github.com/openai/guided-diffusion)
# ------------------------------------------------------------------------------------
from abc import abstractmethod
import torch as th
def create_named_schedule_sampler(name, diffusion):
"""
Create a ScheduleSampler from a library of pre-defined samplers.
:param name: the name of the sampler.
:param diffusion: the diffusion object to sample for.
"""
if name == "uniform":
return UniformSampler(diffusion)
else:
raise NotImplementedError(f"unknown schedule sampler: {name}")
class ScheduleSampler(th.nn.Module):
"""
A distribution over timesteps in the diffusion process, intended to reduce
variance of the objective.
By default, samplers perform unbiased importance sampling, in which the
objective's mean is unchanged.
However, subclasses may override sample() to change how the resampled
terms are reweighted, allowing for actual changes in the objective.
"""
@abstractmethod
def weights(self):
"""
Get a numpy array of weights, one per diffusion step.
The weights needn't be normalized, but must be positive.
"""
def sample(self, batch_size, device):
"""
Importance-sample timesteps for a batch.
:param batch_size: the number of timesteps.
:param device: the torch device to save to.
:return: a tuple (timesteps, weights):
- timesteps: a tensor of timestep indices.
- weights: a tensor of weights to scale the resulting losses.
"""
w = self.weights()
p = w / th.sum(w)
indices = p.multinomial(batch_size, replacement=True)
weights = 1 / (len(p) * p[indices])
return indices, weights
class UniformSampler(ScheduleSampler):
def __init__(self, diffusion):
super(UniformSampler, self).__init__()
self.diffusion = diffusion
self.register_buffer(
"_weights", th.ones([diffusion.num_timesteps]), persistent=False
)
def weights(self):
return self._weights
# ------------------------------------------------------------------------------------
# Modified from Guided-Diffusion (https://github.com/openai/guided-diffusion)
# ------------------------------------------------------------------------------------
import math
from abc import abstractmethod
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from .nn import (
avg_pool_nd,
conv_nd,
linear,
normalization,
timestep_embedding,
zero_module,
)
from .xf import LayerNorm
class TimestepBlock(nn.Module):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@abstractmethod
def forward(self, x, emb):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
"""
def forward(self, x, emb, encoder_out=None, mask=None):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
elif isinstance(layer, AttentionBlock):
x = layer(x, encoder_out, mask=mask)
else:
x = layer(x)
return x
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
)
else:
x = F.interpolate(x, scale_factor=2, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(
dims, self.channels, self.out_channels, 3, stride=stride, padding=1
)
else:
assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)
class ResBlock(TimestepBlock):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout.
:param out_channels: if specified, the number of out channels.
:param use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param use_checkpoint: if True, use gradient checkpointing on this module.
:param up: if True, use this block for upsampling.
:param down: if True, use this block for downsampling.
"""
def __init__(
self,
channels,
emb_channels,
dropout,
out_channels=None,
use_conv=False,
use_scale_shift_norm=False,
dims=2,
use_checkpoint=False,
up=False,
down=False,
):
super().__init__()
self.channels = channels
self.emb_channels = emb_channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_checkpoint = use_checkpoint
self.use_scale_shift_norm = use_scale_shift_norm
self.in_layers = nn.Sequential(
normalization(channels, swish=1.0),
nn.Identity(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
)
self.updown = up or down
if up:
self.h_upd = Upsample(channels, False, dims)
self.x_upd = Upsample(channels, False, dims)
elif down:
self.h_upd = Downsample(channels, False, dims)
self.x_upd = Downsample(channels, False, dims)
else:
self.h_upd = self.x_upd = nn.Identity()
self.emb_layers = nn.Sequential(
nn.SiLU(),
linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
),
)
self.out_layers = nn.Sequential(
normalization(
self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0
),
nn.SiLU() if use_scale_shift_norm else nn.Identity(),
nn.Dropout(p=dropout),
zero_module(
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(
dims, channels, self.out_channels, 3, padding=1
)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
def forward(self, x, emb):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = in_conv(h)
else:
h = self.in_layers(x)
emb_out = self.emb_layers(emb)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = th.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_rest(h)
else:
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h
class ResBlockNoTimeEmbedding(nn.Module):
"""
A residual block without time embedding
:param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout.
:param out_channels: if specified, the number of out channels.
:param use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param use_checkpoint: if True, use gradient checkpointing on this module.
:param up: if True, use this block for upsampling.
:param down: if True, use this block for downsampling.
"""
def __init__(
self,
channels,
emb_channels,
dropout,
out_channels=None,
use_conv=False,
dims=2,
use_checkpoint=False,
up=False,
down=False,
**kwargs,
):
super().__init__()
self.channels = channels
self.emb_channels = emb_channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_checkpoint = use_checkpoint
self.in_layers = nn.Sequential(
normalization(channels, swish=1.0),
nn.Identity(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
)
self.updown = up or down
if up:
self.h_upd = Upsample(channels, False, dims)
self.x_upd = Upsample(channels, False, dims)
elif down:
self.h_upd = Downsample(channels, False, dims)
self.x_upd = Downsample(channels, False, dims)
else:
self.h_upd = self.x_upd = nn.Identity()
self.out_layers = nn.Sequential(
normalization(self.out_channels, swish=1.0),
nn.Dropout(p=dropout),
zero_module(
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(
dims, channels, self.out_channels, 3, padding=1
)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
def forward(self, x, emb=None):
"""
Apply the block to a Tensor, NOT conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
assert emb is None
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = in_conv(h)
else:
h = self.in_layers(x)
h = self.out_layers(h)
return self.skip_connection(x) + h
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=-1,
use_checkpoint=False,
encoder_channels=None,
):
super().__init__()
self.channels = channels
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint
self.norm = normalization(channels, swish=0.0)
self.qkv = conv_nd(1, channels, channels * 3, 1)
self.attention = QKVAttention(self.num_heads)
if encoder_channels is not None:
self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1)
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, x, encoder_out=None, mask=None):
b, c, *spatial = x.shape
qkv = self.qkv(self.norm(x).view(b, c, -1))
if encoder_out is not None:
encoder_out = self.encoder_kv(encoder_out)
h = self.attention(qkv, encoder_out, mask=mask)
else:
h = self.attention(qkv)
h = self.proj_out(h)
return x + h.reshape(b, c, *spatial)
class QKVAttention(nn.Module):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv, encoder_kv=None, mask=None):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
if encoder_kv is not None:
assert encoder_kv.shape[1] == self.n_heads * ch * 2
ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
k = th.cat([ek, k], dim=-1)
v = th.cat([ev, v], dim=-1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum("bct,bcs->bts", q * scale, k * scale)
if mask is not None:
mask = F.pad(mask, (0, length), value=0.0)
mask = (
mask.unsqueeze(1)
.expand(-1, self.n_heads, -1)
.reshape(bs * self.n_heads, 1, -1)
)
weight = weight + mask
weight = th.softmax(weight, dim=-1)
a = th.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length)
class UNetModel(nn.Module):
"""
The full UNet model with attention and timestep embedding.
:param in_channels: channels in the input Tensor.
:param model_channels: base channel count for the model.
:param out_channels: channels in the output Tensor.
:param num_res_blocks: number of residual blocks per downsample.
:param attention_resolutions: a collection of downsample rates at which
attention will take place. May be a set, list, or tuple.
For example, if this contains 4, then at 4x downsampling, attention
will be used.
:param dropout: the dropout probability.
:param channel_mult: channel multiplier for each level of the UNet.
:param conv_resample: if True, use learned convolutions for upsampling and
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param clip_dim: dimension of clip feature.
:param num_classes: if specified (as an int), then this model will be
class-conditional with `num_classes` classes.
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
:param num_heads: the number of attention heads in each attention layer.
:param num_heads_channels: if specified, ignore num_heads and instead use
a fixed channel width per attention head.
:param num_heads_upsample: works with num_heads to set a different number
of heads for upsampling. Deprecated.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
:param resblock_updown: use residual blocks for up/downsampling.
:param encoder_channels: use to make the dimension of query and kv same in AttentionBlock.
:param use_time_embedding: use time embedding for condition.
"""
def __init__(
self,
in_channels,
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
clip_dim=None,
use_checkpoint=False,
num_heads=1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
use_middle_attention=True,
resblock_updown=False,
encoder_channels=None,
use_time_embedding=True,
):
super().__init__()
if num_heads_upsample == -1:
num_heads_upsample = num_heads
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.clip_dim = clip_dim
self.use_checkpoint = use_checkpoint
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
self.use_middle_attention = use_middle_attention
self.use_time_embedding = use_time_embedding
if self.use_time_embedding:
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
if self.clip_dim is not None:
self.clip_emb = nn.Linear(clip_dim, time_embed_dim)
else:
time_embed_dim = None
CustomResidualBlock = (
ResBlock if self.use_time_embedding else ResBlockNoTimeEmbedding
)
ch = input_ch = int(channel_mult[0] * model_channels)
self.input_blocks = nn.ModuleList(
[TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
)
self._feature_size = ch
input_block_chans = [ch]
ds = 1
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
CustomResidualBlock(
ch,
time_embed_dim,
dropout,
out_channels=int(mult * model_channels),
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = int(mult * model_channels)
if ds in attention_resolutions:
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=num_head_channels,
encoder_channels=encoder_channels,
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
CustomResidualBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
)
if resblock_updown
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch
)
)
)
ch = out_ch
input_block_chans.append(ch)
ds *= 2
self._feature_size += ch
self.middle_block = TimestepEmbedSequential(
CustomResidualBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
*(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=num_head_channels,
encoder_channels=encoder_channels,
),
)
if self.use_middle_attention
else tuple(), # add AttentionBlock or not
CustomResidualBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
)
self._feature_size += ch
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]:
for i in range(num_res_blocks + 1):
ich = input_block_chans.pop()
layers = [
CustomResidualBlock(
ch + ich,
time_embed_dim,
dropout,
out_channels=int(model_channels * mult),
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = int(model_channels * mult)
if ds in attention_resolutions:
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads_upsample,
num_head_channels=num_head_channels,
encoder_channels=encoder_channels,
)
)
if level and i == num_res_blocks:
out_ch = ch
layers.append(
CustomResidualBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
up=True,
)
if resblock_updown
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
)
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
self.out = nn.Sequential(
normalization(ch, swish=1.0),
nn.Identity(),
zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)),
)
def forward(self, x, timesteps, y=None):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""
assert (y is not None) == (
self.clip_dim is not None
), "must specify y if and only if the model is clip-rep-conditional"
hs = []
if self.use_time_embedding:
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
if self.clip_dim is not None:
emb = emb + self.clip_emb(y)
else:
emb = None
h = x
for module in self.input_blocks:
h = module(h, emb)
hs.append(h)
h = self.middle_block(h, emb)
for module in self.output_blocks:
h = th.cat([h, hs.pop()], dim=1)
h = module(h, emb)
return self.out(h)
class SuperResUNetModel(UNetModel):
"""
A UNetModel that performs super-resolution.
Expects an extra kwarg `low_res` to condition on a low-resolution image.
Assumes that the shape of low-resolution and the input should be the same.
"""
def __init__(self, *args, **kwargs):
if "in_channels" in kwargs:
kwargs = dict(kwargs)
kwargs["in_channels"] = kwargs["in_channels"] * 2
else:
# Curse you, Python. Or really, just curse positional arguments :|.
args = list(args)
args[1] = args[1] * 2
super().__init__(*args, **kwargs)
def forward(self, x, timesteps, low_res=None, **kwargs):
_, _, new_height, new_width = x.shape
assert new_height == low_res.shape[2] and new_width == low_res.shape[3]
x = th.cat([x, low_res], dim=1)
return super().forward(x, timesteps, **kwargs)
class PLMImUNet(UNetModel):
"""
A UNetModel that conditions on text with a pretrained text encoder in CLIP.
:param text_ctx: number of text tokens to expect.
:param xf_width: width of the transformer.
:param clip_emb_mult: #extra tokens by projecting clip text feature.
:param clip_emb_type: type of condition (here, we fix clip image feature).
:param clip_emb_drop: dropout rato of clip image feature for cfg.
"""
def __init__(
self,
text_ctx,
xf_width,
*args,
clip_emb_mult=None,
clip_emb_type="image",
clip_emb_drop=0.0,
**kwargs,
):
self.text_ctx = text_ctx
self.xf_width = xf_width
self.clip_emb_mult = clip_emb_mult
self.clip_emb_type = clip_emb_type
self.clip_emb_drop = clip_emb_drop
if not xf_width:
super().__init__(*args, **kwargs, encoder_channels=None)
else:
super().__init__(*args, **kwargs, encoder_channels=xf_width)
# Project text encoded feat seq from pre-trained text encoder in CLIP
self.text_seq_proj = nn.Sequential(
nn.Linear(self.clip_dim, xf_width),
LayerNorm(xf_width),
)
# Project CLIP text feat
self.text_feat_proj = nn.Linear(self.clip_dim, self.model_channels * 4)
assert clip_emb_mult is not None
assert clip_emb_type == "image"
assert self.clip_dim is not None, "CLIP representation dim should be specified"
self.clip_tok_proj = nn.Linear(
self.clip_dim, self.xf_width * self.clip_emb_mult
)
if self.clip_emb_drop > 0:
self.cf_param = nn.Parameter(th.empty(self.clip_dim, dtype=th.float32))
def proc_clip_emb_drop(self, feat):
if self.clip_emb_drop > 0:
bsz, feat_dim = feat.shape
assert (
feat_dim == self.clip_dim
), f"CLIP input dim: {feat_dim}, model CLIP dim: {self.clip_dim}"
drop_idx = th.rand((bsz,), device=feat.device) < self.clip_emb_drop
feat = th.where(
drop_idx[..., None], self.cf_param[None].type_as(feat), feat
)
return feat
def forward(
self, x, timesteps, txt_feat=None, txt_feat_seq=None, mask=None, y=None
):
bsz = x.shape[0]
hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
emb = emb + self.clip_emb(y)
xf_out = self.text_seq_proj(txt_feat_seq)
xf_out = xf_out.permute(0, 2, 1)
emb = emb + self.text_feat_proj(txt_feat)
xf_out = th.cat(
[
self.clip_tok_proj(y).reshape(bsz, -1, self.clip_emb_mult),
xf_out,
],
dim=2,
)
mask = F.pad(mask, (self.clip_emb_mult, 0), value=True)
mask = th.where(mask, 0.0, float("-inf"))
h = x
for module in self.input_blocks:
h = module(h, emb, xf_out, mask=mask)
hs.append(h)
h = self.middle_block(h, emb, xf_out, mask=mask)
for module in self.output_blocks:
h = th.cat([h, hs.pop()], dim=1)
h = module(h, emb, xf_out, mask=mask)
h = self.out(h)
return h
# ------------------------------------------------------------------------------------
# Adapted from the repos below:
# (a) Guided-Diffusion (https://github.com/openai/guided-diffusion)
# (b) CLIP ViT (https://github.com/openai/CLIP/)
# ------------------------------------------------------------------------------------
import math
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from .nn import timestep_embedding
def convert_module_to_f16(param):
"""
Convert primitive modules to float16.
"""
if isinstance(param, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
param.weight.data = param.weight.data.half()
if param.bias is not None:
param.bias.data = param.bias.data.half()
class LayerNorm(nn.LayerNorm):
"""
Implementation that supports fp16 inputs but fp32 gains/biases.
"""
def forward(self, x: th.Tensor):
return super().forward(x.float()).to(x.dtype)
class MultiheadAttention(nn.Module):
def __init__(self, n_ctx, width, heads):
super().__init__()
self.n_ctx = n_ctx
self.width = width
self.heads = heads
self.c_qkv = nn.Linear(width, width * 3)
self.c_proj = nn.Linear(width, width)
self.attention = QKVMultiheadAttention(heads, n_ctx)
def forward(self, x, mask=None):
x = self.c_qkv(x)
x = self.attention(x, mask=mask)
x = self.c_proj(x)
return x
class MLP(nn.Module):
def __init__(self, width):
super().__init__()
self.width = width
self.c_fc = nn.Linear(width, width * 4)
self.c_proj = nn.Linear(width * 4, width)
self.gelu = nn.GELU()
def forward(self, x):
return self.c_proj(self.gelu(self.c_fc(x)))
class QKVMultiheadAttention(nn.Module):
def __init__(self, n_heads: int, n_ctx: int):
super().__init__()
self.n_heads = n_heads
self.n_ctx = n_ctx
def forward(self, qkv, mask=None):
bs, n_ctx, width = qkv.shape
attn_ch = width // self.n_heads // 3
scale = 1 / math.sqrt(math.sqrt(attn_ch))
qkv = qkv.view(bs, n_ctx, self.n_heads, -1)
q, k, v = th.split(qkv, attn_ch, dim=-1)
weight = th.einsum("bthc,bshc->bhts", q * scale, k * scale)
wdtype = weight.dtype
if mask is not None:
weight = weight + mask[:, None, ...]
weight = th.softmax(weight, dim=-1).type(wdtype)
return th.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
class ResidualAttentionBlock(nn.Module):
def __init__(
self,
n_ctx: int,
width: int,
heads: int,
):
super().__init__()
self.attn = MultiheadAttention(
n_ctx,
width,
heads,
)
self.ln_1 = LayerNorm(width)
self.mlp = MLP(width)
self.ln_2 = LayerNorm(width)
def forward(self, x, mask=None):
x = x + self.attn(self.ln_1(x), mask=mask)
x = x + self.mlp(self.ln_2(x))
return x
class Transformer(nn.Module):
def __init__(
self,
n_ctx: int,
width: int,
layers: int,
heads: int,
):
super().__init__()
self.n_ctx = n_ctx
self.width = width
self.layers = layers
self.resblocks = nn.ModuleList(
[
ResidualAttentionBlock(
n_ctx,
width,
heads,
)
for _ in range(layers)
]
)
def forward(self, x, mask=None):
for block in self.resblocks:
x = block(x, mask=mask)
return x
class PriorTransformer(nn.Module):
"""
A Causal Transformer that conditions on CLIP text embedding, text.
:param text_ctx: number of text tokens to expect.
:param xf_width: width of the transformer.
:param xf_layers: depth of the transformer.
:param xf_heads: heads in the transformer.
:param xf_final_ln: use a LayerNorm after the output layer.
:param clip_dim: dimension of clip feature.
"""
def __init__(
self,
text_ctx,
xf_width,
xf_layers,
xf_heads,
xf_final_ln,
clip_dim,
):
super().__init__()
self.text_ctx = text_ctx
self.xf_width = xf_width
self.xf_layers = xf_layers
self.xf_heads = xf_heads
self.clip_dim = clip_dim
self.ext_len = 4
self.time_embed = nn.Sequential(
nn.Linear(xf_width, xf_width),
nn.SiLU(),
nn.Linear(xf_width, xf_width),
)
self.text_enc_proj = nn.Linear(clip_dim, xf_width)
self.text_emb_proj = nn.Linear(clip_dim, xf_width)
self.clip_img_proj = nn.Linear(clip_dim, xf_width)
self.out_proj = nn.Linear(xf_width, clip_dim)
self.transformer = Transformer(
text_ctx + self.ext_len,
xf_width,
xf_layers,
xf_heads,
)
if xf_final_ln:
self.final_ln = LayerNorm(xf_width)
else:
self.final_ln = None
self.positional_embedding = nn.Parameter(
th.empty(1, text_ctx + self.ext_len, xf_width)
)
self.prd_emb = nn.Parameter(th.randn((1, 1, xf_width)))
nn.init.normal_(self.prd_emb, std=0.01)
nn.init.normal_(self.positional_embedding, std=0.01)
def forward(
self,
x,
timesteps,
text_emb=None,
text_enc=None,
mask=None,
causal_mask=None,
):
bsz = x.shape[0]
mask = F.pad(mask, (0, self.ext_len), value=True)
t_emb = self.time_embed(timestep_embedding(timesteps, self.xf_width))
text_enc = self.text_enc_proj(text_enc)
text_emb = self.text_emb_proj(text_emb)
x = self.clip_img_proj(x)
input_seq = [
text_enc,
text_emb[:, None, :],
t_emb[:, None, :],
x[:, None, :],
self.prd_emb.to(x.dtype).expand(bsz, -1, -1),
]
input = th.cat(input_seq, dim=1)
input = input + self.positional_embedding.to(input.dtype)
mask = th.where(mask, 0.0, float("-inf"))
mask = (mask[:, None, :] + causal_mask).to(input.dtype)
out = self.transformer(input, mask=mask)
if self.final_ln is not None:
out = self.final_ln(out)
out = self.out_proj(out[:, -1])
return out
# ------------------------------------------------------------------------------------
# Karlo-v1.0.alpha
# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
# source: https://github.com/kakaobrain/karlo/blob/3c68a50a16d76b48a15c181d1c5a5e0879a90f85/karlo/sampler/t2i.py#L15
# ------------------------------------------------------------------------------------
from typing import Iterator
import torch
import torchvision.transforms.functional as TVF
from torchvision.transforms import InterpolationMode
from .template import BaseSampler, CKPT_PATH
class T2ISampler(BaseSampler):
"""
A sampler for text-to-image generation.
:param root_dir: directory for model checkpoints.
:param sampling_type: ["default", "fast"]
"""
def __init__(
self,
root_dir: str,
sampling_type: str = "default",
):
super().__init__(root_dir, sampling_type)
@classmethod
def from_pretrained(
cls,
root_dir: str,
clip_model_path: str,
clip_stat_path: str,
sampling_type: str = "default",
):
model = cls(
root_dir=root_dir,
sampling_type=sampling_type,
)
model.load_clip(clip_model_path)
model.load_prior(
f"{CKPT_PATH['prior']}",
clip_stat_path=clip_stat_path,
prior_config="configs/karlo/prior_1B_vit_l.yaml"
)
model.load_decoder(f"{CKPT_PATH['decoder']}", decoder_config="configs/karlo/decoder_900M_vit_l.yaml")
model.load_sr_64_256(CKPT_PATH["sr_256"], sr_config="configs/karlo/improved_sr_64_256_1.4B.yaml")
return model
def preprocess(
self,
prompt: str,
bsz: int,
):
"""Setup prompts & cfg scales"""
prompts_batch = [prompt for _ in range(bsz)]
prior_cf_scales_batch = [self._prior_cf_scale] * len(prompts_batch)
prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device="cuda")
decoder_cf_scales_batch = [self._decoder_cf_scale] * len(prompts_batch)
decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device="cuda")
""" Get CLIP text feature """
clip_model = self._clip
tokenizer = self._tokenizer
max_txt_length = self._prior.model.text_ctx
tok, mask = tokenizer.padded_tokens_and_mask(prompts_batch, max_txt_length)
cf_token, cf_mask = tokenizer.padded_tokens_and_mask([""], max_txt_length)
if not (cf_token.shape == tok.shape):
cf_token = cf_token.expand(tok.shape[0], -1)
cf_mask = cf_mask.expand(tok.shape[0], -1)
tok = torch.cat([tok, cf_token], dim=0)
mask = torch.cat([mask, cf_mask], dim=0)
tok, mask = tok.to(device="cuda"), mask.to(device="cuda")
txt_feat, txt_feat_seq = clip_model.encode_text(tok)
return (
prompts_batch,
prior_cf_scales_batch,
decoder_cf_scales_batch,
txt_feat,
txt_feat_seq,
tok,
mask,
)
def __call__(
self,
prompt: str,
bsz: int,
progressive_mode=None,
) -> Iterator[torch.Tensor]:
assert progressive_mode in ("loop", "stage", "final")
with torch.no_grad(), torch.cuda.amp.autocast():
(
prompts_batch,
prior_cf_scales_batch,
decoder_cf_scales_batch,
txt_feat,
txt_feat_seq,
tok,
mask,
) = self.preprocess(
prompt,
bsz,
)
""" Transform CLIP text feature into image feature """
img_feat = self._prior(
txt_feat,
txt_feat_seq,
mask,
prior_cf_scales_batch,
timestep_respacing=self._prior_sm,
)
""" Generate 64x64px images """
images_64_outputs = self._decoder(
txt_feat,
txt_feat_seq,
tok,
mask,
img_feat,
cf_guidance_scales=decoder_cf_scales_batch,
timestep_respacing=self._decoder_sm,
)
images_64 = None
for k, out in enumerate(images_64_outputs):
images_64 = out
if progressive_mode == "loop":
yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0)
if progressive_mode == "stage":
yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0)
images_64 = torch.clamp(images_64, -1, 1)
""" Upsample 64x64 to 256x256 """
images_256 = TVF.resize(
images_64,
[256, 256],
interpolation=InterpolationMode.BICUBIC,
antialias=True,
)
images_256_outputs = self._sr_64_256(
images_256, timestep_respacing=self._sr_sm
)
for k, out in enumerate(images_256_outputs):
images_256 = out
if progressive_mode == "loop":
yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0)
if progressive_mode == "stage":
yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0)
yield torch.clamp(images_256 * 0.5 + 0.5, 0.0, 1.0)
class PriorSampler(BaseSampler):
"""
A sampler for text-to-image generation, but only the prior.
:param root_dir: directory for model checkpoints.
:param sampling_type: ["default", "fast"]
"""
def __init__(
self,
root_dir: str,
sampling_type: str = "default",
):
super().__init__(root_dir, sampling_type)
@classmethod
def from_pretrained(
cls,
root_dir: str,
clip_model_path: str,
clip_stat_path: str,
sampling_type: str = "default",
):
model = cls(
root_dir=root_dir,
sampling_type=sampling_type,
)
model.load_clip(clip_model_path)
model.load_prior(
f"{CKPT_PATH['prior']}",
clip_stat_path=clip_stat_path,
prior_config="configs/karlo/prior_1B_vit_l.yaml"
)
return model
def preprocess(
self,
prompt: str,
bsz: int,
):
"""Setup prompts & cfg scales"""
prompts_batch = [prompt for _ in range(bsz)]
prior_cf_scales_batch = [self._prior_cf_scale] * len(prompts_batch)
prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device="cuda")
decoder_cf_scales_batch = [self._decoder_cf_scale] * len(prompts_batch)
decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device="cuda")
""" Get CLIP text feature """
clip_model = self._clip
tokenizer = self._tokenizer
max_txt_length = self._prior.model.text_ctx
tok, mask = tokenizer.padded_tokens_and_mask(prompts_batch, max_txt_length)
cf_token, cf_mask = tokenizer.padded_tokens_and_mask([""], max_txt_length)
if not (cf_token.shape == tok.shape):
cf_token = cf_token.expand(tok.shape[0], -1)
cf_mask = cf_mask.expand(tok.shape[0], -1)
tok = torch.cat([tok, cf_token], dim=0)
mask = torch.cat([mask, cf_mask], dim=0)
tok, mask = tok.to(device="cuda"), mask.to(device="cuda")
txt_feat, txt_feat_seq = clip_model.encode_text(tok)
return (
prompts_batch,
prior_cf_scales_batch,
decoder_cf_scales_batch,
txt_feat,
txt_feat_seq,
tok,
mask,
)
def __call__(
self,
prompt: str,
bsz: int,
progressive_mode=None,
) -> Iterator[torch.Tensor]:
assert progressive_mode in ("loop", "stage", "final")
with torch.no_grad(), torch.cuda.amp.autocast():
(
prompts_batch,
prior_cf_scales_batch,
decoder_cf_scales_batch,
txt_feat,
txt_feat_seq,
tok,
mask,
) = self.preprocess(
prompt,
bsz,
)
""" Transform CLIP text feature into image feature """
img_feat = self._prior(
txt_feat,
txt_feat_seq,
mask,
prior_cf_scales_batch,
timestep_respacing=self._prior_sm,
)
yield img_feat
# ------------------------------------------------------------------------------------
# Karlo-v1.0.alpha
# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
# ------------------------------------------------------------------------------------
import os
import logging
import torch
from omegaconf import OmegaConf
from ldm.modules.karlo.kakao.models.clip import CustomizedCLIP, CustomizedTokenizer
from ldm.modules.karlo.kakao.models.prior_model import PriorDiffusionModel
from ldm.modules.karlo.kakao.models.decoder_model import Text2ImProgressiveModel
from ldm.modules.karlo.kakao.models.sr_64_256 import ImprovedSupRes64to256ProgressiveModel
SAMPLING_CONF = {
"default": {
"prior_sm": "25",
"prior_n_samples": 1,
"prior_cf_scale": 4.0,
"decoder_sm": "50",
"decoder_cf_scale": 8.0,
"sr_sm": "7",
},
"fast": {
"prior_sm": "25",
"prior_n_samples": 1,
"prior_cf_scale": 4.0,
"decoder_sm": "25",
"decoder_cf_scale": 8.0,
"sr_sm": "7",
},
}
CKPT_PATH = {
"prior": "prior-ckpt-step=01000000-of-01000000.ckpt",
"decoder": "decoder-ckpt-step=01000000-of-01000000.ckpt",
"sr_256": "improved-sr-ckpt-step=1.2M.ckpt",
}
class BaseSampler:
_PRIOR_CLASS = PriorDiffusionModel
_DECODER_CLASS = Text2ImProgressiveModel
_SR256_CLASS = ImprovedSupRes64to256ProgressiveModel
def __init__(
self,
root_dir: str,
sampling_type: str = "fast",
):
self._root_dir = root_dir
sampling_type = SAMPLING_CONF[sampling_type]
self._prior_sm = sampling_type["prior_sm"]
self._prior_n_samples = sampling_type["prior_n_samples"]
self._prior_cf_scale = sampling_type["prior_cf_scale"]
assert self._prior_n_samples == 1
self._decoder_sm = sampling_type["decoder_sm"]
self._decoder_cf_scale = sampling_type["decoder_cf_scale"]
self._sr_sm = sampling_type["sr_sm"]
def __repr__(self):
line = ""
line += f"Prior, sampling method: {self._prior_sm}, cf_scale: {self._prior_cf_scale}\n"
line += f"Decoder, sampling method: {self._decoder_sm}, cf_scale: {self._decoder_cf_scale}\n"
line += f"SR(64->256), sampling method: {self._sr_sm}"
return line
def load_clip(self, clip_path: str):
clip = CustomizedCLIP.load_from_checkpoint(
os.path.join(self._root_dir, clip_path)
)
clip = torch.jit.script(clip)
clip.cuda()
clip.eval()
self._clip = clip
self._tokenizer = CustomizedTokenizer()
def load_prior(
self,
ckpt_path: str,
clip_stat_path: str,
prior_config: str = "configs/prior_1B_vit_l.yaml"
):
logging.info(f"Loading prior: {ckpt_path}")
config = OmegaConf.load(prior_config)
clip_mean, clip_std = torch.load(
os.path.join(self._root_dir, clip_stat_path), map_location="cpu"
)
prior = self._PRIOR_CLASS.load_from_checkpoint(
config,
self._tokenizer,
clip_mean,
clip_std,
os.path.join(self._root_dir, ckpt_path),
strict=True,
)
prior.cuda()
prior.eval()
logging.info("done.")
self._prior = prior
def load_decoder(self, ckpt_path: str, decoder_config: str = "configs/decoder_900M_vit_l.yaml"):
logging.info(f"Loading decoder: {ckpt_path}")
config = OmegaConf.load(decoder_config)
decoder = self._DECODER_CLASS.load_from_checkpoint(
config,
self._tokenizer,
os.path.join(self._root_dir, ckpt_path),
strict=True,
)
decoder.cuda()
decoder.eval()
logging.info("done.")
self._decoder = decoder
def load_sr_64_256(self, ckpt_path: str, sr_config: str = "configs/improved_sr_64_256_1.4B.yaml"):
logging.info(f"Loading SR(64->256): {ckpt_path}")
config = OmegaConf.load(sr_config)
sr = self._SR256_CLASS.load_from_checkpoint(
config, os.path.join(self._root_dir, ckpt_path), strict=True
)
sr.cuda()
sr.eval()
logging.info("done.")
self._sr_64_256 = sr
\ No newline at end of file
# based on https://github.com/isl-org/MiDaS
import cv2
import torch
import torch.nn as nn
from torchvision.transforms import Compose
from ldm.modules.midas.midas.dpt_depth import DPTDepthModel
from ldm.modules.midas.midas.midas_net import MidasNet
from ldm.modules.midas.midas.midas_net_custom import MidasNet_small
from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet
ISL_PATHS = {
"dpt_large": "midas_models/dpt_large-midas-2f21e586.pt",
"dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt",
"midas_v21": "",
"midas_v21_small": "",
}
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
def load_midas_transform(model_type):
# https://github.com/isl-org/MiDaS/blob/master/run.py
# load transform only
if model_type == "dpt_large": # DPT-Large
net_w, net_h = 384, 384
resize_mode = "minimal"
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
elif model_type == "dpt_hybrid": # DPT-Hybrid
net_w, net_h = 384, 384
resize_mode = "minimal"
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
elif model_type == "midas_v21":
net_w, net_h = 384, 384
resize_mode = "upper_bound"
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
elif model_type == "midas_v21_small":
net_w, net_h = 256, 256
resize_mode = "upper_bound"
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
else:
assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
transform = Compose(
[
Resize(
net_w,
net_h,
resize_target=None,
keep_aspect_ratio=True,
ensure_multiple_of=32,
resize_method=resize_mode,
image_interpolation_method=cv2.INTER_CUBIC,
),
normalization,
PrepareForNet(),
]
)
return transform
def load_model(model_type):
# https://github.com/isl-org/MiDaS/blob/master/run.py
# load network
model_path = ISL_PATHS[model_type]
if model_type == "dpt_large": # DPT-Large
model = DPTDepthModel(
path=model_path,
backbone="vitl16_384",
non_negative=True,
)
net_w, net_h = 384, 384
resize_mode = "minimal"
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
elif model_type == "dpt_hybrid": # DPT-Hybrid
model = DPTDepthModel(
path=model_path,
backbone="vitb_rn50_384",
non_negative=True,
)
net_w, net_h = 384, 384
resize_mode = "minimal"
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
elif model_type == "midas_v21":
model = MidasNet(model_path, non_negative=True)
net_w, net_h = 384, 384
resize_mode = "upper_bound"
normalization = NormalizeImage(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
elif model_type == "midas_v21_small":
model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
non_negative=True, blocks={'expand': True})
net_w, net_h = 256, 256
resize_mode = "upper_bound"
normalization = NormalizeImage(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
else:
print(f"model_type '{model_type}' not implemented, use: --model_type large")
assert False
transform = Compose(
[
Resize(
net_w,
net_h,
resize_target=None,
keep_aspect_ratio=True,
ensure_multiple_of=32,
resize_method=resize_mode,
image_interpolation_method=cv2.INTER_CUBIC,
),
normalization,
PrepareForNet(),
]
)
return model.eval(), transform
class MiDaSInference(nn.Module):
MODEL_TYPES_TORCH_HUB = [
"DPT_Large",
"DPT_Hybrid",
"MiDaS_small"
]
MODEL_TYPES_ISL = [
"dpt_large",
"dpt_hybrid",
"midas_v21",
"midas_v21_small",
]
def __init__(self, model_type):
super().__init__()
assert (model_type in self.MODEL_TYPES_ISL)
model, _ = load_model(model_type)
self.model = model
self.model.train = disabled_train
def forward(self, x):
# x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array
# NOTE: we expect that the correct transform has been called during dataloading.
with torch.no_grad():
prediction = self.model(x)
prediction = torch.nn.functional.interpolate(
prediction.unsqueeze(1),
size=x.shape[2:],
mode="bicubic",
align_corners=False,
)
assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3])
return prediction
import torch
class BaseModel(torch.nn.Module):
def load(self, path):
"""Load model from file.
Args:
path (str): file path
"""
parameters = torch.load(path, map_location=torch.device('cpu'))
if "optimizer" in parameters:
parameters = parameters["model"]
self.load_state_dict(parameters)
import torch
import torch.nn as nn
from .vit import (
_make_pretrained_vitb_rn50_384,
_make_pretrained_vitl16_384,
_make_pretrained_vitb16_384,
forward_vit,
)
def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
if backbone == "vitl16_384":
pretrained = _make_pretrained_vitl16_384(
use_pretrained, hooks=hooks, use_readout=use_readout
)
scratch = _make_scratch(
[256, 512, 1024, 1024], features, groups=groups, expand=expand
) # ViT-L/16 - 85.0% Top1 (backbone)
elif backbone == "vitb_rn50_384":
pretrained = _make_pretrained_vitb_rn50_384(
use_pretrained,
hooks=hooks,
use_vit_only=use_vit_only,
use_readout=use_readout,
)
scratch = _make_scratch(
[256, 512, 768, 768], features, groups=groups, expand=expand
) # ViT-H/16 - 85.0% Top1 (backbone)
elif backbone == "vitb16_384":
pretrained = _make_pretrained_vitb16_384(
use_pretrained, hooks=hooks, use_readout=use_readout
)
scratch = _make_scratch(
[96, 192, 384, 768], features, groups=groups, expand=expand
) # ViT-B/16 - 84.6% Top1 (backbone)
elif backbone == "resnext101_wsl":
pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
elif backbone == "efficientnet_lite3":
pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
else:
print(f"Backbone '{backbone}' not implemented")
assert False
return pretrained, scratch
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
scratch = nn.Module()
out_shape1 = out_shape
out_shape2 = out_shape
out_shape3 = out_shape
out_shape4 = out_shape
if expand==True:
out_shape1 = out_shape
out_shape2 = out_shape*2
out_shape3 = out_shape*4
out_shape4 = out_shape*8
scratch.layer1_rn = nn.Conv2d(
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer2_rn = nn.Conv2d(
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer3_rn = nn.Conv2d(
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer4_rn = nn.Conv2d(
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
return scratch
def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
efficientnet = torch.hub.load(
"rwightman/gen-efficientnet-pytorch",
"tf_efficientnet_lite3",
pretrained=use_pretrained,
exportable=exportable
)
return _make_efficientnet_backbone(efficientnet)
def _make_efficientnet_backbone(effnet):
pretrained = nn.Module()
pretrained.layer1 = nn.Sequential(
effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
)
pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
return pretrained
def _make_resnet_backbone(resnet):
pretrained = nn.Module()
pretrained.layer1 = nn.Sequential(
resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
)
pretrained.layer2 = resnet.layer2
pretrained.layer3 = resnet.layer3
pretrained.layer4 = resnet.layer4
return pretrained
def _make_pretrained_resnext101_wsl(use_pretrained):
resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
return _make_resnet_backbone(resnet)
class Interpolate(nn.Module):
"""Interpolation module.
"""
def __init__(self, scale_factor, mode, align_corners=False):
"""Init.
Args:
scale_factor (float): scaling
mode (str): interpolation mode
"""
super(Interpolate, self).__init__()
self.interp = nn.functional.interpolate
self.scale_factor = scale_factor
self.mode = mode
self.align_corners = align_corners
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: interpolated data
"""
x = self.interp(
x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
)
return x
class ResidualConvUnit(nn.Module):
"""Residual convolution module.
"""
def __init__(self, features):
"""Init.
Args:
features (int): number of features
"""
super().__init__()
self.conv1 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True
)
self.conv2 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True
)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
out = self.relu(x)
out = self.conv1(out)
out = self.relu(out)
out = self.conv2(out)
return out + x
class FeatureFusionBlock(nn.Module):
"""Feature fusion block.
"""
def __init__(self, features):
"""Init.
Args:
features (int): number of features
"""
super(FeatureFusionBlock, self).__init__()
self.resConfUnit1 = ResidualConvUnit(features)
self.resConfUnit2 = ResidualConvUnit(features)
def forward(self, *xs):
"""Forward pass.
Returns:
tensor: output
"""
output = xs[0]
if len(xs) == 2:
output += self.resConfUnit1(xs[1])
output = self.resConfUnit2(output)
output = nn.functional.interpolate(
output, scale_factor=2, mode="bilinear", align_corners=True
)
return output
class ResidualConvUnit_custom(nn.Module):
"""Residual convolution module.
"""
def __init__(self, features, activation, bn):
"""Init.
Args:
features (int): number of features
"""
super().__init__()
self.bn = bn
self.groups=1
self.conv1 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
)
self.conv2 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
)
if self.bn==True:
self.bn1 = nn.BatchNorm2d(features)
self.bn2 = nn.BatchNorm2d(features)
self.activation = activation
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
out = self.activation(x)
out = self.conv1(out)
if self.bn==True:
out = self.bn1(out)
out = self.activation(out)
out = self.conv2(out)
if self.bn==True:
out = self.bn2(out)
if self.groups > 1:
out = self.conv_merge(out)
return self.skip_add.add(out, x)
# return out + x
class FeatureFusionBlock_custom(nn.Module):
"""Feature fusion block.
"""
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
"""Init.
Args:
features (int): number of features
"""
super(FeatureFusionBlock_custom, self).__init__()
self.deconv = deconv
self.align_corners = align_corners
self.groups=1
self.expand = expand
out_features = features
if self.expand==True:
out_features = features//2
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, *xs):
"""Forward pass.
Returns:
tensor: output
"""
output = xs[0]
if len(xs) == 2:
res = self.resConfUnit1(xs[1])
output = self.skip_add.add(output, res)
# output += res
output = self.resConfUnit2(output)
output = nn.functional.interpolate(
output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
)
output = self.out_conv(output)
return output
import torch
import torch.nn as nn
import torch.nn.functional as F
from .base_model import BaseModel
from .blocks import (
FeatureFusionBlock,
FeatureFusionBlock_custom,
Interpolate,
_make_encoder,
forward_vit,
)
def _make_fusion_block(features, use_bn):
return FeatureFusionBlock_custom(
features,
nn.ReLU(False),
deconv=False,
bn=use_bn,
expand=False,
align_corners=True,
)
class DPT(BaseModel):
def __init__(
self,
head,
features=256,
backbone="vitb_rn50_384",
readout="project",
channels_last=False,
use_bn=False,
):
super(DPT, self).__init__()
self.channels_last = channels_last
hooks = {
"vitb_rn50_384": [0, 1, 8, 11],
"vitb16_384": [2, 5, 8, 11],
"vitl16_384": [5, 11, 17, 23],
}
# Instantiate backbone and reassemble blocks
self.pretrained, self.scratch = _make_encoder(
backbone,
features,
False, # Set to true of you want to train from scratch, uses ImageNet weights
groups=1,
expand=False,
exportable=False,
hooks=hooks[backbone],
use_readout=readout,
)
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
self.scratch.output_conv = head
def forward(self, x):
if self.channels_last == True:
x.contiguous(memory_format=torch.channels_last)
layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)
path_4 = self.scratch.refinenet4(layer_4_rn)
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
out = self.scratch.output_conv(path_1)
return out
class DPTDepthModel(DPT):
def __init__(self, path=None, non_negative=True, **kwargs):
features = kwargs["features"] if "features" in kwargs else 256
head = nn.Sequential(
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
nn.ReLU(True) if non_negative else nn.Identity(),
nn.Identity(),
)
super().__init__(head, **kwargs)
if path is not None:
self.load(path)
def forward(self, x):
return super().forward(x).squeeze(dim=1)
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
This file contains code that is adapted from
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
"""
import torch
import torch.nn as nn
from .base_model import BaseModel
from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
class MidasNet(BaseModel):
"""Network for monocular depth estimation.
"""
def __init__(self, path=None, features=256, non_negative=True):
"""Init.
Args:
path (str, optional): Path to saved model. Defaults to None.
features (int, optional): Number of features. Defaults to 256.
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
"""
print("Loading weights: ", path)
super(MidasNet, self).__init__()
use_pretrained = False if path is None else True
self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
self.scratch.refinenet4 = FeatureFusionBlock(features)
self.scratch.refinenet3 = FeatureFusionBlock(features)
self.scratch.refinenet2 = FeatureFusionBlock(features)
self.scratch.refinenet1 = FeatureFusionBlock(features)
self.scratch.output_conv = nn.Sequential(
nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
Interpolate(scale_factor=2, mode="bilinear"),
nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
nn.ReLU(True) if non_negative else nn.Identity(),
)
if path:
self.load(path)
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input data (image)
Returns:
tensor: depth
"""
layer_1 = self.pretrained.layer1(x)
layer_2 = self.pretrained.layer2(layer_1)
layer_3 = self.pretrained.layer3(layer_2)
layer_4 = self.pretrained.layer4(layer_3)
layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)
path_4 = self.scratch.refinenet4(layer_4_rn)
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
out = self.scratch.output_conv(path_1)
return torch.squeeze(out, dim=1)
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
This file contains code that is adapted from
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
"""
import torch
import torch.nn as nn
from .base_model import BaseModel
from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
class MidasNet_small(BaseModel):
"""Network for monocular depth estimation.
"""
def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
blocks={'expand': True}):
"""Init.
Args:
path (str, optional): Path to saved model. Defaults to None.
features (int, optional): Number of features. Defaults to 256.
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
"""
print("Loading weights: ", path)
super(MidasNet_small, self).__init__()
use_pretrained = False if path else True
self.channels_last = channels_last
self.blocks = blocks
self.backbone = backbone
self.groups = 1
features1=features
features2=features
features3=features
features4=features
self.expand = False
if "expand" in self.blocks and self.blocks['expand'] == True:
self.expand = True
features1=features
features2=features*2
features3=features*4
features4=features*8
self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
self.scratch.activation = nn.ReLU(False)
self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
self.scratch.output_conv = nn.Sequential(
nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
Interpolate(scale_factor=2, mode="bilinear"),
nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
self.scratch.activation,
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
nn.ReLU(True) if non_negative else nn.Identity(),
nn.Identity(),
)
if path:
self.load(path)
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input data (image)
Returns:
tensor: depth
"""
if self.channels_last==True:
print("self.channels_last = ", self.channels_last)
x.contiguous(memory_format=torch.channels_last)
layer_1 = self.pretrained.layer1(x)
layer_2 = self.pretrained.layer2(layer_1)
layer_3 = self.pretrained.layer3(layer_2)
layer_4 = self.pretrained.layer4(layer_3)
layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)
path_4 = self.scratch.refinenet4(layer_4_rn)
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
out = self.scratch.output_conv(path_1)
return torch.squeeze(out, dim=1)
def fuse_model(m):
prev_previous_type = nn.Identity()
prev_previous_name = ''
previous_type = nn.Identity()
previous_name = ''
for name, module in m.named_modules():
if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
# print("FUSED ", prev_previous_name, previous_name, name)
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
# print("FUSED ", prev_previous_name, previous_name)
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
# elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
# print("FUSED ", previous_name, name)
# torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
prev_previous_type = previous_type
prev_previous_name = previous_name
previous_type = type(module)
previous_name = name
\ No newline at end of file
import numpy as np
import cv2
import math
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
Args:
sample (dict): sample
size (tuple): image size
Returns:
tuple: new size
"""
shape = list(sample["disparity"].shape)
if shape[0] >= size[0] and shape[1] >= size[1]:
return sample
scale = [0, 0]
scale[0] = size[0] / shape[0]
scale[1] = size[1] / shape[1]
scale = max(scale)
shape[0] = math.ceil(scale * shape[0])
shape[1] = math.ceil(scale * shape[1])
# resize
sample["image"] = cv2.resize(
sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
)
sample["disparity"] = cv2.resize(
sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
)
sample["mask"] = cv2.resize(
sample["mask"].astype(np.float32),
tuple(shape[::-1]),
interpolation=cv2.INTER_NEAREST,
)
sample["mask"] = sample["mask"].astype(bool)
return tuple(shape)
class Resize(object):
"""Resize sample to given size (width, height).
"""
def __init__(
self,
width,
height,
resize_target=True,
keep_aspect_ratio=False,
ensure_multiple_of=1,
resize_method="lower_bound",
image_interpolation_method=cv2.INTER_AREA,
):
"""Init.
Args:
width (int): desired output width
height (int): desired output height
resize_target (bool, optional):
True: Resize the full sample (image, mask, target).
False: Resize image only.
Defaults to True.
keep_aspect_ratio (bool, optional):
True: Keep the aspect ratio of the input sample.
Output sample might not have the given width and height, and
resize behaviour depends on the parameter 'resize_method'.
Defaults to False.
ensure_multiple_of (int, optional):
Output width and height is constrained to be multiple of this parameter.
Defaults to 1.
resize_method (str, optional):
"lower_bound": Output will be at least as large as the given size.
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
Defaults to "lower_bound".
"""
self.__width = width
self.__height = height
self.__resize_target = resize_target
self.__keep_aspect_ratio = keep_aspect_ratio
self.__multiple_of = ensure_multiple_of
self.__resize_method = resize_method
self.__image_interpolation_method = image_interpolation_method
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
if max_val is not None and y > max_val:
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
if y < min_val:
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
return y
def get_size(self, width, height):
# determine new height and width
scale_height = self.__height / height
scale_width = self.__width / width
if self.__keep_aspect_ratio:
if self.__resize_method == "lower_bound":
# scale such that output size is lower bound
if scale_width > scale_height:
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
elif self.__resize_method == "upper_bound":
# scale such that output size is upper bound
if scale_width < scale_height:
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
elif self.__resize_method == "minimal":
# scale as least as possbile
if abs(1 - scale_width) < abs(1 - scale_height):
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
else:
raise ValueError(
f"resize_method {self.__resize_method} not implemented"
)
if self.__resize_method == "lower_bound":
new_height = self.constrain_to_multiple_of(
scale_height * height, min_val=self.__height
)
new_width = self.constrain_to_multiple_of(
scale_width * width, min_val=self.__width
)
elif self.__resize_method == "upper_bound":
new_height = self.constrain_to_multiple_of(
scale_height * height, max_val=self.__height
)
new_width = self.constrain_to_multiple_of(
scale_width * width, max_val=self.__width
)
elif self.__resize_method == "minimal":
new_height = self.constrain_to_multiple_of(scale_height * height)
new_width = self.constrain_to_multiple_of(scale_width * width)
else:
raise ValueError(f"resize_method {self.__resize_method} not implemented")
return (new_width, new_height)
def __call__(self, sample):
width, height = self.get_size(
sample["image"].shape[1], sample["image"].shape[0]
)
# resize sample
sample["image"] = cv2.resize(
sample["image"],
(width, height),
interpolation=self.__image_interpolation_method,
)
if self.__resize_target:
if "disparity" in sample:
sample["disparity"] = cv2.resize(
sample["disparity"],
(width, height),
interpolation=cv2.INTER_NEAREST,
)
if "depth" in sample:
sample["depth"] = cv2.resize(
sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
)
sample["mask"] = cv2.resize(
sample["mask"].astype(np.float32),
(width, height),
interpolation=cv2.INTER_NEAREST,
)
sample["mask"] = sample["mask"].astype(bool)
return sample
class NormalizeImage(object):
"""Normlize image by given mean and std.
"""
def __init__(self, mean, std):
self.__mean = mean
self.__std = std
def __call__(self, sample):
sample["image"] = (sample["image"] - self.__mean) / self.__std
return sample
class PrepareForNet(object):
"""Prepare sample for usage as network input.
"""
def __init__(self):
pass
def __call__(self, sample):
image = np.transpose(sample["image"], (2, 0, 1))
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
if "mask" in sample:
sample["mask"] = sample["mask"].astype(np.float32)
sample["mask"] = np.ascontiguousarray(sample["mask"])
if "disparity" in sample:
disparity = sample["disparity"].astype(np.float32)
sample["disparity"] = np.ascontiguousarray(disparity)
if "depth" in sample:
depth = sample["depth"].astype(np.float32)
sample["depth"] = np.ascontiguousarray(depth)
return sample
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment