Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
110ffe25
Unverified
Commit
110ffe25
authored
Nov 30, 2022
by
Patrick von Platen
Committed by
GitHub
Nov 30, 2022
Browse files
Allow saving trained betas (#1468)
parent
0b7225e9
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
55 additions
and
30 deletions
+55
-30
src/diffusers/configuration_utils.py
src/diffusers/configuration_utils.py
+8
-0
src/diffusers/schedulers/scheduling_ddim.py
src/diffusers/schedulers/scheduling_ddim.py
+3
-3
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+3
-3
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
+2
-2
src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
...ffusers/schedulers/scheduling_euler_ancestral_discrete.py
+3
-3
src/diffusers/schedulers/scheduling_euler_discrete.py
src/diffusers/schedulers/scheduling_euler_discrete.py
+3
-3
src/diffusers/schedulers/scheduling_heun.py
src/diffusers/schedulers/scheduling_heun.py
+3
-3
src/diffusers/schedulers/scheduling_ipndm.py
src/diffusers/schedulers/scheduling_ipndm.py
+10
-3
src/diffusers/schedulers/scheduling_lms_discrete.py
src/diffusers/schedulers/scheduling_lms_discrete.py
+3
-3
src/diffusers/schedulers/scheduling_pndm.py
src/diffusers/schedulers/scheduling_pndm.py
+3
-3
tests/test_scheduler.py
tests/test_scheduler.py
+14
-4
No files found.
src/diffusers/configuration_utils.py
View file @
110ffe25
...
@@ -24,6 +24,8 @@ import re
...
@@ -24,6 +24,8 @@ import re
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
typing
import
Any
,
Dict
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Tuple
,
Union
import
numpy
as
np
from
huggingface_hub
import
hf_hub_download
from
huggingface_hub
import
hf_hub_download
from
huggingface_hub.utils
import
EntryNotFoundError
,
RepositoryNotFoundError
,
RevisionNotFoundError
from
huggingface_hub.utils
import
EntryNotFoundError
,
RepositoryNotFoundError
,
RevisionNotFoundError
from
requests
import
HTTPError
from
requests
import
HTTPError
...
@@ -502,6 +504,12 @@ class ConfigMixin:
...
@@ -502,6 +504,12 @@ class ConfigMixin:
config_dict
[
"_class_name"
]
=
self
.
__class__
.
__name__
config_dict
[
"_class_name"
]
=
self
.
__class__
.
__name__
config_dict
[
"_diffusers_version"
]
=
__version__
config_dict
[
"_diffusers_version"
]
=
__version__
def
to_json_saveable
(
value
):
if
isinstance
(
value
,
np
.
ndarray
):
value
=
value
.
tolist
()
return
value
config_dict
=
{
k
:
to_json_saveable
(
v
)
for
k
,
v
in
config_dict
.
items
()}
return
json
.
dumps
(
config_dict
,
indent
=
2
,
sort_keys
=
True
)
+
"
\n
"
return
json
.
dumps
(
config_dict
,
indent
=
2
,
sort_keys
=
True
)
+
"
\n
"
def
to_json_file
(
self
,
json_file_path
:
Union
[
str
,
os
.
PathLike
]):
def
to_json_file
(
self
,
json_file_path
:
Union
[
str
,
os
.
PathLike
]):
...
...
src/diffusers/schedulers/scheduling_ddim.py
View file @
110ffe25
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
import
math
import
math
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -123,7 +123,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -123,7 +123,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
beta_start
:
float
=
0.0001
,
beta_start
:
float
=
0.0001
,
beta_end
:
float
=
0.02
,
beta_end
:
float
=
0.02
,
beta_schedule
:
str
=
"linear"
,
beta_schedule
:
str
=
"linear"
,
trained_betas
:
Optional
[
np
.
ndarray
]
=
None
,
trained_betas
:
Optional
[
Union
[
np
.
ndarray
,
List
[
float
]]
]
=
None
,
clip_sample
:
bool
=
True
,
clip_sample
:
bool
=
True
,
set_alpha_to_one
:
bool
=
True
,
set_alpha_to_one
:
bool
=
True
,
steps_offset
:
int
=
0
,
steps_offset
:
int
=
0
,
...
@@ -139,7 +139,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -139,7 +139,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
if
trained_betas
is
not
None
:
if
trained_betas
is
not
None
:
self
.
betas
=
torch
.
from_numpy
(
trained_betas
)
self
.
betas
=
torch
.
tensor
(
trained_betas
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"linear"
:
elif
beta_schedule
==
"linear"
:
self
.
betas
=
torch
.
linspace
(
beta_start
,
beta_end
,
num_train_timesteps
,
dtype
=
torch
.
float32
)
self
.
betas
=
torch
.
linspace
(
beta_start
,
beta_end
,
num_train_timesteps
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"scaled_linear"
:
elif
beta_schedule
==
"scaled_linear"
:
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
110ffe25
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
import
math
import
math
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -115,7 +115,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -115,7 +115,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
beta_start
:
float
=
0.0001
,
beta_start
:
float
=
0.0001
,
beta_end
:
float
=
0.02
,
beta_end
:
float
=
0.02
,
beta_schedule
:
str
=
"linear"
,
beta_schedule
:
str
=
"linear"
,
trained_betas
:
Optional
[
np
.
ndarray
]
=
None
,
trained_betas
:
Optional
[
Union
[
np
.
ndarray
,
List
[
float
]]
]
=
None
,
variance_type
:
str
=
"fixed_small"
,
variance_type
:
str
=
"fixed_small"
,
clip_sample
:
bool
=
True
,
clip_sample
:
bool
=
True
,
prediction_type
:
str
=
"epsilon"
,
prediction_type
:
str
=
"epsilon"
,
...
@@ -130,7 +130,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -130,7 +130,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
if
trained_betas
is
not
None
:
if
trained_betas
is
not
None
:
self
.
betas
=
torch
.
from_numpy
(
trained_betas
)
self
.
betas
=
torch
.
tensor
(
trained_betas
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"linear"
:
elif
beta_schedule
==
"linear"
:
self
.
betas
=
torch
.
linspace
(
beta_start
,
beta_end
,
num_train_timesteps
,
dtype
=
torch
.
float32
)
self
.
betas
=
torch
.
linspace
(
beta_start
,
beta_end
,
num_train_timesteps
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"scaled_linear"
:
elif
beta_schedule
==
"scaled_linear"
:
...
...
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
View file @
110ffe25
...
@@ -127,7 +127,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -127,7 +127,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
beta_start
:
float
=
0.0001
,
beta_start
:
float
=
0.0001
,
beta_end
:
float
=
0.02
,
beta_end
:
float
=
0.02
,
beta_schedule
:
str
=
"linear"
,
beta_schedule
:
str
=
"linear"
,
trained_betas
:
Optional
[
np
.
ndarray
]
=
None
,
trained_betas
:
Optional
[
Union
[
np
.
ndarray
,
List
[
float
]]
]
=
None
,
solver_order
:
int
=
2
,
solver_order
:
int
=
2
,
prediction_type
:
str
=
"epsilon"
,
prediction_type
:
str
=
"epsilon"
,
thresholding
:
bool
=
False
,
thresholding
:
bool
=
False
,
...
@@ -147,7 +147,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -147,7 +147,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
if
trained_betas
is
not
None
:
if
trained_betas
is
not
None
:
self
.
betas
=
torch
.
from_numpy
(
trained_betas
)
self
.
betas
=
torch
.
tensor
(
trained_betas
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"linear"
:
elif
beta_schedule
==
"linear"
:
self
.
betas
=
torch
.
linspace
(
beta_start
,
beta_end
,
num_train_timesteps
,
dtype
=
torch
.
float32
)
self
.
betas
=
torch
.
linspace
(
beta_start
,
beta_end
,
num_train_timesteps
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"scaled_linear"
:
elif
beta_schedule
==
"scaled_linear"
:
...
...
src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
View file @
110ffe25
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# limitations under the License.
# limitations under the License.
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -77,10 +77,10 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -77,10 +77,10 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_start
:
float
=
0.0001
,
beta_start
:
float
=
0.0001
,
beta_end
:
float
=
0.02
,
beta_end
:
float
=
0.02
,
beta_schedule
:
str
=
"linear"
,
beta_schedule
:
str
=
"linear"
,
trained_betas
:
Optional
[
np
.
ndarray
]
=
None
,
trained_betas
:
Optional
[
Union
[
np
.
ndarray
,
List
[
float
]]
]
=
None
,
):
):
if
trained_betas
is
not
None
:
if
trained_betas
is
not
None
:
self
.
betas
=
torch
.
from_numpy
(
trained_betas
)
self
.
betas
=
torch
.
tensor
(
trained_betas
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"linear"
:
elif
beta_schedule
==
"linear"
:
self
.
betas
=
torch
.
linspace
(
beta_start
,
beta_end
,
num_train_timesteps
,
dtype
=
torch
.
float32
)
self
.
betas
=
torch
.
linspace
(
beta_start
,
beta_end
,
num_train_timesteps
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"scaled_linear"
:
elif
beta_schedule
==
"scaled_linear"
:
...
...
src/diffusers/schedulers/scheduling_euler_discrete.py
View file @
110ffe25
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# limitations under the License.
# limitations under the License.
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -78,11 +78,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -78,11 +78,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_start
:
float
=
0.0001
,
beta_start
:
float
=
0.0001
,
beta_end
:
float
=
0.02
,
beta_end
:
float
=
0.02
,
beta_schedule
:
str
=
"linear"
,
beta_schedule
:
str
=
"linear"
,
trained_betas
:
Optional
[
np
.
ndarray
]
=
None
,
trained_betas
:
Optional
[
Union
[
np
.
ndarray
,
List
[
float
]]
]
=
None
,
prediction_type
:
str
=
"epsilon"
,
prediction_type
:
str
=
"epsilon"
,
):
):
if
trained_betas
is
not
None
:
if
trained_betas
is
not
None
:
self
.
betas
=
torch
.
from_numpy
(
trained_betas
)
self
.
betas
=
torch
.
tensor
(
trained_betas
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"linear"
:
elif
beta_schedule
==
"linear"
:
self
.
betas
=
torch
.
linspace
(
beta_start
,
beta_end
,
num_train_timesteps
,
dtype
=
torch
.
float32
)
self
.
betas
=
torch
.
linspace
(
beta_start
,
beta_end
,
num_train_timesteps
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"scaled_linear"
:
elif
beta_schedule
==
"scaled_linear"
:
...
...
src/diffusers/schedulers/scheduling_heun.py
View file @
110ffe25
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -53,10 +53,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -53,10 +53,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_start
:
float
=
0.00085
,
# sensible defaults
beta_start
:
float
=
0.00085
,
# sensible defaults
beta_end
:
float
=
0.012
,
beta_end
:
float
=
0.012
,
beta_schedule
:
str
=
"linear"
,
beta_schedule
:
str
=
"linear"
,
trained_betas
:
Optional
[
np
.
ndarray
]
=
None
,
trained_betas
:
Optional
[
Union
[
np
.
ndarray
,
List
[
float
]]
]
=
None
,
):
):
if
trained_betas
is
not
None
:
if
trained_betas
is
not
None
:
self
.
betas
=
torch
.
from_numpy
(
trained_betas
)
self
.
betas
=
torch
.
tensor
(
trained_betas
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"linear"
:
elif
beta_schedule
==
"linear"
:
self
.
betas
=
torch
.
linspace
(
beta_start
,
beta_end
,
num_train_timesteps
,
dtype
=
torch
.
float32
)
self
.
betas
=
torch
.
linspace
(
beta_start
,
beta_end
,
num_train_timesteps
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"scaled_linear"
:
elif
beta_schedule
==
"scaled_linear"
:
...
...
src/diffusers/schedulers/scheduling_ipndm.py
View file @
110ffe25
...
@@ -13,8 +13,9 @@
...
@@ -13,8 +13,9 @@
# limitations under the License.
# limitations under the License.
import
math
import
math
from
typing
import
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
import
torch
from
..configuration_utils
import
ConfigMixin
,
register_to_config
from
..configuration_utils
import
ConfigMixin
,
register_to_config
...
@@ -40,7 +41,9 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -40,7 +41,9 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
order
=
1
order
=
1
@
register_to_config
@
register_to_config
def
__init__
(
self
,
num_train_timesteps
:
int
=
1000
):
def
__init__
(
self
,
num_train_timesteps
:
int
=
1000
,
trained_betas
:
Optional
[
Union
[
np
.
ndarray
,
List
[
float
]]]
=
None
):
# set `betas`, `alphas`, `timesteps`
# set `betas`, `alphas`, `timesteps`
self
.
set_timesteps
(
num_train_timesteps
)
self
.
set_timesteps
(
num_train_timesteps
)
...
@@ -67,7 +70,11 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -67,7 +70,11 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
steps
=
torch
.
linspace
(
1
,
0
,
num_inference_steps
+
1
)[:
-
1
]
steps
=
torch
.
linspace
(
1
,
0
,
num_inference_steps
+
1
)[:
-
1
]
steps
=
torch
.
cat
([
steps
,
torch
.
tensor
([
0.0
])])
steps
=
torch
.
cat
([
steps
,
torch
.
tensor
([
0.0
])])
self
.
betas
=
torch
.
sin
(
steps
*
math
.
pi
/
2
)
**
2
if
self
.
config
.
trained_betas
is
not
None
:
self
.
betas
=
torch
.
tensor
(
self
.
config
.
trained_betas
,
dtype
=
torch
.
float32
)
else
:
self
.
betas
=
torch
.
sin
(
steps
*
math
.
pi
/
2
)
**
2
self
.
alphas
=
(
1.0
-
self
.
betas
**
2
)
**
0.5
self
.
alphas
=
(
1.0
-
self
.
betas
**
2
)
**
0.5
timesteps
=
(
torch
.
atan2
(
self
.
betas
,
self
.
alphas
)
/
math
.
pi
*
2
)[:
-
1
]
timesteps
=
(
torch
.
atan2
(
self
.
betas
,
self
.
alphas
)
/
math
.
pi
*
2
)[:
-
1
]
...
...
src/diffusers/schedulers/scheduling_lms_discrete.py
View file @
110ffe25
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# limitations under the License.
# limitations under the License.
import
warnings
import
warnings
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -77,10 +77,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -77,10 +77,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_start
:
float
=
0.0001
,
beta_start
:
float
=
0.0001
,
beta_end
:
float
=
0.02
,
beta_end
:
float
=
0.02
,
beta_schedule
:
str
=
"linear"
,
beta_schedule
:
str
=
"linear"
,
trained_betas
:
Optional
[
np
.
ndarray
]
=
None
,
trained_betas
:
Optional
[
Union
[
np
.
ndarray
,
List
[
float
]]
]
=
None
,
):
):
if
trained_betas
is
not
None
:
if
trained_betas
is
not
None
:
self
.
betas
=
torch
.
from_numpy
(
trained_betas
)
self
.
betas
=
torch
.
tensor
(
trained_betas
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"linear"
:
elif
beta_schedule
==
"linear"
:
self
.
betas
=
torch
.
linspace
(
beta_start
,
beta_end
,
num_train_timesteps
,
dtype
=
torch
.
float32
)
self
.
betas
=
torch
.
linspace
(
beta_start
,
beta_end
,
num_train_timesteps
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"scaled_linear"
:
elif
beta_schedule
==
"scaled_linear"
:
...
...
src/diffusers/schedulers/scheduling_pndm.py
View file @
110ffe25
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import
math
import
math
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -99,13 +99,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -99,13 +99,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
beta_start
:
float
=
0.0001
,
beta_start
:
float
=
0.0001
,
beta_end
:
float
=
0.02
,
beta_end
:
float
=
0.02
,
beta_schedule
:
str
=
"linear"
,
beta_schedule
:
str
=
"linear"
,
trained_betas
:
Optional
[
np
.
ndarray
]
=
None
,
trained_betas
:
Optional
[
Union
[
np
.
ndarray
,
List
[
float
]]
]
=
None
,
skip_prk_steps
:
bool
=
False
,
skip_prk_steps
:
bool
=
False
,
set_alpha_to_one
:
bool
=
False
,
set_alpha_to_one
:
bool
=
False
,
steps_offset
:
int
=
0
,
steps_offset
:
int
=
0
,
):
):
if
trained_betas
is
not
None
:
if
trained_betas
is
not
None
:
self
.
betas
=
torch
.
from_numpy
(
trained_betas
)
self
.
betas
=
torch
.
tensor
(
trained_betas
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"linear"
:
elif
beta_schedule
==
"linear"
:
self
.
betas
=
torch
.
linspace
(
beta_start
,
beta_end
,
num_train_timesteps
,
dtype
=
torch
.
float32
)
self
.
betas
=
torch
.
linspace
(
beta_start
,
beta_end
,
num_train_timesteps
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"scaled_linear"
:
elif
beta_schedule
==
"scaled_linear"
:
...
...
tests/test_scheduler.py
View file @
110ffe25
...
@@ -584,6 +584,20 @@ class SchedulerCommonTest(unittest.TestCase):
...
@@ -584,6 +584,20 @@ class SchedulerCommonTest(unittest.TestCase):
" deprecated argument from `_deprecated_kwargs = [<deprecated_argument>]`"
" deprecated argument from `_deprecated_kwargs = [<deprecated_argument>]`"
)
)
def
test_trained_betas
(
self
):
for
scheduler_class
in
self
.
scheduler_classes
:
if
scheduler_class
==
VQDiffusionScheduler
:
continue
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
,
trained_betas
=
np
.
array
([
0.0
,
0.1
]))
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_pretrained
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_pretrained
(
tmpdirname
)
assert
scheduler
.
betas
.
tolist
()
==
new_scheduler
.
betas
.
tolist
()
class
DDPMSchedulerTest
(
SchedulerCommonTest
):
class
DDPMSchedulerTest
(
SchedulerCommonTest
):
scheduler_classes
=
(
DDPMScheduler
,)
scheduler_classes
=
(
DDPMScheduler
,)
...
@@ -1423,7 +1437,6 @@ class LMSDiscreteSchedulerTest(SchedulerCommonTest):
...
@@ -1423,7 +1437,6 @@ class LMSDiscreteSchedulerTest(SchedulerCommonTest):
"beta_start"
:
0.0001
,
"beta_start"
:
0.0001
,
"beta_end"
:
0.02
,
"beta_end"
:
0.02
,
"beta_schedule"
:
"linear"
,
"beta_schedule"
:
"linear"
,
"trained_betas"
:
None
,
}
}
config
.
update
(
**
kwargs
)
config
.
update
(
**
kwargs
)
...
@@ -1505,7 +1518,6 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest):
...
@@ -1505,7 +1518,6 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest):
"beta_start"
:
0.0001
,
"beta_start"
:
0.0001
,
"beta_end"
:
0.02
,
"beta_end"
:
0.02
,
"beta_schedule"
:
"linear"
,
"beta_schedule"
:
"linear"
,
"trained_betas"
:
None
,
}
}
config
.
update
(
**
kwargs
)
config
.
update
(
**
kwargs
)
...
@@ -1596,7 +1608,6 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest):
...
@@ -1596,7 +1608,6 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest):
"beta_start"
:
0.0001
,
"beta_start"
:
0.0001
,
"beta_end"
:
0.02
,
"beta_end"
:
0.02
,
"beta_schedule"
:
"linear"
,
"beta_schedule"
:
"linear"
,
"trained_betas"
:
None
,
}
}
config
.
update
(
**
kwargs
)
config
.
update
(
**
kwargs
)
...
@@ -1905,7 +1916,6 @@ class HeunDiscreteSchedulerTest(SchedulerCommonTest):
...
@@ -1905,7 +1916,6 @@ class HeunDiscreteSchedulerTest(SchedulerCommonTest):
"beta_start"
:
0.0001
,
"beta_start"
:
0.0001
,
"beta_end"
:
0.02
,
"beta_end"
:
0.02
,
"beta_schedule"
:
"linear"
,
"beta_schedule"
:
"linear"
,
"trained_betas"
:
None
,
}
}
config
.
update
(
**
kwargs
)
config
.
update
(
**
kwargs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment