Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
d4197bf4
Unverified
Commit
d4197bf4
authored
May 23, 2023
by
Patrick von Platen
Committed by
GitHub
May 23, 2023
Browse files
Allow custom pipeline loading (#3504)
parent
b134f6a8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
3 deletions
+34
-3
src/diffusers/pipelines/pipeline_utils.py
src/diffusers/pipelines/pipeline_utils.py
+7
-3
tests/pipelines/test_pipelines.py
tests/pipelines/test_pipelines.py
+27
-0
No files found.
src/diffusers/pipelines/pipeline_utils.py
View file @
d4197bf4
...
@@ -491,15 +491,19 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -491,15 +491,19 @@ class DiffusionPipeline(ConfigMixin):
library
=
module
.
__module__
.
split
(
"."
)[
0
]
library
=
module
.
__module__
.
split
(
"."
)[
0
]
# check if the module is a pipeline module
# check if the module is a pipeline module
pipeline_dir
=
module
.
__module__
.
split
(
"."
)[
-
2
]
if
len
(
module
.
__module__
.
split
(
"."
))
>
2
else
None
module_path_items
=
module
.
__module__
.
split
(
"."
)
pipeline_dir
=
module_path_items
[
-
2
]
if
len
(
module_path_items
)
>
2
else
None
path
=
module
.
__module__
.
split
(
"."
)
path
=
module
.
__module__
.
split
(
"."
)
is_pipeline_module
=
pipeline_dir
in
path
and
hasattr
(
pipelines
,
pipeline_dir
)
is_pipeline_module
=
pipeline_dir
in
path
and
hasattr
(
pipelines
,
pipeline_dir
)
# if library is not in LOADABLE_CLASSES, then it is a custom module.
# if library is not in LOADABLE_CLASSES, then it is a custom module.
# Or if it's a pipeline module, then the module is inside the pipeline
# Or if it's a pipeline module, then the module is inside the pipeline
# folder so we set the library to module name.
# folder so we set the library to module name.
if
library
not
in
LOADABLE_CLASSES
or
is_pipeline_module
:
if
is_pipeline_module
:
library
=
pipeline_dir
library
=
pipeline_dir
elif
library
not
in
LOADABLE_CLASSES
:
library
=
module
.
__module__
# retrieve class_name
# retrieve class_name
class_name
=
module
.
__class__
.
__name__
class_name
=
module
.
__class__
.
__name__
...
@@ -1039,7 +1043,7 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -1039,7 +1043,7 @@ class DiffusionPipeline(ConfigMixin):
# 6.2 Define all importable classes
# 6.2 Define all importable classes
is_pipeline_module
=
hasattr
(
pipelines
,
library_name
)
is_pipeline_module
=
hasattr
(
pipelines
,
library_name
)
importable_classes
=
ALL_IMPORTABLE_CLASSES
if
is_pipeline_module
else
LOADABLE_CLASSES
[
library_name
]
importable_classes
=
ALL_IMPORTABLE_CLASSES
loaded_sub_model
=
None
loaded_sub_model
=
None
# 6.3 Use passed sub model or load class_name from library_name
# 6.3 Use passed sub model or load class_name from library_name
...
...
tests/pipelines/test_pipelines.py
View file @
d4197bf4
...
@@ -35,6 +35,7 @@ from transformers import CLIPImageProcessor, CLIPModel, CLIPTextConfig, CLIPText
...
@@ -35,6 +35,7 @@ from transformers import CLIPImageProcessor, CLIPModel, CLIPTextConfig, CLIPText
from
diffusers
import
(
from
diffusers
import
(
AutoencoderKL
,
AutoencoderKL
,
ConfigMixin
,
DDIMPipeline
,
DDIMPipeline
,
DDIMScheduler
,
DDIMScheduler
,
DDPMPipeline
,
DDPMPipeline
,
...
@@ -44,6 +45,7 @@ from diffusers import (
...
@@ -44,6 +45,7 @@ from diffusers import (
EulerAncestralDiscreteScheduler
,
EulerAncestralDiscreteScheduler
,
EulerDiscreteScheduler
,
EulerDiscreteScheduler
,
LMSDiscreteScheduler
,
LMSDiscreteScheduler
,
ModelMixin
,
PNDMScheduler
,
PNDMScheduler
,
StableDiffusionImg2ImgPipeline
,
StableDiffusionImg2ImgPipeline
,
StableDiffusionInpaintPipelineLegacy
,
StableDiffusionInpaintPipelineLegacy
,
...
@@ -77,6 +79,17 @@ from diffusers.utils.testing_utils import (
...
@@ -77,6 +79,17 @@ from diffusers.utils.testing_utils import (
enable_full_determinism
()
enable_full_determinism
()
class
CustomEncoder
(
ModelMixin
,
ConfigMixin
):
def
__init__
(
self
):
super
().
__init__
()
class
CustomPipeline
(
DiffusionPipeline
):
def
__init__
(
self
,
encoder
:
CustomEncoder
,
scheduler
:
DDIMScheduler
):
super
().
__init__
()
self
.
register_modules
(
encoder
=
encoder
,
scheduler
=
scheduler
)
class
DownloadTests
(
unittest
.
TestCase
):
class
DownloadTests
(
unittest
.
TestCase
):
def
test_one_request_upon_cached
(
self
):
def
test_one_request_upon_cached
(
self
):
# TODO: For some reason this test fails on MPS where no HEAD call is made.
# TODO: For some reason this test fails on MPS where no HEAD call is made.
...
@@ -695,6 +708,20 @@ class CustomPipelineTests(unittest.TestCase):
...
@@ -695,6 +708,20 @@ class CustomPipelineTests(unittest.TestCase):
# compare to https://github.com/huggingface/diffusers/blob/main/tests/fixtures/custom_pipeline/pipeline.py#L102
# compare to https://github.com/huggingface/diffusers/blob/main/tests/fixtures/custom_pipeline/pipeline.py#L102
assert
output_str
==
"This is a local test"
assert
output_str
==
"This is a local test"
def
test_custom_model_and_pipeline
(
self
):
pipe
=
CustomPipeline
(
encoder
=
CustomEncoder
(),
scheduler
=
DDIMScheduler
(),
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
pipe
.
save_pretrained
(
tmpdirname
)
pipe_new
=
CustomPipeline
.
from_pretrained
(
tmpdirname
)
pipe_new
.
save_pretrained
(
tmpdirname
)
assert
dict
(
pipe_new
.
config
)
==
dict
(
pipe
.
config
)
@
slow
@
slow
@
require_torch_gpu
@
require_torch_gpu
def
test_download_from_git
(
self
):
def
test_download_from_git
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment