Unverified Commit adcbe674 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

[refactor]Scheduler.set_begin_index (#6728)

parent ec9840a5
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import math import math
from collections import defaultdict
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -148,8 +147,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -148,8 +147,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.use_karras_sigmas = use_karras_sigmas self.use_karras_sigmas = use_karras_sigmas
self._step_index = None self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None): def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None: if schedule_timesteps is None:
schedule_timesteps = self.timesteps schedule_timesteps = self.timesteps
...@@ -160,11 +161,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -160,11 +161,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
# is always the second index (or the last index if there is only 1) # is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in # This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image) # case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(self._index_counter) == 0:
pos = 1 if len(indices) > 1 else 0 pos = 1 if len(indices) > 1 else 0
else:
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
pos = self._index_counter[timestep_int]
return indices[pos].item() return indices[pos].item()
...@@ -183,6 +180,24 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -183,6 +180,24 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
""" """
return self._step_index return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def scale_model_input( def scale_model_input(
self, self,
sample: torch.FloatTensor, sample: torch.FloatTensor,
...@@ -270,13 +285,9 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -270,13 +285,9 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.dt = None self.dt = None
self._step_index = None self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# (YiYi Notes: keep this for now since we are keeping add_noise function which use index_for_timestep)
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
# we need an index counter
self._index_counter = defaultdict(int)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas): def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma # get log sigma
...@@ -333,21 +344,12 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -333,21 +344,12 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep): def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor): if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device) timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
index_candidates = (self.timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(index_candidates) > 1:
step_index = index_candidates[1]
else: else:
step_index = index_candidates[0] self._step_index = self._begin_index
self._step_index = step_index.item()
def step( def step(
self, self,
...@@ -378,11 +380,6 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -378,11 +380,6 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
if self.step_index is None: if self.step_index is None:
self._init_step_index(timestep) self._init_step_index(timestep)
# (YiYi notes: keep this for now since we are keeping the add_noise method)
# advance index counter by 1
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
self._index_counter[timestep_int] += 1
if self.state_in_first_order: if self.state_in_first_order:
sigma = self.sigmas[self.step_index] sigma = self.sigmas[self.step_index]
sigma_next = self.sigmas[self.step_index + 1] sigma_next = self.sigmas[self.step_index + 1]
...@@ -453,6 +450,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -453,6 +450,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
return SchedulerOutput(prev_sample=prev_sample) return SchedulerOutput(prev_sample=prev_sample)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.FloatTensor,
...@@ -469,7 +467,11 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -469,7 +467,11 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps = self.timesteps.to(original_samples.device) schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device)
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
else:
step_indices = [self.begin_index] * timesteps.shape[0]
sigma = sigmas[step_indices].flatten() sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape): while len(sigma.shape) < len(original_samples.shape):
......
...@@ -56,6 +56,7 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -56,6 +56,7 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
# running values # running values
self.ets = [] self.ets = []
self._step_index = None self._step_index = None
self._begin_index = None
@property @property
def step_index(self): def step_index(self):
...@@ -64,6 +65,24 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -64,6 +65,24 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
""" """
return self._step_index return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
""" """
Sets the discrete timesteps used for the diffusion chain (to be run before inference). Sets the discrete timesteps used for the diffusion chain (to be run before inference).
...@@ -90,24 +109,31 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -90,24 +109,31 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
self.ets = [] self.ets = []
self._step_index = None self._step_index = None
self._begin_index = None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
def _init_step_index(self, timestep): def index_for_timestep(self, timestep, schedule_timesteps=None):
if isinstance(timestep, torch.Tensor): if schedule_timesteps is None:
timestep = timestep.to(self.timesteps.device) schedule_timesteps = self.timesteps
index_candidates = (self.timesteps == timestep).nonzero() indices = (schedule_timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step` # The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1) # is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in # This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image) # case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(index_candidates) > 1: pos = 1 if len(indices) > 1 else 0
step_index = index_candidates[1]
else: return indices[pos].item()
step_index = index_candidates[0]
self._step_index = step_index.item() # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def step( def step(
self, self,
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import math import math
from collections import defaultdict
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -140,27 +139,9 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -140,27 +139,9 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
# set all values # set all values
self.set_timesteps(num_train_timesteps, None, num_train_timesteps) self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
self._step_index = None self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = (schedule_timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(self._index_counter) == 0:
pos = 1 if len(indices) > 1 else 0
else:
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
pos = self._index_counter[timestep_int]
return indices[pos].item()
@property @property
def init_noise_sigma(self): def init_noise_sigma(self):
# standard deviation of the initial noise distribution # standard deviation of the initial noise distribution
...@@ -176,6 +157,24 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -176,6 +157,24 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
""" """
return self._step_index return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def scale_model_input( def scale_model_input(
self, self,
sample: torch.FloatTensor, sample: torch.FloatTensor,
...@@ -295,11 +294,8 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -295,11 +294,8 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.sample = None self.sample = None
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
# we need an index counter
self._index_counter = defaultdict(int)
self._step_index = None self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
...@@ -356,23 +352,29 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -356,23 +352,29 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
def state_in_first_order(self): def state_in_first_order(self):
return self.sample is None return self.sample is None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
def _init_step_index(self, timestep): def index_for_timestep(self, timestep, schedule_timesteps=None):
if isinstance(timestep, torch.Tensor): if schedule_timesteps is None:
timestep = timestep.to(self.timesteps.device) schedule_timesteps = self.timesteps
index_candidates = (self.timesteps == timestep).nonzero() indices = (schedule_timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step` # The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1) # is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in # This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image) # case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(index_candidates) > 1: pos = 1 if len(indices) > 1 else 0
step_index = index_candidates[1]
else:
step_index = index_candidates[0]
self._step_index = step_index.item() return indices[pos].item()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def step( def step(
self, self,
...@@ -406,10 +408,6 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -406,10 +408,6 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
if self.step_index is None: if self.step_index is None:
self._init_step_index(timestep) self._init_step_index(timestep)
# advance index counter by 1
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
self._index_counter[timestep_int] += 1
if self.state_in_first_order: if self.state_in_first_order:
sigma = self.sigmas[self.step_index] sigma = self.sigmas[self.step_index]
sigma_interpol = self.sigmas_interpol[self.step_index] sigma_interpol = self.sigmas_interpol[self.step_index]
...@@ -478,7 +476,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -478,7 +476,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
return SchedulerOutput(prev_sample=prev_sample) return SchedulerOutput(prev_sample=prev_sample)
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.add_noise # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.FloatTensor,
...@@ -495,7 +493,11 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -495,7 +493,11 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps = self.timesteps.to(original_samples.device) schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device)
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
else:
step_indices = [self.begin_index] * timesteps.shape[0]
sigma = sigmas[step_indices].flatten() sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape): while len(sigma.shape) < len(original_samples.shape):
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import math import math
from collections import defaultdict
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -140,27 +139,9 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -140,27 +139,9 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
self.set_timesteps(num_train_timesteps, None, num_train_timesteps) self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
self._step_index = None self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = (schedule_timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(self._index_counter) == 0:
pos = 1 if len(indices) > 1 else 0
else:
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
pos = self._index_counter[timestep_int]
return indices[pos].item()
@property @property
def init_noise_sigma(self): def init_noise_sigma(self):
# standard deviation of the initial noise distribution # standard deviation of the initial noise distribution
...@@ -176,6 +157,24 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -176,6 +157,24 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
""" """
return self._step_index return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def scale_model_input( def scale_model_input(
self, self,
sample: torch.FloatTensor, sample: torch.FloatTensor,
...@@ -280,34 +279,37 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -280,34 +279,37 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
self.sample = None self.sample = None
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
# we need an index counter
self._index_counter = defaultdict(int)
self._step_index = None self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property @property
def state_in_first_order(self): def state_in_first_order(self):
return self.sample is None return self.sample is None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
def _init_step_index(self, timestep): def index_for_timestep(self, timestep, schedule_timesteps=None):
if isinstance(timestep, torch.Tensor): if schedule_timesteps is None:
timestep = timestep.to(self.timesteps.device) schedule_timesteps = self.timesteps
index_candidates = (self.timesteps == timestep).nonzero() indices = (schedule_timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step` # The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1) # is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in # This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image) # case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(index_candidates) > 1: pos = 1 if len(indices) > 1 else 0
step_index = index_candidates[1]
else:
step_index = index_candidates[0]
self._step_index = step_index.item() return indices[pos].item()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas): def _sigma_to_t(self, sigma, log_sigmas):
...@@ -388,10 +390,6 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -388,10 +390,6 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
if self.step_index is None: if self.step_index is None:
self._init_step_index(timestep) self._init_step_index(timestep)
# advance index counter by 1
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
self._index_counter[timestep_int] += 1
if self.state_in_first_order: if self.state_in_first_order:
sigma = self.sigmas[self.step_index] sigma = self.sigmas[self.step_index]
sigma_interpol = self.sigmas_interpol[self.step_index + 1] sigma_interpol = self.sigmas_interpol[self.step_index + 1]
...@@ -453,7 +451,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -453,7 +451,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
return SchedulerOutput(prev_sample=prev_sample) return SchedulerOutput(prev_sample=prev_sample)
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.add_noise # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.FloatTensor,
...@@ -470,7 +468,11 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -470,7 +468,11 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps = self.timesteps.to(original_samples.device) schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device)
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
else:
step_indices = [self.begin_index] * timesteps.shape[0]
sigma = sigmas[step_indices].flatten() sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape): while len(sigma.shape) < len(original_samples.shape):
......
...@@ -250,29 +250,54 @@ class LCMScheduler(SchedulerMixin, ConfigMixin): ...@@ -250,29 +250,54 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
self.custom_timesteps = False self.custom_timesteps = False
self._step_index = None self._step_index = None
self._begin_index = None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
def _init_step_index(self, timestep): def index_for_timestep(self, timestep, schedule_timesteps=None):
if isinstance(timestep, torch.Tensor): if schedule_timesteps is None:
timestep = timestep.to(self.timesteps.device) schedule_timesteps = self.timesteps
index_candidates = (self.timesteps == timestep).nonzero() indices = (schedule_timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step` # The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1) # is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in # This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image) # case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(index_candidates) > 1: pos = 1 if len(indices) > 1 else 0
step_index = index_candidates[1]
else:
step_index = index_candidates[0]
self._step_index = step_index.item() return indices[pos].item()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
@property @property
def step_index(self): def step_index(self):
return self._step_index return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
""" """
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
...@@ -462,6 +487,7 @@ class LCMScheduler(SchedulerMixin, ConfigMixin): ...@@ -462,6 +487,7 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.long) self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.long)
self._step_index = None self._step_index = None
self._begin_index = None
def get_scalings_for_boundary_condition_discrete(self, timestep): def get_scalings_for_boundary_condition_discrete(self, timestep):
self.sigma_data = 0.5 # Default: 0.5 self.sigma_data = 0.5 # Default: 0.5
......
...@@ -168,6 +168,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -168,6 +168,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.is_scale_input_called = False self.is_scale_input_called = False
self._step_index = None self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property @property
...@@ -185,6 +186,24 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -185,6 +186,24 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
""" """
return self._step_index return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def scale_model_input( def scale_model_input(
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
) -> torch.FloatTensor: ) -> torch.FloatTensor:
...@@ -280,27 +299,34 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -280,27 +299,34 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = torch.from_numpy(sigmas).to(device=device) self.sigmas = torch.from_numpy(sigmas).to(device=device)
self.timesteps = torch.from_numpy(timesteps).to(device=device) self.timesteps = torch.from_numpy(timesteps).to(device=device)
self._step_index = None self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
self.derivatives = [] self.derivatives = []
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
def _init_step_index(self, timestep): def index_for_timestep(self, timestep, schedule_timesteps=None):
if isinstance(timestep, torch.Tensor): if schedule_timesteps is None:
timestep = timestep.to(self.timesteps.device) schedule_timesteps = self.timesteps
index_candidates = (self.timesteps == timestep).nonzero() indices = (schedule_timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step` # The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1) # is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in # This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image) # case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(index_candidates) > 1: pos = 1 if len(indices) > 1 else 0
step_index = index_candidates[1]
else:
step_index = index_candidates[0]
self._step_index = step_index.item() return indices[pos].item()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
# copied from diffusers.schedulers.scheduling_euler_discrete._sigma_to_t # copied from diffusers.schedulers.scheduling_euler_discrete._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas): def _sigma_to_t(self, sigma, log_sigmas):
...@@ -434,7 +460,11 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -434,7 +460,11 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps = self.timesteps.to(original_samples.device) schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
else:
step_indices = [self.begin_index] * timesteps.shape[0]
sigma = sigmas[step_indices].flatten() sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape): while len(sigma.shape) < len(original_samples.shape):
......
...@@ -212,6 +212,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin): ...@@ -212,6 +212,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
self.lower_order_nums = 0 self.lower_order_nums = 0
self.last_sample = None self.last_sample = None
self._step_index = None self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property @property
...@@ -221,6 +222,24 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin): ...@@ -221,6 +222,24 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
""" """
return self._step_index return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
""" """
Sets the discrete timesteps used for the diffusion chain (to be run before inference). Sets the discrete timesteps used for the diffusion chain (to be run before inference).
...@@ -283,6 +302,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin): ...@@ -283,6 +302,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
# add an index counter for schedulers that allow duplicated timesteps # add an index counter for schedulers that allow duplicated timesteps
self._step_index = None self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
...@@ -925,11 +945,12 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin): ...@@ -925,11 +945,12 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
x_t = x_t.to(x.dtype) x_t = x_t.to(x.dtype)
return x_t return x_t
def _init_step_index(self, timestep): # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
if isinstance(timestep, torch.Tensor): def index_for_timestep(self, timestep, schedule_timesteps=None):
timestep = timestep.to(self.timesteps.device) if schedule_timesteps is None:
schedule_timesteps = self.timesteps
index_candidates = (self.timesteps == timestep).nonzero() index_candidates = (schedule_timesteps == timestep).nonzero()
if len(index_candidates) == 0: if len(index_candidates) == 0:
step_index = len(self.timesteps) - 1 step_index = len(self.timesteps) - 1
...@@ -942,7 +963,20 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin): ...@@ -942,7 +963,20 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
else: else:
step_index = index_candidates[0].item() step_index = index_candidates[0].item()
self._step_index = step_index return step_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
def _init_step_index(self, timestep):
"""
Initialize the step_index counter for the scheduler.
"""
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def step( def step(
self, self,
......
...@@ -198,6 +198,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -198,6 +198,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
self.solver_p = solver_p self.solver_p = solver_p
self.last_sample = None self.last_sample = None
self._step_index = None self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property @property
...@@ -207,6 +208,24 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -207,6 +208,24 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
""" """
return self._step_index return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
""" """
Sets the discrete timesteps used for the diffusion chain (to be run before inference). Sets the discrete timesteps used for the diffusion chain (to be run before inference).
...@@ -269,6 +288,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -269,6 +288,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
# add an index counter for schedulers that allow duplicated timesteps # add an index counter for schedulers that allow duplicated timesteps
self._step_index = None self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
...@@ -698,11 +718,12 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -698,11 +718,12 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
x_t = x_t.to(x.dtype) x_t = x_t.to(x.dtype)
return x_t return x_t
def _init_step_index(self, timestep): # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
if isinstance(timestep, torch.Tensor): def index_for_timestep(self, timestep, schedule_timesteps=None):
timestep = timestep.to(self.timesteps.device) if schedule_timesteps is None:
schedule_timesteps = self.timesteps
index_candidates = (self.timesteps == timestep).nonzero() index_candidates = (schedule_timesteps == timestep).nonzero()
if len(index_candidates) == 0: if len(index_candidates) == 0:
step_index = len(self.timesteps) - 1 step_index = len(self.timesteps) - 1
...@@ -715,7 +736,20 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -715,7 +736,20 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
else: else:
step_index = index_candidates[0].item() step_index = index_candidates[0].item()
self._step_index = step_index return step_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
def _init_step_index(self, timestep):
"""
Initialize the step_index counter for the scheduler.
"""
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def step( def step(
self, self,
...@@ -830,16 +864,11 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -830,16 +864,11 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps = self.timesteps.to(original_samples.device) schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device)
step_indices = [] # begin_index is None when the scheduler is used for training
for timestep in timesteps: if self.begin_index is None:
index_candidates = (schedule_timesteps == timestep).nonzero() step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
if len(index_candidates) == 0:
step_index = len(schedule_timesteps) - 1
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else: else:
step_index = index_candidates[0].item() step_indices = [self.begin_index] * timesteps.shape[0]
step_indices.append(step_index)
sigma = sigmas[step_indices].flatten() sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape): while len(sigma.shape) < len(original_samples.shape):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment