Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
49257b4a
Commit
49257b4a
authored
Jun 09, 2022
by
Patrick von Platen
Browse files
finish transformers removal
parent
09e1b0b4
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
104 additions
and
76 deletions
+104
-76
setup.py
setup.py
+19
-32
src/diffusers/configuration_utils.py
src/diffusers/configuration_utils.py
+33
-34
src/diffusers/modeling_utils.py
src/diffusers/modeling_utils.py
+8
-8
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+3
-2
src/diffusers/utils/__init__.py
src/diffusers/utils/__init__.py
+1
-0
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+40
-0
No files found.
setup.py
View file @
49257b4a
# Copyright 202
1
The HuggingFace Team. All rights reserved.
# Copyright 202
2
The HuggingFace Team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -52,11 +52,11 @@ To create the package for pypi.
...
@@ -52,11 +52,11 @@ To create the package for pypi.
twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/
twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/
Check that you can install it in a virtualenv by running:
Check that you can install it in a virtualenv by running:
pip install -i https://testpypi.python.org/pypi
transform
ers
pip install -i https://testpypi.python.org/pypi
diffus
ers
Check you can run the following commands:
Check you can run the following commands:
python -c "from
transform
ers import pipeline; classifier = pipeline('text-classification'); print(classifier('What a nice release'))"
python -c "from
diffus
ers import pipeline; classifier = pipeline('text-classification'); print(classifier('What a nice release'))"
python -c "from
transform
ers import *"
python -c "from
diffus
ers import *"
9. Upload the final version to actual pypi:
9. Upload the final version to actual pypi:
twine upload dist/* -r pypi
twine upload dist/* -r pypi
...
@@ -77,36 +77,21 @@ from setuptools import find_packages, setup
...
@@ -77,36 +77,21 @@ from setuptools import find_packages, setup
# 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py
# 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py
_deps
=
[
_deps
=
[
"Pillow"
,
"Pillow"
,
"accelerate>=0.9.0"
,
"black~=22.0,>=22.3"
,
"black~=22.0,>=22.3"
,
"codecarbon==1.2.0"
,
"filelock"
,
"dataclasses"
,
"flake8>=3.8.3"
,
"datasets"
,
"huggingface-hub"
,
"GitPython<3.1.19"
,
"hf-doc-builder>=0.3.0"
,
"huggingface-hub>=0.1.0,<1.0"
,
"importlib_metadata"
,
"isort>=5.5.4"
,
"isort>=5.5.4"
,
"numpy
>=1.17
"
,
"numpy"
,
"pytest"
,
"pytest"
,
"pytest-timeout"
,
"pytest-xdist"
,
"python>=3.7.0"
,
"regex!=2019.12.17"
,
"requests"
,
"requests"
,
"sagemaker>=2.31.0"
,
"tokenizers>=0.11.1,!=0.11.3,<0.13"
,
"torch>=1.4"
,
"torch>=1.4"
,
"torchaudio"
,
"torchvision"
,
"tqdm>=4.27"
,
"unidic>=1.0.2"
,
"unidic_lite>=1.0.7"
,
"uvicorn"
,
]
]
# this is a lookup table with items like:
# this is a lookup table with items like:
#
#
# tokenizers: "
tokenizers
==0.
9.4
"
# tokenizers: "
huggingface-hub
==0.
8.0
"
# packaging: "packaging"
# packaging: "packaging"
#
#
# some of the values are versioned whereas others aren't.
# some of the values are versioned whereas others aren't.
...
@@ -176,15 +161,17 @@ extras["quality"] = ["black ~= 22.0", "isort >= 5.5.4", "flake8 >= 3.8.3"]
...
@@ -176,15 +161,17 @@ extras["quality"] = ["black ~= 22.0", "isort >= 5.5.4", "flake8 >= 3.8.3"]
extras
[
"docs"
]
=
[]
extras
[
"docs"
]
=
[]
extras
[
"test"
]
=
[
extras
[
"test"
]
=
[
"pytest"
,
"pytest"
,
"pytest-xdist"
,
"pytest-subtests"
,
"datasets"
,
"transformers"
,
]
]
extras
[
"dev"
]
=
extras
[
"quality"
]
+
extras
[
"test"
]
extras
[
"dev"
]
=
extras
[
"quality"
]
+
extras
[
"test"
]
extras
[
"sagemaker"
]
=
[
install_requires
=
[
"sagemaker"
,
# boto3 is a required package in sagemaker
deps
[
"filelock"
],
deps
[
"huggingface-hub"
],
deps
[
"numpy"
],
deps
[
"requests"
],
deps
[
"torch"
],
deps
[
"torchvision"
],
deps
[
"Pillow"
],
]
]
setup
(
setup
(
...
@@ -201,7 +188,7 @@ setup(
...
@@ -201,7 +188,7 @@ setup(
package_dir
=
{
""
:
"src"
},
package_dir
=
{
""
:
"src"
},
packages
=
find_packages
(
"src"
),
packages
=
find_packages
(
"src"
),
python_requires
=
">=3.6.0"
,
python_requires
=
">=3.6.0"
,
install_requires
=
[
"numpy>=1.17"
,
"packaging>=20.0"
,
"pyyaml"
,
"torch>=1.4.0"
]
,
install_requires
=
install_requires
,
extras_require
=
extras
,
extras_require
=
extras
,
classifiers
=
[
classifiers
=
[
"Development Status :: 5 - Production/Stable"
,
"Development Status :: 5 - Production/Stable"
,
...
...
src/diffusers/configuration_utils.py
View file @
49257b4a
...
@@ -57,6 +57,8 @@ class ConfigMixin:
...
@@ -57,6 +57,8 @@ class ConfigMixin:
if
self
.
config_name
is
None
:
if
self
.
config_name
is
None
:
raise
NotImplementedError
(
f
"Make sure that
{
self
.
__class__
}
has defined a class name `config_name`"
)
raise
NotImplementedError
(
f
"Make sure that
{
self
.
__class__
}
has defined a class name `config_name`"
)
kwargs
[
"_class_name"
]
=
self
.
__class__
.
__name__
kwargs
[
"_class_name"
]
=
self
.
__class__
.
__name__
kwargs
[
"_diffusers_version"
]
=
__version__
for
key
,
value
in
kwargs
.
items
():
for
key
,
value
in
kwargs
.
items
():
try
:
try
:
setattr
(
self
,
key
,
value
)
setattr
(
self
,
key
,
value
)
...
@@ -91,6 +93,21 @@ class ConfigMixin:
...
@@ -91,6 +93,21 @@ class ConfigMixin:
self
.
to_json_file
(
output_config_file
)
self
.
to_json_file
(
output_config_file
)
logger
.
info
(
f
"ConfigMixinuration saved in
{
output_config_file
}
"
)
logger
.
info
(
f
"ConfigMixinuration saved in
{
output_config_file
}
"
)
@
classmethod
def
from_config
(
cls
,
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
return_unused_kwargs
=
False
,
**
kwargs
):
config_dict
=
cls
.
get_config_dict
(
pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
**
kwargs
)
init_dict
,
unused_kwargs
=
cls
.
extract_init_dict
(
config_dict
,
**
kwargs
)
model
=
cls
(
**
init_dict
)
if
return_unused_kwargs
:
return
model
,
unused_kwargs
else
:
return
model
@
classmethod
@
classmethod
def
get_config_dict
(
def
get_config_dict
(
cls
,
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
**
kwargs
cls
,
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
**
kwargs
...
@@ -107,6 +124,12 @@ class ConfigMixin:
...
@@ -107,6 +124,12 @@ class ConfigMixin:
pretrained_model_name_or_path
=
str
(
pretrained_model_name_or_path
)
pretrained_model_name_or_path
=
str
(
pretrained_model_name_or_path
)
if
cls
.
config_name
is
None
:
raise
ValueError
(
"`self.config_name` is not defined. Note that one should not load a config from "
"`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
)
if
os
.
path
.
isfile
(
pretrained_model_name_or_path
):
if
os
.
path
.
isfile
(
pretrained_model_name_or_path
):
config_file
=
pretrained_model_name_or_path
config_file
=
pretrained_model_name_or_path
elif
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
elif
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
...
@@ -168,13 +191,13 @@ class ConfigMixin:
...
@@ -168,13 +191,13 @@ class ConfigMixin:
f
"containing a
{
cls
.
config_name
}
file"
f
"containing a
{
cls
.
config_name
}
file"
)
)
try
:
try
:
# Load config dict
# Load config dict
config_dict
=
cls
.
_dict_from_json_file
(
config_file
)
config_dict
=
cls
.
_dict_from_json_file
(
config_file
)
except
(
json
.
JSONDecodeError
,
UnicodeDecodeError
):
except
(
json
.
JSONDecodeError
,
UnicodeDecodeError
):
raise
EnvironmentError
(
raise
EnvironmentError
(
f
"It looks like the config file at '
{
config_file
}
' is not a valid JSON file."
f
"It looks like the config file at '
{
config_file
}
' is not a valid JSON file."
)
)
return
config_dict
return
config_dict
...
@@ -200,21 +223,6 @@ class ConfigMixin:
...
@@ -200,21 +223,6 @@ class ConfigMixin:
return
init_dict
,
unused_kwargs
return
init_dict
,
unused_kwargs
@
classmethod
def
from_config
(
cls
,
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
return_unused_kwargs
=
False
,
**
kwargs
):
config_dict
=
cls
.
get_config_dict
(
pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
**
kwargs
)
init_dict
,
unused_kwargs
=
cls
.
extract_init_dict
(
config_dict
,
**
kwargs
)
model
=
cls
(
**
init_dict
)
if
return_unused_kwargs
:
return
model
,
unused_kwargs
else
:
return
model
@
classmethod
@
classmethod
def
_dict_from_json_file
(
cls
,
json_file
:
Union
[
str
,
os
.
PathLike
]):
def
_dict_from_json_file
(
cls
,
json_file
:
Union
[
str
,
os
.
PathLike
]):
with
open
(
json_file
,
"r"
,
encoding
=
"utf-8"
)
as
reader
:
with
open
(
json_file
,
"r"
,
encoding
=
"utf-8"
)
as
reader
:
...
@@ -227,18 +235,9 @@ class ConfigMixin:
...
@@ -227,18 +235,9 @@ class ConfigMixin:
def
__repr__
(
self
):
def
__repr__
(
self
):
return
f
"
{
self
.
__class__
.
__name__
}
{
self
.
to_json_string
()
}
"
return
f
"
{
self
.
__class__
.
__name__
}
{
self
.
to_json_string
()
}
"
def
to_dict
(
self
)
->
Dict
[
str
,
Any
]:
@
property
"""
def
config
(
self
)
->
Dict
[
str
,
Any
]:
Serializes this instance to a Python dictionary.
output
=
copy
.
deepcopy
(
self
.
_dict_to_save
)
Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
output
=
copy
.
deepcopy
(
self
.
__dict__
)
# Diffusion version when serializing the model
output
[
"diffusers_version"
]
=
__version__
return
output
return
output
def
to_json_string
(
self
)
->
str
:
def
to_json_string
(
self
)
->
str
:
...
...
src/diffusers/modeling_utils.py
View file @
49257b4a
...
@@ -401,14 +401,14 @@ class ModelMixin(torch.nn.Module):
...
@@ -401,14 +401,14 @@ class ModelMixin(torch.nn.Module):
)
)
# restore default dtype
# restore default dtype
state_dict
=
load_state_dict
(
model_file
)
state_dict
=
load_state_dict
(
model_file
)
model
,
missing_keys
,
unexpected_keys
,
mismatched_keys
,
error_msgs
=
cls
.
_load_pretrained_model
(
model
,
missing_keys
,
unexpected_keys
,
mismatched_keys
,
error_msgs
=
cls
.
_load_pretrained_model
(
model
,
model
,
state_dict
,
state_dict
,
model_file
,
model_file
,
pretrained_model_name_or_path
,
pretrained_model_name_or_path
,
ignore_mismatched_sizes
=
ignore_mismatched_sizes
,
ignore_mismatched_sizes
=
ignore_mismatched_sizes
,
)
)
# Set model in evaluation mode to deactivate DropOut modules by default
# Set model in evaluation mode to deactivate DropOut modules by default
model
.
eval
()
model
.
eval
()
...
...
src/diffusers/pipeline_utils.py
View file @
49257b4a
...
@@ -67,11 +67,12 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -67,11 +67,12 @@ class DiffusionPipeline(ConfigMixin):
def
save_pretrained
(
self
,
save_directory
:
Union
[
str
,
os
.
PathLike
]):
def
save_pretrained
(
self
,
save_directory
:
Union
[
str
,
os
.
PathLike
]):
self
.
save_config
(
save_directory
)
self
.
save_config
(
save_directory
)
model_index_dict
=
self
.
_dict_to_save
model_index_dict
=
self
.
config
model_index_dict
.
pop
(
"_class_name"
)
model_index_dict
.
pop
(
"_class_name"
)
model_index_dict
.
pop
(
"_diffusers_version"
)
model_index_dict
.
pop
(
"_module"
)
model_index_dict
.
pop
(
"_module"
)
for
name
,
(
library_name
,
class_name
)
in
self
.
_dict_to_save
.
items
():
for
name
,
(
library_name
,
class_name
)
in
model_index_dict
.
items
():
importable_classes
=
LOADABLE_CLASSES
[
library_name
]
importable_classes
=
LOADABLE_CLASSES
[
library_name
]
# TODO: Suraj
# TODO: Suraj
...
...
src/diffusers/utils/__init__.py
View file @
49257b4a
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
requests.exceptions
import
HTTPError
from
requests.exceptions
import
HTTPError
import
os
hf_cache_home
=
os
.
path
.
expanduser
(
hf_cache_home
=
os
.
path
.
expanduser
(
os
.
getenv
(
"HF_HOME"
,
os
.
path
.
join
(
os
.
getenv
(
"XDG_CACHE_HOME"
,
"~/.cache"
),
"huggingface"
))
os
.
getenv
(
"HF_HOME"
,
os
.
path
.
join
(
os
.
getenv
(
"XDG_CACHE_HOME"
,
"~/.cache"
),
"huggingface"
))
...
...
tests/test_modeling_utils.py
View file @
49257b4a
...
@@ -24,6 +24,7 @@ import torch
...
@@ -24,6 +24,7 @@ import torch
from
diffusers
import
GaussianDDPMScheduler
,
UNetModel
from
diffusers
import
GaussianDDPMScheduler
,
UNetModel
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.configuration_utils
import
ConfigMixin
from
models.vision.ddpm.modeling_ddpm
import
DDPM
from
models.vision.ddpm.modeling_ddpm
import
DDPM
...
@@ -77,6 +78,45 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None):
...
@@ -77,6 +78,45 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None):
return
torch
.
tensor
(
data
=
values
,
dtype
=
torch
.
float
).
view
(
shape
).
contiguous
()
return
torch
.
tensor
(
data
=
values
,
dtype
=
torch
.
float
).
view
(
shape
).
contiguous
()
class
ConfigTester
(
unittest
.
TestCase
):
def
test_load_not_from_mixin
(
self
):
with
self
.
assertRaises
(
ValueError
):
ConfigMixin
.
from_config
(
"dummy_path"
)
def
test_save_load
(
self
):
class
SampleObject
(
ConfigMixin
):
config_name
=
"config.json"
def
__init__
(
self
,
a
=
2
,
b
=
5
,
c
=
(
2
,
5
),
d
=
"for diffusion"
,
e
=
[
1
,
3
],
):
self
.
register
(
a
=
a
,
b
=
b
,
c
=
c
,
d
=
d
,
e
=
e
)
obj
=
SampleObject
()
config
=
obj
.
config
assert
config
[
"a"
]
==
2
assert
config
[
"b"
]
==
5
assert
config
[
"c"
]
==
(
2
,
5
)
assert
config
[
"d"
]
==
"for diffusion"
assert
config
[
"e"
]
==
[
1
,
3
]
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
obj
.
save_config
(
tmpdirname
)
new_obj
=
SampleObject
.
from_config
(
tmpdirname
)
new_config
=
new_obj
.
config
assert
config
.
pop
(
"c"
)
==
(
2
,
5
)
# instantiated as tuple
assert
new_config
.
pop
(
"c"
)
==
[
2
,
5
]
# saved & loaded as list because of json
assert
config
==
new_config
class
ModelTesterMixin
(
unittest
.
TestCase
):
class
ModelTesterMixin
(
unittest
.
TestCase
):
@
property
@
property
def
dummy_input
(
self
):
def
dummy_input
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment