Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
d4197bf4
"git@developer.sourcefind.cn:OpenDAS/apex.git" did not exist on "b2b554395ddd04393475c97cbccbd0cd3df8b362"
Unverified
Commit
d4197bf4
authored
May 23, 2023
by
Patrick von Platen
Committed by
GitHub
May 23, 2023
Browse files
Allow custom pipeline loading (#3504)
parent
b134f6a8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
3 deletions
+34
-3
src/diffusers/pipelines/pipeline_utils.py
src/diffusers/pipelines/pipeline_utils.py
+7
-3
tests/pipelines/test_pipelines.py
tests/pipelines/test_pipelines.py
+27
-0
No files found.
src/diffusers/pipelines/pipeline_utils.py
View file @
d4197bf4
...
...
@@ -491,15 +491,19 @@ class DiffusionPipeline(ConfigMixin):
library
=
module
.
__module__
.
split
(
"."
)[
0
]
# check if the module is a pipeline module
pipeline_dir
=
module
.
__module__
.
split
(
"."
)[
-
2
]
if
len
(
module
.
__module__
.
split
(
"."
))
>
2
else
None
module_path_items
=
module
.
__module__
.
split
(
"."
)
pipeline_dir
=
module_path_items
[
-
2
]
if
len
(
module_path_items
)
>
2
else
None
path
=
module
.
__module__
.
split
(
"."
)
is_pipeline_module
=
pipeline_dir
in
path
and
hasattr
(
pipelines
,
pipeline_dir
)
# if library is not in LOADABLE_CLASSES, then it is a custom module.
# Or if it's a pipeline module, then the module is inside the pipeline
# folder so we set the library to module name.
if
library
not
in
LOADABLE_CLASSES
or
is_pipeline_module
:
if
is_pipeline_module
:
library
=
pipeline_dir
elif
library
not
in
LOADABLE_CLASSES
:
library
=
module
.
__module__
# retrieve class_name
class_name
=
module
.
__class__
.
__name__
...
...
@@ -1039,7 +1043,7 @@ class DiffusionPipeline(ConfigMixin):
# 6.2 Define all importable classes
is_pipeline_module
=
hasattr
(
pipelines
,
library_name
)
importable_classes
=
ALL_IMPORTABLE_CLASSES
if
is_pipeline_module
else
LOADABLE_CLASSES
[
library_name
]
importable_classes
=
ALL_IMPORTABLE_CLASSES
loaded_sub_model
=
None
# 6.3 Use passed sub model or load class_name from library_name
...
...
tests/pipelines/test_pipelines.py
View file @
d4197bf4
...
...
@@ -35,6 +35,7 @@ from transformers import CLIPImageProcessor, CLIPModel, CLIPTextConfig, CLIPText
from
diffusers
import
(
AutoencoderKL
,
ConfigMixin
,
DDIMPipeline
,
DDIMScheduler
,
DDPMPipeline
,
...
...
@@ -44,6 +45,7 @@ from diffusers import (
EulerAncestralDiscreteScheduler
,
EulerDiscreteScheduler
,
LMSDiscreteScheduler
,
ModelMixin
,
PNDMScheduler
,
StableDiffusionImg2ImgPipeline
,
StableDiffusionInpaintPipelineLegacy
,
...
...
@@ -77,6 +79,17 @@ from diffusers.utils.testing_utils import (
enable_full_determinism
()
class
CustomEncoder
(
ModelMixin
,
ConfigMixin
):
def
__init__
(
self
):
super
().
__init__
()
class
CustomPipeline
(
DiffusionPipeline
):
def
__init__
(
self
,
encoder
:
CustomEncoder
,
scheduler
:
DDIMScheduler
):
super
().
__init__
()
self
.
register_modules
(
encoder
=
encoder
,
scheduler
=
scheduler
)
class
DownloadTests
(
unittest
.
TestCase
):
def
test_one_request_upon_cached
(
self
):
# TODO: For some reason this test fails on MPS where no HEAD call is made.
...
...
@@ -695,6 +708,20 @@ class CustomPipelineTests(unittest.TestCase):
# compare to https://github.com/huggingface/diffusers/blob/main/tests/fixtures/custom_pipeline/pipeline.py#L102
assert
output_str
==
"This is a local test"
def
test_custom_model_and_pipeline
(
self
):
pipe
=
CustomPipeline
(
encoder
=
CustomEncoder
(),
scheduler
=
DDIMScheduler
(),
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
pipe
.
save_pretrained
(
tmpdirname
)
pipe_new
=
CustomPipeline
.
from_pretrained
(
tmpdirname
)
pipe_new
.
save_pretrained
(
tmpdirname
)
assert
dict
(
pipe_new
.
config
)
==
dict
(
pipe
.
config
)
@
slow
@
require_torch_gpu
def
test_download_from_git
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment