"src/vscode:/vscode.git/clone" did not exist on "f5edaa789414517815ba2e66905778027c28aa79"
Unverified Commit 2b23ec82 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

add callbacks to denoising step (#5427)



* draft1

* update

* style

* move to the end of loop

* update

* update callbak_on_step_end_inputs

* Revert "update"

This reverts commit 5f9b153183d0cde3b850f14024d2e37ae8c19576.

* Revert "update callbak_on_step_end_inputs"

This reverts commit 44889f4dabad95b7ebb330faa5f1955b5d008c88.

* update

* update test required_optional_params

* remove self.lora_scale

* img2img

* inpaint

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* fix

* apply feedbacks on img2img + inpaint: keep only important pipeline attributes

* depth

* pix2pix

* make _callback_tensor_inputs an class variable so that we can use it for testing

* add a basic tst for callback

* add a read-only tensor input timesteps + fix tests

* add second test for callback cfg

* sdxl

* sdxl img2img

* sdxl inpaint

* kandinsky prior

* kandinsky decoder

* kandinsky img2img + combined

* kandinsky inpaint

* fix copies

* fix

* consistent default inputs

* fix copies

* wuerstchen_prior prior

* test_wuerstchen_decoder + fix test for prior

* wuerstchen_combined pipeline + skip tests

* skip test for kandinsky combined

* lcm

* remove timesteps etc

* add doc string

* copies

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* make style and improve tests

* up

* up

* fix more

* fix cfg test

* tests for callbacks

* fix for real

* update

* lcm img2img

* add doc

* add doc page to index

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent 080081bd
...@@ -172,6 +172,7 @@ class KandinskyV22PipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -172,6 +172,7 @@ class KandinskyV22PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"output_type", "output_type",
"return_dict", "return_dict",
] ]
callback_cfg_params = ["image_embds"]
test_xformers_attention = False test_xformers_attention = False
def get_dummy_inputs(self, device, seed=0): def get_dummy_inputs(self, device, seed=0):
......
...@@ -192,6 +192,7 @@ class KandinskyV22Img2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCas ...@@ -192,6 +192,7 @@ class KandinskyV22Img2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCas
"return_dict", "return_dict",
] ]
test_xformers_attention = False test_xformers_attention = False
callback_cfg_params = ["image_embeds"]
def get_dummy_components(self): def get_dummy_components(self):
dummies = Dummies() dummies = Dummies()
......
...@@ -123,3 +123,5 @@ TEXT_TO_AUDIO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"]) ...@@ -123,3 +123,5 @@ TEXT_TO_AUDIO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
TOKENS_TO_AUDIO_GENERATION_PARAMS = frozenset(["input_tokens"]) TOKENS_TO_AUDIO_GENERATION_PARAMS = frozenset(["input_tokens"])
TOKENS_TO_AUDIO_GENERATION_BATCH_PARAMS = frozenset(["input_tokens"]) TOKENS_TO_AUDIO_GENERATION_BATCH_PARAMS = frozenset(["input_tokens"])
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS = frozenset(["prompt_embeds"])
...@@ -52,7 +52,12 @@ from diffusers.utils.testing_utils import ( ...@@ -52,7 +52,12 @@ from diffusers.utils.testing_utils import (
torch_device, torch_device,
) )
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..pipeline_params import (
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
TEXT_TO_IMAGE_IMAGE_PARAMS,
TEXT_TO_IMAGE_PARAMS,
)
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
...@@ -100,6 +105,7 @@ class StableDiffusionPipelineFastTests( ...@@ -100,6 +105,7 @@ class StableDiffusionPipelineFastTests(
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
This diff is collapsed.
...@@ -232,3 +232,9 @@ class WuerstchenCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestCase ...@@ -232,3 +232,9 @@ class WuerstchenCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestCase
@unittest.skip(reason="flakey and float16 requires CUDA") @unittest.skip(reason="flakey and float16 requires CUDA")
def test_float16_inference(self): def test_float16_inference(self):
super().test_float16_inference() super().test_float16_inference()
def test_callback_inputs(self):
pass
def test_callback_cfg(self):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment