Unverified Commit 144c3a8b authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Imports] Fix many import bugs and make sure that doc builder CI test works correctly (#5176)

* [Doc builder] Ensure slow import for doc builder

* Apply suggestions from code review

* env for doc builder

* fix more

* [Diffusers] Set import to slow as env variable

* fix docs

* fix docs

* Apply suggestions from code review

* Apply suggestions from code review

* fix docs

* fix docs
parent 30a512ea
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...utils import ( from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable, OptionalDependencyNotAvailable,
_LazyModule, _LazyModule,
get_objects_from_module, get_objects_from_module,
...@@ -34,7 +35,7 @@ else: ...@@ -34,7 +35,7 @@ else:
"VoidNeRFModel", "VoidNeRFModel",
] ]
if TYPE_CHECKING: if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try: try:
if not (is_transformers_available() and is_torch_available()): if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
......
# flake8: noqa # flake8: noqa
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...utils import DIFFUSERS_SLOW_IMPORT
from ...utils import ( from ...utils import (
_LazyModule, _LazyModule,
is_note_seq_available, is_note_seq_available,
...@@ -38,7 +39,7 @@ else: ...@@ -38,7 +39,7 @@ else:
_import_structure["midi_utils"] = ["MidiProcessor"] _import_structure["midi_utils"] = ["MidiProcessor"]
if TYPE_CHECKING: if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try: try:
if not (is_transformers_available() and is_torch_available()): if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
......
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...utils import ( from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable, OptionalDependencyNotAvailable,
_LazyModule, _LazyModule,
get_objects_from_module, get_objects_from_module,
...@@ -120,7 +121,7 @@ if is_transformers_available() and is_flax_available(): ...@@ -120,7 +121,7 @@ if is_transformers_available() and is_flax_available():
_import_structure["pipeline_flax_stable_diffusion_inpaint"] = ["FlaxStableDiffusionInpaintPipeline"] _import_structure["pipeline_flax_stable_diffusion_inpaint"] = ["FlaxStableDiffusionInpaintPipeline"]
_import_structure["safety_checker_flax"] = ["FlaxStableDiffusionSafetyChecker"] _import_structure["safety_checker_flax"] = ["FlaxStableDiffusionSafetyChecker"]
if TYPE_CHECKING: if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try: try:
if not (is_transformers_available() and is_torch_available()): if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
...@@ -130,6 +131,7 @@ if TYPE_CHECKING: ...@@ -130,6 +131,7 @@ if TYPE_CHECKING:
else: else:
from .clip_image_project_model import CLIPImageProjection from .clip_image_project_model import CLIPImageProjection
from .pipeline_cycle_diffusion import CycleDiffusionPipeline
from .pipeline_stable_diffusion import ( from .pipeline_stable_diffusion import (
StableDiffusionPipeline, StableDiffusionPipeline,
StableDiffusionPipelineOutput, StableDiffusionPipelineOutput,
...@@ -195,14 +197,11 @@ if TYPE_CHECKING: ...@@ -195,14 +197,11 @@ if TYPE_CHECKING:
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ...utils.dummy_onnx_objects import * from ...utils.dummy_onnx_objects import *
else: else:
from .pipeline_onnx_stable_diffusion import ( from .pipeline_onnx_stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline
OnnxStableDiffusionImg2ImgPipeline, from .pipeline_onnx_stable_diffusion_img2img import OnnxStableDiffusionImg2ImgPipeline
OnnxStableDiffusionInpaintPipeline, from .pipeline_onnx_stable_diffusion_inpaint import OnnxStableDiffusionInpaintPipeline
OnnxStableDiffusionInpaintPipelineLegacy, from .pipeline_onnx_stable_diffusion_inpaint_legacy import OnnxStableDiffusionInpaintPipelineLegacy
OnnxStableDiffusionPipeline, from .pipeline_onnx_stable_diffusion_upscale import OnnxStableDiffusionUpscalePipeline
OnnxStableDiffusionUpscalePipeline,
StableDiffusionOnnxPipeline,
)
try: try:
if not (is_transformers_available() and is_flax_available()): if not (is_transformers_available() and is_flax_available()):
...@@ -210,13 +209,11 @@ if TYPE_CHECKING: ...@@ -210,13 +209,11 @@ if TYPE_CHECKING:
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ...utils.dummy_flax_objects import * from ...utils.dummy_flax_objects import *
else: else:
from .pipeline_flax_stable_diffusion import ( from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline
FlaxStableDiffusionImg2ImgPipeline, from .pipeline_flax_stable_diffusion_img2img import FlaxStableDiffusionImg2ImgPipeline
FlaxStableDiffusionInpaintPipeline, from .pipeline_flax_stable_diffusion_inpaint import FlaxStableDiffusionInpaintPipeline
FlaxStableDiffusionPipeline,
FlaxStableDiffusionSafetyChecker,
)
from .pipeline_output import FlaxStableDiffusionPipelineOutput from .pipeline_output import FlaxStableDiffusionPipelineOutput
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
else: else:
import sys import sys
......
...@@ -30,7 +30,7 @@ from ...schedulers import DDIMScheduler ...@@ -30,7 +30,7 @@ from ...schedulers import DDIMScheduler
from ...utils import PIL_INTERPOLATION, deprecate, logging from ...utils import PIL_INTERPOLATION, deprecate, logging
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput from .pipeline_output import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
......
...@@ -35,7 +35,7 @@ from ...schedulers import ( ...@@ -35,7 +35,7 @@ from ...schedulers import (
) )
from ...utils import deprecate, logging, replace_example_docstring from ...utils import deprecate, logging, replace_example_docstring
from ..pipeline_flax_utils import FlaxDiffusionPipeline from ..pipeline_flax_utils import FlaxDiffusionPipeline
from . import FlaxStableDiffusionPipelineOutput from .pipeline_output import FlaxStableDiffusionPipelineOutput
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
......
...@@ -34,7 +34,7 @@ from ...schedulers import ( ...@@ -34,7 +34,7 @@ from ...schedulers import (
) )
from ...utils import PIL_INTERPOLATION, logging, replace_example_docstring from ...utils import PIL_INTERPOLATION, logging, replace_example_docstring
from ..pipeline_flax_utils import FlaxDiffusionPipeline from ..pipeline_flax_utils import FlaxDiffusionPipeline
from . import FlaxStableDiffusionPipelineOutput from .pipeline_output import FlaxStableDiffusionPipelineOutput
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
......
...@@ -35,7 +35,7 @@ from ...schedulers import ( ...@@ -35,7 +35,7 @@ from ...schedulers import (
) )
from ...utils import PIL_INTERPOLATION, deprecate, logging, replace_example_docstring from ...utils import PIL_INTERPOLATION, deprecate, logging, replace_example_docstring
from ..pipeline_flax_utils import FlaxDiffusionPipeline from ..pipeline_flax_utils import FlaxDiffusionPipeline
from . import FlaxStableDiffusionPipelineOutput from .pipeline_output import FlaxStableDiffusionPipelineOutput
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
......
...@@ -4,11 +4,7 @@ from typing import List, Optional, Union ...@@ -4,11 +4,7 @@ from typing import List, Optional, Union
import numpy as np import numpy as np
import PIL import PIL
from ...utils import ( from ...utils import BaseOutput, is_flax_available
BaseOutput,
is_flax_available,
is_transformers_available,
)
@dataclass @dataclass
...@@ -29,7 +25,7 @@ class StableDiffusionPipelineOutput(BaseOutput): ...@@ -29,7 +25,7 @@ class StableDiffusionPipelineOutput(BaseOutput):
nsfw_content_detected: Optional[List[bool]] nsfw_content_detected: Optional[List[bool]]
if is_transformers_available() and is_flax_available(): if is_flax_available():
import flax import flax
@flax.struct.dataclass @flax.struct.dataclass
......
...@@ -32,7 +32,7 @@ from ...utils import ( ...@@ -32,7 +32,7 @@ from ...utils import (
) )
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput from .pipeline_output import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
......
...@@ -7,6 +7,7 @@ import PIL ...@@ -7,6 +7,7 @@ import PIL
from PIL import Image from PIL import Image
from ...utils import ( from ...utils import (
DIFFUSERS_SLOW_IMPORT,
BaseOutput, BaseOutput,
OptionalDependencyNotAvailable, OptionalDependencyNotAvailable,
_LazyModule, _LazyModule,
...@@ -71,7 +72,7 @@ else: ...@@ -71,7 +72,7 @@ else:
) )
if TYPE_CHECKING: if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try: try:
if not (is_transformers_available() and is_torch_available()): if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
......
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...utils import ( from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable, OptionalDependencyNotAvailable,
_LazyModule, _LazyModule,
get_objects_from_module, get_objects_from_module,
...@@ -36,7 +37,7 @@ if is_transformers_available() and is_flax_available(): ...@@ -36,7 +37,7 @@ if is_transformers_available() and is_flax_available():
_import_structure["pipeline_flax_stable_diffusion_xl"] = ["FlaxStableDiffusionXLPipeline"] _import_structure["pipeline_flax_stable_diffusion_xl"] = ["FlaxStableDiffusionXLPipeline"]
if TYPE_CHECKING: if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try: try:
if not (is_transformers_available() and is_torch_available()): if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
......
...@@ -30,7 +30,7 @@ from ...schedulers import ( ...@@ -30,7 +30,7 @@ from ...schedulers import (
FlaxPNDMScheduler, FlaxPNDMScheduler,
) )
from ..pipeline_flax_utils import FlaxDiffusionPipeline from ..pipeline_flax_utils import FlaxDiffusionPipeline
from . import FlaxStableDiffusionXLPipelineOutput from .pipeline_output import FlaxStableDiffusionXLPipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
...@@ -4,11 +4,7 @@ from typing import List, Union ...@@ -4,11 +4,7 @@ from typing import List, Union
import numpy as np import numpy as np
import PIL import PIL
from ...utils import ( from ...utils import BaseOutput, is_flax_available
BaseOutput,
is_flax_available,
is_transformers_available,
)
@dataclass @dataclass
...@@ -25,7 +21,7 @@ class StableDiffusionXLPipelineOutput(BaseOutput): ...@@ -25,7 +21,7 @@ class StableDiffusionXLPipelineOutput(BaseOutput):
images: Union[List[PIL.Image.Image], np.ndarray] images: Union[List[PIL.Image.Image], np.ndarray]
if is_transformers_available() and is_flax_available(): if is_flax_available():
import flax import flax
@flax.struct.dataclass @flax.struct.dataclass
......
...@@ -40,7 +40,7 @@ from ...utils import ( ...@@ -40,7 +40,7 @@ from ...utils import (
) )
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionXLPipelineOutput from .pipeline_output import StableDiffusionXLPipelineOutput
if is_invisible_watermark_available(): if is_invisible_watermark_available():
......
...@@ -37,7 +37,7 @@ from ...utils import ( ...@@ -37,7 +37,7 @@ from ...utils import (
) )
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionXLPipelineOutput from .pipeline_output import StableDiffusionXLPipelineOutput
if is_invisible_watermark_available(): if is_invisible_watermark_available():
......
...@@ -39,7 +39,7 @@ from ...utils import ( ...@@ -39,7 +39,7 @@ from ...utils import (
) )
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionXLPipelineOutput from .pipeline_output import StableDiffusionXLPipelineOutput
if is_invisible_watermark_available(): if is_invisible_watermark_available():
......
...@@ -38,7 +38,7 @@ from ...utils import ( ...@@ -38,7 +38,7 @@ from ...utils import (
) )
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionXLPipelineOutput from .pipeline_output import StableDiffusionXLPipelineOutput
if is_invisible_watermark_available(): if is_invisible_watermark_available():
......
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...utils import _LazyModule from ...utils import DIFFUSERS_SLOW_IMPORT, _LazyModule
_import_structure = {"pipeline_stochastic_karras_ve": ["KarrasVePipeline"]} _import_structure = {"pipeline_stochastic_karras_ve": ["KarrasVePipeline"]}
if TYPE_CHECKING: if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pipeline_stochastic_karras_ve import KarrasVePipeline from .pipeline_stochastic_karras_ve import KarrasVePipeline
else: else:
......
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...utils import ( from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable, OptionalDependencyNotAvailable,
_LazyModule, _LazyModule,
get_objects_from_module, get_objects_from_module,
...@@ -24,7 +25,7 @@ else: ...@@ -24,7 +25,7 @@ else:
_import_structure["pipeline_stable_diffusion_xl_adapter"] = ["StableDiffusionXLAdapterPipeline"] _import_structure["pipeline_stable_diffusion_xl_adapter"] = ["StableDiffusionXLAdapterPipeline"]
if TYPE_CHECKING: if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try: try:
if not (is_transformers_available() and is_torch_available()): if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
......
...@@ -20,8 +20,6 @@ import PIL ...@@ -20,8 +20,6 @@ import PIL
import torch import torch
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, MultiAdapter, T2IAdapter, UNet2DConditionModel from ...models import AutoencoderKL, MultiAdapter, T2IAdapter, UNet2DConditionModel
...@@ -40,6 +38,7 @@ from ...utils import ( ...@@ -40,6 +38,7 @@ from ...utils import (
) )
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment