Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
d81b56ba
Commit
d81b56ba
authored
Jun 14, 2022
by
patil-suraj
Browse files
allow loading model from pipeline module
parent
ca72c1f8
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
6 deletions
+25
-6
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+25
-6
No files found.
src/diffusers/pipeline_utils.py
View file @
d81b56ba
...
...
@@ -55,11 +55,20 @@ class DiffusionPipeline(ConfigMixin):
config_name
=
"model_index.json"
def
register_modules
(
self
,
**
kwargs
):
# import it here to avoid circular import
from
diffusers
import
pipelines
for
name
,
module
in
kwargs
.
items
():
# check if the module is a pipeline module
is_pipeline_module
=
hasattr
(
pipelines
,
module
.
__module__
.
split
(
"."
)[
-
1
])
# retrive library
library
=
module
.
__module__
.
split
(
"."
)[
0
]
# if library is not in LOADABLE_CLASSES, then it is a custom module
if
library
not
in
LOADABLE_CLASSES
:
# if library is not in LOADABLE_CLASSES, then it is a custom module.
# Or if it's a pipeline module, then the module is inside the pipeline
# so we set the library to module name.
if
library
not
in
LOADABLE_CLASSES
or
is_pipeline_module
:
library
=
module
.
__module__
.
split
(
"."
)[
-
1
]
# retrive class_name
...
...
@@ -152,11 +161,21 @@ class DiffusionPipeline(ConfigMixin):
init_kwargs
=
{}
# import it here to avoid circular import
from
diffusers
import
pipelines
# 4. Load each module in the pipeline
for
name
,
(
library_name
,
class_name
)
in
init_dict
.
items
():
is_pipeline_module
=
hasattr
(
pipelines
,
library_name
)
# if the model is in a pipeline module, then we load it from the pipeline
if
is_pipeline_module
:
pipeline_module
=
getattr
(
pipelines
,
library_name
)
class_obj
=
getattr
(
pipeline_module
,
class_name
)
importable_classes
=
ALL_IMPORTABLE_CLASSES
class_candidates
=
{
c
:
class_obj
for
c
in
ALL_IMPORTABLE_CLASSES
.
keys
()}
elif
library_name
==
module_candidate_name
:
# if the model is not in diffusers or transformers, we need to load it from the hub
# assumes that it's a subclass of ModelMixin
if
library_name
==
module_candidate_name
:
class_obj
=
get_class_from_dynamic_module
(
cached_folder
,
module_candidate
,
class_name
,
cached_folder
)
# since it's not from a library, we need to check class candidates for all importable classes
importable_classes
=
ALL_IMPORTABLE_CLASSES
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment