Unverified Commit b6e0b016 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Lazy Import for Diffusers (#4829)



* initial commit

* move modules to import struct

* add dummy objects and _LazyModule

* add lazy import to schedulers

* clean up unused imports

* lazy import on models module

* lazy import for schedulers module

* add lazy import to pipelines module

* lazy import altdiffusion

* lazy import audio diffusion

* lazy import audioldm

* lazy import consistency model

* lazy import controlnet

* lazy import dance diffusion ddim ddpm

* lazy import deepfloyd

* lazy import kandinksy

* lazy imports

* lazy import semantic diffusion

* lazy imports

* lazy import stable diffusion

* move sd output to its own module

* clean up

* lazy import t2iadapter

* lazy import unclip

* lazy import versatile and vq diffsuion

* lazy import vq diffusion

* helper to fetch objects from modules

* lazy import sdxl

* lazy import txt2vid

* lazy import stochastic karras

* fix model imports

* fix bug

* lazy import

* clean up

* clean up

* fixes for tests

* fixes for tests

* clean up

* remove import of torch_utils from utils module

* clean up

* clean up

* fix mistake import statement

* dedicated modules for exporting and loading

* remove testing utils from utils module

* fixes from  merge conflicts

* Update src/diffusers/pipelines/kandinsky2_2/__init__.py

* fix docs

* fix alt diffusion copied from

* fix check dummies

* fix more docs

* remove accelerate import from utils module

* add type checking

* make style

* fix check dummies

* remove torch import from xformers check

* clean up error message

* fixes after upstream merges

* dummy objects fix

* fix tests

* remove unused module import

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 88735249
...@@ -31,8 +31,14 @@ from diffusers import ( ...@@ -31,8 +31,14 @@ from diffusers import (
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.models.attention_processor import AttnProcessor from diffusers.models.attention_processor import AttnProcessor
from diffusers.utils import load_numpy, slow, torch_device from diffusers.utils.testing_utils import (
from diffusers.utils.testing_utils import enable_full_determinism, numpy_cosine_similarity_distance, require_torch_gpu enable_full_determinism,
load_numpy,
numpy_cosine_similarity_distance,
require_torch_gpu,
slow,
torch_device,
)
enable_full_determinism() enable_full_determinism()
......
...@@ -24,8 +24,7 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer ...@@ -24,8 +24,7 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel from diffusers import AutoencoderKL, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion_safe import StableDiffusionPipelineSafe as StableDiffusionPipeline from diffusers.pipelines.stable_diffusion_safe import StableDiffusionPipelineSafe as StableDiffusionPipeline
from diffusers.utils import floats_tensor, nightly, torch_device from diffusers.utils.testing_utils import floats_tensor, nightly, require_torch_gpu, torch_device
from diffusers.utils.testing_utils import require_torch_gpu
class SafeDiffusionPipelineFastTests(unittest.TestCase): class SafeDiffusionPipelineFastTests(unittest.TestCase):
......
...@@ -32,8 +32,7 @@ from diffusers import ( ...@@ -32,8 +32,7 @@ from diffusers import (
UNet2DConditionModel, UNet2DConditionModel,
UniPCMultistepScheduler, UniPCMultistepScheduler,
) )
from diffusers.utils import torch_device from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, torch_device
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
......
...@@ -27,8 +27,7 @@ from diffusers import ( ...@@ -27,8 +27,7 @@ from diffusers import (
T2IAdapter, T2IAdapter,
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.utils import floats_tensor from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor
from diffusers.utils.testing_utils import enable_full_determinism
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin from ..test_pipelines_common import PipelineTesterMixin
......
...@@ -26,8 +26,7 @@ from diffusers import ( ...@@ -26,8 +26,7 @@ from diffusers import (
StableDiffusionXLImg2ImgPipeline, StableDiffusionXLImg2ImgPipeline,
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.utils import floats_tensor, torch_device from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, require_torch_gpu, torch_device
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu
from ..pipeline_params import ( from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS, IMAGE_TO_IMAGE_IMAGE_PARAMS,
......
...@@ -32,8 +32,7 @@ from diffusers import ( ...@@ -32,8 +32,7 @@ from diffusers import (
UNet2DConditionModel, UNet2DConditionModel,
UniPCMultistepScheduler, UniPCMultistepScheduler,
) )
from diffusers.utils import floats_tensor, torch_device from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, require_torch_gpu, torch_device
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu
from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
......
...@@ -29,8 +29,7 @@ from diffusers.image_processor import VaeImageProcessor ...@@ -29,8 +29,7 @@ from diffusers.image_processor import VaeImageProcessor
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_instruct_pix2pix import ( from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_instruct_pix2pix import (
StableDiffusionXLInstructPix2PixPipeline, StableDiffusionXLInstructPix2PixPipeline,
) )
from diffusers.utils import floats_tensor, torch_device from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device
from diffusers.utils.testing_utils import enable_full_determinism
from ..pipeline_params import ( from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS, IMAGE_TO_IMAGE_IMAGE_PARAMS,
......
...@@ -62,24 +62,24 @@ from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME ...@@ -62,24 +62,24 @@ from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import ( from diffusers.utils import (
CONFIG_NAME, CONFIG_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
floats_tensor,
is_compiled_module,
nightly,
require_torch_2,
slow,
torch_device,
) )
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
CaptureLogger, CaptureLogger,
enable_full_determinism, enable_full_determinism,
floats_tensor,
get_tests_dir, get_tests_dir,
load_numpy, load_numpy,
nightly,
require_compel, require_compel,
require_flax, require_flax,
require_onnxruntime, require_onnxruntime,
require_torch_2,
require_torch_gpu, require_torch_gpu,
run_test_in_subprocess, run_test_in_subprocess,
slow,
torch_device,
) )
from diffusers.utils.torch_utils import is_compiled_module
enable_full_determinism() enable_full_determinism()
......
...@@ -34,7 +34,7 @@ from diffusers.pipelines.auto_pipeline import ( ...@@ -34,7 +34,7 @@ from diffusers.pipelines.auto_pipeline import (
AUTO_INPAINT_PIPELINES_MAPPING, AUTO_INPAINT_PIPELINES_MAPPING,
AUTO_TEXT2IMAGE_PIPELINES_MAPPING, AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
) )
from diffusers.utils import slow from diffusers.utils.testing_utils import slow
PRETRAINED_MODEL_REPO_MAPPING = OrderedDict( PRETRAINED_MODEL_REPO_MAPPING = OrderedDict(
......
...@@ -455,12 +455,13 @@ class PipelineTesterMixin: ...@@ -455,12 +455,13 @@ class PipelineTesterMixin:
# TODO same as above # TODO same as above
test_mean_pixel_difference = torch_device != "mps" test_mean_pixel_difference = torch_device != "mps"
generator_device = "cpu"
components = self.get_dummy_components() components = self.get_dummy_components()
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe.to(torch_device) pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device) inputs = self.get_dummy_inputs(generator_device)
logger = logging.get_logger(pipe.__module__) logger = logging.get_logger(pipe.__module__)
logger.setLevel(level=diffusers.logging.FATAL) logger.setLevel(level=diffusers.logging.FATAL)
...@@ -624,7 +625,8 @@ class PipelineTesterMixin: ...@@ -624,7 +625,8 @@ class PipelineTesterMixin:
for optional_component in pipe._optional_components: for optional_component in pipe._optional_components:
setattr(pipe, optional_component, None) setattr(pipe, optional_component, None)
inputs = self.get_dummy_inputs(torch_device) generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
output = pipe(**inputs)[0] output = pipe(**inputs)[0]
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
...@@ -642,7 +644,7 @@ class PipelineTesterMixin: ...@@ -642,7 +644,7 @@ class PipelineTesterMixin:
f"`{optional_component}` did not stay set to None after loading.", f"`{optional_component}` did not stay set to None after loading.",
) )
inputs = self.get_dummy_inputs(torch_device) inputs = self.get_dummy_inputs(generator_device)
output_loaded = pipe_loaded(**inputs)[0] output_loaded = pipe_loaded(**inputs)[0]
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
......
...@@ -25,8 +25,15 @@ from diffusers import ( ...@@ -25,8 +25,15 @@ from diffusers import (
TextToVideoSDPipeline, TextToVideoSDPipeline,
UNet3DConditionModel, UNet3DConditionModel,
) )
from diffusers.utils import is_xformers_available, load_numpy, require_torch_gpu, skip_mps, slow, torch_device from diffusers.utils import is_xformers_available
from diffusers.utils.testing_utils import enable_full_determinism from diffusers.utils.testing_utils import (
enable_full_determinism,
load_numpy,
require_torch_gpu,
skip_mps,
slow,
torch_device,
)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin from ..test_pipelines_common import PipelineTesterMixin
......
...@@ -18,7 +18,7 @@ import unittest ...@@ -18,7 +18,7 @@ import unittest
import torch import torch
from diffusers import DDIMScheduler, TextToVideoZeroPipeline from diffusers import DDIMScheduler, TextToVideoZeroPipeline
from diffusers.utils import load_pt, require_torch_gpu, slow from diffusers.utils.testing_utils import load_pt, require_torch_gpu, slow
from ..test_pipelines_common import assert_mean_pixel_difference from ..test_pipelines_common import assert_mean_pixel_difference
......
...@@ -26,8 +26,14 @@ from diffusers import ( ...@@ -26,8 +26,14 @@ from diffusers import (
UNet3DConditionModel, UNet3DConditionModel,
VideoToVideoSDPipeline, VideoToVideoSDPipeline,
) )
from diffusers.utils import floats_tensor, is_xformers_available, skip_mps from diffusers.utils import is_xformers_available
from diffusers.utils.testing_utils import enable_full_determinism, slow, torch_device from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
skip_mps,
slow,
torch_device,
)
from ..pipeline_params import ( from ..pipeline_params import (
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
......
...@@ -22,8 +22,15 @@ from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokeni ...@@ -22,8 +22,15 @@ from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokeni
from diffusers import PriorTransformer, UnCLIPPipeline, UnCLIPScheduler, UNet2DConditionModel, UNet2DModel from diffusers import PriorTransformer, UnCLIPPipeline, UnCLIPScheduler, UNet2DConditionModel, UNet2DModel
from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
from diffusers.utils import load_numpy, nightly, slow, torch_device from diffusers.utils.testing_utils import (
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, skip_mps enable_full_determinism,
load_numpy,
nightly,
require_torch_gpu,
skip_mps,
slow,
torch_device,
)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
......
...@@ -36,8 +36,16 @@ from diffusers import ( ...@@ -36,8 +36,16 @@ from diffusers import (
UNet2DModel, UNet2DModel,
) )
from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
from diffusers.utils import floats_tensor, load_numpy, slow, torch_device from diffusers.utils.testing_utils import (
from diffusers.utils.testing_utils import enable_full_determinism, load_image, require_torch_gpu, skip_mps enable_full_determinism,
floats_tensor,
load_image,
load_numpy,
require_torch_gpu,
skip_mps,
slow,
torch_device,
)
from ..pipeline_params import IMAGE_VARIATION_BATCH_PARAMS, IMAGE_VARIATION_PARAMS from ..pipeline_params import IMAGE_VARIATION_BATCH_PARAMS, IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
......
...@@ -20,8 +20,8 @@ from diffusers import ( ...@@ -20,8 +20,8 @@ from diffusers import (
UniDiffuserPipeline, UniDiffuserPipeline,
UniDiffuserTextDecoder, UniDiffuserTextDecoder,
) )
from diffusers.utils import floats_tensor, load_image, nightly, randn_tensor, slow, torch_device from diffusers.utils.testing_utils import floats_tensor, load_image, nightly, require_torch_gpu, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu from diffusers.utils.torch_utils import randn_tensor
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin from ..test_pipelines_common import PipelineTesterMixin
......
...@@ -22,8 +22,7 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer ...@@ -22,8 +22,7 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import Transformer2DModel, VQDiffusionPipeline, VQDiffusionScheduler, VQModel from diffusers import Transformer2DModel, VQDiffusionPipeline, VQDiffusionScheduler, VQModel
from diffusers.pipelines.vq_diffusion.pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings from diffusers.pipelines.vq_diffusion.pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings
from diffusers.utils import load_numpy, nightly, torch_device from diffusers.utils.testing_utils import load_numpy, nightly, require_torch_gpu, torch_device
from diffusers.utils.testing_utils import require_torch_gpu
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
......
...@@ -21,8 +21,7 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer ...@@ -21,8 +21,7 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import DDPMWuerstchenScheduler, WuerstchenCombinedPipeline from diffusers import DDPMWuerstchenScheduler, WuerstchenCombinedPipeline
from diffusers.pipelines.wuerstchen import PaellaVQModel, WuerstchenDiffNeXt, WuerstchenPrior from diffusers.pipelines.wuerstchen import PaellaVQModel, WuerstchenDiffNeXt, WuerstchenPrior
from diffusers.utils import torch_device from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, torch_device
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu
from ..test_pipelines_common import PipelineTesterMixin from ..test_pipelines_common import PipelineTesterMixin
......
...@@ -21,8 +21,7 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer ...@@ -21,8 +21,7 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import DDPMWuerstchenScheduler, WuerstchenDecoderPipeline from diffusers import DDPMWuerstchenScheduler, WuerstchenDecoderPipeline
from diffusers.pipelines.wuerstchen import PaellaVQModel, WuerstchenDiffNeXt from diffusers.pipelines.wuerstchen import PaellaVQModel, WuerstchenDiffNeXt
from diffusers.utils import torch_device from diffusers.utils.testing_utils import enable_full_determinism, skip_mps, torch_device
from diffusers.utils.testing_utils import enable_full_determinism, skip_mps
from ..test_pipelines_common import PipelineTesterMixin from ..test_pipelines_common import PipelineTesterMixin
......
...@@ -21,8 +21,7 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer ...@@ -21,8 +21,7 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import DDPMWuerstchenScheduler, WuerstchenPriorPipeline from diffusers import DDPMWuerstchenScheduler, WuerstchenPriorPipeline
from diffusers.pipelines.wuerstchen import WuerstchenPrior from diffusers.pipelines.wuerstchen import WuerstchenPrior
from diffusers.utils import torch_device from diffusers.utils.testing_utils import enable_full_determinism, skip_mps, torch_device
from diffusers.utils.testing_utils import enable_full_determinism, skip_mps
from ..test_pipelines_common import PipelineTesterMixin from ..test_pipelines_common import PipelineTesterMixin
...@@ -146,7 +145,6 @@ class WuerstchenPriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -146,7 +145,6 @@ class WuerstchenPriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
image_slice = image[0, 0, 0, -10:] image_slice = image[0, 0, 0, -10:]
image_from_tuple_slice = image_from_tuple[0, 0, 0, -10:] image_from_tuple_slice = image_from_tuple[0, 0, 0, -10:]
assert image.shape == (1, 2, 24, 24) assert image.shape == (1, 2, 24, 24)
expected_slice = np.array( expected_slice = np.array(
...@@ -161,7 +159,7 @@ class WuerstchenPriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -161,7 +159,7 @@ class WuerstchenPriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
218.00089, 218.00089,
-2731.5745, -2731.5745,
-8056.734, -8056.734,
], ]
) )
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
...@@ -176,7 +174,7 @@ class WuerstchenPriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -176,7 +174,7 @@ class WuerstchenPriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
test_max_difference=test_max_difference, test_max_difference=test_max_difference,
relax_max_difference=relax_max_difference, relax_max_difference=relax_max_difference,
test_mean_pixel_difference=test_mean_pixel_difference, test_mean_pixel_difference=test_mean_pixel_difference,
expected_max_diff=1e-1, expected_max_diff=2e-1,
) )
@skip_mps @skip_mps
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment