You need to sign in or sign up before continuing.
Commit 760dcb1f authored by Patrick von Platen's avatar Patrick von Platen
Browse files

fix score sde ve scheduler

parent 919e27d3
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit # TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
import pdb
from typing import Union from typing import Union
import numpy as np import numpy as np
...@@ -55,39 +54,35 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -55,39 +54,35 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
# self.num_inference_steps = None # self.num_inference_steps = None
self.timesteps = None self.timesteps = None
self.set_sigmas(self.num_train_timesteps) self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps)
self.tensor_format = tensor_format self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format) self.set_format(tensor_format=tensor_format)
def set_timesteps(self, num_inference_steps): def set_timesteps(self, num_inference_steps, sampling_eps=None):
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
tensor_format = getattr(self, "tensor_format", "pt") tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np": if tensor_format == "np":
self.timesteps = np.linspace(1, self.config.sampling_eps, num_inference_steps) self.timesteps = np.linspace(1, sampling_eps, num_inference_steps)
elif tensor_format == "pt": elif tensor_format == "pt":
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps) self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps)
else: else:
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def set_sigmas(self, num_inference_steps): def set_sigmas(self, num_inference_steps, sigma_min=None, sigma_max=None, sampling_eps=None):
sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min
sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
if self.timesteps is None: if self.timesteps is None:
self.set_timesteps(num_inference_steps) self.set_timesteps(num_inference_steps, sampling_eps)
tensor_format = getattr(self, "tensor_format", "pt") tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np": if tensor_format == "np":
self.discrete_sigmas = np.exp( self.discrete_sigmas = np.exp(np.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps))
np.linspace(np.log(self.config.sigma_min), np.log(self.config.sigma_max), num_inference_steps) self.sigmas = np.array([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps])
)
self.sigmas = np.array(
[self.config.sigma_min * (self.config.sigma_max / self.sigma_min) ** t for t in self.timesteps]
)
elif tensor_format == "pt": elif tensor_format == "pt":
self.discrete_sigmas = torch.exp( self.discrete_sigmas = torch.exp(torch.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps))
torch.linspace(np.log(self.config.sigma_min), np.log(self.config.sigma_max), num_inference_steps) self.sigmas = torch.tensor([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps])
)
self.sigmas = torch.tensor(
[self.config.sigma_min * (self.config.sigma_max / self.sigma_min) ** t for t in self.timesteps]
)
else: else:
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment