Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
fe7d1363
Commit
fe7d1363
authored
Jun 17, 2022
by
Patrick von Platen
Browse files
correct dict
parent
e660a05f
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
52 additions
and
11 deletions
+52
-11
src/diffusers/configuration_utils.py
src/diffusers/configuration_utils.py
+46
-9
src/diffusers/modeling_utils.py
src/diffusers/modeling_utils.py
+1
-1
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+1
-1
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+4
-0
No files found.
src/diffusers/configuration_utils.py
View file @
fe7d1363
...
...
@@ -14,13 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" ConfigMixinuration base class and utilities."""
import
copy
import
inspect
import
json
import
os
import
re
from
collections
import
OrderedDict
from
typing
import
Any
,
Dict
,
Tuple
,
Union
from
huggingface_hub
import
hf_hub_download
...
...
@@ -63,10 +61,14 @@ class ConfigMixin:
logger
.
error
(
f
"Can't set
{
key
}
with value
{
value
}
for
{
self
}
"
)
raise
err
if
not
hasattr
(
self
,
"_dict_to_save"
):
self
.
_dict_to_save
=
{}
if
not
hasattr
(
self
,
"_internal_dict"
):
internal_dict
=
kwargs
else
:
previous_dict
=
dict
(
self
.
_internal_dict
)
internal_dict
=
{
**
self
.
_internal_dict
,
**
kwargs
}
logger
.
debug
(
f
"Updating config from
{
previous_dict
}
to
{
internal_dict
}
"
)
self
.
_
dict_to_save
.
update
(
kwargs
)
self
.
_
internal_dict
=
FrozenDict
(
internal_dict
)
def
save_config
(
self
,
save_directory
:
Union
[
str
,
os
.
PathLike
],
push_to_hub
:
bool
=
False
,
**
kwargs
):
"""
...
...
@@ -230,8 +232,7 @@ class ConfigMixin:
@
property
def
config
(
self
)
->
Dict
[
str
,
Any
]:
output
=
copy
.
deepcopy
(
self
.
_dict_to_save
)
return
output
return
self
.
_internal_dict
def
to_json_string
(
self
)
->
str
:
"""
...
...
@@ -240,7 +241,7 @@ class ConfigMixin:
Returns:
`str`: String containing all the attributes that make up this configuration instance in JSON format.
"""
config_dict
=
self
.
_
dict_to_save
config_dict
=
self
.
_
internal_dict
return
json
.
dumps
(
config_dict
,
indent
=
2
,
sort_keys
=
True
)
+
"
\n
"
def
to_json_file
(
self
,
json_file_path
:
Union
[
str
,
os
.
PathLike
]):
...
...
@@ -253,3 +254,39 @@ class ConfigMixin:
"""
with
open
(
json_file_path
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
writer
.
write
(
self
.
to_json_string
())
class
FrozenDict
(
OrderedDict
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
# remove `None`
args
=
(
a
for
a
in
args
if
a
is
not
None
)
kwargs
=
{
k
:
v
for
k
,
v
in
kwargs
if
v
is
not
None
}
super
().
__init__
(
*
args
,
**
kwargs
)
for
key
,
value
in
self
.
items
():
setattr
(
self
,
key
,
value
)
self
.
__frozen
=
True
def
__delitem__
(
self
,
*
args
,
**
kwargs
):
raise
Exception
(
f
"You cannot use ``__delitem__`` on a
{
self
.
__class__
.
__name__
}
instance."
)
def
setdefault
(
self
,
*
args
,
**
kwargs
):
raise
Exception
(
f
"You cannot use ``setdefault`` on a
{
self
.
__class__
.
__name__
}
instance."
)
def
pop
(
self
,
*
args
,
**
kwargs
):
raise
Exception
(
f
"You cannot use ``pop`` on a
{
self
.
__class__
.
__name__
}
instance."
)
def
update
(
self
,
*
args
,
**
kwargs
):
raise
Exception
(
f
"You cannot use ``update`` on a
{
self
.
__class__
.
__name__
}
instance."
)
def
__setattr__
(
self
,
name
,
value
):
if
hasattr
(
self
,
"__frozen"
)
and
self
.
__frozen
:
raise
Exception
(
f
"You cannot use ``__setattr__`` on a
{
self
.
__class__
.
__name__
}
instance."
)
super
().
__setattr__
(
name
,
value
)
def
__setitem__
(
self
,
name
,
value
):
if
hasattr
(
self
,
"__frozen"
)
and
self
.
__frozen
:
raise
Exception
(
f
"You cannot use ``__setattr__`` on a
{
self
.
__class__
.
__name__
}
instance."
)
super
().
__setitem__
(
name
,
value
)
src/diffusers/modeling_utils.py
View file @
fe7d1363
...
...
@@ -338,7 +338,7 @@ class ModelMixin(torch.nn.Module):
revision
=
revision
,
**
kwargs
,
)
model
.
register
(
name_or_path
=
pretrained_model_name_or_path
)
model
.
register
_to_config
(
name_or_path
=
pretrained_model_name_or_path
)
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# Load model
pretrained_model_name_or_path
=
str
(
pretrained_model_name_or_path
)
...
...
src/diffusers/pipeline_utils.py
View file @
fe7d1363
...
...
@@ -88,7 +88,7 @@ class DiffusionPipeline(ConfigMixin):
def
save_pretrained
(
self
,
save_directory
:
Union
[
str
,
os
.
PathLike
]):
self
.
save_config
(
save_directory
)
model_index_dict
=
self
.
config
model_index_dict
=
dict
(
self
.
config
)
model_index_dict
.
pop
(
"_class_name"
)
model_index_dict
.
pop
(
"_diffusers_version"
)
model_index_dict
.
pop
(
"_module"
)
...
...
tests/test_modeling_utils.py
View file @
fe7d1363
...
...
@@ -73,6 +73,10 @@ class ConfigTester(unittest.TestCase):
new_obj
=
SampleObject
.
from_config
(
tmpdirname
)
new_config
=
new_obj
.
config
# unfreeze configs
config
=
dict
(
config
)
new_config
=
dict
(
new_config
)
assert
config
.
pop
(
"c"
)
==
(
2
,
5
)
# instantiated as tuple
assert
new_config
.
pop
(
"c"
)
==
[
2
,
5
]
# saved & loaded as list because of json
assert
config
==
new_config
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment