"examples/pytorch/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "0b9df9d798430f1fc440c6fa8a8dca2a1350d8be"
Unverified Commit 9c03a7da authored by Seunghyeon Kim's avatar Seunghyeon Kim Committed by GitHub
Browse files

Fix DDIMInverseScheduler (#5145)



* fix ddim inverse scheduler

* update test of ddim inverse scheduler

* update test of pix2pix_zero

* update test of diffedit

* fix typo

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 1d3120fb
...@@ -288,9 +288,6 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -288,9 +288,6 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'."
) )
# Roll timesteps array by one to reflect reversed origin and destination semantics for each step
timesteps = np.roll(timesteps, 1)
timesteps[0] = int(timesteps[1] - step_ratio)
self.timesteps = torch.from_numpy(timesteps).to(device) self.timesteps = torch.from_numpy(timesteps).to(device)
def step( def step(
...@@ -335,7 +332,8 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -335,7 +332,8 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
""" """
# 1. get previous step value (=t+1) # 1. get previous step value (=t+1)
prev_timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps prev_timestep = timestep
timestep = min(timestep - self.config.num_train_timesteps // self.num_inference_steps, self.num_train_timesteps-1)
# 2. compute alphas, betas # 2. compute alphas, betas
# change original implementation to exactly match noise levels for analogous forward process # change original implementation to exactly match noise levels for analogous forward process
......
...@@ -229,7 +229,7 @@ class StableDiffusionPix2PixZeroPipelineFastTests(PipelineLatentTesterMixin, Pip ...@@ -229,7 +229,7 @@ class StableDiffusionPix2PixZeroPipelineFastTests(PipelineLatentTesterMixin, Pip
image = sd_pipe.invert(**inputs).images image = sd_pipe.invert(**inputs).images
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3) assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.4823, 0.4783, 0.5638, 0.5201, 0.5247, 0.5644, 0.5029, 0.5404, 0.5062]) expected_slice = np.array([0.4732, 0.4630, 0.5722, 0.5103, 0.5140, 0.5622, 0.5104, 0.5390, 0.5020])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
...@@ -244,7 +244,7 @@ class StableDiffusionPix2PixZeroPipelineFastTests(PipelineLatentTesterMixin, Pip ...@@ -244,7 +244,7 @@ class StableDiffusionPix2PixZeroPipelineFastTests(PipelineLatentTesterMixin, Pip
image = sd_pipe.invert(**inputs).images image = sd_pipe.invert(**inputs).images
image_slice = image[1, -3:, -3:, -1] image_slice = image[1, -3:, -3:, -1]
assert image.shape == (2, 32, 32, 3) assert image.shape == (2, 32, 32, 3)
expected_slice = np.array([0.6446, 0.5232, 0.4914, 0.4441, 0.4654, 0.5546, 0.4650, 0.4938, 0.5044]) expected_slice = np.array([0.6046, 0.5400, 0.4902, 0.4448, 0.4694, 0.5498, 0.4857, 0.5073, 0.5089])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
......
...@@ -257,7 +257,7 @@ class StableDiffusionDiffEditPipelineFastTests(PipelineLatentTesterMixin, Pipeli ...@@ -257,7 +257,7 @@ class StableDiffusionDiffEditPipelineFastTests(PipelineLatentTesterMixin, Pipeli
self.assertEqual(image.shape, (2, 32, 32, 3)) self.assertEqual(image.shape, (2, 32, 32, 3))
expected_slice = np.array( expected_slice = np.array(
[0.5150, 0.5134, 0.5043, 0.5376, 0.4694, 0.5105, 0.5015, 0.4407, 0.4799], [0.5160, 0.5115, 0.5060, 0.5456, 0.4704, 0.5060, 0.5019, 0.4405, 0.4726],
) )
max_diff = np.abs(image_slice.flatten() - expected_slice).max() max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3) self.assertLessEqual(max_diff, 1e-3)
......
...@@ -51,7 +51,7 @@ class DDIMInverseSchedulerTest(SchedulerCommonTest): ...@@ -51,7 +51,7 @@ class DDIMInverseSchedulerTest(SchedulerCommonTest):
scheduler_config = self.get_scheduler_config(steps_offset=1) scheduler_config = self.get_scheduler_config(steps_offset=1)
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(5) scheduler.set_timesteps(5)
assert torch.equal(scheduler.timesteps, torch.LongTensor([-199, 1, 201, 401, 601])) assert torch.equal(scheduler.timesteps, torch.LongTensor([ 1, 201, 401, 601, 801]))
def test_betas(self): def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]): for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
...@@ -104,8 +104,8 @@ class DDIMInverseSchedulerTest(SchedulerCommonTest): ...@@ -104,8 +104,8 @@ class DDIMInverseSchedulerTest(SchedulerCommonTest):
result_sum = torch.sum(torch.abs(sample)) result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 509.1079) < 1e-2 assert abs(result_sum.item() - 671.6816) < 1e-2
assert abs(result_mean.item() - 0.6629) < 1e-3 assert abs(result_mean.item() - 0.8746) < 1e-3
def test_full_loop_with_v_prediction(self): def test_full_loop_with_v_prediction(self):
sample = self.full_loop(prediction_type="v_prediction") sample = self.full_loop(prediction_type="v_prediction")
...@@ -113,8 +113,8 @@ class DDIMInverseSchedulerTest(SchedulerCommonTest): ...@@ -113,8 +113,8 @@ class DDIMInverseSchedulerTest(SchedulerCommonTest):
result_sum = torch.sum(torch.abs(sample)) result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 1029.129) < 1e-2 assert abs(result_sum.item() - 1394.2185) < 1e-2
assert abs(result_mean.item() - 1.3400) < 1e-3 assert abs(result_mean.item() - 1.8154) < 1e-3
def test_full_loop_with_set_alpha_to_one(self): def test_full_loop_with_set_alpha_to_one(self):
# We specify different beta, so that the first alpha is 0.99 # We specify different beta, so that the first alpha is 0.99
...@@ -122,8 +122,8 @@ class DDIMInverseSchedulerTest(SchedulerCommonTest): ...@@ -122,8 +122,8 @@ class DDIMInverseSchedulerTest(SchedulerCommonTest):
result_sum = torch.sum(torch.abs(sample)) result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 259.8116) < 1e-2 assert abs(result_sum.item() - 539.9622) < 1e-2
assert abs(result_mean.item() - 0.3383) < 1e-3 assert abs(result_mean.item() - 0.7031) < 1e-3
def test_full_loop_with_no_set_alpha_to_one(self): def test_full_loop_with_no_set_alpha_to_one(self):
# We specify different beta, so that the first alpha is 0.99 # We specify different beta, so that the first alpha is 0.99
...@@ -131,5 +131,5 @@ class DDIMInverseSchedulerTest(SchedulerCommonTest): ...@@ -131,5 +131,5 @@ class DDIMInverseSchedulerTest(SchedulerCommonTest):
result_sum = torch.sum(torch.abs(sample)) result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 239.055) < 1e-2 assert abs(result_sum.item() - 542.6722) < 1e-2
assert abs(result_mean.item() - 0.3113) < 1e-3 assert abs(result_mean.item() - 0.7066) < 1e-3
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment