Unverified Commit 83ae24ce authored by RuiningLi's avatar RuiningLi Committed by GitHub
Browse files

Added get_velocity function to EulerDiscreteScheduler. (#7733)



* Added get_velocity function to EulerDiscreteScheduler.

* Fix white space on blank lines

* Added copied from statement

* back to the original.

---------
Co-authored-by: default avatarRuining Li <ruining@robots.ox.ac.uk>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 8af793b2
...@@ -576,5 +576,44 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -576,5 +576,44 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
noisy_samples = original_samples + noise * sigma noisy_samples = original_samples + noise * sigma
return noisy_samples return noisy_samples
def get_velocity(
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.FloatTensor
) -> torch.FloatTensor:
if (
isinstance(timesteps, int)
or isinstance(timesteps, torch.IntTensor)
or isinstance(timesteps, torch.LongTensor)
):
raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `EulerDiscreteScheduler.get_velocity()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."
),
)
if sample.device.type == "mps" and torch.is_floating_point(timesteps):
# mps does not support float64
schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
timesteps = timesteps.to(sample.device, dtype=torch.float32)
else:
schedule_timesteps = self.timesteps.to(sample.device)
timesteps = timesteps.to(sample.device)
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
alphas_cumprod = self.alphas_cumprod.to(sample)
sqrt_alpha_prod = alphas_cumprod[step_indices] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(sample.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[step_indices]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return velocity
def __len__(self): def __len__(self):
return self.config.num_train_timesteps return self.config.num_train_timesteps
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment