".github/vscode:/vscode.git/clone" did not exist on "0dd0c49a5be452e6e7e844a0fa1d5fe68a5ef1bb"
Unverified Commit b6e0b016 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Lazy Import for Diffusers (#4829)



* initial commit

* move modules to import struct

* add dummy objects and _LazyModule

* add lazy import to schedulers

* clean up unused imports

* lazy import on models module

* lazy import for schedulers module

* add lazy import to pipelines module

* lazy import altdiffusion

* lazy import audio diffusion

* lazy import audioldm

* lazy import consistency model

* lazy import controlnet

* lazy import dance diffusion ddim ddpm

* lazy import deepfloyd

* lazy import kandinksy

* lazy imports

* lazy import semantic diffusion

* lazy imports

* lazy import stable diffusion

* move sd output to its own module

* clean up

* lazy import t2iadapter

* lazy import unclip

* lazy import versatile and vq diffsuion

* lazy import vq diffusion

* helper to fetch objects from modules

* lazy import sdxl

* lazy import txt2vid

* lazy import stochastic karras

* fix model imports

* fix bug

* lazy import

* clean up

* clean up

* fixes for tests

* fixes for tests

* clean up

* remove import of torch_utils from utils module

* clean up

* clean up

* fix mistake import statement

* dedicated modules for exporting and loading

* remove testing utils from utils module

* fixes from  merge conflicts

* Update src/diffusers/pipelines/kandinsky2_2/__init__.py

* fix docs

* fix alt diffusion copied from

* fix check dummies

* fix more docs

* remove accelerate import from utils module

* add type checking

* make style

* fix check dummies

* remove torch import from xformers check

* clean up error message

* fixes after upstream merges

* dummy objects fix

* fix tests

* remove unused module import

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 88735249
......@@ -22,7 +22,8 @@ from transformers import ClapTextModelWithProjection, RobertaTokenizer, RobertaT
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import logging, randn_tensor, replace_example_docstring
from ...utils import logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
......
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
is_transformers_version,
)
_import_structure = {}
_dummy_objects = {}
try:
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import (
AudioLDM2Pipeline,
AudioLDM2ProjectionModel,
AudioLDM2UNet2DConditionModel,
)
from ...utils import dummy_torch_and_transformers_objects
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
from .pipeline_audioldm2 import AudioLDM2Pipeline
_import_structure["modeling_audioldm2"] = ["AudioLDM2ProjectionModel", "AudioLDM2UNet2DConditionModel"]
_import_structure["pipeline_audioldm2"] = ["AudioLDM2Pipeline"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
......@@ -36,9 +36,9 @@ from ...utils import (
is_accelerate_version,
is_librosa_available,
logging,
randn_tensor,
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
......
from .pipeline_consistency_models import ConsistencyModelPipeline
from ...utils import (
_LazyModule,
)
_import_structure = {}
_import_structure["pipeline_consistency_models"] = ["ConsistencyModelPipeline"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
......@@ -8,9 +8,9 @@ from ...utils import (
is_accelerate_available,
is_accelerate_version,
logging,
randn_tensor,
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_flax_available,
is_torch_available,
is_transformers_available,
)
_import_structure = {}
_dummy_objects = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
from .multicontrolnet import MultiControlNetModel
from .pipeline_controlnet import StableDiffusionControlNetPipeline
from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
from .pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipeline
from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
from .pipeline_controlnet_sd_xl_img2img import StableDiffusionXLControlNetImg2ImgPipeline
_import_structure["multicontrolnet"] = ["MultiControlNetModel"]
_import_structure["pipeline_controlnet"] = ["StableDiffusionControlNetPipeline"]
_import_structure["pipeline_controlnet_img2img"] = ["StableDiffusionControlNetImg2ImgPipeline"]
_import_structure["pipeline_controlnet_inpaint"] = ["StableDiffusionControlNetInpaintPipeline"]
_import_structure["pipeline_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPipeline"]
_import_structure["pipeline_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetImg2ImgPipeline"]
_import_structure["pipeline_controlnet_inpaint_sd_xl"] = ["StableDiffusionXLControlNetInpaintPipeline"]
try:
if not (is_transformers_available() and is_flax_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_flax_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
else:
_import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
if is_transformers_available() and is_flax_available():
from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
......@@ -31,11 +31,10 @@ from ...utils import (
deprecate,
is_accelerate_available,
is_accelerate_version,
is_compiled_module,
logging,
randn_tensor,
replace_example_docstring,
)
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
......
......@@ -30,11 +30,10 @@ from ...utils import (
deprecate,
is_accelerate_available,
is_accelerate_version,
is_compiled_module,
logging,
randn_tensor,
replace_example_docstring,
)
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
......
......@@ -32,11 +32,10 @@ from ...utils import (
deprecate,
is_accelerate_available,
is_accelerate_version,
is_compiled_module,
logging,
randn_tensor,
replace_example_docstring,
)
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
......
......@@ -38,12 +38,11 @@ from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
is_accelerate_available,
is_accelerate_version,
is_compiled_module,
is_invisible_watermark_available,
logging,
randn_tensor,
replace_example_docstring,
)
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .multicontrolnet import MultiControlNetModel
......
......@@ -39,11 +39,10 @@ from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
is_accelerate_available,
is_accelerate_version,
is_compiled_module,
logging,
randn_tensor,
replace_example_docstring,
)
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion_xl import StableDiffusionXLPipelineOutput
......
......@@ -38,11 +38,10 @@ from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
is_accelerate_available,
is_accelerate_version,
is_compiled_module,
logging,
randn_tensor,
replace_example_docstring,
)
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion_xl import StableDiffusionXLPipelineOutput
......
from .pipeline_dance_diffusion import DanceDiffusionPipeline
from ...utils import _LazyModule
_import_structure = {}
_import_structure["pipeline_dance_diffusion"] = ["DanceDiffusionPipeline"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
......@@ -17,7 +17,8 @@ from typing import List, Optional, Tuple, Union
import torch
from ...utils import logging, randn_tensor
from ...utils import logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
......
from .pipeline_ddim import DDIMPipeline
from ...utils import _LazyModule
_import_structure = {}
_import_structure["pipeline_ddim"] = ["DDIMPipeline"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
......@@ -17,7 +17,7 @@ from typing import List, Optional, Tuple, Union
import torch
from ...schedulers import DDIMScheduler
from ...utils import randn_tensor
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......
from .pipeline_ddpm import DDPMPipeline
from ...utils import (
_LazyModule,
)
_import_structure = {}
_import_structure["pipeline_ddpm"] = ["DDPMPipeline"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
......@@ -17,7 +17,7 @@ from typing import List, Optional, Tuple, Union
import torch
from ...utils import randn_tensor
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......
from dataclasses import dataclass
from typing import List, Optional, Union
import numpy as np
import PIL
from ...utils import BaseOutput, OptionalDependencyNotAvailable, is_torch_available, is_transformers_available
from .timesteps import (
fast27_timesteps,
smart27_timesteps,
smart50_timesteps,
smart100_timesteps,
smart185_timesteps,
super27_timesteps,
super40_timesteps,
super100_timesteps,
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
@dataclass
class IFPipelineOutput(BaseOutput):
"""
Args:
Output class for Stable Diffusion pipelines.
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
nsfw_detected (`List[bool]`)
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content or a watermark. `None` if safety checking could not be performed.
watermark_detected (`List[bool]`)
List of flags denoting whether the corresponding generated image likely has a watermark. `None` if safety
checking could not be performed.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
nsfw_detected: Optional[List[bool]]
watermark_detected: Optional[List[bool]]
_import_structure = {}
_dummy_objects = {}
_import_structure["timesteps"] = [
"fast27_timesteps",
"smart27_timesteps",
"smart50_timesteps",
"smart100_timesteps",
"smart185_timesteps",
"super27_timesteps",
"super40_timesteps",
"super100_timesteps",
]
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
from .pipeline_if import IFPipeline
from .pipeline_if_img2img import IFImg2ImgPipeline
from .pipeline_if_img2img_superresolution import IFImg2ImgSuperResolutionPipeline
from .pipeline_if_inpainting import IFInpaintingPipeline
from .pipeline_if_inpainting_superresolution import IFInpaintingSuperResolutionPipeline
from .pipeline_if_superresolution import IFSuperResolutionPipeline
from .safety_checker import IFSafetyChecker
from .watermark import IFWatermarker
_import_structure["pipeline_output"] = ["IFPipelineOutput"]
_import_structure["pipeline_if"] = ["IFPipeline"]
_import_structure["pipeline_if_img2img"] = ["IFImg2ImgPipeline"]
_import_structure["pipeline_if_img2img_superresolution"] = ["IFImg2ImgSuperResolutionPipeline"]
_import_structure["pipeline_if_inpainting"] = ["IFInpaintingPipeline"]
_import_structure["pipeline_if_inpainting_superresolution"] = ["IFInpaintingSuperResolutionPipeline"]
_import_structure["pipeline_if_superresolution"] = ["IFSuperResolutionPipeline"]
_import_structure["safety_checker"] = ["IFSafetyChecker"]
_import_structure["watermark"] = ["IFWatermarker"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
......@@ -17,9 +17,9 @@ from ...utils import (
is_bs4_available,
is_ftfy_available,
logging,
randn_tensor,
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from . import IFPipelineOutput
from .safety_checker import IFSafetyChecker
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment