Commit b96c6ce1 authored by patil-suraj's avatar patil-suraj
Browse files

remove trained_betas from ddim and add in ddpm

parent 2d1f7de2
......@@ -26,8 +26,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
beta_start=0.0001,
beta_end=0.02,
beta_schedule="linear",
trained_betas=None,
timestep_values=None,
clip_predicted_image=True,
tensor_format="np",
):
......@@ -39,12 +37,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
beta_schedule=beta_schedule,
)
self.timesteps = int(timesteps)
self.timestep_values = timestep_values # save the fixed timestep values for BDDM
self.clip_image = clip_predicted_image
if trained_betas is not None:
self.betas = np.asarray(trained_betas)
elif beta_schedule == "linear":
if beta_schedule == "linear":
self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
elif beta_schedule == "squaredcos_cap_v2":
# GLIDE cosine schedule
......
......@@ -26,6 +26,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
beta_start=0.0001,
beta_end=0.02,
beta_schedule="linear",
trained_betas=None,
timestep_values=None,
variance_type="fixed_small",
clip_predicted_image=True,
tensor_format="np",
......@@ -40,10 +42,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
clip_predicted_image=clip_predicted_image,
)
self.timesteps = int(timesteps)
self.timestep_values = timestep_values # save the fixed timestep values for BDDM
self.clip_image = clip_predicted_image
self.variance_type = variance_type
if beta_schedule == "linear":
if trained_betas is not None:
self.betas = np.asarray(trained_betas)
elif beta_schedule == "linear":
self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
elif beta_schedule == "squaredcos_cap_v2":
# GLIDE cosine schedule
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment