Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
b4e6a740
Commit
b4e6a740
authored
Jun 17, 2022
by
Patrick von Platen
Browse files
save intermediate
parent
1997b908
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
68 additions
and
47 deletions
+68
-47
Makefile
Makefile
+1
-1
src/diffusers/__init__.py
src/diffusers/__init__.py
+6
-1
src/diffusers/pipelines/__init__.py
src/diffusers/pipelines/__init__.py
+1
-10
src/diffusers/utils/__init__.py
src/diffusers/utils/__init__.py
+16
-9
src/diffusers/utils/dummy_transformers_objects.py
src/diffusers/utils/dummy_transformers_objects.py
+24
-0
utils/check_dummies.py
utils/check_dummies.py
+12
-18
utils/check_table.py
utils/check_table.py
+8
-8
No files found.
Makefile
View file @
b4e6a740
...
@@ -74,9 +74,9 @@ fixup: modified_only_fixup extra_style_checks autogenerate_code repo-consistency
...
@@ -74,9 +74,9 @@ fixup: modified_only_fixup extra_style_checks autogenerate_code repo-consistency
# Make marked copies of snippets of codes conform to the original
# Make marked copies of snippets of codes conform to the original
fix-copies
:
fix-copies
:
python utils/check_copies.py
--fix_and_overwrite
python utils/check_table.py
--fix_and_overwrite
python utils/check_table.py
--fix_and_overwrite
python utils/check_dummies.py
--fix_and_overwrite
python utils/check_dummies.py
--fix_and_overwrite
python utils/check_copies.py
--fix_and_overwrite
# Run tests for the library
# Run tests for the library
...
...
src/diffusers/__init__.py
View file @
b4e6a740
# flake8: noqa
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# module, but to preserve other warnings. So, don't check this module at all.
from
.utils
import
is_transformers_available
__version__
=
"0.0.4"
__version__
=
"0.0.4"
from
.modeling_utils
import
ModelMixin
from
.modeling_utils
import
ModelMixin
from
.models.unet
import
UNetModel
from
.models.unet
import
UNetModel
from
.models.unet_glide
import
GLIDESuperResUNetModel
,
GLIDETextToImageUNetModel
,
GLIDEUNetModel
from
.models.unet_grad_tts
import
UNetGradTTSModel
from
.models.unet_grad_tts
import
UNetGradTTSModel
from
.models.unet_ldm
import
UNetLDMModel
from
.models.unet_ldm
import
UNetLDMModel
from
.pipeline_utils
import
DiffusionPipeline
from
.pipeline_utils
import
DiffusionPipeline
from
.pipelines
import
BDDM
,
DDIM
,
DDPM
,
GLIDE
,
PNDM
,
GradTTS
,
LatentDiffusion
from
.pipelines
import
BDDM
,
DDIM
,
DDPM
,
GLIDE
,
PNDM
,
GradTTS
,
LatentDiffusion
from
.schedulers
import
DDIMScheduler
,
DDPMScheduler
,
GradTTSScheduler
,
PNDMScheduler
,
SchedulerMixin
from
.schedulers
import
DDIMScheduler
,
DDPMScheduler
,
GradTTSScheduler
,
PNDMScheduler
,
SchedulerMixin
from
.schedulers.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
from
.schedulers.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
if
is_transformers_available
():
from
.models.unet_glide
import
GLIDESuperResUNetModel
,
GLIDETextToImageUNetModel
,
GLIDEUNetModel
else
:
from
.utils.dummy_transformers_objects
import
*
src/diffusers/pipelines/__init__.py
View file @
b4e6a740
...
@@ -2,15 +2,6 @@ from .pipeline_bddm import BDDM
...
@@ -2,15 +2,6 @@ from .pipeline_bddm import BDDM
from
.pipeline_ddim
import
DDIM
from
.pipeline_ddim
import
DDIM
from
.pipeline_ddpm
import
DDPM
from
.pipeline_ddpm
import
DDPM
from
.pipeline_grad_tts
import
GradTTS
from
.pipeline_grad_tts
import
GradTTS
from
.pipeline_glide
import
GLIDE
try
:
from
.pipeline_glide
import
GLIDE
except
(
NameError
,
ImportError
):
class
GLIDE
:
pass
from
.pipeline_latent_diffusion
import
LatentDiffusion
from
.pipeline_latent_diffusion
import
LatentDiffusion
from
.pipeline_pndm
import
PNDM
from
.pipeline_pndm
import
PNDM
src/diffusers/utils/__init__.py
View file @
b4e6a740
#!/usr/bin/env python
# coding=utf-8
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
import
os
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
...
@@ -21,6 +12,10 @@ import os
...
@@ -21,6 +12,10 @@ import os
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
requests.exceptions
import
HTTPError
from
requests.exceptions
import
HTTPError
import
importlib
import
importlib_metadata
import
os
from
.logging
import
logger
hf_cache_home
=
os
.
path
.
expanduser
(
hf_cache_home
=
os
.
path
.
expanduser
(
...
@@ -36,6 +31,18 @@ DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
...
@@ -36,6 +31,18 @@ DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE
=
os
.
getenv
(
"HF_MODULES_CACHE"
,
os
.
path
.
join
(
hf_cache_home
,
"modules"
))
HF_MODULES_CACHE
=
os
.
getenv
(
"HF_MODULES_CACHE"
,
os
.
path
.
join
(
hf_cache_home
,
"modules"
))
_transformers_available
=
importlib
.
util
.
find_spec
(
"transformers"
)
is
not
None
try
:
_transformers_version
=
importlib_metadata
.
version
(
"transformers"
)
logger
.
debug
(
f
"Successfully imported transformers version
{
_transformers_version
}
"
)
except
importlib_metadata
.
PackageNotFoundError
:
_transformers_available
=
False
def
is_transformers_available
():
return
_transformers_available
class
RepositoryNotFoundError
(
HTTPError
):
class
RepositoryNotFoundError
(
HTTPError
):
"""
"""
Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does
Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does
...
...
src/diffusers/utils/dummy_transformers_objects.py
0 → 100644
View file @
b4e6a740
# This file is autogenerated by the command `make fix-copies`, do not edit.
# flake8: noqa
from
..utils
import
DummyObject
,
requires_backends
class
GLIDESuperResUNetModel
(
metaclass
=
DummyObject
):
_backends
=
[
"transformers"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"transformers"
])
class
GLIDETextToImageUNetModel
(
metaclass
=
DummyObject
):
_backends
=
[
"transformers"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"transformers"
])
class
GLIDEUNetModel
(
metaclass
=
DummyObject
):
_backends
=
[
"transformers"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"transformers"
])
utils/check_dummies.py
View file @
b4e6a740
...
@@ -20,10 +20,10 @@ import re
...
@@ -20,10 +20,10 @@ import re
# All paths are set with the intent you should run this script from the root of the repo with the command
# All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_dummies.py
# python utils/check_dummies.py
PATH_TO_
TRANSFORM
ERS
=
"src/
transform
ers"
PATH_TO_
DIFFUS
ERS
=
"src/
diffus
ers"
# Matches is_xxx_available()
# Matches is_xxx_available()
_re_backend
=
re
.
compile
(
r
"is\_([a-z_]*)_available
(
)"
)
_re_backend
=
re
.
compile
(
r
"
if
is\_([a-z_]*)_available
\(\
)"
)
# Matches from xxx import bla
# Matches from xxx import bla
_re_single_line_import
=
re
.
compile
(
r
"\s+from\s+\S*\s+import\s+([^\(\s].*)\n"
)
_re_single_line_import
=
re
.
compile
(
r
"\s+from\s+\S*\s+import\s+([^\(\s].*)\n"
)
_re_test_backend
=
re
.
compile
(
r
"^\s+if\s+not\s+is\_[a-z]*\_available\(\)"
)
_re_test_backend
=
re
.
compile
(
r
"^\s+if\s+not\s+is\_[a-z]*\_available\(\)"
)
...
@@ -50,36 +50,30 @@ def {0}(*args, **kwargs):
...
@@ -50,36 +50,30 @@ def {0}(*args, **kwargs):
def
find_backend
(
line
):
def
find_backend
(
line
):
"""Find one (or multiple) backend in a code line of the init."""
"""Find one (or multiple) backend in a code line of the init."""
if
_re_test_backend
.
search
(
line
)
is
None
:
backends
=
_re_backend
.
findall
(
line
)
if
len
(
backends
)
==
0
:
return
None
return
None
backends
=
[
b
[
0
]
for
b
in
_re_backend
.
findall
(
line
)]
backends
.
sort
()
return
backends
[
0
]
return
"_and_"
.
join
(
backends
)
def
read_init
():
def
read_init
():
"""Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects."""
"""Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects."""
with
open
(
os
.
path
.
join
(
PATH_TO_
TRANSFORM
ERS
,
"__init__.py"
),
"r"
,
encoding
=
"utf-8"
,
newline
=
"
\n
"
)
as
f
:
with
open
(
os
.
path
.
join
(
PATH_TO_
DIFFUS
ERS
,
"__init__.py"
),
"r"
,
encoding
=
"utf-8"
,
newline
=
"
\n
"
)
as
f
:
lines
=
f
.
readlines
()
lines
=
f
.
readlines
()
# Get to the point we do the actual imports for type checking
# Get to the point we do the actual imports for type checking
line_index
=
0
line_index
=
0
while
not
lines
[
line_index
].
startswith
(
"if TYPE_CHECKING"
):
line_index
+=
1
backend_specific_objects
=
{}
backend_specific_objects
=
{}
# Go through the end of the file
# Go through the end of the file
while
line_index
<
len
(
lines
):
while
line_index
<
len
(
lines
):
# If the line is an if is_backend_available, we grab all objects associated.
# If the line is an if is_backend_available, we grab all objects associated.
backend
=
find_backend
(
lines
[
line_index
])
backend
=
find_backend
(
lines
[
line_index
])
if
backend
is
not
None
:
if
backend
is
not
None
:
while
not
lines
[
line_index
].
startswith
(
" else:"
):
line_index
+=
1
line_index
+=
1
objects
=
[]
objects
=
[]
line_index
+=
1
# Until we unindent, add backend objects to the list
# Until we unindent, add backend objects to the list
while
len
(
lines
[
line_index
])
<=
1
or
lines
[
line_index
].
startswith
(
"
"
*
8
):
while
not
lines
[
line_index
].
startswith
(
"
else:"
):
line
=
lines
[
line_index
]
line
=
lines
[
line_index
]
single_line_import_search
=
_re_single_line_import
.
search
(
line
)
single_line_import_search
=
_re_single_line_import
.
search
(
line
)
if
single_line_import_search
is
not
None
:
if
single_line_import_search
is
not
None
:
...
@@ -129,7 +123,7 @@ def check_dummies(overwrite=False):
...
@@ -129,7 +123,7 @@ def check_dummies(overwrite=False):
short_names
=
{
"torch"
:
"pt"
}
short_names
=
{
"torch"
:
"pt"
}
# Locate actual dummy modules and read their content.
# Locate actual dummy modules and read their content.
path
=
os
.
path
.
join
(
PATH_TO_
TRANSFORM
ERS
,
"utils"
)
path
=
os
.
path
.
join
(
PATH_TO_
DIFFUS
ERS
,
"utils"
)
dummy_file_paths
=
{
dummy_file_paths
=
{
backend
:
os
.
path
.
join
(
path
,
f
"dummy_
{
short_names
.
get
(
backend
,
backend
)
}
_objects.py"
)
backend
:
os
.
path
.
join
(
path
,
f
"dummy_
{
short_names
.
get
(
backend
,
backend
)
}
_objects.py"
)
for
backend
in
dummy_files
.
keys
()
for
backend
in
dummy_files
.
keys
()
...
@@ -147,7 +141,7 @@ def check_dummies(overwrite=False):
...
@@ -147,7 +141,7 @@ def check_dummies(overwrite=False):
if
dummy_files
[
backend
]
!=
actual_dummies
[
backend
]:
if
dummy_files
[
backend
]
!=
actual_dummies
[
backend
]:
if
overwrite
:
if
overwrite
:
print
(
print
(
f
"Updating
transform
ers.utils.dummy_
{
short_names
.
get
(
backend
,
backend
)
}
_objects.py as the main "
f
"Updating
diffus
ers.utils.dummy_
{
short_names
.
get
(
backend
,
backend
)
}
_objects.py as the main "
"__init__ has new objects."
"__init__ has new objects."
)
)
with
open
(
dummy_file_paths
[
backend
],
"w"
,
encoding
=
"utf-8"
,
newline
=
"
\n
"
)
as
f
:
with
open
(
dummy_file_paths
[
backend
],
"w"
,
encoding
=
"utf-8"
,
newline
=
"
\n
"
)
as
f
:
...
@@ -155,7 +149,7 @@ def check_dummies(overwrite=False):
...
@@ -155,7 +149,7 @@ def check_dummies(overwrite=False):
else
:
else
:
raise
ValueError
(
raise
ValueError
(
"The main __init__ has objects that are not present in "
"The main __init__ has objects that are not present in "
f
"
transform
ers.utils.dummy_
{
short_names
.
get
(
backend
,
backend
)
}
_objects.py. Run `make fix-copies` "
f
"
diffus
ers.utils.dummy_
{
short_names
.
get
(
backend
,
backend
)
}
_objects.py. Run `make fix-copies` "
"to fix this."
"to fix this."
)
)
...
...
utils/check_table.py
View file @
b4e6a740
...
@@ -22,7 +22,7 @@ import re
...
@@ -22,7 +22,7 @@ import re
# All paths are set with the intent you should run this script from the root of the repo with the command
# All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_table.py
# python utils/check_table.py
TRANSFORMERS_PATH
=
"src/
transform
ers"
TRANSFORMERS_PATH
=
"src/
diffus
ers"
PATH_TO_DOCS
=
"docs/source/en"
PATH_TO_DOCS
=
"docs/source/en"
REPO_PATH
=
"."
REPO_PATH
=
"."
...
@@ -62,13 +62,13 @@ _re_flax_models = re.compile(r"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGe
...
@@ -62,13 +62,13 @@ _re_flax_models = re.compile(r"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGe
_re_pt_models
=
re
.
compile
(
r
"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)"
)
_re_pt_models
=
re
.
compile
(
r
"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)"
)
# This is to make sure the
transform
ers module imported is the one in the repo.
# This is to make sure the
diffus
ers module imported is the one in the repo.
spec
=
importlib
.
util
.
spec_from_file_location
(
spec
=
importlib
.
util
.
spec_from_file_location
(
"
transform
ers"
,
"
diffus
ers"
,
os
.
path
.
join
(
TRANSFORMERS_PATH
,
"__init__.py"
),
os
.
path
.
join
(
TRANSFORMERS_PATH
,
"__init__.py"
),
submodule_search_locations
=
[
TRANSFORMERS_PATH
],
submodule_search_locations
=
[
TRANSFORMERS_PATH
],
)
)
transform
ers_module
=
spec
.
loader
.
load_module
()
diffus
ers_module
=
spec
.
loader
.
load_module
()
# Thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python
# Thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python
...
@@ -88,10 +88,10 @@ def _center_text(text, width):
...
@@ -88,10 +88,10 @@ def _center_text(text, width):
def
get_model_table_from_auto_modules
():
def
get_model_table_from_auto_modules
():
"""Generates an up-to-date model table from the content of the auto modules."""
"""Generates an up-to-date model table from the content of the auto modules."""
# Dictionary model names to config.
# Dictionary model names to config.
config_maping_names
=
transform
ers_module
.
models
.
auto
.
configuration_auto
.
CONFIG_MAPPING_NAMES
config_maping_names
=
diffus
ers_module
.
models
.
auto
.
configuration_auto
.
CONFIG_MAPPING_NAMES
model_name_to_config
=
{
model_name_to_config
=
{
name
:
config_maping_names
[
code
]
name
:
config_maping_names
[
code
]
for
code
,
name
in
transform
ers_module
.
MODEL_NAMES_MAPPING
.
items
()
for
code
,
name
in
diffus
ers_module
.
MODEL_NAMES_MAPPING
.
items
()
if
code
in
config_maping_names
if
code
in
config_maping_names
}
}
model_name_to_prefix
=
{
name
:
config
.
replace
(
"ConfigMixin"
,
""
)
for
name
,
config
in
model_name_to_config
.
items
()}
model_name_to_prefix
=
{
name
:
config
.
replace
(
"ConfigMixin"
,
""
)
for
name
,
config
in
model_name_to_config
.
items
()}
...
@@ -103,8 +103,8 @@ def get_model_table_from_auto_modules():
...
@@ -103,8 +103,8 @@ def get_model_table_from_auto_modules():
tf_models
=
collections
.
defaultdict
(
bool
)
tf_models
=
collections
.
defaultdict
(
bool
)
flax_models
=
collections
.
defaultdict
(
bool
)
flax_models
=
collections
.
defaultdict
(
bool
)
# Let's lookup through all
transform
ers object (once).
# Let's lookup through all
diffus
ers object (once).
for
attr_name
in
dir
(
transform
ers_module
):
for
attr_name
in
dir
(
diffus
ers_module
):
lookup_dict
=
None
lookup_dict
=
None
if
attr_name
.
endswith
(
"Tokenizer"
):
if
attr_name
.
endswith
(
"Tokenizer"
):
lookup_dict
=
slow_tokenizers
lookup_dict
=
slow_tokenizers
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment