Unverified Commit cd21b965 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

add a step_index counter (#4347)



add self.step_index

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent d185b5ed
......@@ -981,6 +981,9 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
prompt_embeds_edit[1:2] += edit_direction
# 10. Second denoising loop to generate the edited image.
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
latents = latents_init
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
......
......@@ -96,6 +96,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = torch.from_numpy(timesteps)
self.custom_timesteps = False
self.is_scale_input_called = False
self._step_index = None
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
......@@ -104,6 +105,13 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
indices = (schedule_timesteps == timestep).nonzero()
return indices.item()
@property
def step_index(self):
"""
The index counter for current timestep. It will increae 1 after each scheduler step.
"""
return self._step_index
def scale_model_input(
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
) -> torch.FloatTensor:
......@@ -121,10 +129,10 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
A scaled input sample.
"""
# Get sigma corresponding to timestep
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
step_idx = self.index_for_timestep(timestep)
sigma = self.sigmas[step_idx]
if self.step_index is None:
self._init_step_index(timestep)
sigma = self.sigmas[self.step_index]
sample = sample / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
......@@ -220,6 +228,8 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
else:
self.timesteps = torch.from_numpy(timesteps).to(device=device)
self._step_index = None
# Modified _convert_to_karras implementation that takes in ramp as argument
def _convert_to_karras(self, ramp):
"""Constructs the noise schedule of Karras et al. (2022)."""
......@@ -267,6 +277,24 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
c_out = (sigma - sigma_min) * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
return c_skip, c_out
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (self.timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(index_candidates) > 1:
step_index = index_candidates[1]
else:
step_index = index_candidates[0]
self._step_index = step_index.item()
def step(
self,
model_output: torch.FloatTensor,
......@@ -318,18 +346,16 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
"See `StableDiffusionPipeline` for a usage example."
)
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
sigma_min = self.config.sigma_min
sigma_max = self.config.sigma_max
step_index = self.index_for_timestep(timestep)
if self.step_index is None:
self._init_step_index(timestep)
# sigma_next corresponds to next_t in original implementation
sigma = self.sigmas[step_index]
if step_index + 1 < self.config.num_train_timesteps:
sigma_next = self.sigmas[step_index + 1]
sigma = self.sigmas[self.step_index]
if self.step_index + 1 < self.config.num_train_timesteps:
sigma_next = self.sigmas[self.step_index + 1]
else:
# Set sigma_next to sigma_min
sigma_next = self.sigmas[-1]
......@@ -358,6 +384,9 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
# tau = sigma_hat, eps = sigma_min
prev_sample = denoised + z * (sigma_hat**2 - sigma_min**2) ** 0.5
# upon completion increase step index by one
self._step_index += 1
if not return_dict:
return (prev_sample,)
......
......@@ -166,6 +166,8 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = torch.from_numpy(timesteps)
self.is_scale_input_called = False
self._step_index = None
@property
def init_noise_sigma(self):
# standard deviation of the initial noise distribution
......@@ -174,6 +176,13 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
return (self.sigmas.max() ** 2 + 1) ** 0.5
@property
def step_index(self):
"""
The index counter for current timestep. It will increae 1 after each scheduler step.
"""
return self._step_index
def scale_model_input(
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
) -> torch.FloatTensor:
......@@ -191,10 +200,11 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`:
A scaled input sample.
"""
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
if self.step_index is None:
self._init_step_index(timestep)
sigma = self.sigmas[self.step_index]
sample = sample / ((sigma**2 + 1) ** 0.5)
self.is_scale_input_called = True
return sample
......@@ -213,20 +223,20 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[
::-1
].copy()
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float)
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
timesteps -= 1
else:
raise ValueError(
......@@ -237,11 +247,27 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas).to(device=device)
if str(device).startswith("mps"):
# mps does not support float64
self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
self.timesteps = torch.from_numpy(timesteps).to(device=device)
self._step_index = None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (self.timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(index_candidates) > 1:
step_index = index_candidates[1]
else:
self.timesteps = torch.from_numpy(timesteps).to(device=device)
step_index = index_candidates[0]
self._step_index = step_index.item()
def step(
self,
......@@ -295,11 +321,10 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
"See `StableDiffusionPipeline` for a usage example."
)
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
if self.step_index is None:
self._init_step_index(timestep)
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
sigma = self.sigmas[self.step_index]
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
if self.config.prediction_type == "epsilon":
......@@ -314,8 +339,8 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
)
sigma_from = self.sigmas[step_index]
sigma_to = self.sigmas[step_index + 1]
sigma_from = self.sigmas[self.step_index]
sigma_to = self.sigmas[self.step_index + 1]
sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
......@@ -331,6 +356,9 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
prev_sample = prev_sample + noise * sigma_up
# upon completion increase step index by one
self._step_index += 1
if not return_dict:
return (prev_sample,)
......
......@@ -175,6 +175,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.is_scale_input_called = False
self.use_karras_sigmas = use_karras_sigmas
self._step_index = None
@property
def init_noise_sigma(self):
# standard deviation of the initial noise distribution
......@@ -183,6 +185,13 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
return (self.sigmas.max() ** 2 + 1) ** 0.5
@property
def step_index(self):
"""
The index counter for current timestep. It will increae 1 after each scheduler step.
"""
return self._step_index
def scale_model_input(
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
) -> torch.FloatTensor:
......@@ -200,11 +209,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`:
A scaled input sample.
"""
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
if self.step_index is None:
self._init_step_index(timestep)
sigma = self.sigmas[self.step_index]
sample = sample / ((sigma**2 + 1) ** 0.5)
self.is_scale_input_called = True
......@@ -224,20 +232,20 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[
::-1
].copy()
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float)
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
timesteps -= 1
else:
raise ValueError(
......@@ -263,11 +271,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas).to(device=device)
if str(device).startswith("mps"):
# mps does not support float64
self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
else:
self.timesteps = torch.from_numpy(timesteps).to(device=device)
self.timesteps = torch.from_numpy(timesteps).to(device=device)
self._step_index = None
def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma
......@@ -306,6 +312,23 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (self.timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(index_candidates) > 1:
step_index = index_candidates[1]
else:
step_index = index_candidates[0]
self._step_index = step_index.item()
def step(
self,
model_output: torch.FloatTensor,
......@@ -365,11 +388,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
"See `StableDiffusionPipeline` for a usage example."
)
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
if self.step_index is None:
self._init_step_index(timestep)
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
sigma = self.sigmas[self.step_index]
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
......@@ -401,10 +423,13 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
# 2. Convert to an ODE derivative
derivative = (sample - pred_original_sample) / sigma_hat
dt = self.sigmas[step_index + 1] - sigma_hat
dt = self.sigmas[self.step_index + 1] - sigma_hat
prev_sample = sample + derivative * dt
# upon completion increase step index by one
self._step_index += 1
if not return_dict:
return (prev_sample,)
......
......@@ -149,6 +149,8 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
self.use_karras_sigmas = use_karras_sigmas
self._step_index = None
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
......@@ -175,6 +177,13 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
return (self.sigmas.max() ** 2 + 1) ** 0.5
@property
def step_index(self):
"""
The index counter for current timestep. It will increae 1 after each scheduler step.
"""
return self._step_index
def scale_model_input(
self,
sample: torch.FloatTensor,
......@@ -194,9 +203,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`:
A scaled input sample.
"""
step_index = self.index_for_timestep(timestep)
if self.step_index is None:
self._init_step_index(timestep)
sigma = self.sigmas[step_index]
sigma = self.sigmas[self.step_index]
sample = sample / ((sigma**2 + 1) ** 0.5)
return sample
......@@ -221,18 +231,18 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy()
elif self.config.timestep_spacing == "leading":
step_ratio = num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = num_train_timesteps / self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float)
timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
timesteps -= 1
else:
raise ValueError(
......@@ -254,16 +264,15 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
timesteps = torch.from_numpy(timesteps)
timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])
if str(device).startswith("mps"):
# mps does not support float64
self.timesteps = timesteps.to(device, dtype=torch.float32)
else:
self.timesteps = timesteps.to(device=device)
self.timesteps = timesteps.to(device=device)
# empty dt and derivative
self.prev_derivative = None
self.dt = None
self._step_index = None
# (YiYi Notes: keep this for now since we are keeping add_noise function which use index_for_timestep)
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
# we need an index counter
self._index_counter = defaultdict(int)
......@@ -310,6 +319,24 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
def state_in_first_order(self):
return self.dt is None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (self.timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(index_candidates) > 1:
step_index = index_candidates[1]
else:
step_index = index_candidates[0]
self._step_index = step_index.item()
def step(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
......@@ -336,19 +363,21 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
step_index = self.index_for_timestep(timestep)
if self.step_index is None:
self._init_step_index(timestep)
# (YiYi notes: keep this for now since we are keeping the add_noise method)
# advance index counter by 1
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
self._index_counter[timestep_int] += 1
if self.state_in_first_order:
sigma = self.sigmas[step_index]
sigma_next = self.sigmas[step_index + 1]
sigma = self.sigmas[self.step_index]
sigma_next = self.sigmas[self.step_index + 1]
else:
# 2nd order / Heun's method
sigma = self.sigmas[step_index - 1]
sigma_next = self.sigmas[step_index]
sigma = self.sigmas[self.step_index - 1]
sigma_next = self.sigmas[self.step_index]
# currently only gamma=0 is supported. This usually works best anyways.
# We can support gamma in the future but then need to scale the timestep before
......@@ -404,6 +433,9 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
prev_sample = sample + derivative * dt
# upon completion increase step index by one
self._step_index += 1
if not return_dict:
return (prev_sample,)
......
......@@ -55,6 +55,14 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
# running values
self.ets = []
self._step_index = None
@property
def step_index(self):
"""
The index counter for current timestep. It will increae 1 after each scheduler step.
"""
return self._step_index
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
......@@ -81,6 +89,25 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = timesteps.to(device)
self.ets = []
self._step_index = None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (self.timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(index_candidates) > 1:
step_index = index_candidates[1]
else:
step_index = index_candidates[0]
self._step_index = step_index.item()
def step(
self,
......@@ -112,9 +139,11 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
if self.step_index is None:
self._init_step_index(timestep)
timestep_index = (self.timesteps == timestep).nonzero().item()
prev_timestep_index = timestep_index + 1
timestep_index = self.step_index
prev_timestep_index = self.step_index + 1
ets = sample * self.betas[timestep_index] + model_output * self.alphas[timestep_index]
self.ets.append(ets)
......@@ -130,6 +159,9 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
prev_sample = self._get_prev_sample(sample, timestep_index, prev_timestep_index, ets)
# upon completion increase step index by one
self._step_index += 1
if not return_dict:
return (prev_sample,)
......
......@@ -137,6 +137,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
# set all values
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
self._step_index = None
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
......@@ -165,6 +166,13 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
return (self.sigmas.max() ** 2 + 1) ** 0.5
@property
def step_index(self):
"""
The index counter for current timestep. It will increae 1 after each scheduler step.
"""
return self._step_index
def scale_model_input(
self,
sample: torch.FloatTensor,
......@@ -184,12 +192,13 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`:
A scaled input sample.
"""
step_index = self.index_for_timestep(timestep)
if self.step_index is None:
self._init_step_index(timestep)
if self.state_in_first_order:
sigma = self.sigmas[step_index]
sigma = self.sigmas[self.step_index]
else:
sigma = self.sigmas_interpol[step_index - 1]
sigma = self.sigmas_interpol[self.step_index - 1]
sample = sample / ((sigma**2 + 1) ** 0.5)
return sample
......@@ -215,18 +224,18 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy()
elif self.config.timestep_spacing == "leading":
step_ratio = num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = num_train_timesteps / self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float)
timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
timesteps -= 1
else:
raise ValueError(
......@@ -259,12 +268,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.sigmas_up = torch.cat([sigmas_up[:1], sigmas_up[1:].repeat_interleave(2), sigmas_up[-1:]])
self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]])
if str(device).startswith("mps"):
# mps does not support float64
timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
else:
timesteps = torch.from_numpy(timesteps).to(device)
timesteps = torch.from_numpy(timesteps).to(device)
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype)
interleaved_timesteps = torch.stack((timesteps_interpol[:-2, None], timesteps[1:, None]), dim=-1).flatten()
......@@ -276,6 +280,8 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
# we need an index counter
self._index_counter = defaultdict(int)
self._step_index = None
def sigma_to_t(self, sigma):
# get log sigma
log_sigma = sigma.log()
......@@ -303,6 +309,24 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
def state_in_first_order(self):
return self.sample is None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (self.timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(index_candidates) > 1:
step_index = index_candidates[1]
else:
step_index = index_candidates[0]
self._step_index = step_index.item()
def step(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
......@@ -332,23 +356,24 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
If return_dict is `True`, [`~schedulers.scheduling_ddim.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
step_index = self.index_for_timestep(timestep)
if self.step_index is None:
self._init_step_index(timestep)
# advance index counter by 1
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
self._index_counter[timestep_int] += 1
if self.state_in_first_order:
sigma = self.sigmas[step_index]
sigma_interpol = self.sigmas_interpol[step_index]
sigma_up = self.sigmas_up[step_index]
sigma_down = self.sigmas_down[step_index - 1]
sigma = self.sigmas[self.step_index]
sigma_interpol = self.sigmas_interpol[self.step_index]
sigma_up = self.sigmas_up[self.step_index]
sigma_down = self.sigmas_down[self.step_index - 1]
else:
# 2nd order / KPDM2's method
sigma = self.sigmas[step_index - 1]
sigma_interpol = self.sigmas_interpol[step_index - 1]
sigma_up = self.sigmas_up[step_index - 1]
sigma_down = self.sigmas_down[step_index - 1]
sigma = self.sigmas[self.step_index - 1]
sigma_interpol = self.sigmas_interpol[self.step_index - 1]
sigma_up = self.sigmas_up[self.step_index - 1]
sigma_down = self.sigmas_down[self.step_index - 1]
# currently only gamma=0 is supported. This usually works best anyways.
# We can support gamma in the future but then need to scale the timestep before
......@@ -398,6 +423,9 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
prev_sample = sample + derivative * dt
prev_sample = prev_sample + noise * sigma_up
# upon completion increase step index by one
self._step_index += 1
if not return_dict:
return (prev_sample,)
......
......@@ -137,6 +137,8 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
# set all values
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
self._step_index = None
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
......@@ -164,6 +166,13 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
return (self.sigmas.max() ** 2 + 1) ** 0.5
@property
def step_index(self):
"""
The index counter for current timestep. It will increae 1 after each scheduler step.
"""
return self._step_index
def scale_model_input(
self,
sample: torch.FloatTensor,
......@@ -183,12 +192,13 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`:
A scaled input sample.
"""
step_index = self.index_for_timestep(timestep)
if self.step_index is None:
self._init_step_index(timestep)
if self.state_in_first_order:
sigma = self.sigmas[step_index]
sigma = self.sigmas[self.step_index]
else:
sigma = self.sigmas_interpol[step_index]
sigma = self.sigmas_interpol[self.step_index]
sample = sample / ((sigma**2 + 1) ** 0.5)
return sample
......@@ -214,18 +224,18 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy()
elif self.config.timestep_spacing == "leading":
step_ratio = num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = num_train_timesteps / self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float)
timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
timesteps -= 1
else:
raise ValueError(
......@@ -247,11 +257,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
[sigmas_interpol[:1], sigmas_interpol[1:].repeat_interleave(2), sigmas_interpol[-1:]]
)
if str(device).startswith("mps"):
# mps does not support float64
timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
else:
timesteps = torch.from_numpy(timesteps).to(device)
timesteps = torch.from_numpy(timesteps).to(device)
# interpolate timesteps
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype)
......@@ -265,6 +271,8 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
# we need an index counter
self._index_counter = defaultdict(int)
self._step_index = None
def sigma_to_t(self, sigma):
# get log sigma
log_sigma = sigma.log()
......@@ -292,6 +300,24 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
def state_in_first_order(self):
return self.sample is None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (self.timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(index_candidates) > 1:
step_index = index_candidates[1]
else:
step_index = index_candidates[0]
self._step_index = step_index.item()
def step(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
......@@ -318,21 +344,22 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
step_index = self.index_for_timestep(timestep)
if self.step_index is None:
self._init_step_index(timestep)
# advance index counter by 1
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
self._index_counter[timestep_int] += 1
if self.state_in_first_order:
sigma = self.sigmas[step_index]
sigma_interpol = self.sigmas_interpol[step_index + 1]
sigma_next = self.sigmas[step_index + 1]
sigma = self.sigmas[self.step_index]
sigma_interpol = self.sigmas_interpol[self.step_index + 1]
sigma_next = self.sigmas[self.step_index + 1]
else:
# 2nd order / KDPM2's method
sigma = self.sigmas[step_index - 1]
sigma_interpol = self.sigmas_interpol[step_index]
sigma_next = self.sigmas[step_index]
sigma = self.sigmas[self.step_index - 1]
sigma_interpol = self.sigmas_interpol[self.step_index]
sigma_next = self.sigmas[self.step_index]
# currently only gamma=0 is supported. This usually works best anyways.
# We can support gamma in the future but then need to scale the timestep before
......@@ -375,6 +402,9 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
sample = self.sample
self.sample = None
# upon completion increase step index by one
self._step_index += 1
prev_sample = sample + derivative * dt
if not return_dict:
......
......@@ -169,6 +169,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.derivatives = []
self.is_scale_input_called = False
self._step_index = None
@property
def init_noise_sigma(self):
# standard deviation of the initial noise distribution
......@@ -177,6 +179,13 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
return (self.sigmas.max() ** 2 + 1) ** 0.5
@property
def step_index(self):
"""
The index counter for current timestep. It will increae 1 after each scheduler step.
"""
return self._step_index
def scale_model_input(
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
) -> torch.FloatTensor:
......@@ -194,10 +203,11 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`:
A scaled input sample.
"""
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
if self.step_index is None:
self._init_step_index(timestep)
sigma = self.sigmas[self.step_index]
sample = sample / ((sigma**2 + 1) ** 0.5)
self.is_scale_input_called = True
return sample
......@@ -238,20 +248,20 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[
::-1
].copy()
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float)
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
timesteps -= 1
else:
raise ValueError(
......@@ -269,14 +279,29 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas).to(device=device)
if str(device).startswith("mps"):
# mps does not support float64
self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
else:
self.timesteps = torch.from_numpy(timesteps).to(device=device)
self.timesteps = torch.from_numpy(timesteps).to(device=device)
self._step_index = None
self.derivatives = []
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (self.timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(index_candidates) > 1:
step_index = index_candidates[1]
else:
step_index = index_candidates[0]
self._step_index = step_index.item()
# copied from diffusers.schedulers.scheduling_euler_discrete._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma
......@@ -351,10 +376,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
"See `StableDiffusionPipeline` for a usage example."
)
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
if self.step_index is None:
self._init_step_index(timestep)
sigma = self.sigmas[self.step_index]
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
if self.config.prediction_type == "epsilon":
......@@ -376,14 +401,17 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.derivatives.pop(0)
# 3. Compute linear multistep coefficients
order = min(step_index + 1, order)
lms_coeffs = [self.get_lms_coefficient(order, step_index, curr_order) for curr_order in range(order)]
order = min(self.step_index + 1, order)
lms_coeffs = [self.get_lms_coefficient(order, self.step_index, curr_order) for curr_order in range(order)]
# 4. Compute previous sample based on the derivatives path
prev_sample = sample + sum(
coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives))
)
# upon completion increase step index by one
self._step_index += 1
if not return_dict:
return (prev_sample,)
......
......@@ -104,6 +104,8 @@ class IPNDMSchedulerTest(SchedulerCommonTest):
residual = model(sample, t)
sample = scheduler.step(residual, t, sample).prev_sample
scheduler._step_index = None
for i, t in enumerate(scheduler.timesteps):
residual = model(sample, t)
sample = scheduler.step(residual, t, sample).prev_sample
......
......@@ -485,8 +485,8 @@ class SchedulerCommonTest(unittest.TestCase):
num_inference_steps = kwargs.pop("num_inference_steps", None)
timestep_0 = 0
timestep_1 = 1
timestep_0 = 1
timestep_1 = 0
for scheduler_class in self.scheduler_classes:
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment