"comfy/vscode:/vscode.git/clone" did not exist on "89507f8adff4aff4507b6f35a67717badaecd4ac"
Commit a268a574 authored by comfyanonymous's avatar comfyanonymous
Browse files

Remove a bunch of useless code.

DDIM is the same as euler with a small difference in the inpaint code.
DDIM uses randn_like but I set a fixed seed instead.

I'm keeping it in because I'm sure if I remove it people are going to
complain.
parent 1777b54d
"""SAMPLING ONLY."""
import torch
import numpy as np
from tqdm import tqdm
from comfy.ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
class DDIMSampler(object):
def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
self.device = device
self.parameterization = kwargs.get("parameterization", "eps")
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != self.device:
attr = attr.float().to(self.device)
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
self.make_schedule_timesteps(ddim_timesteps, ddim_eta=ddim_eta, verbose=verbose)
def make_schedule_timesteps(self, ddim_timesteps, ddim_eta=0., verbose=True):
self.ddim_timesteps = torch.tensor(ddim_timesteps)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device)
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
@torch.no_grad()
def sample_custom(self,
ddim_timesteps,
conditioning=None,
callback=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
dynamic_threshold=None,
ucg_schedule=None,
denoise_function=None,
extra_args=None,
to_zero=True,
end_step=None,
disable_pbar=False,
**kwargs
):
self.make_schedule_timesteps(ddim_timesteps=ddim_timesteps, ddim_eta=eta, verbose=verbose)
samples, intermediates = self.ddim_sampling(conditioning, x_T.shape,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
ucg_schedule=ucg_schedule,
denoise_function=denoise_function,
extra_args=extra_args,
to_zero=to_zero,
end_step=end_step,
disable_pbar=disable_pbar
)
return samples, intermediates
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
dynamic_threshold=None,
ucg_schedule=None,
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list): ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
elif isinstance(conditioning, list):
for ctmp in conditioning:
if ctmp.shape[0] != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
samples, intermediates = self.ddim_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
ucg_schedule=ucg_schedule,
denoise_function=None,
extra_args=None
)
return samples, intermediates
def q_sample(self, x_start, t, noise=None):
if noise is None:
noise = torch.randn_like(x_start)
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
@torch.no_grad()
def ddim_sampling(self, cond, shape,
x_T=None, ddim_use_original_steps=False,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None, disable_pbar=False):
device = self.model.alphas_cumprod.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else timesteps.flip(0)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
# print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range[:end_step], desc='DDIM Sampler', total=end_step, disable=disable_pbar)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
if ucg_schedule is not None:
assert len(ucg_schedule) == len(time_range)
unconditional_guidance_scale = ucg_schedule[i]
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold, denoise_function=denoise_function, extra_args=extra_args)
img, pred_x0 = outs
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
if to_zero:
img = pred_x0
else:
if ddim_use_original_steps:
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
img /= sqrt_alphas_cumprod[index - 1]
return img, intermediates
@torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,
dynamic_threshold=None, denoise_function=None, extra_args=None):
b, *_, device = *x.shape, x.device
if denoise_function is not None:
model_output = denoise_function(x, t, **extra_args)
elif unconditional_conditioning is None or unconditional_guidance_scale == 1.:
model_output = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
if isinstance(c, dict):
assert isinstance(unconditional_conditioning, dict)
c_in = dict()
for k in c:
if isinstance(c[k], list):
c_in[k] = [torch.cat([
unconditional_conditioning[k][i],
c[k][i]]) for i in range(len(c[k]))]
else:
c_in[k] = torch.cat([
unconditional_conditioning[k],
c[k]])
elif isinstance(c, list):
c_in = list()
assert isinstance(unconditional_conditioning, list)
for i in range(len(c)):
c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
else:
c_in = torch.cat([unconditional_conditioning, c])
model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
if self.parameterization == "v":
e_t = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * model_output + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
else:
e_t = model_output
if score_corrector is not None:
assert self.parameterization == "eps", 'not implemented'
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
if self.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
else:
pred_x0 = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * x - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * model_output
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
raise NotImplementedError()
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
@torch.no_grad()
def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
assert t_enc <= num_reference_steps
num_steps = t_enc
if use_original_steps:
alphas_next = self.alphas_cumprod[:num_steps]
alphas = self.alphas_cumprod_prev[:num_steps]
else:
alphas_next = self.ddim_alphas[:num_steps]
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
x_next = x0
intermediates = []
inter_steps = []
for i in tqdm(range(num_steps), desc='Encoding Image'):
t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
if unconditional_guidance_scale == 1.:
noise_pred = self.model.apply_model(x_next, t, c)
else:
assert unconditional_conditioning is not None
e_t_uncond, noise_pred = torch.chunk(
self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
torch.cat((unconditional_conditioning, c))), 2)
noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
weighted_noise_pred = alphas_next[i].sqrt() * (
(1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
x_next = xt_weighted + weighted_noise_pred
if return_intermediates and i % (
num_steps // return_intermediates) == 0 and i < num_steps - 1:
intermediates.append(x_next)
inter_steps.append(i)
elif return_intermediates and i >= num_steps - 2:
intermediates.append(x_next)
inter_steps.append(i)
if callback: callback(i)
out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
if return_intermediates:
out.update({'intermediates': intermediates})
return x_next, out
@torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None, max_denoise=False):
# fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas
if use_original_steps:
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
if noise is None:
noise = torch.randn_like(x0)
if max_denoise:
noise_multiplier = 1.0
else:
noise_multiplier = extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape)
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + noise_multiplier * noise)
@torch.no_grad()
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
use_original_steps=False, callback=None):
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
timesteps = timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
if callback: callback(i)
return x_dec
\ No newline at end of file
from .sampler import DPMSolverSampler
\ No newline at end of file
import torch
import torch.nn.functional as F
import math
from tqdm import tqdm
class NoiseScheduleVP:
def __init__(
self,
schedule='discrete',
betas=None,
alphas_cumprod=None,
continuous_beta_0=0.1,
continuous_beta_1=20.,
):
"""Create a wrapper class for the forward SDE (VP type).
***
Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
***
The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
log_alpha_t = self.marginal_log_mean_coeff(t)
sigma_t = self.marginal_std(t)
lambda_t = self.marginal_lambda(t)
Moreover, as lambda(t) is an invertible function, we also support its inverse function:
t = self.inverse_lambda(lambda_t)
===============================================================
We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
1. For discrete-time DPMs:
For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
t_i = (i + 1) / N
e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
Args:
betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
**Important**: Please pay special attention for the args for `alphas_cumprod`:
The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
alpha_{t_n} = \sqrt{\hat{alpha_n}},
and
log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
2. For continuous-time DPMs:
We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
schedule are the default settings in DDPM and improved-DDPM:
Args:
beta_min: A `float` number. The smallest beta for the linear schedule.
beta_max: A `float` number. The largest beta for the linear schedule.
cosine_s: A `float` number. The hyperparameter in the cosine schedule.
cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
T: A `float` number. The ending time of the forward process.
===============================================================
Args:
schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
'linear' or 'cosine' for continuous-time DPMs.
Returns:
A wrapper object of the forward SDE (VP type).
===============================================================
Example:
# For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
>>> ns = NoiseScheduleVP('discrete', betas=betas)
# For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
>>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
# For continuous-time DPMs (VPSDE), linear schedule:
>>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
"""
if schedule not in ['discrete', 'linear', 'cosine']:
raise ValueError(
"Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
schedule))
self.schedule = schedule
if schedule == 'discrete':
if betas is not None:
log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
else:
assert alphas_cumprod is not None
log_alphas = 0.5 * torch.log(alphas_cumprod)
self.total_N = len(log_alphas)
self.T = 1.
self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
self.log_alpha_array = log_alphas.reshape((1, -1,))
else:
self.total_N = 1000
self.beta_0 = continuous_beta_0
self.beta_1 = continuous_beta_1
self.cosine_s = 0.008
self.cosine_beta_max = 999.
self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (
1. + self.cosine_s) / math.pi - self.cosine_s
self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
self.schedule = schedule
if schedule == 'cosine':
# For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
# Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
self.T = 0.9946
else:
self.T = 1.
def marginal_log_mean_coeff(self, t):
"""
Compute log(alpha_t) of a given continuous-time label t in [0, T].
"""
if self.schedule == 'discrete':
return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device),
self.log_alpha_array.to(t.device)).reshape((-1))
elif self.schedule == 'linear':
return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
elif self.schedule == 'cosine':
log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
return log_alpha_t
def marginal_alpha(self, t):
"""
Compute alpha_t of a given continuous-time label t in [0, T].
"""
return torch.exp(self.marginal_log_mean_coeff(t))
def marginal_std(self, t):
"""
Compute sigma_t of a given continuous-time label t in [0, T].
"""
return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
def marginal_lambda(self, t):
"""
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
"""
log_mean_coeff = self.marginal_log_mean_coeff(t)
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
return log_mean_coeff - log_std
def inverse_lambda(self, lamb):
"""
Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
"""
if self.schedule == 'linear':
tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
Delta = self.beta_0 ** 2 + tmp
return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
elif self.schedule == 'discrete':
log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]),
torch.flip(self.t_array.to(lamb.device), [1]))
return t.reshape((-1,))
else:
log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (
1. + self.cosine_s) / math.pi - self.cosine_s
t = t_fn(log_alpha)
return t
def model_wrapper(
model,
noise_schedule,
model_type="noise",
model_kwargs={},
guidance_type="uncond",
condition=None,
unconditional_condition=None,
guidance_scale=1.,
classifier_fn=None,
classifier_kwargs={},
):
"""Create a wrapper function for the noise prediction model.
DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
We support four types of the diffusion model by setting `model_type`:
1. "noise": noise prediction model. (Trained by predicting noise).
2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
3. "v": velocity prediction model. (Trained by predicting the velocity).
The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
[1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
arXiv preprint arXiv:2202.00512 (2022).
[2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
arXiv preprint arXiv:2210.02303 (2022).
4. "score": marginal score function. (Trained by denoising score matching).
Note that the score function and the noise prediction model follows a simple relationship:
```
noise(x_t, t) = -sigma_t * score(x_t, t)
```
We support three types of guided sampling by DPMs by setting `guidance_type`:
1. "uncond": unconditional sampling by DPMs.
The input `model` has the following format:
``
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
``
2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
The input `model` has the following format:
``
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
``
The input `classifier_fn` has the following format:
``
classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
``
[3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
The input `model` has the following format:
``
model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
``
And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
[4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
arXiv preprint arXiv:2207.12598 (2022).
The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
or continuous-time labels (i.e. epsilon to T).
We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
``
def model_fn(x, t_continuous) -> noise:
t_input = get_model_input_time(t_continuous)
return noise_pred(model, x, t_input, **model_kwargs)
``
where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
===============================================================
Args:
model: A diffusion model with the corresponding format described above.
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
model_type: A `str`. The parameterization type of the diffusion model.
"noise" or "x_start" or "v" or "score".
model_kwargs: A `dict`. A dict for the other inputs of the model function.
guidance_type: A `str`. The type of the guidance for sampling.
"uncond" or "classifier" or "classifier-free".
condition: A pytorch tensor. The condition for the guided sampling.
Only used for "classifier" or "classifier-free" guidance type.
unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
Only used for "classifier-free" guidance type.
guidance_scale: A `float`. The scale for the guided sampling.
classifier_fn: A classifier function. Only used for the classifier guidance.
classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
Returns:
A noise prediction model that accepts the noised data and the continuous time as the inputs.
"""
def get_model_input_time(t_continuous):
"""
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
For continuous-time DPMs, we just use `t_continuous`.
"""
if noise_schedule.schedule == 'discrete':
return (t_continuous - 1. / noise_schedule.total_N) * 1000.
else:
return t_continuous
def noise_pred_fn(x, t_continuous, cond=None):
if t_continuous.reshape((-1,)).shape[0] == 1:
t_continuous = t_continuous.expand((x.shape[0]))
t_input = get_model_input_time(t_continuous)
if cond is None:
output = model(x, t_input, **model_kwargs)
else:
output = model(x, t_input, cond, **model_kwargs)
if model_type == "noise":
return output
elif model_type == "x_start":
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
dims = x.dim()
return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
elif model_type == "v":
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
dims = x.dim()
return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
elif model_type == "score":
sigma_t = noise_schedule.marginal_std(t_continuous)
dims = x.dim()
return -expand_dims(sigma_t, dims) * output
def cond_grad_fn(x, t_input):
"""
Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
"""
with torch.enable_grad():
x_in = x.detach().requires_grad_(True)
log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
return torch.autograd.grad(log_prob.sum(), x_in)[0]
def model_fn(x, t_continuous):
"""
The noise predicition model function that is used for DPM-Solver.
"""
if t_continuous.reshape((-1,)).shape[0] == 1:
t_continuous = t_continuous.expand((x.shape[0]))
if guidance_type == "uncond":
return noise_pred_fn(x, t_continuous)
elif guidance_type == "classifier":
assert classifier_fn is not None
t_input = get_model_input_time(t_continuous)
cond_grad = cond_grad_fn(x, t_input)
sigma_t = noise_schedule.marginal_std(t_continuous)
noise = noise_pred_fn(x, t_continuous)
return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
elif guidance_type == "classifier-free":
if guidance_scale == 1. or unconditional_condition is None:
return noise_pred_fn(x, t_continuous, cond=condition)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t_continuous] * 2)
if isinstance(condition, dict):
assert isinstance(unconditional_condition, dict)
c_in = dict()
for k in condition:
if isinstance(condition[k], list):
c_in[k] = [torch.cat([unconditional_condition[k][i], condition[k][i]]) for i in range(len(condition[k]))]
else:
c_in[k] = torch.cat([unconditional_condition[k], condition[k]])
else:
c_in = torch.cat([unconditional_condition, condition])
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
return noise_uncond + guidance_scale * (noise - noise_uncond)
assert model_type in ["noise", "x_start", "v"]
assert guidance_type in ["uncond", "classifier", "classifier-free"]
return model_fn
class DPM_Solver:
def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.):
"""Construct a DPM-Solver.
We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0").
If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver).
If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++).
In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True.
The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales.
Args:
model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
``
def model_fn(x, t_continuous):
return noise
``
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model.
thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1].
max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding.
[1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
"""
self.model = model_fn
self.noise_schedule = noise_schedule
self.predict_x0 = predict_x0
self.thresholding = thresholding
self.max_val = max_val
def noise_prediction_fn(self, x, t):
"""
Return the noise prediction model.
"""
return self.model(x, t)
def data_prediction_fn(self, x, t):
"""
Return the data prediction model (with thresholding).
"""
noise = self.noise_prediction_fn(x, t)
dims = x.dim()
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
if self.thresholding:
p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
x0 = torch.clamp(x0, -s, s) / s
return x0
def model_fn(self, x, t):
"""
Convert the model to the noise prediction model or the data prediction model.
"""
if self.predict_x0:
return self.data_prediction_fn(x, t)
else:
return self.noise_prediction_fn(x, t)
def get_time_steps(self, skip_type, t_T, t_0, N, device):
"""Compute the intermediate time steps for sampling.
Args:
skip_type: A `str`. The type for the spacing of the time steps. We support three types:
- 'logSNR': uniform logSNR for the time steps.
- 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
- 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
t_T: A `float`. The starting time of the sampling (default is T).
t_0: A `float`. The ending time of the sampling (default is epsilon).
N: A `int`. The total number of the spacing of the time steps.
device: A torch device.
Returns:
A pytorch tensor of the time steps, with the shape (N + 1,).
"""
if skip_type == 'logSNR':
lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
return self.noise_schedule.inverse_lambda(logSNR_steps)
elif skip_type == 'time_uniform':
return torch.linspace(t_T, t_0, N + 1).to(device)
elif skip_type == 'time_quadratic':
t_order = 2
t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device)
return t
else:
raise ValueError(
"Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
"""
Get the order of each step for sampling by the singlestep DPM-Solver.
We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
- If order == 1:
We take `steps` of DPM-Solver-1 (i.e. DDIM).
- If order == 2:
- Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
- If steps % 2 == 0, we use K steps of DPM-Solver-2.
- If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
- If order == 3:
- Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
- If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
- If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
- If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
============================================
Args:
order: A `int`. The max order for the solver (2 or 3).
steps: A `int`. The total number of function evaluations (NFE).
skip_type: A `str`. The type for the spacing of the time steps. We support three types:
- 'logSNR': uniform logSNR for the time steps.
- 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
- 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
t_T: A `float`. The starting time of the sampling (default is T).
t_0: A `float`. The ending time of the sampling (default is epsilon).
device: A torch device.
Returns:
orders: A list of the solver order of each step.
"""
if order == 3:
K = steps // 3 + 1
if steps % 3 == 0:
orders = [3, ] * (K - 2) + [2, 1]
elif steps % 3 == 1:
orders = [3, ] * (K - 1) + [1]
else:
orders = [3, ] * (K - 1) + [2]
elif order == 2:
if steps % 2 == 0:
K = steps // 2
orders = [2, ] * K
else:
K = steps // 2 + 1
orders = [2, ] * (K - 1) + [1]
elif order == 1:
K = 1
orders = [1, ] * steps
else:
raise ValueError("'order' must be '1' or '2' or '3'.")
if skip_type == 'logSNR':
# To reproduce the results in DPM-Solver paper
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
else:
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
torch.cumsum(torch.tensor([0, ] + orders)).to(device)]
return timesteps_outer, orders
def denoise_to_zero_fn(self, x, s):
"""
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
"""
return self.data_prediction_fn(x, s)
def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
"""
DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
model_s: A pytorch tensor. The model function evaluated at time `s`.
If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
return_intermediate: A `bool`. If true, also return the model value at time `s`.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
ns = self.noise_schedule
dims = x.dim()
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
h = lambda_t - lambda_s
log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
alpha_t = torch.exp(log_alpha_t)
if self.predict_x0:
phi_1 = torch.expm1(-h)
if model_s is None:
model_s = self.model_fn(x, s)
x_t = (
expand_dims(sigma_t / sigma_s, dims) * x
- expand_dims(alpha_t * phi_1, dims) * model_s
)
if return_intermediate:
return x_t, {'model_s': model_s}
else:
return x_t
else:
phi_1 = torch.expm1(h)
if model_s is None:
model_s = self.model_fn(x, s)
x_t = (
expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
- expand_dims(sigma_t * phi_1, dims) * model_s
)
if return_intermediate:
return x_t, {'model_s': model_s}
else:
return x_t
def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False,
solver_type='dpm_solver'):
"""
Singlestep solver DPM-Solver-2 from time `s` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
r1: A `float`. The hyperparameter of the second-order solver.
model_s: A pytorch tensor. The model function evaluated at time `s`.
If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
if solver_type not in ['dpm_solver', 'taylor']:
raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
if r1 is None:
r1 = 0.5
ns = self.noise_schedule
dims = x.dim()
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
h = lambda_t - lambda_s
lambda_s1 = lambda_s + r1 * h
s1 = ns.inverse_lambda(lambda_s1)
log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(
s1), ns.marginal_log_mean_coeff(t)
sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
if self.predict_x0:
phi_11 = torch.expm1(-r1 * h)
phi_1 = torch.expm1(-h)
if model_s is None:
model_s = self.model_fn(x, s)
x_s1 = (
expand_dims(sigma_s1 / sigma_s, dims) * x
- expand_dims(alpha_s1 * phi_11, dims) * model_s
)
model_s1 = self.model_fn(x_s1, s1)
if solver_type == 'dpm_solver':
x_t = (
expand_dims(sigma_t / sigma_s, dims) * x
- expand_dims(alpha_t * phi_1, dims) * model_s
- (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s)
)
elif solver_type == 'taylor':
x_t = (
expand_dims(sigma_t / sigma_s, dims) * x
- expand_dims(alpha_t * phi_1, dims) * model_s
+ (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (
model_s1 - model_s)
)
else:
phi_11 = torch.expm1(r1 * h)
phi_1 = torch.expm1(h)
if model_s is None:
model_s = self.model_fn(x, s)
x_s1 = (
expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
- expand_dims(sigma_s1 * phi_11, dims) * model_s
)
model_s1 = self.model_fn(x_s1, s1)
if solver_type == 'dpm_solver':
x_t = (
expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
- expand_dims(sigma_t * phi_1, dims) * model_s
- (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s)
)
elif solver_type == 'taylor':
x_t = (
expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
- expand_dims(sigma_t * phi_1, dims) * model_s
- (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s)
)
if return_intermediate:
return x_t, {'model_s': model_s, 'model_s1': model_s1}
else:
return x_t
def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None,
return_intermediate=False, solver_type='dpm_solver'):
"""
Singlestep solver DPM-Solver-3 from time `s` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
r1: A `float`. The hyperparameter of the third-order solver.
r2: A `float`. The hyperparameter of the third-order solver.
model_s: A pytorch tensor. The model function evaluated at time `s`.
If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
if solver_type not in ['dpm_solver', 'taylor']:
raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
if r1 is None:
r1 = 1. / 3.
if r2 is None:
r2 = 2. / 3.
ns = self.noise_schedule
dims = x.dim()
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
h = lambda_t - lambda_s
lambda_s1 = lambda_s + r1 * h
lambda_s2 = lambda_s + r2 * h
s1 = ns.inverse_lambda(lambda_s1)
s2 = ns.inverse_lambda(lambda_s2)
log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(
s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(
s2), ns.marginal_std(t)
alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
if self.predict_x0:
phi_11 = torch.expm1(-r1 * h)
phi_12 = torch.expm1(-r2 * h)
phi_1 = torch.expm1(-h)
phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
phi_2 = phi_1 / h + 1.
phi_3 = phi_2 / h - 0.5
if model_s is None:
model_s = self.model_fn(x, s)
if model_s1 is None:
x_s1 = (
expand_dims(sigma_s1 / sigma_s, dims) * x
- expand_dims(alpha_s1 * phi_11, dims) * model_s
)
model_s1 = self.model_fn(x_s1, s1)
x_s2 = (
expand_dims(sigma_s2 / sigma_s, dims) * x
- expand_dims(alpha_s2 * phi_12, dims) * model_s
+ r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s)
)
model_s2 = self.model_fn(x_s2, s2)
if solver_type == 'dpm_solver':
x_t = (
expand_dims(sigma_t / sigma_s, dims) * x
- expand_dims(alpha_t * phi_1, dims) * model_s
+ (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s)
)
elif solver_type == 'taylor':
D1_0 = (1. / r1) * (model_s1 - model_s)
D1_1 = (1. / r2) * (model_s2 - model_s)
D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
x_t = (
expand_dims(sigma_t / sigma_s, dims) * x
- expand_dims(alpha_t * phi_1, dims) * model_s
+ expand_dims(alpha_t * phi_2, dims) * D1
- expand_dims(alpha_t * phi_3, dims) * D2
)
else:
phi_11 = torch.expm1(r1 * h)
phi_12 = torch.expm1(r2 * h)
phi_1 = torch.expm1(h)
phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
phi_2 = phi_1 / h - 1.
phi_3 = phi_2 / h - 0.5
if model_s is None:
model_s = self.model_fn(x, s)
if model_s1 is None:
x_s1 = (
expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
- expand_dims(sigma_s1 * phi_11, dims) * model_s
)
model_s1 = self.model_fn(x_s1, s1)
x_s2 = (
expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x
- expand_dims(sigma_s2 * phi_12, dims) * model_s
- r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s)
)
model_s2 = self.model_fn(x_s2, s2)
if solver_type == 'dpm_solver':
x_t = (
expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
- expand_dims(sigma_t * phi_1, dims) * model_s
- (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s)
)
elif solver_type == 'taylor':
D1_0 = (1. / r1) * (model_s1 - model_s)
D1_1 = (1. / r2) * (model_s2 - model_s)
D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
x_t = (
expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
- expand_dims(sigma_t * phi_1, dims) * model_s
- expand_dims(sigma_t * phi_2, dims) * D1
- expand_dims(sigma_t * phi_3, dims) * D2
)
if return_intermediate:
return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2}
else:
return x_t
def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"):
"""
Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
model_prev_list: A list of pytorch tensor. The previous computed model values.
t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
if solver_type not in ['dpm_solver', 'taylor']:
raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
ns = self.noise_schedule
dims = x.dim()
model_prev_1, model_prev_0 = model_prev_list
t_prev_1, t_prev_0 = t_prev_list
lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(
t_prev_0), ns.marginal_lambda(t)
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
alpha_t = torch.exp(log_alpha_t)
h_0 = lambda_prev_0 - lambda_prev_1
h = lambda_t - lambda_prev_0
r0 = h_0 / h
D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
if self.predict_x0:
if solver_type == 'dpm_solver':
x_t = (
expand_dims(sigma_t / sigma_prev_0, dims) * x
- expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
- 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0
)
elif solver_type == 'taylor':
x_t = (
expand_dims(sigma_t / sigma_prev_0, dims) * x
- expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
+ expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0
)
else:
if solver_type == 'dpm_solver':
x_t = (
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
- expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
- 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0
)
elif solver_type == 'taylor':
x_t = (
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
- expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
- expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0
)
return x_t
def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'):
"""
Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
model_prev_list: A list of pytorch tensor. The previous computed model values.
t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
ns = self.noise_schedule
dims = x.dim()
model_prev_2, model_prev_1, model_prev_0 = model_prev_list
t_prev_2, t_prev_1, t_prev_0 = t_prev_list
lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(
t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
alpha_t = torch.exp(log_alpha_t)
h_1 = lambda_prev_1 - lambda_prev_2
h_0 = lambda_prev_0 - lambda_prev_1
h = lambda_t - lambda_prev_0
r0, r1 = h_0 / h, h_1 / h
D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2)
D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1)
D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1)
if self.predict_x0:
x_t = (
expand_dims(sigma_t / sigma_prev_0, dims) * x
- expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
+ expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1
- expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2
)
else:
x_t = (
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
- expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
- expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1
- expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2
)
return x_t
def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None,
r2=None):
"""
Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
r1: A `float`. The hyperparameter of the second-order or third-order solver.
r2: A `float`. The hyperparameter of the third-order solver.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
if order == 1:
return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
elif order == 2:
return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate,
solver_type=solver_type, r1=r1)
elif order == 3:
return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate,
solver_type=solver_type, r1=r1, r2=r2)
else:
raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'):
"""
Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
model_prev_list: A list of pytorch tensor. The previous computed model values.
t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
if order == 1:
return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
elif order == 2:
return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
elif order == 3:
return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
else:
raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5,
solver_type='dpm_solver'):
"""
The adaptive step size solver based on singlestep DPM-Solver.
Args:
x: A pytorch tensor. The initial value at time `t_T`.
order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
t_T: A `float`. The starting time of the sampling (default is T).
t_0: A `float`. The ending time of the sampling (default is epsilon).
h_init: A `float`. The initial step size (for logSNR).
atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
current time and `t_0` is less than `t_err`. The default setting is 1e-5.
solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
Returns:
x_0: A pytorch tensor. The approximated solution at time `t_0`.
[1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
"""
ns = self.noise_schedule
s = t_T * torch.ones((x.shape[0],)).to(x)
lambda_s = ns.marginal_lambda(s)
lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
h = h_init * torch.ones_like(s).to(x)
x_prev = x
nfe = 0
if order == 2:
r1 = 0.5
lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
solver_type=solver_type,
**kwargs)
elif order == 3:
r1, r2 = 1. / 3., 2. / 3.
lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
return_intermediate=True,
solver_type=solver_type)
higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2,
solver_type=solver_type,
**kwargs)
else:
raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
while torch.abs((s - t_0)).mean() > t_err:
t = ns.inverse_lambda(lambda_s + h)
x_lower, lower_noise_kwargs = lower_update(x, s, t)
x_higher = higher_update(x, s, t, **lower_noise_kwargs)
delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
E = norm_fn((x_higher - x_lower) / delta).max()
if torch.all(E <= 1.):
x = x_higher
s = t
x_prev = x_lower
lambda_s = ns.marginal_lambda(s)
h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
nfe += order
print('adaptive solver nfe', nfe)
return x
def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
atol=0.0078, rtol=0.05,
):
"""
Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
=====================================================
We support the following algorithms for both noise prediction model and data prediction model:
- 'singlestep':
Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
The total number of function evaluations (NFE) == `steps`.
Given a fixed NFE == `steps`, the sampling procedure is:
- If `order` == 1:
- Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
- If `order` == 2:
- Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
- If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
- If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
- If `order` == 3:
- Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
- If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
- If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
- If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
- 'multistep':
Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
We initialize the first `order` values by lower order multistep solvers.
Given a fixed NFE == `steps`, the sampling procedure is:
Denote K = steps.
- If `order` == 1:
- We use K steps of DPM-Solver-1 (i.e. DDIM).
- If `order` == 2:
- We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
- If `order` == 3:
- We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
- 'singlestep_fixed':
Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
- 'adaptive':
Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
(NFE) and the sample quality.
- If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
- If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
=====================================================
Some advices for choosing the algorithm:
- For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`.
e.g.
>>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False)
>>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
skip_type='time_uniform', method='singlestep')
- For **guided sampling with large guidance scale** by DPMs:
Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`.
e.g.
>>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True)
>>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
skip_type='time_uniform', method='multistep')
We support three types of `skip_type`:
- 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
- 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
- 'time_quadratic': quadratic time for the time steps.
=====================================================
Args:
x: A pytorch tensor. The initial value at time `t_start`
e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
steps: A `int`. The total number of function evaluations (NFE).
t_start: A `float`. The starting time of the sampling.
If `T` is None, we use self.noise_schedule.T (default is 1.0).
t_end: A `float`. The ending time of the sampling.
If `t_end` is None, we use 1. / self.noise_schedule.total_N.
e.g. if total_N == 1000, we have `t_end` == 1e-3.
For discrete-time DPMs:
- We recommend `t_end` == 1. / self.noise_schedule.total_N.
For continuous-time DPMs:
- We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
order: A `int`. The order of DPM-Solver.
skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
for diffusion models sampling by diffusion SDEs for low-resolutional images
(such as CIFAR-10). However, we observed that such trick does not matter for
high-resolutional images. As it needs an additional NFE, we do not recommend
it for high-resolutional images.
lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
Only valid for `method=multistep` and `steps < 15`. We empirically find that
this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
(especially for steps <= 10). So we recommend to set it to be `True`.
solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`.
atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
Returns:
x_end: A pytorch tensor. The approximated solution at time `t_end`.
"""
t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
t_T = self.noise_schedule.T if t_start is None else t_start
device = x.device
if method == 'adaptive':
with torch.no_grad():
x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol,
solver_type=solver_type)
elif method == 'multistep':
assert steps >= order
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
assert timesteps.shape[0] - 1 == steps
with torch.no_grad():
vec_t = timesteps[0].expand((x.shape[0]))
model_prev_list = [self.model_fn(x, vec_t)]
t_prev_list = [vec_t]
# Init the first `order` values by lower order multistep DPM-Solver.
for init_order in tqdm(range(1, order), desc="DPM init order"):
vec_t = timesteps[init_order].expand(x.shape[0])
x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order,
solver_type=solver_type)
model_prev_list.append(self.model_fn(x, vec_t))
t_prev_list.append(vec_t)
# Compute the remaining values by `order`-th order multistep DPM-Solver.
for step in tqdm(range(order, steps + 1), desc="DPM multistep"):
vec_t = timesteps[step].expand(x.shape[0])
if lower_order_final and steps < 15:
step_order = min(order, steps + 1 - step)
else:
step_order = order
x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order,
solver_type=solver_type)
for i in range(order - 1):
t_prev_list[i] = t_prev_list[i + 1]
model_prev_list[i] = model_prev_list[i + 1]
t_prev_list[-1] = vec_t
# We do not need to evaluate the final model value.
if step < steps:
model_prev_list[-1] = self.model_fn(x, vec_t)
elif method in ['singlestep', 'singlestep_fixed']:
if method == 'singlestep':
timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order,
skip_type=skip_type,
t_T=t_T, t_0=t_0,
device=device)
elif method == 'singlestep_fixed':
K = steps // order
orders = [order, ] * K
timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
for i, order in enumerate(orders):
t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1]
timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(),
N=order, device=device)
lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0])
h = lambda_inner[-1] - lambda_inner[0]
r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2)
if denoise_to_zero:
x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
return x
#############################################################
# other utility functions
#############################################################
def interpolate_fn(x, xp, yp):
"""
A piecewise linear function y = f(x), using xp and yp as keypoints.
We implement f(x) in a differentiable way (i.e. applicable for autograd).
The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
Args:
x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
yp: PyTorch tensor with shape [C, K].
Returns:
The function values f(x), with shape [N, C].
"""
N, K = x.shape[0], xp.shape[1]
all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
sorted_all_x, x_indices = torch.sort(all_x, dim=2)
x_idx = torch.argmin(x_indices, dim=2)
cand_start_idx = x_idx - 1
start_idx = torch.where(
torch.eq(x_idx, 0),
torch.tensor(1, device=x.device),
torch.where(
torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
),
)
end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
start_idx2 = torch.where(
torch.eq(x_idx, 0),
torch.tensor(0, device=x.device),
torch.where(
torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
),
)
y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
return cand
def expand_dims(v, dims):
"""
Expand the tensor `v` to the dim `dims`.
Args:
`v`: a PyTorch tensor with shape [N].
`dim`: a `int`.
Returns:
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
"""
return v[(...,) + (None,) * (dims - 1)]
\ No newline at end of file
"""SAMPLING ONLY."""
import torch
from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
MODEL_TYPES = {
"eps": "noise",
"v": "v"
}
class DPMSolverSampler(object):
def __init__(self, model, device=torch.device("cuda"), **kwargs):
super().__init__()
self.model = model
self.device = device
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != self.device:
attr = attr.to(self.device)
setattr(self, name, attr)
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list): ctmp = ctmp[0]
if isinstance(ctmp, torch.Tensor):
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
elif isinstance(conditioning, list):
for ctmp in conditioning:
if ctmp.shape[0] != batch_size:
print(f"Warning: Got {ctmp.shape[0]} conditionings but batch-size is {batch_size}")
else:
if isinstance(conditioning, torch.Tensor):
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
device = self.model.betas.device
if x_T is None:
img = torch.randn(size, device=device)
else:
img = x_T
ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
model_fn = model_wrapper(
lambda x, t, c: self.model.apply_model(x, t, c),
ns,
model_type=MODEL_TYPES[self.model.parameterization],
guidance_type="classifier-free",
condition=conditioning,
unconditional_condition=unconditional_conditioning,
guidance_scale=unconditional_guidance_scale,
)
dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2,
lower_order_final=True)
return x.to(device), None
"""SAMPLING ONLY."""
import torch
import numpy as np
from tqdm import tqdm
from functools import partial
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
from ldm.models.diffusion.sampling_util import norm_thresholding
class PLMSSampler(object):
def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
self.device = device
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != self.device:
attr = attr.to(self.device)
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
if ddim_eta != 0:
raise ValueError('ddim_eta must be 0 for PLMS')
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
dynamic_threshold=None,
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for PLMS sampling is {size}')
samples, intermediates = self.plms_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
)
return samples, intermediates
@torch.no_grad()
def plms_sampling(self, cond, shape,
x_T=None, ddim_use_original_steps=False,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,
dynamic_threshold=None):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running PLMS Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
old_eps = []
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps, t_next=ts_next,
dynamic_threshold=dynamic_threshold)
img, pred_x0, e_t = outs
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
return img, intermediates
@torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
dynamic_threshold=None):
b, *_, device = *x.shape, x.device
def get_model_output(x, t):
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
return e_t
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
def get_x_prev_and_pred_x0(e_t, index):
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
e_t = get_model_output(x, t)
if len(old_eps) == 0:
# Pseudo Improved Euler (2nd order)
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
e_t_next = get_model_output(x_prev, t_next)
e_t_prime = (e_t + e_t_next) / 2
elif len(old_eps) == 1:
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (3 * e_t - old_eps[-1]) / 2
elif len(old_eps) == 2:
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
elif len(old_eps) >= 3:
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
return x_prev, pred_x0, e_t
import torch
import numpy as np
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.
From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
return x[(...,) + (None,) * dims_to_append]
def norm_thresholding(x0, value):
s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
return x0 * (value / s)
def spatial_norm_thresholding(x0, value):
# b c h w
s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
return x0 * (value / s)
\ No newline at end of file
......@@ -4,8 +4,6 @@ from .extra_samplers import uni_pc
import torch
import enum
from comfy import model_management
from .ldm.models.diffusion.ddim import DDIMSampler
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
import math
from comfy import model_base
import comfy.utils
......@@ -511,41 +509,6 @@ class Sampler:
sigma = float(sigmas[0])
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
class DDIM(Sampler):
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
timesteps = []
for s in range(sigmas.shape[0]):
timesteps.insert(0, model_wrap.sigma_to_discrete_timestep(sigmas[s]))
noise_mask = None
if denoise_mask is not None:
noise_mask = 1.0 - denoise_mask
ddim_callback = None
if callback is not None:
total_steps = len(timesteps) - 1
ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps)
max_denoise = self.max_denoise(model_wrap, sigmas)
ddim_sampler = DDIMSampler(model_wrap.inner_model.inner_model, device=noise.device)
ddim_sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
z_enc = ddim_sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(noise.device), noise=noise, max_denoise=max_denoise)
samples, _ = ddim_sampler.sample_custom(ddim_timesteps=timesteps,
batch_size=noise.shape[0],
shape=noise.shape[1:],
verbose=False,
eta=0.0,
x_T=z_enc,
x0=latent_image,
img_callback=ddim_callback,
denoise_function=model_wrap.predict_eps_discrete_timestep,
extra_args=extra_args,
mask=noise_mask,
to_zero=sigmas[-1]==0,
end_step=sigmas.shape[0] - 1,
disable_pbar=disable_pbar)
return samples
class UNIPC(Sampler):
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
......@@ -558,13 +521,17 @@ KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral"
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm"]
def ksampler(sampler_name, extra_options={}):
def ksampler(sampler_name, extra_options={}, inpaint_options={}):
class KSAMPLER(Sampler):
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
extra_args["denoise_mask"] = denoise_mask
model_k = KSamplerX0Inpaint(model_wrap)
model_k.latent_image = latent_image
model_k.noise = noise
if inpaint_options.get("random", False): #TODO: Should this be the default?
generator = torch.manual_seed(extra_args.get("seed", 41) + 1)
model_k.noise = torch.randn(noise.shape, generator=generator, device="cpu").to(noise.dtype).to(noise.device)
else:
model_k.noise = noise
if self.max_denoise(model_wrap, sigmas):
noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0)
......@@ -656,7 +623,7 @@ def sampler_class(name):
elif name == "uni_pc_bh2":
sampler = UNIPCBH2
elif name == "ddim":
sampler = DDIM
sampler = ksampler("euler", inpaint_options={"random": True})
else:
sampler = ksampler(name)
return sampler
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment