Unverified Commit b6e0b016 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Lazy Import for Diffusers (#4829)



* initial commit

* move modules to import struct

* add dummy objects and _LazyModule

* add lazy import to schedulers

* clean up unused imports

* lazy import on models module

* lazy import for schedulers module

* add lazy import to pipelines module

* lazy import altdiffusion

* lazy import audio diffusion

* lazy import audioldm

* lazy import consistency model

* lazy import controlnet

* lazy import dance diffusion ddim ddpm

* lazy import deepfloyd

* lazy import kandinksy

* lazy imports

* lazy import semantic diffusion

* lazy imports

* lazy import stable diffusion

* move sd output to its own module

* clean up

* lazy import t2iadapter

* lazy import unclip

* lazy import versatile and vq diffsuion

* lazy import vq diffusion

* helper to fetch objects from modules

* lazy import sdxl

* lazy import txt2vid

* lazy import stochastic karras

* fix model imports

* fix bug

* lazy import

* clean up

* clean up

* fixes for tests

* fixes for tests

* clean up

* remove import of torch_utils from utils module

* clean up

* clean up

* fix mistake import statement

* dedicated modules for exporting and loading

* remove testing utils from utils module

* fixes from  merge conflicts

* Update src/diffusers/pipelines/kandinsky2_2/__init__.py

* fix docs

* fix alt diffusion copied from

* fix check dummies

* fix more docs

* remove accelerate import from utils module

* add type checking

* make style

* fix check dummies

* remove torch import from xformers check

* clean up error message

* fixes after upstream merges

* dummy objects fix

* fix tests

* remove unused module import

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 88735249
...@@ -20,9 +20,9 @@ from ...utils import ( ...@@ -20,9 +20,9 @@ from ...utils import (
is_bs4_available, is_bs4_available,
is_ftfy_available, is_ftfy_available,
logging, logging,
randn_tensor,
replace_example_docstring, replace_example_docstring,
) )
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import IFPipelineOutput from . import IFPipelineOutput
from .safety_checker import IFSafetyChecker from .safety_checker import IFSafetyChecker
......
...@@ -21,9 +21,9 @@ from ...utils import ( ...@@ -21,9 +21,9 @@ from ...utils import (
is_bs4_available, is_bs4_available,
is_ftfy_available, is_ftfy_available,
logging, logging,
randn_tensor,
replace_example_docstring, replace_example_docstring,
) )
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import IFPipelineOutput from . import IFPipelineOutput
from .safety_checker import IFSafetyChecker from .safety_checker import IFSafetyChecker
......
...@@ -20,9 +20,9 @@ from ...utils import ( ...@@ -20,9 +20,9 @@ from ...utils import (
is_bs4_available, is_bs4_available,
is_ftfy_available, is_ftfy_available,
logging, logging,
randn_tensor,
replace_example_docstring, replace_example_docstring,
) )
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import IFPipelineOutput from . import IFPipelineOutput
from .safety_checker import IFSafetyChecker from .safety_checker import IFSafetyChecker
......
...@@ -21,9 +21,9 @@ from ...utils import ( ...@@ -21,9 +21,9 @@ from ...utils import (
is_bs4_available, is_bs4_available,
is_ftfy_available, is_ftfy_available,
logging, logging,
randn_tensor,
replace_example_docstring, replace_example_docstring,
) )
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import IFPipelineOutput from . import IFPipelineOutput
from .safety_checker import IFSafetyChecker from .safety_checker import IFSafetyChecker
......
...@@ -20,9 +20,9 @@ from ...utils import ( ...@@ -20,9 +20,9 @@ from ...utils import (
is_bs4_available, is_bs4_available,
is_ftfy_available, is_ftfy_available,
logging, logging,
randn_tensor,
replace_example_docstring, replace_example_docstring,
) )
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import IFPipelineOutput from . import IFPipelineOutput
from .safety_checker import IFSafetyChecker from .safety_checker import IFSafetyChecker
......
from dataclasses import dataclass
from typing import List, Optional, Union
import numpy as np
import PIL
from ...utils import BaseOutput
@dataclass
class IFPipelineOutput(BaseOutput):
"""
Args:
Output class for Stable Diffusion pipelines.
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
nsfw_detected (`List[bool]`)
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content or a watermark. `None` if safety checking could not be performed.
watermark_detected (`List[bool]`)
List of flags denoting whether the corresponding generated image likely has a watermark. `None` if safety
checking could not be performed.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
nsfw_detected: Optional[List[bool]]
watermark_detected: Optional[List[bool]]
from .pipeline_dit import DiTPipeline from ...utils import _LazyModule
_import_structure = {}
_import_structure["pipeline_dit"] = ["DiTPipeline"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
...@@ -24,7 +24,7 @@ import torch ...@@ -24,7 +24,7 @@ import torch
from ...models import AutoencoderKL, Transformer2DModel from ...models import AutoencoderKL, Transformer2DModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......
from ...utils import ( from ...utils import (
OptionalDependencyNotAvailable, OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available, is_torch_available,
is_transformers_available, is_transformers_available,
) )
_import_structure = {}
_dummy_objects = {}
try: try:
if not (is_transformers_available() and is_torch_available()): if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else: else:
from .pipeline_kandinsky import KandinskyPipeline _import_structure["pipeline_kandinsky"] = ["KandinskyPipeline"]
from .pipeline_kandinsky_combined import ( _import_structure["pipeline_kandinsky_combined"] = [
KandinskyCombinedPipeline, "KandinskyCombinedPipeline",
KandinskyImg2ImgCombinedPipeline, "KandinskyImg2ImgCombinedPipeline",
KandinskyInpaintCombinedPipeline, "KandinskyInpaintCombinedPipeline",
) ]
from .pipeline_kandinsky_img2img import KandinskyImg2ImgPipeline _import_structure["pipeline_kandinsky_img2img"] = ["KandinskyImg2ImgPipeline"]
from .pipeline_kandinsky_inpaint import KandinskyInpaintPipeline _import_structure["pipeline_kandinsky_inpaint"] = ["KandinskyInpaintPipeline"]
from .pipeline_kandinsky_prior import KandinskyPriorPipeline, KandinskyPriorPipelineOutput _import_structure["pipeline_kandinsky_prior"] = ["KandinskyPriorPipeline", "KandinskyPriorPipelineOutput"]
from .text_encoder import MultilingualCLIP _import_structure["text_encoder"] = ["MultilingualCLIP"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
...@@ -25,9 +25,9 @@ from ...utils import ( ...@@ -25,9 +25,9 @@ from ...utils import (
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
logging, logging,
randn_tensor,
replace_example_docstring, replace_example_docstring,
) )
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .text_encoder import MultilingualCLIP from .text_encoder import MultilingualCLIP
......
...@@ -28,9 +28,9 @@ from ...utils import ( ...@@ -28,9 +28,9 @@ from ...utils import (
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
logging, logging,
randn_tensor,
replace_example_docstring, replace_example_docstring,
) )
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .text_encoder import MultilingualCLIP from .text_encoder import MultilingualCLIP
......
...@@ -32,9 +32,9 @@ from ...utils import ( ...@@ -32,9 +32,9 @@ from ...utils import (
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
logging, logging,
randn_tensor,
replace_example_docstring, replace_example_docstring,
) )
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .text_encoder import MultilingualCLIP from .text_encoder import MultilingualCLIP
......
...@@ -27,9 +27,9 @@ from ...utils import ( ...@@ -27,9 +27,9 @@ from ...utils import (
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
logging, logging,
randn_tensor,
replace_example_docstring, replace_example_docstring,
) )
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
......
from ...utils import ( from ...utils import (
OptionalDependencyNotAvailable, OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available, is_torch_available,
is_transformers_available, is_transformers_available,
) )
_import_structure = {}
_dummy_objects = {}
try: try:
if not (is_transformers_available() and is_torch_available()): if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else: else:
from .pipeline_kandinsky2_2 import KandinskyV22Pipeline _import_structure["pipeline_kandinsky2_2"] = ["KandinskyV22Pipeline"]
from .pipeline_kandinsky2_2_combined import ( _import_structure["pipeline_kandinsky2_2_combined"] = [
KandinskyV22CombinedPipeline, "KandinskyV22CombinedPipeline",
KandinskyV22Img2ImgCombinedPipeline, "KandinskyV22Img2ImgCombinedPipeline",
KandinskyV22InpaintCombinedPipeline, "KandinskyV22InpaintCombinedPipeline",
) ]
from .pipeline_kandinsky2_2_controlnet import KandinskyV22ControlnetPipeline _import_structure["pipeline_kandinsky2_2_controlnet"] = ["KandinskyV22ControlnetPipeline"]
from .pipeline_kandinsky2_2_controlnet_img2img import KandinskyV22ControlnetImg2ImgPipeline _import_structure["pipeline_kandinsky2_2_controlnet_img2img"] = ["KandinskyV22ControlnetImg2ImgPipeline"]
from .pipeline_kandinsky2_2_img2img import KandinskyV22Img2ImgPipeline _import_structure["pipeline_kandinsky2_2_img2img"] = ["KandinskyV22Img2ImgPipeline"]
from .pipeline_kandinsky2_2_inpainting import KandinskyV22InpaintPipeline _import_structure["pipeline_kandinsky2_2_inpainting"] = ["KandinskyV22InpaintPipeline"]
from .pipeline_kandinsky2_2_prior import KandinskyV22PriorPipeline _import_structure["pipeline_kandinsky2_2_prior"] = ["KandinskyV22PriorPipeline"]
from .pipeline_kandinsky2_2_prior_emb2emb import KandinskyV22PriorEmb2EmbPipeline _import_structure["pipeline_kandinsky2_2_prior_emb2emb"] = ["KandinskyV22PriorEmb2EmbPipeline"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
...@@ -22,9 +22,9 @@ from ...utils import ( ...@@ -22,9 +22,9 @@ from ...utils import (
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
logging, logging,
randn_tensor,
replace_example_docstring, replace_example_docstring,
) )
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......
...@@ -22,9 +22,9 @@ from ...utils import ( ...@@ -22,9 +22,9 @@ from ...utils import (
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
logging, logging,
randn_tensor,
replace_example_docstring, replace_example_docstring,
) )
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......
...@@ -25,9 +25,9 @@ from ...utils import ( ...@@ -25,9 +25,9 @@ from ...utils import (
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
logging, logging,
randn_tensor,
replace_example_docstring, replace_example_docstring,
) )
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......
...@@ -25,9 +25,9 @@ from ...utils import ( ...@@ -25,9 +25,9 @@ from ...utils import (
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
logging, logging,
randn_tensor,
replace_example_docstring, replace_example_docstring,
) )
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......
...@@ -29,9 +29,9 @@ from ...utils import ( ...@@ -29,9 +29,9 @@ from ...utils import (
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
logging, logging,
randn_tensor,
replace_example_docstring, replace_example_docstring,
) )
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......
...@@ -10,9 +10,9 @@ from ...utils import ( ...@@ -10,9 +10,9 @@ from ...utils import (
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
logging, logging,
randn_tensor,
replace_example_docstring, replace_example_docstring,
) )
from ...utils.torch_utils import randn_tensor
from ..kandinsky import KandinskyPriorPipelineOutput from ..kandinsky import KandinskyPriorPipelineOutput
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment