"vscode:/vscode.git/clone" did not exist on "fbf61f465b7756bd3d01a272ea994741c3cfcf8c"
Commit bf13b76a authored by anton-l's avatar anton-l
Browse files

Fix merge

parent 9c530191
...@@ -10,9 +10,5 @@ from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel ...@@ -10,9 +10,5 @@ from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from .models.unet_ldm import UNetLDMModel from .models.unet_ldm import UNetLDMModel
from .pipeline_utils import DiffusionPipeline from .pipeline_utils import DiffusionPipeline
from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion
from .schedulers import DDIMScheduler, DDPMScheduler, SchedulerMixin
from .schedulers import SchedulerMixin, DDIMScheduler, DDPMScheduler
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
from .schedulers.ddim import DDIMScheduler
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler
...@@ -16,12 +16,7 @@ ...@@ -16,12 +16,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .classifier_free_guidance import ClassifierFreeGuidanceScheduler
from .scheduling_ddim import DDIMScheduler from .scheduling_ddim import DDIMScheduler
from .scheduling_ddpm import DDPMScheduler from .scheduling_ddpm import DDPMScheduler
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
from .classifier_free_guidance import ClassifierFreeGuidanceScheduler
from .ddim import DDIMScheduler
from .gaussian_ddpm import GaussianDDPMScheduler
from .glide_ddim import GlideDDIMScheduler
from .schedulers_utils import SchedulerMixin
...@@ -15,6 +15,7 @@ import math ...@@ -15,6 +15,7 @@ import math
import numpy as np import numpy as np
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
...@@ -28,11 +29,11 @@ def noise_like(shape, device, repeat=False): ...@@ -28,11 +29,11 @@ def noise_like(shape, device, repeat=False):
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
if ddim_discr_method == 'uniform': if ddim_discr_method == "uniform":
c = num_ddpm_timesteps // num_ddim_timesteps c = num_ddpm_timesteps // num_ddim_timesteps
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
elif ddim_discr_method == 'quad': elif ddim_discr_method == "quad":
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2).astype(int)
else: else:
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
...@@ -40,7 +41,7 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep ...@@ -40,7 +41,7 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep
# add one to get the final alpha values right (the ones from first scale to data during sampling) # add one to get the final alpha values right (the ones from first scale to data during sampling)
steps_out = ddim_timesteps + 1 steps_out = ddim_timesteps + 1
if verbose: if verbose:
print(f'Selected timesteps for ddim sampler: {steps_out}') print(f"Selected timesteps for ddim sampler: {steps_out}")
return steps_out return steps_out
...@@ -52,9 +53,11 @@ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): ...@@ -52,9 +53,11 @@ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
# according the the formula provided in https://arxiv.org/abs/2010.02502 # according the the formula provided in https://arxiv.org/abs/2010.02502
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
if verbose: if verbose:
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') print(f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}")
print(f'For the chosen value of eta, which is {eta}, ' print(
f'this results in the following sigma_t schedule for ddim sampler {sigmas}') f"For the chosen value of eta, which is {eta}, "
f"this results in the following sigma_t schedule for ddim sampler {sigmas}"
)
return sigmas, alphas, alphas_prev return sigmas, alphas, alphas_prev
...@@ -71,41 +74,48 @@ class PLMSSampler(object): ...@@ -71,41 +74,48 @@ class PLMSSampler(object):
attr = attr.to(torch.device("cuda")) attr = attr.to(torch.device("cuda"))
setattr(self, name, attr) setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True):
if ddim_eta != 0: if ddim_eta != 0:
raise ValueError('ddim_eta must be 0 for PLMS') raise ValueError("ddim_eta must be 0 for PLMS")
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, self.ddim_timesteps = make_ddim_timesteps(
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,
verbose=verbose,
)
alphas_cumprod = self.model.alphas_cumprod alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, "alphas have to be defined for each timestep"
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer('betas', to_torch(self.model.betas)) self.register_buffer("betas", to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) self.register_buffer("alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others # calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) self.register_buffer("log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) self.register_buffer("sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) self.register_buffer("sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)))
# ddim sampling parameters # ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
ddim_timesteps=self.ddim_timesteps, alphacums=alphas_cumprod.cpu(), ddim_timesteps=self.ddim_timesteps, eta=ddim_eta, verbose=verbose
eta=ddim_eta,verbose=verbose) )
self.register_buffer('ddim_sigmas', ddim_sigmas) self.register_buffer("ddim_sigmas", ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas) self.register_buffer("ddim_alphas", ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( (1 - self.alphas_cumprod_prev)
1 - self.alphas_cumprod / self.alphas_cumprod_prev)) / (1 - self.alphas_cumprod)
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
)
self.register_buffer("ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps)
@torch.no_grad() @torch.no_grad()
def sample(self, def sample(
self,
S, S,
batch_size, batch_size,
shape, shape,
...@@ -114,20 +124,20 @@ class PLMSSampler(object): ...@@ -114,20 +124,20 @@ class PLMSSampler(object):
normals_sequence=None, normals_sequence=None,
img_callback=None, img_callback=None,
quantize_x0=False, quantize_x0=False,
eta=0., eta=0.0,
mask=None, mask=None,
x0=None, x0=None,
temperature=1., temperature=1.0,
noise_dropout=0., noise_dropout=0.0,
score_corrector=None, score_corrector=None,
corrector_kwargs=None, corrector_kwargs=None,
verbose=True, verbose=True,
x_T=None, x_T=None,
log_every_t=100, log_every_t=100,
unconditional_guidance_scale=1., unconditional_guidance_scale=1.0,
unconditional_conditioning=None, unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs **kwargs,
): ):
if conditioning is not None: if conditioning is not None:
if isinstance(conditioning, dict): if isinstance(conditioning, dict):
...@@ -142,13 +152,16 @@ class PLMSSampler(object): ...@@ -142,13 +152,16 @@ class PLMSSampler(object):
# sampling # sampling
C, H, W = shape C, H, W = shape
size = (batch_size, C, H, W) size = (batch_size, C, H, W)
print(f'Data shape for PLMS sampling is {size}') print(f"Data shape for PLMS sampling is {size}")
samples, intermediates = self.plms_sampling(conditioning, size, samples, intermediates = self.plms_sampling(
conditioning,
size,
callback=callback, callback=callback,
img_callback=img_callback, img_callback=img_callback,
quantize_denoised=quantize_x0, quantize_denoised=quantize_x0,
mask=mask, x0=x0, mask=mask,
x0=x0,
ddim_use_original_steps=False, ddim_use_original_steps=False,
noise_dropout=noise_dropout, noise_dropout=noise_dropout,
temperature=temperature, temperature=temperature,
...@@ -162,12 +175,26 @@ class PLMSSampler(object): ...@@ -162,12 +175,26 @@ class PLMSSampler(object):
return samples, intermediates return samples, intermediates
@torch.no_grad() @torch.no_grad()
def plms_sampling(self, cond, shape, def plms_sampling(
x_T=None, ddim_use_original_steps=False, self,
callback=None, timesteps=None, quantize_denoised=False, cond,
mask=None, x0=None, img_callback=None, log_every_t=100, shape,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, x_T=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,): ddim_use_original_steps=False,
callback=None,
timesteps=None,
quantize_denoised=False,
mask=None,
x0=None,
img_callback=None,
log_every_t=100,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
):
device = self.model.betas.device device = self.model.betas.device
b = shape[0] b = shape[0]
if x_T is None: if x_T is None:
...@@ -181,12 +208,12 @@ class PLMSSampler(object): ...@@ -181,12 +208,12 @@ class PLMSSampler(object):
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end] timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]} intermediates = {"x_inter": [img], "pred_x0": [img]}
time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) time_range = list(reversed(range(0, timesteps))) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running PLMS Sampling with {total_steps} timesteps") print(f"Running PLMS Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) iterator = tqdm(time_range, desc="PLMS Sampler", total=total_steps)
old_eps = [] old_eps = []
for i, step in enumerate(iterator): for i, step in enumerate(iterator):
...@@ -197,36 +224,62 @@ class PLMSSampler(object): ...@@ -197,36 +224,62 @@ class PLMSSampler(object):
if mask is not None: if mask is not None:
assert x0 is not None assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img img = img_orig * mask + (1.0 - mask) * img
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, outs = self.p_sample_plms(
quantize_denoised=quantize_denoised, temperature=temperature, img,
noise_dropout=noise_dropout, score_corrector=score_corrector, cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs, corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning, unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps, t_next=ts_next) old_eps=old_eps,
t_next=ts_next,
)
img, pred_x0, e_t = outs img, pred_x0, e_t = outs
old_eps.append(e_t) old_eps.append(e_t)
if len(old_eps) >= 4: if len(old_eps) >= 4:
old_eps.pop(0) old_eps.pop(0)
if callback: callback(i) if callback:
if img_callback: img_callback(pred_x0, i) callback(i)
if img_callback:
img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1: if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img) intermediates["x_inter"].append(img)
intermediates['pred_x0'].append(pred_x0) intermediates["pred_x0"].append(pred_x0)
return img, intermediates return img, intermediates
@torch.no_grad() @torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, def p_sample_plms(
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, self,
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): x,
c,
t,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
old_eps=None,
t_next=None,
):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
def get_model_output(x, t): def get_model_output(x, t):
if unconditional_conditioning is None or unconditional_guidance_scale == 1.: if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
e_t = self.model.apply_model(x, t, c) e_t = self.model.apply_model(x, t, c)
else: else:
x_in = torch.cat([x] * 2) x_in = torch.cat([x] * 2)
...@@ -243,7 +296,9 @@ class PLMSSampler(object): ...@@ -243,7 +296,9 @@ class PLMSSampler(object):
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas sqrt_one_minus_alphas = (
self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
)
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
def get_x_prev_and_pred_x0(e_t, index): def get_x_prev_and_pred_x0(e_t, index):
...@@ -251,16 +306,16 @@ class PLMSSampler(object): ...@@ -251,16 +306,16 @@ class PLMSSampler(object):
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device)
# current prediction for x_0 # current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised: if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t # direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.: if noise_dropout > 0.0:
noise = torch.nn.functional.dropout(noise, p=noise_dropout) noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0 return x_prev, pred_x0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment