Unverified Commit b6e0b016 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Lazy Import for Diffusers (#4829)



* initial commit

* move modules to import struct

* add dummy objects and _LazyModule

* add lazy import to schedulers

* clean up unused imports

* lazy import on models module

* lazy import for schedulers module

* add lazy import to pipelines module

* lazy import altdiffusion

* lazy import audio diffusion

* lazy import audioldm

* lazy import consistency model

* lazy import controlnet

* lazy import dance diffusion ddim ddpm

* lazy import deepfloyd

* lazy import kandinksy

* lazy imports

* lazy import semantic diffusion

* lazy imports

* lazy import stable diffusion

* move sd output to its own module

* clean up

* lazy import t2iadapter

* lazy import unclip

* lazy import versatile and vq diffsuion

* lazy import vq diffusion

* helper to fetch objects from modules

* lazy import sdxl

* lazy import txt2vid

* lazy import stochastic karras

* fix model imports

* fix bug

* lazy import

* clean up

* clean up

* fixes for tests

* fixes for tests

* clean up

* remove import of torch_utils from utils module

* clean up

* clean up

* fix mistake import statement

* dedicated modules for exporting and loading

* remove testing utils from utils module

* fixes from  merge conflicts

* Update src/diffusers/pipelines/kandinsky2_2/__init__.py

* fix docs

* fix alt diffusion copied from

* fix check dummies

* fix more docs

* remove accelerate import from utils module

* add type checking

* make style

* fix check dummies

* remove torch import from xformers check

* clean up error message

* fixes after upstream merges

* dummy objects fix

* fix tests

* remove unused module import

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 88735249
...@@ -10,9 +10,9 @@ from ...utils import ( ...@@ -10,9 +10,9 @@ from ...utils import (
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
logging, logging,
randn_tensor,
replace_example_docstring, replace_example_docstring,
) )
from ...utils.torch_utils import randn_tensor
from ..kandinsky import KandinskyPriorPipelineOutput from ..kandinsky import KandinskyPriorPipelineOutput
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
......
from ...utils import OptionalDependencyNotAvailable, is_torch_available, is_transformers_available from ...utils import (
from .pipeline_latent_diffusion_superresolution import LDMSuperResolutionPipeline OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_import_structure = {}
_dummy_objects = {}
try: try:
if not (is_transformers_available() and is_torch_available()): if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import ShapEPipeline from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else: else:
from .pipeline_latent_diffusion import LDMBertModel, LDMTextToImagePipeline _import_structure["pipeline_latent_diffusion"] = ["LDMBertModel", "LDMTextToImagePipeline"]
_import_structure["pipeline_latent_diffusion_superresolution"] = ["LDMSuperResolutionPipeline"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
...@@ -25,7 +25,7 @@ from transformers.utils import logging ...@@ -25,7 +25,7 @@ from transformers.utils import logging
from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......
...@@ -15,7 +15,8 @@ from ...schedulers import ( ...@@ -15,7 +15,8 @@ from ...schedulers import (
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
) )
from ...utils import PIL_INTERPOLATION, randn_tensor from ...utils import PIL_INTERPOLATION
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......
from .pipeline_latent_diffusion_uncond import LDMPipeline from ...utils import _LazyModule
_import_structure = {}
_import_structure["pipeline_latent_diffusion_uncond"] = ["LDMPipeline"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
...@@ -19,7 +19,7 @@ import torch ...@@ -19,7 +19,7 @@ import torch
from ...models import UNet2DModel, VQModel from ...models import UNet2DModel, VQModel
from ...schedulers import DDIMScheduler from ...schedulers import DDIMScheduler
from ...utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......
from ...utils import ( from ...utils import (
OptionalDependencyNotAvailable, OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available, is_torch_available,
is_transformers_available, is_transformers_available,
is_transformers_version, is_transformers_version,
) )
_import_structure = {}
_dummy_objects = {}
try: try:
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")): if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import ( from ...utils import dummy_torch_and_transformers_objects # noqa F403
MusicLDMPipeline,
) _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else: else:
from .pipeline_musicldm import MusicLDMPipeline _import_structure["pipeline_musicldm"] = ["MusicLDMPipeline"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
...@@ -28,7 +28,8 @@ from transformers import ( ...@@ -28,7 +28,8 @@ from transformers import (
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import is_librosa_available, logging, randn_tensor, replace_example_docstring from ...utils import is_librosa_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
......
...@@ -5,14 +5,38 @@ import numpy as np ...@@ -5,14 +5,38 @@ import numpy as np
import PIL import PIL
from PIL import Image from PIL import Image
from ...utils import OptionalDependencyNotAvailable, is_torch_available, is_transformers_available from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_import_structure = {}
_dummy_objects = {}
try: try:
if not (is_transformers_available() and is_torch_available()): if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import ShapEPipeline from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else: else:
from .image_encoder import PaintByExampleImageEncoder _import_structure["image_encoder"] = ["PaintByExampleImageEncoder"]
from .pipeline_paint_by_example import PaintByExamplePipeline _import_structure["pipeline_paint_by_example"] = ["PaintByExamplePipeline"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
...@@ -23,7 +23,8 @@ from transformers import CLIPImageProcessor ...@@ -23,7 +23,8 @@ from transformers import CLIPImageProcessor
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import deprecate, logging, randn_tensor from ...utils import deprecate, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion import StableDiffusionPipelineOutput from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
......
...@@ -51,12 +51,12 @@ from ..utils import ( ...@@ -51,12 +51,12 @@ from ..utils import (
get_class_from_dynamic_module, get_class_from_dynamic_module,
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
is_compiled_module,
is_torch_version, is_torch_version,
is_transformers_available, is_transformers_available,
logging, logging,
numpy_to_pil, numpy_to_pil,
) )
from ..utils.torch_utils import is_compiled_module
if is_transformers_available(): if is_transformers_available():
......
from .pipeline_pndm import PNDMPipeline from ...utils import _LazyModule
_import_structure = {}
_import_structure["pipeline_pndm"] = ["PNDMPipeline"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
...@@ -19,7 +19,7 @@ import torch ...@@ -19,7 +19,7 @@ import torch
from ...models import UNet2DModel from ...models import UNet2DModel
from ...schedulers import PNDMScheduler from ...schedulers import PNDMScheduler
from ...utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......
from .pipeline_repaint import RePaintPipeline from ...utils import _LazyModule
_import_structure = {}
_import_structure["pipeline_repaint"] = ["RePaintPipeline"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
...@@ -21,7 +21,8 @@ import torch ...@@ -21,7 +21,8 @@ import torch
from ...models import UNet2DModel from ...models import UNet2DModel
from ...schedulers import RePaintScheduler from ...schedulers import RePaintScheduler
from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor from ...utils import PIL_INTERPOLATION, deprecate, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......
from .pipeline_score_sde_ve import ScoreSdeVePipeline from ...utils import _LazyModule
_import_structure = {}
_import_structure["pipeline_score_sde_ve"] = ["ScoreSdeVePipeline"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
...@@ -18,7 +18,7 @@ import torch ...@@ -18,7 +18,7 @@ import torch
from ...models import UNet2DModel from ...models import UNet2DModel
from ...schedulers import ScoreSdeVeScheduler from ...schedulers import ScoreSdeVeScheduler
from ...utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......
from dataclasses import dataclass from ...utils import (
from enum import Enum OptionalDependencyNotAvailable,
from typing import List, Optional, Union _LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
import numpy as np
import PIL
from PIL import Image
from ...utils import BaseOutput, OptionalDependencyNotAvailable, is_torch_available, is_transformers_available _import_structure = {}
_dummy_objects = {}
@dataclass
class SemanticStableDiffusionPipelineOutput(BaseOutput):
"""
Output class for Stable Diffusion pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
num_channels)`.
nsfw_content_detected (`List[bool]`)
List indicating whether the corresponding generated image contains “not-safe-for-work” (nsfw) content or
`None` if safety checking could not be performed.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
nsfw_content_detected: Optional[List[bool]]
try: try:
if not (is_transformers_available() and is_torch_available()): if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else: else:
from .pipeline_semantic_stable_diffusion import SemanticStableDiffusionPipeline _import_structure["pipeline_output"] = ["SemanticStableDiffusionPipelineOutput"]
_import_structure["pipeline_semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
from dataclasses import dataclass
from typing import List, Optional, Union
import numpy as np
import PIL
from ...utils import BaseOutput
@dataclass
class SemanticStableDiffusionPipelineOutput(BaseOutput):
"""
Output class for Stable Diffusion pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
num_channels)`.
nsfw_content_detected (`List[bool]`)
List indicating whether the corresponding generated image contains “not-safe-for-work” (nsfw) content or
`None` if safety checking could not be performed.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
nsfw_content_detected: Optional[List[bool]]
...@@ -9,7 +9,8 @@ from ...image_processor import VaeImageProcessor ...@@ -9,7 +9,8 @@ from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, randn_tensor from ...utils import deprecate, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import SemanticStableDiffusionPipelineOutput from . import SemanticStableDiffusionPipelineOutput
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment