Unverified Commit 170af08e authored by Richard Löwenström's avatar Richard Löwenström Committed by GitHub
Browse files

Easily understandable error if inference steps not set before using scheduler (#263) (#264)



* Helpful exception if inference steps not set in schedulers (#263)

* Apply suggestions from codereview by patrickvonplaten

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
parent 76985bc8
...@@ -117,6 +117,11 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -117,6 +117,11 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
use_clipped_model_output: bool = False, use_clipped_model_output: bool = False,
generator=None, generator=None,
): ):
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding # Ideally, read DDIM paper in-detail understanding
......
...@@ -145,6 +145,11 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -145,6 +145,11 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
solution to the differential equation. solution to the differential equation.
""" """
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2 diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2
prev_timestep = max(timestep - diff_to_prev, self.prk_timesteps[-1]) prev_timestep = max(timestep - diff_to_prev, self.prk_timesteps[-1])
timestep = self.prk_timesteps[self.counter // 4 * 4] timestep = self.prk_timesteps[self.counter // 4 * 4]
...@@ -179,6 +184,11 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -179,6 +184,11 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
times to approximate the solution. times to approximate the solution.
""" """
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
if not self.config.skip_prk_steps and len(self.ets) < 3: if not self.config.skip_prk_steps and len(self.ets) < 3:
raise ValueError( raise ValueError(
f"{self.__class__} can only be run AFTER scheduler has been run " f"{self.__class__} can only be run AFTER scheduler has been run "
......
...@@ -120,6 +120,11 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -120,6 +120,11 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
self.set_seed(seed) self.set_seed(seed)
# TODO(Patrick) non-PyTorch # TODO(Patrick) non-PyTorch
if self.timesteps is None:
raise ValueError(
"`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
)
timestep = timestep * torch.ones( timestep = timestep * torch.ones(
sample.shape[0], device=sample.device sample.shape[0], device=sample.device
) # torch.repeat_interleave(timestep, sample.shape[0]) ) # torch.repeat_interleave(timestep, sample.shape[0])
...@@ -155,6 +160,11 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -155,6 +160,11 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
if seed is not None: if seed is not None:
self.set_seed(seed) self.set_seed(seed)
if self.timesteps is None:
raise ValueError(
"`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
)
# For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z" # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
# sample noise for correction # sample noise for correction
noise = self.randn_like(sample) noise = self.randn_like(sample)
......
...@@ -35,6 +35,11 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): ...@@ -35,6 +35,11 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps) self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps)
def step_pred(self, score, x, t): def step_pred(self, score, x, t):
if self.timesteps is None:
raise ValueError(
"`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
)
# TODO(Patrick) better comments + non-PyTorch # TODO(Patrick) better comments + non-PyTorch
# postprocess model score # postprocess model score
log_mean_coeff = ( log_mean_coeff = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment