Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
99540747
Commit
99540747
authored
Jun 09, 2022
by
anton-l
Browse files
Merge remote-tracking branch 'origin/main'
parents
528b1293
9fc2b6c5
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
2 deletions
+9
-2
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+9
-2
No files found.
src/diffusers/pipeline_utils.py
View file @
99540747
...
@@ -112,6 +112,7 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -112,6 +112,7 @@ class DiffusionPipeline(ConfigMixin):
local_files_only
=
kwargs
.
pop
(
"local_files_only"
,
False
)
local_files_only
=
kwargs
.
pop
(
"local_files_only"
,
False
)
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
# use snapshot download here to get it working from from_pretrained
if
not
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
if
not
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
cached_folder
=
snapshot_download
(
cached_folder
=
snapshot_download
(
...
@@ -127,11 +128,12 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -127,11 +128,12 @@ class DiffusionPipeline(ConfigMixin):
config_dict
=
cls
.
get_config_dict
(
cached_folder
)
config_dict
=
cls
.
get_config_dict
(
cached_folder
)
module
=
config_dict
[
"_module"
]
# 2. Get class name and module candidates to load custom models
class_name_
=
config_dict
[
"_class_name"
]
class_name_
=
config_dict
[
"_class_name"
]
module_candidate
=
config_dict
[
"_module"
]
module_candidate
=
config_dict
[
"_module"
]
module_candidate_name
=
module_candidate
.
replace
(
".py"
,
""
)
module_candidate_name
=
module_candidate
.
replace
(
".py"
,
""
)
# 3. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it
# if we load from explicit class, let's use it
if
cls
!=
DiffusionPipeline
:
if
cls
!=
DiffusionPipeline
:
pipeline_class
=
cls
pipeline_class
=
cls
...
@@ -145,6 +147,7 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -145,6 +147,7 @@ class DiffusionPipeline(ConfigMixin):
init_kwargs
=
{}
init_kwargs
=
{}
# 4. Load each module in the pipeline
for
name
,
(
library_name
,
class_name
)
in
init_dict
.
items
():
for
name
,
(
library_name
,
class_name
)
in
init_dict
.
items
():
# if the model is not in diffusers or transformers, we need to load it from the hub
# if the model is not in diffusers or transformers, we need to load it from the hub
# assumes that it's a subclass of ModelMixin
# assumes that it's a subclass of ModelMixin
...
@@ -154,6 +157,7 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -154,6 +157,7 @@ class DiffusionPipeline(ConfigMixin):
importable_classes
=
ALL_IMPORTABLE_CLASSES
importable_classes
=
ALL_IMPORTABLE_CLASSES
class_candidates
=
{
c
:
class_obj
for
c
in
ALL_IMPORTABLE_CLASSES
.
keys
()}
class_candidates
=
{
c
:
class_obj
for
c
in
ALL_IMPORTABLE_CLASSES
.
keys
()}
else
:
else
:
# else we just import it from the library.
library
=
importlib
.
import_module
(
library_name
)
library
=
importlib
.
import_module
(
library_name
)
class_obj
=
getattr
(
library
,
class_name
)
class_obj
=
getattr
(
library
,
class_name
)
importable_classes
=
LOADABLE_CLASSES
[
library_name
]
importable_classes
=
LOADABLE_CLASSES
[
library_name
]
...
@@ -166,12 +170,15 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -166,12 +170,15 @@ class DiffusionPipeline(ConfigMixin):
load_method
=
getattr
(
class_obj
,
load_method_name
)
load_method
=
getattr
(
class_obj
,
load_method_name
)
# check if the module is in a subdirectory
if
os
.
path
.
isdir
(
os
.
path
.
join
(
cached_folder
,
name
)):
if
os
.
path
.
isdir
(
os
.
path
.
join
(
cached_folder
,
name
)):
loaded_sub_model
=
load_method
(
os
.
path
.
join
(
cached_folder
,
name
))
loaded_sub_model
=
load_method
(
os
.
path
.
join
(
cached_folder
,
name
))
else
:
else
:
# else load from the root directory
loaded_sub_model
=
load_method
(
cached_folder
)
loaded_sub_model
=
load_method
(
cached_folder
)
init_kwargs
[
name
]
=
loaded_sub_model
# UNet(...), # DiffusionSchedule(...)
init_kwargs
[
name
]
=
loaded_sub_model
# UNet(...), # DiffusionSchedule(...)
# 5. Instantiate the pipeline
model
=
pipeline_class
(
**
init_kwargs
)
model
=
pipeline_class
(
**
init_kwargs
)
return
model
return
model
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment