Unverified Commit bd8df2da authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

[Pytorch] Pytorch only schedulers (#534)



* pytorch only schedulers

* fix style

* remove match_shape

* pytorch only ddpm

* remove SchedulerMixin

* remove numpy from karras_ve

* fix types

* remove numpy from lms_discrete

* remove numpy from pndm

* fix typo

* remove mixin and numpy from sde_vp and ve

* remove remaining tensor_format

* fix style

* sigmas has to be torch tensor

* removed set_format in readme

* remove set format from docs

* remove set_format from pipelines

* update tests

* fix typo

* continue to use mixin

* fix imports

* removed unsed imports

* match shape instead of assuming image shapes

* remove import typo

* update call to add_noise

* use math instead of numpy

* fix t_index

* removed commented out numpy tests

* timesteps needs to be discrete

* cast timesteps to int in flax scheduler too

* fix device mismatch issue

* small fix

* Update src/diffusers/schedulers/scheduling_pndm.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 3b747de8
...@@ -44,8 +44,7 @@ To this end, the design of schedulers is such that: ...@@ -44,8 +44,7 @@ To this end, the design of schedulers is such that:
The core API for any new scheduler must follow a limited structure. The core API for any new scheduler must follow a limited structure.
- Schedulers should provide one or more `def step(...)` functions that should be called to update the generated sample iteratively. - Schedulers should provide one or more `def step(...)` functions that should be called to update the generated sample iteratively.
- Schedulers should provide a `set_timesteps(...)` method that configures the parameters of a schedule function for a specific inference task. - Schedulers should provide a `set_timesteps(...)` method that configures the parameters of a schedule function for a specific inference task.
- Schedulers should be framework-agnostic, but provide a simple functionality to convert the scheduler into a specific framework, such as PyTorch - Schedulers should be framework-specific.
with a `set_format(...)` method.
The base class [`SchedulerMixin`] implements low level utilities used by multiple schedulers. The base class [`SchedulerMixin`] implements low level utilities used by multiple schedulers.
......
...@@ -274,7 +274,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): ...@@ -274,7 +274,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
# # predict the noise residual # predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform classifier free guidance # perform classifier free guidance
......
...@@ -424,7 +424,10 @@ def main(): ...@@ -424,7 +424,10 @@ def main():
# TODO (patil-suraj): load scheduler using args # TODO (patil-suraj): load scheduler using args
noise_scheduler = DDPMScheduler( noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt" beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
) )
train_dataset = TextualInversionDataset( train_dataset = TextualInversionDataset(
......
...@@ -59,7 +59,7 @@ def main(args): ...@@ -59,7 +59,7 @@ def main(args):
"UpBlock2D", "UpBlock2D",
), ),
) )
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, tensor_format="pt") noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
optimizer = torch.optim.AdamW( optimizer = torch.optim.AdamW(
model.parameters(), model.parameters(),
lr=args.learning_rate, lr=args.learning_rate,
......
...@@ -35,7 +35,6 @@ class DDIMPipeline(DiffusionPipeline): ...@@ -35,7 +35,6 @@ class DDIMPipeline(DiffusionPipeline):
def __init__(self, unet, scheduler): def __init__(self, unet, scheduler):
super().__init__() super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(unet=unet, scheduler=scheduler) self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad() @torch.no_grad()
......
...@@ -35,7 +35,6 @@ class DDPMPipeline(DiffusionPipeline): ...@@ -35,7 +35,6 @@ class DDPMPipeline(DiffusionPipeline):
def __init__(self, unet, scheduler): def __init__(self, unet, scheduler):
super().__init__() super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(unet=unet, scheduler=scheduler) self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad() @torch.no_grad()
......
...@@ -45,7 +45,6 @@ class LDMTextToImagePipeline(DiffusionPipeline): ...@@ -45,7 +45,6 @@ class LDMTextToImagePipeline(DiffusionPipeline):
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
): ):
super().__init__() super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler) self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
@torch.no_grad() @torch.no_grad()
......
...@@ -23,7 +23,6 @@ class LDMPipeline(DiffusionPipeline): ...@@ -23,7 +23,6 @@ class LDMPipeline(DiffusionPipeline):
def __init__(self, vqvae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler): def __init__(self, vqvae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler):
super().__init__() super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler) self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler)
@torch.no_grad() @torch.no_grad()
......
...@@ -39,7 +39,6 @@ class PNDMPipeline(DiffusionPipeline): ...@@ -39,7 +39,6 @@ class PNDMPipeline(DiffusionPipeline):
def __init__(self, unet: UNet2DModel, scheduler: PNDMScheduler): def __init__(self, unet: UNet2DModel, scheduler: PNDMScheduler):
super().__init__() super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(unet=unet, scheduler=scheduler) self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad() @torch.no_grad()
......
...@@ -57,7 +57,6 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -57,7 +57,6 @@ class StableDiffusionPipeline(DiffusionPipeline):
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPFeatureExtractor,
): ):
super().__init__() super().__init__()
scheduler = scheduler.set_format("pt")
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
warnings.warn( warnings.warn(
......
...@@ -69,7 +69,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -69,7 +69,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPFeatureExtractor,
): ):
super().__init__() super().__init__()
scheduler = scheduler.set_format("pt")
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
warnings.warn( warnings.warn(
......
...@@ -83,7 +83,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -83,7 +83,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPFeatureExtractor,
): ):
super().__init__() super().__init__()
scheduler = scheduler.set_format("pt")
logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.") logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.")
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
...@@ -320,11 +319,11 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -320,11 +319,11 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
if isinstance(self.scheduler, LMSDiscreteScheduler): if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample
# masking # masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor(t_index)) init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.LongTensor([t_index]))
else: else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# masking # masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t) init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.LongTensor([t]))
latents = (init_latents_proper * mask) + (latents * (1 - mask)) latents = (init_latents_proper * mask) + (latents * (1 - mask))
......
...@@ -35,7 +35,6 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline): ...@@ -35,7 +35,6 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline):
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPFeatureExtractor,
): ):
super().__init__() super().__init__()
scheduler = scheduler.set_format("np")
self.register_modules( self.register_modules(
vae_decoder=vae_decoder, vae_decoder=vae_decoder,
text_encoder=text_encoder, text_encoder=text_encoder,
......
...@@ -29,7 +29,6 @@ class KarrasVePipeline(DiffusionPipeline): ...@@ -29,7 +29,6 @@ class KarrasVePipeline(DiffusionPipeline):
def __init__(self, unet: UNet2DModel, scheduler: KarrasVeScheduler): def __init__(self, unet: UNet2DModel, scheduler: KarrasVeScheduler):
super().__init__() super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(unet=unet, scheduler=scheduler) self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad() @torch.no_grad()
......
...@@ -8,8 +8,7 @@ ...@@ -8,8 +8,7 @@
- Schedulers should provide one or more `def step(...)` functions that should be called iteratively to unroll the diffusion loop during - Schedulers should provide one or more `def step(...)` functions that should be called iteratively to unroll the diffusion loop during
the forward pass. the forward pass.
- Schedulers should be framework-agnostic, but provide a simple functionality to convert the scheduler into a specific framework, such as PyTorch - Schedulers should be framework specific.
with a `set_format(...)` method.
## Examples ## Examples
......
...@@ -46,7 +46,7 @@ class DDIMSchedulerOutput(BaseOutput): ...@@ -46,7 +46,7 @@ class DDIMSchedulerOutput(BaseOutput):
pred_original_sample: Optional[torch.FloatTensor] = None pred_original_sample: Optional[torch.FloatTensor] = None
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
""" """
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1]. (1-beta) over time from t = [0,1].
...@@ -72,7 +72,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): ...@@ -72,7 +72,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
t1 = i / num_diffusion_timesteps t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas, dtype=np.float32) return torch.tensor(betas)
class DDIMScheduler(SchedulerMixin, ConfigMixin): class DDIMScheduler(SchedulerMixin, ConfigMixin):
...@@ -106,7 +106,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -106,7 +106,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
an offset added to the inference steps. You can use a combination of `offset=1` and an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion. stable diffusion.
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
""" """
...@@ -121,15 +120,16 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -121,15 +120,16 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
clip_sample: bool = True, clip_sample: bool = True,
set_alpha_to_one: bool = True, set_alpha_to_one: bool = True,
steps_offset: int = 0, steps_offset: int = 0,
tensor_format: str = "pt",
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = np.asarray(trained_betas) self.betas = torch.from_numpy(trained_betas)
if beta_schedule == "linear": if beta_schedule == "linear":
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear": elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model. # this schedule is very specific to the latent diffusion model.
self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 self.betas = (
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
)
elif beta_schedule == "squaredcos_cap_v2": elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule # Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps) self.betas = betas_for_alpha_bar(num_train_timesteps)
...@@ -137,20 +137,17 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -137,20 +137,17 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas self.alphas = 1.0 - self.betas
self.alphas_cumprod = np.cumprod(self.alphas, axis=0) self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# At every step in ddim, we are looking into the previous alphas_cumprod # At every step in ddim, we are looking into the previous alphas_cumprod
# For the final step, there is no previous alphas_cumprod because we are already at 0 # For the final step, there is no previous alphas_cumprod because we are already at 0
# `set_alpha_to_one` decides whether we set this parameter simply to one or # `set_alpha_to_one` decides whether we set this parameter simply to one or
# whether we use the final alpha of the "non-previous" one. # whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0] self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
# setable values # setable values
self.num_inference_steps = None self.num_inference_steps = None
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() self.timesteps = np.arange(0, num_train_timesteps)[::-1]
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)
def _get_variance(self, timestep, prev_timestep): def _get_variance(self, timestep, prev_timestep):
alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t = self.alphas_cumprod[timestep]
...@@ -186,15 +183,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -186,15 +183,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
step_ratio = self.config.num_train_timesteps // self.num_inference_steps step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio # creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3 # casting to int to avoid issues when num_inference_step is power of 3
self.timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy() self.timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1]
self.timesteps += offset self.timesteps += offset
self.set_format(tensor_format=self.tensor_format)
def step( def step(
self, self,
model_output: Union[torch.FloatTensor, np.ndarray], model_output: torch.FloatTensor,
timestep: int, timestep: int,
sample: Union[torch.FloatTensor, np.ndarray], sample: torch.FloatTensor,
eta: float = 0.0, eta: float = 0.0,
use_clipped_model_output: bool = False, use_clipped_model_output: bool = False,
generator=None, generator=None,
...@@ -205,9 +201,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -205,9 +201,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
Args: Args:
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain. timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor` or `np.ndarray`): sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process. current instance of sample being created by diffusion process.
eta (`float`): weight of noise for added noise in diffusion step. eta (`float`): weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`): TODO use_clipped_model_output (`bool`): TODO
...@@ -251,7 +247,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -251,7 +247,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# 4. Clip "predicted x_0" # 4. Clip "predicted x_0"
if self.config.clip_sample: if self.config.clip_sample:
pred_original_sample = self.clip(pred_original_sample, -1, 1) pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
# 5. compute variance: "sigma_t(η)" -> see formula (16) # 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
...@@ -273,9 +269,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -273,9 +269,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
noise = torch.randn(model_output.shape, generator=generator).to(device) noise = torch.randn(model_output.shape, generator=generator).to(device)
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
if not torch.is_tensor(model_output):
variance = variance.numpy()
prev_sample = prev_sample + variance prev_sample = prev_sample + variance
if not return_dict: if not return_dict:
...@@ -285,16 +278,20 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -285,16 +278,20 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
def add_noise( def add_noise(
self, self,
original_samples: Union[torch.FloatTensor, np.ndarray], original_samples: torch.FloatTensor,
noise: Union[torch.FloatTensor, np.ndarray], noise: torch.FloatTensor,
timesteps: Union[torch.IntTensor, np.ndarray], timesteps: torch.IntTensor,
) -> Union[torch.FloatTensor, np.ndarray]: ) -> torch.FloatTensor:
if self.tensor_format == "pt": timesteps = timesteps.to(self.alphas_cumprod.device)
timesteps = timesteps.to(self.alphas_cumprod.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples return noisy_samples
......
...@@ -70,7 +70,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): ...@@ -70,7 +70,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
t1 = i / num_diffusion_timesteps t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas, dtype=np.float32) return torch.tensor(betas, dtype=torch.float32)
class DDPMScheduler(SchedulerMixin, ConfigMixin): class DDPMScheduler(SchedulerMixin, ConfigMixin):
...@@ -99,7 +99,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -99,7 +99,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
clip_sample (`bool`, default `True`): clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability. option to clip predicted sample between -1 and 1 for numerical stability.
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
""" """
...@@ -113,15 +112,16 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -113,15 +112,16 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
trained_betas: Optional[np.ndarray] = None, trained_betas: Optional[np.ndarray] = None,
variance_type: str = "fixed_small", variance_type: str = "fixed_small",
clip_sample: bool = True, clip_sample: bool = True,
tensor_format: str = "pt",
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = np.asarray(trained_betas) self.betas = torch.from_numpy(trained_betas)
elif beta_schedule == "linear": elif beta_schedule == "linear":
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear": elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model. # this schedule is very specific to the latent diffusion model.
self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 self.betas = (
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
)
elif beta_schedule == "squaredcos_cap_v2": elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule # Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps) self.betas = betas_for_alpha_bar(num_train_timesteps)
...@@ -129,15 +129,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -129,15 +129,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas self.alphas = 1.0 - self.betas
self.alphas_cumprod = np.cumprod(self.alphas, axis=0) self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.one = np.array(1.0) self.one = torch.tensor(1.0)
# setable values # setable values
self.num_inference_steps = None self.num_inference_steps = None
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() self.timesteps = np.arange(0, num_train_timesteps)[::-1]
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)
self.variance_type = variance_type self.variance_type = variance_type
...@@ -153,8 +150,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -153,8 +150,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
self.timesteps = np.arange( self.timesteps = np.arange(
0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps 0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
)[::-1].copy() )[::-1]
self.set_format(tensor_format=self.tensor_format)
def _get_variance(self, t, predicted_variance=None, variance_type=None): def _get_variance(self, t, predicted_variance=None, variance_type=None):
alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t = self.alphas_cumprod[t]
...@@ -170,15 +166,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -170,15 +166,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# hacks - were probably added for training stability # hacks - were probably added for training stability
if variance_type == "fixed_small": if variance_type == "fixed_small":
variance = self.clip(variance, min_value=1e-20) variance = torch.clamp(variance, min=1e-20)
# for rl-diffuser https://arxiv.org/abs/2205.09991 # for rl-diffuser https://arxiv.org/abs/2205.09991
elif variance_type == "fixed_small_log": elif variance_type == "fixed_small_log":
variance = self.log(self.clip(variance, min_value=1e-20)) variance = torch.log(torch.clamp(variance, min=1e-20))
elif variance_type == "fixed_large": elif variance_type == "fixed_large":
variance = self.betas[t] variance = self.betas[t]
elif variance_type == "fixed_large_log": elif variance_type == "fixed_large_log":
# Glide max_log # Glide max_log
variance = self.log(self.betas[t]) variance = torch.log(self.betas[t])
elif variance_type == "learned": elif variance_type == "learned":
return predicted_variance return predicted_variance
elif variance_type == "learned_range": elif variance_type == "learned_range":
...@@ -191,9 +187,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -191,9 +187,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: Union[torch.FloatTensor, np.ndarray], model_output: torch.FloatTensor,
timestep: int, timestep: int,
sample: Union[torch.FloatTensor, np.ndarray], sample: torch.FloatTensor,
predict_epsilon=True, predict_epsilon=True,
generator=None, generator=None,
return_dict: bool = True, return_dict: bool = True,
...@@ -203,9 +199,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -203,9 +199,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
Args: Args:
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain. timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor` or `np.ndarray`): sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process. current instance of sample being created by diffusion process.
predict_epsilon (`bool`): predict_epsilon (`bool`):
optional flag to use when model predicts the samples directly instead of the noise, epsilon. optional flag to use when model predicts the samples directly instead of the noise, epsilon.
...@@ -240,7 +236,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -240,7 +236,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# 3. Clip "predicted x_0" # 3. Clip "predicted x_0"
if self.config.clip_sample: if self.config.clip_sample:
pred_original_sample = self.clip(pred_original_sample, -1, 1) pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
...@@ -254,7 +250,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -254,7 +250,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# 6. Add noise # 6. Add noise
variance = 0 variance = 0
if t > 0: if t > 0:
noise = self.randn_like(model_output, generator=generator) noise = torch.randn(
model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator
).to(model_output.device)
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise
pred_prev_sample = pred_prev_sample + variance pred_prev_sample = pred_prev_sample + variance
...@@ -266,16 +264,21 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -266,16 +264,21 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
def add_noise( def add_noise(
self, self,
original_samples: Union[torch.FloatTensor, np.ndarray], original_samples: torch.FloatTensor,
noise: Union[torch.FloatTensor, np.ndarray], noise: torch.FloatTensor,
timesteps: Union[torch.IntTensor, np.ndarray], timesteps: torch.IntTensor,
) -> Union[torch.FloatTensor, np.ndarray]: ) -> torch.FloatTensor:
if self.tensor_format == "pt": timesteps = timesteps.to(self.alphas_cumprod.device)
timesteps = timesteps.to(self.alphas_cumprod.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples return noisy_samples
......
...@@ -74,7 +74,6 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -74,7 +74,6 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
A reasonable range is [0, 10]. A reasonable range is [0, 10].
s_max (`float`): the end value of the sigma range where we add noise. s_max (`float`): the end value of the sigma range where we add noise.
A reasonable range is [0.2, 80]. A reasonable range is [0.2, 80].
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
""" """
...@@ -87,15 +86,11 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -87,15 +86,11 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
s_churn: float = 80, s_churn: float = 80,
s_min: float = 0.05, s_min: float = 0.05,
s_max: float = 50, s_max: float = 50,
tensor_format: str = "pt",
): ):
# setable values # setable values
self.num_inference_steps = None self.num_inference_steps: int = None
self.timesteps = None self.timesteps: np.ndarray = None
self.schedule = None # sigma(t_i) self.schedule: torch.FloatTensor = None # sigma(t_i)
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)
def set_timesteps(self, num_inference_steps: int): def set_timesteps(self, num_inference_steps: int):
""" """
...@@ -108,20 +103,18 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -108,20 +103,18 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
""" """
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy() self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
self.schedule = [ schedule = [
( (
self.config.sigma_max**2 self.config.sigma_max**2
* (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1)) * (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1))
) )
for i in self.timesteps for i in self.timesteps
] ]
self.schedule = np.array(self.schedule, dtype=np.float32) self.schedule = torch.tensor(schedule, dtype=torch.float32)
self.set_format(tensor_format=self.tensor_format)
def add_noise_to_input( def add_noise_to_input(
self, sample: Union[torch.FloatTensor, np.ndarray], sigma: float, generator: Optional[torch.Generator] = None self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None
) -> Tuple[Union[torch.FloatTensor, np.ndarray], float]: ) -> Tuple[torch.FloatTensor, float]:
""" """
Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
...@@ -142,10 +135,10 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -142,10 +135,10 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: Union[torch.FloatTensor, np.ndarray], model_output: torch.FloatTensor,
sigma_hat: float, sigma_hat: float,
sigma_prev: float, sigma_prev: float,
sample_hat: Union[torch.FloatTensor, np.ndarray], sample_hat: torch.FloatTensor,
return_dict: bool = True, return_dict: bool = True,
) -> Union[KarrasVeOutput, Tuple]: ) -> Union[KarrasVeOutput, Tuple]:
""" """
...@@ -153,10 +146,10 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -153,10 +146,10 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
Args: Args:
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. model_output (`torch.FloatTensor`): direct output from learned diffusion model.
sigma_hat (`float`): TODO sigma_hat (`float`): TODO
sigma_prev (`float`): TODO sigma_prev (`float`): TODO
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO sample_hat (`torch.FloatTensor`): TODO
return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class
KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check). KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check).
...@@ -180,24 +173,24 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -180,24 +173,24 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
def step_correct( def step_correct(
self, self,
model_output: Union[torch.FloatTensor, np.ndarray], model_output: torch.FloatTensor,
sigma_hat: float, sigma_hat: float,
sigma_prev: float, sigma_prev: float,
sample_hat: Union[torch.FloatTensor, np.ndarray], sample_hat: torch.FloatTensor,
sample_prev: Union[torch.FloatTensor, np.ndarray], sample_prev: torch.FloatTensor,
derivative: Union[torch.FloatTensor, np.ndarray], derivative: torch.FloatTensor,
return_dict: bool = True, return_dict: bool = True,
) -> Union[KarrasVeOutput, Tuple]: ) -> Union[KarrasVeOutput, Tuple]:
""" """
Correct the predicted sample based on the output model_output of the network. TODO complete description Correct the predicted sample based on the output model_output of the network. TODO complete description
Args: Args:
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. model_output (`torch.FloatTensor`): direct output from learned diffusion model.
sigma_hat (`float`): TODO sigma_hat (`float`): TODO
sigma_prev (`float`): TODO sigma_prev (`float`): TODO
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO sample_hat (`torch.FloatTensor`): TODO
sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO sample_prev (`torch.FloatTensor`): TODO
derivative (`torch.FloatTensor` or `np.ndarray`): TODO derivative (`torch.FloatTensor`): TODO
return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class
Returns: Returns:
......
...@@ -63,7 +63,6 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -63,7 +63,6 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
`linear` or `scaled_linear`. `linear` or `scaled_linear`.
trained_betas (`np.ndarray`, optional): trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
""" """
...@@ -75,31 +74,29 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -75,31 +74,29 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_end: float = 0.02, beta_end: float = 0.02,
beta_schedule: str = "linear", beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None, trained_betas: Optional[np.ndarray] = None,
tensor_format: str = "pt",
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = np.asarray(trained_betas) self.betas = torch.from_numpy(trained_betas)
if beta_schedule == "linear": if beta_schedule == "linear":
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear": elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model. # this schedule is very specific to the latent diffusion model.
self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 self.betas = (
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
)
else: else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas self.alphas = 1.0 - self.betas
self.alphas_cumprod = np.cumprod(self.alphas, axis=0) self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
# setable values # setable values
self.num_inference_steps = None self.num_inference_steps = None
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() self.timesteps = np.arange(0, num_train_timesteps)[::-1] # to be consistent has to be smaller than sigmas by 1
self.derivatives = [] self.derivatives = []
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)
def get_lms_coefficient(self, order, t, current_order): def get_lms_coefficient(self, order, t, current_order):
""" """
Compute a linear multistep coefficient. Compute a linear multistep coefficient.
...@@ -131,24 +128,24 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -131,24 +128,24 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
the number of diffusion steps used when generating samples with a pre-trained model. the number of diffusion steps used when generating samples with a pre-trained model.
""" """
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
self.timesteps = np.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=float) timesteps = np.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=float)
low_idx = np.floor(self.timesteps).astype(int) low_idx = np.floor(timesteps).astype(int)
high_idx = np.ceil(self.timesteps).astype(int) high_idx = np.ceil(timesteps).astype(int)
frac = np.mod(self.timesteps, 1.0) frac = np.mod(timesteps, 1.0)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx]
self.sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas)
self.timesteps = timesteps.astype(int)
self.derivatives = [] self.derivatives = []
self.set_format(tensor_format=self.tensor_format)
def step( def step(
self, self,
model_output: Union[torch.FloatTensor, np.ndarray], model_output: torch.FloatTensor,
timestep: int, timestep: int,
sample: Union[torch.FloatTensor, np.ndarray], sample: torch.FloatTensor,
order: int = 4, order: int = 4,
return_dict: bool = True, return_dict: bool = True,
) -> Union[LMSDiscreteSchedulerOutput, Tuple]: ) -> Union[LMSDiscreteSchedulerOutput, Tuple]:
...@@ -157,9 +154,9 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -157,9 +154,9 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
Args: Args:
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain. timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor` or `np.ndarray`): sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process. current instance of sample being created by diffusion process.
order: coefficient for multi-step inference. order: coefficient for multi-step inference.
return_dict (`bool`): option for returning tuple rather than LMSDiscreteSchedulerOutput class return_dict (`bool`): option for returning tuple rather than LMSDiscreteSchedulerOutput class
...@@ -197,15 +194,18 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -197,15 +194,18 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
def add_noise( def add_noise(
self, self,
original_samples: Union[torch.FloatTensor, np.ndarray], original_samples: torch.FloatTensor,
noise: Union[torch.FloatTensor, np.ndarray], noise: torch.FloatTensor,
timesteps: Union[torch.IntTensor, np.ndarray], timesteps: torch.IntTensor,
) -> Union[torch.FloatTensor, np.ndarray]: ) -> torch.FloatTensor:
if self.tensor_format == "pt": sigmas = self.sigmas.to(original_samples.device)
timesteps = timesteps.to(self.sigmas.device) timesteps = timesteps.to(original_samples.device)
sigmas = self.match_shape(self.sigmas[timesteps], noise)
noisy_samples = original_samples + noise * sigmas sigma = sigmas[timesteps].flatten()
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)
noisy_samples = original_samples + noise * sigma
return noisy_samples return noisy_samples
def __len__(self): def __len__(self):
......
...@@ -132,7 +132,7 @@ class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -132,7 +132,7 @@ class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
return state.replace( return state.replace(
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,
timesteps=timesteps, timesteps=timesteps.astype(int),
derivatives=jnp.array([]), derivatives=jnp.array([]),
sigmas=sigmas, sigmas=sigmas,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment