"src/vscode:/vscode.git/clone" did not exist on "c81efdf2156c56d8f87f00a366a26c5fcb14eadb"
Unverified Commit 0fc25715 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

fix a bug in 2nd order schedulers when using in ensemble of experts config (#5511)



* fix

* fix copies

* remove heun from tests

* add back heun and fix the tests to include 2nd order

* fix the other test too

* Apply suggestions from code review

* Apply suggestions from code review

* Apply suggestions from code review

* make style

* add more comments

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent de71fa59
...@@ -896,8 +896,20 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -896,8 +896,20 @@ class StableDiffusionXLControlNetInpaintPipeline(
- (denoising_start * self.scheduler.config.num_train_timesteps) - (denoising_start * self.scheduler.config.num_train_timesteps)
) )
) )
timesteps = list(filter(lambda ts: ts < discrete_timestep_cutoff, timesteps))
return torch.tensor(timesteps), len(timesteps) num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
if self.scheduler.order == 2:
# if the scheduler is a 2nd order scheduler we ALWAYS have to do +1
# because `num_inference_steps` will always be even given that every timestep
# (except the highest one) is duplicated. If `num_inference_steps` is even it would
# mean that we cut the timesteps in the middle of the denoising step
# (between 1st and 2nd devirative) which leads to incorrect results. By adding 1
# we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
num_inference_steps = num_inference_steps + 1
# because t_n+1 >= t_n, we slice the timesteps starting from the end
timesteps = timesteps[-num_inference_steps:]
return timesteps, num_inference_steps
return timesteps, num_inference_steps - t_start return timesteps, num_inference_steps - t_start
......
...@@ -553,8 +553,20 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -553,8 +553,20 @@ class StableDiffusionXLImg2ImgPipeline(
- (denoising_start * self.scheduler.config.num_train_timesteps) - (denoising_start * self.scheduler.config.num_train_timesteps)
) )
) )
timesteps = list(filter(lambda ts: ts < discrete_timestep_cutoff, timesteps))
return torch.tensor(timesteps), len(timesteps) num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
if self.scheduler.order == 2:
# if the scheduler is a 2nd order scheduler we ALWAYS have to do +1
# because `num_inference_steps` will always be even given that every timestep
# (except the highest one) is duplicated. If `num_inference_steps` is even it would
# mean that we cut the timesteps in the middle of the denoising step
# (between 1st and 2nd devirative) which leads to incorrect results. By adding 1
# we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
num_inference_steps = num_inference_steps + 1
# because t_n+1 >= t_n, we slice the timesteps starting from the end
timesteps = timesteps[-num_inference_steps:]
return timesteps, num_inference_steps
return timesteps, num_inference_steps - t_start return timesteps, num_inference_steps - t_start
......
...@@ -838,8 +838,20 @@ class StableDiffusionXLInpaintPipeline( ...@@ -838,8 +838,20 @@ class StableDiffusionXLInpaintPipeline(
- (denoising_start * self.scheduler.config.num_train_timesteps) - (denoising_start * self.scheduler.config.num_train_timesteps)
) )
) )
timesteps = list(filter(lambda ts: ts < discrete_timestep_cutoff, timesteps))
return torch.tensor(timesteps), len(timesteps) num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
if self.scheduler.order == 2:
# if the scheduler is a 2nd order scheduler we ALWAYS have to do +1
# because `num_inference_steps` will always be even given that every timestep
# (except the highest one) is duplicated. If `num_inference_steps` is even it would
# mean that we cut the timesteps in the middle of the denoising step
# (between 1st and 2nd devirative) which leads to incorrect results. By adding 1
# we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
num_inference_steps = num_inference_steps + 1
# because t_n+1 >= t_n, we slice the timesteps starting from the end
timesteps = timesteps[-num_inference_steps:]
return timesteps, num_inference_steps
return timesteps, num_inference_steps - t_start return timesteps, num_inference_steps - t_start
......
...@@ -328,8 +328,13 @@ class StableDiffusionXLPipelineFastTests( ...@@ -328,8 +328,13 @@ class StableDiffusionXLPipelineFastTests(
pipe_1.scheduler.set_timesteps(num_steps) pipe_1.scheduler.set_timesteps(num_steps)
expected_steps = pipe_1.scheduler.timesteps.tolist() expected_steps = pipe_1.scheduler.timesteps.tolist()
expected_steps_1 = list(filter(lambda ts: ts >= split, expected_tss)) if pipe_1.scheduler.order == 2:
expected_steps_2 = list(filter(lambda ts: ts < split, expected_tss)) expected_steps_1 = list(filter(lambda ts: ts >= split, expected_tss))
expected_steps_2 = expected_steps_1[-1:] + list(filter(lambda ts: ts < split, expected_tss))
expected_steps = expected_steps_1 + expected_steps_2
else:
expected_steps_1 = list(filter(lambda ts: ts >= split, expected_tss))
expected_steps_2 = list(filter(lambda ts: ts < split, expected_tss))
# now we monkey patch step `done_steps` # now we monkey patch step `done_steps`
# list into the step function for testing # list into the step function for testing
...@@ -611,13 +616,18 @@ class StableDiffusionXLPipelineFastTests( ...@@ -611,13 +616,18 @@ class StableDiffusionXLPipelineFastTests(
split_1_ts = num_train_timesteps - int(round(num_train_timesteps * split_1)) split_1_ts = num_train_timesteps - int(round(num_train_timesteps * split_1))
split_2_ts = num_train_timesteps - int(round(num_train_timesteps * split_2)) split_2_ts = num_train_timesteps - int(round(num_train_timesteps * split_2))
expected_steps_1 = expected_steps[:split_1_ts]
expected_steps_2 = expected_steps[split_1_ts:split_2_ts]
expected_steps_3 = expected_steps[split_2_ts:]
expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps)) if pipe_1.scheduler.order == 2:
expected_steps_2 = list(filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps)) expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps))
expected_steps_3 = list(filter(lambda ts: ts < split_2_ts, expected_steps)) expected_steps_2 = expected_steps_1[-1:] + list(
filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps)
)
expected_steps_3 = expected_steps_2[-1:] + list(filter(lambda ts: ts < split_2_ts, expected_steps))
expected_steps = expected_steps_1 + expected_steps_2 + expected_steps_3
else:
expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps))
expected_steps_2 = list(filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps))
expected_steps_3 = list(filter(lambda ts: ts < split_2_ts, expected_steps))
# now we monkey patch step `done_steps` # now we monkey patch step `done_steps`
# list into the step function for testing # list into the step function for testing
......
...@@ -318,11 +318,14 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel ...@@ -318,11 +318,14 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
expected_steps = pipe_1.scheduler.timesteps.tolist() expected_steps = pipe_1.scheduler.timesteps.tolist()
split_ts = num_train_timesteps - int(round(num_train_timesteps * split)) split_ts = num_train_timesteps - int(round(num_train_timesteps * split))
expected_steps_1 = expected_steps[:split_ts]
expected_steps_2 = expected_steps[split_ts:]
expected_steps_1 = list(filter(lambda ts: ts >= split_ts, expected_steps)) if pipe_1.scheduler.order == 2:
expected_steps_2 = list(filter(lambda ts: ts < split_ts, expected_steps)) expected_steps_1 = list(filter(lambda ts: ts >= split_ts, expected_steps))
expected_steps_2 = expected_steps_1[-1:] + list(filter(lambda ts: ts < split_ts, expected_steps))
expected_steps = expected_steps_1 + expected_steps_2
else:
expected_steps_1 = list(filter(lambda ts: ts >= split_ts, expected_steps))
expected_steps_2 = list(filter(lambda ts: ts < split_ts, expected_steps))
# now we monkey patch step `done_steps` # now we monkey patch step `done_steps`
# list into the step function for testing # list into the step function for testing
...@@ -389,13 +392,18 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel ...@@ -389,13 +392,18 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
split_1_ts = num_train_timesteps - int(round(num_train_timesteps * split_1)) split_1_ts = num_train_timesteps - int(round(num_train_timesteps * split_1))
split_2_ts = num_train_timesteps - int(round(num_train_timesteps * split_2)) split_2_ts = num_train_timesteps - int(round(num_train_timesteps * split_2))
expected_steps_1 = expected_steps[:split_1_ts]
expected_steps_2 = expected_steps[split_1_ts:split_2_ts]
expected_steps_3 = expected_steps[split_2_ts:]
expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps)) if pipe_1.scheduler.order == 2:
expected_steps_2 = list(filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps)) expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps))
expected_steps_3 = list(filter(lambda ts: ts < split_2_ts, expected_steps)) expected_steps_2 = expected_steps_1[-1:] + list(
filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps)
)
expected_steps_3 = expected_steps_2[-1:] + list(filter(lambda ts: ts < split_2_ts, expected_steps))
expected_steps = expected_steps_1 + expected_steps_2 + expected_steps_3
else:
expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps))
expected_steps_2 = list(filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps))
expected_steps_3 = list(filter(lambda ts: ts < split_2_ts, expected_steps))
# now we monkey patch step `done_steps` # now we monkey patch step `done_steps`
# list into the step function for testing # list into the step function for testing
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment