Unverified Commit 675ef1ff authored by Joqsan's avatar Joqsan Committed by GitHub
Browse files

fix: DDPMScheduler.set_timesteps() (#1912)

parent d67c3051
...@@ -201,6 +201,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -201,6 +201,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
num_inference_steps (`int`): num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model. the number of diffusion steps used when generating samples with a pre-trained model.
""" """
if num_inference_steps > self.config.num_train_timesteps:
raise ValueError(
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
f" maximal {self.config.num_train_timesteps} timesteps."
)
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
step_ratio = self.config.num_train_timesteps // self.num_inference_steps step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio # creates integer timesteps by multiplying by ratio
......
...@@ -184,11 +184,18 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -184,11 +184,18 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
num_inference_steps (`int`): num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model. the number of diffusion steps used when generating samples with a pre-trained model.
""" """
num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
if num_inference_steps > self.config.num_train_timesteps:
raise ValueError(
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
f" maximal {self.config.num_train_timesteps} timesteps."
)
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
timesteps = np.arange(
0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps step_ratio = self.config.num_train_timesteps // self.num_inference_steps
)[::-1].copy() timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
self.timesteps = torch.from_numpy(timesteps).to(device) self.timesteps = torch.from_numpy(timesteps).to(device)
def _get_variance(self, t, predicted_variance=None, variance_type=None): def _get_variance(self, t, predicted_variance=None, variance_type=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment