Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
397b31c8
"sgl-kernel/git@developer.sourcefind.cn:change/sglang.git" did not exist on "c186feed7fb7604db59377e74d48bcc61053832e"
Commit
397b31c8
authored
Jun 09, 2022
by
patil-suraj
Browse files
allow loading modules from hub
parent
46dae846
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
13 deletions
+20
-13
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+20
-13
No files found.
src/diffusers/pipeline_utils.py
View file @
397b31c8
...
@@ -54,6 +54,10 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -54,6 +54,10 @@ class DiffusionPipeline(ConfigMixin):
for
name
,
module
in
kwargs
.
items
():
for
name
,
module
in
kwargs
.
items
():
# retrive library
# retrive library
library
=
module
.
__module__
.
split
(
"."
)[
0
]
library
=
module
.
__module__
.
split
(
"."
)[
0
]
# if library is not in LOADABLE_CLASSES, then it is a custom module
if
library
not
in
LOADABLE_CLASSES
:
library
=
module
.
__module__
.
split
(
"."
)[
-
1
]
# retrive class_name
# retrive class_name
class_name
=
module
.
__class__
.
__name__
class_name
=
module
.
__class__
.
__name__
...
@@ -105,6 +109,7 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -105,6 +109,7 @@ class DiffusionPipeline(ConfigMixin):
config_dict
=
cls
.
get_config_dict
(
cached_folder
)
config_dict
=
cls
.
get_config_dict
(
cached_folder
)
module_candidate
=
config_dict
[
"_module"
]
module_candidate
=
config_dict
[
"_module"
]
module_candidate_name
=
module_candidate
.
replace
(
".py"
,
""
)
# if we load from explicit class, let's use it
# if we load from explicit class, let's use it
if
cls
!=
DiffusionPipeline
:
if
cls
!=
DiffusionPipeline
:
...
@@ -120,12 +125,14 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -120,12 +125,14 @@ class DiffusionPipeline(ConfigMixin):
init_kwargs
=
{}
init_kwargs
=
{}
for
name
,
(
library_name
,
class_name
)
in
init_dict
.
items
():
for
name
,
(
library_name
,
class_name
)
in
init_dict
.
items
():
importable_classes
=
LOADABLE_CLASSES
[
library_name
]
if
library_name
==
module_candidate
:
# if the model is not in diffusers or transformers, we need to load it from the hub
# TODO(Suraj)
# assumes that it's a subclass of ModelMixin
# for vq
if
library_name
==
module_candidate_name
:
pass
class_obj
=
get_class_from_dynamic_module
(
cached_folder
,
module
,
class_name
,
cached_folder
)
load_method_name
=
"from_pretrained"
else
:
importable_classes
=
LOADABLE_CLASSES
[
library_name
]
library
=
importlib
.
import_module
(
library_name
)
library
=
importlib
.
import_module
(
library_name
)
class_obj
=
getattr
(
library
,
class_name
)
class_obj
=
getattr
(
library
,
class_name
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment