Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
8faa822d
Unverified
Commit
8faa822d
authored
Nov 25, 2022
by
Patrick von Platen
Committed by
GitHub
Nov 25, 2022
Browse files
Allow to set config params directly in init (#1419)
* fix * fix deprecated kwargs logic * add tests * finish
parent
86aa747d
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
103 additions
and
11 deletions
+103
-11
src/diffusers/configuration_utils.py
src/diffusers/configuration_utils.py
+12
-8
src/diffusers/models/unet_2d_blocks.py
src/diffusers/models/unet_2d_blocks.py
+0
-2
src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
...users/pipelines/versatile_diffusion/modeling_text_unet.py
+0
-1
src/diffusers/schedulers/scheduling_ddim.py
src/diffusers/schedulers/scheduling_ddim.py
+1
-0
src/diffusers/schedulers/scheduling_ddim_flax.py
src/diffusers/schedulers/scheduling_ddim_flax.py
+1
-0
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+1
-0
src/diffusers/schedulers/scheduling_ddpm_flax.py
src/diffusers/schedulers/scheduling_ddpm_flax.py
+1
-0
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
+1
-0
src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
...ffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
+1
-0
tests/test_modeling_common.py
tests/test_modeling_common.py
+20
-0
tests/test_modeling_common_flax.py
tests/test_modeling_common_flax.py
+22
-0
tests/test_scheduler.py
tests/test_scheduler.py
+21
-0
tests/test_scheduler_flax.py
tests/test_scheduler_flax.py
+22
-0
No files found.
src/diffusers/configuration_utils.py
View file @
8faa822d
...
@@ -80,14 +80,18 @@ class ConfigMixin:
...
@@ -80,14 +80,18 @@ class ConfigMixin:
- **config_name** (`str`) -- A filename under which the config should stored when calling
- **config_name** (`str`) -- A filename under which the config should stored when calling
[`~ConfigMixin.save_config`] (should be overridden by parent class).
[`~ConfigMixin.save_config`] (should be overridden by parent class).
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
overridden by parent class).
overridden by subclass).
- **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by parent
- **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
class).
- **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the init function
should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
subclass).
"""
"""
config_name
=
None
config_name
=
None
ignore_for_config
=
[]
ignore_for_config
=
[]
has_compatibles
=
False
has_compatibles
=
False
_deprecated_kwargs
=
[]
def
register_to_config
(
self
,
**
kwargs
):
def
register_to_config
(
self
,
**
kwargs
):
if
self
.
config_name
is
None
:
if
self
.
config_name
is
None
:
raise
NotImplementedError
(
f
"Make sure that
{
self
.
__class__
}
has defined a class name `config_name`"
)
raise
NotImplementedError
(
f
"Make sure that
{
self
.
__class__
}
has defined a class name `config_name`"
)
...
@@ -195,10 +199,10 @@ class ConfigMixin:
...
@@ -195,10 +199,10 @@ class ConfigMixin:
if
"dtype"
in
unused_kwargs
:
if
"dtype"
in
unused_kwargs
:
init_dict
[
"dtype"
]
=
unused_kwargs
.
pop
(
"dtype"
)
init_dict
[
"dtype"
]
=
unused_kwargs
.
pop
(
"dtype"
)
if
"predict_epsilon"
in
unused_kwargs
and
"prediction_type"
not
in
init_dict
:
# add possible deprecated kwargs
deprecate
(
"remove this"
,
"0.10.0"
,
"remove"
)
for
deprecate
d_kwarg
in
cls
.
_deprecated_kwargs
:
predict_epsilon
=
unused_kwargs
.
pop
(
"predict_epsilon"
)
if
deprecated_kwarg
in
unused_kwargs
:
init_dict
[
"
pre
diction_type"
]
=
"epsilon"
if
predict_epsilon
else
"sample"
init_dict
[
de
pre
cated_kwarg
]
=
unused_kwargs
.
pop
(
deprecated_kwarg
)
# Return model and optionally state and/or unused_kwargs
# Return model and optionally state and/or unused_kwargs
model
=
cls
(
**
init_dict
)
model
=
cls
(
**
init_dict
)
...
@@ -526,7 +530,6 @@ def register_to_config(init):
...
@@ -526,7 +530,6 @@ def register_to_config(init):
# Ignore private kwargs in the init.
# Ignore private kwargs in the init.
init_kwargs
=
{
k
:
v
for
k
,
v
in
kwargs
.
items
()
if
not
k
.
startswith
(
"_"
)}
init_kwargs
=
{
k
:
v
for
k
,
v
in
kwargs
.
items
()
if
not
k
.
startswith
(
"_"
)}
config_init_kwargs
=
{
k
:
v
for
k
,
v
in
kwargs
.
items
()
if
k
.
startswith
(
"_"
)}
config_init_kwargs
=
{
k
:
v
for
k
,
v
in
kwargs
.
items
()
if
k
.
startswith
(
"_"
)}
init
(
self
,
*
args
,
**
init_kwargs
)
if
not
isinstance
(
self
,
ConfigMixin
):
if
not
isinstance
(
self
,
ConfigMixin
):
raise
RuntimeError
(
raise
RuntimeError
(
f
"`@register_for_config` was applied to
{
self
.
__class__
.
__name__
}
init method, but this class does "
f
"`@register_for_config` was applied to
{
self
.
__class__
.
__name__
}
init method, but this class does "
...
@@ -553,6 +556,7 @@ def register_to_config(init):
...
@@ -553,6 +556,7 @@ def register_to_config(init):
)
)
new_kwargs
=
{
**
config_init_kwargs
,
**
new_kwargs
}
new_kwargs
=
{
**
config_init_kwargs
,
**
new_kwargs
}
getattr
(
self
,
"register_to_config"
)(
**
new_kwargs
)
getattr
(
self
,
"register_to_config"
)(
**
new_kwargs
)
init
(
self
,
*
args
,
**
init_kwargs
)
return
inner_init
return
inner_init
...
...
src/diffusers/models/unet_2d_blocks.py
View file @
8faa822d
...
@@ -254,7 +254,6 @@ class UNetMidBlock2D(nn.Module):
...
@@ -254,7 +254,6 @@ class UNetMidBlock2D(nn.Module):
attn_num_head_channels
=
1
,
attn_num_head_channels
=
1
,
attention_type
=
"default"
,
attention_type
=
"default"
,
output_scale_factor
=
1.0
,
output_scale_factor
=
1.0
,
**
kwargs
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -336,7 +335,6 @@ class UNetMidBlock2DCrossAttn(nn.Module):
...
@@ -336,7 +335,6 @@ class UNetMidBlock2DCrossAttn(nn.Module):
cross_attention_dim
=
1280
,
cross_attention_dim
=
1280
,
dual_cross_attention
=
False
,
dual_cross_attention
=
False
,
use_linear_projection
=
False
,
use_linear_projection
=
False
,
**
kwargs
,
):
):
super
().
__init__
()
super
().
__init__
()
...
...
src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
View file @
8faa822d
...
@@ -1039,7 +1039,6 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
...
@@ -1039,7 +1039,6 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
cross_attention_dim
=
1280
,
cross_attention_dim
=
1280
,
dual_cross_attention
=
False
,
dual_cross_attention
=
False
,
use_linear_projection
=
False
,
use_linear_projection
=
False
,
**
kwargs
,
):
):
super
().
__init__
()
super
().
__init__
()
...
...
src/diffusers/schedulers/scheduling_ddim.py
View file @
8faa822d
...
@@ -113,6 +113,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -113,6 +113,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
"""
"""
_compatibles
=
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
_compatibles
=
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
_deprecated_kwargs
=
[
"predict_epsilon"
]
@
register_to_config
@
register_to_config
def
__init__
(
def
__init__
(
...
...
src/diffusers/schedulers/scheduling_ddim_flax.py
View file @
8faa822d
...
@@ -116,6 +116,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -116,6 +116,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
"""
_compatibles
=
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
_compatibles
=
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
_deprecated_kwargs
=
[
"predict_epsilon"
]
@
property
@
property
def
has_state
(
self
):
def
has_state
(
self
):
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
8faa822d
...
@@ -105,6 +105,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -105,6 +105,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
"""
"""
_compatibles
=
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
_compatibles
=
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
_deprecated_kwargs
=
[
"predict_epsilon"
]
@
register_to_config
@
register_to_config
def
__init__
(
def
__init__
(
...
...
src/diffusers/schedulers/scheduling_ddpm_flax.py
View file @
8faa822d
...
@@ -109,6 +109,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -109,6 +109,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
"""
_compatibles
=
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
_compatibles
=
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
_deprecated_kwargs
=
[
"predict_epsilon"
]
@
property
@
property
def
has_state
(
self
):
def
has_state
(
self
):
...
...
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
View file @
8faa822d
...
@@ -117,6 +117,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -117,6 +117,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
"""
_compatibles
=
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
_compatibles
=
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
_deprecated_kwargs
=
[
"predict_epsilon"
]
@
register_to_config
@
register_to_config
def
__init__
(
def
__init__
(
...
...
src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
View file @
8faa822d
...
@@ -149,6 +149,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -149,6 +149,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
"""
_compatibles
=
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
_compatibles
=
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
.
copy
()
_deprecated_kwargs
=
[
"predict_epsilon"
]
@
property
@
property
def
has_state
(
self
):
def
has_state
(
self
):
...
...
tests/test_modeling_common.py
View file @
8faa822d
...
@@ -265,3 +265,23 @@ class ModelTesterMixin:
...
@@ -265,3 +265,23 @@ class ModelTesterMixin:
# check disable works
# check disable works
model
.
disable_gradient_checkpointing
()
model
.
disable_gradient_checkpointing
()
self
.
assertFalse
(
model
.
is_gradient_checkpointing
)
self
.
assertFalse
(
model
.
is_gradient_checkpointing
)
def
test_deprecated_kwargs
(
self
):
has_kwarg_in_model_class
=
"kwargs"
in
inspect
.
signature
(
self
.
model_class
.
__init__
).
parameters
has_deprecated_kwarg
=
len
(
self
.
model_class
.
_deprecated_kwargs
)
>
0
if
has_kwarg_in_model_class
and
not
has_deprecated_kwarg
:
raise
ValueError
(
f
"
{
self
.
model_class
}
has `**kwargs` in its __init__ method but has not defined any deprecated kwargs"
" under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if there are"
" no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="
" [<deprecated_argument>]`"
)
if
not
has_kwarg_in_model_class
and
has_deprecated_kwarg
:
raise
ValueError
(
f
"
{
self
.
model_class
}
doesn't have `**kwargs` in its __init__ method but has defined deprecated kwargs"
" under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs` argument to"
f
"
{
self
.
model_class
}
.__init__ if there are deprecated arguments or remove the deprecated argument"
" from `_deprecated_kwargs = [<deprecated_argument>]`"
)
tests/test_modeling_common_flax.py
View file @
8faa822d
import
inspect
from
diffusers.utils
import
is_flax_available
from
diffusers.utils
import
is_flax_available
from
diffusers.utils.testing_utils
import
require_flax
from
diffusers.utils.testing_utils
import
require_flax
...
@@ -42,3 +44,23 @@ class FlaxModelTesterMixin:
...
@@ -42,3 +44,23 @@ class FlaxModelTesterMixin:
self
.
assertIsNotNone
(
output
)
self
.
assertIsNotNone
(
output
)
expected_shape
=
inputs_dict
[
"sample"
].
shape
expected_shape
=
inputs_dict
[
"sample"
].
shape
self
.
assertEqual
(
output
.
shape
,
expected_shape
,
"Input and output shapes do not match"
)
self
.
assertEqual
(
output
.
shape
,
expected_shape
,
"Input and output shapes do not match"
)
def
test_deprecated_kwargs
(
self
):
has_kwarg_in_model_class
=
"kwargs"
in
inspect
.
signature
(
self
.
model_class
.
__init__
).
parameters
has_deprecated_kwarg
=
len
(
self
.
model_class
.
_deprecated_kwargs
)
>
0
if
has_kwarg_in_model_class
and
not
has_deprecated_kwarg
:
raise
ValueError
(
f
"
{
self
.
model_class
}
has `**kwargs` in its __init__ method but has not defined any deprecated kwargs"
" under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if there are"
" no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="
" [<deprecated_argument>]`"
)
if
not
has_kwarg_in_model_class
and
has_deprecated_kwarg
:
raise
ValueError
(
f
"
{
self
.
model_class
}
doesn't have `**kwargs` in its __init__ method but has defined deprecated kwargs"
" under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs` argument to"
f
"
{
self
.
model_class
}
.__init__ if there are deprecated arguments or remove the deprecated argument"
" from `_deprecated_kwargs = [<deprecated_argument>]`"
)
tests/test_scheduler.py
View file @
8faa822d
...
@@ -562,6 +562,27 @@ class SchedulerCommonTest(unittest.TestCase):
...
@@ -562,6 +562,27 @@ class SchedulerCommonTest(unittest.TestCase):
noised
=
scheduler
.
add_noise
(
scaled_sample
,
noise
,
t
)
noised
=
scheduler
.
add_noise
(
scaled_sample
,
noise
,
t
)
self
.
assertEqual
(
noised
.
shape
,
scaled_sample
.
shape
)
self
.
assertEqual
(
noised
.
shape
,
scaled_sample
.
shape
)
def
test_deprecated_kwargs
(
self
):
for
scheduler_class
in
self
.
scheduler_classes
:
has_kwarg_in_model_class
=
"kwargs"
in
inspect
.
signature
(
scheduler_class
.
__init__
).
parameters
has_deprecated_kwarg
=
len
(
scheduler_class
.
_deprecated_kwargs
)
>
0
if
has_kwarg_in_model_class
and
not
has_deprecated_kwarg
:
raise
ValueError
(
f
"
{
scheduler_class
}
has `**kwargs` in its __init__ method but has not defined any deprecated"
" kwargs under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if"
" there are no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="
" [<deprecated_argument>]`"
)
if
not
has_kwarg_in_model_class
and
has_deprecated_kwarg
:
raise
ValueError
(
f
"
{
scheduler_class
}
doesn't have `**kwargs` in its __init__ method but has defined deprecated"
" kwargs under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs`"
f
" argument to
{
self
.
model_class
}
.__init__ if there are deprecated arguments or remove the"
" deprecated argument from `_deprecated_kwargs = [<deprecated_argument>]`"
)
class
DDPMSchedulerTest
(
SchedulerCommonTest
):
class
DDPMSchedulerTest
(
SchedulerCommonTest
):
scheduler_classes
=
(
DDPMScheduler
,)
scheduler_classes
=
(
DDPMScheduler
,)
...
...
tests/test_scheduler_flax.py
View file @
8faa822d
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
inspect
import
tempfile
import
tempfile
import
unittest
import
unittest
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
Dict
,
List
,
Tuple
...
@@ -228,6 +229,27 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
...
@@ -228,6 +229,27 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
recursive_check
(
outputs_tuple
[
0
],
outputs_dict
.
prev_sample
)
recursive_check
(
outputs_tuple
[
0
],
outputs_dict
.
prev_sample
)
def
test_deprecated_kwargs
(
self
):
for
scheduler_class
in
self
.
scheduler_classes
:
has_kwarg_in_model_class
=
"kwargs"
in
inspect
.
signature
(
scheduler_class
.
__init__
).
parameters
has_deprecated_kwarg
=
len
(
scheduler_class
.
_deprecated_kwargs
)
>
0
if
has_kwarg_in_model_class
and
not
has_deprecated_kwarg
:
raise
ValueError
(
f
"
{
scheduler_class
}
has `**kwargs` in its __init__ method but has not defined any deprecated"
" kwargs under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if"
" there are no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="
" [<deprecated_argument>]`"
)
if
not
has_kwarg_in_model_class
and
has_deprecated_kwarg
:
raise
ValueError
(
f
"
{
scheduler_class
}
doesn't have `**kwargs` in its __init__ method but has defined deprecated"
" kwargs under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs`"
f
" argument to
{
self
.
model_class
}
.__init__ if there are deprecated arguments or remove the"
" deprecated argument from `_deprecated_kwargs = [<deprecated_argument>]`"
)
@
require_flax
@
require_flax
class
FlaxDDPMSchedulerTest
(
FlaxSchedulerCommonTest
):
class
FlaxDDPMSchedulerTest
(
FlaxSchedulerCommonTest
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment