Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
d9e7857a
Unverified
Commit
d9e7857a
authored
Sep 26, 2023
by
Pedro Cuenca
Committed by
GitHub
Sep 26, 2023
Browse files
timestep_spacing for FlaxDPMSolverMultistepScheduler (#5189)
* timestep_spacing for FlaxDPMSolverMultistepScheduler * Style
parent
fd1c54ab
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
6 deletions
+27
-6
src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
...ffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
+27
-6
No files found.
src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
View file @
d9e7857a
...
@@ -135,6 +135,9 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -135,6 +135,9 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
lower_order_final (`bool`, default `True`):
lower_order_final (`bool`, default `True`):
whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10.
find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10.
timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
the `dtype` used for params and computation.
the `dtype` used for params and computation.
"""
"""
...
@@ -163,6 +166,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -163,6 +166,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
algorithm_type
:
str
=
"dpmsolver++"
,
algorithm_type
:
str
=
"dpmsolver++"
,
solver_type
:
str
=
"midpoint"
,
solver_type
:
str
=
"midpoint"
,
lower_order_final
:
bool
=
True
,
lower_order_final
:
bool
=
True
,
timestep_spacing
:
str
=
"linspace"
,
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
):
):
self
.
dtype
=
dtype
self
.
dtype
=
dtype
...
@@ -210,11 +214,28 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
...
@@ -210,11 +214,28 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
shape (`Tuple`):
shape (`Tuple`):
the shape of the samples to be generated.
the shape of the samples to be generated.
"""
"""
last_timestep
=
self
.
config
.
num_train_timesteps
if
self
.
config
.
timestep_spacing
==
"linspace"
:
timesteps
=
(
jnp
.
linspace
(
0
,
last_timestep
-
1
,
num_inference_steps
+
1
).
round
()[::
-
1
][:
-
1
].
astype
(
jnp
.
int32
)
)
elif
self
.
config
.
timestep_spacing
==
"leading"
:
step_ratio
=
last_timestep
//
(
num_inference_steps
+
1
)
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps
=
(
timesteps
=
(
jnp
.
linspace
(
0
,
self
.
config
.
num_train_timesteps
-
1
,
num_inference_steps
+
1
)
(
jnp
.
arange
(
0
,
num_inference_steps
+
1
)
*
step_ratio
).
round
()[::
-
1
][:
-
1
].
copy
().
astype
(
jnp
.
int32
)
.
round
()[::
-
1
][:
-
1
]
)
.
astype
(
jnp
.
int32
)
timesteps
+=
self
.
config
.
steps_offset
elif
self
.
config
.
timestep_spacing
==
"trailing"
:
step_ratio
=
self
.
config
.
num_train_timesteps
/
num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps
=
jnp
.
arange
(
last_timestep
,
0
,
-
step_ratio
).
round
().
copy
().
astype
(
jnp
.
int32
)
timesteps
-=
1
else
:
raise
ValueError
(
f
"
{
self
.
config
.
timestep_spacing
}
is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)
)
# initial running values
# initial running values
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment