Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
09779cbb
Unverified
Commit
09779cbb
authored
Jan 25, 2023
by
Patrick von Platen
Committed by
GitHub
Jan 25, 2023
Browse files
[Bump version] 0.13.0dev0 & Deprecate `predict_epsilon` (#2109)
* [Bump version] 0.13 * Bump model up * up
parent
b0cc7c20
Changes
33
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
8 additions
and
166 deletions
+8
-166
src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py
...iffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py
+1
-1
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
...nes/stable_diffusion/pipeline_stable_diffusion_img2img.py
+1
-1
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py
...ble_diffusion/pipeline_stable_diffusion_inpaint_legacy.py
+1
-1
src/diffusers/schedulers/scheduling_ddim.py
src/diffusers/schedulers/scheduling_ddim.py
+1
-11
src/diffusers/schedulers/scheduling_ddim_flax.py
src/diffusers/schedulers/scheduling_ddim_flax.py
+0
-11
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+2
-23
src/diffusers/schedulers/scheduling_ddpm_flax.py
src/diffusers/schedulers/scheduling_ddpm_flax.py
+0
-11
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
+0
-11
src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
...ffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
+0
-11
tests/pipelines/ddpm/test_ddpm.py
tests/pipelines/ddpm/test_ddpm.py
+0
-27
tests/test_config.py
tests/test_config.py
+0
-11
tests/test_scheduler.py
tests/test_scheduler.py
+1
-30
tests/test_scheduler_flax.py
tests/test_scheduler_flax.py
+1
-17
No files found.
src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py
View file @
09779cbb
...
...
@@ -303,7 +303,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
(nsfw) content, according to the `safety_checker`.
"""
message
=
"Please use `image` instead of `init_image`."
init_image
=
deprecate
(
"init_image"
,
"0.1
3
.0"
,
message
,
take_from
=
kwargs
)
init_image
=
deprecate
(
"init_image"
,
"0.1
4
.0"
,
message
,
take_from
=
kwargs
)
image
=
init_image
or
image
if
isinstance
(
prompt
,
str
):
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
View file @
09779cbb
...
...
@@ -616,7 +616,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
(nsfw) content, according to the `safety_checker`.
"""
message
=
"Please use `image` instead of `init_image`."
init_image
=
deprecate
(
"init_image"
,
"0.1
3
.0"
,
message
,
take_from
=
kwargs
)
init_image
=
deprecate
(
"init_image"
,
"0.1
4
.0"
,
message
,
take_from
=
kwargs
)
image
=
init_image
or
image
# 1. Check inputs. Raise error if not correct
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py
View file @
09779cbb
...
...
@@ -556,7 +556,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
(nsfw) content, according to the `safety_checker`.
"""
message
=
"Please use `image` instead of `init_image`."
init_image
=
deprecate
(
"init_image"
,
"0.1
3
.0"
,
message
,
take_from
=
kwargs
)
init_image
=
deprecate
(
"init_image"
,
"0.1
4
.0"
,
message
,
take_from
=
kwargs
)
image
=
init_image
or
image
# 1. Check inputs
...
...
src/diffusers/schedulers/scheduling_ddim.py
View file @
09779cbb
...
...
@@ -23,7 +23,7 @@ import numpy as np
import
torch
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..utils
import
BaseOutput
,
deprecate
,
randn_tensor
from
..utils
import
BaseOutput
,
randn_tensor
from
.scheduling_utils
import
KarrasDiffusionSchedulers
,
SchedulerMixin
...
...
@@ -113,7 +113,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
"""
_compatibles
=
[
e
.
name
for
e
in
KarrasDiffusionSchedulers
]
_deprecated_kwargs
=
[
"predict_epsilon"
]
order
=
1
@
register_to_config
...
...
@@ -128,16 +127,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
set_alpha_to_one
:
bool
=
True
,
steps_offset
:
int
=
0
,
prediction_type
:
str
=
"epsilon"
,
**
kwargs
,
):
message
=
(
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.13.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
:
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
if
trained_betas
is
not
None
:
self
.
betas
=
torch
.
tensor
(
trained_betas
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"linear"
:
...
...
src/diffusers/schedulers/scheduling_ddim_flax.py
View file @
09779cbb
...
...
@@ -22,7 +22,6 @@ import flax
import
jax.numpy
as
jnp
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..utils
import
deprecate
from
.scheduling_utils_flax
import
(
CommonSchedulerState
,
FlaxKarrasDiffusionSchedulers
,
...
...
@@ -103,7 +102,6 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
_compatibles
=
[
e
.
name
for
e
in
FlaxKarrasDiffusionSchedulers
]
_deprecated_kwargs
=
[
"predict_epsilon"
]
dtype
:
jnp
.
dtype
...
...
@@ -123,16 +121,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
steps_offset
:
int
=
0
,
prediction_type
:
str
=
"epsilon"
,
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
**
kwargs
,
):
message
=
(
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
f
"
{
self
.
__class__
.
__name__
}
.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.13.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
:
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
self
.
dtype
=
dtype
def
create_state
(
self
,
common
:
Optional
[
CommonSchedulerState
]
=
None
)
->
DDIMSchedulerState
:
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
09779cbb
...
...
@@ -21,8 +21,8 @@ from typing import List, Optional, Tuple, Union
import
numpy
as
np
import
torch
from
..configuration_utils
import
ConfigMixin
,
FrozenDict
,
register_to_config
from
..utils
import
BaseOutput
,
deprecate
,
randn_tensor
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..utils
import
BaseOutput
,
randn_tensor
from
.scheduling_utils
import
KarrasDiffusionSchedulers
,
SchedulerMixin
...
...
@@ -106,7 +106,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
"""
_compatibles
=
[
e
.
name
for
e
in
KarrasDiffusionSchedulers
]
_deprecated_kwargs
=
[
"predict_epsilon"
]
order
=
1
@
register_to_config
...
...
@@ -120,16 +119,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
variance_type
:
str
=
"fixed_small"
,
clip_sample
:
bool
=
True
,
prediction_type
:
str
=
"epsilon"
,
**
kwargs
,
):
message
=
(
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.13.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
:
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
if
trained_betas
is
not
None
:
self
.
betas
=
torch
.
tensor
(
trained_betas
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"linear"
:
...
...
@@ -239,7 +229,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
sample
:
torch
.
FloatTensor
,
generator
=
None
,
return_dict
:
bool
=
True
,
**
kwargs
,
)
->
Union
[
DDPMSchedulerOutput
,
Tuple
]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
...
...
@@ -259,16 +248,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
returning a tuple, the first element is the sample tensor.
"""
message
=
(
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.13.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
:
new_config
=
dict
(
self
.
config
)
new_config
[
"prediction_type"
]
=
"epsilon"
if
predict_epsilon
else
"sample"
self
.
_internal_dict
=
FrozenDict
(
new_config
)
t
=
timestep
if
model_output
.
shape
[
1
]
==
sample
.
shape
[
1
]
*
2
and
self
.
variance_type
in
[
"learned"
,
"learned_range"
]:
...
...
src/diffusers/schedulers/scheduling_ddpm_flax.py
View file @
09779cbb
...
...
@@ -22,7 +22,6 @@ import jax
import
jax.numpy
as
jnp
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..utils
import
deprecate
from
.scheduling_utils_flax
import
(
CommonSchedulerState
,
FlaxKarrasDiffusionSchedulers
,
...
...
@@ -86,7 +85,6 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
_compatibles
=
[
e
.
name
for
e
in
FlaxKarrasDiffusionSchedulers
]
_deprecated_kwargs
=
[
"predict_epsilon"
]
dtype
:
jnp
.
dtype
...
...
@@ -106,16 +104,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
clip_sample
:
bool
=
True
,
prediction_type
:
str
=
"epsilon"
,
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
**
kwargs
,
):
message
=
(
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
f
"
{
self
.
__class__
.
__name__
}
.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.13.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
:
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
self
.
dtype
=
dtype
def
create_state
(
self
,
common
:
Optional
[
CommonSchedulerState
]
=
None
)
->
DDPMSchedulerState
:
...
...
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
View file @
09779cbb
...
...
@@ -21,7 +21,6 @@ import numpy as np
import
torch
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..utils
import
deprecate
from
.scheduling_utils
import
KarrasDiffusionSchedulers
,
SchedulerMixin
,
SchedulerOutput
...
...
@@ -118,7 +117,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
_compatibles
=
[
e
.
name
for
e
in
KarrasDiffusionSchedulers
]
_deprecated_kwargs
=
[
"predict_epsilon"
]
order
=
1
@
register_to_config
...
...
@@ -137,16 +135,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
algorithm_type
:
str
=
"dpmsolver++"
,
solver_type
:
str
=
"midpoint"
,
lower_order_final
:
bool
=
True
,
**
kwargs
,
):
message
=
(
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DPMSolverMultistepScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.13.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
:
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
if
trained_betas
is
not
None
:
self
.
betas
=
torch
.
tensor
(
trained_betas
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"linear"
:
...
...
src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
View file @
09779cbb
...
...
@@ -22,7 +22,6 @@ import jax
import
jax.numpy
as
jnp
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..utils
import
deprecate
from
.scheduling_utils_flax
import
(
CommonSchedulerState
,
FlaxKarrasDiffusionSchedulers
,
...
...
@@ -141,7 +140,6 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
_compatibles
=
[
e
.
name
for
e
in
FlaxKarrasDiffusionSchedulers
]
_deprecated_kwargs
=
[
"predict_epsilon"
]
dtype
:
jnp
.
dtype
...
...
@@ -166,16 +164,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
solver_type
:
str
=
"midpoint"
,
lower_order_final
:
bool
=
True
,
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
**
kwargs
,
):
message
=
(
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
f
"
{
self
.
__class__
.
__name__
}
.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.13.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
:
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
self
.
dtype
=
dtype
def
create_state
(
self
,
common
:
Optional
[
CommonSchedulerState
]
=
None
)
->
DPMSolverMultistepSchedulerState
:
...
...
tests/pipelines/ddpm/test_ddpm.py
View file @
09779cbb
...
...
@@ -19,7 +19,6 @@ import numpy as np
import
torch
from
diffusers
import
DDPMPipeline
,
DDPMScheduler
,
UNet2DModel
from
diffusers.utils
import
deprecate
from
diffusers.utils.testing_utils
import
require_torch_gpu
,
slow
,
torch_device
...
...
@@ -67,32 +66,6 @@ class DDPMPipelineFastTests(unittest.TestCase):
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
assert
np
.
abs
(
image_from_tuple_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
def
test_inference_deprecated_predict_epsilon
(
self
):
deprecate
(
"remove this test"
,
"0.13.0"
,
"remove"
)
unet
=
self
.
dummy_uncond_unet
scheduler
=
DDPMScheduler
(
predict_epsilon
=
False
)
ddpm
=
DDPMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
ddpm
.
to
(
torch_device
)
ddpm
.
set_progress_bar_config
(
disable
=
None
)
# Warmup pass when using mps (see #372)
if
torch_device
==
"mps"
:
_
=
ddpm
(
num_inference_steps
=
1
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ddpm
(
generator
=
generator
,
num_inference_steps
=
2
,
output_type
=
"numpy"
).
images
generator
=
torch
.
manual_seed
(
0
)
image_eps
=
ddpm
(
generator
=
generator
,
num_inference_steps
=
2
,
output_type
=
"numpy"
,
predict_epsilon
=
False
)[
0
]
image_slice
=
image
[
0
,
-
3
:,
-
3
:,
-
1
]
image_eps_slice
=
image_eps
[
0
,
-
3
:,
-
3
:,
-
1
]
assert
image
.
shape
==
(
1
,
32
,
32
,
3
)
tolerance
=
1e-2
if
torch_device
!=
"mps"
else
3e-2
assert
np
.
abs
(
image_slice
.
flatten
()
-
image_eps_slice
.
flatten
()).
max
()
<
tolerance
def
test_inference_predict_sample
(
self
):
unet
=
self
.
dummy_uncond_unet
scheduler
=
DDPMScheduler
(
prediction_type
=
"sample"
)
...
...
tests/test_config.py
View file @
09779cbb
...
...
@@ -26,7 +26,6 @@ from diffusers import (
logging
,
)
from
diffusers.configuration_utils
import
ConfigMixin
,
register_to_config
from
diffusers.utils
import
deprecate
from
diffusers.utils.testing_utils
import
CaptureLogger
...
...
@@ -202,20 +201,10 @@ class ConfigTester(unittest.TestCase):
with
CaptureLogger
(
logger
)
as
cap_logger_2
:
ddpm_2
=
DDPMScheduler
.
from_pretrained
(
"google/ddpm-celebahq-256"
,
beta_start
=
88
)
with
CaptureLogger
(
logger
)
as
cap_logger
:
deprecate
(
"remove this case"
,
"0.13.0"
,
"remove"
)
ddpm_3
=
DDPMScheduler
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-torch"
,
subfolder
=
"scheduler"
,
predict_epsilon
=
False
,
beta_end
=
8
,
)
assert
ddpm
.
__class__
==
DDPMScheduler
assert
ddpm
.
config
.
prediction_type
==
"sample"
assert
ddpm
.
config
.
beta_end
==
8
assert
ddpm_2
.
config
.
beta_start
==
88
assert
ddpm_3
.
config
.
prediction_type
==
"sample"
# no warning should be thrown
assert
cap_logger
.
out
==
""
...
...
tests/test_scheduler.py
View file @
09779cbb
...
...
@@ -45,7 +45,7 @@ from diffusers import (
)
from
diffusers.configuration_utils
import
ConfigMixin
,
register_to_config
from
diffusers.schedulers.scheduling_utils
import
SchedulerMixin
from
diffusers.utils
import
deprecate
,
torch_device
from
diffusers.utils
import
torch_device
from
diffusers.utils.testing_utils
import
CaptureLogger
...
...
@@ -645,35 +645,6 @@ class DDPMSchedulerTest(SchedulerCommonTest):
for
prediction_type
in
[
"epsilon"
,
"sample"
,
"v_prediction"
]:
self
.
check_over_configs
(
prediction_type
=
prediction_type
)
def
test_deprecated_predict_epsilon
(
self
):
deprecate
(
"remove this test"
,
"0.13.0"
,
"remove"
)
for
predict_epsilon
in
[
True
,
False
]:
self
.
check_over_configs
(
predict_epsilon
=
predict_epsilon
)
def
test_deprecated_epsilon
(
self
):
deprecate
(
"remove this test"
,
"0.13.0"
,
"remove"
)
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_config
=
self
.
get_scheduler_config
()
sample
=
self
.
dummy_sample_deter
residual
=
0.1
*
self
.
dummy_sample_deter
time_step
=
4
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler_eps
=
scheduler_class
(
predict_epsilon
=
False
,
**
scheduler_config
)
kwargs
=
{}
if
"generator"
in
set
(
inspect
.
signature
(
scheduler
.
step
).
parameters
.
keys
()):
kwargs
[
"generator"
]
=
torch
.
manual_seed
(
0
)
output
=
scheduler
.
step
(
residual
,
time_step
,
sample
,
predict_epsilon
=
False
,
**
kwargs
).
prev_sample
kwargs
=
{}
if
"generator"
in
set
(
inspect
.
signature
(
scheduler
.
step
).
parameters
.
keys
()):
kwargs
[
"generator"
]
=
torch
.
manual_seed
(
0
)
output_eps
=
scheduler_eps
.
step
(
residual
,
time_step
,
sample
,
predict_epsilon
=
False
,
**
kwargs
).
prev_sample
assert
(
output
-
output_eps
).
abs
().
sum
()
<
1e-5
def
test_time_indices
(
self
):
for
t
in
[
0
,
500
,
999
]:
self
.
check_over_forward
(
time_step
=
t
)
...
...
tests/test_scheduler_flax.py
View file @
09779cbb
...
...
@@ -18,7 +18,7 @@ import unittest
from
typing
import
Dict
,
List
,
Tuple
from
diffusers
import
FlaxDDIMScheduler
,
FlaxDDPMScheduler
,
FlaxPNDMScheduler
from
diffusers.utils
import
deprecate
,
is_flax_available
from
diffusers.utils
import
is_flax_available
from
diffusers.utils.testing_utils
import
require_flax
...
...
@@ -626,22 +626,6 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
for
prediction_type
in
[
"epsilon"
,
"sample"
,
"v_prediction"
]:
self
.
check_over_configs
(
prediction_type
=
prediction_type
)
def
test_deprecated_predict_epsilon
(
self
):
deprecate
(
"remove this test"
,
"0.13.0"
,
"remove"
)
for
predict_epsilon
in
[
True
,
False
]:
self
.
check_over_configs
(
predict_epsilon
=
predict_epsilon
)
def
test_deprecated_predict_epsilon_to_prediction_type
(
self
):
deprecate
(
"remove this test"
,
"0.13.0"
,
"remove"
)
for
scheduler_class
in
self
.
scheduler_classes
:
scheduler_config
=
self
.
get_scheduler_config
(
predict_epsilon
=
True
)
scheduler
=
scheduler_class
.
from_config
(
scheduler_config
)
assert
scheduler
.
prediction_type
==
"epsilon"
scheduler_config
=
self
.
get_scheduler_config
(
predict_epsilon
=
False
)
scheduler
=
scheduler_class
.
from_config
(
scheduler_config
)
assert
scheduler
.
prediction_type
==
"sample"
@
require_flax
class
FlaxPNDMSchedulerTest
(
FlaxSchedulerCommonTest
):
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment