"...text-generation-inference.git" did not exist on "337afb28422f421353a2ce756260ff01e51fb7d1"
Commit 760dcb1f authored by Patrick von Platen's avatar Patrick von Platen
Browse files

fix score sde ve scheduler

parent 919e27d3
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit # TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
import pdb
from typing import Union from typing import Union
import numpy as np import numpy as np
...@@ -55,39 +54,35 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -55,39 +54,35 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
# self.num_inference_steps = None # self.num_inference_steps = None
self.timesteps = None self.timesteps = None
self.set_sigmas(self.num_train_timesteps) self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps)
self.tensor_format = tensor_format self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format) self.set_format(tensor_format=tensor_format)
def set_timesteps(self, num_inference_steps): def set_timesteps(self, num_inference_steps, sampling_eps=None):
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
tensor_format = getattr(self, "tensor_format", "pt") tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np": if tensor_format == "np":
self.timesteps = np.linspace(1, self.config.sampling_eps, num_inference_steps) self.timesteps = np.linspace(1, sampling_eps, num_inference_steps)
elif tensor_format == "pt": elif tensor_format == "pt":
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps) self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps)
else: else:
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def set_sigmas(self, num_inference_steps): def set_sigmas(self, num_inference_steps, sigma_min=None, sigma_max=None, sampling_eps=None):
sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min
sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
if self.timesteps is None: if self.timesteps is None:
self.set_timesteps(num_inference_steps) self.set_timesteps(num_inference_steps, sampling_eps)
tensor_format = getattr(self, "tensor_format", "pt") tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np": if tensor_format == "np":
self.discrete_sigmas = np.exp( self.discrete_sigmas = np.exp(np.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps))
np.linspace(np.log(self.config.sigma_min), np.log(self.config.sigma_max), num_inference_steps) self.sigmas = np.array([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps])
)
self.sigmas = np.array(
[self.config.sigma_min * (self.config.sigma_max / self.sigma_min) ** t for t in self.timesteps]
)
elif tensor_format == "pt": elif tensor_format == "pt":
self.discrete_sigmas = torch.exp( self.discrete_sigmas = torch.exp(torch.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps))
torch.linspace(np.log(self.config.sigma_min), np.log(self.config.sigma_max), num_inference_steps) self.sigmas = torch.tensor([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps])
)
self.sigmas = torch.tensor(
[self.config.sigma_min * (self.config.sigma_max / self.sigma_min) ** t for t in self.timesteps]
)
else: else:
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment