Unverified Commit a127363d authored by Daniel Hug's avatar Daniel Hug Committed by GitHub
Browse files

Add typing to scheduling_sde_ve: init, set_timesteps, and set_sigmas function definitions (#412)



Add typing to scheduling_sde_ve init, set_timesteps, and set_sigmas functions
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent b8894f18
...@@ -65,13 +65,13 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -65,13 +65,13 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
num_train_timesteps=2000, num_train_timesteps: int = 2000,
snr=0.15, snr: float = 0.15,
sigma_min=0.01, sigma_min: float = 0.01,
sigma_max=1348, sigma_max: float = 1348.0,
sampling_eps=1e-5, sampling_eps: float = 1e-5,
correct_steps=1, correct_steps: int = 1,
tensor_format="pt", tensor_format: str = "pt",
): ):
# setable values # setable values
self.timesteps = None self.timesteps = None
...@@ -81,7 +81,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -81,7 +81,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
self.tensor_format = tensor_format self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format) self.set_format(tensor_format=tensor_format)
def set_timesteps(self, num_inference_steps, sampling_eps=None): def set_timesteps(self, num_inference_steps: int, sampling_eps: float = None):
""" """
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
...@@ -100,7 +100,9 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -100,7 +100,9 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
else: else:
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def set_sigmas(self, num_inference_steps, sigma_min=None, sigma_max=None, sampling_eps=None): def set_sigmas(
self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None
):
""" """
Sets the noise scales used for the diffusion chain. Supporting function to be run before inference. Sets the noise scales used for the diffusion chain. Supporting function to be run before inference.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment