Unverified Commit df9c0701 authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

Add back-compatibility to LMS timesteps (#750)

* Add back-compatibility to LMS timesteps

* style
parent c119dc4c
...@@ -202,11 +202,6 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -202,11 +202,6 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
When returning a tuple, the first element is the sample tensor. When returning a tuple, the first element is the sample tensor.
""" """
if not isinstance(timestep, float) and not isinstance(timestep, torch.FloatTensor):
warnings.warn(
f"`LMSDiscreteScheduler` timesteps must be `float` or `torch.FloatTensor`, not {type(timestep)}. "
"Make sure to pass one of the `scheduler.timesteps`"
)
if not self.is_scale_input_called: if not self.is_scale_input_called:
warnings.warn( warnings.warn(
"The `scale_model_input` function should be called before `step` to ensure correct denoising. " "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
...@@ -215,7 +210,18 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -215,7 +210,18 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
if isinstance(timestep, torch.Tensor): if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device) timestep = timestep.to(self.timesteps.device)
step_index = (self.timesteps == timestep).nonzero().item() if (
isinstance(timestep, int)
or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)
):
warnings.warn(
"Integer timesteps in `LMSDiscreteScheduler.step()` are deprecated and will be removed in version"
" 0.5.0. Make sure to pass one of the `scheduler.timesteps`."
)
step_index = timestep
else:
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index] sigma = self.sigmas[step_index]
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
...@@ -250,7 +256,14 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -250,7 +256,14 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = self.sigmas.to(original_samples.device) sigmas = self.sigmas.to(original_samples.device)
schedule_timesteps = self.timesteps.to(original_samples.device) schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] if isinstance(timesteps, torch.IntTensor) or isinstance(timesteps, torch.LongTensor):
warnings.warn(
"Integer timesteps in `LMSDiscreteScheduler.add_noise()` are deprecated and will be removed in"
" version 0.5.0. Make sure to pass values from `scheduler.timesteps`."
)
step_indices = timesteps
else:
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten() sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape): while len(sigma.shape) < len(original_samples.shape):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment