Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
464374fb
Unverified
Commit
464374fb
authored
Feb 07, 2025
by
hlky
Committed by
GitHub
Feb 07, 2025
Browse files
EDMEulerScheduler accept sigmas, add final_sigmas_type (#10734)
parent
d43ce14e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
45 additions
and
10 deletions
+45
-10
src/diffusers/schedulers/scheduling_edm_euler.py
src/diffusers/schedulers/scheduling_edm_euler.py
+45
-10
No files found.
src/diffusers/schedulers/scheduling_edm_euler.py
View file @
464374fb
...
...
@@ -14,7 +14,7 @@
import
math
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
...
...
@@ -77,6 +77,9 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
Video](https://imagen.research.google/video/paper.pdf) paper).
rho (`float`, *optional*, defaults to 7.0):
The rho parameter used for calculating the Karras sigma schedule, which is set to 7.0 in the EDM paper [1].
final_sigmas_type (`str`, defaults to `"zero"`):
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
"""
_compatibles
=
[]
...
...
@@ -92,6 +95,7 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
num_train_timesteps
:
int
=
1000
,
prediction_type
:
str
=
"epsilon"
,
rho
:
float
=
7.0
,
final_sigmas_type
:
str
=
"zero"
,
# can be "zero" or "sigma_min"
):
if
sigma_schedule
not
in
[
"karras"
,
"exponential"
]:
raise
ValueError
(
f
"Wrong value for provided for `
{
sigma_schedule
=
}
`.`"
)
...
...
@@ -99,15 +103,24 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
# setable values
self
.
num_inference_steps
=
None
ramp
=
torch
.
linspace
(
0
,
1
,
num_train_timesteps
)
sigmas
=
torch
.
arange
(
num_train_timesteps
+
1
)
/
num_train_timesteps
if
sigma_schedule
==
"karras"
:
sigmas
=
self
.
_compute_karras_sigmas
(
ramp
)
sigmas
=
self
.
_compute_karras_sigmas
(
sigmas
)
elif
sigma_schedule
==
"exponential"
:
sigmas
=
self
.
_compute_exponential_sigmas
(
ramp
)
sigmas
=
self
.
_compute_exponential_sigmas
(
sigmas
)
self
.
timesteps
=
self
.
precondition_noise
(
sigmas
)
self
.
sigmas
=
torch
.
cat
([
sigmas
,
torch
.
zeros
(
1
,
device
=
sigmas
.
device
)])
if
self
.
config
.
final_sigmas_type
==
"sigma_min"
:
sigma_last
=
sigmas
[
-
1
]
elif
self
.
config
.
final_sigmas_type
==
"zero"
:
sigma_last
=
0
else
:
raise
ValueError
(
f
"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got
{
self
.
config
.
final_sigmas_type
}
"
)
self
.
sigmas
=
torch
.
cat
([
sigmas
,
torch
.
full
((
1
,),
fill_value
=
sigma_last
,
device
=
sigmas
.
device
)])
self
.
is_scale_input_called
=
False
...
...
@@ -197,7 +210,12 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
self
.
is_scale_input_called
=
True
return
sample
def
set_timesteps
(
self
,
num_inference_steps
:
int
,
device
:
Union
[
str
,
torch
.
device
]
=
None
):
def
set_timesteps
(
self
,
num_inference_steps
:
int
=
None
,
device
:
Union
[
str
,
torch
.
device
]
=
None
,
sigmas
:
Optional
[
Union
[
torch
.
Tensor
,
List
[
float
]]]
=
None
,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
...
...
@@ -206,19 +224,36 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
sigmas (`Union[torch.Tensor, List[float]]`, *optional*):
Custom sigmas to use for the denoising process. If not defined, the default behavior when
`num_inference_steps` is passed will be used.
"""
self
.
num_inference_steps
=
num_inference_steps
ramp
=
torch
.
linspace
(
0
,
1
,
self
.
num_inference_steps
)
if
sigmas
is
None
:
sigmas
=
torch
.
linspace
(
0
,
1
,
self
.
num_inference_steps
)
elif
isinstance
(
sigmas
,
float
):
sigmas
=
torch
.
tensor
(
sigmas
,
dtype
=
torch
.
float32
)
else
:
sigmas
=
sigmas
if
self
.
config
.
sigma_schedule
==
"karras"
:
sigmas
=
self
.
_compute_karras_sigmas
(
ramp
)
sigmas
=
self
.
_compute_karras_sigmas
(
sigmas
)
elif
self
.
config
.
sigma_schedule
==
"exponential"
:
sigmas
=
self
.
_compute_exponential_sigmas
(
ramp
)
sigmas
=
self
.
_compute_exponential_sigmas
(
sigmas
)
sigmas
=
sigmas
.
to
(
dtype
=
torch
.
float32
,
device
=
device
)
self
.
timesteps
=
self
.
precondition_noise
(
sigmas
)
self
.
sigmas
=
torch
.
cat
([
sigmas
,
torch
.
zeros
(
1
,
device
=
sigmas
.
device
)])
if
self
.
config
.
final_sigmas_type
==
"sigma_min"
:
sigma_last
=
sigmas
[
-
1
]
elif
self
.
config
.
final_sigmas_type
==
"zero"
:
sigma_last
=
0
else
:
raise
ValueError
(
f
"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got
{
self
.
config
.
final_sigmas_type
}
"
)
self
.
sigmas
=
torch
.
cat
([
sigmas
,
torch
.
full
((
1
,),
fill_value
=
sigma_last
,
device
=
sigmas
.
device
)])
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment