"docs/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "ab03dc4370acaaef05810465a077691472624b2b"
Unverified Commit fe4837a9 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

add step_index and clear noise_sampler at begining of each loop (#5024)


Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
parent 342c5c02
...@@ -199,6 +199,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): ...@@ -199,6 +199,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
self.use_karras_sigmas = use_karras_sigmas self.use_karras_sigmas = use_karras_sigmas
self.noise_sampler = None self.noise_sampler = None
self.noise_sampler_seed = noise_sampler_seed self.noise_sampler_seed = noise_sampler_seed
self._step_index = None
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None): def index_for_timestep(self, timestep, schedule_timesteps=None):
...@@ -219,6 +220,24 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): ...@@ -219,6 +220,24 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
return indices[pos].item() return indices[pos].item()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (self.timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(index_candidates) > 1:
step_index = index_candidates[1]
else:
step_index = index_candidates[0]
self._step_index = step_index.item()
@property @property
def init_noise_sigma(self): def init_noise_sigma(self):
# standard deviation of the initial noise distribution # standard deviation of the initial noise distribution
...@@ -227,6 +246,13 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): ...@@ -227,6 +246,13 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
return (self.sigmas.max() ** 2 + 1) ** 0.5 return (self.sigmas.max() ** 2 + 1) ** 0.5
@property
def step_index(self):
"""
The index counter for current timestep. It will increae 1 after each scheduler step.
"""
return self._step_index
def scale_model_input( def scale_model_input(
self, self,
sample: torch.FloatTensor, sample: torch.FloatTensor,
...@@ -246,9 +272,10 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): ...@@ -246,9 +272,10 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`: `torch.FloatTensor`:
A scaled input sample. A scaled input sample.
""" """
step_index = self.index_for_timestep(timestep) if self.step_index is None:
self._init_step_index(timestep)
sigma = self.sigmas[step_index] sigma = self.sigmas[self.step_index]
sigma_input = sigma if self.state_in_first_order else self.mid_point_sigma sigma_input = sigma if self.state_in_first_order else self.mid_point_sigma
sample = sample / ((sigma_input**2 + 1) ** 0.5) sample = sample / ((sigma_input**2 + 1) ** 0.5)
return sample return sample
...@@ -321,6 +348,9 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): ...@@ -321,6 +348,9 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
self.sample = None self.sample = None
self.mid_point_sigma = None self.mid_point_sigma = None
self._step_index = None
self.noise_sampler = None
# for exp beta schedules, such as the one for `pipeline_shap_e.py` # for exp beta schedules, such as the one for `pipeline_shap_e.py`
# we need an index counter # we need an index counter
self._index_counter = defaultdict(int) self._index_counter = defaultdict(int)
...@@ -411,7 +441,8 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): ...@@ -411,7 +441,8 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor. tuple is returned where the first element is the sample tensor.
""" """
step_index = self.index_for_timestep(timestep) if self.step_index is None:
self._init_step_index(timestep)
# advance index counter by 1 # advance index counter by 1
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
...@@ -430,12 +461,12 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): ...@@ -430,12 +461,12 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
return _sigma.log().neg() return _sigma.log().neg()
if self.state_in_first_order: if self.state_in_first_order:
sigma = self.sigmas[step_index] sigma = self.sigmas[self.step_index]
sigma_next = self.sigmas[step_index + 1] sigma_next = self.sigmas[self.step_index + 1]
else: else:
# 2nd order # 2nd order
sigma = self.sigmas[step_index - 1] sigma = self.sigmas[self.step_index - 1]
sigma_next = self.sigmas[step_index] sigma_next = self.sigmas[self.step_index]
# Set the midpoint and step size for the current step # Set the midpoint and step size for the current step
midpoint_ratio = 0.5 midpoint_ratio = 0.5
...@@ -488,6 +519,9 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): ...@@ -488,6 +519,9 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
self.sample = None self.sample = None
self.mid_point_sigma = None self.mid_point_sigma = None
# upon completion increase step index by one
self._step_index += 1
if not return_dict: if not return_dict:
return (prev_sample,) return (prev_sample,)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment