Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
fe99460b
Commit
fe99460b
authored
Jun 07, 2022
by
patil-suraj
Browse files
update config dict logic
parent
a61a9613
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
27 additions
and
17 deletions
+27
-17
src/diffusers/configuration_utils.py
src/diffusers/configuration_utils.py
+19
-10
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+8
-7
No files found.
src/diffusers/configuration_utils.py
View file @
fe99460b
...
@@ -89,6 +89,7 @@ class ConfigMixin:
...
@@ -89,6 +89,7 @@ class ConfigMixin:
self
.
to_json_file
(
output_config_file
)
self
.
to_json_file
(
output_config_file
)
logger
.
info
(
f
"ConfigMixinuration saved in
{
output_config_file
}
"
)
logger
.
info
(
f
"ConfigMixinuration saved in
{
output_config_file
}
"
)
@
classmethod
@
classmethod
def
get_config_dict
(
def
get_config_dict
(
...
@@ -182,35 +183,43 @@ class ConfigMixin:
...
@@ -182,35 +183,43 @@ class ConfigMixin:
logger
.
info
(
f
"loading configuration file
{
config_file
}
"
)
logger
.
info
(
f
"loading configuration file
{
config_file
}
"
)
else
:
else
:
logger
.
info
(
f
"loading configuration file
{
config_file
}
from cache at
{
resolved_config_file
}
"
)
logger
.
info
(
f
"loading configuration file
{
config_file
}
from cache at
{
resolved_config_file
}
"
)
return
config_dict
@
classmethod
def
extract_init_dict
(
cls
,
config_dict
,
**
kwargs
):
expected_keys
=
set
(
dict
(
inspect
.
signature
(
cls
.
__init__
).
parameters
).
keys
())
expected_keys
=
set
(
dict
(
inspect
.
signature
(
cls
.
__init__
).
parameters
).
keys
())
expected_keys
.
remove
(
"self"
)
expected_keys
.
remove
(
"self"
)
import
ipdb
;
ipdb
.
set_trace
()
init_dict
=
{}
for
key
in
expected_keys
:
for
key
in
expected_keys
:
if
key
in
kwargs
:
if
key
in
kwargs
:
# overwrite key
# overwrite key
config_dict
[
key
]
=
kwargs
.
pop
(
key
)
init_dict
[
key
]
=
kwargs
.
pop
(
key
)
elif
key
in
config_dict
:
# use value from config dict
init_dict
[
key
]
=
config_dict
.
pop
(
key
)
passed_keys
=
set
(
config_dict
.
keys
())
unused_kwargs
=
kwargs
for
key
in
passed_keys
-
expected_keys
:
unused_kwargs
[
key
]
=
config_dict
.
pop
(
key
)
unused_kwargs
=
config_dict
.
update
(
kwargs
)
passed_keys
=
set
(
init_dict
.
keys
())
if
len
(
expected_keys
-
passed_keys
)
>
0
:
if
len
(
expected_keys
-
passed_keys
)
>
0
:
logger
.
warn
(
logger
.
warn
(
f
"
{
expected_keys
-
passed_keys
}
was not found in config. Values will be initialized to default values."
f
"
{
expected_keys
-
passed_keys
}
was not found in config. Values will be initialized to default values."
)
)
return
config
_dict
,
unused_kwargs
return
init
_dict
,
unused_kwargs
@
classmethod
@
classmethod
def
from_config
(
cls
,
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
return_unused_kwargs
=
False
,
**
kwargs
):
def
from_config
(
cls
,
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
return_unused_kwargs
=
False
,
**
kwargs
):
config_dict
,
unused_kwargs
=
cls
.
get_config_dict
(
config_dict
=
cls
.
get_config_dict
(
pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
**
kwargs
pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
**
kwargs
)
)
model
=
cls
(
**
config_dict
)
init_dict
,
unused_kwargs
=
cls
.
extract_init_dict
(
config_dict
,
**
kwargs
)
model
=
cls
(
**
init_dict
)
if
return_unused_kwargs
:
if
return_unused_kwargs
:
return
model
,
unused_kwargs
return
model
,
unused_kwargs
...
...
src/diffusers/pipeline_utils.py
View file @
fe99460b
...
@@ -97,16 +97,17 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -97,16 +97,17 @@ class DiffusionPipeline(ConfigMixin):
else
:
else
:
cached_folder
=
pretrained_model_name_or_path
cached_folder
=
pretrained_model_name_or_path
config_dict
,
pipeline_kwargs
=
cls
.
get_config_dict
(
cached_folder
)
config_dict
=
cls
.
get_config_dict
(
cached_folder
)
module
=
config_dict
[
"_module"
]
class_name_
=
config_dict
[
"_class_name"
]
class_obj
=
get_class_from_dynamic_module
(
cached_folder
,
module
,
class_name_
,
cached_folder
)
module
=
pipeline_kwargs
.
pop
(
"_module"
,
None
)
init_dict
,
unused
=
class_obj
.
extract_init_dict
(
config_dict
,
**
kwargs
)
# TODO(Suraj) - make from hub import work
import
ipdb
;
ipdb
.
set_trace
()
# Make `ddpm = DiffusionPipeline.from_pretrained("fusing/ddpm-lsun-bedroom-pipe")` work
# Add Sylvains code from transformers
init_kwargs
=
{}
init_kwargs
=
{}
for
name
,
(
library_name
,
class_name
)
in
config
_dict
.
items
():
for
name
,
(
library_name
,
class_name
)
in
init
_dict
.
items
():
importable_classes
=
LOADABLE_CLASSES
[
library_name
]
importable_classes
=
LOADABLE_CLASSES
[
library_name
]
if
library_name
==
module
:
if
library_name
==
module
:
...
@@ -131,6 +132,6 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -131,6 +132,6 @@ class DiffusionPipeline(ConfigMixin):
init_kwargs
[
name
]
=
loaded_sub_model
# UNet(...), # DiffusionSchedule(...)
init_kwargs
[
name
]
=
loaded_sub_model
# UNet(...), # DiffusionSchedule(...)
class_obj
=
get_class_from_dynamic_module
(
cached_folder
,
module
,
class_name_
,
cached_folder
)
model
=
class_obj
(
**
init_kwargs
)
model
=
class_obj
(
**
init_kwargs
)
return
model
return
model
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment