Unverified Commit b811964a authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

ddpm custom timesteps (#3007)

add custom timesteps test

add custom timesteps descending order check

docs

timesteps -> custom_timesteps

can only pass one of num_inference_steps and timesteps
parent 1bd4c9e9
...@@ -162,6 +162,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -162,6 +162,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self.init_noise_sigma = 1.0 self.init_noise_sigma = 1.0
# setable values # setable values
self.custom_timesteps = False
self.num_inference_steps = None self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
...@@ -191,31 +192,62 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -191,31 +192,62 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
""" """
return sample return sample
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): def set_timesteps(
self,
num_inference_steps: Optional[int] = None,
device: Union[str, torch.device] = None,
timesteps: Optional[List[int]] = None,
):
""" """
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
Args: Args:
num_inference_steps (`int`): num_inference_steps (`Optional[int]`):
the number of diffusion steps used when generating samples with a pre-trained model. the number of diffusion steps used when generating samples with a pre-trained model. If passed, then
`timesteps` must be `None`.
device (`str` or `torch.device`, optional):
the device to which the timesteps are moved to.
custom_timesteps (`List[int]`, optional):
custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
timestep spacing strategy of equal spacing between timesteps is used. If passed, `num_inference_steps`
must be `None`.
""" """
if num_inference_steps is not None and timesteps is not None:
raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
if timesteps is not None:
for i in range(1, len(timesteps)):
if timesteps[i] >= timesteps[i - 1]:
raise ValueError("`custom_timesteps` must be in descending order.")
if timesteps[0] >= self.config.num_train_timesteps:
raise ValueError(
f"`timesteps` must start before `self.config.train_timesteps`:"
f" {self.config.num_train_timesteps}."
)
timesteps = np.array(timesteps, dtype=np.int64)
self.custom_timesteps = True
else:
if num_inference_steps > self.config.num_train_timesteps:
raise ValueError(
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
f" maximal {self.config.num_train_timesteps} timesteps."
)
if num_inference_steps > self.config.num_train_timesteps: self.num_inference_steps = num_inference_steps
raise ValueError(
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
f" maximal {self.config.num_train_timesteps} timesteps."
)
self.num_inference_steps = num_inference_steps step_ratio = self.config.num_train_timesteps // self.num_inference_steps
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
self.custom_timesteps = False
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
self.timesteps = torch.from_numpy(timesteps).to(device) self.timesteps = torch.from_numpy(timesteps).to(device)
def _get_variance(self, t, predicted_variance=None, variance_type=None): def _get_variance(self, t, predicted_variance=None, variance_type=None):
num_inference_steps = self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps prev_t = self.previous_timestep(t)
prev_t = t - self.config.num_train_timesteps // num_inference_steps
alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev
...@@ -314,8 +346,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -314,8 +346,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
""" """
t = timestep t = timestep
num_inference_steps = self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps prev_t = self.previous_timestep(t)
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
...@@ -428,3 +460,18 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -428,3 +460,18 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
def __len__(self): def __len__(self):
return self.config.num_train_timesteps return self.config.num_train_timesteps
def previous_timestep(self, timestep):
if self.custom_timesteps:
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
if index == self.timesteps.shape[0] - 1:
prev_t = torch.tensor(-1)
else:
prev_t = self.timesteps[index + 1]
else:
num_inference_steps = (
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
)
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
return prev_t
...@@ -129,3 +129,59 @@ class DDPMSchedulerTest(SchedulerCommonTest): ...@@ -129,3 +129,59 @@ class DDPMSchedulerTest(SchedulerCommonTest):
assert abs(result_sum.item() - 202.0296) < 1e-2 assert abs(result_sum.item() - 202.0296) < 1e-2
assert abs(result_mean.item() - 0.2631) < 1e-3 assert abs(result_mean.item() - 0.2631) < 1e-3
def test_custom_timesteps(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
timesteps = [100, 87, 50, 1, 0]
scheduler.set_timesteps(timesteps=timesteps)
scheduler_timesteps = scheduler.timesteps
for i, timestep in enumerate(scheduler_timesteps):
if i == len(timesteps) - 1:
expected_prev_t = -1
else:
expected_prev_t = timesteps[i + 1]
prev_t = scheduler.previous_timestep(timestep)
prev_t = prev_t.item()
self.assertEqual(prev_t, expected_prev_t)
def test_custom_timesteps_increasing_order(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
timesteps = [100, 87, 50, 51, 0]
with self.assertRaises(ValueError, msg="`custom_timesteps` must be in descending order."):
scheduler.set_timesteps(timesteps=timesteps)
def test_custom_timesteps_passing_both_num_inference_steps_and_timesteps(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
timesteps = [100, 87, 50, 1, 0]
num_inference_steps = len(timesteps)
with self.assertRaises(ValueError, msg="Can only pass one of `num_inference_steps` or `custom_timesteps`."):
scheduler.set_timesteps(num_inference_steps=num_inference_steps, timesteps=timesteps)
def test_custom_timesteps_too_large(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
timesteps = [scheduler.config.num_train_timesteps]
with self.assertRaises(
ValueError,
msg="`timesteps` must start before `self.config.train_timesteps`: {scheduler.config.num_train_timesteps}}",
):
scheduler.set_timesteps(timesteps=timesteps)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment