Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
09779cbb
Unverified
Commit
09779cbb
authored
Jan 25, 2023
by
Patrick von Platen
Committed by
GitHub
Jan 25, 2023
Browse files
[Bump version] 0.13.0dev0 & Deprecate `predict_epsilon` (#2109)
* [Bump version] 0.13 * Bump model up * up
parent
b0cc7c20
Changes
33
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
8 additions
and
166 deletions
+8
-166
src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py
...iffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py
+1
-1
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
...nes/stable_diffusion/pipeline_stable_diffusion_img2img.py
+1
-1
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py
...ble_diffusion/pipeline_stable_diffusion_inpaint_legacy.py
+1
-1
src/diffusers/schedulers/scheduling_ddim.py
src/diffusers/schedulers/scheduling_ddim.py
+1
-11
src/diffusers/schedulers/scheduling_ddim_flax.py
src/diffusers/schedulers/scheduling_ddim_flax.py
+0
-11
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+2
-23
src/diffusers/schedulers/scheduling_ddpm_flax.py
src/diffusers/schedulers/scheduling_ddpm_flax.py
+0
-11
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
+0
-11
src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
...ffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
+0
-11
tests/pipelines/ddpm/test_ddpm.py
tests/pipelines/ddpm/test_ddpm.py
+0
-27
tests/test_config.py
tests/test_config.py
+0
-11
tests/test_scheduler.py
tests/test_scheduler.py
+1
-30
tests/test_scheduler_flax.py
tests/test_scheduler_flax.py
+1
-17
No files found.
src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py
View file @
09779cbb
...
@@ -303,7 +303,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
...
@@ -303,7 +303,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
(nsfw) content, according to the `safety_checker`.
(nsfw) content, according to the `safety_checker`.
"""
"""
message
=
"Please use `image` instead of `init_image`."
message
=
"Please use `image` instead of `init_image`."
init_image
=
deprecate
(
"init_image"
,
"0.1
3
.0"
,
message
,
take_from
=
kwargs
)
init_image
=
deprecate
(
"init_image"
,
"0.1
4
.0"
,
message
,
take_from
=
kwargs
)
image
=
init_image
or
image
image
=
init_image
or
image
if
isinstance
(
prompt
,
str
):
if
isinstance
(
prompt
,
str
):
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
View file @
09779cbb
...
@@ -616,7 +616,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
...
@@ -616,7 +616,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
(nsfw) content, according to the `safety_checker`.
(nsfw) content, according to the `safety_checker`.
"""
"""
message
=
"Please use `image` instead of `init_image`."
message
=
"Please use `image` instead of `init_image`."
init_image
=
deprecate
(
"init_image"
,
"0.1
3
.0"
,
message
,
take_from
=
kwargs
)
init_image
=
deprecate
(
"init_image"
,
"0.1
4
.0"
,
message
,
take_from
=
kwargs
)
image
=
init_image
or
image
image
=
init_image
or
image
# 1. Check inputs. Raise error if not correct
# 1. Check inputs. Raise error if not correct
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py
View file @
09779cbb
...
@@ -556,7 +556,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
...
@@ -556,7 +556,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
(nsfw) content, according to the `safety_checker`.
(nsfw) content, according to the `safety_checker`.
"""
"""
message
=
"Please use `image` instead of `init_image`."
message
=
"Please use `image` instead of `init_image`."
init_image
=
deprecate
(
"init_image"
,
"0.1
3
.0"
,
message
,
take_from
=
kwargs
)
init_image
=
deprecate
(
"init_image"
,
"0.1
4
.0"
,
message
,
take_from
=
kwargs
)
image
=
init_image
or
image
image
=
init_image
or
image
# 1. Check inputs
# 1. Check inputs
...
...
src/diffusers/schedulers/scheduling_ddim.py
View file @
09779cbb
...
@@ -23,7 +23,7 @@ import numpy as np
...
@@ -23,7 +23,7 @@ import numpy as np
import
torch
import
torch
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..utils
import
BaseOutput
,
deprecate
,
randn_tensor
from
..utils
import
BaseOutput
,
randn_tensor
from
.scheduling_utils
import
KarrasDiffusionSchedulers
,
SchedulerMixin
from
.scheduling_utils
import
KarrasDiffusionSchedulers
,
SchedulerMixin
...
@@ -113,7 +113,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -113,7 +113,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
"""
"""
_compatibles
=
[
e
.
name
for
e
in
KarrasDiffusionSchedulers
]
_compatibles
=
[
e
.
name
for
e
in
KarrasDiffusionSchedulers
]
_deprecated_kwargs
=
[
"predict_epsilon"
]
order
=
1
order
=
1
@
register_to_config
@
register_to_config
...
@@ -128,16 +127,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -128,16 +127,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
set_alpha_to_one
:
bool
=
True
,
set_alpha_to_one
:
bool
=
True
,
steps_offset
:
int
=
0
,
steps_offset
:
int
=
0
,
prediction_type
:
str
=
"epsilon"
,
prediction_type
:
str
=
"epsilon"
,
**
kwargs
,
):
):
message
=
(
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.13.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
:
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
if
trained_betas
is
not
None
:
if
trained_betas
is
not
None
:
self
.
betas
=
torch
.
tensor
(
trained_betas
,
dtype
=
torch
.
float32
)
self
.
betas
=
torch
.
tensor
(
trained_betas
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"linear"
:
elif
beta_schedule
==
"linear"
:
...
...
src/diffusers/schedulers/scheduling_ddim_flax.py
View file @
09779cbb
...
@@ -22,7 +22,6 @@ import flax
...
@@ -22,7 +22,6 @@ import flax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..utils
import
deprecate
from
.scheduling_utils_flax
import
(
from
.scheduling_utils_flax
import
(
CommonSchedulerState
,
CommonSchedulerState
,
FlaxKarrasDiffusionSchedulers
,
FlaxKarrasDiffusionSchedulers
,
...
@@ -103,7 +102,6 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -103,7 +102,6 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
"""
_compatibles
=
[
e
.
name
for
e
in
FlaxKarrasDiffusionSchedulers
]
_compatibles
=
[
e
.
name
for
e
in
FlaxKarrasDiffusionSchedulers
]
_deprecated_kwargs
=
[
"predict_epsilon"
]
dtype
:
jnp
.
dtype
dtype
:
jnp
.
dtype
...
@@ -123,16 +121,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -123,16 +121,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
steps_offset
:
int
=
0
,
steps_offset
:
int
=
0
,
prediction_type
:
str
=
"epsilon"
,
prediction_type
:
str
=
"epsilon"
,
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
**
kwargs
,
):
):
message
=
(
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
f
"
{
self
.
__class__
.
__name__
}
.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.13.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
:
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
self
.
dtype
=
dtype
self
.
dtype
=
dtype
def
create_state
(
self
,
common
:
Optional
[
CommonSchedulerState
]
=
None
)
->
DDIMSchedulerState
:
def
create_state
(
self
,
common
:
Optional
[
CommonSchedulerState
]
=
None
)
->
DDIMSchedulerState
:
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
09779cbb
...
@@ -21,8 +21,8 @@ from typing import List, Optional, Tuple, Union
...
@@ -21,8 +21,8 @@ from typing import List, Optional, Tuple, Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
..configuration_utils
import
ConfigMixin
,
FrozenDict
,
register_to_config
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..utils
import
BaseOutput
,
deprecate
,
randn_tensor
from
..utils
import
BaseOutput
,
randn_tensor
from
.scheduling_utils
import
KarrasDiffusionSchedulers
,
SchedulerMixin
from
.scheduling_utils
import
KarrasDiffusionSchedulers
,
SchedulerMixin
...
@@ -106,7 +106,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -106,7 +106,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
"""
"""
_compatibles
=
[
e
.
name
for
e
in
KarrasDiffusionSchedulers
]
_compatibles
=
[
e
.
name
for
e
in
KarrasDiffusionSchedulers
]
_deprecated_kwargs
=
[
"predict_epsilon"
]
order
=
1
order
=
1
@
register_to_config
@
register_to_config
...
@@ -120,16 +119,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -120,16 +119,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
variance_type
:
str
=
"fixed_small"
,
variance_type
:
str
=
"fixed_small"
,
clip_sample
:
bool
=
True
,
clip_sample
:
bool
=
True
,
prediction_type
:
str
=
"epsilon"
,
prediction_type
:
str
=
"epsilon"
,
**
kwargs
,
):
):
message
=
(
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.13.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
:
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
if
trained_betas
is
not
None
:
if
trained_betas
is
not
None
:
self
.
betas
=
torch
.
tensor
(
trained_betas
,
dtype
=
torch
.
float32
)
self
.
betas
=
torch
.
tensor
(
trained_betas
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"linear"
:
elif
beta_schedule
==
"linear"
:
...
@@ -239,7 +229,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -239,7 +229,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
sample
:
torch
.
FloatTensor
,
sample
:
torch
.
FloatTensor
,
generator
=
None
,
generator
=
None
,
return_dict
:
bool
=
True
,
return_dict
:
bool
=
True
,
**
kwargs
,
)
->
Union
[
DDPMSchedulerOutput
,
Tuple
]:
)
->
Union
[
DDPMSchedulerOutput
,
Tuple
]:
"""
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
...
@@ -259,16 +248,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -259,16 +248,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
returning a tuple, the first element is the sample tensor.
returning a tuple, the first element is the sample tensor.
"""
"""
message
=
(
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.13.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
:
new_config
=
dict
(
self
.
config
)
new_config
[
"prediction_type"
]
=
"epsilon"
if
predict_epsilon
else
"sample"
self
.
_internal_dict
=
FrozenDict
(
new_config
)
t
=
timestep
t
=
timestep
if
model_output
.
shape
[
1
]
==
sample
.
shape
[
1
]
*
2
and
self
.
variance_type
in
[
"learned"
,
"learned_range"
]:
if
model_output
.
shape
[
1
]
==
sample
.
shape
[
1
]
*
2
and
self
.
variance_type
in
[
"learned"
,
"learned_range"
]:
...
...
src/diffusers/schedulers/scheduling_ddpm_flax.py
View file @
09779cbb
...
@@ -22,7 +22,6 @@ import jax
...
@@ -22,7 +22,6 @@ import jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..utils
import
deprecate
from
.scheduling_utils_flax
import
(
from
.scheduling_utils_flax
import
(
CommonSchedulerState
,
CommonSchedulerState
,
FlaxKarrasDiffusionSchedulers
,
FlaxKarrasDiffusionSchedulers
,
...
@@ -86,7 +85,6 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -86,7 +85,6 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
"""
_compatibles
=
[
e
.
name
for
e
in
FlaxKarrasDiffusionSchedulers
]
_compatibles
=
[
e
.
name
for
e
in
FlaxKarrasDiffusionSchedulers
]
_deprecated_kwargs
=
[
"predict_epsilon"
]
dtype
:
jnp
.
dtype
dtype
:
jnp
.
dtype
...
@@ -106,16 +104,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -106,16 +104,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
clip_sample
:
bool
=
True
,
clip_sample
:
bool
=
True
,
prediction_type
:
str
=
"epsilon"
,
prediction_type
:
str
=
"epsilon"
,
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
**
kwargs
,
):
):
message
=
(
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
f
"
{
self
.
__class__
.
__name__
}
.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.13.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
:
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
self
.
dtype
=
dtype
self
.
dtype
=
dtype
def
create_state
(
self
,
common
:
Optional
[
CommonSchedulerState
]
=
None
)
->
DDPMSchedulerState
:
def
create_state
(
self
,
common
:
Optional
[
CommonSchedulerState
]
=
None
)
->
DDPMSchedulerState
:
...
...
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
View file @
09779cbb
...
@@ -21,7 +21,6 @@ import numpy as np
...
@@ -21,7 +21,6 @@ import numpy as np
import
torch
import
torch
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..utils
import
deprecate
from
.scheduling_utils
import
KarrasDiffusionSchedulers
,
SchedulerMixin
,
SchedulerOutput
from
.scheduling_utils
import
KarrasDiffusionSchedulers
,
SchedulerMixin
,
SchedulerOutput
...
@@ -118,7 +117,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -118,7 +117,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
"""
_compatibles
=
[
e
.
name
for
e
in
KarrasDiffusionSchedulers
]
_compatibles
=
[
e
.
name
for
e
in
KarrasDiffusionSchedulers
]
_deprecated_kwargs
=
[
"predict_epsilon"
]
order
=
1
order
=
1
@
register_to_config
@
register_to_config
...
@@ -137,16 +135,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -137,16 +135,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
algorithm_type
:
str
=
"dpmsolver++"
,
algorithm_type
:
str
=
"dpmsolver++"
,
solver_type
:
str
=
"midpoint"
,
solver_type
:
str
=
"midpoint"
,
lower_order_final
:
bool
=
True
,
lower_order_final
:
bool
=
True
,
**
kwargs
,
):
):
message
=
(
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DPMSolverMultistepScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.13.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
:
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
if
trained_betas
is
not
None
:
if
trained_betas
is
not
None
:
self
.
betas
=
torch
.
tensor
(
trained_betas
,
dtype
=
torch
.
float32
)
self
.
betas
=
torch
.
tensor
(
trained_betas
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"linear"
:
elif
beta_schedule
==
"linear"
:
...
...
src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
View file @
09779cbb
...
@@ -22,7 +22,6 @@ import jax
...
@@ -22,7 +22,6 @@ import jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..utils
import
deprecate
from
.scheduling_utils_flax
import
(
from
.scheduling_utils_flax
import
(
CommonSchedulerState
,
CommonSchedulerState
,
FlaxKarrasDiffusionSchedulers
,
FlaxKarrasDiffusionSchedulers
,
...
@@ -141,7 +140,6 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -141,7 +140,6 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
"""
_compatibles
=
[
e
.
name
for
e
in
FlaxKarrasDiffusionSchedulers
]
_compatibles
=
[
e
.
name
for
e
in
FlaxKarrasDiffusionSchedulers
]
_deprecated_kwargs
=
[
"predict_epsilon"
]
dtype
:
jnp
.
dtype
dtype
:
jnp
.
dtype
...
@@ -166,16 +164,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -166,16 +164,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
solver_type
:
str
=
"midpoint"
,
solver_type
:
str
=
"midpoint"
,
lower_order_final
:
bool
=
True
,
lower_order_final
:
bool
=
True
,
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
**
kwargs
,
):
):
message
=
(
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
f
"
{
self
.
__class__
.
__name__
}
.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.13.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
:
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
self
.
dtype
=
dtype
self
.
dtype
=
dtype
def
create_state
(
self
,
common
:
Optional
[
CommonSchedulerState
]
=
None
)
->
DPMSolverMultistepSchedulerState
:
def
create_state
(
self
,
common
:
Optional
[
CommonSchedulerState
]
=
None
)
->
DPMSolverMultistepSchedulerState
:
...
...
tests/pipelines/ddpm/test_ddpm.py
View file @
09779cbb
...
@@ -19,7 +19,6 @@ import numpy as np
...
@@ -19,7 +19,6 @@ import numpy as np
import
torch
import
torch
from
diffusers
import
DDPMPipeline
,
DDPMScheduler
,
UNet2DModel
from
diffusers
import
DDPMPipeline
,
DDPMScheduler
,
UNet2DModel
from
diffusers.utils
import
deprecate
from
diffusers.utils.testing_utils
import
require_torch_gpu
,
slow
,
torch_device
from
diffusers.utils.testing_utils
import
require_torch_gpu
,
slow
,
torch_device
...
@@ -67,32 +66,6 @@ class DDPMPipelineFastTests(unittest.TestCase):
...
@@ -67,32 +66,6 @@ class DDPMPipelineFastTests(unittest.TestCase):
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
assert
np
.
abs
(
image_from_tuple_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
assert
np
.
abs
(
image_from_tuple_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
def
test_inference_deprecated_predict_epsilon
(
self
):
deprecate
(
"remove this test"
,
"0.13.0"
,
"remove"
)
unet
=
self
.
dummy_uncond_unet
scheduler
=
DDPMScheduler
(
predict_epsilon
=
False
)
ddpm
=
DDPMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
ddpm
.
to
(
torch_device
)
ddpm
.
set_progress_bar_config
(
disable
=
None
)
# Warmup pass when using mps (see #372)
if
torch_device
==
"mps"
:
_
=
ddpm
(
num_inference_steps
=
1
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ddpm
(
generator
=
generator
,
num_inference_steps
=
2
,
output_type
=
"numpy"
).
images
generator
=
torch
.
manual_seed
(
0
)
image_eps
=
ddpm
(
generator
=
generator
,
num_inference_steps
=
2
,
output_type
=
"numpy"
,
predict_epsilon
=
False
)[
0
]
image_slice
=
image
[
0
,
-
3
:,
-
3
:,
-
1
]
image_eps_slice
=
image_eps
[
0
,
-
3
:,
-
3
:,
-
1
]
assert
image
.
shape
==
(
1
,
32
,
32
,
3
)
tolerance
=
1e-2
if
torch_device
!=
"mps"
else
3e-2
assert
np
.
abs
(
image_slice
.
flatten
()
-
image_eps_slice
.
flatten
()).
max
()
<
tolerance
def
test_inference_predict_sample
(
self
):
def
test_inference_predict_sample
(
self
):
unet
=
self
.
dummy_uncond_unet
unet
=
self
.
dummy_uncond_unet
scheduler
=
DDPMScheduler
(
prediction_type
=
"sample"
)
scheduler
=
DDPMScheduler
(
prediction_type
=
"sample"
)
...
...
tests/test_config.py
View file @
09779cbb
...
@@ -26,7 +26,6 @@ from diffusers import (
...
@@ -26,7 +26,6 @@ from diffusers import (
logging
,
logging
,
)
)
from
diffusers.configuration_utils
import
ConfigMixin
,
register_to_config
from
diffusers.configuration_utils
import
ConfigMixin
,
register_to_config
from
diffusers.utils
import
deprecate
from
diffusers.utils.testing_utils
import
CaptureLogger
from
diffusers.utils.testing_utils
import
CaptureLogger
...
@@ -202,20 +201,10 @@ class ConfigTester(unittest.TestCase):
...
@@ -202,20 +201,10 @@ class ConfigTester(unittest.TestCase):
with
CaptureLogger
(
logger
)
as
cap_logger_2
:
with
CaptureLogger
(
logger
)
as
cap_logger_2
:
ddpm_2
=
DDPMScheduler
.
from_pretrained
(
"google/ddpm-celebahq-256"
,
beta_start
=
88
)
ddpm_2
=
DDPMScheduler
.
from_pretrained
(
"google/ddpm-celebahq-256"
,
beta_start
=
88
)
with
CaptureLogger
(
logger
)
as
cap_logger
:
deprecate
(
"remove this case"
,
"0.13.0"
,
"remove"
)
ddpm_3
=
DDPMScheduler
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-torch"
,
subfolder
=
"scheduler"
,
predict_epsilon
=
False
,
beta_end
=
8
,
)
assert
ddpm
.
__class__
==
DDPMScheduler
assert
ddpm
.
__class__
==
DDPMScheduler
assert
ddpm
.
config
.
prediction_type
==
"sample"
assert
ddpm
.
config
.
prediction_type
==
"sample"
assert
ddpm
.
config
.
beta_end
==
8
assert
ddpm
.
config
.
beta_end
==
8
assert
ddpm_2
.
config
.
beta_start
==
88
assert
ddpm_2
.
config
.
beta_start
==
88
assert
ddpm_3
.
config
.
prediction_type
==
"sample"
# no warning should be thrown
# no warning should be thrown
assert
cap_logger
.
out
==
""
assert
cap_logger
.
out
==
""
...
...
tests/test_scheduler.py
View file @
09779cbb
...
@@ -45,7 +45,7 @@ from diffusers import (
...
@@ -45,7 +45,7 @@ from diffusers import (
)
)
from
diffusers.configuration_utils
import
ConfigMixin
,
register_to_config
from
diffusers.configuration_utils
import
ConfigMixin
,
register_to_config
from
diffusers.schedulers.scheduling_utils
import
SchedulerMixin
from
diffusers.schedulers.scheduling_utils
import
SchedulerMixin
from
diffusers.utils
import
deprecate
,
torch_device
from
diffusers.utils
import
torch_device
from
diffusers.utils.testing_utils
import
CaptureLogger
from
diffusers.utils.testing_utils
import
CaptureLogger
...
@@ -645,35 +645,6 @@ class DDPMSchedulerTest(SchedulerCommonTest):
...
@@ -645,35 +645,6 @@ class DDPMSchedulerTest(SchedulerCommonTest):
for
prediction_type
in
[
"epsilon"
,
"sample"
,
"v_prediction"
]:
for
prediction_type
in
[
"epsilon"
,
"sample"
,
"v_prediction"
]:
self
.
check_over_configs
(
prediction_type
=
prediction_type
)
self
.
check_over_configs
(
prediction_type
=
prediction_type
)
def
test_deprecated_predict_epsilon
(
self
):
deprecate
(
"remove this test"
,
"0.13.0"
,
"remove"
)
for
predict_epsilon
in
[
True
,
False
]:
self
.
check_over_configs
(
predict_epsilon
=
predict_epsilon
)
def
test_deprecated_epsilon
(
self
):
deprecate
(
"remove this test"
,
"0.13.0"
,
"remove"
)
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_config
=
self
.
get_scheduler_config
()
sample
=
self
.
dummy_sample_deter
residual
=
0.1
*
self
.
dummy_sample_deter
time_step
=
4
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler_eps
=
scheduler_class
(
predict_epsilon
=
False
,
**
scheduler_config
)
kwargs
=
{}
if
"generator"
in
set
(
inspect
.
signature
(
scheduler
.
step
).
parameters
.
keys
()):
kwargs
[
"generator"
]
=
torch
.
manual_seed
(
0
)
output
=
scheduler
.
step
(
residual
,
time_step
,
sample
,
predict_epsilon
=
False
,
**
kwargs
).
prev_sample
kwargs
=
{}
if
"generator"
in
set
(
inspect
.
signature
(
scheduler
.
step
).
parameters
.
keys
()):
kwargs
[
"generator"
]
=
torch
.
manual_seed
(
0
)
output_eps
=
scheduler_eps
.
step
(
residual
,
time_step
,
sample
,
predict_epsilon
=
False
,
**
kwargs
).
prev_sample
assert
(
output
-
output_eps
).
abs
().
sum
()
<
1e-5
def
test_time_indices
(
self
):
def
test_time_indices
(
self
):
for
t
in
[
0
,
500
,
999
]:
for
t
in
[
0
,
500
,
999
]:
self
.
check_over_forward
(
time_step
=
t
)
self
.
check_over_forward
(
time_step
=
t
)
...
...
tests/test_scheduler_flax.py
View file @
09779cbb
...
@@ -18,7 +18,7 @@ import unittest
...
@@ -18,7 +18,7 @@ import unittest
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
Dict
,
List
,
Tuple
from
diffusers
import
FlaxDDIMScheduler
,
FlaxDDPMScheduler
,
FlaxPNDMScheduler
from
diffusers
import
FlaxDDIMScheduler
,
FlaxDDPMScheduler
,
FlaxPNDMScheduler
from
diffusers.utils
import
deprecate
,
is_flax_available
from
diffusers.utils
import
is_flax_available
from
diffusers.utils.testing_utils
import
require_flax
from
diffusers.utils.testing_utils
import
require_flax
...
@@ -626,22 +626,6 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
...
@@ -626,22 +626,6 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
for
prediction_type
in
[
"epsilon"
,
"sample"
,
"v_prediction"
]:
for
prediction_type
in
[
"epsilon"
,
"sample"
,
"v_prediction"
]:
self
.
check_over_configs
(
prediction_type
=
prediction_type
)
self
.
check_over_configs
(
prediction_type
=
prediction_type
)
def
test_deprecated_predict_epsilon
(
self
):
deprecate
(
"remove this test"
,
"0.13.0"
,
"remove"
)
for
predict_epsilon
in
[
True
,
False
]:
self
.
check_over_configs
(
predict_epsilon
=
predict_epsilon
)
def
test_deprecated_predict_epsilon_to_prediction_type
(
self
):
deprecate
(
"remove this test"
,
"0.13.0"
,
"remove"
)
for
scheduler_class
in
self
.
scheduler_classes
:
scheduler_config
=
self
.
get_scheduler_config
(
predict_epsilon
=
True
)
scheduler
=
scheduler_class
.
from_config
(
scheduler_config
)
assert
scheduler
.
prediction_type
==
"epsilon"
scheduler_config
=
self
.
get_scheduler_config
(
predict_epsilon
=
False
)
scheduler
=
scheduler_class
.
from_config
(
scheduler_config
)
assert
scheduler
.
prediction_type
==
"sample"
@
require_flax
@
require_flax
class
FlaxPNDMSchedulerTest
(
FlaxSchedulerCommonTest
):
class
FlaxPNDMSchedulerTest
(
FlaxSchedulerCommonTest
):
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment