Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
936cd084
"src/vscode:/vscode.git/clone" did not exist on "ed759f0aee721f8520c5bf94d4b7bd7c0ae3dcbb"
Commit
936cd084
authored
Jul 19, 2022
by
Patrick von Platen
Browse files
improve loading a bit
parent
3a32b8c9
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
22 additions
and
0 deletions
+22
-0
src/diffusers/configuration_utils.py
src/diffusers/configuration_utils.py
+1
-0
src/diffusers/modeling_utils.py
src/diffusers/modeling_utils.py
+1
-0
src/diffusers/models/unet_conditional.py
src/diffusers/models/unet_conditional.py
+10
-0
src/diffusers/models/unet_unconditional.py
src/diffusers/models/unet_unconditional.py
+10
-0
No files found.
src/diffusers/configuration_utils.py
View file @
936cd084
...
@@ -208,6 +208,7 @@ class ConfigMixin:
...
@@ -208,6 +208,7 @@ class ConfigMixin:
def
extract_init_dict
(
cls
,
config_dict
,
**
kwargs
):
def
extract_init_dict
(
cls
,
config_dict
,
**
kwargs
):
expected_keys
=
set
(
dict
(
inspect
.
signature
(
cls
.
__init__
).
parameters
).
keys
())
expected_keys
=
set
(
dict
(
inspect
.
signature
(
cls
.
__init__
).
parameters
).
keys
())
expected_keys
.
remove
(
"self"
)
expected_keys
.
remove
(
"self"
)
expected_keys
.
remove
(
"kwargs"
)
init_dict
=
{}
init_dict
=
{}
for
key
in
expected_keys
:
for
key
in
expected_keys
:
if
key
in
kwargs
:
if
key
in
kwargs
:
...
...
src/diffusers/modeling_utils.py
View file @
936cd084
...
@@ -147,6 +147,7 @@ class ModelMixin(torch.nn.Module):
...
@@ -147,6 +147,7 @@ class ModelMixin(torch.nn.Module):
models, `pixel_values` for vision models and `input_values` for speech models).
models, `pixel_values` for vision models and `input_values` for speech models).
"""
"""
config_name
=
CONFIG_NAME
config_name
=
CONFIG_NAME
_automatically_saved_args
=
[
"_diffusers_version"
,
"_class_name"
,
"name_or_path"
]
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
()
super
().
__init__
()
...
...
src/diffusers/models/unet_conditional.py
View file @
936cd084
...
@@ -63,8 +63,18 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
...
@@ -63,8 +63,18 @@ class UNetConditionalModel(ModelMixin, ConfigMixin):
mid_block_scale_factor
=
1
,
mid_block_scale_factor
=
1
,
center_input_sample
=
False
,
center_input_sample
=
False
,
resnet_num_groups
=
30
,
resnet_num_groups
=
30
,
**
kwargs
,
):
):
super
().
__init__
()
super
().
__init__
()
# remove automatically added kwargs
for
arg
in
self
.
_automatically_saved_args
:
kwargs
.
pop
(
arg
,
None
)
if
len
(
kwargs
)
>
0
:
raise
ValueError
(
f
"The following keyword arguments do not exist for
{
self
.
__class__
}
:
{
','
.
join
(
kwargs
.
keys
())
}
"
)
# register all __init__ params to be accessible via `self.config.<...>`
# register all __init__ params to be accessible via `self.config.<...>`
# should probably be automated down the road as this is pure boiler plate code
# should probably be automated down the road as this is pure boiler plate code
self
.
register_to_config
(
self
.
register_to_config
(
...
...
src/diffusers/models/unet_unconditional.py
View file @
936cd084
...
@@ -59,8 +59,18 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -59,8 +59,18 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
mid_block_scale_factor
=
1
,
mid_block_scale_factor
=
1
,
center_input_sample
=
False
,
center_input_sample
=
False
,
resnet_num_groups
=
32
,
resnet_num_groups
=
32
,
**
kwargs
,
):
):
super
().
__init__
()
super
().
__init__
()
# remove automatically added kwargs
for
arg
in
self
.
_automatically_saved_args
:
kwargs
.
pop
(
arg
,
None
)
if
len
(
kwargs
)
>
0
:
raise
ValueError
(
f
"The following keyword arguments do not exist for
{
self
.
__class__
}
:
{
','
.
join
(
kwargs
.
keys
())
}
"
)
# register all __init__ params to be accessible via `self.config.<...>`
# register all __init__ params to be accessible via `self.config.<...>`
# should probably be automated down the road as this is pure boiler plate code
# should probably be automated down the road as this is pure boiler plate code
self
.
register_to_config
(
self
.
register_to_config
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment