Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
09779cbb
Unverified
Commit
09779cbb
authored
Jan 25, 2023
by
Patrick von Platen
Committed by
GitHub
Jan 25, 2023
Browse files
[Bump version] 0.13.0dev0 & Deprecate `predict_epsilon` (#2109)
* [Bump version] 0.13 * Bump model up * up
parent
b0cc7c20
Changes
33
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
8 additions
and
166 deletions
+8
-166
src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py
...iffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py
+1
-1
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
...nes/stable_diffusion/pipeline_stable_diffusion_img2img.py
+1
-1
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py
...ble_diffusion/pipeline_stable_diffusion_inpaint_legacy.py
+1
-1
src/diffusers/schedulers/scheduling_ddim.py
src/diffusers/schedulers/scheduling_ddim.py
+1
-11
src/diffusers/schedulers/scheduling_ddim_flax.py
src/diffusers/schedulers/scheduling_ddim_flax.py
+0
-11
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+2
-23
src/diffusers/schedulers/scheduling_ddpm_flax.py
src/diffusers/schedulers/scheduling_ddpm_flax.py
+0
-11
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
+0
-11
src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
...ffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
+0
-11
tests/pipelines/ddpm/test_ddpm.py
tests/pipelines/ddpm/test_ddpm.py
+0
-27
tests/test_config.py
tests/test_config.py
+0
-11
tests/test_scheduler.py
tests/test_scheduler.py
+1
-30
tests/test_scheduler_flax.py
tests/test_scheduler_flax.py
+1
-17
No files found.
src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py
View file @
09779cbb
...
...
@@ -303,7 +303,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
(nsfw) content, according to the `safety_checker`.
"""
message
=
"Please use `image` instead of `init_image`."
init_image
=
deprecate
(
"init_image"
,
"0.1
3
.0"
,
message
,
take_from
=
kwargs
)
init_image
=
deprecate
(
"init_image"
,
"0.1
4
.0"
,
message
,
take_from
=
kwargs
)
image
=
init_image
or
image
if
isinstance
(
prompt
,
str
):
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
View file @
09779cbb
...
...
@@ -616,7 +616,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
(nsfw) content, according to the `safety_checker`.
"""
message
=
"Please use `image` instead of `init_image`."
init_image
=
deprecate
(
"init_image"
,
"0.1
3
.0"
,
message
,
take_from
=
kwargs
)
init_image
=
deprecate
(
"init_image"
,
"0.1
4
.0"
,
message
,
take_from
=
kwargs
)
image
=
init_image
or
image
# 1. Check inputs. Raise error if not correct
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py
View file @
09779cbb
...
...
@@ -556,7 +556,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
(nsfw) content, according to the `safety_checker`.
"""
message
=
"Please use `image` instead of `init_image`."
init_image
=
deprecate
(
"init_image"
,
"0.1
3
.0"
,
message
,
take_from
=
kwargs
)
init_image
=
deprecate
(
"init_image"
,
"0.1
4
.0"
,
message
,
take_from
=
kwargs
)
image
=
init_image
or
image
# 1. Check inputs
...
...
src/diffusers/schedulers/scheduling_ddim.py
View file @
09779cbb
...
...
@@ -23,7 +23,7 @@ import numpy as np
import
torch
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..utils
import
BaseOutput
,
deprecate
,
randn_tensor
from
..utils
import
BaseOutput
,
randn_tensor
from
.scheduling_utils
import
KarrasDiffusionSchedulers
,
SchedulerMixin
...
...
@@ -113,7 +113,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
"""
_compatibles
=
[
e
.
name
for
e
in
KarrasDiffusionSchedulers
]
_deprecated_kwargs
=
[
"predict_epsilon"
]
order
=
1
@
register_to_config
...
...
@@ -128,16 +127,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
set_alpha_to_one
:
bool
=
True
,
steps_offset
:
int
=
0
,
prediction_type
:
str
=
"epsilon"
,
**
kwargs
,
):
message
=
(
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.13.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
:
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
if
trained_betas
is
not
None
:
self
.
betas
=
torch
.
tensor
(
trained_betas
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"linear"
:
...
...
src/diffusers/schedulers/scheduling_ddim_flax.py
View file @
09779cbb
...
...
@@ -22,7 +22,6 @@ import flax
import
jax.numpy
as
jnp
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..utils
import
deprecate
from
.scheduling_utils_flax
import
(
CommonSchedulerState
,
FlaxKarrasDiffusionSchedulers
,
...
...
@@ -103,7 +102,6 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
_compatibles
=
[
e
.
name
for
e
in
FlaxKarrasDiffusionSchedulers
]
_deprecated_kwargs
=
[
"predict_epsilon"
]
dtype
:
jnp
.
dtype
...
...
@@ -123,16 +121,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
steps_offset
:
int
=
0
,
prediction_type
:
str
=
"epsilon"
,
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
**
kwargs
,
):
message
=
(
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
f
"
{
self
.
__class__
.
__name__
}
.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.13.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
:
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
self
.
dtype
=
dtype
def
create_state
(
self
,
common
:
Optional
[
CommonSchedulerState
]
=
None
)
->
DDIMSchedulerState
:
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
09779cbb
...
...
@@ -21,8 +21,8 @@ from typing import List, Optional, Tuple, Union
import
numpy
as
np
import
torch
from
..configuration_utils
import
ConfigMixin
,
FrozenDict
,
register_to_config
from
..utils
import
BaseOutput
,
deprecate
,
randn_tensor
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..utils
import
BaseOutput
,
randn_tensor
from
.scheduling_utils
import
KarrasDiffusionSchedulers
,
SchedulerMixin
...
...
@@ -106,7 +106,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
"""
_compatibles
=
[
e
.
name
for
e
in
KarrasDiffusionSchedulers
]
_deprecated_kwargs
=
[
"predict_epsilon"
]
order
=
1
@
register_to_config
...
...
@@ -120,16 +119,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
variance_type
:
str
=
"fixed_small"
,
clip_sample
:
bool
=
True
,
prediction_type
:
str
=
"epsilon"
,
**
kwargs
,
):
message
=
(
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.13.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
:
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
if
trained_betas
is
not
None
:
self
.
betas
=
torch
.
tensor
(
trained_betas
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"linear"
:
...
...
@@ -239,7 +229,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
sample
:
torch
.
FloatTensor
,
generator
=
None
,
return_dict
:
bool
=
True
,
**
kwargs
,
)
->
Union
[
DDPMSchedulerOutput
,
Tuple
]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
...
...
@@ -259,16 +248,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
returning a tuple, the first element is the sample tensor.
"""
message
=
(
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.13.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
:
new_config
=
dict
(
self
.
config
)
new_config
[
"prediction_type"
]
=
"epsilon"
if
predict_epsilon
else
"sample"
self
.
_internal_dict
=
FrozenDict
(
new_config
)
t
=
timestep
if
model_output
.
shape
[
1
]
==
sample
.
shape
[
1
]
*
2
and
self
.
variance_type
in
[
"learned"
,
"learned_range"
]:
...
...
src/diffusers/schedulers/scheduling_ddpm_flax.py
View file @
09779cbb
...
...
@@ -22,7 +22,6 @@ import jax
import
jax.numpy
as
jnp
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..utils
import
deprecate
from
.scheduling_utils_flax
import
(
CommonSchedulerState
,
FlaxKarrasDiffusionSchedulers
,
...
...
@@ -86,7 +85,6 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
_compatibles
=
[
e
.
name
for
e
in
FlaxKarrasDiffusionSchedulers
]
_deprecated_kwargs
=
[
"predict_epsilon"
]
dtype
:
jnp
.
dtype
...
...
@@ -106,16 +104,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
clip_sample
:
bool
=
True
,
prediction_type
:
str
=
"epsilon"
,
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
**
kwargs
,
):
message
=
(
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
f
"
{
self
.
__class__
.
__name__
}
.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.13.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
:
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
self
.
dtype
=
dtype
def
create_state
(
self
,
common
:
Optional
[
CommonSchedulerState
]
=
None
)
->
DDPMSchedulerState
:
...
...
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
View file @
09779cbb
...
...
@@ -21,7 +21,6 @@ import numpy as np
import
torch
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..utils
import
deprecate
from
.scheduling_utils
import
KarrasDiffusionSchedulers
,
SchedulerMixin
,
SchedulerOutput
...
...
@@ -118,7 +117,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
_compatibles
=
[
e
.
name
for
e
in
KarrasDiffusionSchedulers
]
_deprecated_kwargs
=
[
"predict_epsilon"
]
order
=
1
@
register_to_config
...
...
@@ -137,16 +135,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
algorithm_type
:
str
=
"dpmsolver++"
,
solver_type
:
str
=
"midpoint"
,
lower_order_final
:
bool
=
True
,
**
kwargs
,
):
message
=
(
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DPMSolverMultistepScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.13.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
:
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
if
trained_betas
is
not
None
:
self
.
betas
=
torch
.
tensor
(
trained_betas
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"linear"
:
...
...
src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
View file @
09779cbb
...
...
@@ -22,7 +22,6 @@ import jax
import
jax.numpy
as
jnp
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..utils
import
deprecate
from
.scheduling_utils_flax
import
(
CommonSchedulerState
,
FlaxKarrasDiffusionSchedulers
,
...
...
@@ -141,7 +140,6 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
_compatibles
=
[
e
.
name
for
e
in
FlaxKarrasDiffusionSchedulers
]
_deprecated_kwargs
=
[
"predict_epsilon"
]
dtype
:
jnp
.
dtype
...
...
@@ -166,16 +164,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
solver_type
:
str
=
"midpoint"
,
lower_order_final
:
bool
=
True
,
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
**
kwargs
,
):
message
=
(
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
f
"
{
self
.
__class__
.
__name__
}
.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon
=
deprecate
(
"predict_epsilon"
,
"0.13.0"
,
message
,
take_from
=
kwargs
)
if
predict_epsilon
is
not
None
:
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
self
.
dtype
=
dtype
def
create_state
(
self
,
common
:
Optional
[
CommonSchedulerState
]
=
None
)
->
DPMSolverMultistepSchedulerState
:
...
...
tests/pipelines/ddpm/test_ddpm.py
View file @
09779cbb
...
...
@@ -19,7 +19,6 @@ import numpy as np
import
torch
from
diffusers
import
DDPMPipeline
,
DDPMScheduler
,
UNet2DModel
from
diffusers.utils
import
deprecate
from
diffusers.utils.testing_utils
import
require_torch_gpu
,
slow
,
torch_device
...
...
@@ -67,32 +66,6 @@ class DDPMPipelineFastTests(unittest.TestCase):
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
assert
np
.
abs
(
image_from_tuple_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
def
test_inference_deprecated_predict_epsilon
(
self
):
deprecate
(
"remove this test"
,
"0.13.0"
,
"remove"
)
unet
=
self
.
dummy_uncond_unet
scheduler
=
DDPMScheduler
(
predict_epsilon
=
False
)
ddpm
=
DDPMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
ddpm
.
to
(
torch_device
)
ddpm
.
set_progress_bar_config
(
disable
=
None
)
# Warmup pass when using mps (see #372)
if
torch_device
==
"mps"
:
_
=
ddpm
(
num_inference_steps
=
1
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ddpm
(
generator
=
generator
,
num_inference_steps
=
2
,
output_type
=
"numpy"
).
images
generator
=
torch
.
manual_seed
(
0
)
image_eps
=
ddpm
(
generator
=
generator
,
num_inference_steps
=
2
,
output_type
=
"numpy"
,
predict_epsilon
=
False
)[
0
]
image_slice
=
image
[
0
,
-
3
:,
-
3
:,
-
1
]
image_eps_slice
=
image_eps
[
0
,
-
3
:,
-
3
:,
-
1
]
assert
image
.
shape
==
(
1
,
32
,
32
,
3
)
tolerance
=
1e-2
if
torch_device
!=
"mps"
else
3e-2
assert
np
.
abs
(
image_slice
.
flatten
()
-
image_eps_slice
.
flatten
()).
max
()
<
tolerance
def
test_inference_predict_sample
(
self
):
unet
=
self
.
dummy_uncond_unet
scheduler
=
DDPMScheduler
(
prediction_type
=
"sample"
)
...
...
tests/test_config.py
View file @
09779cbb
...
...
@@ -26,7 +26,6 @@ from diffusers import (
logging
,
)
from
diffusers.configuration_utils
import
ConfigMixin
,
register_to_config
from
diffusers.utils
import
deprecate
from
diffusers.utils.testing_utils
import
CaptureLogger
...
...
@@ -202,20 +201,10 @@ class ConfigTester(unittest.TestCase):
with
CaptureLogger
(
logger
)
as
cap_logger_2
:
ddpm_2
=
DDPMScheduler
.
from_pretrained
(
"google/ddpm-celebahq-256"
,
beta_start
=
88
)
with
CaptureLogger
(
logger
)
as
cap_logger
:
deprecate
(
"remove this case"
,
"0.13.0"
,
"remove"
)
ddpm_3
=
DDPMScheduler
.
from_pretrained
(
"hf-internal-testing/tiny-stable-diffusion-torch"
,
subfolder
=
"scheduler"
,
predict_epsilon
=
False
,
beta_end
=
8
,
)
assert
ddpm
.
__class__
==
DDPMScheduler
assert
ddpm
.
config
.
prediction_type
==
"sample"
assert
ddpm
.
config
.
beta_end
==
8
assert
ddpm_2
.
config
.
beta_start
==
88
assert
ddpm_3
.
config
.
prediction_type
==
"sample"
# no warning should be thrown
assert
cap_logger
.
out
==
""
...
...
tests/test_scheduler.py
View file @
09779cbb
...
...
@@ -45,7 +45,7 @@ from diffusers import (
)
from
diffusers.configuration_utils
import
ConfigMixin
,
register_to_config
from
diffusers.schedulers.scheduling_utils
import
SchedulerMixin
from
diffusers.utils
import
deprecate
,
torch_device
from
diffusers.utils
import
torch_device
from
diffusers.utils.testing_utils
import
CaptureLogger
...
...
@@ -645,35 +645,6 @@ class DDPMSchedulerTest(SchedulerCommonTest):
for
prediction_type
in
[
"epsilon"
,
"sample"
,
"v_prediction"
]:
self
.
check_over_configs
(
prediction_type
=
prediction_type
)
def
test_deprecated_predict_epsilon
(
self
):
deprecate
(
"remove this test"
,
"0.13.0"
,
"remove"
)
for
predict_epsilon
in
[
True
,
False
]:
self
.
check_over_configs
(
predict_epsilon
=
predict_epsilon
)
def
test_deprecated_epsilon
(
self
):
deprecate
(
"remove this test"
,
"0.13.0"
,
"remove"
)
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_config
=
self
.
get_scheduler_config
()
sample
=
self
.
dummy_sample_deter
residual
=
0.1
*
self
.
dummy_sample_deter
time_step
=
4
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler_eps
=
scheduler_class
(
predict_epsilon
=
False
,
**
scheduler_config
)
kwargs
=
{}
if
"generator"
in
set
(
inspect
.
signature
(
scheduler
.
step
).
parameters
.
keys
()):
kwargs
[
"generator"
]
=
torch
.
manual_seed
(
0
)
output
=
scheduler
.
step
(
residual
,
time_step
,
sample
,
predict_epsilon
=
False
,
**
kwargs
).
prev_sample
kwargs
=
{}
if
"generator"
in
set
(
inspect
.
signature
(
scheduler
.
step
).
parameters
.
keys
()):
kwargs
[
"generator"
]
=
torch
.
manual_seed
(
0
)
output_eps
=
scheduler_eps
.
step
(
residual
,
time_step
,
sample
,
predict_epsilon
=
False
,
**
kwargs
).
prev_sample
assert
(
output
-
output_eps
).
abs
().
sum
()
<
1e-5
def
test_time_indices
(
self
):
for
t
in
[
0
,
500
,
999
]:
self
.
check_over_forward
(
time_step
=
t
)
...
...
tests/test_scheduler_flax.py
View file @
09779cbb
...
...
@@ -18,7 +18,7 @@ import unittest
from
typing
import
Dict
,
List
,
Tuple
from
diffusers
import
FlaxDDIMScheduler
,
FlaxDDPMScheduler
,
FlaxPNDMScheduler
from
diffusers.utils
import
deprecate
,
is_flax_available
from
diffusers.utils
import
is_flax_available
from
diffusers.utils.testing_utils
import
require_flax
...
...
@@ -626,22 +626,6 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
for
prediction_type
in
[
"epsilon"
,
"sample"
,
"v_prediction"
]:
self
.
check_over_configs
(
prediction_type
=
prediction_type
)
def
test_deprecated_predict_epsilon
(
self
):
deprecate
(
"remove this test"
,
"0.13.0"
,
"remove"
)
for
predict_epsilon
in
[
True
,
False
]:
self
.
check_over_configs
(
predict_epsilon
=
predict_epsilon
)
def
test_deprecated_predict_epsilon_to_prediction_type
(
self
):
deprecate
(
"remove this test"
,
"0.13.0"
,
"remove"
)
for
scheduler_class
in
self
.
scheduler_classes
:
scheduler_config
=
self
.
get_scheduler_config
(
predict_epsilon
=
True
)
scheduler
=
scheduler_class
.
from_config
(
scheduler_config
)
assert
scheduler
.
prediction_type
==
"epsilon"
scheduler_config
=
self
.
get_scheduler_config
(
predict_epsilon
=
False
)
scheduler
=
scheduler_class
.
from_config
(
scheduler_config
)
assert
scheduler
.
prediction_type
==
"sample"
@
require_flax
class
FlaxPNDMSchedulerTest
(
FlaxSchedulerCommonTest
):
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment