Commit 21ceda3f authored by anton-l's avatar anton-l
Browse files

Remove duplicate add_noise

parent 5321f3e2
...@@ -130,14 +130,5 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -130,14 +130,5 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
noisy_samples = (alpha_prod**0.5) * original_samples + ((1 - alpha_prod) ** 0.5) * noise noisy_samples = (alpha_prod**0.5) * original_samples + ((1 - alpha_prod) ** 0.5) * noise
return noisy_samples return noisy_samples
def add_noise(self, original_samples, noise, timesteps):
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
def __len__(self): def __len__(self):
return self.config.num_train_timesteps return self.config.num_train_timesteps
...@@ -28,12 +28,12 @@ from .import_utils import ( ...@@ -28,12 +28,12 @@ from .import_utils import (
DummyObject, DummyObject,
is_flax_available, is_flax_available,
is_inflect_available, is_inflect_available,
is_modelcards_available,
is_scipy_available, is_scipy_available,
is_tf_available, is_tf_available,
is_torch_available, is_torch_available,
is_transformers_available, is_transformers_available,
is_unidecode_available, is_unidecode_available,
is_modelcards_available,
requires_backends, requires_backends,
) )
from .logging import get_logger from .logging import get_logger
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment