Unverified Commit b2ca39c8 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[tests] test `encode_prompt()` in isolation (#10438)

* poc encode_prompt() tests

* fix

* updates.

* fixes

* fixes

* updates

* updates

* updates

* revert

* updates

* updates

* updates

* updates

* remove SDXLOptionalComponentsTesterMixin.

* remove tests that directly leveraged encode_prompt() in some way or the other.

* fix imports.

* remove _save_load

* fixes

* fixes

* fixes

* fixes
parent 53217126
...@@ -268,7 +268,8 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin): ...@@ -268,7 +268,8 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin):
else: else:
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
self.tokenizer.padding_side = "right" if getattr(self, "tokenizer", None) is not None:
self.tokenizer.padding_side = "right"
# See Section 3.1. of the paper. # See Section 3.1. of the paper.
max_length = max_sequence_length max_length = max_sequence_length
......
...@@ -312,7 +312,8 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin): ...@@ -312,7 +312,8 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
else: else:
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
self.tokenizer.padding_side = "right" if getattr(self, "tokenizer", None) is not None:
self.tokenizer.padding_side = "right"
# See Section 3.1. of the paper. # See Section 3.1. of the paper.
max_length = max_sequence_length max_length = max_sequence_length
......
import ast
import importlib
import inspect
import textwrap
class ReturnNameVisitor(ast.NodeVisitor):
"""Thanks to ChatGPT for pairing."""
def __init__(self):
self.return_names = []
def visit_Return(self, node):
# Check if the return value is a tuple.
if isinstance(node.value, ast.Tuple):
for elt in node.value.elts:
if isinstance(elt, ast.Name):
self.return_names.append(elt.id)
else:
try:
self.return_names.append(ast.unparse(elt))
except Exception:
self.return_names.append(str(elt))
else:
if isinstance(node.value, ast.Name):
self.return_names.append(node.value.id)
else:
try:
self.return_names.append(ast.unparse(node.value))
except Exception:
self.return_names.append(str(node.value))
self.generic_visit(node)
def _determine_parent_module(self, cls):
from diffusers import DiffusionPipeline
from diffusers.models.modeling_utils import ModelMixin
if issubclass(cls, DiffusionPipeline):
return "pipelines"
elif issubclass(cls, ModelMixin):
return "models"
else:
raise NotImplementedError
def get_ast_tree(self, cls, attribute_name="encode_prompt"):
parent_module_name = self._determine_parent_module(cls)
main_module = importlib.import_module(f"diffusers.{parent_module_name}")
current_cls_module = getattr(main_module, cls.__name__)
source_code = inspect.getsource(getattr(current_cls_module, attribute_name))
source_code = textwrap.dedent(source_code)
tree = ast.parse(source_code)
return tree
...@@ -548,6 +548,14 @@ class AnimateDiffPipelineFastTests( ...@@ -548,6 +548,14 @@ class AnimateDiffPipelineFastTests(
def test_vae_slicing(self): def test_vae_slicing(self):
return super().test_vae_slicing(image_count=2) return super().test_vae_slicing(image_count=2)
def test_encode_prompt_works_in_isolation(self):
extra_required_param_value_dict = {
"device": torch.device(torch_device).type,
"num_images_per_prompt": 1,
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
}
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
@slow @slow
@require_torch_accelerator @require_torch_accelerator
......
...@@ -517,3 +517,11 @@ class AnimateDiffControlNetPipelineFastTests( ...@@ -517,3 +517,11 @@ class AnimateDiffControlNetPipelineFastTests(
output_2 = pipe(**inputs) output_2 = pipe(**inputs)
assert np.abs(output_2[0].flatten() - output_1[0].flatten()).max() < 1e-2 assert np.abs(output_2[0].flatten() - output_1[0].flatten()).max() < 1e-2
def test_encode_prompt_works_in_isolation(self):
extra_required_param_value_dict = {
"device": torch.device(torch_device).type,
"num_images_per_prompt": 1,
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
}
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
...@@ -21,7 +21,6 @@ from ..test_pipelines_common import ( ...@@ -21,7 +21,6 @@ from ..test_pipelines_common import (
IPAdapterTesterMixin, IPAdapterTesterMixin,
PipelineTesterMixin, PipelineTesterMixin,
SDFunctionTesterMixin, SDFunctionTesterMixin,
SDXLOptionalComponentsTesterMixin,
) )
...@@ -36,7 +35,6 @@ class AnimateDiffPipelineSDXLFastTests( ...@@ -36,7 +35,6 @@ class AnimateDiffPipelineSDXLFastTests(
IPAdapterTesterMixin, IPAdapterTesterMixin,
SDFunctionTesterMixin, SDFunctionTesterMixin,
PipelineTesterMixin, PipelineTesterMixin,
SDXLOptionalComponentsTesterMixin,
unittest.TestCase, unittest.TestCase,
): ):
pipeline_class = AnimateDiffSDXLPipeline pipeline_class = AnimateDiffSDXLPipeline
...@@ -250,33 +248,6 @@ class AnimateDiffPipelineSDXLFastTests( ...@@ -250,33 +248,6 @@ class AnimateDiffPipelineSDXLFastTests(
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes)) self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
def test_prompt_embeds(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
prompt = inputs.pop("prompt")
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = pipe.encode_prompt(prompt)
pipe(
**inputs,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
)
def test_save_load_optional_components(self):
self._test_save_load_optional_components()
@unittest.skipIf( @unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(), torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed", reason="XFormers attention is only available with CUDA and `xformers` installed",
...@@ -305,3 +276,11 @@ class AnimateDiffPipelineSDXLFastTests( ...@@ -305,3 +276,11 @@ class AnimateDiffPipelineSDXLFastTests(
max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max() max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
self.assertLess(max_diff, 1e-4, "XFormers attention should not affect the inference results") self.assertLess(max_diff, 1e-4, "XFormers attention should not affect the inference results")
@unittest.skip("Test currently not supported.")
def test_encode_prompt_works_in_isolation(self):
pass
@unittest.skip("Functionality is tested elsewhere.")
def test_save_load_optional_components(self):
pass
...@@ -484,3 +484,11 @@ class AnimateDiffSparseControlNetPipelineFastTests( ...@@ -484,3 +484,11 @@ class AnimateDiffSparseControlNetPipelineFastTests(
def test_vae_slicing(self): def test_vae_slicing(self):
return super().test_vae_slicing(image_count=2) return super().test_vae_slicing(image_count=2)
def test_encode_prompt_works_in_isolation(self):
extra_required_param_value_dict = {
"device": torch.device(torch_device).type,
"num_images_per_prompt": 1,
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
}
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
...@@ -544,3 +544,11 @@ class AnimateDiffVideoToVideoPipelineFastTests( ...@@ -544,3 +544,11 @@ class AnimateDiffVideoToVideoPipelineFastTests(
inputs["strength"] = 0.5 inputs["strength"] = 0.5
inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf", 42: "Error on a leaf"} inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf", 42: "Error on a leaf"}
pipe(**inputs).frames[0] pipe(**inputs).frames[0]
def test_encode_prompt_works_in_isolation(self):
extra_required_param_value_dict = {
"device": torch.device(torch_device).type,
"num_images_per_prompt": 1,
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
}
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
...@@ -533,3 +533,11 @@ class AnimateDiffVideoToVideoControlNetPipelineFastTests( ...@@ -533,3 +533,11 @@ class AnimateDiffVideoToVideoControlNetPipelineFastTests(
inputs["strength"] = 0.5 inputs["strength"] = 0.5
inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf", 42: "Error on a leaf"} inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf", 42: "Error on a leaf"}
pipe(**inputs).frames[0] pipe(**inputs).frames[0]
def test_encode_prompt_works_in_isolation(self):
extra_required_param_value_dict = {
"device": torch.device(torch_device).type,
"num_images_per_prompt": 1,
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
}
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
...@@ -508,9 +508,14 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -508,9 +508,14 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")} model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values())) self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values()))
@unittest.skip("Test not supported.")
def test_sequential_cpu_offload_forward_pass(self): def test_sequential_cpu_offload_forward_pass(self):
pass pass
@unittest.skip("Test not supported for now because of the use of `projection_model` in `encode_prompt()`.")
def test_encode_prompt_works_in_isolation(self):
pass
@nightly @nightly
class AudioLDM2PipelineSlowTests(unittest.TestCase): class AudioLDM2PipelineSlowTests(unittest.TestCase):
......
...@@ -5,9 +5,6 @@ import torch ...@@ -5,9 +5,6 @@ import torch
from transformers import AutoTokenizer, UMT5EncoderModel from transformers import AutoTokenizer, UMT5EncoderModel
from diffusers import AuraFlowPipeline, AuraFlowTransformer2DModel, AutoencoderKL, FlowMatchEulerDiscreteScheduler from diffusers import AuraFlowPipeline, AuraFlowTransformer2DModel, AutoencoderKL, FlowMatchEulerDiscreteScheduler
from diffusers.utils.testing_utils import (
torch_device,
)
from ..test_pipelines_common import ( from ..test_pipelines_common import (
PipelineTesterMixin, PipelineTesterMixin,
...@@ -90,37 +87,6 @@ class AuraFlowPipelineFastTests(unittest.TestCase, PipelineTesterMixin): ...@@ -90,37 +87,6 @@ class AuraFlowPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
} }
return inputs return inputs
def test_aura_flow_prompt_embeds(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_with_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
prompt = inputs.pop("prompt")
do_classifier_free_guidance = inputs["guidance_scale"] > 1
(
prompt_embeds,
prompt_attention_mask,
negative_prompt_embeds,
negative_prompt_attention_mask,
) = pipe.encode_prompt(
prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
device=torch_device,
)
output_with_embeds = pipe(
prompt_embeds=prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_attention_mask=negative_prompt_attention_mask,
**inputs,
).images[0]
max_diff = np.abs(output_with_prompt - output_with_embeds).max()
assert max_diff < 1e-4
def test_attention_slicing_forward_pass(self): def test_attention_slicing_forward_pass(self):
# Attention slicing needs to implemented differently for this because how single DiT and MMDiT # Attention slicing needs to implemented differently for this because how single DiT and MMDiT
# blocks interfere with each other. # blocks interfere with each other.
......
...@@ -198,3 +198,7 @@ class BlipDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -198,3 +198,7 @@ class BlipDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
assert ( assert (
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
), f" expected_slice {image_slice.flatten()}, but got {image_slice.flatten()}" ), f" expected_slice {image_slice.flatten()}, but got {image_slice.flatten()}"
@unittest.skip("Test not supported because of complexities in deriving query_embeds.")
def test_encode_prompt_works_in_isolation(self):
pass
...@@ -232,6 +232,9 @@ class CogView3PlusPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -232,6 +232,9 @@ class CogView3PlusPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"Attention slicing should not affect the inference results", "Attention slicing should not affect the inference results",
) )
def test_encode_prompt_works_in_isolation(self):
return super().test_encode_prompt_works_in_isolation(atol=1e-3, rtol=1e-3)
@slow @slow
@require_torch_accelerator @require_torch_accelerator
......
...@@ -288,6 +288,13 @@ class ControlNetPipelineFastTests( ...@@ -288,6 +288,13 @@ class ControlNetPipelineFastTests(
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_encode_prompt_works_in_isolation(self):
extra_required_param_value_dict = {
"device": torch.device(torch_device).type,
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
}
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
class StableDiffusionMultiControlNetPipelineFastTests( class StableDiffusionMultiControlNetPipelineFastTests(
IPAdapterTesterMixin, PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase IPAdapterTesterMixin, PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
...@@ -522,6 +529,13 @@ class StableDiffusionMultiControlNetPipelineFastTests( ...@@ -522,6 +529,13 @@ class StableDiffusionMultiControlNetPipelineFastTests(
assert image.shape == (4, 64, 64, 3) assert image.shape == (4, 64, 64, 3)
def test_encode_prompt_works_in_isolation(self):
extra_required_param_value_dict = {
"device": torch.device(torch_device).type,
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
}
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
class StableDiffusionMultiControlNetOneModelPipelineFastTests( class StableDiffusionMultiControlNetOneModelPipelineFastTests(
IPAdapterTesterMixin, PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase IPAdapterTesterMixin, PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
...@@ -707,6 +721,13 @@ class StableDiffusionMultiControlNetOneModelPipelineFastTests( ...@@ -707,6 +721,13 @@ class StableDiffusionMultiControlNetOneModelPipelineFastTests(
except NotImplementedError: except NotImplementedError:
pass pass
def test_encode_prompt_works_in_isolation(self):
extra_required_param_value_dict = {
"device": torch.device(torch_device).type,
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
}
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
@slow @slow
@require_torch_accelerator @require_torch_accelerator
......
...@@ -222,3 +222,7 @@ class BlipDiffusionControlNetPipelineFastTests(PipelineTesterMixin, unittest.Tes ...@@ -222,3 +222,7 @@ class BlipDiffusionControlNetPipelineFastTests(PipelineTesterMixin, unittest.Tes
assert ( assert (
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
), f" expected_slice {expected_slice}, but got {image_slice.flatten()}" ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
@unittest.skip("Test not supported because of complexities in deriving query_embeds.")
def test_encode_prompt_works_in_isolation(self):
pass
...@@ -189,6 +189,13 @@ class ControlNetImg2ImgPipelineFastTests( ...@@ -189,6 +189,13 @@ class ControlNetImg2ImgPipelineFastTests(
def test_inference_batch_single_identical(self): def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3) self._test_inference_batch_single_identical(expected_max_diff=2e-3)
def test_encode_prompt_works_in_isolation(self):
extra_required_param_value_dict = {
"device": torch.device(torch_device).type,
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
}
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
class StableDiffusionMultiControlNetPipelineFastTests( class StableDiffusionMultiControlNetPipelineFastTests(
IPAdapterTesterMixin, PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase IPAdapterTesterMixin, PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
...@@ -391,6 +398,13 @@ class StableDiffusionMultiControlNetPipelineFastTests( ...@@ -391,6 +398,13 @@ class StableDiffusionMultiControlNetPipelineFastTests(
except NotImplementedError: except NotImplementedError:
pass pass
def test_encode_prompt_works_in_isolation(self):
extra_required_param_value_dict = {
"device": torch.device(torch_device).type,
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
}
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
@slow @slow
@require_torch_accelerator @require_torch_accelerator
......
...@@ -176,6 +176,13 @@ class ControlNetInpaintPipelineFastTests( ...@@ -176,6 +176,13 @@ class ControlNetInpaintPipelineFastTests(
def test_inference_batch_single_identical(self): def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3) self._test_inference_batch_single_identical(expected_max_diff=2e-3)
def test_encode_prompt_works_in_isolation(self):
extra_required_param_value_dict = {
"device": torch.device(torch_device).type,
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
}
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
class ControlNetSimpleInpaintPipelineFastTests(ControlNetInpaintPipelineFastTests): class ControlNetSimpleInpaintPipelineFastTests(ControlNetInpaintPipelineFastTests):
pipeline_class = StableDiffusionControlNetInpaintPipeline pipeline_class = StableDiffusionControlNetInpaintPipeline
...@@ -443,6 +450,13 @@ class MultiControlNetInpaintPipelineFastTests( ...@@ -443,6 +450,13 @@ class MultiControlNetInpaintPipelineFastTests(
except NotImplementedError: except NotImplementedError:
pass pass
def test_encode_prompt_works_in_isolation(self):
extra_required_param_value_dict = {
"device": torch.device(torch_device).type,
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
}
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
@slow @slow
@require_torch_accelerator @require_torch_accelerator
......
...@@ -55,7 +55,6 @@ from ..test_pipelines_common import ( ...@@ -55,7 +55,6 @@ from ..test_pipelines_common import (
PipelineKarrasSchedulerTesterMixin, PipelineKarrasSchedulerTesterMixin,
PipelineLatentTesterMixin, PipelineLatentTesterMixin,
PipelineTesterMixin, PipelineTesterMixin,
SDXLOptionalComponentsTesterMixin,
) )
...@@ -67,7 +66,6 @@ class StableDiffusionXLControlNetPipelineFastTests( ...@@ -67,7 +66,6 @@ class StableDiffusionXLControlNetPipelineFastTests(
PipelineLatentTesterMixin, PipelineLatentTesterMixin,
PipelineKarrasSchedulerTesterMixin, PipelineKarrasSchedulerTesterMixin,
PipelineTesterMixin, PipelineTesterMixin,
SDXLOptionalComponentsTesterMixin,
unittest.TestCase, unittest.TestCase,
): ):
pipeline_class = StableDiffusionXLControlNetPipeline pipeline_class = StableDiffusionXLControlNetPipeline
...@@ -212,8 +210,9 @@ class StableDiffusionXLControlNetPipelineFastTests( ...@@ -212,8 +210,9 @@ class StableDiffusionXLControlNetPipelineFastTests(
def test_inference_batch_single_identical(self): def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3) self._test_inference_batch_single_identical(expected_max_diff=2e-3)
@unittest.skip("We test this functionality elsewhere already.")
def test_save_load_optional_components(self): def test_save_load_optional_components(self):
self._test_save_load_optional_components() pass
@require_torch_accelerator @require_torch_accelerator
def test_stable_diffusion_xl_offloads(self): def test_stable_diffusion_xl_offloads(self):
...@@ -297,45 +296,6 @@ class StableDiffusionXLControlNetPipelineFastTests( ...@@ -297,45 +296,6 @@ class StableDiffusionXLControlNetPipelineFastTests(
# ensure the results are not equal # ensure the results are not equal
assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4 assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
# Copied from test_stable_diffusion_xl.py
def test_stable_diffusion_xl_prompt_embeds(self):
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
# forward without prompt embeds
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt"] = 2 * [inputs["prompt"]]
inputs["num_images_per_prompt"] = 2
output = sd_pipe(**inputs)
image_slice_1 = output.images[0, -3:, -3:, -1]
# forward with prompt embeds
inputs = self.get_dummy_inputs(torch_device)
prompt = 2 * [inputs.pop("prompt")]
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = sd_pipe.encode_prompt(prompt)
output = sd_pipe(
**inputs,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
)
image_slice_2 = output.images[0, -3:, -3:, -1]
# make sure that it's equal
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
def test_controlnet_sdxl_guess(self): def test_controlnet_sdxl_guess(self):
device = "cpu" device = "cpu"
...@@ -483,7 +443,7 @@ class StableDiffusionXLControlNetPipelineFastTests( ...@@ -483,7 +443,7 @@ class StableDiffusionXLControlNetPipelineFastTests(
class StableDiffusionXLMultiControlNetPipelineFastTests( class StableDiffusionXLMultiControlNetPipelineFastTests(
PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
): ):
pipeline_class = StableDiffusionXLControlNetPipeline pipeline_class = StableDiffusionXLControlNetPipeline
params = TEXT_TO_IMAGE_PARAMS params = TEXT_TO_IMAGE_PARAMS
...@@ -685,12 +645,13 @@ class StableDiffusionXLMultiControlNetPipelineFastTests( ...@@ -685,12 +645,13 @@ class StableDiffusionXLMultiControlNetPipelineFastTests(
def test_inference_batch_single_identical(self): def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3) self._test_inference_batch_single_identical(expected_max_diff=2e-3)
@unittest.skip("We test this functionality elsewhere already.")
def test_save_load_optional_components(self): def test_save_load_optional_components(self):
return self._test_save_load_optional_components() pass
class StableDiffusionXLMultiControlNetOneModelPipelineFastTests( class StableDiffusionXLMultiControlNetOneModelPipelineFastTests(
PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
): ):
pipeline_class = StableDiffusionXLControlNetPipeline pipeline_class = StableDiffusionXLControlNetPipeline
params = TEXT_TO_IMAGE_PARAMS params = TEXT_TO_IMAGE_PARAMS
...@@ -862,6 +823,10 @@ class StableDiffusionXLMultiControlNetOneModelPipelineFastTests( ...@@ -862,6 +823,10 @@ class StableDiffusionXLMultiControlNetOneModelPipelineFastTests(
def test_attention_slicing_forward_pass(self): def test_attention_slicing_forward_pass(self):
return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3) return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3)
@unittest.skip("We test this functionality elsewhere already.")
def test_save_load_optional_components(self):
pass
@unittest.skipIf( @unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(), torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed", reason="XFormers attention is only available with CUDA and `xformers` installed",
...@@ -872,9 +837,6 @@ class StableDiffusionXLMultiControlNetOneModelPipelineFastTests( ...@@ -872,9 +837,6 @@ class StableDiffusionXLMultiControlNetOneModelPipelineFastTests(
def test_inference_batch_single_identical(self): def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3) self._test_inference_batch_single_identical(expected_max_diff=2e-3)
def test_save_load_optional_components(self):
self._test_save_load_optional_components()
def test_negative_conditions(self): def test_negative_conditions(self):
components = self.get_dummy_components() components = self.get_dummy_components()
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
......
...@@ -327,42 +327,3 @@ class ControlNetPipelineSDXLImg2ImgFastTests( ...@@ -327,42 +327,3 @@ class ControlNetPipelineSDXLImg2ImgFastTests(
# ensure the results are not equal # ensure the results are not equal
assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4 assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
# Copied from test_stable_diffusion_xl.py
def test_stable_diffusion_xl_prompt_embeds(self):
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
# forward without prompt embeds
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt"] = 2 * [inputs["prompt"]]
inputs["num_images_per_prompt"] = 2
output = sd_pipe(**inputs)
image_slice_1 = output.images[0, -3:, -3:, -1]
# forward with prompt embeds
inputs = self.get_dummy_inputs(torch_device)
prompt = 2 * [inputs.pop("prompt")]
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = sd_pipe.encode_prompt(prompt)
output = sd_pipe(
**inputs,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
)
image_slice_2 = output.images[0, -3:, -3:, -1]
# make sure that it's equal
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
...@@ -178,6 +178,12 @@ class HunyuanDiTControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMix ...@@ -178,6 +178,12 @@ class HunyuanDiTControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMix
# TODO(YiYi) need to fix later # TODO(YiYi) need to fix later
pass pass
@unittest.skip(
"Test not supported as `encode_prompt` is called two times separately which deivates from about 99% of the pipelines we have."
)
def test_encode_prompt_works_in_isolation(self):
pass
@slow @slow
@require_torch_accelerator @require_torch_accelerator
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment