Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
dd4cd081
Commit
dd4cd081
authored
Jun 07, 2022
by
Patrick von Platen
Browse files
fix naming
parent
ab8e5364
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
4 deletions
+8
-4
models/vision/ddpm/modeling_ddpm.py
models/vision/ddpm/modeling_ddpm.py
+1
-1
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+7
-3
No files found.
models/vision/ddpm/modeling_ddpm.py
View file @
dd4cd081
...
@@ -23,7 +23,7 @@ class DDPM(DiffusionPipeline):
...
@@ -23,7 +23,7 @@ class DDPM(DiffusionPipeline):
modeling_file
=
"modeling_ddpm.py"
modeling_file
=
"modeling_ddpm.py"
def
__init__
(
self
,
unet
,
noise_scheduler
,
vqvae
):
def
__init__
(
self
,
unet
,
noise_scheduler
):
super
().
__init__
()
super
().
__init__
()
self
.
register_modules
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
self
.
register_modules
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
...
...
src/diffusers/pipeline_utils.py
View file @
dd4cd081
...
@@ -90,10 +90,14 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -90,10 +90,14 @@ class DiffusionPipeline(ConfigMixin):
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
Optional
[
Union
[
str
,
os
.
PathLike
]],
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
Optional
[
Union
[
str
,
os
.
PathLike
]],
**
kwargs
):
# use snapshot download here to get it working from from_pretrained
# use snapshot download here to get it working from from_pretrained
cached_folder
=
snapshot_download
(
pretrained_model_name_or_path
)
if
not
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
cached_folder
=
snapshot_download
(
pretrained_model_name_or_path
)
else
:
cached_folder
=
pretrained_model_name_or_path
config_dict
,
pipeline_kwargs
=
cls
.
get_config_dict
(
cached_folder
)
config_dict
,
pipeline_kwargs
=
cls
.
get_config_dict
(
cached_folder
)
module
=
pipeline_kwargs
[
"_module"
]
module
=
pipeline_kwargs
.
pop
(
"_module"
,
None
)
# TODO(Suraj) - make from hub import work
# TODO(Suraj) - make from hub import work
# Make `ddpm = DiffusionPipeline.from_pretrained("fusing/ddpm-lsun-bedroom-pipe")` work
# Make `ddpm = DiffusionPipeline.from_pretrained("fusing/ddpm-lsun-bedroom-pipe")` work
# Add Sylvains code from transformers
# Add Sylvains code from transformers
...
@@ -118,7 +122,7 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -118,7 +122,7 @@ class DiffusionPipeline(ConfigMixin):
load_method
=
getattr
(
class_obj
,
load_method_name
)
load_method
=
getattr
(
class_obj
,
load_method_name
)
if
os
.
path
.
dir
(
os
.
path
.
join
(
cached_folder
,
name
)):
if
os
.
path
.
is
dir
(
os
.
path
.
join
(
cached_folder
,
name
)):
loaded_sub_model
=
load_method
(
os
.
path
.
join
(
cached_folder
,
name
))
loaded_sub_model
=
load_method
(
os
.
path
.
join
(
cached_folder
,
name
))
else
:
else
:
loaded_sub_model
=
load_method
(
cached_folder
)
loaded_sub_model
=
load_method
(
cached_folder
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment