Unverified Commit bb1b76d3 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

IPAdapterTesterMixin (#6862)



* begin IPAdapterTesterMixin



---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent e4b8f173
...@@ -27,14 +27,14 @@ from ..pipeline_params import ( ...@@ -27,14 +27,14 @@ from ..pipeline_params import (
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
) )
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin from ..test_pipelines_common import IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
enable_full_determinism() enable_full_determinism()
class LatentConsistencyModelImg2ImgPipelineFastTests( class LatentConsistencyModelImg2ImgPipelineFastTests(
PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase
): ):
pipeline_class = LatentConsistencyModelImg2ImgPipeline pipeline_class = LatentConsistencyModelImg2ImgPipeline
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width", "negative_prompt", "negative_prompt_embeds"} params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width", "negative_prompt", "negative_prompt_embeds"}
......
...@@ -17,7 +17,7 @@ from diffusers import ( ...@@ -17,7 +17,7 @@ from diffusers import (
from diffusers.utils import is_xformers_available, logging from diffusers.utils import is_xformers_available, logging
from diffusers.utils.testing_utils import floats_tensor, torch_device from diffusers.utils.testing_utils import floats_tensor, torch_device
from ..test_pipelines_common import PipelineTesterMixin from ..test_pipelines_common import IPAdapterTesterMixin, PipelineTesterMixin
def to_np(tensor): def to_np(tensor):
...@@ -27,7 +27,7 @@ def to_np(tensor): ...@@ -27,7 +27,7 @@ def to_np(tensor):
return tensor return tensor
class PIAPipelineFastTests(PipelineTesterMixin, unittest.TestCase): class PIAPipelineFastTests(IPAdapterTesterMixin, PipelineTesterMixin, unittest.TestCase):
pipeline_class = PIAPipeline pipeline_class = PIAPipeline
params = frozenset( params = frozenset(
[ [
......
...@@ -23,7 +23,11 @@ import unittest ...@@ -23,7 +23,11 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from transformers import (
CLIPTextConfig,
CLIPTextModel,
CLIPTokenizer,
)
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
...@@ -60,7 +64,12 @@ from ..pipeline_params import ( ...@@ -60,7 +64,12 @@ from ..pipeline_params import (
TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS,
TEXT_TO_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS,
) )
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin from ..test_pipelines_common import (
IPAdapterTesterMixin,
PipelineKarrasSchedulerTesterMixin,
PipelineLatentTesterMixin,
PipelineTesterMixin,
)
enable_full_determinism() enable_full_determinism()
...@@ -100,7 +109,11 @@ def _test_stable_diffusion_compile(in_queue, out_queue, timeout): ...@@ -100,7 +109,11 @@ def _test_stable_diffusion_compile(in_queue, out_queue, timeout):
class StableDiffusionPipelineFastTests( class StableDiffusionPipelineFastTests(
PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase IPAdapterTesterMixin,
PipelineLatentTesterMixin,
PipelineKarrasSchedulerTesterMixin,
PipelineTesterMixin,
unittest.TestCase,
): ):
pipeline_class = StableDiffusionPipeline pipeline_class = StableDiffusionPipeline
params = TEXT_TO_IMAGE_PARAMS params = TEXT_TO_IMAGE_PARAMS
...@@ -177,7 +190,7 @@ class StableDiffusionPipelineFastTests( ...@@ -177,7 +190,7 @@ class StableDiffusionPipelineFastTests(
"generator": generator, "generator": generator,
"num_inference_steps": 2, "num_inference_steps": 2,
"guidance_scale": 6.0, "guidance_scale": 6.0,
"output_type": "numpy", "output_type": "np",
} }
return inputs return inputs
......
...@@ -55,7 +55,12 @@ from ..pipeline_params import ( ...@@ -55,7 +55,12 @@ from ..pipeline_params import (
TEXT_GUIDED_IMAGE_VARIATION_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS, TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
) )
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin from ..test_pipelines_common import (
IPAdapterTesterMixin,
PipelineKarrasSchedulerTesterMixin,
PipelineLatentTesterMixin,
PipelineTesterMixin,
)
enable_full_determinism() enable_full_determinism()
...@@ -94,7 +99,11 @@ def _test_img2img_compile(in_queue, out_queue, timeout): ...@@ -94,7 +99,11 @@ def _test_img2img_compile(in_queue, out_queue, timeout):
class StableDiffusionImg2ImgPipelineFastTests( class StableDiffusionImg2ImgPipelineFastTests(
PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase IPAdapterTesterMixin,
PipelineLatentTesterMixin,
PipelineKarrasSchedulerTesterMixin,
PipelineTesterMixin,
unittest.TestCase,
): ):
pipeline_class = StableDiffusionImg2ImgPipeline pipeline_class = StableDiffusionImg2ImgPipeline
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"} params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"}
......
...@@ -57,7 +57,12 @@ from ..pipeline_params import ( ...@@ -57,7 +57,12 @@ from ..pipeline_params import (
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS, TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
) )
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin from ..test_pipelines_common import (
IPAdapterTesterMixin,
PipelineKarrasSchedulerTesterMixin,
PipelineLatentTesterMixin,
PipelineTesterMixin,
)
enable_full_determinism() enable_full_determinism()
...@@ -98,7 +103,11 @@ def _test_inpaint_compile(in_queue, out_queue, timeout): ...@@ -98,7 +103,11 @@ def _test_inpaint_compile(in_queue, out_queue, timeout):
class StableDiffusionInpaintPipelineFastTests( class StableDiffusionInpaintPipelineFastTests(
PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase IPAdapterTesterMixin,
PipelineLatentTesterMixin,
PipelineKarrasSchedulerTesterMixin,
PipelineTesterMixin,
unittest.TestCase,
): ):
pipeline_class = StableDiffusionInpaintPipeline pipeline_class = StableDiffusionInpaintPipeline
params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
......
...@@ -47,7 +47,11 @@ from ..pipeline_params import ( ...@@ -47,7 +47,11 @@ from ..pipeline_params import (
TEXT_GUIDED_IMAGE_VARIATION_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS, TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
) )
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin from ..test_pipelines_common import (
PipelineKarrasSchedulerTesterMixin,
PipelineLatentTesterMixin,
PipelineTesterMixin,
)
enable_full_determinism() enable_full_determinism()
......
...@@ -49,14 +49,23 @@ from ..pipeline_params import ( ...@@ -49,14 +49,23 @@ from ..pipeline_params import (
TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS,
TEXT_TO_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS,
) )
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin from ..test_pipelines_common import (
IPAdapterTesterMixin,
PipelineLatentTesterMixin,
PipelineTesterMixin,
SDXLOptionalComponentsTesterMixin,
)
enable_full_determinism() enable_full_determinism()
class StableDiffusionXLPipelineFastTests( class StableDiffusionXLPipelineFastTests(
PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase IPAdapterTesterMixin,
PipelineLatentTesterMixin,
PipelineTesterMixin,
SDXLOptionalComponentsTesterMixin,
unittest.TestCase,
): ):
pipeline_class = StableDiffusionXLPipeline pipeline_class = StableDiffusionXLPipeline
params = TEXT_TO_IMAGE_PARAMS params = TEXT_TO_IMAGE_PARAMS
......
...@@ -44,6 +44,7 @@ from diffusers.utils.testing_utils import ( ...@@ -44,6 +44,7 @@ from diffusers.utils.testing_utils import (
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import ( from ..test_pipelines_common import (
IPAdapterTesterMixin,
PipelineTesterMixin, PipelineTesterMixin,
SDXLOptionalComponentsTesterMixin, SDXLOptionalComponentsTesterMixin,
assert_mean_pixel_difference, assert_mean_pixel_difference,
...@@ -54,7 +55,7 @@ enable_full_determinism() ...@@ -54,7 +55,7 @@ enable_full_determinism()
class StableDiffusionXLAdapterPipelineFastTests( class StableDiffusionXLAdapterPipelineFastTests(
PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase IPAdapterTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase
): ):
pipeline_class = StableDiffusionXLAdapterPipeline pipeline_class = StableDiffusionXLAdapterPipeline
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS
......
...@@ -54,13 +54,20 @@ from ..pipeline_params import ( ...@@ -54,13 +54,20 @@ from ..pipeline_params import (
TEXT_GUIDED_IMAGE_VARIATION_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS, TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
) )
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin from ..test_pipelines_common import (
IPAdapterTesterMixin,
PipelineLatentTesterMixin,
PipelineTesterMixin,
SDXLOptionalComponentsTesterMixin,
)
enable_full_determinism() enable_full_determinism()
class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): class StableDiffusionXLImg2ImgPipelineFastTests(
IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase
):
pipeline_class = StableDiffusionXLImg2ImgPipeline pipeline_class = StableDiffusionXLImg2ImgPipeline
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"} params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"}
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
......
...@@ -48,13 +48,15 @@ from ..pipeline_params import ( ...@@ -48,13 +48,15 @@ from ..pipeline_params import (
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS, TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
) )
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin from ..test_pipelines_common import IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
enable_full_determinism() enable_full_determinism()
class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): class StableDiffusionXLInpaintPipelineFastTests(
IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase
):
pipeline_class = StableDiffusionXLInpaintPipeline pipeline_class = StableDiffusionXLInpaintPipeline
params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
......
...@@ -8,7 +8,7 @@ import re ...@@ -8,7 +8,7 @@ import re
import tempfile import tempfile
import unittest import unittest
import uuid import uuid
from typing import Callable, Union from typing import Any, Callable, Dict, Union
import numpy as np import numpy as np
import PIL.Image import PIL.Image
...@@ -29,6 +29,7 @@ from diffusers import ( ...@@ -29,6 +29,7 @@ from diffusers import (
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import IPAdapterMixin
from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import logging from diffusers.utils import logging
from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available
...@@ -44,6 +45,7 @@ from ..models.autoencoders.test_models_vae import ( ...@@ -44,6 +45,7 @@ from ..models.autoencoders.test_models_vae import (
get_autoencoder_tiny_config, get_autoencoder_tiny_config,
get_consistency_vae_config, get_consistency_vae_config,
) )
from ..models.unets.test_models_unet_2d_condition import create_ip_adapter_state_dict
from ..others.test_utils import TOKEN, USER, is_staging_test from ..others.test_utils import TOKEN, USER, is_staging_test
...@@ -59,6 +61,118 @@ def check_same_shape(tensor_list): ...@@ -59,6 +61,118 @@ def check_same_shape(tensor_list):
return all(shape == shapes[0] for shape in shapes[1:]) return all(shape == shapes[0] for shape in shapes[1:])
class IPAdapterTesterMixin:
"""
This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.
It provides a set of common tests for pipelines that support IP Adapters.
"""
def test_pipeline_signature(self):
parameters = inspect.signature(self.pipeline_class.__call__).parameters
assert issubclass(self.pipeline_class, IPAdapterMixin)
self.assertIn(
"ip_adapter_image",
parameters,
"`ip_adapter_image` argument must be supported by the `__call__` method",
)
self.assertIn(
"ip_adapter_image_embeds",
parameters,
"`ip_adapter_image_embeds` argument must be supported by the `__call__` method",
)
def _get_dummy_image_embeds(self, cross_attention_dim: int = 32):
return torch.randn((2, 1, cross_attention_dim), device=torch_device)
def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]):
parameters = inspect.signature(self.pipeline_class.__call__).parameters
if "image" in parameters.keys() and "strength" in parameters.keys():
inputs["num_inference_steps"] = 4
inputs["output_type"] = "np"
inputs["return_dict"] = False
return inputs
def test_ip_adapter_single(self, expected_max_diff: float = 1e-4):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components).to(torch_device)
pipe.set_progress_bar_config(disable=None)
cross_attention_dim = pipe.unet.config.get("cross_attention_dim", 32)
# forward pass without ip adapter
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
output_without_adapter = pipe(**inputs)[0]
adapter_state_dict = create_ip_adapter_state_dict(pipe.unet)
pipe.unet._load_ip_adapter_weights(adapter_state_dict)
# forward pass with single ip adapter, but scale=0 which should have no effect
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
pipe.set_ip_adapter_scale(0.0)
output_without_adapter_scale = pipe(**inputs)[0]
# forward pass with single ip adapter, but with scale of adapter weights
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
pipe.set_ip_adapter_scale(42.0)
output_with_adapter_scale = pipe(**inputs)[0]
max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max()
max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max()
self.assertLess(
max_diff_without_adapter_scale,
expected_max_diff,
"Output without ip-adapter must be same as normal inference",
)
self.assertGreater(
max_diff_with_adapter_scale, 1e-2, "Output with ip-adapter must be different from normal inference"
)
def test_ip_adapter_multi(self, expected_max_diff: float = 1e-4):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components).to(torch_device)
pipe.set_progress_bar_config(disable=None)
cross_attention_dim = pipe.unet.config.get("cross_attention_dim", 32)
# forward pass without ip adapter
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
output_without_adapter = pipe(**inputs)[0]
adapter_state_dict_1 = create_ip_adapter_state_dict(pipe.unet)
adapter_state_dict_2 = create_ip_adapter_state_dict(pipe.unet)
pipe.unet._load_ip_adapter_weights([adapter_state_dict_1, adapter_state_dict_2])
# forward pass with multi ip adapter, but scale=0 which should have no effect
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
pipe.set_ip_adapter_scale([0.0, 0.0])
output_without_multi_adapter_scale = pipe(**inputs)[0]
# forward pass with multi ip adapter, but with scale of adapter weights
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
pipe.set_ip_adapter_scale([42.0, 42.0])
output_with_multi_adapter_scale = pipe(**inputs)[0]
max_diff_without_multi_adapter_scale = np.abs(
output_without_multi_adapter_scale - output_without_adapter
).max()
max_diff_with_multi_adapter_scale = np.abs(output_with_multi_adapter_scale - output_without_adapter).max()
self.assertLess(
max_diff_without_multi_adapter_scale,
expected_max_diff,
"Output without multi-ip-adapter must be same as normal inference",
)
self.assertGreater(
max_diff_with_multi_adapter_scale,
1e-2,
"Output with multi-ip-adapter scale must be different from normal inference",
)
class PipelineLatentTesterMixin: class PipelineLatentTesterMixin:
""" """
This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes. This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment