Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
1a6196e8
"vscode:/vscode.git/clone" did not exist on "16b9a40c1ca9cc3be3587625d3d58271f9fd88c2"
Commit
1a6196e8
authored
Jun 07, 2022
by
Patrick von Platen
Browse files
add more logic for dynamic loading
parent
40dc888f
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
13 additions
and
2 deletions
+13
-2
models/vision/ddpm/modeling_ddpm.py
models/vision/ddpm/modeling_ddpm.py
+1
-1
models/vision/glide/modeling_vqvae.py.py
models/vision/glide/modeling_vqvae.py.py
+1
-0
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+11
-1
No files found.
models/vision/ddpm/modeling_ddpm.py
View file @
1a6196e8
...
...
@@ -23,7 +23,7 @@ class DDPM(DiffusionPipeline):
modeling_file
=
"modeling_ddpm.py"
def
__init__
(
self
,
unet
,
noise_scheduler
):
def
__init__
(
self
,
unet
,
noise_scheduler
,
vqvae
):
super
().
__init__
()
self
.
register_modules
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
...
...
models/vision/glide/modeling_vqvae.py.py
0 → 100755
View file @
1a6196e8
#!/usr/bin/env python3
src/diffusers/pipeline_utils.py
View file @
1a6196e8
...
...
@@ -71,6 +71,10 @@ class DiffusionPipeline(ConfigMixin):
for
name
,
(
library_name
,
class_name
)
in
self
.
_dict_to_save
.
items
():
importable_classes
=
LOADABLE_CLASSES
[
library_name
]
# TODO: Suraj
if
library_name
==
self
.
__module__
:
library_name
=
self
library
=
importlib
.
import_module
(
library_name
)
class_obj
=
getattr
(
library
,
class_name
)
class_candidates
=
{
c
:
getattr
(
library
,
c
)
for
c
in
importable_classes
.
keys
()}
...
...
@@ -91,12 +95,18 @@ class DiffusionPipeline(ConfigMixin):
module
=
pipeline_kwargs
[
"_module"
]
# TODO(Suraj) - make from hub import work
# Make `ddpm = DiffusionPipeline.from_pretrained("fusing/ddpm-lsun-bedroom-pipe")` work
# Add Sylvains code from transformers
init_kwargs
=
{}
for
name
,
(
library_name
,
class_name
)
in
config_dict
.
items
():
importable_classes
=
LOADABLE_CLASSES
[
library_name
]
if
library_name
==
module
:
# TODO(Suraj)
pass
library
=
importlib
.
import_module
(
library_name
)
class_obj
=
getattr
(
library
,
class_name
)
class_candidates
=
{
c
:
getattr
(
library
,
c
)
for
c
in
importable_classes
.
keys
()}
...
...
@@ -110,7 +120,7 @@ class DiffusionPipeline(ConfigMixin):
loaded_sub_model
=
load_method
(
os
.
path
.
join
(
cached_folder
,
name
))
init_kwargs
[
name
]
=
loaded_sub_model
init_kwargs
[
name
]
=
loaded_sub_model
# UNet(...), # DiffusionSchedule(...)
model
=
cls
(
**
init_kwargs
)
return
model
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment