Commit 1997b908 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

image->sample in schedule tests

parent b2274ece
...@@ -31,37 +31,37 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -31,37 +31,37 @@ class SchedulerCommonTest(unittest.TestCase):
forward_default_kwargs = () forward_default_kwargs = ()
@property @property
def dummy_image(self): def dummy_sample(self):
batch_size = 4 batch_size = 4
num_channels = 3 num_channels = 3
height = 8 height = 8
width = 8 width = 8
image = np.random.rand(batch_size, num_channels, height, width) sample = np.random.rand(batch_size, num_channels, height, width)
return image return sample
@property @property
def dummy_image_deter(self): def dummy_sample_deter(self):
batch_size = 4 batch_size = 4
num_channels = 3 num_channels = 3
height = 8 height = 8
width = 8 width = 8
num_elems = batch_size * num_channels * height * width num_elems = batch_size * num_channels * height * width
image = np.arange(num_elems) sample = np.arange(num_elems)
image = image.reshape(num_channels, height, width, batch_size) sample = sample.reshape(num_channels, height, width, batch_size)
image = image / num_elems sample = sample / num_elems
image = image.transpose(3, 0, 1, 2) sample = sample.transpose(3, 0, 1, 2)
return image return sample
def get_scheduler_config(self): def get_scheduler_config(self):
raise NotImplementedError raise NotImplementedError
def dummy_model(self): def dummy_model(self):
def model(image, t, *args): def model(sample, t, *args):
return image * t / (t + 1) return sample * t / (t + 1)
return model return model
...@@ -70,8 +70,8 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -70,8 +70,8 @@ class SchedulerCommonTest(unittest.TestCase):
for scheduler_class in self.scheduler_classes: for scheduler_class in self.scheduler_classes:
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
image = self.dummy_image sample = self.dummy_sample
residual = 0.1 * image residual = 0.1 * sample
scheduler_config = self.get_scheduler_config(**config) scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
...@@ -80,8 +80,8 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -80,8 +80,8 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler.save_config(tmpdirname) scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname) new_scheduler = scheduler_class.from_config(tmpdirname)
output = scheduler.step(residual, image, time_step, **kwargs) output = scheduler.step(residual, sample, time_step, **kwargs)
new_output = new_scheduler.step(residual, image, time_step, **kwargs) new_output = new_scheduler.step(residual, sample, time_step, **kwargs)
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
...@@ -90,8 +90,8 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -90,8 +90,8 @@ class SchedulerCommonTest(unittest.TestCase):
kwargs.update(forward_kwargs) kwargs.update(forward_kwargs)
for scheduler_class in self.scheduler_classes: for scheduler_class in self.scheduler_classes:
image = self.dummy_image sample = self.dummy_sample
residual = 0.1 * image residual = 0.1 * sample
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
...@@ -101,8 +101,8 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -101,8 +101,8 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler.save_config(tmpdirname) scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname) new_scheduler = scheduler_class.from_config(tmpdirname)
output = scheduler.step(residual, image, time_step, **kwargs) output = scheduler.step(residual, sample, time_step, **kwargs)
new_output = new_scheduler.step(residual, image, time_step, **kwargs) new_output = new_scheduler.step(residual, sample, time_step, **kwargs)
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
...@@ -110,8 +110,8 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -110,8 +110,8 @@ class SchedulerCommonTest(unittest.TestCase):
kwargs = dict(self.forward_default_kwargs) kwargs = dict(self.forward_default_kwargs)
for scheduler_class in self.scheduler_classes: for scheduler_class in self.scheduler_classes:
image = self.dummy_image sample = self.dummy_sample
residual = 0.1 * image residual = 0.1 * sample
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
...@@ -120,8 +120,8 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -120,8 +120,8 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler.save_config(tmpdirname) scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname) new_scheduler = scheduler_class.from_config(tmpdirname)
output = scheduler.step(residual, image, 1, **kwargs) output = scheduler.step(residual, sample, 1, **kwargs)
new_output = new_scheduler.step(residual, image, 1, **kwargs) new_output = new_scheduler.step(residual, sample, 1, **kwargs)
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
...@@ -132,32 +132,32 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -132,32 +132,32 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
image = self.dummy_image sample = self.dummy_sample
residual = 0.1 * image residual = 0.1 * sample
output_0 = scheduler.step(residual, image, 0, **kwargs) output_0 = scheduler.step(residual, sample, 0, **kwargs)
output_1 = scheduler.step(residual, image, 1, **kwargs) output_1 = scheduler.step(residual, sample, 1, **kwargs)
self.assertEqual(output_0.shape, image.shape) self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape) self.assertEqual(output_0.shape, output_1.shape)
def test_pytorch_equal_numpy(self): def test_pytorch_equal_numpy(self):
kwargs = dict(self.forward_default_kwargs) kwargs = dict(self.forward_default_kwargs)
for scheduler_class in self.scheduler_classes: for scheduler_class in self.scheduler_classes:
image = self.dummy_image sample = self.dummy_sample
residual = 0.1 * image residual = 0.1 * sample
image_pt = torch.tensor(image) sample_pt = torch.tensor(sample)
residual_pt = 0.1 * image_pt residual_pt = 0.1 * sample_pt
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config) scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)
output = scheduler.step(residual, image, 1, **kwargs) output = scheduler.step(residual, sample, 1, **kwargs)
output_pt = scheduler_pt.step(residual_pt, image_pt, 1, **kwargs) output_pt = scheduler_pt.step(residual_pt, sample_pt, 1, **kwargs)
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical" assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
...@@ -194,7 +194,7 @@ class DDPMSchedulerTest(SchedulerCommonTest): ...@@ -194,7 +194,7 @@ class DDPMSchedulerTest(SchedulerCommonTest):
for variance in ["fixed_small", "fixed_large", "other"]: for variance in ["fixed_small", "fixed_large", "other"]:
self.check_over_configs(variance_type=variance) self.check_over_configs(variance_type=variance)
def test_clip_image(self): def test_clip_sample(self):
for clip_sample in [True, False]: for clip_sample in [True, False]:
self.check_over_configs(clip_sample=clip_sample) self.check_over_configs(clip_sample=clip_sample)
...@@ -219,23 +219,23 @@ class DDPMSchedulerTest(SchedulerCommonTest): ...@@ -219,23 +219,23 @@ class DDPMSchedulerTest(SchedulerCommonTest):
num_trained_timesteps = len(scheduler) num_trained_timesteps = len(scheduler)
model = self.dummy_model() model = self.dummy_model()
image = self.dummy_image_deter sample = self.dummy_sample_deter
for t in reversed(range(num_trained_timesteps)): for t in reversed(range(num_trained_timesteps)):
# 1. predict noise residual # 1. predict noise residual
residual = model(image, t) residual = model(sample, t)
# 2. predict previous mean of image x_t-1 # 2. predict previous mean of sample x_t-1
pred_prev_image = scheduler.step(residual, image, t) pred_prev_sample = scheduler.step(residual, sample, t)
if t > 0: if t > 0:
noise = self.dummy_image_deter noise = self.dummy_sample_deter
variance = scheduler.get_variance(t) ** (0.5) * noise variance = scheduler.get_variance(t) ** (0.5) * noise
image = pred_prev_image + variance sample = pred_prev_sample + variance
result_sum = np.sum(np.abs(image)) result_sum = np.sum(np.abs(sample))
result_mean = np.mean(np.abs(image)) result_mean = np.mean(np.abs(sample))
assert abs(result_sum.item() - 732.9947) < 1e-2 assert abs(result_sum.item() - 732.9947) < 1e-2
assert abs(result_mean.item() - 0.9544) < 1e-3 assert abs(result_mean.item() - 0.9544) < 1e-3
...@@ -269,7 +269,7 @@ class DDIMSchedulerTest(SchedulerCommonTest): ...@@ -269,7 +269,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
for schedule in ["linear", "squaredcos_cap_v2"]: for schedule in ["linear", "squaredcos_cap_v2"]:
self.check_over_configs(beta_schedule=schedule) self.check_over_configs(beta_schedule=schedule)
def test_clip_image(self): def test_clip_sample(self):
for clip_sample in [True, False]: for clip_sample in [True, False]:
self.check_over_configs(clip_sample=clip_sample) self.check_over_configs(clip_sample=clip_sample)
...@@ -308,22 +308,22 @@ class DDIMSchedulerTest(SchedulerCommonTest): ...@@ -308,22 +308,22 @@ class DDIMSchedulerTest(SchedulerCommonTest):
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps) inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
model = self.dummy_model() model = self.dummy_model()
image = self.dummy_image_deter sample = self.dummy_sample_deter
for t in reversed(range(num_inference_steps)): for t in reversed(range(num_inference_steps)):
residual = model(image, inference_step_times[t]) residual = model(sample, inference_step_times[t])
pred_prev_image = scheduler.step(residual, image, t, num_inference_steps, eta) pred_prev_sample = scheduler.step(residual, sample, t, num_inference_steps, eta)
variance = 0 variance = 0
if eta > 0: if eta > 0:
noise = self.dummy_image_deter noise = self.dummy_sample_deter
variance = scheduler.get_variance(t, num_inference_steps) ** (0.5) * eta * noise variance = scheduler.get_variance(t, num_inference_steps) ** (0.5) * eta * noise
image = pred_prev_image + variance sample = pred_prev_sample + variance
result_sum = np.sum(np.abs(image)) result_sum = np.sum(np.abs(sample))
result_mean = np.mean(np.abs(image)) result_mean = np.mean(np.abs(sample))
assert abs(result_sum.item() - 270.6214) < 1e-2 assert abs(result_sum.item() - 270.6214) < 1e-2
assert abs(result_mean.item() - 0.3524) < 1e-3 assert abs(result_mean.item() - 0.3524) < 1e-3
...@@ -346,8 +346,8 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -346,8 +346,8 @@ class PNDMSchedulerTest(SchedulerCommonTest):
def check_over_configs_pmls(self, time_step=0, **config): def check_over_configs_pmls(self, time_step=0, **config):
kwargs = dict(self.forward_default_kwargs) kwargs = dict(self.forward_default_kwargs)
image = self.dummy_image sample = self.dummy_sample
residual = 0.1 * image residual = 0.1 * sample
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05] dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
for scheduler_class in self.scheduler_classes: for scheduler_class in self.scheduler_classes:
...@@ -365,16 +365,16 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -365,16 +365,16 @@ class PNDMSchedulerTest(SchedulerCommonTest):
new_scheduler.ets = dummy_past_residuals[:] new_scheduler.ets = dummy_past_residuals[:]
new_scheduler.set_plms_mode() new_scheduler.set_plms_mode()
output = scheduler.step(residual, image, time_step, **kwargs) output = scheduler.step(residual, sample, time_step, **kwargs)
new_output = new_scheduler.step(residual, image, time_step, **kwargs) new_output = new_scheduler.step(residual, sample, time_step, **kwargs)
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def check_over_forward_pmls(self, time_step=0, **forward_kwargs): def check_over_forward_pmls(self, time_step=0, **forward_kwargs):
kwargs = dict(self.forward_default_kwargs) kwargs = dict(self.forward_default_kwargs)
kwargs.update(forward_kwargs) kwargs.update(forward_kwargs)
image = self.dummy_image sample = self.dummy_sample
residual = 0.1 * image residual = 0.1 * sample
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05] dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
for scheduler_class in self.scheduler_classes: for scheduler_class in self.scheduler_classes:
...@@ -392,8 +392,8 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -392,8 +392,8 @@ class PNDMSchedulerTest(SchedulerCommonTest):
new_scheduler.ets = dummy_past_residuals[:] new_scheduler.ets = dummy_past_residuals[:]
new_scheduler.set_plms_mode() new_scheduler.set_plms_mode()
output = scheduler.step(residual, image, time_step, **kwargs) output = scheduler.step(residual, sample, time_step, **kwargs)
new_output = new_scheduler.step(residual, image, time_step, **kwargs) new_output = new_scheduler.step(residual, sample, time_step, **kwargs)
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
...@@ -445,7 +445,7 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -445,7 +445,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
scheduler.set_plms_mode() scheduler.set_plms_mode()
scheduler.step(self.dummy_image, self.dummy_image, 1, 50) scheduler.step(self.dummy_sample, self.dummy_sample, 1, 50)
def test_full_loop_no_noise(self): def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
...@@ -454,24 +454,24 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -454,24 +454,24 @@ class PNDMSchedulerTest(SchedulerCommonTest):
num_inference_steps = 10 num_inference_steps = 10
model = self.dummy_model() model = self.dummy_model()
image = self.dummy_image_deter sample = self.dummy_sample_deter
prk_time_steps = scheduler.get_prk_time_steps(num_inference_steps) prk_time_steps = scheduler.get_prk_time_steps(num_inference_steps)
for t in range(len(prk_time_steps)): for t in range(len(prk_time_steps)):
t_orig = prk_time_steps[t] t_orig = prk_time_steps[t]
residual = model(image, t_orig) residual = model(sample, t_orig)
image = scheduler.step_prk(residual, image, t, num_inference_steps) sample = scheduler.step_prk(residual, sample, t, num_inference_steps)
timesteps = scheduler.get_time_steps(num_inference_steps) timesteps = scheduler.get_time_steps(num_inference_steps)
for t in range(len(timesteps)): for t in range(len(timesteps)):
t_orig = timesteps[t] t_orig = timesteps[t]
residual = model(image, t_orig) residual = model(sample, t_orig)
image = scheduler.step_plms(residual, image, t, num_inference_steps) sample = scheduler.step_plms(residual, sample, t, num_inference_steps)
result_sum = np.sum(np.abs(image)) result_sum = np.sum(np.abs(sample))
result_mean = np.mean(np.abs(image)) result_mean = np.mean(np.abs(sample))
assert abs(result_sum.item() - 199.1169) < 1e-2 assert abs(result_sum.item() - 199.1169) < 1e-2
assert abs(result_mean.item() - 0.2593) < 1e-3 assert abs(result_mean.item() - 0.2593) < 1e-3
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment