Commit b4e6a740 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

save intermediate

parent 1997b908
...@@ -74,9 +74,9 @@ fixup: modified_only_fixup extra_style_checks autogenerate_code repo-consistency ...@@ -74,9 +74,9 @@ fixup: modified_only_fixup extra_style_checks autogenerate_code repo-consistency
# Make marked copies of snippets of codes conform to the original # Make marked copies of snippets of codes conform to the original
fix-copies: fix-copies:
python utils/check_copies.py --fix_and_overwrite
python utils/check_table.py --fix_and_overwrite python utils/check_table.py --fix_and_overwrite
python utils/check_dummies.py --fix_and_overwrite python utils/check_dummies.py --fix_and_overwrite
python utils/check_copies.py --fix_and_overwrite
# Run tests for the library # Run tests for the library
......
# flake8: noqa # flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this # There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all. # module, but to preserve other warnings. So, don't check this module at all.
from .utils import is_transformers_available
__version__ = "0.0.4" __version__ = "0.0.4"
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .models.unet import UNetModel from .models.unet import UNetModel
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel
from .models.unet_grad_tts import UNetGradTTSModel from .models.unet_grad_tts import UNetGradTTSModel
from .models.unet_ldm import UNetLDMModel from .models.unet_ldm import UNetLDMModel
from .pipeline_utils import DiffusionPipeline from .pipeline_utils import DiffusionPipeline
from .pipelines import BDDM, DDIM, DDPM, GLIDE, PNDM, GradTTS, LatentDiffusion from .pipelines import BDDM, DDIM, DDPM, GLIDE, PNDM, GradTTS, LatentDiffusion
from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
if is_transformers_available():
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel
else:
from .utils.dummy_transformers_objects import *
...@@ -2,15 +2,6 @@ from .pipeline_bddm import BDDM ...@@ -2,15 +2,6 @@ from .pipeline_bddm import BDDM
from .pipeline_ddim import DDIM from .pipeline_ddim import DDIM
from .pipeline_ddpm import DDPM from .pipeline_ddpm import DDPM
from .pipeline_grad_tts import GradTTS from .pipeline_grad_tts import GradTTS
from .pipeline_glide import GLIDE
try:
from .pipeline_glide import GLIDE
except (NameError, ImportError):
class GLIDE:
pass
from .pipeline_latent_diffusion import LatentDiffusion from .pipeline_latent_diffusion import LatentDiffusion
from .pipeline_pndm import PNDM from .pipeline_pndm import PNDM
#!/usr/bin/env python
# coding=utf-8
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
import os
# Copyright 2021 The HuggingFace Inc. team. All rights reserved. # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -21,6 +12,10 @@ import os ...@@ -21,6 +12,10 @@ import os
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
import importlib
import importlib_metadata
import os
from .logging import logger
hf_cache_home = os.path.expanduser( hf_cache_home = os.path.expanduser(
...@@ -36,6 +31,18 @@ DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" ...@@ -36,6 +31,18 @@ DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
_transformers_available = importlib.util.find_spec("transformers") is not None
try:
_transformers_version = importlib_metadata.version("transformers")
logger.debug(f"Successfully imported transformers version {_transformers_version}")
except importlib_metadata.PackageNotFoundError:
_transformers_available = False
def is_transformers_available():
return _transformers_available
class RepositoryNotFoundError(HTTPError): class RepositoryNotFoundError(HTTPError):
""" """
Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does
......
# This file is autogenerated by the command `make fix-copies`, do not edit.
# flake8: noqa
from ..utils import DummyObject, requires_backends
class GLIDESuperResUNetModel(metaclass=DummyObject):
_backends = ["transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers"])
class GLIDETextToImageUNetModel(metaclass=DummyObject):
_backends = ["transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers"])
class GLIDEUNetModel(metaclass=DummyObject):
_backends = ["transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers"])
...@@ -20,10 +20,10 @@ import re ...@@ -20,10 +20,10 @@ import re
# All paths are set with the intent you should run this script from the root of the repo with the command # All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_dummies.py # python utils/check_dummies.py
PATH_TO_TRANSFORMERS = "src/transformers" PATH_TO_DIFFUSERS = "src/diffusers"
# Matches is_xxx_available() # Matches is_xxx_available()
_re_backend = re.compile(r"is\_([a-z_]*)_available()") _re_backend = re.compile(r"if is\_([a-z_]*)_available\(\)")
# Matches from xxx import bla # Matches from xxx import bla
_re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n") _re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
_re_test_backend = re.compile(r"^\s+if\s+not\s+is\_[a-z]*\_available\(\)") _re_test_backend = re.compile(r"^\s+if\s+not\s+is\_[a-z]*\_available\(\)")
...@@ -50,36 +50,30 @@ def {0}(*args, **kwargs): ...@@ -50,36 +50,30 @@ def {0}(*args, **kwargs):
def find_backend(line): def find_backend(line):
"""Find one (or multiple) backend in a code line of the init.""" """Find one (or multiple) backend in a code line of the init."""
if _re_test_backend.search(line) is None: backends = _re_backend.findall(line)
if len(backends) == 0:
return None return None
backends = [b[0] for b in _re_backend.findall(line)]
backends.sort() return backends[0]
return "_and_".join(backends)
def read_init(): def read_init():
"""Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects.""" """Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects."""
with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f: with open(os.path.join(PATH_TO_DIFFUSERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines() lines = f.readlines()
# Get to the point we do the actual imports for type checking # Get to the point we do the actual imports for type checking
line_index = 0 line_index = 0
while not lines[line_index].startswith("if TYPE_CHECKING"):
line_index += 1
backend_specific_objects = {} backend_specific_objects = {}
# Go through the end of the file # Go through the end of the file
while line_index < len(lines): while line_index < len(lines):
# If the line is an if is_backend_available, we grab all objects associated. # If the line is an if is_backend_available, we grab all objects associated.
backend = find_backend(lines[line_index]) backend = find_backend(lines[line_index])
if backend is not None: if backend is not None:
while not lines[line_index].startswith(" else:"):
line_index += 1
line_index += 1
objects = [] objects = []
line_index += 1
# Until we unindent, add backend objects to the list # Until we unindent, add backend objects to the list
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8): while not lines[line_index].startswith("else:"):
line = lines[line_index] line = lines[line_index]
single_line_import_search = _re_single_line_import.search(line) single_line_import_search = _re_single_line_import.search(line)
if single_line_import_search is not None: if single_line_import_search is not None:
...@@ -129,7 +123,7 @@ def check_dummies(overwrite=False): ...@@ -129,7 +123,7 @@ def check_dummies(overwrite=False):
short_names = {"torch": "pt"} short_names = {"torch": "pt"}
# Locate actual dummy modules and read their content. # Locate actual dummy modules and read their content.
path = os.path.join(PATH_TO_TRANSFORMERS, "utils") path = os.path.join(PATH_TO_DIFFUSERS, "utils")
dummy_file_paths = { dummy_file_paths = {
backend: os.path.join(path, f"dummy_{short_names.get(backend, backend)}_objects.py") backend: os.path.join(path, f"dummy_{short_names.get(backend, backend)}_objects.py")
for backend in dummy_files.keys() for backend in dummy_files.keys()
...@@ -147,7 +141,7 @@ def check_dummies(overwrite=False): ...@@ -147,7 +141,7 @@ def check_dummies(overwrite=False):
if dummy_files[backend] != actual_dummies[backend]: if dummy_files[backend] != actual_dummies[backend]:
if overwrite: if overwrite:
print( print(
f"Updating transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py as the main " f"Updating diffusers.utils.dummy_{short_names.get(backend, backend)}_objects.py as the main "
"__init__ has new objects." "__init__ has new objects."
) )
with open(dummy_file_paths[backend], "w", encoding="utf-8", newline="\n") as f: with open(dummy_file_paths[backend], "w", encoding="utf-8", newline="\n") as f:
...@@ -155,7 +149,7 @@ def check_dummies(overwrite=False): ...@@ -155,7 +149,7 @@ def check_dummies(overwrite=False):
else: else:
raise ValueError( raise ValueError(
"The main __init__ has objects that are not present in " "The main __init__ has objects that are not present in "
f"transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py. Run `make fix-copies` " f"diffusers.utils.dummy_{short_names.get(backend, backend)}_objects.py. Run `make fix-copies` "
"to fix this." "to fix this."
) )
......
...@@ -22,7 +22,7 @@ import re ...@@ -22,7 +22,7 @@ import re
# All paths are set with the intent you should run this script from the root of the repo with the command # All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_table.py # python utils/check_table.py
TRANSFORMERS_PATH = "src/transformers" TRANSFORMERS_PATH = "src/diffusers"
PATH_TO_DOCS = "docs/source/en" PATH_TO_DOCS = "docs/source/en"
REPO_PATH = "." REPO_PATH = "."
...@@ -62,13 +62,13 @@ _re_flax_models = re.compile(r"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGe ...@@ -62,13 +62,13 @@ _re_flax_models = re.compile(r"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGe
_re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)") _re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
# This is to make sure the transformers module imported is the one in the repo. # This is to make sure the diffusers module imported is the one in the repo.
spec = importlib.util.spec_from_file_location( spec = importlib.util.spec_from_file_location(
"transformers", "diffusers",
os.path.join(TRANSFORMERS_PATH, "__init__.py"), os.path.join(TRANSFORMERS_PATH, "__init__.py"),
submodule_search_locations=[TRANSFORMERS_PATH], submodule_search_locations=[TRANSFORMERS_PATH],
) )
transformers_module = spec.loader.load_module() diffusers_module = spec.loader.load_module()
# Thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python # Thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python
...@@ -88,10 +88,10 @@ def _center_text(text, width): ...@@ -88,10 +88,10 @@ def _center_text(text, width):
def get_model_table_from_auto_modules(): def get_model_table_from_auto_modules():
"""Generates an up-to-date model table from the content of the auto modules.""" """Generates an up-to-date model table from the content of the auto modules."""
# Dictionary model names to config. # Dictionary model names to config.
config_maping_names = transformers_module.models.auto.configuration_auto.CONFIG_MAPPING_NAMES config_maping_names = diffusers_module.models.auto.configuration_auto.CONFIG_MAPPING_NAMES
model_name_to_config = { model_name_to_config = {
name: config_maping_names[code] name: config_maping_names[code]
for code, name in transformers_module.MODEL_NAMES_MAPPING.items() for code, name in diffusers_module.MODEL_NAMES_MAPPING.items()
if code in config_maping_names if code in config_maping_names
} }
model_name_to_prefix = {name: config.replace("ConfigMixin", "") for name, config in model_name_to_config.items()} model_name_to_prefix = {name: config.replace("ConfigMixin", "") for name, config in model_name_to_config.items()}
...@@ -103,8 +103,8 @@ def get_model_table_from_auto_modules(): ...@@ -103,8 +103,8 @@ def get_model_table_from_auto_modules():
tf_models = collections.defaultdict(bool) tf_models = collections.defaultdict(bool)
flax_models = collections.defaultdict(bool) flax_models = collections.defaultdict(bool)
# Let's lookup through all transformers object (once). # Let's lookup through all diffusers object (once).
for attr_name in dir(transformers_module): for attr_name in dir(diffusers_module):
lookup_dict = None lookup_dict = None
if attr_name.endswith("Tokenizer"): if attr_name.endswith("Tokenizer"):
lookup_dict = slow_tokenizers lookup_dict = slow_tokenizers
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment