Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
b93fe085
Unverified
Commit
b93fe085
authored
Nov 09, 2022
by
Patrick von Platen
Committed by
GitHub
Nov 09, 2022
Browse files
[Loading] Make sure loading edge cases work (#1192)
* [Loading] Make edge cases work * up * finish * up
parent
3f7edc5f
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
61 additions
and
12 deletions
+61
-12
src/diffusers/pipeline_flax_utils.py
src/diffusers/pipeline_flax_utils.py
+8
-6
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+9
-6
tests/test_pipelines.py
tests/test_pipelines.py
+44
-0
No files found.
src/diffusers/pipeline_flax_utils.py
View file @
b93fe085
...
@@ -55,6 +55,8 @@ LOADABLE_CLASSES = {
...
@@ -55,6 +55,8 @@ LOADABLE_CLASSES = {
"PreTrainedTokenizerFast"
:
[
"save_pretrained"
,
"from_pretrained"
],
"PreTrainedTokenizerFast"
:
[
"save_pretrained"
,
"from_pretrained"
],
"FlaxPreTrainedModel"
:
[
"save_pretrained"
,
"from_pretrained"
],
"FlaxPreTrainedModel"
:
[
"save_pretrained"
,
"from_pretrained"
],
"FeatureExtractionMixin"
:
[
"save_pretrained"
,
"from_pretrained"
],
"FeatureExtractionMixin"
:
[
"save_pretrained"
,
"from_pretrained"
],
"ProcessorMixin"
:
[
"save_pretrained"
,
"from_pretrained"
],
"ImageProcessingMixin"
:
[
"save_pretrained"
,
"from_pretrained"
],
},
},
}
}
...
@@ -172,8 +174,8 @@ class FlaxDiffusionPipeline(ConfigMixin):
...
@@ -172,8 +174,8 @@ class FlaxDiffusionPipeline(ConfigMixin):
for
library_name
,
library_classes
in
LOADABLE_CLASSES
.
items
():
for
library_name
,
library_classes
in
LOADABLE_CLASSES
.
items
():
library
=
importlib
.
import_module
(
library_name
)
library
=
importlib
.
import_module
(
library_name
)
for
base_class
,
save_load_methods
in
library_classes
.
items
():
for
base_class
,
save_load_methods
in
library_classes
.
items
():
class_candidate
=
getattr
(
library
,
base_class
)
class_candidate
=
getattr
(
library
,
base_class
,
None
)
if
issubclass
(
model_cls
,
class_candidate
):
if
class_candidate
is
not
None
and
issubclass
(
model_cls
,
class_candidate
):
# if we found a suitable base class in LOADABLE_CLASSES then grab its save method
# if we found a suitable base class in LOADABLE_CLASSES then grab its save method
save_method_name
=
save_load_methods
[
0
]
save_method_name
=
save_load_methods
[
0
]
break
break
...
@@ -387,11 +389,11 @@ class FlaxDiffusionPipeline(ConfigMixin):
...
@@ -387,11 +389,11 @@ class FlaxDiffusionPipeline(ConfigMixin):
library
=
importlib
.
import_module
(
library_name
)
library
=
importlib
.
import_module
(
library_name
)
class_obj
=
getattr
(
library
,
class_name
)
class_obj
=
getattr
(
library
,
class_name
)
importable_classes
=
LOADABLE_CLASSES
[
library_name
]
importable_classes
=
LOADABLE_CLASSES
[
library_name
]
class_candidates
=
{
c
:
getattr
(
library
,
c
)
for
c
in
importable_classes
.
keys
()}
class_candidates
=
{
c
:
getattr
(
library
,
c
,
None
)
for
c
in
importable_classes
.
keys
()}
expected_class_obj
=
None
expected_class_obj
=
None
for
class_name
,
class_candidate
in
class_candidates
.
items
():
for
class_name
,
class_candidate
in
class_candidates
.
items
():
if
issubclass
(
class_obj
,
class_candidate
):
if
class_candidate
is
not
None
and
issubclass
(
class_obj
,
class_candidate
):
expected_class_obj
=
class_candidate
expected_class_obj
=
class_candidate
if
not
issubclass
(
passed_class_obj
[
name
].
__class__
,
expected_class_obj
):
if
not
issubclass
(
passed_class_obj
[
name
].
__class__
,
expected_class_obj
):
...
@@ -425,12 +427,12 @@ class FlaxDiffusionPipeline(ConfigMixin):
...
@@ -425,12 +427,12 @@ class FlaxDiffusionPipeline(ConfigMixin):
class_obj
=
import_flax_or_no_model
(
library
,
class_name
)
class_obj
=
import_flax_or_no_model
(
library
,
class_name
)
importable_classes
=
LOADABLE_CLASSES
[
library_name
]
importable_classes
=
LOADABLE_CLASSES
[
library_name
]
class_candidates
=
{
c
:
getattr
(
library
,
c
)
for
c
in
importable_classes
.
keys
()}
class_candidates
=
{
c
:
getattr
(
library
,
c
,
None
)
for
c
in
importable_classes
.
keys
()}
if
loaded_sub_model
is
None
and
sub_model_should_be_defined
:
if
loaded_sub_model
is
None
and
sub_model_should_be_defined
:
load_method_name
=
None
load_method_name
=
None
for
class_name
,
class_candidate
in
class_candidates
.
items
():
for
class_name
,
class_candidate
in
class_candidates
.
items
():
if
issubclass
(
class_obj
,
class_candidate
):
if
class_candidate
is
not
None
and
issubclass
(
class_obj
,
class_candidate
):
load_method_name
=
importable_classes
[
class_name
][
1
]
load_method_name
=
importable_classes
[
class_name
][
1
]
load_method
=
getattr
(
class_obj
,
load_method_name
)
load_method
=
getattr
(
class_obj
,
load_method_name
)
...
...
src/diffusers/pipeline_utils.py
View file @
b93fe085
...
@@ -74,6 +74,8 @@ LOADABLE_CLASSES = {
...
@@ -74,6 +74,8 @@ LOADABLE_CLASSES = {
"PreTrainedTokenizerFast"
:
[
"save_pretrained"
,
"from_pretrained"
],
"PreTrainedTokenizerFast"
:
[
"save_pretrained"
,
"from_pretrained"
],
"PreTrainedModel"
:
[
"save_pretrained"
,
"from_pretrained"
],
"PreTrainedModel"
:
[
"save_pretrained"
,
"from_pretrained"
],
"FeatureExtractionMixin"
:
[
"save_pretrained"
,
"from_pretrained"
],
"FeatureExtractionMixin"
:
[
"save_pretrained"
,
"from_pretrained"
],
"ProcessorMixin"
:
[
"save_pretrained"
,
"from_pretrained"
],
"ImageProcessingMixin"
:
[
"save_pretrained"
,
"from_pretrained"
],
},
},
}
}
...
@@ -190,8 +192,8 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -190,8 +192,8 @@ class DiffusionPipeline(ConfigMixin):
for
library_name
,
library_classes
in
LOADABLE_CLASSES
.
items
():
for
library_name
,
library_classes
in
LOADABLE_CLASSES
.
items
():
library
=
importlib
.
import_module
(
library_name
)
library
=
importlib
.
import_module
(
library_name
)
for
base_class
,
save_load_methods
in
library_classes
.
items
():
for
base_class
,
save_load_methods
in
library_classes
.
items
():
class_candidate
=
getattr
(
library
,
base_class
)
class_candidate
=
getattr
(
library
,
base_class
,
None
)
if
issubclass
(
model_cls
,
class_candidate
):
if
class_candidate
is
not
None
and
issubclass
(
model_cls
,
class_candidate
):
# if we found a suitable base class in LOADABLE_CLASSES then grab its save method
# if we found a suitable base class in LOADABLE_CLASSES then grab its save method
save_method_name
=
save_load_methods
[
0
]
save_method_name
=
save_load_methods
[
0
]
break
break
...
@@ -543,11 +545,11 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -543,11 +545,11 @@ class DiffusionPipeline(ConfigMixin):
library
=
importlib
.
import_module
(
library_name
)
library
=
importlib
.
import_module
(
library_name
)
class_obj
=
getattr
(
library
,
class_name
)
class_obj
=
getattr
(
library
,
class_name
)
importable_classes
=
LOADABLE_CLASSES
[
library_name
]
importable_classes
=
LOADABLE_CLASSES
[
library_name
]
class_candidates
=
{
c
:
getattr
(
library
,
c
)
for
c
in
importable_classes
.
keys
()}
class_candidates
=
{
c
:
getattr
(
library
,
c
,
None
)
for
c
in
importable_classes
.
keys
()}
expected_class_obj
=
None
expected_class_obj
=
None
for
class_name
,
class_candidate
in
class_candidates
.
items
():
for
class_name
,
class_candidate
in
class_candidates
.
items
():
if
issubclass
(
class_obj
,
class_candidate
):
if
class_candidate
is
not
None
and
issubclass
(
class_obj
,
class_candidate
):
expected_class_obj
=
class_candidate
expected_class_obj
=
class_candidate
if
not
issubclass
(
passed_class_obj
[
name
].
__class__
,
expected_class_obj
):
if
not
issubclass
(
passed_class_obj
[
name
].
__class__
,
expected_class_obj
):
...
@@ -577,14 +579,15 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -577,14 +579,15 @@ class DiffusionPipeline(ConfigMixin):
else
:
else
:
# else we just import it from the library.
# else we just import it from the library.
library
=
importlib
.
import_module
(
library_name
)
library
=
importlib
.
import_module
(
library_name
)
class_obj
=
getattr
(
library
,
class_name
)
class_obj
=
getattr
(
library
,
class_name
)
importable_classes
=
LOADABLE_CLASSES
[
library_name
]
importable_classes
=
LOADABLE_CLASSES
[
library_name
]
class_candidates
=
{
c
:
getattr
(
library
,
c
)
for
c
in
importable_classes
.
keys
()}
class_candidates
=
{
c
:
getattr
(
library
,
c
,
None
)
for
c
in
importable_classes
.
keys
()}
if
loaded_sub_model
is
None
and
sub_model_should_be_defined
:
if
loaded_sub_model
is
None
and
sub_model_should_be_defined
:
load_method_name
=
None
load_method_name
=
None
for
class_name
,
class_candidate
in
class_candidates
.
items
():
for
class_name
,
class_candidate
in
class_candidates
.
items
():
if
issubclass
(
class_obj
,
class_candidate
):
if
class_candidate
is
not
None
and
issubclass
(
class_obj
,
class_candidate
):
load_method_name
=
importable_classes
[
class_name
][
1
]
load_method_name
=
importable_classes
[
class_name
][
1
]
if
load_method_name
is
None
:
if
load_method_name
is
None
:
...
...
tests/test_pipelines.py
View file @
b93fe085
...
@@ -88,6 +88,50 @@ class DownloadTests(unittest.TestCase):
...
@@ -88,6 +88,50 @@ class DownloadTests(unittest.TestCase):
# https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_flax_model.msgpack
# https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_flax_model.msgpack
assert
not
any
(
f
.
endswith
(
".msgpack"
)
for
f
in
files
)
assert
not
any
(
f
.
endswith
(
".msgpack"
)
for
f
in
files
)
def
test_download_no_safety_checker
(
self
):
prompt
=
"hello"
pipe
=
StableDiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-torch"
,
safety_checker
=
None
)
generator
=
torch
.
Generator
(
device
=
torch_device
).
manual_seed
(
0
)
out
=
pipe
(
prompt
,
num_inference_steps
=
2
,
generator
=
generator
,
output_type
=
"numpy"
).
images
pipe_2
=
StableDiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-torch"
)
generator_2
=
torch
.
Generator
(
device
=
torch_device
).
manual_seed
(
0
)
out_2
=
pipe_2
(
prompt
,
num_inference_steps
=
2
,
generator
=
generator_2
,
output_type
=
"numpy"
).
images
assert
np
.
max
(
np
.
abs
(
out
-
out_2
))
<
1e-3
def
test_load_no_safety_checker_explicit_locally
(
self
):
prompt
=
"hello"
pipe
=
StableDiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-torch"
,
safety_checker
=
None
)
generator
=
torch
.
Generator
(
device
=
torch_device
).
manual_seed
(
0
)
out
=
pipe
(
prompt
,
num_inference_steps
=
2
,
generator
=
generator
,
output_type
=
"numpy"
).
images
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
pipe
.
save_pretrained
(
tmpdirname
)
pipe_2
=
StableDiffusionPipeline
.
from_pretrained
(
tmpdirname
,
safety_checker
=
None
)
generator_2
=
torch
.
Generator
(
device
=
torch_device
).
manual_seed
(
0
)
out_2
=
pipe_2
(
prompt
,
num_inference_steps
=
2
,
generator
=
generator_2
,
output_type
=
"numpy"
).
images
assert
np
.
max
(
np
.
abs
(
out
-
out_2
))
<
1e-3
def
test_load_no_safety_checker_default_locally
(
self
):
prompt
=
"hello"
pipe
=
StableDiffusionPipeline
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-torch"
)
generator
=
torch
.
Generator
(
device
=
torch_device
).
manual_seed
(
0
)
out
=
pipe
(
prompt
,
num_inference_steps
=
2
,
generator
=
generator
,
output_type
=
"numpy"
).
images
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
pipe
.
save_pretrained
(
tmpdirname
)
pipe_2
=
StableDiffusionPipeline
.
from_pretrained
(
tmpdirname
)
generator_2
=
torch
.
Generator
(
device
=
torch_device
).
manual_seed
(
0
)
out_2
=
pipe_2
(
prompt
,
num_inference_steps
=
2
,
generator
=
generator_2
,
output_type
=
"numpy"
).
images
assert
np
.
max
(
np
.
abs
(
out
-
out_2
))
<
1e-3
class
CustomPipelineTests
(
unittest
.
TestCase
):
class
CustomPipelineTests
(
unittest
.
TestCase
):
def
test_load_custom_pipeline
(
self
):
def
test_load_custom_pipeline
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment