Commit 8c1f5197 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

make clip name shorter

parent dcb23b2d
......@@ -28,7 +28,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
beta_schedule="linear",
trained_betas=None,
timestep_values=None,
clip_predicted_sample=True,
clip_sample=True,
tensor_format="np",
):
super().__init__()
......@@ -40,7 +40,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
)
self.timesteps = int(timesteps)
self.timestep_values = timestep_values # save the fixed timestep values for BDDM
self.clip_sample = clip_predicted_sample
self.clip_sample = clip_sample
if beta_schedule == "linear":
self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
......
......@@ -29,7 +29,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
trained_betas=None,
timestep_values=None,
variance_type="fixed_small",
clip_predicted_sample=True,
clip_sample=True,
tensor_format="np",
):
super().__init__()
......@@ -41,11 +41,11 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
trained_betas=trained_betas,
timestep_values=timestep_values,
variance_type=variance_type,
clip_predicted_sample=clip_predicted_sample,
clip_sample=clip_sample,
)
self.timesteps = int(timesteps)
self.timestep_values = timestep_values # save the fixed timestep values for BDDM
self.clip_sample = clip_predicted_sample
self.clip_sample = clip_sample
self.variance_type = variance_type
if trained_betas is not None:
......@@ -124,7 +124,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
pred_original_sample = (sample - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5)
# 3. Clip "predicted x_0"
if self.clip_predicted_sample:
if self.clip_sample:
pred_original_sample = self.clip(pred_original_sample, -1, 1)
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
......
......@@ -172,7 +172,7 @@ class DDPMSchedulerTest(SchedulerCommonTest):
"beta_end": 0.02,
"beta_schedule": "linear",
"variance_type": "fixed_small",
"clip_predicted_sample": True,
"clip_sample": True,
}
config.update(**kwargs)
......@@ -195,8 +195,8 @@ class DDPMSchedulerTest(SchedulerCommonTest):
self.check_over_configs(variance_type=variance)
def test_clip_image(self):
for clip_predicted_sample in [True, False]:
self.check_over_configs(clip_predicted_sample=clip_predicted_sample)
for clip_sample in [True, False]:
self.check_over_configs(clip_sample=clip_sample)
def test_time_indices(self):
for t in [0, 500, 999]:
......@@ -251,7 +251,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
"beta_start": 0.0001,
"beta_end": 0.02,
"beta_schedule": "linear",
"clip_predicted_sample": True,
"clip_sample": True,
}
config.update(**kwargs)
......@@ -270,8 +270,8 @@ class DDIMSchedulerTest(SchedulerCommonTest):
self.check_over_configs(beta_schedule=schedule)
def test_clip_image(self):
for clip_predicted_sample in [True, False]:
self.check_over_configs(clip_predicted_sample=clip_predicted_sample)
for clip_sample in [True, False]:
self.check_over_configs(clip_sample=clip_sample)
def test_time_indices(self):
for t in [1, 10, 49]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment