"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "ff182ad6694ada3c01b3514eeae03392b2761b92"
Unverified Commit f45c675d authored by aengusng8's avatar aengusng8 Committed by GitHub
Browse files

[addresses issue #1642] add add_noise to scheduling-sde-ve (#1827)

* add add_noise to scheduling-sde-ve

* run Black formater
parent 1bf4f0da
...@@ -262,5 +262,18 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -262,5 +262,18 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
return SchedulerOutput(prev_sample=prev_sample) return SchedulerOutput(prev_sample=prev_sample)
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.FloatTensor,
) -> torch.FloatTensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
timesteps = timesteps.to(original_samples.device)
sigmas = self.discrete_sigmas.to(original_samples.device)[timesteps]
noise = torch.randn_like(original_samples) * sigmas[:, None, None, None]
noisy_samples = noise + original_samples
return noisy_samples
def __len__(self): def __len__(self):
return self.config.num_train_timesteps return self.config.num_train_timesteps
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment