Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
110ffe25
Unverified
Commit
110ffe25
authored
Nov 30, 2022
by
Patrick von Platen
Committed by
GitHub
Nov 30, 2022
Browse files
Allow saving trained betas (#1468)
parent
0b7225e9
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
55 additions
and
30 deletions
+55
-30
src/diffusers/configuration_utils.py
src/diffusers/configuration_utils.py
+8
-0
src/diffusers/schedulers/scheduling_ddim.py
src/diffusers/schedulers/scheduling_ddim.py
+3
-3
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+3
-3
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
+2
-2
src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
...ffusers/schedulers/scheduling_euler_ancestral_discrete.py
+3
-3
src/diffusers/schedulers/scheduling_euler_discrete.py
src/diffusers/schedulers/scheduling_euler_discrete.py
+3
-3
src/diffusers/schedulers/scheduling_heun.py
src/diffusers/schedulers/scheduling_heun.py
+3
-3
src/diffusers/schedulers/scheduling_ipndm.py
src/diffusers/schedulers/scheduling_ipndm.py
+10
-3
src/diffusers/schedulers/scheduling_lms_discrete.py
src/diffusers/schedulers/scheduling_lms_discrete.py
+3
-3
src/diffusers/schedulers/scheduling_pndm.py
src/diffusers/schedulers/scheduling_pndm.py
+3
-3
tests/test_scheduler.py
tests/test_scheduler.py
+14
-4
No files found.
src/diffusers/configuration_utils.py
View file @
110ffe25
...
...
@@ -24,6 +24,8 @@ import re
from
collections
import
OrderedDict
from
typing
import
Any
,
Dict
,
Tuple
,
Union
import
numpy
as
np
from
huggingface_hub
import
hf_hub_download
from
huggingface_hub.utils
import
EntryNotFoundError
,
RepositoryNotFoundError
,
RevisionNotFoundError
from
requests
import
HTTPError
...
...
@@ -502,6 +504,12 @@ class ConfigMixin:
config_dict
[
"_class_name"
]
=
self
.
__class__
.
__name__
config_dict
[
"_diffusers_version"
]
=
__version__
def
to_json_saveable
(
value
):
if
isinstance
(
value
,
np
.
ndarray
):
value
=
value
.
tolist
()
return
value
config_dict
=
{
k
:
to_json_saveable
(
v
)
for
k
,
v
in
config_dict
.
items
()}
return
json
.
dumps
(
config_dict
,
indent
=
2
,
sort_keys
=
True
)
+
"
\n
"
def
to_json_file
(
self
,
json_file_path
:
Union
[
str
,
os
.
PathLike
]):
...
...
src/diffusers/schedulers/scheduling_ddim.py
View file @
110ffe25
...
...
@@ -17,7 +17,7 @@
import
math
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
...
...
@@ -123,7 +123,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
beta_start
:
float
=
0.0001
,
beta_end
:
float
=
0.02
,
beta_schedule
:
str
=
"linear"
,
trained_betas
:
Optional
[
np
.
ndarray
]
=
None
,
trained_betas
:
Optional
[
Union
[
np
.
ndarray
,
List
[
float
]]
]
=
None
,
clip_sample
:
bool
=
True
,
set_alpha_to_one
:
bool
=
True
,
steps_offset
:
int
=
0
,
...
...
@@ -139,7 +139,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
if
trained_betas
is
not
None
:
self
.
betas
=
torch
.
from_numpy
(
trained_betas
)
self
.
betas
=
torch
.
tensor
(
trained_betas
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"linear"
:
self
.
betas
=
torch
.
linspace
(
beta_start
,
beta_end
,
num_train_timesteps
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"scaled_linear"
:
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
110ffe25
...
...
@@ -16,7 +16,7 @@
import
math
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
...
...
@@ -115,7 +115,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
beta_start
:
float
=
0.0001
,
beta_end
:
float
=
0.02
,
beta_schedule
:
str
=
"linear"
,
trained_betas
:
Optional
[
np
.
ndarray
]
=
None
,
trained_betas
:
Optional
[
Union
[
np
.
ndarray
,
List
[
float
]]
]
=
None
,
variance_type
:
str
=
"fixed_small"
,
clip_sample
:
bool
=
True
,
prediction_type
:
str
=
"epsilon"
,
...
...
@@ -130,7 +130,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
if
trained_betas
is
not
None
:
self
.
betas
=
torch
.
from_numpy
(
trained_betas
)
self
.
betas
=
torch
.
tensor
(
trained_betas
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"linear"
:
self
.
betas
=
torch
.
linspace
(
beta_start
,
beta_end
,
num_train_timesteps
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"scaled_linear"
:
...
...
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
View file @
110ffe25
...
...
@@ -127,7 +127,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
beta_start
:
float
=
0.0001
,
beta_end
:
float
=
0.02
,
beta_schedule
:
str
=
"linear"
,
trained_betas
:
Optional
[
np
.
ndarray
]
=
None
,
trained_betas
:
Optional
[
Union
[
np
.
ndarray
,
List
[
float
]]
]
=
None
,
solver_order
:
int
=
2
,
prediction_type
:
str
=
"epsilon"
,
thresholding
:
bool
=
False
,
...
...
@@ -147,7 +147,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self
.
register_to_config
(
prediction_type
=
"epsilon"
if
predict_epsilon
else
"sample"
)
if
trained_betas
is
not
None
:
self
.
betas
=
torch
.
from_numpy
(
trained_betas
)
self
.
betas
=
torch
.
tensor
(
trained_betas
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"linear"
:
self
.
betas
=
torch
.
linspace
(
beta_start
,
beta_end
,
num_train_timesteps
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"scaled_linear"
:
...
...
src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
View file @
110ffe25
...
...
@@ -13,7 +13,7 @@
# limitations under the License.
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
...
...
@@ -77,10 +77,10 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_start
:
float
=
0.0001
,
beta_end
:
float
=
0.02
,
beta_schedule
:
str
=
"linear"
,
trained_betas
:
Optional
[
np
.
ndarray
]
=
None
,
trained_betas
:
Optional
[
Union
[
np
.
ndarray
,
List
[
float
]]
]
=
None
,
):
if
trained_betas
is
not
None
:
self
.
betas
=
torch
.
from_numpy
(
trained_betas
)
self
.
betas
=
torch
.
tensor
(
trained_betas
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"linear"
:
self
.
betas
=
torch
.
linspace
(
beta_start
,
beta_end
,
num_train_timesteps
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"scaled_linear"
:
...
...
src/diffusers/schedulers/scheduling_euler_discrete.py
View file @
110ffe25
...
...
@@ -13,7 +13,7 @@
# limitations under the License.
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
...
...
@@ -78,11 +78,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_start
:
float
=
0.0001
,
beta_end
:
float
=
0.02
,
beta_schedule
:
str
=
"linear"
,
trained_betas
:
Optional
[
np
.
ndarray
]
=
None
,
trained_betas
:
Optional
[
Union
[
np
.
ndarray
,
List
[
float
]]
]
=
None
,
prediction_type
:
str
=
"epsilon"
,
):
if
trained_betas
is
not
None
:
self
.
betas
=
torch
.
from_numpy
(
trained_betas
)
self
.
betas
=
torch
.
tensor
(
trained_betas
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"linear"
:
self
.
betas
=
torch
.
linspace
(
beta_start
,
beta_end
,
num_train_timesteps
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"scaled_linear"
:
...
...
src/diffusers/schedulers/scheduling_heun.py
View file @
110ffe25
...
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
...
...
@@ -53,10 +53,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_start
:
float
=
0.00085
,
# sensible defaults
beta_end
:
float
=
0.012
,
beta_schedule
:
str
=
"linear"
,
trained_betas
:
Optional
[
np
.
ndarray
]
=
None
,
trained_betas
:
Optional
[
Union
[
np
.
ndarray
,
List
[
float
]]
]
=
None
,
):
if
trained_betas
is
not
None
:
self
.
betas
=
torch
.
from_numpy
(
trained_betas
)
self
.
betas
=
torch
.
tensor
(
trained_betas
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"linear"
:
self
.
betas
=
torch
.
linspace
(
beta_start
,
beta_end
,
num_train_timesteps
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"scaled_linear"
:
...
...
src/diffusers/schedulers/scheduling_ipndm.py
View file @
110ffe25
...
...
@@ -13,8 +13,9 @@
# limitations under the License.
import
math
from
typing
import
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
from
..configuration_utils
import
ConfigMixin
,
register_to_config
...
...
@@ -40,7 +41,9 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
order
=
1
@
register_to_config
def
__init__
(
self
,
num_train_timesteps
:
int
=
1000
):
def
__init__
(
self
,
num_train_timesteps
:
int
=
1000
,
trained_betas
:
Optional
[
Union
[
np
.
ndarray
,
List
[
float
]]]
=
None
):
# set `betas`, `alphas`, `timesteps`
self
.
set_timesteps
(
num_train_timesteps
)
...
...
@@ -67,7 +70,11 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
steps
=
torch
.
linspace
(
1
,
0
,
num_inference_steps
+
1
)[:
-
1
]
steps
=
torch
.
cat
([
steps
,
torch
.
tensor
([
0.0
])])
self
.
betas
=
torch
.
sin
(
steps
*
math
.
pi
/
2
)
**
2
if
self
.
config
.
trained_betas
is
not
None
:
self
.
betas
=
torch
.
tensor
(
self
.
config
.
trained_betas
,
dtype
=
torch
.
float32
)
else
:
self
.
betas
=
torch
.
sin
(
steps
*
math
.
pi
/
2
)
**
2
self
.
alphas
=
(
1.0
-
self
.
betas
**
2
)
**
0.5
timesteps
=
(
torch
.
atan2
(
self
.
betas
,
self
.
alphas
)
/
math
.
pi
*
2
)[:
-
1
]
...
...
src/diffusers/schedulers/scheduling_lms_discrete.py
View file @
110ffe25
...
...
@@ -13,7 +13,7 @@
# limitations under the License.
import
warnings
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
...
...
@@ -77,10 +77,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_start
:
float
=
0.0001
,
beta_end
:
float
=
0.02
,
beta_schedule
:
str
=
"linear"
,
trained_betas
:
Optional
[
np
.
ndarray
]
=
None
,
trained_betas
:
Optional
[
Union
[
np
.
ndarray
,
List
[
float
]]
]
=
None
,
):
if
trained_betas
is
not
None
:
self
.
betas
=
torch
.
from_numpy
(
trained_betas
)
self
.
betas
=
torch
.
tensor
(
trained_betas
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"linear"
:
self
.
betas
=
torch
.
linspace
(
beta_start
,
beta_end
,
num_train_timesteps
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"scaled_linear"
:
...
...
src/diffusers/schedulers/scheduling_pndm.py
View file @
110ffe25
...
...
@@ -15,7 +15,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import
math
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
...
...
@@ -99,13 +99,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
beta_start
:
float
=
0.0001
,
beta_end
:
float
=
0.02
,
beta_schedule
:
str
=
"linear"
,
trained_betas
:
Optional
[
np
.
ndarray
]
=
None
,
trained_betas
:
Optional
[
Union
[
np
.
ndarray
,
List
[
float
]]
]
=
None
,
skip_prk_steps
:
bool
=
False
,
set_alpha_to_one
:
bool
=
False
,
steps_offset
:
int
=
0
,
):
if
trained_betas
is
not
None
:
self
.
betas
=
torch
.
from_numpy
(
trained_betas
)
self
.
betas
=
torch
.
tensor
(
trained_betas
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"linear"
:
self
.
betas
=
torch
.
linspace
(
beta_start
,
beta_end
,
num_train_timesteps
,
dtype
=
torch
.
float32
)
elif
beta_schedule
==
"scaled_linear"
:
...
...
tests/test_scheduler.py
View file @
110ffe25
...
...
@@ -584,6 +584,20 @@ class SchedulerCommonTest(unittest.TestCase):
" deprecated argument from `_deprecated_kwargs = [<deprecated_argument>]`"
)
def
test_trained_betas
(
self
):
for
scheduler_class
in
self
.
scheduler_classes
:
if
scheduler_class
==
VQDiffusionScheduler
:
continue
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
,
trained_betas
=
np
.
array
([
0.0
,
0.1
]))
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_pretrained
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_pretrained
(
tmpdirname
)
assert
scheduler
.
betas
.
tolist
()
==
new_scheduler
.
betas
.
tolist
()
class
DDPMSchedulerTest
(
SchedulerCommonTest
):
scheduler_classes
=
(
DDPMScheduler
,)
...
...
@@ -1423,7 +1437,6 @@ class LMSDiscreteSchedulerTest(SchedulerCommonTest):
"beta_start"
:
0.0001
,
"beta_end"
:
0.02
,
"beta_schedule"
:
"linear"
,
"trained_betas"
:
None
,
}
config
.
update
(
**
kwargs
)
...
...
@@ -1505,7 +1518,6 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest):
"beta_start"
:
0.0001
,
"beta_end"
:
0.02
,
"beta_schedule"
:
"linear"
,
"trained_betas"
:
None
,
}
config
.
update
(
**
kwargs
)
...
...
@@ -1596,7 +1608,6 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest):
"beta_start"
:
0.0001
,
"beta_end"
:
0.02
,
"beta_schedule"
:
"linear"
,
"trained_betas"
:
None
,
}
config
.
update
(
**
kwargs
)
...
...
@@ -1905,7 +1916,6 @@ class HeunDiscreteSchedulerTest(SchedulerCommonTest):
"beta_start"
:
0.0001
,
"beta_end"
:
0.02
,
"beta_schedule"
:
"linear"
,
"trained_betas"
:
None
,
}
config
.
update
(
**
kwargs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment