"tests/vscode:/vscode.git/clone" did not exist on "b6f5ba9a809fcd2e5b2c440f538c1ccc965a9e59"
Unverified Commit 170af08e authored by Richard Löwenström's avatar Richard Löwenström Committed by GitHub
Browse files

Easily understandable error if inference steps not set before using scheduler (#263) (#264)



* Helpful exception if inference steps not set in schedulers (#263)

* Apply suggestions from codereview by patrickvonplaten

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
parent 76985bc8
...@@ -117,6 +117,11 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -117,6 +117,11 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
use_clipped_model_output: bool = False, use_clipped_model_output: bool = False,
generator=None, generator=None,
): ):
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding # Ideally, read DDIM paper in-detail understanding
......
...@@ -145,6 +145,11 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -145,6 +145,11 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
solution to the differential equation. solution to the differential equation.
""" """
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2 diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2
prev_timestep = max(timestep - diff_to_prev, self.prk_timesteps[-1]) prev_timestep = max(timestep - diff_to_prev, self.prk_timesteps[-1])
timestep = self.prk_timesteps[self.counter // 4 * 4] timestep = self.prk_timesteps[self.counter // 4 * 4]
...@@ -179,6 +184,11 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -179,6 +184,11 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
times to approximate the solution. times to approximate the solution.
""" """
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
if not self.config.skip_prk_steps and len(self.ets) < 3: if not self.config.skip_prk_steps and len(self.ets) < 3:
raise ValueError( raise ValueError(
f"{self.__class__} can only be run AFTER scheduler has been run " f"{self.__class__} can only be run AFTER scheduler has been run "
......
...@@ -120,6 +120,11 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -120,6 +120,11 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
self.set_seed(seed) self.set_seed(seed)
# TODO(Patrick) non-PyTorch # TODO(Patrick) non-PyTorch
if self.timesteps is None:
raise ValueError(
"`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
)
timestep = timestep * torch.ones( timestep = timestep * torch.ones(
sample.shape[0], device=sample.device sample.shape[0], device=sample.device
) # torch.repeat_interleave(timestep, sample.shape[0]) ) # torch.repeat_interleave(timestep, sample.shape[0])
...@@ -155,6 +160,11 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -155,6 +160,11 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
if seed is not None: if seed is not None:
self.set_seed(seed) self.set_seed(seed)
if self.timesteps is None:
raise ValueError(
"`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
)
# For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z" # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
# sample noise for correction # sample noise for correction
noise = self.randn_like(sample) noise = self.randn_like(sample)
......
...@@ -35,6 +35,11 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): ...@@ -35,6 +35,11 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps) self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps)
def step_pred(self, score, x, t): def step_pred(self, score, x, t):
if self.timesteps is None:
raise ValueError(
"`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
)
# TODO(Patrick) better comments + non-PyTorch # TODO(Patrick) better comments + non-PyTorch
# postprocess model score # postprocess model score
log_mean_coeff = ( log_mean_coeff = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment