Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
d9e7857a
Unverified
Commit
d9e7857a
authored
Sep 26, 2023
by
Pedro Cuenca
Committed by
GitHub
Sep 26, 2023
Browse files
timestep_spacing for FlaxDPMSolverMultistepScheduler (#5189)
* timestep_spacing for FlaxDPMSolverMultistepScheduler * Style
parent
fd1c54ab
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
6 deletions
+27
-6
src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
...ffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
+27
-6
No files found.
src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
View file @
d9e7857a
...
@@ -135,6 +135,9 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -135,6 +135,9 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
lower_order_final (`bool`, default `True`):
lower_order_final (`bool`, default `True`):
whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10.
find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10.
timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
the `dtype` used for params and computation.
the `dtype` used for params and computation.
"""
"""
...
@@ -163,6 +166,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -163,6 +166,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
algorithm_type
:
str
=
"dpmsolver++"
,
algorithm_type
:
str
=
"dpmsolver++"
,
solver_type
:
str
=
"midpoint"
,
solver_type
:
str
=
"midpoint"
,
lower_order_final
:
bool
=
True
,
lower_order_final
:
bool
=
True
,
timestep_spacing
:
str
=
"linspace"
,
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
):
):
self
.
dtype
=
dtype
self
.
dtype
=
dtype
...
@@ -210,12 +214,29 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -210,12 +214,29 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
shape (`Tuple`):
shape (`Tuple`):
the shape of the samples to be generated.
the shape of the samples to be generated.
"""
"""
last_timestep
=
self
.
config
.
num_train_timesteps
timesteps
=
(
if
self
.
config
.
timestep_spacing
==
"linspace"
:
jnp
.
linspace
(
0
,
self
.
config
.
num_train_timesteps
-
1
,
num_inference_steps
+
1
)
timesteps
=
(
.
round
()[::
-
1
][:
-
1
]
jnp
.
linspace
(
0
,
last_timestep
-
1
,
num_inference_steps
+
1
).
round
()[::
-
1
][:
-
1
].
astype
(
jnp
.
int32
)
.
astype
(
jnp
.
int32
)
)
)
elif
self
.
config
.
timestep_spacing
==
"leading"
:
step_ratio
=
last_timestep
//
(
num_inference_steps
+
1
)
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps
=
(
(
jnp
.
arange
(
0
,
num_inference_steps
+
1
)
*
step_ratio
).
round
()[::
-
1
][:
-
1
].
copy
().
astype
(
jnp
.
int32
)
)
timesteps
+=
self
.
config
.
steps_offset
elif
self
.
config
.
timestep_spacing
==
"trailing"
:
step_ratio
=
self
.
config
.
num_train_timesteps
/
num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps
=
jnp
.
arange
(
last_timestep
,
0
,
-
step_ratio
).
round
().
copy
().
astype
(
jnp
.
int32
)
timesteps
-=
1
else
:
raise
ValueError
(
f
"
{
self
.
config
.
timestep_spacing
}
is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)
# initial running values
# initial running values
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment