Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
760dcb1f
You need to sign in or sign up before continuing.
Commit
760dcb1f
authored
Jul 20, 2022
by
Patrick von Platen
Browse files
fix score sde ve scheduler
parent
919e27d3
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
19 deletions
+14
-19
src/diffusers/schedulers/scheduling_sde_ve.py
src/diffusers/schedulers/scheduling_sde_ve.py
+14
-19
No files found.
src/diffusers/schedulers/scheduling_sde_ve.py
View file @
760dcb1f
...
@@ -15,7 +15,6 @@
...
@@ -15,7 +15,6 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
import
pdb
from
typing
import
Union
from
typing
import
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -55,39 +54,35 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
...
@@ -55,39 +54,35 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
# self.num_inference_steps = None
# self.num_inference_steps = None
self
.
timesteps
=
None
self
.
timesteps
=
None
self
.
set_sigmas
(
self
.
num_train_timesteps
)
self
.
set_sigmas
(
num_train_timesteps
,
sigma_min
,
sigma_max
,
sampling_eps
)
self
.
tensor_format
=
tensor_format
self
.
tensor_format
=
tensor_format
self
.
set_format
(
tensor_format
=
tensor_format
)
self
.
set_format
(
tensor_format
=
tensor_format
)
def
set_timesteps
(
self
,
num_inference_steps
):
def
set_timesteps
(
self
,
num_inference_steps
,
sampling_eps
=
None
):
sampling_eps
=
sampling_eps
if
sampling_eps
is
not
None
else
self
.
config
.
sampling_eps
tensor_format
=
getattr
(
self
,
"tensor_format"
,
"pt"
)
tensor_format
=
getattr
(
self
,
"tensor_format"
,
"pt"
)
if
tensor_format
==
"np"
:
if
tensor_format
==
"np"
:
self
.
timesteps
=
np
.
linspace
(
1
,
self
.
config
.
sampling_eps
,
num_inference_steps
)
self
.
timesteps
=
np
.
linspace
(
1
,
sampling_eps
,
num_inference_steps
)
elif
tensor_format
==
"pt"
:
elif
tensor_format
==
"pt"
:
self
.
timesteps
=
torch
.
linspace
(
1
,
self
.
config
.
sampling_eps
,
num_inference_steps
)
self
.
timesteps
=
torch
.
linspace
(
1
,
sampling_eps
,
num_inference_steps
)
else
:
else
:
raise
ValueError
(
f
"`self.tensor_format`:
{
self
.
tensor_format
}
is not valid."
)
raise
ValueError
(
f
"`self.tensor_format`:
{
self
.
tensor_format
}
is not valid."
)
def
set_sigmas
(
self
,
num_inference_steps
):
def
set_sigmas
(
self
,
num_inference_steps
,
sigma_min
=
None
,
sigma_max
=
None
,
sampling_eps
=
None
):
sigma_min
=
sigma_min
if
sigma_min
is
not
None
else
self
.
config
.
sigma_min
sigma_max
=
sigma_max
if
sigma_max
is
not
None
else
self
.
config
.
sigma_max
sampling_eps
=
sampling_eps
if
sampling_eps
is
not
None
else
self
.
config
.
sampling_eps
if
self
.
timesteps
is
None
:
if
self
.
timesteps
is
None
:
self
.
set_timesteps
(
num_inference_steps
)
self
.
set_timesteps
(
num_inference_steps
,
sampling_eps
)
tensor_format
=
getattr
(
self
,
"tensor_format"
,
"pt"
)
tensor_format
=
getattr
(
self
,
"tensor_format"
,
"pt"
)
if
tensor_format
==
"np"
:
if
tensor_format
==
"np"
:
self
.
discrete_sigmas
=
np
.
exp
(
self
.
discrete_sigmas
=
np
.
exp
(
np
.
linspace
(
np
.
log
(
sigma_min
),
np
.
log
(
sigma_max
),
num_inference_steps
))
np
.
linspace
(
np
.
log
(
self
.
config
.
sigma_min
),
np
.
log
(
self
.
config
.
sigma_max
),
num_inference_steps
)
self
.
sigmas
=
np
.
array
([
sigma_min
*
(
sigma_max
/
sigma_min
)
**
t
for
t
in
self
.
timesteps
])
)
self
.
sigmas
=
np
.
array
(
[
self
.
config
.
sigma_min
*
(
self
.
config
.
sigma_max
/
self
.
sigma_min
)
**
t
for
t
in
self
.
timesteps
]
)
elif
tensor_format
==
"pt"
:
elif
tensor_format
==
"pt"
:
self
.
discrete_sigmas
=
torch
.
exp
(
self
.
discrete_sigmas
=
torch
.
exp
(
torch
.
linspace
(
np
.
log
(
sigma_min
),
np
.
log
(
sigma_max
),
num_inference_steps
))
torch
.
linspace
(
np
.
log
(
self
.
config
.
sigma_min
),
np
.
log
(
self
.
config
.
sigma_max
),
num_inference_steps
)
self
.
sigmas
=
torch
.
tensor
([
sigma_min
*
(
sigma_max
/
sigma_min
)
**
t
for
t
in
self
.
timesteps
])
)
self
.
sigmas
=
torch
.
tensor
(
[
self
.
config
.
sigma_min
*
(
self
.
config
.
sigma_max
/
self
.
sigma_min
)
**
t
for
t
in
self
.
timesteps
]
)
else
:
else
:
raise
ValueError
(
f
"`self.tensor_format`:
{
self
.
tensor_format
}
is not valid."
)
raise
ValueError
(
f
"`self.tensor_format`:
{
self
.
tensor_format
}
is not valid."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment