"git@developer.sourcefind.cn:OpenDAS/apex.git" did not exist on "dc41c5ce17688283aa103ac753b247744d64ed96"
Unverified Commit 3aa64128 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Import] Add missing settings / Correct some dummy imports (#5036)

* [Import] Add missing settings

* up

* up

* up
parent ef29b24f
...@@ -50,13 +50,26 @@ class SafetyConfig(object): ...@@ -50,13 +50,26 @@ class SafetyConfig(object):
_dummy_objects = {} _dummy_objects = {}
_additional_imports = {} _additional_imports = {}
_import_structure = { _import_structure = {}
"pipeline_output": ["StableDiffusionSafePipelineOutput"],
"pipeline_stable_diffusion_safe": ["StableDiffusionPipelineSafe"],
"safety_checker": ["StableDiffusionSafetyChecker"],
}
_additional_imports.update({"SafetyConfig": SafetyConfig}) _additional_imports.update({"SafetyConfig": SafetyConfig})
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure.update(
{
"pipeline_output": ["StableDiffusionSafePipelineOutput"],
"pipeline_stable_diffusion_safe": ["StableDiffusionPipelineSafe"],
"safety_checker": ["StableDiffusionSafetyChecker"],
}
)
if TYPE_CHECKING: if TYPE_CHECKING:
try: try:
...@@ -70,25 +83,16 @@ if TYPE_CHECKING: ...@@ -70,25 +83,16 @@ if TYPE_CHECKING:
from .safety_checker import SafeStableDiffusionSafetyChecker from .safety_checker import SafeStableDiffusionSafetyChecker
else: else:
try: import sys
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
import sys
sys.modules[__name__] = _LazyModule( sys.modules[__name__] = _LazyModule(
__name__, __name__,
globals()["__file__"], globals()["__file__"],
_import_structure, _import_structure,
module_spec=__spec__, module_spec=__spec__,
) )
for name, value in _dummy_objects.items(): for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value) setattr(sys.modules[__name__], name, value)
for name, value in _additional_imports.items(): for name, value in _additional_imports.items():
setattr(sys.modules[__name__], name, value) setattr(sys.modules[__name__], name, value)
...@@ -47,3 +47,5 @@ else: ...@@ -47,3 +47,5 @@ else:
_import_structure, _import_structure,
module_spec=__spec__, module_spec=__spec__,
) )
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
...@@ -51,3 +51,6 @@ else: ...@@ -51,3 +51,6 @@ else:
_import_structure, _import_structure,
module_spec=__spec__, module_spec=__spec__,
) )
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
...@@ -41,7 +41,6 @@ if TYPE_CHECKING: ...@@ -41,7 +41,6 @@ if TYPE_CHECKING:
from .pipeline_wuerstchen import WuerstchenDecoderPipeline from .pipeline_wuerstchen import WuerstchenDecoderPipeline
from .pipeline_wuerstchen_combined import WuerstchenCombinedPipeline from .pipeline_wuerstchen_combined import WuerstchenCombinedPipeline
from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline
else: else:
import sys import sys
...@@ -51,3 +50,6 @@ else: ...@@ -51,3 +50,6 @@ else:
_import_structure, _import_structure,
module_spec=__spec__, module_spec=__spec__,
) )
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment