Unverified Commit 67d07074 authored by dg845's avatar dg845 Committed by GitHub
Browse files

Add Custom Timesteps Support to LCMScheduler and Supported Pipelines (#5874)

* Add custom timesteps support to LCMScheduler.

* Add custom timesteps support to StableDiffusionPipeline.

* Add custom timesteps support to StableDiffusionXLPipeline.

* Add custom timesteps support to remaining Stable Diffusion pipelines which support LCMScheduler (img2img, inpaint).

* Add custom timesteps support to remaining Stable Diffusion XL pipelines which support LCMScheduler (img2img, inpaint).

* Add custom timesteps support to StableDiffusionControlNetPipeline.

* Add custom timesteps support to T21 Stable Diffusion (XL) Adapters.

* Clean up Stable Diffusion inpaint tests.

* Manually add support for custom timesteps to AltDiffusion pipelines since make fix-copies doesn't appear to work correctly (it deletes the whole pipeline).

* make style

* Refactor pipeline timestep handling into the retrieve_timesteps function.
parent 9c357bda
...@@ -226,6 +226,25 @@ class StableDiffusionInpaintPipelineFastTests( ...@@ -226,6 +226,25 @@ class StableDiffusionInpaintPipelineFastTests(
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_inpaint_lcm_custom_timesteps(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components(time_cond_proj_dim=256)
sd_pipe = StableDiffusionInpaintPipeline(**components)
sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
del inputs["num_inference_steps"]
inputs["timesteps"] = [999, 499]
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.4931, 0.5988, 0.4569, 0.5556, 0.6650, 0.5087, 0.5966, 0.5358, 0.5269])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_inpaint_image_tensor(self): def test_stable_diffusion_inpaint_image_tensor(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components() components = self.get_dummy_components()
...@@ -420,6 +439,25 @@ class StableDiffusionSimpleInpaintPipelineFastTests(StableDiffusionInpaintPipeli ...@@ -420,6 +439,25 @@ class StableDiffusionSimpleInpaintPipelineFastTests(StableDiffusionInpaintPipeli
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_inpaint_lcm_custom_timesteps(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components(time_cond_proj_dim=256)
sd_pipe = StableDiffusionInpaintPipeline(**components)
sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
del inputs["num_inference_steps"]
inputs["timesteps"] = [999, 499]
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.6240, 0.5355, 0.5649, 0.5378, 0.5374, 0.6242, 0.5132, 0.5347, 0.5396])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_inpaint_2_images(self): def test_stable_diffusion_inpaint_2_images(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components() components = self.get_dummy_components()
......
...@@ -183,6 +183,25 @@ class StableDiffusionXLPipelineFastTests( ...@@ -183,6 +183,25 @@ class StableDiffusionXLPipelineFastTests(
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_xl_euler_lcm_custom_timesteps(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components(time_cond_proj_dim=256)
sd_pipe = StableDiffusionXLPipeline(**components)
sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
del inputs["num_inference_steps"]
inputs["timesteps"] = [999, 499]
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.4917, 0.6555, 0.4348, 0.5219, 0.7324, 0.4855, 0.5168, 0.5447, 0.5156])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_xl_prompt_embeds(self): def test_stable_diffusion_xl_prompt_embeds(self):
components = self.get_dummy_components() components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**components) sd_pipe = StableDiffusionXLPipeline(**components)
......
...@@ -389,6 +389,28 @@ class StableDiffusionXLAdapterPipelineFastTests( ...@@ -389,6 +389,28 @@ class StableDiffusionXLAdapterPipelineFastTests(
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_adapter_sdxl_lcm_custom_timesteps(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components(time_cond_proj_dim=256)
sd_pipe = StableDiffusionXLAdapterPipeline(**components)
sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
del inputs["num_inference_steps"]
inputs["timesteps"] = [999, 499]
output = sd_pipe(**inputs)
image = output.images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.5425, 0.5385, 0.4964, 0.5045, 0.6149, 0.4974, 0.5469, 0.5332, 0.5426])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
class StableDiffusionXLMultiAdapterPipelineFastTests( class StableDiffusionXLMultiAdapterPipelineFastTests(
StableDiffusionXLAdapterPipelineFastTests, PipelineTesterMixin, unittest.TestCase StableDiffusionXLAdapterPipelineFastTests, PipelineTesterMixin, unittest.TestCase
...@@ -614,6 +636,31 @@ class StableDiffusionXLMultiAdapterPipelineFastTests( ...@@ -614,6 +636,31 @@ class StableDiffusionXLMultiAdapterPipelineFastTests(
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_adapter_sdxl_lcm_custom_timesteps(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components(time_cond_proj_dim=256)
sd_pipe = StableDiffusionXLAdapterPipeline(**components)
sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
del inputs["num_inference_steps"]
inputs["timesteps"] = [999, 499]
output = sd_pipe(**inputs)
image = output.images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.5313, 0.5375, 0.4942, 0.5021, 0.6142, 0.4968, 0.5434, 0.5311, 0.5448])
debug = [str(round(i, 4)) for i in image_slice.flatten().tolist()]
print(",".join(debug))
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow @slow
@require_torch_gpu @require_torch_gpu
......
...@@ -227,6 +227,26 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel ...@@ -227,6 +227,26 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_xl_img2img_euler_lcm_custom_timesteps(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components(time_cond_proj_dim=256)
sd_pipe = StableDiffusionXLImg2ImgPipeline(**components)
sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.config)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
del inputs["num_inference_steps"]
inputs["timesteps"] = [999, 499]
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.5604, 0.4352, 0.4717, 0.5844, 0.5101, 0.6704, 0.6290, 0.5460, 0.5286])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_attention_slicing_forward_pass(self): def test_attention_slicing_forward_pass(self):
super().test_attention_slicing_forward_pass(expected_max_diff=3e-3) super().test_attention_slicing_forward_pass(expected_max_diff=3e-3)
......
...@@ -264,6 +264,26 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel ...@@ -264,6 +264,26 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_xl_inpaint_euler_lcm_custom_timesteps(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components(time_cond_proj_dim=256)
sd_pipe = StableDiffusionXLInpaintPipeline(**components)
sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.config)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
del inputs["num_inference_steps"]
inputs["timesteps"] = [999, 499]
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.6611, 0.5569, 0.5531, 0.5471, 0.5918, 0.6393, 0.5074, 0.5468, 0.5185])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_attention_slicing_forward_pass(self): def test_attention_slicing_forward_pass(self):
super().test_attention_slicing_forward_pass(expected_max_diff=3e-3) super().test_attention_slicing_forward_pass(expected_max_diff=3e-3)
......
...@@ -242,3 +242,59 @@ class LCMSchedulerTest(SchedulerCommonTest): ...@@ -242,3 +242,59 @@ class LCMSchedulerTest(SchedulerCommonTest):
# TODO: get expected sum and mean # TODO: get expected sum and mean
assert abs(result_sum.item() - 197.7616) < 1e-3 assert abs(result_sum.item() - 197.7616) < 1e-3
assert abs(result_mean.item() - 0.2575) < 1e-3 assert abs(result_mean.item() - 0.2575) < 1e-3
def test_custom_timesteps(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
timesteps = [100, 87, 50, 1, 0]
scheduler.set_timesteps(timesteps=timesteps)
scheduler_timesteps = scheduler.timesteps
for i, timestep in enumerate(scheduler_timesteps):
if i == len(timesteps) - 1:
expected_prev_t = -1
else:
expected_prev_t = timesteps[i + 1]
prev_t = scheduler.previous_timestep(timestep)
prev_t = prev_t.item()
self.assertEqual(prev_t, expected_prev_t)
def test_custom_timesteps_increasing_order(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
timesteps = [100, 87, 50, 51, 0]
with self.assertRaises(ValueError, msg="`custom_timesteps` must be in descending order."):
scheduler.set_timesteps(timesteps=timesteps)
def test_custom_timesteps_passing_both_num_inference_steps_and_timesteps(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
timesteps = [100, 87, 50, 1, 0]
num_inference_steps = len(timesteps)
with self.assertRaises(ValueError, msg="Can only pass one of `num_inference_steps` or `custom_timesteps`."):
scheduler.set_timesteps(num_inference_steps=num_inference_steps, timesteps=timesteps)
def test_custom_timesteps_too_large(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
timesteps = [scheduler.config.num_train_timesteps]
with self.assertRaises(
ValueError,
msg="`timesteps` must start before `self.config.train_timesteps`: {scheduler.config.num_train_timesteps}}",
):
scheduler.set_timesteps(timesteps=timesteps)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment