Commit ff885b0e authored by Patrick von Platen's avatar Patrick von Platen
Browse files

add dummy imports

parent b4e6a740
......@@ -74,8 +74,8 @@ fixup: modified_only_fixup extra_style_checks autogenerate_code repo-consistency
# Make marked copies of snippets of codes conform to the original
fix-copies:
python utils/check_table.py --fix_and_overwrite
python utils/check_dummies.py --fix_and_overwrite
python utils/check_table.py --fix_and_overwrite
python utils/check_copies.py --fix_and_overwrite
# Run tests for the library
......
......@@ -3,18 +3,21 @@
# module, but to preserve other warnings. So, don't check this module at all.
from .utils import is_transformers_available
__version__ = "0.0.4"
from .modeling_utils import ModelMixin
from .models.unet import UNetModel
from .models.unet_grad_tts import UNetGradTTSModel
from .models.unet_ldm import UNetLDMModel
from .pipeline_utils import DiffusionPipeline
from .pipelines import BDDM, DDIM, DDPM, GLIDE, PNDM, GradTTS, LatentDiffusion
from .pipelines import BDDM, DDIM, DDPM, PNDM
from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
if is_transformers_available():
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel
from .models.unet_grad_tts import UNetGradTTSModel
from .pipelines import GLIDE, GradTTS, LatentDiffusion
else:
from .utils.dummy_transformers_objects import *
from ..utils import is_transformers_available
from .pipeline_bddm import BDDM
from .pipeline_ddim import DDIM
from .pipeline_ddpm import DDPM
from .pipeline_grad_tts import GradTTS
from .pipeline_glide import GLIDE
from .pipeline_latent_diffusion import LatentDiffusion
from .pipeline_pndm import PNDM
if is_transformers_available():
from .pipeline_glide import GLIDE
from .pipeline_grad_tts import GradTTS
from .pipeline_latent_diffusion import LatentDiffusion
......@@ -6,11 +6,8 @@ from shutil import copyfile
import torch
from transformers import PreTrainedTokenizer
try:
from transformers import PreTrainedTokenizer
except:
print("transformers is not installed")
try:
from unidecode import unidecode
......
......@@ -24,17 +24,11 @@ import torch.utils.checkpoint
from torch import nn
import tqdm
try:
from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig, GPT2Tokenizer
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, replace_return_docstrings
except:
print("Transformers is not installed")
pass
from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig, GPT2Tokenizer
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, replace_return_docstrings
from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from ..pipeline_utils import DiffusionPipeline
......
......@@ -11,11 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from requests.exceptions import HTTPError
import importlib
import importlib_metadata
import os
from .logging import logger
from collections import OrderedDict
import importlib_metadata
from requests.exceptions import HTTPError
from .logging import get_logger
logger = get_logger(__name__)
hf_cache_home = os.path.expanduser(
......@@ -56,3 +62,39 @@ class EntryNotFoundError(HTTPError):
class RevisionNotFoundError(HTTPError):
"""Raised when trying to access a hf.co URL with a valid repository but an invalid revision."""
TRANSFORMERS_IMPORT_ERROR = """
{0} requires the transformers library but it was not found in your environment. You can install it with pip:
`pip install accelerate`
"""
BACKENDS_MAPPING = OrderedDict(
[
("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)),
]
)
def requires_backends(obj, backends):
if not isinstance(backends, (list, tuple)):
backends = [backends]
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
checks = (BACKENDS_MAPPING[backend] for backend in backends)
failed = [msg.format(name) for available, msg in checks if not available()]
if failed:
raise ImportError("".join(failed))
class DummyObject(type):
"""
Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by
`requires_backend` each time a user tries to access any method of that class.
"""
def __getattr__(cls, key):
if key.startswith("_"):
return super().__getattr__(cls, key)
requires_backends(cls, cls._backends)
......@@ -22,3 +22,27 @@ class GLIDEUNetModel(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers"])
class UNetGradTTSModel(metaclass=DummyObject):
_backends = ["transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers"])
GLIDE = None
class GradTTS(metaclass=DummyObject):
_backends = ["transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers"])
class LatentDiffusion(metaclass=DummyObject):
_backends = ["transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment