Unverified Commit adcbe674 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

[refactor]Scheduler.set_begin_index (#6728)

parent ec9840a5
...@@ -789,6 +789,8 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -789,6 +789,8 @@ class StableDiffusionControlNetImg2ImgPipeline(
t_start = max(num_inference_steps - init_timestep, 0) t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start return timesteps, num_inference_steps - t_start
......
...@@ -705,6 +705,8 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -705,6 +705,8 @@ class StableDiffusionControlNetInpaintPipeline(
t_start = max(num_inference_steps - init_timestep, 0) t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start return timesteps, num_inference_steps - t_start
......
...@@ -871,6 +871,8 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -871,6 +871,8 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
t_start = max(num_inference_steps - init_timestep, 0) t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start return timesteps, num_inference_steps - t_start
......
...@@ -566,6 +566,8 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor ...@@ -566,6 +566,8 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
t_start = max(num_inference_steps - init_timestep, 0) t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start return timesteps, num_inference_steps - t_start
......
...@@ -536,6 +536,8 @@ class StableDiffusionInpaintPipelineLegacy( ...@@ -536,6 +536,8 @@ class StableDiffusionInpaintPipelineLegacy(
t_start = max(num_inference_steps - init_timestep, 0) t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start return timesteps, num_inference_steps - t_start
......
...@@ -634,6 +634,8 @@ class LatentConsistencyModelImg2ImgPipeline( ...@@ -634,6 +634,8 @@ class LatentConsistencyModelImg2ImgPipeline(
t_start = max(num_inference_steps - init_timestep, 0) t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start return timesteps, num_inference_steps - t_start
......
...@@ -906,6 +906,8 @@ class PIAPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin ...@@ -906,6 +906,8 @@ class PIAPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin
t_start = max(num_inference_steps - init_timestep, 0) t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start return timesteps, num_inference_steps - t_start
......
...@@ -467,6 +467,8 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -467,6 +467,8 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
t_start = max(num_inference_steps - init_timestep, 0) t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start return timesteps, num_inference_steps - t_start
......
...@@ -659,6 +659,8 @@ class StableDiffusionImg2ImgPipeline( ...@@ -659,6 +659,8 @@ class StableDiffusionImg2ImgPipeline(
t_start = max(num_inference_steps - init_timestep, 0) t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start return timesteps, num_inference_steps - t_start
......
...@@ -859,6 +859,8 @@ class StableDiffusionInpaintPipeline( ...@@ -859,6 +859,8 @@ class StableDiffusionInpaintPipeline(
t_start = max(num_inference_steps - init_timestep, 0) t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start return timesteps, num_inference_steps - t_start
......
...@@ -754,6 +754,8 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -754,6 +754,8 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
t_start = max(num_inference_steps - init_timestep, 0) t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start return timesteps, num_inference_steps - t_start
......
...@@ -554,6 +554,8 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor ...@@ -554,6 +554,8 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
t_start = max(num_inference_steps - init_timestep, 0) t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start return timesteps, num_inference_steps - t_start
......
...@@ -98,15 +98,9 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin): ...@@ -98,15 +98,9 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
self.custom_timesteps = False self.custom_timesteps = False
self.is_scale_input_called = False self.is_scale_input_called = False
self._step_index = None self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = (schedule_timesteps == timestep).nonzero()
return indices.item()
@property @property
def step_index(self): def step_index(self):
""" """
...@@ -114,6 +108,24 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin): ...@@ -114,6 +108,24 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
""" """
return self._step_index return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def scale_model_input( def scale_model_input(
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
) -> torch.FloatTensor: ) -> torch.FloatTensor:
...@@ -231,6 +243,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin): ...@@ -231,6 +243,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = torch.from_numpy(timesteps).to(device=device) self.timesteps = torch.from_numpy(timesteps).to(device=device)
self._step_index = None self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Modified _convert_to_karras implementation that takes in ramp as argument # Modified _convert_to_karras implementation that takes in ramp as argument
...@@ -280,23 +293,29 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin): ...@@ -280,23 +293,29 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
c_out = (sigma - sigma_min) * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 c_out = (sigma - sigma_min) * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
return c_skip, c_out return c_skip, c_out
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
def _init_step_index(self, timestep): def index_for_timestep(self, timestep, schedule_timesteps=None):
if isinstance(timestep, torch.Tensor): if schedule_timesteps is None:
timestep = timestep.to(self.timesteps.device) schedule_timesteps = self.timesteps
index_candidates = (self.timesteps == timestep).nonzero() indices = (schedule_timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step` # The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1) # is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in # This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image) # case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(index_candidates) > 1: pos = 1 if len(indices) > 1 else 0
step_index = index_candidates[1]
else: return indices[pos].item()
step_index = index_candidates[0]
self._step_index = step_index.item() # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def step( def step(
self, self,
...@@ -412,7 +431,11 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin): ...@@ -412,7 +431,11 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps = self.timesteps.to(original_samples.device) schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
else:
step_indices = [self.begin_index] * timesteps.shape[0]
sigma = sigmas[step_indices].flatten() sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape): while len(sigma.shape) < len(original_samples.shape):
......
...@@ -187,6 +187,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -187,6 +187,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
self.model_outputs = [None] * solver_order self.model_outputs = [None] * solver_order
self.lower_order_nums = 0 self.lower_order_nums = 0
self._step_index = None self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property @property
...@@ -196,6 +197,24 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -196,6 +197,24 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
""" """
return self._step_index return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
""" """
Sets the discrete timesteps used for the diffusion chain (to be run before inference). Sets the discrete timesteps used for the diffusion chain (to be run before inference).
...@@ -255,6 +274,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -255,6 +274,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
# add an index counter for schedulers that allow duplicated timesteps # add an index counter for schedulers that allow duplicated timesteps
self._step_index = None self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
...@@ -620,11 +640,12 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -620,11 +640,12 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
else: else:
raise NotImplementedError("only support log-rho multistep deis now") raise NotImplementedError("only support log-rho multistep deis now")
def _init_step_index(self, timestep): # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
if isinstance(timestep, torch.Tensor): def index_for_timestep(self, timestep, schedule_timesteps=None):
timestep = timestep.to(self.timesteps.device) if schedule_timesteps is None:
schedule_timesteps = self.timesteps
index_candidates = (self.timesteps == timestep).nonzero() index_candidates = (schedule_timesteps == timestep).nonzero()
if len(index_candidates) == 0: if len(index_candidates) == 0:
step_index = len(self.timesteps) - 1 step_index = len(self.timesteps) - 1
...@@ -637,7 +658,20 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -637,7 +658,20 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
else: else:
step_index = index_candidates[0].item() step_index = index_candidates[0].item()
self._step_index = step_index return step_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
def _init_step_index(self, timestep):
"""
Initialize the step_index counter for the scheduler.
"""
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def step( def step(
self, self,
...@@ -736,16 +770,11 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -736,16 +770,11 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps = self.timesteps.to(original_samples.device) schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device)
step_indices = [] # begin_index is None when the scheduler is used for training
for timestep in timesteps: if self.begin_index is None:
index_candidates = (schedule_timesteps == timestep).nonzero() step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
if len(index_candidates) == 0: else:
step_index = len(schedule_timesteps) - 1 step_indices = [self.begin_index] * timesteps.shape[0]
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
step_indices.append(step_index)
sigma = sigmas[step_indices].flatten() sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape): while len(sigma.shape) < len(original_samples.shape):
......
...@@ -227,6 +227,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -227,6 +227,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self.model_outputs = [None] * solver_order self.model_outputs = [None] * solver_order
self.lower_order_nums = 0 self.lower_order_nums = 0
self._step_index = None self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property @property
...@@ -236,6 +237,23 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -236,6 +237,23 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
""" """
return self._step_index return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
""" """
Sets the discrete timesteps used for the diffusion chain (to be run before inference). Sets the discrete timesteps used for the diffusion chain (to be run before inference).
...@@ -311,6 +329,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -311,6 +329,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
# add an index counter for schedulers that allow duplicated timesteps # add an index counter for schedulers that allow duplicated timesteps
self._step_index = None self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
...@@ -792,11 +811,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -792,11 +811,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
) )
return x_t return x_t
def _init_step_index(self, timestep): def index_for_timestep(self, timestep, schedule_timesteps=None):
if isinstance(timestep, torch.Tensor): if schedule_timesteps is None:
timestep = timestep.to(self.timesteps.device) schedule_timesteps = self.timesteps
index_candidates = (self.timesteps == timestep).nonzero() index_candidates = (schedule_timesteps == timestep).nonzero()
if len(index_candidates) == 0: if len(index_candidates) == 0:
step_index = len(self.timesteps) - 1 step_index = len(self.timesteps) - 1
...@@ -809,7 +828,19 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -809,7 +828,19 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
else: else:
step_index = index_candidates[0].item() step_index = index_candidates[0].item()
self._step_index = step_index return step_index
def _init_step_index(self, timestep):
"""
Initialize the step_index counter for the scheduler.
"""
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def step( def step(
self, self,
...@@ -920,16 +951,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -920,16 +951,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps = self.timesteps.to(original_samples.device) schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device)
step_indices = [] # begin_index is None when the scheduler is used for training
for timestep in timesteps: if self.begin_index is None:
index_candidates = (schedule_timesteps == timestep).nonzero() step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
if len(index_candidates) == 0: else:
step_index = len(schedule_timesteps) - 1 step_indices = [self.begin_index] * timesteps.shape[0]
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
step_indices.append(step_index)
sigma = sigmas[step_indices].flatten() sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape): while len(sigma.shape) < len(original_samples.shape):
......
...@@ -767,7 +767,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -767,7 +767,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
) )
return x_t return x_t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
def _init_step_index(self, timestep): def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor): if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device) timestep = timestep.to(self.timesteps.device)
...@@ -879,7 +878,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -879,7 +878,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
""" """
return sample return sample
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.FloatTensor,
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import math import math
from collections import defaultdict
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -198,9 +197,10 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): ...@@ -198,9 +197,10 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
self.noise_sampler = None self.noise_sampler = None
self.noise_sampler_seed = noise_sampler_seed self.noise_sampler_seed = noise_sampler_seed
self._step_index = None self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None): def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None: if schedule_timesteps is None:
schedule_timesteps = self.timesteps schedule_timesteps = self.timesteps
...@@ -211,31 +211,18 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): ...@@ -211,31 +211,18 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
# is always the second index (or the last index if there is only 1) # is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in # This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image) # case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(self._index_counter) == 0: pos = 1 if len(indices) > 1 else 0
pos = 1 if len(indices) > 1 else 0
else:
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
pos = self._index_counter[timestep_int]
return indices[pos].item() return indices[pos].item()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep): def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor): if self.begin_index is None:
timestep = timestep.to(self.timesteps.device) if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (self.timesteps == timestep).nonzero() self._step_index = self.index_for_timestep(timestep)
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(index_candidates) > 1:
step_index = index_candidates[1]
else: else:
step_index = index_candidates[0] self._step_index = self._begin_index
self._step_index = step_index.item()
@property @property
def init_noise_sigma(self): def init_noise_sigma(self):
...@@ -252,6 +239,24 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): ...@@ -252,6 +239,24 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
""" """
return self._step_index return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def scale_model_input( def scale_model_input(
self, self,
sample: torch.FloatTensor, sample: torch.FloatTensor,
...@@ -348,13 +353,10 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): ...@@ -348,13 +353,10 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
self.mid_point_sigma = None self.mid_point_sigma = None
self._step_index = None self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
self.noise_sampler = None self.noise_sampler = None
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
# we need an index counter
self._index_counter = defaultdict(int)
def _second_order_timesteps(self, sigmas, log_sigmas): def _second_order_timesteps(self, sigmas, log_sigmas):
def sigma_fn(_t): def sigma_fn(_t):
return np.exp(-_t) return np.exp(-_t)
...@@ -444,10 +446,6 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): ...@@ -444,10 +446,6 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
if self.step_index is None: if self.step_index is None:
self._init_step_index(timestep) self._init_step_index(timestep)
# advance index counter by 1
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
self._index_counter[timestep_int] += 1
# Create a noise sampler if it hasn't been created yet # Create a noise sampler if it hasn't been created yet
if self.noise_sampler is None: if self.noise_sampler is None:
min_sigma, max_sigma = self.sigmas[self.sigmas > 0].min(), self.sigmas.max() min_sigma, max_sigma = self.sigmas[self.sigmas > 0].min(), self.sigmas.max()
...@@ -527,7 +525,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): ...@@ -527,7 +525,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
return SchedulerOutput(prev_sample=prev_sample) return SchedulerOutput(prev_sample=prev_sample)
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.add_noise # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.FloatTensor,
...@@ -544,7 +542,11 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): ...@@ -544,7 +542,11 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps = self.timesteps.to(original_samples.device) schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device)
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
else:
step_indices = [self.begin_index] * timesteps.shape[0]
sigma = sigmas[step_indices].flatten() sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape): while len(sigma.shape) < len(original_samples.shape):
......
...@@ -210,6 +210,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -210,6 +210,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
self.sample = None self.sample = None
self.order_list = self.get_order_list(num_train_timesteps) self.order_list = self.get_order_list(num_train_timesteps)
self._step_index = None self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
def get_order_list(self, num_inference_steps: int) -> List[int]: def get_order_list(self, num_inference_steps: int) -> List[int]:
...@@ -253,6 +254,24 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -253,6 +254,24 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
""" """
return self._step_index return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
""" """
Sets the discrete timesteps used for the diffusion chain (to be run before inference). Sets the discrete timesteps used for the diffusion chain (to be run before inference).
...@@ -315,6 +334,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -315,6 +334,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
# add an index counter for schedulers that allow duplicated timesteps # add an index counter for schedulers that allow duplicated timesteps
self._step_index = None self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
...@@ -813,11 +833,12 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -813,11 +833,12 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
else: else:
raise ValueError(f"Order must be 1, 2, 3, got {order}") raise ValueError(f"Order must be 1, 2, 3, got {order}")
def _init_step_index(self, timestep): # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
if isinstance(timestep, torch.Tensor): def index_for_timestep(self, timestep, schedule_timesteps=None):
timestep = timestep.to(self.timesteps.device) if schedule_timesteps is None:
schedule_timesteps = self.timesteps
index_candidates = (self.timesteps == timestep).nonzero() index_candidates = (schedule_timesteps == timestep).nonzero()
if len(index_candidates) == 0: if len(index_candidates) == 0:
step_index = len(self.timesteps) - 1 step_index = len(self.timesteps) - 1
...@@ -830,7 +851,20 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -830,7 +851,20 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
else: else:
step_index = index_candidates[0].item() step_index = index_candidates[0].item()
self._step_index = step_index return step_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
def _init_step_index(self, timestep):
"""
Initialize the step_index counter for the scheduler.
"""
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def step( def step(
self, self,
...@@ -925,16 +959,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -925,16 +959,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps = self.timesteps.to(original_samples.device) schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device)
step_indices = [] # begin_index is None when the scheduler is used for training
for timestep in timesteps: if self.begin_index is None:
index_candidates = (schedule_timesteps == timestep).nonzero() step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
if len(index_candidates) == 0: else:
step_index = len(schedule_timesteps) - 1 step_indices = [self.begin_index] * timesteps.shape[0]
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
step_indices.append(step_index)
sigma = sigmas[step_indices].flatten() sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape): while len(sigma.shape) < len(original_samples.shape):
......
...@@ -216,6 +216,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -216,6 +216,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.is_scale_input_called = False self.is_scale_input_called = False
self._step_index = None self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property @property
...@@ -233,6 +234,24 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -233,6 +234,24 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
""" """
return self._step_index return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def scale_model_input( def scale_model_input(
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
) -> torch.FloatTensor: ) -> torch.FloatTensor:
...@@ -300,25 +319,32 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -300,25 +319,32 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = torch.from_numpy(timesteps).to(device=device) self.timesteps = torch.from_numpy(timesteps).to(device=device)
self._step_index = None self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
def _init_step_index(self, timestep): def index_for_timestep(self, timestep, schedule_timesteps=None):
if isinstance(timestep, torch.Tensor): if schedule_timesteps is None:
timestep = timestep.to(self.timesteps.device) schedule_timesteps = self.timesteps
index_candidates = (self.timesteps == timestep).nonzero() indices = (schedule_timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step` # The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1) # is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in # This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image) # case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(index_candidates) > 1: pos = 1 if len(indices) > 1 else 0
step_index = index_candidates[1]
else: return indices[pos].item()
step_index = index_candidates[0]
self._step_index = step_index.item() # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def step( def step(
self, self,
...@@ -440,7 +466,11 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -440,7 +466,11 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps = self.timesteps.to(original_samples.device) schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
else:
step_indices = [self.begin_index] * timesteps.shape[0]
sigma = sigmas[step_indices].flatten() sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape): while len(sigma.shape) < len(original_samples.shape):
......
...@@ -237,6 +237,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -237,6 +237,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.use_karras_sigmas = use_karras_sigmas self.use_karras_sigmas = use_karras_sigmas
self._step_index = None self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property @property
...@@ -255,6 +256,24 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -255,6 +256,24 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
""" """
return self._step_index return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def scale_model_input( def scale_model_input(
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
) -> torch.FloatTensor: ) -> torch.FloatTensor:
...@@ -342,6 +361,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -342,6 +361,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
self._step_index = None self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
def _sigma_to_t(self, sigma, log_sigmas): def _sigma_to_t(self, sigma, log_sigmas):
...@@ -393,22 +413,27 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -393,22 +413,27 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas return sigmas
def _init_step_index(self, timestep): def index_for_timestep(self, timestep, schedule_timesteps=None):
if isinstance(timestep, torch.Tensor): if schedule_timesteps is None:
timestep = timestep.to(self.timesteps.device) schedule_timesteps = self.timesteps
index_candidates = (self.timesteps == timestep).nonzero() indices = (schedule_timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step` # The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1) # is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in # This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image) # case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(index_candidates) > 1: pos = 1 if len(indices) > 1 else 0
step_index = index_candidates[1]
else:
step_index = index_candidates[0]
self._step_index = step_index.item() return indices[pos].item()
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def step( def step(
self, self,
...@@ -538,7 +563,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -538,7 +563,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps = self.timesteps.to(original_samples.device) schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
else:
step_indices = [self.begin_index] * timesteps.shape[0]
sigma = sigmas[step_indices].flatten() sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape): while len(sigma.shape) < len(original_samples.shape):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment