Unverified Commit b6e0b016 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Lazy Import for Diffusers (#4829)



* initial commit

* move modules to import struct

* add dummy objects and _LazyModule

* add lazy import to schedulers

* clean up unused imports

* lazy import on models module

* lazy import for schedulers module

* add lazy import to pipelines module

* lazy import altdiffusion

* lazy import audio diffusion

* lazy import audioldm

* lazy import consistency model

* lazy import controlnet

* lazy import dance diffusion ddim ddpm

* lazy import deepfloyd

* lazy import kandinksy

* lazy imports

* lazy import semantic diffusion

* lazy imports

* lazy import stable diffusion

* move sd output to its own module

* clean up

* lazy import t2iadapter

* lazy import unclip

* lazy import versatile and vq diffsuion

* lazy import vq diffusion

* helper to fetch objects from modules

* lazy import sdxl

* lazy import txt2vid

* lazy import stochastic karras

* fix model imports

* fix bug

* lazy import

* clean up

* clean up

* fixes for tests

* fixes for tests

* clean up

* remove import of torch_utils from utils module

* clean up

* clean up

* fix mistake import statement

* dedicated modules for exporting and loading

* remove testing utils from utils module

* fixes from  merge conflicts

* Update src/diffusers/pipelines/kandinsky2_2/__init__.py

* fix docs

* fix alt diffusion copied from

* fix check dummies

* fix more docs

* remove accelerate import from utils module

* add type checking

* make style

* fix check dummies

* remove torch import from xformers check

* clean up error message

* fixes after upstream merges

* dummy objects fix

* fix tests

* remove unused module import

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 88735249
...@@ -22,7 +22,8 @@ from transformers import ClapTextModelWithProjection, RobertaTokenizer, RobertaT ...@@ -22,7 +22,8 @@ from transformers import ClapTextModelWithProjection, RobertaTokenizer, RobertaT
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import logging, randn_tensor, replace_example_docstring from ...utils import logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
......
from ...utils import ( from ...utils import (
OptionalDependencyNotAvailable, OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available, is_torch_available,
is_transformers_available, is_transformers_available,
is_transformers_version, is_transformers_version,
) )
_import_structure = {}
_dummy_objects = {}
try: try:
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")): if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import ( from ...utils import dummy_torch_and_transformers_objects
AudioLDM2Pipeline,
AudioLDM2ProjectionModel, _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
AudioLDM2UNet2DConditionModel,
)
else: else:
from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel _import_structure["modeling_audioldm2"] = ["AudioLDM2ProjectionModel", "AudioLDM2UNet2DConditionModel"]
from .pipeline_audioldm2 import AudioLDM2Pipeline _import_structure["pipeline_audioldm2"] = ["AudioLDM2Pipeline"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
...@@ -36,9 +36,9 @@ from ...utils import ( ...@@ -36,9 +36,9 @@ from ...utils import (
is_accelerate_version, is_accelerate_version,
is_librosa_available, is_librosa_available,
logging, logging,
randn_tensor,
replace_example_docstring, replace_example_docstring,
) )
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
......
from .pipeline_consistency_models import ConsistencyModelPipeline from ...utils import (
_LazyModule,
)
_import_structure = {}
_import_structure["pipeline_consistency_models"] = ["ConsistencyModelPipeline"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
...@@ -8,9 +8,9 @@ from ...utils import ( ...@@ -8,9 +8,9 @@ from ...utils import (
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
logging, logging,
randn_tensor,
replace_example_docstring, replace_example_docstring,
) )
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......
from ...utils import ( from ...utils import (
OptionalDependencyNotAvailable, OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_flax_available, is_flax_available,
is_torch_available, is_torch_available,
is_transformers_available, is_transformers_available,
) )
_import_structure = {}
_dummy_objects = {}
try: try:
if not (is_transformers_available() and is_torch_available()): if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else: else:
from .multicontrolnet import MultiControlNetModel _import_structure["multicontrolnet"] = ["MultiControlNetModel"]
from .pipeline_controlnet import StableDiffusionControlNetPipeline _import_structure["pipeline_controlnet"] = ["StableDiffusionControlNetPipeline"]
from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline _import_structure["pipeline_controlnet_img2img"] = ["StableDiffusionControlNetImg2ImgPipeline"]
from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline _import_structure["pipeline_controlnet_inpaint"] = ["StableDiffusionControlNetInpaintPipeline"]
from .pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipeline _import_structure["pipeline_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPipeline"]
from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline _import_structure["pipeline_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetImg2ImgPipeline"]
from .pipeline_controlnet_sd_xl_img2img import StableDiffusionXLControlNetImg2ImgPipeline _import_structure["pipeline_controlnet_inpaint_sd_xl"] = ["StableDiffusionXLControlNetInpaintPipeline"]
try:
if not (is_transformers_available() and is_flax_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_flax_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
else:
_import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
if is_transformers_available() and is_flax_available(): for name, value in _dummy_objects.items():
from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline setattr(sys.modules[__name__], name, value)
...@@ -31,11 +31,10 @@ from ...utils import ( ...@@ -31,11 +31,10 @@ from ...utils import (
deprecate, deprecate,
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
is_compiled_module,
logging, logging,
randn_tensor,
replace_example_docstring, replace_example_docstring,
) )
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion import StableDiffusionPipelineOutput from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
......
...@@ -30,11 +30,10 @@ from ...utils import ( ...@@ -30,11 +30,10 @@ from ...utils import (
deprecate, deprecate,
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
is_compiled_module,
logging, logging,
randn_tensor,
replace_example_docstring, replace_example_docstring,
) )
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion import StableDiffusionPipelineOutput from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
......
...@@ -32,11 +32,10 @@ from ...utils import ( ...@@ -32,11 +32,10 @@ from ...utils import (
deprecate, deprecate,
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
is_compiled_module,
logging, logging,
randn_tensor,
replace_example_docstring, replace_example_docstring,
) )
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion import StableDiffusionPipelineOutput from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
......
...@@ -38,12 +38,11 @@ from ...schedulers import KarrasDiffusionSchedulers ...@@ -38,12 +38,11 @@ from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
is_compiled_module,
is_invisible_watermark_available, is_invisible_watermark_available,
logging, logging,
randn_tensor,
replace_example_docstring, replace_example_docstring,
) )
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from .multicontrolnet import MultiControlNetModel from .multicontrolnet import MultiControlNetModel
......
...@@ -39,11 +39,10 @@ from ...schedulers import KarrasDiffusionSchedulers ...@@ -39,11 +39,10 @@ from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
is_compiled_module,
logging, logging,
randn_tensor,
replace_example_docstring, replace_example_docstring,
) )
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion_xl import StableDiffusionXLPipelineOutput from ..stable_diffusion_xl import StableDiffusionXLPipelineOutput
......
...@@ -38,11 +38,10 @@ from ...schedulers import KarrasDiffusionSchedulers ...@@ -38,11 +38,10 @@ from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
is_compiled_module,
logging, logging,
randn_tensor,
replace_example_docstring, replace_example_docstring,
) )
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion_xl import StableDiffusionXLPipelineOutput from ..stable_diffusion_xl import StableDiffusionXLPipelineOutput
......
from .pipeline_dance_diffusion import DanceDiffusionPipeline from ...utils import _LazyModule
_import_structure = {}
_import_structure["pipeline_dance_diffusion"] = ["DanceDiffusionPipeline"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
...@@ -17,7 +17,8 @@ from typing import List, Optional, Tuple, Union ...@@ -17,7 +17,8 @@ from typing import List, Optional, Tuple, Union
import torch import torch
from ...utils import logging, randn_tensor from ...utils import logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
......
from .pipeline_ddim import DDIMPipeline from ...utils import _LazyModule
_import_structure = {}
_import_structure["pipeline_ddim"] = ["DDIMPipeline"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
...@@ -17,7 +17,7 @@ from typing import List, Optional, Tuple, Union ...@@ -17,7 +17,7 @@ from typing import List, Optional, Tuple, Union
import torch import torch
from ...schedulers import DDIMScheduler from ...schedulers import DDIMScheduler
from ...utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......
from .pipeline_ddpm import DDPMPipeline from ...utils import (
_LazyModule,
)
_import_structure = {}
_import_structure["pipeline_ddpm"] = ["DDPMPipeline"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
...@@ -17,7 +17,7 @@ from typing import List, Optional, Tuple, Union ...@@ -17,7 +17,7 @@ from typing import List, Optional, Tuple, Union
import torch import torch
from ...utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......
from dataclasses import dataclass from ...utils import (
from typing import List, Optional, Union OptionalDependencyNotAvailable,
_LazyModule,
import numpy as np get_objects_from_module,
import PIL is_torch_available,
is_transformers_available,
from ...utils import BaseOutput, OptionalDependencyNotAvailable, is_torch_available, is_transformers_available
from .timesteps import (
fast27_timesteps,
smart27_timesteps,
smart50_timesteps,
smart100_timesteps,
smart185_timesteps,
super27_timesteps,
super40_timesteps,
super100_timesteps,
) )
@dataclass _import_structure = {}
class IFPipelineOutput(BaseOutput): _dummy_objects = {}
"""
Args:
Output class for Stable Diffusion pipelines.
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
nsfw_detected (`List[bool]`)
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content or a watermark. `None` if safety checking could not be performed.
watermark_detected (`List[bool]`)
List of flags denoting whether the corresponding generated image likely has a watermark. `None` if safety
checking could not be performed.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
nsfw_detected: Optional[List[bool]]
watermark_detected: Optional[List[bool]]
_import_structure["timesteps"] = [
"fast27_timesteps",
"smart27_timesteps",
"smart50_timesteps",
"smart100_timesteps",
"smart185_timesteps",
"super27_timesteps",
"super40_timesteps",
"super100_timesteps",
]
try: try:
if not (is_transformers_available() and is_torch_available()): if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else: else:
from .pipeline_if import IFPipeline _import_structure["pipeline_output"] = ["IFPipelineOutput"]
from .pipeline_if_img2img import IFImg2ImgPipeline _import_structure["pipeline_if"] = ["IFPipeline"]
from .pipeline_if_img2img_superresolution import IFImg2ImgSuperResolutionPipeline _import_structure["pipeline_if_img2img"] = ["IFImg2ImgPipeline"]
from .pipeline_if_inpainting import IFInpaintingPipeline _import_structure["pipeline_if_img2img_superresolution"] = ["IFImg2ImgSuperResolutionPipeline"]
from .pipeline_if_inpainting_superresolution import IFInpaintingSuperResolutionPipeline _import_structure["pipeline_if_inpainting"] = ["IFInpaintingPipeline"]
from .pipeline_if_superresolution import IFSuperResolutionPipeline _import_structure["pipeline_if_inpainting_superresolution"] = ["IFInpaintingSuperResolutionPipeline"]
from .safety_checker import IFSafetyChecker _import_structure["pipeline_if_superresolution"] = ["IFSuperResolutionPipeline"]
from .watermark import IFWatermarker _import_structure["safety_checker"] = ["IFSafetyChecker"]
_import_structure["watermark"] = ["IFWatermarker"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
...@@ -17,9 +17,9 @@ from ...utils import ( ...@@ -17,9 +17,9 @@ from ...utils import (
is_bs4_available, is_bs4_available,
is_ftfy_available, is_ftfy_available,
logging, logging,
randn_tensor,
replace_example_docstring, replace_example_docstring,
) )
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import IFPipelineOutput from . import IFPipelineOutput
from .safety_checker import IFSafetyChecker from .safety_checker import IFSafetyChecker
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment