Commit 12da0fe1 authored by patil-suraj's avatar patil-suraj
Browse files

Merge branch 'main' into model-tests

parents cf6cd395 9c96682a
...@@ -258,10 +258,6 @@ class ConfigMixin: ...@@ -258,10 +258,6 @@ class ConfigMixin:
class FrozenDict(OrderedDict): class FrozenDict(OrderedDict):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
# remove `None`
args = (a for a in args if a is not None)
kwargs = {k: v for k, v in kwargs if v is not None}
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
for key, value in self.items(): for key, value in self.items():
......
# model adapted from diffuser https://github.com/jannerm/diffuser/blob/main/diffuser/models/temporal.py
import torch
import torch.nn as nn
import einops
from einops.layers.torch import Rearrange
import math
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class Downsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
def forward(self, x):
return self.conv(x)
class Upsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
def forward(self, x):
return self.conv(x)
class Conv1dBlock(nn.Module):
'''
Conv1d --> GroupNorm --> Mish
'''
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
Rearrange('batch channels horizon -> batch channels 1 horizon'),
nn.GroupNorm(n_groups, out_channels),
Rearrange('batch channels 1 horizon -> batch channels horizon'),
nn.Mish(),
)
def forward(self, x):
return self.block(x)
class ResidualTemporalBlock(nn.Module):
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
super().__init__()
self.blocks = nn.ModuleList([
Conv1dBlock(inp_channels, out_channels, kernel_size),
Conv1dBlock(out_channels, out_channels, kernel_size),
])
self.time_mlp = nn.Sequential(
nn.Mish(),
nn.Linear(embed_dim, out_channels),
Rearrange('batch t -> batch t 1'),
)
self.residual_conv = nn.Conv1d(inp_channels, out_channels, 1) \
if inp_channels != out_channels else nn.Identity()
def forward(self, x, t):
'''
x : [ batch_size x inp_channels x horizon ]
t : [ batch_size x embed_dim ]
returns:
out : [ batch_size x out_channels x horizon ]
'''
out = self.blocks[0](x) + self.time_mlp(t)
out = self.blocks[1](out)
return out + self.residual_conv(x)
class TemporalUnet(nn.Module):
def __init__(
self,
horizon,
transition_dim,
cond_dim,
dim=32,
dim_mults=(1, 2, 4, 8),
):
super().__init__()
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
print(f'[ models/temporal ] Channel dimensions: {in_out}')
time_dim = dim
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(dim, dim * 4),
nn.Mish(),
nn.Linear(dim * 4, dim),
)
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
print(in_out)
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
self.downs.append(nn.ModuleList([
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=horizon),
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=horizon),
Downsample1d(dim_out) if not is_last else nn.Identity()
]))
if not is_last:
horizon = horizon // 2
mid_dim = dims[-1]
self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=horizon)
self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=horizon)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1)
self.ups.append(nn.ModuleList([
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=horizon),
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=horizon),
Upsample1d(dim_in) if not is_last else nn.Identity()
]))
if not is_last:
horizon = horizon * 2
self.final_conv = nn.Sequential(
Conv1dBlock(dim, dim, kernel_size=5),
nn.Conv1d(dim, transition_dim, 1),
)
def forward(self, x, cond, time):
'''
x : [ batch x horizon x transition ]
'''
x = einops.rearrange(x, 'b h t -> b t h')
t = self.time_mlp(time)
h = []
for resnet, resnet2, downsample in self.downs:
x = resnet(x, t)
x = resnet2(x, t)
h.append(x)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.mid_block2(x, t)
for resnet, resnet2, upsample in self.ups:
x = torch.cat((x, h.pop()), dim=1)
x = resnet(x, t)
x = resnet2(x, t)
x = upsample(x)
x = self.final_conv(x)
x = einops.rearrange(x, 'b t h -> b h t')
return x
class TemporalValue(nn.Module):
def __init__(
self,
horizon,
transition_dim,
cond_dim,
dim=32,
time_dim=None,
out_dim=1,
dim_mults=(1, 2, 4, 8),
):
super().__init__()
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
time_dim = time_dim or dim
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(dim, dim * 4),
nn.Mish(),
nn.Linear(dim * 4, dim),
)
self.blocks = nn.ModuleList([])
print(in_out)
for dim_in, dim_out in in_out:
self.blocks.append(nn.ModuleList([
ResidualTemporalBlock(dim_in, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon),
ResidualTemporalBlock(dim_out, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon),
Downsample1d(dim_out)
]))
horizon = horizon // 2
fc_dim = dims[-1] * max(horizon, 1)
self.final_block = nn.Sequential(
nn.Linear(fc_dim + time_dim, fc_dim // 2),
nn.Mish(),
nn.Linear(fc_dim // 2, out_dim),
)
def forward(self, x, cond, time, *args):
'''
x : [ batch x horizon x transition ]
'''
x = einops.rearrange(x, 'b h t -> b t h')
t = self.time_mlp(time)
for resnet, resnet2, downsample in self.blocks:
x = resnet(x, t)
x = resnet2(x, t)
x = downsample(x)
x = x.view(len(x), -1)
out = self.final_block(torch.cat([x, t], dim=-1))
return out
\ No newline at end of file
# Pipelines
- Pipelines are a collection of end-to-end diffusion systems that can be used out-of-the-box
- Pipelines should stay as close as possible to their original implementation
- Pipelines can include components of other library, such as text-encoders.
## API
TODO(Patrick, Anton, Suraj)
## Examples
- DDPM for unconditional image generation in [pipeline_ddpm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_ddpm.py).
- DDIM for unconditional image generation in [pipeline_ddim](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_ddim.py).
- PNDM for unconditional image generation in [pipeline_pndm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py).
- Latent diffusion for text to image generation / conditional image generation in [pipeline_ddpm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_bddm.py).
- Glide for text to image generation / conditional image generation in [pipeline_ddpm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_bddm.py).
- BDDM for spectrogram-to-sound vocoding in [pipeline_ddpm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_bddm.py).
- Grad-TTS for text to audio generation / conditional audio generation in [pipeline_ddpm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_bddm.py).
...@@ -46,7 +46,7 @@ class LatentDiffusion(DiffusionPipeline): ...@@ -46,7 +46,7 @@ class LatentDiffusion(DiffusionPipeline):
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device) text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device)
text_embedding = self.bert(text_input.input_ids)[0] text_embedding = self.bert(text_input.input_ids)[0]
num_trained_timesteps = self.noise_scheduler.timesteps num_trained_timesteps = self.noise_scheduler.config.timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps) inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
image = self.noise_scheduler.sample_noise( image = self.noise_scheduler.sample_noise(
......
...@@ -291,7 +291,7 @@ class BDDM(DiffusionPipeline): ...@@ -291,7 +291,7 @@ class BDDM(DiffusionPipeline):
# Sample gaussian noise to begin loop # Sample gaussian noise to begin loop
audio = torch.normal(0, 1, size=audio_size, generator=generator).to(torch_device) audio = torch.normal(0, 1, size=audio_size, generator=generator).to(torch_device)
timestep_values = self.noise_scheduler.timestep_values timestep_values = self.noise_scheduler.get_timestep_values()
num_prediction_steps = len(self.noise_scheduler) num_prediction_steps = len(self.noise_scheduler)
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps): for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
# 1. predict noise residual # 1. predict noise residual
......
...@@ -32,7 +32,7 @@ class DDIM(DiffusionPipeline): ...@@ -32,7 +32,7 @@ class DDIM(DiffusionPipeline):
if torch_device is None: if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu" torch_device = "cuda" if torch.cuda.is_available() else "cpu"
num_trained_timesteps = self.noise_scheduler.timesteps num_trained_timesteps = self.noise_scheduler.config.timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps) inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
self.unet.to(torch_device) self.unet.to(torch_device)
......
...@@ -897,7 +897,7 @@ class LatentDiffusion(DiffusionPipeline): ...@@ -897,7 +897,7 @@ class LatentDiffusion(DiffusionPipeline):
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device) text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device)
text_embedding = self.bert(text_input.input_ids)[0] text_embedding = self.bert(text_input.input_ids)[0]
num_trained_timesteps = self.noise_scheduler.timesteps num_trained_timesteps = self.noise_scheduler.config.timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps) inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
image = torch.randn( image = torch.randn(
......
...@@ -42,9 +42,9 @@ class PNDM(DiffusionPipeline): ...@@ -42,9 +42,9 @@ class PNDM(DiffusionPipeline):
) )
image = image.to(torch_device) image = image.to(torch_device)
warmup_time_steps = self.noise_scheduler.get_warmup_time_steps(num_inference_steps) prk_time_steps = self.noise_scheduler.get_prk_time_steps(num_inference_steps)
for t in tqdm.tqdm(range(len(warmup_time_steps))): for t in tqdm.tqdm(range(len(prk_time_steps))):
t_orig = warmup_time_steps[t] t_orig = prk_time_steps[t]
residual = self.unet(image, t_orig) residual = self.unet(image, t_orig)
image = self.noise_scheduler.step_prk(residual, image, t, num_inference_steps) image = self.noise_scheduler.step_prk(residual, image, t, num_inference_steps)
......
...@@ -61,7 +61,6 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin): ...@@ -61,7 +61,6 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin):
timesteps=timesteps, timesteps=timesteps,
beta_schedule=beta_schedule, beta_schedule=beta_schedule,
) )
self.timesteps = int(timesteps)
if beta_schedule == "squaredcos_cap_v2": if beta_schedule == "squaredcos_cap_v2":
# GLIDE cosine schedule # GLIDE cosine schedule
...@@ -94,4 +93,4 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin): ...@@ -94,4 +93,4 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin):
return torch.randn(shape, generator=generator).to(device) return torch.randn(shape, generator=generator).to(device)
def __len__(self): def __len__(self):
return self.timesteps return self.config.timesteps
...@@ -37,10 +37,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -37,10 +37,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
beta_start=beta_start, beta_start=beta_start,
beta_end=beta_end, beta_end=beta_end,
beta_schedule=beta_schedule, beta_schedule=beta_schedule,
trained_betas=trained_betas,
timestep_values=timestep_values,
clip_sample=clip_sample,
) )
self.timesteps = int(timesteps)
self.timestep_values = timestep_values # save the fixed timestep values for BDDM
self.clip_sample = clip_sample
if beta_schedule == "linear": if beta_schedule == "linear":
self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end) self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
...@@ -81,6 +81,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -81,6 +81,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# ) # )
# self.alphas = 1.0 - self.betas # self.alphas = 1.0 - self.betas
# self.alphas_cumprod = np.cumprod(self.alphas, axis=0) # self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
def get_timestep_values(self):
return self.config.timestep_values
def get_alpha(self, time_step): def get_alpha(self, time_step):
return self.alphas[time_step] return self.alphas[time_step]
...@@ -96,7 +98,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -96,7 +98,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
def get_orig_t(self, t, num_inference_steps): def get_orig_t(self, t, num_inference_steps):
if t < 0: if t < 0:
return -1 return -1
return self.timesteps // num_inference_steps * t return self.config.timesteps // num_inference_steps * t
def get_variance(self, t, num_inference_steps): def get_variance(self, t, num_inference_steps):
orig_t = self.get_orig_t(t, num_inference_steps) orig_t = self.get_orig_t(t, num_inference_steps)
...@@ -137,7 +139,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -137,7 +139,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
pred_original_sample = (sample - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5) pred_original_sample = (sample - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5)
# 4. Clip "predicted x_0" # 4. Clip "predicted x_0"
if self.clip_sample: if self.config.clip_sample:
pred_original_sample = self.clip(pred_original_sample, -1, 1) pred_original_sample = self.clip(pred_original_sample, -1, 1)
# 5. compute variance: "sigma_t(η)" -> see formula (16) # 5. compute variance: "sigma_t(η)" -> see formula (16)
...@@ -158,4 +160,4 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -158,4 +160,4 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
return pred_prev_sample return pred_prev_sample
def __len__(self): def __len__(self):
return self.timesteps return self.config.timesteps
...@@ -43,10 +43,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -43,10 +43,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
variance_type=variance_type, variance_type=variance_type,
clip_sample=clip_sample, clip_sample=clip_sample,
) )
self.timesteps = int(timesteps)
self.timestep_values = timestep_values # save the fixed timestep values for BDDM
self.clip_sample = clip_sample
self.variance_type = variance_type
if trained_betas is not None: if trained_betas is not None:
self.betas = np.asarray(trained_betas) self.betas = np.asarray(trained_betas)
...@@ -83,6 +79,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -83,6 +79,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# #
# #
# self.register_buffer("log_variance", log_variance.to(torch.float32)) # self.register_buffer("log_variance", log_variance.to(torch.float32))
def get_timestep_values(self):
return self.config.timestep_values
def get_alpha(self, time_step): def get_alpha(self, time_step):
return self.alphas[time_step] return self.alphas[time_step]
...@@ -105,14 +103,17 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -105,14 +103,17 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.get_beta(t) variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.get_beta(t)
# hacks - were probs added for training stability # hacks - were probs added for training stability
if self.variance_type == "fixed_small": if self.config.variance_type == "fixed_small":
variance = self.clip(variance, min_value=1e-20) variance = self.clip(variance, min_value=1e-20)
elif self.variance_type == "fixed_large": # for rl-diffuser https://arxiv.org/abs/2205.09991
elif self.config.variance_type == "fixed_small_log":
variance = self.log(self.clip(variance, min_value=1e-20))
elif self.config.variance_type == "fixed_large":
variance = self.get_beta(t) variance = self.get_beta(t)
return variance return variance
def step(self, residual, sample, t): def step(self, residual, sample, t, predict_epsilon=True):
# 1. compute alphas, betas # 1. compute alphas, betas
alpha_prod_t = self.get_alpha_prod(t) alpha_prod_t = self.get_alpha_prod(t)
alpha_prod_t_prev = self.get_alpha_prod(t - 1) alpha_prod_t_prev = self.get_alpha_prod(t - 1)
...@@ -121,10 +122,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -121,10 +122,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# 2. compute predicted original sample from predicted noise also called # 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
if predict_epsilon:
pred_original_sample = (sample - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5) pred_original_sample = (sample - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5)
else:
pred_original_sample = residual
# 3. Clip "predicted x_0" # 3. Clip "predicted x_0"
if self.clip_sample: if self.config.clip_sample:
pred_original_sample = self.clip(pred_original_sample, -1, 1) pred_original_sample = self.clip(pred_original_sample, -1, 1)
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
...@@ -145,4 +149,4 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -145,4 +149,4 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return noisy_sample return noisy_sample
def __len__(self): def __len__(self):
return self.timesteps return self.config.timesteps
...@@ -30,8 +30,6 @@ class GradTTSScheduler(SchedulerMixin, ConfigMixin): ...@@ -30,8 +30,6 @@ class GradTTSScheduler(SchedulerMixin, ConfigMixin):
beta_start=beta_start, beta_start=beta_start,
beta_end=beta_end, beta_end=beta_end,
) )
self.timesteps = int(timesteps)
self.set_format(tensor_format=tensor_format) self.set_format(tensor_format=tensor_format)
def sample_noise(self, timestep): def sample_noise(self, timestep):
...@@ -46,4 +44,4 @@ class GradTTSScheduler(SchedulerMixin, ConfigMixin): ...@@ -46,4 +44,4 @@ class GradTTSScheduler(SchedulerMixin, ConfigMixin):
return xt return xt
def __len__(self): def __len__(self):
return self.timesteps return len(self.config.timesteps)
...@@ -35,7 +35,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -35,7 +35,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
beta_end=beta_end, beta_end=beta_end,
beta_schedule=beta_schedule, beta_schedule=beta_schedule,
) )
self.timesteps = int(timesteps)
if beta_schedule == "linear": if beta_schedule == "linear":
self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end) self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
...@@ -57,15 +56,16 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -57,15 +56,16 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
# For now we only support F-PNDM, i.e. the runge-kutta method # For now we only support F-PNDM, i.e. the runge-kutta method
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
# mainly at equations (12) and (13) and the Algorithm 2. # mainly at formula (9), (12), (13) and the Algorithm 2.
self.pndm_order = 4 self.pndm_order = 4
# running values # running values
self.cur_residual = 0 self.cur_residual = 0
self.cur_sample = None self.cur_sample = None
self.ets = [] self.ets = []
self.warmup_time_steps = {} self.prk_time_steps = {}
self.time_steps = {} self.time_steps = {}
self.set_prk_mode()
def get_alpha(self, time_step): def get_alpha(self, time_step):
return self.alphas[time_step] return self.alphas[time_step]
...@@ -78,34 +78,47 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -78,34 +78,47 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return self.one return self.one
return self.alphas_cumprod[time_step] return self.alphas_cumprod[time_step]
def get_warmup_time_steps(self, num_inference_steps): def get_prk_time_steps(self, num_inference_steps):
if num_inference_steps in self.warmup_time_steps: if num_inference_steps in self.prk_time_steps:
return self.warmup_time_steps[num_inference_steps] return self.prk_time_steps[num_inference_steps]
inference_step_times = list(range(0, self.timesteps, self.timesteps // num_inference_steps)) inference_step_times = list(range(0, self.config.timesteps, self.config.timesteps // num_inference_steps))
warmup_time_steps = np.array(inference_step_times[-self.pndm_order :]).repeat(2) + np.tile( prk_time_steps = np.array(inference_step_times[-self.pndm_order :]).repeat(2) + np.tile(
np.array([0, self.timesteps // num_inference_steps // 2]), self.pndm_order np.array([0, self.config.timesteps // num_inference_steps // 2]), self.pndm_order
) )
self.warmup_time_steps[num_inference_steps] = list(reversed(warmup_time_steps[:-1].repeat(2)[1:-1])) self.prk_time_steps[num_inference_steps] = list(reversed(prk_time_steps[:-1].repeat(2)[1:-1]))
return self.warmup_time_steps[num_inference_steps] return self.prk_time_steps[num_inference_steps]
def get_time_steps(self, num_inference_steps): def get_time_steps(self, num_inference_steps):
if num_inference_steps in self.time_steps: if num_inference_steps in self.time_steps:
return self.time_steps[num_inference_steps] return self.time_steps[num_inference_steps]
inference_step_times = list(range(0, self.timesteps, self.timesteps // num_inference_steps)) inference_step_times = list(range(0, self.config.timesteps, self.config.timesteps // num_inference_steps))
self.time_steps[num_inference_steps] = list(reversed(inference_step_times[:-3])) self.time_steps[num_inference_steps] = list(reversed(inference_step_times[:-3]))
return self.time_steps[num_inference_steps] return self.time_steps[num_inference_steps]
def set_prk_mode(self):
self.mode = "prk"
def set_plms_mode(self):
self.mode = "plms"
def step(self, *args, **kwargs):
if self.mode == "prk":
return self.step_prk(*args, **kwargs)
if self.mode == "plms":
return self.step_plms(*args, **kwargs)
raise ValueError(f"mode {self.mode} does not exist.")
def step_prk(self, residual, sample, t, num_inference_steps): def step_prk(self, residual, sample, t, num_inference_steps):
# TODO(Patrick) - need to rethink whether the "warmup" way is the correct API design here prk_time_steps = self.get_prk_time_steps(num_inference_steps)
warmup_time_steps = self.get_warmup_time_steps(num_inference_steps)
t_prev = warmup_time_steps[t // 4 * 4] t_orig = prk_time_steps[t // 4 * 4]
t_next = warmup_time_steps[min(t + 1, len(warmup_time_steps) - 1)] t_orig_prev = prk_time_steps[min(t + 1, len(prk_time_steps) - 1)]
if t % 4 == 0: if t % 4 == 0:
self.cur_residual += 1 / 6 * residual self.cur_residual += 1 / 6 * residual
...@@ -119,33 +132,63 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -119,33 +132,63 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
residual = self.cur_residual + 1 / 6 * residual residual = self.cur_residual + 1 / 6 * residual
self.cur_residual = 0 self.cur_residual = 0
return self.transfer(self.cur_sample, t_prev, t_next, residual) # cur_sample should not be `None`
cur_sample = self.cur_sample if self.cur_sample is not None else sample
return self.get_prev_sample(cur_sample, t_orig, t_orig_prev, residual)
def step_plms(self, residual, sample, t, num_inference_steps): def step_plms(self, residual, sample, t, num_inference_steps):
if len(self.ets) < 3:
raise ValueError(
f"{self.__class__} can only be run AFTER scheduler has been run "
"in 'prk' mode for at least 12 iterations "
"See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py "
"for more information."
)
timesteps = self.get_time_steps(num_inference_steps) timesteps = self.get_time_steps(num_inference_steps)
t_prev = timesteps[t] t_orig = timesteps[t]
t_next = timesteps[min(t + 1, len(timesteps) - 1)] t_orig_prev = timesteps[min(t + 1, len(timesteps) - 1)]
self.ets.append(residual) self.ets.append(residual)
residual = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4]) residual = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
return self.transfer(sample, t_prev, t_next, residual) return self.get_prev_sample(sample, t_orig, t_orig_prev, residual)
def transfer(self, x, t, t_next, et): def get_prev_sample(self, sample, t_orig, t_orig_prev, residual):
# TODO(Patrick): clean up to be compatible with numpy and give better names # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
# this function computes x_(t−δ) using the formula of (9)
alphas_cump = self.alphas_cumprod.to(x.device) # Note that x_t needs to be added to both sides of the equation
at = alphas_cump[t + 1].view(-1, 1, 1, 1)
at_next = alphas_cump[t_next + 1].view(-1, 1, 1, 1) # Notation (<variable name> -> <name in paper>
# alpha_prod_t -> α_t
x_delta = (at_next - at) * ( # alpha_prod_t_prev -> α_(t−δ)
(1 / (at.sqrt() * (at.sqrt() + at_next.sqrt()))) * x # beta_prod_t -> (1 - α_t)
- 1 / (at.sqrt() * (((1 - at_next) * at).sqrt() + ((1 - at) * at_next).sqrt())) * et # beta_prod_t_prev -> (1 - α_(t−δ))
) # sample -> x_t
# residual -> e_θ(x_t, t)
x_next = x + x_delta # prev_sample -> x_(t−δ)
return x_next alpha_prod_t = self.get_alpha_prod(t_orig + 1)
alpha_prod_t_prev = self.get_alpha_prod(t_orig_prev + 1)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
# corresponds to (α_(t−δ) - α_t) divided by
# denominator of x_t in formula (9) and plus 1
# Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) =
# sqrt(α_(t−δ)) / sqrt(α_t))
sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5)
# corresponds to denominator of e_θ(x_t, t) in formula (9)
residual_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + (
alpha_prod_t * beta_prod_t * alpha_prod_t_prev
) ** (0.5)
# full formula (9)
prev_sample = sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * residual / residual_denom_coeff
return prev_sample
def __len__(self): def __len__(self):
return self.timesteps return self.config.timesteps
...@@ -64,3 +64,13 @@ class SchedulerMixin: ...@@ -64,3 +64,13 @@ class SchedulerMixin:
return torch.clamp(tensor, min_value, max_value) return torch.clamp(tensor, min_value, max_value)
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def log(self, tensor):
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
return np.log(tensor)
elif tensor_format == "pt":
return torch.log(tensor)
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
...@@ -20,7 +20,7 @@ import unittest ...@@ -20,7 +20,7 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from diffusers import DDIMScheduler, DDPMScheduler from diffusers import DDIMScheduler, DDPMScheduler, PNDMScheduler
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
...@@ -31,37 +31,37 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -31,37 +31,37 @@ class SchedulerCommonTest(unittest.TestCase):
forward_default_kwargs = () forward_default_kwargs = ()
@property @property
def dummy_image(self): def dummy_sample(self):
batch_size = 4 batch_size = 4
num_channels = 3 num_channels = 3
height = 8 height = 8
width = 8 width = 8
image = np.random.rand(batch_size, num_channels, height, width) sample = np.random.rand(batch_size, num_channels, height, width)
return image return sample
@property @property
def dummy_image_deter(self): def dummy_sample_deter(self):
batch_size = 4 batch_size = 4
num_channels = 3 num_channels = 3
height = 8 height = 8
width = 8 width = 8
num_elems = batch_size * num_channels * height * width num_elems = batch_size * num_channels * height * width
image = np.arange(num_elems) sample = np.arange(num_elems)
image = image.reshape(num_channels, height, width, batch_size) sample = sample.reshape(num_channels, height, width, batch_size)
image = image / num_elems sample = sample / num_elems
image = image.transpose(3, 0, 1, 2) sample = sample.transpose(3, 0, 1, 2)
return image return sample
def get_scheduler_config(self): def get_scheduler_config(self):
raise NotImplementedError raise NotImplementedError
def dummy_model(self): def dummy_model(self):
def model(image, t, *args): def model(sample, t, *args):
return image * t / (t + 1) return sample * t / (t + 1)
return model return model
...@@ -70,8 +70,8 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -70,8 +70,8 @@ class SchedulerCommonTest(unittest.TestCase):
for scheduler_class in self.scheduler_classes: for scheduler_class in self.scheduler_classes:
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
image = self.dummy_image sample = self.dummy_sample
residual = 0.1 * image residual = 0.1 * sample
scheduler_config = self.get_scheduler_config(**config) scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
...@@ -80,8 +80,8 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -80,8 +80,8 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler.save_config(tmpdirname) scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname) new_scheduler = scheduler_class.from_config(tmpdirname)
output = scheduler.step(residual, image, time_step, **kwargs) output = scheduler.step(residual, sample, time_step, **kwargs)
new_output = new_scheduler.step(residual, image, time_step, **kwargs) new_output = new_scheduler.step(residual, sample, time_step, **kwargs)
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
...@@ -90,10 +90,10 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -90,10 +90,10 @@ class SchedulerCommonTest(unittest.TestCase):
kwargs.update(forward_kwargs) kwargs.update(forward_kwargs)
for scheduler_class in self.scheduler_classes: for scheduler_class in self.scheduler_classes:
scheduler_class = self.scheduler_classes[0] sample = self.dummy_sample
image = self.dummy_image residual = 0.1 * sample
residual = 0.1 * image
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
...@@ -101,8 +101,8 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -101,8 +101,8 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler.save_config(tmpdirname) scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname) new_scheduler = scheduler_class.from_config(tmpdirname)
output = scheduler.step(residual, image, time_step, **kwargs) output = scheduler.step(residual, sample, time_step, **kwargs)
new_output = new_scheduler.step(residual, image, time_step, **kwargs) new_output = new_scheduler.step(residual, sample, time_step, **kwargs)
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
...@@ -110,8 +110,8 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -110,8 +110,8 @@ class SchedulerCommonTest(unittest.TestCase):
kwargs = dict(self.forward_default_kwargs) kwargs = dict(self.forward_default_kwargs)
for scheduler_class in self.scheduler_classes: for scheduler_class in self.scheduler_classes:
image = self.dummy_image sample = self.dummy_sample
residual = 0.1 * image residual = 0.1 * sample
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
...@@ -120,8 +120,8 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -120,8 +120,8 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler.save_config(tmpdirname) scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname) new_scheduler = scheduler_class.from_config(tmpdirname)
output = scheduler.step(residual, image, 1, **kwargs) output = scheduler.step(residual, sample, 1, **kwargs)
new_output = new_scheduler.step(residual, image, 1, **kwargs) new_output = new_scheduler.step(residual, sample, 1, **kwargs)
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
...@@ -132,34 +132,34 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -132,34 +132,34 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
image = self.dummy_image sample = self.dummy_sample
residual = 0.1 * image residual = 0.1 * sample
output_0 = scheduler.step(residual, image, 0, **kwargs) output_0 = scheduler.step(residual, sample, 0, **kwargs)
output_1 = scheduler.step(residual, image, 1, **kwargs) output_1 = scheduler.step(residual, sample, 1, **kwargs)
self.assertEqual(output_0.shape, image.shape) self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape) self.assertEqual(output_0.shape, output_1.shape)
def test_pytorch_equal_numpy(self): def test_pytorch_equal_numpy(self):
kwargs = dict(self.forward_default_kwargs) kwargs = dict(self.forward_default_kwargs)
for scheduler_class in self.scheduler_classes: for scheduler_class in self.scheduler_classes:
image = self.dummy_image sample = self.dummy_sample
residual = 0.1 * image residual = 0.1 * sample
image_pt = torch.tensor(image) sample_pt = torch.tensor(sample)
residual_pt = 0.1 * image_pt residual_pt = 0.1 * sample_pt
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config) scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)
output = scheduler.step(residual, image, 1, **kwargs) output = scheduler.step(residual, sample, 1, **kwargs)
output_pt = scheduler_pt.step(residual_pt, image_pt, 1, **kwargs) output_pt = scheduler_pt.step(residual_pt, sample_pt, 1, **kwargs)
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-5, "Scheduler outputs are not identical" assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
class DDPMSchedulerTest(SchedulerCommonTest): class DDPMSchedulerTest(SchedulerCommonTest):
...@@ -194,7 +194,7 @@ class DDPMSchedulerTest(SchedulerCommonTest): ...@@ -194,7 +194,7 @@ class DDPMSchedulerTest(SchedulerCommonTest):
for variance in ["fixed_small", "fixed_large", "other"]: for variance in ["fixed_small", "fixed_large", "other"]:
self.check_over_configs(variance_type=variance) self.check_over_configs(variance_type=variance)
def test_clip_image(self): def test_clip_sample(self):
for clip_sample in [True, False]: for clip_sample in [True, False]:
self.check_over_configs(clip_sample=clip_sample) self.check_over_configs(clip_sample=clip_sample)
...@@ -219,26 +219,26 @@ class DDPMSchedulerTest(SchedulerCommonTest): ...@@ -219,26 +219,26 @@ class DDPMSchedulerTest(SchedulerCommonTest):
num_trained_timesteps = len(scheduler) num_trained_timesteps = len(scheduler)
model = self.dummy_model() model = self.dummy_model()
image = self.dummy_image_deter sample = self.dummy_sample_deter
for t in reversed(range(num_trained_timesteps)): for t in reversed(range(num_trained_timesteps)):
# 1. predict noise residual # 1. predict noise residual
residual = model(image, t) residual = model(sample, t)
# 2. predict previous mean of image x_t-1 # 2. predict previous mean of sample x_t-1
pred_prev_image = scheduler.step(residual, image, t) pred_prev_sample = scheduler.step(residual, sample, t)
if t > 0: if t > 0:
noise = self.dummy_image_deter noise = self.dummy_sample_deter
variance = scheduler.get_variance(t) ** (0.5) * noise variance = scheduler.get_variance(t) ** (0.5) * noise
image = pred_prev_image + variance sample = pred_prev_sample + variance
result_sum = np.sum(np.abs(image)) result_sum = np.sum(np.abs(sample))
result_mean = np.mean(np.abs(image)) result_mean = np.mean(np.abs(sample))
assert result_sum.item() - 732.9947 < 1e-3 assert abs(result_sum.item() - 732.9947) < 1e-2
assert result_mean.item() - 0.9544 < 1e-3 assert abs(result_mean.item() - 0.9544) < 1e-3
class DDIMSchedulerTest(SchedulerCommonTest): class DDIMSchedulerTest(SchedulerCommonTest):
...@@ -269,7 +269,7 @@ class DDIMSchedulerTest(SchedulerCommonTest): ...@@ -269,7 +269,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
for schedule in ["linear", "squaredcos_cap_v2"]: for schedule in ["linear", "squaredcos_cap_v2"]:
self.check_over_configs(beta_schedule=schedule) self.check_over_configs(beta_schedule=schedule)
def test_clip_image(self): def test_clip_sample(self):
for clip_sample in [True, False]: for clip_sample in [True, False]:
self.check_over_configs(clip_sample=clip_sample) self.check_over_configs(clip_sample=clip_sample)
...@@ -308,22 +308,170 @@ class DDIMSchedulerTest(SchedulerCommonTest): ...@@ -308,22 +308,170 @@ class DDIMSchedulerTest(SchedulerCommonTest):
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps) inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
model = self.dummy_model() model = self.dummy_model()
image = self.dummy_image_deter sample = self.dummy_sample_deter
for t in reversed(range(num_inference_steps)): for t in reversed(range(num_inference_steps)):
residual = model(image, inference_step_times[t]) residual = model(sample, inference_step_times[t])
pred_prev_image = scheduler.step(residual, image, t, num_inference_steps, eta) pred_prev_sample = scheduler.step(residual, sample, t, num_inference_steps, eta)
variance = 0 variance = 0
if eta > 0: if eta > 0:
noise = self.dummy_image_deter noise = self.dummy_sample_deter
variance = scheduler.get_variance(t, num_inference_steps) ** (0.5) * eta * noise variance = scheduler.get_variance(t, num_inference_steps) ** (0.5) * eta * noise
image = pred_prev_image + variance sample = pred_prev_sample + variance
result_sum = np.sum(np.abs(sample))
result_mean = np.mean(np.abs(sample))
assert abs(result_sum.item() - 270.6214) < 1e-2
assert abs(result_mean.item() - 0.3524) < 1e-3
class PNDMSchedulerTest(SchedulerCommonTest):
scheduler_classes = (PNDMScheduler,)
forward_default_kwargs = (("num_inference_steps", 50),)
def get_scheduler_config(self, **kwargs):
config = {
"timesteps": 1000,
"beta_start": 0.0001,
"beta_end": 0.02,
"beta_schedule": "linear",
}
config.update(**kwargs)
return config
def check_over_configs_pmls(self, time_step=0, **config):
kwargs = dict(self.forward_default_kwargs)
sample = self.dummy_sample
residual = 0.1 * sample
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
for scheduler_class in self.scheduler_classes:
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
# copy over dummy past residuals
scheduler.ets = dummy_past_residuals[:]
scheduler.set_plms_mode()
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
# copy over dummy past residuals
new_scheduler.ets = dummy_past_residuals[:]
new_scheduler.set_plms_mode()
output = scheduler.step(residual, sample, time_step, **kwargs)
new_output = new_scheduler.step(residual, sample, time_step, **kwargs)
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def check_over_forward_pmls(self, time_step=0, **forward_kwargs):
kwargs = dict(self.forward_default_kwargs)
kwargs.update(forward_kwargs)
sample = self.dummy_sample
residual = 0.1 * sample
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
for scheduler_class in self.scheduler_classes:
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
# copy over dummy past residuals
scheduler.ets = dummy_past_residuals[:]
scheduler.set_plms_mode()
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
# copy over dummy past residuals
new_scheduler.ets = dummy_past_residuals[:]
new_scheduler.set_plms_mode()
output = scheduler.step(residual, sample, time_step, **kwargs)
new_output = new_scheduler.step(residual, sample, time_step, **kwargs)
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def test_timesteps(self):
for timesteps in [100, 1000]:
self.check_over_configs(timesteps=timesteps)
def test_timesteps_pmls(self):
for timesteps in [100, 1000]:
self.check_over_configs_pmls(timesteps=timesteps)
def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]):
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
def test_betas_pmls(self):
for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]):
self.check_over_configs_pmls(beta_start=beta_start, beta_end=beta_end)
def test_schedules(self):
for schedule in ["linear", "squaredcos_cap_v2"]:
self.check_over_configs(beta_schedule=schedule)
def test_schedules_pmls(self):
for schedule in ["linear", "squaredcos_cap_v2"]:
self.check_over_configs(beta_schedule=schedule)
def test_time_indices(self):
for t in [1, 5, 10]:
self.check_over_forward(time_step=t)
def test_time_indices_pmls(self):
for t in [1, 5, 10]:
self.check_over_forward_pmls(time_step=t)
def test_inference_steps(self):
for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
def test_inference_steps_pmls(self):
for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
self.check_over_forward_pmls(time_step=t, num_inference_steps=num_inference_steps)
def test_inference_pmls_no_past_residuals(self):
with self.assertRaises(ValueError):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_plms_mode()
scheduler.step(self.dummy_sample, self.dummy_sample, 1, 50)
def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
num_inference_steps = 10
model = self.dummy_model()
sample = self.dummy_sample_deter
prk_time_steps = scheduler.get_prk_time_steps(num_inference_steps)
for t in range(len(prk_time_steps)):
t_orig = prk_time_steps[t]
residual = model(sample, t_orig)
sample = scheduler.step_prk(residual, sample, t, num_inference_steps)
timesteps = scheduler.get_time_steps(num_inference_steps)
for t in range(len(timesteps)):
t_orig = timesteps[t]
residual = model(sample, t_orig)
sample = scheduler.step_plms(residual, sample, t, num_inference_steps)
result_sum = np.sum(np.abs(image)) result_sum = np.sum(np.abs(sample))
result_mean = np.mean(np.abs(image)) result_mean = np.mean(np.abs(sample))
assert result_sum.item() - 270.6214 < 1e-3 assert abs(result_sum.item() - 199.1169) < 1e-2
assert result_mean.item() - 0.3524 < 1e-3 assert abs(result_mean.item() - 0.2593) < 1e-3
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment