Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
fe7d1363
Commit
fe7d1363
authored
Jun 17, 2022
by
Patrick von Platen
Browse files
correct dict
parent
e660a05f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
52 additions
and
11 deletions
+52
-11
src/diffusers/configuration_utils.py
src/diffusers/configuration_utils.py
+46
-9
src/diffusers/modeling_utils.py
src/diffusers/modeling_utils.py
+1
-1
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+1
-1
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+4
-0
No files found.
src/diffusers/configuration_utils.py
View file @
fe7d1363
...
@@ -14,13 +14,11 @@
...
@@ -14,13 +14,11 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
""" ConfigMixinuration base class and utilities."""
""" ConfigMixinuration base class and utilities."""
import
copy
import
inspect
import
inspect
import
json
import
json
import
os
import
os
import
re
import
re
from
collections
import
OrderedDict
from
typing
import
Any
,
Dict
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Tuple
,
Union
from
huggingface_hub
import
hf_hub_download
from
huggingface_hub
import
hf_hub_download
...
@@ -63,10 +61,14 @@ class ConfigMixin:
...
@@ -63,10 +61,14 @@ class ConfigMixin:
logger
.
error
(
f
"Can't set
{
key
}
with value
{
value
}
for
{
self
}
"
)
logger
.
error
(
f
"Can't set
{
key
}
with value
{
value
}
for
{
self
}
"
)
raise
err
raise
err
if
not
hasattr
(
self
,
"_dict_to_save"
):
if
not
hasattr
(
self
,
"_internal_dict"
):
self
.
_dict_to_save
=
{}
internal_dict
=
kwargs
else
:
previous_dict
=
dict
(
self
.
_internal_dict
)
internal_dict
=
{
**
self
.
_internal_dict
,
**
kwargs
}
logger
.
debug
(
f
"Updating config from
{
previous_dict
}
to
{
internal_dict
}
"
)
self
.
_
dict_to_save
.
update
(
kwargs
)
self
.
_
internal_dict
=
FrozenDict
(
internal_dict
)
def
save_config
(
self
,
save_directory
:
Union
[
str
,
os
.
PathLike
],
push_to_hub
:
bool
=
False
,
**
kwargs
):
def
save_config
(
self
,
save_directory
:
Union
[
str
,
os
.
PathLike
],
push_to_hub
:
bool
=
False
,
**
kwargs
):
"""
"""
...
@@ -230,8 +232,7 @@ class ConfigMixin:
...
@@ -230,8 +232,7 @@ class ConfigMixin:
@
property
@
property
def
config
(
self
)
->
Dict
[
str
,
Any
]:
def
config
(
self
)
->
Dict
[
str
,
Any
]:
output
=
copy
.
deepcopy
(
self
.
_dict_to_save
)
return
self
.
_internal_dict
return
output
def
to_json_string
(
self
)
->
str
:
def
to_json_string
(
self
)
->
str
:
"""
"""
...
@@ -240,7 +241,7 @@ class ConfigMixin:
...
@@ -240,7 +241,7 @@ class ConfigMixin:
Returns:
Returns:
`str`: String containing all the attributes that make up this configuration instance in JSON format.
`str`: String containing all the attributes that make up this configuration instance in JSON format.
"""
"""
config_dict
=
self
.
_
dict_to_save
config_dict
=
self
.
_
internal_dict
return
json
.
dumps
(
config_dict
,
indent
=
2
,
sort_keys
=
True
)
+
"
\n
"
return
json
.
dumps
(
config_dict
,
indent
=
2
,
sort_keys
=
True
)
+
"
\n
"
def
to_json_file
(
self
,
json_file_path
:
Union
[
str
,
os
.
PathLike
]):
def
to_json_file
(
self
,
json_file_path
:
Union
[
str
,
os
.
PathLike
]):
...
@@ -253,3 +254,39 @@ class ConfigMixin:
...
@@ -253,3 +254,39 @@ class ConfigMixin:
"""
"""
with
open
(
json_file_path
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
with
open
(
json_file_path
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
writer
.
write
(
self
.
to_json_string
())
writer
.
write
(
self
.
to_json_string
())
class
FrozenDict
(
OrderedDict
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
# remove `None`
args
=
(
a
for
a
in
args
if
a
is
not
None
)
kwargs
=
{
k
:
v
for
k
,
v
in
kwargs
if
v
is
not
None
}
super
().
__init__
(
*
args
,
**
kwargs
)
for
key
,
value
in
self
.
items
():
setattr
(
self
,
key
,
value
)
self
.
__frozen
=
True
def
__delitem__
(
self
,
*
args
,
**
kwargs
):
raise
Exception
(
f
"You cannot use ``__delitem__`` on a
{
self
.
__class__
.
__name__
}
instance."
)
def
setdefault
(
self
,
*
args
,
**
kwargs
):
raise
Exception
(
f
"You cannot use ``setdefault`` on a
{
self
.
__class__
.
__name__
}
instance."
)
def
pop
(
self
,
*
args
,
**
kwargs
):
raise
Exception
(
f
"You cannot use ``pop`` on a
{
self
.
__class__
.
__name__
}
instance."
)
def
update
(
self
,
*
args
,
**
kwargs
):
raise
Exception
(
f
"You cannot use ``update`` on a
{
self
.
__class__
.
__name__
}
instance."
)
def
__setattr__
(
self
,
name
,
value
):
if
hasattr
(
self
,
"__frozen"
)
and
self
.
__frozen
:
raise
Exception
(
f
"You cannot use ``__setattr__`` on a
{
self
.
__class__
.
__name__
}
instance."
)
super
().
__setattr__
(
name
,
value
)
def
__setitem__
(
self
,
name
,
value
):
if
hasattr
(
self
,
"__frozen"
)
and
self
.
__frozen
:
raise
Exception
(
f
"You cannot use ``__setattr__`` on a
{
self
.
__class__
.
__name__
}
instance."
)
super
().
__setitem__
(
name
,
value
)
src/diffusers/modeling_utils.py
View file @
fe7d1363
...
@@ -338,7 +338,7 @@ class ModelMixin(torch.nn.Module):
...
@@ -338,7 +338,7 @@ class ModelMixin(torch.nn.Module):
revision
=
revision
,
revision
=
revision
,
**
kwargs
,
**
kwargs
,
)
)
model
.
register
(
name_or_path
=
pretrained_model_name_or_path
)
model
.
register
_to_config
(
name_or_path
=
pretrained_model_name_or_path
)
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# Load model
# Load model
pretrained_model_name_or_path
=
str
(
pretrained_model_name_or_path
)
pretrained_model_name_or_path
=
str
(
pretrained_model_name_or_path
)
...
...
src/diffusers/pipeline_utils.py
View file @
fe7d1363
...
@@ -88,7 +88,7 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -88,7 +88,7 @@ class DiffusionPipeline(ConfigMixin):
def
save_pretrained
(
self
,
save_directory
:
Union
[
str
,
os
.
PathLike
]):
def
save_pretrained
(
self
,
save_directory
:
Union
[
str
,
os
.
PathLike
]):
self
.
save_config
(
save_directory
)
self
.
save_config
(
save_directory
)
model_index_dict
=
self
.
config
model_index_dict
=
dict
(
self
.
config
)
model_index_dict
.
pop
(
"_class_name"
)
model_index_dict
.
pop
(
"_class_name"
)
model_index_dict
.
pop
(
"_diffusers_version"
)
model_index_dict
.
pop
(
"_diffusers_version"
)
model_index_dict
.
pop
(
"_module"
)
model_index_dict
.
pop
(
"_module"
)
...
...
tests/test_modeling_utils.py
View file @
fe7d1363
...
@@ -73,6 +73,10 @@ class ConfigTester(unittest.TestCase):
...
@@ -73,6 +73,10 @@ class ConfigTester(unittest.TestCase):
new_obj
=
SampleObject
.
from_config
(
tmpdirname
)
new_obj
=
SampleObject
.
from_config
(
tmpdirname
)
new_config
=
new_obj
.
config
new_config
=
new_obj
.
config
# unfreeze configs
config
=
dict
(
config
)
new_config
=
dict
(
new_config
)
assert
config
.
pop
(
"c"
)
==
(
2
,
5
)
# instantiated as tuple
assert
config
.
pop
(
"c"
)
==
(
2
,
5
)
# instantiated as tuple
assert
new_config
.
pop
(
"c"
)
==
[
2
,
5
]
# saved & loaded as list because of json
assert
new_config
.
pop
(
"c"
)
==
[
2
,
5
]
# saved & loaded as list because of json
assert
config
==
new_config
assert
config
==
new_config
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment