Unverified Commit b2ca39c8 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[tests] test `encode_prompt()` in isolation (#10438)

* poc encode_prompt() tests

* fix

* updates.

* fixes

* fixes

* updates

* updates

* updates

* revert

* updates

* updates

* updates

* updates

* remove SDXLOptionalComponentsTesterMixin.

* remove tests that directly leveraged encode_prompt() in some way or the other.

* fix imports.

* remove _save_load

* fixes

* fixes

* fixes

* fixes
parent 53217126
......@@ -268,7 +268,8 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin):
else:
batch_size = prompt_embeds.shape[0]
self.tokenizer.padding_side = "right"
if getattr(self, "tokenizer", None) is not None:
self.tokenizer.padding_side = "right"
# See Section 3.1. of the paper.
max_length = max_sequence_length
......
......@@ -312,7 +312,8 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
else:
batch_size = prompt_embeds.shape[0]
self.tokenizer.padding_side = "right"
if getattr(self, "tokenizer", None) is not None:
self.tokenizer.padding_side = "right"
# See Section 3.1. of the paper.
max_length = max_sequence_length
......
import ast
import importlib
import inspect
import textwrap
class ReturnNameVisitor(ast.NodeVisitor):
"""Thanks to ChatGPT for pairing."""
def __init__(self):
self.return_names = []
def visit_Return(self, node):
# Check if the return value is a tuple.
if isinstance(node.value, ast.Tuple):
for elt in node.value.elts:
if isinstance(elt, ast.Name):
self.return_names.append(elt.id)
else:
try:
self.return_names.append(ast.unparse(elt))
except Exception:
self.return_names.append(str(elt))
else:
if isinstance(node.value, ast.Name):
self.return_names.append(node.value.id)
else:
try:
self.return_names.append(ast.unparse(node.value))
except Exception:
self.return_names.append(str(node.value))
self.generic_visit(node)
def _determine_parent_module(self, cls):
from diffusers import DiffusionPipeline
from diffusers.models.modeling_utils import ModelMixin
if issubclass(cls, DiffusionPipeline):
return "pipelines"
elif issubclass(cls, ModelMixin):
return "models"
else:
raise NotImplementedError
def get_ast_tree(self, cls, attribute_name="encode_prompt"):
parent_module_name = self._determine_parent_module(cls)
main_module = importlib.import_module(f"diffusers.{parent_module_name}")
current_cls_module = getattr(main_module, cls.__name__)
source_code = inspect.getsource(getattr(current_cls_module, attribute_name))
source_code = textwrap.dedent(source_code)
tree = ast.parse(source_code)
return tree
......@@ -548,6 +548,14 @@ class AnimateDiffPipelineFastTests(
def test_vae_slicing(self):
return super().test_vae_slicing(image_count=2)
def test_encode_prompt_works_in_isolation(self):
extra_required_param_value_dict = {
"device": torch.device(torch_device).type,
"num_images_per_prompt": 1,
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
}
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
@slow
@require_torch_accelerator
......
......@@ -517,3 +517,11 @@ class AnimateDiffControlNetPipelineFastTests(
output_2 = pipe(**inputs)
assert np.abs(output_2[0].flatten() - output_1[0].flatten()).max() < 1e-2
def test_encode_prompt_works_in_isolation(self):
extra_required_param_value_dict = {
"device": torch.device(torch_device).type,
"num_images_per_prompt": 1,
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
}
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
......@@ -21,7 +21,6 @@ from ..test_pipelines_common import (
IPAdapterTesterMixin,
PipelineTesterMixin,
SDFunctionTesterMixin,
SDXLOptionalComponentsTesterMixin,
)
......@@ -36,7 +35,6 @@ class AnimateDiffPipelineSDXLFastTests(
IPAdapterTesterMixin,
SDFunctionTesterMixin,
PipelineTesterMixin,
SDXLOptionalComponentsTesterMixin,
unittest.TestCase,
):
pipeline_class = AnimateDiffSDXLPipeline
......@@ -250,33 +248,6 @@ class AnimateDiffPipelineSDXLFastTests(
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
def test_prompt_embeds(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
prompt = inputs.pop("prompt")
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = pipe.encode_prompt(prompt)
pipe(
**inputs,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
)
def test_save_load_optional_components(self):
self._test_save_load_optional_components()
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
......@@ -305,3 +276,11 @@ class AnimateDiffPipelineSDXLFastTests(
max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
self.assertLess(max_diff, 1e-4, "XFormers attention should not affect the inference results")
@unittest.skip("Test currently not supported.")
def test_encode_prompt_works_in_isolation(self):
pass
@unittest.skip("Functionality is tested elsewhere.")
def test_save_load_optional_components(self):
pass
......@@ -484,3 +484,11 @@ class AnimateDiffSparseControlNetPipelineFastTests(
def test_vae_slicing(self):
return super().test_vae_slicing(image_count=2)
def test_encode_prompt_works_in_isolation(self):
extra_required_param_value_dict = {
"device": torch.device(torch_device).type,
"num_images_per_prompt": 1,
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
}
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
......@@ -544,3 +544,11 @@ class AnimateDiffVideoToVideoPipelineFastTests(
inputs["strength"] = 0.5
inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf", 42: "Error on a leaf"}
pipe(**inputs).frames[0]
def test_encode_prompt_works_in_isolation(self):
extra_required_param_value_dict = {
"device": torch.device(torch_device).type,
"num_images_per_prompt": 1,
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
}
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
......@@ -533,3 +533,11 @@ class AnimateDiffVideoToVideoControlNetPipelineFastTests(
inputs["strength"] = 0.5
inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf", 42: "Error on a leaf"}
pipe(**inputs).frames[0]
def test_encode_prompt_works_in_isolation(self):
extra_required_param_value_dict = {
"device": torch.device(torch_device).type,
"num_images_per_prompt": 1,
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
}
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
......@@ -508,9 +508,14 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values()))
@unittest.skip("Test not supported.")
def test_sequential_cpu_offload_forward_pass(self):
pass
@unittest.skip("Test not supported for now because of the use of `projection_model` in `encode_prompt()`.")
def test_encode_prompt_works_in_isolation(self):
pass
@nightly
class AudioLDM2PipelineSlowTests(unittest.TestCase):
......
......@@ -5,9 +5,6 @@ import torch
from transformers import AutoTokenizer, UMT5EncoderModel
from diffusers import AuraFlowPipeline, AuraFlowTransformer2DModel, AutoencoderKL, FlowMatchEulerDiscreteScheduler
from diffusers.utils.testing_utils import (
torch_device,
)
from ..test_pipelines_common import (
PipelineTesterMixin,
......@@ -90,37 +87,6 @@ class AuraFlowPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
}
return inputs
def test_aura_flow_prompt_embeds(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_with_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
prompt = inputs.pop("prompt")
do_classifier_free_guidance = inputs["guidance_scale"] > 1
(
prompt_embeds,
prompt_attention_mask,
negative_prompt_embeds,
negative_prompt_attention_mask,
) = pipe.encode_prompt(
prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
device=torch_device,
)
output_with_embeds = pipe(
prompt_embeds=prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_attention_mask=negative_prompt_attention_mask,
**inputs,
).images[0]
max_diff = np.abs(output_with_prompt - output_with_embeds).max()
assert max_diff < 1e-4
def test_attention_slicing_forward_pass(self):
# Attention slicing needs to implemented differently for this because how single DiT and MMDiT
# blocks interfere with each other.
......
......@@ -198,3 +198,7 @@ class BlipDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
assert (
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
), f" expected_slice {image_slice.flatten()}, but got {image_slice.flatten()}"
@unittest.skip("Test not supported because of complexities in deriving query_embeds.")
def test_encode_prompt_works_in_isolation(self):
pass
......@@ -232,6 +232,9 @@ class CogView3PlusPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"Attention slicing should not affect the inference results",
)
def test_encode_prompt_works_in_isolation(self):
return super().test_encode_prompt_works_in_isolation(atol=1e-3, rtol=1e-3)
@slow
@require_torch_accelerator
......
......@@ -288,6 +288,13 @@ class ControlNetPipelineFastTests(
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_encode_prompt_works_in_isolation(self):
extra_required_param_value_dict = {
"device": torch.device(torch_device).type,
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
}
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
class StableDiffusionMultiControlNetPipelineFastTests(
IPAdapterTesterMixin, PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
......@@ -522,6 +529,13 @@ class StableDiffusionMultiControlNetPipelineFastTests(
assert image.shape == (4, 64, 64, 3)
def test_encode_prompt_works_in_isolation(self):
extra_required_param_value_dict = {
"device": torch.device(torch_device).type,
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
}
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
class StableDiffusionMultiControlNetOneModelPipelineFastTests(
IPAdapterTesterMixin, PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
......@@ -707,6 +721,13 @@ class StableDiffusionMultiControlNetOneModelPipelineFastTests(
except NotImplementedError:
pass
def test_encode_prompt_works_in_isolation(self):
extra_required_param_value_dict = {
"device": torch.device(torch_device).type,
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
}
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
@slow
@require_torch_accelerator
......
......@@ -222,3 +222,7 @@ class BlipDiffusionControlNetPipelineFastTests(PipelineTesterMixin, unittest.Tes
assert (
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
@unittest.skip("Test not supported because of complexities in deriving query_embeds.")
def test_encode_prompt_works_in_isolation(self):
pass
......@@ -189,6 +189,13 @@ class ControlNetImg2ImgPipelineFastTests(
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
def test_encode_prompt_works_in_isolation(self):
extra_required_param_value_dict = {
"device": torch.device(torch_device).type,
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
}
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
class StableDiffusionMultiControlNetPipelineFastTests(
IPAdapterTesterMixin, PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
......@@ -391,6 +398,13 @@ class StableDiffusionMultiControlNetPipelineFastTests(
except NotImplementedError:
pass
def test_encode_prompt_works_in_isolation(self):
extra_required_param_value_dict = {
"device": torch.device(torch_device).type,
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
}
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
@slow
@require_torch_accelerator
......
......@@ -176,6 +176,13 @@ class ControlNetInpaintPipelineFastTests(
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
def test_encode_prompt_works_in_isolation(self):
extra_required_param_value_dict = {
"device": torch.device(torch_device).type,
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
}
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
class ControlNetSimpleInpaintPipelineFastTests(ControlNetInpaintPipelineFastTests):
pipeline_class = StableDiffusionControlNetInpaintPipeline
......@@ -443,6 +450,13 @@ class MultiControlNetInpaintPipelineFastTests(
except NotImplementedError:
pass
def test_encode_prompt_works_in_isolation(self):
extra_required_param_value_dict = {
"device": torch.device(torch_device).type,
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
}
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
@slow
@require_torch_accelerator
......
......@@ -55,7 +55,6 @@ from ..test_pipelines_common import (
PipelineKarrasSchedulerTesterMixin,
PipelineLatentTesterMixin,
PipelineTesterMixin,
SDXLOptionalComponentsTesterMixin,
)
......@@ -67,7 +66,6 @@ class StableDiffusionXLControlNetPipelineFastTests(
PipelineLatentTesterMixin,
PipelineKarrasSchedulerTesterMixin,
PipelineTesterMixin,
SDXLOptionalComponentsTesterMixin,
unittest.TestCase,
):
pipeline_class = StableDiffusionXLControlNetPipeline
......@@ -212,8 +210,9 @@ class StableDiffusionXLControlNetPipelineFastTests(
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
@unittest.skip("We test this functionality elsewhere already.")
def test_save_load_optional_components(self):
self._test_save_load_optional_components()
pass
@require_torch_accelerator
def test_stable_diffusion_xl_offloads(self):
......@@ -297,45 +296,6 @@ class StableDiffusionXLControlNetPipelineFastTests(
# ensure the results are not equal
assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
# Copied from test_stable_diffusion_xl.py
def test_stable_diffusion_xl_prompt_embeds(self):
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
# forward without prompt embeds
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt"] = 2 * [inputs["prompt"]]
inputs["num_images_per_prompt"] = 2
output = sd_pipe(**inputs)
image_slice_1 = output.images[0, -3:, -3:, -1]
# forward with prompt embeds
inputs = self.get_dummy_inputs(torch_device)
prompt = 2 * [inputs.pop("prompt")]
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = sd_pipe.encode_prompt(prompt)
output = sd_pipe(
**inputs,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
)
image_slice_2 = output.images[0, -3:, -3:, -1]
# make sure that it's equal
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
def test_controlnet_sdxl_guess(self):
device = "cpu"
......@@ -483,7 +443,7 @@ class StableDiffusionXLControlNetPipelineFastTests(
class StableDiffusionXLMultiControlNetPipelineFastTests(
PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase
PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
):
pipeline_class = StableDiffusionXLControlNetPipeline
params = TEXT_TO_IMAGE_PARAMS
......@@ -685,12 +645,13 @@ class StableDiffusionXLMultiControlNetPipelineFastTests(
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
@unittest.skip("We test this functionality elsewhere already.")
def test_save_load_optional_components(self):
return self._test_save_load_optional_components()
pass
class StableDiffusionXLMultiControlNetOneModelPipelineFastTests(
PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase
PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
):
pipeline_class = StableDiffusionXLControlNetPipeline
params = TEXT_TO_IMAGE_PARAMS
......@@ -862,6 +823,10 @@ class StableDiffusionXLMultiControlNetOneModelPipelineFastTests(
def test_attention_slicing_forward_pass(self):
return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3)
@unittest.skip("We test this functionality elsewhere already.")
def test_save_load_optional_components(self):
pass
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
......@@ -872,9 +837,6 @@ class StableDiffusionXLMultiControlNetOneModelPipelineFastTests(
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
def test_save_load_optional_components(self):
self._test_save_load_optional_components()
def test_negative_conditions(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
......
......@@ -327,42 +327,3 @@ class ControlNetPipelineSDXLImg2ImgFastTests(
# ensure the results are not equal
assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
# Copied from test_stable_diffusion_xl.py
def test_stable_diffusion_xl_prompt_embeds(self):
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
# forward without prompt embeds
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt"] = 2 * [inputs["prompt"]]
inputs["num_images_per_prompt"] = 2
output = sd_pipe(**inputs)
image_slice_1 = output.images[0, -3:, -3:, -1]
# forward with prompt embeds
inputs = self.get_dummy_inputs(torch_device)
prompt = 2 * [inputs.pop("prompt")]
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = sd_pipe.encode_prompt(prompt)
output = sd_pipe(
**inputs,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
)
image_slice_2 = output.images[0, -3:, -3:, -1]
# make sure that it's equal
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
......@@ -178,6 +178,12 @@ class HunyuanDiTControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMix
# TODO(YiYi) need to fix later
pass
@unittest.skip(
"Test not supported as `encode_prompt` is called two times separately which deivates from about 99% of the pipelines we have."
)
def test_encode_prompt_works_in_isolation(self):
pass
@slow
@require_torch_accelerator
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment