Unverified Commit 2b23ec82 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

add callbacks to denoising step (#5427)



* draft1

* update

* style

* move to the end of loop

* update

* update callbak_on_step_end_inputs

* Revert "update"

This reverts commit 5f9b153183d0cde3b850f14024d2e37ae8c19576.

* Revert "update callbak_on_step_end_inputs"

This reverts commit 44889f4dabad95b7ebb330faa5f1955b5d008c88.

* update

* update test required_optional_params

* remove self.lora_scale

* img2img

* inpaint

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* fix

* apply feedbacks on img2img + inpaint: keep only important pipeline attributes

* depth

* pix2pix

* make _callback_tensor_inputs an class variable so that we can use it for testing

* add a basic tst for callback

* add a read-only tensor input timesteps + fix tests

* add second test for callback cfg

* sdxl

* sdxl img2img

* sdxl inpaint

* kandinsky prior

* kandinsky decoder

* kandinsky img2img + combined

* kandinsky inpaint

* fix copies

* fix

* consistent default inputs

* fix copies

* wuerstchen_prior prior

* test_wuerstchen_decoder + fix test for prior

* wuerstchen_combined pipeline + skip tests

* skip test for kandinsky combined

* lcm

* remove timesteps etc

* add doc string

* copies

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* make style and improve tests

* up

* up

* fix more

* fix cfg test

* tests for callbacks

* fix for real

* update

* lcm img2img

* add doc

* add doc page to index

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent 080081bd
......@@ -43,6 +43,7 @@ class WuerstchenDecoderPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
"return_dict",
]
test_xformers_attention = False
callback_cfg_params = ["image_embeddings", "text_encoder_hidden_states"]
@property
def text_embedder_hidden_size(self):
......
......@@ -44,6 +44,7 @@ class WuerstchenPriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"return_dict",
]
test_xformers_attention = False
callback_cfg_params = ["text_encoder_hidden_states"]
@property
def text_embedder_hidden_size(self):
......@@ -183,3 +184,38 @@ class WuerstchenPriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
@unittest.skip(reason="flaky for now")
def test_float16_inference(self):
super().test_float16_inference()
# override because we need to make sure latent_mean and latent_std to be 0
def test_callback_inputs(self):
components = self.get_dummy_components()
components["latent_mean"] = 0
components["latent_std"] = 0
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
self.assertTrue(
hasattr(pipe, "_callback_tensor_inputs"),
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
)
def callback_inputs_test(pipe, i, t, callback_kwargs):
missing_callback_inputs = set()
for v in pipe._callback_tensor_inputs:
if v not in callback_kwargs:
missing_callback_inputs.add(v)
self.assertTrue(
len(missing_callback_inputs) == 0, f"Missing callback tensor inputs: {missing_callback_inputs}"
)
last_i = pipe.num_timesteps - 1
if i == last_i:
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
return callback_kwargs
inputs = self.get_dummy_inputs(torch_device)
inputs["callback_on_step_end"] = callback_inputs_test
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
inputs["output_type"] = "latent"
output = pipe(**inputs)[0]
assert output.abs().sum() == 0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment