Unverified Commit b6e0b016 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Lazy Import for Diffusers (#4829)



* initial commit

* move modules to import struct

* add dummy objects and _LazyModule

* add lazy import to schedulers

* clean up unused imports

* lazy import on models module

* lazy import for schedulers module

* add lazy import to pipelines module

* lazy import altdiffusion

* lazy import audio diffusion

* lazy import audioldm

* lazy import consistency model

* lazy import controlnet

* lazy import dance diffusion ddim ddpm

* lazy import deepfloyd

* lazy import kandinksy

* lazy imports

* lazy import semantic diffusion

* lazy imports

* lazy import stable diffusion

* move sd output to its own module

* clean up

* lazy import t2iadapter

* lazy import unclip

* lazy import versatile and vq diffsuion

* lazy import vq diffusion

* helper to fetch objects from modules

* lazy import sdxl

* lazy import txt2vid

* lazy import stochastic karras

* fix model imports

* fix bug

* lazy import

* clean up

* clean up

* fixes for tests

* fixes for tests

* clean up

* remove import of torch_utils from utils module

* clean up

* clean up

* fix mistake import statement

* dedicated modules for exporting and loading

* remove testing utils from utils module

* fixes from  merge conflicts

* Update src/diffusers/pipelines/kandinsky2_2/__init__.py

* fix docs

* fix alt diffusion copied from

* fix check dummies

* fix more docs

* remove accelerate import from utils module

* add type checking

* make style

* fix check dummies

* remove torch import from xformers check

* clean up error message

* fixes after upstream merges

* dummy objects fix

* fix tests

* remove unused module import

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 88735249
......@@ -10,9 +10,9 @@ from ...utils import (
is_accelerate_available,
is_accelerate_version,
logging,
randn_tensor,
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor
from ..kandinsky import KandinskyPriorPipelineOutput
from ..pipeline_utils import DiffusionPipeline
......
from ...utils import OptionalDependencyNotAvailable, is_torch_available, is_transformers_available
from .pipeline_latent_diffusion_superresolution import LDMSuperResolutionPipeline
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_import_structure = {}
_dummy_objects = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import ShapEPipeline
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
from .pipeline_latent_diffusion import LDMBertModel, LDMTextToImagePipeline
_import_structure["pipeline_latent_diffusion"] = ["LDMBertModel", "LDMTextToImagePipeline"]
_import_structure["pipeline_latent_diffusion_superresolution"] = ["LDMSuperResolutionPipeline"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
......@@ -25,7 +25,7 @@ from transformers.utils import logging
from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import randn_tensor
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......
......@@ -15,7 +15,8 @@ from ...schedulers import (
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...utils import PIL_INTERPOLATION, randn_tensor
from ...utils import PIL_INTERPOLATION
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......
from .pipeline_latent_diffusion_uncond import LDMPipeline
from ...utils import _LazyModule
_import_structure = {}
_import_structure["pipeline_latent_diffusion_uncond"] = ["LDMPipeline"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
......@@ -19,7 +19,7 @@ import torch
from ...models import UNet2DModel, VQModel
from ...schedulers import DDIMScheduler
from ...utils import randn_tensor
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
is_transformers_version,
)
_import_structure = {}
_dummy_objects = {}
try:
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import (
MusicLDMPipeline,
)
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
from .pipeline_musicldm import MusicLDMPipeline
_import_structure["pipeline_musicldm"] = ["MusicLDMPipeline"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
......@@ -28,7 +28,8 @@ from transformers import (
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import is_librosa_available, logging, randn_tensor, replace_example_docstring
from ...utils import is_librosa_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
......
......@@ -5,14 +5,38 @@ import numpy as np
import PIL
from PIL import Image
from ...utils import OptionalDependencyNotAvailable, is_torch_available, is_transformers_available
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_import_structure = {}
_dummy_objects = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import ShapEPipeline
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
from .image_encoder import PaintByExampleImageEncoder
from .pipeline_paint_by_example import PaintByExamplePipeline
_import_structure["image_encoder"] = ["PaintByExampleImageEncoder"]
_import_structure["pipeline_paint_by_example"] = ["PaintByExamplePipeline"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
......@@ -23,7 +23,8 @@ from transformers import CLIPImageProcessor
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import deprecate, logging, randn_tensor
from ...utils import deprecate, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
......
......@@ -51,12 +51,12 @@ from ..utils import (
get_class_from_dynamic_module,
is_accelerate_available,
is_accelerate_version,
is_compiled_module,
is_torch_version,
is_transformers_available,
logging,
numpy_to_pil,
)
from ..utils.torch_utils import is_compiled_module
if is_transformers_available():
......
from .pipeline_pndm import PNDMPipeline
from ...utils import _LazyModule
_import_structure = {}
_import_structure["pipeline_pndm"] = ["PNDMPipeline"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
......@@ -19,7 +19,7 @@ import torch
from ...models import UNet2DModel
from ...schedulers import PNDMScheduler
from ...utils import randn_tensor
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......
from .pipeline_repaint import RePaintPipeline
from ...utils import _LazyModule
_import_structure = {}
_import_structure["pipeline_repaint"] = ["RePaintPipeline"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
......@@ -21,7 +21,8 @@ import torch
from ...models import UNet2DModel
from ...schedulers import RePaintScheduler
from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor
from ...utils import PIL_INTERPOLATION, deprecate, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......
from .pipeline_score_sde_ve import ScoreSdeVePipeline
from ...utils import _LazyModule
_import_structure = {}
_import_structure["pipeline_score_sde_ve"] = ["ScoreSdeVePipeline"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
......@@ -18,7 +18,7 @@ import torch
from ...models import UNet2DModel
from ...schedulers import ScoreSdeVeScheduler
from ...utils import randn_tensor
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional, Union
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
import numpy as np
import PIL
from PIL import Image
from ...utils import BaseOutput, OptionalDependencyNotAvailable, is_torch_available, is_transformers_available
@dataclass
class SemanticStableDiffusionPipelineOutput(BaseOutput):
"""
Output class for Stable Diffusion pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
num_channels)`.
nsfw_content_detected (`List[bool]`)
List indicating whether the corresponding generated image contains “not-safe-for-work” (nsfw) content or
`None` if safety checking could not be performed.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
nsfw_content_detected: Optional[List[bool]]
_import_structure = {}
_dummy_objects = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
from .pipeline_semantic_stable_diffusion import SemanticStableDiffusionPipeline
_import_structure["pipeline_output"] = ["SemanticStableDiffusionPipelineOutput"]
_import_structure["pipeline_semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
from dataclasses import dataclass
from typing import List, Optional, Union
import numpy as np
import PIL
from ...utils import BaseOutput
@dataclass
class SemanticStableDiffusionPipelineOutput(BaseOutput):
"""
Output class for Stable Diffusion pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
num_channels)`.
nsfw_content_detected (`List[bool]`)
List indicating whether the corresponding generated image contains “not-safe-for-work” (nsfw) content or
`None` if safety checking could not be performed.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
nsfw_content_detected: Optional[List[bool]]
......@@ -9,7 +9,8 @@ from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, randn_tensor
from ...utils import deprecate, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from . import SemanticStableDiffusionPipelineOutput
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment