Commit 2b8bc91c authored by Patrick von Platen's avatar Patrick von Platen
Browse files

removed get alpha / get beta

parent 5b8ce1e7
...@@ -291,7 +291,7 @@ class BDDM(DiffusionPipeline): ...@@ -291,7 +291,7 @@ class BDDM(DiffusionPipeline):
# Sample gaussian noise to begin loop # Sample gaussian noise to begin loop
audio = torch.normal(0, 1, size=audio_size, generator=generator).to(torch_device) audio = torch.normal(0, 1, size=audio_size, generator=generator).to(torch_device)
timestep_values = self.noise_scheduler.get_timestep_values() timestep_values = self.noise_scheduler.config.timestep_values
num_prediction_steps = len(self.noise_scheduler) num_prediction_steps = len(self.noise_scheduler)
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps): for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
# 1. predict noise residual # 1. predict noise residual
......
...@@ -79,31 +79,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -79,31 +79,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
self.set_format(tensor_format=tensor_format) self.set_format(tensor_format=tensor_format)
def get_timestep_values(self):
return self.config.timestep_values
def get_alpha(self, time_step):
return self.alphas[time_step]
def get_beta(self, time_step):
return self.betas[time_step]
def get_alpha_prod(self, time_step):
if time_step < 0:
return self.one
return self.alphas_cumprod[time_step]
def get_orig_t(self, t, num_inference_steps):
if t < 0:
return -1
return self.config.timesteps // num_inference_steps * t
def get_variance(self, t, num_inference_steps): def get_variance(self, t, num_inference_steps):
orig_t = self.get_orig_t(t, num_inference_steps) orig_t = self.config.timesteps // num_inference_steps * t
orig_prev_t = self.get_orig_t(t - 1, num_inference_steps) orig_prev_t = self.config.timesteps // num_inference_steps * (t - 1) if t > 0 else -1
alpha_prod_t = self.get_alpha_prod(orig_t) alpha_prod_t = self.alphas_cumprod[orig_t]
alpha_prod_t_prev = self.get_alpha_prod(orig_prev_t) alpha_prod_t_prev = self.alphas_cumprod[orig_prev_t] if orig_prev_t >= 0 else self.one
beta_prod_t = 1 - alpha_prod_t beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev beta_prod_t_prev = 1 - alpha_prod_t_prev
...@@ -124,12 +105,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -124,12 +105,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# - pred_prev_sample -> "x_t-1" # - pred_prev_sample -> "x_t-1"
# 1. get actual t and t-1 # 1. get actual t and t-1
orig_t = self.get_orig_t(t, num_inference_steps) orig_t = self.config.timesteps // num_inference_steps * t
orig_prev_t = self.get_orig_t(t - 1, num_inference_steps) orig_prev_t = self.config.timesteps // num_inference_steps * (t - 1) if t > 0 else -1
# 2. compute alphas, betas # 2. compute alphas, betas
alpha_prod_t = self.get_alpha_prod(orig_t) alpha_prod_t = self.alphas_cumprod[orig_t]
alpha_prod_t_prev = self.get_alpha_prod(orig_prev_t) alpha_prod_t_prev = self.alphas_cumprod[orig_prev_t] if orig_prev_t >= 0 else self.one
beta_prod_t = 1 - alpha_prod_t beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called # 3. compute predicted original sample from predicted noise also called
......
...@@ -83,44 +83,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -83,44 +83,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self.set_format(tensor_format=tensor_format) self.set_format(tensor_format=tensor_format)
# self.register_buffer("betas", betas.to(torch.float32))
# self.register_buffer("alphas", alphas.to(torch.float32))
# self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32))
# alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
# TODO(PVP) - check how much of these is actually necessary!
# LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ...
# https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246
# variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
# if variance_type == "fixed_small":
# log_variance = torch.log(variance.clamp(min=1e-20))
# elif variance_type == "fixed_large":
# log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0))
#
#
# self.register_buffer("log_variance", log_variance.to(torch.float32))
def get_timestep_values(self):
return self.config.timestep_values
def get_alpha(self, time_step):
return self.alphas[time_step]
def get_beta(self, time_step):
return self.betas[time_step]
def get_alpha_prod(self, time_step):
if time_step < 0:
return self.one
return self.alphas_cumprod[time_step]
def get_variance(self, t): def get_variance(self, t):
alpha_prod_t = self.get_alpha_prod(t) alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.get_alpha_prod(t - 1) alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
# For t > 0, compute predicted variance βt (see formala (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) # For t > 0, compute predicted variance βt (see formala (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
# and sample from it to get previous sample # and sample from it to get previous sample
# x_{t-1} ~ N(pred_prev_sample, variance) == add variane to pred_sample # x_{t-1} ~ N(pred_prev_sample, variance) == add variane to pred_sample
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.get_beta(t) variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t]
# hacks - were probs added for training stability # hacks - were probs added for training stability
if self.config.variance_type == "fixed_small": if self.config.variance_type == "fixed_small":
...@@ -129,14 +99,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -129,14 +99,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
elif self.config.variance_type == "fixed_small_log": elif self.config.variance_type == "fixed_small_log":
variance = self.log(self.clip(variance, min_value=1e-20)) variance = self.log(self.clip(variance, min_value=1e-20))
elif self.config.variance_type == "fixed_large": elif self.config.variance_type == "fixed_large":
variance = self.get_beta(t) variance = self.betas[t]
return variance return variance
def step(self, residual, sample, t, predict_epsilon=True): def step(self, residual, sample, t, predict_epsilon=True):
# 1. compute alphas, betas # 1. compute alphas, betas
alpha_prod_t = self.get_alpha_prod(t) alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.get_alpha_prod(t - 1) alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
beta_prod_t = 1 - alpha_prod_t beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev beta_prod_t_prev = 1 - alpha_prod_t_prev
...@@ -153,8 +123,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -153,8 +123,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.get_beta(t)) / beta_prod_t pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t
current_sample_coeff = self.get_alpha(t) ** (0.5) * beta_prod_t_prev / beta_prod_t current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t
# 5. Compute predicted previous sample µ_t # 5. Compute predicted previous sample µ_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
...@@ -163,8 +133,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -163,8 +133,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return pred_prev_sample return pred_prev_sample
def forward_step(self, original_sample, noise, t): def forward_step(self, original_sample, noise, t):
sqrt_alpha_prod = self.get_alpha_prod(t) ** 0.5 sqrt_alpha_prod = self.alpha_prod_t[t] ** 0.5
sqrt_one_minus_alpha_prod = (1 - self.get_alpha_prod(t)) ** 0.5 sqrt_one_minus_alpha_prod = (1 - self.alpha_prod_t[t]) ** 0.5
noisy_sample = sqrt_alpha_prod * original_sample + sqrt_one_minus_alpha_prod * noise noisy_sample = sqrt_alpha_prod * original_sample + sqrt_one_minus_alpha_prod * noise
return noisy_sample return noisy_sample
......
...@@ -86,17 +86,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -86,17 +86,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self.time_steps = {} self.time_steps = {}
self.set_prk_mode() self.set_prk_mode()
def get_alpha(self, time_step):
return self.alphas[time_step]
def get_beta(self, time_step):
return self.betas[time_step]
def get_alpha_prod(self, time_step):
if time_step < 0:
return self.one
return self.alphas_cumprod[time_step]
def get_prk_time_steps(self, num_inference_steps): def get_prk_time_steps(self, num_inference_steps):
if num_inference_steps in self.prk_time_steps: if num_inference_steps in self.prk_time_steps:
return self.prk_time_steps[num_inference_steps] return self.prk_time_steps[num_inference_steps]
...@@ -188,8 +177,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -188,8 +177,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
# sample -> x_t # sample -> x_t
# residual -> e_θ(x_t, t) # residual -> e_θ(x_t, t)
# prev_sample -> x_(t−δ) # prev_sample -> x_(t−δ)
alpha_prod_t = self.get_alpha_prod(t_orig + 1) alpha_prod_t = self.alphas_cumprod[t_orig + 1]
alpha_prod_t_prev = self.get_alpha_prod(t_orig_prev + 1) alpha_prod_t_prev = self.alphas_cumprod[t_orig_prev + 1]
beta_prod_t = 1 - alpha_prod_t beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev beta_prod_t_prev = 1 - alpha_prod_t_prev
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment