Commit ff885b0e authored by Patrick von Platen's avatar Patrick von Platen
Browse files

add dummy imports

parent b4e6a740
...@@ -74,8 +74,8 @@ fixup: modified_only_fixup extra_style_checks autogenerate_code repo-consistency ...@@ -74,8 +74,8 @@ fixup: modified_only_fixup extra_style_checks autogenerate_code repo-consistency
# Make marked copies of snippets of codes conform to the original # Make marked copies of snippets of codes conform to the original
fix-copies: fix-copies:
python utils/check_table.py --fix_and_overwrite
python utils/check_dummies.py --fix_and_overwrite python utils/check_dummies.py --fix_and_overwrite
python utils/check_table.py --fix_and_overwrite
python utils/check_copies.py --fix_and_overwrite python utils/check_copies.py --fix_and_overwrite
# Run tests for the library # Run tests for the library
......
...@@ -3,18 +3,21 @@ ...@@ -3,18 +3,21 @@
# module, but to preserve other warnings. So, don't check this module at all. # module, but to preserve other warnings. So, don't check this module at all.
from .utils import is_transformers_available from .utils import is_transformers_available
__version__ = "0.0.4" __version__ = "0.0.4"
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .models.unet import UNetModel from .models.unet import UNetModel
from .models.unet_grad_tts import UNetGradTTSModel
from .models.unet_ldm import UNetLDMModel from .models.unet_ldm import UNetLDMModel
from .pipeline_utils import DiffusionPipeline from .pipeline_utils import DiffusionPipeline
from .pipelines import BDDM, DDIM, DDPM, GLIDE, PNDM, GradTTS, LatentDiffusion from .pipelines import BDDM, DDIM, DDPM, PNDM
from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
if is_transformers_available(): if is_transformers_available():
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel
from .models.unet_grad_tts import UNetGradTTSModel
from .pipelines import GLIDE, GradTTS, LatentDiffusion
else: else:
from .utils.dummy_transformers_objects import * from .utils.dummy_transformers_objects import *
from ..utils import is_transformers_available
from .pipeline_bddm import BDDM from .pipeline_bddm import BDDM
from .pipeline_ddim import DDIM from .pipeline_ddim import DDIM
from .pipeline_ddpm import DDPM from .pipeline_ddpm import DDPM
from .pipeline_grad_tts import GradTTS
from .pipeline_glide import GLIDE
from .pipeline_latent_diffusion import LatentDiffusion
from .pipeline_pndm import PNDM from .pipeline_pndm import PNDM
if is_transformers_available():
from .pipeline_glide import GLIDE
from .pipeline_grad_tts import GradTTS
from .pipeline_latent_diffusion import LatentDiffusion
...@@ -6,11 +6,8 @@ from shutil import copyfile ...@@ -6,11 +6,8 @@ from shutil import copyfile
import torch import torch
from transformers import PreTrainedTokenizer
try:
from transformers import PreTrainedTokenizer
except:
print("transformers is not installed")
try: try:
from unidecode import unidecode from unidecode import unidecode
......
...@@ -24,17 +24,11 @@ import torch.utils.checkpoint ...@@ -24,17 +24,11 @@ import torch.utils.checkpoint
from torch import nn from torch import nn
import tqdm import tqdm
from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig, GPT2Tokenizer
from transformers.activations import ACT2FN
try: from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig, GPT2Tokenizer from transformers.modeling_utils import PreTrainedModel
from transformers.activations import ACT2FN from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, replace_return_docstrings
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, replace_return_docstrings
except:
print("Transformers is not installed")
pass
from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
......
...@@ -11,11 +11,17 @@ ...@@ -11,11 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from requests.exceptions import HTTPError
import importlib import importlib
import importlib_metadata
import os import os
from .logging import logger from collections import OrderedDict
import importlib_metadata
from requests.exceptions import HTTPError
from .logging import get_logger
logger = get_logger(__name__)
hf_cache_home = os.path.expanduser( hf_cache_home = os.path.expanduser(
...@@ -56,3 +62,39 @@ class EntryNotFoundError(HTTPError): ...@@ -56,3 +62,39 @@ class EntryNotFoundError(HTTPError):
class RevisionNotFoundError(HTTPError): class RevisionNotFoundError(HTTPError):
"""Raised when trying to access a hf.co URL with a valid repository but an invalid revision.""" """Raised when trying to access a hf.co URL with a valid repository but an invalid revision."""
TRANSFORMERS_IMPORT_ERROR = """
{0} requires the transformers library but it was not found in your environment. You can install it with pip:
`pip install accelerate`
"""
BACKENDS_MAPPING = OrderedDict(
[
("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)),
]
)
def requires_backends(obj, backends):
if not isinstance(backends, (list, tuple)):
backends = [backends]
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
checks = (BACKENDS_MAPPING[backend] for backend in backends)
failed = [msg.format(name) for available, msg in checks if not available()]
if failed:
raise ImportError("".join(failed))
class DummyObject(type):
"""
Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by
`requires_backend` each time a user tries to access any method of that class.
"""
def __getattr__(cls, key):
if key.startswith("_"):
return super().__getattr__(cls, key)
requires_backends(cls, cls._backends)
...@@ -22,3 +22,27 @@ class GLIDEUNetModel(metaclass=DummyObject): ...@@ -22,3 +22,27 @@ class GLIDEUNetModel(metaclass=DummyObject):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers"]) requires_backends(self, ["transformers"])
class UNetGradTTSModel(metaclass=DummyObject):
_backends = ["transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers"])
GLIDE = None
class GradTTS(metaclass=DummyObject):
_backends = ["transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers"])
class LatentDiffusion(metaclass=DummyObject):
_backends = ["transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment