Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
b93fe085
Unverified
Commit
b93fe085
authored
Nov 09, 2022
by
Patrick von Platen
Committed by
GitHub
Nov 09, 2022
Browse files
[Loading] Make sure loading edge cases work (#1192)
* [Loading] Make edge cases work * up * finish * up
parent
3f7edc5f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
61 additions
and
12 deletions
+61
-12
src/diffusers/pipeline_flax_utils.py
src/diffusers/pipeline_flax_utils.py
+8
-6
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+9
-6
tests/test_pipelines.py
tests/test_pipelines.py
+44
-0
No files found.
src/diffusers/pipeline_flax_utils.py
View file @
b93fe085
...
@@ -55,6 +55,8 @@ LOADABLE_CLASSES = {
...
@@ -55,6 +55,8 @@ LOADABLE_CLASSES = {
"PreTrainedTokenizerFast"
:
[
"save_pretrained"
,
"from_pretrained"
],
"PreTrainedTokenizerFast"
:
[
"save_pretrained"
,
"from_pretrained"
],
"FlaxPreTrainedModel"
:
[
"save_pretrained"
,
"from_pretrained"
],
"FlaxPreTrainedModel"
:
[
"save_pretrained"
,
"from_pretrained"
],
"FeatureExtractionMixin"
:
[
"save_pretrained"
,
"from_pretrained"
],
"FeatureExtractionMixin"
:
[
"save_pretrained"
,
"from_pretrained"
],
"ProcessorMixin"
:
[
"save_pretrained"
,
"from_pretrained"
],
"ImageProcessingMixin"
:
[
"save_pretrained"
,
"from_pretrained"
],
},
},
}
}
...
@@ -172,8 +174,8 @@ class FlaxDiffusionPipeline(ConfigMixin):
...
@@ -172,8 +174,8 @@ class FlaxDiffusionPipeline(ConfigMixin):
for
library_name
,
library_classes
in
LOADABLE_CLASSES
.
items
():
for
library_name
,
library_classes
in
LOADABLE_CLASSES
.
items
():
library
=
importlib
.
import_module
(
library_name
)
library
=
importlib
.
import_module
(
library_name
)
for
base_class
,
save_load_methods
in
library_classes
.
items
():
for
base_class
,
save_load_methods
in
library_classes
.
items
():
class_candidate
=
getattr
(
library
,
base_class
)
class_candidate
=
getattr
(
library
,
base_class
,
None
)
if
issubclass
(
model_cls
,
class_candidate
):
if
class_candidate
is
not
None
and
issubclass
(
model_cls
,
class_candidate
):
# if we found a suitable base class in LOADABLE_CLASSES then grab its save method
# if we found a suitable base class in LOADABLE_CLASSES then grab its save method
save_method_name
=
save_load_methods
[
0
]
save_method_name
=
save_load_methods
[
0
]
break
break
...
@@ -387,11 +389,11 @@ class FlaxDiffusionPipeline(ConfigMixin):
...
@@ -387,11 +389,11 @@ class FlaxDiffusionPipeline(ConfigMixin):
library
=
importlib
.
import_module
(
library_name
)
library
=
importlib
.
import_module
(
library_name
)
class_obj
=
getattr
(
library
,
class_name
)
class_obj
=
getattr
(
library
,
class_name
)
importable_classes
=
LOADABLE_CLASSES
[
library_name
]
importable_classes
=
LOADABLE_CLASSES
[
library_name
]
class_candidates
=
{
c
:
getattr
(
library
,
c
)
for
c
in
importable_classes
.
keys
()}
class_candidates
=
{
c
:
getattr
(
library
,
c
,
None
)
for
c
in
importable_classes
.
keys
()}
expected_class_obj
=
None
expected_class_obj
=
None
for
class_name
,
class_candidate
in
class_candidates
.
items
():
for
class_name
,
class_candidate
in
class_candidates
.
items
():
if
issubclass
(
class_obj
,
class_candidate
):
if
class_candidate
is
not
None
and
issubclass
(
class_obj
,
class_candidate
):
expected_class_obj
=
class_candidate
expected_class_obj
=
class_candidate
if
not
issubclass
(
passed_class_obj
[
name
].
__class__
,
expected_class_obj
):
if
not
issubclass
(
passed_class_obj
[
name
].
__class__
,
expected_class_obj
):
...
@@ -425,12 +427,12 @@ class FlaxDiffusionPipeline(ConfigMixin):
...
@@ -425,12 +427,12 @@ class FlaxDiffusionPipeline(ConfigMixin):
class_obj
=
import_flax_or_no_model
(
library
,
class_name
)
class_obj
=
import_flax_or_no_model
(
library
,
class_name
)
importable_classes
=
LOADABLE_CLASSES
[
library_name
]
importable_classes
=
LOADABLE_CLASSES
[
library_name
]
class_candidates
=
{
c
:
getattr
(
library
,
c
)
for
c
in
importable_classes
.
keys
()}
class_candidates
=
{
c
:
getattr
(
library
,
c
,
None
)
for
c
in
importable_classes
.
keys
()}
if
loaded_sub_model
is
None
and
sub_model_should_be_defined
:
if
loaded_sub_model
is
None
and
sub_model_should_be_defined
:
load_method_name
=
None
load_method_name
=
None
for
class_name
,
class_candidate
in
class_candidates
.
items
():
for
class_name
,
class_candidate
in
class_candidates
.
items
():
if
issubclass
(
class_obj
,
class_candidate
):
if
class_candidate
is
not
None
and
issubclass
(
class_obj
,
class_candidate
):
load_method_name
=
importable_classes
[
class_name
][
1
]
load_method_name
=
importable_classes
[
class_name
][
1
]
load_method
=
getattr
(
class_obj
,
load_method_name
)
load_method
=
getattr
(
class_obj
,
load_method_name
)
...
...
src/diffusers/pipeline_utils.py
View file @
b93fe085
...
@@ -74,6 +74,8 @@ LOADABLE_CLASSES = {
...
@@ -74,6 +74,8 @@ LOADABLE_CLASSES = {
"PreTrainedTokenizerFast"
:
[
"save_pretrained"
,
"from_pretrained"
],
"PreTrainedTokenizerFast"
:
[
"save_pretrained"
,
"from_pretrained"
],
"PreTrainedModel"
:
[
"save_pretrained"
,
"from_pretrained"
],
"PreTrainedModel"
:
[
"save_pretrained"
,
"from_pretrained"
],
"FeatureExtractionMixin"
:
[
"save_pretrained"
,
"from_pretrained"
],
"FeatureExtractionMixin"
:
[
"save_pretrained"
,
"from_pretrained"
],
"ProcessorMixin"
:
[
"save_pretrained"
,
"from_pretrained"
],
"ImageProcessingMixin"
:
[
"save_pretrained"
,
"from_pretrained"
],
},
},
}
}
...
@@ -190,8 +192,8 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -190,8 +192,8 @@ class DiffusionPipeline(ConfigMixin):
for
library_name
,
library_classes
in
LOADABLE_CLASSES
.
items
():
for
library_name
,
library_classes
in
LOADABLE_CLASSES
.
items
():
library
=
importlib
.
import_module
(
library_name
)
library
=
importlib
.
import_module
(
library_name
)
for
base_class
,
save_load_methods
in
library_classes
.
items
():
for
base_class
,
save_load_methods
in
library_classes
.
items
():
class_candidate
=
getattr
(
library
,
base_class
)
class_candidate
=
getattr
(
library
,
base_class
,
None
)
if
issubclass
(
model_cls
,
class_candidate
):
if
class_candidate
is
not
None
and
issubclass
(
model_cls
,
class_candidate
):
# if we found a suitable base class in LOADABLE_CLASSES then grab its save method
# if we found a suitable base class in LOADABLE_CLASSES then grab its save method
save_method_name
=
save_load_methods
[
0
]
save_method_name
=
save_load_methods
[
0
]
break
break
...
@@ -543,11 +545,11 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -543,11 +545,11 @@ class DiffusionPipeline(ConfigMixin):
library
=
importlib
.
import_module
(
library_name
)
library
=
importlib
.
import_module
(
library_name
)
class_obj
=
getattr
(
library
,
class_name
)
class_obj
=
getattr
(
library
,
class_name
)
importable_classes
=
LOADABLE_CLASSES
[
library_name
]
importable_classes
=
LOADABLE_CLASSES
[
library_name
]
class_candidates
=
{
c
:
getattr
(
library
,
c
)
for
c
in
importable_classes
.
keys
()}
class_candidates
=
{
c
:
getattr
(
library
,
c
,
None
)
for
c
in
importable_classes
.
keys
()}
expected_class_obj
=
None
expected_class_obj
=
None
for
class_name
,
class_candidate
in
class_candidates
.
items
():
for
class_name
,
class_candidate
in
class_candidates
.
items
():
if
issubclass
(
class_obj
,
class_candidate
):
if
class_candidate
is
not
None
and
issubclass
(
class_obj
,
class_candidate
):
expected_class_obj
=
class_candidate
expected_class_obj
=
class_candidate
if
not
issubclass
(
passed_class_obj
[
name
].
__class__
,
expected_class_obj
):
if
not
issubclass
(
passed_class_obj
[
name
].
__class__
,
expected_class_obj
):
...
@@ -577,14 +579,15 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -577,14 +579,15 @@ class DiffusionPipeline(ConfigMixin):
else
:
else
:
# else we just import it from the library.
# else we just import it from the library.
library
=
importlib
.
import_module
(
library_name
)
library
=
importlib
.
import_module
(
library_name
)
class_obj
=
getattr
(
library
,
class_name
)
class_obj
=
getattr
(
library
,
class_name
)
importable_classes
=
LOADABLE_CLASSES
[
library_name
]
importable_classes
=
LOADABLE_CLASSES
[
library_name
]
class_candidates
=
{
c
:
getattr
(
library
,
c
)
for
c
in
importable_classes
.
keys
()}
class_candidates
=
{
c
:
getattr
(
library
,
c
,
None
)
for
c
in
importable_classes
.
keys
()}
if
loaded_sub_model
is
None
and
sub_model_should_be_defined
:
if
loaded_sub_model
is
None
and
sub_model_should_be_defined
:
load_method_name
=
None
load_method_name
=
None
for
class_name
,
class_candidate
in
class_candidates
.
items
():
for
class_name
,
class_candidate
in
class_candidates
.
items
():
if
issubclass
(
class_obj
,
class_candidate
):
if
class_candidate
is
not
None
and
issubclass
(
class_obj
,
class_candidate
):
load_method_name
=
importable_classes
[
class_name
][
1
]
load_method_name
=
importable_classes
[
class_name
][
1
]
if
load_method_name
is
None
:
if
load_method_name
is
None
:
...
...
tests/test_pipelines.py
View file @
b93fe085
...
@@ -88,6 +88,50 @@ class DownloadTests(unittest.TestCase):
...
@@ -88,6 +88,50 @@ class DownloadTests(unittest.TestCase):
# https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_flax_model.msgpack
# https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_flax_model.msgpack
assert
not
any
(
f
.
endswith
(
".msgpack"
)
for
f
in
files
)
assert
not
any
(
f
.
endswith
(
".msgpack"
)
for
f
in
files
)
def
test_download_no_safety_checker
(
self
):
prompt
=
"hello"
pipe
=
StableDiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-torch"
,
safety_checker
=
None
)
generator
=
torch
.
Generator
(
device
=
torch_device
).
manual_seed
(
0
)
out
=
pipe
(
prompt
,
num_inference_steps
=
2
,
generator
=
generator
,
output_type
=
"numpy"
).
images
pipe_2
=
StableDiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-torch"
)
generator_2
=
torch
.
Generator
(
device
=
torch_device
).
manual_seed
(
0
)
out_2
=
pipe_2
(
prompt
,
num_inference_steps
=
2
,
generator
=
generator_2
,
output_type
=
"numpy"
).
images
assert
np
.
max
(
np
.
abs
(
out
-
out_2
))
<
1e-3
def
test_load_no_safety_checker_explicit_locally
(
self
):
prompt
=
"hello"
pipe
=
StableDiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-torch"
,
safety_checker
=
None
)
generator
=
torch
.
Generator
(
device
=
torch_device
).
manual_seed
(
0
)
out
=
pipe
(
prompt
,
num_inference_steps
=
2
,
generator
=
generator
,
output_type
=
"numpy"
).
images
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
pipe
.
save_pretrained
(
tmpdirname
)
pipe_2
=
StableDiffusionPipeline
.
from_pretrained
(
tmpdirname
,
safety_checker
=
None
)
generator_2
=
torch
.
Generator
(
device
=
torch_device
).
manual_seed
(
0
)
out_2
=
pipe_2
(
prompt
,
num_inference_steps
=
2
,
generator
=
generator_2
,
output_type
=
"numpy"
).
images
assert
np
.
max
(
np
.
abs
(
out
-
out_2
))
<
1e-3
def
test_load_no_safety_checker_default_locally
(
self
):
prompt
=
"hello"
pipe
=
StableDiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-torch"
)
generator
=
torch
.
Generator
(
device
=
torch_device
).
manual_seed
(
0
)
out
=
pipe
(
prompt
,
num_inference_steps
=
2
,
generator
=
generator
,
output_type
=
"numpy"
).
images
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
pipe
.
save_pretrained
(
tmpdirname
)
pipe_2
=
StableDiffusionPipeline
.
from_pretrained
(
tmpdirname
)
generator_2
=
torch
.
Generator
(
device
=
torch_device
).
manual_seed
(
0
)
out_2
=
pipe_2
(
prompt
,
num_inference_steps
=
2
,
generator
=
generator_2
,
output_type
=
"numpy"
).
images
assert
np
.
max
(
np
.
abs
(
out
-
out_2
))
<
1e-3
class
CustomPipelineTests
(
unittest
.
TestCase
):
class
CustomPipelineTests
(
unittest
.
TestCase
):
def
test_load_custom_pipeline
(
self
):
def
test_load_custom_pipeline
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment