Unverified Commit 3aa64128 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Import] Add missing settings / Correct some dummy imports (#5036)

* [Import] Add missing settings

* up

* up

* up
parent ef29b24f
......@@ -50,12 +50,25 @@ class SafetyConfig(object):
_dummy_objects = {}
_additional_imports = {}
_import_structure = {
_import_structure = {}
_additional_imports.update({"SafetyConfig": SafetyConfig})
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure.update(
{
"pipeline_output": ["StableDiffusionSafePipelineOutput"],
"pipeline_stable_diffusion_safe": ["StableDiffusionPipelineSafe"],
"safety_checker": ["StableDiffusionSafetyChecker"],
}
_additional_imports.update({"SafetyConfig": SafetyConfig})
}
)
if TYPE_CHECKING:
......@@ -70,15 +83,6 @@ if TYPE_CHECKING:
from .safety_checker import SafeStableDiffusionSafetyChecker
else:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
import sys
sys.modules[__name__] = _LazyModule(
......
......@@ -47,3 +47,5 @@ else:
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
......@@ -51,3 +51,6 @@ else:
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
......@@ -41,7 +41,6 @@ if TYPE_CHECKING:
from .pipeline_wuerstchen import WuerstchenDecoderPipeline
from .pipeline_wuerstchen_combined import WuerstchenCombinedPipeline
from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline
else:
import sys
......@@ -51,3 +50,6 @@ else:
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment