Commit dcb23b2d authored by Patrick von Platen's avatar Patrick von Platen
Browse files

rename image to sample in schedulers

parent 13a78b3c
...@@ -28,7 +28,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -28,7 +28,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
beta_schedule="linear", beta_schedule="linear",
trained_betas=None, trained_betas=None,
timestep_values=None, timestep_values=None,
clip_predicted_image=True, clip_predicted_sample=True,
tensor_format="np", tensor_format="np",
): ):
super().__init__() super().__init__()
...@@ -40,7 +40,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -40,7 +40,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
) )
self.timesteps = int(timesteps) self.timesteps = int(timesteps)
self.timestep_values = timestep_values # save the fixed timestep values for BDDM self.timestep_values = timestep_values # save the fixed timestep values for BDDM
self.clip_image = clip_predicted_image self.clip_sample = clip_predicted_sample
if beta_schedule == "linear": if beta_schedule == "linear":
self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end) self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
...@@ -111,17 +111,17 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -111,17 +111,17 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
return variance return variance
def step(self, residual, image, t, num_inference_steps, eta, use_clipped_residual=False): def step(self, residual, sample, t, num_inference_steps, eta, use_clipped_residual=False):
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding # Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper> # Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t) # - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0 # - pred_original_sample -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t # - std_dev_t -> sigma_t
# - eta -> η # - eta -> η
# - pred_image_direction -> "direction pointingc to x_t" # - pred_sample_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1" # - pred_prev_sample -> "x_t-1"
# 1. get actual t and t-1 # 1. get actual t and t-1
orig_t = self.get_orig_t(t, num_inference_steps) orig_t = self.get_orig_t(t, num_inference_steps)
...@@ -132,13 +132,13 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -132,13 +132,13 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
alpha_prod_t_prev = self.get_alpha_prod(orig_prev_t) alpha_prod_t_prev = self.get_alpha_prod(orig_prev_t)
beta_prod_t = 1 - alpha_prod_t beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original image from predicted noise also called # 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_image = (image - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5) pred_original_sample = (sample - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5)
# 4. Clip "predicted x_0" # 4. Clip "predicted x_0"
if self.clip_image: if self.clip_sample:
pred_original_image = self.clip(pred_original_image, -1, 1) pred_original_sample = self.clip(pred_original_sample, -1, 1)
# 5. compute variance: "sigma_t(η)" -> see formula (16) # 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
...@@ -147,15 +147,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -147,15 +147,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
if use_clipped_residual: if use_clipped_residual:
# the residual is always re-derived from the clipped x_0 in GLIDE # the residual is always re-derived from the clipped x_0 in GLIDE
residual = (image - alpha_prod_t ** (0.5) * pred_original_image) / beta_prod_t ** (0.5) residual = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_image_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * residual pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * residual
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_prev_image = alpha_prod_t_prev ** (0.5) * pred_original_image + pred_image_direction pred_prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
return pred_prev_image return pred_prev_sample
def __len__(self): def __len__(self):
return self.timesteps return self.timesteps
...@@ -29,7 +29,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -29,7 +29,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
trained_betas=None, trained_betas=None,
timestep_values=None, timestep_values=None,
variance_type="fixed_small", variance_type="fixed_small",
clip_predicted_image=True, clip_predicted_sample=True,
tensor_format="np", tensor_format="np",
): ):
super().__init__() super().__init__()
...@@ -41,11 +41,11 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -41,11 +41,11 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
trained_betas=trained_betas, trained_betas=trained_betas,
timestep_values=timestep_values, timestep_values=timestep_values,
variance_type=variance_type, variance_type=variance_type,
clip_predicted_image=clip_predicted_image, clip_predicted_sample=clip_predicted_sample,
) )
self.timesteps = int(timesteps) self.timesteps = int(timesteps)
self.timestep_values = timestep_values # save the fixed timestep values for BDDM self.timestep_values = timestep_values # save the fixed timestep values for BDDM
self.clip_image = clip_predicted_image self.clip_sample = clip_predicted_sample
self.variance_type = variance_type self.variance_type = variance_type
if trained_betas is not None: if trained_betas is not None:
...@@ -100,8 +100,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -100,8 +100,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
alpha_prod_t_prev = self.get_alpha_prod(t - 1) alpha_prod_t_prev = self.get_alpha_prod(t - 1)
# For t > 0, compute predicted variance βt (see formala (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) # For t > 0, compute predicted variance βt (see formala (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
# and sample from it to get previous image # and sample from it to get previous sample
# x_{t-1} ~ N(pred_prev_image, variance) == add variane to pred_image # x_{t-1} ~ N(pred_prev_sample, variance) == add variane to pred_sample
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.get_beta(t) variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.get_beta(t)
# hacks - were probs added for training stability # hacks - were probs added for training stability
...@@ -112,37 +112,37 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -112,37 +112,37 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return variance return variance
def step(self, residual, image, t): def step(self, residual, sample, t):
# 1. compute alphas, betas # 1. compute alphas, betas
alpha_prod_t = self.get_alpha_prod(t) alpha_prod_t = self.get_alpha_prod(t)
alpha_prod_t_prev = self.get_alpha_prod(t - 1) alpha_prod_t_prev = self.get_alpha_prod(t - 1)
beta_prod_t = 1 - alpha_prod_t beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev beta_prod_t_prev = 1 - alpha_prod_t_prev
# 2. compute predicted original image from predicted noise also called # 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_image = (image - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5) pred_original_sample = (sample - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5)
# 3. Clip "predicted x_0" # 3. Clip "predicted x_0"
if self.clip_predicted_image: if self.clip_predicted_sample:
pred_original_image = self.clip(pred_original_image, -1, 1) pred_original_sample = self.clip(pred_original_sample, -1, 1)
# 4. Compute coefficients for pred_original_image x_0 and current image x_t # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_image_coeff = (alpha_prod_t_prev ** (0.5) * self.get_beta(t)) / beta_prod_t pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.get_beta(t)) / beta_prod_t
current_image_coeff = self.get_alpha(t) ** (0.5) * beta_prod_t_prev / beta_prod_t current_sample_coeff = self.get_alpha(t) ** (0.5) * beta_prod_t_prev / beta_prod_t
# 5. Compute predicted previous image µ_t # 5. Compute predicted previous sample µ_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_prev_image = pred_original_image_coeff * pred_original_image + current_image_coeff * image pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
return pred_prev_image return pred_prev_sample
def forward_step(self, original_image, noise, t): def forward_step(self, original_sample, noise, t):
sqrt_alpha_prod = self.get_alpha_prod(t) ** 0.5 sqrt_alpha_prod = self.get_alpha_prod(t) ** 0.5
sqrt_one_minus_alpha_prod = (1 - self.get_alpha_prod(t)) ** 0.5 sqrt_one_minus_alpha_prod = (1 - self.get_alpha_prod(t)) ** 0.5
noisy_image = sqrt_alpha_prod * original_image + sqrt_one_minus_alpha_prod * noise noisy_sample = sqrt_alpha_prod * original_sample + sqrt_one_minus_alpha_prod * noise
return noisy_image return noisy_sample
def __len__(self): def __len__(self):
return self.timesteps return self.timesteps
...@@ -62,7 +62,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -62,7 +62,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
# running values # running values
self.cur_residual = 0 self.cur_residual = 0
self.cur_image = None self.cur_sample = None
self.ets = [] self.ets = []
self.warmup_time_steps = {} self.warmup_time_steps = {}
self.time_steps = {} self.time_steps = {}
...@@ -100,7 +100,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -100,7 +100,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return self.time_steps[num_inference_steps] return self.time_steps[num_inference_steps]
def step_prk(self, residual, image, t, num_inference_steps): def step_prk(self, residual, sample, t, num_inference_steps):
# TODO(Patrick) - need to rethink whether the "warmup" way is the correct API design here # TODO(Patrick) - need to rethink whether the "warmup" way is the correct API design here
warmup_time_steps = self.get_warmup_time_steps(num_inference_steps) warmup_time_steps = self.get_warmup_time_steps(num_inference_steps)
...@@ -110,7 +110,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -110,7 +110,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
if t % 4 == 0: if t % 4 == 0:
self.cur_residual += 1 / 6 * residual self.cur_residual += 1 / 6 * residual
self.ets.append(residual) self.ets.append(residual)
self.cur_image = image self.cur_sample = sample
elif (t - 1) % 4 == 0: elif (t - 1) % 4 == 0:
self.cur_residual += 1 / 3 * residual self.cur_residual += 1 / 3 * residual
elif (t - 2) % 4 == 0: elif (t - 2) % 4 == 0:
...@@ -119,9 +119,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -119,9 +119,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
residual = self.cur_residual + 1 / 6 * residual residual = self.cur_residual + 1 / 6 * residual
self.cur_residual = 0 self.cur_residual = 0
return self.transfer(self.cur_image, t_prev, t_next, residual) return self.transfer(self.cur_sample, t_prev, t_next, residual)
def step_plms(self, residual, image, t, num_inference_steps): def step_plms(self, residual, sample, t, num_inference_steps):
timesteps = self.get_time_steps(num_inference_steps) timesteps = self.get_time_steps(num_inference_steps)
t_prev = timesteps[t] t_prev = timesteps[t]
...@@ -130,7 +130,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -130,7 +130,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
residual = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4]) residual = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
return self.transfer(image, t_prev, t_next, residual) return self.transfer(sample, t_prev, t_next, residual)
def transfer(self, x, t, t_next, et): def transfer(self, x, t, t_next, et):
# TODO(Patrick): clean up to be compatible with numpy and give better names # TODO(Patrick): clean up to be compatible with numpy and give better names
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment