Unverified Commit b6e0b016 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Lazy Import for Diffusers (#4829)



* initial commit

* move modules to import struct

* add dummy objects and _LazyModule

* add lazy import to schedulers

* clean up unused imports

* lazy import on models module

* lazy import for schedulers module

* add lazy import to pipelines module

* lazy import altdiffusion

* lazy import audio diffusion

* lazy import audioldm

* lazy import consistency model

* lazy import controlnet

* lazy import dance diffusion ddim ddpm

* lazy import deepfloyd

* lazy import kandinksy

* lazy imports

* lazy import semantic diffusion

* lazy imports

* lazy import stable diffusion

* move sd output to its own module

* clean up

* lazy import t2iadapter

* lazy import unclip

* lazy import versatile and vq diffsuion

* lazy import vq diffusion

* helper to fetch objects from modules

* lazy import sdxl

* lazy import txt2vid

* lazy import stochastic karras

* fix model imports

* fix bug

* lazy import

* clean up

* clean up

* fixes for tests

* fixes for tests

* clean up

* remove import of torch_utils from utils module

* clean up

* clean up

* fix mistake import statement

* dedicated modules for exporting and loading

* remove testing utils from utils module

* fixes from  merge conflicts

* Update src/diffusers/pipelines/kandinsky2_2/__init__.py

* fix docs

* fix alt diffusion copied from

* fix check dummies

* fix more docs

* remove accelerate import from utils module

* add type checking

* make style

* fix check dummies

* remove torch import from xformers check

* clean up error message

* fixes after upstream merges

* dummy objects fix

* fix tests

* remove unused module import

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 88735249
from ...utils import ( from ...utils import (
OptionalDependencyNotAvailable, OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available, is_torch_available,
is_transformers_available, is_transformers_available,
is_transformers_version,
) )
_import_structure = {}
_dummy_objects = {}
try: try:
if not (is_transformers_available() and is_torch_available()): if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import ShapEPipeline from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else: else:
from .camera import create_pan_cameras _import_structure["camera"] = ["create_pan_cameras"]
from .pipeline_shap_e import ShapEPipeline _import_structure["pipeline_shap_e"] = ["ShapEPipeline"]
from .pipeline_shap_e_img2img import ShapEImg2ImgPipeline _import_structure["pipeline_shap_e_img2img"] = ["ShapEImg2ImgPipeline"]
from .renderer import ( _import_structure["renderer"] = [
BoundingBoxVolume, "BoundingBoxVolume",
ImportanceRaySampler, "ImportanceRaySampler",
MLPNeRFModelOutput, "MLPNeRFModelOutput",
MLPNeRSTFModel, "MLPNeRSTFModel",
ShapEParamsProjModel, "ShapEParamsProjModel",
ShapERenderer, "ShapERenderer",
StratifiedRaySampler, "StratifiedRaySampler",
VoidNeRFModel, "VoidNeRFModel",
) ]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
...@@ -28,9 +28,9 @@ from ...utils import ( ...@@ -28,9 +28,9 @@ from ...utils import (
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
logging, logging,
randn_tensor,
replace_example_docstring, replace_example_docstring,
) )
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from .renderer import ShapERenderer from .renderer import ShapERenderer
......
...@@ -25,9 +25,9 @@ from ...schedulers import HeunDiscreteScheduler ...@@ -25,9 +25,9 @@ from ...schedulers import HeunDiscreteScheduler
from ...utils import ( from ...utils import (
BaseOutput, BaseOutput,
logging, logging,
randn_tensor,
replace_example_docstring, replace_example_docstring,
) )
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from .renderer import ShapERenderer from .renderer import ShapERenderer
......
# flake8: noqa # flake8: noqa
from ...utils import is_note_seq_available, is_transformers_available, is_torch_available from ...utils import (
from ...utils import OptionalDependencyNotAvailable _LazyModule,
is_note_seq_available,
OptionalDependencyNotAvailable,
is_torch_available,
is_transformers_available,
get_objects_from_module,
)
_import_structure = {}
_dummy_objects = {}
try: try:
if not (is_transformers_available() and is_torch_available()): if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else: else:
from .notes_encoder import SpectrogramNotesEncoder _import_structure["notes_encoder"] = ["SpectrogramNotesEncoder"]
from .continous_encoder import SpectrogramContEncoder _import_structure["continous_encoder"] = ["SpectrogramContEncoder"]
from .pipeline_spectrogram_diffusion import ( _import_structure["pipeline_spectrogram_diffusion"] = [
SpectrogramContEncoder, "SpectrogramContEncoder",
SpectrogramDiffusionPipeline, "SpectrogramDiffusionPipeline",
T5FilmDecoder, "T5FilmDecoder",
) ]
try: try:
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
...@@ -23,4 +35,16 @@ try: ...@@ -23,4 +35,16 @@ try:
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ...utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403 from ...utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
else: else:
from .midi_utils import MidiProcessor _import_structure["midi_utils"] = ["MidiProcessor"]
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
...@@ -21,7 +21,8 @@ import torch ...@@ -21,7 +21,8 @@ import torch
from ...models import T5FilmDecoder from ...models import T5FilmDecoder
from ...schedulers import DDPMScheduler from ...schedulers import DDPMScheduler
from ...utils import is_onnx_available, logging, randn_tensor from ...utils import is_onnx_available, logging
from ...utils.torch_utils import randn_tensor
if is_onnx_available(): if is_onnx_available():
......
from dataclasses import dataclass
from typing import List, Optional, Union
import numpy as np
import PIL
from PIL import Image
from ...utils import ( from ...utils import (
BaseOutput,
OptionalDependencyNotAvailable, OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_flax_available, is_flax_available,
is_k_diffusion_available, is_k_diffusion_available,
is_k_diffusion_version, is_k_diffusion_version,
...@@ -18,59 +12,56 @@ from ...utils import ( ...@@ -18,59 +12,56 @@ from ...utils import (
) )
@dataclass _import_structure = {}
class StableDiffusionPipelineOutput(BaseOutput): _additional_imports = {}
""" _dummy_objects = {}
Output class for Stable Diffusion pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
num_channels)`.
nsfw_content_detected (`List[bool]`)
List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or
`None` if safety checking could not be performed.
"""
images: Union[List[PIL.Image.Image], np.ndarray] _import_structure["pipeline_output"] = ["StableDiffusionPipelineOutput"]
nsfw_content_detected: Optional[List[bool]]
if is_transformers_available() and is_flax_available():
_import_structure["pipeline_output"].extend(["FlaxStableDiffusionPipelineOutput"])
try: try:
if not (is_transformers_available() and is_torch_available()): if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else: else:
from .clip_image_project_model import CLIPImageProjection _import_structure["pipeline_cycle_diffusion"] = ["CycleDiffusionPipeline"]
from .pipeline_cycle_diffusion import CycleDiffusionPipeline _import_structure["pipeline_stable_diffusion"] = ["StableDiffusionPipeline"]
from .pipeline_stable_diffusion import StableDiffusionPipeline _import_structure["pipeline_stable_diffusion_attend_and_excite"] = ["StableDiffusionAttendAndExcitePipeline"]
from .pipeline_stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline _import_structure["pipeline_stable_diffusion_gligen"] = ["StableDiffusionGLIGENPipeline"]
from .pipeline_stable_diffusion_gligen import StableDiffusionGLIGENPipeline _import_structure["pipeline_stable_diffusion_img2img"] = ["StableDiffusionImg2ImgPipeline"]
from .pipeline_stable_diffusion_gligen_text_image import StableDiffusionGLIGENTextImagePipeline _import_structure["pipeline_stable_diffusion_inpaint"] = ["StableDiffusionInpaintPipeline"]
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline _import_structure["pipeline_stable_diffusion_inpaint_legacy"] = ["StableDiffusionInpaintPipelineLegacy"]
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline _import_structure["pipeline_stable_diffusion_instruct_pix2pix"] = ["StableDiffusionInstructPix2PixPipeline"]
from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy _import_structure["pipeline_stable_diffusion_latent_upscale"] = ["StableDiffusionLatentUpscalePipeline"]
from .pipeline_stable_diffusion_instruct_pix2pix import StableDiffusionInstructPix2PixPipeline _import_structure["pipeline_stable_diffusion_ldm3d"] = ["StableDiffusionLDM3DPipeline"]
from .pipeline_stable_diffusion_latent_upscale import StableDiffusionLatentUpscalePipeline _import_structure["pipeline_stable_diffusion_model_editing"] = ["StableDiffusionModelEditingPipeline"]
from .pipeline_stable_diffusion_ldm3d import StableDiffusionLDM3DPipeline _import_structure["pipeline_stable_diffusion_panorama"] = ["StableDiffusionPanoramaPipeline"]
from .pipeline_stable_diffusion_model_editing import StableDiffusionModelEditingPipeline _import_structure["pipeline_stable_diffusion_paradigms"] = ["StableDiffusionParadigmsPipeline"]
from .pipeline_stable_diffusion_panorama import StableDiffusionPanoramaPipeline _import_structure["pipeline_stable_diffusion_sag"] = ["StableDiffusionSAGPipeline"]
from .pipeline_stable_diffusion_paradigms import StableDiffusionParadigmsPipeline _import_structure["pipeline_stable_diffusion_upscale"] = ["StableDiffusionUpscalePipeline"]
from .pipeline_stable_diffusion_sag import StableDiffusionSAGPipeline _import_structure["pipeline_stable_unclip"] = ["StableUnCLIPPipeline"]
from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline _import_structure["pipeline_stable_unclip_img2img"] = ["StableUnCLIPImg2ImgPipeline"]
from .pipeline_stable_unclip import StableUnCLIPPipeline _import_structure["safety_checker"] = ["StableDiffusionSafetyChecker"]
from .pipeline_stable_unclip_img2img import StableUnCLIPImg2ImgPipeline _import_structure["stable_unclip_image_normalizer"] = ["StableUnCLIPImageNormalizer"]
from .safety_checker import StableDiffusionSafetyChecker _import_structure["pipeline_stable_diffusion_gligen_text_image"] = ["StableDiffusionGLIGENTextImagePipeline"]
from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer _import_structure["pipeline_stable_diffusion_gligen"] = ["StableDiffusionGLIGENPipeline"]
_import_structure["clip_image_project_model"] = ["CLIPImageProjection"]
try: try:
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import StableDiffusionImageVariationPipeline from ...utils.dummy_torch_and_transformers_objects import StableDiffusionImageVariationPipeline
_dummy_objects.update({"StableDiffusionImageVariationPipeline": StableDiffusionImageVariationPipeline})
else: else:
from .pipeline_stable_diffusion_image_variation import StableDiffusionImageVariationPipeline _import_structure["pipeline_stable_diffusion_image_variation"] = ["StableDiffusionImageVariationPipeline"]
try: try:
...@@ -82,10 +73,18 @@ except OptionalDependencyNotAvailable: ...@@ -82,10 +73,18 @@ except OptionalDependencyNotAvailable:
StableDiffusionDiffEditPipeline, StableDiffusionDiffEditPipeline,
StableDiffusionPix2PixZeroPipeline, StableDiffusionPix2PixZeroPipeline,
) )
_dummy_objects.update(
{
"StableDiffusionDepth2ImgPipeline": StableDiffusionDepth2ImgPipeline,
"StableDiffusionDiffEditPipeline": StableDiffusionDiffEditPipeline,
"StableDiffusionPix2PixZeroPipeline": StableDiffusionPix2PixZeroPipeline,
}
)
else: else:
from .pipeline_stable_diffusion_depth2img import StableDiffusionDepth2ImgPipeline _import_structure["pipeline_stable_diffusion_depth2img"] = ["StableDiffusionDepth2ImgPipeline"]
from .pipeline_stable_diffusion_diffedit import StableDiffusionDiffEditPipeline _import_structure["pipeline_stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"]
from .pipeline_stable_diffusion_pix2pix_zero import StableDiffusionPix2PixZeroPipeline _import_structure["pipeline_stable_diffusion_pix2pix_zero"] = ["StableDiffusionPix2PixZeroPipeline"]
try: try:
...@@ -97,43 +96,52 @@ try: ...@@ -97,43 +96,52 @@ try:
): ):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403 from ...utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects))
else: else:
from .pipeline_stable_diffusion_k_diffusion import StableDiffusionKDiffusionPipeline _import_structure["pipeline_stable_diffusion_k_diffusion"] = ["StableDiffusionKDiffusionPipeline"]
try: try:
if not (is_transformers_available() and is_onnx_available()): if not (is_transformers_available() and is_onnx_available()):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ...utils.dummy_onnx_objects import * # noqa F403 from ...utils import dummy_onnx_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_onnx_objects))
else: else:
from .pipeline_onnx_stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline _import_structure["pipeline_onnx_stable_diffusion"] = [
from .pipeline_onnx_stable_diffusion_img2img import OnnxStableDiffusionImg2ImgPipeline "OnnxStableDiffusionPipeline",
from .pipeline_onnx_stable_diffusion_inpaint import OnnxStableDiffusionInpaintPipeline "StableDiffusionOnnxPipeline",
from .pipeline_onnx_stable_diffusion_inpaint_legacy import OnnxStableDiffusionInpaintPipelineLegacy ]
from .pipeline_onnx_stable_diffusion_upscale import OnnxStableDiffusionUpscalePipeline _import_structure["pipeline_onnx_stable_diffusion_img2img"] = ["OnnxStableDiffusionImg2ImgPipeline"]
_import_structure["pipeline_onnx_stable_diffusion_inpaint"] = ["OnnxStableDiffusionInpaintPipeline"]
_import_structure["pipeline_onnx_stable_diffusion_inpaint_legacy"] = ["OnnxStableDiffusionInpaintPipelineLegacy"]
_import_structure["pipeline_onnx_stable_diffusion_upscale"] = ["OnnxStableDiffusionUpscalePipeline"]
if is_transformers_available() and is_flax_available(): if is_transformers_available() and is_flax_available():
import flax from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState
@flax.struct.dataclass _additional_imports.update({"PNDMSchedulerState": PNDMSchedulerState})
class FlaxStableDiffusionPipelineOutput(BaseOutput):
"""
Output class for Flax-based Stable Diffusion pipelines.
Args: _import_structure["pipeline_flax_stable_diffusion"] = ["FlaxStableDiffusionPipeline"]
images (`np.ndarray`): _import_structure["pipeline_flax_stable_diffusion_img2img"] = ["FlaxStableDiffusionImg2ImgPipeline"]
Denoised images of array shape of `(batch_size, height, width, num_channels)`. _import_structure["pipeline_flax_stable_diffusion_inpaint"] = ["FlaxStableDiffusionInpaintPipeline"]
nsfw_content_detected (`List[bool]`): _import_structure["safety_checker_flax"] = ["FlaxStableDiffusionSafetyChecker"]
List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content
or `None` if safety checking could not be performed.
"""
images: np.ndarray import sys
nsfw_content_detected: List[bool]
from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState
from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline sys.modules[__name__] = _LazyModule(
from .pipeline_flax_stable_diffusion_img2img import FlaxStableDiffusionImg2ImgPipeline __name__,
from .pipeline_flax_stable_diffusion_inpaint import FlaxStableDiffusionInpaintPipeline globals()["__file__"],
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker _import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
for name, value in _additional_imports.items():
setattr(sys.modules[__name__], name, value)
...@@ -29,7 +29,8 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin ...@@ -29,7 +29,8 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import DDIMScheduler from ...schedulers import DDIMScheduler
from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor from ...utils import PIL_INTERPOLATION, deprecate, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
......
from dataclasses import dataclass
from typing import List, Optional, Union
import numpy as np
import PIL
from ...utils import (
BaseOutput,
is_flax_available,
is_transformers_available,
)
@dataclass
class StableDiffusionPipelineOutput(BaseOutput):
"""
Output class for Stable Diffusion pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
num_channels)`.
nsfw_content_detected (`List[bool]`)
List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or
`None` if safety checking could not be performed.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
nsfw_content_detected: Optional[List[bool]]
if is_transformers_available() and is_flax_available():
import flax
@flax.struct.dataclass
class FlaxStableDiffusionPipelineOutput(BaseOutput):
"""
Output class for Flax-based Stable Diffusion pipelines.
Args:
images (`np.ndarray`):
Denoised images of array shape of `(batch_size, height, width, num_channels)`.
nsfw_content_detected (`List[bool]`):
List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content
or `None` if safety checking could not be performed.
"""
images: np.ndarray
nsfw_content_detected: List[bool]
...@@ -30,9 +30,9 @@ from ...utils import ( ...@@ -30,9 +30,9 @@ from ...utils import (
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
logging, logging,
randn_tensor,
replace_example_docstring, replace_example_docstring,
) )
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
......
...@@ -27,7 +27,8 @@ from ...models import AutoencoderKL, UNet2DConditionModel ...@@ -27,7 +27,8 @@ from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import Attention from ...models.attention_processor import Attention
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, randn_tensor, replace_example_docstring from ...utils import deprecate, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
......
...@@ -28,7 +28,8 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin ...@@ -28,7 +28,8 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor from ...utils import PIL_INTERPOLATION, deprecate, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......
...@@ -35,9 +35,9 @@ from ...utils import ( ...@@ -35,9 +35,9 @@ from ...utils import (
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
logging, logging,
randn_tensor,
replace_example_docstring, replace_example_docstring,
) )
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
......
...@@ -31,9 +31,9 @@ from ...utils import ( ...@@ -31,9 +31,9 @@ from ...utils import (
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
logging, logging,
randn_tensor,
replace_example_docstring, replace_example_docstring,
) )
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
......
...@@ -36,9 +36,9 @@ from ...utils import ( ...@@ -36,9 +36,9 @@ from ...utils import (
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
logging, logging,
randn_tensor,
replace_example_docstring, replace_example_docstring,
) )
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .clip_image_project_model import CLIPImageProjection from .clip_image_project_model import CLIPImageProjection
......
...@@ -24,7 +24,8 @@ from ...configuration_utils import FrozenDict ...@@ -24,7 +24,8 @@ from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, randn_tensor from ...utils import deprecate, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
......
...@@ -33,9 +33,9 @@ from ...utils import ( ...@@ -33,9 +33,9 @@ from ...utils import (
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
logging, logging,
randn_tensor,
replace_example_docstring, replace_example_docstring,
) )
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
......
...@@ -27,7 +27,8 @@ from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoa ...@@ -27,7 +27,8 @@ from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoa
from ...models import AsymmetricAutoencoderKL, AutoencoderKL, UNet2DConditionModel from ...models import AsymmetricAutoencoderKL, AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
......
...@@ -27,14 +27,8 @@ from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoa ...@@ -27,14 +27,8 @@ from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoa
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, is_accelerate_version, logging
PIL_INTERPOLATION, from ...utils.torch_utils import randn_tensor
deprecate,
is_accelerate_available,
is_accelerate_version,
logging,
randn_tensor,
)
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
......
...@@ -24,14 +24,8 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor ...@@ -24,14 +24,8 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, is_accelerate_version, logging
PIL_INTERPOLATION, from ...utils.torch_utils import randn_tensor
deprecate,
is_accelerate_available,
is_accelerate_version,
logging,
randn_tensor,
)
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
......
...@@ -24,7 +24,8 @@ from ...image_processor import VaeImageProcessor ...@@ -24,7 +24,8 @@ from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import LMSDiscreteScheduler from ...schedulers import LMSDiscreteScheduler
from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment