Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
decac197
Commit
decac197
authored
Jun 09, 2022
by
anton-l
Browse files
Merge branch 'main' of github.com:huggingface/diffusers
parents
ae73d95e
2fa1d648
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
12 deletions
+6
-12
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+6
-12
No files found.
src/diffusers/pipeline_utils.py
View file @
decac197
...
...
@@ -45,6 +45,10 @@ LOADABLE_CLASSES = {
},
}
ALL_IMPORTABLE_CLASSES
=
{}
for
library
in
LOADABLE_CLASSES
:
ALL_IMPORTABLE_CLASSES
.
update
(
LOADABLE_CLASSES
[
library
])
class
DiffusionPipeline
(
ConfigMixin
):
...
...
@@ -105,10 +109,8 @@ class DiffusionPipeline(ConfigMixin):
Add docstrings
"""
cache_dir
=
kwargs
.
pop
(
"cache_dir"
,
DIFFUSERS_CACHE
)
force_download
=
kwargs
.
pop
(
"force_download"
,
False
)
resume_download
=
kwargs
.
pop
(
"resume_download"
,
False
)
proxies
=
kwargs
.
pop
(
"proxies"
,
None
)
output_loading_info
=
kwargs
.
pop
(
"output_loading_info"
,
False
)
local_files_only
=
kwargs
.
pop
(
"local_files_only"
,
False
)
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
...
...
@@ -117,10 +119,8 @@ class DiffusionPipeline(ConfigMixin):
cached_folder
=
snapshot_download
(
pretrained_model_name_or_path
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
resume_download
=
resume_download
,
proxies
=
proxies
,
output_loading_info
=
output_loading_info
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
)
...
...
@@ -147,20 +147,14 @@ class DiffusionPipeline(ConfigMixin):
init_kwargs
=
{}
# get all importable classes to get the load method name for custom models/components
# here we enforce that custom models/components should always subclass from base classes in tansformers and diffusers
all_importable_classes
=
{}
for
library
in
LOADABLE_CLASSES
:
all_importable_classes
.
update
(
LOADABLE_CLASSES
[
library
])
for
name
,
(
library_name
,
class_name
)
in
init_dict
.
items
():
# if the model is not in diffusers or transformers, we need to load it from the hub
# assumes that it's a subclass of ModelMixin
if
library_name
==
module_candidate_name
:
class_obj
=
get_class_from_dynamic_module
(
cached_folder
,
module
,
class_name
,
cached_folder
)
# since it's not from a library, we need to check class candidates for all importable classes
importable_classes
=
all_importable_classes
class_candidates
=
{
c
:
class_obj
for
c
in
all_importable_classes
}
importable_classes
=
ALL_IMPORTABLE_CLASSES
class_candidates
=
{
c
:
class_obj
for
c
in
ALL_IMPORTABLE_CLASSES
.
keys
()
}
else
:
library
=
importlib
.
import_module
(
library_name
)
class_obj
=
getattr
(
library
,
class_name
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment