Unverified Commit 2b23ec82 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

add callbacks to denoising step (#5427)



* draft1

* update

* style

* move to the end of loop

* update

* update callbak_on_step_end_inputs

* Revert "update"

This reverts commit 5f9b153183d0cde3b850f14024d2e37ae8c19576.

* Revert "update callbak_on_step_end_inputs"

This reverts commit 44889f4dabad95b7ebb330faa5f1955b5d008c88.

* update

* update test required_optional_params

* remove self.lora_scale

* img2img

* inpaint

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* fix

* apply feedbacks on img2img + inpaint: keep only important pipeline attributes

* depth

* pix2pix

* make _callback_tensor_inputs an class variable so that we can use it for testing

* add a basic tst for callback

* add a read-only tensor input timesteps + fix tests

* add second test for callback cfg

* sdxl

* sdxl img2img

* sdxl inpaint

* kandinsky prior

* kandinsky decoder

* kandinsky img2img + combined

* kandinsky inpaint

* fix copies

* fix

* consistent default inputs

* fix copies

* wuerstchen_prior prior

* test_wuerstchen_decoder + fix test for prior

* wuerstchen_combined pipeline + skip tests

* skip test for kandinsky combined

* lcm

* remove timesteps etc

* add doc string

* copies

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* make style and improve tests

* up

* up

* fix more

* fix cfg test

* tests for callbacks

* fix for real

* update

* lcm img2img

* add doc

* add doc page to index

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent 080081bd
......@@ -172,6 +172,7 @@ class KandinskyV22PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"output_type",
"return_dict",
]
callback_cfg_params = ["image_embds"]
test_xformers_attention = False
def get_dummy_inputs(self, device, seed=0):
......
......@@ -55,6 +55,7 @@ class KandinskyV22PipelineCombinedFastTests(PipelineTesterMixin, unittest.TestCa
"return_dict",
]
test_xformers_attention = True
callback_cfg_params = ["image_embds"]
def get_dummy_components(self):
dummy = Dummies()
......@@ -152,6 +153,12 @@ class KandinskyV22PipelineCombinedFastTests(PipelineTesterMixin, unittest.TestCa
def test_save_load_optional_components(self):
super().test_save_load_optional_components(expected_max_difference=5e-3)
def test_callback_inputs(self):
pass
def test_callback_cfg(self):
pass
class KandinskyV22PipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = KandinskyV22Img2ImgCombinedPipeline
......@@ -172,6 +179,7 @@ class KandinskyV22PipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest
"return_dict",
]
test_xformers_attention = False
callback_cfg_params = ["image_embds"]
def get_dummy_components(self):
dummy = Img2ImgDummies()
......@@ -267,6 +275,12 @@ class KandinskyV22PipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest
def save_load_local(self):
super().test_save_load_local(expected_max_difference=5e-3)
def test_callback_inputs(self):
pass
def test_callback_cfg(self):
pass
class KandinskyV22PipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = KandinskyV22InpaintCombinedPipeline
......@@ -384,3 +398,9 @@ class KandinskyV22PipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest
def test_sequential_cpu_offload_forward_pass(self):
super().test_sequential_cpu_offload_forward_pass(expected_max_diff=5e-4)
def test_callback_inputs(self):
pass
def test_callback_cfg(self):
pass
......@@ -192,6 +192,7 @@ class KandinskyV22Img2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCas
"return_dict",
]
test_xformers_attention = False
callback_cfg_params = ["image_embeds"]
def get_dummy_components(self):
dummies = Dummies()
......
......@@ -194,6 +194,7 @@ class KandinskyV22InpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCas
"return_dict",
]
test_xformers_attention = False
callback_cfg_params = ["image_embeds", "masked_image", "mask_image"]
def get_dummy_components(self):
dummies = Dummies()
......@@ -252,6 +253,40 @@ class KandinskyV22InpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCas
def test_sequential_cpu_offload_forward_pass(self):
super().test_sequential_cpu_offload_forward_pass(expected_max_diff=5e-4)
# override default test because we need to zero out mask too in order to make sure final latent is all zero
def test_callback_inputs(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
self.assertTrue(
hasattr(pipe, "_callback_tensor_inputs"),
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
)
def callback_inputs_test(pipe, i, t, callback_kwargs):
missing_callback_inputs = set()
for v in pipe._callback_tensor_inputs:
if v not in callback_kwargs:
missing_callback_inputs.add(v)
self.assertTrue(
len(missing_callback_inputs) == 0, f"Missing callback tensor inputs: {missing_callback_inputs}"
)
last_i = pipe.num_timesteps - 1
if i == last_i:
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
callback_kwargs["mask_image"] = torch.zeros_like(callback_kwargs["mask_image"])
return callback_kwargs
inputs = self.get_dummy_inputs(torch_device)
inputs["callback_on_step_end"] = callback_inputs_test
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
inputs["output_type"] = "latent"
output = pipe(**inputs)[0]
assert output.abs().sum() == 0
@slow
@require_torch_gpu
......
......@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import unittest
import numpy as np
......@@ -182,6 +183,7 @@ class KandinskyV22PriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
"output_type",
"return_dict",
]
callback_cfg_params = ["prompt_embeds", "text_encoder_hidden_states", "text_mask"]
test_xformers_attention = False
def get_dummy_components(self):
......@@ -235,3 +237,42 @@ class KandinskyV22PriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
test_max_difference=test_max_difference,
test_mean_pixel_difference=test_mean_pixel_difference,
)
# override default test because no output_type "latent", use "pt" instead
def test_callback_inputs(self):
sig = inspect.signature(self.pipeline_class.__call__)
if not ("callback_on_step_end_tensor_inputs" in sig.parameters and "callback_on_step_end" in sig.parameters):
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
self.assertTrue(
hasattr(pipe, "_callback_tensor_inputs"),
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
)
def callback_inputs_test(pipe, i, t, callback_kwargs):
missing_callback_inputs = set()
for v in pipe._callback_tensor_inputs:
if v not in callback_kwargs:
missing_callback_inputs.add(v)
self.assertTrue(
len(missing_callback_inputs) == 0, f"Missing callback tensor inputs: {missing_callback_inputs}"
)
last_i = pipe.num_timesteps - 1
if i == last_i:
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
return callback_kwargs
inputs = self.get_dummy_inputs(torch_device)
inputs["callback_on_step_end"] = callback_inputs_test
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
inputs["num_inference_steps"] = 2
inputs["output_type"] = "pt"
output = pipe(**inputs)[0]
assert output.abs().sum() == 0
import gc
import inspect
import unittest
import numpy as np
......@@ -142,6 +143,48 @@ class LatentConsistencyModelPipelineFastTests(PipelineLatentTesterMixin, Pipelin
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=5e-4)
# skip because lcm pipeline apply cfg differently
def test_callback_cfg(self):
pass
# override default test because the final latent variable is "denoised" instead of "latents"
def test_callback_inputs(self):
sig = inspect.signature(self.pipeline_class.__call__)
if not ("callback_on_step_end_tensor_inputs" in sig.parameters and "callback_on_step_end" in sig.parameters):
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
self.assertTrue(
hasattr(pipe, "_callback_tensor_inputs"),
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
)
def callback_inputs_test(pipe, i, t, callback_kwargs):
missing_callback_inputs = set()
for v in pipe._callback_tensor_inputs:
if v not in callback_kwargs:
missing_callback_inputs.add(v)
self.assertTrue(
len(missing_callback_inputs) == 0, f"Missing callback tensor inputs: {missing_callback_inputs}"
)
last_i = pipe.num_timesteps - 1
if i == last_i:
callback_kwargs["denoised"] = torch.zeros_like(callback_kwargs["denoised"])
return callback_kwargs
inputs = self.get_dummy_inputs(torch_device)
inputs["callback_on_step_end"] = callback_inputs_test
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
inputs["output_type"] = "latent"
output = pipe(**inputs)[0]
assert output.abs().sum() == 0
@slow
@require_torch_gpu
......
import gc
import inspect
import random
import unittest
......@@ -155,6 +156,44 @@ class LatentConsistencyModelImg2ImgPipelineFastTests(
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=5e-4)
# override default test because the final latent variable is "denoised" instead of "latents"
def test_callback_inputs(self):
sig = inspect.signature(self.pipeline_class.__call__)
if not ("callback_on_step_end_tensor_inputs" in sig.parameters and "callback_on_step_end" in sig.parameters):
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
self.assertTrue(
hasattr(pipe, "_callback_tensor_inputs"),
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
)
def callback_inputs_test(pipe, i, t, callback_kwargs):
missing_callback_inputs = set()
for v in pipe._callback_tensor_inputs:
if v not in callback_kwargs:
missing_callback_inputs.add(v)
self.assertTrue(
len(missing_callback_inputs) == 0, f"Missing callback tensor inputs: {missing_callback_inputs}"
)
last_i = pipe.num_timesteps - 1
if i == last_i:
callback_kwargs["denoised"] = torch.zeros_like(callback_kwargs["denoised"])
return callback_kwargs
inputs = self.get_dummy_inputs(torch_device)
inputs["callback_on_step_end"] = callback_inputs_test
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
inputs["output_type"] = "latent"
output = pipe(**inputs)[0]
assert output.abs().sum() == 0
@slow
@require_torch_gpu
......
......@@ -123,3 +123,5 @@ TEXT_TO_AUDIO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
TOKENS_TO_AUDIO_GENERATION_PARAMS = frozenset(["input_tokens"])
TOKENS_TO_AUDIO_GENERATION_BATCH_PARAMS = frozenset(["input_tokens"])
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS = frozenset(["prompt_embeds"])
......@@ -52,7 +52,12 @@ from diffusers.utils.testing_utils import (
torch_device,
)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..pipeline_params import (
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
TEXT_TO_IMAGE_IMAGE_PARAMS,
TEXT_TO_IMAGE_PARAMS,
)
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
......@@ -100,6 +105,7 @@ class StableDiffusionPipelineFastTests(
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -52,6 +52,7 @@ from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
)
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
......@@ -100,6 +101,7 @@ class StableDiffusionImg2ImgPipelineFastTests(
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -50,7 +50,11 @@ from diffusers.utils.testing_utils import (
torch_device,
)
from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
)
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
......@@ -100,6 +104,7 @@ class StableDiffusionInpaintPipelineFastTests(
image_params = frozenset([])
# TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
image_latents_params = frozenset([])
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"mask", "masked_image_latents"})
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -45,6 +45,7 @@ from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
)
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
......@@ -60,6 +61,7 @@ class StableDiffusionInstructPix2PixPipelineFastTests(
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"image_latents"}) - {"negative_prompt_embeds"}
def get_dummy_components(self):
torch.manual_seed(0)
......@@ -232,6 +234,34 @@ class StableDiffusionInstructPix2PixPipelineFastTests(
max_diff = np.abs(out - out_latents_inputs).max()
self.assertLess(max_diff, 1e-4, "passing latents as image input generate different result from passing image")
# Override the default test_callback_cfg because pix2pix create inputs for cfg differently
def test_callback_cfg(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
def callback_no_cfg(pipe, i, t, callback_kwargs):
if i == 1:
for k, w in callback_kwargs.items():
if k in self.callback_cfg_params:
callback_kwargs[k] = callback_kwargs[k].chunk(3)[0]
pipe._guidance_scale = 1.0
return callback_kwargs
inputs = self.get_dummy_inputs(torch_device)
inputs["guidance_scale"] = 1.0
inputs["num_inference_steps"] = 2
out_no_cfg = pipe(**inputs)[0]
inputs["guidance_scale"] = 7.5
inputs["callback_on_step_end"] = callback_no_cfg
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
out_callback_no_cfg = pipe(**inputs)[0]
assert out_no_cfg.shape == out_callback_no_cfg.shape
@slow
@require_torch_gpu
......
......@@ -43,7 +43,12 @@ from diffusers.utils.testing_utils import (
torch_device,
)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..pipeline_params import (
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
TEXT_TO_IMAGE_IMAGE_PARAMS,
TEXT_TO_IMAGE_PARAMS,
)
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
......@@ -58,6 +63,7 @@ class StableDiffusion2PipelineFastTests(
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -56,6 +56,7 @@ from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
TEXT_TO_IMAGE_IMAGE_PARAMS,
)
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
......@@ -75,6 +76,7 @@ class StableDiffusionDepth2ImgPipelineFastTests(
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"depth_mask"})
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -33,7 +33,11 @@ from diffusers.utils.testing_utils import (
torch_device,
)
from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
)
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
......@@ -50,6 +54,7 @@ class StableDiffusion2InpaintPipelineFastTests(
[]
) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
image_latents_params = frozenset([])
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"mask", "masked_image_latents"})
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -34,7 +34,12 @@ from diffusers import (
)
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, slow, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..pipeline_params import (
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
TEXT_TO_IMAGE_IMAGE_PARAMS,
TEXT_TO_IMAGE_PARAMS,
)
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin
......@@ -49,6 +54,7 @@ class StableDiffusionXLPipelineFastTests(
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"})
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -38,6 +38,7 @@ from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
)
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin
......@@ -52,6 +53,9 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union(
{"add_text_embeds", "add_time_ids", "add_neg_time_ids"}
)
def get_dummy_components(self, skip_first_text_encoder=False):
torch.manual_seed(0)
......
......@@ -34,7 +34,11 @@ from diffusers import (
)
from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, require_torch_gpu, slow, torch_device
from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
)
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
......@@ -48,6 +52,14 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
image_params = frozenset([])
# TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
image_latents_params = frozenset([])
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union(
{
"add_text_embeds",
"add_time_ids",
"mask",
"masked_image_latents",
}
)
def get_dummy_components(self, skip_first_text_encoder=False):
torch.manual_seed(0)
......
......@@ -231,8 +231,6 @@ class PipelineTesterMixin:
"latents",
"output_type",
"return_dict",
"callback",
"callback_steps",
]
)
......@@ -294,6 +292,20 @@ class PipelineTesterMixin:
"See existing pipeline tests for reference."
)
@property
def callback_cfg_params(self) -> frozenset:
raise NotImplementedError(
"You need to set the attribute `callback_cfg_params` in the child test class that requires to run test_callback_cfg. "
"`callback_cfg_params` are the parameters that needs to be passed to the pipeline's callback "
"function when dynamically adjusting `guidance_scale`. They are variables that require special"
"treatment when `do_classifier_free_guidance` is `True`. `pipeline_params.py` provides some common"
" sets of parameters such as `TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS`. If your pipeline's "
"set of cfg arguments has minor changes from one of the common sets of cfg arguments, "
"do not make modifications to the existing common sets of cfg arguments. I.e. for inpaint pipeine, you "
" need to adjust batch size of `mask` and `masked_image_latents` so should set the attribute as"
"`callback_cfg_params = TEXT_TO_IMAGE_CFG_PARAMS.union({'mask', 'masked_image_latents'})`"
)
def tearDown(self):
# clean up the VRAM after each test in case of CUDA runtime errors
super().tearDown()
......@@ -861,6 +873,107 @@ class PipelineTesterMixin:
assert out_cfg.shape == out_no_cfg.shape
def test_callback_inputs(self):
sig = inspect.signature(self.pipeline_class.__call__)
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
has_callback_step_end = "callback_on_step_end" in sig.parameters
if not (has_callback_tensor_inputs and has_callback_step_end):
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
self.assertTrue(
hasattr(pipe, "_callback_tensor_inputs"),
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
)
def callback_inputs_subset(pipe, i, t, callback_kwargs):
# interate over callback args
for tensor_name, tensor_value in callback_kwargs.items():
# check that we're only passing in allowed tensor inputs
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
def callback_inputs_all(pipe, i, t, callback_kwargs):
for tensor_name in pipe._callback_tensor_inputs:
assert tensor_name in callback_kwargs
# interate over callback args
for tensor_name, tensor_value in callback_kwargs.items():
# check that we're only passing in allowed tensor inputs
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
inputs = self.get_dummy_inputs(torch_device)
# Test passing in a subset
inputs["callback_on_step_end"] = callback_inputs_subset
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
inputs["output_type"] = "latent"
output = pipe(**inputs)[0]
# Test passing in a everything
inputs["callback_on_step_end"] = callback_inputs_all
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
inputs["output_type"] = "latent"
output = pipe(**inputs)[0]
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
is_last = i == (pipe.num_timesteps - 1)
if is_last:
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
return callback_kwargs
inputs["callback_on_step_end"] = callback_inputs_change_tensor
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
inputs["output_type"] = "latent"
output = pipe(**inputs)[0]
assert output.abs().sum() == 0
def test_callback_cfg(self):
sig = inspect.signature(self.pipeline_class.__call__)
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
has_callback_step_end = "callback_on_step_end" in sig.parameters
if not (has_callback_tensor_inputs and has_callback_step_end):
return
if "guidance_scale" not in sig.parameters:
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
self.assertTrue(
hasattr(pipe, "_callback_tensor_inputs"),
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
)
def callback_increase_guidance(pipe, i, t, callback_kwargs):
pipe._guidance_scale += 1.0
return callback_kwargs
inputs = self.get_dummy_inputs(torch_device)
# use cfg guidance because some pipelines modify the shape of the latents
# outside of the denoising loop
inputs["guidance_scale"] = 2.0
inputs["callback_on_step_end"] = callback_increase_guidance
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
_ = pipe(**inputs)[0]
# we increase the guidance scale by 1.0 at every step
# check that the guidance scale is increased by the number of scheduler timesteps
# accounts for models that modify the number of inference steps based on strength
assert pipe.guidance_scale == (inputs["guidance_scale"] + pipe.num_timesteps)
@is_staging_test
class PipelinePushToHubTester(unittest.TestCase):
......
......@@ -232,3 +232,9 @@ class WuerstchenCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestCase
@unittest.skip(reason="flakey and float16 requires CUDA")
def test_float16_inference(self):
super().test_float16_inference()
def test_callback_inputs(self):
pass
def test_callback_cfg(self):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment