Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
758f9d22
Commit
758f9d22
authored
Jun 09, 2022
by
patil-suraj
Browse files
add some comments
parent
2234877e
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
2 deletions
+9
-2
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+9
-2
No files found.
src/diffusers/pipeline_utils.py
View file @
758f9d22
...
...
@@ -114,6 +114,7 @@ class DiffusionPipeline(ConfigMixin):
local_files_only
=
kwargs
.
pop
(
"local_files_only"
,
False
)
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
if
not
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
cached_folder
=
snapshot_download
(
...
...
@@ -129,11 +130,12 @@ class DiffusionPipeline(ConfigMixin):
config_dict
=
cls
.
get_config_dict
(
cached_folder
)
module
=
config_dict
[
"_module"
]
# 2. Get class name and module candidates to load custom models
class_name_
=
config_dict
[
"_class_name"
]
module_candidate
=
config_dict
[
"_module"
]
module_candidate_name
=
module_candidate
.
replace
(
".py"
,
""
)
# 3. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it
if
cls
!=
DiffusionPipeline
:
pipeline_class
=
cls
...
...
@@ -147,6 +149,7 @@ class DiffusionPipeline(ConfigMixin):
init_kwargs
=
{}
# 4. Load each module in the pipeline
for
name
,
(
library_name
,
class_name
)
in
init_dict
.
items
():
# if the model is not in diffusers or transformers, we need to load it from the hub
# assumes that it's a subclass of ModelMixin
...
...
@@ -156,6 +159,7 @@ class DiffusionPipeline(ConfigMixin):
importable_classes
=
ALL_IMPORTABLE_CLASSES
class_candidates
=
{
c
:
class_obj
for
c
in
ALL_IMPORTABLE_CLASSES
.
keys
()}
else
:
# else we just import it from the library.
library
=
importlib
.
import_module
(
library_name
)
class_obj
=
getattr
(
library
,
class_name
)
importable_classes
=
LOADABLE_CLASSES
[
library_name
]
...
...
@@ -168,12 +172,15 @@ class DiffusionPipeline(ConfigMixin):
load_method
=
getattr
(
class_obj
,
load_method_name
)
# check if the module is in a subdirectory
if
os
.
path
.
isdir
(
os
.
path
.
join
(
cached_folder
,
name
)):
loaded_sub_model
=
load_method
(
os
.
path
.
join
(
cached_folder
,
name
))
else
:
# else load from the root directory
loaded_sub_model
=
load_method
(
cached_folder
)
init_kwargs
[
name
]
=
loaded_sub_model
# UNet(...), # DiffusionSchedule(...)
# 5. Instantiate the pipeline
model
=
pipeline_class
(
**
init_kwargs
)
return
model
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment