Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
46c52f9b
Unverified
Commit
46c52f9b
authored
Apr 13, 2023
by
Patrick von Platen
Committed by
GitHub
Apr 13, 2023
Browse files
[Pipelines] Make sure that None functions are correctly not saved (#3080)
parent
d06e0694
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
6 deletions
+19
-6
src/diffusers/pipelines/pipeline_utils.py
src/diffusers/pipelines/pipeline_utils.py
+19
-6
No files found.
src/diffusers/pipelines/pipeline_utils.py
View file @
46c52f9b
...
...
@@ -19,6 +19,7 @@ import importlib
import
inspect
import
os
import
re
import
sys
import
warnings
from
dataclasses
import
dataclass
from
pathlib
import
Path
...
...
@@ -540,11 +541,9 @@ class DiffusionPipeline(ConfigMixin):
variant (`str`, *optional*):
If specified, weights are saved in the format pytorch_model.<variant>.bin.
"""
self
.
save_config
(
save_directory
)
model_index_dict
=
dict
(
self
.
config
)
model_index_dict
.
pop
(
"_class_name"
)
model_index_dict
.
pop
(
"_diffusers_version"
)
model_index_dict
.
pop
(
"_class_name"
,
None
)
model_index_dict
.
pop
(
"_diffusers_version"
,
None
)
model_index_dict
.
pop
(
"_module"
,
None
)
expected_modules
,
optional_kwargs
=
self
.
_get_signature_keys
(
self
)
...
...
@@ -557,7 +556,6 @@ class DiffusionPipeline(ConfigMixin):
return
True
model_index_dict
=
{
k
:
v
for
k
,
v
in
model_index_dict
.
items
()
if
is_saveable_module
(
k
,
v
)}
for
pipeline_component_name
in
model_index_dict
.
keys
():
sub_model
=
getattr
(
self
,
pipeline_component_name
)
model_cls
=
sub_model
.
__class__
...
...
@@ -571,7 +569,13 @@ class DiffusionPipeline(ConfigMixin):
save_method_name
=
None
# search for the model's base class in LOADABLE_CLASSES
for
library_name
,
library_classes
in
LOADABLE_CLASSES
.
items
():
library
=
importlib
.
import_module
(
library_name
)
if
library_name
in
sys
.
modules
:
library
=
importlib
.
import_module
(
library_name
)
else
:
logger
.
info
(
f
"
{
library_name
}
is not installed. Cannot save
{
pipeline_component_name
}
as
{
library_classes
}
from
{
library_name
}
"
)
for
base_class
,
save_load_methods
in
library_classes
.
items
():
class_candidate
=
getattr
(
library
,
base_class
,
None
)
if
class_candidate
is
not
None
and
issubclass
(
model_cls
,
class_candidate
):
...
...
@@ -581,6 +585,12 @@ class DiffusionPipeline(ConfigMixin):
if
save_method_name
is
not
None
:
break
if
save_method_name
is
None
:
logger
.
warn
(
f
"self.
{
pipeline_component_name
}
=
{
sub_model
}
of type
{
type
(
sub_model
)
}
cannot be saved."
)
# make sure that unsaveable components are not tried to be loaded afterward
self
.
register_to_config
(
**
{
pipeline_component_name
:
(
None
,
None
)})
continue
save_method
=
getattr
(
sub_model
,
save_method_name
)
# Call the save method with the argument safe_serialization only if it's supported
...
...
@@ -596,6 +606,9 @@ class DiffusionPipeline(ConfigMixin):
save_method
(
os
.
path
.
join
(
save_directory
,
pipeline_component_name
),
**
save_kwargs
)
# finally save the config
self
.
save_config
(
save_directory
)
def
to
(
self
,
torch_device
:
Optional
[
Union
[
str
,
torch
.
device
]]
=
None
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment