Unverified Commit bd8df2da authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

[Pytorch] Pytorch only schedulers (#534)



* pytorch only schedulers

* fix style

* remove match_shape

* pytorch only ddpm

* remove SchedulerMixin

* remove numpy from karras_ve

* fix types

* remove numpy from lms_discrete

* remove numpy from pndm

* fix typo

* remove mixin and numpy from sde_vp and ve

* remove remaining tensor_format

* fix style

* sigmas has to be torch tensor

* removed set_format in readme

* remove set format from docs

* remove set_format from pipelines

* update tests

* fix typo

* continue to use mixin

* fix imports

* removed unsed imports

* match shape instead of assuming image shapes

* remove import typo

* update call to add_noise

* use math instead of numpy

* fix t_index

* removed commented out numpy tests

* timesteps needs to be discrete

* cast timesteps to int in flax scheduler too

* fix device mismatch issue

* small fix

* Update src/diffusers/schedulers/scheduling_pndm.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 3b747de8
......@@ -44,8 +44,7 @@ To this end, the design of schedulers is such that:
The core API for any new scheduler must follow a limited structure.
- Schedulers should provide one or more `def step(...)` functions that should be called to update the generated sample iteratively.
- Schedulers should provide a `set_timesteps(...)` method that configures the parameters of a schedule function for a specific inference task.
- Schedulers should be framework-agnostic, but provide a simple functionality to convert the scheduler into a specific framework, such as PyTorch
with a `set_format(...)` method.
- Schedulers should be framework-specific.
The base class [`SchedulerMixin`] implements low level utilities used by multiple schedulers.
......
......@@ -274,7 +274,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
# # predict the noise residual
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform classifier free guidance
......
......@@ -424,7 +424,10 @@ def main():
# TODO (patil-suraj): load scheduler using args
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt"
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
)
train_dataset = TextualInversionDataset(
......
......@@ -59,7 +59,7 @@ def main(args):
"UpBlock2D",
),
)
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, tensor_format="pt")
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=args.learning_rate,
......
......@@ -35,7 +35,6 @@ class DDIMPipeline(DiffusionPipeline):
def __init__(self, unet, scheduler):
super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()
......
......@@ -35,7 +35,6 @@ class DDPMPipeline(DiffusionPipeline):
def __init__(self, unet, scheduler):
super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()
......
......@@ -45,7 +45,6 @@ class LDMTextToImagePipeline(DiffusionPipeline):
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
):
super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
@torch.no_grad()
......
......@@ -23,7 +23,6 @@ class LDMPipeline(DiffusionPipeline):
def __init__(self, vqvae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler):
super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler)
@torch.no_grad()
......
......@@ -39,7 +39,6 @@ class PNDMPipeline(DiffusionPipeline):
def __init__(self, unet: UNet2DModel, scheduler: PNDMScheduler):
super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()
......
......@@ -57,7 +57,6 @@ class StableDiffusionPipeline(DiffusionPipeline):
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
scheduler = scheduler.set_format("pt")
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
warnings.warn(
......
......@@ -69,7 +69,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
scheduler = scheduler.set_format("pt")
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
warnings.warn(
......
......@@ -83,7 +83,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
scheduler = scheduler.set_format("pt")
logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.")
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
......@@ -320,11 +319,11 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample
# masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor(t_index))
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.LongTensor([t_index]))
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t)
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.LongTensor([t]))
latents = (init_latents_proper * mask) + (latents * (1 - mask))
......
......@@ -35,7 +35,6 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline):
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
scheduler = scheduler.set_format("np")
self.register_modules(
vae_decoder=vae_decoder,
text_encoder=text_encoder,
......
......@@ -29,7 +29,6 @@ class KarrasVePipeline(DiffusionPipeline):
def __init__(self, unet: UNet2DModel, scheduler: KarrasVeScheduler):
super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()
......
......@@ -8,8 +8,7 @@
- Schedulers should provide one or more `def step(...)` functions that should be called iteratively to unroll the diffusion loop during
the forward pass.
- Schedulers should be framework-agnostic, but provide a simple functionality to convert the scheduler into a specific framework, such as PyTorch
with a `set_format(...)` method.
- Schedulers should be framework specific.
## Examples
......
......@@ -46,7 +46,7 @@ class DDIMSchedulerOutput(BaseOutput):
pred_original_sample: Optional[torch.FloatTensor] = None
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
......@@ -72,7 +72,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas, dtype=np.float32)
return torch.tensor(betas)
class DDIMScheduler(SchedulerMixin, ConfigMixin):
......@@ -106,7 +106,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
"""
......@@ -121,15 +120,16 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
clip_sample: bool = True,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
tensor_format: str = "pt",
):
if trained_betas is not None:
self.betas = np.asarray(trained_betas)
self.betas = torch.from_numpy(trained_betas)
if beta_schedule == "linear":
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
self.betas = (
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
)
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
......@@ -137,20 +137,17 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# At every step in ddim, we are looking into the previous alphas_cumprod
# For the final step, there is no previous alphas_cumprod because we are already at 0
# `set_alpha_to_one` decides whether we set this parameter simply to one or
# whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
# setable values
self.num_inference_steps = None
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)
self.timesteps = np.arange(0, num_train_timesteps)[::-1]
def _get_variance(self, timestep, prev_timestep):
alpha_prod_t = self.alphas_cumprod[timestep]
......@@ -186,15 +183,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
self.timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
self.timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1]
self.timesteps += offset
self.set_format(tensor_format=self.tensor_format)
def step(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
model_output: torch.FloatTensor,
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
sample: torch.FloatTensor,
eta: float = 0.0,
use_clipped_model_output: bool = False,
generator=None,
......@@ -205,9 +201,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor` or `np.ndarray`):
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
eta (`float`): weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`): TODO
......@@ -251,7 +247,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# 4. Clip "predicted x_0"
if self.config.clip_sample:
pred_original_sample = self.clip(pred_original_sample, -1, 1)
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
# 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
......@@ -273,9 +269,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
noise = torch.randn(model_output.shape, generator=generator).to(device)
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
if not torch.is_tensor(model_output):
variance = variance.numpy()
prev_sample = prev_sample + variance
if not return_dict:
......@@ -285,16 +278,20 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
def add_noise(
self,
original_samples: Union[torch.FloatTensor, np.ndarray],
noise: Union[torch.FloatTensor, np.ndarray],
timesteps: Union[torch.IntTensor, np.ndarray],
) -> Union[torch.FloatTensor, np.ndarray]:
if self.tensor_format == "pt":
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
timesteps = timesteps.to(self.alphas_cumprod.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
......
......@@ -70,7 +70,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas, dtype=np.float32)
return torch.tensor(betas, dtype=torch.float32)
class DDPMScheduler(SchedulerMixin, ConfigMixin):
......@@ -99,7 +99,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability.
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
"""
......@@ -113,15 +112,16 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
trained_betas: Optional[np.ndarray] = None,
variance_type: str = "fixed_small",
clip_sample: bool = True,
tensor_format: str = "pt",
):
if trained_betas is not None:
self.betas = np.asarray(trained_betas)
self.betas = torch.from_numpy(trained_betas)
elif beta_schedule == "linear":
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
self.betas = (
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
)
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
......@@ -129,15 +129,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
self.one = np.array(1.0)
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.one = torch.tensor(1.0)
# setable values
self.num_inference_steps = None
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)
self.timesteps = np.arange(0, num_train_timesteps)[::-1]
self.variance_type = variance_type
......@@ -153,8 +150,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self.num_inference_steps = num_inference_steps
self.timesteps = np.arange(
0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
)[::-1].copy()
self.set_format(tensor_format=self.tensor_format)
)[::-1]
def _get_variance(self, t, predicted_variance=None, variance_type=None):
alpha_prod_t = self.alphas_cumprod[t]
......@@ -170,15 +166,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# hacks - were probably added for training stability
if variance_type == "fixed_small":
variance = self.clip(variance, min_value=1e-20)
variance = torch.clamp(variance, min=1e-20)
# for rl-diffuser https://arxiv.org/abs/2205.09991
elif variance_type == "fixed_small_log":
variance = self.log(self.clip(variance, min_value=1e-20))
variance = torch.log(torch.clamp(variance, min=1e-20))
elif variance_type == "fixed_large":
variance = self.betas[t]
elif variance_type == "fixed_large_log":
# Glide max_log
variance = self.log(self.betas[t])
variance = torch.log(self.betas[t])
elif variance_type == "learned":
return predicted_variance
elif variance_type == "learned_range":
......@@ -191,9 +187,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
model_output: torch.FloatTensor,
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
sample: torch.FloatTensor,
predict_epsilon=True,
generator=None,
return_dict: bool = True,
......@@ -203,9 +199,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor` or `np.ndarray`):
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
predict_epsilon (`bool`):
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
......@@ -240,7 +236,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# 3. Clip "predicted x_0"
if self.config.clip_sample:
pred_original_sample = self.clip(pred_original_sample, -1, 1)
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
......@@ -254,7 +250,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# 6. Add noise
variance = 0
if t > 0:
noise = self.randn_like(model_output, generator=generator)
noise = torch.randn(
model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator
).to(model_output.device)
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise
pred_prev_sample = pred_prev_sample + variance
......@@ -266,16 +264,21 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
def add_noise(
self,
original_samples: Union[torch.FloatTensor, np.ndarray],
noise: Union[torch.FloatTensor, np.ndarray],
timesteps: Union[torch.IntTensor, np.ndarray],
) -> Union[torch.FloatTensor, np.ndarray]:
if self.tensor_format == "pt":
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
timesteps = timesteps.to(self.alphas_cumprod.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
......
......@@ -74,7 +74,6 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
A reasonable range is [0, 10].
s_max (`float`): the end value of the sigma range where we add noise.
A reasonable range is [0.2, 80].
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
"""
......@@ -87,15 +86,11 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
s_churn: float = 80,
s_min: float = 0.05,
s_max: float = 50,
tensor_format: str = "pt",
):
# setable values
self.num_inference_steps = None
self.timesteps = None
self.schedule = None # sigma(t_i)
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)
self.num_inference_steps: int = None
self.timesteps: np.ndarray = None
self.schedule: torch.FloatTensor = None # sigma(t_i)
def set_timesteps(self, num_inference_steps: int):
"""
......@@ -108,20 +103,18 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
"""
self.num_inference_steps = num_inference_steps
self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
self.schedule = [
schedule = [
(
self.config.sigma_max**2
* (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1))
)
for i in self.timesteps
]
self.schedule = np.array(self.schedule, dtype=np.float32)
self.set_format(tensor_format=self.tensor_format)
self.schedule = torch.tensor(schedule, dtype=torch.float32)
def add_noise_to_input(
self, sample: Union[torch.FloatTensor, np.ndarray], sigma: float, generator: Optional[torch.Generator] = None
) -> Tuple[Union[torch.FloatTensor, np.ndarray], float]:
self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None
) -> Tuple[torch.FloatTensor, float]:
"""
Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
......@@ -142,10 +135,10 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
model_output: torch.FloatTensor,
sigma_hat: float,
sigma_prev: float,
sample_hat: Union[torch.FloatTensor, np.ndarray],
sample_hat: torch.FloatTensor,
return_dict: bool = True,
) -> Union[KarrasVeOutput, Tuple]:
"""
......@@ -153,10 +146,10 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
sigma_hat (`float`): TODO
sigma_prev (`float`): TODO
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
sample_hat (`torch.FloatTensor`): TODO
return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class
KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check).
......@@ -180,24 +173,24 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
def step_correct(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
model_output: torch.FloatTensor,
sigma_hat: float,
sigma_prev: float,
sample_hat: Union[torch.FloatTensor, np.ndarray],
sample_prev: Union[torch.FloatTensor, np.ndarray],
derivative: Union[torch.FloatTensor, np.ndarray],
sample_hat: torch.FloatTensor,
sample_prev: torch.FloatTensor,
derivative: torch.FloatTensor,
return_dict: bool = True,
) -> Union[KarrasVeOutput, Tuple]:
"""
Correct the predicted sample based on the output model_output of the network. TODO complete description
Args:
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
sigma_hat (`float`): TODO
sigma_prev (`float`): TODO
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO
derivative (`torch.FloatTensor` or `np.ndarray`): TODO
sample_hat (`torch.FloatTensor`): TODO
sample_prev (`torch.FloatTensor`): TODO
derivative (`torch.FloatTensor`): TODO
return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class
Returns:
......
......@@ -63,7 +63,6 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
`linear` or `scaled_linear`.
trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
"""
......@@ -75,31 +74,29 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
tensor_format: str = "pt",
):
if trained_betas is not None:
self.betas = np.asarray(trained_betas)
self.betas = torch.from_numpy(trained_betas)
if beta_schedule == "linear":
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
self.betas = (
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
# setable values
self.num_inference_steps = None
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
self.timesteps = np.arange(0, num_train_timesteps)[::-1] # to be consistent has to be smaller than sigmas by 1
self.derivatives = []
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)
def get_lms_coefficient(self, order, t, current_order):
"""
Compute a linear multistep coefficient.
......@@ -131,24 +128,24 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
the number of diffusion steps used when generating samples with a pre-trained model.
"""
self.num_inference_steps = num_inference_steps
self.timesteps = np.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=float)
timesteps = np.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=float)
low_idx = np.floor(self.timesteps).astype(int)
high_idx = np.ceil(self.timesteps).astype(int)
frac = np.mod(self.timesteps, 1.0)
low_idx = np.floor(timesteps).astype(int)
high_idx = np.ceil(timesteps).astype(int)
frac = np.mod(timesteps, 1.0)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx]
self.sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas)
self.timesteps = timesteps.astype(int)
self.derivatives = []
self.set_format(tensor_format=self.tensor_format)
def step(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
model_output: torch.FloatTensor,
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
sample: torch.FloatTensor,
order: int = 4,
return_dict: bool = True,
) -> Union[LMSDiscreteSchedulerOutput, Tuple]:
......@@ -157,9 +154,9 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor` or `np.ndarray`):
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
order: coefficient for multi-step inference.
return_dict (`bool`): option for returning tuple rather than LMSDiscreteSchedulerOutput class
......@@ -197,15 +194,18 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
def add_noise(
self,
original_samples: Union[torch.FloatTensor, np.ndarray],
noise: Union[torch.FloatTensor, np.ndarray],
timesteps: Union[torch.IntTensor, np.ndarray],
) -> Union[torch.FloatTensor, np.ndarray]:
if self.tensor_format == "pt":
timesteps = timesteps.to(self.sigmas.device)
sigmas = self.match_shape(self.sigmas[timesteps], noise)
noisy_samples = original_samples + noise * sigmas
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
sigmas = self.sigmas.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
sigma = sigmas[timesteps].flatten()
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)
noisy_samples = original_samples + noise * sigma
return noisy_samples
def __len__(self):
......
......@@ -132,7 +132,7 @@ class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
return state.replace(
num_inference_steps=num_inference_steps,
timesteps=timesteps,
timesteps=timesteps.astype(int),
derivatives=jnp.array([]),
sigmas=sigmas,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment