Unverified Commit 110ffe25 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Allow saving trained betas (#1468)

parent 0b7225e9
...@@ -24,6 +24,8 @@ import re ...@@ -24,6 +24,8 @@ import re
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Dict, Tuple, Union from typing import Any, Dict, Tuple, Union
import numpy as np
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError from requests import HTTPError
...@@ -502,6 +504,12 @@ class ConfigMixin: ...@@ -502,6 +504,12 @@ class ConfigMixin:
config_dict["_class_name"] = self.__class__.__name__ config_dict["_class_name"] = self.__class__.__name__
config_dict["_diffusers_version"] = __version__ config_dict["_diffusers_version"] = __version__
def to_json_saveable(value):
if isinstance(value, np.ndarray):
value = value.tolist()
return value
config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path: Union[str, os.PathLike]): def to_json_file(self, json_file_path: Union[str, os.PathLike]):
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -123,7 +123,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -123,7 +123,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
beta_start: float = 0.0001, beta_start: float = 0.0001,
beta_end: float = 0.02, beta_end: float = 0.02,
beta_schedule: str = "linear", beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None, trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
clip_sample: bool = True, clip_sample: bool = True,
set_alpha_to_one: bool = True, set_alpha_to_one: bool = True,
steps_offset: int = 0, steps_offset: int = 0,
...@@ -139,7 +139,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -139,7 +139,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample") self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear": elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear": elif beta_schedule == "scaled_linear":
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -115,7 +115,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -115,7 +115,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
beta_start: float = 0.0001, beta_start: float = 0.0001,
beta_end: float = 0.02, beta_end: float = 0.02,
beta_schedule: str = "linear", beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None, trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
variance_type: str = "fixed_small", variance_type: str = "fixed_small",
clip_sample: bool = True, clip_sample: bool = True,
prediction_type: str = "epsilon", prediction_type: str = "epsilon",
...@@ -130,7 +130,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -130,7 +130,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample") self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear": elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear": elif beta_schedule == "scaled_linear":
......
...@@ -127,7 +127,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -127,7 +127,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
beta_start: float = 0.0001, beta_start: float = 0.0001,
beta_end: float = 0.02, beta_end: float = 0.02,
beta_schedule: str = "linear", beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None, trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
solver_order: int = 2, solver_order: int = 2,
prediction_type: str = "epsilon", prediction_type: str = "epsilon",
thresholding: bool = False, thresholding: bool = False,
...@@ -147,7 +147,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -147,7 +147,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample") self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear": elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear": elif beta_schedule == "scaled_linear":
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -77,10 +77,10 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -77,10 +77,10 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_start: float = 0.0001, beta_start: float = 0.0001,
beta_end: float = 0.02, beta_end: float = 0.02,
beta_schedule: str = "linear", beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None, trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear": elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear": elif beta_schedule == "scaled_linear":
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -78,11 +78,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -78,11 +78,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_start: float = 0.0001, beta_start: float = 0.0001,
beta_end: float = 0.02, beta_end: float = 0.02,
beta_schedule: str = "linear", beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None, trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon", prediction_type: str = "epsilon",
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear": elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear": elif beta_schedule == "scaled_linear":
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -53,10 +53,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -53,10 +53,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_start: float = 0.00085, # sensible defaults beta_start: float = 0.00085, # sensible defaults
beta_end: float = 0.012, beta_end: float = 0.012,
beta_schedule: str = "linear", beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None, trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear": elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear": elif beta_schedule == "scaled_linear":
......
...@@ -13,8 +13,9 @@ ...@@ -13,8 +13,9 @@
# limitations under the License. # limitations under the License.
import math import math
from typing import Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
...@@ -40,7 +41,9 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -40,7 +41,9 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
order = 1 order = 1
@register_to_config @register_to_config
def __init__(self, num_train_timesteps: int = 1000): def __init__(
self, num_train_timesteps: int = 1000, trained_betas: Optional[Union[np.ndarray, List[float]]] = None
):
# set `betas`, `alphas`, `timesteps` # set `betas`, `alphas`, `timesteps`
self.set_timesteps(num_train_timesteps) self.set_timesteps(num_train_timesteps)
...@@ -67,7 +70,11 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -67,7 +70,11 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
steps = torch.linspace(1, 0, num_inference_steps + 1)[:-1] steps = torch.linspace(1, 0, num_inference_steps + 1)[:-1]
steps = torch.cat([steps, torch.tensor([0.0])]) steps = torch.cat([steps, torch.tensor([0.0])])
self.betas = torch.sin(steps * math.pi / 2) ** 2 if self.config.trained_betas is not None:
self.betas = torch.tensor(self.config.trained_betas, dtype=torch.float32)
else:
self.betas = torch.sin(steps * math.pi / 2) ** 2
self.alphas = (1.0 - self.betas**2) ** 0.5 self.alphas = (1.0 - self.betas**2) ** 0.5
timesteps = (torch.atan2(self.betas, self.alphas) / math.pi * 2)[:-1] timesteps = (torch.atan2(self.betas, self.alphas) / math.pi * 2)[:-1]
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -77,10 +77,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -77,10 +77,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_start: float = 0.0001, beta_start: float = 0.0001,
beta_end: float = 0.02, beta_end: float = 0.02,
beta_schedule: str = "linear", beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None, trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear": elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear": elif beta_schedule == "scaled_linear":
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math import math
from typing import Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -99,13 +99,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -99,13 +99,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
beta_start: float = 0.0001, beta_start: float = 0.0001,
beta_end: float = 0.02, beta_end: float = 0.02,
beta_schedule: str = "linear", beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None, trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
skip_prk_steps: bool = False, skip_prk_steps: bool = False,
set_alpha_to_one: bool = False, set_alpha_to_one: bool = False,
steps_offset: int = 0, steps_offset: int = 0,
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear": elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear": elif beta_schedule == "scaled_linear":
......
...@@ -584,6 +584,20 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -584,6 +584,20 @@ class SchedulerCommonTest(unittest.TestCase):
" deprecated argument from `_deprecated_kwargs = [<deprecated_argument>]`" " deprecated argument from `_deprecated_kwargs = [<deprecated_argument>]`"
) )
def test_trained_betas(self):
for scheduler_class in self.scheduler_classes:
if scheduler_class == VQDiffusionScheduler:
continue
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config, trained_betas=np.array([0.0, 0.1]))
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_pretrained(tmpdirname)
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
assert scheduler.betas.tolist() == new_scheduler.betas.tolist()
class DDPMSchedulerTest(SchedulerCommonTest): class DDPMSchedulerTest(SchedulerCommonTest):
scheduler_classes = (DDPMScheduler,) scheduler_classes = (DDPMScheduler,)
...@@ -1423,7 +1437,6 @@ class LMSDiscreteSchedulerTest(SchedulerCommonTest): ...@@ -1423,7 +1437,6 @@ class LMSDiscreteSchedulerTest(SchedulerCommonTest):
"beta_start": 0.0001, "beta_start": 0.0001,
"beta_end": 0.02, "beta_end": 0.02,
"beta_schedule": "linear", "beta_schedule": "linear",
"trained_betas": None,
} }
config.update(**kwargs) config.update(**kwargs)
...@@ -1505,7 +1518,6 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest): ...@@ -1505,7 +1518,6 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest):
"beta_start": 0.0001, "beta_start": 0.0001,
"beta_end": 0.02, "beta_end": 0.02,
"beta_schedule": "linear", "beta_schedule": "linear",
"trained_betas": None,
} }
config.update(**kwargs) config.update(**kwargs)
...@@ -1596,7 +1608,6 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest): ...@@ -1596,7 +1608,6 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest):
"beta_start": 0.0001, "beta_start": 0.0001,
"beta_end": 0.02, "beta_end": 0.02,
"beta_schedule": "linear", "beta_schedule": "linear",
"trained_betas": None,
} }
config.update(**kwargs) config.update(**kwargs)
...@@ -1905,7 +1916,6 @@ class HeunDiscreteSchedulerTest(SchedulerCommonTest): ...@@ -1905,7 +1916,6 @@ class HeunDiscreteSchedulerTest(SchedulerCommonTest):
"beta_start": 0.0001, "beta_start": 0.0001,
"beta_end": 0.02, "beta_end": 0.02,
"beta_schedule": "linear", "beta_schedule": "linear",
"trained_betas": None,
} }
config.update(**kwargs) config.update(**kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment