Unverified Commit b6e0b016 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Lazy Import for Diffusers (#4829)



* initial commit

* move modules to import struct

* add dummy objects and _LazyModule

* add lazy import to schedulers

* clean up unused imports

* lazy import on models module

* lazy import for schedulers module

* add lazy import to pipelines module

* lazy import altdiffusion

* lazy import audio diffusion

* lazy import audioldm

* lazy import consistency model

* lazy import controlnet

* lazy import dance diffusion ddim ddpm

* lazy import deepfloyd

* lazy import kandinksy

* lazy imports

* lazy import semantic diffusion

* lazy imports

* lazy import stable diffusion

* move sd output to its own module

* clean up

* lazy import t2iadapter

* lazy import unclip

* lazy import versatile and vq diffsuion

* lazy import vq diffusion

* helper to fetch objects from modules

* lazy import sdxl

* lazy import txt2vid

* lazy import stochastic karras

* fix model imports

* fix bug

* lazy import

* clean up

* clean up

* fixes for tests

* fixes for tests

* clean up

* remove import of torch_utils from utils module

* clean up

* clean up

* fix mistake import statement

* dedicated modules for exporting and loading

* remove testing utils from utils module

* fixes from  merge conflicts

* Update src/diffusers/pipelines/kandinsky2_2/__init__.py

* fix docs

* fix alt diffusion copied from

* fix check dummies

* fix more docs

* remove accelerate import from utils module

* add type checking

* make style

* fix check dummies

* remove torch import from xformers check

* clean up error message

* fixes after upstream merges

* dummy objects fix

* fix tests

* remove unused module import

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 88735249
...@@ -33,9 +33,9 @@ from ...utils import ( ...@@ -33,9 +33,9 @@ from ...utils import (
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
logging, logging,
randn_tensor,
replace_example_docstring, replace_example_docstring,
) )
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
......
...@@ -38,9 +38,9 @@ from ...utils import ( ...@@ -38,9 +38,9 @@ from ...utils import (
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
logging, logging,
randn_tensor,
replace_example_docstring, replace_example_docstring,
) )
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
......
from dataclasses import dataclass from ...utils import (
from typing import List, Optional, Union OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
import numpy as np
import torch
from ...utils import BaseOutput, OptionalDependencyNotAvailable, is_torch_available, is_transformers_available
@dataclass
class TextToVideoSDPipelineOutput(BaseOutput):
"""
Output class for text-to-video pipelines.
Args:
frames (`List[np.ndarray]` or `torch.FloatTensor`)
List of denoised frames (essentially images) as NumPy arrays of shape `(height, width, num_channels)` or as
a `torch` tensor. The length of the list denotes the video length (the number of frames).
"""
frames: Union[List[np.ndarray], torch.FloatTensor]
_import_structure = {}
_dummy_objects = {}
try: try:
if not (is_transformers_available() and is_torch_available()): if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else: else:
from .pipeline_text_to_video_synth import TextToVideoSDPipeline _import_structure["pipeline_output"] = ["TextToVideoSDPipelineOutput"]
from .pipeline_text_to_video_synth_img2img import VideoToVideoSDPipeline # noqa: F401 _import_structure["pipeline_text_to_video_synth"] = ["TextToVideoSDPipeline"]
from .pipeline_text_to_video_zero import TextToVideoZeroPipeline _import_structure["pipeline_text_to_video_synth_img2img"] = ["VideoToVideoSDPipeline"]
_import_structure["pipeline_text_to_video_zero"] = ["TextToVideoZeroPipeline"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import torch
from ...utils import (
BaseOutput,
)
@dataclass
class TextToVideoSDPipelineOutput(BaseOutput):
"""
Output class for text-to-video pipelines.
Args:
frames (`List[np.ndarray]` or `torch.FloatTensor`)
List of denoised frames (essentially images) as NumPy arrays of shape `(height, width, num_channels)` or as
a `torch` tensor. The length of the list denotes the video length (the number of frames).
"""
frames: Union[List[np.ndarray], torch.FloatTensor]
...@@ -28,9 +28,9 @@ from ...utils import ( ...@@ -28,9 +28,9 @@ from ...utils import (
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
logging, logging,
randn_tensor,
replace_example_docstring, replace_example_docstring,
) )
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import TextToVideoSDPipelineOutput from . import TextToVideoSDPipelineOutput
......
...@@ -29,9 +29,9 @@ from ...utils import ( ...@@ -29,9 +29,9 @@ from ...utils import (
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
logging, logging,
randn_tensor,
replace_example_docstring, replace_example_docstring,
) )
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import TextToVideoSDPipelineOutput from . import TextToVideoSDPipelineOutput
......
from ...utils import ( from ...utils import (
OptionalDependencyNotAvailable, OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available, is_torch_available,
is_transformers_available, is_transformers_available,
is_transformers_version, is_transformers_version,
) )
_import_structure = {}
_dummy_objects = {}
try: try:
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import UnCLIPImageVariationPipeline, UnCLIPPipeline from ...utils.dummy_torch_and_transformers_objects import UnCLIPImageVariationPipeline, UnCLIPPipeline
_dummy_objects.update(
{"UnCLIPImageVariationPipeline": UnCLIPImageVariationPipeline, "UnCLIPPipeline": UnCLIPPipeline}
)
else: else:
from .pipeline_unclip import UnCLIPPipeline _import_structure["pipeline_unclip"] = ["UnCLIPPipeline"]
from .pipeline_unclip_image_variation import UnCLIPImageVariationPipeline _import_structure["pipeline_unclip_image_variation"] = ["UnCLIPImageVariationPipeline"]
from .text_proj import UnCLIPTextProjModel _import_structure["text_proj"] = ["UnCLIPTextProjModel"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
...@@ -22,7 +22,8 @@ from transformers.models.clip.modeling_clip import CLIPTextModelOutput ...@@ -22,7 +22,8 @@ from transformers.models.clip.modeling_clip import CLIPTextModelOutput
from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel
from ...schedulers import UnCLIPScheduler from ...schedulers import UnCLIPScheduler
from ...utils import logging, randn_tensor from ...utils import logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .text_proj import UnCLIPTextProjModel from .text_proj import UnCLIPTextProjModel
......
...@@ -27,7 +27,8 @@ from transformers import ( ...@@ -27,7 +27,8 @@ from transformers import (
from ...models import UNet2DConditionModel, UNet2DModel from ...models import UNet2DConditionModel, UNet2DModel
from ...schedulers import UnCLIPScheduler from ...schedulers import UnCLIPScheduler
from ...utils import logging, randn_tensor from ...utils import logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .text_proj import UnCLIPTextProjModel from .text_proj import UnCLIPTextProjModel
......
from ...utils import ( from ...utils import (
OptionalDependencyNotAvailable, OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available, is_torch_available,
is_transformers_available, is_transformers_available,
is_transformers_version,
) )
_import_structure = {}
_dummy_objects = {}
try: try:
if not (is_transformers_available() and is_torch_available()): if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
...@@ -14,7 +18,25 @@ except OptionalDependencyNotAvailable: ...@@ -14,7 +18,25 @@ except OptionalDependencyNotAvailable:
ImageTextPipelineOutput, ImageTextPipelineOutput,
UniDiffuserPipeline, UniDiffuserPipeline,
) )
_dummy_objects.update(
{"ImageTextPipelineOutput": ImageTextPipelineOutput, "UniDiffuserPipeline": UniDiffuserPipeline}
)
else: else:
from .modeling_text_decoder import UniDiffuserTextDecoder _import_structure["modeling_text_decoder"] = ["UniDiffuserTextDecoder"]
from .modeling_uvit import UniDiffuserModel, UTransformer2DModel _import_structure["modeling_uvit"] = ["UniDiffuserModel", "UTransformer2DModel"]
from .pipeline_unidiffuser import ImageTextPipelineOutput, UniDiffuserPipeline _import_structure["pipeline_unidiffuser"] = ["ImageTextPipelineOutput", "UniDiffuserPipeline"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
...@@ -15,15 +15,9 @@ from transformers import ( ...@@ -15,15 +15,9 @@ from transformers import (
from ...models import AutoencoderKL from ...models import AutoencoderKL
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, is_accelerate_version, logging
PIL_INTERPOLATION,
deprecate,
is_accelerate_available,
is_accelerate_version,
logging,
randn_tensor,
)
from ...utils.outputs import BaseOutput from ...utils.outputs import BaseOutput
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from .modeling_text_decoder import UniDiffuserTextDecoder from .modeling_text_decoder import UniDiffuserTextDecoder
from .modeling_uvit import UniDiffuserModel from .modeling_uvit import UniDiffuserModel
......
from ...utils import ( from ...utils import (
OptionalDependencyNotAvailable, OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available, is_torch_available,
is_transformers_available, is_transformers_available,
is_transformers_version, is_transformers_version,
) )
_import_structure = {}
_dummy_objects = {}
try: try:
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
...@@ -16,9 +21,31 @@ except OptionalDependencyNotAvailable: ...@@ -16,9 +21,31 @@ except OptionalDependencyNotAvailable:
VersatileDiffusionPipeline, VersatileDiffusionPipeline,
VersatileDiffusionTextToImagePipeline, VersatileDiffusionTextToImagePipeline,
) )
_dummy_objects.update(
{
"VersatileDiffusionDualGuidedPipeline": VersatileDiffusionDualGuidedPipeline,
"VersatileDiffusionImageVariationPipeline": VersatileDiffusionImageVariationPipeline,
"VersatileDiffusionPipeline": VersatileDiffusionPipeline,
"VersatileDiffusionTextToImagePipeline": VersatileDiffusionTextToImagePipeline,
}
)
else: else:
from .modeling_text_unet import UNetFlatConditionModel _import_structure["modeling_text_unet"] = ["UNetFlatConditionModel"]
from .pipeline_versatile_diffusion import VersatileDiffusionPipeline _import_structure["pipeline_versatile_diffusion"] = ["VersatileDiffusionPipeline"]
from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline _import_structure["pipeline_versatile_diffusion_dual_guided"] = ["VersatileDiffusionDualGuidedPipeline"]
from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline _import_structure["pipeline_versatile_diffusion_image_variation"] = ["VersatileDiffusionImageVariationPipeline"]
from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline _import_structure["pipeline_versatile_diffusion_text_to_image"] = ["VersatileDiffusionTextToImagePipeline"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
...@@ -29,7 +29,8 @@ from transformers import ( ...@@ -29,7 +29,8 @@ from transformers import (
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, DualTransformer2DModel, Transformer2DModel, UNet2DConditionModel from ...models import AutoencoderKL, DualTransformer2DModel, Transformer2DModel, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, randn_tensor from ...utils import deprecate, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .modeling_text_unet import UNetFlatConditionModel from .modeling_text_unet import UNetFlatConditionModel
......
...@@ -24,7 +24,8 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection ...@@ -24,7 +24,8 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, randn_tensor from ...utils import deprecate, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......
...@@ -22,7 +22,8 @@ from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTo ...@@ -22,7 +22,8 @@ from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTo
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, Transformer2DModel, UNet2DConditionModel from ...models import AutoencoderKL, Transformer2DModel, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, randn_tensor from ...utils import deprecate, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .modeling_text_unet import UNetFlatConditionModel from .modeling_text_unet import UNetFlatConditionModel
......
from ...utils import OptionalDependencyNotAvailable, is_torch_available, is_transformers_available from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
is_transformers_available,
)
_import_structure = {}
_dummy_objects = {}
try: try:
if not (is_transformers_available() and is_torch_available()): if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * from ...utils.dummy_torch_and_transformers_objects import (
LearnedClassifierFreeSamplingEmbeddings,
VQDiffusionPipeline,
)
_dummy_objects.update(
{
"LearnedClassifierFreeSamplingEmbeddings": LearnedClassifierFreeSamplingEmbeddings,
"VQDiffusionPipeline": VQDiffusionPipeline,
}
)
else: else:
from .pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings, VQDiffusionPipeline _import_structure["pipeline_vq_diffusion"] = ["LearnedClassifierFreeSamplingEmbeddings", "VQDiffusionPipeline"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
from ...utils import is_torch_available, is_transformers_available from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
if is_transformers_available() and is_torch_available(): _import_structure = {}
from .modeling_paella_vq_model import PaellaVQModel _dummy_objects = {}
from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt try:
from .modeling_wuerstchen_prior import WuerstchenPrior if not (is_transformers_available() and is_torch_available()):
from .pipeline_wuerstchen import WuerstchenDecoderPipeline raise OptionalDependencyNotAvailable()
from .pipeline_wuerstchen_combined import WuerstchenCombinedPipeline
from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["modeling_paella_vq_model"] = ["PaellaVQModel"]
_import_structure["modeling_wuerstchen_diffnext"] = ["WuerstchenDiffNeXt"]
_import_structure["modeling_wuerstchen_prior"] = ["WuerstchenPrior"]
_import_structure["pipeline_wuerstchen"] = ["WuerstchenDecoderPipeline"]
_import_structure["pipeline_wuerstchen_combined"] = ["WuerstchenCombinedPipeline"]
_import_structure["pipeline_wuerstchen_prior"] = ["WuerstchenPriorPipeline"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
...@@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config ...@@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...models.modeling_utils import ModelMixin from ...models.modeling_utils import ModelMixin
from ...models.vae import DecoderOutput, VectorQuantizer from ...models.vae import DecoderOutput, VectorQuantizer
from ...models.vq_model import VQEncoderOutput from ...models.vq_model import VQEncoderOutput
from ...utils import apply_forward_hook from ...utils.accelerate_utils import apply_forward_hook
class MixingResidualBlock(nn.Module): class MixingResidualBlock(nn.Module):
......
...@@ -19,7 +19,8 @@ import torch ...@@ -19,7 +19,8 @@ import torch
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTokenizer
from ...schedulers import DDPMWuerstchenScheduler from ...schedulers import DDPMWuerstchenScheduler
from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring from ...utils import is_accelerate_available, is_accelerate_version, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .modeling_paella_vq_model import PaellaVQModel from .modeling_paella_vq_model import PaellaVQModel
from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt
......
...@@ -26,9 +26,9 @@ from ...utils import ( ...@@ -26,9 +26,9 @@ from ...utils import (
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
logging, logging,
randn_tensor,
replace_example_docstring, replace_example_docstring,
) )
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from .modeling_wuerstchen_prior import WuerstchenPrior from .modeling_wuerstchen_prior import WuerstchenPrior
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment