Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
4f761e95
"git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "0b208debff6a06305afd40cda3fe4d04a2f9eebb"
Commit
4f761e95
authored
Jun 09, 2022
by
patil-suraj
Browse files
Merge branch 'main' of
https://github.com/huggingface/diffusers
into main
parents
6b66999e
b02d0d6b
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
630 additions
and
223 deletions
+630
-223
setup.py
setup.py
+19
-32
src/diffusers/configuration_utils.py
src/diffusers/configuration_utils.py
+92
-98
src/diffusers/dynamic_modules_utils.py
src/diffusers/dynamic_modules_utils.py
+4
-12
src/diffusers/modeling_utils.py
src/diffusers/modeling_utils.py
+55
-75
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+27
-6
src/diffusers/utils/__init__.py
src/diffusers/utils/__init__.py
+49
-0
src/diffusers/utils/logging.py
src/diffusers/utils/logging.py
+344
-0
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+40
-0
No files found.
setup.py
View file @
4f761e95
# Copyright 202
1
The HuggingFace Team. All rights reserved.
# Copyright 202
2
The HuggingFace Team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -52,11 +52,11 @@ To create the package for pypi.
...
@@ -52,11 +52,11 @@ To create the package for pypi.
twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/
twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/
Check that you can install it in a virtualenv by running:
Check that you can install it in a virtualenv by running:
pip install -i https://testpypi.python.org/pypi
transform
ers
pip install -i https://testpypi.python.org/pypi
diffus
ers
Check you can run the following commands:
Check you can run the following commands:
python -c "from
transform
ers import pipeline; classifier = pipeline('text-classification'); print(classifier('What a nice release'))"
python -c "from
diffus
ers import pipeline; classifier = pipeline('text-classification'); print(classifier('What a nice release'))"
python -c "from
transform
ers import *"
python -c "from
diffus
ers import *"
9. Upload the final version to actual pypi:
9. Upload the final version to actual pypi:
twine upload dist/* -r pypi
twine upload dist/* -r pypi
...
@@ -77,36 +77,21 @@ from setuptools import find_packages, setup
...
@@ -77,36 +77,21 @@ from setuptools import find_packages, setup
# 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py
# 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py
_deps
=
[
_deps
=
[
"Pillow"
,
"Pillow"
,
"accelerate>=0.9.0"
,
"black~=22.0,>=22.3"
,
"black~=22.0,>=22.3"
,
"codecarbon==1.2.0"
,
"filelock"
,
"dataclasses"
,
"flake8>=3.8.3"
,
"datasets"
,
"huggingface-hub"
,
"GitPython<3.1.19"
,
"hf-doc-builder>=0.3.0"
,
"huggingface-hub>=0.1.0,<1.0"
,
"importlib_metadata"
,
"isort>=5.5.4"
,
"isort>=5.5.4"
,
"numpy
>=1.17
"
,
"numpy"
,
"pytest"
,
"pytest"
,
"pytest-timeout"
,
"pytest-xdist"
,
"python>=3.7.0"
,
"regex!=2019.12.17"
,
"requests"
,
"requests"
,
"sagemaker>=2.31.0"
,
"tokenizers>=0.11.1,!=0.11.3,<0.13"
,
"torch>=1.4"
,
"torch>=1.4"
,
"torchaudio"
,
"torchvision"
,
"tqdm>=4.27"
,
"unidic>=1.0.2"
,
"unidic_lite>=1.0.7"
,
"uvicorn"
,
]
]
# this is a lookup table with items like:
# this is a lookup table with items like:
#
#
# tokenizers: "
tokenizers
==0.
9.4
"
# tokenizers: "
huggingface-hub
==0.
8.0
"
# packaging: "packaging"
# packaging: "packaging"
#
#
# some of the values are versioned whereas others aren't.
# some of the values are versioned whereas others aren't.
...
@@ -176,15 +161,17 @@ extras["quality"] = ["black ~= 22.0", "isort >= 5.5.4", "flake8 >= 3.8.3"]
...
@@ -176,15 +161,17 @@ extras["quality"] = ["black ~= 22.0", "isort >= 5.5.4", "flake8 >= 3.8.3"]
extras
[
"docs"
]
=
[]
extras
[
"docs"
]
=
[]
extras
[
"test"
]
=
[
extras
[
"test"
]
=
[
"pytest"
,
"pytest"
,
"pytest-xdist"
,
"pytest-subtests"
,
"datasets"
,
"transformers"
,
]
]
extras
[
"dev"
]
=
extras
[
"quality"
]
+
extras
[
"test"
]
extras
[
"dev"
]
=
extras
[
"quality"
]
+
extras
[
"test"
]
extras
[
"sagemaker"
]
=
[
install_requires
=
[
"sagemaker"
,
# boto3 is a required package in sagemaker
deps
[
"filelock"
],
deps
[
"huggingface-hub"
],
deps
[
"numpy"
],
deps
[
"requests"
],
deps
[
"torch"
],
deps
[
"torchvision"
],
deps
[
"Pillow"
],
]
]
setup
(
setup
(
...
@@ -201,7 +188,7 @@ setup(
...
@@ -201,7 +188,7 @@ setup(
package_dir
=
{
""
:
"src"
},
package_dir
=
{
""
:
"src"
},
packages
=
find_packages
(
"src"
),
packages
=
find_packages
(
"src"
),
python_requires
=
">=3.6.0"
,
python_requires
=
">=3.6.0"
,
install_requires
=
[
"numpy>=1.17"
,
"packaging>=20.0"
,
"pyyaml"
,
"torch>=1.4.0"
]
,
install_requires
=
install_requires
,
extras_require
=
extras
,
extras_require
=
extras
,
classifiers
=
[
classifiers
=
[
"Development Status :: 5 - Production/Stable"
,
"Development Status :: 5 - Production/Stable"
,
...
...
src/diffusers/configuration_utils.py
View file @
4f761e95
...
@@ -24,18 +24,19 @@ import re
...
@@ -24,18 +24,19 @@ import re
from
typing
import
Any
,
Dict
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Tuple
,
Union
from
requests
import
HTTPError
from
requests
import
HTTPError
from
transformers.utils
import
(
from
huggingface_hub
import
hf_hub_download
from
.utils
import
(
HUGGINGFACE_CO_RESOLVE_ENDPOINT
,
HUGGINGFACE_CO_RESOLVE_ENDPOINT
,
DIFFUSERS_CACHE
,
EntryNotFoundError
,
EntryNotFoundError
,
RepositoryNotFoundError
,
RepositoryNotFoundError
,
RevisionNotFoundError
,
RevisionNotFoundError
,
cached_path
,
hf_bucket_url
,
is_offline_mode
,
is_remote_url
,
logging
,
logging
,
)
)
from
.
import
__version__
from
.
import
__version__
...
@@ -56,6 +57,8 @@ class ConfigMixin:
...
@@ -56,6 +57,8 @@ class ConfigMixin:
if
self
.
config_name
is
None
:
if
self
.
config_name
is
None
:
raise
NotImplementedError
(
f
"Make sure that
{
self
.
__class__
}
has defined a class name `config_name`"
)
raise
NotImplementedError
(
f
"Make sure that
{
self
.
__class__
}
has defined a class name `config_name`"
)
kwargs
[
"_class_name"
]
=
self
.
__class__
.
__name__
kwargs
[
"_class_name"
]
=
self
.
__class__
.
__name__
kwargs
[
"_diffusers_version"
]
=
__version__
for
key
,
value
in
kwargs
.
items
():
for
key
,
value
in
kwargs
.
items
():
try
:
try
:
setattr
(
self
,
key
,
value
)
setattr
(
self
,
key
,
value
)
...
@@ -90,11 +93,26 @@ class ConfigMixin:
...
@@ -90,11 +93,26 @@ class ConfigMixin:
self
.
to_json_file
(
output_config_file
)
self
.
to_json_file
(
output_config_file
)
logger
.
info
(
f
"ConfigMixinuration saved in
{
output_config_file
}
"
)
logger
.
info
(
f
"ConfigMixinuration saved in
{
output_config_file
}
"
)
@
classmethod
def
from_config
(
cls
,
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
return_unused_kwargs
=
False
,
**
kwargs
):
config_dict
=
cls
.
get_config_dict
(
pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
**
kwargs
)
init_dict
,
unused_kwargs
=
cls
.
extract_init_dict
(
config_dict
,
**
kwargs
)
model
=
cls
(
**
init_dict
)
if
return_unused_kwargs
:
return
model
,
unused_kwargs
else
:
return
model
@
classmethod
@
classmethod
def
get_config_dict
(
def
get_config_dict
(
cls
,
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
**
kwargs
cls
,
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
**
kwargs
)
->
Tuple
[
Dict
[
str
,
Any
],
Dict
[
str
,
Any
]]:
)
->
Tuple
[
Dict
[
str
,
Any
],
Dict
[
str
,
Any
]]:
cache_dir
=
kwargs
.
pop
(
"cache_dir"
,
None
)
cache_dir
=
kwargs
.
pop
(
"cache_dir"
,
DIFFUSERS_CACHE
)
force_download
=
kwargs
.
pop
(
"force_download"
,
False
)
force_download
=
kwargs
.
pop
(
"force_download"
,
False
)
resume_download
=
kwargs
.
pop
(
"resume_download"
,
False
)
resume_download
=
kwargs
.
pop
(
"resume_download"
,
False
)
proxies
=
kwargs
.
pop
(
"proxies"
,
None
)
proxies
=
kwargs
.
pop
(
"proxies"
,
None
)
...
@@ -104,85 +122,83 @@ class ConfigMixin:
...
@@ -104,85 +122,83 @@ class ConfigMixin:
user_agent
=
{
"file_type"
:
"config"
}
user_agent
=
{
"file_type"
:
"config"
}
if
is_offline_mode
()
and
not
local_files_only
:
logger
.
info
(
"Offline mode: forcing local_files_only=True"
)
local_files_only
=
True
pretrained_model_name_or_path
=
str
(
pretrained_model_name_or_path
)
pretrained_model_name_or_path
=
str
(
pretrained_model_name_or_path
)
if
os
.
path
.
isfile
(
pretrained_model_name_or_path
)
or
is_remote_url
(
pretrained_model_name_or_path
):
config_file
=
pretrained_model_name_or_path
else
:
configuration_file
=
cls
.
config_name
if
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
if
cls
.
config_name
is
None
:
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
configuration_file
)
raise
ValueError
(
"`self.config_name` is not defined. Note that one should not load a config from "
"`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
)
if
os
.
path
.
isfile
(
pretrained_model_name_or_path
):
config_file
=
pretrained_model_name_or_path
elif
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
if
os
.
path
.
isfile
(
os
.
path
.
join
(
pretrained_model_name_or_path
,
cls
.
config_name
)):
# Load from a PyTorch checkpoint
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
cls
.
config_name
)
else
:
else
:
config_file
=
hf_bucket_url
(
raise
EnvironmentError
(
pretrained_model_name_or_path
,
filename
=
configuration_file
,
revision
=
revision
,
mirror
=
None
f
"Error no file named
{
cls
.
config_name
}
found in directory
{
pretrained_model_name_or_path
}
."
)
else
:
try
:
# Load from URL or cache if already cached
config_file
=
hf_hub_download
(
pretrained_model_name_or_path
,
filename
=
cls
.
config_name
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
proxies
=
proxies
,
resume_download
=
resume_download
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
user_agent
=
user_agent
,
)
)
try
:
except
RepositoryNotFoundError
:
# Load from URL or cache if already cached
raise
EnvironmentError
(
resolved_config_file
=
cached_path
(
f
"
{
pretrained_model_name_or_path
}
is not a local folder and is not a valid model identifier listed on "
config_file
,
"'https://huggingface.co/models'
\n
If this is a private repository, make sure to pass a token having "
cache_dir
=
cache_dir
,
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass "
force_download
=
force_download
,
"`use_auth_token=True`."
proxies
=
proxies
,
)
resume_download
=
resume_download
,
except
RevisionNotFoundError
:
local_files_only
=
local_files_only
,
raise
EnvironmentError
(
use_auth_token
=
use_auth_token
,
f
"
{
revision
}
is not a valid git identifier (branch name, tag name or commit id) that exists for this "
user_agent
=
user_agent
,
f
"model name. Check the model page at 'https://huggingface.co/
{
pretrained_model_name_or_path
}
' for "
)
"available revisions."
)
except
RepositoryNotFoundError
:
except
EntryNotFoundError
:
raise
EnvironmentError
(
raise
EnvironmentError
(
f
"
{
pretrained_model_name_or_path
}
is not a local folder and is not a valid model identifier listed on "
f
"
{
pretrained_model_name_or_path
}
does not appear to have a file named
{
cls
.
config_name
}
."
"'https://huggingface.co/models'
\n
If this is a private repository, make sure to pass a token having "
)
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass "
except
HTTPError
as
err
:
"`use_auth_token=True`."
raise
EnvironmentError
(
)
f
"There was a specific connection error when trying to load
{
pretrained_model_name_or_path
}
:
\n
{
err
}
"
except
RevisionNotFoundError
:
)
raise
EnvironmentError
(
except
ValueError
:
f
"
{
revision
}
is not a valid git identifier (branch name, tag name or commit id) that exists for this "
raise
EnvironmentError
(
f
"model name. Check the model page at 'https://huggingface.co/
{
pretrained_model_name_or_path
}
' for "
f
"We couldn't connect to '
{
HUGGINGFACE_CO_RESOLVE_ENDPOINT
}
' to load this model, couldn't find it in"
"available revisions."
f
" the cached files and it looks like
{
pretrained_model_name_or_path
}
is not the path to a directory"
)
f
" containing a
{
cls
.
config_name
}
file.
\n
Checkout your internet connection or see how to run the"
except
EntryNotFoundError
:
" library in offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
raise
EnvironmentError
(
)
f
"
{
pretrained_model_name_or_path
}
does not appear to have a file named
{
configuration_file
}
."
except
EnvironmentError
:
)
raise
EnvironmentError
(
except
HTTPError
as
err
:
f
"Can't load config for '
{
pretrained_model_name_or_path
}
'. If you were trying to load it from "
raise
EnvironmentError
(
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f
"There was a specific connection error when trying to load
{
pretrained_model_name_or_path
}
:
\n
{
err
}
"
f
"Otherwise, make sure '
{
pretrained_model_name_or_path
}
' is the correct path to a directory "
)
f
"containing a
{
cls
.
config_name
}
file"
except
ValueError
:
)
raise
EnvironmentError
(
f
"We couldn't connect to '
{
HUGGINGFACE_CO_RESOLVE_ENDPOINT
}
' to load this model, couldn't find it in"
f
" the cached files and it looks like
{
pretrained_model_name_or_path
}
is not the path to a directory"
f
" containing a
{
configuration_file
}
file.
\n
Checkout your internet connection or see how to run the"
" library in offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
)
except
EnvironmentError
:
raise
EnvironmentError
(
f
"Can't load config for '
{
pretrained_model_name_or_path
}
'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f
"Otherwise, make sure '
{
pretrained_model_name_or_path
}
' is the correct path to a directory "
f
"containing a
{
configuration_file
}
file"
)
try
:
try
:
# Load config dict
# Load config dict
config_dict
=
cls
.
_dict_from_json_file
(
resolved_
config_file
)
config_dict
=
cls
.
_dict_from_json_file
(
config_file
)
except
(
json
.
JSONDecodeError
,
UnicodeDecodeError
):
except
(
json
.
JSONDecodeError
,
UnicodeDecodeError
):
raise
EnvironmentError
(
raise
EnvironmentError
(
f
"It looks like the config file at '
{
resolved_
config_file
}
' is not a valid JSON file."
f
"It looks like the config file at '
{
config_file
}
' is not a valid JSON file."
)
)
if
resolved_config_file
==
config_file
:
logger
.
info
(
f
"loading configuration file
{
config_file
}
"
)
else
:
logger
.
info
(
f
"loading configuration file
{
config_file
}
from cache at
{
resolved_config_file
}
"
)
return
config_dict
return
config_dict
@
classmethod
@
classmethod
...
@@ -208,19 +224,6 @@ class ConfigMixin:
...
@@ -208,19 +224,6 @@ class ConfigMixin:
return
init_dict
,
unused_kwargs
return
init_dict
,
unused_kwargs
@
classmethod
def
from_config
(
cls
,
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
return_unused_kwargs
=
False
,
**
kwargs
):
config_dict
=
cls
.
get_config_dict
(
pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
**
kwargs
)
init_dict
,
unused_kwargs
=
cls
.
extract_init_dict
(
config_dict
,
**
kwargs
)
model
=
cls
(
**
init_dict
)
if
return_unused_kwargs
:
return
model
,
unused_kwargs
else
:
return
model
@
classmethod
@
classmethod
def
_dict_from_json_file
(
cls
,
json_file
:
Union
[
str
,
os
.
PathLike
]):
def
_dict_from_json_file
(
cls
,
json_file
:
Union
[
str
,
os
.
PathLike
]):
with
open
(
json_file
,
"r"
,
encoding
=
"utf-8"
)
as
reader
:
with
open
(
json_file
,
"r"
,
encoding
=
"utf-8"
)
as
reader
:
...
@@ -233,18 +236,9 @@ class ConfigMixin:
...
@@ -233,18 +236,9 @@ class ConfigMixin:
def
__repr__
(
self
):
def
__repr__
(
self
):
return
f
"
{
self
.
__class__
.
__name__
}
{
self
.
to_json_string
()
}
"
return
f
"
{
self
.
__class__
.
__name__
}
{
self
.
to_json_string
()
}
"
def
to_dict
(
self
)
->
Dict
[
str
,
Any
]:
@
property
"""
def
config
(
self
)
->
Dict
[
str
,
Any
]:
Serializes this instance to a Python dictionary.
output
=
copy
.
deepcopy
(
self
.
_dict_to_save
)
Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
output
=
copy
.
deepcopy
(
self
.
__dict__
)
# Diffusion version when serializing the model
output
[
"diffusers_version"
]
=
__version__
return
output
return
output
def
to_json_string
(
self
)
->
str
:
def
to_json_string
(
self
)
->
str
:
...
...
src/diffusers/dynamic_modules_utils.py
View file @
4f761e95
...
@@ -22,16 +22,8 @@ import sys
...
@@ -22,16 +22,8 @@ import sys
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Dict
,
Optional
,
Union
from
typing
import
Dict
,
Optional
,
Union
from
huggingface_hub
import
HfFolder
,
model_info
from
huggingface_hub
import
cached_download
from
.utils
import
HF_MODULES_CACHE
,
DIFFUSERS_DYNAMIC_MODULE_NAME
,
logging
from
transformers.utils
import
(
HF_MODULES_CACHE
,
TRANSFORMERS_DYNAMIC_MODULE_NAME
,
cached_path
,
hf_bucket_url
,
is_offline_mode
,
logging
,
)
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
@@ -219,7 +211,7 @@ def get_cached_module_file(
...
@@ -219,7 +211,7 @@ def get_cached_module_file(
try
:
try
:
# Load from URL or cache if already cached
# Load from URL or cache if already cached
resolved_module_file
=
cached_
path
(
resolved_module_file
=
cached_
download
(
module_file_or_url
,
module_file_or_url
,
cache_dir
=
cache_dir
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
force_download
=
force_download
,
...
@@ -237,7 +229,7 @@ def get_cached_module_file(
...
@@ -237,7 +229,7 @@ def get_cached_module_file(
modules_needed
=
check_imports
(
resolved_module_file
)
modules_needed
=
check_imports
(
resolved_module_file
)
# Now we move the module inside our cached dynamic modules.
# Now we move the module inside our cached dynamic modules.
full_submodule
=
TRANSFORM
ERS_DYNAMIC_MODULE_NAME
+
os
.
path
.
sep
+
submodule
full_submodule
=
DIFFUS
ERS_DYNAMIC_MODULE_NAME
+
os
.
path
.
sep
+
submodule
create_dynamic_module
(
full_submodule
)
create_dynamic_module
(
full_submodule
)
submodule_path
=
Path
(
HF_MODULES_CACHE
)
/
full_submodule
submodule_path
=
Path
(
HF_MODULES_CACHE
)
/
full_submodule
# We always copy local files (we could hash the file to see if there was a change, and give them the name of
# We always copy local files (we could hash the file to see if there was a change, and give them the name of
...
...
src/diffusers/modeling_utils.py
View file @
4f761e95
...
@@ -21,18 +21,15 @@ import torch
...
@@ -21,18 +21,15 @@ import torch
from
torch
import
Tensor
,
device
from
torch
import
Tensor
,
device
from
requests
import
HTTPError
from
requests
import
HTTPError
from
huggingface_hub
import
hf_hub_download
# CHANGE to diffusers.utils
from
.utils
import
(
from
transformers.utils
import
(
CONFIG_NAME
,
CONFIG_NAME
,
DIFFUSERS_CACHE
,
HUGGINGFACE_CO_RESOLVE_ENDPOINT
,
HUGGINGFACE_CO_RESOLVE_ENDPOINT
,
EntryNotFoundError
,
EntryNotFoundError
,
RepositoryNotFoundError
,
RepositoryNotFoundError
,
RevisionNotFoundError
,
RevisionNotFoundError
,
cached_path
,
hf_bucket_url
,
is_offline_mode
,
is_remote_url
,
logging
,
logging
,
)
)
...
@@ -314,7 +311,7 @@ class ModelMixin(torch.nn.Module):
...
@@ -314,7 +311,7 @@ class ModelMixin(torch.nn.Module):
</Tip>
</Tip>
"""
"""
cache_dir
=
kwargs
.
pop
(
"cache_dir"
,
None
)
cache_dir
=
kwargs
.
pop
(
"cache_dir"
,
DIFFUSERS_CACHE
)
ignore_mismatched_sizes
=
kwargs
.
pop
(
"ignore_mismatched_sizes"
,
False
)
ignore_mismatched_sizes
=
kwargs
.
pop
(
"ignore_mismatched_sizes"
,
False
)
force_download
=
kwargs
.
pop
(
"force_download"
,
False
)
force_download
=
kwargs
.
pop
(
"force_download"
,
False
)
resume_download
=
kwargs
.
pop
(
"resume_download"
,
False
)
resume_download
=
kwargs
.
pop
(
"resume_download"
,
False
)
...
@@ -323,15 +320,10 @@ class ModelMixin(torch.nn.Module):
...
@@ -323,15 +320,10 @@ class ModelMixin(torch.nn.Module):
local_files_only
=
kwargs
.
pop
(
"local_files_only"
,
False
)
local_files_only
=
kwargs
.
pop
(
"local_files_only"
,
False
)
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
revision
=
kwargs
.
pop
(
"revision"
,
None
)
revision
=
kwargs
.
pop
(
"revision"
,
None
)
mirror
=
kwargs
.
pop
(
"mirror"
,
None
)
from_auto_class
=
kwargs
.
pop
(
"_from_auto"
,
False
)
from_auto_class
=
kwargs
.
pop
(
"_from_auto"
,
False
)
user_agent
=
{
"file_type"
:
"model"
,
"framework"
:
"pytorch"
,
"from_auto_class"
:
from_auto_class
}
user_agent
=
{
"file_type"
:
"model"
,
"framework"
:
"pytorch"
,
"from_auto_class"
:
from_auto_class
}
if
is_offline_mode
()
and
not
local_files_only
:
logger
.
info
(
"Offline mode: forcing local_files_only=True"
)
local_files_only
=
True
# Load config if we don't provide a configuration
# Load config if we don't provide a configuration
config_path
=
pretrained_model_name_or_path
config_path
=
pretrained_model_name_or_path
model
,
unused_kwargs
=
cls
.
from_config
(
model
,
unused_kwargs
=
cls
.
from_config
(
...
@@ -353,79 +345,67 @@ class ModelMixin(torch.nn.Module):
...
@@ -353,79 +345,67 @@ class ModelMixin(torch.nn.Module):
if
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
if
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
if
os
.
path
.
isfile
(
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)):
if
os
.
path
.
isfile
(
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)):
# Load from a PyTorch checkpoint
# Load from a PyTorch checkpoint
archive
_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
model
_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
else
:
else
:
raise
EnvironmentError
(
raise
EnvironmentError
(
f
"Error no file named
{
WEIGHTS_NAME
}
found in directory
{
pretrained_model_name_or_path
}
."
f
"Error no file named
{
WEIGHTS_NAME
}
found in directory
{
pretrained_model_name_or_path
}
."
)
)
elif
os
.
path
.
isfile
(
pretrained_model_name_or_path
)
or
is_remote_url
(
pretrained_model_name_or_path
):
archive_file
=
pretrained_model_name_or_path
else
:
else
:
filename
=
WEIGHTS_NAME
try
:
# Load from URL or cache if already cached
archive_file
=
hf_bucket_url
(
model_file
=
hf_hub_download
(
pretrained_model_name_or_path
,
filename
=
filename
,
revision
=
revision
,
mirror
=
mirror
pretrained_model_name_or_path
,
)
filename
=
WEIGHTS_NAME
,
cache_dir
=
cache_dir
,
try
:
force_download
=
force_download
,
# Load from URL or cache if already cached
proxies
=
proxies
,
resolved_archive_file
=
cached_path
(
resume_download
=
resume_download
,
archive_file
,
local_files_only
=
local_files_only
,
cache_dir
=
cache_dir
,
use_auth_token
=
use_auth_token
,
force_download
=
force_download
,
user_agent
=
user_agent
,
proxies
=
proxies
,
)
resume_download
=
resume_download
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
user_agent
=
user_agent
,
)
except
RepositoryNotFoundError
:
raise
EnvironmentError
(
f
"
{
pretrained_model_name_or_path
}
is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'
\n
If this is a private repository, make sure to pass a "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
"login` and pass `use_auth_token=True`."
)
except
RevisionNotFoundError
:
raise
EnvironmentError
(
f
"
{
revision
}
is not a valid git identifier (branch name, tag name or commit id) that exists for "
"this model name. Check the model page at "
f
"'https://huggingface.co/
{
pretrained_model_name_or_path
}
' for available revisions."
)
except
EntryNotFoundError
:
raise
EnvironmentError
(
f
"
{
pretrained_model_name_or_path
}
does not appear to have a file named
{
filename
}
."
)
except
HTTPError
as
err
:
raise
EnvironmentError
(
f
"There was a specific connection error when trying to load
{
pretrained_model_name_or_path
}
:
\n
{
err
}
"
)
except
ValueError
:
raise
EnvironmentError
(
f
"We couldn't connect to '
{
HUGGINGFACE_CO_RESOLVE_ENDPOINT
}
' to load this model, couldn't find it"
f
" in the cached files and it looks like
{
pretrained_model_name_or_path
}
is not the path to a"
f
" directory containing a file named
{
WEIGHTS_NAME
}
or"
"
\n
Checkout your internet connection or see how to run the library in"
" offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
except
EnvironmentError
:
raise
EnvironmentError
(
f
"Can't load the model for '
{
pretrained_model_name_or_path
}
'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f
"Otherwise, make sure '
{
pretrained_model_name_or_path
}
' is the correct path to a directory "
f
"containing a file named
{
WEIGHTS_NAME
}
"
)
if
resolved_archive_file
==
archive_file
:
except
RepositoryNotFoundError
:
logger
.
info
(
f
"loading weights file
{
archive_file
}
"
)
raise
EnvironmentError
(
else
:
f
"
{
pretrained_model_name_or_path
}
is not a local folder and is not a valid model identifier "
logger
.
info
(
f
"loading weights file
{
archive_file
}
from cache at
{
resolved_archive_file
}
"
)
"listed on 'https://huggingface.co/models'
\n
If this is a private repository, make sure to pass a "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
"login` and pass `use_auth_token=True`."
)
except
RevisionNotFoundError
:
raise
EnvironmentError
(
f
"
{
revision
}
is not a valid git identifier (branch name, tag name or commit id) that exists for "
"this model name. Check the model page at "
f
"'https://huggingface.co/
{
pretrained_model_name_or_path
}
' for available revisions."
)
except
EntryNotFoundError
:
raise
EnvironmentError
(
f
"
{
pretrained_model_name_or_path
}
does not appear to have a file named
{
model_file
}
."
)
except
HTTPError
as
err
:
raise
EnvironmentError
(
f
"There was a specific connection error when trying to load
{
pretrained_model_name_or_path
}
:
\n
{
err
}
"
)
except
ValueError
:
raise
EnvironmentError
(
f
"We couldn't connect to '
{
HUGGINGFACE_CO_RESOLVE_ENDPOINT
}
' to load this model, couldn't find it"
f
" in the cached files and it looks like
{
pretrained_model_name_or_path
}
is not the path to a"
f
" directory containing a file named
{
WEIGHTS_NAME
}
or"
"
\n
Checkout your internet connection or see how to run the library in"
" offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
except
EnvironmentError
:
raise
EnvironmentError
(
f
"Can't load the model for '
{
pretrained_model_name_or_path
}
'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f
"Otherwise, make sure '
{
pretrained_model_name_or_path
}
' is the correct path to a directory "
f
"containing a file named
{
WEIGHTS_NAME
}
"
)
# restore default dtype
# restore default dtype
state_dict
=
load_state_dict
(
resolved_archive
_file
)
state_dict
=
load_state_dict
(
model
_file
)
model
,
missing_keys
,
unexpected_keys
,
mismatched_keys
,
error_msgs
=
cls
.
_load_pretrained_model
(
model
,
missing_keys
,
unexpected_keys
,
mismatched_keys
,
error_msgs
=
cls
.
_load_pretrained_model
(
model
,
model
,
state_dict
,
state_dict
,
resolved_archive
_file
,
model
_file
,
pretrained_model_name_or_path
,
pretrained_model_name_or_path
,
ignore_mismatched_sizes
=
ignore_mismatched_sizes
,
ignore_mismatched_sizes
=
ignore_mismatched_sizes
,
)
)
...
...
src/diffusers/pipeline_utils.py
View file @
4f761e95
...
@@ -20,8 +20,7 @@ from typing import Optional, Union
...
@@ -20,8 +20,7 @@ from typing import Optional, Union
from
huggingface_hub
import
snapshot_download
from
huggingface_hub
import
snapshot_download
# CHANGE to diffusers.utils
from
.utils
import
logging
,
DIFFUSERS_CACHE
from
transformers.utils
import
logging
from
.configuration_utils
import
ConfigMixin
from
.configuration_utils
import
ConfigMixin
from
.dynamic_modules_utils
import
get_class_from_dynamic_module
from
.dynamic_modules_utils
import
get_class_from_dynamic_module
...
@@ -80,11 +79,12 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -80,11 +79,12 @@ class DiffusionPipeline(ConfigMixin):
def
save_pretrained
(
self
,
save_directory
:
Union
[
str
,
os
.
PathLike
]):
def
save_pretrained
(
self
,
save_directory
:
Union
[
str
,
os
.
PathLike
]):
self
.
save_config
(
save_directory
)
self
.
save_config
(
save_directory
)
model_index_dict
=
self
.
_dict_to_save
model_index_dict
=
self
.
config
model_index_dict
.
pop
(
"_class_name"
)
model_index_dict
.
pop
(
"_class_name"
)
model_index_dict
.
pop
(
"_diffusers_version"
)
model_index_dict
.
pop
(
"_module"
)
model_index_dict
.
pop
(
"_module"
)
for
name
,
(
library_name
,
class_name
)
in
self
.
_dict_to_save
.
items
():
for
name
,
(
library_name
,
class_name
)
in
model_index_dict
.
items
():
importable_classes
=
LOADABLE_CLASSES
[
library_name
]
importable_classes
=
LOADABLE_CLASSES
[
library_name
]
# TODO: Suraj
# TODO: Suraj
...
@@ -105,14 +105,36 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -105,14 +105,36 @@ class DiffusionPipeline(ConfigMixin):
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
Optional
[
Union
[
str
,
os
.
PathLike
]],
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
Optional
[
Union
[
str
,
os
.
PathLike
]],
**
kwargs
):
r
"""
Add docstrings
"""
cache_dir
=
kwargs
.
pop
(
"cache_dir"
,
DIFFUSERS_CACHE
)
force_download
=
kwargs
.
pop
(
"force_download"
,
False
)
resume_download
=
kwargs
.
pop
(
"resume_download"
,
False
)
proxies
=
kwargs
.
pop
(
"proxies"
,
None
)
output_loading_info
=
kwargs
.
pop
(
"output_loading_info"
,
False
)
local_files_only
=
kwargs
.
pop
(
"local_files_only"
,
False
)
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
# use snapshot download here to get it working from from_pretrained
# use snapshot download here to get it working from from_pretrained
if
not
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
if
not
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
cached_folder
=
snapshot_download
(
pretrained_model_name_or_path
)
cached_folder
=
snapshot_download
(
pretrained_model_name_or_path
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
resume_download
=
resume_download
,
proxies
=
proxies
,
output_loading_info
=
output_loading_info
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
)
else
:
else
:
cached_folder
=
pretrained_model_name_or_path
cached_folder
=
pretrained_model_name_or_path
config_dict
=
cls
.
get_config_dict
(
cached_folder
)
config_dict
=
cls
.
get_config_dict
(
cached_folder
)
module
=
config_dict
[
"_module"
]
class_name_
=
config_dict
[
"_class_name"
]
module_candidate
=
config_dict
[
"_module"
]
module_candidate
=
config_dict
[
"_module"
]
module_candidate_name
=
module_candidate
.
replace
(
".py"
,
""
)
module_candidate_name
=
module_candidate
.
replace
(
".py"
,
""
)
...
@@ -130,7 +152,6 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -130,7 +152,6 @@ class DiffusionPipeline(ConfigMixin):
init_kwargs
=
{}
init_kwargs
=
{}
for
name
,
(
library_name
,
class_name
)
in
init_dict
.
items
():
for
name
,
(
library_name
,
class_name
)
in
init_dict
.
items
():
# if the model is not in diffusers or transformers, we need to load it from the hub
# if the model is not in diffusers or transformers, we need to load it from the hub
# assumes that it's a subclass of ModelMixin
# assumes that it's a subclass of ModelMixin
if
library_name
==
module_candidate_name
:
if
library_name
==
module_candidate_name
:
...
...
src/diffusers/utils/__init__.py
0 → 100644
View file @
4f761e95
#!/usr/bin/env python
# coding=utf-8
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
requests.exceptions
import
HTTPError
import
os
hf_cache_home
=
os
.
path
.
expanduser
(
os
.
getenv
(
"HF_HOME"
,
os
.
path
.
join
(
os
.
getenv
(
"XDG_CACHE_HOME"
,
"~/.cache"
),
"huggingface"
))
)
default_cache_path
=
os
.
path
.
join
(
hf_cache_home
,
"diffusers"
)
CONFIG_NAME
=
"config.json"
HUGGINGFACE_CO_RESOLVE_ENDPOINT
=
"https://huggingface.co"
DIFFUSERS_CACHE
=
default_cache_path
DIFFUSERS_DYNAMIC_MODULE_NAME
=
"diffusers_modules"
HF_MODULES_CACHE
=
os
.
getenv
(
"HF_MODULES_CACHE"
,
os
.
path
.
join
(
hf_cache_home
,
"modules"
))
class
RepositoryNotFoundError
(
HTTPError
):
"""
Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does
not have access to.
"""
class
EntryNotFoundError
(
HTTPError
):
"""Raised when trying to access a hf.co URL with a valid repository and revision but an invalid filename."""
class
RevisionNotFoundError
(
HTTPError
):
"""Raised when trying to access a hf.co URL with a valid repository but an invalid revision."""
src/diffusers/utils/logging.py
0 → 100644
View file @
4f761e95
# coding=utf-8
# Copyright 2020 Optuna, Hugging Face
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Logging utilities."""
import
logging
import
os
import
sys
import
threading
from
logging
import
CRITICAL
# NOQA
from
logging
import
DEBUG
# NOQA
from
logging
import
ERROR
# NOQA
from
logging
import
FATAL
# NOQA
from
logging
import
INFO
# NOQA
from
logging
import
NOTSET
# NOQA
from
logging
import
WARN
# NOQA
from
logging
import
WARNING
# NOQA
from
typing
import
Optional
from
tqdm
import
auto
as
tqdm_lib
_lock
=
threading
.
Lock
()
_default_handler
:
Optional
[
logging
.
Handler
]
=
None
log_levels
=
{
"debug"
:
logging
.
DEBUG
,
"info"
:
logging
.
INFO
,
"warning"
:
logging
.
WARNING
,
"error"
:
logging
.
ERROR
,
"critical"
:
logging
.
CRITICAL
,
}
_default_log_level
=
logging
.
WARNING
_tqdm_active
=
True
def
_get_default_logging_level
():
"""
If TRANSFORMERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
not - fall back to `_default_log_level`
"""
env_level_str
=
os
.
getenv
(
"TRANSFORMERS_VERBOSITY"
,
None
)
if
env_level_str
:
if
env_level_str
in
log_levels
:
return
log_levels
[
env_level_str
]
else
:
logging
.
getLogger
().
warning
(
f
"Unknown option TRANSFORMERS_VERBOSITY=
{
env_level_str
}
, "
f
"has to be one of:
{
', '
.
join
(
log_levels
.
keys
())
}
"
)
return
_default_log_level
def
_get_library_name
()
->
str
:
return
__name__
.
split
(
"."
)[
0
]
def
_get_library_root_logger
()
->
logging
.
Logger
:
return
logging
.
getLogger
(
_get_library_name
())
def
_configure_library_root_logger
()
->
None
:
global
_default_handler
with
_lock
:
if
_default_handler
:
# This library has already configured the library root logger.
return
_default_handler
=
logging
.
StreamHandler
()
# Set sys.stderr as stream.
_default_handler
.
flush
=
sys
.
stderr
.
flush
# Apply our default configuration to the library root logger.
library_root_logger
=
_get_library_root_logger
()
library_root_logger
.
addHandler
(
_default_handler
)
library_root_logger
.
setLevel
(
_get_default_logging_level
())
library_root_logger
.
propagate
=
False
def
_reset_library_root_logger
()
->
None
:
global
_default_handler
with
_lock
:
if
not
_default_handler
:
return
library_root_logger
=
_get_library_root_logger
()
library_root_logger
.
removeHandler
(
_default_handler
)
library_root_logger
.
setLevel
(
logging
.
NOTSET
)
_default_handler
=
None
def
get_log_levels_dict
():
return
log_levels
def
get_logger
(
name
:
Optional
[
str
]
=
None
)
->
logging
.
Logger
:
"""
Return a logger with the specified name.
This function is not supposed to be directly accessed unless you are writing a custom diffusers module.
"""
if
name
is
None
:
name
=
_get_library_name
()
_configure_library_root_logger
()
return
logging
.
getLogger
(
name
)
def
get_verbosity
()
->
int
:
"""
Return the current level for the 🤗 Transformers's root logger as an int.
Returns:
`int`: The logging level.
<Tip>
🤗 Transformers has following logging levels:
- 50: `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
- 40: `diffusers.logging.ERROR`
- 30: `diffusers.logging.WARNING` or `diffusers.logging.WARN`
- 20: `diffusers.logging.INFO`
- 10: `diffusers.logging.DEBUG`
</Tip>"""
_configure_library_root_logger
()
return
_get_library_root_logger
().
getEffectiveLevel
()
def
set_verbosity
(
verbosity
:
int
)
->
None
:
"""
Set the verbosity level for the 🤗 Transformers's root logger.
Args:
verbosity (`int`):
Logging level, e.g., one of:
- `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
- `diffusers.logging.ERROR`
- `diffusers.logging.WARNING` or `diffusers.logging.WARN`
- `diffusers.logging.INFO`
- `diffusers.logging.DEBUG`
"""
_configure_library_root_logger
()
_get_library_root_logger
().
setLevel
(
verbosity
)
def
set_verbosity_info
():
"""Set the verbosity to the `INFO` level."""
return
set_verbosity
(
INFO
)
def
set_verbosity_warning
():
"""Set the verbosity to the `WARNING` level."""
return
set_verbosity
(
WARNING
)
def
set_verbosity_debug
():
"""Set the verbosity to the `DEBUG` level."""
return
set_verbosity
(
DEBUG
)
def
set_verbosity_error
():
"""Set the verbosity to the `ERROR` level."""
return
set_verbosity
(
ERROR
)
def
disable_default_handler
()
->
None
:
"""Disable the default handler of the HuggingFace Transformers's root logger."""
_configure_library_root_logger
()
assert
_default_handler
is
not
None
_get_library_root_logger
().
removeHandler
(
_default_handler
)
def
enable_default_handler
()
->
None
:
"""Enable the default handler of the HuggingFace Transformers's root logger."""
_configure_library_root_logger
()
assert
_default_handler
is
not
None
_get_library_root_logger
().
addHandler
(
_default_handler
)
def
add_handler
(
handler
:
logging
.
Handler
)
->
None
:
"""adds a handler to the HuggingFace Transformers's root logger."""
_configure_library_root_logger
()
assert
handler
is
not
None
_get_library_root_logger
().
addHandler
(
handler
)
def
remove_handler
(
handler
:
logging
.
Handler
)
->
None
:
"""removes given handler from the HuggingFace Transformers's root logger."""
_configure_library_root_logger
()
assert
handler
is
not
None
and
handler
not
in
_get_library_root_logger
().
handlers
_get_library_root_logger
().
removeHandler
(
handler
)
def
disable_propagation
()
->
None
:
"""
Disable propagation of the library log outputs. Note that log propagation is disabled by default.
"""
_configure_library_root_logger
()
_get_library_root_logger
().
propagate
=
False
def
enable_propagation
()
->
None
:
"""
Enable propagation of the library log outputs. Please disable the HuggingFace Transformers's default handler to
prevent double logging if the root logger has been configured.
"""
_configure_library_root_logger
()
_get_library_root_logger
().
propagate
=
True
def
enable_explicit_format
()
->
None
:
"""
Enable explicit formatting for every HuggingFace Transformers's logger. The explicit formatter is as follows:
```
[LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
```
All handlers currently bound to the root logger are affected by this method.
"""
handlers
=
_get_library_root_logger
().
handlers
for
handler
in
handlers
:
formatter
=
logging
.
Formatter
(
"[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s"
)
handler
.
setFormatter
(
formatter
)
def
reset_format
()
->
None
:
"""
Resets the formatting for HuggingFace Transformers's loggers.
All handlers currently bound to the root logger are affected by this method.
"""
handlers
=
_get_library_root_logger
().
handlers
for
handler
in
handlers
:
handler
.
setFormatter
(
None
)
def
warning_advice
(
self
,
*
args
,
**
kwargs
):
"""
This method is identical to `logger.warning()`, but if env var TRANSFORMERS_NO_ADVISORY_WARNINGS=1 is set, this
warning will not be printed
"""
no_advisory_warnings
=
os
.
getenv
(
"TRANSFORMERS_NO_ADVISORY_WARNINGS"
,
False
)
if
no_advisory_warnings
:
return
self
.
warning
(
*
args
,
**
kwargs
)
logging
.
Logger
.
warning_advice
=
warning_advice
class
EmptyTqdm
:
"""Dummy tqdm which doesn't do anything."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
# pylint: disable=unused-argument
self
.
_iterator
=
args
[
0
]
if
args
else
None
def
__iter__
(
self
):
return
iter
(
self
.
_iterator
)
def
__getattr__
(
self
,
_
):
"""Return empty function."""
def
empty_fn
(
*
args
,
**
kwargs
):
# pylint: disable=unused-argument
return
return
empty_fn
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
type_
,
value
,
traceback
):
return
class
_tqdm_cls
:
def
__call__
(
self
,
*
args
,
**
kwargs
):
if
_tqdm_active
:
return
tqdm_lib
.
tqdm
(
*
args
,
**
kwargs
)
else
:
return
EmptyTqdm
(
*
args
,
**
kwargs
)
def
set_lock
(
self
,
*
args
,
**
kwargs
):
self
.
_lock
=
None
if
_tqdm_active
:
return
tqdm_lib
.
tqdm
.
set_lock
(
*
args
,
**
kwargs
)
def
get_lock
(
self
):
if
_tqdm_active
:
return
tqdm_lib
.
tqdm
.
get_lock
()
tqdm
=
_tqdm_cls
()
def
is_progress_bar_enabled
()
->
bool
:
"""Return a boolean indicating whether tqdm progress bars are enabled."""
global
_tqdm_active
return
bool
(
_tqdm_active
)
def
enable_progress_bar
():
"""Enable tqdm progress bar."""
global
_tqdm_active
_tqdm_active
=
True
def
disable_progress_bar
():
"""Disable tqdm progress bar."""
global
_tqdm_active
_tqdm_active
=
False
tests/test_modeling_utils.py
View file @
4f761e95
...
@@ -24,6 +24,7 @@ import torch
...
@@ -24,6 +24,7 @@ import torch
from
diffusers
import
GaussianDDPMScheduler
,
UNetModel
from
diffusers
import
GaussianDDPMScheduler
,
UNetModel
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.configuration_utils
import
ConfigMixin
from
models.vision.ddpm.modeling_ddpm
import
DDPM
from
models.vision.ddpm.modeling_ddpm
import
DDPM
from
models.vision.ddim.modeling_ddim
import
DDIM
from
models.vision.ddim.modeling_ddim
import
DDIM
...
@@ -78,6 +79,45 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None):
...
@@ -78,6 +79,45 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None):
return
torch
.
tensor
(
data
=
values
,
dtype
=
torch
.
float
).
view
(
shape
).
contiguous
()
return
torch
.
tensor
(
data
=
values
,
dtype
=
torch
.
float
).
view
(
shape
).
contiguous
()
class
ConfigTester
(
unittest
.
TestCase
):
def
test_load_not_from_mixin
(
self
):
with
self
.
assertRaises
(
ValueError
):
ConfigMixin
.
from_config
(
"dummy_path"
)
def
test_save_load
(
self
):
class
SampleObject
(
ConfigMixin
):
config_name
=
"config.json"
def
__init__
(
self
,
a
=
2
,
b
=
5
,
c
=
(
2
,
5
),
d
=
"for diffusion"
,
e
=
[
1
,
3
],
):
self
.
register
(
a
=
a
,
b
=
b
,
c
=
c
,
d
=
d
,
e
=
e
)
obj
=
SampleObject
()
config
=
obj
.
config
assert
config
[
"a"
]
==
2
assert
config
[
"b"
]
==
5
assert
config
[
"c"
]
==
(
2
,
5
)
assert
config
[
"d"
]
==
"for diffusion"
assert
config
[
"e"
]
==
[
1
,
3
]
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
obj
.
save_config
(
tmpdirname
)
new_obj
=
SampleObject
.
from_config
(
tmpdirname
)
new_config
=
new_obj
.
config
assert
config
.
pop
(
"c"
)
==
(
2
,
5
)
# instantiated as tuple
assert
new_config
.
pop
(
"c"
)
==
[
2
,
5
]
# saved & loaded as list because of json
assert
config
==
new_config
class
ModelTesterMixin
(
unittest
.
TestCase
):
class
ModelTesterMixin
(
unittest
.
TestCase
):
@
property
@
property
def
dummy_input
(
self
):
def
dummy_input
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment