Unverified Commit f448360b authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Finish scheduler API (#91)

* finish

* up
parent 97e1e3ba
...@@ -438,8 +438,8 @@ class GlideTextToImageUNetModel(GlideUNetModel): ...@@ -438,8 +438,8 @@ class GlideTextToImageUNetModel(GlideUNetModel):
self.transformer_proj = nn.Linear(transformer_dim, self.model_channels * 4) self.transformer_proj = nn.Linear(transformer_dim, self.model_channels * 4)
def forward(self, sample, step_value, transformer_out=None): def forward(self, sample, timestep, transformer_out=None):
timesteps = step_value timesteps = timestep
x = sample x = sample
hs = [] hs = []
emb = self.time_embed( emb = self.time_embed(
...@@ -530,8 +530,8 @@ class GlideSuperResUNetModel(GlideUNetModel): ...@@ -530,8 +530,8 @@ class GlideSuperResUNetModel(GlideUNetModel):
resblock_updown=resblock_updown, resblock_updown=resblock_updown,
) )
def forward(self, sample, step_value, low_res=None): def forward(self, sample, timestep, low_res=None):
timesteps = step_value timesteps = timestep
x = sample x = sample
_, _, new_height, new_width = x.shape _, _, new_height, new_width = x.shape
upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear") upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
......
...@@ -323,8 +323,8 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -323,8 +323,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
self.all_modules = nn.ModuleList(modules) self.all_modules = nn.ModuleList(modules)
def forward(self, sample, step_value, sigmas=None): def forward(self, sample, timestep, sigmas=None):
timesteps = step_value timesteps = timestep
x = sample x = sample
# timestep/noise_level embedding; only for continuous training # timestep/noise_level embedding; only for continuous training
modules = self.all_modules modules = self.all_modules
......
...@@ -254,7 +254,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): ...@@ -254,7 +254,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
# ====================================== # ======================================
def forward( def forward(
self, sample: torch.FloatTensor, step_value: Union[torch.Tensor, float, int] self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int]
) -> Dict[str, torch.FloatTensor]: ) -> Dict[str, torch.FloatTensor]:
# TODO(PVP) - to delete later at release # TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API # IMPORTANT: NOT RELEVANT WHEN REVIEWING API
...@@ -263,10 +263,12 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin): ...@@ -263,10 +263,12 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
self.set_weights() self.set_weights()
# ====================================== # ======================================
# 1. time step embeddings # 1. time step embeddings -> make correct tensor
timesteps = step_value timesteps = timestep
if not torch.is_tensor(timesteps): if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
t_emb = get_timestep_embedding( t_emb = get_timestep_embedding(
timesteps, timesteps,
......
...@@ -22,19 +22,16 @@ from ...pipeline_utils import DiffusionPipeline ...@@ -22,19 +22,16 @@ from ...pipeline_utils import DiffusionPipeline
class DDIMPipeline(DiffusionPipeline): class DDIMPipeline(DiffusionPipeline):
def __init__(self, unet, noise_scheduler): def __init__(self, unet, scheduler):
super().__init__() super().__init__()
noise_scheduler = noise_scheduler.set_format("pt") scheduler = scheduler.set_format("pt")
self.register_modules(unet=unet, noise_scheduler=noise_scheduler) self.register_modules(unet=unet, scheduler=scheduler)
def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50): def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50):
# eta corresponds to η in paper and should be between [0, 1] # eta corresponds to η in paper and should be between [0, 1]
if torch_device is None: if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu" torch_device = "cuda" if torch.cuda.is_available() else "cpu"
num_trained_timesteps = self.noise_scheduler.config.timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
self.unet.to(torch_device) self.unet.to(torch_device)
# Sample gaussian noise to begin loop # Sample gaussian noise to begin loop
...@@ -44,34 +41,19 @@ class DDIMPipeline(DiffusionPipeline): ...@@ -44,34 +41,19 @@ class DDIMPipeline(DiffusionPipeline):
) )
image = image.to(torch_device) image = image.to(torch_device)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # set step values
# Ideally, read DDIM paper in-detail understanding self.scheduler.set_timesteps(num_inference_steps)
# Notation (<variable name> -> <name in paper> for t in tqdm.tqdm(self.scheduler.timesteps):
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
# 1. predict noise residual # 1. predict noise residual
with torch.no_grad(): with torch.no_grad():
residual = self.unet(image, inference_step_times[t]) residual = self.unet(image, t)
if isinstance(residual, dict): if isinstance(residual, dict):
residual = residual["sample"] residual = residual["sample"]
# 2. predict previous mean of image x_t-1 # 2. predict previous mean of image x_t-1 and add variance depending on eta
pred_prev_image = self.noise_scheduler.step(residual, image, t, num_inference_steps, eta) # do x_t -> x_t-1
image = self.scheduler.step(residual, t, image, eta)["prev_sample"]
# 3. optionally sample variance
variance = 0
if eta > 0:
noise = torch.randn(image.shape, generator=generator).to(image.device)
variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise
# 4. set current image to prev_image: x_t -> x_t-1
image = pred_prev_image + variance
return image return {"sample": image}
...@@ -22,10 +22,10 @@ from ...pipeline_utils import DiffusionPipeline ...@@ -22,10 +22,10 @@ from ...pipeline_utils import DiffusionPipeline
class DDPMPipeline(DiffusionPipeline): class DDPMPipeline(DiffusionPipeline):
def __init__(self, unet, noise_scheduler): def __init__(self, unet, scheduler):
super().__init__() super().__init__()
noise_scheduler = noise_scheduler.set_format("pt") scheduler = scheduler.set_format("pt")
self.register_modules(unet=unet, noise_scheduler=noise_scheduler) self.register_modules(unet=unet, scheduler=scheduler)
def __call__(self, batch_size=1, generator=None, torch_device=None): def __call__(self, batch_size=1, generator=None, torch_device=None):
if torch_device is None: if torch_device is None:
...@@ -40,7 +40,7 @@ class DDPMPipeline(DiffusionPipeline): ...@@ -40,7 +40,7 @@ class DDPMPipeline(DiffusionPipeline):
) )
image = image.to(torch_device) image = image.to(torch_device)
num_prediction_steps = len(self.noise_scheduler) num_prediction_steps = len(self.scheduler)
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps): for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
# 1. predict noise residual # 1. predict noise residual
with torch.no_grad(): with torch.no_grad():
...@@ -50,13 +50,13 @@ class DDPMPipeline(DiffusionPipeline): ...@@ -50,13 +50,13 @@ class DDPMPipeline(DiffusionPipeline):
residual = residual["sample"] residual = residual["sample"]
# 2. predict previous mean of image x_t-1 # 2. predict previous mean of image x_t-1
pred_prev_image = self.noise_scheduler.step(residual, image, t) pred_prev_image = self.scheduler.step(residual, t, image)["prev_sample"]
# 3. optionally sample variance # 3. optionally sample variance
variance = 0 variance = 0
if t > 0: if t > 0:
noise = torch.randn(image.shape, generator=generator).to(image.device) noise = torch.randn(image.shape, generator=generator).to(image.device)
variance = self.noise_scheduler.get_variance(t).sqrt() * noise variance = self.scheduler.get_variance(t).sqrt() * noise
# 4. set current image to prev_image: x_t -> x_t-1 # 4. set current image to prev_image: x_t -> x_t-1
image = pred_prev_image + variance image = pred_prev_image + variance
......
...@@ -713,20 +713,20 @@ class GlidePipeline(DiffusionPipeline): ...@@ -713,20 +713,20 @@ class GlidePipeline(DiffusionPipeline):
def __init__( def __init__(
self, self,
text_unet: GlideTextToImageUNetModel, text_unet: GlideTextToImageUNetModel,
text_noise_scheduler: DDPMScheduler, text_scheduler: DDPMScheduler,
text_encoder: CLIPTextModel, text_encoder: CLIPTextModel,
tokenizer: GPT2Tokenizer, tokenizer: GPT2Tokenizer,
upscale_unet: GlideSuperResUNetModel, upscale_unet: GlideSuperResUNetModel,
upscale_noise_scheduler: DDIMScheduler, upscale_scheduler: DDIMScheduler,
): ):
super().__init__() super().__init__()
self.register_modules( self.register_modules(
text_unet=text_unet, text_unet=text_unet,
text_noise_scheduler=text_noise_scheduler, text_scheduler=text_scheduler,
text_encoder=text_encoder, text_encoder=text_encoder,
tokenizer=tokenizer, tokenizer=tokenizer,
upscale_unet=upscale_unet, upscale_unet=upscale_unet,
upscale_noise_scheduler=upscale_noise_scheduler, upscale_scheduler=upscale_scheduler,
) )
@torch.no_grad() @torch.no_grad()
...@@ -777,20 +777,20 @@ class GlidePipeline(DiffusionPipeline): ...@@ -777,20 +777,20 @@ class GlidePipeline(DiffusionPipeline):
transformer_out = self.text_encoder(input_ids, attention_mask).last_hidden_state transformer_out = self.text_encoder(input_ids, attention_mask).last_hidden_state
# 3. Run the text2image generation step # 3. Run the text2image generation step
num_prediction_steps = len(self.text_noise_scheduler) num_prediction_steps = len(self.text_scheduler)
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps): for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
with torch.no_grad(): with torch.no_grad():
time_input = torch.tensor([t] * image.shape[0], device=torch_device) time_input = torch.tensor([t] * image.shape[0], device=torch_device)
model_output = text_model_fn(image, time_input, transformer_out) model_output = text_model_fn(image, time_input, transformer_out)
noise_residual, model_var_values = torch.split(model_output, 3, dim=1) noise_residual, model_var_values = torch.split(model_output, 3, dim=1)
min_log = self.text_noise_scheduler.get_variance(t, "fixed_small_log") min_log = self.text_scheduler.get_variance(t, "fixed_small_log")
max_log = self.text_noise_scheduler.get_variance(t, "fixed_large_log") max_log = self.text_scheduler.get_variance(t, "fixed_large_log")
# The model_var_values is [-1, 1] for [min_var, max_var]. # The model_var_values is [-1, 1] for [min_var, max_var].
frac = (model_var_values + 1) / 2 frac = (model_var_values + 1) / 2
model_log_variance = frac * max_log + (1 - frac) * min_log model_log_variance = frac * max_log + (1 - frac) * min_log
pred_prev_image = self.text_noise_scheduler.step(noise_residual, image, t) pred_prev_image = self.text_scheduler.step(noise_residual, image, t)
noise = torch.randn(image.shape, generator=generator).to(torch_device) noise = torch.randn(image.shape, generator=generator).to(torch_device)
variance = torch.exp(0.5 * model_log_variance) * noise variance = torch.exp(0.5 * model_log_variance) * noise
...@@ -814,7 +814,7 @@ class GlidePipeline(DiffusionPipeline): ...@@ -814,7 +814,7 @@ class GlidePipeline(DiffusionPipeline):
).to(torch_device) ).to(torch_device)
image = image * upsample_temp image = image * upsample_temp
num_trained_timesteps = self.upscale_noise_scheduler.timesteps num_trained_timesteps = self.upscale_scheduler.timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps_upscale) inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps_upscale)
for t in tqdm.tqdm(reversed(range(num_inference_steps_upscale)), total=num_inference_steps_upscale): for t in tqdm.tqdm(reversed(range(num_inference_steps_upscale)), total=num_inference_steps_upscale):
...@@ -825,7 +825,7 @@ class GlidePipeline(DiffusionPipeline): ...@@ -825,7 +825,7 @@ class GlidePipeline(DiffusionPipeline):
noise_residual, pred_variance = torch.split(model_output, 3, dim=1) noise_residual, pred_variance = torch.split(model_output, 3, dim=1)
# 2. predict previous mean of image x_t-1 # 2. predict previous mean of image x_t-1
pred_prev_image = self.upscale_noise_scheduler.step( pred_prev_image = self.upscale_scheduler.step(
noise_residual, image, t, num_inference_steps_upscale, eta, use_clipped_residual=True noise_residual, image, t, num_inference_steps_upscale, eta, use_clipped_residual=True
) )
...@@ -833,9 +833,7 @@ class GlidePipeline(DiffusionPipeline): ...@@ -833,9 +833,7 @@ class GlidePipeline(DiffusionPipeline):
variance = 0 variance = 0
if eta > 0: if eta > 0:
noise = torch.randn(image.shape, generator=generator).to(torch_device) noise = torch.randn(image.shape, generator=generator).to(torch_device)
variance = ( variance = self.upscale_scheduler.get_variance(t, num_inference_steps_upscale).sqrt() * eta * noise
self.upscale_noise_scheduler.get_variance(t, num_inference_steps_upscale).sqrt() * eta * noise
)
# 4. set current image to prev_image: x_t -> x_t-1 # 4. set current image to prev_image: x_t -> x_t-1
image = pred_prev_image + variance image = pred_prev_image + variance
......
...@@ -545,10 +545,10 @@ class LDMBertModel(LDMBertPreTrainedModel): ...@@ -545,10 +545,10 @@ class LDMBertModel(LDMBertPreTrainedModel):
class LatentDiffusionPipeline(DiffusionPipeline): class LatentDiffusionPipeline(DiffusionPipeline):
def __init__(self, vqvae, bert, tokenizer, unet, noise_scheduler): def __init__(self, vqvae, bert, tokenizer, unet, scheduler):
super().__init__() super().__init__()
noise_scheduler = noise_scheduler.set_format("pt") scheduler = scheduler.set_format("pt")
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, noise_scheduler=noise_scheduler) self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
...@@ -581,7 +581,7 @@ class LatentDiffusionPipeline(DiffusionPipeline): ...@@ -581,7 +581,7 @@ class LatentDiffusionPipeline(DiffusionPipeline):
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device) text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device)
text_embedding = self.bert(text_input.input_ids) text_embedding = self.bert(text_input.input_ids)
num_trained_timesteps = self.noise_scheduler.config.timesteps num_trained_timesteps = self.scheduler.config.timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps) inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
image = torch.randn( image = torch.randn(
...@@ -622,13 +622,13 @@ class LatentDiffusionPipeline(DiffusionPipeline): ...@@ -622,13 +622,13 @@ class LatentDiffusionPipeline(DiffusionPipeline):
pred_noise_t = pred_noise_t_uncond + guidance_scale * (pred_noise_t - pred_noise_t_uncond) pred_noise_t = pred_noise_t_uncond + guidance_scale * (pred_noise_t - pred_noise_t_uncond)
# 2. predict previous mean of image x_t-1 # 2. predict previous mean of image x_t-1
pred_prev_image = self.noise_scheduler.step(pred_noise_t, image, t, num_inference_steps, eta) pred_prev_image = self.scheduler.step(pred_noise_t, image, t, num_inference_steps, eta)
# 3. optionally sample variance # 3. optionally sample variance
variance = 0 variance = 0
if eta > 0: if eta > 0:
noise = torch.randn(image.shape, generator=generator).to(image.device) noise = torch.randn(image.shape, generator=generator).to(image.device)
variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise variance = self.scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise
# 4. set current image to prev_image: x_t -> x_t-1 # 4. set current image to prev_image: x_t -> x_t-1
image = pred_prev_image + variance image = pred_prev_image + variance
......
...@@ -6,10 +6,10 @@ from ...pipeline_utils import DiffusionPipeline ...@@ -6,10 +6,10 @@ from ...pipeline_utils import DiffusionPipeline
class LatentDiffusionUncondPipeline(DiffusionPipeline): class LatentDiffusionUncondPipeline(DiffusionPipeline):
def __init__(self, vqvae, unet, noise_scheduler): def __init__(self, vqvae, unet, scheduler):
super().__init__() super().__init__()
noise_scheduler = noise_scheduler.set_format("pt") scheduler = scheduler.set_format("pt")
self.register_modules(vqvae=vqvae, unet=unet, noise_scheduler=noise_scheduler) self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler)
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
...@@ -28,44 +28,23 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline): ...@@ -28,44 +28,23 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
self.unet.to(torch_device) self.unet.to(torch_device)
self.vqvae.to(torch_device) self.vqvae.to(torch_device)
num_trained_timesteps = self.noise_scheduler.config.timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
image = torch.randn( image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size), (batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
generator=generator, generator=generator,
).to(torch_device) ).to(torch_device)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf self.scheduler.set_timesteps(num_inference_steps)
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
# 1. predict noise residual
timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
pred_noise_t = self.unet(image, timesteps)
if isinstance(pred_noise_t, dict):
pred_noise_t = pred_noise_t["sample"]
# 2. predict previous mean of image x_t-1 for t in tqdm.tqdm(self.scheduler.timesteps):
pred_prev_image = self.noise_scheduler.step(pred_noise_t, image, t, num_inference_steps, eta) residual = self.unet(image, t)
# 3. optionally sample variance if isinstance(residual, dict):
variance = 0 residual = residual["sample"]
if eta > 0:
noise = torch.randn(image.shape, generator=generator).to(image.device)
variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise
# 4. set current image to prev_image: x_t -> x_t-1 # 2. predict previous mean of image x_t-1 and add variance depending on eta
image = pred_prev_image + variance # do x_t -> x_t-1
image = self.scheduler.step(residual, t, image, eta)["prev_sample"]
# decode image with vae # decode image with vae
image = self.vqvae.decode(image) image = self.vqvae.decode(image)
return image return {"sample": image}
...@@ -22,10 +22,10 @@ from ...pipeline_utils import DiffusionPipeline ...@@ -22,10 +22,10 @@ from ...pipeline_utils import DiffusionPipeline
class PNDMPipeline(DiffusionPipeline): class PNDMPipeline(DiffusionPipeline):
def __init__(self, unet, noise_scheduler): def __init__(self, unet, scheduler):
super().__init__() super().__init__()
noise_scheduler = noise_scheduler.set_format("pt") scheduler = scheduler.set_format("pt")
self.register_modules(unet=unet, noise_scheduler=noise_scheduler) self.register_modules(unet=unet, scheduler=scheduler)
def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50): def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50):
# For more information on the sampling method you can take a look at Algorithm 2 of # For more information on the sampling method you can take a look at Algorithm 2 of
...@@ -42,7 +42,7 @@ class PNDMPipeline(DiffusionPipeline): ...@@ -42,7 +42,7 @@ class PNDMPipeline(DiffusionPipeline):
) )
image = image.to(torch_device) image = image.to(torch_device)
prk_time_steps = self.noise_scheduler.get_prk_time_steps(num_inference_steps) prk_time_steps = self.scheduler.get_prk_time_steps(num_inference_steps)
for t in tqdm.tqdm(range(len(prk_time_steps))): for t in tqdm.tqdm(range(len(prk_time_steps))):
t_orig = prk_time_steps[t] t_orig = prk_time_steps[t]
residual = self.unet(image, t_orig) residual = self.unet(image, t_orig)
...@@ -50,9 +50,9 @@ class PNDMPipeline(DiffusionPipeline): ...@@ -50,9 +50,9 @@ class PNDMPipeline(DiffusionPipeline):
if isinstance(residual, dict): if isinstance(residual, dict):
residual = residual["sample"] residual = residual["sample"]
image = self.noise_scheduler.step_prk(residual, image, t, num_inference_steps) image = self.scheduler.step_prk(residual, t, image, num_inference_steps)["prev_sample"]
timesteps = self.noise_scheduler.get_time_steps(num_inference_steps) timesteps = self.scheduler.get_time_steps(num_inference_steps)
for t in tqdm.tqdm(range(len(timesteps))): for t in tqdm.tqdm(range(len(timesteps))):
t_orig = timesteps[t] t_orig = timesteps[t]
residual = self.unet(image, t_orig) residual = self.unet(image, t_orig)
...@@ -60,6 +60,6 @@ class PNDMPipeline(DiffusionPipeline): ...@@ -60,6 +60,6 @@ class PNDMPipeline(DiffusionPipeline):
if isinstance(residual, dict): if isinstance(residual, dict):
residual = residual["sample"] residual = residual["sample"]
image = self.noise_scheduler.step_plms(residual, image, t, num_inference_steps) image = self.scheduler.step_plms(residual, t, image, num_inference_steps)["prev_sample"]
return image return image
...@@ -16,8 +16,10 @@ ...@@ -16,8 +16,10 @@
# and https://github.com/hojonathanho/diffusion # and https://github.com/hojonathanho/diffusion
import math import math
from typing import Union
import numpy as np import numpy as np
import torch
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
...@@ -84,14 +86,16 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -84,14 +86,16 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
self.alphas_cumprod = np.cumprod(self.alphas, axis=0) self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
self.one = np.array(1.0) self.one = np.array(1.0)
self.set_format(tensor_format=tensor_format) # setable values
self.num_inference_steps = None
self.timesteps = np.arange(0, self.config.timesteps)[::-1].copy()
def get_variance(self, t, num_inference_steps): self.tensor_format = tensor_format
orig_t = self.config.timesteps // num_inference_steps * t self.set_format(tensor_format=tensor_format)
orig_prev_t = self.config.timesteps // num_inference_steps * (t - 1) if t > 0 else -1
alpha_prod_t = self.alphas_cumprod[orig_t] def _get_variance(self, timestep, prev_timestep):
alpha_prod_t_prev = self.alphas_cumprod[orig_prev_t] if orig_prev_t >= 0 else self.one alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one
beta_prod_t = 1 - alpha_prod_t beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev beta_prod_t_prev = 1 - alpha_prod_t_prev
...@@ -99,7 +103,22 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -99,7 +103,22 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
return variance return variance
def step(self, residual, sample, t, num_inference_steps, eta, use_clipped_residual=False): def set_timesteps(self, num_inference_steps):
self.num_inference_steps = num_inference_steps
self.timesteps = np.arange(0, self.config.timesteps, self.config.timesteps // self.num_inference_steps)[
::-1
].copy()
self.set_format(tensor_format=self.tensor_format)
def step(
self,
residual: Union[torch.FloatTensor, np.ndarray],
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
eta,
use_clipped_residual=False,
generator=None,
):
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding # Ideally, read DDIM paper in-detail understanding
...@@ -111,13 +130,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -111,13 +130,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# - pred_sample_direction -> "direction pointingc to x_t" # - pred_sample_direction -> "direction pointingc to x_t"
# - pred_prev_sample -> "x_t-1" # - pred_prev_sample -> "x_t-1"
# 1. get actual t and t-1 # 1. get previous step value (=t-1)
orig_t = self.config.timesteps // num_inference_steps * t prev_timestep = timestep - self.config.timesteps // self.num_inference_steps
orig_prev_t = self.config.timesteps // num_inference_steps * (t - 1) if t > 0 else -1
# 2. compute alphas, betas # 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[orig_t] alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[orig_prev_t] if orig_prev_t >= 0 else self.one alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one
beta_prod_t = 1 - alpha_prod_t beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called # 3. compute predicted original sample from predicted noise also called
...@@ -130,7 +148,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -130,7 +148,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# 5. compute variance: "sigma_t(η)" -> see formula (16) # 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
variance = self.get_variance(t, num_inference_steps) variance = self._get_variance(timestep, prev_timestep)
std_dev_t = eta * variance ** (0.5) std_dev_t = eta * variance ** (0.5)
if use_clipped_residual: if use_clipped_residual:
...@@ -141,9 +159,19 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -141,9 +159,19 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * residual pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * residual
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
if eta > 0:
device = residual.device if torch.is_tensor(residual) else "cpu"
noise = torch.randn(residual.shape, generator=generator).to(device)
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
if not torch.is_tensor(residual):
variance = variance.numpy()
prev_sample = prev_sample + variance
return pred_prev_sample return {"prev_sample": prev_sample}
def add_noise(self, original_samples, noise, timesteps): def add_noise(self, original_samples, noise, timesteps):
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
......
...@@ -15,8 +15,10 @@ ...@@ -15,8 +15,10 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math import math
from typing import Union
import numpy as np import numpy as np
import torch
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
...@@ -112,7 +114,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -112,7 +114,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return variance return variance
def step(self, residual, sample, t, predict_epsilon=True): def step(
self,
residual: Union[torch.FloatTensor, np.ndarray],
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
predict_epsilon=True,
):
t = timestep
# 1. compute alphas, betas # 1. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
...@@ -139,7 +148,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -139,7 +148,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
return pred_prev_sample return {"prev_sample": pred_prev_sample}
def add_noise(self, original_samples, noise, timesteps): def add_noise(self, original_samples, noise, timesteps):
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
......
...@@ -15,8 +15,10 @@ ...@@ -15,8 +15,10 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math import math
from typing import Union
import numpy as np import numpy as np
import torch
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
...@@ -126,7 +128,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -126,7 +128,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
raise ValueError(f"mode {self.mode} does not exist.") raise ValueError(f"mode {self.mode} does not exist.")
def step_prk(self, residual, sample, t, num_inference_steps): def step_prk(
self,
residual: Union[torch.FloatTensor, np.ndarray],
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
num_inference_steps,
):
t = timestep
prk_time_steps = self.get_prk_time_steps(num_inference_steps) prk_time_steps = self.get_prk_time_steps(num_inference_steps)
t_orig = prk_time_steps[t // 4 * 4] t_orig = prk_time_steps[t // 4 * 4]
...@@ -147,9 +156,16 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -147,9 +156,16 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
# cur_sample should not be `None` # cur_sample should not be `None`
cur_sample = self.cur_sample if self.cur_sample is not None else sample cur_sample = self.cur_sample if self.cur_sample is not None else sample
return self.get_prev_sample(cur_sample, t_orig, t_orig_prev, residual) return {"prev_sample": self.get_prev_sample(cur_sample, t_orig, t_orig_prev, residual)}
def step_plms(self, residual, sample, t, num_inference_steps): def step_plms(
self,
residual: Union[torch.FloatTensor, np.ndarray],
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
num_inference_steps,
):
t = timestep
if len(self.ets) < 3: if len(self.ets) < 3:
raise ValueError( raise ValueError(
f"{self.__class__} can only be run AFTER scheduler has been run " f"{self.__class__} can only be run AFTER scheduler has been run "
...@@ -166,7 +182,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -166,7 +182,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
residual = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4]) residual = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
return self.get_prev_sample(sample, t_orig, t_orig_prev, residual) return {"prev_sample": self.get_prev_sample(sample, t_orig, t_orig_prev, residual)}
def get_prev_sample(self, sample, t_orig, t_orig_prev, residual): def get_prev_sample(self, sample, t_orig, t_orig_prev, residual):
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
......
...@@ -165,7 +165,7 @@ class ModelTesterMixin: ...@@ -165,7 +165,7 @@ class ModelTesterMixin:
# signature.parameters is an OrderedDict => so arg_names order is deterministic # signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()] arg_names = [*signature.parameters.keys()]
expected_arg_names = ["sample", "step_value"] expected_arg_names = ["sample", "timestep"]
self.assertListEqual(arg_names[:2], expected_arg_names) self.assertListEqual(arg_names[:2], expected_arg_names)
def test_model_from_config(self): def test_model_from_config(self):
...@@ -248,7 +248,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -248,7 +248,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device) time_step = torch.tensor([10]).to(torch_device)
return {"sample": noise, "step_value": time_step} return {"sample": noise, "timestep": time_step}
@property @property
def input_shape(self): def input_shape(self):
...@@ -323,7 +323,7 @@ class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase): ...@@ -323,7 +323,7 @@ class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase):
low_res = torch.randn((batch_size, 3) + low_res_size).to(torch_device) low_res = torch.randn((batch_size, 3) + low_res_size).to(torch_device)
time_step = torch.tensor([10] * noise.shape[0], device=torch_device) time_step = torch.tensor([10] * noise.shape[0], device=torch_device)
return {"sample": noise, "step_value": time_step, "low_res": low_res} return {"sample": noise, "timestep": time_step, "low_res": low_res}
@property @property
def input_shape(self): def input_shape(self):
...@@ -414,7 +414,7 @@ class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -414,7 +414,7 @@ class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
emb = torch.randn((batch_size, seq_len, transformer_dim)).to(torch_device) emb = torch.randn((batch_size, seq_len, transformer_dim)).to(torch_device)
time_step = torch.tensor([10] * noise.shape[0], device=torch_device) time_step = torch.tensor([10] * noise.shape[0], device=torch_device)
return {"sample": noise, "step_value": time_step, "transformer_out": emb} return {"sample": noise, "timestep": time_step, "transformer_out": emb}
@property @property
def input_shape(self): def input_shape(self):
...@@ -506,7 +506,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -506,7 +506,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device) time_step = torch.tensor([10]).to(torch_device)
return {"sample": noise, "step_value": time_step} return {"sample": noise, "timestep": time_step}
@property @property
def input_shape(self): def input_shape(self):
...@@ -601,7 +601,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -601,7 +601,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor(batch_size * [10]).to(torch_device) time_step = torch.tensor(batch_size * [10]).to(torch_device)
return {"sample": noise, "step_value": time_step} return {"sample": noise, "timestep": time_step}
@property @property
def input_shape(self): def input_shape(self):
...@@ -899,8 +899,8 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -899,8 +899,8 @@ class PipelineTesterMixin(unittest.TestCase):
ddpm = DDPMPipeline.from_pretrained(model_path) ddpm = DDPMPipeline.from_pretrained(model_path)
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path) ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path)
ddpm.noise_scheduler.num_timesteps = 10 ddpm.scheduler.num_timesteps = 10
ddpm_from_hub.noise_scheduler.num_timesteps = 10 ddpm_from_hub.scheduler.num_timesteps = 10
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
...@@ -915,10 +915,10 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -915,10 +915,10 @@ class PipelineTesterMixin(unittest.TestCase):
model_id = "fusing/ddpm-cifar10" model_id = "fusing/ddpm-cifar10"
unet = UNetUnconditionalModel.from_pretrained(model_id, ddpm=True) unet = UNetUnconditionalModel.from_pretrained(model_id, ddpm=True)
noise_scheduler = DDPMScheduler.from_config(model_id) scheduler = DDPMScheduler.from_config(model_id)
noise_scheduler = noise_scheduler.set_format("pt") scheduler = scheduler.set_format("pt")
ddpm = DDPMPipeline(unet=unet, noise_scheduler=noise_scheduler) ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = ddpm(generator=generator) image = ddpm(generator=generator)
...@@ -936,13 +936,12 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -936,13 +936,12 @@ class PipelineTesterMixin(unittest.TestCase):
model_id = "fusing/ddpm-lsun-bedroom-ema" model_id = "fusing/ddpm-lsun-bedroom-ema"
unet = UNetUnconditionalModel.from_pretrained(model_id, ddpm=True) unet = UNetUnconditionalModel.from_pretrained(model_id, ddpm=True)
noise_scheduler = DDIMScheduler.from_config(model_id) scheduler = DDIMScheduler.from_config(model_id)
noise_scheduler = noise_scheduler.set_format("pt")
ddpm = DDIMPipeline(unet=unet, noise_scheduler=noise_scheduler) ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = ddpm(generator=generator) image = ddpm(generator=generator)["sample"]
image_slice = image[0, -1, -3:, -3:].cpu() image_slice = image[0, -1, -3:, -3:].cpu()
...@@ -957,12 +956,12 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -957,12 +956,12 @@ class PipelineTesterMixin(unittest.TestCase):
model_id = "fusing/ddpm-cifar10" model_id = "fusing/ddpm-cifar10"
unet = UNetUnconditionalModel.from_pretrained(model_id, ddpm=True) unet = UNetUnconditionalModel.from_pretrained(model_id, ddpm=True)
noise_scheduler = DDIMScheduler(tensor_format="pt") scheduler = DDIMScheduler(tensor_format="pt")
ddim = DDIMPipeline(unet=unet, noise_scheduler=noise_scheduler) ddim = DDIMPipeline(unet=unet, scheduler=scheduler)
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = ddim(generator=generator, eta=0.0) image = ddim(generator=generator, eta=0.0)["sample"]
image_slice = image[0, -1, -3:, -3:].cpu() image_slice = image[0, -1, -3:, -3:].cpu()
...@@ -977,9 +976,9 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -977,9 +976,9 @@ class PipelineTesterMixin(unittest.TestCase):
model_id = "fusing/ddpm-cifar10" model_id = "fusing/ddpm-cifar10"
unet = UNetUnconditionalModel.from_pretrained(model_id, ddpm=True) unet = UNetUnconditionalModel.from_pretrained(model_id, ddpm=True)
noise_scheduler = PNDMScheduler(tensor_format="pt") scheduler = PNDMScheduler(tensor_format="pt")
pndm = PNDMPipeline(unet=unet, noise_scheduler=noise_scheduler) pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = pndm(generator=generator) image = pndm(generator=generator)
...@@ -1074,7 +1073,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -1074,7 +1073,7 @@ class PipelineTesterMixin(unittest.TestCase):
ldm = LatentDiffusionUncondPipeline.from_pretrained("fusing/latent-diffusion-celeba-256", ldm=True) ldm = LatentDiffusionUncondPipeline.from_pretrained("fusing/latent-diffusion-celeba-256", ldm=True)
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = ldm(generator=generator, num_inference_steps=5) image = ldm(generator=generator, num_inference_steps=5)["sample"]
image_slice = image[0, -1, -3:, -3:].cpu() image_slice = image[0, -1, -3:, -3:].cpu()
......
...@@ -68,6 +68,8 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -68,6 +68,8 @@ class SchedulerCommonTest(unittest.TestCase):
def check_over_configs(self, time_step=0, **config): def check_over_configs(self, time_step=0, **config):
kwargs = dict(self.forward_default_kwargs) kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes: for scheduler_class in self.scheduler_classes:
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
sample = self.dummy_sample sample = self.dummy_sample
...@@ -80,8 +82,14 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -80,8 +82,14 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler.save_config(tmpdirname) scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname) new_scheduler = scheduler_class.from_config(tmpdirname)
output = scheduler.step(residual, sample, time_step, **kwargs) if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
new_output = new_scheduler.step(residual, sample, time_step, **kwargs) scheduler.set_timesteps(num_inference_steps)
new_scheduler.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
...@@ -89,6 +97,8 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -89,6 +97,8 @@ class SchedulerCommonTest(unittest.TestCase):
kwargs = dict(self.forward_default_kwargs) kwargs = dict(self.forward_default_kwargs)
kwargs.update(forward_kwargs) kwargs.update(forward_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes: for scheduler_class in self.scheduler_classes:
sample = self.dummy_sample sample = self.dummy_sample
residual = 0.1 * sample residual = 0.1 * sample
...@@ -101,14 +111,24 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -101,14 +111,24 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler.save_config(tmpdirname) scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname) new_scheduler = scheduler_class.from_config(tmpdirname)
output = scheduler.step(residual, sample, time_step, **kwargs) if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
new_output = new_scheduler.step(residual, sample, time_step, **kwargs) scheduler.set_timesteps(num_inference_steps)
new_scheduler.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
torch.manual_seed(0)
output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
torch.manual_seed(0)
new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def test_from_pretrained_save_pretrained(self): def test_from_pretrained_save_pretrained(self):
kwargs = dict(self.forward_default_kwargs) kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes: for scheduler_class in self.scheduler_classes:
sample = self.dummy_sample sample = self.dummy_sample
residual = 0.1 * sample residual = 0.1 * sample
...@@ -120,14 +140,22 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -120,14 +140,22 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler.save_config(tmpdirname) scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname) new_scheduler = scheduler_class.from_config(tmpdirname)
output = scheduler.step(residual, sample, 1, **kwargs) if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
new_output = new_scheduler.step(residual, sample, 1, **kwargs) scheduler.set_timesteps(num_inference_steps)
new_scheduler.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output = scheduler.step(residual, 1, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step(residual, 1, sample, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def test_step_shape(self): def test_step_shape(self):
kwargs = dict(self.forward_default_kwargs) kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes: for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
...@@ -135,8 +163,13 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -135,8 +163,13 @@ class SchedulerCommonTest(unittest.TestCase):
sample = self.dummy_sample sample = self.dummy_sample
residual = 0.1 * sample residual = 0.1 * sample
output_0 = scheduler.step(residual, sample, 0, **kwargs) if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
output_1 = scheduler.step(residual, sample, 1, **kwargs) scheduler.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output_0 = scheduler.step(residual, 0, sample, **kwargs)["prev_sample"]
output_1 = scheduler.step(residual, 1, sample, **kwargs)["prev_sample"]
self.assertEqual(output_0.shape, sample.shape) self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape) self.assertEqual(output_0.shape, output_1.shape)
...@@ -144,6 +177,8 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -144,6 +177,8 @@ class SchedulerCommonTest(unittest.TestCase):
def test_pytorch_equal_numpy(self): def test_pytorch_equal_numpy(self):
kwargs = dict(self.forward_default_kwargs) kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes: for scheduler_class in self.scheduler_classes:
sample = self.dummy_sample sample = self.dummy_sample
residual = 0.1 * sample residual = 0.1 * sample
...@@ -156,8 +191,14 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -156,8 +191,14 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config) scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)
output = scheduler.step(residual, sample, 1, **kwargs) if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
output_pt = scheduler_pt.step(residual_pt, sample_pt, 1, **kwargs) scheduler.set_timesteps(num_inference_steps)
scheduler_pt.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output = scheduler.step(residual, 1, sample, **kwargs)["prev_sample"]
output_pt = scheduler_pt.step(residual_pt, 1, sample_pt, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical" assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
...@@ -226,7 +267,7 @@ class DDPMSchedulerTest(SchedulerCommonTest): ...@@ -226,7 +267,7 @@ class DDPMSchedulerTest(SchedulerCommonTest):
residual = model(sample, t) residual = model(sample, t)
# 2. predict previous mean of sample x_t-1 # 2. predict previous mean of sample x_t-1
pred_prev_sample = scheduler.step(residual, sample, t) pred_prev_sample = scheduler.step(residual, t, sample)["prev_sample"]
if t > 0: if t > 0:
noise = self.dummy_sample_deter noise = self.dummy_sample_deter
...@@ -243,7 +284,7 @@ class DDPMSchedulerTest(SchedulerCommonTest): ...@@ -243,7 +284,7 @@ class DDPMSchedulerTest(SchedulerCommonTest):
class DDIMSchedulerTest(SchedulerCommonTest): class DDIMSchedulerTest(SchedulerCommonTest):
scheduler_classes = (DDIMScheduler,) scheduler_classes = (DDIMScheduler,)
forward_default_kwargs = (("num_inference_steps", 50), ("eta", 0.0)) forward_default_kwargs = (("eta", 0.0), ("num_inference_steps", 50))
def get_scheduler_config(self, **kwargs): def get_scheduler_config(self, **kwargs):
config = { config = {
...@@ -258,7 +299,7 @@ class DDIMSchedulerTest(SchedulerCommonTest): ...@@ -258,7 +299,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
return config return config
def test_timesteps(self): def test_timesteps(self):
for timesteps in [1, 5, 100, 1000]: for timesteps in [100, 500, 1000]:
self.check_over_configs(timesteps=timesteps) self.check_over_configs(timesteps=timesteps)
def test_betas(self): def test_betas(self):
...@@ -279,7 +320,7 @@ class DDIMSchedulerTest(SchedulerCommonTest): ...@@ -279,7 +320,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
def test_inference_steps(self): def test_inference_steps(self):
for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]): for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]):
self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps) self.check_over_forward(num_inference_steps=num_inference_steps)
def test_eta(self): def test_eta(self):
for t, eta in zip([1, 10, 49], [0.0, 0.5, 1.0]): for t, eta in zip([1, 10, 49], [0.0, 0.5, 1.0]):
...@@ -290,43 +331,34 @@ class DDIMSchedulerTest(SchedulerCommonTest): ...@@ -290,43 +331,34 @@ class DDIMSchedulerTest(SchedulerCommonTest):
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
assert np.sum(np.abs(scheduler.get_variance(0, 50) - 0.0)) < 1e-5 assert np.sum(np.abs(scheduler._get_variance(0, 0) - 0.0)) < 1e-5
assert np.sum(np.abs(scheduler.get_variance(21, 50) - 0.14771)) < 1e-5 assert np.sum(np.abs(scheduler._get_variance(420, 400) - 0.14771)) < 1e-5
assert np.sum(np.abs(scheduler.get_variance(49, 50) - 0.32460)) < 1e-5 assert np.sum(np.abs(scheduler._get_variance(980, 960) - 0.32460)) < 1e-5
assert np.sum(np.abs(scheduler.get_variance(0, 1000) - 0.0)) < 1e-5 assert np.sum(np.abs(scheduler._get_variance(0, 0) - 0.0)) < 1e-5
assert np.sum(np.abs(scheduler.get_variance(487, 1000) - 0.00979)) < 1e-5 assert np.sum(np.abs(scheduler._get_variance(487, 486) - 0.00979)) < 1e-5
assert np.sum(np.abs(scheduler.get_variance(999, 1000) - 0.02)) < 1e-5 assert np.sum(np.abs(scheduler._get_variance(999, 998) - 0.02)) < 1e-5
def test_full_loop_no_noise(self): def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
num_inference_steps, eta = 10, 0.1 num_inference_steps, eta = 10, 0.0
num_trained_timesteps = len(scheduler)
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
model = self.dummy_model() model = self.dummy_model()
sample = self.dummy_sample_deter sample = self.dummy_sample_deter
for t in reversed(range(num_inference_steps)): scheduler.set_timesteps(num_inference_steps)
residual = model(sample, inference_step_times[t]) for t in scheduler.timesteps:
residual = model(sample, t)
pred_prev_sample = scheduler.step(residual, sample, t, num_inference_steps, eta)
variance = 0
if eta > 0:
noise = self.dummy_sample_deter
variance = scheduler.get_variance(t, num_inference_steps) ** (0.5) * eta * noise
sample = pred_prev_sample + variance sample = scheduler.step(residual, t, sample, eta)["prev_sample"]
result_sum = np.sum(np.abs(sample)) result_sum = np.sum(np.abs(sample))
result_mean = np.mean(np.abs(sample)) result_mean = np.mean(np.abs(sample))
assert abs(result_sum.item() - 270.6214) < 1e-2 assert abs(result_sum.item() - 172.0067) < 1e-2
assert abs(result_mean.item() - 0.3524) < 1e-3 assert abs(result_mean.item() - 0.223967) < 1e-3
class PNDMSchedulerTest(SchedulerCommonTest): class PNDMSchedulerTest(SchedulerCommonTest):
...@@ -365,8 +397,8 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -365,8 +397,8 @@ class PNDMSchedulerTest(SchedulerCommonTest):
new_scheduler.ets = dummy_past_residuals[:] new_scheduler.ets = dummy_past_residuals[:]
new_scheduler.set_plms_mode() new_scheduler.set_plms_mode()
output = scheduler.step(residual, sample, time_step, **kwargs) output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step(residual, sample, time_step, **kwargs) new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
...@@ -392,8 +424,8 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -392,8 +424,8 @@ class PNDMSchedulerTest(SchedulerCommonTest):
new_scheduler.ets = dummy_past_residuals[:] new_scheduler.ets = dummy_past_residuals[:]
new_scheduler.set_plms_mode() new_scheduler.set_plms_mode()
output = scheduler.step(residual, sample, time_step, **kwargs) output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step(residual, sample, time_step, **kwargs) new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
...@@ -445,7 +477,7 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -445,7 +477,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
scheduler.set_plms_mode() scheduler.set_plms_mode()
scheduler.step(self.dummy_sample, self.dummy_sample, 1, 50) scheduler.step(self.dummy_sample, 1, self.dummy_sample, 50)["prev_sample"]
def test_full_loop_no_noise(self): def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
...@@ -461,14 +493,14 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -461,14 +493,14 @@ class PNDMSchedulerTest(SchedulerCommonTest):
t_orig = prk_time_steps[t] t_orig = prk_time_steps[t]
residual = model(sample, t_orig) residual = model(sample, t_orig)
sample = scheduler.step_prk(residual, sample, t, num_inference_steps) sample = scheduler.step_prk(residual, t, sample, num_inference_steps)["prev_sample"]
timesteps = scheduler.get_time_steps(num_inference_steps) timesteps = scheduler.get_time_steps(num_inference_steps)
for t in range(len(timesteps)): for t in range(len(timesteps)):
t_orig = timesteps[t] t_orig = timesteps[t]
residual = model(sample, t_orig) residual = model(sample, t_orig)
sample = scheduler.step_plms(residual, sample, t, num_inference_steps) sample = scheduler.step_plms(residual, t, sample, num_inference_steps)["prev_sample"]
result_sum = np.sum(np.abs(sample)) result_sum = np.sum(np.abs(sample))
result_mean = np.mean(np.abs(sample)) result_mean = np.mean(np.abs(sample))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment