Unverified Commit fa633ed6 authored by elucida's avatar elucida Committed by GitHub
Browse files

refactor: move model helper function in pipeline to a mixin class (#6571)



* move model helper function in pipeline to EfficiencyMixin

---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 2cad1a84
...@@ -675,6 +675,21 @@ class ScoreSdeVePipeline(metaclass=DummyObject): ...@@ -675,6 +675,21 @@ class ScoreSdeVePipeline(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class StableDiffusionMixin(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AmusedScheduler(metaclass=DummyObject): class AmusedScheduler(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -18,7 +18,7 @@ from diffusers.utils import is_xformers_available, logging ...@@ -18,7 +18,7 @@ from diffusers.utils import is_xformers_available, logging
from diffusers.utils.testing_utils import numpy_cosine_similarity_distance, require_torch_gpu, slow, torch_device from diffusers.utils.testing_utils import numpy_cosine_similarity_distance, require_torch_gpu, slow, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import IPAdapterTesterMixin, PipelineTesterMixin from ..test_pipelines_common import IPAdapterTesterMixin, PipelineTesterMixin, SDFunctionTesterMixin
def to_np(tensor): def to_np(tensor):
...@@ -28,7 +28,9 @@ def to_np(tensor): ...@@ -28,7 +28,9 @@ def to_np(tensor):
return tensor return tensor
class AnimateDiffPipelineFastTests(IPAdapterTesterMixin, PipelineTesterMixin, unittest.TestCase): class AnimateDiffPipelineFastTests(
IPAdapterTesterMixin, SDFunctionTesterMixin, PipelineTesterMixin, unittest.TestCase
):
pipeline_class = AnimateDiffPipeline pipeline_class = AnimateDiffPipeline
params = TEXT_TO_IMAGE_PARAMS params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
......
...@@ -46,14 +46,14 @@ from diffusers.utils.testing_utils import ( ...@@ -46,14 +46,14 @@ from diffusers.utils.testing_utils import (
torch_device, torch_device,
) )
from ..test_pipelines_common import PipelineTesterMixin from ..test_pipelines_common import PipelineTesterMixin, SDFunctionTesterMixin
enable_full_determinism() enable_full_determinism()
@skip_mps @skip_mps
class I2VGenXLPipelineFastTests(PipelineTesterMixin, unittest.TestCase): class I2VGenXLPipelineFastTests(SDFunctionTesterMixin, PipelineTesterMixin, unittest.TestCase):
pipeline_class = I2VGenXLPipeline pipeline_class = I2VGenXLPipeline
params = frozenset(["prompt", "negative_prompt", "image"]) params = frozenset(["prompt", "negative_prompt", "image"])
batch_params = frozenset(["prompt", "negative_prompt", "image", "generator"]) batch_params = frozenset(["prompt", "negative_prompt", "image", "generator"])
......
...@@ -52,14 +52,23 @@ from ..pipeline_params import ( ...@@ -52,14 +52,23 @@ from ..pipeline_params import (
TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS,
TEXT_TO_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS,
) )
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin from ..test_pipelines_common import (
PipelineKarrasSchedulerTesterMixin,
PipelineLatentTesterMixin,
PipelineTesterMixin,
SDFunctionTesterMixin,
)
enable_full_determinism() enable_full_determinism()
class StableDiffusion2PipelineFastTests( class StableDiffusion2PipelineFastTests(
PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase SDFunctionTesterMixin,
PipelineLatentTesterMixin,
PipelineKarrasSchedulerTesterMixin,
PipelineTesterMixin,
unittest.TestCase,
): ):
pipeline_class = StableDiffusionPipeline pipeline_class = StableDiffusionPipeline
params = TEXT_TO_IMAGE_PARAMS params = TEXT_TO_IMAGE_PARAMS
......
...@@ -53,6 +53,7 @@ from ..test_pipelines_common import ( ...@@ -53,6 +53,7 @@ from ..test_pipelines_common import (
IPAdapterTesterMixin, IPAdapterTesterMixin,
PipelineLatentTesterMixin, PipelineLatentTesterMixin,
PipelineTesterMixin, PipelineTesterMixin,
SDFunctionTesterMixin,
SDXLOptionalComponentsTesterMixin, SDXLOptionalComponentsTesterMixin,
) )
...@@ -61,6 +62,7 @@ enable_full_determinism() ...@@ -61,6 +62,7 @@ enable_full_determinism()
class StableDiffusionXLPipelineFastTests( class StableDiffusionXLPipelineFastTests(
SDFunctionTesterMixin,
IPAdapterTesterMixin, IPAdapterTesterMixin,
PipelineLatentTesterMixin, PipelineLatentTesterMixin,
PipelineTesterMixin, PipelineTesterMixin,
...@@ -948,37 +950,6 @@ class StableDiffusionXLPipelineFastTests( ...@@ -948,37 +950,6 @@ class StableDiffusionXLPipelineFastTests(
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3 assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
def test_stable_diffusion_xl_with_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs).images
original_image_slice = image[0, -3:, -3:, -1]
sd_pipe.fuse_qkv_projections()
inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs).images
image_slice_fused = image[0, -3:, -3:, -1]
sd_pipe.unfuse_qkv_projections()
inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
assert np.allclose(
original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
), "Fusion of QKV projections shouldn't affect the outputs."
assert np.allclose(
image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
assert np.allclose(
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Original outputs should match when fused QKV projections are disabled."
def test_pipeline_interrupt(self): def test_pipeline_interrupt(self):
components = self.get_dummy_components() components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**components) sd_pipe = StableDiffusionXLPipeline(**components)
......
...@@ -30,6 +30,10 @@ from diffusers import ( ...@@ -30,6 +30,10 @@ from diffusers import (
) )
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import IPAdapterMixin from diffusers.loaders import IPAdapterMixin
from diffusers.models.unets.unet_3d_condition import UNet3DConditionModel
from diffusers.models.unets.unet_i2vgen_xl import I2VGenXLUNet
from diffusers.models.unets.unet_motion_model import UNetMotionModel
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import logging from diffusers.utils import logging
from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available
...@@ -61,6 +65,149 @@ def check_same_shape(tensor_list): ...@@ -61,6 +65,149 @@ def check_same_shape(tensor_list):
return all(shape == shapes[0] for shape in shapes[1:]) return all(shape == shapes[0] for shape in shapes[1:])
class SDFunctionTesterMixin:
"""
This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.
It provides a set of common tests for PyTorch pipeline that inherit from StableDiffusionMixin, e.g. vae_slicing, vae_tiling, freeu, etc.
"""
def test_vae_slicing(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
# components["scheduler"] = LMSDiscreteScheduler.from_config(components["scheduler"].config)
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
image_count = 4
inputs = self.get_dummy_inputs(device)
inputs["prompt"] = [inputs["prompt"]] * image_count
if "image" in inputs: # fix batch size mismatch in I2V_Gen pipeline
inputs["image"] = [inputs["image"]] * image_count
output_1 = pipe(**inputs)
# make sure sliced vae decode yields the same result
pipe.enable_vae_slicing()
inputs = self.get_dummy_inputs(device)
inputs["prompt"] = [inputs["prompt"]] * image_count
if "image" in inputs:
inputs["image"] = [inputs["image"]] * image_count
inputs["return_dict"] = False
output_2 = pipe(**inputs)
assert np.abs(output_2[0].flatten() - output_1[0].flatten()).max() < 1e-2
def test_vae_tiling(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
# make sure here that pndm scheduler skips prk
if "safety_checker" in components:
components["safety_checker"] = None
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
inputs["return_dict"] = False
# Test that tiled decode at 512x512 yields the same result as the non-tiled decode
output_1 = pipe(**inputs)[0]
# make sure tiled vae decode yields the same result
pipe.enable_vae_tiling()
inputs = self.get_dummy_inputs(torch_device)
inputs["return_dict"] = False
output_2 = pipe(**inputs)[0]
assert np.abs(output_2 - output_1).max() < 5e-1
# test that tiled decode works with various shapes
shapes = [(1, 4, 73, 97), (1, 4, 97, 73), (1, 4, 49, 65), (1, 4, 65, 49)]
for shape in shapes:
zeros = torch.zeros(shape).to(device)
pipe.vae.decode(zeros)
def test_freeu_enabled(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
inputs["return_dict"] = False
output = pipe(**inputs)[0]
pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
inputs = self.get_dummy_inputs(torch_device)
inputs["return_dict"] = False
output_freeu = pipe(**inputs)[0]
assert not np.allclose(
output[0, -3:, -3:, -1], output_freeu[0, -3:, -3:, -1]
), "Enabling of FreeU should lead to different results."
def test_freeu_disabled(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
inputs["return_dict"] = False
output = pipe(**inputs)[0]
pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
pipe.disable_freeu()
freeu_keys = {"s1", "s2", "b1", "b2"}
for upsample_block in pipe.unet.up_blocks:
for key in freeu_keys:
assert getattr(upsample_block, key) is None, f"Disabling of FreeU should have set {key} to None."
inputs = self.get_dummy_inputs(torch_device)
inputs["return_dict"] = False
output_no_freeu = pipe(**inputs)[0]
assert np.allclose(
output, output_no_freeu, atol=1e-2
), f"Disabling of FreeU should lead to results similar to the default pipeline results but Max Abs Error={np.abs(output_no_freeu - output).max()}."
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["return_dict"] = False
image = pipe(**inputs)[0]
original_image_slice = image[0, -3:, -3:, -1]
pipe.fuse_qkv_projections()
inputs = self.get_dummy_inputs(device)
inputs["return_dict"] = False
image_fused = pipe(**inputs)[0]
image_slice_fused = image_fused[0, -3:, -3:, -1]
pipe.unfuse_qkv_projections()
inputs = self.get_dummy_inputs(device)
inputs["return_dict"] = False
image_disabled = pipe(**inputs)[0]
image_slice_disabled = image_disabled[0, -3:, -3:, -1]
assert np.allclose(
original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
), "Fusion of QKV projections shouldn't affect the outputs."
assert np.allclose(
image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
assert np.allclose(
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Original outputs should match when fused QKV projections are disabled."
class IPAdapterTesterMixin: class IPAdapterTesterMixin:
""" """
This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes. This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.
...@@ -1137,6 +1284,18 @@ class PipelineTesterMixin: ...@@ -1137,6 +1284,18 @@ class PipelineTesterMixin:
# accounts for models that modify the number of inference steps based on strength # accounts for models that modify the number of inference steps based on strength
assert pipe.guidance_scale == (inputs["guidance_scale"] + pipe.num_timesteps) assert pipe.guidance_scale == (inputs["guidance_scale"] + pipe.num_timesteps)
def test_StableDiffusionMixin_component(self):
"""Any pipeline that have LDMFuncMixin should have vae and unet components."""
if not issubclass(self.pipeline_class, StableDiffusionMixin):
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
self.assertTrue(hasattr(pipe, "vae") and isinstance(pipe.vae, (AutoencoderKL, AutoencoderTiny)))
self.assertTrue(
hasattr(pipe, "unet")
and isinstance(pipe.unet, (UNet2DConditionModel, UNet3DConditionModel, I2VGenXLUNet, UNetMotionModel))
)
@is_staging_test @is_staging_test
class PipelinePushToHubTester(unittest.TestCase): class PipelinePushToHubTester(unittest.TestCase):
......
...@@ -37,14 +37,14 @@ from diffusers.utils.testing_utils import ( ...@@ -37,14 +37,14 @@ from diffusers.utils.testing_utils import (
) )
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin from ..test_pipelines_common import PipelineTesterMixin, SDFunctionTesterMixin
enable_full_determinism() enable_full_determinism()
@skip_mps @skip_mps
class TextToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase): class TextToVideoSDPipelineFastTests(PipelineTesterMixin, SDFunctionTesterMixin, unittest.TestCase):
pipeline_class = TextToVideoSDPipeline pipeline_class = TextToVideoSDPipeline
params = TEXT_TO_IMAGE_PARAMS params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment