Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
8faa822d
Unverified
Commit
8faa822d
authored
Nov 25, 2022
by
Patrick von Platen
Committed by
GitHub
Nov 25, 2022
Browse files
Allow to set config params directly in init (#1419)
* fix * fix deprecated kwargs logic * add tests * finish
parent
86aa747d
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
103 additions
and
11 deletions
+103
-11
src/diffusers/configuration_utils.py
src/diffusers/configuration_utils.py
+12
-8
src/diffusers/models/unet_2d_blocks.py
src/diffusers/models/unet_2d_blocks.py
+0
-2
src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
...users/pipelines/versatile_diffusion/modeling_text_unet.py
+0
-1
src/diffusers/schedulers/scheduling_ddim.py
src/diffusers/schedulers/scheduling_ddim.py
+1
-0
src/diffusers/schedulers/scheduling_ddim_flax.py
src/diffusers/schedulers/scheduling_ddim_flax.py
+1
-0
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+1
-0
src/diffusers/schedulers/scheduling_ddpm_flax.py
src/diffusers/schedulers/scheduling_ddpm_flax.py
+1
-0
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
+1
-0
src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
...ffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
+1
-0
tests/test_modeling_common.py
tests/test_modeling_common.py
+20
-0
tests/test_modeling_common_flax.py
tests/test_modeling_common_flax.py
+22
-0
tests/test_scheduler.py
tests/test_scheduler.py
+21
-0
tests/test_scheduler_flax.py
tests/test_scheduler_flax.py
+22
-0
No files found.
src/diffusers/configuration_utils.py
View file @
8faa822d
...
@@ -80,14 +80,18 @@ class ConfigMixin:
...
@@ -80,14 +80,18 @@ class ConfigMixin:
- **config_name** (`str`) -- A filename under which the config should stored when calling
- **config_name** (`str`) -- A filename under which the config should stored when calling
[`~ConfigMixin.save_config`] (should be overridden by parent class).
[`~ConfigMixin.save_config`] (should be overridden by parent class).
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
overridden by parent class).
overridden by subclass).
- **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by parent
- **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
class).
- **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the init function
should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
subclass).
"""
"""
config_name
=
None
config_name
=
None
ignore_for_config
=
[]
ignore_for_config
=
[]
has_compatibles
=
False
has_compatibles
=
False
_deprecated_kwargs
=
[]
def
register_to_config
(
self
,
**
kwargs
):
def
register_to_config
(
self
,
**
kwargs
):
if
self
.
config_name
is
None
:
if
self
.
config_name
is
None
:
raise
NotImplementedError
(
f
"Make sure that
{
self
.
__class__
}
has defined a class name `config_name`"
)
raise
NotImplementedError
(
f
"Make sure that
{
self
.
__class__
}
has defined a class name `config_name`"
)
...
@@ -195,10 +199,10 @@ class ConfigMixin:
...
@@ -195,10 +199,10 @@ class ConfigMixin:
if
"dtype"
in
unused_kwargs
:
if
"dtype"
in
unused_kwargs
:
init_dict
[
"dtype"
]
=
unused_kwargs
.
pop
(
"dtype"
)
init_dict
[
"dtype"
]
=
unused_kwargs
.
pop
(
"dtype"
)
if
"predict_epsilon"
in
unused_kwargs
and
"prediction_type"
not
in
init_dict
:
# add possible deprecated kwargs
deprecate
(
"remove this"
,
"0.10.0"
,
"remove"
)
for
deprecate
d_kwarg
in
cls
.
_deprecated_kwargs
:
predict_epsilon
=
unused_kwargs
.
pop
(
"predict_epsilon"
)
if
deprecated_kwarg
in
unused_kwargs
:
init_dict
[
"
pre
diction_type"
]
=
"epsilon"
if
predict_epsilon
else
"sample"
init_dict
[
de
pre
cated_kwarg
]
=
unused_kwargs
.
pop
(
deprecated_kwarg
)
# Return model and optionally state and/or unused_kwargs
# Return model and optionally state and/or unused_kwargs
model
=
cls
(
**
init_dict
)
model
=
cls
(
**
init_dict
)
...
@@ -526,7 +530,6 @@ def register_to_config(init):
...
@@ -526,7 +530,6 @@ def register_to_config(init):
# Ignore private kwargs in the init.
# Ignore private kwargs in the init.
init_kwargs
=
{
k
:
v
for
k
,
v
in
kwargs
.
items
()
if
not
k
.
startswith
(
"_"
)}
init_kwargs
=
{
k
:
v
for
k
,
v
in
kwargs
.
items
()
if
not
k
.
startswith
(
"_"
)}
config_init_kwargs
=
{
k
:
v
for
k
,
v
in
kwargs
.
items
()
if
k
.
startswith
(
"_"
)}
config_init_kwargs
=
{
k
:
v
for
k
,
v
in
kwargs
.
items
()
if
k
.
startswith
(
"_"
)}
init
(
self
,
*
args
,
**
init_kwargs
)
if
not
isinstance
(
self
,
ConfigMixin
):
if
not
isinstance
(
self
,
ConfigMixin
):
raise
RuntimeError
(
raise
RuntimeError
(
f
"`@register_for_config` was applied to
{
self
.
__class__
.
__name__
}
init method, but this class does "
f
"`@register_for_config` was applied to
{
self
.
__class__
.
__name__
}
init method, but this class does "
...
@@ -553,6 +556,7 @@ def register_to_config(init):
...
@@ -553,6 +556,7 @@ def register_to_config(init):
)
)
new_kwargs
=
{
**
config_init_kwargs
,
**
new_kwargs
}
new_kwargs
=
{
**
config_init_kwargs
,
**
new_kwargs
}
getattr
(
self
,
"register_to_config"
)(
**
new_kwargs
)
getattr
(
self
,
"register_to_config"
)(
**
new_kwargs
)
init
(
self
,
*
args
,
**
init_kwargs
)
return
inner_init
return
inner_init
...
...
src/diffusers/models/unet_2d_blocks.py
View file @
8faa822d
...
@@ -254,7 +254,6 @@ class UNetMidBlock2D(nn.Module):
...
@@ -254,7 +254,6 @@ class UNetMidBlock2D(nn.Module):
attn_num_head_channels
=
1
,
attn_num_head_channels
=
1
,
attention_type
=
"default"
,
attention_type
=
"default"
,
output_scale_factor
=
1.0
,
output_scale_factor
=
1.0
,
**
kwargs
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -336,7 +335,6 @@ class UNetMidBlock2DCrossAttn(nn.Module):
...
@@ -336,7 +335,6 @@ class UNetMidBlock2DCrossAttn(nn.Module):
cross_attention_dim
=
1280
,
cross_attention_dim
=
1280
,
dual_cross_attention
=
False
,
dual_cross_attention
=
False
,
use_linear_projection
=
False
,
use_linear_projection
=
False
,
**
kwargs
,
):
):
super
().
__init__
()
super
().
__init__
()
...
...
src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
View file @
8faa822d
...
@@ -1039,7 +1039,6 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
...
@@ -1039,7 +1039,6 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
cross_attention_dim
=
1280
,
cross_attention_dim
=
1280
,
dual_cross_attention
=
False
,
dual_cross_attention
=
False
,
use_linear_projection
=
False
,
use_linear_projection
=
False
,
**
kwargs
,
):
):
super
().
__init__
()
super
().
__init__
()
...
...
src/diffusers/schedulers/scheduling_ddim.py
View file @
8faa822d
...
@@ -113,6 +113,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -113,6 +113,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
"""
"""
_compatibles
=
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
_compatibles
=
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
_deprecated_kwargs
=
[
"predict_epsilon"
]
@
register_to_config
@
register_to_config
def
__init__
(
def
__init__
(
...
...
src/diffusers/schedulers/scheduling_ddim_flax.py
View file @
8faa822d
...
@@ -116,6 +116,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -116,6 +116,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
"""
_compatibles
=
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
_compatibles
=
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
_deprecated_kwargs
=
[
"predict_epsilon"
]
@
property
@
property
def
has_state
(
self
):
def
has_state
(
self
):
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
8faa822d
...
@@ -105,6 +105,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -105,6 +105,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
"""
"""
_compatibles
=
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
_compatibles
=
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
_deprecated_kwargs
=
[
"predict_epsilon"
]
@
register_to_config
@
register_to_config
def
__init__
(
def
__init__
(
...
...
src/diffusers/schedulers/scheduling_ddpm_flax.py
View file @
8faa822d
...
@@ -109,6 +109,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -109,6 +109,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
"""
_compatibles
=
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
_compatibles
=
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
_deprecated_kwargs
=
[
"predict_epsilon"
]
@
property
@
property
def
has_state
(
self
):
def
has_state
(
self
):
...
...
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
View file @
8faa822d
...
@@ -117,6 +117,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -117,6 +117,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
"""
_compatibles
=
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
_compatibles
=
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
_deprecated_kwargs
=
[
"predict_epsilon"
]
@
register_to_config
@
register_to_config
def
__init__
(
def
__init__
(
...
...
src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
View file @
8faa822d
...
@@ -149,6 +149,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -149,6 +149,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
"""
_compatibles
=
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
_compatibles
=
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
_deprecated_kwargs
=
[
"predict_epsilon"
]
@
property
@
property
def
has_state
(
self
):
def
has_state
(
self
):
...
...
tests/test_modeling_common.py
View file @
8faa822d
...
@@ -265,3 +265,23 @@ class ModelTesterMixin:
...
@@ -265,3 +265,23 @@ class ModelTesterMixin:
# check disable works
# check disable works
model
.
disable_gradient_checkpointing
()
model
.
disable_gradient_checkpointing
()
self
.
assertFalse
(
model
.
is_gradient_checkpointing
)
self
.
assertFalse
(
model
.
is_gradient_checkpointing
)
def
test_deprecated_kwargs
(
self
):
has_kwarg_in_model_class
=
"kwargs"
in
inspect
.
signature
(
self
.
model_class
.
__init__
).
parameters
has_deprecated_kwarg
=
len
(
self
.
model_class
.
_deprecated_kwargs
)
>
0
if
has_kwarg_in_model_class
and
not
has_deprecated_kwarg
:
raise
ValueError
(
f
"
{
self
.
model_class
}
has `**kwargs` in its __init__ method but has not defined any deprecated kwargs"
" under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if there are"
" no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="
" [<deprecated_argument>]`"
)
if
not
has_kwarg_in_model_class
and
has_deprecated_kwarg
:
raise
ValueError
(
f
"
{
self
.
model_class
}
doesn't have `**kwargs` in its __init__ method but has defined deprecated kwargs"
" under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs` argument to"
f
"
{
self
.
model_class
}
.__init__ if there are deprecated arguments or remove the deprecated argument"
" from `_deprecated_kwargs = [<deprecated_argument>]`"
)
tests/test_modeling_common_flax.py
View file @
8faa822d
import
inspect
from
diffusers.utils
import
is_flax_available
from
diffusers.utils
import
is_flax_available
from
diffusers.utils.testing_utils
import
require_flax
from
diffusers.utils.testing_utils
import
require_flax
...
@@ -42,3 +44,23 @@ class FlaxModelTesterMixin:
...
@@ -42,3 +44,23 @@ class FlaxModelTesterMixin:
self
.
assertIsNotNone
(
output
)
self
.
assertIsNotNone
(
output
)
expected_shape
=
inputs_dict
[
"sample"
].
shape
expected_shape
=
inputs_dict
[
"sample"
].
shape
self
.
assertEqual
(
output
.
shape
,
expected_shape
,
"Input and output shapes do not match"
)
self
.
assertEqual
(
output
.
shape
,
expected_shape
,
"Input and output shapes do not match"
)
def
test_deprecated_kwargs
(
self
):
has_kwarg_in_model_class
=
"kwargs"
in
inspect
.
signature
(
self
.
model_class
.
__init__
).
parameters
has_deprecated_kwarg
=
len
(
self
.
model_class
.
_deprecated_kwargs
)
>
0
if
has_kwarg_in_model_class
and
not
has_deprecated_kwarg
:
raise
ValueError
(
f
"
{
self
.
model_class
}
has `**kwargs` in its __init__ method but has not defined any deprecated kwargs"
" under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if there are"
" no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="
" [<deprecated_argument>]`"
)
if
not
has_kwarg_in_model_class
and
has_deprecated_kwarg
:
raise
ValueError
(
f
"
{
self
.
model_class
}
doesn't have `**kwargs` in its __init__ method but has defined deprecated kwargs"
" under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs` argument to"
f
"
{
self
.
model_class
}
.__init__ if there are deprecated arguments or remove the deprecated argument"
" from `_deprecated_kwargs = [<deprecated_argument>]`"
)
tests/test_scheduler.py
View file @
8faa822d
...
@@ -562,6 +562,27 @@ class SchedulerCommonTest(unittest.TestCase):
...
@@ -562,6 +562,27 @@ class SchedulerCommonTest(unittest.TestCase):
noised
=
scheduler
.
add_noise
(
scaled_sample
,
noise
,
t
)
noised
=
scheduler
.
add_noise
(
scaled_sample
,
noise
,
t
)
self
.
assertEqual
(
noised
.
shape
,
scaled_sample
.
shape
)
self
.
assertEqual
(
noised
.
shape
,
scaled_sample
.
shape
)
def
test_deprecated_kwargs
(
self
):
for
scheduler_class
in
self
.
scheduler_classes
:
has_kwarg_in_model_class
=
"kwargs"
in
inspect
.
signature
(
scheduler_class
.
__init__
).
parameters
has_deprecated_kwarg
=
len
(
scheduler_class
.
_deprecated_kwargs
)
>
0
if
has_kwarg_in_model_class
and
not
has_deprecated_kwarg
:
raise
ValueError
(
f
"
{
scheduler_class
}
has `**kwargs` in its __init__ method but has not defined any deprecated"
" kwargs under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if"
" there are no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="
" [<deprecated_argument>]`"
)
if
not
has_kwarg_in_model_class
and
has_deprecated_kwarg
:
raise
ValueError
(
f
"
{
scheduler_class
}
doesn't have `**kwargs` in its __init__ method but has defined deprecated"
" kwargs under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs`"
f
" argument to
{
self
.
model_class
}
.__init__ if there are deprecated arguments or remove the"
" deprecated argument from `_deprecated_kwargs = [<deprecated_argument>]`"
)
class
DDPMSchedulerTest
(
SchedulerCommonTest
):
class
DDPMSchedulerTest
(
SchedulerCommonTest
):
scheduler_classes
=
(
DDPMScheduler
,)
scheduler_classes
=
(
DDPMScheduler
,)
...
...
tests/test_scheduler_flax.py
View file @
8faa822d
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
inspect
import
tempfile
import
tempfile
import
unittest
import
unittest
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
Dict
,
List
,
Tuple
...
@@ -228,6 +229,27 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
...
@@ -228,6 +229,27 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
recursive_check
(
outputs_tuple
[
0
],
outputs_dict
.
prev_sample
)
recursive_check
(
outputs_tuple
[
0
],
outputs_dict
.
prev_sample
)
def
test_deprecated_kwargs
(
self
):
for
scheduler_class
in
self
.
scheduler_classes
:
has_kwarg_in_model_class
=
"kwargs"
in
inspect
.
signature
(
scheduler_class
.
__init__
).
parameters
has_deprecated_kwarg
=
len
(
scheduler_class
.
_deprecated_kwargs
)
>
0
if
has_kwarg_in_model_class
and
not
has_deprecated_kwarg
:
raise
ValueError
(
f
"
{
scheduler_class
}
has `**kwargs` in its __init__ method but has not defined any deprecated"
" kwargs under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if"
" there are no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="
" [<deprecated_argument>]`"
)
if
not
has_kwarg_in_model_class
and
has_deprecated_kwarg
:
raise
ValueError
(
f
"
{
scheduler_class
}
doesn't have `**kwargs` in its __init__ method but has defined deprecated"
" kwargs under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs`"
f
" argument to
{
self
.
model_class
}
.__init__ if there are deprecated arguments or remove the"
" deprecated argument from `_deprecated_kwargs = [<deprecated_argument>]`"
)
@
require_flax
@
require_flax
class
FlaxDDPMSchedulerTest
(
FlaxSchedulerCommonTest
):
class
FlaxDDPMSchedulerTest
(
FlaxSchedulerCommonTest
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment