Unverified Commit b6e0b016 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Lazy Import for Diffusers (#4829)



* initial commit

* move modules to import struct

* add dummy objects and _LazyModule

* add lazy import to schedulers

* clean up unused imports

* lazy import on models module

* lazy import for schedulers module

* add lazy import to pipelines module

* lazy import altdiffusion

* lazy import audio diffusion

* lazy import audioldm

* lazy import consistency model

* lazy import controlnet

* lazy import dance diffusion ddim ddpm

* lazy import deepfloyd

* lazy import kandinksy

* lazy imports

* lazy import semantic diffusion

* lazy imports

* lazy import stable diffusion

* move sd output to its own module

* clean up

* lazy import t2iadapter

* lazy import unclip

* lazy import versatile and vq diffsuion

* lazy import vq diffusion

* helper to fetch objects from modules

* lazy import sdxl

* lazy import txt2vid

* lazy import stochastic karras

* fix model imports

* fix bug

* lazy import

* clean up

* clean up

* fixes for tests

* fixes for tests

* clean up

* remove import of torch_utils from utils module

* clean up

* clean up

* fix mistake import statement

* dedicated modules for exporting and loading

* remove testing utils from utils module

* fixes from  merge conflicts

* Update src/diffusers/pipelines/kandinsky2_2/__init__.py

* fix docs

* fix alt diffusion copied from

* fix check dummies

* fix more docs

* remove accelerate import from utils module

* add type checking

* make style

* fix check dummies

* remove torch import from xformers check

* clean up error message

* fixes after upstream merges

* dummy objects fix

* fix tests

* remove unused module import

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 88735249
......@@ -33,9 +33,9 @@ from ...utils import (
is_accelerate_available,
is_accelerate_version,
logging,
randn_tensor,
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
......
......@@ -38,9 +38,9 @@ from ...utils import (
is_accelerate_available,
is_accelerate_version,
logging,
randn_tensor,
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
......
from dataclasses import dataclass
from typing import List, Optional, Union
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
import numpy as np
import torch
from ...utils import BaseOutput, OptionalDependencyNotAvailable, is_torch_available, is_transformers_available
@dataclass
class TextToVideoSDPipelineOutput(BaseOutput):
"""
Output class for text-to-video pipelines.
Args:
frames (`List[np.ndarray]` or `torch.FloatTensor`)
List of denoised frames (essentially images) as NumPy arrays of shape `(height, width, num_channels)` or as
a `torch` tensor. The length of the list denotes the video length (the number of frames).
"""
frames: Union[List[np.ndarray], torch.FloatTensor]
_import_structure = {}
_dummy_objects = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
from .pipeline_text_to_video_synth import TextToVideoSDPipeline
from .pipeline_text_to_video_synth_img2img import VideoToVideoSDPipeline # noqa: F401
from .pipeline_text_to_video_zero import TextToVideoZeroPipeline
_import_structure["pipeline_output"] = ["TextToVideoSDPipelineOutput"]
_import_structure["pipeline_text_to_video_synth"] = ["TextToVideoSDPipeline"]
_import_structure["pipeline_text_to_video_synth_img2img"] = ["VideoToVideoSDPipeline"]
_import_structure["pipeline_text_to_video_zero"] = ["TextToVideoZeroPipeline"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import torch
from ...utils import (
BaseOutput,
)
@dataclass
class TextToVideoSDPipelineOutput(BaseOutput):
"""
Output class for text-to-video pipelines.
Args:
frames (`List[np.ndarray]` or `torch.FloatTensor`)
List of denoised frames (essentially images) as NumPy arrays of shape `(height, width, num_channels)` or as
a `torch` tensor. The length of the list denotes the video length (the number of frames).
"""
frames: Union[List[np.ndarray], torch.FloatTensor]
......@@ -28,9 +28,9 @@ from ...utils import (
is_accelerate_available,
is_accelerate_version,
logging,
randn_tensor,
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from . import TextToVideoSDPipelineOutput
......
......@@ -29,9 +29,9 @@ from ...utils import (
is_accelerate_available,
is_accelerate_version,
logging,
randn_tensor,
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from . import TextToVideoSDPipelineOutput
......
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
is_transformers_available,
is_transformers_version,
)
_import_structure = {}
_dummy_objects = {}
try:
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import UnCLIPImageVariationPipeline, UnCLIPPipeline
_dummy_objects.update(
{"UnCLIPImageVariationPipeline": UnCLIPImageVariationPipeline, "UnCLIPPipeline": UnCLIPPipeline}
)
else:
from .pipeline_unclip import UnCLIPPipeline
from .pipeline_unclip_image_variation import UnCLIPImageVariationPipeline
from .text_proj import UnCLIPTextProjModel
_import_structure["pipeline_unclip"] = ["UnCLIPPipeline"]
_import_structure["pipeline_unclip_image_variation"] = ["UnCLIPImageVariationPipeline"]
_import_structure["text_proj"] = ["UnCLIPTextProjModel"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
......@@ -22,7 +22,8 @@ from transformers.models.clip.modeling_clip import CLIPTextModelOutput
from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel
from ...schedulers import UnCLIPScheduler
from ...utils import logging, randn_tensor
from ...utils import logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .text_proj import UnCLIPTextProjModel
......
......@@ -27,7 +27,8 @@ from transformers import (
from ...models import UNet2DConditionModel, UNet2DModel
from ...schedulers import UnCLIPScheduler
from ...utils import logging, randn_tensor
from ...utils import logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .text_proj import UnCLIPTextProjModel
......
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
is_transformers_available,
is_transformers_version,
)
_import_structure = {}
_dummy_objects = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
......@@ -14,7 +18,25 @@ except OptionalDependencyNotAvailable:
ImageTextPipelineOutput,
UniDiffuserPipeline,
)
_dummy_objects.update(
{"ImageTextPipelineOutput": ImageTextPipelineOutput, "UniDiffuserPipeline": UniDiffuserPipeline}
)
else:
from .modeling_text_decoder import UniDiffuserTextDecoder
from .modeling_uvit import UniDiffuserModel, UTransformer2DModel
from .pipeline_unidiffuser import ImageTextPipelineOutput, UniDiffuserPipeline
_import_structure["modeling_text_decoder"] = ["UniDiffuserTextDecoder"]
_import_structure["modeling_uvit"] = ["UniDiffuserModel", "UTransformer2DModel"]
_import_structure["pipeline_unidiffuser"] = ["ImageTextPipelineOutput", "UniDiffuserPipeline"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
......@@ -15,15 +15,9 @@ from transformers import (
from ...models import AutoencoderKL
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
PIL_INTERPOLATION,
deprecate,
is_accelerate_available,
is_accelerate_version,
logging,
randn_tensor,
)
from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, is_accelerate_version, logging
from ...utils.outputs import BaseOutput
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .modeling_text_decoder import UniDiffuserTextDecoder
from .modeling_uvit import UniDiffuserModel
......
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
is_transformers_available,
is_transformers_version,
)
_import_structure = {}
_dummy_objects = {}
try:
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")):
raise OptionalDependencyNotAvailable()
......@@ -16,9 +21,31 @@ except OptionalDependencyNotAvailable:
VersatileDiffusionPipeline,
VersatileDiffusionTextToImagePipeline,
)
_dummy_objects.update(
{
"VersatileDiffusionDualGuidedPipeline": VersatileDiffusionDualGuidedPipeline,
"VersatileDiffusionImageVariationPipeline": VersatileDiffusionImageVariationPipeline,
"VersatileDiffusionPipeline": VersatileDiffusionPipeline,
"VersatileDiffusionTextToImagePipeline": VersatileDiffusionTextToImagePipeline,
}
)
else:
from .modeling_text_unet import UNetFlatConditionModel
from .pipeline_versatile_diffusion import VersatileDiffusionPipeline
from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline
from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline
from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline
_import_structure["modeling_text_unet"] = ["UNetFlatConditionModel"]
_import_structure["pipeline_versatile_diffusion"] = ["VersatileDiffusionPipeline"]
_import_structure["pipeline_versatile_diffusion_dual_guided"] = ["VersatileDiffusionDualGuidedPipeline"]
_import_structure["pipeline_versatile_diffusion_image_variation"] = ["VersatileDiffusionImageVariationPipeline"]
_import_structure["pipeline_versatile_diffusion_text_to_image"] = ["VersatileDiffusionTextToImagePipeline"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
......@@ -29,7 +29,8 @@ from transformers import (
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, DualTransformer2DModel, Transformer2DModel, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, randn_tensor
from ...utils import deprecate, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .modeling_text_unet import UNetFlatConditionModel
......
......@@ -24,7 +24,8 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, randn_tensor
from ...utils import deprecate, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......
......@@ -22,7 +22,8 @@ from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTo
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, Transformer2DModel, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, randn_tensor
from ...utils import deprecate, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .modeling_text_unet import UNetFlatConditionModel
......
from ...utils import OptionalDependencyNotAvailable, is_torch_available, is_transformers_available
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
is_transformers_available,
)
_import_structure = {}
_dummy_objects = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
from ...utils.dummy_torch_and_transformers_objects import (
LearnedClassifierFreeSamplingEmbeddings,
VQDiffusionPipeline,
)
_dummy_objects.update(
{
"LearnedClassifierFreeSamplingEmbeddings": LearnedClassifierFreeSamplingEmbeddings,
"VQDiffusionPipeline": VQDiffusionPipeline,
}
)
else:
from .pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings, VQDiffusionPipeline
_import_structure["pipeline_vq_diffusion"] = ["LearnedClassifierFreeSamplingEmbeddings", "VQDiffusionPipeline"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
from ...utils import is_torch_available, is_transformers_available
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
if is_transformers_available() and is_torch_available():
from .modeling_paella_vq_model import PaellaVQModel
from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt
from .modeling_wuerstchen_prior import WuerstchenPrior
from .pipeline_wuerstchen import WuerstchenDecoderPipeline
from .pipeline_wuerstchen_combined import WuerstchenCombinedPipeline
from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline
_import_structure = {}
_dummy_objects = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["modeling_paella_vq_model"] = ["PaellaVQModel"]
_import_structure["modeling_wuerstchen_diffnext"] = ["WuerstchenDiffNeXt"]
_import_structure["modeling_wuerstchen_prior"] = ["WuerstchenPrior"]
_import_structure["pipeline_wuerstchen"] = ["WuerstchenDecoderPipeline"]
_import_structure["pipeline_wuerstchen_combined"] = ["WuerstchenCombinedPipeline"]
_import_structure["pipeline_wuerstchen_prior"] = ["WuerstchenPriorPipeline"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
......@@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...models.modeling_utils import ModelMixin
from ...models.vae import DecoderOutput, VectorQuantizer
from ...models.vq_model import VQEncoderOutput
from ...utils import apply_forward_hook
from ...utils.accelerate_utils import apply_forward_hook
class MixingResidualBlock(nn.Module):
......
......@@ -19,7 +19,8 @@ import torch
from transformers import CLIPTextModel, CLIPTokenizer
from ...schedulers import DDPMWuerstchenScheduler
from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring
from ...utils import is_accelerate_available, is_accelerate_version, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .modeling_paella_vq_model import PaellaVQModel
from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt
......
......@@ -26,9 +26,9 @@ from ...utils import (
is_accelerate_available,
is_accelerate_version,
logging,
randn_tensor,
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .modeling_wuerstchen_prior import WuerstchenPrior
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment