Unverified Commit cd21b965 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

add a step_index counter (#4347)



add self.step_index

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent d185b5ed
...@@ -981,6 +981,9 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): ...@@ -981,6 +981,9 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
prompt_embeds_edit[1:2] += edit_direction prompt_embeds_edit[1:2] += edit_direction
# 10. Second denoising loop to generate the edited image. # 10. Second denoising loop to generate the edited image.
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
latents = latents_init latents = latents_init
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
......
...@@ -96,6 +96,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin): ...@@ -96,6 +96,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = torch.from_numpy(timesteps) self.timesteps = torch.from_numpy(timesteps)
self.custom_timesteps = False self.custom_timesteps = False
self.is_scale_input_called = False self.is_scale_input_called = False
self._step_index = None
def index_for_timestep(self, timestep, schedule_timesteps=None): def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None: if schedule_timesteps is None:
...@@ -104,6 +105,13 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin): ...@@ -104,6 +105,13 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
indices = (schedule_timesteps == timestep).nonzero() indices = (schedule_timesteps == timestep).nonzero()
return indices.item() return indices.item()
@property
def step_index(self):
"""
The index counter for current timestep. It will increae 1 after each scheduler step.
"""
return self._step_index
def scale_model_input( def scale_model_input(
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
) -> torch.FloatTensor: ) -> torch.FloatTensor:
...@@ -121,10 +129,10 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin): ...@@ -121,10 +129,10 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
A scaled input sample. A scaled input sample.
""" """
# Get sigma corresponding to timestep # Get sigma corresponding to timestep
if isinstance(timestep, torch.Tensor): if self.step_index is None:
timestep = timestep.to(self.timesteps.device) self._init_step_index(timestep)
step_idx = self.index_for_timestep(timestep)
sigma = self.sigmas[step_idx] sigma = self.sigmas[self.step_index]
sample = sample / ((sigma**2 + self.config.sigma_data**2) ** 0.5) sample = sample / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
...@@ -220,6 +228,8 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin): ...@@ -220,6 +228,8 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
else: else:
self.timesteps = torch.from_numpy(timesteps).to(device=device) self.timesteps = torch.from_numpy(timesteps).to(device=device)
self._step_index = None
# Modified _convert_to_karras implementation that takes in ramp as argument # Modified _convert_to_karras implementation that takes in ramp as argument
def _convert_to_karras(self, ramp): def _convert_to_karras(self, ramp):
"""Constructs the noise schedule of Karras et al. (2022).""" """Constructs the noise schedule of Karras et al. (2022)."""
...@@ -267,6 +277,24 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin): ...@@ -267,6 +277,24 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
c_out = (sigma - sigma_min) * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 c_out = (sigma - sigma_min) * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
return c_skip, c_out return c_skip, c_out
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (self.timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(index_candidates) > 1:
step_index = index_candidates[1]
else:
step_index = index_candidates[0]
self._step_index = step_index.item()
def step( def step(
self, self,
model_output: torch.FloatTensor, model_output: torch.FloatTensor,
...@@ -318,18 +346,16 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin): ...@@ -318,18 +346,16 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
"See `StableDiffusionPipeline` for a usage example." "See `StableDiffusionPipeline` for a usage example."
) )
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
sigma_min = self.config.sigma_min sigma_min = self.config.sigma_min
sigma_max = self.config.sigma_max sigma_max = self.config.sigma_max
step_index = self.index_for_timestep(timestep) if self.step_index is None:
self._init_step_index(timestep)
# sigma_next corresponds to next_t in original implementation # sigma_next corresponds to next_t in original implementation
sigma = self.sigmas[step_index] sigma = self.sigmas[self.step_index]
if step_index + 1 < self.config.num_train_timesteps: if self.step_index + 1 < self.config.num_train_timesteps:
sigma_next = self.sigmas[step_index + 1] sigma_next = self.sigmas[self.step_index + 1]
else: else:
# Set sigma_next to sigma_min # Set sigma_next to sigma_min
sigma_next = self.sigmas[-1] sigma_next = self.sigmas[-1]
...@@ -358,6 +384,9 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin): ...@@ -358,6 +384,9 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
# tau = sigma_hat, eps = sigma_min # tau = sigma_hat, eps = sigma_min
prev_sample = denoised + z * (sigma_hat**2 - sigma_min**2) ** 0.5 prev_sample = denoised + z * (sigma_hat**2 - sigma_min**2) ** 0.5
# upon completion increase step index by one
self._step_index += 1
if not return_dict: if not return_dict:
return (prev_sample,) return (prev_sample,)
......
...@@ -166,6 +166,8 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -166,6 +166,8 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = torch.from_numpy(timesteps) self.timesteps = torch.from_numpy(timesteps)
self.is_scale_input_called = False self.is_scale_input_called = False
self._step_index = None
@property @property
def init_noise_sigma(self): def init_noise_sigma(self):
# standard deviation of the initial noise distribution # standard deviation of the initial noise distribution
...@@ -174,6 +176,13 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -174,6 +176,13 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
return (self.sigmas.max() ** 2 + 1) ** 0.5 return (self.sigmas.max() ** 2 + 1) ** 0.5
@property
def step_index(self):
"""
The index counter for current timestep. It will increae 1 after each scheduler step.
"""
return self._step_index
def scale_model_input( def scale_model_input(
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
) -> torch.FloatTensor: ) -> torch.FloatTensor:
...@@ -191,10 +200,11 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -191,10 +200,11 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`: `torch.FloatTensor`:
A scaled input sample. A scaled input sample.
""" """
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device) if self.step_index is None:
step_index = (self.timesteps == timestep).nonzero().item() self._init_step_index(timestep)
sigma = self.sigmas[step_index]
sigma = self.sigmas[self.step_index]
sample = sample / ((sigma**2 + 1) ** 0.5) sample = sample / ((sigma**2 + 1) ** 0.5)
self.is_scale_input_called = True self.is_scale_input_called = True
return sample return sample
...@@ -213,20 +223,20 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -213,20 +223,20 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace": if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[ timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[
::-1 ::-1
].copy() ].copy()
elif self.config.timestep_spacing == "leading": elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio # creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3 # casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
timesteps += self.config.steps_offset timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing": elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / self.num_inference_steps step_ratio = self.config.num_train_timesteps / self.num_inference_steps
# creates integer timesteps by multiplying by ratio # creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3 # casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
timesteps -= 1 timesteps -= 1
else: else:
raise ValueError( raise ValueError(
...@@ -237,11 +247,27 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -237,11 +247,27 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas).to(device=device) self.sigmas = torch.from_numpy(sigmas).to(device=device)
if str(device).startswith("mps"):
# mps does not support float64 self.timesteps = torch.from_numpy(timesteps).to(device=device)
self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) self._step_index = None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (self.timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(index_candidates) > 1:
step_index = index_candidates[1]
else: else:
self.timesteps = torch.from_numpy(timesteps).to(device=device) step_index = index_candidates[0]
self._step_index = step_index.item()
def step( def step(
self, self,
...@@ -295,11 +321,10 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -295,11 +321,10 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
"See `StableDiffusionPipeline` for a usage example." "See `StableDiffusionPipeline` for a usage example."
) )
if isinstance(timestep, torch.Tensor): if self.step_index is None:
timestep = timestep.to(self.timesteps.device) self._init_step_index(timestep)
step_index = (self.timesteps == timestep).nonzero().item() sigma = self.sigmas[self.step_index]
sigma = self.sigmas[step_index]
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
if self.config.prediction_type == "epsilon": if self.config.prediction_type == "epsilon":
...@@ -314,8 +339,8 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -314,8 +339,8 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
) )
sigma_from = self.sigmas[step_index] sigma_from = self.sigmas[self.step_index]
sigma_to = self.sigmas[step_index + 1] sigma_to = self.sigmas[self.step_index + 1]
sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
...@@ -331,6 +356,9 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -331,6 +356,9 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
prev_sample = prev_sample + noise * sigma_up prev_sample = prev_sample + noise * sigma_up
# upon completion increase step index by one
self._step_index += 1
if not return_dict: if not return_dict:
return (prev_sample,) return (prev_sample,)
......
...@@ -175,6 +175,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -175,6 +175,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.is_scale_input_called = False self.is_scale_input_called = False
self.use_karras_sigmas = use_karras_sigmas self.use_karras_sigmas = use_karras_sigmas
self._step_index = None
@property @property
def init_noise_sigma(self): def init_noise_sigma(self):
# standard deviation of the initial noise distribution # standard deviation of the initial noise distribution
...@@ -183,6 +185,13 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -183,6 +185,13 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
return (self.sigmas.max() ** 2 + 1) ** 0.5 return (self.sigmas.max() ** 2 + 1) ** 0.5
@property
def step_index(self):
"""
The index counter for current timestep. It will increae 1 after each scheduler step.
"""
return self._step_index
def scale_model_input( def scale_model_input(
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
) -> torch.FloatTensor: ) -> torch.FloatTensor:
...@@ -200,11 +209,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -200,11 +209,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`: `torch.FloatTensor`:
A scaled input sample. A scaled input sample.
""" """
if isinstance(timestep, torch.Tensor): if self.step_index is None:
timestep = timestep.to(self.timesteps.device) self._init_step_index(timestep)
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
sigma = self.sigmas[self.step_index]
sample = sample / ((sigma**2 + 1) ** 0.5) sample = sample / ((sigma**2 + 1) ** 0.5)
self.is_scale_input_called = True self.is_scale_input_called = True
...@@ -224,20 +232,20 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -224,20 +232,20 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace": if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[ timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[
::-1 ::-1
].copy() ].copy()
elif self.config.timestep_spacing == "leading": elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio # creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3 # casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
timesteps += self.config.steps_offset timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing": elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / self.num_inference_steps step_ratio = self.config.num_train_timesteps / self.num_inference_steps
# creates integer timesteps by multiplying by ratio # creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3 # casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
timesteps -= 1 timesteps -= 1
else: else:
raise ValueError( raise ValueError(
...@@ -263,11 +271,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -263,11 +271,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas).to(device=device) self.sigmas = torch.from_numpy(sigmas).to(device=device)
if str(device).startswith("mps"):
# mps does not support float64 self.timesteps = torch.from_numpy(timesteps).to(device=device)
self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) self._step_index = None
else:
self.timesteps = torch.from_numpy(timesteps).to(device=device)
def _sigma_to_t(self, sigma, log_sigmas): def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma # get log sigma
...@@ -306,6 +312,23 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -306,6 +312,23 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas return sigmas
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (self.timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(index_candidates) > 1:
step_index = index_candidates[1]
else:
step_index = index_candidates[0]
self._step_index = step_index.item()
def step( def step(
self, self,
model_output: torch.FloatTensor, model_output: torch.FloatTensor,
...@@ -365,11 +388,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -365,11 +388,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
"See `StableDiffusionPipeline` for a usage example." "See `StableDiffusionPipeline` for a usage example."
) )
if isinstance(timestep, torch.Tensor): if self.step_index is None:
timestep = timestep.to(self.timesteps.device) self._init_step_index(timestep)
step_index = (self.timesteps == timestep).nonzero().item() sigma = self.sigmas[self.step_index]
sigma = self.sigmas[step_index]
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
...@@ -401,10 +423,13 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -401,10 +423,13 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
# 2. Convert to an ODE derivative # 2. Convert to an ODE derivative
derivative = (sample - pred_original_sample) / sigma_hat derivative = (sample - pred_original_sample) / sigma_hat
dt = self.sigmas[step_index + 1] - sigma_hat dt = self.sigmas[self.step_index + 1] - sigma_hat
prev_sample = sample + derivative * dt prev_sample = sample + derivative * dt
# upon completion increase step index by one
self._step_index += 1
if not return_dict: if not return_dict:
return (prev_sample,) return (prev_sample,)
......
...@@ -149,6 +149,8 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -149,6 +149,8 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.set_timesteps(num_train_timesteps, None, num_train_timesteps) self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
self.use_karras_sigmas = use_karras_sigmas self.use_karras_sigmas = use_karras_sigmas
self._step_index = None
def index_for_timestep(self, timestep, schedule_timesteps=None): def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None: if schedule_timesteps is None:
schedule_timesteps = self.timesteps schedule_timesteps = self.timesteps
...@@ -175,6 +177,13 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -175,6 +177,13 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
return (self.sigmas.max() ** 2 + 1) ** 0.5 return (self.sigmas.max() ** 2 + 1) ** 0.5
@property
def step_index(self):
"""
The index counter for current timestep. It will increae 1 after each scheduler step.
"""
return self._step_index
def scale_model_input( def scale_model_input(
self, self,
sample: torch.FloatTensor, sample: torch.FloatTensor,
...@@ -194,9 +203,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -194,9 +203,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`: `torch.FloatTensor`:
A scaled input sample. A scaled input sample.
""" """
step_index = self.index_for_timestep(timestep) if self.step_index is None:
self._init_step_index(timestep)
sigma = self.sigmas[step_index] sigma = self.sigmas[self.step_index]
sample = sample / ((sigma**2 + 1) ** 0.5) sample = sample / ((sigma**2 + 1) ** 0.5)
return sample return sample
...@@ -221,18 +231,18 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -221,18 +231,18 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace": if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy()
elif self.config.timestep_spacing == "leading": elif self.config.timestep_spacing == "leading":
step_ratio = num_train_timesteps // self.num_inference_steps step_ratio = num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio # creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3 # casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
timesteps += self.config.steps_offset timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing": elif self.config.timestep_spacing == "trailing":
step_ratio = num_train_timesteps / self.num_inference_steps step_ratio = num_train_timesteps / self.num_inference_steps
# creates integer timesteps by multiplying by ratio # creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3 # casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
timesteps -= 1 timesteps -= 1
else: else:
raise ValueError( raise ValueError(
...@@ -254,16 +264,15 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -254,16 +264,15 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
timesteps = torch.from_numpy(timesteps) timesteps = torch.from_numpy(timesteps)
timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)]) timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])
if str(device).startswith("mps"): self.timesteps = timesteps.to(device=device)
# mps does not support float64
self.timesteps = timesteps.to(device, dtype=torch.float32)
else:
self.timesteps = timesteps.to(device=device)
# empty dt and derivative # empty dt and derivative
self.prev_derivative = None self.prev_derivative = None
self.dt = None self.dt = None
self._step_index = None
# (YiYi Notes: keep this for now since we are keeping add_noise function which use index_for_timestep)
# for exp beta schedules, such as the one for `pipeline_shap_e.py` # for exp beta schedules, such as the one for `pipeline_shap_e.py`
# we need an index counter # we need an index counter
self._index_counter = defaultdict(int) self._index_counter = defaultdict(int)
...@@ -310,6 +319,24 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -310,6 +319,24 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
def state_in_first_order(self): def state_in_first_order(self):
return self.dt is None return self.dt is None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (self.timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(index_candidates) > 1:
step_index = index_candidates[1]
else:
step_index = index_candidates[0]
self._step_index = step_index.item()
def step( def step(
self, self,
model_output: Union[torch.FloatTensor, np.ndarray], model_output: Union[torch.FloatTensor, np.ndarray],
...@@ -336,19 +363,21 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -336,19 +363,21 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor. tuple is returned where the first element is the sample tensor.
""" """
step_index = self.index_for_timestep(timestep) if self.step_index is None:
self._init_step_index(timestep)
# (YiYi notes: keep this for now since we are keeping the add_noise method)
# advance index counter by 1 # advance index counter by 1
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
self._index_counter[timestep_int] += 1 self._index_counter[timestep_int] += 1
if self.state_in_first_order: if self.state_in_first_order:
sigma = self.sigmas[step_index] sigma = self.sigmas[self.step_index]
sigma_next = self.sigmas[step_index + 1] sigma_next = self.sigmas[self.step_index + 1]
else: else:
# 2nd order / Heun's method # 2nd order / Heun's method
sigma = self.sigmas[step_index - 1] sigma = self.sigmas[self.step_index - 1]
sigma_next = self.sigmas[step_index] sigma_next = self.sigmas[self.step_index]
# currently only gamma=0 is supported. This usually works best anyways. # currently only gamma=0 is supported. This usually works best anyways.
# We can support gamma in the future but then need to scale the timestep before # We can support gamma in the future but then need to scale the timestep before
...@@ -404,6 +433,9 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -404,6 +433,9 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
prev_sample = sample + derivative * dt prev_sample = sample + derivative * dt
# upon completion increase step index by one
self._step_index += 1
if not return_dict: if not return_dict:
return (prev_sample,) return (prev_sample,)
......
...@@ -55,6 +55,14 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -55,6 +55,14 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
# running values # running values
self.ets = [] self.ets = []
self._step_index = None
@property
def step_index(self):
"""
The index counter for current timestep. It will increae 1 after each scheduler step.
"""
return self._step_index
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
""" """
...@@ -81,6 +89,25 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -81,6 +89,25 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = timesteps.to(device) self.timesteps = timesteps.to(device)
self.ets = [] self.ets = []
self._step_index = None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (self.timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(index_candidates) > 1:
step_index = index_candidates[1]
else:
step_index = index_candidates[0]
self._step_index = step_index.item()
def step( def step(
self, self,
...@@ -112,9 +139,11 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -112,9 +139,11 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
raise ValueError( raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
) )
if self.step_index is None:
self._init_step_index(timestep)
timestep_index = (self.timesteps == timestep).nonzero().item() timestep_index = self.step_index
prev_timestep_index = timestep_index + 1 prev_timestep_index = self.step_index + 1
ets = sample * self.betas[timestep_index] + model_output * self.alphas[timestep_index] ets = sample * self.betas[timestep_index] + model_output * self.alphas[timestep_index]
self.ets.append(ets) self.ets.append(ets)
...@@ -130,6 +159,9 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -130,6 +159,9 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
prev_sample = self._get_prev_sample(sample, timestep_index, prev_timestep_index, ets) prev_sample = self._get_prev_sample(sample, timestep_index, prev_timestep_index, ets)
# upon completion increase step index by one
self._step_index += 1
if not return_dict: if not return_dict:
return (prev_sample,) return (prev_sample,)
......
...@@ -137,6 +137,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -137,6 +137,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
# set all values # set all values
self.set_timesteps(num_train_timesteps, None, num_train_timesteps) self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
self._step_index = None
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None): def index_for_timestep(self, timestep, schedule_timesteps=None):
...@@ -165,6 +166,13 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -165,6 +166,13 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
return (self.sigmas.max() ** 2 + 1) ** 0.5 return (self.sigmas.max() ** 2 + 1) ** 0.5
@property
def step_index(self):
"""
The index counter for current timestep. It will increae 1 after each scheduler step.
"""
return self._step_index
def scale_model_input( def scale_model_input(
self, self,
sample: torch.FloatTensor, sample: torch.FloatTensor,
...@@ -184,12 +192,13 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -184,12 +192,13 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`: `torch.FloatTensor`:
A scaled input sample. A scaled input sample.
""" """
step_index = self.index_for_timestep(timestep) if self.step_index is None:
self._init_step_index(timestep)
if self.state_in_first_order: if self.state_in_first_order:
sigma = self.sigmas[step_index] sigma = self.sigmas[self.step_index]
else: else:
sigma = self.sigmas_interpol[step_index - 1] sigma = self.sigmas_interpol[self.step_index - 1]
sample = sample / ((sigma**2 + 1) ** 0.5) sample = sample / ((sigma**2 + 1) ** 0.5)
return sample return sample
...@@ -215,18 +224,18 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -215,18 +224,18 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace": if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy()
elif self.config.timestep_spacing == "leading": elif self.config.timestep_spacing == "leading":
step_ratio = num_train_timesteps // self.num_inference_steps step_ratio = num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio # creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3 # casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
timesteps += self.config.steps_offset timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing": elif self.config.timestep_spacing == "trailing":
step_ratio = num_train_timesteps / self.num_inference_steps step_ratio = num_train_timesteps / self.num_inference_steps
# creates integer timesteps by multiplying by ratio # creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3 # casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
timesteps -= 1 timesteps -= 1
else: else:
raise ValueError( raise ValueError(
...@@ -259,12 +268,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -259,12 +268,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.sigmas_up = torch.cat([sigmas_up[:1], sigmas_up[1:].repeat_interleave(2), sigmas_up[-1:]]) self.sigmas_up = torch.cat([sigmas_up[:1], sigmas_up[1:].repeat_interleave(2), sigmas_up[-1:]])
self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]]) self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]])
if str(device).startswith("mps"): timesteps = torch.from_numpy(timesteps).to(device)
# mps does not support float64
timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
else:
timesteps = torch.from_numpy(timesteps).to(device)
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype) timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype)
interleaved_timesteps = torch.stack((timesteps_interpol[:-2, None], timesteps[1:, None]), dim=-1).flatten() interleaved_timesteps = torch.stack((timesteps_interpol[:-2, None], timesteps[1:, None]), dim=-1).flatten()
...@@ -276,6 +280,8 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -276,6 +280,8 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
# we need an index counter # we need an index counter
self._index_counter = defaultdict(int) self._index_counter = defaultdict(int)
self._step_index = None
def sigma_to_t(self, sigma): def sigma_to_t(self, sigma):
# get log sigma # get log sigma
log_sigma = sigma.log() log_sigma = sigma.log()
...@@ -303,6 +309,24 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -303,6 +309,24 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
def state_in_first_order(self): def state_in_first_order(self):
return self.sample is None return self.sample is None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (self.timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(index_candidates) > 1:
step_index = index_candidates[1]
else:
step_index = index_candidates[0]
self._step_index = step_index.item()
def step( def step(
self, self,
model_output: Union[torch.FloatTensor, np.ndarray], model_output: Union[torch.FloatTensor, np.ndarray],
...@@ -332,23 +356,24 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -332,23 +356,24 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
If return_dict is `True`, [`~schedulers.scheduling_ddim.SchedulerOutput`] is returned, otherwise a If return_dict is `True`, [`~schedulers.scheduling_ddim.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor. tuple is returned where the first element is the sample tensor.
""" """
step_index = self.index_for_timestep(timestep) if self.step_index is None:
self._init_step_index(timestep)
# advance index counter by 1 # advance index counter by 1
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
self._index_counter[timestep_int] += 1 self._index_counter[timestep_int] += 1
if self.state_in_first_order: if self.state_in_first_order:
sigma = self.sigmas[step_index] sigma = self.sigmas[self.step_index]
sigma_interpol = self.sigmas_interpol[step_index] sigma_interpol = self.sigmas_interpol[self.step_index]
sigma_up = self.sigmas_up[step_index] sigma_up = self.sigmas_up[self.step_index]
sigma_down = self.sigmas_down[step_index - 1] sigma_down = self.sigmas_down[self.step_index - 1]
else: else:
# 2nd order / KPDM2's method # 2nd order / KPDM2's method
sigma = self.sigmas[step_index - 1] sigma = self.sigmas[self.step_index - 1]
sigma_interpol = self.sigmas_interpol[step_index - 1] sigma_interpol = self.sigmas_interpol[self.step_index - 1]
sigma_up = self.sigmas_up[step_index - 1] sigma_up = self.sigmas_up[self.step_index - 1]
sigma_down = self.sigmas_down[step_index - 1] sigma_down = self.sigmas_down[self.step_index - 1]
# currently only gamma=0 is supported. This usually works best anyways. # currently only gamma=0 is supported. This usually works best anyways.
# We can support gamma in the future but then need to scale the timestep before # We can support gamma in the future but then need to scale the timestep before
...@@ -398,6 +423,9 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -398,6 +423,9 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
prev_sample = sample + derivative * dt prev_sample = sample + derivative * dt
prev_sample = prev_sample + noise * sigma_up prev_sample = prev_sample + noise * sigma_up
# upon completion increase step index by one
self._step_index += 1
if not return_dict: if not return_dict:
return (prev_sample,) return (prev_sample,)
......
...@@ -137,6 +137,8 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -137,6 +137,8 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
# set all values # set all values
self.set_timesteps(num_train_timesteps, None, num_train_timesteps) self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
self._step_index = None
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None): def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None: if schedule_timesteps is None:
...@@ -164,6 +166,13 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -164,6 +166,13 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
return (self.sigmas.max() ** 2 + 1) ** 0.5 return (self.sigmas.max() ** 2 + 1) ** 0.5
@property
def step_index(self):
"""
The index counter for current timestep. It will increae 1 after each scheduler step.
"""
return self._step_index
def scale_model_input( def scale_model_input(
self, self,
sample: torch.FloatTensor, sample: torch.FloatTensor,
...@@ -183,12 +192,13 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -183,12 +192,13 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`: `torch.FloatTensor`:
A scaled input sample. A scaled input sample.
""" """
step_index = self.index_for_timestep(timestep) if self.step_index is None:
self._init_step_index(timestep)
if self.state_in_first_order: if self.state_in_first_order:
sigma = self.sigmas[step_index] sigma = self.sigmas[self.step_index]
else: else:
sigma = self.sigmas_interpol[step_index] sigma = self.sigmas_interpol[self.step_index]
sample = sample / ((sigma**2 + 1) ** 0.5) sample = sample / ((sigma**2 + 1) ** 0.5)
return sample return sample
...@@ -214,18 +224,18 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -214,18 +224,18 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace": if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy()
elif self.config.timestep_spacing == "leading": elif self.config.timestep_spacing == "leading":
step_ratio = num_train_timesteps // self.num_inference_steps step_ratio = num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio # creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3 # casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
timesteps += self.config.steps_offset timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing": elif self.config.timestep_spacing == "trailing":
step_ratio = num_train_timesteps / self.num_inference_steps step_ratio = num_train_timesteps / self.num_inference_steps
# creates integer timesteps by multiplying by ratio # creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3 # casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
timesteps -= 1 timesteps -= 1
else: else:
raise ValueError( raise ValueError(
...@@ -247,11 +257,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -247,11 +257,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
[sigmas_interpol[:1], sigmas_interpol[1:].repeat_interleave(2), sigmas_interpol[-1:]] [sigmas_interpol[:1], sigmas_interpol[1:].repeat_interleave(2), sigmas_interpol[-1:]]
) )
if str(device).startswith("mps"): timesteps = torch.from_numpy(timesteps).to(device)
# mps does not support float64
timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
else:
timesteps = torch.from_numpy(timesteps).to(device)
# interpolate timesteps # interpolate timesteps
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype) timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype)
...@@ -265,6 +271,8 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -265,6 +271,8 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
# we need an index counter # we need an index counter
self._index_counter = defaultdict(int) self._index_counter = defaultdict(int)
self._step_index = None
def sigma_to_t(self, sigma): def sigma_to_t(self, sigma):
# get log sigma # get log sigma
log_sigma = sigma.log() log_sigma = sigma.log()
...@@ -292,6 +300,24 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -292,6 +300,24 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
def state_in_first_order(self): def state_in_first_order(self):
return self.sample is None return self.sample is None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (self.timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(index_candidates) > 1:
step_index = index_candidates[1]
else:
step_index = index_candidates[0]
self._step_index = step_index.item()
def step( def step(
self, self,
model_output: Union[torch.FloatTensor, np.ndarray], model_output: Union[torch.FloatTensor, np.ndarray],
...@@ -318,21 +344,22 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -318,21 +344,22 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor. tuple is returned where the first element is the sample tensor.
""" """
step_index = self.index_for_timestep(timestep) if self.step_index is None:
self._init_step_index(timestep)
# advance index counter by 1 # advance index counter by 1
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
self._index_counter[timestep_int] += 1 self._index_counter[timestep_int] += 1
if self.state_in_first_order: if self.state_in_first_order:
sigma = self.sigmas[step_index] sigma = self.sigmas[self.step_index]
sigma_interpol = self.sigmas_interpol[step_index + 1] sigma_interpol = self.sigmas_interpol[self.step_index + 1]
sigma_next = self.sigmas[step_index + 1] sigma_next = self.sigmas[self.step_index + 1]
else: else:
# 2nd order / KDPM2's method # 2nd order / KDPM2's method
sigma = self.sigmas[step_index - 1] sigma = self.sigmas[self.step_index - 1]
sigma_interpol = self.sigmas_interpol[step_index] sigma_interpol = self.sigmas_interpol[self.step_index]
sigma_next = self.sigmas[step_index] sigma_next = self.sigmas[self.step_index]
# currently only gamma=0 is supported. This usually works best anyways. # currently only gamma=0 is supported. This usually works best anyways.
# We can support gamma in the future but then need to scale the timestep before # We can support gamma in the future but then need to scale the timestep before
...@@ -375,6 +402,9 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -375,6 +402,9 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
sample = self.sample sample = self.sample
self.sample = None self.sample = None
# upon completion increase step index by one
self._step_index += 1
prev_sample = sample + derivative * dt prev_sample = sample + derivative * dt
if not return_dict: if not return_dict:
......
...@@ -169,6 +169,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -169,6 +169,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.derivatives = [] self.derivatives = []
self.is_scale_input_called = False self.is_scale_input_called = False
self._step_index = None
@property @property
def init_noise_sigma(self): def init_noise_sigma(self):
# standard deviation of the initial noise distribution # standard deviation of the initial noise distribution
...@@ -177,6 +179,13 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -177,6 +179,13 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
return (self.sigmas.max() ** 2 + 1) ** 0.5 return (self.sigmas.max() ** 2 + 1) ** 0.5
@property
def step_index(self):
"""
The index counter for current timestep. It will increae 1 after each scheduler step.
"""
return self._step_index
def scale_model_input( def scale_model_input(
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
) -> torch.FloatTensor: ) -> torch.FloatTensor:
...@@ -194,10 +203,11 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -194,10 +203,11 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`: `torch.FloatTensor`:
A scaled input sample. A scaled input sample.
""" """
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device) if self.step_index is None:
step_index = (self.timesteps == timestep).nonzero().item() self._init_step_index(timestep)
sigma = self.sigmas[step_index]
sigma = self.sigmas[self.step_index]
sample = sample / ((sigma**2 + 1) ** 0.5) sample = sample / ((sigma**2 + 1) ** 0.5)
self.is_scale_input_called = True self.is_scale_input_called = True
return sample return sample
...@@ -238,20 +248,20 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -238,20 +248,20 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace": if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[ timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[
::-1 ::-1
].copy() ].copy()
elif self.config.timestep_spacing == "leading": elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio # creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3 # casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
timesteps += self.config.steps_offset timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing": elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / self.num_inference_steps step_ratio = self.config.num_train_timesteps / self.num_inference_steps
# creates integer timesteps by multiplying by ratio # creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3 # casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
timesteps -= 1 timesteps -= 1
else: else:
raise ValueError( raise ValueError(
...@@ -269,14 +279,29 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -269,14 +279,29 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas).to(device=device) self.sigmas = torch.from_numpy(sigmas).to(device=device)
if str(device).startswith("mps"): self.timesteps = torch.from_numpy(timesteps).to(device=device)
# mps does not support float64 self._step_index = None
self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
else:
self.timesteps = torch.from_numpy(timesteps).to(device=device)
self.derivatives = [] self.derivatives = []
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (self.timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(index_candidates) > 1:
step_index = index_candidates[1]
else:
step_index = index_candidates[0]
self._step_index = step_index.item()
# copied from diffusers.schedulers.scheduling_euler_discrete._sigma_to_t # copied from diffusers.schedulers.scheduling_euler_discrete._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas): def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma # get log sigma
...@@ -351,10 +376,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -351,10 +376,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
"See `StableDiffusionPipeline` for a usage example." "See `StableDiffusionPipeline` for a usage example."
) )
if isinstance(timestep, torch.Tensor): if self.step_index is None:
timestep = timestep.to(self.timesteps.device) self._init_step_index(timestep)
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index] sigma = self.sigmas[self.step_index]
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
if self.config.prediction_type == "epsilon": if self.config.prediction_type == "epsilon":
...@@ -376,14 +401,17 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -376,14 +401,17 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.derivatives.pop(0) self.derivatives.pop(0)
# 3. Compute linear multistep coefficients # 3. Compute linear multistep coefficients
order = min(step_index + 1, order) order = min(self.step_index + 1, order)
lms_coeffs = [self.get_lms_coefficient(order, step_index, curr_order) for curr_order in range(order)] lms_coeffs = [self.get_lms_coefficient(order, self.step_index, curr_order) for curr_order in range(order)]
# 4. Compute previous sample based on the derivatives path # 4. Compute previous sample based on the derivatives path
prev_sample = sample + sum( prev_sample = sample + sum(
coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives)) coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives))
) )
# upon completion increase step index by one
self._step_index += 1
if not return_dict: if not return_dict:
return (prev_sample,) return (prev_sample,)
......
...@@ -104,6 +104,8 @@ class IPNDMSchedulerTest(SchedulerCommonTest): ...@@ -104,6 +104,8 @@ class IPNDMSchedulerTest(SchedulerCommonTest):
residual = model(sample, t) residual = model(sample, t)
sample = scheduler.step(residual, t, sample).prev_sample sample = scheduler.step(residual, t, sample).prev_sample
scheduler._step_index = None
for i, t in enumerate(scheduler.timesteps): for i, t in enumerate(scheduler.timesteps):
residual = model(sample, t) residual = model(sample, t)
sample = scheduler.step(residual, t, sample).prev_sample sample = scheduler.step(residual, t, sample).prev_sample
......
...@@ -485,8 +485,8 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -485,8 +485,8 @@ class SchedulerCommonTest(unittest.TestCase):
num_inference_steps = kwargs.pop("num_inference_steps", None) num_inference_steps = kwargs.pop("num_inference_steps", None)
timestep_0 = 0 timestep_0 = 1
timestep_1 = 1 timestep_1 = 0
for scheduler_class in self.scheduler_classes: for scheduler_class in self.scheduler_classes:
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler): if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment